chiark / gitweb /
catacomb/__init__.py (KeySZRange.pad): Return correct value.
[catacomb-python] / catacomb / __init__.py
index 1a2b5f44dae692137e73a0edb9162a88d351a452..65695bb89bc2dfa4af84f30f0b647254a5c0d2c9 100644 (file)
 
 from __future__ import with_statement
 
-import _base
-import types as _types
 from binascii import hexlify as _hexify, unhexlify as _unhexify
 from contextlib import contextmanager as _ctxmgr
-from sys import argv as _argv
+try: import DLFCN as _dlfcn
+except ImportError: _dlfcn = None
+import os as _os
 from struct import pack as _pack
+import sys as _sys
+import types as _types
+
+###--------------------------------------------------------------------------
+### Import the main C extension module.
+
+try:
+  _dlflags = _odlflags = _sys.getdlopenflags()
+except AttributeError:
+  _dlflags = _odlflags = -1
+
+## Set the `deep binding' flag.  Python has its own different MD5
+## implementation, and some distributions export `md5_init' and friends so
+## they override our versions, which doesn't end well.  Figure out how to
+## turn this flag on so we don't have the problem.
+if _dlflags >= 0:
+  try: _dlflags |= _dlfcn.RTLD_DEEPBIND
+  except AttributeError:
+    try: _dlflags |= _os.RTLD_DEEPBIND
+    except AttributeError:
+      if _os.uname()[0] == 'Linux': _dlflags |= 8 # magic knowledge
+      else: pass # can't do this.
+  _sys.setdlopenflags(_dlflags)
+
+import _base
+
+if _odlflags >= 0:
+  _sys.setdlopenflags(_odlflags)
+
+del _dlflags, _odlflags
 
 ###--------------------------------------------------------------------------
 ### Basic stuff.
 
-## For the benefit of the default keyreporter, we need the program na,e.
-_base._ego(_argv[0])
+## For the benefit of the default keyreporter, we need the program name.
+_base._ego(_sys.argv[0])
+
+## Register our module.
+_base._set_home_module(_sys.modules[__name__])
+def default_lostexchook(why, ty, val, tb):
+  """`catacomb.lostexchook(WHY, TY, VAL, TB)' reports lost exceptions."""
+  _sys.stderr.write("\n\n!!! LOST EXCEPTION: %s\n" % why)
+  _sys.excepthook(ty, val, tb)
+  _sys.stderr.write("\n")
+lostexchook = default_lostexchook
 
 ## How to fix a name back into the right identifier.  Alas, the rules are not
 ## consistent.
@@ -46,7 +85,7 @@ def _fixname(name):
   name = name.replace('-', '_')
 
   ## But slashes might become underscores or just vanish.
-  if name.startswith('salsa20'): name = name.translate(None, '/')
+  if name.startswith('salsa20'): name = name.replace('/', '')
   else: name = name.replace('/', '_')
 
   ## Done.
@@ -73,7 +112,7 @@ def _init():
     for j in b:
       if j[:plen] == pre:
         setattr(c, j[plen:], classmethod(b[j]))
-  for i in [gcciphers, gchashes, gcmacs, gcprps]:
+  for i in [gcciphers, gcaeads, gchashes, gcmacs, gcprps]:
     for c in i.itervalues():
       d[_fixname(c.name)] = c
   for c in gccrands.itervalues():
@@ -154,6 +193,27 @@ _augment(ByteString, _tmp)
 ByteString.__hash__ = str.__hash__
 bytes = ByteString.fromhex
 
+###--------------------------------------------------------------------------
+### Symmetric encryption.
+
+class _tmp:
+  def encrypt(me, n, m, tsz = None, h = ByteString('')):
+    if tsz is None: tsz = me.__class__.tagsz.default
+    e = me.enc(n, len(h), len(m), tsz)
+    if not len(h): a = None
+    else: a = e.aad().hash(h)
+    c0 = e.encrypt(m)
+    c1, t = e.done(aad = a)
+    return c0 + c1, t
+  def decrypt(me, n, c, t, h = ByteString('')):
+    d = me.dec(n, len(h), len(c), len(t))
+    if not len(h): a = None
+    else: a = d.aad().hash(h)
+    m = d.decrypt(c)
+    m += d.done(t, aad = a)
+    return m
+_augment(GAEKey, _tmp)
+
 ###--------------------------------------------------------------------------
 ### Hashing.
 
