chiark / gitweb /
Version bump.
[catacomb] / mpmont.c
1 /* -*-c-*-
2  *
3  * $Id: mpmont.c,v 1.9 2000/06/17 11:45:09 mdw Exp $
4  *
5  * Montgomery reduction
6  *
7  * (c) 1999 Straylight/Edgeware
8  */
9
10 /*----- Licensing notice --------------------------------------------------* 
11  *
12  * This file is part of Catacomb.
13  *
14  * Catacomb is free software; you can redistribute it and/or modify
15  * it under the terms of the GNU Library General Public License as
16  * published by the Free Software Foundation; either version 2 of the
17  * License, or (at your option) any later version.
18  * 
19  * Catacomb is distributed in the hope that it will be useful,
20  * but WITHOUT ANY WARRANTY; without even the implied warranty of
21  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
22  * GNU Library General Public License for more details.
23  * 
24  * You should have received a copy of the GNU Library General Public
25  * License along with Catacomb; if not, write to the Free
26  * Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
27  * MA 02111-1307, USA.
28  */
29
30 /*----- Revision history --------------------------------------------------* 
31  *
32  * $Log: mpmont.c,v $
33  * Revision 1.9  2000/06/17 11:45:09  mdw
34  * Major memory management overhaul.  Added arena support.  Use the secure
35  * arena for secret integers.  Replace and improve the MP management macros
36  * (e.g., replace MP_MODIFY by MP_DEST).
37  *
38  * Revision 1.8  1999/12/22 15:55:00  mdw
39  * Adjust Karatsuba parameters.
40  *
41  * Revision 1.7  1999/12/11 01:51:14  mdw
42  * Use a Karatsuba-based reduction for large moduli.
43  *
44  * Revision 1.6  1999/12/10 23:18:39  mdw
45  * Change interface for suggested destinations.
46  *
47  * Revision 1.5  1999/11/22 13:58:40  mdw
48  * Add an option to disable Montgomery reduction, so that performance
49  * comparisons can be done.
50  *
51  * Revision 1.4  1999/11/21 12:27:06  mdw
52  * Remove a division from the Montgomery setup by calculating
53  * %$R^2 \bmod m$% first and then %$R \bmod m$% by Montgomery reduction of
54  * %$R^2$%.
55  *
56  * Revision 1.3  1999/11/21 11:35:10  mdw
57  * Performance improvement: use @mp_sqr@ and @mpmont_reduce@ instead of
58  * @mpmont_mul@ for squaring in exponentiation.
59  *
60  * Revision 1.2  1999/11/19 13:17:26  mdw
61  * Add extra interface to exponentiation which returns a Montgomerized
62  * result.
63  *
64  * Revision 1.1  1999/11/17 18:02:16  mdw
65  * New multiprecision integer arithmetic suite.
66  *
67  */
68
69 /*----- Header files ------------------------------------------------------*/
70
71 #include "mp.h"
72 #include "mpmont.h"
73
74 /*----- Tweakables --------------------------------------------------------*/
75
76 /* --- @MPMONT_DISABLE@ --- *
77  *
78  * Replace all the clever Montgomery reduction with good old-fashioned long
79  * division.
80  */
81
82 /* #define MPMONT_DISABLE */
83
84 /*----- Main code ---------------------------------------------------------*/
85
86 /* --- @mpmont_create@ --- *
87  *
88  * Arguments:   @mpmont *mm@ = pointer to Montgomery reduction context
89  *              @mp *m@ = modulus to use
90  *
91  * Returns:     ---
92  *
93  * Use:         Initializes a Montgomery reduction context ready for use.
94  *              The argument @m@ must be a positive odd integer.
95  */
96
97 #ifdef MPMONT_DISABLE
98
99 void mpmont_create(mpmont *mm, mp *m)
100 {
101   mp_shrink(m);
102   mm->m = MP_COPY(m);
103   mm->r = MP_ONE;
104   mm->r2 = MP_ONE;
105   mm->mi = MP_ONE;
106 }
107
108 #else
109
110 void mpmont_create(mpmont *mm, mp *m)
111 {
112   size_t n = MP_LEN(m);
113   mp *r2 = mp_new(2 * n + 1, 0);
114   mp r;
115
116   /* --- Validate the arguments --- */
117
118   assert(((void)"Montgomery modulus must be positive",
119           (m->f & MP_NEG) == 0));
120   assert(((void)"Montgomery modulus must be odd", m->v[0] & 1));
121
122   /* --- Take a copy of the modulus --- */
123
124   mp_shrink(m);
125   mm->m = MP_COPY(m);
126
127   /* --- Determine %$R^2$% --- */
128
129   mm->n = n;
130   MPX_ZERO(r2->v, r2->vl - 1);
131   r2->vl[-1] = 1;
132
133   /* --- Find the magic value @mi@ --- */
134
135   mp_build(&r, r2->v + n, r2->vl);
136   mm->mi = MP_NEW;
137   mp_gcd(0, 0, &mm->mi, &r, m);
138   mm->mi = mp_sub(mm->mi, &r, mm->mi);
139
140   /* --- Discover the values %$R \bmod m$% and %$R^2 \bmod m$% --- */
141
142   mm->r2 = MP_NEW;
143   mp_div(0, &mm->r2, r2, m);
144   mm->r = mpmont_reduce(mm, MP_NEW, mm->r2);
145   MP_DROP(r2);
146 }
147
148 #endif
149
150 /* --- @mpmont_destroy@ --- *
151  *
152  * Arguments:   @mpmont *mm@ = pointer to a Montgomery reduction context
153  *
154  * Returns:     ---
155  *
156  * Use:         Disposes of a context when it's no longer of any use to
157  *              anyone.
158  */
159
160 void mpmont_destroy(mpmont *mm)
161 {
162   MP_DROP(mm->m);
163   MP_DROP(mm->r);
164   MP_DROP(mm->r2);
165   MP_DROP(mm->mi);
166 }
167
168 /* --- @mpmont_reduce@ --- *
169  *
170  * Arguments:   @mpmont *mm@ = pointer to Montgomery reduction context
171  *              @mp *d@ = destination
172  *              @mp *a@ = source, assumed positive
173  *
174  * Returns:     Result, %$a R^{-1} \bmod m$%.
175  */
176
177 #ifdef MPMONT_DISABLE
178
179 mp *mpmont_reduce(mpmont *mm, mp *d, mp *a)
180 {
181   mp_div(0, &d, a, mm->m);
182   return (d);
183 }
184
185 #else
186
187 mp *mpmont_reduce(mpmont *mm, mp *d, mp *a)
188 {
189   size_t n = mm->n;
190
191   /* --- Check for serious Karatsuba reduction --- */
192
193   if (n > KARATSUBA_CUTOFF * 3) {
194     mp al;
195     mpw *vl;
196     mp *u;
197
198     if (MP_LEN(a) >= n)
199       vl = a->v + n;
200     else
201       vl = a->vl;
202     mp_build(&al, a->v, vl);
203     u = mp_mul(MP_NEW, &al, mm->mi);
204     if (MP_LEN(u) > n)
205       u->vl = u->v + n;
206     u = mp_mul(u, u, mm->m);
207     d = mp_add(d, a, u);
208     mp_drop(u);
209   }
210
211   /* --- Otherwise do it the hard way --- */
212
213   else {
214     mpw *dv, *dvl;
215     mpw *mv, *mvl;
216     mpw mi;
217     size_t k = n;
218
219     /* --- Initial conditioning of the arguments --- */
220
221     a = MP_COPY(a);
222     if (d)
223       MP_DROP(d);
224     d = a;
225     MP_DEST(d, 2 * n + 1, a->f);
226
227     dv = d->v; dvl = d->vl;
228     mv = mm->m->v; mvl = mm->m->vl;
229
230     /* --- Let's go to work --- */
231
232     mi = mm->mi->v[0];
233     while (k--) {
234       mpw u = MPW(*dv * mi);
235       MPX_UMLAN(dv, dvl, mv, mvl, u);
236       dv++;
237     }
238   }
239
240   /* --- Wrap everything up --- */
241
242   memmove(d->v, d->v + n, MPWS(MP_LEN(d) - n));
243   d->vl -= n;
244   if (MP_CMP(d, >=, mm->m))
245     d = mp_sub(d, d, mm->m);
246   MP_SHRINK(d);
247   return (d);
248 }
249
250 #endif
251
252 /* --- @mpmont_mul@ --- *
253  *
254  * Arguments:   @mpmont *mm@ = pointer to Montgomery reduction context
255  *              @mp *d@ = destination
256  *              @mp *a, *b@ = sources, assumed positive
257  *
258  * Returns:     Result, %$a b R^{-1} \bmod m$%.
259  */
260
261 #ifdef MPMONT_DISABLE
262
263 mp *mpmont_mul(mpmont *mm, mp *d, mp *a, mp *b)
264 {
265   d = mp_mul(d, a, b);
266   mp_div(0, &d, d, mm->m);
267   return (d);
268 }
269
270 #else
271
272 mp *mpmont_mul(mpmont *mm, mp *d, mp *a, mp *b)
273 {
274   if (mm->n > KARATSUBA_CUTOFF * 3) {
275     d = mp_mul(d, a, b);
276     d = mpmont_reduce(mm, d, d);
277   } else {
278     mpw *dv, *dvl;
279     mpw *av, *avl;
280     mpw *bv, *bvl;
281     mpw *mv, *mvl;
282     mpw y;
283     size_t n, i;
284     mpw mi;
285
286     /* --- Initial conditioning of the arguments --- */
287
288     if (MP_LEN(a) > MP_LEN(b)) {
289       mp *t = a; a = b; b = t;
290     }
291     n = MP_LEN(mm->m);
292
293     a = MP_COPY(a);
294     b = MP_COPY(b);
295     MP_DEST(d, 2 * n + 1, a->f | b->f | MP_UNDEF);
296     dv = d->v; dvl = d->vl;
297     MPX_ZERO(dv, dvl);
298     av = a->v; avl = a->vl;
299     bv = b->v; bvl = b->vl;
300     mv = mm->m->v; mvl = mm->m->vl;
301     y = *bv;
302
303     /* --- Montgomery multiplication phase --- */
304
305     i = 0;
306     mi = mm->mi->v[0];
307     while (i < n && av < avl) {
308       mpw x = *av++;
309       mpw u = MPW((*dv + x * y) * mi);
310       MPX_UMLAN(dv, dvl, bv, bvl, x);
311       MPX_UMLAN(dv, dvl, mv, mvl, u);
312       dv++;
313       i++;
314     }
315
316     /* --- Simpler Montgomery reduction phase --- */
317
318     while (i < n) {
319       mpw u = MPW(*dv * mi);
320       MPX_UMLAN(dv, dvl, mv, mvl, u);
321       dv++;
322       i++;
323     }
324
325     /* --- Done --- */
326
327     memmove(d->v, dv, MPWS(dvl - dv));
328     d->vl -= dv - d->v;
329     MP_SHRINK(d);
330     d->f = (a->f | b->f) & MP_BURN;
331     if (MP_CMP(d, >=, mm->m))
332       d = mp_sub(d, d, mm->m);
333     MP_DROP(a);
334     MP_DROP(b);
335   }
336
337   return (d);
338 }
339
340 #endif
341
342 /* --- @mpmont_expr@ --- *
343  *
344  * Arguments:   @mpmont *mm@ = pointer to Montgomery reduction context
345  *              @mp *d@ = fake destination
346  *              @mp *a@ = base
347  *              @mp *e@ = exponent
348  *
349  * Returns:     Result, %$a^e R \bmod m$%.
350  */
351
352 mp *mpmont_expr(mpmont *mm, mp *d, mp *a, mp *e)
353 {
354   mpscan sc;
355   mp *ar = mpmont_mul(mm, MP_NEW, a, mm->r2);
356   mp *x = MP_COPY(mm->r);
357   mp *spare = (e->f & MP_BURN) ? MP_NEWSEC : MP_NEW;
358
359   mp_scan(&sc, e);
360
361   if (MP_STEP(&sc)) {
362     size_t sq = 0;
363     for (;;) {
364       mp *dd;
365       if (MP_BIT(&sc)) {
366         while (sq) {
367           dd = mp_sqr(spare, ar);
368           dd = mpmont_reduce(mm, dd, dd);
369           spare = ar; ar = dd;
370           sq--;
371         }
372         dd = mpmont_mul(mm, spare, x, ar);
373         spare = x; x = dd;
374       }
375       sq++;
376       if (!MP_STEP(&sc))
377         break;
378     }
379   }
380   MP_DROP(ar);
381   if (spare != MP_NEW)
382     MP_DROP(spare);
383   if (d != MP_NEW)
384     MP_DROP(d);
385   return (x);
386 }
387
388 /* --- @mpmont_exp@ --- *
389  *
390  * Arguments:   @mpmont *mm@ = pointer to Montgomery reduction context
391  *              @mp *d@ = fake destination
392  *              @mp *a@ = base
393  *              @mp *e@ = exponent
394  *
395  * Returns:     Result, %$a^e \bmod m$%.
396  */
397
398 mp *mpmont_exp(mpmont *mm, mp *d, mp *a, mp *e)
399 {
400   d = mpmont_expr(mm, d, a, e);
401   d = mpmont_reduce(mm, d, d);
402   return (d);
403 }
404
405 /*----- Test rig ----------------------------------------------------------*/
406
407 #ifdef TEST_RIG
408
409 static int tcreate(dstr *v)
410 {
411   mp *m = *(mp **)v[0].buf;
412   mp *mi = *(mp **)v[1].buf;
413   mp *r = *(mp **)v[2].buf;
414   mp *r2 = *(mp **)v[3].buf;
415
416   mpmont mm;
417   int ok = 1;
418
419   mpmont_create(&mm, m);
420
421   if (mm.mi->v[0] != mi->v[0]) {
422     fprintf(stderr, "\n*** bad mi: found %lu, expected %lu",
423             (unsigned long)mm.mi->v[0], (unsigned long)mi->v[0]);
424     fputs("\nm = ", stderr); mp_writefile(m, stderr, 10);
425     fputc('\n', stderr);
426     ok = 0;
427   }
428
429   if (MP_CMP(mm.r, !=, r)) {
430     fputs("\n*** bad r", stderr);
431     fputs("\nm = ", stderr); mp_writefile(m, stderr, 10);
432     fputs("\nexpected ", stderr); mp_writefile(r, stderr, 10);
433     fputs("\n   found ", stderr); mp_writefile(mm.r, stderr, 10);
434     fputc('\n', stderr);
435     ok = 0;
436   }
437
438   if (MP_CMP(mm.r2, !=, r2)) {
439     fputs("\n*** bad r2", stderr);
440     fputs("\nm = ", stderr); mp_writefile(m, stderr, 10);
441     fputs("\nexpected ", stderr); mp_writefile(r2, stderr, 10);
442     fputs("\n   found ", stderr); mp_writefile(mm.r2, stderr, 10);
443     fputc('\n', stderr);
444     ok = 0;
445   }
446
447   MP_DROP(m);
448   MP_DROP(mi);
449   MP_DROP(r);
450   MP_DROP(r2);
451   mpmont_destroy(&mm);
452   assert(mparena_count(MPARENA_GLOBAL) == 0);
453   return (ok);
454 }
455
456 static int tmul(dstr *v)
457 {
458   mp *m = *(mp **)v[0].buf;
459   mp *a = *(mp **)v[1].buf;
460   mp *b = *(mp **)v[2].buf;
461   mp *r = *(mp **)v[3].buf;
462   int ok = 1;
463
464   mpmont mm;
465   mpmont_create(&mm, m);
466
467   {
468     mp *qr = mp_mul(MP_NEW, a, b);
469     mp_div(0, &qr, qr, m);
470
471     if (MP_CMP(qr, !=, r)) {
472       fputs("\n*** classical modmul failed", stderr);
473       fputs("\n m = ", stderr); mp_writefile(m, stderr, 10);
474       fputs("\n a = ", stderr); mp_writefile(a, stderr, 10);
475       fputs("\n b = ", stderr); mp_writefile(b, stderr, 10);
476       fputs("\n r = ", stderr); mp_writefile(r, stderr, 10);
477       fputs("\nqr = ", stderr); mp_writefile(qr, stderr, 10);
478       fputc('\n', stderr);
479       ok = 0;
480     }
481
482     mp_drop(qr);
483   }
484
485   {
486     mp *ar = mpmont_mul(&mm, MP_NEW, a, mm.r2);
487     mp *br = mpmont_mul(&mm, MP_NEW, b, mm.r2);
488     mp *mr = mpmont_mul(&mm, MP_NEW, ar, br);
489     mr = mpmont_reduce(&mm, mr, mr);
490     if (MP_CMP(mr, !=, r)) {
491       fputs("\n*** montgomery modmul failed", stderr);
492       fputs("\n m = ", stderr); mp_writefile(m, stderr, 10);
493       fputs("\n a = ", stderr); mp_writefile(a, stderr, 10);
494       fputs("\n b = ", stderr); mp_writefile(b, stderr, 10);
495       fputs("\n r = ", stderr); mp_writefile(r, stderr, 10);
496       fputs("\nmr = ", stderr); mp_writefile(mr, stderr, 10);
497       fputc('\n', stderr);
498       ok = 0;
499     }
500     MP_DROP(ar); MP_DROP(br);
501     mp_drop(mr);
502   }
503
504
505   MP_DROP(m);
506   MP_DROP(a);
507   MP_DROP(b);
508   MP_DROP(r);
509   mpmont_destroy(&mm);
510   assert(mparena_count(MPARENA_GLOBAL) == 0);
511   return ok;
512 }
513
514 static int texp(dstr *v)
515 {
516   mp *m = *(mp **)v[0].buf;
517   mp *a = *(mp **)v[1].buf;
518   mp *b = *(mp **)v[2].buf;
519   mp *r = *(mp **)v[3].buf;
520   mp *mr;
521   int ok = 1;
522
523   mpmont mm;
524   mpmont_create(&mm, m);
525
526   mr = mpmont_exp(&mm, MP_NEW, a, b);
527
528   if (MP_CMP(mr, !=, r)) {
529     fputs("\n*** montgomery modexp failed", stderr);
530     fputs("\n m = ", stderr); mp_writefile(m, stderr, 10);
531     fputs("\n a = ", stderr); mp_writefile(a, stderr, 10);
532     fputs("\n e = ", stderr); mp_writefile(b, stderr, 10);
533     fputs("\n r = ", stderr); mp_writefile(r, stderr, 10);
534     fputs("\nmr = ", stderr); mp_writefile(mr, stderr, 10);
535     fputc('\n', stderr);
536     ok = 0;
537   }
538
539   MP_DROP(m);
540   MP_DROP(a);
541   MP_DROP(b);
542   MP_DROP(r);
543   MP_DROP(mr);
544   mpmont_destroy(&mm);
545   assert(mparena_count(MPARENA_GLOBAL) == 0);
546   return ok;
547 }
548
549
550 static test_chunk tests[] = {
551   { "create", tcreate, { &type_mp, &type_mp, &type_mp, &type_mp, 0 } },
552   { "mul", tmul, { &type_mp, &type_mp, &type_mp, &type_mp, 0 } },
553   { "exp", texp, { &type_mp, &type_mp, &type_mp, &type_mp, 0 } },
554   { 0, 0, { 0 } },
555 };
556
557 int main(int argc, char *argv[])
558 {
559   sub_init();
560   test_run(argc, argv, tests, SRCDIR "/tests/mpmont");
561   return (0);
562 }
563
564 #endif
565
566 /*----- That's all, folks -------------------------------------------------*/