chiark / gitweb /
Initial version.
[ocb-tv] / ocbgen
1 #! /usr/bin/python
2 ### -*-python-*-
3 ###
4 ### Generalization of OCB mode for other block sizes
5 ###
6 ### (c) 2017 Mark Wooding
7 ###
8
9 ###----- Licensing notice ---------------------------------------------------
10 ###
11 ### This program is free software; you can redistribute it and/or modify
12 ### it under the terms of the GNU General Public License as published by
13 ### the Free Software Foundation; either version 2 of the License, or
14 ### (at your option) any later version.
15 ###
16 ### This program is distributed in the hope that it will be useful,
17 ### but WITHOUT ANY WARRANTY; without even the implied warranty of
18 ### MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
19 ### GNU General Public License for more details.
20 ###
21 ### You should have received a copy of the GNU General Public License
22 ### along with this program; if not, write to the Free Software Foundation,
23 ### Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
24
25 from sys import argv, stderr
26 from struct import pack
27 from itertools import izip
28 import catacomb as C
29
30 R = C.FibRand(0)
31
32 ###--------------------------------------------------------------------------
33 ### Utilities.
34
35 def combs(things, k):
36   ii = range(k)
37   n = len(things)
38   while True:
39     yield [things[i] for i in ii]
40     for j in xrange(k):
41       if j == k - 1: lim = n
42       else: lim = ii[j + 1]
43       i = ii[j] + 1
44       if i < lim:
45         ii[j] = i
46         break
47       ii[j] = j
48     else:
49       return
50
51 POLYMAP = {}
52
53 def poly(nbits):
54   try: return POLYMAP[nbits]
55   except KeyError: pass
56   base = C.GF(0).setbit(nbits).setbit(0)
57   for k in xrange(1, nbits, 2):
58     for cc in combs(range(1, nbits), k):
59       p = base + sum(C.GF(0).setbit(c) for c in cc)
60       if p.irreduciblep(): POLYMAP[nbits] = p; return p
61   raise ValueError, nbits
62
63 def prim(nbits):
64   ## No fancy way to do this: I'd need a much cleverer factoring algorithm
65   ## than I have in my pockets.
66   if nbits == 64: cc = [64, 4, 3, 1, 0]
67   elif nbits == 96: cc = [96, 10, 9, 6, 0]
68   elif nbits == 128: cc = [128, 7, 2, 1, 0]
69   elif nbits == 192: cc = [192, 15, 11, 5, 0]
70   elif nbits == 256: cc = [256, 10, 5, 2, 0]
71   else: raise ValueError, 'no field for %d bits' % nbits
72   p = C.GF(0)
73   for c in cc: p = p.setbit(c)
74   return p
75
76 def Z(n):
77   return C.ByteString.zero(n)
78
79 def mul_blk_gf(m, x, p): return ((C.GF.loadb(m)*x)%p).storeb((p.nbits + 6)/8)
80
81 def with_lastp(it):
82   it = iter(it)
83   try: j = next(it)
84   except StopIteration: raise ValueError, 'empty iter'
85   lastp = False
86   while not lastp:
87     i = j
88     try: j = next(it)
89     except StopIteration: lastp = True
90     yield i, lastp
91
92 def safehex(x):
93   if len(x): return hex(x)
94   else: return '""'
95
96 def keylens(ksz):
97   sel = []
98   if isinstance(ksz, C.KeySZSet): kk = ksz.set
99   elif isinstance(ksz, C.KeySZRange): kk = range(ksz.min, ksz.max, ksz.mod)
100   elif isinstance(ksz, C.KeySZAny): kk = range(64); sel = [0]
101   kk = list(kk); kk = kk[:]
102   n = len(kk)
103   while n and len(sel) < 4:
104     i = R.range(n)
105     n -= 1
106     kk[i], kk[n] = kk[n], kk[i]
107     sel.append(kk[n])
108   return sel
109
110 def pad0star(m, w):
111   n = len(m)
112   if not n: r = w
113   else: r = (-len(m))%w
114   if r: m += Z(r)
115   return C.ByteString(m)
116
117 def pad10star(m, w):
118   r = w - len(m)%w
119   if r: m += '\x80' + Z(r - 1)
120   return C.ByteString(m)
121
122 def ntz(i):
123   j = 0
124   while (i&1) == 0: i >>= 1; j += 1
125   return j
126
127 def blocks(x, w):
128   v, i, n = [], 0, len(x)
129   while n - i > w:
130     v.append(C.ByteString(x[i:i + w]))
131     i += w
132   return v, C.ByteString(x[i:])
133
134 EMPTY = C.bytes('')
135
136 def blocks0(x, w):
137   v, tl = blocks(x, w)
138   if len(tl) == w: v.append(tl); tl = EMPTY
139   return v, tl
140
141 ###--------------------------------------------------------------------------
142 ### Luby--Rackoff large-block ciphers.
143
144 class LubyRackoffCipher (type):
145   def __new__(cls, bc, blksz):
146     assert blksz%2 == 0
147     assert blksz <= 2*bc.blksz
148     name = '%s-lr[%d]' % (bc.name, 8*blksz)
149     me = type(name, (LubyRackoffBase,), {})
150     me.name = name
151     me.blksz = blksz
152     me.keysz = bc.keysz
153     me.bc = bc
154     return me
155
156 class LubyRackoffBase (object):
157   NR = 4 # for strong-PRP security
158   def __init__(me, k):
159     if LRVERBOSE: print 'K = %s' % hex(k)
160     bc, blksz = me.__class__.bc, me.__class__.blksz
161     E = bc(k)
162     me.f = []
163     ksz = len(k)
164     i = C.MP(0)
165     for j in xrange(me.NR):
166       b = C.WriteBuffer()
167       while b.size < ksz:
168         x = E.encrypt(i.storeb(bc.blksz))
169         b.put(x)
170         if LRVERBOSE: print 'E(K; [%d]) = %s' % (i, hex(x))
171         i += 1
172       kj = C.ByteString(C.ByteString(b)[0:ksz])
173       if LRVERBOSE: print 'K_%d = %s' % (j, hex(kj))
174       me.f.append(bc(kj))
175   def encrypt(me, m):
176     bc, blksz = me.__class__.bc, me.__class__.blksz
177     assert len(m) == blksz
178     l, r = C.ByteString(m[:blksz/2]), C.ByteString(m[blksz/2:])
179     if LRVERBOSE: print 'L_0, R_0 = %s, %s' % (hex(l), hex(r))
180     for j in xrange(me.NR):
181       l0 = pad0star(l, bc.blksz)
182       t = me.f[j].encrypt(l0)
183       l, r = r ^ t[:blksz/2], l
184       if LRVERBOSE:
185         print 'E(K_%d; L_%d || 0^*) = %s' % (j, j, hex(t))
186         print 'L_%d, R_%d = %s, %s' % (j + 1, j + 1, hex(l), hex(r))
187     return C.ByteString(r + l)
188   def decrypt(me, c):
189     bc, blksz = me.__class__.bc, me.__class__.blksz
190     assert len(c) == blksz
191     l, r = C.ByteString(c[:blksz/2]), C.ByteString(c[blksz/2:])
192     for j in xrange(me.NR - 1, -1, -1):
193       l0 = pad0star(l, bc.blksz)
194       t = me.f[j].encrypt(l0)
195       if LRVERBOSE:
196         print 'L_%d, R_%d = %s, %s' % (j + 1, j + 1, hex(l), hex(r))
197         print 'E(K_%d; L_%d || 0^*) = %s' % (j + 1, j + 1, hex(t))
198       l, r = r ^ t[:blksz/2], l
199     if LRVERBOSE: print 'L_0, R_0 = %s, %s' % (hex(l), hex(r))
200     return C.ByteString(r + l)
201
202 LRAES = {}
203 for i in [8, 12, 16, 24, 32]:
204   LRAES['lraes%d' % (8*i)] = LubyRackoffCipher(C.rijndael, i)
205
206 ###--------------------------------------------------------------------------
207 ### PMAC.
208
209 def ocb_masks(E):
210   blksz = E.__class__.blksz
211   p = poly(8*blksz)
212   x = C.GF(2); xinv = p.modinv(x)
213   z = Z(blksz)
214   L = E.encrypt(z)
215   Lxinv = mul_blk_gf(L, xinv, p)
216   Lgamma = 66*[L]
217   for i in xrange(1, len(Lgamma)):
218     Lgamma[i] = mul_blk_gf(Lgamma[i - 1], x, p)
219   return Lgamma, Lxinv
220
221 def dump_ocb(E):
222   Lgamma, Lxinv = ocb_masks(E)
223   print 'L x^-1 = %s' % hex(Lxinv)
224   for i, lg in enumerate(Lgamma):
225     print 'L x^%d = %s' % (i, hex(lg))
226
227 def pmac1(E, m):
228   blksz = E.__class__.blksz
229   Lgamma, Lxinv = ocb_masks(E)
230   a = o = Z(blksz)
231   i = 1
232   v, tl = blocks(m, blksz)
233   for x in v:
234     b = ntz(i)
235     o ^= Lgamma[b]
236     a ^= E.encrypt(x ^ o)
237     if VERBOSE:
238       print 'Z[%d]: %d -> %s' % (i - 1, b, hex(o))
239       print 'A[%d]: %s' % (i - 1, hex(a))
240     i += 1
241   if len(tl) == blksz: a ^= tl ^ Lxinv
242   else: a ^= pad10star(tl, blksz)
243   return E.encrypt(a)
244
245 def pmac2(E, m):
246   blksz = E.__class__.blksz
247   p = prim(8*blksz)
248   L = E.encrypt(Z(blksz))
249   o = mul_blk_gf(L, 10, p)
250   a = Z(blksz)
251   v, tl = blocks(m, blksz)
252   for x in v:
253     a ^= E.encrypt(x ^ o)
254     o = mul_blk_gf(o, 2, p)
255   if len(tl) == blksz: a ^= tl ^ mul_blk_gf(o, 3, p)
256   else: a ^= pad10star(tl, blksz) ^ mul_blk_gf(o, 5, p)
257   return E.encrypt(a)
258
259 def ocb3_masks(E):
260   Lgamma, _ = ocb_masks(E)
261   Lstar = Lgamma[0]
262   Ldollar = Lgamma[1]
263   return Lstar, Ldollar, Lgamma[2:]
264
265 def dump_ocb3(E):
266   Lstar, Ldollar, Lgamma = ocb3_masks(E)
267   print 'L_*       : %s' % hex(Lstar)
268   print 'L_$       : %s' % hex(Ldollar)
269   for i, lg in enumerate(Lgamma[:4]):
270     print 'L_%-8d: %s' % (i, hex(lg))
271
272 def pmac3(E, m):
273   blksz = E.__class__.blksz
274   Lstar, Ldollar, Lgamma = ocb3_masks(E)
275   a = o = Z(blksz)
276   i = 1
277   v, tl = blocks0(m, blksz)
278   for x in v:
279     b = ntz(i)
280     o ^= Lgamma[b]
281     a ^= E.encrypt(x ^ o)
282     if VERBOSE:
283       print 'Offset\'_%-2d: %s' % (i, hex(o))
284       print 'AuthSum_%-2d: %s' % (i, hex(a))
285     i += 1
286   if tl:
287     o ^= Lstar
288     a ^= E.encrypt(pad10star(tl, blksz) ^ o)
289     if VERBOSE:
290       print 'Offset\'_* : %s' % hex(o)
291       print 'AuthSum_* : %s' % hex(a)
292   return a
293
294 def pmac1_pub(E, m):
295   if VERBOSE: dump_ocb(E)
296   return pmac1(E, m),
297
298 def pmac2_pub(E, m):
299   return pmac2(E, m),
300
301 def pmac3_pub(E, m):
302   return pmac3(E, m),
303
304 def pmacgen(bc):
305   return [(0,), (1,),
306           (3*bc.blksz,),
307           (3*bc.blksz - 5,)]
308
309 ###--------------------------------------------------------------------------
310 ### OCB.
311
312 ## For OCB2, it's important for security that n = log_x (x + 1) is large in
313 ## the field representations of GF(2^w) used -- in fact, we need more, that
314 ## i n (mod 2^w - 1) is large for i in {4, -3, -2, -1, 1, 2, 3, 4}.  The
315 ## original paper lists the values for 64 and 128, but we support other block
316 ## sizes, so here's the result of the (rather large, in some cases)
317 ## computation.
318 ##
319 ## Block size           log_x (x + 1)
320 ##
321 ##       64             9686038906114705801
322 ##       96             63214690573408919568138788065
323 ##      128             338793687469689340204974836150077311399
324 ##      192             161110085006042185925119981866940491651092686475226538785
325 ##      256             22928580326165511958494515843249267194111962539778797914076675796261938307298
326
327 def ocb1(E, n, h, m, tsz = None):
328   ## This is OCB1.PMAC1 from Rogaway's `Authenticated-Encryption with
329   ## Associated-Data'.
330   blksz = E.__class__.blksz
331   if VERBOSE: dump_ocb(E)
332   Lgamma, Lxinv = ocb_masks(E)
333   if tsz is None: tsz = blksz
334   a = Z(blksz)
335   o = E.encrypt(n ^ Lgamma[0])
336   if VERBOSE: print 'R = %s' % hex(o)
337   i = 1
338   y = C.WriteBuffer()
339   v, tl = blocks(m, blksz)
340   for x in v:
341     b = ntz(i)
342     o ^= Lgamma[b]
343     a ^= x
344     if VERBOSE:
345       print 'Z[%d]: %d -> %s' % (i - 1, b, hex(o))
346       print 'A[%d]: %s' % (i - 1, hex(a))
347     y.put(E.encrypt(x ^ o) ^ o)
348     i += 1
349   b = ntz(i)
350   o ^= Lgamma[b]
351   n = len(tl)
352   if VERBOSE:
353     print 'Z[%d]: %d -> %s' % (i - 1, b, hex(o))
354     print 'LEN = %s' % hex(C.MP(8*n).storeb(blksz))
355   yfinal = E.encrypt(C.MP(8*n).storeb(blksz) ^ Lxinv ^ o)
356   cfinal = tl ^ yfinal[:n]
357   a ^= o ^ (tl + yfinal[n:])
358   y.put(cfinal)
359   t = E.encrypt(a)
360   if h: t ^= pmac1(E, h)
361   return C.ByteString(y), C.ByteString(t[:tsz])
362
363 def ocb2(E, n, h, m, tsz = None):
364   blksz = E.__class__.blksz
365   if tsz is None: tsz = blksz
366   p = prim(8*blksz)
367   L = E.encrypt(n)
368   o = mul_blk_gf(L, 2, p)
369   a = Z(blksz)
370   v, tl = blocks(m, blksz)
371   y = C.WriteBuffer()
372   for x in v:
373     a ^= x
374     y.put(E.encrypt(x ^ o) ^ o)
375     o = mul_blk_gf(o, 2, p)
376   n = len(tl)
377   yfinal = E.encrypt(C.MP(8*n).storeb(blksz) ^ o)
378   cfinal = tl ^ yfinal[:n]
379   a ^= (tl + yfinal[n:]) ^ mul_blk_gf(o, 3, p)
380   y.put(cfinal)
381   t = E.encrypt(a)
382   if h: t ^= pmac2(E, h)
383   return C.ByteString(y), C.ByteString(t[:tsz])
384
385 OCB3_STRETCH = { 8: (5, 25),
386                  12: (6, 33),
387                  16: (6, 8),
388                  24: (7, 40),
389                  32: (7, 120) }
390
391 def ocb3(E, n, h, m, tsz = None):
392   blksz = E.__class__.blksz
393   if tsz is None: tsz = blksz
394   Lstar, Ldollar, Lgamma = ocb3_masks(E)
395   if VERBOSE: dump_ocb3(E)
396
397   ## Figure out how much we need to glue onto the nonce.  This ends up being
398   ## [t mod w]_v || 0^p || 1 || N, where w is the block size in bits, t is
399   ## the tag length in bits, v = floor(log_2(w - 1)) + 1, and p = w - l(N) -
400   ## v - 1.  But this is an annoying way to think about it because of the
401   ## byte misalignment.  Instead, think of it as a byte-aligned prefix
402   ## encoding the tag and an `is the nonce full-length' flag, followed by
403   ## optional padding, and then the nonce:
404   ##
405   ##    F || N                  if l(N) = w - f
406   ##    F || 0^p || 1 || N      otherwise
407   ##
408   ## where F is [t mod w]_v || 0^{f-v-1} || b; f = floor(log_2(w - 1)) + 2;
409   ## b is 1 if l(N) = w - f, or 0 otherwise; and p = w - f - l(N) - 1.
410   tszbits = C.MP(8*blksz - 1).nbits
411   fwd = tszbits/8 + 1
412   f = tsz << 3 + 8*fwd - tszbits
413
414   ## Form the augmented nonce.
415   nb = C.WriteBuffer()
416   nsz, nwd = len(n), blksz - fwd
417   if nsz == nwd: f |= 1
418   nb.put(C.MP(f).storeb(fwd))
419   if nsz < nwd: nb.zero(nwd - nsz - 1).putu8(1)
420   nb.put(n)
421   nn = C.ByteString(nb)
422   if VERBOSE: print 'N\'        : %s' % hex(nn)
423
424   ## Calculate the initial offset.
425   split, shift = OCB3_STRETCH[blksz]
426   splitbits = 1 << split
427   t2ps = C.MP(0).setbit(splitbits)
428   lomask = (C.MP(0).setbit(split) - 1)
429   himask = ~lomask
430   top, bottom = nn&himask.storeb2c(blksz), C.MP.loadb(nn)&lomask
431   ktop = C.MP.loadb(E.encrypt(top))
432   stretch = (ktop << splitbits) | \
433       (((ktop ^ (ktop << shift)) >> (8*blksz - splitbits))%t2ps)
434   o = (stretch >> splitbits - bottom).storeb(blksz)
435   a = C.ByteString.zero(blksz)
436   if VERBOSE:
437     print 'bottom    : %d' % bottom
438     print 'Ktop      : %s' % hex(ktop.storeb(blksz))
439     print 'Stretch   : %s' % hex(stretch.storeb(blksz + (1 << split - 3)))
440     print 'Offset_0  : %s' % hex(o)
441
442   ## Split the message into blocks.
443   i = 1
444   y = C.WriteBuffer()
445   v, tl = blocks0(m, blksz)
446   for x in v:
447     b = ntz(i)
448     o ^= Lgamma[b]
449     a ^= x
450     if VERBOSE:
451       print 'Offset_%-3d: %s' % (i, hex(o))
452       print 'Checksum_%d: %s' % (i, hex(a))
453     y.put(E.encrypt(x ^ o) ^ o)
454     i += 1
455   if tl:
456     o ^= Lstar
457     n = len(tl)
458     pad = E.encrypt(o)
459     a ^= pad10star(tl, blksz)
460     if VERBOSE:
461       print 'Offset_*  : %s' % hex(o)
462       print 'Checksum_*: %s' % hex(a)
463     y.put(tl ^ pad[0:n])
464   o ^= Ldollar
465   t = E.encrypt(a ^ o) ^ pmac3(E, h)
466   return C.ByteString(y), C.ByteString(t[:tsz])
467
468 def ocbgen(bc):
469   w = bc.blksz
470   return [(w, 0, 0), (w, 1, 0), (w, 0, 1),
471           (w, 0, 3*w),
472           (w, 3*w, 3*w),
473           (w, 0, 3*w + 5),
474           (w, 3*w - 5, 3*w + 5)]
475
476 def ocb3gen(bc):
477   w = bc.blksz
478   return [(w - 2, 0, 0), (w - 2, 1, 0), (w - 2, 0, 1),
479           (w - 5, 0, 3*w),
480           (w - 3, 3*w, 3*w),
481           (w - 2, 0, 3*w + 5),
482           (w - 2, 3*w - 5, 3*w + 5)]
483
484 ###--------------------------------------------------------------------------
485 ### Main program.
486
487 VERBOSE = LRVERBOSE = False
488
489 class struct (object):
490   def __init__(me, **kw):
491     me.__dict__.update(kw)
492
493 def mct(ocb, bc, ksz, nsz, tsz):
494   k = C.MP(8*tsz).storeb(ksz)
495   E = bc(k)
496   e = C.ByteString('')
497   n = C.MP(1)
498   cbuf = C.WriteBuffer()
499   for i in xrange(128):
500     s = C.ByteString.zero(i)
501     y, t = ocb(E, n.storeb(nsz), s, s, tsz); n += 1; cbuf.put(y).put(t)
502     y, t = ocb(E, n.storeb(nsz), e, s, tsz); n += 1; cbuf.put(y).put(t)
503     y, t = ocb(E, n.storeb(nsz), s, e, tsz); n += 1; cbuf.put(y).put(t)
504   _, t = ocb(E, n.storeb(nsz), C.ByteString(cbuf), e, tsz)
505   print hex(t)
506
507 argc = len(argv)
508 argi = 1
509
510 def usage():
511   print >>stderr, """\
512 usage: %s [-v] OCB BLKC OP ARGS...
513         mct KSZ NSZ TSZ
514         kat K N0 TSZ HSZ,MSZ ...
515         lraes W K M""" % argv[0]
516   exit(2)
517
518 def arg(must = True, default = None):
519   global argi
520   if argi < argc: argi += 1; return argv[argi - 1]
521   elif not must: return default
522   else: usage()
523
524 MODEMAP = { 'ocb1': ocb1,
525             'ocb2': ocb2,
526             'ocb3': ocb3 }
527
528 def pat(sz):
529   b = C.WriteBuffer()
530   for i in xrange(sz): b.putu8(i%256)
531   return C.ByteString(b)
532
533 opt = arg()
534 if opt == '-v': VERBOSE = True; opt = arg()
535 ocb = MODEMAP[opt]
536
537 bcname = arg()
538 bc = None
539 for d in LRAES, C.gcprps:
540   try: bc = d[bcname]
541   except KeyError: pass
542   else: break
543 if bc is None: raise KeyError, bcname
544
545 mode = arg()
546 if mode == 'mct':
547   ksz = int(arg()); nsz = int(arg()); tsz = int(arg())
548   mct(ocb, bc, ksz, nsz, tsz)
549   exit(0)
550
551 elif mode == 'kat':
552   k = C.bytes(arg())
553   E = bc(k)
554   nspec = arg()
555   if nspec.endswith('+'): ninc = 1; nspec = nspec[:-1]
556   else: ninc = 0
557   n0 = C.bytes(nspec)
558   nz = C.MP.loadb(n0)
559   nsz = len(n0)
560   tsz = int(arg())
561
562   print 'K: %s' % hex(k)
563
564   while True:
565     hmsz = arg(must = False)
566     if hmsz is None: break
567     hsz, msz = map(int, hmsz.split(','))
568     n = nz.storeb(nsz)
569     h = pat(hsz)
570     m = pat(msz)
571     y, t = ocb(E, n, h, m, tsz)
572     print
573     print 'N: %s' % hex(n)
574     print 'A: %s' % hex(h)
575     print 'P: %s' % hex(m)
576     print 'C: %s%s' % (hex(y), hex(t))
577     nz += ninc
578
579 elif mode == 'lraes':
580   w = int(arg())
581   k = C.bytes(arg())
582   m = C.bytes(arg())
583   LRVERBOSE = True
584   lr = LubyRackoffCipher(bc, w)
585   E = lr(k)
586   print
587   c = E.encrypt(m)
588   print 'E\'(K, m) = %s' % hex(c)
589
590 else:
591   usage()
592
593 ###----- That's all, folks --------------------------------------------------