@@ -167,15 +227,31 @@ _augment(Poly1305Hash, _tmp)
 class _HashBase (object):
   ## The standard hash methods.  Assume that `hash' is defined and returns
   ## the receiver.
-  def hashu8(me, n): return me.hash(_pack('B', n))
-  def hashu16l(me, n): return me.hash(_pack('<H', n))
-  def hashu16b(me, n): return me.hash(_pack('>H', n))
+  def _check_range(me, n, max):
+    if not (0 <= n <= max): raise OverflowError("out of range")
+  def hashu8(me, n):
+    me._check_range(n, 0xff)
+    return me.hash(_pack('B', n))
+  def hashu16l(me, n):
+    me._check_range(n, 0xffff)
+    return me.hash(_pack('<H', n))
+  def hashu16b(me, n):
+    me._check_range(n, 0xffff)
+    return me.hash(_pack('>H', n))
   hashu16 = hashu16b
-  def hashu32l(me, n): return me.hash(_pack('<L', n))
-  def hashu32b(me, n): return me.hash(_pack('>L', n))
+  def hashu32l(me, n):
+    me._check_range(n, 0xffffffff)
+    return me.hash(_pack('<L', n))
+  def hashu32b(me, n):
+    me._check_range(n, 0xffffffff)
+    return me.hash(_pack('>L', n))
   hashu32 = hashu32b
-  def hashu64l(me, n): return me.hash(_pack('<Q', n))
-  def hashu64b(me, n): return me.hash(_pack('>Q', n))
+  def hashu64l(me, n):
+    me._check_range(n, 0xffffffffffffffff)
+    return me.hash(_pack('<Q', n))
+  def hashu64b(me, n):
+    me._check_range(n, 0xffffffffffffffff)
+    return me.hash(_pack('>Q', n))
   hashu64 = hashu64b
   def hashbuf8(me, s): return me.hashu8(len(s)).hash(s)
   def hashbuf16l(me, s): return me.hashu16l(len(s)).hash(s)
@@ -198,8 +274,8 @@ class _ShakeBase (_HashBase):
     me._h = me._SHAKE(perso = perso, func = me._FUNC)
 
   ## Delegate methods...
-  def copy(me): new = me.__class__(); new._copy(me)
-  def _copy(me, other): me._h = other._h
+  def copy(me): new = me.__class__._bare_new(); new._copy(me); return new
+  def _copy(me, other): me._h = other._h.copy()
   def hash(me, m): me._h.hash(m); return me
   def xof(me): me._h.xof(); return me
   def get(me, n): return me._h.get(n)
@@ -212,6 +288,8 @@ class _ShakeBase (_HashBase):
   def buffered(me): return me._h.buffered
   @property
   def rate(me): return me._h.rate
+  @classmethod
+  def _bare_new(cls): return cls()
 
 class _tmp:
   def check(me, h):
@@ -236,7 +314,7 @@ class _tmp:
     me.bytepad_after()
 _augment(Shake, _tmp)
 _augment(_ShakeBase, _tmp)
-Shake._Z = _ShakeBase._Z = ByteString(200*'\0')
+Shake._Z = _ShakeBase._Z = ByteString.zero(200)
 
 class KMAC (_ShakeBase):
   _FUNC = 'KMAC'
@@ -250,6 +328,8 @@ class KMAC (_ShakeBase):
   def xof(me):
     me.rightenc(0)
     return super(KMAC, me).xof()
+  @classmethod
+  def _bare_new(cls): return cls("")
 
 class KMAC128 (KMAC): _SHAKE = Shake128; _TAGSZ = 16
 class KMAC256 (KMAC): _SHAKE = Shake256; _TAGSZ = 32
@@ -258,21 +338,12 @@ class KMAC256 (KMAC): _SHAKE = Shake256; _TAGSZ = 32
 ### NaCl `secretbox'.
 
 def secret_box(k, n, m):
-  E = xsalsa20(k).setiv(n)
-  r = E.enczero(poly1305.keysz.default)
-  s = E.enczero(poly1305.masksz)
-  y = E.encrypt(m)
-  t = poly1305(r)(s).hash(y).done()
-  return ByteString(t + y)
+  y, t = salsa20_naclbox(k).encrypt(n, m)
+  return t + y
 
 def secret_unbox(k, n, c):
