chiark / gitweb /
algorithms.c: Set `KSZ.max' to `None' to indicate no bound.
[catacomb-python] / t / t-algorithms.py
1 ### -*- mode: python, coding: utf-8 -*-
2 ###
3 ### Test symmetric algorithms
4 ###
5 ### (c) 2019 Straylight/Edgeware
6 ###
7
8 ###----- Licensing notice ---------------------------------------------------
9 ###
10 ### This file is part of the Python interface to Catacomb.
11 ###
12 ### Catacomb/Python is free software: you can redistribute it and/or
13 ### modify it under the terms of the GNU General Public License as
14 ### published by the Free Software Foundation; either version 2 of the
15 ### License, or (at your option) any later version.
16 ###
17 ### Catacomb/Python is distributed in the hope that it will be useful, but
18 ### WITHOUT ANY WARRANTY; without even the implied warranty of
19 ### MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
20 ### General Public License for more details.
21 ###
22 ### You should have received a copy of the GNU General Public License
23 ### along with Catacomb/Python.  If not, write to the Free Software
24 ### Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307,
25 ### USA.
26
27 ###--------------------------------------------------------------------------
28 ### Imported modules.
29
30 import catacomb as C
31 import unittest as U
32 import testutils as T
33
34 ###--------------------------------------------------------------------------
35 ### Utilities.
36
37 def bad_key_size(ksz):
38   if isinstance(ksz, C.KeySZAny): return None
39   elif isinstance(ksz, C.KeySZRange):
40     if ksz.mod != 1: return ksz.min + 1
41     elif ksz.max is not None: return ksz.max + 1
42     elif ksz.min != 0: return ksz.min - 1
43     else: return None
44   elif isinstance(ksz, C.KeySZSet):
45     for sz in sorted(ksz.set):
46       if sz + 1 not in ksz.set: return sz + 1
47     assert False, "That should have worked."
48   else:
49     return None
50
51 def different_key_size(ksz, sz):
52   if isinstance(ksz, C.KeySZAny): return sz + 1
53   elif isinstance(ksz, C.KeySZRange):
54     if sz > ksz.min: return sz - ksz.mod
55     elif ksz.max is None or sz < ksz.max: return sz + ksz.mod
56     else: return None
57   elif isinstance(ksz, C.KeySZSet):
58     for sz1 in sorted(ksz.set):
59       if sz != sz1: return sz1
60     return None
61   else:
62     return None
63
64 class HashBufferTestMixin (U.TestCase):
65   """Mixin class for testing all of the various `hash...' methods."""
66
67   def check_hashbuffer_hashn(me, w, bigendp, makefn, hashfn):
68     """Check `hashuN'."""
69
70     ## Check encoding an integer.
71     h0, donefn0 = makefn(w + 2)
72     hashfn(h0.hashu8(0x00), T.bytes_as_int(w, bigendp)).hashu8(w + 1)
73     h1, donefn1 = makefn(w + 2)
74     h1.hash(T.span(w + 2))
75     me.assertEqual(donefn0(), donefn1())
76
77     ## Check overflow detection.
78     h0, _ = makefn(w)
79     me.assertRaises((OverflowError, ValueError),
80                     hashfn, h0, 1 << 8*w)
81
82   def check_hashbuffer_bufn(me, w, bigendp, makefn, hashfn):
83     """Check `hashbufN'."""
84
85     ## Go through a number of different sizes.
86     for n in [0, 1, 7, 8, 19, 255, 12345, 65535, 123456]:
87       if n >= 1 << 8*w: continue
88       h0, donefn0 = makefn(2 + w + n)
89       hashfn(h0.hashu8(0x00), T.span(n)).hashu8(0xff)
90       h1, donefn1 = makefn(2 + w + n)
91       h1.hash(T.prep_lenseq(w, n, bigendp, True))
92       me.assertEqual(donefn0(), donefn1())
93
94     ## Check blocks which are too large for the length prefix.
95     if w <= 3:
96       n = 1 << 8*w
97       h0, _ = makefn(w + n)
98       me.assertRaises((ValueError, OverflowError, TypeError),
99                       hashfn, h0, C.ByteString.zero(n))
100
101   def check_hashbuffer(me, makefn):
102     """Test the various `hash...' methods."""
103
104     ## Check `hashuN'.
105     me.check_hashbuffer_hashn(1, True, makefn, lambda h, n: h.hashu8(n))
106     me.check_hashbuffer_hashn(2, True, makefn, lambda h, n: h.hashu16(n))
107     me.check_hashbuffer_hashn(2, True, makefn, lambda h, n: h.hashu16b(n))
108     me.check_hashbuffer_hashn(2, False, makefn, lambda h, n: h.hashu16l(n))
109     if hasattr(makefn(0)[0], "hashu24"):
110       me.check_hashbuffer_hashn(3, True, makefn, lambda h, n: h.hashu24(n))
111       me.check_hashbuffer_hashn(3, True, makefn, lambda h, n: h.hashu24b(n))
112       me.check_hashbuffer_hashn(3, False, makefn, lambda h, n: h.hashu24l(n))
113     me.check_hashbuffer_hashn(4, True, makefn, lambda h, n: h.hashu32(n))
114     me.check_hashbuffer_hashn(4, True, makefn, lambda h, n: h.hashu32b(n))
115     me.check_hashbuffer_hashn(4, False, makefn, lambda h, n: h.hashu32l(n))
116     if hasattr(makefn(0)[0], "hashu64"):
117       me.check_hashbuffer_hashn(8, True, makefn, lambda h, n: h.hashu64(n))
118       me.check_hashbuffer_hashn(8, True, makefn, lambda h, n: h.hashu64b(n))
119       me.check_hashbuffer_hashn(8, False, makefn, lambda h, n: h.hashu64l(n))
120
121     ## Check `hashbufN'.
122     me.check_hashbuffer_bufn(1, True, makefn, lambda h, x: h.hashbuf8(x))
123     me.check_hashbuffer_bufn(2, True, makefn, lambda h, x: h.hashbuf16(x))
124     me.check_hashbuffer_bufn(2, True, makefn, lambda h, x: h.hashbuf16b(x))
125     me.check_hashbuffer_bufn(2, False, makefn, lambda h, x: h.hashbuf16l(x))
126     if hasattr(makefn(0)[0], "hashbuf24"):
127       me.check_hashbuffer_bufn(3, True, makefn, lambda h, x: h.hashbuf24(x))
128       me.check_hashbuffer_bufn(3, True, makefn, lambda h, x: h.hashbuf24b(x))
129       me.check_hashbuffer_bufn(3, False, makefn, lambda h, x: h.hashbuf24l(x))
130     me.check_hashbuffer_bufn(4, True, makefn, lambda h, x: h.hashbuf32(x))
131     me.check_hashbuffer_bufn(4, True, makefn, lambda h, x: h.hashbuf32b(x))
132     me.check_hashbuffer_bufn(4, False, makefn, lambda h, x: h.hashbuf32l(x))
133     if hasattr(makefn(0)[0], "hashbuf64"):
134       me.check_hashbuffer_bufn(8, True, makefn, lambda h, x: h.hashbuf64(x))
135       me.check_hashbuffer_bufn(8, True, makefn, lambda h, x: h.hashbuf64b(x))
136       me.check_hashbuffer_bufn(8, False, makefn, lambda h, x: h.hashbuf64l(x))
137
138 ###--------------------------------------------------------------------------
139 class TestKeysize (U.TestCase):
140
141   def test_any(me):
142
143     ## A typical one-byte spec.
144     ksz = C.seal.keysz
145     me.assertEqual(type(ksz), C.KeySZAny)
146     me.assertEqual(ksz.default, 20)
147     me.assertEqual(ksz.min, 0)
148     me.assertEqual(ksz.max, None)
149     for n in [0, 12, 20, 5000]:
150       me.assertTrue(ksz.check(n))
151       me.assertEqual(ksz.best(n), n)
152       me.assertEqual(ksz.pad(n), n)
153
154     ## A typical two-byte spec.  (No published algorithms actually /need/ a
155     ## two-byte key-size spec, but all of the HMAC variants use one anyway.)
156     ksz = C.sha256_hmac.keysz
157     me.assertEqual(type(ksz), C.KeySZAny)
158     me.assertEqual(ksz.default, 32)
159     me.assertEqual(ksz.min, 0)
160     me.assertEqual(ksz.max, None)
161     for n in [0, 12, 20, 5000]:
162       me.assertTrue(ksz.check(n))
163       me.assertEqual(ksz.best(n), n)
164       me.assertEqual(ksz.pad(n), n)
165
166     ## Check construction.
167     ksz = C.KeySZAny(15)
168     me.assertEqual(ksz.default, 15)
169     me.assertEqual(ksz.min, 0)
170     me.assertEqual(ksz.max, None)
171     me.assertRaises(ValueError, lambda: C.KeySZAny(-8))
172     me.assertEqual(C.KeySZAny(0).default, 0)
173
174   def test_set(me):
175     ## Note that no published algorithm uses a 16-bit `set' spec.
176
177     ## A typical spec.
178     ksz = C.salsa20.keysz
179     me.assertEqual(type(ksz), C.KeySZSet)
180     me.assertEqual(ksz.default, 32)
181     me.assertEqual(ksz.min, 10)
182     me.assertEqual(ksz.max, 32)
183     me.assertEqual(set(ksz.set), set([10, 16, 32]))
184     for x, best, pad in [(9, None, 10), (10, 10, 10), (11, 10, 16),
185                          (15, 10, 16), (16, 16, 16), (17, 16, 32),
186                          (31, 16, 32), (32, 32, 32), (33, 32, None)]:
187       if x == best == pad: me.assertTrue(ksz.check(x))
188       else: me.assertFalse(ksz.check(x))
189       if best is None: me.assertRaises(ValueError, ksz.best, x)
190       else: me.assertEqual(ksz.best(x), best)
191       if pad is None: me.assertRaises(ValueError, ksz.pad, x)
192       else: me.assertEqual(ksz.pad(x), pad)
193
194     ## Check construction.
195     ksz = C.KeySZSet(7)
196     me.assertEqual(ksz.default, 7)
197     me.assertEqual(set(ksz.set), set([7]))
198     me.assertEqual(ksz.min, 7)
199     me.assertEqual(ksz.max, 7)
200     ksz = C.KeySZSet(7, [3, 6, 9])
201     me.assertEqual(ksz.default, 7)
202     me.assertEqual(set(ksz.set), set([3, 6, 7, 9]))
203     me.assertEqual(ksz.min, 3)
204     me.assertEqual(ksz.max, 9)
205
206   def test_range(me):
207     ## Note that no published algorithm uses a 16-bit `range' spec, or an
208     ## unbounded `range'.
209
210     ## A typical spec.
211     ksz = C.rijndael.keysz
212     me.assertEqual(type(ksz), C.KeySZRange)
213     me.assertEqual(ksz.default, 32)
214     me.assertEqual(ksz.min, 4)
215     me.assertEqual(ksz.max, 32)
216     me.assertEqual(ksz.mod, 4)
217     for x, best, pad in [(3, None, 4), (4, 4, 4), (5, 4, 8),
218                          (15, 12, 16), (16, 16, 16), (17, 16, 20),
219                          (31, 28, 32), (32, 32, 32), (33, 32, None)]:
220       if x == best == pad: me.assertTrue(ksz.check(x))
221       else: me.assertFalse(ksz.check(x))
222       if best is None: me.assertRaises(ValueError, ksz.best, x)
223       else: me.assertEqual(ksz.best(x), best)
224       if pad is None: me.assertRaises(ValueError, ksz.pad, x)
225       else: me.assertEqual(ksz.pad(x), pad)
226
227     ## Check construction.
228     ksz = C.KeySZRange(28, 21, 35, 7)
229     me.assertEqual(ksz.default, 28)
230     me.assertEqual(ksz.min, 21)
231     me.assertEqual(ksz.max, 35)
232     me.assertEqual(ksz.mod, 7)
233     ksz = C.KeySZRange(28, 21, None, 7)
234     me.assertEqual(ksz.min, 21)
235     me.assertEqual(ksz.max, None)
236     me.assertEqual(ksz.mod, 7)
237     me.assertEqual(ksz.pad(36), 42)
238     me.assertRaises(ValueError, C.KeySZRange, 29, 21, 35, 7)
239     me.assertRaises(ValueError, C.KeySZRange, 28, 20, 35, 7)
240     me.assertRaises(ValueError, C.KeySZRange, 28, 21, 34, 7)
241     me.assertRaises(ValueError, C.KeySZRange, 28, -7, 35, 7)
242     me.assertRaises(ValueError, C.KeySZRange, 28, 35, 21, 7)
243     me.assertRaises(ValueError, C.KeySZRange, 35, 21, 28, 7)
244     me.assertRaises(ValueError, C.KeySZRange, 21, 28, 35, 7)
245
246   def test_conversions(me):
247     me.assertEqual(C.KeySZ.fromec(256), 128)
248     me.assertEqual(C.KeySZ.fromschnorr(256), 128)
249     me.assertEqual(round(C.KeySZ.fromdl(2958.6875)), 128)
250     me.assertEqual(round(C.KeySZ.fromif(2958.6875)), 128)
251     me.assertEqual(C.KeySZ.toec(128), 256)
252     me.assertEqual(C.KeySZ.toschnorr(128), 256)
253     me.assertEqual(C.KeySZ.todl(128), 2958.6875)
254     me.assertEqual(C.KeySZ.toif(128), 2958.6875)
255
256 ###--------------------------------------------------------------------------
257 class TestCipher (T.GenericTestMixin):
258   """Test basic symmetric ciphers."""
259
260   def _test_cipher(me, ccls):
261
262     ## Check the class properties.
263     me.assertEqual(type(ccls.name), str)
264     me.assertTrue(isinstance(ccls.keysz, C.KeySZ))
265     me.assertEqual(type(ccls.blksz), int)
266
267     ## Check round-tripping.
268     k = T.span(ccls.keysz.default)
269     iv = T.span(ccls.blksz)
270     m = T.span(253)
271     enc = ccls(k)
272     dec = ccls(k)
273     try: enc.setiv(iv)
274     except ValueError: can_setiv = False
275     else:
276       can_setiv = True
277       dec.setiv(iv)
278     c0 = enc.encrypt(m[0:57])
279     m0 = dec.decrypt(c0)
280     c1 = enc.encrypt(m[57:189])
281     m1 = dec.decrypt(c1)
282     try: enc.bdry()
283     except ValueError: can_bdry = False
284     else:
285       dec.bdry()
286       can_bdry = True
287     c2 = enc.encrypt(m[189:253])
288     m2 = dec.decrypt(c2)
289     me.assertEqual(len(c0) + len(c1) + len(c2), len(m))
290     me.assertEqual(m0, m[0:57])
291     me.assertEqual(m1, m[57:189])
292     me.assertEqual(m2, m[189:253])
293
294     ## Check the `enczero' and `deczero' methods.
295     c3 = enc.enczero(32)
296     me.assertEqual(dec.decrypt(c3), C.ByteString.zero(32))
297     m4 = dec.deczero(32)
298     me.assertEqual(enc.encrypt(m4), C.ByteString.zero(32))
299
300     ## Check that ciphers which support a `boundary' operation actually
301     ## need it.
302     if can_bdry:
303       dec = ccls(k)
304       if can_setiv: dec.setiv(iv)
305       m01 = dec.decrypt(c0 + c1)
306       me.assertEqual(m01, m[0:189])
307
308     ## Check that the boundary actually does something.
309     if can_bdry:
310       dec = ccls(k)
311       if can_setiv: dec.setiv(iv)
312       m012 = dec.decrypt(c0 + c1 + c2)
313       me.assertNotEqual(m012, m)
314
315     ## Check that bad key lengths are rejected.
316     badlen = bad_key_size(ccls.keysz)
317     if badlen is not None: me.assertRaises(ValueError, ccls, T.span(badlen))
318
319 TestCipher.generate_testcases((name, C.gcciphers[name]) for name in
320   ["des-ecb", "rijndael-cbc", "twofish-cfb", "serpent-ofb",
321    "blowfish-counter", "rc4", "seal", "salsa20/8", "shake128-xof"])
322
323 ###--------------------------------------------------------------------------
324 class TestAuthenticatedEncryption \
325         (HashBufferTestMixin, T.GenericTestMixin):
326   """Test authenticated encryption schemes."""
327
328   def _test_aead(me, aecls):
329
330     ## Check the class properties.
331     me.assertEqual(type(aecls.name), str)
332     me.assertTrue(isinstance(aecls.keysz, C.KeySZ))
333     me.assertTrue(isinstance(aecls.noncesz, C.KeySZ))
334     me.assertTrue(isinstance(aecls.tagsz, C.KeySZ))
335     me.assertEqual(type(aecls.blksz), int)
336     me.assertEqual(type(aecls.bufsz), int)
337     me.assertEqual(type(aecls.ohd), int)
338     me.assertEqual(type(aecls.flags), int)
339
340     ## Check round-tripping, with full precommitment.  First, select some
341     ## parameters.  (It's conceivable that some AEAD schemes are more
342     ## restrictive than advertised by the various properties, but this works
343     ## out OK in practice.)
344     k = T.span(aecls.keysz.default)
345     n = T.span(aecls.noncesz.default)
346     if aecls.flags&C.AEADF_NOAAD: h = T.span(0)
347     else: h = T.span(131)
348     m = T.span(253)
349     tsz = aecls.tagsz.default
350     key = aecls(k)
351
352     ## Next, encrypt a message, checking that things are proper as we go.
353     enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
354     me.assertEqual(enc.hsz, len(h))
355     me.assertEqual(enc.msz, len(m))
356     me.assertEqual(enc.mlen, 0)
357     me.assertEqual(enc.tsz, tsz)
358     aad = enc.aad()
359     if aecls.flags&C.AEADF_AADNDEP: me.assertEqual(aad.hsz, len(h))
360     else: me.assertEqual(aad.hsz, None)
361     me.assertEqual(aad.hlen, 0)
362     if not aecls.flags&C.AEADF_NOAAD:
363       aad.hash(h[0:83])
364       me.assertEqual(aad.hlen, 83)
365       aad.hash(h[83:131])
366       me.assertEqual(aad.hlen, 131)
367     c0 = enc.encrypt(m[0:57])
368     me.assertEqual(enc.mlen, 57)
369     me.assertTrue(57 - aecls.bufsz <= len(c0) <= 57 + aecls.ohd)
370     c1 = enc.encrypt(m[57:189])
371     me.assertEqual(enc.mlen, 189)
372     me.assertTrue(132 - aecls.bufsz <= len(c1) <=
373                   132 + aecls.bufsz + aecls.ohd)
374     c2 = enc.encrypt(m[189:253])
375     me.assertEqual(enc.mlen, 253)
376     me.assertTrue(64 - aecls.bufsz <= len(c2) <=
377                   64 + aecls.bufsz + aecls.ohd)
378     c3, t = enc.done(aad = aad)
379     me.assertTrue(len(c3) <= aecls.bufsz + aecls.ohd)
380     c = c0 + c1 + c2 + c3
381     me.assertTrue(len(m) <= len(c) <= len(m) + aecls.ohd)
382     me.assertEqual(len(t), tsz)
383
384     ## And now decrypt it again, with different record boundaries.
385     dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
386     me.assertEqual(dec.hsz, len(h))
387     me.assertEqual(dec.csz, len(c))
388     me.assertEqual(dec.clen, 0)
389     me.assertEqual(dec.tsz, tsz)
390     aad = dec.aad()
391     if aecls.flags&C.AEADF_AADNDEP: me.assertEqual(aad.hsz, len(h))
392     else: me.assertEqual(aad.hsz, None)
393     me.assertEqual(aad.hlen, 0)
394     aad.hash(h)
395     m0 = dec.decrypt(c[0:156])
396     me.assertTrue(156 - aecls.bufsz <= len(m0) <= 156)
397     m1 = dec.decrypt(c[156:])
398     me.assertTrue(len(c) - 156 - aecls.bufsz <= len(m1) <=
399                   len(c) - 156 + aecls.bufsz)
400     m2 = dec.done(tag = t, aad = aad)
401     me.assertEqual(m0 + m1 + m2, m)
402
403     ## And again, with the wrong tag.
404     dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
405     aad = dec.aad(); aad.hash(h)
406     _ = dec.decrypt(c)
407     me.assertRaises(ValueError, dec.done, tag = t ^ tsz*C.bytes("55"))
408
409     ## Check that the all-in-one methods work.
410     me.assertEqual((c, t),
411                    key.encrypt(n = n, h = h, m = m, tsz = tsz))
412     me.assertEqual(m,
413                    key.decrypt(n = n, h = h, c = c, t = t))
414
415     ## Check that bad key, nonce, and tag lengths are rejected.
416     badlen = bad_key_size(aecls.keysz)
417     if badlen is not None: me.assertRaises(ValueError, aecls, T.span(badlen))
418     badlen = bad_key_size(aecls.noncesz)
419     if badlen is not None:
420       me.assertRaises(ValueError, key.enc, nonce = T.span(badlen),
421                       hsz = len(h), msz = len(m), tsz = tsz)
422       me.assertRaises(ValueError, key.dec, nonce = T.span(badlen),
423                       hsz = len(h), csz = len(c), tsz = tsz)
424       if not aecls.flags&C.AEADF_PCTSZ:
425         enc = key.enc(nonce = n, hsz = 0, msz = len(m))
426         _ = enc.encrypt(m)
427         me.assertRaises(ValueError, enc.done, tsz = badlen)
428     badlen = bad_key_size(aecls.tagsz)
429     if badlen is not None:
430       me.assertRaises(ValueError, key.enc, nonce = n,
431                       hsz = len(h), msz = len(m), tsz = badlen)
432       me.assertRaises(ValueError, key.dec, nonce = n,
433                       hsz = len(h), csz = len(c), tsz = badlen)
434
435     ## Check that we can't get a loose `aad' object from a scheme which has
436     ## nonce-dependent AAD processing.
437     if aecls.flags&C.AEADF_AADNDEP: me.assertRaises(ValueError, key.aad)
438
439     ## Check the menagerie of AAD hashing methods.
440     if not aecls.flags&C.AEADF_NOAAD:
441       def mkhash(hsz):
442         enc = key.enc(nonce = n, hsz = hsz, msz = 0, tsz = tsz)
443         aad = enc.aad()
444         return aad, lambda: enc.done(aad = aad)[1]
445       me.check_hashbuffer(mkhash)
446
447     ## Check that encryption/decryption works with the given precommitments.
448     def quick_enc_check(**kw):
449       enc = key.enc(**kw)
450       aad = enc.aad().hash(h)
451       c0 = enc.encrypt(m); c1, tt = enc.done(aad = aad, tsz = tsz)
452       me.assertEqual((c, t), (c0 + c1, tt))
453     def quick_dec_check(**kw):
454       dec = key.dec(**kw)
455       aad = dec.aad().hash(h)
456       m0 = dec.decrypt(c); m1 = dec.done(aad = aad, tag = t)
457       me.assertEqual(m, m0 + m1)
458
459     ## Check that we can get away without precommitting to the header length
460     ## if and only if the AEAD scheme says it will let us.
461     if aecls.flags&C.AEADF_PCHSZ:
462       me.assertRaises(ValueError, key.enc, nonce = n,
463                       msz = len(m), tsz = tsz)
464       me.assertRaises(ValueError, key.dec, nonce = n,
465                       csz = len(c), tsz = tsz)
466     else:
467       quick_enc_check(nonce = n, msz = len(m), tsz = tsz)
468       quick_dec_check(nonce = n, csz = len(c), tsz = tsz)
469
470     ## Check that we can get away without precommitting to the message/
471     ## ciphertext length if and only if the AEAD scheme says it will let us.
472     if aecls.flags&C.AEADF_PCMSZ:
473       me.assertRaises(ValueError, key.enc, nonce = n,
474                       hsz = len(h), tsz = tsz)
475       me.assertRaises(ValueError, key.dec, nonce = n,
476                       hsz = len(h), tsz = tsz)
477     else:
478       quick_enc_check(nonce = n, hsz = len(h), tsz = tsz)
479       quick_dec_check(nonce = n, hsz = len(h), tsz = tsz)
480
481     ## Check that we can get away without precommitting to the tag length if
482     ## and only if the AEAD scheme says it will let us.
483     if aecls.flags&C.AEADF_PCTSZ:
484       me.assertRaises(ValueError, key.enc, nonce = n,
485                       hsz = len(h), msz = len(m))
486       me.assertRaises(ValueError, key.dec, nonce = n,
487                       hsz = len(h), csz = len(c))
488     else:
489       quick_enc_check(nonce = n, hsz = len(h), msz = len(m))
490       quick_dec_check(nonce = n, hsz = len(h), csz = len(c))
491
492     ## Check that if we precommit to the header length, we're properly held
493     ## to the commitment.
494     if not aecls.flags&C.AEADF_NOAAD:
495
496       ## First, check encryption with underrun.  If we must supply AAD first,
497       ## then the underrun will be reported when we start trying to encrypt;
498       ## otherwise, checking is delayed until `done'.
499       enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
500       aad = enc.aad().hash(h[0:83])
501       if aecls.flags&C.AEADF_AADFIRST:
502         me.assertRaises(ValueError, enc.encrypt, m)
503       else:
504         _ = enc.encrypt(m)
505         me.assertRaises(ValueError, enc.done, aad = aad)
506
507       ## Next, check decryption with underrun.  If we must supply AAD first,
508       ## then the underrun will be reported when we start trying to encrypt;
509       ## otherwise, checking is delayed until `done'.
510       dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
511       aad = dec.aad().hash(h[0:83])
512       if aecls.flags&C.AEADF_AADFIRST:
513         me.assertRaises(ValueError, dec.decrypt, c)
514       else:
515         _ = dec.decrypt(c)
516         me.assertRaises(ValueError, dec.done, tag = t, aad = aad)
517
518       ## If AAD processing is nonce-dependent then an overrun will be
519       ## detected imediately.
520       if aecls.flags&C.AEADF_AADNDEP:
521         enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
522         aad = enc.aad().hash(h[0:83])
523         me.assertRaises(ValueError, aad.hash, h[82:131])
524         dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
525         aad = dec.aad().hash(h[0:83])
526         me.assertRaises(ValueError, aad.hash, h[82:131])
527
528     ## Some additional tests for nonce-dependent `aad' objects.
529     if aecls.flags&C.AEADF_AADNDEP:
530
531       ## Check that `aad' objects can't be used once their parents are gone.
532       enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
533       aad = enc.aad()
534       del enc
535       me.assertRaises(ValueError, aad.hash, h)
536
537       ## Check that they can't be crossed over.
538       enc0 = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
539       enc1 = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
540       enc0.aad().hash(h)
541       aad1 = enc1.aad().hash(h)
542       _ = enc0.encrypt(m)
543       me.assertRaises(ValueError, enc0.done, tsz = tsz, aad = aad1)
544
545     ## Test copying AAD.
546     if not aecls.flags&C.AEADF_AADNDEP and not aecls.flags&C.AEADF_NOAAD:
547       aad0 = key.aad()
548       aad0.hash(h[0:83])
549       aad1 = aad0.copy()
550       aad2 = aad1.copy()
551       aad0.hash(h[83:131])
552       aad1.hash(h[83:131])
553       aad2.hash(h[83:131] ^ 48*C.bytes("ff"))
554       me.assertEqual(key.enc(nonce = n, hsz = len(h),
555                              msz = 0, tsz = tsz).done(aad = aad0),
556                      key.enc(nonce = n, hsz = len(h),
557                              msz = 0, tsz = tsz).done(aad = aad1))
558       me.assertNotEqual(key.enc(nonce = n, hsz = len(h),
559                                 msz = 0, tsz = tsz).done(aad = aad0),
560                         key.enc(nonce = n, hsz = len(h),
561                                 msz = 0, tsz = tsz).done(aad = aad2))
562
563     ## Check that if we precommit to the message length, we're properly held
564     ## to the commitment.  (Fortunately, this is way simpler than the AAD
565     ## case above.)  First, try an underrun.
566     enc = key.enc(nonce = n, hsz = 0, msz = len(m), tsz = tsz)
567     _ = enc.encrypt(m[0:183])
568     me.assertRaises(ValueError, enc.done, tsz = tsz)
569     dec = key.dec(nonce = n, hsz = 0, csz = len(c), tsz = tsz)
570     _ = dec.decrypt(c[0:183])
571     me.assertRaises(ValueError, dec.done, tag = t)
572
573     ## And now an overrun.
574     enc = key.enc(nonce = n, hsz = 0, msz = 183, tsz = tsz)
575     me.assertRaises(ValueError, enc.encrypt, m)
576     dec = key.dec(nonce = n, hsz = 0, csz = 183, tsz = tsz)
577     me.assertRaises(ValueError, dec.decrypt, c)
578
579     ## Finally, check that if we precommit to a tag length, we're properly
580     ## held to the commitment.  This depends on being able to find a tag size
581     ## which isn't the default.
582     tsz1 = different_key_size(aecls.tagsz, tsz)
583     if tsz1 is not None:
584       enc = key.enc(nonce = n, hsz = 0, msz = len(m), tsz = tsz1)
585       _ = enc.encrypt(m)
586       me.assertRaises(ValueError, enc.done, tsz = tsz)
587       dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz1)
588       aad = dec.aad().hash(h)
589       _ = dec.decrypt(c)
590       me.assertRaises(ValueError, enc.done, tsz = tsz, aad = aad)
591
592 TestAuthenticatedEncryption.generate_testcases \
593   ((name, C.gcaeads[name]) for name in
594    ["des3-ccm", "blowfish-ocb1", "square-ocb3", "rijndael-gcm",
595     "serpent-eax", "salsa20-naclbox", "chacha20-poly1305"])
596
597 ###--------------------------------------------------------------------------
598 class BaseTestHash (HashBufferTestMixin):
599   """Base class for testing hash functions."""
600
601   def check_hash(me, hcls, need_bufsz = True):
602     """
603     Check hash class HCLS.
604
605     If NEED_BUFSZ is false, then don't insist that HCLS have working `bufsz',
606     `name', or `hashsz' attributes.  This test is mostly reused for MACs,
607     which don't have these attributes.
608     """
609     ## Check the class properties.
610     if need_bufsz:
611       me.assertEqual(type(hcls.name), str)
612       me.assertEqual(type(hcls.bufsz), int)
613       me.assertEqual(type(hcls.hashsz), int)
614
615     ## Set some initial values.
616     m = T.span(131)
617     h = hcls().hash(m).done()
618
619     ## Check that hash length comes out right.
620     if need_bufsz: me.assertEqual(len(h), hcls.hashsz)
621
622     ## Check that we get the same answer if we split the message up.
623     me.assertEqual(h, hcls().hash(m[0:73]).hash(m[73:131]).done())
624
625     ## Check the `check' method.
626     me.assertTrue(hcls().hash(m).check(h))
627     me.assertFalse(hcls().hash(m).check(h ^ len(h)*C.bytes("aa")))
628
629     ## Check the menagerie of random hashing methods.
630     def mkhash(_):
631       h = hcls()
632       return h, h.done
633     me.check_hashbuffer(mkhash)
634
635 class TestHash (BaseTestHash, T.GenericTestMixin):
636   """Test hash functions."""
637   def _test_hash(me, hcls):    me.check_hash(hcls, need_bufsz = True)
638
639 TestHash.generate_testcases((name, C.gchashes[name]) for name in
640   ["md5", "sha", "whirlpool", "sha256", "sha512/224", "sha3-384", "shake256",
641    "crc32"])
642
643 ###--------------------------------------------------------------------------
644 class TestMessageAuthentication (BaseTestHash, T.GenericTestMixin):
645   """Test message authentication codes."""
646
647   def _test_mac(me, mcls):
648
649     ## Check the MAC properties.
650     me.assertEqual(type(mcls.name), str)
651     me.assertTrue(isinstance(mcls.keysz, C.KeySZ))
652     me.assertEqual(type(mcls.tagsz), int)
653
654     ## Test hashing.
655     k = T.span(mcls.keysz.default)
656     key = mcls(k)
657     me.check_hash(key, need_bufsz = False)
658
659     ## Check that bad key lengths are rejected.
660     badlen = bad_key_size(mcls.keysz)
661     if badlen is not None: me.assertRaises(ValueError, mcls, T.span(badlen))
662
663 TestMessageAuthentication.generate_testcases \
664   ((name, C.gcmacs[name]) for name in
665    ["sha-hmac", "rijndael-cmac", "twofish-pmac1", "kmac128"])
666
667 class TestPoly1305 (HashBufferTestMixin):
668   """Check the Poly1305 one-time message authentication function."""
669
670   def test_poly1305(me):
671
672     ## Check the MAC properties.
673     me.assertEqual(C.poly1305.name, "poly1305")
674     me.assertEqual(type(C.poly1305.keysz), C.KeySZSet)
675     me.assertEqual(C.poly1305.keysz.default, 16)
676     me.assertEqual(set(C.poly1305.keysz.set), set([16]))
677     me.assertEqual(C.poly1305.tagsz, 16)
678     me.assertEqual(C.poly1305.masksz, 16)
679
680     ## Set some initial values.
681     k = T.span(16)
682     u = T.span(64)[-16:]
683     m = T.span(149)
684     key = C.poly1305(k)
685     t = key(u).hash(m).done()
686
687     ## Check the key properties.
688     me.assertEqual(len(t), 16)
689
690     ## Check that we get the same answer if we split the message up.
691     me.assertEqual(t, key(u).hash(m[0:86]).hash(m[86:149]).done())
692
693     ## Check the `check' method.
694     me.assertTrue(key(u).hash(m).check(t))
695     me.assertFalse(key(u).hash(m).check(t ^ 16*C.bytes("cc")))
696
697     ## Check the menagerie of random hashing methods.
698     def mkhash(_):
699       h = key(u)
700       return h, h.done
701     me.check_hashbuffer(mkhash)
702
703     ## Check that we can't complete hashing without a mask.
704     me.assertRaises(ValueError, key().hash(m).done)
705
706     ## Check `concat'.
707     h0 = key().hash(m[0:96])
708     h1 = key().hash(m[96:117])
709     me.assertEqual(t, key(u).concat(h0, h1).hash(m[117:149]).done())
710     key1 = C.poly1305(k)
711     me.assertRaises(TypeError, key().concat, key1().hash(m[0:96]), h1)
712     me.assertRaises(TypeError, key().concat, h0, key1().hash(m[96:117]))
713     me.assertRaises(ValueError, key().concat, key().hash(m[0:93]), h1)
714
715 ###--------------------------------------------------------------------------
716 class TestHLatin (U.TestCase):
717   """Test the `hsalsa20' and `hchacha20' functions."""
718
719   def test_hlatin(me):
720     kk = [T.span(sz) for sz in [10, 16, 32]]
721     n = T.span(16)
722     bad_k = T.span(18)
723     bad_n = T.span(13)
724     for fn in [C.hsalsa208_prf, C.hsalsa2012_prf, C.hsalsa20_prf,
725                C.hchacha8_prf, C.hchacha12_prf, C.hchacha20_prf]:
726       for k in kk:
727         h = fn(k, n)
728         me.assertEqual(len(h), 32)
729       me.assertRaises(ValueError, fn, bad_k, n)
730       me.assertRaises(ValueError, fn, k, bad_n)
731
732 ###--------------------------------------------------------------------------
733 class TestKeccak (HashBufferTestMixin):
734   """Test the Keccak-p[1600, n] sponge function."""
735
736   def test_keccak(me):
737
738     ## Make a state and feed some stuff into it.
739     m0 = T.bin("some initial string")
740     m1 = T.bin("awesome follow-up string")
741     st0 = C.Keccak1600()
742     me.assertEqual(st0.nround, 24)
743     st0.mix(m0).step()
744
745     ## Make another step with a different round count.
746     st1 = C.Keccak1600(23)
747     st1.mix(m0).step()
748     me.assertNotEqual(st0.extract(32), st1.extract(32))
749
750     ## Check state copying.
751     st1 = st0.copy()
752     mask = st1.extract(len(m1))
753     st0.mix(m1)
754     st1.mix(m1)
755     me.assertEqual(st0.extract(32), st1.extract(32))
756
757     ## Check error conditions.
758     _ = st0.extract(200)
759     me.assertRaises(ValueError, st0.extract, 201)
760     st0.mix(T.span(200))
761     me.assertRaises(ValueError, st0.mix, T.span(201))
762
763   def check_shake(me, xcls, c, done_matches_xof = True):
764     """
765     Test the SHAKE and cSHAKE XOFs.
766
767     This is also used for testing KMAC, but that sets DONE_MATCHES_XOF false
768     to indicate that the XOF output is range-separated from the fixed-length
769     outputs (unlike the basic SHAKE functions).
770     """
771
772     ## Check the hash attributes.
773     x = xcls()
774     me.assertEqual(x.rate, 200 - c)
775     me.assertEqual(x.buffered, 0)
776     me.assertEqual(x.state, "absorb")
777
778     ## Set some initial values.
779     func = T.bin("TESTXOF")
780     perso = T.bin("catacomb-python test")
781     m = T.span(167)
782     h0 = xcls().hash(m).done(193)
783     me.assertEqual(len(h0), 193)
784     h1 = xcls(func = func, perso = perso).hash(m).done(193)
785     me.assertEqual(len(h1), 193)
786     me.assertNotEqual(h0, h1)
787
788     ## Check input and output in pieces, and the state machine.
789     if done_matches_xof: h = h0
790     else: h = xcls().hash(m).xof().get(len(h0))
791     x = xcls().hash(m[0:76]).hash(m[76:167]).xof()
792     me.assertEqual(h, x.get(98) + x.get(95))
793
794     ## Check masking.
795     x = xcls().hash(m).xof()
796     me.assertEqual(x.mask(m), m ^ h[0:len(m)])
797
798     ## Check the `check' method.
799     me.assertTrue(xcls().hash(m).check(h0))
800     me.assertFalse(xcls().hash(m).check(h1))
801
802     ## Check the menagerie of random hashing methods.
803     def mkhash(_):
804       x = xcls(func = func, perso = perso)
805       return x, x.done
806     me.check_hashbuffer(mkhash)
807
808     ## Check the state machine tracking.
809     x = xcls(); me.assertEqual(x.state, "absorb")
810     x.hash(m); me.assertEqual(x.state, "absorb")
811     xx = x.copy()
812     h = xx.done(); me.assertEqual(len(h), 100 - x.rate//2)
813     me.assertEqual(xx.state, "dead")
814     me.assertRaises(ValueError, xx.done, 1)
815     me.assertRaises(ValueError, xx.get, 1)
816     me.assertEqual(x.state, "absorb")
817     me.assertRaises(ValueError, x.get, 1)
818     x.xof(); me.assertEqual(x.state, "squeeze")
819     me.assertRaises(ValueError, x.done, 1)
820     _ = x.get(1)
821     yy = x.copy(); me.assertEqual(yy.state, "squeeze")
822
823   def test_shake128(me): me.check_shake(C.Shake128, 32)
824   def test_shake256(me): me.check_shake(C.Shake256, 64)
825
826   def check_kmac(me, mcls, c):
827     k = T.span(32)
828     me.check_shake(lambda func = None, perso = None:
829                      mcls(k, perso = perso),
830                    c, done_matches_xof = False)
831
832   def test_kmac128(me): me.check_kmac(C.KMAC128, 32)
833   def test_kmac256(me): me.check_kmac(C.KMAC256, 64)
834
835 ###--------------------------------------------------------------------------
836 class TestPRP (T.GenericTestMixin):
837   """Test pseudorandom permutations (PRPs)."""
838
839   def _test_prp(me, pcls):
840
841     ## Check the PRP properties.
842     me.assertEqual(type(pcls.name), str)
843     me.assertTrue(isinstance(pcls.keysz, C.KeySZ))
844     me.assertEqual(type(pcls.blksz), int)
845
846     ## Check round-tripping.
847     k = T.span(pcls.keysz.default)
848     key = pcls(k)
849     m = T.span(pcls.blksz)
850     c = key.encrypt(m)
851     me.assertEqual(len(c), pcls.blksz)
852     me.assertEqual(m, key.decrypt(c))
853
854     ## Check that bad key lengths are rejected.
855     badlen = bad_key_size(pcls.keysz)
856     if badlen is not None: me.assertRaises(ValueError, pcls, T.span(badlen))
857
858     ## Check that bad blocks are rejected.
859     badblk = T.span(pcls.blksz + 1)
860     me.assertRaises(ValueError, key.encrypt, badblk)
861     me.assertRaises(ValueError, key.decrypt, badblk)
862
863 TestPRP.generate_testcases((name, C.gcprps[name]) for name in
864   ["desx", "blowfish", "rijndael"])
865
866 ###----- That's all, folks --------------------------------------------------
867
868 if __name__ == "__main__": U.main()