chiark / gitweb /
math/mpx-mul4-x86-sse2.S: Slightly reorder to reduce dependence.
[catacomb] / math / mpmont.c
1 /* -*-c-*-
2  *
3  * Montgomery reduction
4  *
5  * (c) 1999 Straylight/Edgeware
6  */
7
8 /*----- Licensing notice --------------------------------------------------*
9  *
10  * This file is part of Catacomb.
11  *
12  * Catacomb is free software; you can redistribute it and/or modify
13  * it under the terms of the GNU Library General Public License as
14  * published by the Free Software Foundation; either version 2 of the
15  * License, or (at your option) any later version.
16  *
17  * Catacomb is distributed in the hope that it will be useful,
18  * but WITHOUT ANY WARRANTY; without even the implied warranty of
19  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
20  * GNU Library General Public License for more details.
21  *
22  * You should have received a copy of the GNU Library General Public
23  * License along with Catacomb; if not, write to the Free
24  * Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
25  * MA 02111-1307, USA.
26  */
27
28 /*----- Header files ------------------------------------------------------*/
29
30 #include "config.h"
31 #include "dispatch.h"
32 #include "mp.h"
33 #include "mpmont.h"
34
35 /*----- Tweakables --------------------------------------------------------*/
36
37 /* --- @MPMONT_DISABLE@ --- *
38  *
39  * Replace all the clever Montgomery reduction with good old-fashioned long
40  * division.
41  */
42
43 /* #define MPMONT_DISABLE */
44
45 #define MPMONT_KTHRESH (16*MPK_THRESH)
46
47 /*----- Low-level implementation ------------------------------------------*/
48
49 #ifndef MPMONT_DISABLE
50
51 /* --- @redccore@ --- *
52  *
53  * Arguments:   @mpw *dv, *dvl@ = base and limit of source/destination
54  *              @const mpw *mv@ = base of modulus %$m$%
55  *              @size_t n@ = length of modulus
56  *              @const mpw *mi@ = base of REDC coefficient %$m'$%
57  *
58  * Returns:     ---
59  *
60  * Use:         Let %$a$% be the input operand.  Store in %$d$% the value
61  *              %$a + (m' a \bmod R) m$%.  The destination has space for at
62  *              least %$2 n + 1$% words of result.
63  */
64
65 CPU_DISPATCH(static, (void), void, redccore,
66              (mpw *dv, mpw *dvl, const mpw *mv, size_t n, const mpw *mi),
67              (dv, dvl, mv, n, mi), pick_redccore, simple_redccore);
68
69 static void simple_redccore(mpw *dv, mpw *dvl, const mpw *mv,
70                             size_t n, const mpw *mi)
71 {
72   mpw mi0 = *mi;
73   size_t i;
74
75   for (i = 0; i < n; i++) {
76     MPX_UMLAN(dv, dvl, mv, mv + n, MPW(*dv*mi0));
77     dv++;
78   }
79 }
80
81 #define MAYBE_REDC4(impl)                                               \
82   extern void mpxmont_redc4_##impl(mpw *dv, mpw *dvl, const mpw *mv,    \
83                                    size_t n, const mpw *mi);            \
84   static void maybe_redc4_##impl(mpw *dv, mpw *dvl, const mpw *mv,      \
85                                  size_t n, const mpw *mi)               \
86   {                                                                     \
87     if (n%4) simple_redccore(dv, dvl, mv, n, mi);                       \
88     else mpxmont_redc4_##impl(dv, dvl, mv, n, mi);                      \
89   }
90
91 #if CPUFAM_X86
92   MAYBE_REDC4(x86_sse2)
93 #endif
94
95 static redccore__functype *pick_redccore(void)
96 {
97 #if CPUFAM_X86
98   DISPATCH_PICK_COND(mpmont_reduce, maybe_redc4_x86_sse2,
99                      cpu_feature_p(CPUFEAT_X86_SSE2));
100 #endif
101   DISPATCH_PICK_FALLBACK(mpmont_reduce, simple_redccore);
102 }
103
104 /* --- @redccore@ --- *
105  *
106  * Arguments:   @mpw *dv, *dvl@ = base and limit of source/destination
107  *              @const mpw *av, *avl@ = base and limit of first multiplicand
108  *              @const mpw *bv, *bvl@ = base and limit of second multiplicand
109  *              @const mpw *mv@ = base of modulus %$m$%
110  *              @size_t n@ = length of modulus
111  *              @const mpw *mi@ = base of REDC coefficient %$m'$%
112  *
113  * Returns:     ---
114  *
115  * Use:         Let %$a$% and %$b$% be the multiplicands.  Let %$w = a b$%.
116  *              Store in %$d$% the value %$a b + (m' a b \bmod R) m$%.
117  */
118
119 CPU_DISPATCH(static, (void), void, mulcore,
120              (mpw *dv, mpw *dvl, const mpw *av, const mpw *avl,
121               const mpw *bv, const mpw *bvl, const mpw *mv,
122               size_t n, const mpw *mi),
123              (dv, dvl, av, avl, bv, bvl, mv, n, mi),
124              pick_mulcore, simple_mulcore);
125
126 static void simple_mulcore(mpw *dv, mpw *dvl,
127                            const mpw *av, const mpw *avl,
128                            const mpw *bv, const mpw *bvl,
129                            const mpw *mv, size_t n, const mpw *mi)
130 {
131   mpw ai, b0, y, mi0 = *mi;
132   const mpw *tv, *tvl;
133   const mpw *mvl = mv + n;
134   size_t i = 0;
135
136   /* --- Initial setup --- */
137
138   MPX_ZERO(dv, dvl);
139   if (avl - av > bvl - bv) {
140     tv = av; av = bv; bv = tv;
141     tvl = avl; avl = bvl; bvl = tvl;
142   }
143   b0 = *bv;
144
145   /* --- Multiply, until we run out of multiplicand --- */
146
147   while (i < n && av < avl) {
148     ai = *av++;
149     y = MPW((*dv + ai*b0)*mi0);
150     MPX_UMLAN(dv, dvl, bv, bvl, ai);
151     MPX_UMLAN(dv, dvl, mv, mvl, y);
152     dv++; i++;
153   }
154
155   /* --- Continue reducing until we run out of modulus --- */
156
157   while (i < n) {
158     y = MPW(*dv*mi0);
159     MPX_UMLAN(dv, dvl, mv, mvl, y);
160     dv++; i++;
161   }
162 }
163
164 #define MAYBE_MUL4(impl)                                                \
165   extern void mpxmont_mul4_##impl(mpw *dv,                              \
166                                   const mpw *av, const mpw *bv,         \
167                                   const mpw *mv,                        \
168                                   size_t n, const mpw *mi);             \
169   static void maybe_mul4_##impl(mpw *dv, mpw *dvl,                      \
170                            const mpw *av, const mpw *avl,               \
171                            const mpw *bv, const mpw *bvl,               \
172                            const mpw *mv, size_t n, const mpw *mi)      \
173   {                                                                     \
174     size_t an = avl - av, bn = bvl - bv;                                \
175     if (n%4 || an != n || bn != n)                                      \
176       simple_mulcore(dv, dvl, av, avl, bv, bvl, mv, n, mi);             \
177     else {                                                              \
178       mpxmont_mul4_##impl(dv, av, bv, mv, n, mi);                       \
179       MPX_ZERO(dv + 2*n + 1, dvl);                                      \
180     }                                                                   \
181   }
182
183 #if CPUFAM_X86
184   MAYBE_MUL4(x86_sse2)
185 #endif
186
187 static mulcore__functype *pick_mulcore(void)
188 {
189 #if CPUFAM_X86
190   DISPATCH_PICK_COND(mpmont_mul, maybe_mul4_x86_sse2,
191                      cpu_feature_p(CPUFEAT_X86_SSE2));
192 #endif
193   DISPATCH_PICK_FALLBACK(mpmont_mul, simple_mulcore);
194 }
195
196 /* --- @finish@ --- *
197  *
198  * Arguments:   @mpmont *mm@ = pointer to a Montgomery reduction context
199  *              *mp *d@ = pointer to mostly-reduced operand
200  *
201  * Returns:     ---
202  *
203  * Use:         Applies the finishing touches to Montgomery reduction.  The
204  *              operand @d@ is a multiple of %$R%$ at this point, so it needs
205  *              to be shifted down; the result might need a further
206  *              subtraction to get it into the right interval; and we may
207  *              need to do an additional subtraction if %$d$% is negative.
208  */
209
210 static void finish(mpmont *mm, mp *d)
211 {
212   mpw *dv = d->v, *dvl = d->vl;
213   size_t n = mm->n;
214
215   memmove(dv, dv + n, MPWS(dvl - (dv + n)));
216   dvl -= n;
217
218   if (MPX_UCMP(dv, dvl, >=, mm->m->v, mm->m->vl))
219     mpx_usub(dv, dvl, dv, dvl, mm->m->v, mm->m->vl);
220
221   if (d->f & MP_NEG) {
222     mpx_usub(dv, dvl, mm->m->v, mm->m->vl, dv, dvl);
223     d->f &= ~MP_NEG;
224   }
225
226   d->vl = dvl;
227   MP_SHRINK(d);
228 }
229
230 #endif
231
232 /*----- Reduction and multiplication --------------------------------------*/
233
234 /* --- @mpmont_create@ --- *
235  *
236  * Arguments:   @mpmont *mm@ = pointer to Montgomery reduction context
237  *              @mp *m@ = modulus to use
238  *
239  * Returns:     Zero on success, nonzero on error.
240  *
241  * Use:         Initializes a Montgomery reduction context ready for use.
242  *              The argument @m@ must be a positive odd integer.
243  */
244
245 #ifdef MPMONT_DISABLE
246
247 int mpmont_create(mpmont *mm, mp *m)
248 {
249   mp_shrink(m);
250   mm->m = MP_COPY(m);
251   mm->r = MP_ONE;
252   mm->r2 = MP_ONE;
253   mm->mi = MP_ONE;
254   return (0);
255 }
256
257 #else
258
259 int mpmont_create(mpmont *mm, mp *m)
260 {
261   size_t n = MP_LEN(m);
262   mp *r2 = mp_new(2 * n + 1, 0);
263   mp r;
264
265   /* --- Take a copy of the modulus --- */
266
267  if (!MP_POSP(m) || !MP_ODDP(m))
268    return (-1);
269   mm->m = MP_COPY(m);
270
271   /* --- Determine %$R^2$% --- */
272
273   mm->n = n;
274   MPX_ZERO(r2->v, r2->vl - 1);
275   r2->vl[-1] = 1;
276
277   /* --- Find the magic value @mi@ --- */
278
279   mp_build(&r, r2->v + n, r2->vl);
280   mm->mi = mp_modinv(MP_NEW, m, &r);
281   mm->mi = mp_sub(mm->mi, &r, mm->mi);
282   MP_ENSURE(mm->mi, n);
283
284   /* --- Discover the values %$R \bmod m$% and %$R^2 \bmod m$% --- */
285
286   mm->r2 = MP_NEW;
287   mp_div(0, &mm->r2, r2, m);
288   mm->r = mpmont_reduce(mm, MP_NEW, mm->r2);
289   MP_DROP(r2);
290   return (0);
291 }
292
293 #endif
294
295 /* --- @mpmont_destroy@ --- *
296  *
297  * Arguments:   @mpmont *mm@ = pointer to a Montgomery reduction context
298  *
299  * Returns:     ---
300  *
301  * Use:         Disposes of a context when it's no longer of any use to
302  *              anyone.
303  */
304
305 void mpmont_destroy(mpmont *mm)
306 {
307   MP_DROP(mm->m);
308   MP_DROP(mm->r);
309   MP_DROP(mm->r2);
310   MP_DROP(mm->mi);
311 }
312
313 /* --- @mpmont_reduce@ --- *
314  *
315  * Arguments:   @mpmont *mm@ = pointer to Montgomery reduction context
316  *              @mp *d@ = destination
317  *              @mp *a@ = source, assumed positive
318  *
319  * Returns:     Result, %$a R^{-1} \bmod m$%.
320  */
321
322 #ifdef MPMONT_DISABLE
323
324 mp *mpmont_reduce(mpmont *mm, mp *d, mp *a)
325 {
326   mp_div(0, &d, a, mm->m);
327   return (d);
328 }
329
330 #else
331
332 mp *mpmont_reduce(mpmont *mm, mp *d, mp *a)
333 {
334   size_t n = mm->n;
335
336   /* --- Check for serious Karatsuba reduction --- */
337
338   if (n > MPMONT_KTHRESH) {
339     mp al;
340     mpw *vl;
341     mp *u;
342
343     if (MP_LEN(a) >= n) vl = a->v + n;
344     else vl = a->vl;
345     mp_build(&al, a->v, vl);
346     u = mp_mul(MP_NEW, &al, mm->mi);
347     if (MP_LEN(u) > n) u->vl = u->v + n;
348     u = mp_mul(u, u, mm->m);
349     d = mp_add(d, a, u);
350     MP_ENSURE(d, n);
351     mp_drop(u);
352   }
353
354   /* --- Otherwise do it the hard way --- */
355
356   else {
357     a = MP_COPY(a);
358     if (d) MP_DROP(d);
359     d = a;
360     MP_DEST(d, 2*mm->n + 1, a->f);
361     redccore(d->v, d->vl, mm->m->v, mm->n, mm->mi->v);
362   }
363
364   /* --- Wrap everything up --- */
365
366   finish(mm, d);
367   return (d);
368 }
369
370 #endif
371
372 /* --- @mpmont_mul@ --- *
373  *
374  * Arguments:   @mpmont *mm@ = pointer to Montgomery reduction context
375  *              @mp *d@ = destination
376  *              @mp *a, *b@ = sources, assumed positive
377  *
378  * Returns:     Result, %$a b R^{-1} \bmod m$%.
379  */
380
381 #ifdef MPMONT_DISABLE
382
383 mp *mpmont_mul(mpmont *mm, mp *d, mp *a, mp *b)
384 {
385   d = mp_mul(d, a, b);
386   mp_div(0, &d, d, mm->m);
387   return (d);
388 }
389
390 #else
391
392 mp *mpmont_mul(mpmont *mm, mp *d, mp *a, mp *b)
393 {
394   size_t n = mm->n;
395
396   if (n > MPMONT_KTHRESH) {
397     d = mp_mul(d, a, b);
398     d = mpmont_reduce(mm, d, d);
399   } else {
400     a = MP_COPY(a); b = MP_COPY(b);
401     MP_DEST(d, 2*n + 1, a->f | b->f | MP_UNDEF);
402     mulcore(d->v, d->vl, a->v, a->vl, b->v, b->vl,
403             mm->m->v, mm->n, mm->mi->v);
404     d->f = ((a->f | b->f) & MP_BURN) | ((a->f ^ b->f) & MP_NEG);
405     finish(mm, d);
406     MP_DROP(a); MP_DROP(b);
407   }
408
409   return (d);
410 }
411
412 #endif
413
414 /*----- Test rig ----------------------------------------------------------*/
415
416 #ifdef TEST_RIG
417
418 static int tcreate(dstr *v)
419 {
420   mp *m = *(mp **)v[0].buf;
421   mp *mi = *(mp **)v[1].buf;
422   mp *r = *(mp **)v[2].buf;
423   mp *r2 = *(mp **)v[3].buf;
424
425   mpmont mm;
426   int ok = 1;
427
428   mpmont_create(&mm, m);
429
430   if (mm.mi->v[0] != mi->v[0]) {
431     fprintf(stderr, "\n*** bad mi: found %lu, expected %lu",
432             (unsigned long)mm.mi->v[0], (unsigned long)mi->v[0]);
433     fputs("\nm = ", stderr); mp_writefile(m, stderr, 10);
434     fputc('\n', stderr);
435     ok = 0;
436   }
437
438   if (!MP_EQ(mm.r, r)) {
439     fputs("\n*** bad r", stderr);
440     fputs("\nm = ", stderr); mp_writefile(m, stderr, 10);
441     fputs("\nexpected ", stderr); mp_writefile(r, stderr, 10);
442     fputs("\n   found ", stderr); mp_writefile(mm.r, stderr, 10);
443     fputc('\n', stderr);
444     ok = 0;
445   }
446
447   if (!MP_EQ(mm.r2, r2)) {
448     fputs("\n*** bad r2", stderr);
449     fputs("\nm = ", stderr); mp_writefile(m, stderr, 10);
450     fputs("\nexpected ", stderr); mp_writefile(r2, stderr, 10);
451     fputs("\n   found ", stderr); mp_writefile(mm.r2, stderr, 10);
452     fputc('\n', stderr);
453     ok = 0;
454   }
455
456   MP_DROP(m);
457   MP_DROP(mi);
458   MP_DROP(r);
459   MP_DROP(r2);
460   mpmont_destroy(&mm);
461   assert(mparena_count(MPARENA_GLOBAL) == 0);
462   return (ok);
463 }
464
465 static int tmul(dstr *v)
466 {
467   mp *m = *(mp **)v[0].buf;
468   mp *a = *(mp **)v[1].buf;
469   mp *b = *(mp **)v[2].buf;
470   mp *r = *(mp **)v[3].buf;
471   int ok = 1;
472
473   mpmont mm;
474   mpmont_create(&mm, m);
475
476   {
477     mp *qr = mp_mul(MP_NEW, a, b);
478     mp_div(0, &qr, qr, m);
479
480     if (!MP_EQ(qr, r)) {
481       fputs("\n*** classical modmul failed", stderr);
482       fputs("\n m = ", stderr); mp_writefile(m, stderr, 10);
483       fputs("\n a = ", stderr); mp_writefile(a, stderr, 10);
484       fputs("\n b = ", stderr); mp_writefile(b, stderr, 10);
485       fputs("\n r = ", stderr); mp_writefile(r, stderr, 10);
486       fputs("\nqr = ", stderr); mp_writefile(qr, stderr, 10);
487       fputc('\n', stderr);
488       ok = 0;
489     }
490
491     mp_drop(qr);
492   }
493
494   {
495     mp *ar = mpmont_mul(&mm, MP_NEW, a, mm.r2);
496     mp *br = mpmont_mul(&mm, MP_NEW, b, mm.r2);
497     mp *mr = mpmont_mul(&mm, MP_NEW, ar, br);
498     mr = mpmont_reduce(&mm, mr, mr);
499     if (!MP_EQ(mr, r)) {
500       fputs("\n*** montgomery modmul failed", stderr);
501       fputs("\n m = ", stderr); mp_writefile(m, stderr, 10);
502       fputs("\n a = ", stderr); mp_writefile(a, stderr, 10);
503       fputs("\n b = ", stderr); mp_writefile(b, stderr, 10);
504       fputs("\n r = ", stderr); mp_writefile(r, stderr, 10);
505       fputs("\nmr = ", stderr); mp_writefile(mr, stderr, 10);
506       fputc('\n', stderr);
507       ok = 0;
508     }
509     MP_DROP(ar); MP_DROP(br);
510     mp_drop(mr);
511   }
512
513
514   MP_DROP(m);
515   MP_DROP(a);
516   MP_DROP(b);
517   MP_DROP(r);
518   mpmont_destroy(&mm);
519   assert(mparena_count(MPARENA_GLOBAL) == 0);
520   return ok;
521 }
522
523 static test_chunk tests[] = {
524   { "create", tcreate, { &type_mp, &type_mp, &type_mp, &type_mp, 0 } },
525   { "mul", tmul, { &type_mp, &type_mp, &type_mp, &type_mp, 0 } },
526   { 0, 0, { 0 } },
527 };
528
529 int main(int argc, char *argv[])
530 {
531   sub_init();
532   test_run(argc, argv, tests, SRCDIR "/t/mpmont");
533   return (0);
534 }
535
536 #endif
537
538 /*----- That's all, folks -------------------------------------------------*/