chiark / gitweb /
math/mpx-mul4-*-sse2.S (mpxmont_redc4): Fix end-of-outer-loop commentary.
[catacomb] / progs / cc-kem.c
1 /* -*-c-*-
2  *
3  * Catcrypt key-encapsulation
4  *
5  * (c) 2004 Straylight/Edgeware
6  */
7
8 /*----- Licensing notice --------------------------------------------------*
9  *
10  * This file is part of Catacomb.
11  *
12  * Catacomb is free software; you can redistribute it and/or modify
13  * it under the terms of the GNU Library General Public License as
14  * published by the Free Software Foundation; either version 2 of the
15  * License, or (at your option) any later version.
16  *
17  * Catacomb is distributed in the hope that it will be useful,
18  * but WITHOUT ANY WARRANTY; without even the implied warranty of
19  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
20  * GNU Library General Public License for more details.
21  *
22  * You should have received a copy of the GNU Library General Public
23  * License along with Catacomb; if not, write to the Free
24  * Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
25  * MA 02111-1307, USA.
26  */
27
28 /*----- Header files ------------------------------------------------------*/
29
30 #define _FILE_OFFSET_BITS 64
31
32 #include <stdlib.h>
33
34 #include <mLib/alloc.h>
35 #include <mLib/dstr.h>
36 #include <mLib/macros.h>
37 #include <mLib/report.h>
38 #include <mLib/sub.h>
39
40 #include "gaead.h"
41 #include "mprand.h"
42 #include "rand.h"
43
44 #include "ec.h"
45 #include "ec-keys.h"
46 #include "dh.h"
47 #include "rsa.h"
48 #include "x25519.h"
49 #include "x448.h"
50
51 #include "rmd160.h"
52 #include "blowfish-cbc.h"
53 #include "chacha20-poly1305.h"
54 #include "poly1305.h"
55 #include "salsa20.h"
56 #include "chacha.h"
57
58 #include "cc.h"
59
60 /*----- Bulk crypto -------------------------------------------------------*/
61
62 /* --- Authenticated encryption schemes --- */
63
64 typedef struct aead_encctx {
65   bulk b;
66   const gcaead *aec;
67   gaead_key *key;
68   union { gaead_enc *enc; gaead_dec *dec; } ed;
69   octet *t;
70   size_t nsz, tsz;
71 } aead_encctx;
72
73 static bulk *aead_internalinit(key *k, const gcaead *aec)
74 {
75   aead_encctx *ctx = CREATE(aead_encctx);
76
77   ctx->key = 0;
78   ctx->aec = aec;
79   if ((ctx->nsz = keysz_pad(4, aec->noncesz)) == 0)
80     die(EXIT_FAILURE, "no suitable nonce size for `%s'", aec->name);
81   ctx->tsz = keysz(0, ctx->aec->tagsz);
82
83   return (&ctx->b);
84 }
85
86 static bulk *aead_init(key *k, const char *calg, const char *halg)
87 {
88   const gcaead *aec;
89   const char *q;
90   dstr t = DSTR_INIT;
91
92   key_fulltag(k, &t);
93
94   if ((q = key_getattr(0, k, "cipher")) != 0) calg = q;
95   if (!calg) aec = &chacha20_poly1305;
96   else if ((aec = gaead_byname(calg)) == 0)
97     die(EXIT_FAILURE, "AEAD scheme `%s' not found in key `%s'",
98         calg, t.buf);
99
100   dstr_destroy(&t);
101   return (aead_internalinit(k, aec));
102 }
103
104 static int aead_commonsetup(aead_encctx *ctx, gcipher *cx)
105 {
106   size_t ksz, n;
107
108   n = ksz = keysz(0, ctx->aec->keysz);
109   if (n < ctx->nsz) n = ctx->nsz;
110   if (n < ctx->tsz) n = ctx->tsz;
111   ctx->t = xmalloc(n);
112
113   GC_ENCRYPT(cx, 0, ctx->t, ksz);
114   ctx->key = GAEAD_KEY(ctx->aec, ctx->t, ksz);
115   return (0);
116 }
117
118 static size_t aead_overhead(bulk *b)
119   { aead_encctx *ctx = (aead_encctx *)b; return (ctx->aec->ohd + ctx->tsz); }
120
121 static void aead_commondestroy(aead_encctx *ctx)
122 {
123   if (ctx->key) GAEAD_DESTROY(ctx->key);
124   xfree(ctx->t);
125   DESTROY(ctx);
126 }
127
128 static int aead_encsetup(bulk *b, gcipher *cx)
129 {
130   aead_encctx *ctx = (aead_encctx *)b;
131   ctx->ed.enc = 0; return (aead_commonsetup(ctx, cx));
132 }
133
134 static const char *aead_encdoit(bulk *b, uint32 seq, buf *bb,
135                                 const void *p, size_t sz)
136 {
137   aead_encctx *ctx = (aead_encctx *)b;
138   octet *t;
139   int rc;
140
141   memset(ctx->t + 4, 0, ctx->nsz - 4); STORE32_B(ctx->t, seq);
142   if (!ctx->ed.enc)
143     ctx->ed.enc = GAEAD_ENC(ctx->key, ctx->t, ctx->nsz, 0, sz, ctx->tsz);
144   else
145     GAEAD_REINIT(ctx->ed.enc, ctx->t, ctx->nsz, 0, sz, ctx->tsz);
146   t = buf_get(bb, ctx->tsz); assert(t);
147   rc = GAEAD_ENCRYPT(ctx->ed.enc, p, sz, bb); assert(rc >= 0);
148   rc = GAEAD_DONE(ctx->ed.enc, 0, bb, t, ctx->tsz); assert(rc >= 0);
149   return (0);
150 }
151
152 static void aead_encdestroy(bulk *b)
153 {
154   aead_encctx *ctx = (aead_encctx *)b;
155   if (ctx->ed.enc) GAEAD_DESTROY(ctx->ed.enc);
156   aead_commondestroy(ctx);
157 }
158
159 static int aead_decsetup(bulk *b, gcipher *cx)
160 {
161   aead_encctx *ctx = (aead_encctx *)b;
162   ctx->ed.dec = 0; return (aead_commonsetup(ctx, cx));
163 }
164
165 static const char *aead_decdoit(bulk *b, uint32 seq, buf *bb,
166                                 const void *p, size_t sz)
167 {
168   aead_encctx *ctx = (aead_encctx *)b;
169   buf bin;
170   const octet *t;
171   int rc;
172
173   memset(ctx->t + 4, 0, ctx->nsz - 4); STORE32_B(ctx->t, seq);
174   if (!ctx->ed.dec)
175     ctx->ed.dec = GAEAD_DEC(ctx->key, ctx->t, ctx->nsz, 0, sz, ctx->tsz);
176   else
177     GAEAD_REINIT(ctx->ed.enc, ctx->t, ctx->nsz, 0, sz, ctx->tsz);
178
179   buf_init(&bin, (/*unconst*/ void *)p, sz);
180   t = buf_get(&bin, ctx->tsz); if (!t) return ("no tag");
181   rc = GAEAD_DECRYPT(ctx->ed.dec, BCUR(&bin), BLEFT(&bin), bb);
182   assert(rc >= 0);
183   rc = GAEAD_DONE(ctx->ed.dec, 0, bb, t, ctx->tsz); assert(rc >= 0);
184   if (!rc) return ("authentication failure");
185   return (0);
186 }
187
188 static void aead_decdestroy(bulk *b)
189 {
190   aead_encctx *ctx = (aead_encctx *)b;
191   if (ctx->ed.dec) GAEAD_DESTROY(ctx->ed.dec);
192   aead_commondestroy(ctx);
193 }
194
195 static const struct bulkops aead_encops = {
196   aead_init, aead_encsetup, aead_overhead,
197   aead_encdoit, aead_encdestroy
198 }, aead_decops = {
199   aead_init, aead_decsetup, aead_overhead,
200   aead_decdoit, aead_decdestroy
201 };
202
203 /* --- NaCl `secretbox' in terms of AEAD --- */
204
205 static bulk *naclbox_init(key *k, const char *calg, const char *halg)
206 {
207   const gcaead *aec;
208   dstr t = DSTR_INIT;
209   const char *q;
210
211   key_fulltag(k, &t);
212
213   if ((q = key_getattr(0, k, "cipher")) != 0) calg = q;
214   if (!calg || STRCMP(calg, ==, "salsa20")) aec = &salsa20_naclbox;
215   else if (STRCMP(calg, ==, "salsa20/12")) aec = &salsa2012_naclbox;
216   else if (STRCMP(calg, ==, "salsa20/8")) aec = &salsa208_naclbox;
217   else if (STRCMP(calg, ==, "chacha20")) aec = &chacha20_naclbox;
218   else if (STRCMP(calg, ==, "chacha12")) aec = &chacha12_naclbox;
219   else if (STRCMP(calg, ==, "chacha8")) aec = &chacha8_naclbox;
220   else {
221     die(EXIT_FAILURE,
222         "unknown or inappropriate encryption scheme `%s' in key `%s'",
223         calg, t.buf);
224   }
225
226   dstr_destroy(&t);
227   return (aead_internalinit(k, aec));
228 }
229
230 static const bulkops naclbox_encops = {
231   naclbox_init, aead_encsetup, aead_overhead,
232   aead_encdoit, aead_encdestroy
233 }, naclbox_decops = {
234   naclbox_init, aead_decsetup, aead_overhead,
235   aead_decdoit, aead_decdestroy
236 };
237
238 /* --- Generic composition --- */
239
240 typedef struct gencomp_encctx {
241   bulk b;
242   const gccipher *cc;
243   const gcmac *mc;
244   gcipher *c, *cx;
245   gmac *m;
246   octet *t; size_t tsz;
247 } gencomp_encctx;
248
249 static bulk *gencomp_init(key *k, const char *calg, const char *halg)
250 {
251   gencomp_encctx *ctx = CREATE(gencomp_encctx);
252   const char *q;
253   dstr d = DSTR_INIT, t = DSTR_INIT;
254
255   key_fulltag(k, &t);
256
257   if ((q = key_getattr(0, k, "cipher")) != 0) calg = q;
258   if (!calg) ctx->cc = &blowfish_cbc;
259   else if ((ctx->cc = gcipher_byname(calg)) == 0) {
260     die(EXIT_FAILURE, "encryption scheme `%s' not found in key `%s'",
261         calg, t.buf);
262   }
263
264   dstr_reset(&d);
265   if ((q = key_getattr(0, k, "mac")) == 0) {
266     dstr_putf(&d, "%s-hmac", halg);
267     q = d.buf;
268   }
269   if ((ctx->mc = gmac_byname(q)) == 0) {
270     die(EXIT_FAILURE,
271         "message authentication code `%s' not found in key `%s'",
272         q, t.buf);
273   }
274
275   return (&ctx->b);
276 }
277
278 static int gencomp_setup(bulk *b, gcipher *cx)
279 {
280   gencomp_encctx *ctx = (gencomp_encctx *)b;
281   size_t cn, mn, n;
282   octet *kd;
283
284   ctx->cx = cx;
285   n = ctx->cc->blksz;
286   cn = keysz(0, ctx->cc->keysz); if (cn > n) n = cn;
287   mn = keysz(0, ctx->mc->keysz); if (mn > n) n = mn;
288   ctx->t = kd = xmalloc(n); ctx->tsz = n;
289   GC_ENCRYPT(cx, 0, kd, cn);
290   ctx->c = GC_INIT(ctx->cc, kd, cn);
291   GC_ENCRYPT(cx, 0, kd, mn);
292   ctx->m = GM_KEY(ctx->mc, kd, mn);
293   return (0);
294 }
295
296 static size_t gencomp_overhead(bulk *b)
297 {
298   gencomp_encctx *ctx = (gencomp_encctx *)b;
299   return (ctx->cc->blksz + ctx->mc->hashsz); }
300
301 static void gencomp_destroy(bulk *b)
302 {
303   gencomp_encctx *ctx = (gencomp_encctx *)b;
304
305   GC_DESTROY(ctx->c);
306   GC_DESTROY(ctx->m);
307   xfree(ctx->t);
308   DESTROY(ctx);
309 }
310
311 static const char *gencomp_encdoit(bulk *b, uint32 seq, buf *bb,
312                                    const void *p, size_t sz)
313 {
314   gencomp_encctx *ctx = (gencomp_encctx *)b;
315   octet *tag, *ct;
316   ghash *h = GM_INIT(ctx->m);
317
318   GH_HASHU32(h, seq);
319   if (ctx->cc->blksz) {
320     GC_ENCRYPT(ctx->cx, 0, ctx->t, ctx->cc->blksz);
321     GC_SETIV(ctx->c, ctx->t);
322   }
323   tag = buf_get(bb, ctx->mc->hashsz); assert(tag);
324   ct = buf_get(bb, sz); assert(ct);
325   GC_ENCRYPT(ctx->c, p, ct, sz);
326   GH_HASH(h, ct, sz);
327   GH_DONE(h, tag);
328   GH_DESTROY(h);
329   return (0);
330 }
331
332 static const char *gencomp_decdoit(bulk *b, uint32 seq, buf *bb,
333                                    const void *p, size_t sz)
334 {
335   gencomp_encctx *ctx = (gencomp_encctx *)b;
336   buf bin;
337   const octet *tag, *ct;
338   octet *pt;
339   ghash *h;
340   int ok;
341
342   buf_init(&bin, (/*unconst*/ void *)p, sz);
343   if ((tag = buf_get(&bin, ctx->mc->hashsz)) == 0) return ("no tag");
344   ct = BCUR(&bin); sz = BLEFT(&bin);
345   pt = buf_get(bb, sz); assert(pt);
346
347   h = GM_INIT(ctx->m);
348   GH_HASHU32(h, seq);
349   GH_HASH(h, ct, sz);
350   ok = ct_memeq(tag, GH_DONE(h, 0), ctx->mc->hashsz);
351   GH_DESTROY(h);
352   if (!ok) return ("authentication failure");
353
354   if (ctx->cc->blksz) {
355     GC_ENCRYPT(ctx->cx, 0, ctx->t, ctx->cc->blksz);
356     GC_SETIV(ctx->c, ctx->t);
357   }
358   GC_DECRYPT(ctx->c, ct, pt, sz);
359   return (0);
360 }
361
362 static const bulkops gencomp_encops = {
363   gencomp_init, gencomp_setup, gencomp_overhead,
364   gencomp_encdoit, gencomp_destroy
365 }, gencomp_decops = {
366   gencomp_init, gencomp_setup, gencomp_overhead,
367   gencomp_decdoit, gencomp_destroy
368 };
369
370 const struct bulktab bulktab[] = {
371   { "gencomp",  &gencomp_encops,        &gencomp_decops },
372   { "naclbox",  &naclbox_encops,        &naclbox_decops },
373   { "aead",     &aead_encops,           &aead_decops },
374   { 0,          0,                      0 }
375 };
376
377 /*----- Key encapsulation -------------------------------------------------*/
378
379 /* --- RSA --- */
380
381 typedef struct rsa_encctx {
382   kem k;
383   rsa_pubctx rp;
384 } rsa_encctx;
385
386 static kem *rsa_encinit(key *k, void *kd)
387 {
388   rsa_encctx *re = CREATE(rsa_encctx);
389   rsa_pubcreate(&re->rp, kd);
390   return (&re->k);
391 }
392
393 static int rsa_encdoit(kem *k, dstr *d, ghash *h)
394 {
395   rsa_encctx *re = (rsa_encctx *)k;
396   mp *x = mprand_range(MP_NEW, re->rp.rp->n, &rand_global, 0);
397   mp *y = rsa_pubop(&re->rp, MP_NEW, x);
398   size_t n = mp_octets(re->rp.rp->n);
399   dstr_ensure(d, n);
400   mp_storeb(x, d->buf, n);
401   GH_HASH(h, d->buf, n);
402   mp_storeb(y, d->buf, n);
403   d->len += n;
404   mp_drop(x);
405   mp_drop(y);
406   return (0);
407 }
408
409 static const char *rsa_lengthcheck(mp *n)
410 {
411   if (mp_bits(n) < 1020) return ("key too short");
412   return (0);
413 }
414
415 static const char *rsa_enccheck(kem *k)
416 {
417   rsa_encctx *re = (rsa_encctx *)k;
418   const char *e;
419   if ((e = rsa_lengthcheck(re->rp.rp->n)) != 0) return (e);
420   return (0);
421 }
422
423 static void rsa_encdestroy(kem *k)
424 {
425   rsa_encctx *re = (rsa_encctx *)k;
426   rsa_pubdestroy(&re->rp);
427   DESTROY(re);
428 }
429
430 static const kemops rsa_encops = {
431   rsa_pubfetch, sizeof(rsa_pub),
432   rsa_encinit, rsa_encdoit, rsa_enccheck, rsa_encdestroy
433 };
434
435 typedef struct rsa_decctx {
436   kem k;
437   rsa_privctx rp;
438 } rsa_decctx;
439
440 static kem *rsa_decinit(key *k, void *kd)
441 {
442   rsa_decctx *rd = CREATE(rsa_decctx);
443   rsa_privcreate(&rd->rp, kd, &rand_global);
444   return (&rd->k);
445 }
446
447 static int rsa_decdoit(kem *k, dstr *d, ghash *h)
448 {
449   rsa_decctx *rd = (rsa_decctx *)k;
450   mp *x = mp_loadb(MP_NEW, d->buf, d->len);
451   size_t n;
452   char *p;
453
454   if (MP_CMP(x, >=, rd->rp.rp->n)) {
455     mp_drop(x);
456     return (-1);
457   }
458   n = mp_octets(rd->rp.rp->n);
459   p = xmalloc(n);
460   x = rsa_privop(&rd->rp, x, x);
461   mp_storeb(x, p, n);
462   GH_HASH(h, p, n);
463   mp_drop(x);
464   xfree(p);
465   return (0);
466 }
467
468 static const char *rsa_deccheck(kem *k)
469 {
470   rsa_decctx *rd = (rsa_decctx *)k;
471   const char *e;
472   if ((e = rsa_lengthcheck(rd->rp.rp->n)) != 0) return (e);
473   return (0);
474 }
475
476 static void rsa_decdestroy(kem *k)
477 {
478   rsa_decctx *rd = (rsa_decctx *)k;
479   rsa_privdestroy(&rd->rp);
480   DESTROY(rd);
481 }
482
483 static const kemops rsa_decops = {
484   rsa_privfetch, sizeof(rsa_priv),
485   rsa_decinit, rsa_decdoit, rsa_deccheck, rsa_decdestroy
486 };
487
488 /* --- DH and EC --- */
489
490 typedef struct dh_encctx {
491   kem k;
492   group *g;
493   mp *x;
494   ge *y;
495 } dh_encctx;
496
497 static dh_encctx *dh_doinit(key *k, const gprime_param *gp, mp *y,
498                             group *(*makegroup)(const gprime_param *),
499                             const char *what)
500 {
501   dh_encctx *de = CREATE(dh_encctx);
502   dstr t = DSTR_INIT;
503
504   key_fulltag(k, &t);
505   if ((de->g = makegroup(gp)) == 0)
506     die(EXIT_FAILURE, "bad %s group in key `%s'", what, t.buf);
507   de->x = MP_NEW;
508   de->y = G_CREATE(de->g);
509   if (G_FROMINT(de->g, de->y, y))
510     die(EXIT_FAILURE, "bad public key `%s'", t.buf);
511   dstr_destroy(&t);
512   return (de);
513 }
514
515 static dh_encctx *ec_doinit(key *k, const char *cstr, const ec *y)
516 {
517   dh_encctx *de = CREATE(dh_encctx);
518   ec_info ei;
519   const char *e;
520   dstr t = DSTR_INIT;
521
522   key_fulltag(k, &t);
523   if ((e = ec_getinfo(&ei, cstr)) != 0 ||
524       (de->g = group_ec(&ei)) == 0)
525     die(EXIT_FAILURE, "bad elliptic curve spec in key `%s': %s", t.buf, e);
526   de->x = MP_NEW;
527   de->y = G_CREATE(de->g);
528   if (G_FROMEC(de->g, de->y, y))
529     die(EXIT_FAILURE, "bad public curve point `%s'", t.buf);
530   dstr_destroy(&t);
531   return (de);
532 }
533
534 static kem *dh_encinit(key *k, void *kd)
535 {
536   dh_pub *dp = kd;
537   dh_encctx *de = dh_doinit(k, &dp->dp, dp->y, group_prime, "prime");
538   return (&de->k);
539 }
540
541 static kem *bindh_encinit(key *k, void *kd)
542 {
543   dh_pub *dp = kd;
544   dh_encctx *de = dh_doinit(k, &dp->dp, dp->y, group_binary, "binary");
545   return (&de->k);
546 }
547
548 static kem *ec_encinit(key *k, void *kd)
549 {
550   ec_pub *ep = kd;
551   dh_encctx *de = ec_doinit(k, ep->cstr, &ep->p);
552   return (&de->k);
553 }
554
555 static int dh_encdoit(kem *k, dstr *d, ghash *h)
556 {
557   dh_encctx *de = (dh_encctx *)k;
558   mp *r = mprand_range(MP_NEW, de->g->r, &rand_global, 0);
559   ge *x = G_CREATE(de->g);
560   ge *y = G_CREATE(de->g);
561   size_t n = de->g->noctets;
562   buf b;
563
564   G_EXP(de->g, x, de->g->g, r);
565   G_EXP(de->g, y, de->y, r);
566   dstr_ensure(d, n);
567   buf_init(&b, d->buf, n);
568   G_TORAW(de->g, &b, y);
569   GH_HASH(h, BBASE(&b), BLEN(&b));
570   buf_init(&b, d->buf, n);
571   G_TORAW(de->g, &b, x);
572   GH_HASH(h, BBASE(&b), BLEN(&b));
573   d->len += BLEN(&b);
574   mp_drop(r);
575   G_DESTROY(de->g, x);
576   G_DESTROY(de->g, y);
577   return (0);
578 }
579
580 static const char *dh_enccheck(kem *k)
581 {
582   dh_encctx *de = (dh_encctx *)k;
583   const char *e;
584   if ((e = G_CHECK(de->g, &rand_global)) != 0)
585     return (0);
586   if (group_check(de->g, de->y))
587     return ("public key not in subgroup");
588   return (0);
589 }
590
591 static void dh_encdestroy(kem *k)
592 {
593   dh_encctx *de = (dh_encctx *)k;
594   G_DESTROY(de->g, de->y);
595   mp_drop(de->x);
596   G_DESTROYGROUP(de->g);
597   DESTROY(de);
598 }
599
600 static const kemops dh_encops = {
601   dh_pubfetch, sizeof(dh_pub),
602   dh_encinit, dh_encdoit, dh_enccheck, dh_encdestroy
603 };
604
605 static const kemops bindh_encops = {
606   dh_pubfetch, sizeof(dh_pub),
607   bindh_encinit, dh_encdoit, dh_enccheck, dh_encdestroy
608 };
609
610 static const kemops ec_encops = {
611   ec_pubfetch, sizeof(ec_pub),
612   ec_encinit, dh_encdoit, dh_enccheck, dh_encdestroy
613 };
614
615 static kem *dh_decinit(key *k, void *kd)
616 {
617   dh_priv *dp = kd;
618   dh_encctx *de = dh_doinit(k, &dp->dp, dp->y, group_prime, "prime");
619   de->x = MP_COPY(dp->x);
620   return (&de->k);
621 }
622
623 static kem *bindh_decinit(key *k, void *kd)
624 {
625   dh_priv *dp = kd;
626   dh_encctx *de = dh_doinit(k, &dp->dp, dp->y, group_binary, "binary");
627   de->x = MP_COPY(dp->x);
628   return (&de->k);
629 }
630
631 static kem *ec_decinit(key *k, void *kd)
632 {
633   ec_priv *ep = kd;
634   dh_encctx *de = ec_doinit(k, ep->cstr, &ep->p);
635   de->x = MP_COPY(ep->x);
636   return (&de->k);
637 }
638
639 static int dh_decdoit(kem *k, dstr *d, ghash *h)
640 {
641   dh_encctx *de = (dh_encctx *)k;
642   ge *x = G_CREATE(de->g);
643   size_t n = de->g->noctets;
644   void *p = xmalloc(n);
645   buf b;
646   int rc = -1;
647
648   buf_init(&b, d->buf, d->len);
649   if (G_FROMRAW(de->g, &b, x) || group_check(de->g, x))
650     goto done;
651   G_EXP(de->g, x, x, de->x);
652   buf_init(&b, p, n);
653   G_TORAW(de->g, &b, x);
654   GH_HASH(h, BBASE(&b), BLEN(&b));
655   GH_HASH(h, d->buf, d->len);
656   rc = 0;
657 done:
658   G_DESTROY(de->g, x);
659   xfree(p);
660   return (rc);
661 }
662
663 static const kemops dh_decops = {
664   dh_privfetch, sizeof(dh_priv),
665   dh_decinit, dh_decdoit, dh_enccheck, dh_encdestroy
666 };
667
668 static const kemops bindh_decops = {
669   dh_privfetch, sizeof(dh_priv),
670   bindh_decinit, dh_decdoit, dh_enccheck, dh_encdestroy
671 };
672
673 static const kemops ec_decops = {
674   ec_privfetch, sizeof(ec_priv),
675   ec_decinit, dh_decdoit, dh_enccheck, dh_encdestroy
676 };
677
678 /* --- X25519 and similar schemes --- */
679
680 #define XDHS(_)                                                         \
681   _(x25519, X25519)                                                     \
682   _(x448, X448)
683
684 #define XDHDEF(xdh, XDH)                                                \
685                                                                         \
686   static kem *xdh##_encinit(key *k, void *kd) { return (CREATE(kem)); } \
687   static void xdh##_encdestroy(kem *k) { DESTROY(k); }                  \
688                                                                         \
689   static const char *xdh##_enccheck(kem *k)                             \
690   {                                                                     \
691     xdh##_pub *kd = k->kd;                                              \
692                                                                         \
693     if (kd->pub.sz != XDH##_PUBSZ)                                      \
694       return ("incorrect " #XDH "public key length");                   \
695     return (0);                                                         \
696   }                                                                     \
697                                                                         \
698   static int xdh##_encdoit(kem *k, dstr *d, ghash *h)                   \
699   {                                                                     \
700     octet t[XDH##_KEYSZ], z[XDH##_OUTSZ];                               \
701     xdh##_pub *kd = k->kd;                                              \
702                                                                         \
703     rand_get(RAND_GLOBAL, t, sizeof(t));                                \
704     dstr_ensure(d, XDH##_PUBSZ);                                        \
705     xdh((octet *)d->buf, t, xdh##_base);                                \
706     xdh(z, t, kd->pub.k);                                               \
707     d->len += XDH##_PUBSZ;                                              \
708     GH_HASH(h, d->buf, XDH##_PUBSZ);                                    \
709     GH_HASH(h, z, XDH##_OUTSZ);                                         \
710     return (0);                                                         \
711   }                                                                     \
712                                                                         \
713   static const char *xdh##_deccheck(kem *k)                             \
714   {                                                                     \
715     xdh##_priv *kd = k->kd;                                             \
716                                                                         \
717     if (kd->priv.sz != XDH##_KEYSZ)                                     \
718       return ("incorrect " #XDH " private key length");                 \
719     if (kd->pub.sz != XDH##_PUBSZ)                                      \
720       return ("incorrect " #XDH " public key length");                  \
721     return (0);                                                         \
722   }                                                                     \
723                                                                         \
724   static int xdh##_decdoit(kem *k, dstr *d, ghash *h)                   \
725   {                                                                     \
726     octet z[XDH##_OUTSZ];                                               \
727     xdh##_priv *kd = k->kd;                                             \
728     int rc = -1;                                                        \
729                                                                         \
730     if (d->len != XDH##_PUBSZ) goto done;                               \
731     xdh(z, kd->priv.k, (const octet *)d->buf);                          \
732     GH_HASH(h, d->buf, XDH##_PUBSZ);                                    \
733     GH_HASH(h, z, XDH##_OUTSZ);                                         \
734     rc = 0;                                                             \
735   done:                                                                 \
736     return (rc);                                                        \
737   }                                                                     \
738                                                                         \
739   static const kemops xdh##_encops = {                                  \
740     xdh##_pubfetch, sizeof(xdh##_pub),                                  \
741     xdh##_encinit, xdh##_encdoit, xdh##_enccheck, xdh##_encdestroy      \
742   };                                                                    \
743                                                                         \
744   static const kemops xdh##_decops = {                                  \
745     xdh##_privfetch, sizeof(xdh##_priv),                                \
746     xdh##_encinit, xdh##_decdoit, xdh##_deccheck, xdh##_encdestroy      \
747   };
748
749 XDHS(XDHDEF)
750 #undef XDHDEF
751
752 /* --- Symmetric --- */
753
754 typedef struct symm_ctx {
755   kem k;
756   key_packdef kp;
757   key_bin kb;
758 } symm_ctx;
759
760 static kem *symm_init(key *k, void *kd)
761 {
762   symm_ctx *s;
763   dstr d = DSTR_INIT;
764   int err;
765
766   s = CREATE(symm_ctx);
767
768   key_fulltag(k, &d);
769   s->kp.e = KENC_BINARY;
770   s->kp.p = &s->kb;
771   s->kp.kd = 0;
772
773   if ((err = key_unpack(&s->kp, kd, &d)) != 0) {
774     die(EXIT_FAILURE, "failed to unpack symmetric key `%s': %s",
775         d.buf, key_strerror(err));
776   }
777   dstr_destroy(&d);
778   return (&s->k);
779 }
780
781 static int symm_decdoit(kem *k, dstr *d, ghash *h)
782 {
783   symm_ctx *s = (symm_ctx *)k;
784
785   GH_HASH(h, s->kb.k, s->kb.sz);
786   GH_HASH(h, d->buf, d->len);
787   return (0);
788 }
789
790 static int symm_encdoit(kem *k, dstr *d, ghash *h)
791 {
792   dstr_ensure(d, h->ops->c->hashsz);
793   d->len += h->ops->c->hashsz;
794   rand_get(RAND_GLOBAL, d->buf, d->len);
795   return (symm_decdoit(k, d, h));
796 }
797
798 static const char *symm_check(kem *k) { return (0); }
799
800 static void symm_destroy(kem *k)
801   { symm_ctx *s = (symm_ctx *)k; key_unpackdone(&s->kp); }
802
803 static const kemops symm_encops = {
804   0, 0,
805   symm_init, symm_encdoit, symm_check, symm_destroy
806 };
807
808 static const kemops symm_decops = {
809   0, 0,
810   symm_init, symm_decdoit, symm_check, symm_destroy
811 };
812
813 /* --- The switch table --- */
814
815 const struct kemtab kemtab[] = {
816   { "rsa",      &rsa_encops,    &rsa_decops },
817   { "dh",       &dh_encops,     &dh_decops },
818   { "bindh",    &bindh_encops,  &bindh_decops },
819   { "ec",       &ec_encops,     &ec_decops },
820 #define XDHTAB(xdh, XDH)                                                \
821   { #xdh,       &xdh##_encops,  &xdh##_decops },
822   XDHS(XDHTAB)
823 #undef XDHTAB
824   { "symm",     &symm_encops,   &symm_decops },
825   { 0,          0,              0 }
826 };
827
828 /* --- @getkem@ --- *
829  *
830  * Arguments:   @key *k@ = the key to load
831  *              @const char *app@ = application name
832  *              @int wantpriv@ = nonzero if we want to decrypt
833  *              @bulk **bc@ = bulk crypto context to set up
834  *
835  * Returns:     A key-encapsulating thing.
836  *
837  * Use:         Loads a key.
838  */
839
840 kem *getkem(key *k, const char *app, int wantpriv, bulk **bc)
841 {
842   const char *kalg, *halg = 0, *balg = 0;
843   dstr d = DSTR_INIT;
844   dstr t = DSTR_INIT;
845   size_t n;
846   char *p = 0;
847   const char *q;
848   kem *kk;
849   const struct kemtab *kt;
850   const kemops *ko;
851   const struct bulktab *bt;
852   const bulkops *bo;
853   void *kd;
854   int e;
855   key_packdef *kp;
856
857   /* --- Setup stuff --- */
858
859   key_fulltag(k, &t);
860
861   /* --- Get the KEM name --- *
862    *
863    * Take the attribute if it's there; otherwise use the key type.
864    */
865
866   n = strlen(app);
867   if ((q = key_getattr(0, k, "kem")) != 0) {
868     dstr_puts(&d, q);
869     p = d.buf;
870   } else if (STRNCMP(k->type, ==, app, n) && k->type[n] == '-') {
871     dstr_puts(&d, k->type);
872     p = d.buf + n + 1;
873   } else
874     die(EXIT_FAILURE, "no KEM for key `%s'", t.buf);
875   kalg = p;
876
877   /* --- Grab the bulk encryption scheme --- *
878    *
879    * Grab it from the KEM if it's there, but override it from the attribute.
880    */
881
882   if (p && (p = strchr(p, '/')) != 0) {
883     *p++ = 0;
884     balg = p;
885   }
886   if ((q = key_getattr(0, k, "bulk")) != 0)
887     balg = q;
888
889   /* --- Grab the hash function --- */
890
891   if (p && (p = strchr(p, '/')) != 0) {
892     *p++ = 0;
893     halg = p;
894   }
895   if ((q = key_getattr(0, k, "hash")) != 0)
896     halg = q;
897
898   /* --- Instantiate the KEM --- */
899
900   for (kt = kemtab; kt->name; kt++) {
901     if (STRCMP(kt->name, ==, kalg))
902       goto k_found;
903   }
904   die(EXIT_FAILURE, "key encapsulation mechanism `%s' not found in key `%s'",
905       kalg, t.buf);
906 k_found:;
907   ko = wantpriv ? kt->decops : kt->encops;
908   if (!ko->kf) {
909     kd = k->k;
910     key_incref(kd);
911     kp = 0;
912   } else {
913     kd = xmalloc(ko->kdsz);
914     kp = key_fetchinit(ko->kf, 0, kd);
915     if ((e = key_fetch(kp, k)) != 0) {
916       die(EXIT_FAILURE, "error fetching key `%s': %s",
917           t.buf, key_strerror(e));
918     }
919   }
920   kk = ko->init(k, kd);
921   kk->kp = kp;
922   kk->ops = ko;
923   kk->kd = kd;
924
925   /* --- Set up the bulk crypto --- */
926
927   if (!halg)
928     kk->hc = &rmd160;
929   else if ((kk->hc = ghash_byname(halg)) == 0) {
930     die(EXIT_FAILURE, "hash algorithm `%s' not found in key `%s'",
931         halg, t.buf);
932   }
933
934   if (!balg)
935     bt = bulktab;
936   else {
937     for (bt = bulktab, bo = 0; bt->name; bt++) {
938       if (STRCMP(balg, ==, bt->name))
939         { balg = 0; goto b_found; }
940       n = strlen(bt->name);
941       if (STRNCMP(balg, ==, bt->name, n) && balg[n] == '-')
942         { balg += n + 1; goto b_found; }
943     }
944     bt = bulktab;
945   b_found:;
946   }
947   bo = wantpriv ? bt->decops : bt->encops;
948   *bc = bo->init(k, balg, kk->hc->name);
949   (*bc)->ops = bo;
950
951   dstr_reset(&d);
952   if ((q = key_getattr(0, k, "kdf")) == 0) {
953     dstr_putf(&d, "%s-mgf", kk->hc->name);
954     q = d.buf;
955   }
956   if ((kk->cxc = gcipher_byname(q)) == 0) {
957     die(EXIT_FAILURE, "encryption scheme (KDF) `%s' not found in key `%s'",
958         q, t.buf);
959   }
960
961   /* --- Tidy up --- */
962
963   dstr_destroy(&d);
964   dstr_destroy(&t);
965   return (kk);
966 }
967
968 /* --- @setupkem@ --- *
969  *
970  * Arguments:   @kem *k@ = key-encapsulation thing
971  *              @dstr *d@ = key-encapsulation data
972  *              @bulk *bc@ = bulk crypto context to set up
973  *
974  * Returns:     Zero on success, nonzero on failure.
975  *
976  * Use:         Initializes all the various symmetric things from a KEM.
977  */
978
979 int setupkem(kem *k, dstr *d, bulk *bc)
980 {
981   octet *kd;
982   size_t n;
983   ghash *h;
984   int rc = -1;
985
986   h = GH_INIT(k->hc);
987   if (k->ops->doit(k, d, h))
988     goto done;
989   n = keysz(GH_CLASS(h)->hashsz, k->cxc->keysz);
990   if (!n)
991     goto done;
992   kd = GH_DONE(h, 0);
993   k->cx = GC_INIT(k->cxc, kd, n);
994   bc->ops->setup(bc, k->cx);
995
996   rc = 0;
997 done:
998   GH_DESTROY(h);
999   return (rc);
1000 }
1001
1002 /* --- @freekem@ --- *
1003  *
1004  * Arguments:   @kem *k@ = key-encapsulation thing
1005  *
1006  * Returns:     ---
1007  *
1008  * Use:         Frees up a key-encapsulation thing.
1009  */
1010
1011 void freekem(kem *k)
1012 {
1013   if (!k->ops->kf)
1014     key_drop(k->kd);
1015   else {
1016     key_fetchdone(k->kp);
1017     xfree(k->kd);
1018   }
1019   GC_DESTROY(k->cx);
1020   k->ops->destroy(k);
1021 }
1022
1023 /*----- That's all, folks -------------------------------------------------*/