-  E = xsalsa20(k).setiv(n)
-  r = E.enczero(poly1305.keysz.default)
-  s = E.enczero(poly1305.masksz)
-  y = c[poly1305.tagsz:]
-  if not poly1305(r)(s).hash(y).check(c[0:poly1305.tagsz]):
-    raise ValueError, 'decryption failed'
-  return E.decrypt(c[poly1305.tagsz:])
+  tsz = poly1305.tagsz
+  return salsa20_naclbox(k).decrypt(n, c[tsz:], c[0:tsz])
 
 ###--------------------------------------------------------------------------
 ### Multiprecision integers and binary polynomials.
@@ -312,15 +383,18 @@ class BaseRat (object):
   def __mul__(me, you):
     n, d = _split_rat(you)
     return type(me)(me._n*n, me._d*d)
-  def __div__(me, you):
+  __rmul__ = __mul__
+  def __truediv__(me, you):
     n, d = _split_rat(you)
     return type(me)(me._n*d, me._d*n)
-  def __rdiv__(me, you):
+  def __rtruediv__(me, you):
     n, d = _split_rat(you)
     return type(me)(me._d*n, me._n*d)
+  __div__ = __truediv__
+  __rdiv__ = __rtruediv__
   def __cmp__(me, you):
     n, d = _split_rat(you)
-    return type(me)(me._n*d, n*me._d)
+    return cmp(me._n*d, n*me._d)
   def __rcmp__(me, you):
     n, d = _split_rat(you)
     return cmp(n*me._d, me._n*d)
@@ -340,8 +414,10 @@ class _tmp:
   def mont(x): return MPMont(x)
   def barrett(x): return MPBarrett(x)
   def reduce(x): return MPReduce(x)
-  def __div__(me, you): return IntRat(me, you)
-  def __rdiv__(me, you): return IntRat(you, me)
+  def __truediv__(me, you): return IntRat(me, you)
+  def __rtruediv__(me, you): return IntRat(you, me)
+  __div__ = __truediv__
+  __rdiv__ = __rtruediv__
   _repr_pretty_ = _pp_str
 _augment(MP, _tmp)
 
@@ -352,8 +428,10 @@ class _tmp:
   def halftrace(x, y): return x.reduce().halftrace(y)
   def modsqrt(x, y): return x.reduce().sqrt(y)
   def quadsolve(x, y): return x.reduce().quadsolve(y)
-  def __div__(me, you): return GFRat(me, you)
-  def __rdiv__(me, you): return GFRat(you, me)
+  def __truediv__(me, you): return GFRat(me, you)
+  def __rtruediv__(me, you): return GFRat(you, me)
+  __div__ = __truediv__
+  __rdiv__ = __rtruediv__
   _repr_pretty_ = _pp_str
 _augment(GF, _tmp)
 
@@ -520,6 +598,7 @@ class _tmp:
   def __repr__(me): return '%s(%d)' % (_clsname(me), me.default)
   def check(me, sz): return True
   def best(me, sz): return sz
+  def pad(me, sz): return sz
 _augment(KeySZAny, _tmp)
 
 class _tmp:
@@ -536,11 +615,15 @@ class _tmp:
       pp.pretty(me.max); pp.text(','); pp.breakable()
       pp.pretty(me.mod)
     pp.end_group(ind, ')')
-  def check(me, sz): return me.min <= sz <= me.max and sz % me.mod == 0
+  def check(me, sz): return me.min <= sz <= me.max and sz%me.mod == 0
   def best(me, sz):
     if sz < me.min: raise ValueError, 'key too small'
     elif sz > me.max: return me.max
-    else: return sz - (sz % me.mod)
+    else: return sz - sz%me.mod
+  def pad(me, sz):
+    if sz > me.max: raise ValueError, 'key too large'
+    elif sz < me.min: return me.min
+    else: sz += me.mod - 1; return sz - sz%me.mod
 _augment(KeySZRange, _tmp)
 
 class _tmp:
@@ -562,6 +645,12 @@ class _tmp:
       if found < i <= sz: found = i
     if found < 0: raise ValueError, 'key too small'
     return found
+  def pad(me, sz):
+    found = -1
+    for i in me.set:
+      if sz <= i and (found == -1 or i < found): found = i
+    if found < 0: raise ValueError, 'key too large'
+    return found
 _augment(KeySZSet, _tmp)
 
 ###--------------------------------------------------------------------------