chiark / gitweb /
symm/gcm.h, symm/gcm-def.h: Implement the GCM authenticated encryption mode.
[catacomb] / utils / advmodes
1 #! /usr/bin/python
2
3 from sys import argv
4 from struct import unpack, pack
5 from itertools import izip
6 import catacomb as C
7
8 R = C.FibRand(0)
9
10 ###--------------------------------------------------------------------------
11 ### Utilities.
12
13 def combs(things, k):
14   ii = range(k)
15   n = len(things)
16   while True:
17     yield [things[i] for i in ii]
18     for j in xrange(k):
19       if j == k - 1: lim = n
20       else: lim = ii[j + 1]
21       i = ii[j] + 1
22       if i < lim:
23         ii[j] = i
24         break
25       ii[j] = j
26     else:
27       return
28
29 POLYMAP = {}
30
31 def poly(nbits):
32   try: return POLYMAP[nbits]
33   except KeyError: pass
34   base = C.GF(0).setbit(nbits).setbit(0)
35   for k in xrange(1, nbits, 2):
36     for cc in combs(range(1, nbits), k):
37       p = base + sum(C.GF(0).setbit(c) for c in cc)
38       if p.irreduciblep(): POLYMAP[nbits] = p; return p
39   raise ValueError, nbits
40
41 def Z(n):
42   return C.ByteString.zero(n)
43
44 def mul_blk_gf(m, x, p): return ((C.GF.loadb(m)*x)%p).storeb((p.nbits + 6)/8)
45
46 def with_lastp(it):
47   it = iter(it)
48   try: j = next(it)
49   except StopIteration: raise ValueError, 'empty iter'
50   lastp = False
51   while not lastp:
52     i = j
53     try: j = next(it)
54     except StopIteration: lastp = True
55     yield i, lastp
56
57 def safehex(x):
58   if len(x): return hex(x)
59   else: return '""'
60
61 def keylens(ksz):
62   sel = []
63   if isinstance(ksz, C.KeySZSet): kk = ksz.set
64   elif isinstance(ksz, C.KeySZRange): kk = range(ksz.min, ksz.max, ksz.mod)
65   elif isinstance(ksz, C.KeySZAny): kk = range(64); sel = [0]
66   kk = list(kk); kk = kk[:]
67   n = len(kk)
68   while n and len(sel) < 4:
69     i = R.range(n)
70     n -= 1
71     kk[i], kk[n] = kk[n], kk[i]
72     sel.append(kk[n])
73   return sel
74
75 def pad0star(m, w):
76   n = len(m)
77   if not n: r = w
78   else: r = (-len(m))%w
79   if r: m += Z(r)
80   return C.ByteString(m)
81
82 def pad10star(m, w):
83   r = w - len(m)%w
84   if r: m += '\x80' + Z(r - 1)
85   return C.ByteString(m)
86
87 def ntz(i):
88   j = 0
89   while (i&1) == 0: i >>= 1; j += 1
90   return j
91
92 def blocks(x, w):
93   v, i, n = [], 0, len(x)
94   while n - i > w:
95     v.append(C.ByteString(x[i:i + w]))
96     i += w
97   return v, C.ByteString(x[i:])
98
99 EMPTY = C.bytes('')
100
101 def blocks0(x, w):
102   v, tl = blocks(x, w)
103   if len(tl) == w: v.append(tl); tl = EMPTY
104   return v, tl
105
106 def dummygen(bc): return []
107
108 CUSTOM = {}
109
110 ###--------------------------------------------------------------------------
111 ### RC6.
112
113 class RC6Cipher (type):
114   def __new__(cls, w, r):
115     name = 'rc6-%d/%d' % (w, r)
116     me = type(name, (RC6Base,), {})
117     me.name = name
118     me.r = r
119     me.w = w
120     me.blksz = w/2
121     me.keysz = C.KeySZRange(me.blksz, 1, 255, 1)
122     return me
123
124 def rotw(w):
125   return w.bit_length() - 1
126
127 def rol(w, x, n):
128   m0, m1 = C.MP(0).setbit(w - n) - 1, C.MP(0).setbit(n) - 1
129   return ((x&m0) << n) | (x >> (w - n))&m1
130
131 def ror(w, x, n):
132   m0, m1 = C.MP(0).setbit(n) - 1, C.MP(0).setbit(w - n) - 1
133   return ((x&m0) << (w - n)) | (x >> n)&m1
134
135 class RC6Base (object):
136
137   ## Magic constants.
138   P400 = C.MP(0xb7e151628aed2a6abf7158809cf4f3c762e7160f38b4da56a784d9045190cfef324e7738926cfbe5f4bf8d8d8c31d763da06)
139   Q400 = C.MP(0x9e3779b97f4a7c15f39cc0605cedc8341082276bf3a27251f86c6a11d0c18e952767f0b153d27b7f0347045b5bf1827f0188)
140
141   def __init__(me, k):
142
143     ## Build the magic numbers.
144     P = me.P400 >> (400 - me.w)
145     if P%2 == 0: P += 1
146     Q = me.Q400 >> (400 - me.w)
147     if Q%2 == 0: Q += 1
148     M = C.MP(0).setbit(me.w) - 1
149
150     ## Convert the key into words.
151     wb = me.w/8
152     c = (len(k) + wb - 1)/wb
153     kb, ktl = blocks(k, me.w/8)
154     L = map(C.MP.loadl, kb + [ktl])
155     assert c == len(L)
156
157     ## Build the subkey table.
158     me.d = rotw(me.w)
159     n = 2*me.r + 4
160     S = [(P + i*Q)&M for i in xrange(n)]
161
162     ##for j in xrange(c):
163     ##  print 'L[%3d] = %s' % (j, hex(L[j]).upper()[2:].rjust(2*wb, '0'))
164     ##for i in xrange(n):
165     ##  print 'S[%3d] = %s' % (i, hex(S[i]).upper()[2:].rjust(2*wb, '0'))
166
167     i = j = 0
168     A = B = C.MP(0)
169
170     for s in xrange(3*max(c, n)):
171       A = S[i] = rol(me.w, S[i] + A + B, 3)
172       B = L[j] = rol(me.w, L[j] + A + B, (A + B)%(1 << me.d))
173       ##print 'S[%3d] = %s' % (i, hex(S[i]).upper()[2:].rjust(2*wb, '0'))
174       ##print 'L[%3d] = %s' % (j, hex(L[j]).upper()[2:].rjust(2*wb, '0'))
175       i = (i + 1)%n
176       j = (j + 1)%c
177
178     ## Done.
179     me.s = S
180
181   def encrypt(me, x):
182     M = C.MP(0).setbit(me.w) - 1
183     a, b, c, d = map(C.MP.loadl, blocks0(x, me.blksz/4)[0])
184     b = (b + me.s[0])&M
185     d = (d + me.s[1])&M
186     ##print 'B = %s' % (hex(b).upper()[2:].rjust(me.w/4, '0'))
187     ##print 'D = %s' % (hex(d).upper()[2:].rjust(me.w/4, '0'))
188     for i in xrange(2, 2*me.r + 2, 2):
189       t = rol(me.w, 2*b*b + b, me.d)
190       u = rol(me.w, 2*d*d + d, me.d)
191       a = (rol(me.w, a ^ t, u%(1 << me.d)) + me.s[i + 0])&M
192       c = (rol(me.w, c ^ u, t%(1 << me.d)) + me.s[i + 1])&M
193       ##print 'A = %s' % (hex(a).upper()[2:].rjust(me.w/4, '0'))
194       ##print 'C = %s' % (hex(c).upper()[2:].rjust(me.w/4, '0'))
195       a, b, c, d = b, c, d, a
196     a = (a + me.s[2*me.r + 2])&M
197     c = (c + me.s[2*me.r + 3])&M
198     ##print 'A = %s' % (hex(a).upper()[2:].rjust(me.w/4, '0'))
199     ##print 'C = %s' % (hex(c).upper()[2:].rjust(me.w/4, '0'))
200     return C.ByteString(a.storel(me.blksz/4) + b.storel(me.blksz/4) +
201                         c.storel(me.blksz/4) + d.storel(me.blksz/4))
202
203   def decrypt(me, x):
204     M = C.MP(0).setbit(me.w) - 1
205     a, b, c, d = map(C.MP.loadl, blocks0(x, me.blksz/4))
206     c = (c - me.s[2*me.r + 3])&M
207     a = (a - me.s[2*me.r + 2])&M
208     for i in xrange(2*me.r + 1, 1, -2):
209       a, b, c, d = d, a, b, c
210       u = rol(me.w, 2*d*d + d, me.d)
211       t = rol(me.w, 2*b*b + b, me.d)
212       c = ror(me.w, (c - me.s[i + 1])&M, t%(1 << me.d)) ^ u
213       a = ror(me.w, (a - me.s[i + 0])&M, u%(1 << me.d)) ^ t
214     a = (a + s[2*me.r + 2])&M
215     c = (c + s[2*me.r + 3])&M
216     return C.ByteString(a.storel(me.blksz/4) + b.storel(me.blksz/4) +
217                         c.storel(me.blksz/4) + d.storel(me.blksz/4))
218
219 for (w, r) in [(8, 16), (16, 16), (24, 16), (32, 16),
220                (32, 20), (48, 16), (64, 16), (96, 16), (128, 16),
221                (192, 16), (256, 16), (400, 16)]:
222   CUSTOM['rc6-%d/%d' % (w, r)] = RC6Cipher(w, r)
223
224 ###--------------------------------------------------------------------------
225 ### OMAC (or CMAC).
226
227 def omac_masks(E):
228   blksz = E.__class__.blksz
229   p = poly(8*blksz)
230   z = Z(blksz)
231   L = E.encrypt(z)
232   m0 = mul_blk_gf(L, 2, p)
233   m1 = mul_blk_gf(m0, 2, p)
234   return m0, m1
235
236 def dump_omac(E):
237   blksz = E.__class__.blksz
238   m0, m1 = omac_masks(E)
239   print 'L = %s' % hex(E.encrypt(Z(blksz)))
240   print 'm0 = %s' % hex(m0)
241   print 'm1 = %s' % hex(m1)
242   for t in xrange(3):
243     print 'v%d = %s' % (t, hex(E.encrypt(C.MP(t).storeb(blksz))))
244     print 'z%d = %s' % (t, hex(omac(E, t, '')))
245
246 def omac(E, t, m):
247   blksz = E.__class__.blksz
248   m0, m1 = omac_masks(E)
249   a = Z(blksz)
250   if t is not None: m = C.MP(t).storeb(blksz) + m
251   v, tl = blocks(m, blksz)
252   for x in v: a = E.encrypt(a ^ x)
253   r = blksz - len(tl)
254   if r == 0:
255     a = E.encrypt(a ^ tl ^ m0)
256   else:
257     pad = pad10star(tl, blksz)
258     a = E.encrypt(a ^ pad ^ m1)
259   return a
260
261 def cmac(E, m):
262   if VERBOSE: dump_omac(E)
263   return omac(E, None, m),
264
265 def cmacgen(bc):
266   return [(0,), (1,),
267           (3*bc.blksz,),
268           (3*bc.blksz - 5,)]
269
270 ###--------------------------------------------------------------------------
271 ### Counter mode.
272
273 def ctr(E, m, c0):
274   blksz = E.__class__.blksz
275   y = C.WriteBuffer()
276   c = C.MP.loadb(c0)
277   while y.size < len(m):
278     y.put(E.encrypt(c.storeb(blksz)))
279     c += 1
280   return C.ByteString(m) ^ C.ByteString(y)[:len(m)]
281
282 ###--------------------------------------------------------------------------
283 ### GCM.
284
285 def gcm_mangle(x):
286   y = C.WriteBuffer()
287   for b in x:
288     b = ord(b)
289     bb = 0
290     for i in xrange(8):
291       bb <<= 1
292       if b&1: bb |= 1
293       b >>= 1
294     y.putu8(bb)
295   return C.ByteString(y)
296
297 def gcm_mul(x, y):
298   w = len(x)
299   p = poly(8*w)
300   u, v = C.GF.loadl(gcm_mangle(x)), C.GF.loadl(gcm_mangle(y))
301   z = (u*v)%p
302   return gcm_mangle(z.storel(w))
303
304 def gcm_pow(x, n):
305   w = len(x)
306   p = poly(8*w)
307   u = C.GF.loadl(gcm_mangle(x))
308   z = pow(u, n, p)
309   return gcm_mangle(z.storel(w))
310
311 def gcm_ctr(E, m, c0):
312   y = C.WriteBuffer()
313   pre = c0[:-4]
314   c, = unpack('>L', c0[-4:])
315   while y.size < len(m):
316     c += 1
317     y.put(E.encrypt(pre + pack('>L', c)))
318   return C.ByteString(m) ^ C.ByteString(y)[:len(m)]
319
320 def g(what, x, m, a0 = None):
321   n = len(x)
322   if a0 is None: a = Z(n)
323   else: a = a0
324   i = 0
325   for b in blocks0(m, n)[0]:
326     a = gcm_mul(a ^ b, x)
327     if VERBOSE: print '%s[%d] = %s -> %s' % (what, i, hex(b), hex(a))
328     i += 1
329   return a
330
331 def gcm_pad(w, x):
332   return C.ByteString(x + Z(-len(x)%w))
333
334 def gcm_lens(w, a, b):
335   if w < 12: n = w
336   else: n = w/2
337   return C.ByteString(C.MP(a).storeb(n) + C.MP(b).storeb(n))
338
339 def ghash(whata, whatb, x, a, b):
340   w = len(x)
341   ha = g(whata, x, gcm_pad(w, a))
342   hb = g(whatb, x, gcm_pad(w, b))
343   if a:
344     hc = gcm_mul(ha, gcm_pow(x, (len(b) + w - 1)/w)) ^ hb
345     if VERBOSE: print '%s || %s -> %s' % (whata, whatb, hex(hc))
346   else:
347     hc = hb
348   return g(whatb, x, gcm_lens(w, 8*len(a), 8*len(b)), hc)
349
350 def gcmenc(E, n, h, m, tsz = None):
351   w = E.__class__.blksz
352   x = E.encrypt(Z(w))
353   if VERBOSE: print 'x = %s' % hex(x)
354   if len(n) + 4 == w: c0 = C.ByteString(n + pack('>L', 1))
355   else: c0 = ghash('?', 'n', x, EMPTY, n)
356   if VERBOSE: print 'c0 = %s' % hex(c0)
357   y = gcm_ctr(E, m, c0)
358   t = ghash('h', 'y', x, h, y) ^ E.encrypt(c0)
359   return y, t
360
361 def gcmdec(E, n, h, y, t):
362   w = E.__class__.blksz
363   x = E.encrypt(Z(w))
364   if VERBOSE: print 'x = %s' % hex(x)
365   if len(n) + 4 == w: c0 = C.ByteString(n + pack('>L', 1))
366   else: c0 = ghash('?', 'n', x, EMPTY, n)
367   if VERBOSE: print 'c0 = %s' % hex(c0)
368   m = gcm_ctr(E, y, c0)
369   tt = ghash('h', 'y', x, h, y) ^ E.encrypt(c0)
370   if t == tt: return m,
371   else: return None,
372
373 def gcmgen(bc):
374   return [(0, 0, 0), (1, 0, 0), (0, 1, 0), (0, 0, 1),
375           (bc.blksz, 3*bc.blksz, 3*bc.blksz),
376           (bc.blksz - 4, bc.blksz + 3, 3*bc.blksz + 9),
377           (bc.blksz - 1, 3*bc.blksz - 5, 3*bc.blksz + 5)]
378
379 ###--------------------------------------------------------------------------
380 ### EAX.
381
382 def eaxenc(E, n, h, m, tsz = None):
383   if VERBOSE:
384     print 'k = %s' % hex(k)
385     print 'n = %s' % hex(n)
386     print 'h = %s' % hex(h)
387     print 'm = %s' % hex(m)
388     dump_omac(E)
389   if tsz is None: tsz = E.__class__.blksz
390   c0 = omac(E, 0, n)
391   y = ctr(E, m, c0)
392   ht = omac(E, 1, h)
393   yt = omac(E, 2, y)
394   if VERBOSE:
395     print 'c0 = %s' % hex(c0)
396     print 'ht = %s' % hex(ht)
397     print 'yt = %s' % hex(yt)
398   return y, C.ByteString((c0 ^ ht ^ yt)[:tsz])
399
400 def eaxdec(E, n, h, y, t):
401   if VERBOSE:
402     print 'k = %s' % hex(k)
403     print 'n = %s' % hex(n)
404     print 'h = %s' % hex(h)
405     print 'y = %s' % hex(y)
406     print 't = %s' % hex(t)
407     dump_omac(E)
408   c0 = omac(E, 0, n)
409   m = ctr(E, y, c0)
410   ht = omac(E, 1, h)
411   yt = omac(E, 2, y)
412   if VERBOSE:
413     print 'c0 = %s' % hex(c0)
414     print 'ht = %s' % hex(ht)
415     print 'yt = %s' % hex(yt)
416   if t == (c0 ^ ht ^ yt)[:len(t)]: return m,
417   else: return None,
418
419 def eaxgen(bc):
420   return [(0, 0, 0), (1, 0, 0), (0, 1, 0), (0, 0, 1),
421           (bc.blksz, 3*bc.blksz, 3*bc.blksz),
422           (bc.blksz - 1, 3*bc.blksz - 5, 3*bc.blksz + 5)]
423
424 ###--------------------------------------------------------------------------
425 ### Main program.
426
427 class struct (object):
428   def __init__(me, **kw):
429     me.__dict__.update(kw)
430
431 binarg = struct(mk = R.block, parse = C.bytes, show = safehex)
432 intarg = struct(mk = lambda x: x, parse = int, show = None)
433
434 MODEMAP = { 'eax-enc': (eaxgen, 3*[binarg] + [intarg], eaxenc),
435             'eax-dec': (dummygen, 4*[binarg], eaxdec),
436             'cmac': (cmacgen, [binarg], cmac),
437             'gcm-enc': (gcmgen, 3*[binarg] + [intarg], gcmenc),
438             'gcm-dec': (dummygen, 4*[binarg], gcmdec) }
439
440 mode = argv[1]
441 bc = None
442 for d in CUSTOM, C.gcprps:
443   try: bc = d[argv[2]]
444   except KeyError: pass
445   else: break
446 if bc is None: raise KeyError, argv[2]
447 if len(argv) == 3:
448   VERBOSE = False
449   gen, argty, func = MODEMAP[mode]
450   if mode.endswith('-enc'): mode = mode[:-4]
451   print '%s-%s {' % (bc.name, mode)
452   for ksz in keylens(bc.keysz):
453     for argvals in gen(bc):
454       k = R.block(ksz)
455       args = [t.mk(a) for t, a in izip(argty, argvals)]
456       rets = func(bc(k), *args)
457       print '  %s' % safehex(k)
458       for t, a in izip(argty, args):
459         if t.show: print '    %s' % t.show(a)
460       for r, lastp in with_lastp(rets):
461         print '    %s%s' % (safehex(r), lastp and ';' or '')
462   print '}'
463 else:
464   VERBOSE = True
465   k = C.bytes(argv[3])
466   gen, argty, func = MODEMAP[mode]
467   args = [t.parse(a) for t, a in izip(argty, argv[4:])]
468   rets = func(bc(k), *args)
469   for r in rets:
470     if r is None: print "X"
471     else: print hex(r)
472
473 ###----- That's all, folks --------------------------------------------------