chiark / gitweb /
a0c360f2233c5768a28c550f8d446f2b8e0d55cb
[ocb-tv] / ocbgen
1 #! /usr/bin/python
2 ### -*-python-*-
3 ###
4 ### Generalization of OCB mode for other block sizes
5 ###
6 ### (c) 2017 Mark Wooding
7 ###
8
9 ###----- Licensing notice ---------------------------------------------------
10 ###
11 ### This program is free software; you can redistribute it and/or modify
12 ### it under the terms of the GNU General Public License as published by
13 ### the Free Software Foundation; either version 2 of the License, or
14 ### (at your option) any later version.
15 ###
16 ### This program is distributed in the hope that it will be useful,
17 ### but WITHOUT ANY WARRANTY; without even the implied warranty of
18 ### MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
19 ### GNU General Public License for more details.
20 ###
21 ### You should have received a copy of the GNU General Public License
22 ### along with this program; if not, write to the Free Software Foundation,
23 ### Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
24
25 from sys import argv, stderr
26 from struct import pack
27 from itertools import izip
28 from contextlib import contextmanager
29 import catacomb as C
30
31 R = C.FibRand(0)
32
33 ###--------------------------------------------------------------------------
34 ### Utilities.
35
36 def combs(things, k):
37   ii = range(k)
38   n = len(things)
39   while True:
40     yield [things[i] for i in ii]
41     for j in xrange(k):
42       if j == k - 1: lim = n
43       else: lim = ii[j + 1]
44       i = ii[j] + 1
45       if i < lim:
46         ii[j] = i
47         break
48       ii[j] = j
49     else:
50       return
51
52 POLYMAP = {}
53
54 def poly(nbits):
55   try: return POLYMAP[nbits]
56   except KeyError: pass
57   base = C.GF(0).setbit(nbits).setbit(0)
58   for k in xrange(1, nbits, 2):
59     for cc in combs(range(1, nbits), k):
60       p = base + sum(C.GF(0).setbit(c) for c in cc)
61       if p.irreduciblep(): POLYMAP[nbits] = p; return p
62   raise ValueError, nbits
63
64 def prim(nbits):
65   ## No fancy way to do this: I'd need a much cleverer factoring algorithm
66   ## than I have in my pockets.
67   if nbits == 64: cc = [64, 4, 3, 1, 0]
68   elif nbits == 96: cc = [96, 10, 9, 6, 0]
69   elif nbits == 128: cc = [128, 7, 2, 1, 0]
70   elif nbits == 192: cc = [192, 15, 11, 5, 0]
71   elif nbits == 256: cc = [256, 10, 5, 2, 0]
72   else: raise ValueError, 'no field for %d bits' % nbits
73   p = C.GF(0)
74   for c in cc: p = p.setbit(c)
75   return p
76
77 def Z(n):
78   return C.ByteString.zero(n)
79
80 def mul_blk_gf(m, x, p): return ((C.GF.loadb(m)*x)%p).storeb((p.nbits + 6)/8)
81
82 def with_lastp(it):
83   it = iter(it)
84   try: j = next(it)
85   except StopIteration: raise ValueError, 'empty iter'
86   lastp = False
87   while not lastp:
88     i = j
89     try: j = next(it)
90     except StopIteration: lastp = True
91     yield i, lastp
92
93 def safehex(x):
94   if len(x): return hex(x)
95   else: return '""'
96
97 def keylens(ksz):
98   sel = []
99   if isinstance(ksz, C.KeySZSet): kk = ksz.set
100   elif isinstance(ksz, C.KeySZRange): kk = range(ksz.min, ksz.max, ksz.mod)
101   elif isinstance(ksz, C.KeySZAny): kk = range(64); sel = [0]
102   kk = list(kk); kk = kk[:]
103   n = len(kk)
104   while n and len(sel) < 4:
105     i = R.range(n)
106     n -= 1
107     kk[i], kk[n] = kk[n], kk[i]
108     sel.append(kk[n])
109   return sel
110
111 def pad0star(m, w):
112   n = len(m)
113   if not n: r = w
114   else: r = (-len(m))%w
115   if r: m += Z(r)
116   return C.ByteString(m)
117
118 def pad10star(m, w):
119   r = w - len(m)%w
120   if r: m += '\x80' + Z(r - 1)
121   return C.ByteString(m)
122
123 def ntz(i):
124   j = 0
125   while (i&1) == 0: i >>= 1; j += 1
126   return j
127
128 def blocks(x, w):
129   v, i, n = [], 0, len(x)
130   while n - i > w:
131     v.append(C.ByteString(x[i:i + w]))
132     i += w
133   return v, C.ByteString(x[i:])
134
135 EMPTY = C.bytes('')
136
137 def blocks0(x, w):
138   v, tl = blocks(x, w)
139   if len(tl) == w: v.append(tl); tl = EMPTY
140   return v, tl
141
142 ###--------------------------------------------------------------------------
143 ### Luby--Rackoff large-block ciphers.
144
145 class LubyRackoffCipher (type):
146   def __new__(cls, bc, blksz):
147     assert blksz%2 == 0
148     assert blksz <= 2*bc.blksz
149     name = '%s-lr[%d]' % (bc.name, 8*blksz)
150     me = type(name, (LubyRackoffBase,), {})
151     me.name = name
152     me.blksz = blksz
153     me.keysz = bc.keysz
154     me.bc = bc
155     return me
156
157 @contextmanager
158 def muffle():
159   global VERBOSE, LRVERBOSE
160   _v, _lrv = VERBOSE, LRVERBOSE
161   try:
162     VERBOSE = LRVERBOSE = False
163     yield None
164   finally:
165     VERBOSE, LRVERBOSE = _v, _lrv
166
167 class LubyRackoffBase (object):
168   NR = 4 # for strong-PRP security
169   def __init__(me, k):
170     if LRVERBOSE: print 'K = %s' % hex(k)
171     bc, blksz = me.__class__.bc, me.__class__.blksz
172     with muffle(): E = bc(k)
173     me.f = []
174     ksz = len(k)
175     i = C.MP(0)
176     for j in xrange(me.NR):
177       b = C.WriteBuffer()
178       while b.size < ksz:
179         with muffle(): x = E.encrypt(i.storeb(bc.blksz))
180         b.put(x)
181         if LRVERBOSE: print 'E(K; [%d]) = %s' % (i, hex(x))
182         i += 1
183       kj = C.ByteString(C.ByteString(b)[0:ksz])
184       if LRVERBOSE: print 'K_%d = %s' % (j, hex(kj))
185       with muffle(): me.f.append(bc(kj))
186   def encrypt(me, m):
187     bc, blksz = me.__class__.bc, me.__class__.blksz
188     assert len(m) == blksz
189     l, r = C.ByteString(m[:blksz/2]), C.ByteString(m[blksz/2:])
190     if LRVERBOSE: print 'L_0, R_0 = %s, %s' % (hex(l), hex(r))
191     for j in xrange(me.NR):
192       l0 = pad0star(l, bc.blksz)
193       with muffle(): t = me.f[j].encrypt(l0)
194       l, r = r ^ t[:blksz/2], l
195       if LRVERBOSE:
196         print 'E(K_%d; L_%d || 0^*) = %s' % (j, j, hex(t))
197         print 'L_%d, R_%d = %s, %s' % (j + 1, j + 1, hex(l), hex(r))
198     return C.ByteString(r + l)
199   def decrypt(me, c):
200     bc, blksz = me.__class__.bc, me.__class__.blksz
201     assert len(c) == blksz
202     l, r = C.ByteString(c[:blksz/2]), C.ByteString(c[blksz/2:])
203     for j in xrange(me.NR - 1, -1, -1):
204       l0 = pad0star(l, bc.blksz)
205       with muffle(): t = me.f[j].encrypt(l0)
206       if LRVERBOSE:
207         print 'L_%d, R_%d = %s, %s' % (j + 1, j + 1, hex(l), hex(r))
208         print 'E(K_%d; L_%d || 0^*) = %s' % (j + 1, j + 1, hex(t))
209       l, r = r ^ t[:blksz/2], l
210     if LRVERBOSE: print 'L_0, R_0 = %s, %s' % (hex(l), hex(r))
211     return C.ByteString(r + l)
212
213 LRAES = {}
214 for i in [8, 12, 16, 24, 32]:
215   LRAES['lraes%d' % (8*i)] = LubyRackoffCipher(C.rijndael, i)
216 LRAES['dlraes512'] = LubyRackoffCipher(LubyRackoffCipher(C.rijndael, 32), 64)
217
218 ###--------------------------------------------------------------------------
219 ### PMAC.
220
221 def ocb_masks(E):
222   blksz = E.__class__.blksz
223   p = poly(8*blksz)
224   x = C.GF(2); xinv = p.modinv(x)
225   z = Z(blksz)
226   L = E.encrypt(z)
227   Lxinv = mul_blk_gf(L, xinv, p)
228   Lgamma = 66*[L]
229   for i in xrange(1, len(Lgamma)):
230     Lgamma[i] = mul_blk_gf(Lgamma[i - 1], x, p)
231   return Lgamma, Lxinv
232
233 def dump_ocb(E):
234   Lgamma, Lxinv = ocb_masks(E)
235   print 'L x^-1 = %s' % hex(Lxinv)
236   for i, lg in enumerate(Lgamma):
237     print 'L x^%d = %s' % (i, hex(lg))
238
239 def pmac1(E, m):
240   blksz = E.__class__.blksz
241   Lgamma, Lxinv = ocb_masks(E)
242   a = o = Z(blksz)
243   i = 1
244   v, tl = blocks(m, blksz)
245   for x in v:
246     b = ntz(i)
247     o ^= Lgamma[b]
248     a ^= E.encrypt(x ^ o)
249     if VERBOSE:
250       print 'Z[%d]: %d -> %s' % (i - 1, b, hex(o))
251       print 'A[%d]: %s' % (i - 1, hex(a))
252     i += 1
253   if len(tl) == blksz: a ^= tl ^ Lxinv
254   else: a ^= pad10star(tl, blksz)
255   return E.encrypt(a)
256
257 def pmac2(E, m):
258   blksz = E.__class__.blksz
259   p = prim(8*blksz)
260   L = E.encrypt(Z(blksz))
261   o = mul_blk_gf(L, 10, p)
262   a = Z(blksz)
263   v, tl = blocks(m, blksz)
264   for x in v:
265     a ^= E.encrypt(x ^ o)
266     o = mul_blk_gf(o, 2, p)
267   if len(tl) == blksz: a ^= tl ^ mul_blk_gf(o, 3, p)
268   else: a ^= pad10star(tl, blksz) ^ mul_blk_gf(o, 5, p)
269   return E.encrypt(a)
270
271 def ocb3_masks(E):
272   Lgamma, _ = ocb_masks(E)
273   Lstar = Lgamma[0]
274   Ldollar = Lgamma[1]
275   return Lstar, Ldollar, Lgamma[2:]
276
277 def dump_ocb3(E):
278   Lstar, Ldollar, Lgamma = ocb3_masks(E)
279   print 'L_*       : %s' % hex(Lstar)
280   print 'L_$       : %s' % hex(Ldollar)
281   for i, lg in enumerate(Lgamma[:4]):
282     print 'L_%-8d: %s' % (i, hex(lg))
283
284 def pmac3(E, m):
285   blksz = E.__class__.blksz
286   Lstar, Ldollar, Lgamma = ocb3_masks(E)
287   a = o = Z(blksz)
288   i = 1
289   v, tl = blocks0(m, blksz)
290   for x in v:
291     b = ntz(i)
292     o ^= Lgamma[b]
293     a ^= E.encrypt(x ^ o)
294     if VERBOSE:
295       print 'Offset\'_%-2d: %s' % (i, hex(o))
296       print 'AuthSum_%-2d: %s' % (i, hex(a))
297     i += 1
298   if tl:
299     o ^= Lstar
300     a ^= E.encrypt(pad10star(tl, blksz) ^ o)
301     if VERBOSE:
302       print 'Offset\'_* : %s' % hex(o)
303       print 'AuthSum_* : %s' % hex(a)
304   return a
305
306 def pmac1_pub(E, m):
307   if VERBOSE: dump_ocb(E)
308   return pmac1(E, m),
309
310 def pmac2_pub(E, m):
311   return pmac2(E, m),
312
313 def pmac3_pub(E, m):
314   return pmac3(E, m),
315
316 def pmacgen(bc):
317   return [(0,), (1,),
318           (3*bc.blksz,),
319           (3*bc.blksz - 5,)]
320
321 ###--------------------------------------------------------------------------
322 ### OCB.
323
324 ## For OCB2, it's important for security that n = log_x (x + 1) is large in
325 ## the field representations of GF(2^w) used -- in fact, we need more, that
326 ## i n (mod 2^w - 1) is large for i in {4, -3, -2, -1, 1, 2, 3, 4}.  The
327 ## original paper lists the values for 64 and 128, but we support other block
328 ## sizes, so here's the result of the (rather large, in some cases)
329 ## computation.
330 ##
331 ## Block size           log_x (x + 1)
332 ##
333 ##       64             9686038906114705801
334 ##       96             63214690573408919568138788065
335 ##      128             338793687469689340204974836150077311399
336 ##      192             161110085006042185925119981866940491651092686475226538785
337 ##      256             22928580326165511958494515843249267194111962539778797914076675796261938307298
338
339 def ocb1(E, n, h, m, tsz = None):
340   ## This is OCB1.PMAC1 from Rogaway's `Authenticated-Encryption with
341   ## Associated-Data'.
342   blksz = E.__class__.blksz
343   if VERBOSE: dump_ocb(E)
344   Lgamma, Lxinv = ocb_masks(E)
345   if tsz is None: tsz = blksz
346   a = Z(blksz)
347   o = E.encrypt(n ^ Lgamma[0])
348   if VERBOSE: print 'R = %s' % hex(o)
349   i = 1
350   y = C.WriteBuffer()
351   v, tl = blocks(m, blksz)
352   for x in v:
353     b = ntz(i)
354     o ^= Lgamma[b]
355     a ^= x
356     if VERBOSE:
357       print 'Z[%d]: %d -> %s' % (i - 1, b, hex(o))
358       print 'A[%d]: %s' % (i - 1, hex(a))
359     y.put(E.encrypt(x ^ o) ^ o)
360     i += 1
361   b = ntz(i)
362   o ^= Lgamma[b]
363   n = len(tl)
364   if VERBOSE:
365     print 'Z[%d]: %d -> %s' % (i - 1, b, hex(o))
366     print 'LEN = %s' % hex(C.MP(8*n).storeb(blksz))
367   yfinal = E.encrypt(C.MP(8*n).storeb(blksz) ^ Lxinv ^ o)
368   cfinal = tl ^ yfinal[:n]
369   a ^= o ^ (tl + yfinal[n:])
370   y.put(cfinal)
371   t = E.encrypt(a)
372   if h: t ^= pmac1(E, h)
373   return C.ByteString(y), C.ByteString(t[:tsz])
374
375 def ocb2(E, n, h, m, tsz = None):
376   blksz = E.__class__.blksz
377   if tsz is None: tsz = blksz
378   p = prim(8*blksz)
379   L = E.encrypt(n)
380   o = mul_blk_gf(L, 2, p)
381   a = Z(blksz)
382   v, tl = blocks(m, blksz)
383   y = C.WriteBuffer()
384   for x in v:
385     a ^= x
386     y.put(E.encrypt(x ^ o) ^ o)
387     o = mul_blk_gf(o, 2, p)
388   n = len(tl)
389   yfinal = E.encrypt(C.MP(8*n).storeb(blksz) ^ o)
390   cfinal = tl ^ yfinal[:n]
391   a ^= (tl + yfinal[n:]) ^ mul_blk_gf(o, 3, p)
392   y.put(cfinal)
393   t = E.encrypt(a)
394   if h: t ^= pmac2(E, h)
395   return C.ByteString(y), C.ByteString(t[:tsz])
396
397 OCB3_STRETCH = { 8: (5, 25),
398                  12: (6, 33),
399                  16: (6, 8),
400                  24: (7, 40),
401                  32: (7, 120),
402                  64: (8, 240) }
403
404 def ocb3(E, n, h, m, tsz = None):
405   blksz = E.__class__.blksz
406   if tsz is None: tsz = blksz
407   Lstar, Ldollar, Lgamma = ocb3_masks(E)
408   if VERBOSE: dump_ocb3(E)
409
410   ## Figure out how much we need to glue onto the nonce.  This ends up being
411   ## [t mod w]_v || 0^p || 1 || N, where w is the block size in bits, t is
412   ## the tag length in bits, v = floor(log_2(w - 1)) + 1, and p = w - l(N) -
413   ## v - 1.  But this is an annoying way to think about it because of the
414   ## byte misalignment.  Instead, think of it as a byte-aligned prefix
415   ## encoding the tag and an `is the nonce full-length' flag, followed by
416   ## optional padding, and then the nonce:
417   ##
418   ##    F || N                  if l(N) = w - f
419   ##    F || 0^p || 1 || N      otherwise
420   ##
421   ## where F is [t mod w]_v || 0^{f-v-1} || b; f = floor(log_2(w - 1)) + 2;
422   ## b is 1 if l(N) = w - f, or 0 otherwise; and p = w - f - l(N) - 1.
423   tszbits = C.MP(8*blksz - 1).nbits
424   fwd = tszbits/8 + 1
425   f = tsz << 3 + 8*fwd - tszbits
426
427   ## Form the augmented nonce.
428   nb = C.WriteBuffer()
429   nsz, nwd = len(n), blksz - fwd
430   if nsz == nwd: f |= 1
431   nb.put(C.MP(f).storeb(fwd))
432   if nsz < nwd: nb.zero(nwd - nsz - 1).putu8(1)
433   nb.put(n)
434   nn = C.ByteString(nb)
435   if VERBOSE: print 'N\'        : %s' % hex(nn)
436
437   ## Calculate the initial offset.
438   split, shift = OCB3_STRETCH[blksz]
439   splitbits = 1 << split
440   t2ps = C.MP(0).setbit(splitbits)
441   lomask = (C.MP(0).setbit(split) - 1)
442   himask = ~lomask
443   top, bottom = nn&himask.storeb2c(blksz), C.MP.loadb(nn)&lomask
444   ktop = C.MP.loadb(E.encrypt(top))
445   stretch = (ktop << splitbits) | \
446       (((ktop ^ (ktop << shift)) >> (8*blksz - splitbits))%t2ps)
447   o = (stretch >> splitbits - bottom).storeb(blksz)
448   a = C.ByteString.zero(blksz)
449   if VERBOSE:
450     print 'bottom    : %d' % bottom
451     print 'Ktop      : %s' % hex(ktop.storeb(blksz))
452     print 'Stretch   : %s' % hex(stretch.storeb(blksz + (1 << split - 3)))
453     print 'Offset_0  : %s' % hex(o)
454
455   ## Split the message into blocks.
456   i = 1
457   y = C.WriteBuffer()
458   v, tl = blocks0(m, blksz)
459   for x in v:
460     b = ntz(i)
461     o ^= Lgamma[b]
462     a ^= x
463     if VERBOSE:
464       print 'Offset_%-3d: %s' % (i, hex(o))
465       print 'Checksum_%d: %s' % (i, hex(a))
466     y.put(E.encrypt(x ^ o) ^ o)
467     i += 1
468   if tl:
469     o ^= Lstar
470     n = len(tl)
471     pad = E.encrypt(o)
472     a ^= pad10star(tl, blksz)
473     if VERBOSE:
474       print 'Offset_*  : %s' % hex(o)
475       print 'Checksum_*: %s' % hex(a)
476     y.put(tl ^ pad[0:n])
477   o ^= Ldollar
478   t = E.encrypt(a ^ o) ^ pmac3(E, h)
479   return C.ByteString(y), C.ByteString(t[:tsz])
480
481 def ocbgen(bc):
482   w = bc.blksz
483   return [(w, 0, 0), (w, 1, 0), (w, 0, 1),
484           (w, 0, 3*w),
485           (w, 3*w, 3*w),
486           (w, 0, 3*w + 5),
487           (w, 3*w - 5, 3*w + 5)]
488
489 def ocb3gen(bc):
490   w = bc.blksz
491   return [(w - 2, 0, 0), (w - 2, 1, 0), (w - 2, 0, 1),
492           (w - 5, 0, 3*w),
493           (w - 3, 3*w, 3*w),
494           (w - 2, 0, 3*w + 5),
495           (w - 2, 3*w - 5, 3*w + 5)]
496
497 ###--------------------------------------------------------------------------
498 ### Main program.
499
500 VERBOSE = LRVERBOSE = False
501
502 class struct (object):
503   def __init__(me, **kw):
504     me.__dict__.update(kw)
505
506 def mct(ocb, bc, ksz, nsz, tsz):
507   k = C.MP(8*tsz).storeb(ksz)
508   E = bc(k)
509   e = C.ByteString('')
510   n = C.MP(1)
511   cbuf = C.WriteBuffer()
512   for i in xrange(128):
513     s = C.ByteString.zero(i)
514     y, t = ocb(E, n.storeb(nsz), s, s, tsz); n += 1; cbuf.put(y).put(t)
515     y, t = ocb(E, n.storeb(nsz), e, s, tsz); n += 1; cbuf.put(y).put(t)
516     y, t = ocb(E, n.storeb(nsz), s, e, tsz); n += 1; cbuf.put(y).put(t)
517   _, t = ocb(E, n.storeb(nsz), C.ByteString(cbuf), e, tsz)
518   print hex(t)
519
520 argc = len(argv)
521 argi = 1
522
523 def usage():
524   print >>stderr, """\
525 usage: %s [-v] OCB BLKC OP ARGS...
526         mct KSZ NSZ TSZ
527         kat K N0 TSZ HSZ,MSZ ...
528         lraes W K M""" % argv[0]
529   exit(2)
530
531 def arg(must = True, default = None):
532   global argi
533   if argi < argc: argi += 1; return argv[argi - 1]
534   elif not must: return default
535   else: usage()
536
537 MODEMAP = { 'ocb1': ocb1,
538             'ocb2': ocb2,
539             'ocb3': ocb3 }
540
541 def pat(sz):
542   b = C.WriteBuffer()
543   for i in xrange(sz): b.putu8(i%256)
544   return C.ByteString(b)
545
546 opt = arg()
547 if opt == '-v': VERBOSE = True; opt = arg()
548 ocb = MODEMAP[opt]
549
550 bcname = arg()
551 bc = None
552 for d in LRAES, C.gcprps:
553   try: bc = d[bcname]
554   except KeyError: pass
555   else: break
556 if bc is None: raise KeyError, bcname
557
558 mode = arg()
559 if mode == 'mct':
560   ksz = int(arg()); nsz = int(arg()); tsz = int(arg())
561   mct(ocb, bc, ksz, nsz, tsz)
562   exit(0)
563
564 elif mode == 'kat':
565   k = C.bytes(arg())
566   E = bc(k)
567   nspec = arg()
568   if nspec.endswith('+'): ninc = 1; nspec = nspec[:-1]
569   else: ninc = 0
570   n0 = C.bytes(nspec)
571   nz = C.MP.loadb(n0)
572   nsz = len(n0)
573   tsz = int(arg())
574
575   print 'K: %s' % hex(k)
576
577   while True:
578     hmsz = arg(must = False)
579     if hmsz is None: break
580     hsz, msz = map(int, hmsz.split(','))
581     n = nz.storeb(nsz)
582     h = pat(hsz)
583     m = pat(msz)
584     y, t = ocb(E, n, h, m, tsz)
585     print
586     print 'N: %s' % hex(n)
587     print 'A: %s' % hex(h)
588     print 'P: %s' % hex(m)
589     print 'C: %s%s' % (hex(y), hex(t))
590     nz += ninc
591
592 elif mode == 'lraes':
593   w = int(arg())
594   k = C.bytes(arg())
595   m = C.bytes(arg())
596   LRVERBOSE = True
597   lr = LubyRackoffCipher(bc, w)
598   E = lr(k)
599   print
600   c = E.encrypt(m)
601   print 'E\'(K, m) = %s' % hex(c)
602
603 else:
604   usage()
605
606 ###----- That's all, folks --------------------------------------------------