chiark / gitweb /
ff771b9a051c0b0bbe76ac91c8d1e27bc7c9db54
[catacomb-python] / catacomb / __init__.py
1 ### -*-python-*-
2 ###
3 ### Setup for Catacomb/Python bindings
4 ###
5 ### (c) 2004 Straylight/Edgeware
6 ###
7
8 ###----- Licensing notice ---------------------------------------------------
9 ###
10 ### This file is part of the Python interface to Catacomb.
11 ###
12 ### Catacomb/Python is free software; you can redistribute it and/or modify
13 ### it under the terms of the GNU General Public License as published by
14 ### the Free Software Foundation; either version 2 of the License, or
15 ### (at your option) any later version.
16 ###
17 ### Catacomb/Python is distributed in the hope that it will be useful,
18 ### but WITHOUT ANY WARRANTY; without even the implied warranty of
19 ### MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
20 ### GNU General Public License for more details.
21 ###
22 ### You should have received a copy of the GNU General Public License
23 ### along with Catacomb/Python; if not, write to the Free Software Foundation,
24 ### Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
25
26 from __future__ import with_statement
27
28 from binascii import hexlify as _hexify, unhexlify as _unhexify
29 from contextlib import contextmanager as _ctxmgr
30 try: import DLFCN as _dlfcn
31 except ImportError: _dlfcn = None
32 import os as _os
33 from struct import pack as _pack
34 import sys as _sys
35 import types as _types
36
37 ###--------------------------------------------------------------------------
38 ### Import the main C extension module.
39
40 try:
41   _dlflags = _odlflags = _sys.getdlopenflags()
42 except AttributeError:
43   _dlflags = _odlflags = -1
44
45 ## Set the `deep binding' flag.  Python has its own different MD5
46 ## implementation, and some distributions export `md5_init' and friends so
47 ## they override our versions, which doesn't end well.  Figure out how to
48 ## turn this flag on so we don't have the problem.
49 if _dlflags >= 0:
50   try: _dlflags |= _dlfcn.RTLD_DEEPBIND
51   except AttributeError:
52     try: _dlflags |= _os.RTLD_DEEPBIND
53     except AttributeError:
54       if _os.uname()[0] == 'Linux': _dlflags |= 8 # magic knowledge
55       else: pass # can't do this.
56   _sys.setdlopenflags(_dlflags)
57
58 import _base
59
60 if _odlflags >= 0:
61   _sys.setdlopenflags(_odlflags)
62
63 del _dlflags, _odlflags
64
65 ###--------------------------------------------------------------------------
66 ### Basic stuff.
67
68 ## For the benefit of the default keyreporter, we need the program name.
69 _base._ego(_sys.argv[0])
70
71 ## Register our module.
72 _base._set_home_module(_sys.modules[__name__])
73 def default_lostexchook(why, ty, val, tb):
74   """`catacomb.lostexchook(WHY, TY, VAL, TB)' reports lost exceptions."""
75   _sys.stderr.write("\n\n!!! LOST EXCEPTION: %s\n" % why)
76   _sys.excepthook(ty, val, tb)
77   _sys.stderr.write("\n")
78 lostexchook = default_lostexchook
79
80 ## How to fix a name back into the right identifier.  Alas, the rules are not
81 ## consistent.
82 def _fixname(name):
83
84   ## Hyphens consistently become underscores.
85   name = name.replace('-', '_')
86
87   ## But slashes might become underscores or just vanish.
88   if name.startswith('salsa20'): name = name.replace('/', '')
89   else: name = name.replace('/', '_')
90
91   ## Done.
92   return name
93
94 ## Initialize the module.  Drag in the static methods of the various
95 ## classes; create names for the various known crypto algorithms.
96 def _init():
97   d = globals()
98   b = _base.__dict__;
99   for i in b:
100     if i[0] != '_':
101       d[i] = b[i];
102   for i in ['ByteString',
103             'MP', 'GF', 'Field',
104             'ECPt', 'ECPtCurve', 'ECCurve', 'ECInfo',
105             'DHInfo', 'BinDHInfo', 'RSAPriv', 'BBSPriv',
106             'PrimeFilter', 'RabinMiller',
107             'Group', 'GE',
108             'KeySZ', 'KeyData']:
109     c = d[i]
110     pre = '_' + i + '_'
111     plen = len(pre)
112     for j in b:
113       if j[:plen] == pre:
114         setattr(c, j[plen:], classmethod(b[j]))
115   for i in [gcciphers, gchashes, gcmacs, gcprps]:
116     for c in i.itervalues():
117       d[_fixname(c.name)] = c
118   for c in gccrands.itervalues():
119     d[_fixname(c.name + 'rand')] = c
120 _init()
121
122 ## A handy function for our work: add the methods of a named class to an
123 ## existing class.  This is how we write the Python-implemented parts of our
124 ## mostly-C types.
125 def _augment(c, cc):
126   for i in cc.__dict__:
127     a = cc.__dict__[i]
128     if type(a) is _types.MethodType:
129       a = a.im_func
130     elif type(a) not in (_types.FunctionType, staticmethod, classmethod):
131       continue
132     setattr(c, i, a)
133
134 ## Parsing functions tend to return the object parsed and the remainder of
135 ## the input.  This checks that the remainder is input and, if so, returns
136 ## just the object.
137 def _checkend(r):
138   x, rest = r
139   if rest != '':
140     raise SyntaxError, 'junk at end of string'
141   return x
142
143 ## Some pretty-printing utilities.
144 PRINT_SECRETS = False
145 def _clsname(me): return type(me).__name__
146 def _repr_secret(thing, secretp = True):
147   if not secretp or PRINT_SECRETS: return repr(thing)
148   else: return '#<SECRET>'
149 def _pp_str(me, pp, cyclep): pp.text(cyclep and '...' or str(me))
150 def _pp_secret(pp, thing, secretp = True):
151   if not secretp or PRINT_SECRETS: pp.pretty(thing)
152   else: pp.text('#<SECRET>')
153 def _pp_bgroup(pp, text):
154   ind = len(text)
155   pp.begin_group(ind, text)
156   return ind
157 def _pp_bgroup_tyname(pp, obj, open = '('):
158   return _pp_bgroup(pp, _clsname(obj) + open)
159 def _pp_kv(pp, k, v, secretp = False):
160   ind = _pp_bgroup(pp, k + ' = ')
161   _pp_secret(pp, v, secretp)
162   pp.end_group(ind, '')
163 def _pp_commas(pp, printfn, items):
164   firstp = True
165   for i in items:
166     if firstp: firstp = False
167     else: pp.text(','); pp.breakable()
168     printfn(i)
169 def _pp_dict(pp, items):
170   def p((k, v)):
171     pp.begin_group(0)
172     pp.pretty(k)
173     pp.text(':')
174     pp.begin_group(2)
175     pp.breakable()
176     pp.pretty(v)
177     pp.end_group(2)
178     pp.end_group(0)
179   _pp_commas(pp, p, items)
180
181 ###--------------------------------------------------------------------------
182 ### Bytestrings.
183
184 class _tmp:
185   def fromhex(x):
186     return ByteString(_unhexify(x))
187   fromhex = staticmethod(fromhex)
188   def __hex__(me):
189     return _hexify(me)
190   def __repr__(me):
191     return 'bytes(%r)' % hex(me)
192 _augment(ByteString, _tmp)
193 ByteString.__hash__ = str.__hash__
194 bytes = ByteString.fromhex
195
196 ###--------------------------------------------------------------------------
197 ### Hashing.
198
199 class _tmp:
200   def check(me, h):
201     hh = me.done()
202     return ctstreq(h, hh)
203 _augment(GHash, _tmp)
204 _augment(Poly1305Hash, _tmp)
205
206 class _HashBase (object):
207   ## The standard hash methods.  Assume that `hash' is defined and returns
208   ## the receiver.
209   def hashu8(me, n): return me.hash(_pack('B', n))
210   def hashu16l(me, n): return me.hash(_pack('<H', n))
211   def hashu16b(me, n): return me.hash(_pack('>H', n))
212   hashu16 = hashu16b
213   def hashu32l(me, n): return me.hash(_pack('<L', n))
214   def hashu32b(me, n): return me.hash(_pack('>L', n))
215   hashu32 = hashu32b
216   def hashu64l(me, n): return me.hash(_pack('<Q', n))
217   def hashu64b(me, n): return me.hash(_pack('>Q', n))
218   hashu64 = hashu64b
219   def hashbuf8(me, s): return me.hashu8(len(s)).hash(s)
220   def hashbuf16l(me, s): return me.hashu16l(len(s)).hash(s)
221   def hashbuf16b(me, s): return me.hashu16b(len(s)).hash(s)
222   hashbuf16 = hashbuf16b
223   def hashbuf32l(me, s): return me.hashu32l(len(s)).hash(s)
224   def hashbuf32b(me, s): return me.hashu32b(len(s)).hash(s)
225   hashbuf32 = hashbuf32b
226   def hashbuf64l(me, s): return me.hashu64l(len(s)).hash(s)
227   def hashbuf64b(me, s): return me.hashu64b(len(s)).hash(s)
228   hashbuf64 = hashbuf64b
229   def hashstrz(me, s): return me.hash(s).hashu8(0)
230
231 class _ShakeBase (_HashBase):
232
233   ## Python gets really confused if I try to augment `__new__' on native
234   ## classes, so wrap and delegate.  Sorry.
235   def __init__(me, perso = '', *args, **kw):
236     super(_ShakeBase, me).__init__(*args, **kw)
237     me._h = me._SHAKE(perso = perso, func = me._FUNC)
238
239   ## Delegate methods...
240   def copy(me): new = me.__class__(); new._copy(me)
241   def _copy(me, other): me._h = other._h.copy()
242   def hash(me, m): me._h.hash(m); return me
243   def xof(me): me._h.xof(); return me
244   def get(me, n): return me._h.get(n)
245   def mask(me, m): return me._h.mask(m)
246   def done(me, n): return me._h.done(n)
247   def check(me, h): return ctstreq(h, me.done(len(h)))
248   @property
249   def state(me): return me._h.state
250   @property
251   def buffered(me): return me._h.buffered
252   @property
253   def rate(me): return me._h.rate
254
255 class _tmp:
256   def check(me, h):
257     return ctstreq(h, me.done(len(h)))
258   def leftenc(me, n):
259     nn = MP(n).storeb()
260     return me.hashu8(len(nn)).hash(nn)
261   def rightenc(me, n):
262     nn = MP(n).storeb()
263     return me.hash(nn).hashu8(len(nn))
264   def stringenc(me, str):
265     return me.leftenc(8*len(str)).hash(str)
266   def bytepad_before(me):
267     return me.leftenc(me.rate)
268   def bytepad_after(me):
269     if me.buffered: me.hash(me._Z[:me.rate - me.buffered])
270     return me
271   @_ctxmgr
272   def bytepad(me):
273     me.bytepad_before()
274     yield me
275     me.bytepad_after()
276 _augment(Shake, _tmp)
277 _augment(_ShakeBase, _tmp)
278 Shake._Z = _ShakeBase._Z = ByteString(200*'\0')
279
280 class KMAC (_ShakeBase):
281   _FUNC = 'KMAC'
282   def __init__(me, k, *arg, **kw):
283     super(KMAC, me).__init__(*arg, **kw)
284     with me.bytepad(): me.stringenc(k)
285   def done(me, n = -1):
286     if n < 0: n = me._TAGSZ
287     me.rightenc(8*n)
288     return super(KMAC, me).done(n)
289   def xof(me):
290     me.rightenc(0)
291     return super(KMAC, me).xof()
292
293 class KMAC128 (KMAC): _SHAKE = Shake128; _TAGSZ = 16
294 class KMAC256 (KMAC): _SHAKE = Shake256; _TAGSZ = 32
295
296 ###--------------------------------------------------------------------------
297 ### NaCl `secretbox'.
298
299 def secret_box(k, n, m):
300   E = xsalsa20(k).setiv(n)
301   r = E.enczero(poly1305.keysz.default)
302   s = E.enczero(poly1305.masksz)
303   y = E.encrypt(m)
304   t = poly1305(r)(s).hash(y).done()
305   return ByteString(t + y)
306
307 def secret_unbox(k, n, c):
308   E = xsalsa20(k).setiv(n)
309   r = E.enczero(poly1305.keysz.default)
310   s = E.enczero(poly1305.masksz)
311   y = c[poly1305.tagsz:]
312   if not poly1305(r)(s).hash(y).check(c[0:poly1305.tagsz]):
313     raise ValueError, 'decryption failed'
314   return E.decrypt(c[poly1305.tagsz:])
315
316 ###--------------------------------------------------------------------------
317 ### Multiprecision integers and binary polynomials.
318
319 def _split_rat(x):
320   if isinstance(x, BaseRat): return x._n, x._d
321   else: return x, 1
322 class BaseRat (object):
323   """Base class implementing fields of fractions over Euclidean domains."""
324   def __new__(cls, a, b):
325     a, b = cls.RING(a), cls.RING(b)
326     q, r = divmod(a, b)
327     if r == 0: return q
328     g = b.gcd(r)
329     me = super(BaseRat, cls).__new__(cls)
330     me._n = a//g
331     me._d = b//g
332     return me
333   @property
334   def numer(me): return me._n
335   @property
336   def denom(me): return me._d
337   def __str__(me): return '%s/%s' % (me._n, me._d)
338   def __repr__(me): return '%s(%s, %s)' % (_clsname(me), me._n, me._d)
339   _repr_pretty_ = _pp_str
340
341   def __add__(me, you):
342     n, d = _split_rat(you)
343     return type(me)(me._n*d + n*me._d, d*me._d)
344   __radd__ = __add__
345   def __sub__(me, you):
346     n, d = _split_rat(you)
347     return type(me)(me._n*d - n*me._d, d*me._d)
348   def __rsub__(me, you):
349     n, d = _split_rat(you)
350     return type(me)(n*me._d - me._n*d, d*me._d)
351   def __mul__(me, you):
352     n, d = _split_rat(you)
353     return type(me)(me._n*n, me._d*d)
354   __rmul__ = __mul__
355   def __truediv__(me, you):
356     n, d = _split_rat(you)
357     return type(me)(me._n*d, me._d*n)
358   def __rtruediv__(me, you):
359     n, d = _split_rat(you)
360     return type(me)(me._d*n, me._n*d)
361   __div__ = __truediv__
362   __rdiv__ = __rtruediv__
363   def __cmp__(me, you):
364     n, d = _split_rat(you)
365     return cmp(me._n*d, n*me._d)
366   def __rcmp__(me, you):
367     n, d = _split_rat(you)
368     return cmp(n*me._d, me._n*d)
369
370 class IntRat (BaseRat):
371   RING = MP
372
373 class GFRat (BaseRat):
374   RING = GF
375
376 class _tmp:
377   def negp(x): return x < 0
378   def posp(x): return x > 0
379   def zerop(x): return x == 0
380   def oddp(x): return x.testbit(0)
381   def evenp(x): return not x.testbit(0)
382   def mont(x): return MPMont(x)
383   def barrett(x): return MPBarrett(x)
384   def reduce(x): return MPReduce(x)
385   def __truediv__(me, you): return IntRat(me, you)
386   def __rtruediv__(me, you): return IntRat(you, me)
387   __div__ = __truediv__
388   __rdiv__ = __rtruediv__
389   _repr_pretty_ = _pp_str
390 _augment(MP, _tmp)
391
392 class _tmp:
393   def zerop(x): return x == 0
394   def reduce(x): return GFReduce(x)
395   def trace(x, y): return x.reduce().trace(y)
396   def halftrace(x, y): return x.reduce().halftrace(y)
397   def modsqrt(x, y): return x.reduce().sqrt(y)
398   def quadsolve(x, y): return x.reduce().quadsolve(y)
399   def __truediv__(me, you): return GFRat(me, you)
400   def __rtruediv__(me, you): return GFRat(you, me)
401   __div__ = __truediv__
402   __rdiv__ = __rtruediv__
403   _repr_pretty_ = _pp_str
404 _augment(GF, _tmp)
405
406 class _tmp:
407   def product(*arg):
408     'product(ITERABLE) or product(I, ...) -> PRODUCT'
409     return MPMul(*arg).done()
410   product = staticmethod(product)
411 _augment(MPMul, _tmp)
412
413 ###--------------------------------------------------------------------------
414 ### Abstract fields.
415
416 class _tmp:
417   def fromstring(str): return _checkend(Field.parse(str))
418   fromstring = staticmethod(fromstring)
419 _augment(Field, _tmp)
420
421 class _tmp:
422   def __repr__(me): return '%s(%sL)' % (_clsname(me), me.p)
423   def __hash__(me): return 0x114401de ^ hash(me.p)
424   def _repr_pretty_(me, pp, cyclep):
425     ind = _pp_bgroup_tyname(pp, me)
426     if cyclep: pp.text('...')
427     else: pp.pretty(me.p)
428     pp.end_group(ind, ')')
429   def ec(me, a, b): return ECPrimeProjCurve(me, a, b)
430 _augment(PrimeField, _tmp)
431
432 class _tmp:
433   def __repr__(me): return '%s(%#xL)' % (_clsname(me), me.p)
434   def ec(me, a, b): return ECBinProjCurve(me, a, b)
435   def _repr_pretty_(me, pp, cyclep):
436     ind = _pp_bgroup_tyname(pp, me)
437     if cyclep: pp.text('...')
438     else: pp.text('%#x' % me.p)
439     pp.end_group(ind, ')')
440 _augment(BinField, _tmp)
441
442 class _tmp:
443   def __hash__(me): return 0x23e4701c ^ hash(me.p)
444 _augment(BinPolyField, _tmp)
445
446 class _tmp:
447   def __hash__(me):
448     h = 0x9a7d6240
449     h ^=   hash(me.p)
450     h ^= 2*hash(me.beta) & 0xffffffff
451     return h
452 _augment(BinNormField, _tmp)
453
454 class _tmp:
455   def __str__(me): return str(me.value)
456   def __repr__(me): return '%s(%s)' % (repr(me.field), repr(me.value))
457   _repr_pretty_ = _pp_str
458 _augment(FE, _tmp)
459
460 ###--------------------------------------------------------------------------
461 ### Elliptic curves.
462
463 class _tmp:
464   def __repr__(me):
465     return '%s(%r, %s, %s)' % (_clsname(me), me.field, me.a, me.b)
466   def _repr_pretty_(me, pp, cyclep):
467     ind = _pp_bgroup_tyname(pp, me)
468     if cyclep:
469       pp.text('...')
470     else:
471       pp.pretty(me.field); pp.text(','); pp.breakable()
472       pp.pretty(me.a); pp.text(','); pp.breakable()
473       pp.pretty(me.b)
474     pp.end_group(ind, ')')
475   def frombuf(me, s):
476     return ecpt.frombuf(me, s)
477   def fromraw(me, s):
478     return ecpt.fromraw(me, s)
479   def pt(me, *args):
480     return me(*args)
481 _augment(ECCurve, _tmp)
482
483 class _tmp:
484   def __hash__(me):
485     h = 0x6751d341
486     h ^=   hash(me.field)
487     h ^= 2*hash(me.a) ^ 0xffffffff
488     h ^= 5*hash(me.b) ^ 0xffffffff
489     return h
490 _augment(ECPrimeCurve, _tmp)
491
492 class _tmp:
493   def __hash__(me):
494     h = 0x2ac203c5
495     h ^=   hash(me.field)
496     h ^= 2*hash(me.a) ^ 0xffffffff
497     h ^= 5*hash(me.b) ^ 0xffffffff
498     return h
499 _augment(ECBinCurve, _tmp)
500
501 class _tmp:
502   def __repr__(me):
503     if not me: return '%s()' % _clsname(me)
504     return '%s(%s, %s)' % (_clsname(me), me.ix, me.iy)
505   def __str__(me):
506     if not me: return 'inf'
507     return '(%s, %s)' % (me.ix, me.iy)
508   def _repr_pretty_(me, pp, cyclep):
509     if cyclep:
510       pp.text('...')
511     elif not me:
512       pp.text('inf')
513     else:
514       ind = _pp_bgroup(pp, '(')
515       pp.pretty(me.ix); pp.text(','); pp.breakable()
516       pp.pretty(me.iy)
517       pp.end_group(ind, ')')
518 _augment(ECPt, _tmp)
519
520 class _tmp:
521   def __repr__(me):
522     return '%s(curve = %r, G = %r, r = %s, h = %s)' % \
523            (_clsname(me), me.curve, me.G, me.r, me.h)
524   def _repr_pretty_(me, pp, cyclep):
525     ind = _pp_bgroup_tyname(pp, me)
526     if cyclep:
527       pp.text('...')
528     else:
529       _pp_kv(pp, 'curve', me.curve); pp.text(','); pp.breakable()
530       _pp_kv(pp, 'G', me.G); pp.text(','); pp.breakable()
531       _pp_kv(pp, 'r', me.r); pp.text(','); pp.breakable()
532       _pp_kv(pp, 'h', me.h)
533     pp.end_group(ind, ')')
534   def __hash__(me):
535     h = 0x9bedb8de
536     h ^=   hash(me.curve)
537     h ^= 2*hash(me.G) & 0xffffffff
538     return h
539   def group(me):
540     return ECGroup(me)
541 _augment(ECInfo, _tmp)
542
543 class _tmp:
544   def __repr__(me):
545     if not me: return '%r()' % (me.curve)
546     return '%r(%s, %s)' % (me.curve, me.x, me.y)
547   def __str__(me):
548     if not me: return 'inf'
549     return '(%s, %s)' % (me.x, me.y)
550   def _repr_pretty_(me, pp, cyclep):
551     if cyclep:
552       pp.text('...')
553     elif not me:
554       pp.text('inf')
555     else:
556       ind = _pp_bgroup(pp, '(')
557       pp.pretty(me.x); pp.text(','); pp.breakable()
558       pp.pretty(me.y)
559       pp.end_group(ind, ')')
560 _augment(ECPtCurve, _tmp)
561
562 ###--------------------------------------------------------------------------
563 ### Key sizes.
564
565 class _tmp:
566   def __repr__(me): return '%s(%d)' % (_clsname(me), me.default)
567   def check(me, sz): return True
568   def best(me, sz): return sz
569 _augment(KeySZAny, _tmp)
570
571 class _tmp:
572   def __repr__(me):
573     return '%s(%d, %d, %d, %d)' % \
574            (_clsname(me), me.default, me.min, me.max, me.mod)
575   def _repr_pretty_(me, pp, cyclep):
576     ind = _pp_bgroup_tyname(pp, me)
577     if cyclep:
578       pp.text('...')
579     else:
580       pp.pretty(me.default); pp.text(','); pp.breakable()
581       pp.pretty(me.min); pp.text(','); pp.breakable()
582       pp.pretty(me.max); pp.text(','); pp.breakable()
583       pp.pretty(me.mod)
584     pp.end_group(ind, ')')
585   def check(me, sz): return me.min <= sz <= me.max and sz % me.mod == 0
586   def best(me, sz):
587     if sz < me.min: raise ValueError, 'key too small'
588     elif sz > me.max: return me.max
589     else: return sz - (sz % me.mod)
590 _augment(KeySZRange, _tmp)
591
592 class _tmp:
593   def __repr__(me): return '%s(%d, %s)' % (_clsname(me), me.default, me.set)
594   def _repr_pretty_(me, pp, cyclep):
595     ind = _pp_bgroup_tyname(pp, me)
596     if cyclep:
597       pp.text('...')
598     else:
599       pp.pretty(me.default); pp.text(','); pp.breakable()
600       ind1 = _pp_bgroup(pp, '{')
601       _pp_commas(pp, pp.pretty, me.set)
602       pp.end_group(ind1, '}')
603     pp.end_group(ind, ')')
604   def check(me, sz): return sz in me.set
605   def best(me, sz):
606     found = -1
607     for i in me.set:
608       if found < i <= sz: found = i
609     if found < 0: raise ValueError, 'key too small'
610     return found
611 _augment(KeySZSet, _tmp)
612
613 ###--------------------------------------------------------------------------
614 ### Key data objects.
615
616 class _tmp:
617   def __repr__(me): return '%s(%r)' % (_clsname(me), me.name)
618 _augment(KeyFile, _tmp)
619
620 class _tmp:
621   def __repr__(me): return '%s(%r)' % (_clsname(me), me.fulltag)
622 _augment(Key, _tmp)
623
624 class _tmp:
625   def __repr__(me):
626     return '%s({%s})' % (_clsname(me),
627                          ', '.join(['%r: %r' % kv for kv in me.iteritems()]))
628   def _repr_pretty_(me, pp, cyclep):
629     ind = _pp_bgroup_tyname(pp, me)
630     if cyclep: pp.text('...')
631     else: _pp_dict(pp, me.iteritems())
632     pp.end_group(ind, ')')
633 _augment(KeyAttributes, _tmp)
634
635 class _tmp:
636   def __repr__(me):
637     return '%s(%s, %r)' % (_clsname(me),
638                            _repr_secret(me._guts(),
639                                         not (me.flags & KF_NONSECRET)),
640                            me.writeflags(me.flags))
641   def _repr_pretty_(me, pp, cyclep):
642     ind = _pp_bgroup_tyname(pp, me)
643     if cyclep:
644       pp.text('...')
645     else:
646       _pp_secret(pp, me._guts(), not (me.flags & KF_NONSECRET))
647       pp.text(','); pp.breakable()
648       pp.pretty(me.writeflags(me.flags))
649     pp.end_group(ind, ')')
650 _augment(KeyData, _tmp)
651
652 class _tmp:
653   def _guts(me): return me.bin
654 _augment(KeyDataBinary, _tmp)
655
656 class _tmp:
657   def _guts(me): return me.ct
658 _augment(KeyDataEncrypted, _tmp)
659
660 class _tmp:
661   def _guts(me): return me.mp
662 _augment(KeyDataMP, _tmp)
663
664 class _tmp:
665   def _guts(me): return me.str
666 _augment(KeyDataString, _tmp)
667
668 class _tmp:
669   def _guts(me): return me.ecpt
670 _augment(KeyDataECPt, _tmp)
671
672 class _tmp:
673   def __repr__(me):
674     return '%s({%s})' % (_clsname(me),
675                          ', '.join(['%r: %r' % kv for kv in me.iteritems()]))
676   def _repr_pretty_(me, pp, cyclep):
677     ind = _pp_bgroup_tyname(pp, me, '({ ')
678     if cyclep: pp.text('...')
679     else: _pp_dict(pp, me.iteritems())
680     pp.end_group(ind, ' })')
681 _augment(KeyDataStructured, _tmp)
682
683 ###--------------------------------------------------------------------------
684 ### Abstract groups.
685
686 class _tmp:
687   def __repr__(me):
688     return '%s(p = %s, r = %s, g = %s)' % (_clsname(me), me.p, me.r, me.g)
689   def _repr_pretty_(me, pp, cyclep):
690     ind = _pp_bgroup_tyname(pp, me)
691     if cyclep:
692       pp.text('...')
693     else:
694       _pp_kv(pp, 'p', me.p); pp.text(','); pp.breakable()
695       _pp_kv(pp, 'r', me.r); pp.text(','); pp.breakable()
696       _pp_kv(pp, 'g', me.g)
697     pp.end_group(ind, ')')
698 _augment(FGInfo, _tmp)
699
700 class _tmp:
701   def group(me): return PrimeGroup(me)
702 _augment(DHInfo, _tmp)
703
704 class _tmp:
705   def group(me): return BinGroup(me)
706 _augment(BinDHInfo, _tmp)
707
708 class _tmp:
709   def __repr__(me):
710     return '%s(%r)' % (_clsname(me), me.info)
711   def _repr_pretty_(me, pp, cyclep):
712     ind = _pp_bgroup_tyname(pp, me)
713     if cyclep: pp.text('...')
714     else: pp.pretty(me.info)
715     pp.end_group(ind, ')')
716 _augment(Group, _tmp)
717
718 class _tmp:
719   def __hash__(me):
720     info = me.info
721     h = 0xbce3cfe6
722     h ^=   hash(info.p)
723     h ^= 2*hash(info.r) & 0xffffffff
724     h ^= 5*hash(info.g) & 0xffffffff
725     return h
726   def _get_geval(me, x): return MP(x)
727 _augment(PrimeGroup, _tmp)
728
729 class _tmp:
730   def __hash__(me):
731     info = me.info
732     h = 0x80695949
733     h ^=   hash(info.p)
734     h ^= 2*hash(info.r) & 0xffffffff
735     h ^= 5*hash(info.g) & 0xffffffff
736     return h
737   def _get_geval(me, x): return GF(x)
738 _augment(BinGroup, _tmp)
739
740 class _tmp:
741   def __hash__(me): return 0x0ec23dab ^ hash(me.info)
742   def _get_geval(me, x): return x.toec()
743 _augment(ECGroup, _tmp)
744
745 class _tmp:
746   def __repr__(me):
747     return '%r(%r)' % (me.group, str(me))
748   def _repr_pretty_(me, pp, cyclep):
749     pp.pretty(type(me)._get_geval(me))
750 _augment(GE, _tmp)
751
752 ###--------------------------------------------------------------------------
753 ### RSA encoding techniques.
754
755 class PKCS1Crypt (object):
756   def __init__(me, ep = '', rng = rand):
757     me.ep = ep
758     me.rng = rng
759   def encode(me, msg, nbits):
760     return _base._p1crypt_encode(msg, nbits, me.ep, me.rng)
761   def decode(me, ct, nbits):
762     return _base._p1crypt_decode(ct, nbits, me.ep, me.rng)
763
764 class PKCS1Sig (object):
765   def __init__(me, ep = '', rng = rand):
766     me.ep = ep
767     me.rng = rng
768   def encode(me, msg, nbits):
769     return _base._p1sig_encode(msg, nbits, me.ep, me.rng)
770   def decode(me, msg, sig, nbits):
771     return _base._p1sig_decode(msg, sig, nbits, me.ep, me.rng)
772
773 class OAEP (object):
774   def __init__(me, mgf = sha_mgf, hash = sha, ep = '', rng = rand):
775     me.mgf = mgf
776     me.hash = hash
777     me.ep = ep
778     me.rng = rng
779   def encode(me, msg, nbits):
780     return _base._oaep_encode(msg, nbits, me.mgf, me.hash, me.ep, me.rng)
781   def decode(me, ct, nbits):
782     return _base._oaep_decode(ct, nbits, me.mgf, me.hash, me.ep, me.rng)
783
784 class PSS (object):
785   def __init__(me, mgf = sha_mgf, hash = sha, saltsz = None, rng = rand):
786     me.mgf = mgf
787     me.hash = hash
788     if saltsz is None:
789       saltsz = hash.hashsz
790     me.saltsz = saltsz
791     me.rng = rng
792   def encode(me, msg, nbits):
793     return _base._pss_encode(msg, nbits, me.mgf, me.hash, me.saltsz, me.rng)
794   def decode(me, msg, sig, nbits):
795     return _base._pss_decode(msg, sig, nbits,
796                              me.mgf, me.hash, me.saltsz, me.rng)
797
798 class _tmp:
799   def encrypt(me, msg, enc):
800     return me.pubop(enc.encode(msg, me.n.nbits))
801   def verify(me, msg, sig, enc):
802     if msg is None: return enc.decode(msg, me.pubop(sig), me.n.nbits)
803     try:
804       x = enc.decode(msg, me.pubop(sig), me.n.nbits)
805       return x is None or x == msg
806     except ValueError:
807       return False
808   def __repr__(me):
809     return '%s(n = %r, e = %r)' % (_clsname(me), me.n, me.e)
810   def _repr_pretty_(me, pp, cyclep):
811     ind = _pp_bgroup_tyname(pp, me)
812     if cyclep:
813       pp.text('...')
814     else:
815       _pp_kv(pp, 'n', me.n); pp.text(','); pp.breakable()
816       _pp_kv(pp, 'e', me.e)
817     pp.end_group(ind, ')')
818 _augment(RSAPub, _tmp)
819
820 class _tmp:
821   def decrypt(me, ct, enc): return enc.decode(me.privop(ct), me.n.nbits)
822   def sign(me, msg, enc): return me.privop(enc.encode(msg, me.n.nbits))
823   def __repr__(me):
824     return '%s(n = %r, e = %r, d = %s, ' \
825       'p = %s, q = %s, dp = %s, dq = %s, q_inv = %s)' % \
826       (_clsname(me), me.n, me.e,
827        _repr_secret(me.d), _repr_secret(me.p), _repr_secret(me.q),
828        _repr_secret(me.dp), _repr_secret(me.dq), _repr_secret(me.q_inv))
829   def _repr_pretty_(me, pp, cyclep):
830     ind = _pp_bgroup_tyname(pp, me)
831     if cyclep:
832       pp.text('...')
833     else:
834       _pp_kv(pp, 'n', me.n); pp.text(','); pp.breakable()
835       _pp_kv(pp, 'e', me.e); pp.text(','); pp.breakable()
836       _pp_kv(pp, 'd', me.d, secretp = True); pp.text(','); pp.breakable()
837       _pp_kv(pp, 'p', me.p, secretp = True); pp.text(','); pp.breakable()
838       _pp_kv(pp, 'q', me.q, secretp = True); pp.text(','); pp.breakable()
839       _pp_kv(pp, 'dp', me.dp, secretp = True); pp.text(','); pp.breakable()
840       _pp_kv(pp, 'dq', me.dq, secretp = True); pp.text(','); pp.breakable()
841       _pp_kv(pp, 'q_inv', me.q_inv, secretp = True)
842     pp.end_group(ind, ')')
843 _augment(RSAPriv, _tmp)
844
845 ###--------------------------------------------------------------------------
846 ### DSA and related schemes.
847
848 class _tmp:
849   def __repr__(me): return '%s(G = %r, p = %r)' % (_clsname(me), me.G, me.p)
850   def _repr_pretty_(me, pp, cyclep):
851     ind = _pp_bgroup_tyname(pp, me)
852     if cyclep:
853       pp.text('...')
854     else:
855       _pp_kv(pp, 'G', me.G); pp.text(','); pp.breakable()
856       _pp_kv(pp, 'p', me.p)
857     pp.end_group(ind, ')')
858 _augment(DSAPub, _tmp)
859 _augment(KCDSAPub, _tmp)
860
861 class _tmp:
862   def __repr__(me): return '%s(G = %r, u = %s, p = %r)' % \
863       (_clsname(me), me.G, _repr_secret(me.u), me.p)
864   def _repr_pretty_(me, pp, cyclep):
865     ind = _pp_bgroup_tyname(pp, me)
866     if cyclep:
867       pp.text('...')
868     else:
869       _pp_kv(pp, 'G', me.G); pp.text(','); pp.breakable()
870       _pp_kv(pp, 'u', me.u, True); pp.text(','); pp.breakable()
871       _pp_kv(pp, 'p', me.p)
872     pp.end_group(ind, ')')
873 _augment(DSAPriv, _tmp)
874 _augment(KCDSAPriv, _tmp)
875
876 ###--------------------------------------------------------------------------
877 ### Bernstein's elliptic curve crypto and related schemes.
878
879 X25519_BASE = MP(9).storel(32)
880 X448_BASE = MP(5).storel(56)
881
882 Z128 = ByteString.zero(16)
883
884 class _BasePub (object):
885   def __init__(me, pub, *args, **kw):
886     if not me._PUBSZ.check(len(pub)): raise ValueError, 'bad public key'
887     super(_BasePub, me).__init__(*args, **kw)
888     me.pub = pub
889   def __repr__(me): return '%s(pub = %r)' % (_clsname(me), me.pub)
890   def _pp(me, pp): _pp_kv(pp, 'pub', me.pub)
891   def _repr_pretty_(me, pp, cyclep):
892     ind = _pp_bgroup_tyname(pp, me)
893     if cyclep: pp.text('...')
894     else: me._pp(pp)
895     pp.end_group(ind, ')')
896
897 class _BasePriv (object):
898   def __init__(me, priv, pub = None, *args, **kw):
899     if not me._KEYSZ.check(len(priv)): raise ValueError, 'bad private key'
900     if pub is None: pub = me._pubkey(priv)
901     super(_BasePriv, me).__init__(pub = pub, *args, **kw)
902     me.priv = priv
903   @classmethod
904   def generate(cls, rng = rand):
905     return cls(rng.block(cls._KEYSZ.default))
906   def __repr__(me):
907     return '%s(priv = %d, pub = %r)' % \
908         (_clsname(me), _repr_secret(me.priv), me.pub)
909   def _pp(me, pp):
910     _pp_kv(pp, 'priv', me.priv, secretp = True); pp.text(','); pp.breakable()
911     super(_BasePriv, me)._pp(pp)
912
913 class _XDHPub (_BasePub):  pass
914
915 class _XDHPriv (_BasePriv):
916   def _pubkey(me, priv): return me._op(priv, me._BASE)
917   def agree(me, you): return me._op(me.priv, you.pub)
918   def boxkey(me, recip): return me._hashkey(me.agree(recip))
919   def box(me, recip, n, m): return secret_box(me.boxkey(recip), n, m)
920   def unbox(me, recip, n, c): return secret_unbox(me.boxkey(recip), n, c)
921
922 class X25519Pub (_XDHPub):
923   _PUBSZ = KeySZSet(X25519_PUBSZ)
924   _BASE = X25519_BASE
925
926 class X25519Priv (_XDHPriv, X25519Pub):
927   _KEYSZ = KeySZSet(X25519_KEYSZ)
928   def _op(me, k, X): return x25519(k, X)
929   def _hashkey(me, z): return hsalsa20_prf(z, Z128)
930
931 class X448Pub (_XDHPub):
932   _PUBSZ = KeySZSet(X448_PUBSZ)
933   _BASE = X448_BASE
934
935 class X448Priv (_XDHPriv, X448Pub):
936   _KEYSZ = KeySZSet(X448_KEYSZ)
937   def _op(me, k, X): return x448(k, X)
938   def _hashkey(me, z): return Shake256().hash(z).done(salsa20.keysz.default)
939
940 class _EdDSAPub (_BasePub):
941   def beginhash(me): return me._HASH()
942   def endhash(me, h): return h.done()
943
944 class _EdDSAPriv (_BasePriv, _EdDSAPub):
945   pass
946
947 class Ed25519Pub (_EdDSAPub):
948   _PUBSZ = KeySZSet(ED25519_PUBSZ)
949   _HASH = sha512
950   def verify(me, msg, sig, **kw):
951     return ed25519_verify(me.pub, msg, sig, **kw)
952
953 class Ed25519Priv (_EdDSAPriv, Ed25519Pub):
954   _KEYSZ = KeySZAny(ED25519_KEYSZ)
955   def _pubkey(me, priv): return ed25519_pubkey(priv)
956   def sign(me, msg, **kw):
957     return ed25519_sign(me.priv, msg, pub = me.pub, **kw)
958
959 class Ed448Pub (_EdDSAPub):
960   _PUBSZ = KeySZSet(ED448_PUBSZ)
961   _HASH = shake256
962   def verify(me, msg, sig, **kw):
963     return ed448_verify(me.pub, msg, sig, **kw)
964
965 class Ed448Priv (_EdDSAPriv, Ed448Pub):
966   _KEYSZ = KeySZAny(ED448_KEYSZ)
967   def _pubkey(me, priv): return ed448_pubkey(priv)
968   def sign(me, msg, **kw):
969     return ed448_sign(me.priv, msg, pub = me.pub, **kw)
970
971 ###--------------------------------------------------------------------------
972 ### Built-in named curves and prime groups.
973
974 class _groupmap (object):
975   def __init__(me, map, nth):
976     me.map = map
977     me.nth = nth
978     me._n = max(map.values()) + 1
979     me.i = me._n*[None]
980   def __repr__(me):
981     return '{%s}' % ', '.join(['%r: %r' % kv for kv in me.iteritems()])
982   def _repr_pretty_(me, pp, cyclep):
983     ind = _pp_bgroup(pp, '{ ')
984     if cyclep: pp.text('...')
985     else: _pp_dict(pp, me.iteritems())
986     pp.end_group(ind, ' }')
987   def __len__(me):
988     return me._n
989   def __contains__(me, k):
990     return k in me.map
991   def __getitem__(me, k):
992     i = me.map[k]
993     if me.i[i] is None:
994       me.i[i] = me.nth(i)
995     return me.i[i]
996   def __setitem__(me, k, v):
997     raise TypeError, "immutable object"
998   def __iter__(me):
999     return iter(me.map)
1000   def iterkeys(me):
1001     return iter(me.map)
1002   def itervalues(me):
1003     for k in me:
1004       yield me[k]
1005   def iteritems(me):
1006     for k in me:
1007       yield k, me[k]
1008   def keys(me):
1009     return [k for k in me]
1010   def values(me):
1011     return [me[k] for k in me]
1012   def items(me):
1013     return [(k, me[k]) for k in me]
1014 eccurves = _groupmap(_base._eccurves, ECInfo._curven)
1015 primegroups = _groupmap(_base._pgroups, DHInfo._groupn)
1016 bingroups = _groupmap(_base._bingroups, BinDHInfo._groupn)
1017
1018 ###--------------------------------------------------------------------------
1019 ### Prime number generation.
1020
1021 class PrimeGenEventHandler (object):
1022   def pg_begin(me, ev):
1023     return me.pg_try(ev)
1024   def pg_done(me, ev):
1025     return PGEN_DONE
1026   def pg_abort(me, ev):
1027     return PGEN_TRY
1028   def pg_fail(me, ev):
1029     return PGEN_TRY
1030   def pg_pass(me, ev):
1031     return PGEN_TRY
1032
1033 class SophieGermainStepJump (object):
1034   def pg_begin(me, ev):
1035     me.lf = PrimeFilter(ev.x)
1036     me.hf = me.lf.muladd(2, 1)
1037     return me.cont(ev)
1038   def pg_try(me, ev):
1039     me.step()
1040     return me.cont(ev)
1041   def cont(me, ev):
1042     while me.lf.status == PGEN_FAIL or me.hf.status == PGEN_FAIL:
1043       me.step()
1044     if me.lf.status == PGEN_ABORT or me.hf.status == PGEN_ABORT:
1045       return PGEN_ABORT
1046     ev.x = me.lf.x
1047     if me.lf.status == PGEN_DONE and me.hf.status == PGEN_DONE:
1048       return PGEN_DONE
1049     return PGEN_TRY
1050   def pg_done(me, ev):
1051     del me.lf
1052     del me.hf
1053
1054 class SophieGermainStepper (SophieGermainStepJump):
1055   def __init__(me, step):
1056     me.lstep = step;
1057     me.hstep = 2 * step
1058   def step(me):
1059     me.lf.step(me.lstep)
1060     me.hf.step(me.hstep)
1061
1062 class SophieGermainJumper (SophieGermainStepJump):
1063   def __init__(me, jump):
1064     me.ljump = PrimeFilter(jump);
1065     me.hjump = me.ljump.muladd(2, 0)
1066   def step(me):
1067     me.lf.jump(me.ljump)
1068     me.hf.jump(me.hjump)
1069   def pg_done(me, ev):
1070     del me.ljump
1071     del me.hjump
1072     SophieGermainStepJump.pg_done(me, ev)
1073
1074 class SophieGermainTester (object):
1075   def __init__(me):
1076     pass
1077   def pg_begin(me, ev):
1078     me.lr = RabinMiller(ev.x)
1079     me.hr = RabinMiller(2 * ev.x + 1)
1080   def pg_try(me, ev):
1081     lst = me.lr.test(ev.rng.range(me.lr.x))
1082     if lst != PGEN_PASS and lst != PGEN_DONE:
1083       return lst
1084     rst = me.hr.test(ev.rng.range(me.hr.x))
1085     if rst != PGEN_PASS and rst != PGEN_DONE:
1086       return rst
1087     if lst == PGEN_DONE and rst == PGEN_DONE:
1088       return PGEN_DONE
1089     return PGEN_PASS
1090   def pg_done(me, ev):
1091     del me.lr
1092     del me.hr
1093
1094 class PrimitiveStepper (PrimeGenEventHandler):
1095   def __init__(me):
1096     pass
1097   def pg_try(me, ev):
1098     ev.x = me.i.next()
1099     return PGEN_TRY
1100   def pg_begin(me, ev):
1101     me.i = iter(smallprimes)
1102     return me.pg_try(ev)
1103
1104 class PrimitiveTester (PrimeGenEventHandler):
1105   def __init__(me, mod, hh = [], exp = None):
1106     me.mod = MPMont(mod)
1107     me.exp = exp
1108     me.hh = hh
1109   def pg_try(me, ev):
1110     x = ev.x
1111     if me.exp is not None:
1112       x = me.mod.exp(x, me.exp)
1113       if x == 1: return PGEN_FAIL
1114     for h in me.hh:
1115       if me.mod.exp(x, h) == 1: return PGEN_FAIL
1116     ev.x = x
1117     return PGEN_DONE
1118
1119 class SimulStepper (PrimeGenEventHandler):
1120   def __init__(me, mul = 2, add = 1, step = 2):
1121     me.step = step
1122     me.mul = mul
1123     me.add = add
1124   def _stepfn(me, step):
1125     if step <= 0:
1126       raise ValueError, 'step must be positive'
1127     if step <= MPW_MAX:
1128       return lambda f: f.step(step)
1129     j = PrimeFilter(step)
1130     return lambda f: f.jump(j)
1131   def pg_begin(me, ev):
1132     x = ev.x
1133     me.lf = PrimeFilter(x)
1134     me.hf = PrimeFilter(x * me.mul + me.add)
1135     me.lstep = me._stepfn(me.step)
1136     me.hstep = me._stepfn(me.step * me.mul)
1137     SimulStepper._cont(me, ev)
1138   def pg_try(me, ev):
1139     me._step()
1140     me._cont(ev)
1141   def _step(me):
1142     me.lstep(me.lf)
1143     me.hstep(me.hf)
1144   def _cont(me, ev):
1145     while me.lf.status == PGEN_FAIL or me.hf.status == PGEN_FAIL:
1146       me._step()
1147     if me.lf.status == PGEN_ABORT or me.hf.status == PGEN_ABORT:
1148       return PGEN_ABORT
1149     ev.x = me.lf.x
1150     if me.lf.status == PGEN_DONE and me.hf.status == PGEN_DONE:
1151       return PGEN_DONE
1152     return PGEN_TRY
1153   def pg_done(me, ev):
1154     del me.lf
1155     del me.hf
1156     del me.lstep
1157     del me.hstep
1158
1159 class SimulTester (PrimeGenEventHandler):
1160   def __init__(me, mul = 2, add = 1):
1161     me.mul = mul
1162     me.add = add
1163   def pg_begin(me, ev):
1164     x = ev.x
1165     me.lr = RabinMiller(x)
1166     me.hr = RabinMiller(x * me.mul + me.add)
1167   def pg_try(me, ev):
1168     lst = me.lr.test(ev.rng.range(me.lr.x))
1169     if lst != PGEN_PASS and lst != PGEN_DONE:
1170       return lst
1171     rst = me.hr.test(ev.rng.range(me.hr.x))
1172     if rst != PGEN_PASS and rst != PGEN_DONE:
1173       return rst
1174     if lst == PGEN_DONE and rst == PGEN_DONE:
1175       return PGEN_DONE
1176     return PGEN_PASS
1177   def pg_done(me, ev):
1178     del me.lr
1179     del me.hr
1180
1181 def sgprime(start, step = 2, name = 'p', event = pgen_nullev, nsteps = 0):
1182   start = MP(start)
1183   return pgen(start, name, SimulStepper(step = step), SimulTester(), event,
1184               nsteps, RabinMiller.iters(start.nbits))
1185
1186 def findprimitive(mod, hh = [], exp = None, name = 'g', event = pgen_nullev):
1187   return pgen(0, name, PrimitiveStepper(), PrimitiveTester(mod, hh, exp),
1188               event, 0, 1)
1189
1190 def kcdsaprime(pbits, qbits, rng = rand,
1191                event = pgen_nullev, name = 'p', nsteps = 0):
1192   hbits = pbits - qbits
1193   h = pgen(rng.mp(hbits, 1), name + ' [h]',
1194            PrimeGenStepper(2), PrimeGenTester(),
1195            event, nsteps, RabinMiller.iters(hbits))
1196   q = pgen(rng.mp(qbits, 1), name, SimulStepper(2 * h, 1, 2),
1197            SimulTester(2 * h, 1), event, nsteps, RabinMiller.iters(qbits))
1198   p = 2 * q * h + 1
1199   return p, q, h
1200
1201 #----- That's all, folks ----------------------------------------------------