chiark / gitweb /
progs/perftest.c: Use from Glibc syscall numbers.
[catacomb] / utils / gcm-ref
1 #! /usr/bin/python
2 ### -*- coding: utf-8 -*-
3
4 from sys import argv, exit
5
6 import catacomb as C
7
8 ###--------------------------------------------------------------------------
9 ### Random utilities.
10
11 def words(s):
12   """Split S into 32-bit pieces and report their values as hex."""
13   return ' '.join('%08x' % C.MP.loadb(s[i:i + 4])
14                   for i in xrange(0, len(s), 4))
15
16 def words_64(s):
17   """Split S into 64-bit pieces and report their values as hex."""
18   return ' '.join('%016x' % C.MP.loadb(s[i:i + 8])
19                   for i in xrange(0, len(s), 8))
20
21 def repmask(val, wd, n):
22   """Return a mask consisting of N copies of the WD-bit value VAL."""
23   v = C.GF(val)
24   a = C.GF(0)
25   for i in xrange(n): a = (a << wd) | v
26   return a
27
28 def combs(things, k):
29   """Iterate over all possible combinations of K of the THINGS."""
30   ii = range(k)
31   n = len(things)
32   while True:
33     yield [things[i] for i in ii]
34     for j in xrange(k):
35       if j == k - 1: lim = n
36       else: lim = ii[j + 1]
37       i = ii[j] + 1
38       if i < lim:
39         ii[j] = i
40         break
41       ii[j] = j
42     else:
43       return
44
45 POLYMAP = {}
46
47 def poly(nbits):
48   """
49   Return the lexically first irreducible polynomial of degree NBITS of lowest
50   weight.
51   """
52   try: return POLYMAP[nbits]
53   except KeyError: pass
54   base = C.GF(0).setbit(nbits).setbit(0)
55   for k in xrange(1, nbits, 2):
56     for cc in combs(range(1, nbits), k):
57       p = base + sum((C.GF(0).setbit(c) for c in cc), C.GF(0))
58       if p.irreduciblep(): POLYMAP[nbits] = p; return p
59   raise ValueError, nbits
60
61 def gcm_mangle(x):
62   """Flip the bits within each byte according to GCM's insane convention."""
63   y = C.WriteBuffer()
64   for b in x:
65     b = ord(b)
66     bb = 0
67     for i in xrange(8):
68       bb <<= 1
69       if b&1: bb |= 1
70       b >>= 1
71     y.putu8(bb)
72   return y.contents
73
74 def endswap_words_32(x):
75   """End-swap each 32-bit word of X."""
76   x = C.ReadBuffer(x)
77   y = C.WriteBuffer()
78   while x.left: y.putu32l(x.getu32b())
79   return y.contents
80
81 def endswap_words_64(x):
82   """End-swap each 64-bit word of X."""
83   x = C.ReadBuffer(x)
84   y = C.WriteBuffer()
85   while x.left: y.putu64l(x.getu64b())
86   return y.contents
87
88 def endswap_bytes(x):
89   """End-swap X by bytes."""
90   y = C.WriteBuffer()
91   for ch in reversed(x): y.put(ch)
92   return y.contents
93
94 def gfmask(n):
95   return C.GF(C.MP(0).setbit(n) - 1)
96
97 def gcm_mul(x, y):
98   """Multiply X and Y according to the GCM rules."""
99   w = len(x)
100   p = poly(8*w)
101   u, v = C.GF.loadl(gcm_mangle(x)), C.GF.loadl(gcm_mangle(y))
102   z = (u*v)%p
103   return gcm_mangle(z.storel(w))
104
105 DEMOMAP = {}
106 def demo(func):
107   name = func.func_name
108   assert(name.startswith('demo_'))
109   DEMOMAP[name[5:].replace('_', '-')] = func
110   return func
111
112 def iota(i = 0):
113   vi = [i]
114   def next(): vi[0] += 1; return vi[0] - 1
115   return next
116
117 ###--------------------------------------------------------------------------
118 ### Portable table-driven implementation.
119
120 def shift_left(x):
121   """Given a field element X (in external format), return X t."""
122   w = len(x)
123   p = poly(8*w)
124   return gcm_mangle(C.GF.storel((C.GF.loadl(gcm_mangle(x)) << 1)%p))
125
126 def shift_right(x):
127   """Given a field element X (in external format), return X/t."""
128   w = len(x)
129   p = poly(8*w)
130   return gcm_mangle(C.GF.storel((C.GF.loadl(gcm_mangle(x))*p.modinv(2))%p))
131
132 def table_common(u, v, flip, getword, ixmask):
133   """
134   Multiply U by V using table lookup; common for `table-b' and `table-l'.
135
136   This matches the `simple_mulk_...' implementation in `gcm.c'.  One entry
137   per bit is the best we can manage if we want a constant-time
138   implementation: processing n bits at a time means we need to scan
139   (2^n - 1)/n times as much memory.
140
141     * FLIP is a function (assumed to be an involution) on one argument X to
142       convert X from external format to table-entry format or back again.
143
144     * GETWORD is a function on one argument B to retrieve the next 32-bit
145       chunk of a field element held in a `ReadBuffer'.  Bits within a word
146       are processed most-significant first.
147
148     * IXMASK is a mask XORed into table indices to permute the table so that
149       its order matches that induced by GETWORD.
150
151   The table is built such that tab[i XOR IXMASK] = U t^i.
152   """
153   w = len(u); assert(w == len(v))
154   a = C.ByteString.zero(w)
155   tab = [None]*(8*w)
156   for i in xrange(8*w):
157     print ';; %9s = %7s = %s' % ('utab[%d]' % i, 'u t^%d' % i, words(u))
158     tab[i ^ ixmask] = flip(u)
159     u = shift_left(u)
160   v = C.ReadBuffer(v)
161   i = 0
162   while v.left:
163     t = getword(v)
164     for j in xrange(32):
165       bit = (t >> 31)&1
166       if bit: a ^= tab[i]
167       print ';; %6s = %d: a <- %s [%9s = %s]' % \
168         ('v[%d]' % (i ^ ixmask), bit, words(a),
169          'utab[%d]' % (i ^ ixmask), words(tab[i]))
170       i += 1; t <<= 1
171   return flip(a)
172
173 @demo
174 def demo_table_b(u, v):
175   """Big-endian table lookup."""
176   return table_common(u, v, lambda x: x, lambda b: b.getu32b(), 0)
177
178 @demo
179 def demo_table_l(u, v):
180   """Little-endian table lookup."""
181   return table_common(u, v, endswap_words_32, lambda b: b.getu32l(), 0x18)
182
183 ###--------------------------------------------------------------------------
184 ### Implementation using 64×64->128-bit binary polynomial multiplication.
185
186 _i = iota()
187 TAG_INPUT_U = _i()
188 TAG_INPUT_V = _i()
189 TAG_SHIFTED_V = _i()
190 TAG_KPIECE_U = _i()
191 TAG_KPIECE_V = _i()
192 TAG_PRODPIECE = _i()
193 TAG_PRODSUM = _i()
194 TAG_PRODUCT = _i()
195 TAG_REDCBITS = _i()
196 TAG_REDCFULL = _i()
197 TAG_REDCMIX = _i()
198 TAG_OUTPUT = _i()
199
200 def split_gf(x, n):
201   n /= 8
202   return [C.GF.loadb(x[i:i + n]) for i in xrange(0, len(x), n)]
203
204 def join_gf(xx, n):
205   x = C.GF(0)
206   for i in xrange(len(xx)): x = (x << n) | xx[i]
207   return x
208
209 def present_gf(x, w, n, what):
210   firstp = True
211   m = gfmask(n)
212   for i in xrange(0, w, 128):
213     print ';; %12s%c         =%s' % \
214       (firstp and what or '',
215        firstp and ':' or ' ',
216        ''.join([j < w
217                 and '          0x%s' % hex(((x >> j)&m).storeb(n/8))
218                 or ''
219                 for j in xrange(i, i + 128, n)]))
220     firstp = False
221
222 def present_gf_pclmul(tag, wd, x, w, n, what):
223   if tag != TAG_PRODPIECE: present_gf(x, w, n, what)
224
225 def reverse(x, w):
226   return C.GF.loadl(x.storeb(w/8))
227
228 def rev32(x):
229   w = x.noctets
230   m_ffff = repmask(0xffff, 32, w/4)
231   m_ff = repmask(0xff, 16, w/2)
232   x = ((x&m_ffff) << 16) | ((x >> 16)&m_ffff)
233   x = ((x&m_ff) << 8) | ((x >> 8)&m_ff)
234   return x
235
236 def rev8(x):
237   w = x.noctets
238   m_0f = repmask(0x0f, 8, w)
239   m_33 = repmask(0x33, 8, w)
240   m_55 = repmask(0x55, 8, w)
241   x = ((x&m_0f) << 4) | ((x >> 4)&m_0f)
242   x = ((x&m_33) << 2) | ((x >> 2)&m_33)
243   x = ((x&m_55) << 1) | ((x >> 1)&m_55)
244   return x
245
246 def present_gf_vmullp64(tag, wd, x, w, n, what):
247   if tag == TAG_PRODPIECE or tag == TAG_REDCFULL:
248     return
249   elif (wd == 128 or wd == 64) and TAG_PRODSUM <= tag <= TAG_PRODUCT:
250     y = x
251   elif (wd == 96 or wd == 192 or wd == 256) and \
252        TAG_PRODSUM <= tag < TAG_OUTPUT:
253     y = x
254   else:
255     xx = x.storeb(w/8)
256     extra = len(xx)%8
257     if extra: xx += C.ByteString.zero(8 - extra)
258     yb = C.WriteBuffer()
259     for i in xrange(len(xx), 0, -8): yb.put(xx[i - 8:i])
260     y = C.GF.loadb(yb.contents)
261   present_gf(y, (w + 63)&~63, n, what)
262
263 def present_gf_pmull(tag, wd, x, w, n, what):
264   if tag == TAG_PRODPIECE or tag == TAG_REDCFULL:
265     return
266   elif tag == TAG_INPUT_V or tag == TAG_SHIFTED_V or tag == TAG_KPIECE_V:
267     w = (w + 63)&~63
268     bx = C.ReadBuffer(x.storeb(w/8))
269     by = C.WriteBuffer()
270     while bx.left: chunk = bx.get(8); by.put(chunk).put(chunk)
271     x = C.GF.loadb(by.contents)
272     w *= 2
273   elif TAG_PRODSUM <= tag <= TAG_PRODUCT:
274     x <<= 1
275   y = reverse(rev8(x), w)
276   present_gf(y, w, n, what)
277
278 def poly64_mul_simple(u, v, presfn, wd, dispwd, mulwd, uwhat, vwhat):
279   """
280   Multiply U by V, returning the product.
281
282   This is the fallback long multiplication.
283   """
284
285   uw, vw = 8*len(u), 8*len(v)
286
287   ## We start by carving the operands into 64-bit pieces.  This is
288   ## straightforward except for the 96-bit case, where we end up with two
289   ## short pieces which we pad at the beginning.
290   upad = (-uw)%mulwd; u += C.ByteString.zero(upad); uw += upad
291   vpad = (-vw)%mulwd; v += C.ByteString.zero(vpad); vw += vpad
292   uu = split_gf(u, mulwd); vv = split_gf(v, mulwd)
293
294   ## Report and accumulate the individual product pieces.
295   x = C.GF(0)
296   ulim, vlim = uw/mulwd, vw/mulwd
297   for i in xrange(ulim + vlim - 2, -1, -1):
298     t = C.GF(0)
299     for j in xrange(max(0, i - vlim + 1), min(vlim, i + 1)):
300       s = uu[ulim - 1 - i + j]*vv[vlim - 1 - j]
301       presfn(TAG_PRODPIECE, wd, s, 2*mulwd, dispwd,
302              '%s_%d %s_%d' % (uwhat, i - j, vwhat, j))
303       t += s
304     presfn(TAG_PRODSUM, wd, t, 2*mulwd, dispwd,
305            '(%s %s)_%d' % (uwhat, vwhat, ulim + vlim - 2 - i))
306     x += t << (mulwd*i)
307   presfn(TAG_PRODUCT, wd, x, uw + vw, dispwd, '%s %s' % (uwhat, vwhat))
308
309   return x >> (upad + vpad)
310
311 def poly64_mul_karatsuba(u, v, klimit, presfn, wd,
312                          dispwd, mulwd, uwhat, vwhat):
313   """
314   Multiply U by V, returning the product.
315
316   If the length of U and V is at least KLIMIT, and the operands are otherwise
317   suitable, then do Karatsuba--Ofman multiplication; otherwise, delegate to
318   `poly64_mul_simple'.
319   """
320   w = 8*len(u)
321
322   if w < klimit or w != 8*len(v) or w%(2*mulwd) != 0:
323     return poly64_mul_simple(u, v, presfn, wd, dispwd, mulwd, uwhat, vwhat)
324
325   hw = w/2
326   u0, u1 = u[:hw/8], u[hw/8:]
327   v0, v1 = v[:hw/8], v[hw/8:]
328   uu, vv = u0 ^ u1, v0 ^ v1
329
330   presfn(TAG_KPIECE_U, wd, C.GF.loadb(uu), hw, dispwd, '%s*' % uwhat)
331   presfn(TAG_KPIECE_V, wd, C.GF.loadb(vv), hw, dispwd, '%s*' % vwhat)
332   uuvv = poly64_mul_karatsuba(uu, vv, klimit, presfn, wd, dispwd, mulwd,
333                               '%s*' % uwhat, '%s*' % vwhat)
334
335   presfn(TAG_KPIECE_U, wd, C.GF.loadb(u0), hw, dispwd, '%s0' % uwhat)
336   presfn(TAG_KPIECE_V, wd, C.GF.loadb(v0), hw, dispwd, '%s0' % vwhat)
337   u0v0 = poly64_mul_karatsuba(u0, v0, klimit, presfn, wd, dispwd, mulwd,
338                               '%s0' % uwhat, '%s0' % vwhat)
339
340   presfn(TAG_KPIECE_U, wd, C.GF.loadb(u1), hw, dispwd, '%s1' % uwhat)
341   presfn(TAG_KPIECE_V, wd, C.GF.loadb(v1), hw, dispwd, '%s1' % vwhat)
342   u1v1 = poly64_mul_karatsuba(u1, v1, klimit, presfn, wd, dispwd, mulwd,
343                               '%s1' % uwhat, '%s1' % vwhat)
344
345   uvuv = uuvv + u0v0 + u1v1
346   presfn(TAG_PRODSUM, wd, uvuv, w, dispwd, '%s!%s' % (uwhat, vwhat))
347
348   x = u1v1 + (uvuv << hw) + (u0v0 << w)
349   presfn(TAG_PRODUCT, wd, x, 2*w, dispwd, '%s %s' % (uwhat, vwhat))
350   return x
351
352 def poly64_mul(u, v, presfn, dispwd, mulwd, klimit, uwhat, vwhat):
353   """
354   Multiply U by V using a primitive 64-bit binary polynomial mutliplier.
355
356   Such a multiplier exists as the appallingly-named `pclmul[lh]q[lh]qdq' on
357   x86, and as `vmull.p64'/`pmull' on ARM.
358
359   Operands arrive in a `register format', which is a byte-swapped variant of
360   the external format.  Implementations differ on the precise details,
361   though.  Returns the double-precision product.
362   """
363
364   w = 8*len(u); assert(w == 8*len(v))
365   x = poly64_mul_karatsuba(u, v, klimit, presfn,
366                            w, dispwd, mulwd, uwhat, vwhat)
367
368   return x.storeb(w/4)
369
370 def poly64_redc(y, presfn, dispwd, redcwd):
371   """
372   Reduce a double-precision product X modulo the appropriate polynomial.
373
374   The operand arrives in a `register format', which is a byte-swapped variant
375   of the external format.  Implementations differ on the precise details,
376   though.  Returns the single-precision reduced value.
377   """
378
379   w = 4*len(y)
380   p = poly(w)
381
382   ## Our polynomial has the form p = t^d + r where r = SUM_{0<=i<d} r_i t^i,
383   ## with each r_i either 0 or 1.  Because we choose the lexically earliest
384   ## irreducible polynomial with the necessary degree, r_i = 1 happens only
385   ## for a small number of tiny i.  In our field, we have t^d = r.
386   ##
387   ## We carve the product into convenient n-bit pieces, for some n dividing d
388   ## -- typically n = 32 or 64.  Let d = m n, and write y = SUM_{0<=i<2m} y_i
389   ## t^{ni}.  The upper portion, the y_i with i >= m, needs reduction; but
390   ## y_i t^{ni} = y_i r t^{n(i-m)}, so we just multiply the top half by r and
391   ## add it to the bottom half.  This all depends on r_i = 0 for all i >=
392   ## n/2.  We process each nonzero coefficient of r separately, in two
393   ## passes.
394   ##
395   ## Multiplying a chunk y_i by some t^j is the same as shifting it left by j
396   ## bits (or would be if GCM weren't backwards, but let's not worry about
397   ## that right now).  The high j bits will spill over into the next chunk,
398   ## while the low n - j bits will stay where they are.  It's these high bits
399   ## which cause trouble -- particularly the high bits of the top chunk,
400   ## since we'll add them on to y_m, which will need further reduction.  But
401   ## only the topmost j bits will do this.
402   ##
403   ## The trick is that we do all of the bits which spill over first -- all of
404   ## the top j bits in each chunk, for each j -- in one pass, and then a
405   ## second pass of all the bits which don't.  Because j, j' < n/2 for any
406   ## two nonzero coefficient degrees j and j', we have j + j' < n whence j <
407   ## n - j' -- so all of the bits contributed to y_m will be handled in the
408   ## second pass when we handle the bits that don't spill over.
409   rr = [i for i in xrange(1, w) if p.testbit(i)]
410   m = gfmask(redcwd)
411
412   ## Handle the spilling bits.
413   yy = split_gf(y, redcwd)
414   b = C.GF(0)
415   for rj in rr:
416     br = [(yi << (redcwd - rj))&m for yi in yy[w/redcwd:]]
417     presfn(TAG_REDCBITS, w, join_gf(br, redcwd), w, dispwd, 'b(%d)' % rj)
418     b += join_gf(br, redcwd) << (w - redcwd)
419   presfn(TAG_REDCFULL, w, b, 2*w, dispwd, 'b')
420   s = C.GF.loadb(y) + b
421   presfn(TAG_REDCMIX, w, s, 2*w, dispwd, 's')
422
423   ## Handle the nonspilling bits.
424   ss = split_gf(s.storeb(w/4), redcwd)
425   a = C.GF(0)
426   for rj in rr:
427     ar = [si >> rj for si in ss[w/redcwd:]]
428     presfn(TAG_REDCBITS, w, join_gf(ar, redcwd), w, dispwd, 'a(%d)' % rj)
429     a += join_gf(ar, redcwd)
430   presfn(TAG_REDCFULL, w, a, w, dispwd, 'a')
431
432   ## Mix everything together.
433   m = gfmask(w)
434   z = (s&m) + (s >> w) + a
435   presfn(TAG_OUTPUT, w, z, w, dispwd, 'z')
436
437   ## And we're done.
438   return z.storeb(w/8)
439
440 def poly64_shiftcommon(u, v, presfn, dispwd = 32, mulwd = 64,
441                        redcwd = 32, klimit = 256):
442   w = 8*len(u)
443   presfn(TAG_INPUT_U, w, C.GF.loadb(u), w, dispwd, 'u')
444   presfn(TAG_INPUT_V, w, C.GF.loadb(v), w, dispwd, 'v')
445   vv = shift_right(v)
446   presfn(TAG_SHIFTED_V, w, C.GF.loadb(vv), w, dispwd, "v'")
447   y = poly64_mul(u, vv, presfn, dispwd, mulwd, klimit, "u", "v'")
448   z = poly64_redc(y, presfn, dispwd, redcwd)
449   return z
450
451 def poly64_directcommon(u, v, presfn, dispwd = 32, mulwd = 64,
452                         redcwd = 32, klimit = 256):
453   w = 8*len(u)
454   presfn(TAG_INPUT_U, w, C.GF.loadb(u), w, dispwd, 'u')
455   presfn(TAG_INPUT_V, w, C.GF.loadb(v), w, dispwd, 'v')
456   y = poly64_mul(u, v, presfn, dispwd, mulwd, klimit, "u", "v")
457   y = (C.GF.loadb(y) << 1).storeb(w/4)
458   z = poly64_redc(y, presfn, dispwd, redcwd)
459   return z
460
461 @demo
462 def demo_pclmul(u, v):
463   return poly64_shiftcommon(u, v, presfn = present_gf_pclmul)
464
465 @demo
466 def demo_vmullp64(u, v):
467   w = 8*len(u)
468   return poly64_shiftcommon(u, v, presfn = present_gf_vmullp64,
469                             redcwd = w%64 == 32 and 32 or 64)
470
471 @demo
472 def demo_pmull(u, v):
473   w = 8*len(u)
474   return poly64_directcommon(u, v, presfn = present_gf_pmull,
475                              redcwd = w%64 == 32 and 32 or 64)
476
477 ###--------------------------------------------------------------------------
478 ### @@@ Random debris to be deleted. @@@
479
480 def cutting_room_floor():
481
482   x = C.bytes('cde4bef260d7bcda163547d348b7551195e77022907dd1df')
483   y = C.bytes('f7dac5c9941d26d0c6eb14ad568f86edd1dc9268eeee5332')
484
485   u, v = C.GF.loadb(x), C.GF.loadb(y)
486
487   g = u*v << 1
488   print 'y = %s' % words(g.storeb(48))
489   b1 = (g&repmask(0x01, 32, 6)) << 191
490   b2 = (g&repmask(0x03, 32, 6)) << 190
491   b7 = (g&repmask(0x7f, 32, 6)) << 185
492   b = b1 + b2 + b7
493   print 'b = %s' % words(b.storeb(48)[0:28])
494   h = g + b
495   print 'w = %s' % words(h.storeb(48))
496
497   a0 = (h&repmask(0xffffffff, 32, 6)) << 192
498   a1 = (h&repmask(0xfffffffe, 32, 6)) << 191
499   a2 = (h&repmask(0xfffffffc, 32, 6)) << 190
500   a7 = (h&repmask(0xffffff80, 32, 6)) << 185
501   a = a0 + a1 + a2 + a7
502
503   print '     a_1 = %s' % words(a1.storeb(48)[0:24])
504   print '     a_2 = %s' % words(a2.storeb(48)[0:24])
505   print '     a_7 = %s' % words(a7.storeb(48)[0:24])
506
507   print 'low+unit = %s' % words((h + a0).storeb(48)[0:24])
508   print ' low+0,2 = %s' % words((h + a0 + a2).storeb(48)[0:24])
509   print '     1,7 = %s' % words((a1 + a7).storeb(48)[0:24])
510
511   print 'a = %s' % words(a.storeb(48)[0:24])
512   z = h + a
513   print 'z = %s' % words(z.storeb(48))
514
515   z = gcm_mul(x, y)
516   print 'u v mod p = %s' % words(z)
517
518 ###--------------------------------------------------------------------------
519 ### Main program.
520
521 style = argv[1]
522 u = C.bytes(argv[2])
523 v = C.bytes(argv[3])
524 zz = DEMOMAP[style](u, v)
525 assert zz == gcm_mul(u, v)
526
527 ###----- That's all, folks --------------------------------------------------