chiark / gitweb /
ec-bin (ec_binproj): Make curve setup faster.
[catacomb] / mpmont.c
1 /* -*-c-*-
2  *
3  * $Id$
4  *
5  * Montgomery reduction
6  *
7  * (c) 1999 Straylight/Edgeware
8  */
9
10 /*----- Licensing notice --------------------------------------------------* 
11  *
12  * This file is part of Catacomb.
13  *
14  * Catacomb is free software; you can redistribute it and/or modify
15  * it under the terms of the GNU Library General Public License as
16  * published by the Free Software Foundation; either version 2 of the
17  * License, or (at your option) any later version.
18  * 
19  * Catacomb is distributed in the hope that it will be useful,
20  * but WITHOUT ANY WARRANTY; without even the implied warranty of
21  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
22  * GNU Library General Public License for more details.
23  * 
24  * You should have received a copy of the GNU Library General Public
25  * License along with Catacomb; if not, write to the Free
26  * Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
27  * MA 02111-1307, USA.
28  */
29
30 /*----- Header files ------------------------------------------------------*/
31
32 #include "mp.h"
33 #include "mpmont.h"
34
35 /*----- Tweakables --------------------------------------------------------*/
36
37 /* --- @MPMONT_DISABLE@ --- *
38  *
39  * Replace all the clever Montgomery reduction with good old-fashioned long
40  * division.
41  */
42
43 /* #define MPMONT_DISABLE */
44
45 /*----- Reduction and multiplication --------------------------------------*/
46
47 /* --- @mpmont_create@ --- *
48  *
49  * Arguments:   @mpmont *mm@ = pointer to Montgomery reduction context
50  *              @mp *m@ = modulus to use
51  *
52  * Returns:     Zero on success, nonzero on error.
53  *
54  * Use:         Initializes a Montgomery reduction context ready for use.
55  *              The argument @m@ must be a positive odd integer.
56  */
57
58 #ifdef MPMONT_DISABLE
59
60 int mpmont_create(mpmont *mm, mp *m)
61 {
62   mp_shrink(m);
63   mm->m = MP_COPY(m);
64   mm->r = MP_ONE;
65   mm->r2 = MP_ONE;
66   mm->mi = MP_ONE;
67   return (0);
68 }
69
70 #else
71
72 int mpmont_create(mpmont *mm, mp *m)
73 {
74   size_t n = MP_LEN(m);
75   mp *r2 = mp_new(2 * n + 1, 0);
76   mp r;
77
78   /* --- Take a copy of the modulus --- */
79
80  if (!MP_POSP(m) || !MP_ODDP(m))
81    return (-1);
82   mm->m = MP_COPY(m);
83
84   /* --- Determine %$R^2$% --- */
85
86   mm->n = n;
87   MPX_ZERO(r2->v, r2->vl - 1);
88   r2->vl[-1] = 1;
89
90   /* --- Find the magic value @mi@ --- */
91
92   mp_build(&r, r2->v + n, r2->vl);
93   mm->mi = mp_modinv(MP_NEW, m, &r);
94   mm->mi = mp_sub(mm->mi, &r, mm->mi);
95
96   /* --- Discover the values %$R \bmod m$% and %$R^2 \bmod m$% --- */
97
98   mm->r2 = MP_NEW;
99   mp_div(0, &mm->r2, r2, m);
100   mm->r = mpmont_reduce(mm, MP_NEW, mm->r2);
101   MP_DROP(r2);
102   return (0);
103 }
104
105 #endif
106
107 /* --- @mpmont_destroy@ --- *
108  *
109  * Arguments:   @mpmont *mm@ = pointer to a Montgomery reduction context
110  *
111  * Returns:     ---
112  *
113  * Use:         Disposes of a context when it's no longer of any use to
114  *              anyone.
115  */
116
117 void mpmont_destroy(mpmont *mm)
118 {
119   MP_DROP(mm->m);
120   MP_DROP(mm->r);
121   MP_DROP(mm->r2);
122   MP_DROP(mm->mi);
123 }
124
125 /* --- @mpmont_reduce@ --- *
126  *
127  * Arguments:   @mpmont *mm@ = pointer to Montgomery reduction context
128  *              @mp *d@ = destination
129  *              @mp *a@ = source, assumed positive
130  *
131  * Returns:     Result, %$a R^{-1} \bmod m$%.
132  */
133
134 #ifdef MPMONT_DISABLE
135
136 mp *mpmont_reduce(mpmont *mm, mp *d, mp *a)
137 {
138   mp_div(0, &d, a, mm->m);
139   return (d);
140 }
141
142 #else
143
144 mp *mpmont_reduce(mpmont *mm, mp *d, mp *a)
145 {
146   size_t n = mm->n;
147
148   /* --- Check for serious Karatsuba reduction --- */
149
150   if (n > MPK_THRESH * 3) {
151     mp al;
152     mpw *vl;
153     mp *u;
154
155     if (MP_LEN(a) >= n)
156       vl = a->v + n;
157     else
158       vl = a->vl;
159     mp_build(&al, a->v, vl);
160     u = mp_mul(MP_NEW, &al, mm->mi);
161     if (MP_LEN(u) > n)
162       u->vl = u->v + n;
163     u = mp_mul(u, u, mm->m);
164     d = mp_add(d, a, u);
165     mp_drop(u);
166   }
167
168   /* --- Otherwise do it the hard way --- */
169
170   else {
171     mpw *dv, *dvl;
172     mpw *mv, *mvl;
173     mpw mi;
174     size_t k = n;
175
176     /* --- Initial conditioning of the arguments --- */
177
178     a = MP_COPY(a);
179     if (d)
180       MP_DROP(d);
181     d = a;
182     MP_DEST(d, 2 * n + 1, a->f);
183
184     dv = d->v; dvl = d->vl;
185     mv = mm->m->v; mvl = mm->m->vl;
186
187     /* --- Let's go to work --- */
188
189     mi = mm->mi->v[0];
190     while (k--) {
191       mpw u = MPW(*dv * mi);
192       MPX_UMLAN(dv, dvl, mv, mvl, u);
193       dv++;
194     }
195   }
196
197   /* --- Wrap everything up --- */
198
199   memmove(d->v, d->v + n, MPWS(MP_LEN(d) - n));
200   d->vl -= n;
201   if (MPX_UCMP(d->v, d->vl, >=, mm->m->v, mm->m->vl))
202     mpx_usub(d->v, d->vl, d->v, d->vl, mm->m->v, mm->m->vl);
203   if (d->f & MP_NEG) {
204     mpx_usub(d->v, d->vl, mm->m->v, mm->m->vl, d->v, d->vl);
205     d->f &= ~MP_NEG;
206   }
207   MP_SHRINK(d);
208   return (d);
209 }
210
211 #endif
212
213 /* --- @mpmont_mul@ --- *
214  *
215  * Arguments:   @mpmont *mm@ = pointer to Montgomery reduction context
216  *              @mp *d@ = destination
217  *              @mp *a, *b@ = sources, assumed positive
218  *
219  * Returns:     Result, %$a b R^{-1} \bmod m$%.
220  */
221
222 #ifdef MPMONT_DISABLE
223
224 mp *mpmont_mul(mpmont *mm, mp *d, mp *a, mp *b)
225 {
226   d = mp_mul(d, a, b);
227   mp_div(0, &d, d, mm->m);
228   return (d);
229 }
230
231 #else
232
233 mp *mpmont_mul(mpmont *mm, mp *d, mp *a, mp *b)
234 {
235   if (mm->n > MPK_THRESH * 3) {
236     d = mp_mul(d, a, b);
237     d = mpmont_reduce(mm, d, d);
238   } else {
239     mpw *dv, *dvl;
240     mpw *av, *avl;
241     mpw *bv, *bvl;
242     mpw *mv, *mvl;
243     mpw y;
244     size_t n, i;
245     mpw mi;
246
247     /* --- Initial conditioning of the arguments --- */
248
249     if (MP_LEN(a) > MP_LEN(b)) {
250       mp *t = a; a = b; b = t;
251     }
252     n = MP_LEN(mm->m);
253
254     a = MP_COPY(a);
255     b = MP_COPY(b);
256     MP_DEST(d, 2 * n + 1, a->f | b->f | MP_UNDEF);
257     dv = d->v; dvl = d->vl;
258     MPX_ZERO(dv, dvl);
259     av = a->v; avl = a->vl;
260     bv = b->v; bvl = b->vl;
261     mv = mm->m->v; mvl = mm->m->vl;
262     y = *bv;
263
264     /* --- Montgomery multiplication phase --- */
265
266     i = 0;
267     mi = mm->mi->v[0];
268     while (i < n && av < avl) {
269       mpw x = *av++;
270       mpw u = MPW((*dv + x * y) * mi);
271       MPX_UMLAN(dv, dvl, bv, bvl, x);
272       MPX_UMLAN(dv, dvl, mv, mvl, u);
273       dv++;
274       i++;
275     }
276
277     /* --- Simpler Montgomery reduction phase --- */
278
279     while (i < n) {
280       mpw u = MPW(*dv * mi);
281       MPX_UMLAN(dv, dvl, mv, mvl, u);
282       dv++;
283       i++;
284     }
285
286     /* --- Done --- */
287
288     memmove(d->v, dv, MPWS(dvl - dv));
289     d->vl -= dv - d->v;
290     if (MPX_UCMP(d->v, d->vl, >=, mm->m->v, mm->m->vl))
291       mpx_usub(d->v, d->vl, d->v, d->vl, mm->m->v, mm->m->vl);
292     if ((a->f ^ b->f) & MP_NEG)
293       mpx_usub(d->v, d->vl, mm->m->v, mm->m->vl, d->v, d->vl);
294     MP_SHRINK(d);
295     d->f = (a->f | b->f) & MP_BURN;
296     MP_DROP(a);
297     MP_DROP(b);
298   }
299
300   return (d);
301 }
302
303 #endif
304
305 /*----- Test rig ----------------------------------------------------------*/
306
307 #ifdef TEST_RIG
308
309 static int tcreate(dstr *v)
310 {
311   mp *m = *(mp **)v[0].buf;
312   mp *mi = *(mp **)v[1].buf;
313   mp *r = *(mp **)v[2].buf;
314   mp *r2 = *(mp **)v[3].buf;
315
316   mpmont mm;
317   int ok = 1;
318
319   mpmont_create(&mm, m);
320
321   if (mm.mi->v[0] != mi->v[0]) {
322     fprintf(stderr, "\n*** bad mi: found %lu, expected %lu",
323             (unsigned long)mm.mi->v[0], (unsigned long)mi->v[0]);
324     fputs("\nm = ", stderr); mp_writefile(m, stderr, 10);
325     fputc('\n', stderr);
326     ok = 0;
327   }
328
329   if (!MP_EQ(mm.r, r)) {
330     fputs("\n*** bad r", stderr);
331     fputs("\nm = ", stderr); mp_writefile(m, stderr, 10);
332     fputs("\nexpected ", stderr); mp_writefile(r, stderr, 10);
333     fputs("\n   found ", stderr); mp_writefile(mm.r, stderr, 10);
334     fputc('\n', stderr);
335     ok = 0;
336   }
337
338   if (!MP_EQ(mm.r2, r2)) {
339     fputs("\n*** bad r2", stderr);
340     fputs("\nm = ", stderr); mp_writefile(m, stderr, 10);
341     fputs("\nexpected ", stderr); mp_writefile(r2, stderr, 10);
342     fputs("\n   found ", stderr); mp_writefile(mm.r2, stderr, 10);
343     fputc('\n', stderr);
344     ok = 0;
345   }
346
347   MP_DROP(m);
348   MP_DROP(mi);
349   MP_DROP(r);
350   MP_DROP(r2);
351   mpmont_destroy(&mm);
352   assert(mparena_count(MPARENA_GLOBAL) == 0);
353   return (ok);
354 }
355
356 static int tmul(dstr *v)
357 {
358   mp *m = *(mp **)v[0].buf;
359   mp *a = *(mp **)v[1].buf;
360   mp *b = *(mp **)v[2].buf;
361   mp *r = *(mp **)v[3].buf;
362   int ok = 1;
363
364   mpmont mm;
365   mpmont_create(&mm, m);
366
367   {
368     mp *qr = mp_mul(MP_NEW, a, b);
369     mp_div(0, &qr, qr, m);
370
371     if (!MP_EQ(qr, r)) {
372       fputs("\n*** classical modmul failed", stderr);
373       fputs("\n m = ", stderr); mp_writefile(m, stderr, 10);
374       fputs("\n a = ", stderr); mp_writefile(a, stderr, 10);
375       fputs("\n b = ", stderr); mp_writefile(b, stderr, 10);
376       fputs("\n r = ", stderr); mp_writefile(r, stderr, 10);
377       fputs("\nqr = ", stderr); mp_writefile(qr, stderr, 10);
378       fputc('\n', stderr);
379       ok = 0;
380     }
381
382     mp_drop(qr);
383   }
384
385   {
386     mp *ar = mpmont_mul(&mm, MP_NEW, a, mm.r2);
387     mp *br = mpmont_mul(&mm, MP_NEW, b, mm.r2);
388     mp *mr = mpmont_mul(&mm, MP_NEW, ar, br);
389     mr = mpmont_reduce(&mm, mr, mr);
390     if (!MP_EQ(mr, r)) {
391       fputs("\n*** montgomery modmul failed", stderr);
392       fputs("\n m = ", stderr); mp_writefile(m, stderr, 10);
393       fputs("\n a = ", stderr); mp_writefile(a, stderr, 10);
394       fputs("\n b = ", stderr); mp_writefile(b, stderr, 10);
395       fputs("\n r = ", stderr); mp_writefile(r, stderr, 10);
396       fputs("\nmr = ", stderr); mp_writefile(mr, stderr, 10);
397       fputc('\n', stderr);
398       ok = 0;
399     }
400     MP_DROP(ar); MP_DROP(br);
401     mp_drop(mr);
402   }
403
404
405   MP_DROP(m);
406   MP_DROP(a);
407   MP_DROP(b);
408   MP_DROP(r);
409   mpmont_destroy(&mm);
410   assert(mparena_count(MPARENA_GLOBAL) == 0);
411   return ok;
412 }
413
414 static test_chunk tests[] = {
415   { "create", tcreate, { &type_mp, &type_mp, &type_mp, &type_mp, 0 } },
416   { "mul", tmul, { &type_mp, &type_mp, &type_mp, &type_mp, 0 } },
417   { 0, 0, { 0 } },
418 };
419
420 int main(int argc, char *argv[])
421 {
422   sub_init();
423   test_run(argc, argv, tests, SRCDIR "/tests/mpmont");
424   return (0);
425 }
426
427 #endif
428
429 /*----- That's all, folks -------------------------------------------------*/