chiark / gitweb /
f8a26119bf03492f8d0afd8281f1b11d29d5571d
[catacomb] / math / mpmont.c
1 /* -*-c-*-
2  *
3  * Montgomery reduction
4  *
5  * (c) 1999 Straylight/Edgeware
6  */
7
8 /*----- Licensing notice --------------------------------------------------*
9  *
10  * This file is part of Catacomb.
11  *
12  * Catacomb is free software; you can redistribute it and/or modify
13  * it under the terms of the GNU Library General Public License as
14  * published by the Free Software Foundation; either version 2 of the
15  * License, or (at your option) any later version.
16  *
17  * Catacomb is distributed in the hope that it will be useful,
18  * but WITHOUT ANY WARRANTY; without even the implied warranty of
19  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
20  * GNU Library General Public License for more details.
21  *
22  * You should have received a copy of the GNU Library General Public
23  * License along with Catacomb; if not, write to the Free
24  * Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
25  * MA 02111-1307, USA.
26  */
27
28 /*----- Header files ------------------------------------------------------*/
29
30 #include "config.h"
31 #include "dispatch.h"
32 #include "mp.h"
33 #include "mpmont.h"
34
35 /*----- Tweakables --------------------------------------------------------*/
36
37 /* --- @MPMONT_DISABLE@ --- *
38  *
39  * Replace all the clever Montgomery reduction with good old-fashioned long
40  * division.
41  */
42
43 /* #define MPMONT_DISABLE */
44
45 #define MPMONT_KTHRESH (16*MPK_THRESH)
46
47 /*----- Low-level implementation ------------------------------------------*/
48
49 #ifndef MPMONT_DISABLE
50
51 /* --- @redccore@ --- *
52  *
53  * Arguments:   @mpw *dv, *dvl@ = base and limit of source/destination
54  *              @const mpw *mv@ = base of modulus %$m$%
55  *              @size_t n@ = length of modulus
56  *              @const mpw *mi@ = base of REDC coefficient %$m'$%
57  *
58  * Returns:     ---
59  *
60  * Use:         Let %$a$% be the input operand.  Store in %$d$% the value
61  *              %$a + (m' a \bmod R) m$%.  The destination has space for at
62  *              least %$2 n + 1$% words of result.
63  */
64
65 CPU_DISPATCH(static, (void), void, redccore,
66              (mpw *dv, mpw *dvl, const mpw *mv, size_t n, const mpw *mi),
67              (dv, dvl, mv, n, mi), pick_redccore, simple_redccore);
68
69 static void simple_redccore(mpw *dv, mpw *dvl, const mpw *mv,
70                             size_t n, const mpw *mi)
71 {
72   mpw mi0 = *mi;
73   size_t i;
74
75   for (i = 0; i < n; i++) {
76     MPX_UMLAN(dv, dvl, mv, mv + n, MPW(*dv*mi0));
77     dv++;
78   }
79 }
80
81 #define MAYBE_REDC4(impl)                                               \
82   extern void mpxmont_redc4_##impl(mpw *dv, mpw *dvl, const mpw *mv,    \
83                                    size_t n, const mpw *mi);            \
84   static void maybe_redc4_##impl(mpw *dv, mpw *dvl, const mpw *mv,      \
85                                  size_t n, const mpw *mi)               \
86   {                                                                     \
87     if (n%4) simple_redccore(dv, dvl, mv, n, mi);                       \
88     else mpxmont_redc4_##impl(dv, dvl, mv, n, mi);                      \
89   }
90
91 #if CPUFAM_X86
92   MAYBE_REDC4(x86_sse2)
93 #endif
94
95 #if CPUFAM_AMD64
96   MAYBE_REDC4(amd64_sse2)
97 #endif
98
99 static redccore__functype *pick_redccore(void)
100 {
101 #if CPUFAM_X86
102   DISPATCH_PICK_COND(mpmont_reduce, maybe_redc4_x86_sse2,
103                      cpu_feature_p(CPUFEAT_X86_SSE2));
104 #endif
105 #if CPUFAM_AMD64
106   DISPATCH_PICK_COND(mpmont_reduce, maybe_redc4_amd64_sse2,
107                      cpu_feature_p(CPUFEAT_X86_SSE2));
108 #endif
109   DISPATCH_PICK_FALLBACK(mpmont_reduce, simple_redccore);
110 }
111
112 /* --- @redccore@ --- *
113  *
114  * Arguments:   @mpw *dv, *dvl@ = base and limit of source/destination
115  *              @const mpw *av, *avl@ = base and limit of first multiplicand
116  *              @const mpw *bv, *bvl@ = base and limit of second multiplicand
117  *              @const mpw *mv@ = base of modulus %$m$%
118  *              @size_t n@ = length of modulus
119  *              @const mpw *mi@ = base of REDC coefficient %$m'$%
120  *
121  * Returns:     ---
122  *
123  * Use:         Let %$a$% and %$b$% be the multiplicands.  Let %$w = a b$%.
124  *              Store in %$d$% the value %$a b + (m' a b \bmod R) m$%.
125  */
126
127 CPU_DISPATCH(static, (void), void, mulcore,
128              (mpw *dv, mpw *dvl, const mpw *av, const mpw *avl,
129               const mpw *bv, const mpw *bvl, const mpw *mv,
130               size_t n, const mpw *mi),
131              (dv, dvl, av, avl, bv, bvl, mv, n, mi),
132              pick_mulcore, simple_mulcore);
133
134 static void simple_mulcore(mpw *dv, mpw *dvl,
135                            const mpw *av, const mpw *avl,
136                            const mpw *bv, const mpw *bvl,
137                            const mpw *mv, size_t n, const mpw *mi)
138 {
139   mpw ai, b0, y, mi0 = *mi;
140   const mpw *tv, *tvl;
141   const mpw *mvl = mv + n;
142   size_t i = 0;
143
144   /* --- Initial setup --- */
145
146   MPX_ZERO(dv, dvl);
147   if (avl - av > bvl - bv) {
148     tv = av; av = bv; bv = tv;
149     tvl = avl; avl = bvl; bvl = tvl;
150   }
151   b0 = *bv;
152
153   /* --- Multiply, until we run out of multiplicand --- */
154
155   while (i < n && av < avl) {
156     ai = *av++;
157     y = MPW((*dv + ai*b0)*mi0);
158     MPX_UMLAN(dv, dvl, bv, bvl, ai);
159     MPX_UMLAN(dv, dvl, mv, mvl, y);
160     dv++; i++;
161   }
162
163   /* --- Continue reducing until we run out of modulus --- */
164
165   while (i < n) {
166     y = MPW(*dv*mi0);
167     MPX_UMLAN(dv, dvl, mv, mvl, y);
168     dv++; i++;
169   }
170 }
171
172 #define MAYBE_MUL4(impl)                                                \
173   extern void mpxmont_mul4_##impl(mpw *dv,                              \
174                                   const mpw *av, const mpw *bv,         \
175                                   const mpw *mv,                        \
176                                   size_t n, const mpw *mi);             \
177   static void maybe_mul4_##impl(mpw *dv, mpw *dvl,                      \
178                            const mpw *av, const mpw *avl,               \
179                            const mpw *bv, const mpw *bvl,               \
180                            const mpw *mv, size_t n, const mpw *mi)      \
181   {                                                                     \
182     size_t an = avl - av, bn = bvl - bv;                                \
183     if (n%4 || an != n || bn != n)                                      \
184       simple_mulcore(dv, dvl, av, avl, bv, bvl, mv, n, mi);             \
185     else {                                                              \
186       mpxmont_mul4_##impl(dv, av, bv, mv, n, mi);                       \
187       MPX_ZERO(dv + 2*n + 1, dvl);                                      \
188     }                                                                   \
189   }
190
191 #if CPUFAM_X86
192   MAYBE_MUL4(x86_sse2)
193 #endif
194
195 #if CPUFAM_AMD64
196   MAYBE_MUL4(amd64_sse2)
197 #endif
198
199 static mulcore__functype *pick_mulcore(void)
200 {
201 #if CPUFAM_X86
202   DISPATCH_PICK_COND(mpmont_mul, maybe_mul4_x86_sse2,
203                      cpu_feature_p(CPUFEAT_X86_SSE2));
204 #endif
205 #if CPUFAM_AMD64
206   DISPATCH_PICK_COND(mpmont_mul, maybe_mul4_amd64_sse2,
207                      cpu_feature_p(CPUFEAT_X86_SSE2));
208 #endif
209   DISPATCH_PICK_FALLBACK(mpmont_mul, simple_mulcore);
210 }
211
212 /* --- @finish@ --- *
213  *
214  * Arguments:   @const mpmont *mm@ = pointer to a Montgomery reduction
215  *                      context
216  *              *mp *d@ = pointer to mostly-reduced operand
217  *
218  * Returns:     ---
219  *
220  * Use:         Applies the finishing touches to Montgomery reduction.  The
221  *              operand @d@ is a multiple of %$R%$ at this point, so it needs
222  *              to be shifted down; the result might need a further
223  *              subtraction to get it into the right interval; and we may
224  *              need to do an additional subtraction if %$d$% is negative.
225  */
226
227 static void finish(const mpmont *mm, mp *d)
228 {
229   mpw *dv = d->v, *dvl = d->vl;
230   size_t n = mm->n;
231
232   memmove(dv, dv + n, MPWS(dvl - (dv + n)));
233   dvl -= n;
234
235   if (MPX_UCMP(dv, dvl, >=, mm->m->v, mm->m->vl))
236     mpx_usub(dv, dvl, dv, dvl, mm->m->v, mm->m->vl);
237
238   if (d->f & MP_NEG) {
239     mpx_usub(dv, dvl, mm->m->v, mm->m->vl, dv, dvl);
240     d->f &= ~MP_NEG;
241   }
242
243   d->vl = dvl;
244   MP_SHRINK(d);
245 }
246
247 #endif
248
249 /*----- Reduction and multiplication --------------------------------------*/
250
251 /* --- @mpmont_create@ --- *
252  *
253  * Arguments:   @mpmont *mm@ = pointer to Montgomery reduction context
254  *              @mp *m@ = modulus to use
255  *
256  * Returns:     Zero on success, nonzero on error.
257  *
258  * Use:         Initializes a Montgomery reduction context ready for use.
259  *              The argument @m@ must be a positive odd integer.
260  */
261
262 #ifdef MPMONT_DISABLE
263
264 int mpmont_create(mpmont *mm, mp *m)
265 {
266   mp_shrink(m);
267   mm->m = MP_COPY(m);
268   mm->r = MP_ONE;
269   mm->r2 = MP_ONE;
270   mm->mi = MP_ONE;
271   return (0);
272 }
273
274 #else
275
276 int mpmont_create(mpmont *mm, mp *m)
277 {
278   size_t n = MP_LEN(m);
279   mp *r2 = mp_new(2 * n + 1, 0);
280   mp r;
281
282   /* --- Take a copy of the modulus --- */
283
284  if (!MP_POSP(m) || !MP_ODDP(m))
285    return (-1);
286   mm->m = MP_COPY(m);
287
288   /* --- Determine %$R^2$% --- */
289
290   mm->n = n;
291   MPX_ZERO(r2->v, r2->vl - 1);
292   r2->vl[-1] = 1;
293
294   /* --- Find the magic value @mi@ --- */
295
296   mp_build(&r, r2->v + n, r2->vl);
297   mm->mi = mp_modinv(MP_NEW, m, &r);
298   mm->mi = mp_sub(mm->mi, &r, mm->mi);
299   MP_ENSURE(mm->mi, n);
300
301   /* --- Discover the values %$R \bmod m$% and %$R^2 \bmod m$% --- */
302
303   mm->r2 = MP_NEW;
304   mp_div(0, &mm->r2, r2, m);
305   mm->r = mpmont_reduce(mm, MP_NEW, mm->r2);
306   MP_DROP(r2);
307   return (0);
308 }
309
310 #endif
311
312 /* --- @mpmont_destroy@ --- *
313  *
314  * Arguments:   @mpmont *mm@ = pointer to a Montgomery reduction context
315  *
316  * Returns:     ---
317  *
318  * Use:         Disposes of a context when it's no longer of any use to
319  *              anyone.
320  */
321
322 void mpmont_destroy(mpmont *mm)
323 {
324   MP_DROP(mm->m);
325   MP_DROP(mm->r);
326   MP_DROP(mm->r2);
327   MP_DROP(mm->mi);
328 }
329
330 /* --- @mpmont_reduce@ --- *
331  *
332  * Arguments:   @const mpmont *mm@ = pointer to Montgomery reduction context
333  *              @mp *d@ = destination
334  *              @mp *a@ = source, assumed positive
335  *
336  * Returns:     Result, %$a R^{-1} \bmod m$%.
337  */
338
339 #ifdef MPMONT_DISABLE
340
341 mp *mpmont_reduce(const mpmont *mm, mp *d, mp *a)
342 {
343   mp_div(0, &d, a, mm->m);
344   return (d);
345 }
346
347 #else
348
349 mp *mpmont_reduce(const mpmont *mm, mp *d, mp *a)
350 {
351   size_t n = mm->n;
352
353   /* --- Check for serious Karatsuba reduction --- */
354
355   if (n > MPMONT_KTHRESH) {
356     mp al;
357     mpw *vl;
358     mp *u;
359
360     if (MP_LEN(a) >= n) vl = a->v + n;
361     else vl = a->vl;
362     mp_build(&al, a->v, vl);
363     u = mp_mul(MP_NEW, &al, mm->mi);
364     if (MP_LEN(u) > n) u->vl = u->v + n;
365     u = mp_mul(u, u, mm->m);
366     d = mp_add(d, a, u);
367     MP_ENSURE(d, n);
368     mp_drop(u);
369   }
370
371   /* --- Otherwise do it the hard way --- */
372
373   else {
374     a = MP_COPY(a);
375     if (d) MP_DROP(d);
376     d = a;
377     MP_DEST(d, 2*mm->n + 1, a->f);
378     redccore(d->v, d->vl, mm->m->v, mm->n, mm->mi->v);
379   }
380
381   /* --- Wrap everything up --- */
382
383   finish(mm, d);
384   return (d);
385 }
386
387 #endif
388
389 /* --- @mpmont_mul@ --- *
390  *
391  * Arguments:   @const mpmont *mm@ = pointer to Montgomery reduction context
392  *              @mp *d@ = destination
393  *              @mp *a, *b@ = sources, assumed positive
394  *
395  * Returns:     Result, %$a b R^{-1} \bmod m$%.
396  */
397
398 #ifdef MPMONT_DISABLE
399
400 mp *mpmont_mul(const mpmont *mm, mp *d, mp *a, mp *b)
401 {
402   d = mp_mul(d, a, b);
403   mp_div(0, &d, d, mm->m);
404   return (d);
405 }
406
407 #else
408
409 mp *mpmont_mul(const mpmont *mm, mp *d, mp *a, mp *b)
410 {
411   size_t n = mm->n;
412
413   if (n > MPMONT_KTHRESH) {
414     d = mp_mul(d, a, b);
415     d = mpmont_reduce(mm, d, d);
416   } else {
417     a = MP_COPY(a); b = MP_COPY(b);
418     MP_DEST(d, 2*n + 1, a->f | b->f | MP_UNDEF);
419     mulcore(d->v, d->vl, a->v, a->vl, b->v, b->vl,
420             mm->m->v, mm->n, mm->mi->v);
421     d->f = ((a->f | b->f) & MP_BURN) | ((a->f ^ b->f) & MP_NEG);
422     finish(mm, d);
423     MP_DROP(a); MP_DROP(b);
424   }
425
426   return (d);
427 }
428
429 #endif
430
431 /*----- Test rig ----------------------------------------------------------*/
432
433 #ifdef TEST_RIG
434
435 static int tcreate(dstr *v)
436 {
437   mp *m = *(mp **)v[0].buf;
438   mp *mi = *(mp **)v[1].buf;
439   mp *r = *(mp **)v[2].buf;
440   mp *r2 = *(mp **)v[3].buf;
441
442   mpmont mm;
443   int ok = 1;
444
445   mpmont_create(&mm, m);
446
447   if (mm.mi->v[0] != mi->v[0]) {
448     fprintf(stderr, "\n*** bad mi: found %lu, expected %lu",
449             (unsigned long)mm.mi->v[0], (unsigned long)mi->v[0]);
450     fputs("\nm = ", stderr); mp_writefile(m, stderr, 10);
451     fputc('\n', stderr);
452     ok = 0;
453   }
454
455   if (!MP_EQ(mm.r, r)) {
456     fputs("\n*** bad r", stderr);
457     fputs("\nm = ", stderr); mp_writefile(m, stderr, 10);
458     fputs("\nexpected ", stderr); mp_writefile(r, stderr, 10);
459     fputs("\n   found ", stderr); mp_writefile(mm.r, stderr, 10);
460     fputc('\n', stderr);
461     ok = 0;
462   }
463
464   if (!MP_EQ(mm.r2, r2)) {
465     fputs("\n*** bad r2", stderr);
466     fputs("\nm = ", stderr); mp_writefile(m, stderr, 10);
467     fputs("\nexpected ", stderr); mp_writefile(r2, stderr, 10);
468     fputs("\n   found ", stderr); mp_writefile(mm.r2, stderr, 10);
469     fputc('\n', stderr);
470     ok = 0;
471   }
472
473   MP_DROP(m);
474   MP_DROP(mi);
475   MP_DROP(r);
476   MP_DROP(r2);
477   mpmont_destroy(&mm);
478   assert(mparena_count(MPARENA_GLOBAL) == 0);
479   return (ok);
480 }
481
482 static int tmul(dstr *v)
483 {
484   mp *m = *(mp **)v[0].buf;
485   mp *a = *(mp **)v[1].buf;
486   mp *b = *(mp **)v[2].buf;
487   mp *r = *(mp **)v[3].buf;
488   int ok = 1;
489
490   mpmont mm;
491   mpmont_create(&mm, m);
492
493   {
494     mp *qr = mp_mul(MP_NEW, a, b);
495     mp_div(0, &qr, qr, m);
496
497     if (!MP_EQ(qr, r)) {
498       fputs("\n*** classical modmul failed", stderr);
499       fputs("\n m = ", stderr); mp_writefile(m, stderr, 10);
500       fputs("\n a = ", stderr); mp_writefile(a, stderr, 10);
501       fputs("\n b = ", stderr); mp_writefile(b, stderr, 10);
502       fputs("\n r = ", stderr); mp_writefile(r, stderr, 10);
503       fputs("\nqr = ", stderr); mp_writefile(qr, stderr, 10);
504       fputc('\n', stderr);
505       ok = 0;
506     }
507
508     mp_drop(qr);
509   }
510
511   {
512     mp *ar = mpmont_mul(&mm, MP_NEW, a, mm.r2);
513     mp *br = mpmont_mul(&mm, MP_NEW, b, mm.r2);
514     mp *mr = mpmont_mul(&mm, MP_NEW, ar, br);
515     mr = mpmont_reduce(&mm, mr, mr);
516     if (!MP_EQ(mr, r)) {
517       fputs("\n*** montgomery modmul failed", stderr);
518       fputs("\n m = ", stderr); mp_writefile(m, stderr, 10);
519       fputs("\n a = ", stderr); mp_writefile(a, stderr, 10);
520       fputs("\n b = ", stderr); mp_writefile(b, stderr, 10);
521       fputs("\n r = ", stderr); mp_writefile(r, stderr, 10);
522       fputs("\nmr = ", stderr); mp_writefile(mr, stderr, 10);
523       fputc('\n', stderr);
524       ok = 0;
525     }
526     MP_DROP(ar); MP_DROP(br);
527     mp_drop(mr);
528   }
529
530
531   MP_DROP(m);
532   MP_DROP(a);
533   MP_DROP(b);
534   MP_DROP(r);
535   mpmont_destroy(&mm);
536   assert(mparena_count(MPARENA_GLOBAL) == 0);
537   return ok;
538 }
539
540 static test_chunk tests[] = {
541   { "create", tcreate, { &type_mp, &type_mp, &type_mp, &type_mp, 0 } },
542   { "mul", tmul, { &type_mp, &type_mp, &type_mp, &type_mp, 0 } },
543   { 0, 0, { 0 } },
544 };
545
546 int main(int argc, char *argv[])
547 {
548   sub_init();
549   test_run(argc, argv, tests, SRCDIR "/t/mpmont");
550   return (0);
551 }
552
553 #endif
554
555 /*----- That's all, folks -------------------------------------------------*/