chiark / gitweb /
base/asm-common.h: Improve conditional instruction notation.
[catacomb] / pub / rsa-test.c
1 /* -*-c-*-
2  *
3  * Testing RSA padding operations
4  *
5  * (c) 2004 Straylight/Edgeware
6  */
7
8 /*----- Licensing notice --------------------------------------------------*
9  *
10  * This file is part of Catacomb.
11  *
12  * Catacomb is free software; you can redistribute it and/or modify
13  * it under the terms of the GNU Library General Public License as
14  * published by the Free Software Foundation; either version 2 of the
15  * License, or (at your option) any later version.
16  *
17  * Catacomb is distributed in the hope that it will be useful,
18  * but WITHOUT ANY WARRANTY; without even the implied warranty of
19  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
20  * GNU Library General Public License for more details.
21  *
22  * You should have received a copy of the GNU Library General Public
23  * License along with Catacomb; if not, write to the Free
24  * Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
25  * MA 02111-1307, USA.
26  */
27
28 /*----- Header files ------------------------------------------------------*/
29
30 #include <mLib/macros.h>
31
32 #include "fibrand.h"
33 #include "rsa.h"
34
35 /*----- Main code ---------------------------------------------------------*/
36
37 static int tencpad(int nbits,
38                    dstr *p, int rc, mp *c,
39                    const char *ename, dstr *eparam, rsa_pad *e, void *earg)
40 {
41   size_t n = (nbits + 7)/8;
42   void *q = xmalloc(n);
43   mp *d;
44   int ok = 1;
45
46   d = e(MP_NEW, p->buf, p->len, q, n, nbits, earg);
47   if (!d == !rc || (!rc && !MP_EQ(d, c))) {
48     ok = 0;
49     fprintf(stderr, "*** %s padding failed!\n", ename);
50     fprintf(stderr, "*** padding bits = %d\n", nbits);
51     if (eparam) {
52       fprintf(stderr, "*** encoding parameters = ");
53       type_hex.dump(eparam, stderr);
54       fputc('\n', stderr);
55     }
56     fprintf(stderr, "*** input message = "); type_hex.dump(p, stderr);
57     if (rc)
58       fprintf(stderr, "\n*** expected failure\n");
59     else {
60       MP_EPRINTX("\n*** expected", c);
61       MP_EPRINTX("*** computed", d);
62     }
63   }
64   mp_drop(d);
65   mp_drop(c);
66   xfree(q);
67   assert(mparena_count(MPARENA_GLOBAL) == 0);
68   return (ok);
69 }
70
71 #define tsigpad tencpad
72
73 #define DSTR_EQ(x, y)                                                   \
74   ((x)->len == (y)->len && MEMCMP((x)->buf, ==, (y)->buf, (x)->len))
75
76 static int tdecpad(int nbits,
77                    mp *c, int rc, dstr *p,
78                    const char *ename, dstr *eparam,
79                    rsa_decunpad *e, void *earg)
80 {
81   dstr d = DSTR_INIT;
82   int n = (nbits + 7)/8;
83   int ok = 1;
84
85   dstr_ensure(&d, n);
86   n = e(c, (octet *)d.buf, n, nbits, earg);
87   if (n >= 0)
88     d.len += n;
89   if (n != rc || (rc >= 0 && !DSTR_EQ(&d, p))) {
90     ok = 0;
91     fprintf(stderr, "*** %s encryption unpadding failed!\n", ename);
92     fprintf(stderr, "*** padding bits = %d\n", nbits);
93     if (eparam) {
94       fprintf(stderr, "*** encoding parameters = ");
95       type_hex.dump(eparam, stderr);
96       fputc('\n', stderr);
97     }
98     MP_EPRINTX("*** input", c);
99     if (rc < 0)
100       fprintf(stderr, "*** expected failure\n");
101     else {
102       fprintf(stderr, "*** expected: %d = ", rc); type_hex.dump(p, stderr);
103       fprintf(stderr, "\n*** computed: %d = ", n); type_hex.dump(&d, stderr);
104       fprintf(stderr, "\n");
105     }
106   }
107   mp_drop(c);
108   dstr_destroy(&d);
109   assert(mparena_count(MPARENA_GLOBAL) == 0);
110   return (ok);
111 }
112
113 static int tvrfpad(int nbits,
114                    mp *c, dstr *m, int rc, dstr *p,
115                    const char *ename, dstr *eparam,
116                    rsa_vrfunpad *e, void *earg)
117 {
118   dstr d = DSTR_INIT;
119   int n = (nbits + 7)/8;
120   int ok = 1;
121
122   dstr_ensure(&d, n);
123   n = e(c, m->len ? (octet *)m->buf : 0, m->len,
124         (octet *)d.buf, n, nbits, earg);
125   if (n >= 0)
126     d.len += n;
127   if (n != rc || (rc >= 0 && !DSTR_EQ(&d, p))) {
128     ok = 0;
129     fprintf(stderr, "*** %s signature unpadding failed!\n", ename);
130     fprintf(stderr, "*** padding bits = %d\n", nbits);
131     MP_EPRINTX("*** input", c);
132     if (eparam) {
133       fprintf(stderr, "*** encoding parameters = ");
134       type_hex.dump(eparam, stderr);
135       fputc('\n', stderr);
136     }
137     fprintf(stderr, "*** message = "); type_hex.dump(m, stderr);
138     if (rc < 0)
139       fprintf(stderr, "\n*** expected failure\n");
140     else {
141       fprintf(stderr, "\n*** expected = %d: ", rc); type_hex.dump(p, stderr);
142       fprintf(stderr, "\n*** computed = %d: ", n); type_hex.dump(&d, stderr);
143       fprintf(stderr, "\n");
144     }
145   }
146   mp_drop(c);
147   dstr_destroy(&d);
148   assert(mparena_count(MPARENA_GLOBAL) == 0);
149   return (ok);
150 }
151
152 static int tencpub(rsa_pub *rp,
153                    dstr *p, int rc, mp *c,
154                    const char *ename, dstr *eparam, rsa_pad *e, void *earg)
155 {
156   mp *d;
157   rsa_pubctx rpc;
158   int ok = 1;
159
160   rsa_pubcreate(&rpc, rp);
161   d = rsa_encrypt(&rpc, MP_NEW, p->buf, p->len, e, earg);
162   if (!d == !rc || (!rc && !MP_EQ(d, c))) {
163     ok = 0;
164     fprintf(stderr, "*** encrypt with %s padding failed!\n", ename);
165     MP_EPRINTX("*** key.n", rp->n);
166     MP_EPRINTX("*** key.e", rp->e);
167     if (eparam) {
168       fprintf(stderr, "*** encoding parameters = ");
169       type_hex.dump(eparam, stderr);
170       fputc('\n', stderr);
171     }
172     fprintf(stderr, "*** input message = "); type_hex.dump(p, stderr);
173     if (rc)
174       fprintf(stderr, "\n*** expected failure\n");
175     else {
176       MP_EPRINTX("\n*** expected", c);
177       MP_EPRINTX("*** computed", d);
178     }
179   }
180   rsa_pubdestroy(&rpc);
181   rsa_pubfree(rp);
182   mp_drop(d);
183   mp_drop(c);
184   assert(mparena_count(MPARENA_GLOBAL) == 0);
185   return (ok);
186 }
187
188 static int tsigpriv(rsa_priv *rp,
189                     dstr *p, int rc, mp *c,
190                     const char *ename, dstr *eparam, rsa_pad *e, void *earg)
191 {
192   mp *d;
193   grand *r = fibrand_create(0);
194   rsa_privctx rpc;
195   int ok = 1;
196
197   rsa_privcreate(&rpc, rp, r);
198   d = rsa_sign(&rpc, MP_NEW, p->buf, p->len, e, earg);
199   if (!d == !rc || (!rc && !MP_EQ(d, c))) {
200     ok = 0;
201     fprintf(stderr, "*** sign with %s padding failed!\n", ename);
202     MP_EPRINTX("*** key.n", rp->n);
203     MP_EPRINTX("*** key.d", rp->d);
204     MP_EPRINTX("*** key.e", rp->e);
205     if (eparam) {
206       fprintf(stderr, "*** encoding parameters = ");
207       type_hex.dump(eparam, stderr);
208       fputc('\n', stderr);
209     }
210     fprintf(stderr, "*** input message = "); type_hex.dump(p, stderr);
211     if (rc)
212       fprintf(stderr, "\n*** expected failure\n");
213     else {
214       MP_EPRINTX("\n*** expected", c);
215       MP_EPRINTX("\n*** computed", d);
216     }
217   }
218   rsa_privdestroy(&rpc);
219   rsa_privfree(rp);
220   mp_drop(d);
221   mp_drop(c);
222   GR_DESTROY(r);
223   assert(mparena_count(MPARENA_GLOBAL) == 0);
224   return (ok);
225 }
226
227 static int tdecpriv(rsa_priv *rp,
228                     mp *c, int rc, dstr *p,
229                     const char *ename, dstr *eparam,
230                     rsa_decunpad *e, void *earg)
231 {
232   rsa_privctx rpc;
233   dstr d = DSTR_INIT;
234   grand *r = fibrand_create(0);
235   int n;
236   int ok = 1;
237
238   rsa_privcreate(&rpc, rp, r);
239   n = rsa_decrypt(&rpc, c, &d, e, earg);
240   if (n != rc || (rc >= 0 && !DSTR_EQ(&d, p))) {
241     ok = 0;
242     fprintf(stderr, "*** decryption with %s padding failed!\n", ename);
243     MP_EPRINTX("*** key.n", rp->n);
244     MP_EPRINTX("*** key.d", rp->d);
245     MP_EPRINTX("*** key.e", rp->e);
246     if (eparam) {
247       fprintf(stderr, "*** encoding parameters = ");
248       type_hex.dump(eparam, stderr);
249       fputc('\n', stderr);
250     }
251     MP_EPRINTX("*** input", c);
252     if (rc < 0)
253       fprintf(stderr, "*** expected failure\n");
254     else {
255       fprintf(stderr, "*** expected = %d: ", rc); type_hex.dump(p, stderr);
256       fprintf(stderr, "\n*** computed = %d: ", n); type_hex.dump(&d, stderr);
257       fprintf(stderr, "\n");
258     }
259   }
260   rsa_privdestroy(&rpc);
261   rsa_privfree(rp);
262   mp_drop(c);
263   dstr_destroy(&d);
264   GR_DESTROY(r);
265   assert(mparena_count(MPARENA_GLOBAL) == 0);
266   return (ok);
267 }
268
269 static int tvrfpub(rsa_pub *rp,
270                    mp *c, dstr *m, int rc, dstr *p,
271                    const char *ename, dstr *eparam,
272                    rsa_vrfunpad *e, void *earg)
273 {
274   rsa_pubctx rpc;
275   dstr d = DSTR_INIT;
276   int n;
277   int ok = 1;
278
279   rsa_pubcreate(&rpc, rp);
280   n = rsa_verify(&rpc, c, m->len ? m->buf : 0, m->len, &d, e, earg);
281   if (n != rc || (rc >= 0 && !DSTR_EQ(&d, p))) {
282     ok = 0;
283     fprintf(stderr, "*** verification with %s padding failed!\n", ename);
284     MP_EPRINTX("*** key.n", rp->n);
285     MP_EPRINTX("*** key.e", rp->e);
286     if (eparam) {
287       fprintf(stderr, "*** encoding parameters = ");
288       type_hex.dump(eparam, stderr);
289       fputc('\n', stderr);
290     }
291     MP_EPRINTX("*** input", c);
292     fprintf(stderr, "*** message = "); type_hex.dump(m, stderr);
293     if (rc < 0)
294       fprintf(stderr, "\n*** expected failure\n");
295     else {
296       fprintf(stderr, "\n*** expected = %d: ", rc); type_hex.dump(p, stderr);
297       fprintf(stderr, "\n*** computed = %d: ", n); type_hex.dump(&d, stderr);
298       fprintf(stderr, "\n");
299     }
300   }
301   rsa_pubdestroy(&rpc);
302   rsa_pubfree(rp);
303   mp_drop(c);
304   dstr_destroy(&d);
305   assert(mparena_count(MPARENA_GLOBAL) == 0);
306   return (ok);
307 }
308
309 /*----- Deep magic --------------------------------------------------------*
310  *
311  * Wahey!  Whacko macro programming on curry and lager.  There's nothing like
312  * it.
313  */
314
315 #define DECL_priv                                                       \
316   rsa_priv rp = { 0 };
317 #define FUNC_priv                                                       \
318   rp.n = *(mp **)v++->buf;                                              \
319   rp.e = *(mp **)v++->buf;                                              \
320   rp.d = *(mp **)v++->buf;                                              \
321   rsa_recover(&rp);
322 #define ARG_priv                                                        \
323   &rp,
324 #define TAB_priv                                                        \
325   &type_mp, &type_mp, &type_mp,
326
327 #define DECL_pub                                                        \
328   rsa_pub rp;
329 #define FUNC_pub                                                        \
330   rp.n = *(mp **)v++->buf;                                              \
331   rp.e = *(mp **)v++->buf;
332 #define ARG_pub                                                         \
333   &rp,
334 #define TAB_pub                                                         \
335   &type_mp, &type_mp,
336
337 #define DECL_pad                                                        \
338   int nbits;
339 #define FUNC_pad                                                        \
340   nbits = *(int *)v++->buf;
341 #define ARG_pad                                                         \
342   nbits,
343 #define TAB_pad                                                         \
344   &type_int,
345
346 #define DECL_enc                                                        \
347   dstr *p;                                                              \
348   int rc;                                                               \
349   mp *c;
350 #define FUNC_enc                                                        \
351   p = v++;                                                              \
352   rc = *(int *)v++->buf;                                                \
353   c = *(mp **)v++->buf;
354 #define ARG_enc                                                         \
355   p, rc, c,
356 #define TAB_enc                                                         \
357   &type_hex, &type_int, &type_mp,
358
359 #define DECL_sig DECL_enc
360 #define FUNC_sig FUNC_enc
361 #define ARG_sig ARG_enc
362 #define TAB_sig TAB_enc
363
364 #define DECL_dec                                                        \
365   mp *c;                                                                \
366   int rc;                                                               \
367   dstr *p;
368 #define FUNC_dec                                                        \
369   c = *(mp **)v++->buf;                                                 \
370   rc = *(int *)v++->buf;                                                \
371   p = v++;
372 #define ARG_dec                                                         \
373   c, rc, p,
374 #define TAB_dec                                                         \
375   &type_mp, &type_int, &type_hex,
376
377 #define DECL_vrf                                                        \
378   mp *c;                                                                \
379   dstr *m;                                                              \
380   int rc;                                                               \
381   dstr *p;
382 #define FUNC_vrf                                                        \
383   c = *(mp **)v++->buf;                                                 \
384   m = v++;                                                              \
385   rc = *(int *)v++->buf;                                                \
386   p = v++;
387 #define ARG_vrf                                                         \
388   c, m, rc, p,
389 #define TAB_vrf                                                         \
390   &type_mp, &type_hex, &type_int, &type_hex,
391
392 #define DECL_p1enc                                                      \
393   pkcs1 p1;                                                             \
394   dstr *ep;
395 #define FUNC_p1enc                                                      \
396   p1.r = fib;                                                           \
397   ep = v++;                                                             \
398   p1.ep = ep->buf;                                                      \
399   p1.epsz = ep->len;
400 #define ARG_p1enc                                                       \
401   "pkcs1", ep, pkcs1_cryptencode, &p1
402 #define TAB_p1enc                                                       \
403   &type_hex
404
405 #define DECL_p1sig DECL_p1enc
406 #define FUNC_p1sig FUNC_p1enc
407 #define ARG_p1sig                                                       \
408   "pkcs1", ep, pkcs1_sigencode, &p1
409 #define TAB_p1sig TAB_p1enc
410
411 #define DECL_p1dec DECL_p1enc
412 #define FUNC_p1dec FUNC_p1enc
413 #define ARG_p1dec                                                       \
414   "pkcs1", ep, pkcs1_cryptdecode, &p1
415 #define TAB_p1dec TAB_p1enc
416
417 #define DECL_p1vrf DECL_p1enc
418 #define FUNC_p1vrf FUNC_p1enc
419 #define ARG_p1vrf                                                       \
420   "pkcs1", ep, pkcs1_sigdecode, &p1
421 #define TAB_p1vrf TAB_p1enc
422
423 #define DECL_oaepenc                                                    \
424   oaep o;                                                               \
425   dstr *ep;
426 #define FUNC_oaepenc                                                    \
427   o.r = fib;                                                            \
428   o.cc = gcipher_byname(v++->buf);                                      \
429   o.ch = ghash_byname(v++->buf);                                        \
430   ep = v++;                                                             \
431   o.ep = ep->buf;                                                       \
432   o.epsz = ep->len;
433 #define ARG_oaepenc                                                     \
434   "oaep", ep, oaep_encode, &o
435 #define TAB_oaepenc                                                     \
436   &type_string, &type_string, &type_hex
437
438 #define DECL_oaepdec DECL_oaepenc
439 #define FUNC_oaepdec FUNC_oaepenc
440 #define ARG_oaepdec                                                     \
441   "oaep", ep, oaep_decode, &o
442 #define TAB_oaepdec TAB_oaepenc
443
444 #define DECL_psssig                                                     \
445   pss pp;
446 #define FUNC_psssig                                                     \
447   pp.r = fib;                                                           \
448   pp.cc = gcipher_byname(v++->buf);                                     \
449   pp.ch = ghash_byname(v++->buf);                                       \
450   pp.ssz = *(int *)v++->buf;
451 #define ARG_psssig                                                      \
452   "pss", 0, pss_encode, &pp
453 #define TAB_psssig                                                      \
454   &type_string, &type_string, &type_int
455
456 #define DECL_pssvrf DECL_psssig
457 #define FUNC_pssvrf FUNC_psssig
458 #define ARG_pssvrf                                                      \
459   "pss", 0, pss_decode, &pp
460 #define TAB_pssvrf TAB_psssig
461
462 #define TESTS(DO)                                                       \
463   DO(pad, enc, p1enc)                                                   \
464   DO(pad, dec, p1dec)                                                   \
465   DO(pad, sig, p1sig)                                                   \
466   DO(pad, vrf, p1vrf)                                                   \
467   DO(pub, enc, p1enc)                                                   \
468   DO(priv, dec, p1dec)                                                  \
469   DO(priv, sig, p1sig)                                                  \
470   DO(pub, vrf, p1vrf)                                                   \
471   DO(pad, enc, oaepenc)                                                 \
472   DO(pad, dec, oaepdec)                                                 \
473   DO(pub, enc, oaepenc)                                                 \
474   DO(priv, dec, oaepdec)                                                \
475   DO(pad, sig, psssig)                                                  \
476   DO(pad, vrf, pssvrf)                                                  \
477   DO(priv, sig, psssig)                                                 \
478   DO(pub, vrf, pssvrf)
479
480 #define FUNCS(key, op, enc)                                             \
481   int t_##key##_##enc(dstr *v)                                          \
482   {                                                                     \
483     DECL_##key                                                          \
484     DECL_##op                                                           \
485     DECL_##enc                                                          \
486     fib->ops->misc(fib, GRAND_SEEDINT, 14);                             \
487     FUNC_##key                                                          \
488     FUNC_##op                                                           \
489     FUNC_##enc                                                          \
490     return (t##op##key(ARG_##key ARG_##op ARG_##enc));                  \
491   }
492
493 #define TAB(key, op, enc)                                               \
494   { #enc "-" #key, t_##key##_##enc, { TAB_##key TAB_##op TAB_##enc } },
495
496 static grand *fib;
497
498 TESTS(FUNCS)
499
500 static const test_chunk tests[] = {
501   TESTS(TAB)
502   { 0 }
503 };
504
505 int main(int argc, char *argv[])
506 {
507   sub_init();
508   fib = fibrand_create(0);
509   test_run(argc, argv, tests, SRCDIR "/t/rsa");
510   GR_DESTROY(fib);
511   return (0);
512 }
513
514 /*----- That's all, folks -------------------------------------------------*/