chiark / gitweb /
rand/rand-x86ish.S: Hoist argument register allocation outside.
[catacomb] / utils / advmodes
1 #! /usr/bin/python
2
3 from sys import argv, exit
4 from struct import unpack, pack
5 from itertools import izip
6 import catacomb as C
7
8 R = C.FibRand(0)
9
10 ###--------------------------------------------------------------------------
11 ### Utilities.
12
13 def combs(things, k):
14   ii = range(k)
15   n = len(things)
16   while True:
17     yield [things[i] for i in ii]
18     for j in xrange(k):
19       if j == k - 1: lim = n
20       else: lim = ii[j + 1]
21       i = ii[j] + 1
22       if i < lim:
23         ii[j] = i
24         break
25       ii[j] = j
26     else:
27       return
28
29 POLYMAP = {}
30
31 def poly(nbits):
32   try: return POLYMAP[nbits]
33   except KeyError: pass
34   base = C.GF(0).setbit(nbits).setbit(0)
35   for k in xrange(1, nbits, 2):
36     for cc in combs(range(1, nbits), k):
37       p = base + sum((C.GF(0).setbit(c) for c in cc), C.GF(0))
38       if p.irreduciblep(): POLYMAP[nbits] = p; return p
39   raise ValueError, nbits
40
41 def prim(nbits):
42   ## No fancy way to do this: I'd need a much cleverer factoring algorithm
43   ## than I have in my pockets.
44   if nbits == 64: cc = [64, 4, 3, 1, 0]
45   elif nbits == 96: cc = [96, 10, 9, 6, 0]
46   elif nbits == 128: cc = [128, 7, 2, 1, 0]
47   elif nbits == 192: cc = [192, 15, 11, 5, 0]
48   elif nbits == 256: cc = [256, 10, 5, 2, 0]
49   else: raise ValueError, 'no field for %d bits' % nbits
50   p = C.GF(0)
51   for c in cc: p = p.setbit(c)
52   return p
53
54 def Z(n):
55   return C.ByteString.zero(n)
56
57 def mul_blk_gf(m, x, p): return ((C.GF.loadb(m)*x)%p).storeb((p.nbits + 6)/8)
58
59 def with_lastp(it):
60   it = iter(it)
61   try: j = next(it)
62   except StopIteration: raise ValueError, 'empty iter'
63   lastp = False
64   while not lastp:
65     i = j
66     try: j = next(it)
67     except StopIteration: lastp = True
68     yield i, lastp
69
70 def safehex(x):
71   if len(x): return hex(x)
72   else: return '""'
73
74 def keylens(ksz):
75   sel = []
76   if isinstance(ksz, C.KeySZSet): kk = ksz.set
77   elif isinstance(ksz, C.KeySZRange): kk = range(ksz.min, ksz.max, ksz.mod)
78   elif isinstance(ksz, C.KeySZAny): kk = range(64); sel = [0]
79   kk = list(kk); kk = kk[:]
80   n = len(kk)
81   while n and len(sel) < 4:
82     i = R.range(n)
83     n -= 1
84     kk[i], kk[n] = kk[n], kk[i]
85     sel.append(kk[n])
86   return sel
87
88 def pad0star(m, w):
89   n = len(m)
90   if not n: r = w
91   else: r = (-len(m))%w
92   if r: m += Z(r)
93   return C.ByteString(m)
94
95 def pad10star(m, w):
96   r = w - len(m)%w
97   if r: m += '\x80' + Z(r - 1)
98   return C.ByteString(m)
99
100 def ntz(i):
101   j = 0
102   while (i&1) == 0: i >>= 1; j += 1
103   return j
104
105 def blocks(x, w):
106   v, i, n = [], 0, len(x)
107   while n - i > w:
108     v.append(C.ByteString(x[i:i + w]))
109     i += w
110   return v, C.ByteString(x[i:])
111
112 EMPTY = C.bytes('')
113
114 def blocks0(x, w):
115   v, tl = blocks(x, w)
116   if len(tl) == w: v.append(tl); tl = EMPTY
117   return v, tl
118
119 def dummygen(bc): return []
120
121 CUSTOM = {}
122
123 ###--------------------------------------------------------------------------
124 ### RC6.
125
126 class RC6Cipher (type):
127   def __new__(cls, w, r):
128     name = 'rc6-%d/%d' % (w, r)
129     me = type(name, (RC6Base,), {})
130     me.name = name
131     me.r = r
132     me.w = w
133     me.blksz = w/2
134     me.keysz = C.KeySZRange(me.blksz, 1, 255, 1)
135     return me
136
137 def rotw(w):
138   return w.bit_length() - 1
139
140 def rol(w, x, n):
141   m0, m1 = C.MP(0).setbit(w - n) - 1, C.MP(0).setbit(n) - 1
142   return ((x&m0) << n) | (x >> (w - n))&m1
143
144 def ror(w, x, n):
145   m0, m1 = C.MP(0).setbit(n) - 1, C.MP(0).setbit(w - n) - 1
146   return ((x&m0) << (w - n)) | (x >> n)&m1
147
148 class RC6Base (object):
149
150   ## Magic constants.
151   P400 = C.MP(0xb7e151628aed2a6abf7158809cf4f3c762e7160f38b4da56a784d9045190cfef324e7738926cfbe5f4bf8d8d8c31d763da06)
152   Q400 = C.MP(0x9e3779b97f4a7c15f39cc0605cedc8341082276bf3a27251f86c6a11d0c18e952767f0b153d27b7f0347045b5bf1827f0188)
153
154   def __init__(me, k):
155
156     ## Build the magic numbers.
157     P = me.P400 >> (400 - me.w)
158     if P%2 == 0: P += 1
159     Q = me.Q400 >> (400 - me.w)
160     if Q%2 == 0: Q += 1
161     M = C.MP(0).setbit(me.w) - 1
162
163     ## Convert the key into words.
164     wb = me.w/8
165     c = (len(k) + wb - 1)/wb
166     kb, ktl = blocks(k, me.w/8)
167     L = map(C.MP.loadl, kb + [ktl])
168     assert c == len(L)
169
170     ## Build the subkey table.
171     me.d = rotw(me.w)
172     n = 2*me.r + 4
173     S = [(P + i*Q)&M for i in xrange(n)]
174
175     ##for j in xrange(c):
176     ##  print 'L[%3d] = %s' % (j, hex(L[j]).upper()[2:].rjust(2*wb, '0'))
177     ##for i in xrange(n):
178     ##  print 'S[%3d] = %s' % (i, hex(S[i]).upper()[2:].rjust(2*wb, '0'))
179
180     i = j = 0
181     A = B = C.MP(0)
182
183     for s in xrange(3*max(c, n)):
184       A = S[i] = rol(me.w, S[i] + A + B, 3)
185       B = L[j] = rol(me.w, L[j] + A + B, (A + B)%(1 << me.d))
186       ##print 'S[%3d] = %s' % (i, hex(S[i]).upper()[2:].rjust(2*wb, '0'))
187       ##print 'L[%3d] = %s' % (j, hex(L[j]).upper()[2:].rjust(2*wb, '0'))
188       i = (i + 1)%n
189       j = (j + 1)%c
190
191     ## Done.
192     me.s = S
193
194   def encrypt(me, x):
195     M = C.MP(0).setbit(me.w) - 1
196     a, b, c, d = map(C.MP.loadl, blocks0(x, me.blksz/4)[0])
197     b = (b + me.s[0])&M
198     d = (d + me.s[1])&M
199     ##print 'B = %s' % (hex(b).upper()[2:].rjust(me.w/4, '0'))
200     ##print 'D = %s' % (hex(d).upper()[2:].rjust(me.w/4, '0'))
201     for i in xrange(2, 2*me.r + 2, 2):
202       t = rol(me.w, 2*b*b + b, me.d)
203       u = rol(me.w, 2*d*d + d, me.d)
204       a = (rol(me.w, a ^ t, u%(1 << me.d)) + me.s[i + 0])&M
205       c = (rol(me.w, c ^ u, t%(1 << me.d)) + me.s[i + 1])&M
206       ##print 'A = %s' % (hex(a).upper()[2:].rjust(me.w/4, '0'))
207       ##print 'C = %s' % (hex(c).upper()[2:].rjust(me.w/4, '0'))
208       a, b, c, d = b, c, d, a
209     a = (a + me.s[2*me.r + 2])&M
210     c = (c + me.s[2*me.r + 3])&M
211     ##print 'A = %s' % (hex(a).upper()[2:].rjust(me.w/4, '0'))
212     ##print 'C = %s' % (hex(c).upper()[2:].rjust(me.w/4, '0'))
213     return C.ByteString(a.storel(me.blksz/4) + b.storel(me.blksz/4) +
214                         c.storel(me.blksz/4) + d.storel(me.blksz/4))
215
216   def decrypt(me, x):
217     M = C.MP(0).setbit(me.w) - 1
218     a, b, c, d = map(C.MP.loadl, blocks0(x, me.blksz/4))
219     c = (c - me.s[2*me.r + 3])&M
220     a = (a - me.s[2*me.r + 2])&M
221     for i in xrange(2*me.r + 1, 1, -2):
222       a, b, c, d = d, a, b, c
223       u = rol(me.w, 2*d*d + d, me.d)
224       t = rol(me.w, 2*b*b + b, me.d)
225       c = ror(me.w, (c - me.s[i + 1])&M, t%(1 << me.d)) ^ u
226       a = ror(me.w, (a - me.s[i + 0])&M, u%(1 << me.d)) ^ t
227     a = (a + s[2*me.r + 2])&M
228     c = (c + s[2*me.r + 3])&M
229     return C.ByteString(a.storel(me.blksz/4) + b.storel(me.blksz/4) +
230                         c.storel(me.blksz/4) + d.storel(me.blksz/4))
231
232 for (w, r) in [(8, 16), (16, 16), (24, 16), (32, 16),
233                (32, 20), (48, 16), (64, 16), (96, 16), (128, 16),
234                (192, 16), (256, 16), (400, 16)]:
235   CUSTOM['rc6-%d/%d' % (w, r)] = RC6Cipher(w, r)
236
237 ###--------------------------------------------------------------------------
238 ### OMAC (or CMAC).
239
240 def omac_masks(E):
241   blksz = E.__class__.blksz
242   p = poly(8*blksz)
243   z = Z(blksz)
244   L = E.encrypt(z)
245   m0 = mul_blk_gf(L, C.GF(2), p)
246   m1 = mul_blk_gf(m0, C.GF(2), p)
247   return m0, m1
248
249 def dump_omac(E):
250   blksz = E.__class__.blksz
251   m0, m1 = omac_masks(E)
252   print 'L = %s' % hex(E.encrypt(Z(blksz)))
253   print 'm0 = %s' % hex(m0)
254   print 'm1 = %s' % hex(m1)
255   for t in xrange(3):
256     print 'v%d = %s' % (t, hex(E.encrypt(C.MP(t).storeb(blksz))))
257     print 'z%d = %s' % (t, hex(omac(E, t, '')))
258
259 def omac(E, t, m):
260   blksz = E.__class__.blksz
261   m0, m1 = omac_masks(E)
262   a = Z(blksz)
263   if t is not None: m = C.MP(t).storeb(blksz) + m
264   v, tl = blocks(m, blksz)
265   for x in v: a = E.encrypt(a ^ x)
266   r = blksz - len(tl)
267   if r == 0:
268     a = E.encrypt(a ^ tl ^ m0)
269   else:
270     pad = pad10star(tl, blksz)
271     a = E.encrypt(a ^ pad ^ m1)
272   return a
273
274 def cmac(E, m):
275   if VERBOSE: dump_omac(E)
276   return omac(E, None, m),
277
278 def cmacgen(bc):
279   return [(0,), (1,),
280           (3*bc.blksz,),
281           (3*bc.blksz - 5,)]
282
283 ###--------------------------------------------------------------------------
284 ### Counter mode.
285
286 def ctr(E, m, c0):
287   blksz = E.__class__.blksz
288   y = C.WriteBuffer()
289   c = C.MP.loadb(c0)
290   while y.size < len(m):
291     y.put(E.encrypt(c.storeb(blksz)))
292     c += 1
293   return C.ByteString(m) ^ C.ByteString(y)[:len(m)]
294
295 ###--------------------------------------------------------------------------
296 ### GCM.
297
298 def gcm_mangle(x):
299   y = C.WriteBuffer()
300   for b in x:
301     b = ord(b)
302     bb = 0
303     for i in xrange(8):
304       bb <<= 1
305       if b&1: bb |= 1
306       b >>= 1
307     y.putu8(bb)
308   return C.ByteString(y)
309
310 def gcm_mul(x, y):
311   w = len(x)
312   p = poly(8*w)
313   u, v = C.GF.loadl(gcm_mangle(x)), C.GF.loadl(gcm_mangle(y))
314   z = (u*v)%p
315   return gcm_mangle(z.storel(w))
316
317 def gcm_pow(x, n):
318   w = len(x)
319   p = poly(8*w)
320   u = C.GF.loadl(gcm_mangle(x))
321   z = pow(u, n, p)
322   return gcm_mangle(z.storel(w))
323
324 def gcm_ctr(E, m, c0):
325   y = C.WriteBuffer()
326   pre = c0[:-4]
327   c, = unpack('>L', c0[-4:])
328   while y.size < len(m):
329     c += 1
330     y.put(E.encrypt(pre + pack('>L', c)))
331   return C.ByteString(m) ^ C.ByteString(y)[:len(m)]
332
333 def g(what, x, m, a0 = None):
334   n = len(x)
335   if a0 is None: a = Z(n)
336   else: a = a0
337   i = 0
338   for b in blocks0(m, n)[0]:
339     a = gcm_mul(a ^ b, x)
340     if VERBOSE: print '%s[%d] = %s -> %s' % (what, i, hex(b), hex(a))
341     i += 1
342   return a
343
344 def gcm_pad(w, x):
345   return C.ByteString(x + Z(-len(x)%w))
346
347 def gcm_lens(w, a, b):
348   if w < 12: n = w
349   else: n = w/2
350   return C.ByteString(C.MP(a).storeb(n) + C.MP(b).storeb(n))
351
352 def ghash(whata, whatb, x, a, b):
353   w = len(x)
354   ha = g(whata, x, gcm_pad(w, a))
355   hb = g(whatb, x, gcm_pad(w, b))
356   if a:
357     hc = gcm_mul(ha, gcm_pow(x, (len(b) + w - 1)/w)) ^ hb
358     if VERBOSE: print '%s || %s -> %s' % (whata, whatb, hex(hc))
359   else:
360     hc = hb
361   return g(whatb, x, gcm_lens(w, 8*len(a), 8*len(b)), hc)
362
363 def gcmenc(E, n, h, m, tsz = None):
364   w = E.__class__.blksz
365   x = E.encrypt(Z(w))
366   if VERBOSE: print 'x = %s' % hex(x)
367   if len(n) + 4 == w: c0 = C.ByteString(n + pack('>L', 1))
368   else: c0 = ghash('?', 'n', x, EMPTY, n)
369   if VERBOSE: print 'c0 = %s' % hex(c0)
370   y = gcm_ctr(E, m, c0)
371   t = ghash('h', 'y', x, h, y) ^ E.encrypt(c0)
372   return y, t
373
374 def gcmdec(E, n, h, y, t):
375   w = E.__class__.blksz
376   x = E.encrypt(Z(w))
377   if VERBOSE: print 'x = %s' % hex(x)
378   if len(n) + 4 == w: c0 = C.ByteString(n + pack('>L', 1))
379   else: c0 = ghash('?', 'n', x, EMPTY, n)
380   if VERBOSE: print 'c0 = %s' % hex(c0)
381   m = gcm_ctr(E, y, c0)
382   tt = ghash('h', 'y', x, h, y) ^ E.encrypt(c0)
383   if t == tt: return m,
384   else: return None,
385
386 def gcmgen(bc):
387   return [(0, 0, 0), (1, 0, 0), (0, 1, 0), (0, 0, 1),
388           (bc.blksz, 3*bc.blksz, 3*bc.blksz),
389           (bc.blksz - 4, bc.blksz + 3, 3*bc.blksz + 9),
390           (bc.blksz - 1, 3*bc.blksz - 5, 3*bc.blksz + 5)]
391
392 def gcm_mul_tests(nbits):
393   print 'gcm-mul%d {' % nbits
394   for i in xrange(64):
395     x = R.block(nbits/8)
396     y = R.block(nbits/8)
397     z = gcm_mul(x, y)
398     print '  %s\n    %s\n    %s;' % (hex(x), hex(y), hex(z))
399   print '}'
400
401 ###--------------------------------------------------------------------------
402 ### CCM.
403
404 def stbe(n, w): return C.MP(n).storeb(w)
405
406 def ccm_fmthdr(blksz, n, hsz, msz, tsz):
407   b = C.WriteBuffer()
408   if blksz == 8:
409     q = blksz - len(n) - 1
410     f = 0
411     if hsz: f |= 0x40
412     f |= (tsz - 1) << 3
413     f |= q - 1
414     b.putu8(f).put(n).put(stbe(msz, q))
415   elif blksz == 16:
416     q = blksz - len(n) - 1
417     f = 0
418     if hsz: f |= 0x40
419     f |= (tsz - 2)/2 << 3
420     f |= q - 1
421     b.putu8(f).put(n).put(stbe(msz, q))
422   else:
423     q = blksz - len(n) - 2
424     f0 = f1 = 0
425     if hsz: f1 |= 0x80
426     f0 |= tsz
427     f1 |= q
428     b.putu8(f0).putu8(f1).put(n).put(stbe(msz, q))
429   b = C.ByteString(b)
430   if VERBOSE: print 'hdr = %s' % hex(b)
431   return b
432
433 def ccm_fmtctr(blksz, n, i = 0):
434   b = C.WriteBuffer()
435   if blksz == 8 or blksz == 16:
436     q = blksz - len(n) - 1
437     b.putu8(q - 1).put(n).put(stbe(i, q))
438   else:
439     q = blksz - len(n) - 2
440     b.putu8(0).putu8(q).put(n).put(stbe(i, q))
441   b = C.ByteString(b)
442   if VERBOSE: print 'ctr = %s' % hex(b)
443   return b
444
445 def ccmaad(b, h, blksz):
446   hsz = len(h)
447   if not hsz: pass
448   elif hsz < 0xfffe: b.putu16(hsz)
449   elif hsz <= 0xffffffff: b.putu16(0xfffe).putu32(hsz)
450   else: b.putu16(0xffff).putu64(hsz)
451   b.put(h); b.zero((-b.size)%blksz)
452
453 def ccmenc(E, n, h, m, tsz = None):
454   blksz = E.__class__.blksz
455   if tsz is None: tsz = blksz
456   b = C.WriteBuffer()
457   b.put(ccm_fmthdr(blksz, n, len(h), len(m), tsz))
458   ccmaad(b, h, blksz)
459   b.put(m); b.zero((-b.size)%blksz)
460   b = C.ByteString(b)
461   a = Z(blksz)
462   v, _ = blocks0(b, blksz)
463   i = 0
464   for x in v:
465     a = E.encrypt(a ^ x)
466     if VERBOSE:
467       print 'b[%d] = %s' % (i, hex(x))
468       print 'a[%d] = %s' % (i + 1, hex(a))
469     i += 1
470   y = ctr(E, a + m, ccm_fmtctr(blksz, n))
471   return C.ByteString(y[blksz:]), C.ByteString(y[0:tsz])
472
473 def ccmdec(E, n, h, y, t):
474   blksz = E.__class__.blksz
475   tsz = len(t)
476   b = C.WriteBuffer()
477   b.put(ccm_fmthdr(blksz, n, len(h), len(y), tsz))
478   ccmaad(b, h, blksz)
479   mm = ctr(E, t + Z(blksz - tsz) + y, ccm_fmtctr(blksz, n))
480   u, m = C.ByteString(mm[0:tsz]), C.ByteString(mm[blksz:])
481   b.put(m); b.zero((-b.size)%blksz)
482   b = C.ByteString(b)
483   a = Z(blksz)
484   v, _ = blocks0(b, blksz)
485   i = 0
486   for x in v:
487     a = E.encrypt(a ^ x)
488     if VERBOSE:
489       print 'b[%d] = %s' % (i, hex(x))
490       print 'a[%d] = %s' % (i + 1, hex(a))
491     i += 1
492   if u == a[:tsz]: return m,
493   else: return None,
494
495 def ccmgen(bc):
496   bsz = bc.blksz
497   return [(bsz - 5, 0, 0, 4), (bsz - 5, 1, 0, 4), (bsz - 5, 0, 1, 4),
498           (bsz/2 + 1, 3*bc.blksz, 3*bc.blksz),
499           (bsz/2 + 1, 3*bc.blksz - 5, 3*bc.blksz + 5)]
500
501 ###--------------------------------------------------------------------------
502 ### EAX.
503
504 def eaxenc(E, n, h, m, tsz = None):
505   if VERBOSE:
506     print 'k = %s' % hex(k)
507     print 'n = %s' % hex(n)
508     print 'h = %s' % hex(h)
509     print 'm = %s' % hex(m)
510     dump_omac(E)
511   if tsz is None: tsz = E.__class__.blksz
512   c0 = omac(E, 0, n)
513   y = ctr(E, m, c0)
514   ht = omac(E, 1, h)
515   yt = omac(E, 2, y)
516   if VERBOSE:
517     print 'c0 = %s' % hex(c0)
518     print 'ht = %s' % hex(ht)
519     print 'yt = %s' % hex(yt)
520   return y, C.ByteString((c0 ^ ht ^ yt)[:tsz])
521
522 def eaxdec(E, n, h, y, t):
523   if VERBOSE:
524     print 'k = %s' % hex(k)
525     print 'n = %s' % hex(n)
526     print 'h = %s' % hex(h)
527     print 'y = %s' % hex(y)
528     print 't = %s' % hex(t)
529     dump_omac(E)
530   c0 = omac(E, 0, n)
531   m = ctr(E, y, c0)
532   ht = omac(E, 1, h)
533   yt = omac(E, 2, y)
534   if VERBOSE:
535     print 'c0 = %s' % hex(c0)
536     print 'ht = %s' % hex(ht)
537     print 'yt = %s' % hex(yt)
538   if t == (c0 ^ ht ^ yt)[:len(t)]: return m,
539   else: return None,
540
541 def eaxgen(bc):
542   return [(0, 0, 0), (1, 0, 0), (0, 1, 0), (0, 0, 1),
543           (bc.blksz, 3*bc.blksz, 3*bc.blksz),
544           (bc.blksz - 1, 3*bc.blksz - 5, 3*bc.blksz + 5)]
545
546 ###--------------------------------------------------------------------------
547 ### PMAC.
548
549 def ocb_masks(E):
550   blksz = E.__class__.blksz
551   p = poly(8*blksz)
552   x = C.GF(2); xinv = p.modinv(x)
553   z = Z(blksz)
554   L = E.encrypt(z)
555   Lxinv = mul_blk_gf(L, xinv, p)
556   Lgamma = 66*[L]
557   for i in xrange(1, len(Lgamma)):
558     Lgamma[i] = mul_blk_gf(Lgamma[i - 1], x, p)
559   return Lgamma, Lxinv
560
561 def dump_ocb(E):
562   Lgamma, Lxinv = ocb_masks(E)
563   print 'L x^-1 = %s' % hex(Lxinv)
564   for i, lg in enumerate(Lgamma[:16]):
565     print 'L x^%d = %s' % (i, hex(lg))
566
567 def pmac1(E, m):
568   blksz = E.__class__.blksz
569   Lgamma, Lxinv = ocb_masks(E)
570   a = o = Z(blksz)
571   i = 0
572   v, tl = blocks(m, blksz)
573   for x in v:
574     i += 1
575     b = ntz(i)
576     o ^= Lgamma[b]
577     a ^= E.encrypt(x ^ o)
578     if VERBOSE:
579       print 'Z[%d]: %d -> %s' % (i, b, hex(o))
580       print 'A[%d]: %s' % (i, hex(a))
581   if len(tl) == blksz: a ^= tl ^ Lxinv
582   else: a ^= pad10star(tl, blksz)
583   return E.encrypt(a)
584
585 def pmac2(E, m):
586   blksz = E.__class__.blksz
587   p = prim(8*blksz)
588   L = E.encrypt(Z(blksz))
589   o = mul_blk_gf(L, C.GF(10), p)
590   a = Z(blksz)
591   v, tl = blocks(m, blksz)
592   for x in v:
593     a ^= E.encrypt(x ^ o)
594     o = mul_blk_gf(o, C.GF(2), p)
595   if len(tl) == blksz: a ^= tl ^ mul_blk_gf(o, C.GF(3), p)
596   else: a ^= pad10star(tl, blksz) ^ mul_blk_gf(o, C.GF(5), p)
597   return E.encrypt(a)
598
599 def ocb3_masks(E):
600   Lgamma, _ = ocb_masks(E)
601   Lstar = Lgamma[0]
602   Ldollar = Lgamma[1]
603   return Lstar, Ldollar, Lgamma[2:]
604
605 def dump_ocb3(E):
606   Lstar, Ldollar, Lgamma = ocb3_masks(E)
607   print 'L_* = %s' % hex(Lstar)
608   print 'L_$ = %s' % hex(Ldollar)
609   for i, lg in enumerate(Lgamma[:16]):
610     print 'L x^%d = %s' % (i, hex(lg))
611
612 def pmac3(E, m):
613   ## Note that `PMAC3' is /not/ a secure MAC.  It depends on other parts of
614   ## OCB3 to prevent a rather easy linear-algebra attack.
615   blksz = E.__class__.blksz
616   Lstar, Ldollar, Lgamma = ocb3_masks(E)
617   a = o = Z(blksz)
618   i = 0
619   v, tl = blocks0(m, blksz)
620   for x in v:
621     i += 1
622     b = ntz(i)
623     o ^= Lgamma[b]
624     a ^= E.encrypt(x ^ o)
625     if VERBOSE:
626       print 'Z[%d]: %d -> %s' % (i, b, hex(o))
627       print 'A[%d]: %s' % (i, hex(a))
628   if tl:
629     o ^= Lstar
630     a ^= E.encrypt(pad10star(tl, blksz) ^ o)
631     if VERBOSE:
632       print 'Z[%d]: * -> %s' % (i, hex(o))
633       print 'A[%d]: %s' % (i, hex(a))
634   return a
635
636 def pmac1_pub(E, m):
637   if VERBOSE: dump_ocb(E)
638   return pmac1(E, m),
639
640 def pmacgen(bc):
641   return [(0,), (1,),
642           (3*bc.blksz,),
643           (3*bc.blksz - 5,)]
644
645 ###--------------------------------------------------------------------------
646 ### OCB.
647
648 def ocb1enc(E, n, h, m, tsz = None):
649   ## This is OCB1.PMAC1 from Rogaway's `Authenticated-Encryption with
650   ## Associated-Data'.
651   blksz = E.__class__.blksz
652   if VERBOSE: dump_ocb(E)
653   Lgamma, Lxinv = ocb_masks(E)
654   if tsz is None: tsz = blksz
655   a = Z(blksz)
656   o = E.encrypt(n ^ Lgamma[0])
657   if VERBOSE: print 'R = %s' % hex(o)
658   i = 0
659   y = C.WriteBuffer()
660   v, tl = blocks(m, blksz)
661   for x in v:
662     i += 1
663     b = ntz(i)
664     o ^= Lgamma[b]
665     a ^= x
666     if VERBOSE:
667       print 'Z[%d]: %d -> %s' % (i, b, hex(o))
668       print 'A[%d]: %s' % (i, hex(a))
669     y.put(E.encrypt(x ^ o) ^ o)
670   i += 1
671   b = ntz(i)
672   o ^= Lgamma[b]
673   n = len(tl)
674   if VERBOSE:
675     print 'Z[%d]: %d -> %s' % (i, b, hex(o))
676     print 'LEN = %s' % hex(C.MP(8*n).storeb(blksz))
677   yfinal = E.encrypt(C.MP(8*n).storeb(blksz) ^ Lxinv ^ o)
678   cfinal = tl ^ yfinal[:n]
679   a ^= o ^ (tl + yfinal[n:])
680   y.put(cfinal)
681   t = E.encrypt(a)
682   if h: t ^= pmac1(E, h)
683   return C.ByteString(y), C.ByteString(t[:tsz])
684
685 def ocb1dec(E, n, h, y, t):
686   ## This is OCB1.PMAC1 from Rogaway's `Authenticated-Encryption with
687   ## Associated-Data'.
688   blksz = E.__class__.blksz
689   if VERBOSE: dump_ocb(E)
690   Lgamma, Lxinv = ocb_masks(E)
691   a = Z(blksz)
692   o = E.encrypt(n ^ Lgamma[0])
693   if VERBOSE: print 'R = %s' % hex(o)
694   i = 0
695   m = C.WriteBuffer()
696   v, tl = blocks(y, blksz)
697   for x in v:
698     i += 1
699     b = ntz(i)
700     o ^= Lgamma[b]
701     if VERBOSE:
702       print 'Z[%d]: %d -> %s' % (i, b, hex(o))
703       print 'A[%d]: %s' % (i, hex(a))
704     u = E.decrypt(x ^ o) ^ o
705     m.put(u)
706     a ^= u
707   i += 1
708   b = ntz(i)
709   o ^= Lgamma[b]
710   n = len(tl)
711   if VERBOSE:
712     print 'Z[%d]: %d -> %s' % (i, b, hex(o))
713     print 'LEN = %s' % hex(C.MP(8*n).storeb(blksz))
714   yfinal = E.encrypt(C.MP(8*n).storeb(blksz) ^ Lxinv ^ o)
715   mfinal = tl ^ yfinal[:n]
716   a ^= o ^ (mfinal + yfinal[n:])
717   m.put(mfinal)
718   u = E.encrypt(a)
719   if h: u ^= pmac1(E, h)
720   if t == u[:len(t)]: return C.ByteString(m),
721   else: return None,
722
723 def ocb2enc(E, n, h, m, tsz = None):
724   ## For OCB2, it's important for security that n = log_x (x + 1) is large in
725   ## the field representations of GF(2^w) used -- in fact, we need more, that
726   ## i n (mod 2^w - 1) is large for i in {4, -3, -2, -1, 1, 2, 3, 4}.  The
727   ## original paper lists the values for 64 and 128, but we support other
728   ## block sizes, so here's the result of the (rather large, in some cases)
729   ## computation.
730   ##
731   ## Block size           log_x (x + 1)
732   ##
733   ##       64             9686038906114705801
734   ##       96             63214690573408919568138788065
735   ##      128             338793687469689340204974836150077311399
736   ##      192             161110085006042185925119981866940491651092686475226538785
737   ##      256             22928580326165511958494515843249267194111962539778797914076675796261938307298
738
739   blksz = E.__class__.blksz
740   if tsz is None: tsz = blksz
741   p = prim(8*blksz)
742   L = E.encrypt(n)
743   o = mul_blk_gf(L, C.GF(2), p)
744   a = Z(blksz)
745   v, tl = blocks(m, blksz)
746   y = C.WriteBuffer()
747   for x in v:
748     a ^= x
749     y.put(E.encrypt(x ^ o) ^ o)
750     o = mul_blk_gf(o, C.GF(2), p)
751   n = len(tl)
752   yfinal = E.encrypt(C.MP(8*n).storeb(blksz) ^ o)
753   cfinal = tl ^ yfinal[:n]
754   a ^= (tl + yfinal[n:]) ^ mul_blk_gf(o, C.GF(3), p)
755   y.put(cfinal)
756   t = E.encrypt(a)
757   if h: t ^= pmac2(E, h)
758   return C.ByteString(y), C.ByteString(t[:tsz])
759
760 def ocb2dec(E, n, h, y, t):
761   blksz = E.__class__.blksz
762   p = prim(8*blksz)
763   L = E.encrypt(n)
764   o = mul_blk_gf(L, C.GF(2), p)
765   a = Z(blksz)
766   v, tl = blocks(y, blksz)
767   m = C.WriteBuffer()
768   for x in v:
769     u = E.encrypt(x ^ o) ^ o
770     y.put(u)
771     a ^= u
772     o = mul_blk_gf(o, C.GF(2), p)
773   n = len(tl)
774   yfinal = E.encrypt(C.MP(8*n).storeb(blksz) ^ o)
775   mfinal = tl ^ yfinal[:n]
776   a ^= (mfinal + yfinal[n:]) ^ mul_blk_gf(o, C.GF(3), p)
777   m.put(mfinal)
778   u = E.encrypt(a)
779   if h: u ^= pmac2(E, h)
780   if t == u[:len(t)]: return C.ByteString(m),
781   else: return None,
782
783 OCB3_STRETCH = {  4: ( 4,  17),
784                   8: ( 5,  25),
785                  12: ( 6,  33),
786                  16: ( 6,   8),
787                  24: ( 7,  40),
788                  32: ( 8,   1),
789                  48: ( 8,  80),
790                  64: ( 8, 176),
791                  96: ( 9, 160),
792                 128: ( 9, 352),
793                 200: (10, 192) }
794
795 def ocb3nonce(E, n, tsz):
796
797   ## Figure out how much we need to glue onto the nonce.  This ends up being
798   ## [t mod w]_v || 0^p || 1 || N, where w is the block size in bits, t is
799   ## the tag length in bits, v = floor(log_2(w - 1)) + 1, and p = w - l(N) -
800   ## v - 1.  But this is an annoying way to think about it because of the
801   ## byte misalignment.  Instead, think of it as a byte-aligned prefix
802   ## encoding the tag and an `is the nonce full-length' flag, followed by
803   ## optional padding, and then the nonce:
804   ##
805   ##    F || N                  if l(N) = w - f
806   ##    F || 0^p || 1 || N      otherwise
807   ##
808   ## where F is [t mod w]_v || 0^{f-v-1} || b; f = floor(log_2(w - 1)) + 2;
809   ## b is 1 if l(N) = w - f, or 0 otherwise; and p = w - f - l(N) - 1.
810   blksz = E.__class__.blksz
811   tszbits = min(C.MP(8*blksz - 1).nbits, 8)
812   fwd = tszbits/8 + 1
813   f = 8*(tsz%blksz) << + 8*fwd - tszbits
814
815   ## Form the augmented nonce.
816   nb = C.WriteBuffer()
817   nsz, nwd = len(n), blksz - fwd
818   if nsz == nwd: f |= 1
819   nb.put(C.MP(f).storeb(fwd))
820   if nsz < nwd: nb.zero(nwd - nsz - 1).putu8(1)
821   nb.put(n)
822   nn = C.ByteString(nb)
823   if VERBOSE: print 'aug-nonce = %s' % hex(nn)
824
825   ## Calculate the initial offset.
826   split, shift = OCB3_STRETCH[blksz]
827   t2pw = C.MP(0).setbit(8*blksz) - 1
828   lomask = (C.MP(0).setbit(split) - 1)
829   himask = ~lomask
830   top, bottom = nn&himask.storeb2c(blksz), C.MP.loadb(nn)&lomask
831   ktop = C.MP.loadb(E.encrypt(top))
832   stretch = (ktop << 8*blksz) | (ktop ^ (ktop << shift)&t2pw)
833   o = (stretch >> 8*blksz - bottom).storeb(blksz)
834   if VERBOSE:
835     print 'stretch = %s' % hex(stretch.storeb(2*blksz))
836     print 'Z[0] = %s' % hex(o)
837
838   return o
839
840 def ocb3enc(E, n, h, m, tsz = None):
841   blksz = E.__class__.blksz
842   if tsz is None: tsz = blksz
843   Lstar, Ldollar, Lgamma = ocb3_masks(E)
844   if VERBOSE: dump_ocb3(E)
845
846   ## Set things up.
847   o = ocb3nonce(E, n, tsz)
848   a = C.ByteString.zero(blksz)
849
850   ## Split the message into blocks.
851   i = 0
852   y = C.WriteBuffer()
853   v, tl = blocks0(m, blksz)
854   for x in v:
855     i += 1
856     b = ntz(i)
857     o ^= Lgamma[b]
858     a ^= x
859     if VERBOSE:
860       print 'Z[%d]: %d -> %s' % (i, b, hex(o))
861       print 'A[%d]: %s' % (i, hex(a))
862     y.put(E.encrypt(x ^ o) ^ o)
863   if tl:
864     o ^= Lstar
865     n = len(tl)
866     pad = E.encrypt(o)
867     a ^= pad10star(tl, blksz)
868     if VERBOSE:
869       print 'Z[%d]: * -> %s' % (i, hex(o))
870       print 'A[%d]: %s' % (i, hex(a))
871     y.put(tl ^ pad[0:n])
872   o ^= Ldollar
873   t = E.encrypt(a ^ o) ^ pmac3(E, h)
874   return C.ByteString(y), C.ByteString(t[:tsz])
875
876 def ocb3dec(E, n, h, y, t):
877   blksz = E.__class__.blksz
878   tsz = len(t)
879   Lstar, Ldollar, Lgamma = ocb3_masks(E)
880   if VERBOSE: dump_ocb3(E)
881
882   ## Set things up.
883   o = ocb3nonce(E, n, tsz)
884   a = C.ByteString.zero(blksz)
885
886   ## Split the message into blocks.
887   i = 0
888   m = C.WriteBuffer()
889   v, tl = blocks0(y, blksz)
890   for x in v:
891     i += 1
892     b = ntz(i)
893     o ^= Lgamma[b]
894     if VERBOSE:
895       print 'Z[%d]: %d -> %s' % (i, b, hex(o))
896       print 'A[%d]: %s' % (i, hex(a))
897     u = E.encrypt(x ^ o) ^ o
898     m.put(u)
899     a ^= u
900   if tl:
901     o ^= Lstar
902     n = len(tl)
903     pad = E.encrypt(o)
904     if VERBOSE:
905       print 'Z[%d]: * -> %s' % (i, hex(o))
906       print 'A[%d]: %s' % (i, hex(a))
907     u = tl ^ pad[0:n]
908     m.put(u)
909     a ^= pad10star(u, blksz)
910   o ^= Ldollar
911   u = E.encrypt(a ^ o) ^ pmac3(E, h)
912   if t == u[:tsz]: return C.ByteString(m),
913   else: return None,
914
915 def ocbgen(bc):
916   w = bc.blksz
917   return [(w, 0, 0), (w, 1, 0), (w, 0, 1),
918           (w, 0, 3*w),
919           (w, 3*w, 3*w),
920           (w, 0, 3*w + 5),
921           (w, 3*w - 5, 3*w + 5)]
922
923 def ocb3gen(bc):
924   w = bc.blksz
925   return [(w - 2, 0, 0), (w - 2, 1, 0), (w - 2, 0, 1),
926           (w - 5, 0, 3*w),
927           (w - 3, 3*w, 3*w),
928           (w - 2, 0, 3*w + 5),
929           (w - 2, 3*w - 5, 3*w + 5)]
930
931 def ocb3_mct(bc, ksz, tsz):
932   k = C.ByteString(C.WriteBuffer().zero(ksz - 4).putu32(8*tsz))
933   E = bc(k)
934   n = C.MP(1)
935   nw = bc.blksz - 4
936   cbuf = C.WriteBuffer()
937   for i in xrange(128):
938     s = C.ByteString.zero(i)
939     y, t = ocb3enc(E, n.storeb(nw), s, s, tsz); n += 1; cbuf.put(y).put(t)
940     y, t = ocb3enc(E, n.storeb(nw), EMPTY, s, tsz); n += 1; cbuf.put(y).put(t)
941     y, t = ocb3enc(E, n.storeb(nw), s, EMPTY, tsz); n += 1; cbuf.put(y).put(t)
942   _, t = ocb3enc(E, n.storeb(nw), C.ByteString(cbuf), EMPTY, tsz)
943   print hex(t)
944
945 def ocb3_mct2(bc):
946   k = C.bytes('000102030405060708090a0b0c0d0e0f')
947   E = bc(k)
948   tsz = min(E.blksz, 32)
949   n = C.MP(1)
950   cbuf = C.WriteBuffer()
951   for i in xrange(128):
952     sbuf = C.WriteBuffer()
953     for j in xrange(i): sbuf.putu8(j)
954     s = C.ByteString(sbuf)
955     y, t = ocb3enc(E, n.storeb(2), s, s, tsz); n += 1; cbuf.put(y).put(t)
956     y, t = ocb3enc(E, n.storeb(2), EMPTY, s, tsz); n += 1; cbuf.put(y).put(t)
957     y, t = ocb3enc(E, n.storeb(2), s, EMPTY, tsz); n += 1; cbuf.put(y).put(t)
958   _, t = ocb3enc(E, n.storeb(2), C.ByteString(cbuf), EMPTY, tsz)
959   print hex(t)
960
961 ###--------------------------------------------------------------------------
962 ### Main program.
963
964 class struct (object):
965   def __init__(me, **kw):
966     me.__dict__.update(kw)
967
968 binarg = struct(mk = R.block, parse = C.bytes, show = safehex)
969 intarg = struct(mk = lambda x: x, parse = int, show = None)
970
971 MODEMAP = { 'eax-enc': (eaxgen, 3*[binarg] + [intarg], eaxenc),
972             'eax-dec': (dummygen, 4*[binarg], eaxdec),
973             'ccm-enc': (ccmgen, 3*[binarg] + [intarg], ccmenc),
974             'ccm-dec': (dummygen, 4*[binarg], ccmdec),
975             'cmac': (cmacgen, [binarg], cmac),
976             'gcm-enc': (gcmgen, 3*[binarg] + [intarg], gcmenc),
977             'gcm-dec': (dummygen, 4*[binarg], gcmdec),
978             'ocb1-enc': (ocbgen, 3*[binarg] + [intarg], ocb1enc),
979             'ocb1-dec': (dummygen, 4*[binarg], ocb1dec),
980             'ocb2-enc': (ocbgen, 3*[binarg] + [intarg], ocb2enc),
981             'ocb2-dec': (dummygen, 4*[binarg], ocb2dec),
982             'ocb3-enc': (ocb3gen, 3*[binarg] + [intarg], ocb3enc),
983             'ocb3-dec': (dummygen, 4*[binarg], ocb3dec),
984             'pmac1': (pmacgen, [binarg], pmac1_pub) }
985
986 mode = argv[1]
987 if len(argv) == 3 and mode == 'gcm-mul':
988   VERBOSE = False
989   nbits = int(argv[2])
990   gcm_mul_tests(nbits)
991   exit(0)
992 bc = None
993 for d in CUSTOM, C.gcprps:
994   try: bc = d[argv[2]]
995   except KeyError: pass
996   else: break
997 if bc is None: raise KeyError, argv[2]
998 if len(argv) == 5 and mode == 'ocb3-mct':
999   VERBOSE = False
1000   ksz, tsz = int(argv[3]), int(argv[4])
1001   ocb3_mct(bc, ksz, tsz)
1002   exit(0)
1003 if len(argv) == 3 and mode == 'ocb3-mct2':
1004   VERBOSE = False
1005   ocb3_mct2(bc)
1006   exit(0)
1007 if len(argv) == 3:
1008   VERBOSE = False
1009   gen, argty, func = MODEMAP[mode]
1010   if mode.endswith('-enc'): mode = mode[:-4]
1011   print '%s-%s {' % (bc.name, mode)
1012   for ksz in keylens(bc.keysz):
1013     for argvals in gen(bc):
1014       k = R.block(ksz)
1015       args = [t.mk(a) for t, a in izip(argty, argvals)]
1016       rets = func(bc(k), *args)
1017       print '  %s' % safehex(k)
1018       for t, a in izip(argty, args):
1019         if t.show: print '    %s' % t.show(a)
1020       for r, lastp in with_lastp(rets):
1021         print '    %s%s' % (safehex(r), lastp and ';' or '')
1022   print '}'
1023 else:
1024   VERBOSE = True
1025   k = C.bytes(argv[3])
1026   gen, argty, func = MODEMAP[mode]
1027   args = [t.parse(a) for t, a in izip(argty, argv[4:])]
1028   rets = func(bc(k), *args)
1029   for r in rets:
1030     if r is None: print "X"
1031     else: print hex(r)
1032
1033 ###----- That's all, folks --------------------------------------------------