chiark / gitweb /
catacomb/pwsafe.py, pwsafe: Make GDBM support conditional.
[catacomb-python] / catacomb / __init__.py
1 ### -*-python-*-
2 ###
3 ### Setup for Catacomb/Python bindings
4 ###
5 ### (c) 2004 Straylight/Edgeware
6 ###
7
8 ###----- Licensing notice ---------------------------------------------------
9 ###
10 ### This file is part of the Python interface to Catacomb.
11 ###
12 ### Catacomb/Python is free software; you can redistribute it and/or modify
13 ### it under the terms of the GNU General Public License as published by
14 ### the Free Software Foundation; either version 2 of the License, or
15 ### (at your option) any later version.
16 ###
17 ### Catacomb/Python is distributed in the hope that it will be useful,
18 ### but WITHOUT ANY WARRANTY; without even the implied warranty of
19 ### MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
20 ### GNU General Public License for more details.
21 ###
22 ### You should have received a copy of the GNU General Public License
23 ### along with Catacomb/Python; if not, write to the Free Software Foundation,
24 ### Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
25
26 import _base
27 import types as _types
28 from binascii import hexlify as _hexify, unhexlify as _unhexify
29 from sys import argv as _argv
30
31 ###--------------------------------------------------------------------------
32 ### Basic stuff.
33
34 ## For the benefit of the default keyreporter, we need the program na,e.
35 _base._ego(_argv[0])
36
37 ## Initialize the module.  Drag in the static methods of the various
38 ## classes; create names for the various known crypto algorithms.
39 def _init():
40   d = globals()
41   b = _base.__dict__;
42   for i in b:
43     if i[0] != '_':
44       d[i] = b[i];
45   for i in ['MP', 'GF', 'Field',
46             'ECPt', 'ECPtCurve', 'ECCurve', 'ECInfo',
47             'DHInfo', 'BinDHInfo', 'RSAPriv', 'BBSPriv',
48             'PrimeFilter', 'RabinMiller',
49             'Group', 'GE',
50             'KeyData']:
51     c = d[i]
52     pre = '_' + i + '_'
53     plen = len(pre)
54     for j in b:
55       if j[:plen] == pre:
56         setattr(c, j[plen:], classmethod(b[j]))
57   for i in [gcciphers, gchashes, gcmacs, gcprps]:
58     for c in i.itervalues():
59       d[c.name.replace('-', '_')] = c
60   for c in gccrands.itervalues():
61     d[c.name.replace('-', '_') + 'rand'] = c
62 _init()
63
64 ## A handy function for our work: add the methods of a named class to an
65 ## existing class.  This is how we write the Python-implemented parts of our
66 ## mostly-C types.
67 def _augment(c, cc):
68   for i in cc.__dict__:
69     a = cc.__dict__[i]
70     if type(a) is _types.MethodType:
71       a = a.im_func
72     elif type(a) not in (_types.FunctionType, staticmethod, classmethod):
73       continue
74     setattr(c, i, a)
75
76 ## Parsing functions tend to return the object parsed and the remainder of
77 ## the input.  This checks that the remainder is input and, if so, returns
78 ## just the object.
79 def _checkend(r):
80   x, rest = r
81   if rest != '':
82     raise SyntaxError, 'junk at end of string'
83   return x
84
85 ###--------------------------------------------------------------------------
86 ### Bytestrings.
87
88 class _tmp:
89   def fromhex(x):
90     return ByteString(_unhexify(x))
91   fromhex = staticmethod(fromhex)
92   def __hex__(me):
93     return _hexify(me)
94   def __repr__(me):
95     return 'bytes(%r)' % hex(me)
96 _augment(ByteString, _tmp)
97 bytes = ByteString.fromhex
98
99 ###--------------------------------------------------------------------------
100 ### Multiprecision integers and binary polynomials.
101
102 class _tmp:
103   def negp(x): return x < 0
104   def posp(x): return x > 0
105   def zerop(x): return x == 0
106   def oddp(x): return x.testbit(0)
107   def evenp(x): return not x.testbit(0)
108   def mont(x): return MPMont(x)
109   def barrett(x): return MPBarrett(x)
110   def reduce(x): return MPReduce(x)
111   def factorial(x):
112     'factorial(X) -> X!'
113     if x < 0: raise ValueError, 'factorial argument must be > 0'
114     return MPMul.product(xrange(1, x + 1))
115   factorial = staticmethod(factorial)
116 _augment(MP, _tmp)
117
118 class _tmp:
119   def zerop(x): return x == 0
120   def reduce(x): return GFReduce(x)
121   def trace(x, y): return x.reduce().trace(y)
122   def halftrace(x, y): return x.reduce().halftrace(y)
123   def modsqrt(x, y): return x.reduce().sqrt(y)
124   def quadsolve(x, y): return x.reduce().quadsolve(y)
125 _augment(GF, _tmp)
126
127 class _tmp:
128   def product(*arg):
129     'product(ITERABLE) or product(I, ...) -> PRODUCT'
130     return MPMul(*arg).done()
131   product = staticmethod(product)
132 _augment(MPMul, _tmp)
133
134 ###--------------------------------------------------------------------------
135 ### Abstract fields.
136
137 class _tmp:
138   def fromstring(str): return _checkend(Field.parse(str))
139   fromstring = staticmethod(fromstring)
140 _augment(Field, _tmp)
141
142 class _tmp:
143   def __repr__(me): return '%s(%sL)' % (type(me).__name__, me.p)
144   def ec(me, a, b): return ECPrimeProjCurve(me, a, b)
145 _augment(PrimeField, _tmp)
146
147 class _tmp:
148   def __repr__(me): return '%s(%sL)' % (type(me).__name__, hex(me.p))
149   def ec(me, a, b): return ECBinProjCurve(me, a, b)
150 _augment(BinField, _tmp)
151
152 class _tmp:
153   def __str__(me): return str(me.value)
154   def __repr__(me): return '%s(%s)' % (repr(me.field), repr(me.value))
155 _augment(FE, _tmp)
156
157 ###--------------------------------------------------------------------------
158 ### Elliptic curves.
159
160 class _tmp:
161   def __repr__(me):
162     return '%s(%r, %s, %s)' % (type(me).__name__, me.field, me.a, me.b)
163   def frombuf(me, s):
164     return ecpt.frombuf(me, s)
165   def fromraw(me, s):
166     return ecpt.fromraw(me, s)
167   def pt(me, *args):
168     return me(*args)
169 _augment(ECCurve, _tmp)
170
171 class _tmp:
172   def __repr__(me):
173     if not me: return 'ECPt()'
174     return 'ECPt(%s, %s)' % (me.ix, me.iy)
175   def __str__(me):
176     if not me: return 'inf'
177     return '(%s, %s)' % (me.ix, me.iy)
178 _augment(ECPt, _tmp)
179
180 class _tmp:
181   def __repr__(me):
182     return 'ECInfo(curve = %r, G = %r, r = %s, h = %s)' % \
183            (me.curve, me.G, me.r, me.h)
184   def group(me):
185     return ECGroup(me)
186 _augment(ECInfo, _tmp)
187
188 class _tmp:
189   def __repr__(me):
190     if not me: return '%r()' % (me.curve)
191     return '%r(%s, %s)' % (me.curve, me.x, me.y)
192   def __str__(me):
193     if not me: return 'inf'
194     return '(%s, %s)' % (me.x, me.y)
195 _augment(ECPtCurve, _tmp)
196
197 ###--------------------------------------------------------------------------
198 ### Key sizes.
199
200 class _tmp:
201   def __repr__(me): return 'KeySZAny(%d)' % me.default
202   def check(me, sz): return True
203   def best(me, sz): return sz
204 _augment(KeySZAny, _tmp)
205
206 class _tmp:
207   def __repr__(me):
208     return 'KeySZRange(%d, %d, %d, %d)' % \
209            (me.default, me.min, me.max, me.mod)
210   def check(me, sz): return me.min <= sz <= me.max and sz % me.mod == 0
211   def best(me, sz):
212     if sz < me.min: raise ValueError, 'key too small'
213     elif sz > me.max: return me.max
214     else: return sz - (sz % me.mod)
215 _augment(KeySZRange, _tmp)
216
217 class _tmp:
218   def __repr__(me): return 'KeySZSet(%d, %s)' % (me.default, me.set)
219   def check(me, sz): return sz in me.set
220   def best(me, sz):
221     found = -1
222     for i in me.set:
223       if found < i <= sz: found = i
224     if found < 0: raise ValueError, 'key too small'
225     return found
226 _augment(KeySZSet, _tmp)
227
228 ###--------------------------------------------------------------------------
229 ### Abstract groups.
230
231 class _tmp:
232   def __repr__(me):
233     return '%s(p = %s, r = %s, g = %s)' % \
234            (type(me).__name__, me.p, me.r, me.g)
235 _augment(FGInfo, _tmp)
236
237 class _tmp:
238   def group(me): return PrimeGroup(me)
239 _augment(DHInfo, _tmp)
240
241 class _tmp:
242   def group(me): return BinGroup(me)
243 _augment(BinDHInfo, _tmp)
244
245 class _tmp:
246   def __repr__(me):
247     return '%s(%r)' % (type(me).__name__, me.info)
248 _augment(Group, _tmp)
249
250 class _tmp:
251   def __repr__(me):
252     return '%r(%r)' % (me.group, str(me))
253 _augment(GE, _tmp)
254
255 ###--------------------------------------------------------------------------
256 ### RSA encoding techniques.
257
258 class PKCS1Crypt (object):
259   def __init__(me, ep = '', rng = rand):
260     me.ep = ep
261     me.rng = rng
262   def encode(me, msg, nbits):
263     return _base._p1crypt_encode(msg, nbits, me.ep, me.rng)
264   def decode(me, ct, nbits):
265     return _base._p1crypt_decode(ct, nbits, me.ep, me.rng)
266
267 class PKCS1Sig (object):
268   def __init__(me, ep = '', rng = rand):
269     me.ep = ep
270     me.rng = rng
271   def encode(me, msg, nbits):
272     return _base._p1sig_encode(msg, nbits, me.ep, me.rng)
273   def decode(me, msg, sig, nbits):
274     return _base._p1sig_decode(msg, sig, nbits, me.ep, me.rng)
275
276 class OAEP (object):
277   def __init__(me, mgf = sha_mgf, hash = sha, ep = '', rng = rand):
278     me.mgf = mgf
279     me.hash = hash
280     me.ep = ep
281     me.rng = rng
282   def encode(me, msg, nbits):
283     return _base._oaep_encode(msg, nbits, me.mgf, me.hash, me.ep, me.rng)
284   def decode(me, ct, nbits):
285     return _base._oaep_decode(ct, nbits, me.mgf, me.hash, me.ep, me.rng)
286
287 class PSS (object):
288   def __init__(me, mgf = sha_mgf, hash = sha, saltsz = None, rng = rand):
289     me.mgf = mgf
290     me.hash = hash
291     if saltsz is None:
292       saltsz = hash.hashsz
293     me.saltsz = saltsz
294     me.rng = rng
295   def encode(me, msg, nbits):
296     return _base._pss_encode(msg, nbits, me.mgf, me.hash, me.saltsz, me.rng)
297   def decode(me, msg, sig, nbits):
298     return _base._pss_decode(msg, sig, nbits,
299                              me.mgf, me.hash, me.saltsz, me.rng)
300
301 class _tmp:
302   def encrypt(me, msg, enc):
303     return me.pubop(enc.encode(msg, me.n.nbits))
304   def verify(me, msg, sig, enc):
305     if msg is None: return enc.decode(msg, me.pubop(sig), me.n.nbits)
306     try:
307       x = enc.decode(msg, me.pubop(sig), me.n.nbits)
308       return x is None or x == msg
309     except ValueError:
310       return False
311 _augment(RSAPub, _tmp)
312
313 class _tmp:
314   def decrypt(me, ct, enc): return enc.decode(me.privop(ct), me.n.nbits)
315   def sign(me, msg, enc): return me.privop(enc.encode(msg, me.n.nbits))
316 _augment(RSAPriv, _tmp)
317
318 ###--------------------------------------------------------------------------
319 ### Built-in named curves and prime groups.
320
321 class _groupmap (object):
322   def __init__(me, map, nth):
323     me.map = map
324     me.nth = nth
325     me.i = [None] * (max(map.values()) + 1)
326   def __repr__(me):
327     return '{%s}' % ', '.join(['%r: %r' % (k, me[k]) for k in me])
328   def __contains__(me, k):
329     return k in me.map
330   def __getitem__(me, k):
331     i = me.map[k]
332     if me.i[i] is None:
333       me.i[i] = me.nth(i)
334     return me.i[i]
335   def __setitem__(me, k, v):
336     raise TypeError, "immutable object"
337   def __iter__(me):
338     return iter(me.map)
339   def iterkeys(me):
340     return iter(me.map)
341   def itervalues(me):
342     for k in me:
343       yield me[k]
344   def iteritems(me):
345     for k in me:
346       yield k, me[k]
347   def keys(me):
348     return [k for k in me]
349   def values(me):
350     return [me[k] for k in me]
351   def items(me):
352     return [(k, me[k]) for k in me]
353 eccurves = _groupmap(_base._eccurves, ECInfo._curven)
354 primegroups = _groupmap(_base._pgroups, DHInfo._groupn)
355 bingroups = _groupmap(_base._bingroups, BinDHInfo._groupn)
356
357 ###--------------------------------------------------------------------------
358 ### Prime number generation.
359
360 class PrimeGenEventHandler (object):
361   def pg_begin(me, ev):
362     return me.pg_try(ev)
363   def pg_done(me, ev):
364     return PGEN_DONE
365   def pg_abort(me, ev):
366     return PGEN_TRY
367   def pg_fail(me, ev):
368     return PGEN_TRY
369   def pg_pass(me, ev):
370     return PGEN_TRY
371
372 class SophieGermainStepJump (object):
373   def pg_begin(me, ev):
374     me.lf = PrimeFilter(ev.x)
375     me.hf = me.lf.muladd(2, 1)
376     return me.cont(ev)
377   def pg_try(me, ev):
378     me.step()
379     return me.cont(ev)
380   def cont(me, ev):
381     while me.lf.status == PGEN_FAIL or me.hf.status == PGEN_FAIL:
382       me.step()
383     if me.lf.status == PGEN_ABORT or me.hf.status == PGEN_ABORT:
384       return PGEN_ABORT
385     ev.x = me.lf.x
386     if me.lf.status == PGEN_DONE and me.hf.status == PGEN_DONE:
387       return PGEN_DONE
388     return PGEN_TRY
389   def pg_done(me, ev):
390     del me.lf
391     del me.hf
392
393 class SophieGermainStepper (SophieGermainStepJump):
394   def __init__(me, step):
395     me.lstep = step;
396     me.hstep = 2 * step
397   def step(me):
398     me.lf.step(me.lstep)
399     me.hf.step(me.hstep)
400
401 class SophieGermainJumper (SophieGermainStepJump):
402   def __init__(me, jump):
403     me.ljump = PrimeFilter(jump);
404     me.hjump = me.ljump.muladd(2, 0)
405   def step(me):
406     me.lf.jump(me.ljump)
407     me.hf.jump(me.hjump)
408   def pg_done(me, ev):
409     del me.ljump
410     del me.hjump
411     SophieGermainStepJump.pg_done(me, ev)
412
413 class SophieGermainTester (object):
414   def __init__(me):
415     pass
416   def pg_begin(me, ev):
417     me.lr = RabinMiller(ev.x)
418     me.hr = RabinMiller(2 * ev.x + 1)
419   def pg_try(me, ev):
420     lst = me.lr.test(ev.rng.range(me.lr.x))
421     if lst != PGEN_PASS and lst != PGEN_DONE:
422       return lst
423     rst = me.hr.test(ev.rng.range(me.hr.x))
424     if rst != PGEN_PASS and rst != PGEN_DONE:
425       return rst
426     if lst == PGEN_DONE and rst == PGEN_DONE:
427       return PGEN_DONE
428     return PGEN_PASS
429   def pg_done(me, ev):
430     del me.lr
431     del me.hr
432
433 class PrimitiveStepper (PrimeGenEventHandler):
434   def __init__(me):
435     pass
436   def pg_try(me, ev):
437     ev.x = me.i.next()
438     return PGEN_TRY
439   def pg_begin(me, ev):
440     me.i = iter(smallprimes)
441     return me.pg_try(ev)
442
443 class PrimitiveTester (PrimeGenEventHandler):
444   def __init__(me, mod, hh = [], exp = None):
445     me.mod = MPMont(mod)
446     me.exp = exp
447     me.hh = hh
448   def pg_try(me, ev):
449     x = ev.x
450     if me.exp is not None:
451       x = me.mod.exp(x, me.exp)
452       if x == 1: return PGEN_FAIL
453     for h in me.hh:
454       if me.mod.exp(x, h) == 1: return PGEN_FAIL
455     ev.x = x
456     return PGEN_DONE
457
458 class SimulStepper (PrimeGenEventHandler):
459   def __init__(me, mul = 2, add = 1, step = 2):
460     me.step = step
461     me.mul = mul
462     me.add = add
463   def _stepfn(me, step):
464     if step <= 0:
465       raise ValueError, 'step must be positive'
466     if step <= MPW_MAX:
467       return lambda f: f.step(step)
468     j = PrimeFilter(step)
469     return lambda f: f.jump(j)
470   def pg_begin(me, ev):
471     x = ev.x
472     me.lf = PrimeFilter(x)
473     me.hf = PrimeFilter(x * me.mul + me.add)
474     me.lstep = me._stepfn(me.step)
475     me.hstep = me._stepfn(me.step * me.mul)
476     SimulStepper._cont(me, ev)
477   def pg_try(me, ev):
478     me._step()
479     me._cont(ev)
480   def _step(me):
481     me.lstep(me.lf)
482     me.hstep(me.hf)
483   def _cont(me, ev):
484     while me.lf.status == PGEN_FAIL or me.hf.status == PGEN_FAIL:
485       me._step()
486     if me.lf.status == PGEN_ABORT or me.hf.status == PGEN_ABORT:
487       return PGEN_ABORT
488     ev.x = me.lf.x
489     if me.lf.status == PGEN_DONE and me.hf.status == PGEN_DONE:
490       return PGEN_DONE
491     return PGEN_TRY
492   def pg_done(me, ev):
493     del me.lf
494     del me.hf
495     del me.lstep
496     del me.hstep
497
498 class SimulTester (PrimeGenEventHandler):
499   def __init__(me, mul = 2, add = 1):
500     me.mul = mul
501     me.add = add
502   def pg_begin(me, ev):
503     x = ev.x
504     me.lr = RabinMiller(x)
505     me.hr = RabinMiller(x * me.mul + me.add)
506   def pg_try(me, ev):
507     lst = me.lr.test(ev.rng.range(me.lr.x))
508     if lst != PGEN_PASS and lst != PGEN_DONE:
509       return lst
510     rst = me.hr.test(ev.rng.range(me.hr.x))
511     if rst != PGEN_PASS and rst != PGEN_DONE:
512       return rst
513     if lst == PGEN_DONE and rst == PGEN_DONE:
514       return PGEN_DONE
515     return PGEN_PASS
516   def pg_done(me, ev):
517     del me.lr
518     del me.hr
519
520 def sgprime(start, step = 2, name = 'p', event = pgen_nullev, nsteps = 0):
521   start = MP(start)
522   return pgen(start, name, SimulStepper(step = step), SimulTester(), event,
523               nsteps, RabinMiller.iters(start.nbits))
524
525 def findprimitive(mod, hh = [], exp = None, name = 'g', event = pgen_nullev):
526   return pgen(0, name, PrimitiveStepper(), PrimitiveTester(mod, hh, exp),
527               event, 0, 1)
528
529 def kcdsaprime(pbits, qbits, rng = rand,
530                event = pgen_nullev, name = 'p', nsteps = 0):
531   hbits = pbits - qbits
532   h = pgen(rng.mp(hbits, 1), name + ' [h]',
533            PrimeGenStepper(2), PrimeGenTester(),
534            event, nsteps, RabinMiller.iters(hbits))
535   q = pgen(rng.mp(qbits, 1), name, SimulStepper(2 * h, 1, 2),
536            SimulTester(2 * h, 1), event, nsteps, RabinMiller.iters(qbits))
537   p = 2 * q * h + 1
538   return p, q, h
539
540 #----- That's all, folks ----------------------------------------------------