chiark / gitweb /
t/: Add a test suite.
[catacomb-python] / t / testutils.py
diff --git a/t/testutils.py b/t/testutils.py
new file mode 100644 (file)
index 0000000..42bbd76
--- /dev/null
@@ -0,0 +1,275 @@
+### -*- mode: python, coding: utf-8 -*-
+###
+### Test utilities
+###
+### (c) 2019 Straylight/Edgeware
+###
+
+###----- Licensing notice ---------------------------------------------------
+###
+### This file is part of the Python interface to Catacomb.
+###
+### Catacomb/Python is free software: you can redistribute it and/or
+### modify it under the terms of the GNU General Public License as
+### published by the Free Software Foundation; either version 2 of the
+### License, or (at your option) any later version.
+###
+### Catacomb/Python is distributed in the hope that it will be useful, but
+### WITHOUT ANY WARRANTY; without even the implied warranty of
+### MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+### General Public License for more details.
+###
+### You should have received a copy of the GNU General Public License
+### along with Catacomb/Python.  If not, write to the Free Software
+### Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307,
+### USA.
+
+###--------------------------------------------------------------------------
+### Imported modules.
+
+import catacomb as C
+import sys as SYS
+if SYS.version_info >= (3,): import builtins as B
+else: import __builtin__ as B
+import unittest as U
+
+###--------------------------------------------------------------------------
+### Main code.
+
+## Some compatibility hacks.
+import itertools as I
+def bin(x): return x
+range = xrange
+long = long
+imap = I.imap
+def byteseq(seq): return "".join(map(chr, seq))
+def iterkeys(m): return m.iterkeys()
+def itervalues(m): return m.itervalues()
+def iteritems(m): return m.iteritems()
+from cStringIO import StringIO
+MAXFIXNUM = SYS.maxint
+
+DEBUGP = hasattr(SYS, "gettotalrefcount")
+
+FULLSPAN = byteseq(range(256))
+def span(n):
+  """A string `00 01 .. NN'."""
+  return (n >> 8)*FULLSPAN + FULLSPAN[:n&255]
+
+def bytes_as_int(w, bigendp):
+  """Convert the byte-sequence `01 02 ... WW' to an integer."""
+  x = 0
+  if bigendp:
+    for i in range(w): x = x << 8 | i + 1
+  else:
+    for i in range(w): x |= i + 1 << 8*i
+  return x
+
+def prep_lenseq(w, n, bigendp, goodp):
+  """
+  Return a reference buffer containing `00 LL .. LL 00 01 02 .. NN ff'.
+
+  Here, LL .. LL is the length of following sequence, not including the final
+  `ff', as a W-byte integer.  If GOODP is false, then the most significant
+  bit of LL .. LL is set, to provoke an overflow.
+  """
+  if goodp: l = n
+  else: l = n + (1 << 8*w - 1)
+  lenbyte = bigendp \
+    and (lambda i: (l >> 8*(w - i - 1))&0xff) \
+    or (lambda i: (l >> 8*i)&0xff)
+  return byteseq([0x00]) + \
+    byteseq([lenbyte(i) for i in range(w)]) + \
+    span(n) + \
+    byteseq([0xff])
+
+Z64 = C.ByteString.zero(8)
+def detrand(seed):
+  """Return a fast deterministic random generator with the given SEED."""
+  return C.chacha8rand(C.sha256().hash(bin(seed)).done(), Z64)
+
+class GenericTestMixin (U.TestCase):
+  """
+  A mixin class to generate test-case functions for all similar things.
+  """
+
+  @classmethod
+  def generate_testcases(cls, things):
+    testfns = dict()
+    checkfns = []
+    for k, v in iteritems(cls.__dict__):
+      if k.startswith("_test_"): checkfns.append((k[6:], v))
+    for name, thing in things:
+      for test, checkfn in checkfns:
+        testfn = lambda me, thing = thing: checkfn(me, thing)
+        doc = getattr(checkfn, "__doc__", None)
+        if doc is not None: testfn.__doc__ = doc % name
+        testfns["test_%s%%%s" % (test, name)] = testfn
+    tmpcls =  type("_tmp", (cls,), testfns)
+    for k, v in iteritems(tmpcls.__dict__):
+      if k.startswith("test_"): setattr(cls, k, v)
+
+class ImmutableMappingTextMixin (U.TestCase):
+
+  ## Subclass stubs.
+  def _mkkey(me, i): return "k#%d" % i
+  def _getkey(me, k): return int(k[2:])
+  def _getvalue(me, v): return int(v[2:])
+  def _getitem(me, it): k, v = it; return me._getkey(k), me._getvalue(v)
+
+  def check_immutable_mapping(me, map, model):
+
+    ## Lookup.
+    limk = 0
+    any = False
+    me.assertEqual(len(map), len(model))
+    for k, v in iteritems(model):
+      any = True
+      if k >= limk: limk = k + 1
+      me.assertTrue(me._mkkey(k) in map)
+      me.assertTrue(map.has_key(me._mkkey(k)))
+      me.assertEqual(me._getvalue(map[me._mkkey(k)]), v)
+      me.assertEqual(me._getvalue(map.get(me._mkkey(k))), v)
+    if any: me.assertTrue(me._mkkey(k) in map)
+    me.assertFalse(map.has_key(me._mkkey(limk)))
+    me.assertRaises(KeyError, lambda: map[me._mkkey(limk)])
+    me.assertEqual(map.get(me._mkkey(limk)), None)
+    for listfn, getfn in [(lambda x: x.keys(), me._getkey),
+                          (lambda x: x.values(), me._getvalue),
+                          (lambda x: x.items(), me._getitem)]:
+      rlist, mlist = listfn(map), listfn(model)
+      me.assertEqual(type(rlist), list)
+      rlist = B.map(getfn, rlist)
+      rlist.sort(); mlist.sort(); me.assertEqual(rlist, mlist)
+    for iterfn, getfn in [(lambda x: x.iterkeys(), me._getkey),
+                          (lambda x: x.itervalues(), me._getvalue),
+                          (lambda x: x.iteritems(), me._getitem)]:
+      me.assertEqual(set(imap(getfn, iterfn(map))), set(iterfn(model)))
+
+class MutableMappingTestMixin (ImmutableMappingTextMixin):
+
+  ## Subclass stubs.
+  def _mkvalue(me, i): return "v#%d" % i
+
+  def check_mapping(me, emptymapfn):
+
+    map = emptymapfn()
+    me.assertEqual(len(map), 0)
+
+    def check_views():
+      me.check_immutable_mapping(map, model)
+
+    model = { 1: 101, 2: 202, 4: 404 }
+    for k, v in iteritems(model): map[me._mkkey(k)] = me._mkvalue(v)
+    check_views()
+
+    model.update({ 2: 212, 6: 606, 7: 707 })
+    map.update({ me._mkkey(2): me._mkvalue(212),
+                 me._mkkey(6): me._mkvalue(606),
+                 me._mkkey(7): me._mkvalue(707) })
+    check_views()
+
+    model[9] = 909
+    map[me._mkkey(9)] = me._mkvalue(909)
+    check_views()
+
+    model[9] = 919
+    map[me._mkkey(9)] = me._mkvalue(919)
+    check_views()
+
+    map.setdefault(me._mkkey(9), me._mkvalue(929))
+    check_views()
+
+    model[8] = 808
+    map.setdefault(me._mkkey(8), me._mkvalue(808))
+    check_views()
+
+    me.assertRaises(KeyError, map.pop, me._mkkey(5))
+    obj = object()
+    me.assertEqual(map.pop(me._mkkey(5), obj), obj)
+    me.assertEqual(me._getvalue(map.pop(me._mkkey(8))), 808)
+    del model[8]
+    check_views()
+
+    del model[9]
+    del map[me._mkkey(9)]
+    check_views()
+
+    k, v = map.popitem()
+    mk, mv = me._getkey(k), me._getvalue(v)
+    me.assertEqual(model[mk], mv)
+    del model[mk]
+    check_views()
+
+    map.clear()
+    model = {}
+    check_views()
+
+class Explosion (Exception): pass
+
+class EventRecorder (C.PrimeGenEventHandler):
+  def __init__(me, parent = None, explode_after = None, *args, **kw):
+    super(EventRecorder, me).__init__(*args, **kw)
+    me._streak = 0
+    me._op = None
+    me._parent = parent
+    me._countdown = explode_after
+    me.rng = None
+    if parent is None: me._buf = StringIO()
+    else: me._buf = parent._buf
+  def _event_common(me, ev):
+    if me.rng is None: me.rng = ev.rng
+    if me._countdown is None: pass
+    elif me._countdown == 0: raise Explosion()
+    else: me._countdown -= 1
+  def _put(me, op):
+    if op == me._op:
+      me._streak += 1
+    else:
+      if me._op is not None: me._buf.write("%s%d/" % (me._op, me._streak))
+      me._op = op
+      me._streak = 1
+  def pg_begin(me, ev):
+    me._event_common(ev)
+    me._buf.write("[%s:" % ev.name)
+  def pg_try(me, ev):
+    me._event_common(ev)
+  def pg_fail(me, ev):
+    me._event_common(ev)
+    me._put("F")
+  def pg_pass(me, ev):
+    me._event_common(ev)
+    me._put("P")
+  def pg_done(me, ev):
+    me._event_common(ev)
+    me._put(None); me._buf.write("D]")
+  def pg_abort(me, ev):
+    me._event_common(ev)
+    me._put(None); me._buf.write("A]")
+  @property
+  def events(me):
+    return me._buf.getvalue()
+
+## Functions for operators.
+neg = lambda x: -x
+pos = lambda x: +x
+add = lambda x, y: x + y
+sub = lambda x, y: x - y
+mul = lambda x, y: x*y
+div = lambda x, y: x/y
+mod = lambda x, y: x%y
+floordiv = lambda x, y: x//y
+bitand = lambda x, y: x&y
+bitor = lambda x, y: x | y
+bitxor = lambda x, y: x ^ y
+bitnot = lambda x: ~x
+lsl = lambda x, y: x << y
+lsr = lambda x, y: x >> y
+eq = lambda x, y: x == y
+ne = lambda x, y: x != y
+lt = lambda x, y: x < y
+le = lambda x, y: x <= y
+ge = lambda x, y: x >= y
+gt = lambda x, y: x > y
+
+###----- That's all, folks --------------------------------------------------