chiark / gitweb /
@@@ crypto-test
[secnet] / crypto-test.c
1 /*
2  * crypto-test.c: common test vector processing
3  */
4 /*
5  * This file is Free Software.  It was originally written for secnet.
6  *
7  * Copyright 2017 Mark Wooding
8  *
9  * You may redistribute secnet as a whole and/or modify it under the
10  * terms of the GNU General Public License as published by the Free
11  * Software Foundation; either version 3, or (at your option) any
12  * later version.
13  *
14  * You may redistribute this file and/or modify it under the terms of
15  * the GNU General Public License as published by the Free Software
16  * Foundation; either version 2, or (at your option) any later
17  * version.
18  *
19  * This software is distributed in the hope that it will be useful,
20  * but WITHOUT ANY WARRANTY; without even the implied warranty of
21  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
22  * GNU General Public License for more details.
23  *
24  * You should have received a copy of the GNU General Public License
25  * along with this software; if not, see
26  * https://www.gnu.org/licenses/gpl.html.
27  */
28
29 #include <assert.h>
30 #include <errno.h>
31 #include <ctype.h>
32 #include <stdarg.h>
33 #include <stdio.h>
34 #include <stdlib.h>
35 #include <string.h>
36
37 #include "secnet.h"
38 #include "util.h"
39
40 #include "crypto-test.h"
41
42 /*----- Utilities ---------------------------------------------------------*/
43
44 static void *xmalloc(size_t sz)
45 {
46     void *p;
47
48     if (!sz) return 0;
49     p = malloc(sz);
50     if (!p) {
51         fprintf(stderr, "out of memory!\n");
52         exit(2);
53     }
54     return p;
55 }
56
57 static void *xrealloc(void *p, size_t sz)
58 {
59     void *q;
60
61     if (!sz) { free(p); return 0; }
62     else if (!p) return xmalloc(sz);
63     q = realloc(p, sz);
64     if (!q) {
65         fprintf(stderr, "out of memory!\n");
66         exit(2);
67     }
68     return q;
69 }
70
71 static int lno;
72
73 void bail(const char *msg, ...)
74 {
75     va_list ap;
76     va_start(ap, msg);
77     fprintf(stderr, "unexpected error (line %d): ", lno);
78     vfprintf(stderr, msg, ap);
79     va_end(ap);
80     fputc('\n', stderr);
81     exit(2);
82 }
83
84 struct linebuf {
85     char *p;
86     size_t sz;
87 };
88 #define LINEBUF_INIT { 0, 0 };
89
90 static int read_line(struct linebuf *b, FILE *fp)
91 {
92     size_t n = 0;
93     int ch;
94
95     ch = getc(fp); if (ch == EOF) return EOF;
96     for (;;) {
97         if (n >= b->sz) {
98             b->sz = b->sz ? 2*b->sz : 64;
99             b->p = xrealloc(b->p, b->sz);
100         }
101         if (ch == EOF || ch == '\n') { b->p[n++] = 0; return 0; }
102         b->p[n++] = ch;
103         ch = getc(fp);
104     }
105 }
106
107 void parse_hex(uint8_t *b, size_t sz, char *p)
108 {
109     size_t n = strlen(p);
110     unsigned i;
111     char bb[3];
112
113     if (n%2) bail("bad hex (odd number of nibbles)");
114     else if (n/2 != sz) bail("bad hex (want %zu bytes, found %zu)", sz, n/2);
115     while (sz) {
116         for (i = 0; i < 2; i++) {
117             if (!isxdigit((unsigned char)p[i]))
118                 bail("bad hex digit `%c'", p[i]);
119             bb[i] = p[i];
120         }
121         bb[2] = 0;
122         p += 2;
123         *b++ = strtoul(bb, 0, 16); sz--;
124     }
125 }
126
127 void dump_hex(FILE *fp, const uint8_t *b, size_t sz)
128     { while (sz--) fprintf(fp, "%02x", *b++); fputc('\n', fp); }
129
130 void trivial_regty_init(union regval *v) { ; }
131 void trivial_regty_release(union regval *v) { ; }
132
133 /* Define some global variables we shouldn't need.
134  *
135  * Annoyingly, `secnet.h' declares static pointers and initializes them to
136  * point to some external variables.  At `-O0', GCC doesn't optimize these
137  * away, so there's a link-time dependency on these variables.  Define them
138  * here, so that `f25519.c' and `f448.c' can find them.
139  *
140  * (Later GCC has `-Og', which optimizes without making debugging a
141  * nightmare, but I'm not running that version here.  Note that `serpent.c'
142  * doesn't have this problem because it defines its own word load and store
143  * operations to cope with its endian weirdness, whereas the field arithmetic
144  * uses `unaligned.h' which manages to include `secnet.h'.)
145  */
146 uint64_t now_global;
147 struct timeval tv_now_global;
148
149 /* Bletch.  util.c is a mess of layers. */
150 int consttime_memeq(const void *s1in, const void *s2in, size_t n)
151 {
152     const uint8_t *s1=s1in, *s2=s2in;
153     register volatile uint8_t accumulator=0;
154
155     while (n-- > 0) {
156         accumulator |= (*s1++ ^ *s2++);
157     }
158     accumulator |= accumulator >> 4; /* constant-time             */
159     accumulator |= accumulator >> 2; /*  boolean canonicalisation */
160     accumulator |= accumulator >> 1;
161     accumulator &= 1;
162     accumulator ^= 1;
163     return accumulator;
164 }
165
166 /*----- Built-in types ----------------------------------------------------*/
167
168 /* Signed integer. */
169
170 static void parse_int(union regval *v, char *p)
171 {
172     char *q;
173
174     errno = 0;
175     v->i = strtol(p, &q, 0);
176     if (*q || errno) bail("bad integer `%s'", p);
177 }
178
179 static void dump_int(FILE *fp, const union regval *v)
180     { fprintf(fp, "%ld\n", v->i); }
181
182 static int eq_int(const union regval *v0, const union regval *v1)
183     { return (v0->i == v1->i); }
184
185 const struct regty regty_int = {
186     trivial_regty_init,
187     parse_int,
188     dump_int,
189     eq_int,
190     trivial_regty_release
191 };
192
193 /* Unsigned integer. */
194
195 static void parse_uint(union regval *v, char *p)
196 {
197     char *q;
198
199     errno = 0;
200     v->u = strtoul(p, &q, 0);
201     if (*q || errno) bail("bad integer `%s'", p);
202 }
203
204 static void dump_uint(FILE *fp, const union regval *v)
205     { fprintf(fp, "%lu\n", v->u); }
206
207 static int eq_uint(const union regval *v0, const union regval *v1)
208     { return (v0->u == v1->u); }
209
210 const struct regty regty_uint = {
211     trivial_regty_init,
212     parse_uint,
213     dump_uint,
214     eq_uint,
215     trivial_regty_release
216 };
217
218 /* Byte string, as hex. */
219
220 void allocate_bytes(union regval *v, size_t sz)
221     { v->bytes.p = xmalloc(sz); v->bytes.sz = sz; }
222
223 static void init_bytes(union regval *v) { v->bytes.p = 0; v->bytes.sz = 0; }
224
225 static void parse_bytes(union regval *v, char *p)
226 {
227     size_t n = strlen(p);
228
229     allocate_bytes(v, n/2);
230     parse_hex(v->bytes.p, v->bytes.sz, p);
231 }
232
233 static void dump_bytes(FILE *fp, const union regval *v)
234     { dump_hex(fp, v->bytes.p, v->bytes.sz); }
235
236 static int eq_bytes(const union regval *v0, const union regval *v1)
237 {
238     return v0->bytes.sz == v1->bytes.sz &&
239         !memcmp(v0->bytes.p, v1->bytes.p, v0->bytes.sz);
240 }
241
242 static void release_bytes(union regval *v) { free(v->bytes.p); }
243
244 const struct regty regty_bytes = {
245     init_bytes,
246     parse_bytes,
247     dump_bytes,
248     eq_bytes,
249     release_bytes
250 };
251
252 /*----- Core test machinery -----------------------------------------------*/
253
254 /* Say that a register is `reset' by releasing and then re-initializing it.
255  * While there is a current test, all of that test's registers are
256  * initialized.  The input registers are reset at the end of `check', ready
257  * for the next test to load new values.  The output registers are reset at
258  * the end of `check_test_output', so that a test runner can run a test
259  * multiple times against the same test input, but with different context
260  * data.
261  */
262
263 #define REG(rvec, i)                                                    \
264     ((struct reg *)((unsigned char *)state->rvec + (i)*state->regsz))
265
266 void check_test_output(struct test_state *state, const struct test *test)
267 {
268     const struct regdef *def;
269     struct reg *reg, *in, *out;
270     int ok = 1;
271     int match;
272
273     for (def = test->regs; def->name; def++) {
274         if (def->i >= state->nrout) continue;
275         in = REG(in, def->i); out = REG(out, def->i);
276         if (!def->ty->eq(&in->v, &out->v)) ok = 0;
277     }
278     if (ok)
279         state->win++;
280     else {
281         printf("failed test `%s'\n", test->name);
282         for (def = test->regs; def->name; def++) {
283             in = REG(in, def->i);
284             if (def->i >= state->nrout) {
285                 printf("\t   input `%s' = ", def->name);
286                 def->ty->dump(stdout, &in->v);
287             } else {
288                 out = REG(out, def->i);
289                 match = def->ty->eq(&in->v, &out->v);
290                 printf("\t%s `%s' = ",
291                        match ? "  output" : "expected", def->name);
292                 def->ty->dump(stdout, &in->v);
293                 if (!match) {
294                     printf("\tcomputed `%s' = ", def->name);
295                     def->ty->dump(stdout, &out->v);
296                 }
297             }
298         }
299         state->lose++;
300     }
301
302     for (def = test->regs; def->name; def++) {
303         if (def->i >= state->nrout) continue;
304         reg = REG(out, def->i);
305         def->ty->release(&reg->v); def->ty->init(&reg->v);
306     }
307 }
308
309 void run_test(struct test_state *state, const struct test *test)
310 {
311     test->fn(state->out, state->in, 0);
312     check_test_output(state, test);
313 }
314
315 static void check(struct test_state *state, const struct test *test)
316 {
317     const struct regdef *def, *miss = 0;
318     struct reg *reg;
319     int any = 0;
320
321     if (!test) return;
322     for (def = test->regs; def->name; def++) {
323         reg = REG(in, def->i);
324         if (reg->f&REGF_LIVE) any = 1;
325         else if (!miss && !(def->f&REGF_OPT)) miss = def;
326     }
327     if (!any) return;
328     if (miss)
329         bail("register `%s' not set in test `%s'", def->name, test->name);
330
331     test->run(state, test);
332
333     for (def = test->regs; def->name; def++) {
334         reg = REG(in, def->i);
335         reg->f = 0; def->ty->release(&reg->v); def->ty->init(&reg->v);
336     }
337 }
338
339 int run_test_suite(unsigned nrout, unsigned nreg, size_t regsz,
340                    const struct test *tests, FILE *fp)
341 {
342     struct linebuf buf = LINEBUF_INIT;
343     struct test_state state[1];
344     const struct test *test;
345     const struct regdef *def;
346     struct reg *reg;
347     char *p;
348     const char *q;
349     int total;
350     size_t n;
351
352     for (test = tests; test->name; test++)
353         for (def = test->regs; def->name; def++)
354             assert(def->i < nreg);
355
356     state->in = xmalloc(nreg*regsz);
357     state->out = xmalloc(nrout*regsz);
358     state->nrout = nrout;
359     state->nreg = nreg;
360     state->regsz = regsz;
361     state->win = state->lose = 0;
362
363     test = 0;
364     lno = 0;
365     while (!read_line(&buf, fp)) {
366         lno++;
367         p = buf.p; n = strlen(buf.p);
368
369         while (isspace((unsigned char)*p)) p++;
370         if (*p == '#') continue;
371         if (!*p) { check(state, test); continue; }
372
373         q = p;
374         while (*p && !isspace((unsigned char)*p)) p++;
375         if (*p) *p++ = 0;
376
377         if (!strcmp(q, "test")) {
378             if (!*p) bail("missing argument");
379             check(state, test);
380             if (test) {
381                 for (def = test->regs; def->name; def++) {
382                     def->ty->release(&REG(in, def->i)->v);
383                     if (def->i < state->nrout)
384                         def->ty->release(&REG(out, def->i)->v);
385                 }
386             }
387             for (test = tests; test->name; test++)
388                 if (!strcmp(p, test->name)) goto found_test;
389             bail("unknown test `%s'", p);
390         found_test:
391             for (def = test->regs; def->name; def++) {
392                 reg = REG(in, def->i);
393                 reg->f = 0; def->ty->init(&reg->v);
394                 if (def->i < state->nrout) {
395                     reg = REG(out, def->i);
396                     reg->f = 0; def->ty->init(&reg->v);
397                 }
398             }
399             continue;
400         }
401
402         if (!test) bail("no current test");
403         for (def = test->regs; def->name; def++)
404             if (!strcmp(q, def->name)) goto found_reg;
405         bail("unknown register `%s' in test `%s'", q, test->name);
406     found_reg:
407         reg = REG(in, def->i);
408         if (reg->f&REGF_LIVE) bail("register `%s' already set", def->name);
409         def->ty->parse(&reg->v, p); reg->f |= REGF_LIVE;
410     }
411     check(state, test);
412
413     total = state->win + state->lose;
414     if (!state->lose)
415         printf("PASSED all %d test%s\n", state->win, total == 1 ? "" : "s");
416     else
417         printf("FAILED %d of %d test%s\n", state->lose, total,
418                total == 1 ? "" : "s");
419     return state->lose ? 1 : 0;
420 }