chiark / gitweb /
Merge branch 'mdw/rsvr'
[catacomb] / math / mpmont.c
1 /* -*-c-*-
2  *
3  * Montgomery reduction
4  *
5  * (c) 1999 Straylight/Edgeware
6  */
7
8 /*----- Licensing notice --------------------------------------------------*
9  *
10  * This file is part of Catacomb.
11  *
12  * Catacomb is free software; you can redistribute it and/or modify
13  * it under the terms of the GNU Library General Public License as
14  * published by the Free Software Foundation; either version 2 of the
15  * License, or (at your option) any later version.
16  *
17  * Catacomb is distributed in the hope that it will be useful,
18  * but WITHOUT ANY WARRANTY; without even the implied warranty of
19  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
20  * GNU Library General Public License for more details.
21  *
22  * You should have received a copy of the GNU Library General Public
23  * License along with Catacomb; if not, write to the Free
24  * Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
25  * MA 02111-1307, USA.
26  */
27
28 /*----- Header files ------------------------------------------------------*/
29
30 #include "config.h"
31 #include "dispatch.h"
32 #include "mp.h"
33 #include "mpmont.h"
34
35 /*----- Tweakables --------------------------------------------------------*/
36
37 /* --- @MPMONT_DISABLE@ --- *
38  *
39  * Replace all the clever Montgomery reduction with good old-fashioned long
40  * division.
41  */
42
43 /* #define MPMONT_DISABLE */
44
45 #define MPMONT_KTHRESH (16*MPK_THRESH)
46
47 /*----- Low-level implementation ------------------------------------------*/
48
49 #ifndef MPMONT_DISABLE
50
51 /* --- @redccore@ --- *
52  *
53  * Arguments:   @mpw *dv, *dvl@ = base and limit of source/destination
54  *              @const mpw *mv@ = base of modulus %$m$%
55  *              @size_t n@ = length of modulus
56  *              @const mpw *mi@ = base of REDC coefficient %$m'$%
57  *
58  * Returns:     ---
59  *
60  * Use:         Let %$a$% be the input operand.  Store in %$d$% the value
61  *              %$a + (m' a \bmod R) m$%.  The destination has space for at
62  *              least %$2 n + 1$% words of result.
63  */
64
65 CPU_DISPATCH(static, (void), void, redccore,
66              (mpw *dv, mpw *dvl, const mpw *mv, size_t n, const mpw *mi),
67              (dv, dvl, mv, n, mi), pick_redccore, simple_redccore);
68
69 static void simple_redccore(mpw *dv, mpw *dvl, const mpw *mv,
70                             size_t n, const mpw *mi)
71 {
72   mpw mi0 = *mi;
73   size_t i;
74
75   for (i = 0; i < n; i++) {
76     MPX_UMLAN(dv, dvl, mv, mv + n, MPW(*dv*mi0));
77     dv++;
78   }
79 }
80
81 #define MAYBE_REDC4(impl)                                               \
82   extern void mpxmont_redc4_##impl(mpw *dv, mpw *dvl, const mpw *mv,    \
83                                    size_t n, const mpw *mi);            \
84   static void maybe_redc4_##impl(mpw *dv, mpw *dvl, const mpw *mv,      \
85                                  size_t n, const mpw *mi)               \
86   {                                                                     \
87     if (n%4) simple_redccore(dv, dvl, mv, n, mi);                       \
88     else mpxmont_redc4_##impl(dv, dvl, mv, n, mi);                      \
89   }
90
91 #if CPUFAM_X86
92   MAYBE_REDC4(x86_sse2)
93   MAYBE_REDC4(x86_avx)
94 #endif
95
96 #if CPUFAM_AMD64
97   MAYBE_REDC4(amd64_sse2)
98   MAYBE_REDC4(amd64_avx)
99 #endif
100
101 static redccore__functype *pick_redccore(void)
102 {
103 #if CPUFAM_X86
104   DISPATCH_PICK_COND(mpmont_reduce, maybe_redc4_x86_avx,
105                      cpu_feature_p(CPUFEAT_X86_AVX));
106   DISPATCH_PICK_COND(mpmont_reduce, maybe_redc4_x86_sse2,
107                      cpu_feature_p(CPUFEAT_X86_SSE2));
108 #endif
109 #if CPUFAM_AMD64
110   DISPATCH_PICK_COND(mpmont_reduce, maybe_redc4_amd64_avx,
111                      cpu_feature_p(CPUFEAT_X86_AVX));
112   DISPATCH_PICK_COND(mpmont_reduce, maybe_redc4_amd64_sse2,
113                      cpu_feature_p(CPUFEAT_X86_SSE2));
114 #endif
115   DISPATCH_PICK_FALLBACK(mpmont_reduce, simple_redccore);
116 }
117
118 /* --- @redccore@ --- *
119  *
120  * Arguments:   @mpw *dv, *dvl@ = base and limit of source/destination
121  *              @const mpw *av, *avl@ = base and limit of first multiplicand
122  *              @const mpw *bv, *bvl@ = base and limit of second multiplicand
123  *              @const mpw *mv@ = base of modulus %$m$%
124  *              @size_t n@ = length of modulus
125  *              @const mpw *mi@ = base of REDC coefficient %$m'$%
126  *
127  * Returns:     ---
128  *
129  * Use:         Let %$a$% and %$b$% be the multiplicands.  Let %$w = a b$%.
130  *              Store in %$d$% the value %$a b + (m' a b \bmod R) m$%.
131  */
132
133 CPU_DISPATCH(static, (void), void, mulcore,
134              (mpw *dv, mpw *dvl, const mpw *av, const mpw *avl,
135               const mpw *bv, const mpw *bvl, const mpw *mv,
136               size_t n, const mpw *mi),
137              (dv, dvl, av, avl, bv, bvl, mv, n, mi),
138              pick_mulcore, simple_mulcore);
139
140 static void simple_mulcore(mpw *dv, mpw *dvl,
141                            const mpw *av, const mpw *avl,
142                            const mpw *bv, const mpw *bvl,
143                            const mpw *mv, size_t n, const mpw *mi)
144 {
145   mpw ai, b0, y, mi0 = *mi;
146   const mpw *tv, *tvl;
147   const mpw *mvl = mv + n;
148   size_t i = 0;
149
150   /* --- Initial setup --- */
151
152   MPX_ZERO(dv, dvl);
153   if (avl - av > bvl - bv) {
154     tv = av; av = bv; bv = tv;
155     tvl = avl; avl = bvl; bvl = tvl;
156   }
157   b0 = *bv;
158
159   /* --- Multiply, until we run out of multiplicand --- */
160
161   while (i < n && av < avl) {
162     ai = *av++;
163     y = MPW((*dv + ai*b0)*mi0);
164     MPX_UMLAN(dv, dvl, bv, bvl, ai);
165     MPX_UMLAN(dv, dvl, mv, mvl, y);
166     dv++; i++;
167   }
168
169   /* --- Continue reducing until we run out of modulus --- */
170
171   while (i < n) {
172     y = MPW(*dv*mi0);
173     MPX_UMLAN(dv, dvl, mv, mvl, y);
174     dv++; i++;
175   }
176 }
177
178 #define MAYBE_MUL4(impl)                                                \
179   extern void mpxmont_mul4_##impl(mpw *dv,                              \
180                                   const mpw *av, const mpw *bv,         \
181                                   const mpw *mv,                        \
182                                   size_t n, const mpw *mi);             \
183   static void maybe_mul4_##impl(mpw *dv, mpw *dvl,                      \
184                            const mpw *av, const mpw *avl,               \
185                            const mpw *bv, const mpw *bvl,               \
186                            const mpw *mv, size_t n, const mpw *mi)      \
187   {                                                                     \
188     size_t an = avl - av, bn = bvl - bv;                                \
189     if (n%4 || an != n || bn != n)                                      \
190       simple_mulcore(dv, dvl, av, avl, bv, bvl, mv, n, mi);             \
191     else {                                                              \
192       mpxmont_mul4_##impl(dv, av, bv, mv, n, mi);                       \
193       MPX_ZERO(dv + 2*n + 1, dvl);                                      \
194     }                                                                   \
195   }
196
197 #if CPUFAM_X86
198   MAYBE_MUL4(x86_sse2)
199   MAYBE_MUL4(x86_avx)
200 #endif
201
202 #if CPUFAM_AMD64
203   MAYBE_MUL4(amd64_sse2)
204   MAYBE_MUL4(amd64_avx)
205 #endif
206
207 static mulcore__functype *pick_mulcore(void)
208 {
209 #if CPUFAM_X86
210   DISPATCH_PICK_COND(mpmont_mul, maybe_mul4_x86_avx,
211                      cpu_feature_p(CPUFEAT_X86_AVX));
212   DISPATCH_PICK_COND(mpmont_mul, maybe_mul4_x86_sse2,
213                      cpu_feature_p(CPUFEAT_X86_SSE2));
214 #endif
215 #if CPUFAM_AMD64
216   DISPATCH_PICK_COND(mpmont_mul, maybe_mul4_amd64_avx,
217                      cpu_feature_p(CPUFEAT_X86_AVX));
218   DISPATCH_PICK_COND(mpmont_mul, maybe_mul4_amd64_sse2,
219                      cpu_feature_p(CPUFEAT_X86_SSE2));
220 #endif
221   DISPATCH_PICK_FALLBACK(mpmont_mul, simple_mulcore);
222 }
223
224 /* --- @finish@ --- *
225  *
226  * Arguments:   @const mpmont *mm@ = pointer to a Montgomery reduction
227  *                      context
228  *              *mp *d@ = pointer to mostly-reduced operand
229  *
230  * Returns:     ---
231  *
232  * Use:         Applies the finishing touches to Montgomery reduction.  The
233  *              operand @d@ is a multiple of %$R%$ at this point, so it needs
234  *              to be shifted down; the result might need a further
235  *              subtraction to get it into the right interval; and we may
236  *              need to do an additional subtraction if %$d$% is negative.
237  */
238
239 static void finish(const mpmont *mm, mp *d)
240 {
241   mpw *dv = d->v, *dvl = d->vl;
242   size_t n = mm->n;
243
244   memmove(dv, dv + n, MPWS(dvl - (dv + n)));
245   dvl -= n;
246
247   if (MPX_UCMP(dv, dvl, >=, mm->m->v, mm->m->vl))
248     mpx_usub(dv, dvl, dv, dvl, mm->m->v, mm->m->vl);
249
250   if (d->f & MP_NEG) {
251     mpx_usub(dv, dvl, mm->m->v, mm->m->vl, dv, dvl);
252     d->f &= ~MP_NEG;
253   }
254
255   d->vl = dvl;
256   MP_SHRINK(d);
257 }
258
259 #endif
260
261 /*----- Reduction and multiplication --------------------------------------*/
262
263 /* --- @mpmont_create@ --- *
264  *
265  * Arguments:   @mpmont *mm@ = pointer to Montgomery reduction context
266  *              @mp *m@ = modulus to use
267  *
268  * Returns:     Zero on success, nonzero on error.
269  *
270  * Use:         Initializes a Montgomery reduction context ready for use.
271  *              The argument @m@ must be a positive odd integer.
272  */
273
274 #ifdef MPMONT_DISABLE
275
276 int mpmont_create(mpmont *mm, mp *m)
277 {
278   mp_shrink(m);
279   mm->m = MP_COPY(m);
280   mm->r = MP_ONE;
281   mm->r2 = MP_ONE;
282   mm->mi = MP_ONE;
283   return (0);
284 }
285
286 #else
287
288 int mpmont_create(mpmont *mm, mp *m)
289 {
290   size_t n = MP_LEN(m);
291   mp *r2 = mp_new(2 * n + 1, 0);
292   mp r;
293
294   /* --- Take a copy of the modulus --- */
295
296  if (!MP_POSP(m) || !MP_ODDP(m))
297    return (-1);
298   mm->m = MP_COPY(m);
299
300   /* --- Determine %$R^2$% --- */
301
302   mm->n = n;
303   MPX_ZERO(r2->v, r2->vl - 1);
304   r2->vl[-1] = 1;
305
306   /* --- Find the magic value @mi@ --- */
307
308   mp_build(&r, r2->v + n, r2->vl);
309   mm->mi = mp_modinv(MP_NEW, m, &r);
310   mm->mi = mp_sub(mm->mi, &r, mm->mi);
311   MP_ENSURE(mm->mi, n);
312
313   /* --- Discover the values %$R \bmod m$% and %$R^2 \bmod m$% --- */
314
315   mm->r2 = MP_NEW;
316   mp_div(0, &mm->r2, r2, m);
317   mm->r = mpmont_reduce(mm, MP_NEW, mm->r2);
318   MP_DROP(r2);
319   return (0);
320 }
321
322 #endif
323
324 /* --- @mpmont_destroy@ --- *
325  *
326  * Arguments:   @mpmont *mm@ = pointer to a Montgomery reduction context
327  *
328  * Returns:     ---
329  *
330  * Use:         Disposes of a context when it's no longer of any use to
331  *              anyone.
332  */
333
334 void mpmont_destroy(mpmont *mm)
335 {
336   MP_DROP(mm->m);
337   MP_DROP(mm->r);
338   MP_DROP(mm->r2);
339   MP_DROP(mm->mi);
340 }
341
342 /* --- @mpmont_reduce@ --- *
343  *
344  * Arguments:   @const mpmont *mm@ = pointer to Montgomery reduction context
345  *              @mp *d@ = destination
346  *              @mp *a@ = source, assumed positive
347  *
348  * Returns:     Result, %$a R^{-1} \bmod m$%.
349  */
350
351 #ifdef MPMONT_DISABLE
352
353 mp *mpmont_reduce(const mpmont *mm, mp *d, mp *a)
354 {
355   mp_div(0, &d, a, mm->m);
356   return (d);
357 }
358
359 #else
360
361 mp *mpmont_reduce(const mpmont *mm, mp *d, mp *a)
362 {
363   size_t n = mm->n;
364
365   /* --- Check for serious Karatsuba reduction --- */
366
367   if (n > MPMONT_KTHRESH) {
368     mp al;
369     mpw *vl;
370     mp *u;
371
372     if (MP_LEN(a) >= n) vl = a->v + n;
373     else vl = a->vl;
374     mp_build(&al, a->v, vl);
375     u = mp_mul(MP_NEW, &al, mm->mi);
376     if (MP_LEN(u) > n) u->vl = u->v + n;
377     u = mp_mul(u, u, mm->m);
378     d = mp_add(d, a, u);
379     MP_ENSURE(d, n);
380     mp_drop(u);
381   }
382
383   /* --- Otherwise do it the hard way --- */
384
385   else {
386     a = MP_COPY(a);
387     if (d) MP_DROP(d);
388     d = a;
389     MP_DEST(d, 2*mm->n + 1, a->f);
390     redccore(d->v, d->vl, mm->m->v, mm->n, mm->mi->v);
391   }
392
393   /* --- Wrap everything up --- */
394
395   finish(mm, d);
396   return (d);
397 }
398
399 #endif
400
401 /* --- @mpmont_mul@ --- *
402  *
403  * Arguments:   @const mpmont *mm@ = pointer to Montgomery reduction context
404  *              @mp *d@ = destination
405  *              @mp *a, *b@ = sources, assumed positive
406  *
407  * Returns:     Result, %$a b R^{-1} \bmod m$%.
408  */
409
410 #ifdef MPMONT_DISABLE
411
412 mp *mpmont_mul(const mpmont *mm, mp *d, mp *a, mp *b)
413 {
414   d = mp_mul(d, a, b);
415   mp_div(0, &d, d, mm->m);
416   return (d);
417 }
418
419 #else
420
421 mp *mpmont_mul(const mpmont *mm, mp *d, mp *a, mp *b)
422 {
423   size_t n = mm->n;
424
425   if (n > MPMONT_KTHRESH) {
426     d = mp_mul(d, a, b);
427     d = mpmont_reduce(mm, d, d);
428   } else {
429     a = MP_COPY(a); b = MP_COPY(b);
430     MP_DEST(d, 2*n + 1, a->f | b->f | MP_UNDEF);
431     mulcore(d->v, d->vl, a->v, a->vl, b->v, b->vl,
432             mm->m->v, mm->n, mm->mi->v);
433     d->f = ((a->f | b->f) & MP_BURN) | ((a->f ^ b->f) & MP_NEG);
434     finish(mm, d);
435     MP_DROP(a); MP_DROP(b);
436   }
437
438   return (d);
439 }
440
441 #endif
442
443 /*----- Test rig ----------------------------------------------------------*/
444
445 #ifdef TEST_RIG
446
447 static int tcreate(dstr *v)
448 {
449   mp *m = *(mp **)v[0].buf;
450   mp *mi = *(mp **)v[1].buf;
451   mp *r = *(mp **)v[2].buf;
452   mp *r2 = *(mp **)v[3].buf;
453
454   mpmont mm;
455   int ok = 1;
456
457   mpmont_create(&mm, m);
458
459   if (mm.mi->v[0] != mi->v[0]) {
460     fprintf(stderr, "\n*** bad mi: found %lu, expected %lu",
461             (unsigned long)mm.mi->v[0], (unsigned long)mi->v[0]);
462     fputs("\nm = ", stderr); mp_writefile(m, stderr, 10);
463     fputc('\n', stderr);
464     ok = 0;
465   }
466
467   if (!MP_EQ(mm.r, r)) {
468     fputs("\n*** bad r", stderr);
469     fputs("\nm = ", stderr); mp_writefile(m, stderr, 10);
470     fputs("\nexpected ", stderr); mp_writefile(r, stderr, 10);
471     fputs("\n   found ", stderr); mp_writefile(mm.r, stderr, 10);
472     fputc('\n', stderr);
473     ok = 0;
474   }
475
476   if (!MP_EQ(mm.r2, r2)) {
477     fputs("\n*** bad r2", stderr);
478     fputs("\nm = ", stderr); mp_writefile(m, stderr, 10);
479     fputs("\nexpected ", stderr); mp_writefile(r2, stderr, 10);
480     fputs("\n   found ", stderr); mp_writefile(mm.r2, stderr, 10);
481     fputc('\n', stderr);
482     ok = 0;
483   }
484
485   MP_DROP(m);
486   MP_DROP(mi);
487   MP_DROP(r);
488   MP_DROP(r2);
489   mpmont_destroy(&mm);
490   assert(mparena_count(MPARENA_GLOBAL) == 0);
491   return (ok);
492 }
493
494 static int tmul(dstr *v)
495 {
496   mp *m = *(mp **)v[0].buf;
497   mp *a = *(mp **)v[1].buf;
498   mp *b = *(mp **)v[2].buf;
499   mp *r = *(mp **)v[3].buf;
500   int ok = 1;
501
502   mpmont mm;
503   mpmont_create(&mm, m);
504
505   {
506     mp *qr = mp_mul(MP_NEW, a, b);
507     mp_div(0, &qr, qr, m);
508
509     if (!MP_EQ(qr, r)) {
510       fputs("\n*** classical modmul failed", stderr);
511       fputs("\n m = ", stderr); mp_writefile(m, stderr, 10);
512       fputs("\n a = ", stderr); mp_writefile(a, stderr, 10);
513       fputs("\n b = ", stderr); mp_writefile(b, stderr, 10);
514       fputs("\n r = ", stderr); mp_writefile(r, stderr, 10);
515       fputs("\nqr = ", stderr); mp_writefile(qr, stderr, 10);
516       fputc('\n', stderr);
517       ok = 0;
518     }
519
520     mp_drop(qr);
521   }
522
523   {
524     mp *ar = mpmont_mul(&mm, MP_NEW, a, mm.r2);
525     mp *br = mpmont_mul(&mm, MP_NEW, b, mm.r2);
526     mp *mr = mpmont_mul(&mm, MP_NEW, ar, br);
527     mr = mpmont_reduce(&mm, mr, mr);
528     if (!MP_EQ(mr, r)) {
529       fputs("\n*** montgomery modmul failed", stderr);
530       fputs("\n m = ", stderr); mp_writefile(m, stderr, 10);
531       fputs("\n a = ", stderr); mp_writefile(a, stderr, 10);
532       fputs("\n b = ", stderr); mp_writefile(b, stderr, 10);
533       fputs("\n r = ", stderr); mp_writefile(r, stderr, 10);
534       fputs("\nmr = ", stderr); mp_writefile(mr, stderr, 10);
535       fputc('\n', stderr);
536       ok = 0;
537     }
538     MP_DROP(ar); MP_DROP(br);
539     mp_drop(mr);
540   }
541
542
543   MP_DROP(m);
544   MP_DROP(a);
545   MP_DROP(b);
546   MP_DROP(r);
547   mpmont_destroy(&mm);
548   assert(mparena_count(MPARENA_GLOBAL) == 0);
549   return ok;
550 }
551
552 static test_chunk tests[] = {
553   { "create", tcreate, { &type_mp, &type_mp, &type_mp, &type_mp, 0 } },
554   { "mul", tmul, { &type_mp, &type_mp, &type_mp, &type_mp, 0 } },
555   { 0, 0, { 0 } },
556 };
557
558 int main(int argc, char *argv[])
559 {
560   sub_init();
561   test_run(argc, argv, tests, SRCDIR "/t/mpmont");
562   return (0);
563 }
564
565 #endif
566
567 /*----- That's all, folks -------------------------------------------------*/