chiark / gitweb /
catacomb/__init__.py: Implement equality and hashing for `KeyData' objects.
authorMark Wooding <mdw@distorted.org.uk>
Mon, 25 Nov 2019 12:07:16 +0000 (12:07 +0000)
committerMark Wooding <mdw@distorted.org.uk>
Sat, 11 Apr 2020 11:44:21 +0000 (12:44 +0100)
Equality is determined by value, so don't use `KeyData' objects as
hashtable keys and then mutate them.

catacomb/__init__.py
t/t-key.py

index bd8aa57e2bbb37f6f5ae4957feae8a26046902e2..94efb6fa6897b5ea7a6c89857a0887bc5fe86272 100644 (file)
@@ -701,27 +701,41 @@ class _tmp:
       pp.text(','); pp.breakable()
       pp.pretty(me.writeflags(me.flags))
     pp.end_group(ind, ')')
+  def __hash__(me): return me._HASHBASE ^ hash(me._guts())
+  def __eq__(me, kd):
+    return type(me) == type(kd) and \
+      me._guts() == kd._guts() and \
+      me.flags == kd.flags
+  def __ne__(me, kd):
+    return not me == kd
 _augment(KeyData, _tmp)
 
 class _tmp:
   def _guts(me): return me.bin
+  def __eq__(me, kd):
+    return isinstance(kd, KeyDataBinary) and me.bin == kd.bin
 _augment(KeyDataBinary, _tmp)
+KeyDataBinary._HASHBASE = 0x961755c3
 
 class _tmp:
   def _guts(me): return me.ct
 _augment(KeyDataEncrypted, _tmp)
+KeyDataEncrypted._HASHBASE = 0xffe000d4
 
 class _tmp:
   def _guts(me): return me.mp
 _augment(KeyDataMP, _tmp)
+KeyDataMP._HASHBASE = 0x1cb64d69
 
 class _tmp:
   def _guts(me): return me.str
 _augment(KeyDataString, _tmp)
+KeyDataString._HASHBASE = 0x349c33ea
 
 class _tmp:
   def _guts(me): return me.ecpt
 _augment(KeyDataECPt, _tmp)
+KeyDataECPt._HASHBASE = 0x2509718b
 
 class _tmp:
   def __repr__(me):
@@ -732,7 +746,21 @@ class _tmp:
     if cyclep: pp.text('...')
     else: _pp_dict(pp, _iteritems(me))
     pp.end_group(ind, ' })')
+  def __hash__(me):
+    h = me._HASHBASE
+    for k, v in _iteritems(me):
+      h = ((h << 1) ^ 3*hash(k) ^ 5*hash(v))&0xffffffff
+    return h
+  def __eq__(me, kd):
+    if type(me) != type(kd) or me.flags != kd.flags or len(me) != len(kd):
+      return False
+    for k, v in _iteritems(me):
+      try: vv = kd[k]
+      except KeyError: return False
+      if v != vv: return False
+    return True
 _augment(KeyDataStructured, _tmp)
+KeyDataStructured._HASHBASE = 0x85851b21
 
 ###--------------------------------------------------------------------------
 ### Abstract groups.
index 20cadb1681d2da6c54bca1c523e2f7486f8e22eb..bee107b460411da229479e82ee0eb4458ebc843a 100644 (file)
@@ -188,7 +188,7 @@ class TestKeyFile (U.TestCase):
     k = kf.newkey(0x11111111, "first", exp)
     me.assertEqual(kf.modifiedp, True)
 
-    me.assertEqual(kf[0x11111111].id, 0x11111111)
+    me.assertEqual(k, kf[0x11111111])
     me.assertEqual(k.exptime, exp)
     me.assertEqual(k.deltime, exp)
     me.assertRaises(ValueError, setattr, k, "deltime", C.KEXP_FOREVER)
@@ -212,24 +212,6 @@ class TestKeyFile (U.TestCase):
                  "22222222:test integer,public:32519164 forever forever -")
 
 ###--------------------------------------------------------------------------
-
-def keydata_equalp(kd0, kd1):
-  if type(kd0) is not type(kd1): return False
-  elif type(kd0) is C.KeyDataBinary: return kd0.bin == kd1.bin
-  elif type(kd0) is C.KeyDataMP: return kd0.mp == kd1.mp
-  elif type(kd0) is C.KeyDataEncrypted: return kd0.ct == kd1.ct
-  elif type(kd0) is C.KeyDataECPt: return kd0.ecpt == kd1.ecpt
-  elif type(kd0) is C.KeyDataString: return kd0.str == kd1.str
-  elif type(kd0) is C.KeyDataStructured:
-    if len(kd0) != len(kd1): return False
-    for t, v0 in T.iteritems(kd0):
-      try: v1 = kd1[t]
-      except KeyError: return False
-      if not keydata_equalp(v0, v1): return False
-    return True
-  else:
-    raise SystemError("unexpected keydata type")
-
 class TestKeyData (U.TestCase):
 
   def test_flags(me):
@@ -266,10 +248,8 @@ class TestKeyData (U.TestCase):
     me.assertEqual(set(T.iterkeys(kd2)), set(["b"]))
 
   def check_encode(me, kd):
-    me.assertTrue(keydata_equalp(C.KeyData.decode(kd.encode()), kd))
-    kd1, tail = C.KeyData.read(kd.write())
-    me.assertEqual(tail, "")
-    me.assertTrue(keydata_equalp(kd, kd1))
+    me.assertEqual(C.KeyData.decode(kd.encode()), kd)
+    me.assertEqual(C.KeyData.read(kd.write()), (kd, ""))
 
   def test_bin(me):
     rng = T.detrand("kd-bin")
@@ -335,6 +315,13 @@ class TestKeyFileMapping (T.ImmutableMappingTextMixin):
 
     me.check_immutable_mapping(kf, model)
 
+class TestKeyStructMapping (T.MutableMappingTestMixin):
+  def _mkvalue(me, i): return C.KeyDataMP(i)
+  def _getvalue(me, v): return v.mp
+
+  def test_keystructmap(me):
+    me.check_mapping(C.KeyDataStructured)
+
 class TestKeyAttrMapping (T.MutableMappingTestMixin):
 
   def test_attrmap(me):