chiark / gitweb /
t/t-algorithms.py: Add a simple test for `Keccak1600.copy'.
[catacomb-python] / t / t-algorithms.py
1 ### -*- mode: python, coding: utf-8 -*-
2 ###
3 ### Test symmetric algorithms
4 ###
5 ### (c) 2019 Straylight/Edgeware
6 ###
7
8 ###----- Licensing notice ---------------------------------------------------
9 ###
10 ### This file is part of the Python interface to Catacomb.
11 ###
12 ### Catacomb/Python is free software: you can redistribute it and/or
13 ### modify it under the terms of the GNU General Public License as
14 ### published by the Free Software Foundation; either version 2 of the
15 ### License, or (at your option) any later version.
16 ###
17 ### Catacomb/Python is distributed in the hope that it will be useful, but
18 ### WITHOUT ANY WARRANTY; without even the implied warranty of
19 ### MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
20 ### General Public License for more details.
21 ###
22 ### You should have received a copy of the GNU General Public License
23 ### along with Catacomb/Python.  If not, write to the Free Software
24 ### Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307,
25 ### USA.
26
27 ###--------------------------------------------------------------------------
28 ### Imported modules.
29
30 import catacomb as C
31 import unittest as U
32 import testutils as T
33
34 ###--------------------------------------------------------------------------
35 ### Utilities.
36
37 def bad_key_size(ksz):
38   if isinstance(ksz, C.KeySZAny): return None
39   elif isinstance(ksz, C.KeySZRange):
40     if ksz.mod != 1: return ksz.min + 1
41     elif ksz.max != 0: return ksz.max + 1
42     elif ksz.min != 0: return ksz.min - 1
43     else: return None
44   elif isinstance(ksz, C.KeySZSet):
45     for sz in sorted(ksz.set):
46       if sz + 1 not in ksz.set: return sz + 1
47     assert False, "That should have worked."
48   else:
49     return None
50
51 def different_key_size(ksz, sz):
52   if isinstance(ksz, C.KeySZAny): return sz + 1
53   elif isinstance(ksz, C.KeySZRange):
54     if sz > ksz.min: return sz - ksz.mod
55     elif ksz.max == 0 or sz < ksz.max: return sz + ksz.mod
56     else: return None
57   elif isinstance(ksz, C.KeySZSet):
58     for sz1 in sorted(ksz.set):
59       if sz != sz1: return sz1
60     return None
61   else:
62     return None
63
64 class HashBufferTestMixin (U.TestCase):
65   """Mixin class for testing all of the various `hash...' methods."""
66
67   def check_hashbuffer_hashn(me, w, bigendp, makefn, hashfn):
68     """Check `hashuN'."""
69
70     ## Check encoding an integer.
71     h0, donefn0 = makefn(w + 2)
72     hashfn(h0.hashu8(0x00), T.bytes_as_int(w, bigendp)).hashu8(w + 1)
73     h1, donefn1 = makefn(w + 2)
74     h1.hash(T.span(w + 2))
75     me.assertEqual(donefn0(), donefn1())
76
77     ## Check overflow detection.
78     h0, _ = makefn(w)
79     me.assertRaises((OverflowError, ValueError),
80                     hashfn, h0, 1 << 8*w)
81
82   def check_hashbuffer_bufn(me, w, bigendp, makefn, hashfn):
83     """Check `hashbufN'."""
84
85     ## Go through a number of different sizes.
86     for n in [0, 1, 7, 8, 19, 255, 12345, 65535, 123456]:
87       if n >= 1 << 8*w: continue
88       h0, donefn0 = makefn(2 + w + n)
89       hashfn(h0.hashu8(0x00), T.span(n)).hashu8(0xff)
90       h1, donefn1 = makefn(2 + w + n)
91       h1.hash(T.prep_lenseq(w, n, bigendp, True))
92       me.assertEqual(donefn0(), donefn1())
93
94     ## Check blocks which are too large for the length prefix.
95     if w <= 3:
96       n = 1 << 8*w
97       h0, _ = makefn(w + n)
98       me.assertRaises((ValueError, OverflowError, TypeError),
99                       hashfn, h0, C.ByteString.zero(n))
100
101   def check_hashbuffer(me, makefn):
102     """Test the various `hash...' methods."""
103
104     ## Check `hashuN'.
105     me.check_hashbuffer_hashn(1, True, makefn, lambda h, n: h.hashu8(n))
106     me.check_hashbuffer_hashn(2, True, makefn, lambda h, n: h.hashu16(n))
107     me.check_hashbuffer_hashn(2, True, makefn, lambda h, n: h.hashu16b(n))
108     me.check_hashbuffer_hashn(2, False, makefn, lambda h, n: h.hashu16l(n))
109     if hasattr(makefn(0)[0], "hashu24"):
110       me.check_hashbuffer_hashn(3, True, makefn, lambda h, n: h.hashu24(n))
111       me.check_hashbuffer_hashn(3, True, makefn, lambda h, n: h.hashu24b(n))
112       me.check_hashbuffer_hashn(3, False, makefn, lambda h, n: h.hashu24l(n))
113     me.check_hashbuffer_hashn(4, True, makefn, lambda h, n: h.hashu32(n))
114     me.check_hashbuffer_hashn(4, True, makefn, lambda h, n: h.hashu32b(n))
115     me.check_hashbuffer_hashn(4, False, makefn, lambda h, n: h.hashu32l(n))
116     if hasattr(makefn(0)[0], "hashu64"):
117       me.check_hashbuffer_hashn(8, True, makefn, lambda h, n: h.hashu64(n))
118       me.check_hashbuffer_hashn(8, True, makefn, lambda h, n: h.hashu64b(n))
119       me.check_hashbuffer_hashn(8, False, makefn, lambda h, n: h.hashu64l(n))
120
121     ## Check `hashbufN'.
122     me.check_hashbuffer_bufn(1, True, makefn, lambda h, x: h.hashbuf8(x))
123     me.check_hashbuffer_bufn(2, True, makefn, lambda h, x: h.hashbuf16(x))
124     me.check_hashbuffer_bufn(2, True, makefn, lambda h, x: h.hashbuf16b(x))
125     me.check_hashbuffer_bufn(2, False, makefn, lambda h, x: h.hashbuf16l(x))
126     if hasattr(makefn(0)[0], "hashbuf24"):
127       me.check_hashbuffer_bufn(3, True, makefn, lambda h, x: h.hashbuf24(x))
128       me.check_hashbuffer_bufn(3, True, makefn, lambda h, x: h.hashbuf24b(x))
129       me.check_hashbuffer_bufn(3, False, makefn, lambda h, x: h.hashbuf24l(x))
130     me.check_hashbuffer_bufn(4, True, makefn, lambda h, x: h.hashbuf32(x))
131     me.check_hashbuffer_bufn(4, True, makefn, lambda h, x: h.hashbuf32b(x))
132     me.check_hashbuffer_bufn(4, False, makefn, lambda h, x: h.hashbuf32l(x))
133     if hasattr(makefn(0)[0], "hashbuf64"):
134       me.check_hashbuffer_bufn(8, True, makefn, lambda h, x: h.hashbuf64(x))
135       me.check_hashbuffer_bufn(8, True, makefn, lambda h, x: h.hashbuf64b(x))
136       me.check_hashbuffer_bufn(8, False, makefn, lambda h, x: h.hashbuf64l(x))
137
138 ###--------------------------------------------------------------------------
139 class TestKeysize (U.TestCase):
140
141   def test_any(me):
142
143     ## A typical one-byte spec.
144     ksz = C.seal.keysz
145     me.assertEqual(type(ksz), C.KeySZAny)
146     me.assertEqual(ksz.default, 20)
147     me.assertEqual(ksz.min, 0)
148     me.assertEqual(ksz.max, 0)
149     for n in [0, 12, 20, 5000]:
150       me.assertTrue(ksz.check(n))
151       me.assertEqual(ksz.best(n), n)
152       me.assertEqual(ksz.pad(n), n)
153
154     ## A typical two-byte spec.  (No published algorithms actually /need/ a
155     ## two-byte key-size spec, but all of the HMAC variants use one anyway.)
156     ksz = C.sha256_hmac.keysz
157     me.assertEqual(type(ksz), C.KeySZAny)
158     me.assertEqual(ksz.default, 32)
159     me.assertEqual(ksz.min, 0)
160     me.assertEqual(ksz.max, 0)
161     for n in [0, 12, 20, 5000]:
162       me.assertTrue(ksz.check(n))
163       me.assertEqual(ksz.best(n), n)
164       me.assertEqual(ksz.pad(n), n)
165
166     ## Check construction.
167     ksz = C.KeySZAny(15)
168     me.assertEqual(ksz.default, 15)
169     me.assertEqual(ksz.min, 0)
170     me.assertEqual(ksz.max, 0)
171     me.assertRaises(ValueError, lambda: C.KeySZAny(-8))
172     me.assertEqual(C.KeySZAny(0).default, 0)
173
174   def test_set(me):
175     ## Note that no published algorithm uses a 16-bit `set' spec.
176
177     ## A typical spec.
178     ksz = C.salsa20.keysz
179     me.assertEqual(type(ksz), C.KeySZSet)
180     me.assertEqual(ksz.default, 32)
181     me.assertEqual(ksz.min, 10)
182     me.assertEqual(ksz.max, 32)
183     me.assertEqual(set(ksz.set), set([10, 16, 32]))
184     for x, best, pad in [(9, None, 10), (10, 10, 10), (11, 10, 16),
185                          (15, 10, 16), (16, 16, 16), (17, 16, 32),
186                          (31, 16, 32), (32, 32, 32), (33, 32, None)]:
187       if x == best == pad: me.assertTrue(ksz.check(x))
188       else: me.assertFalse(ksz.check(x))
189       if best is None: me.assertRaises(ValueError, ksz.best, x)
190       else: me.assertEqual(ksz.best(x), best)
191       if pad is None: me.assertRaises(ValueError, ksz.pad, x)
192       else: me.assertEqual(ksz.pad(x), pad)
193
194     ## Check construction.
195     ksz = C.KeySZSet(7)
196     me.assertEqual(ksz.default, 7)
197     me.assertEqual(set(ksz.set), set([7]))
198     me.assertEqual(ksz.min, 7)
199     me.assertEqual(ksz.max, 7)
200     ksz = C.KeySZSet(7, [3, 6, 9])
201     me.assertEqual(ksz.default, 7)
202     me.assertEqual(set(ksz.set), set([3, 6, 7, 9]))
203     me.assertEqual(ksz.min, 3)
204     me.assertEqual(ksz.max, 9)
205
206   def test_range(me):
207     ## Note that no published algorithm uses a 16-bit `range' spec, or an
208     ## unbounded `range'.
209
210     ## A typical spec.
211     ksz = C.rijndael.keysz
212     me.assertEqual(type(ksz), C.KeySZRange)
213     me.assertEqual(ksz.default, 32)
214     me.assertEqual(ksz.min, 4)
215     me.assertEqual(ksz.max, 32)
216     me.assertEqual(ksz.mod, 4)
217     for x, best, pad in [(3, None, 4), (4, 4, 4), (5, 4, 8),
218                          (15, 12, 16), (16, 16, 16), (17, 16, 20),
219                          (31, 28, 32), (32, 32, 32), (33, 32, None)]:
220       if x == best == pad: me.assertTrue(ksz.check(x))
221       else: me.assertFalse(ksz.check(x))
222       if best is None: me.assertRaises(ValueError, ksz.best, x)
223       else: me.assertEqual(ksz.best(x), best)
224       if pad is None: me.assertRaises(ValueError, ksz.pad, x)
225       else: me.assertEqual(ksz.pad(x), pad)
226
227     ## Check construction.
228     ksz = C.KeySZRange(28, 21, 35, 7)
229     me.assertEqual(ksz.default, 28)
230     me.assertEqual(ksz.min, 21)
231     me.assertEqual(ksz.max, 35)
232     me.assertEqual(ksz.mod, 7)
233     me.assertRaises(ValueError, C.KeySZRange, 29, 21, 35, 7)
234     me.assertRaises(ValueError, C.KeySZRange, 28, 20, 35, 7)
235     me.assertRaises(ValueError, C.KeySZRange, 28, 21, 34, 7)
236     me.assertRaises(ValueError, C.KeySZRange, 28, -7, 35, 7)
237     me.assertRaises(ValueError, C.KeySZRange, 28, 35, 21, 7)
238     me.assertRaises(ValueError, C.KeySZRange, 35, 21, 28, 7)
239     me.assertRaises(ValueError, C.KeySZRange, 21, 28, 35, 7)
240
241   def test_conversions(me):
242     me.assertEqual(C.KeySZ.fromec(256), 128)
243     me.assertEqual(C.KeySZ.fromschnorr(256), 128)
244     me.assertEqual(round(C.KeySZ.fromdl(2958.6875)), 128)
245     me.assertEqual(round(C.KeySZ.fromif(2958.6875)), 128)
246     me.assertEqual(C.KeySZ.toec(128), 256)
247     me.assertEqual(C.KeySZ.toschnorr(128), 256)
248     me.assertEqual(C.KeySZ.todl(128), 2958.6875)
249     me.assertEqual(C.KeySZ.toif(128), 2958.6875)
250
251 ###--------------------------------------------------------------------------
252 class TestCipher (T.GenericTestMixin):
253   """Test basic symmetric ciphers."""
254
255   def _test_cipher(me, ccls):
256
257     ## Check the class properties.
258     me.assertEqual(type(ccls.name), str)
259     me.assertTrue(isinstance(ccls.keysz, C.KeySZ))
260     me.assertEqual(type(ccls.blksz), int)
261
262     ## Check round-tripping.
263     k = T.span(ccls.keysz.default)
264     iv = T.span(ccls.blksz)
265     m = T.span(253)
266     enc = ccls(k)
267     dec = ccls(k)
268     try: enc.setiv(iv)
269     except ValueError: can_setiv = False
270     else:
271       can_setiv = True
272       dec.setiv(iv)
273     c0 = enc.encrypt(m[0:57])
274     m0 = dec.decrypt(c0)
275     c1 = enc.encrypt(m[57:189])
276     m1 = dec.decrypt(c1)
277     try: enc.bdry()
278     except ValueError: can_bdry = False
279     else:
280       dec.bdry()
281       can_bdry = True
282     c2 = enc.encrypt(m[189:253])
283     m2 = dec.decrypt(c2)
284     me.assertEqual(len(c0) + len(c1) + len(c2), len(m))
285     me.assertEqual(m0, m[0:57])
286     me.assertEqual(m1, m[57:189])
287     me.assertEqual(m2, m[189:253])
288
289     ## Check the `enczero' and `deczero' methods.
290     c3 = enc.enczero(32)
291     me.assertEqual(dec.decrypt(c3), C.ByteString.zero(32))
292     m4 = dec.deczero(32)
293     me.assertEqual(enc.encrypt(m4), C.ByteString.zero(32))
294
295     ## Check that ciphers which support a `boundary' operation actually
296     ## need it.
297     if can_bdry:
298       dec = ccls(k)
299       if can_setiv: dec.setiv(iv)
300       m01 = dec.decrypt(c0 + c1)
301       me.assertEqual(m01, m[0:189])
302
303     ## Check that the boundary actually does something.
304     if can_bdry:
305       dec = ccls(k)
306       if can_setiv: dec.setiv(iv)
307       m012 = dec.decrypt(c0 + c1 + c2)
308       me.assertNotEqual(m012, m)
309
310     ## Check that bad key lengths are rejected.
311     badlen = bad_key_size(ccls.keysz)
312     if badlen is not None: me.assertRaises(ValueError, ccls, T.span(badlen))
313
314 TestCipher.generate_testcases((name, C.gcciphers[name]) for name in
315   ["des-ecb", "rijndael-cbc", "twofish-cfb", "serpent-ofb",
316    "blowfish-counter", "rc4", "seal", "salsa20/8", "shake128-xof"])
317
318 ###--------------------------------------------------------------------------
319 class TestAuthenticatedEncryption \
320         (HashBufferTestMixin, T.GenericTestMixin):
321   """Test authenticated encryption schemes."""
322
323   def _test_aead(me, aecls):
324
325     ## Check the class properties.
326     me.assertEqual(type(aecls.name), str)
327     me.assertTrue(isinstance(aecls.keysz, C.KeySZ))
328     me.assertTrue(isinstance(aecls.noncesz, C.KeySZ))
329     me.assertTrue(isinstance(aecls.tagsz, C.KeySZ))
330     me.assertEqual(type(aecls.blksz), int)
331     me.assertEqual(type(aecls.bufsz), int)
332     me.assertEqual(type(aecls.ohd), int)
333     me.assertEqual(type(aecls.flags), int)
334
335     ## Check round-tripping, with full precommitment.  First, select some
336     ## parameters.  (It's conceivable that some AEAD schemes are more
337     ## restrictive than advertised by the various properties, but this works
338     ## out OK in practice.)
339     k = T.span(aecls.keysz.default)
340     n = T.span(aecls.noncesz.default)
341     if aecls.flags&C.AEADF_NOAAD: h = T.span(0)
342     else: h = T.span(131)
343     m = T.span(253)
344     tsz = aecls.tagsz.default
345     key = aecls(k)
346
347     ## Next, encrypt a message, checking that things are proper as we go.
348     enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
349     me.assertEqual(enc.hsz, len(h))
350     me.assertEqual(enc.msz, len(m))
351     me.assertEqual(enc.mlen, 0)
352     me.assertEqual(enc.tsz, tsz)
353     aad = enc.aad()
354     if aecls.flags&C.AEADF_AADNDEP: me.assertEqual(aad.hsz, len(h))
355     else: me.assertEqual(aad.hsz, None)
356     me.assertEqual(aad.hlen, 0)
357     if not aecls.flags&C.AEADF_NOAAD:
358       aad.hash(h[0:83])
359       me.assertEqual(aad.hlen, 83)
360       aad.hash(h[83:131])
361       me.assertEqual(aad.hlen, 131)
362     c0 = enc.encrypt(m[0:57])
363     me.assertEqual(enc.mlen, 57)
364     me.assertTrue(57 - aecls.bufsz <= len(c0) <= 57 + aecls.ohd)
365     c1 = enc.encrypt(m[57:189])
366     me.assertEqual(enc.mlen, 189)
367     me.assertTrue(132 - aecls.bufsz <= len(c1) <=
368                   132 + aecls.bufsz + aecls.ohd)
369     c2 = enc.encrypt(m[189:253])
370     me.assertEqual(enc.mlen, 253)
371     me.assertTrue(64 - aecls.bufsz <= len(c2) <=
372                   64 + aecls.bufsz + aecls.ohd)
373     c3, t = enc.done(aad = aad)
374     me.assertTrue(len(c3) <= aecls.bufsz + aecls.ohd)
375     c = c0 + c1 + c2 + c3
376     me.assertTrue(len(m) <= len(c) <= len(m) + aecls.ohd)
377     me.assertEqual(len(t), tsz)
378
379     ## And now decrypt it again, with different record boundaries.
380     dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
381     me.assertEqual(dec.hsz, len(h))
382     me.assertEqual(dec.csz, len(c))
383     me.assertEqual(dec.clen, 0)
384     me.assertEqual(dec.tsz, tsz)
385     aad = dec.aad()
386     if aecls.flags&C.AEADF_AADNDEP: me.assertEqual(aad.hsz, len(h))
387     else: me.assertEqual(aad.hsz, None)
388     me.assertEqual(aad.hlen, 0)
389     aad.hash(h)
390     m0 = dec.decrypt(c[0:156])
391     me.assertTrue(156 - aecls.bufsz <= len(m0) <= 156)
392     m1 = dec.decrypt(c[156:])
393     me.assertTrue(len(c) - 156 - aecls.bufsz <= len(m1) <=
394                   len(c) - 156 + aecls.bufsz)
395     m2 = dec.done(tag = t, aad = aad)
396     me.assertEqual(m0 + m1 + m2, m)
397
398     ## And again, with the wrong tag.
399     dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
400     aad = dec.aad(); aad.hash(h)
401     _ = dec.decrypt(c)
402     me.assertRaises(ValueError, dec.done, tag = t ^ tsz*C.bytes("55"))
403
404     ## Check that the all-in-one methods work.
405     me.assertEqual((c, t),
406                    key.encrypt(n = n, h = h, m = m, tsz = tsz))
407     me.assertEqual(m,
408                    key.decrypt(n = n, h = h, c = c, t = t))
409
410     ## Check that bad key, nonce, and tag lengths are rejected.
411     badlen = bad_key_size(aecls.keysz)
412     if badlen is not None: me.assertRaises(ValueError, aecls, T.span(badlen))
413     badlen = bad_key_size(aecls.noncesz)
414     if badlen is not None:
415       me.assertRaises(ValueError, key.enc, nonce = T.span(badlen),
416                       hsz = len(h), msz = len(m), tsz = tsz)
417       me.assertRaises(ValueError, key.dec, nonce = T.span(badlen),
418                       hsz = len(h), csz = len(c), tsz = tsz)
419       if not aecls.flags&C.AEADF_PCTSZ:
420         enc = key.enc(nonce = n, hsz = 0, msz = len(m))
421         _ = enc.encrypt(m)
422         me.assertRaises(ValueError, enc.done, tsz = badlen)
423     badlen = bad_key_size(aecls.tagsz)
424     if badlen is not None:
425       me.assertRaises(ValueError, key.enc, nonce = n,
426                       hsz = len(h), msz = len(m), tsz = badlen)
427       me.assertRaises(ValueError, key.dec, nonce = n,
428                       hsz = len(h), csz = len(c), tsz = badlen)
429
430     ## Check that we can't get a loose `aad' object from a scheme which has
431     ## nonce-dependent AAD processing.
432     if aecls.flags&C.AEADF_AADNDEP: me.assertRaises(ValueError, key.aad)
433
434     ## Check the menagerie of AAD hashing methods.
435     if not aecls.flags&C.AEADF_NOAAD:
436       def mkhash(hsz):
437         enc = key.enc(nonce = n, hsz = hsz, msz = 0, tsz = tsz)
438         aad = enc.aad()
439         return aad, lambda: enc.done(aad = aad)[1]
440       me.check_hashbuffer(mkhash)
441
442     ## Check that encryption/decryption works with the given precommitments.
443     def quick_enc_check(**kw):
444       enc = key.enc(**kw)
445       aad = enc.aad().hash(h)
446       c0 = enc.encrypt(m); c1, tt = enc.done(aad = aad, tsz = tsz)
447       me.assertEqual((c, t), (c0 + c1, tt))
448     def quick_dec_check(**kw):
449       dec = key.dec(**kw)
450       aad = dec.aad().hash(h)
451       m0 = dec.decrypt(c); m1 = dec.done(aad = aad, tag = t)
452       me.assertEqual(m, m0 + m1)
453
454     ## Check that we can get away without precommitting to the header length
455     ## if and only if the AEAD scheme says it will let us.
456     if aecls.flags&C.AEADF_PCHSZ:
457       me.assertRaises(ValueError, key.enc, nonce = n,
458                       msz = len(m), tsz = tsz)
459       me.assertRaises(ValueError, key.dec, nonce = n,
460                       csz = len(c), tsz = tsz)
461     else:
462       quick_enc_check(nonce = n, msz = len(m), tsz = tsz)
463       quick_dec_check(nonce = n, csz = len(c), tsz = tsz)
464
465     ## Check that we can get away without precommitting to the message/
466     ## ciphertext length if and only if the AEAD scheme says it will let us.
467     if aecls.flags&C.AEADF_PCMSZ:
468       me.assertRaises(ValueError, key.enc, nonce = n,
469                       hsz = len(h), tsz = tsz)
470       me.assertRaises(ValueError, key.dec, nonce = n,
471                       hsz = len(h), tsz = tsz)
472     else:
473       quick_enc_check(nonce = n, hsz = len(h), tsz = tsz)
474       quick_dec_check(nonce = n, hsz = len(h), tsz = tsz)
475
476     ## Check that we can get away without precommitting to the tag length if
477     ## and only if the AEAD scheme says it will let us.
478     if aecls.flags&C.AEADF_PCTSZ:
479       me.assertRaises(ValueError, key.enc, nonce = n,
480                       hsz = len(h), msz = len(m))
481       me.assertRaises(ValueError, key.dec, nonce = n,
482                       hsz = len(h), csz = len(c))
483     else:
484       quick_enc_check(nonce = n, hsz = len(h), msz = len(m))
485       quick_dec_check(nonce = n, hsz = len(h), csz = len(c))
486
487     ## Check that if we precommit to the header length, we're properly held
488     ## to the commitment.
489     if not aecls.flags&C.AEADF_NOAAD:
490
491       ## First, check encryption with underrun.  If we must supply AAD first,
492       ## then the underrun will be reported when we start trying to encrypt;
493       ## otherwise, checking is delayed until `done'.
494       enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
495       aad = enc.aad().hash(h[0:83])
496       if aecls.flags&C.AEADF_AADFIRST:
497         me.assertRaises(ValueError, enc.encrypt, m)
498       else:
499         _ = enc.encrypt(m)
500         me.assertRaises(ValueError, enc.done, aad = aad)
501
502       ## Next, check decryption with underrun.  If we must supply AAD first,
503       ## then the underrun will be reported when we start trying to encrypt;
504       ## otherwise, checking is delayed until `done'.
505       dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
506       aad = dec.aad().hash(h[0:83])
507       if aecls.flags&C.AEADF_AADFIRST:
508         me.assertRaises(ValueError, dec.decrypt, c)
509       else:
510         _ = dec.decrypt(c)
511         me.assertRaises(ValueError, dec.done, tag = t, aad = aad)
512
513       ## If AAD processing is nonce-dependent then an overrun will be
514       ## detected imediately.
515       if aecls.flags&C.AEADF_AADNDEP:
516         enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
517         aad = enc.aad().hash(h[0:83])
518         me.assertRaises(ValueError, aad.hash, h[82:131])
519         dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
520         aad = dec.aad().hash(h[0:83])
521         me.assertRaises(ValueError, aad.hash, h[82:131])
522
523     ## Some additional tests for nonce-dependent `aad' objects.
524     if aecls.flags&C.AEADF_AADNDEP:
525
526       ## Check that `aad' objects can't be used once their parents are gone.
527       enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
528       aad = enc.aad()
529       del enc
530       me.assertRaises(ValueError, aad.hash, h)
531
532       ## Check that they can't be crossed over.
533       enc0 = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
534       enc1 = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
535       enc0.aad().hash(h)
536       aad1 = enc1.aad().hash(h)
537       _ = enc0.encrypt(m)
538       me.assertRaises(ValueError, enc0.done, tsz = tsz, aad = aad1)
539
540     ## Test copying AAD.
541     if not aecls.flags&C.AEADF_AADNDEP and not aecls.flags&C.AEADF_NOAAD:
542       aad0 = key.aad()
543       aad0.hash(h[0:83])
544       aad1 = aad0.copy()
545       aad2 = aad1.copy()
546       aad0.hash(h[83:131])
547       aad1.hash(h[83:131])
548       aad2.hash(h[83:131] ^ 48*C.bytes("ff"))
549       me.assertEqual(key.enc(nonce = n, hsz = len(h),
550                              msz = 0, tsz = tsz).done(aad = aad0),
551                      key.enc(nonce = n, hsz = len(h),
552                              msz = 0, tsz = tsz).done(aad = aad1))
553       me.assertNotEqual(key.enc(nonce = n, hsz = len(h),
554                                 msz = 0, tsz = tsz).done(aad = aad0),
555                         key.enc(nonce = n, hsz = len(h),
556                                 msz = 0, tsz = tsz).done(aad = aad2))
557
558     ## Check that if we precommit to the message length, we're properly held
559     ## to the commitment.  (Fortunately, this is way simpler than the AAD
560     ## case above.)  First, try an underrun.
561     enc = key.enc(nonce = n, hsz = 0, msz = len(m), tsz = tsz)
562     _ = enc.encrypt(m[0:183])
563     me.assertRaises(ValueError, enc.done, tsz = tsz)
564     dec = key.dec(nonce = n, hsz = 0, csz = len(c), tsz = tsz)
565     _ = dec.decrypt(c[0:183])
566     me.assertRaises(ValueError, dec.done, tag = t)
567
568     ## And now an overrun.
569     enc = key.enc(nonce = n, hsz = 0, msz = 183, tsz = tsz)
570     me.assertRaises(ValueError, enc.encrypt, m)
571     dec = key.dec(nonce = n, hsz = 0, csz = 183, tsz = tsz)
572     me.assertRaises(ValueError, dec.decrypt, c)
573
574     ## Finally, check that if we precommit to a tag length, we're properly
575     ## held to the commitment.  This depends on being able to find a tag size
576     ## which isn't the default.
577     tsz1 = different_key_size(aecls.tagsz, tsz)
578     if tsz1 is not None:
579       enc = key.enc(nonce = n, hsz = 0, msz = len(m), tsz = tsz1)
580       _ = enc.encrypt(m)
581       me.assertRaises(ValueError, enc.done, tsz = tsz)
582       dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz1)
583       aad = dec.aad().hash(h)
584       _ = dec.decrypt(c)
585       me.assertRaises(ValueError, enc.done, tsz = tsz, aad = aad)
586
587 TestAuthenticatedEncryption.generate_testcases \
588   ((name, C.gcaeads[name]) for name in
589    ["des3-ccm", "blowfish-ocb1", "square-ocb3", "rijndael-gcm",
590     "serpent-eax", "salsa20-naclbox", "chacha20-poly1305"])
591
592 ###--------------------------------------------------------------------------
593 class BaseTestHash (HashBufferTestMixin):
594   """Base class for testing hash functions."""
595
596   def check_hash(me, hcls, need_bufsz = True):
597     """
598     Check hash class HCLS.
599
600     If NEED_BUFSZ is false, then don't insist that HCLS have working `bufsz',
601     `name', or `hashsz' attributes.  This test is mostly reused for MACs,
602     which don't have these attributes.
603     """
604     ## Check the class properties.
605     if need_bufsz:
606       me.assertEqual(type(hcls.name), str)
607       me.assertEqual(type(hcls.bufsz), int)
608       me.assertEqual(type(hcls.hashsz), int)
609
610     ## Set some initial values.
611     m = T.span(131)
612     h = hcls().hash(m).done()
613
614     ## Check that hash length comes out right.
615     if need_bufsz: me.assertEqual(len(h), hcls.hashsz)
616
617     ## Check that we get the same answer if we split the message up.
618     me.assertEqual(h, hcls().hash(m[0:73]).hash(m[73:131]).done())
619
620     ## Check the `check' method.
621     me.assertTrue(hcls().hash(m).check(h))
622     me.assertFalse(hcls().hash(m).check(h ^ len(h)*C.bytes("aa")))
623
624     ## Check the menagerie of random hashing methods.
625     def mkhash(_):
626       h = hcls()
627       return h, h.done
628     me.check_hashbuffer(mkhash)
629
630 class TestHash (BaseTestHash, T.GenericTestMixin):
631   """Test hash functions."""
632   def _test_hash(me, hcls):    me.check_hash(hcls, need_bufsz = True)
633
634 TestHash.generate_testcases((name, C.gchashes[name]) for name in
635   ["md5", "sha", "whirlpool", "sha256", "sha512/224", "sha3-384", "shake256",
636    "crc32"])
637
638 ###--------------------------------------------------------------------------
639 class TestMessageAuthentication (BaseTestHash, T.GenericTestMixin):
640   """Test message authentication codes."""
641
642   def _test_mac(me, mcls):
643
644     ## Check the MAC properties.
645     me.assertEqual(type(mcls.name), str)
646     me.assertTrue(isinstance(mcls.keysz, C.KeySZ))
647     me.assertEqual(type(mcls.tagsz), int)
648
649     ## Test hashing.
650     k = T.span(mcls.keysz.default)
651     key = mcls(k)
652     me.check_hash(key, need_bufsz = False)
653
654     ## Check that bad key lengths are rejected.
655     badlen = bad_key_size(mcls.keysz)
656     if badlen is not None: me.assertRaises(ValueError, mcls, T.span(badlen))
657
658 TestMessageAuthentication.generate_testcases \
659   ((name, C.gcmacs[name]) for name in
660    ["sha-hmac", "rijndael-cmac", "twofish-pmac1", "kmac128"])
661
662 class TestPoly1305 (HashBufferTestMixin):
663   """Check the Poly1305 one-time message authentication function."""
664
665   def test_poly1305(me):
666
667     ## Check the MAC properties.
668     me.assertEqual(C.poly1305.name, "poly1305")
669     me.assertEqual(type(C.poly1305.keysz), C.KeySZSet)
670     me.assertEqual(C.poly1305.keysz.default, 16)
671     me.assertEqual(set(C.poly1305.keysz.set), set([16]))
672     me.assertEqual(C.poly1305.tagsz, 16)
673     me.assertEqual(C.poly1305.masksz, 16)
674
675     ## Set some initial values.
676     k = T.span(16)
677     u = T.span(64)[-16:]
678     m = T.span(149)
679     key = C.poly1305(k)
680     t = key(u).hash(m).done()
681
682     ## Check the key properties.
683     me.assertEqual(len(t), 16)
684
685     ## Check that we get the same answer if we split the message up.
686     me.assertEqual(t, key(u).hash(m[0:86]).hash(m[86:149]).done())
687
688     ## Check the `check' method.
689     me.assertTrue(key(u).hash(m).check(t))
690     me.assertFalse(key(u).hash(m).check(t ^ 16*C.bytes("cc")))
691
692     ## Check the menagerie of random hashing methods.
693     def mkhash(_):
694       h = key(u)
695       return h, h.done
696     me.check_hashbuffer(mkhash)
697
698     ## Check that we can't complete hashing without a mask.
699     me.assertRaises(ValueError, key().hash(m).done)
700
701     ## Check `concat'.
702     h0 = key().hash(m[0:96])
703     h1 = key().hash(m[96:117])
704     me.assertEqual(t, key(u).concat(h0, h1).hash(m[117:149]).done())
705     key1 = C.poly1305(k)
706     me.assertRaises(TypeError, key().concat, key1().hash(m[0:96]), h1)
707     me.assertRaises(TypeError, key().concat, h0, key1().hash(m[96:117]))
708     me.assertRaises(ValueError, key().concat, key().hash(m[0:93]), h1)
709
710 ###--------------------------------------------------------------------------
711 class TestHLatin (U.TestCase):
712   """Test the `hsalsa20' and `hchacha20' functions."""
713
714   def test_hlatin(me):
715     kk = [T.span(sz) for sz in [10, 16, 32]]
716     n = T.span(16)
717     bad_k = T.span(18)
718     bad_n = T.span(13)
719     for fn in [C.hsalsa208_prf, C.hsalsa2012_prf, C.hsalsa20_prf,
720                C.hchacha8_prf, C.hchacha12_prf, C.hchacha20_prf]:
721       for k in kk:
722         h = fn(k, n)
723         me.assertEqual(len(h), 32)
724       me.assertRaises(ValueError, fn, bad_k, n)
725       me.assertRaises(ValueError, fn, k, bad_n)
726
727 ###--------------------------------------------------------------------------
728 class TestKeccak (HashBufferTestMixin):
729   """Test the Keccak-p[1600, n] sponge function."""
730
731   def test_keccak(me):
732
733     ## Make a state and feed some stuff into it.
734     m0 = T.bin("some initial string")
735     m1 = T.bin("awesome follow-up string")
736     st0 = C.Keccak1600()
737     me.assertEqual(st0.nround, 24)
738     st0.mix(m0).step()
739
740     ## Make another step with a different round count.
741     st1 = C.Keccak1600(23)
742     st1.mix(m0).step()
743     me.assertNotEqual(st0.extract(32), st1.extract(32))
744
745     ## Check state copying.
746     st1 = st0.copy()
747     mask = st1.extract(len(m1))
748     st0.mix(m1)
749     st1.mix(m1)
750     me.assertEqual(st0.extract(32), st1.extract(32))
751
752     ## Check error conditions.
753     _ = st0.extract(200)
754     me.assertRaises(ValueError, st0.extract, 201)
755     st0.mix(T.span(200))
756     me.assertRaises(ValueError, st0.mix, T.span(201))
757
758   def check_shake(me, xcls, c, done_matches_xof = True):
759     """
760     Test the SHAKE and cSHAKE XOFs.
761
762     This is also used for testing KMAC, but that sets DONE_MATCHES_XOF false
763     to indicate that the XOF output is range-separated from the fixed-length
764     outputs (unlike the basic SHAKE functions).
765     """
766
767     ## Check the hash attributes.
768     x = xcls()
769     me.assertEqual(x.rate, 200 - c)
770     me.assertEqual(x.buffered, 0)
771     me.assertEqual(x.state, "absorb")
772
773     ## Set some initial values.
774     func = T.bin("TESTXOF")
775     perso = T.bin("catacomb-python test")
776     m = T.span(167)
777     h0 = xcls().hash(m).done(193)
778     me.assertEqual(len(h0), 193)
779     h1 = xcls(func = func, perso = perso).hash(m).done(193)
780     me.assertEqual(len(h1), 193)
781     me.assertNotEqual(h0, h1)
782
783     ## Check input and output in pieces, and the state machine.
784     if done_matches_xof: h = h0
785     else: h = xcls().hash(m).xof().get(len(h0))
786     x = xcls().hash(m[0:76]).hash(m[76:167]).xof()
787     me.assertEqual(h, x.get(98) + x.get(95))
788
789     ## Check masking.
790     x = xcls().hash(m).xof()
791     me.assertEqual(x.mask(m), C.ByteString(m) ^ C.ByteString(h[0:len(m)]))
792
793     ## Check the `check' method.
794     me.assertTrue(xcls().hash(m).check(h0))
795     me.assertFalse(xcls().hash(m).check(h1))
796
797     ## Check the menagerie of random hashing methods.
798     def mkhash(_):
799       x = xcls(func = func, perso = perso)
800       return x, lambda: x.done(100 - x.rate//2)
801     me.check_hashbuffer(mkhash)
802
803     ## Check the state machine tracking.
804     x = xcls(); me.assertEqual(x.state, "absorb")
805     x.hash(m); me.assertEqual(x.state, "absorb")
806     xx = x.copy()
807     h = xx.done(100 - x.rate//2)
808     me.assertEqual(xx.state, "dead")
809     me.assertRaises(ValueError, xx.done, 1)
810     me.assertRaises(ValueError, xx.get, 1)
811     me.assertEqual(x.state, "absorb")
812     me.assertRaises(ValueError, x.get, 1)
813     x.xof(); me.assertEqual(x.state, "squeeze")
814     me.assertRaises(ValueError, x.done, 1)
815     _ = x.get(1)
816     yy = x.copy(); me.assertEqual(yy.state, "squeeze")
817
818   def test_shake128(me): me.check_shake(C.Shake128, 32)
819   def test_shake256(me): me.check_shake(C.Shake256, 64)
820
821   def check_kmac(me, mcls, c):
822     k = T.span(32)
823     me.check_shake(lambda func = None, perso = T.bin(""):
824                      mcls(k, perso = perso),
825                    c, done_matches_xof = False)
826
827   def test_kmac128(me): me.check_kmac(C.KMAC128, 32)
828   def test_kmac256(me): me.check_kmac(C.KMAC256, 64)
829
830 ###--------------------------------------------------------------------------
831 class TestPRP (T.GenericTestMixin):
832   """Test pseudorandom permutations (PRPs)."""
833
834   def _test_prp(me, pcls):
835
836     ## Check the PRP properties.
837     me.assertEqual(type(pcls.name), str)
838     me.assertTrue(isinstance(pcls.keysz, C.KeySZ))
839     me.assertEqual(type(pcls.blksz), int)
840
841     ## Check round-tripping.
842     k = T.span(pcls.keysz.default)
843     key = pcls(k)
844     m = T.span(pcls.blksz)
845     c = key.encrypt(m)
846     me.assertEqual(len(c), pcls.blksz)
847     me.assertEqual(m, key.decrypt(c))
848
849     ## Check that bad key lengths are rejected.
850     badlen = bad_key_size(pcls.keysz)
851     if badlen is not None: me.assertRaises(ValueError, pcls, T.span(badlen))
852
853     ## Check that bad blocks are rejected.
854     badblk = T.span(pcls.blksz + 1)
855     me.assertRaises(ValueError, key.encrypt, badblk)
856     me.assertRaises(ValueError, key.decrypt, badblk)
857
858 TestPRP.generate_testcases((name, C.gcprps[name]) for name in
859   ["desx", "blowfish", "rijndael"])
860
861 ###----- That's all, folks --------------------------------------------------
862
863 if __name__ == "__main__": U.main()