chiark / gitweb /
algorithms.c: Implement KMAC in C.
[catacomb-python] / t / t-algorithms.py
1 ### -*- mode: python, coding: utf-8 -*-
2 ###
3 ### Test symmetric algorithms
4 ###
5 ### (c) 2019 Straylight/Edgeware
6 ###
7
8 ###----- Licensing notice ---------------------------------------------------
9 ###
10 ### This file is part of the Python interface to Catacomb.
11 ###
12 ### Catacomb/Python is free software: you can redistribute it and/or
13 ### modify it under the terms of the GNU General Public License as
14 ### published by the Free Software Foundation; either version 2 of the
15 ### License, or (at your option) any later version.
16 ###
17 ### Catacomb/Python is distributed in the hope that it will be useful, but
18 ### WITHOUT ANY WARRANTY; without even the implied warranty of
19 ### MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
20 ### General Public License for more details.
21 ###
22 ### You should have received a copy of the GNU General Public License
23 ### along with Catacomb/Python.  If not, write to the Free Software
24 ### Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307,
25 ### USA.
26
27 ###--------------------------------------------------------------------------
28 ### Imported modules.
29
30 import catacomb as C
31 import unittest as U
32 import testutils as T
33
34 ###--------------------------------------------------------------------------
35 ### Utilities.
36
37 def bad_key_size(ksz):
38   if isinstance(ksz, C.KeySZAny): return None
39   elif isinstance(ksz, C.KeySZRange):
40     if ksz.mod != 1: return ksz.min + 1
41     elif ksz.max is not None: return ksz.max + 1
42     elif ksz.min != 0: return ksz.min - 1
43     else: return None
44   elif isinstance(ksz, C.KeySZSet):
45     for sz in sorted(ksz.set):
46       if sz + 1 not in ksz.set: return sz + 1
47     assert False, "That should have worked."
48   else:
49     return None
50
51 def different_key_size(ksz, sz):
52   if isinstance(ksz, C.KeySZAny): return sz + 1
53   elif isinstance(ksz, C.KeySZRange):
54     if sz > ksz.min: return sz - ksz.mod
55     elif ksz.max is None or sz < ksz.max: return sz + ksz.mod
56     else: return None
57   elif isinstance(ksz, C.KeySZSet):
58     for sz1 in sorted(ksz.set):
59       if sz != sz1: return sz1
60     return None
61   else:
62     return None
63
64 class HashBufferTestMixin (U.TestCase):
65   """Mixin class for testing all of the various `hash...' methods."""
66
67   def check_hashbuffer_hashn(me, w, bigendp, makefn, hashfn):
68     """Check `hashuN'."""
69
70     ## Check encoding an integer.
71     h0, donefn0 = makefn(w + 2)
72     hashfn(h0.hashu8(0x00), T.bytes_as_int(w, bigendp)).hashu8(w + 1)
73     h1, donefn1 = makefn(w + 2)
74     h1.hash(T.span(w + 2))
75     me.assertEqual(donefn0(), donefn1())
76
77     ## Check overflow detection.
78     h0, _ = makefn(w)
79     me.assertRaises((OverflowError, ValueError),
80                     hashfn, h0, 1 << 8*w)
81
82   def check_hashbuffer_bufn(me, w, bigendp, makefn, hashfn):
83     """Check `hashbufN'."""
84
85     ## Go through a number of different sizes.
86     for n in [0, 1, 7, 8, 19, 255, 12345, 65535, 123456]:
87       if n >= 1 << 8*w: continue
88       h0, donefn0 = makefn(2 + w + n)
89       hashfn(h0.hashu8(0x00), T.span(n)).hashu8(0xff)
90       h1, donefn1 = makefn(2 + w + n)
91       h1.hash(T.prep_lenseq(w, n, bigendp, True))
92       me.assertEqual(donefn0(), donefn1())
93
94     ## Check blocks which are too large for the length prefix.
95     if w <= 3:
96       n = 1 << 8*w
97       h0, _ = makefn(w + n)
98       me.assertRaises((ValueError, TypeError),
99                       hashfn, h0, C.ByteString.zero(n))
100
101   def check_hashbuffer(me, makefn):
102     """Test the various `hash...' methods."""
103
104     ## Check `hashuN'.
105     me.check_hashbuffer_hashn(1, True, makefn, lambda h, n: h.hashu8(n))
106     me.check_hashbuffer_hashn(2, True, makefn, lambda h, n: h.hashu16(n))
107     me.check_hashbuffer_hashn(2, True, makefn, lambda h, n: h.hashu16b(n))
108     me.check_hashbuffer_hashn(2, False, makefn, lambda h, n: h.hashu16l(n))
109     me.check_hashbuffer_hashn(3, True, makefn, lambda h, n: h.hashu24(n))
110     me.check_hashbuffer_hashn(3, True, makefn, lambda h, n: h.hashu24b(n))
111     me.check_hashbuffer_hashn(3, False, makefn, lambda h, n: h.hashu24l(n))
112     me.check_hashbuffer_hashn(4, True, makefn, lambda h, n: h.hashu32(n))
113     me.check_hashbuffer_hashn(4, True, makefn, lambda h, n: h.hashu32b(n))
114     me.check_hashbuffer_hashn(4, False, makefn, lambda h, n: h.hashu32l(n))
115     if hasattr(makefn(0)[0], "hashu64"):
116       me.check_hashbuffer_hashn(8, True, makefn, lambda h, n: h.hashu64(n))
117       me.check_hashbuffer_hashn(8, True, makefn, lambda h, n: h.hashu64b(n))
118       me.check_hashbuffer_hashn(8, False, makefn, lambda h, n: h.hashu64l(n))
119
120     ## Check `hashbufN'.
121     me.check_hashbuffer_bufn(1, True, makefn, lambda h, x: h.hashbuf8(x))
122     me.check_hashbuffer_bufn(2, True, makefn, lambda h, x: h.hashbuf16(x))
123     me.check_hashbuffer_bufn(2, True, makefn, lambda h, x: h.hashbuf16b(x))
124     me.check_hashbuffer_bufn(2, False, makefn, lambda h, x: h.hashbuf16l(x))
125     me.check_hashbuffer_bufn(3, True, makefn, lambda h, x: h.hashbuf24(x))
126     me.check_hashbuffer_bufn(3, True, makefn, lambda h, x: h.hashbuf24b(x))
127     me.check_hashbuffer_bufn(3, False, makefn, lambda h, x: h.hashbuf24l(x))
128     me.check_hashbuffer_bufn(4, True, makefn, lambda h, x: h.hashbuf32(x))
129     me.check_hashbuffer_bufn(4, True, makefn, lambda h, x: h.hashbuf32b(x))
130     me.check_hashbuffer_bufn(4, False, makefn, lambda h, x: h.hashbuf32l(x))
131     if hasattr(makefn(0)[0], "hashbuf64"):
132       me.check_hashbuffer_bufn(8, True, makefn, lambda h, x: h.hashbuf64(x))
133       me.check_hashbuffer_bufn(8, True, makefn, lambda h, x: h.hashbuf64b(x))
134       me.check_hashbuffer_bufn(8, False, makefn, lambda h, x: h.hashbuf64l(x))
135
136 ###--------------------------------------------------------------------------
137 class TestKeysize (U.TestCase):
138
139   def test_any(me):
140
141     ## A typical one-byte spec.
142     ksz = C.seal.keysz
143     me.assertEqual(type(ksz), C.KeySZAny)
144     me.assertEqual(ksz.default, 20)
145     me.assertEqual(ksz.min, 0)
146     me.assertEqual(ksz.max, None)
147     for n in [0, 12, 20, 5000]:
148       me.assertTrue(ksz.check(n))
149       me.assertEqual(ksz.best(n), n)
150       me.assertEqual(ksz.pad(n), n)
151
152     ## A typical two-byte spec.  (No published algorithms actually /need/ a
153     ## two-byte key-size spec, but all of the HMAC variants use one anyway.)
154     ksz = C.sha256_hmac.keysz
155     me.assertEqual(type(ksz), C.KeySZAny)
156     me.assertEqual(ksz.default, 32)
157     me.assertEqual(ksz.min, 0)
158     me.assertEqual(ksz.max, None)
159     for n in [0, 12, 20, 5000]:
160       me.assertTrue(ksz.check(n))
161       me.assertEqual(ksz.best(n), n)
162       me.assertEqual(ksz.pad(n), n)
163
164     ## Check construction.
165     ksz = C.KeySZAny(15)
166     me.assertEqual(ksz.default, 15)
167     me.assertEqual(ksz.min, 0)
168     me.assertEqual(ksz.max, None)
169     me.assertRaises(ValueError, lambda: C.KeySZAny(-8))
170     me.assertEqual(C.KeySZAny(0).default, 0)
171
172   def test_set(me):
173     ## Note that no published algorithm uses a 16-bit `set' spec.
174
175     ## A typical spec.
176     ksz = C.salsa20.keysz
177     me.assertEqual(type(ksz), C.KeySZSet)
178     me.assertEqual(ksz.default, 32)
179     me.assertEqual(ksz.min, 10)
180     me.assertEqual(ksz.max, 32)
181     me.assertEqual(ksz.set, set([10, 16, 32]))
182     for x, best, pad in [(9, None, 10), (10, 10, 10), (11, 10, 16),
183                          (15, 10, 16), (16, 16, 16), (17, 16, 32),
184                          (31, 16, 32), (32, 32, 32), (33, 32, None)]:
185       if x == best == pad: me.assertTrue(ksz.check(x))
186       else: me.assertFalse(ksz.check(x))
187       if best is None: me.assertRaises(ValueError, ksz.best, x)
188       else: me.assertEqual(ksz.best(x), best)
189       if pad is None: me.assertRaises(ValueError, ksz.pad, x)
190       else: me.assertEqual(ksz.pad(x), pad)
191
192     ## Check construction.
193     ksz = C.KeySZSet(7)
194     me.assertEqual(ksz.default, 7)
195     me.assertEqual(ksz.set, set([7]))
196     me.assertEqual(ksz.min, 7)
197     me.assertEqual(ksz.max, 7)
198     ksz = C.KeySZSet(7, iter([3, 6, 9]))
199     me.assertEqual(ksz.default, 7)
200     me.assertEqual(ksz.set, set([3, 6, 7, 9]))
201     me.assertEqual(ksz.min, 3)
202     me.assertEqual(ksz.max, 9)
203
204   def test_range(me):
205     ## Note that no published algorithm uses a 16-bit `range' spec, or an
206     ## unbounded `range'.
207
208     ## A typical spec.
209     ksz = C.rijndael.keysz
210     me.assertEqual(type(ksz), C.KeySZRange)
211     me.assertEqual(ksz.default, 32)
212     me.assertEqual(ksz.min, 4)
213     me.assertEqual(ksz.max, 32)
214     me.assertEqual(ksz.mod, 4)
215     for x, best, pad in [(3, None, 4), (4, 4, 4), (5, 4, 8),
216                          (15, 12, 16), (16, 16, 16), (17, 16, 20),
217                          (31, 28, 32), (32, 32, 32), (33, 32, None)]:
218       if x == best == pad: me.assertTrue(ksz.check(x))
219       else: me.assertFalse(ksz.check(x))
220       if best is None: me.assertRaises(ValueError, ksz.best, x)
221       else: me.assertEqual(ksz.best(x), best)
222       if pad is None: me.assertRaises(ValueError, ksz.pad, x)
223       else: me.assertEqual(ksz.pad(x), pad)
224
225     ## Check construction.
226     ksz = C.KeySZRange(28, 21, 35, 7)
227     me.assertEqual(ksz.default, 28)
228     me.assertEqual(ksz.min, 21)
229     me.assertEqual(ksz.max, 35)
230     me.assertEqual(ksz.mod, 7)
231     ksz = C.KeySZRange(28, 21, None, 7)
232     me.assertEqual(ksz.min, 21)
233     me.assertEqual(ksz.max, None)
234     me.assertEqual(ksz.mod, 7)
235     me.assertEqual(ksz.pad(36), 42)
236     me.assertRaises(ValueError, C.KeySZRange, 29, 21, 35, 7)
237     me.assertRaises(ValueError, C.KeySZRange, 28, 20, 35, 7)
238     me.assertRaises(ValueError, C.KeySZRange, 28, 21, 34, 7)
239     me.assertRaises(ValueError, C.KeySZRange, 28, -7, 35, 7)
240     me.assertRaises(ValueError, C.KeySZRange, 28, 35, 21, 7)
241     me.assertRaises(ValueError, C.KeySZRange, 35, 21, 28, 7)
242     me.assertRaises(ValueError, C.KeySZRange, 21, 28, 35, 7)
243
244   def test_conversions(me):
245     me.assertEqual(C.KeySZ.fromec(256), 128)
246     me.assertEqual(C.KeySZ.fromschnorr(256), 128)
247     me.assertEqual(round(C.KeySZ.fromdl(2958.6875)), 128)
248     me.assertEqual(round(C.KeySZ.fromif(2958.6875)), 128)
249     me.assertEqual(C.KeySZ.toec(128), 256)
250     me.assertEqual(C.KeySZ.toschnorr(128), 256)
251     me.assertEqual(C.KeySZ.todl(128), 2958.6875)
252     me.assertEqual(C.KeySZ.toif(128), 2958.6875)
253
254 ###--------------------------------------------------------------------------
255 class TestCipher (T.GenericTestMixin):
256   """Test basic symmetric ciphers."""
257
258   def _test_cipher(me, ccls):
259
260     ## Check the class properties.
261     me.assertEqual(type(ccls.name), str)
262     me.assertTrue(isinstance(ccls.keysz, C.KeySZ))
263     me.assertEqual(type(ccls.blksz), int)
264
265     ## Check round-tripping.
266     k = T.span(ccls.keysz.default)
267     iv = T.span(ccls.blksz)
268     m = T.span(253)
269     enc = ccls(k)
270     dec = ccls(k)
271     try: enc.setiv(iv)
272     except ValueError: can_setiv = False
273     else:
274       can_setiv = True
275       dec.setiv(iv)
276     c0 = enc.encrypt(m[0:57])
277     m0 = dec.decrypt(c0)
278     c1 = enc.encrypt(m[57:189])
279     m1 = dec.decrypt(c1)
280     try: enc.bdry()
281     except ValueError: can_bdry = False
282     else:
283       dec.bdry()
284       can_bdry = True
285     c2 = enc.encrypt(m[189:253])
286     m2 = dec.decrypt(c2)
287     me.assertEqual(len(c0) + len(c1) + len(c2), len(m))
288     me.assertEqual(m0, m[0:57])
289     me.assertEqual(m1, m[57:189])
290     me.assertEqual(m2, m[189:253])
291
292     ## Check the `enczero' and `deczero' methods.
293     c3 = enc.enczero(32)
294     me.assertEqual(dec.decrypt(c3), C.ByteString.zero(32))
295     m4 = dec.deczero(32)
296     me.assertEqual(enc.encrypt(m4), C.ByteString.zero(32))
297
298     ## Check that ciphers which support a `boundary' operation actually
299     ## need it.
300     if can_bdry:
301       dec = ccls(k)
302       if can_setiv: dec.setiv(iv)
303       m01 = dec.decrypt(c0 + c1)
304       me.assertEqual(m01, m[0:189])
305
306     ## Check that the boundary actually does something.
307     if can_bdry:
308       dec = ccls(k)
309       if can_setiv: dec.setiv(iv)
310       m012 = dec.decrypt(c0 + c1 + c2)
311       me.assertNotEqual(m012, m)
312
313     ## Check that bad key lengths are rejected.
314     badlen = bad_key_size(ccls.keysz)
315     if badlen is not None: me.assertRaises(ValueError, ccls, T.span(badlen))
316
317 TestCipher.generate_testcases((name, C.gcciphers[name]) for name in
318   ["des-ecb", "rijndael-cbc", "twofish-cfb", "serpent-ofb",
319    "blowfish-counter", "rc4", "seal", "salsa20/8", "shake128-xof"])
320
321 ###--------------------------------------------------------------------------
322 class TestAuthenticatedEncryption \
323         (HashBufferTestMixin, T.GenericTestMixin):
324   """Test authenticated encryption schemes."""
325
326   def _test_aead(me, aecls):
327
328     ## Check the class properties.
329     me.assertEqual(type(aecls.name), str)
330     me.assertTrue(isinstance(aecls.keysz, C.KeySZ))
331     me.assertTrue(isinstance(aecls.noncesz, C.KeySZ))
332     me.assertTrue(isinstance(aecls.tagsz, C.KeySZ))
333     me.assertEqual(type(aecls.blksz), int)
334     me.assertEqual(type(aecls.bufsz), int)
335     me.assertEqual(type(aecls.ohd), int)
336     me.assertEqual(type(aecls.flags), int)
337
338     ## Check round-tripping, with full precommitment.  First, select some
339     ## parameters.  (It's conceivable that some AEAD schemes are more
340     ## restrictive than advertised by the various properties, but this works
341     ## out OK in practice.)
342     k = T.span(aecls.keysz.default)
343     n = T.span(aecls.noncesz.default)
344     if aecls.flags&C.AEADF_NOAAD: h = T.span(0)
345     else: h = T.span(131)
346     m = T.span(253)
347     tsz = aecls.tagsz.default
348     key = aecls(k)
349
350     ## Next, encrypt a message, checking that things are proper as we go.
351     enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
352     me.assertEqual(enc.hsz, len(h))
353     me.assertEqual(enc.msz, len(m))
354     me.assertEqual(enc.mlen, 0)
355     me.assertEqual(enc.tsz, tsz)
356     aad = enc.aad()
357     if aecls.flags&C.AEADF_AADNDEP: me.assertEqual(aad.hsz, len(h))
358     else: me.assertEqual(aad.hsz, None)
359     me.assertEqual(aad.hlen, 0)
360     if not aecls.flags&C.AEADF_NOAAD:
361       aad.hash(h[0:83])
362       me.assertEqual(aad.hlen, 83)
363       aad.hash(h[83:131])
364       me.assertEqual(aad.hlen, 131)
365     c0 = enc.encrypt(m[0:57])
366     me.assertEqual(enc.mlen, 57)
367     me.assertTrue(57 - aecls.bufsz <= len(c0) <= 57 + aecls.ohd)
368     c1 = enc.encrypt(m[57:189])
369     me.assertEqual(enc.mlen, 189)
370     me.assertTrue(132 - aecls.bufsz <= len(c1) <=
371                   132 + aecls.bufsz + aecls.ohd)
372     c2 = enc.encrypt(m[189:253])
373     me.assertEqual(enc.mlen, 253)
374     me.assertTrue(64 - aecls.bufsz <= len(c2) <=
375                   64 + aecls.bufsz + aecls.ohd)
376     c3, t = enc.done(aad = aad)
377     me.assertTrue(len(c3) <= aecls.bufsz + aecls.ohd)
378     c = c0 + c1 + c2 + c3
379     me.assertTrue(len(m) <= len(c) <= len(m) + aecls.ohd)
380     me.assertEqual(len(t), tsz)
381
382     ## And now decrypt it again, with different record boundaries.
383     dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
384     me.assertEqual(dec.hsz, len(h))
385     me.assertEqual(dec.csz, len(c))
386     me.assertEqual(dec.clen, 0)
387     me.assertEqual(dec.tsz, tsz)
388     aad = dec.aad()
389     if aecls.flags&C.AEADF_AADNDEP: me.assertEqual(aad.hsz, len(h))
390     else: me.assertEqual(aad.hsz, None)
391     me.assertEqual(aad.hlen, 0)
392     aad.hash(h)
393     m0 = dec.decrypt(c[0:156])
394     me.assertTrue(156 - aecls.bufsz <= len(m0) <= 156)
395     m1 = dec.decrypt(c[156:])
396     me.assertTrue(len(c) - 156 - aecls.bufsz <= len(m1) <=
397                   len(c) - 156 + aecls.bufsz)
398     m2 = dec.done(tag = t, aad = aad)
399     me.assertEqual(m0 + m1 + m2, m)
400
401     ## And again, with the wrong tag.
402     dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
403     aad = dec.aad(); aad.hash(h)
404     _ = dec.decrypt(c)
405     me.assertRaises(ValueError, dec.done, tag = t ^ tsz*C.bytes("55"))
406
407     ## Check that the all-in-one methods work.
408     me.assertEqual((c, t),
409                    key.encrypt(n = n, h = h, m = m, tsz = tsz))
410     me.assertEqual(m,
411                    key.decrypt(n = n, h = h, c = c, t = t))
412
413     ## Check that bad key, nonce, and tag lengths are rejected.
414     badlen = bad_key_size(aecls.keysz)
415     if badlen is not None: me.assertRaises(ValueError, aecls, T.span(badlen))
416     badlen = bad_key_size(aecls.noncesz)
417     if badlen is not None:
418       me.assertRaises(ValueError, key.enc, nonce = T.span(badlen),
419                       hsz = len(h), msz = len(m), tsz = tsz)
420       me.assertRaises(ValueError, key.dec, nonce = T.span(badlen),
421                       hsz = len(h), csz = len(c), tsz = tsz)
422       if not aecls.flags&C.AEADF_PCTSZ:
423         enc = key.enc(nonce = n, hsz = 0, msz = len(m))
424         _ = enc.encrypt(m)
425         me.assertRaises(ValueError, enc.done, tsz = badlen)
426     badlen = bad_key_size(aecls.tagsz)
427     if badlen is not None:
428       me.assertRaises(ValueError, key.enc, nonce = n,
429                       hsz = len(h), msz = len(m), tsz = badlen)
430       me.assertRaises(ValueError, key.dec, nonce = n,
431                       hsz = len(h), csz = len(c), tsz = badlen)
432
433     ## Check that we can't get a loose `aad' object from a scheme which has
434     ## nonce-dependent AAD processing.
435     if aecls.flags&C.AEADF_AADNDEP: me.assertRaises(ValueError, key.aad)
436
437     ## Check the menagerie of AAD hashing methods.
438     if not aecls.flags&C.AEADF_NOAAD:
439       def mkhash(hsz):
440         enc = key.enc(nonce = n, hsz = hsz, msz = 0, tsz = tsz)
441         aad = enc.aad()
442         return aad, lambda: enc.done(aad = aad)[1]
443       me.check_hashbuffer(mkhash)
444
445     ## Check that encryption/decryption works with the given precommitments.
446     def quick_enc_check(**kw):
447       enc = key.enc(**kw)
448       aad = enc.aad().hash(h)
449       c0 = enc.encrypt(m); c1, tt = enc.done(aad = aad, tsz = tsz)
450       me.assertEqual((c, t), (c0 + c1, tt))
451     def quick_dec_check(**kw):
452       dec = key.dec(**kw)
453       aad = dec.aad().hash(h)
454       m0 = dec.decrypt(c); m1 = dec.done(aad = aad, tag = t)
455       me.assertEqual(m, m0 + m1)
456
457     ## Check that we can get away without precommitting to the header length
458     ## if and only if the AEAD scheme says it will let us.
459     if aecls.flags&C.AEADF_PCHSZ:
460       me.assertRaises(ValueError, key.enc, nonce = n,
461                       msz = len(m), tsz = tsz)
462       me.assertRaises(ValueError, key.dec, nonce = n,
463                       csz = len(c), tsz = tsz)
464     else:
465       quick_enc_check(nonce = n, msz = len(m), tsz = tsz)
466       quick_dec_check(nonce = n, csz = len(c), tsz = tsz)
467
468     ## Check that we can get away without precommitting to the message/
469     ## ciphertext length if and only if the AEAD scheme says it will let us.
470     if aecls.flags&C.AEADF_PCMSZ:
471       me.assertRaises(ValueError, key.enc, nonce = n,
472                       hsz = len(h), tsz = tsz)
473       me.assertRaises(ValueError, key.dec, nonce = n,
474                       hsz = len(h), tsz = tsz)
475     else:
476       quick_enc_check(nonce = n, hsz = len(h), tsz = tsz)
477       quick_dec_check(nonce = n, hsz = len(h), tsz = tsz)
478
479     ## Check that we can get away without precommitting to the tag length if
480     ## and only if the AEAD scheme says it will let us.
481     if aecls.flags&C.AEADF_PCTSZ:
482       me.assertRaises(ValueError, key.enc, nonce = n,
483                       hsz = len(h), msz = len(m))
484       me.assertRaises(ValueError, key.dec, nonce = n,
485                       hsz = len(h), csz = len(c))
486     else:
487       quick_enc_check(nonce = n, hsz = len(h), msz = len(m))
488       quick_dec_check(nonce = n, hsz = len(h), csz = len(c))
489
490     ## Check that if we precommit to the header length, we're properly held
491     ## to the commitment.
492     if not aecls.flags&C.AEADF_NOAAD:
493
494       ## First, check encryption with underrun.  If we must supply AAD first,
495       ## then the underrun will be reported when we start trying to encrypt;
496       ## otherwise, checking is delayed until `done'.
497       enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
498       aad = enc.aad().hash(h[0:83])
499       if aecls.flags&C.AEADF_AADFIRST:
500         me.assertRaises(ValueError, enc.encrypt, m)
501       else:
502         _ = enc.encrypt(m)
503         me.assertRaises(ValueError, enc.done, aad = aad)
504
505       ## Next, check decryption with underrun.  If we must supply AAD first,
506       ## then the underrun will be reported when we start trying to encrypt;
507       ## otherwise, checking is delayed until `done'.
508       dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
509       aad = dec.aad().hash(h[0:83])
510       if aecls.flags&C.AEADF_AADFIRST:
511         me.assertRaises(ValueError, dec.decrypt, c)
512       else:
513         _ = dec.decrypt(c)
514         me.assertRaises(ValueError, dec.done, tag = t, aad = aad)
515
516       ## If AAD processing is nonce-dependent then an overrun will be
517       ## detected imediately.
518       if aecls.flags&C.AEADF_AADNDEP:
519         enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
520         aad = enc.aad().hash(h[0:83])
521         me.assertRaises(ValueError, aad.hash, h[82:131])
522         dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
523         aad = dec.aad().hash(h[0:83])
524         me.assertRaises(ValueError, aad.hash, h[82:131])
525
526     ## Some additional tests for nonce-dependent `aad' objects.
527     if aecls.flags&C.AEADF_AADNDEP:
528
529       ## Check that `aad' objects can't be used once their parents are gone.
530       enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
531       aad = enc.aad()
532       del enc
533       me.assertRaises(ValueError, aad.hash, h)
534
535       ## Check that they can't be crossed over.
536       enc0 = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
537       enc1 = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
538       enc0.aad().hash(h)
539       aad1 = enc1.aad().hash(h)
540       _ = enc0.encrypt(m)
541       me.assertRaises(ValueError, enc0.done, tsz = tsz, aad = aad1)
542
543     ## Test copying AAD.
544     if not aecls.flags&C.AEADF_AADNDEP and not aecls.flags&C.AEADF_NOAAD:
545       aad0 = key.aad()
546       aad0.hash(h[0:83])
547       aad1 = aad0.copy()
548       aad2 = aad1.copy()
549       aad0.hash(h[83:131])
550       aad1.hash(h[83:131])
551       aad2.hash(h[83:131] ^ 48*C.bytes("ff"))
552       me.assertEqual(key.enc(nonce = n, hsz = len(h),
553                              msz = 0, tsz = tsz).done(aad = aad0),
554                      key.enc(nonce = n, hsz = len(h),
555                              msz = 0, tsz = tsz).done(aad = aad1))
556       me.assertNotEqual(key.enc(nonce = n, hsz = len(h),
557                                 msz = 0, tsz = tsz).done(aad = aad0),
558                         key.enc(nonce = n, hsz = len(h),
559                                 msz = 0, tsz = tsz).done(aad = aad2))
560
561     ## Check that if we precommit to the message length, we're properly held
562     ## to the commitment.  (Fortunately, this is way simpler than the AAD
563     ## case above.)  First, try an underrun.
564     enc = key.enc(nonce = n, hsz = 0, msz = len(m), tsz = tsz)
565     _ = enc.encrypt(m[0:183])
566     me.assertRaises(ValueError, enc.done, tsz = tsz)
567     dec = key.dec(nonce = n, hsz = 0, csz = len(c), tsz = tsz)
568     _ = dec.decrypt(c[0:183])
569     me.assertRaises(ValueError, dec.done, tag = t)
570
571     ## And now an overrun.
572     enc = key.enc(nonce = n, hsz = 0, msz = 183, tsz = tsz)
573     me.assertRaises(ValueError, enc.encrypt, m)
574     dec = key.dec(nonce = n, hsz = 0, csz = 183, tsz = tsz)
575     me.assertRaises(ValueError, dec.decrypt, c)
576
577     ## Finally, check that if we precommit to a tag length, we're properly
578     ## held to the commitment.  This depends on being able to find a tag size
579     ## which isn't the default.
580     tsz1 = different_key_size(aecls.tagsz, tsz)
581     if tsz1 is not None:
582       enc = key.enc(nonce = n, hsz = 0, msz = len(m), tsz = tsz1)
583       _ = enc.encrypt(m)
584       me.assertRaises(ValueError, enc.done, tsz = tsz)
585       dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz1)
586       aad = dec.aad().hash(h)
587       _ = dec.decrypt(c)
588       me.assertRaises(ValueError, enc.done, tsz = tsz, aad = aad)
589
590 TestAuthenticatedEncryption.generate_testcases \
591   ((name, C.gcaeads[name]) for name in
592    ["des3-ccm", "blowfish-ocb1", "square-ocb3", "rijndael-gcm",
593     "serpent-eax", "salsa20-naclbox", "chacha20-poly1305"])
594
595 ###--------------------------------------------------------------------------
596 class BaseTestHash (HashBufferTestMixin):
597   """Base class for testing hash functions."""
598
599   def check_hash(me, hcls, need_bufsz = True):
600     """
601     Check hash class HCLS.
602
603     If NEED_BUFSZ is false, then don't insist that HCLS has a working `bufsz'
604     attribute.  This test is mostly reused for MACs, which don't have this
605     attribute.
606     """
607     ## Check the class properties.
608     me.assertEqual(type(hcls.name), str)
609     if need_bufsz: me.assertEqual(type(hcls.bufsz), int)
610     me.assertEqual(type(hcls.hashsz), int)
611
612     ## Set some initial values.
613     m = T.span(131)
614     h = hcls().hash(m).done()
615
616     ## Check that hash length comes out right.
617     me.assertEqual(len(h), hcls.hashsz)
618
619     ## Check that we get the same answer if we split the message up.
620     me.assertEqual(h, hcls().hash(m[0:73]).hash(m[73:131]).done())
621
622     ## Check the `check' method.
623     me.assertTrue(hcls().hash(m).check(h))
624     me.assertFalse(hcls().hash(m).check(h ^ hcls.hashsz*C.bytes("aa")))
625
626     ## Check the menagerie of random hashing methods.
627     def mkhash(_):
628       h = hcls()
629       return h, h.done
630     me.check_hashbuffer(mkhash)
631
632 class TestHash (BaseTestHash, T.GenericTestMixin):
633   """Test hash functions."""
634   def _test_hash(me, hcls):    me.check_hash(hcls, need_bufsz = True)
635
636 TestHash.generate_testcases((name, C.gchashes[name]) for name in
637   ["md5", "sha", "whirlpool", "sha256", "sha512/224", "sha3-384", "shake256",
638    "crc32"])
639
640 ###--------------------------------------------------------------------------
641 class TestMessageAuthentication (BaseTestHash, T.GenericTestMixin):
642   """Test message authentication codes."""
643
644   def _test_mac(me, mcls):
645
646     ## Check the MAC properties.
647     me.assertEqual(type(mcls.name), str)
648     me.assertTrue(isinstance(mcls.keysz, C.KeySZ))
649     me.assertEqual(type(mcls.tagsz), int)
650
651     ## Test hashing.
652     k = T.span(mcls.keysz.default)
653     key = mcls(k)
654     me.assertEqual(key.hashsz, key.tagsz)
655     me.check_hash(key, need_bufsz = False)
656
657     ## Check that bad key lengths are rejected.
658     badlen = bad_key_size(mcls.keysz)
659     if badlen is not None: me.assertRaises(ValueError, mcls, T.span(badlen))
660
661 TestMessageAuthentication.generate_testcases \
662   ((name, C.gcmacs[name]) for name in
663    ["sha-hmac", "rijndael-cmac", "twofish-pmac1", "kmac128"])
664
665 class TestPoly1305 (HashBufferTestMixin):
666   """Check the Poly1305 one-time message authentication function."""
667
668   def test_poly1305(me):
669
670     ## Check the MAC properties.
671     me.assertEqual(C.poly1305.name, "poly1305")
672     me.assertEqual(type(C.poly1305.keysz), C.KeySZSet)
673     me.assertEqual(C.poly1305.keysz.default, 16)
674     me.assertEqual(C.poly1305.keysz.set, set([16]))
675     me.assertEqual(C.poly1305.tagsz, 16)
676     me.assertEqual(C.poly1305.masksz, 16)
677
678     ## Set some initial values.
679     k = T.span(16)
680     u = T.span(64)[-16:]
681     m = T.span(149)
682     key = C.poly1305(k)
683     t = key(u).hash(m).done()
684
685     ## Check the key properties.
686     me.assertEqual(key.name, "poly1305")
687     me.assertEqual(key.tagsz, 16)
688     me.assertEqual(key.tagsz, 16)
689     me.assertEqual(len(t), 16)
690
691     ## Check that we get the same answer if we split the message up.
692     me.assertEqual(t, key(u).hash(m[0:86]).hash(m[86:149]).done())
693
694     ## Check the `check' method.
695     me.assertTrue(key(u).hash(m).check(t))
696     me.assertFalse(key(u).hash(m).check(t ^ 16*C.bytes("cc")))
697
698     ## Check the menagerie of random hashing methods.
699     def mkhash(_):
700       h = key(u)
701       return h, h.done
702     me.check_hashbuffer(mkhash)
703
704     ## Check that we can't complete hashing without a mask.
705     me.assertRaises(ValueError, key().hash(m).done)
706
707     ## Check `concat'.
708     h0 = key().hash(m[0:96])
709     h1 = key().hash(m[96:117])
710     me.assertEqual(t, key(u).concat(h0, h1).hash(m[117:149]).done())
711     key1 = C.poly1305(k)
712     me.assertRaises(TypeError, key().concat, key1().hash(m[0:96]), h1)
713     me.assertRaises(TypeError, key().concat, h0, key1().hash(m[96:117]))
714     me.assertRaises(ValueError, key().concat, key().hash(m[0:93]), h1)
715
716 ###--------------------------------------------------------------------------
717 class TestHLatin (U.TestCase):
718   """Test the `hsalsa20' and `hchacha20' functions."""
719
720   def test_hlatin(me):
721     kk = [T.span(sz) for sz in [10, 16, 32]]
722     n = T.span(16)
723     bad_k = T.span(18)
724     bad_n = T.span(13)
725     for fn in [C.hsalsa208_prf, C.hsalsa2012_prf, C.hsalsa20_prf,
726                C.hchacha8_prf, C.hchacha12_prf, C.hchacha20_prf]:
727       for k in kk:
728         h = fn(k, n)
729         me.assertEqual(len(h), 32)
730       me.assertRaises(ValueError, fn, bad_k, n)
731       me.assertRaises(ValueError, fn, k, bad_n)
732
733 ###--------------------------------------------------------------------------
734 class TestKeccak (HashBufferTestMixin):
735   """Test the Keccak-p[1600, n] sponge function."""
736
737   def test_keccak(me):
738
739     ## Make a state and feed some stuff into it.
740     m0 = T.bin("some initial string")
741     m1 = T.bin("awesome follow-up string")
742     st0 = C.Keccak1600()
743     me.assertEqual(st0.nround, 24)
744     st0.mix(m0).step()
745
746     ## Make another step with a different round count.
747     st1 = C.Keccak1600(23)
748     st1.mix(m0).step()
749     me.assertNotEqual(st0.extract(32), st1.extract(32))
750
751     ## Check state copying.
752     st1 = st0.copy()
753     mask = st1.extract(len(m1))
754     st0.mix(m1)
755     st1.mix(m1)
756     me.assertEqual(st0.extract(32), st1.extract(32))
757
758     ## Check error conditions.
759     _ = st0.extract(200)
760     me.assertRaises(ValueError, st0.extract, 201)
761     st0.mix(T.span(200))
762     me.assertRaises(ValueError, st0.mix, T.span(201))
763
764   def check_shake(me, xcls, c, done_matches_xof = True):
765     """
766     Test the SHAKE and cSHAKE XOFs.
767
768     This is also used for testing KMAC, but that sets DONE_MATCHES_XOF false
769     to indicate that the XOF output is range-separated from the fixed-length
770     outputs (unlike the basic SHAKE functions).
771     """
772
773     ## Check the hash attributes.
774     x = xcls()
775     me.assertEqual(x.rate, 200 - c)
776     me.assertEqual(x.buffered, 0)
777     me.assertEqual(x.state, "absorb")
778
779     ## Set some initial values.
780     func = T.bin("TESTXOF")
781     perso = T.bin("catacomb-python test")
782     m = T.span(167)
783     h0 = xcls().hash(m).done(193)
784     me.assertEqual(len(h0), 193)
785     h1 = xcls(func = func, perso = perso).hash(m).done(193)
786     me.assertEqual(len(h1), 193)
787     me.assertNotEqual(h0, h1)
788
789     ## Check input and output in pieces, and the state machine.
790     if done_matches_xof: h = h0
791     else: h = xcls().hash(m).xof().get(len(h0))
792     x = xcls().hash(m[0:76]).hash(m[76:167]).xof()
793     me.assertEqual(h, x.get(98) + x.get(95))
794
795     ## Check masking.
796     x = xcls().hash(m).xof()
797     me.assertEqual(x.mask(m), m ^ h[0:len(m)])
798
799     ## Check the `check' method.
800     me.assertTrue(xcls().hash(m).check(h0))
801     me.assertFalse(xcls().hash(m).check(h1))
802
803     ## Check the menagerie of random hashing methods.
804     def mkhash(_):
805       x = xcls(func = func, perso = perso)
806       return x, x.done
807     me.check_hashbuffer(mkhash)
808
809     ## Check the state machine tracking.
810     x = xcls(); me.assertEqual(x.state, "absorb")
811     x.hash(m); me.assertEqual(x.state, "absorb")
812     xx = x.copy()
813     h = xx.done(); me.assertEqual(len(h), 100 - x.rate//2)
814     me.assertEqual(xx.state, "dead")
815     me.assertRaises(ValueError, xx.done, 1)
816     me.assertRaises(ValueError, xx.get, 1)
817     me.assertEqual(x.state, "absorb")
818     me.assertRaises(ValueError, x.get, 1)
819     x.xof(); me.assertEqual(x.state, "squeeze")
820     me.assertRaises(ValueError, x.done, 1)
821     _ = x.get(1)
822     yy = x.copy(); me.assertEqual(yy.state, "squeeze")
823
824   def test_shake128(me): me.check_shake(C.Shake128, 32)
825   def test_shake256(me): me.check_shake(C.Shake256, 64)
826
827   def check_kmac(me, mcls, c):
828     k = T.span(32)
829     me.check_shake(lambda func = None, perso = None:
830                      mcls(k, perso = perso),
831                    c, done_matches_xof = False)
832
833   def test_kmac128(me): me.check_kmac(C.KMAC128, 32)
834   def test_kmac256(me): me.check_kmac(C.KMAC256, 64)
835
836 ###--------------------------------------------------------------------------
837 class TestPRP (T.GenericTestMixin):
838   """Test pseudorandom permutations (PRPs)."""
839
840   def _test_prp(me, pcls):
841
842     ## Check the PRP properties.
843     me.assertEqual(type(pcls.name), str)
844     me.assertTrue(isinstance(pcls.keysz, C.KeySZ))
845     me.assertEqual(type(pcls.blksz), int)
846
847     ## Check round-tripping.
848     k = T.span(pcls.keysz.default)
849     key = pcls(k)
850     m = T.span(pcls.blksz)
851     c = key.encrypt(m)
852     me.assertEqual(len(c), pcls.blksz)
853     me.assertEqual(m, key.decrypt(c))
854
855     ## Check that bad key lengths are rejected.
856     badlen = bad_key_size(pcls.keysz)
857     if badlen is not None: me.assertRaises(ValueError, pcls, T.span(badlen))
858
859     ## Check that bad blocks are rejected.
860     badblk = T.span(pcls.blksz + 1)
861     me.assertRaises(ValueError, key.encrypt, badblk)
862     me.assertRaises(ValueError, key.decrypt, badblk)
863
864 TestPRP.generate_testcases((name, C.gcprps[name]) for name in
865   ["desx", "blowfish", "rijndael"])
866
867 ###----- That's all, folks --------------------------------------------------
868
869 if __name__ == "__main__": U.main()