chiark / gitweb /
t/t-algorithms.py: Add tests for other HSalsa20 and HChaCha key sizes.
[catacomb-python] / t / t-algorithms.py
CommitLineData
553d59fe
MW
1### -*- mode: python, coding: utf-8 -*-
2###
3### Test symmetric algorithms
4###
5### (c) 2019 Straylight/Edgeware
6###
7
8###----- Licensing notice ---------------------------------------------------
9###
10### This file is part of the Python interface to Catacomb.
11###
12### Catacomb/Python is free software: you can redistribute it and/or
13### modify it under the terms of the GNU General Public License as
14### published by the Free Software Foundation; either version 2 of the
15### License, or (at your option) any later version.
16###
17### Catacomb/Python is distributed in the hope that it will be useful, but
18### WITHOUT ANY WARRANTY; without even the implied warranty of
19### MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
20### General Public License for more details.
21###
22### You should have received a copy of the GNU General Public License
23### along with Catacomb/Python. If not, write to the Free Software
24### Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307,
25### USA.
26
27###--------------------------------------------------------------------------
28### Imported modules.
29
30import catacomb as C
31import unittest as U
32import testutils as T
33
34###--------------------------------------------------------------------------
35### Utilities.
36
37def bad_key_size(ksz):
38 if isinstance(ksz, C.KeySZAny): return None
39 elif isinstance(ksz, C.KeySZRange):
40 if ksz.mod != 1: return ksz.min + 1
41 elif ksz.max != 0: return ksz.max + 1
42 elif ksz.min != 0: return ksz.min - 1
43 else: return None
44 elif isinstance(ksz, C.KeySZSet):
45 for sz in sorted(ksz.set):
46 if sz + 1 not in ksz.set: return sz + 1
47 assert False, "That should have worked."
48 else:
49 return None
50
51def different_key_size(ksz, sz):
52 if isinstance(ksz, C.KeySZAny): return sz + 1
53 elif isinstance(ksz, C.KeySZRange):
54 if sz > ksz.min: return sz - ksz.mod
55 elif ksz.max == 0 or sz < ksz.max: return sz + ksz.mod
56 else: return None
57 elif isinstance(ksz, C.KeySZSet):
58 for sz1 in sorted(ksz.set):
59 if sz != sz1: return sz1
60 return None
61 else:
62 return None
63
64class HashBufferTestMixin (U.TestCase):
65 """Mixin class for testing all of the various `hash...' methods."""
66
67 def check_hashbuffer_hashn(me, w, bigendp, makefn, hashfn):
68 """Check `hashuN'."""
69
70 ## Check encoding an integer.
71 h0, donefn0 = makefn(w + 2)
72 hashfn(h0.hashu8(0x00), T.bytes_as_int(w, bigendp)).hashu8(w + 1)
73 h1, donefn1 = makefn(w + 2)
74 h1.hash(T.span(w + 2))
75 me.assertEqual(donefn0(), donefn1())
76
77 ## Check overflow detection.
78 h0, _ = makefn(w)
79 me.assertRaises((OverflowError, ValueError),
80 hashfn, h0, 1 << 8*w)
81
82 def check_hashbuffer_bufn(me, w, bigendp, makefn, hashfn):
83 """Check `hashbufN'."""
84
85 ## Go through a number of different sizes.
86 for n in [0, 1, 7, 8, 19, 255, 12345, 65535, 123456]:
87 if n >= 1 << 8*w: continue
88 h0, donefn0 = makefn(2 + w + n)
89 hashfn(h0.hashu8(0x00), T.span(n)).hashu8(0xff)
90 h1, donefn1 = makefn(2 + w + n)
91 h1.hash(T.prep_lenseq(w, n, bigendp, True))
92 me.assertEqual(donefn0(), donefn1())
93
94 ## Check blocks which are too large for the length prefix.
95 if w <= 3:
96 n = 1 << 8*w
97 h0, _ = makefn(w + n)
98 me.assertRaises((ValueError, OverflowError, TypeError),
99 hashfn, h0, C.ByteString.zero(n))
100
101 def check_hashbuffer(me, makefn):
102 """Test the various `hash...' methods."""
103
104 ## Check `hashuN'.
105 me.check_hashbuffer_hashn(1, True, makefn, lambda h, n: h.hashu8(n))
106 me.check_hashbuffer_hashn(2, True, makefn, lambda h, n: h.hashu16(n))
107 me.check_hashbuffer_hashn(2, True, makefn, lambda h, n: h.hashu16b(n))
108 me.check_hashbuffer_hashn(2, False, makefn, lambda h, n: h.hashu16l(n))
109 if hasattr(makefn(0)[0], "hashu24"):
110 me.check_hashbuffer_hashn(3, True, makefn, lambda h, n: h.hashu24(n))
111 me.check_hashbuffer_hashn(3, True, makefn, lambda h, n: h.hashu24b(n))
112 me.check_hashbuffer_hashn(3, False, makefn, lambda h, n: h.hashu24l(n))
113 me.check_hashbuffer_hashn(4, True, makefn, lambda h, n: h.hashu32(n))
114 me.check_hashbuffer_hashn(4, True, makefn, lambda h, n: h.hashu32b(n))
115 me.check_hashbuffer_hashn(4, False, makefn, lambda h, n: h.hashu32l(n))
116 if hasattr(makefn(0)[0], "hashu64"):
117 me.check_hashbuffer_hashn(8, True, makefn, lambda h, n: h.hashu64(n))
118 me.check_hashbuffer_hashn(8, True, makefn, lambda h, n: h.hashu64b(n))
119 me.check_hashbuffer_hashn(8, False, makefn, lambda h, n: h.hashu64l(n))
120
121 ## Check `hashbufN'.
122 me.check_hashbuffer_bufn(1, True, makefn, lambda h, x: h.hashbuf8(x))
123 me.check_hashbuffer_bufn(2, True, makefn, lambda h, x: h.hashbuf16(x))
124 me.check_hashbuffer_bufn(2, True, makefn, lambda h, x: h.hashbuf16b(x))
125 me.check_hashbuffer_bufn(2, False, makefn, lambda h, x: h.hashbuf16l(x))
126 if hasattr(makefn(0)[0], "hashbuf24"):
127 me.check_hashbuffer_bufn(3, True, makefn, lambda h, x: h.hashbuf24(x))
128 me.check_hashbuffer_bufn(3, True, makefn, lambda h, x: h.hashbuf24b(x))
129 me.check_hashbuffer_bufn(3, False, makefn, lambda h, x: h.hashbuf24l(x))
130 me.check_hashbuffer_bufn(4, True, makefn, lambda h, x: h.hashbuf32(x))
131 me.check_hashbuffer_bufn(4, True, makefn, lambda h, x: h.hashbuf32b(x))
132 me.check_hashbuffer_bufn(4, False, makefn, lambda h, x: h.hashbuf32l(x))
133 if hasattr(makefn(0)[0], "hashbuf64"):
134 me.check_hashbuffer_bufn(8, True, makefn, lambda h, x: h.hashbuf64(x))
135 me.check_hashbuffer_bufn(8, True, makefn, lambda h, x: h.hashbuf64b(x))
136 me.check_hashbuffer_bufn(8, False, makefn, lambda h, x: h.hashbuf64l(x))
137
138###--------------------------------------------------------------------------
139class TestKeysize (U.TestCase):
140
141 def test_any(me):
142
143 ## A typical one-byte spec.
144 ksz = C.seal.keysz
145 me.assertEqual(type(ksz), C.KeySZAny)
146 me.assertEqual(ksz.default, 20)
147 me.assertEqual(ksz.min, 0)
148 me.assertEqual(ksz.max, 0)
149 for n in [0, 12, 20, 5000]:
150 me.assertTrue(ksz.check(n))
151 me.assertEqual(ksz.best(n), n)
606677f6 152 me.assertEqual(ksz.pad(n), n)
553d59fe
MW
153
154 ## A typical two-byte spec. (No published algorithms actually /need/ a
155 ## two-byte key-size spec, but all of the HMAC variants use one anyway.)
156 ksz = C.sha256_hmac.keysz
157 me.assertEqual(type(ksz), C.KeySZAny)
158 me.assertEqual(ksz.default, 32)
159 me.assertEqual(ksz.min, 0)
160 me.assertEqual(ksz.max, 0)
161 for n in [0, 12, 20, 5000]:
162 me.assertTrue(ksz.check(n))
163 me.assertEqual(ksz.best(n), n)
606677f6 164 me.assertEqual(ksz.pad(n), n)
553d59fe
MW
165
166 ## Check construction.
167 ksz = C.KeySZAny(15)
168 me.assertEqual(ksz.default, 15)
169 me.assertEqual(ksz.min, 0)
170 me.assertEqual(ksz.max, 0)
171 me.assertRaises(ValueError, lambda: C.KeySZAny(-8))
172 me.assertEqual(C.KeySZAny(0).default, 0)
173
174 def test_set(me):
175 ## Note that no published algorithm uses a 16-bit `set' spec.
176
177 ## A typical spec.
178 ksz = C.salsa20.keysz
179 me.assertEqual(type(ksz), C.KeySZSet)
180 me.assertEqual(ksz.default, 32)
181 me.assertEqual(ksz.min, 10)
182 me.assertEqual(ksz.max, 32)
183 me.assertEqual(set(ksz.set), set([10, 16, 32]))
184 for x, best, pad in [(9, None, 10), (10, 10, 10), (11, 10, 16),
185 (15, 10, 16), (16, 16, 16), (17, 16, 32),
186 (31, 16, 32), (32, 32, 32), (33, 32, None)]:
187 if x == best == pad: me.assertTrue(ksz.check(x))
188 else: me.assertFalse(ksz.check(x))
189 if best is None: me.assertRaises(ValueError, ksz.best, x)
190 else: me.assertEqual(ksz.best(x), best)
606677f6
MW
191 if pad is None: me.assertRaises(ValueError, ksz.pad, x)
192 else: me.assertEqual(ksz.pad(x), pad)
553d59fe
MW
193
194 ## Check construction.
195 ksz = C.KeySZSet(7)
196 me.assertEqual(ksz.default, 7)
197 me.assertEqual(set(ksz.set), set([7]))
198 me.assertEqual(ksz.min, 7)
199 me.assertEqual(ksz.max, 7)
200 ksz = C.KeySZSet(7, [3, 6, 9])
201 me.assertEqual(ksz.default, 7)
202 me.assertEqual(set(ksz.set), set([3, 6, 7, 9]))
203 me.assertEqual(ksz.min, 3)
204 me.assertEqual(ksz.max, 9)
205
206 def test_range(me):
207 ## Note that no published algorithm uses a 16-bit `range' spec, or an
208 ## unbounded `range'.
209
210 ## A typical spec.
211 ksz = C.rijndael.keysz
212 me.assertEqual(type(ksz), C.KeySZRange)
213 me.assertEqual(ksz.default, 32)
214 me.assertEqual(ksz.min, 4)
215 me.assertEqual(ksz.max, 32)
216 me.assertEqual(ksz.mod, 4)
606677f6
MW
217 for x, best, pad in [(3, None, 4), (4, 4, 4), (5, 4, 8),
218 (15, 12, 16), (16, 16, 16), (17, 16, 20),
219 (31, 28, 32), (32, 32, 32), (33, 32, None)]:
220 if x == best == pad: me.assertTrue(ksz.check(x))
553d59fe
MW
221 else: me.assertFalse(ksz.check(x))
222 if best is None: me.assertRaises(ValueError, ksz.best, x)
223 else: me.assertEqual(ksz.best(x), best)
606677f6
MW
224 if pad is None: me.assertRaises(ValueError, ksz.pad, x)
225 else: me.assertEqual(ksz.pad(x), pad)
553d59fe
MW
226
227 ## Check construction.
228 ksz = C.KeySZRange(28, 21, 35, 7)
229 me.assertEqual(ksz.default, 28)
230 me.assertEqual(ksz.min, 21)
231 me.assertEqual(ksz.max, 35)
232 me.assertEqual(ksz.mod, 7)
233 me.assertRaises(ValueError, C.KeySZRange, 29, 21, 35, 7)
234 me.assertRaises(ValueError, C.KeySZRange, 28, 20, 35, 7)
235 me.assertRaises(ValueError, C.KeySZRange, 28, 21, 34, 7)
236 me.assertRaises(ValueError, C.KeySZRange, 28, -7, 35, 7)
237 me.assertRaises(ValueError, C.KeySZRange, 28, 35, 21, 7)
238 me.assertRaises(ValueError, C.KeySZRange, 35, 21, 28, 7)
239 me.assertRaises(ValueError, C.KeySZRange, 21, 28, 35, 7)
240
241 def test_conversions(me):
242 me.assertEqual(C.KeySZ.fromec(256), 128)
243 me.assertEqual(C.KeySZ.fromschnorr(256), 128)
244 me.assertEqual(round(C.KeySZ.fromdl(2958.6875)), 128)
245 me.assertEqual(round(C.KeySZ.fromif(2958.6875)), 128)
246 me.assertEqual(C.KeySZ.toec(128), 256)
247 me.assertEqual(C.KeySZ.toschnorr(128), 256)
248 me.assertEqual(C.KeySZ.todl(128), 2958.6875)
249 me.assertEqual(C.KeySZ.toif(128), 2958.6875)
250
251###--------------------------------------------------------------------------
252class TestCipher (T.GenericTestMixin):
253 """Test basic symmetric ciphers."""
254
255 def _test_cipher(me, ccls):
256
257 ## Check the class properties.
258 me.assertEqual(type(ccls.name), str)
259 me.assertTrue(isinstance(ccls.keysz, C.KeySZ))
260 me.assertEqual(type(ccls.blksz), int)
261
262 ## Check round-tripping.
263 k = T.span(ccls.keysz.default)
264 iv = T.span(ccls.blksz)
265 m = T.span(253)
266 enc = ccls(k)
267 dec = ccls(k)
268 try: enc.setiv(iv)
269 except ValueError: can_setiv = False
270 else:
271 can_setiv = True
272 dec.setiv(iv)
273 c0 = enc.encrypt(m[0:57])
274 m0 = dec.decrypt(c0)
275 c1 = enc.encrypt(m[57:189])
276 m1 = dec.decrypt(c1)
277 try: enc.bdry()
278 except ValueError: can_bdry = False
279 else:
280 dec.bdry()
281 can_bdry = True
282 c2 = enc.encrypt(m[189:253])
283 m2 = dec.decrypt(c2)
284 me.assertEqual(len(c0) + len(c1) + len(c2), len(m))
285 me.assertEqual(m0, m[0:57])
286 me.assertEqual(m1, m[57:189])
287 me.assertEqual(m2, m[189:253])
288
289 ## Check the `enczero' and `deczero' methods.
290 c3 = enc.enczero(32)
291 me.assertEqual(dec.decrypt(c3), C.ByteString.zero(32))
292 m4 = dec.deczero(32)
293 me.assertEqual(enc.encrypt(m4), C.ByteString.zero(32))
294
295 ## Check that ciphers which support a `boundary' operation actually
296 ## need it.
297 if can_bdry:
298 dec = ccls(k)
299 if can_setiv: dec.setiv(iv)
300 m01 = dec.decrypt(c0 + c1)
301 me.assertEqual(m01, m[0:189])
302
303 ## Check that the boundary actually does something.
304 if can_bdry:
305 dec = ccls(k)
306 if can_setiv: dec.setiv(iv)
307 m012 = dec.decrypt(c0 + c1 + c2)
308 me.assertNotEqual(m012, m)
309
310 ## Check that bad key lengths are rejected.
311 badlen = bad_key_size(ccls.keysz)
312 if badlen is not None: me.assertRaises(ValueError, ccls, T.span(badlen))
313
314TestCipher.generate_testcases((name, C.gcciphers[name]) for name in
315 ["des-ecb", "rijndael-cbc", "twofish-cfb", "serpent-ofb",
316 "blowfish-counter", "rc4", "seal", "salsa20/8", "shake128-xof"])
317
10f3f611
MW
318###--------------------------------------------------------------------------
319class TestAuthenticatedEncryption \
320 (HashBufferTestMixin, T.GenericTestMixin):
321 """Test authenticated encryption schemes."""
322
323 def _test_aead(me, aecls):
324
325 ## Check the class properties.
326 me.assertEqual(type(aecls.name), str)
327 me.assertTrue(isinstance(aecls.keysz, C.KeySZ))
328 me.assertTrue(isinstance(aecls.noncesz, C.KeySZ))
329 me.assertTrue(isinstance(aecls.tagsz, C.KeySZ))
330 me.assertEqual(type(aecls.blksz), int)
331 me.assertEqual(type(aecls.bufsz), int)
332 me.assertEqual(type(aecls.ohd), int)
333 me.assertEqual(type(aecls.flags), int)
334
335 ## Check round-tripping, with full precommitment. First, select some
336 ## parameters. (It's conceivable that some AEAD schemes are more
337 ## restrictive than advertised by the various properties, but this works
338 ## out OK in practice.)
339 k = T.span(aecls.keysz.default)
340 n = T.span(aecls.noncesz.default)
341 if aecls.flags&C.AEADF_NOAAD: h = T.span(0)
342 else: h = T.span(131)
343 m = T.span(253)
344 tsz = aecls.tagsz.default
345 key = aecls(k)
346
347 ## Next, encrypt a message, checking that things are proper as we go.
348 enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
349 me.assertEqual(enc.hsz, len(h))
350 me.assertEqual(enc.msz, len(m))
351 me.assertEqual(enc.mlen, 0)
352 me.assertEqual(enc.tsz, tsz)
353 aad = enc.aad()
354 if aecls.flags&C.AEADF_AADNDEP: me.assertEqual(aad.hsz, len(h))
355 else: me.assertEqual(aad.hsz, None)
356 me.assertEqual(aad.hlen, 0)
357 if not aecls.flags&C.AEADF_NOAAD:
358 aad.hash(h[0:83])
359 me.assertEqual(aad.hlen, 83)
360 aad.hash(h[83:131])
361 me.assertEqual(aad.hlen, 131)
362 c0 = enc.encrypt(m[0:57])
363 me.assertEqual(enc.mlen, 57)
364 me.assertTrue(57 - aecls.bufsz <= len(c0) <= 57 + aecls.ohd)
365 c1 = enc.encrypt(m[57:189])
366 me.assertEqual(enc.mlen, 189)
367 me.assertTrue(132 - aecls.bufsz <= len(c1) <=
368 132 + aecls.bufsz + aecls.ohd)
369 c2 = enc.encrypt(m[189:253])
370 me.assertEqual(enc.mlen, 253)
371 me.assertTrue(64 - aecls.bufsz <= len(c2) <=
372 64 + aecls.bufsz + aecls.ohd)
373 c3, t = enc.done(aad = aad)
374 me.assertTrue(len(c3) <= aecls.bufsz + aecls.ohd)
375 c = c0 + c1 + c2 + c3
376 me.assertTrue(len(m) <= len(c) <= len(m) + aecls.ohd)
377 me.assertEqual(len(t), tsz)
378
379 ## And now decrypt it again, with different record boundaries.
380 dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
381 me.assertEqual(dec.hsz, len(h))
382 me.assertEqual(dec.csz, len(c))
383 me.assertEqual(dec.clen, 0)
384 me.assertEqual(dec.tsz, tsz)
385 aad = dec.aad()
386 if aecls.flags&C.AEADF_AADNDEP: me.assertEqual(aad.hsz, len(h))
387 else: me.assertEqual(aad.hsz, None)
388 me.assertEqual(aad.hlen, 0)
389 aad.hash(h)
390 m0 = dec.decrypt(c[0:156])
391 me.assertTrue(156 - aecls.bufsz <= len(m0) <= 156)
392 m1 = dec.decrypt(c[156:])
393 me.assertTrue(len(c) - 156 - aecls.bufsz <= len(m1) <=
394 len(c) - 156 + aecls.bufsz)
395 m2 = dec.done(tag = t, aad = aad)
396 me.assertEqual(m0 + m1 + m2, m)
397
398 ## And again, with the wrong tag.
399 dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
400 aad = dec.aad(); aad.hash(h)
401 _ = dec.decrypt(c)
402 me.assertRaises(ValueError, dec.done, tag = t ^ tsz*C.bytes("55"))
403
404 ## Check that the all-in-one methods work.
405 me.assertEqual((c, t),
406 key.encrypt(n = n, h = h, m = m, tsz = tsz))
407 me.assertEqual(m,
408 key.decrypt(n = n, h = h, c = c, t = t))
409
410 ## Check that bad key, nonce, and tag lengths are rejected.
411 badlen = bad_key_size(aecls.keysz)
412 if badlen is not None: me.assertRaises(ValueError, aecls, T.span(badlen))
413 badlen = bad_key_size(aecls.noncesz)
414 if badlen is not None:
415 me.assertRaises(ValueError, key.enc, nonce = T.span(badlen),
416 hsz = len(h), msz = len(m), tsz = tsz)
417 me.assertRaises(ValueError, key.dec, nonce = T.span(badlen),
418 hsz = len(h), csz = len(c), tsz = tsz)
419 if not aecls.flags&C.AEADF_PCTSZ:
420 enc = key.enc(nonce = n, hsz = 0, msz = len(m))
421 _ = enc.encrypt(m)
422 me.assertRaises(ValueError, enc.done, tsz = badlen)
423 badlen = bad_key_size(aecls.tagsz)
424 if badlen is not None:
425 me.assertRaises(ValueError, key.enc, nonce = n,
426 hsz = len(h), msz = len(m), tsz = badlen)
427 me.assertRaises(ValueError, key.dec, nonce = n,
428 hsz = len(h), csz = len(c), tsz = badlen)
429
430 ## Check that we can't get a loose `aad' object from a scheme which has
431 ## nonce-dependent AAD processing.
432 if aecls.flags&C.AEADF_AADNDEP: me.assertRaises(ValueError, key.aad)
433
434 ## Check the menagerie of AAD hashing methods.
435 if not aecls.flags&C.AEADF_NOAAD:
436 def mkhash(hsz):
437 enc = key.enc(nonce = n, hsz = hsz, msz = 0, tsz = tsz)
438 aad = enc.aad()
439 return aad, lambda: enc.done(aad = aad)[1]
440 me.check_hashbuffer(mkhash)
441
442 ## Check that encryption/decryption works with the given precommitments.
443 def quick_enc_check(**kw):
444 enc = key.enc(**kw)
445 aad = enc.aad().hash(h)
446 c0 = enc.encrypt(m); c1, tt = enc.done(aad = aad, tsz = tsz)
447 me.assertEqual((c, t), (c0 + c1, tt))
448 def quick_dec_check(**kw):
449 dec = key.dec(**kw)
450 aad = dec.aad().hash(h)
451 m0 = dec.decrypt(c); m1 = dec.done(aad = aad, tag = t)
452 me.assertEqual(m, m0 + m1)
453
454 ## Check that we can get away without precommitting to the header length
455 ## if and only if the AEAD scheme says it will let us.
456 if aecls.flags&C.AEADF_PCHSZ:
457 me.assertRaises(ValueError, key.enc, nonce = n,
458 msz = len(m), tsz = tsz)
459 me.assertRaises(ValueError, key.dec, nonce = n,
460 csz = len(c), tsz = tsz)
461 else:
462 quick_enc_check(nonce = n, msz = len(m), tsz = tsz)
463 quick_dec_check(nonce = n, csz = len(c), tsz = tsz)
464
465 ## Check that we can get away without precommitting to the message/
466 ## ciphertext length if and only if the AEAD scheme says it will let us.
467 if aecls.flags&C.AEADF_PCMSZ:
468 me.assertRaises(ValueError, key.enc, nonce = n,
469 hsz = len(h), tsz = tsz)
470 me.assertRaises(ValueError, key.dec, nonce = n,
471 hsz = len(h), tsz = tsz)
472 else:
473 quick_enc_check(nonce = n, hsz = len(h), tsz = tsz)
474 quick_dec_check(nonce = n, hsz = len(h), tsz = tsz)
475
476 ## Check that we can get away without precommitting to the tag length if
477 ## and only if the AEAD scheme says it will let us.
478 if aecls.flags&C.AEADF_PCTSZ:
479 me.assertRaises(ValueError, key.enc, nonce = n,
480 hsz = len(h), msz = len(m))
481 me.assertRaises(ValueError, key.dec, nonce = n,
482 hsz = len(h), csz = len(c))
483 else:
484 quick_enc_check(nonce = n, hsz = len(h), msz = len(m))
485 quick_dec_check(nonce = n, hsz = len(h), csz = len(c))
486
487 ## Check that if we precommit to the header length, we're properly held
488 ## to the commitment.
489 if not aecls.flags&C.AEADF_NOAAD:
490
491 ## First, check encryption with underrun. If we must supply AAD first,
492 ## then the underrun will be reported when we start trying to encrypt;
493 ## otherwise, checking is delayed until `done'.
494 enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
495 aad = enc.aad().hash(h[0:83])
496 if aecls.flags&C.AEADF_AADFIRST:
497 me.assertRaises(ValueError, enc.encrypt, m)
498 else:
499 _ = enc.encrypt(m)
500 me.assertRaises(ValueError, enc.done, aad = aad)
501
502 ## Next, check decryption with underrun. If we must supply AAD first,
503 ## then the underrun will be reported when we start trying to encrypt;
504 ## otherwise, checking is delayed until `done'.
505 dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
506 aad = dec.aad().hash(h[0:83])
507 if aecls.flags&C.AEADF_AADFIRST:
508 me.assertRaises(ValueError, dec.decrypt, c)
509 else:
510 _ = dec.decrypt(c)
511 me.assertRaises(ValueError, dec.done, tag = t, aad = aad)
512
513 ## If AAD processing is nonce-dependent then an overrun will be
514 ## detected imediately.
515 if aecls.flags&C.AEADF_AADNDEP:
516 enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
517 aad = enc.aad().hash(h[0:83])
518 me.assertRaises(ValueError, aad.hash, h[82:131])
519 dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz)
520 aad = dec.aad().hash(h[0:83])
521 me.assertRaises(ValueError, aad.hash, h[82:131])
522
523 ## Some additional tests for nonce-dependent `aad' objects.
524 if aecls.flags&C.AEADF_AADNDEP:
525
526 ## Check that `aad' objects can't be used once their parents are gone.
527 enc = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
528 aad = enc.aad()
529 del enc
530 me.assertRaises(ValueError, aad.hash, h)
531
532 ## Check that they can't be crossed over.
533 enc0 = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
534 enc1 = key.enc(nonce = n, hsz = len(h), msz = len(m), tsz = tsz)
535 enc0.aad().hash(h)
536 aad1 = enc1.aad().hash(h)
537 _ = enc0.encrypt(m)
538 me.assertRaises(ValueError, enc0.done, tsz = tsz, aad = aad1)
539
540 ## Test copying AAD.
541 if not aecls.flags&C.AEADF_AADNDEP and not aecls.flags&C.AEADF_NOAAD:
542 aad0 = key.aad()
543 aad0.hash(h[0:83])
544 aad1 = aad0.copy()
545 aad2 = aad1.copy()
546 aad0.hash(h[83:131])
547 aad1.hash(h[83:131])
548 aad2.hash(h[83:131] ^ 48*C.bytes("ff"))
549 me.assertEqual(key.enc(nonce = n, hsz = len(h),
550 msz = 0, tsz = tsz).done(aad = aad0),
551 key.enc(nonce = n, hsz = len(h),
552 msz = 0, tsz = tsz).done(aad = aad1))
553 me.assertNotEqual(key.enc(nonce = n, hsz = len(h),
554 msz = 0, tsz = tsz).done(aad = aad0),
555 key.enc(nonce = n, hsz = len(h),
556 msz = 0, tsz = tsz).done(aad = aad2))
557
558 ## Check that if we precommit to the message length, we're properly held
559 ## to the commitment. (Fortunately, this is way simpler than the AAD
560 ## case above.) First, try an underrun.
561 enc = key.enc(nonce = n, hsz = 0, msz = len(m), tsz = tsz)
562 _ = enc.encrypt(m[0:183])
563 me.assertRaises(ValueError, enc.done, tsz = tsz)
564 dec = key.dec(nonce = n, hsz = 0, csz = len(c), tsz = tsz)
565 _ = dec.decrypt(c[0:183])
566 me.assertRaises(ValueError, dec.done, tag = t)
567
568 ## And now an overrun.
569 enc = key.enc(nonce = n, hsz = 0, msz = 183, tsz = tsz)
570 me.assertRaises(ValueError, enc.encrypt, m)
571 dec = key.dec(nonce = n, hsz = 0, csz = 183, tsz = tsz)
572 me.assertRaises(ValueError, dec.decrypt, c)
573
574 ## Finally, check that if we precommit to a tag length, we're properly
575 ## held to the commitment. This depends on being able to find a tag size
576 ## which isn't the default.
577 tsz1 = different_key_size(aecls.tagsz, tsz)
578 if tsz1 is not None:
579 enc = key.enc(nonce = n, hsz = 0, msz = len(m), tsz = tsz1)
580 _ = enc.encrypt(m)
581 me.assertRaises(ValueError, enc.done, tsz = tsz)
582 dec = key.dec(nonce = n, hsz = len(h), csz = len(c), tsz = tsz1)
583 aad = dec.aad().hash(h)
584 _ = dec.decrypt(c)
585 me.assertRaises(ValueError, enc.done, tsz = tsz, aad = aad)
586
587TestAuthenticatedEncryption.generate_testcases \
588 ((name, C.gcaeads[name]) for name in
589 ["des3-ccm", "blowfish-ocb1", "square-ocb3", "rijndael-gcm",
590 "serpent-eax", "salsa20-naclbox", "chacha20-poly1305"])
591
553d59fe
MW
592###--------------------------------------------------------------------------
593class BaseTestHash (HashBufferTestMixin):
594 """Base class for testing hash functions."""
595
596 def check_hash(me, hcls, need_bufsz = True):
597 """
598 Check hash class HCLS.
599
600 If NEED_BUFSZ is false, then don't insist that HCLS have working `bufsz',
601 `name', or `hashsz' attributes. This test is mostly reused for MACs,
602 which don't have these attributes.
603 """
604 ## Check the class properties.
605 if need_bufsz:
606 me.assertEqual(type(hcls.name), str)
607 me.assertEqual(type(hcls.bufsz), int)
608 me.assertEqual(type(hcls.hashsz), int)
609
610 ## Set some initial values.
611 m = T.span(131)
612 h = hcls().hash(m).done()
613
614 ## Check that hash length comes out right.
615 if need_bufsz: me.assertEqual(len(h), hcls.hashsz)
616
617 ## Check that we get the same answer if we split the message up.
618 me.assertEqual(h, hcls().hash(m[0:73]).hash(m[73:131]).done())
619
620 ## Check the `check' method.
621 me.assertTrue(hcls().hash(m).check(h))
622 me.assertFalse(hcls().hash(m).check(h ^ len(h)*C.bytes("aa")))
623
624 ## Check the menagerie of random hashing methods.
625 def mkhash(_):
626 h = hcls()
627 return h, h.done
628 me.check_hashbuffer(mkhash)
629
630class TestHash (BaseTestHash, T.GenericTestMixin):
631 """Test hash functions."""
632 def _test_hash(me, hcls): me.check_hash(hcls, need_bufsz = True)
633
634TestHash.generate_testcases((name, C.gchashes[name]) for name in
635 ["md5", "sha", "whirlpool", "sha256", "sha512/224", "sha3-384", "shake256",
636 "crc32"])
637
638###--------------------------------------------------------------------------
639class TestMessageAuthentication (BaseTestHash, T.GenericTestMixin):
640 """Test message authentication codes."""
641
642 def _test_mac(me, mcls):
643
644 ## Check the MAC properties.
645 me.assertEqual(type(mcls.name), str)
646 me.assertTrue(isinstance(mcls.keysz, C.KeySZ))
647 me.assertEqual(type(mcls.tagsz), int)
648
649 ## Test hashing.
650 k = T.span(mcls.keysz.default)
651 key = mcls(k)
652 me.check_hash(key, need_bufsz = False)
653
654 ## Check that bad key lengths are rejected.
655 badlen = bad_key_size(mcls.keysz)
656 if badlen is not None: me.assertRaises(ValueError, mcls, T.span(badlen))
657
658TestMessageAuthentication.generate_testcases \
659 ((name, C.gcmacs[name]) for name in
660 ["sha-hmac", "rijndael-cmac", "twofish-pmac1", "kmac128"])
661
662class TestPoly1305 (HashBufferTestMixin):
663 """Check the Poly1305 one-time message authentication function."""
664
665 def test_poly1305(me):
666
667 ## Check the MAC properties.
668 me.assertEqual(C.poly1305.name, "poly1305")
669 me.assertEqual(type(C.poly1305.keysz), C.KeySZSet)
670 me.assertEqual(C.poly1305.keysz.default, 16)
671 me.assertEqual(set(C.poly1305.keysz.set), set([16]))
672 me.assertEqual(C.poly1305.tagsz, 16)
673 me.assertEqual(C.poly1305.masksz, 16)
674
675 ## Set some initial values.
676 k = T.span(16)
677 u = T.span(64)[-16:]
678 m = T.span(149)
679 key = C.poly1305(k)
680 t = key(u).hash(m).done()
681
682 ## Check the key properties.
683 me.assertEqual(len(t), 16)
684
685 ## Check that we get the same answer if we split the message up.
686 me.assertEqual(t, key(u).hash(m[0:86]).hash(m[86:149]).done())
687
688 ## Check the `check' method.
689 me.assertTrue(key(u).hash(m).check(t))
690 me.assertFalse(key(u).hash(m).check(t ^ 16*C.bytes("cc")))
691
692 ## Check the menagerie of random hashing methods.
693 def mkhash(_):
694 h = key(u)
695 return h, h.done
696 me.check_hashbuffer(mkhash)
697
698 ## Check that we can't complete hashing without a mask.
699 me.assertRaises(ValueError, key().hash(m).done)
700
701 ## Check `concat'.
702 h0 = key().hash(m[0:96])
703 h1 = key().hash(m[96:117])
704 me.assertEqual(t, key(u).concat(h0, h1).hash(m[117:149]).done())
705 key1 = C.poly1305(k)
706 me.assertRaises(TypeError, key().concat, key1().hash(m[0:96]), h1)
707 me.assertRaises(TypeError, key().concat, h0, key1().hash(m[96:117]))
708 me.assertRaises(ValueError, key().concat, key().hash(m[0:93]), h1)
709
710###--------------------------------------------------------------------------
711class TestHLatin (U.TestCase):
712 """Test the `hsalsa20' and `hchacha20' functions."""
713
714 def test_hlatin(me):
3b5f9ac0 715 kk = [T.span(sz) for sz in [10, 16, 32]]
553d59fe
MW
716 n = T.span(16)
717 bad_k = T.span(18)
718 bad_n = T.span(13)
719 for fn in [C.hsalsa208_prf, C.hsalsa2012_prf, C.hsalsa20_prf,
720 C.hchacha8_prf, C.hchacha12_prf, C.hchacha20_prf]:
721 for k in kk:
722 h = fn(k, n)
723 me.assertEqual(len(h), 32)
724 me.assertRaises(ValueError, fn, bad_k, n)
725 me.assertRaises(ValueError, fn, k, bad_n)
726
727###--------------------------------------------------------------------------
728class TestKeccak (HashBufferTestMixin):
729 """Test the Keccak-p[1600, n] sponge function."""
730
731 def test_keccak(me):
732
733 ## Make a state and feed some stuff into it.
734 m0 = T.bin("some initial string")
735 m1 = T.bin("awesome follow-up string")
736 st0 = C.Keccak1600()
737 me.assertEqual(st0.nround, 24)
738 st0.mix(m0).step()
739
740 ## Make another step with a different round count.
741 st1 = C.Keccak1600(23)
742 st1.mix(m0).step()
743 me.assertNotEqual(st0.extract(32), st1.extract(32))
744
745 ## Check error conditions.
746 _ = st0.extract(200)
747 me.assertRaises(ValueError, st0.extract, 201)
748 st0.mix(T.span(200))
749 me.assertRaises(ValueError, st0.mix, T.span(201))
750
751 def check_shake(me, xcls, c, done_matches_xof = True):
752 """
753 Test the SHAKE and cSHAKE XOFs.
754
755 This is also used for testing KMAC, but that sets DONE_MATCHES_XOF false
756 to indicate that the XOF output is range-separated from the fixed-length
757 outputs (unlike the basic SHAKE functions).
758 """
759
760 ## Check the hash attributes.
761 x = xcls()
762 me.assertEqual(x.rate, 200 - c)
763 me.assertEqual(x.buffered, 0)
764 me.assertEqual(x.state, "absorb")
765
766 ## Set some initial values.
767 func = T.bin("TESTXOF")
768 perso = T.bin("catacomb-python test")
769 m = T.span(167)
770 h0 = xcls().hash(m).done(193)
771 me.assertEqual(len(h0), 193)
772 h1 = xcls(func = func, perso = perso).hash(m).done(193)
773 me.assertEqual(len(h1), 193)
774 me.assertNotEqual(h0, h1)
775
776 ## Check input and output in pieces, and the state machine.
777 if done_matches_xof: h = h0
778 else: h = xcls().hash(m).xof().get(len(h0))
779 x = xcls().hash(m[0:76]).hash(m[76:167]).xof()
780 me.assertEqual(h, x.get(98) + x.get(95))
781
782 ## Check masking.
783 x = xcls().hash(m).xof()
784 me.assertEqual(x.mask(m), C.ByteString(m) ^ C.ByteString(h[0:len(m)]))
785
786 ## Check the `check' method.
787 me.assertTrue(xcls().hash(m).check(h0))
788 me.assertFalse(xcls().hash(m).check(h1))
789
790 ## Check the menagerie of random hashing methods.
791 def mkhash(_):
792 x = xcls(func = func, perso = perso)
793 return x, lambda: x.done(100 - x.rate//2)
794 me.check_hashbuffer(mkhash)
795
796 ## Check the state machine tracking.
797 x = xcls(); me.assertEqual(x.state, "absorb")
798 x.hash(m); me.assertEqual(x.state, "absorb")
799 xx = x.copy()
800 h = xx.done(100 - x.rate//2)
801 me.assertEqual(xx.state, "dead")
802 me.assertRaises(ValueError, xx.done, 1)
803 me.assertRaises(ValueError, xx.get, 1)
804 me.assertEqual(x.state, "absorb")
805 me.assertRaises(ValueError, x.get, 1)
806 x.xof(); me.assertEqual(x.state, "squeeze")
807 me.assertRaises(ValueError, x.done, 1)
808 _ = x.get(1)
809 yy = x.copy(); me.assertEqual(yy.state, "squeeze")
810
811 def test_shake128(me): me.check_shake(C.Shake128, 32)
812 def test_shake256(me): me.check_shake(C.Shake256, 64)
813
814 def check_kmac(me, mcls, c):
815 k = T.span(32)
816 me.check_shake(lambda func = None, perso = T.bin(""):
817 mcls(k, perso = perso),
818 c, done_matches_xof = False)
819
820 def test_kmac128(me): me.check_kmac(C.KMAC128, 32)
821 def test_kmac256(me): me.check_kmac(C.KMAC256, 64)
822
823###--------------------------------------------------------------------------
824class TestPRP (T.GenericTestMixin):
825 """Test pseudorandom permutations (PRPs)."""
826
827 def _test_prp(me, pcls):
828
829 ## Check the PRP properties.
830 me.assertEqual(type(pcls.name), str)
831 me.assertTrue(isinstance(pcls.keysz, C.KeySZ))
832 me.assertEqual(type(pcls.blksz), int)
833
834 ## Check round-tripping.
835 k = T.span(pcls.keysz.default)
836 key = pcls(k)
837 m = T.span(pcls.blksz)
838 c = key.encrypt(m)
839 me.assertEqual(len(c), pcls.blksz)
840 me.assertEqual(m, key.decrypt(c))
841
842 ## Check that bad key lengths are rejected.
843 badlen = bad_key_size(pcls.keysz)
844 if badlen is not None: me.assertRaises(ValueError, pcls, T.span(badlen))
845
846 ## Check that bad blocks are rejected.
847 badblk = T.span(pcls.blksz + 1)
848 me.assertRaises(ValueError, key.encrypt, badblk)
849 me.assertRaises(ValueError, key.decrypt, badblk)
850
851TestPRP.generate_testcases((name, C.gcprps[name]) for name in
852 ["desx", "blowfish", "rijndael"])
853
854###----- That's all, folks --------------------------------------------------
855
856if __name__ == "__main__": U.main()