chiark / gitweb /
New multiprecision integer arithmetic suite.
[catacomb] / mpmont.c
1 /* -*-c-*-
2  *
3  * $Id: mpmont.c,v 1.1 1999/11/17 18:02:16 mdw Exp $
4  *
5  * Montgomery reduction
6  *
7  * (c) 1999 Straylight/Edgeware
8  */
9
10 /*----- Licensing notice --------------------------------------------------* 
11  *
12  * This file is part of Catacomb.
13  *
14  * Catacomb is free software; you can redistribute it and/or modify
15  * it under the terms of the GNU Library General Public License as
16  * published by the Free Software Foundation; either version 2 of the
17  * License, or (at your option) any later version.
18  * 
19  * Catacomb is distributed in the hope that it will be useful,
20  * but WITHOUT ANY WARRANTY; without even the implied warranty of
21  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
22  * GNU Library General Public License for more details.
23  * 
24  * You should have received a copy of the GNU Library General Public
25  * License along with Catacomb; if not, write to the Free
26  * Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
27  * MA 02111-1307, USA.
28  */
29
30 /*----- Revision history --------------------------------------------------* 
31  *
32  * $Log: mpmont.c,v $
33  * Revision 1.1  1999/11/17 18:02:16  mdw
34  * New multiprecision integer arithmetic suite.
35  *
36  */
37
38 /*----- Header files ------------------------------------------------------*/
39
40 #include "mp.h"
41 #include "mpmont.h"
42
43 /*----- Main code ---------------------------------------------------------*/
44
45 /* --- @mpmont_create@ --- *
46  *
47  * Arguments:   @mpmont *mm@ = pointer to Montgomery reduction context
48  *              @mp *m@ = modulus to use
49  *
50  * Returns:     ---
51  *
52  * Use:         Initializes a Montgomery reduction context ready for use.
53  */
54
55 void mpmont_create(mpmont *mm, mp *m)
56 {
57   /* --- Take a copy of the modulus --- */
58
59   mp_shrink(m);
60   mm->m = MP_COPY(m);
61
62   /* --- Find the magic value @mi@ --- *
63    *
64    * This is a slightly grungy way of solving the problem, but it does work.
65    */
66
67   {
68     mpw av[2] = { 0, 1 };
69     mp a, b;
70     mp *i;
71     mpw mi;
72
73     mp_build(&a, av, av + 2);
74     mp_build(&b, m->v, m->v + 1);
75     mp_gcd(0, 0, &i, &a, &b);
76     mi = i->v[0];
77     if (!(i->f & MP_NEG))
78       mi = MPW(-mi);
79     mm->mi = mi;
80     MP_DROP(i);
81   }
82
83   /* --- Discover the values %$R \bmod m$% and %$R^2 \bmod m$% --- */
84
85   {
86     size_t l = MP_LEN(m);
87     mp *r = mp_create(l + 1);
88
89     mm->shift = l * MPW_BITS;
90     MPX_ZERO(r->v, r->vl - 1);
91     r->vl[-1] = 1;
92     mm->r = mm->r2 = MP_NEW;
93     mp_div(0, &mm->r, r, m);
94     r = mp_sqr(r, mm->r);
95     mp_div(0, &mm->r2, r, m);
96     MP_DROP(r);
97   }
98 }
99
100 /* --- @mpmont_destroy@ --- *
101  *
102  * Arguments:   @mpmont *mm@ = pointer to a Montgomery reduction context
103  *
104  * Returns:     ---
105  *
106  * Use:         Disposes of a context when it's no longer of any use to
107  *              anyone.
108  */
109
110 void mpmont_destroy(mpmont *mm)
111 {
112   MP_DROP(mm->m);
113   MP_DROP(mm->r);
114   MP_DROP(mm->r2);
115 }
116
117 /* --- @mpmont_reduce@ --- *
118  *
119  * Arguments:   @mpmont *mm@ = pointer to Montgomery reduction context
120  *              @mp *d@ = destination
121  *              @const mp *a@ = source, assumed positive
122  *
123  * Returns:     Result, %$a R^{-1} \bmod m$%.
124  */
125
126 mp *mpmont_reduce(mpmont *mm, mp *d, const mp *a)
127 {
128   mpw *dv, *dvl;
129   const mpw *mv, *mvl;
130   size_t n;
131
132   /* --- Initial conditioning of the arguments --- */
133
134   n = MP_LEN(mm->m);
135
136   if (d == a)
137     MP_MODIFY(d, 2 * n);
138   else {
139     MP_MODIFY(d, 2 * n);
140     memcpy(d->v, a->v, MPWS(MP_LEN(a)));
141     memset(d->v + MP_LEN(a), 0, MPWS(MP_LEN(d) - MP_LEN(a)));
142   }
143     
144   dv = d->v; dvl = d->vl;
145   mv = mm->m->v; mvl = mm->m->vl;
146
147   /* --- Let's go to work --- */
148
149   while (n--) {
150     mpw u = MPW(*dv * mm->mi);
151     MPX_UMLAN(dv, dvl, mv, mvl, u);
152     dv++;
153   }
154
155   /* --- Done --- */
156
157   memmove(d->v, dv, MPWS(dvl - dv));
158   d->vl -= dv - d->v;
159   MP_SHRINK(d);
160   d->f = a->f & MP_BURN;
161   return (d);
162 }
163
164 /* --- @mpmont_mul@ --- *
165  *
166  * Arguments:   @mpmont *mm@ = pointer to Montgomery reduction context
167  *              @mp *d@ = destination
168  *              @const mp *a, *b@ = sources, assumed positive
169  *
170  * Returns:     Result, %$a b R^{-1} \bmod m$%.
171  */
172
173 mp *mpmont_mul(mpmont *mm, mp *d, const mp *a, const mp *b)
174 {
175   mpw *dv, *dvl;
176   const mpw *av, *avl;
177   const mpw *bv, *bvl;
178   const mpw *mv, *mvl;
179   mpw y;
180   size_t n, i;
181
182   /* --- Initial conditioning of the arguments --- */
183
184   if (MP_LEN(a) > MP_LEN(b)) {
185     const mp *t = a; a = b; b = t;
186   }
187   n = MP_LEN(mm->m);
188     
189   MP_MODIFY(d, 2 * n + 1);
190   dv = d->v; dvl = d->vl;
191   MPX_ZERO(dv, dvl);
192   av = a->v; avl = a->vl;
193   bv = b->v; bvl = b->vl;
194   mv = mm->m->v; mvl = mm->m->vl;
195   y = *bv;
196
197   /* --- Montgomery multiplication phase --- */
198
199   i = 0;
200   while (i < n && av < avl) {
201     mpw x = *av++;
202     mpw u = MPW((*dv + x * y) * mm->mi);
203     MPX_UMLAN(dv, dvl, bv, bvl, x);
204     MPX_UMLAN(dv, dvl, mv, mvl, u);
205     dv++;
206     i++;
207   }
208
209   /* --- Simpler Montgomery reduction phase --- */
210
211   while (i < n) {
212     mpw u = MPW(*dv * mm->mi);
213     MPX_UMLAN(dv, dvl, mv, mvl, u);
214     dv++;
215     i++;
216   }
217
218   /* --- Done --- */
219
220   memmove(d->v, dv, MPWS(dvl - dv));
221   d->vl -= dv - d->v;
222   MP_SHRINK(d);
223   d->f = (a->f | b->f) & MP_BURN;
224   return (d);
225 }
226
227 /* --- @mpmont_exp@ --- *
228  *
229  * Arguments:   @mpmont *mm@ = pointer to Montgomery reduction context
230  *              @const mp *a@ = base
231  *              @const mp *e@ = exponent
232  *
233  * Returns:     Result, %$a^e \bmod m$%.
234  */
235
236 mp *mpmont_exp(mpmont *mm, const mp *a, const mp *e)
237 {
238   mpscan sc;
239   mp *ar = mpmont_mul(mm, MP_NEW, a, mm->r2);
240   mp *d = MP_COPY(mm->r);
241
242   mp_scan(&sc, e);
243
244   if (MP_STEP(&sc)) {
245     for (;;) {
246       mp *dd;
247       if (MP_BIT(&sc)) {
248         dd = mpmont_mul(mm, MP_NEW, d, ar);
249         MP_DROP(d);
250         d = dd;
251       }
252       if (!MP_STEP(&sc))
253         break;
254       dd = mpmont_mul(mm, MP_NEW, ar, ar);
255       MP_DROP(ar);
256       ar = dd;
257     }
258   }
259   MP_DROP(ar);
260   return (mpmont_reduce(mm, d, d));
261 }
262
263 /*----- Test rig ----------------------------------------------------------*/
264
265 #ifdef TEST_RIG
266
267 static int tcreate(dstr *v)
268 {
269   mp *m = *(mp **)v[0].buf;
270   mp *mi = *(mp **)v[1].buf;
271   mp *r = *(mp **)v[2].buf;
272   mp *r2 = *(mp **)v[3].buf;
273
274   mpmont mm;
275   int ok = 1;
276
277   mpmont_create(&mm, m);
278
279   if (mm.mi != mi->v[0]) {
280     fprintf(stderr, "\n*** bad mi: found %lu, expected %lu",
281             (unsigned long)mm.mi, (unsigned long)mi->v[0]);
282     fputs("\nm = ", stderr); mp_writefile(m, stderr, 10);
283     fputc('\n', stderr);
284     ok = 0;
285   }
286
287   if (MP_CMP(mm.r, !=, r)) {
288     fputs("\n*** bad r", stderr);
289     fputs("\nm = ", stderr); mp_writefile(m, stderr, 10);
290     fputs("\nexpected ", stderr); mp_writefile(r, stderr, 10);
291     fputs("\n   found ", stderr); mp_writefile(r, stderr, 10);
292     fputc('\n', stderr);
293     ok = 0;
294   }
295
296   if (MP_CMP(mm.r2, !=, r2)) {
297     fputs("\n*** bad r2", stderr);
298     fputs("\nm = ", stderr); mp_writefile(m, stderr, 10);
299     fputs("\nexpected ", stderr); mp_writefile(r2, stderr, 10);
300     fputs("\n   found ", stderr); mp_writefile(r2, stderr, 10);
301     fputc('\n', stderr);
302     ok = 0;
303   }
304
305   MP_DROP(m);
306   MP_DROP(mi);
307   MP_DROP(r);
308   MP_DROP(r2);
309   mpmont_destroy(&mm);
310   return (ok);
311 }
312
313 static int tmul(dstr *v)
314 {
315   mp *m = *(mp **)v[0].buf;
316   mp *a = *(mp **)v[1].buf;
317   mp *b = *(mp **)v[2].buf;
318   mp *r = *(mp **)v[3].buf;
319   mp *mr, *qr;
320   int ok = 1;
321
322   mpmont mm;
323   mpmont_create(&mm, m);
324
325   {
326     mp *ar = mpmont_mul(&mm, MP_NEW, a, mm.r2);
327     mp *br = mpmont_mul(&mm, MP_NEW, b, mm.r2);
328     mr = mpmont_mul(&mm, MP_NEW, ar, br);
329     mr = mpmont_reduce(&mm, mr, mr);
330     MP_DROP(ar); MP_DROP(br);
331   }
332
333   qr = mp_mul(MP_NEW, a, b);
334   mp_div(0, &qr, qr, m);
335
336   if (MP_CMP(qr, !=, r)) {
337     fputs("\n*** classical modmul failed", stderr);
338     fputs("\n m = ", stderr); mp_writefile(m, stderr, 10);
339     fputs("\n a = ", stderr); mp_writefile(a, stderr, 10);
340     fputs("\n b = ", stderr); mp_writefile(b, stderr, 10);
341     fputs("\n r = ", stderr); mp_writefile(r, stderr, 10);
342     fputs("\nqr = ", stderr); mp_writefile(qr, stderr, 10);
343     fputc('\n', stderr);
344     ok = 0;
345   }
346
347   if (MP_CMP(mr, !=, r)) {
348     fputs("\n*** montgomery modmul failed", stderr);
349     fputs("\n m = ", stderr); mp_writefile(m, stderr, 10);
350     fputs("\n a = ", stderr); mp_writefile(a, stderr, 10);
351     fputs("\n b = ", stderr); mp_writefile(b, stderr, 10);
352     fputs("\n r = ", stderr); mp_writefile(r, stderr, 10);
353     fputs("\nmr = ", stderr); mp_writefile(mr, stderr, 10);
354     fputc('\n', stderr);
355     ok = 0;
356   }
357
358   MP_DROP(m);
359   MP_DROP(a);
360   MP_DROP(b);
361   MP_DROP(r);
362   MP_DROP(mr);
363   MP_DROP(qr);
364   mpmont_destroy(&mm);
365   return ok;
366 }
367
368 static int texp(dstr *v)
369 {
370   mp *m = *(mp **)v[0].buf;
371   mp *a = *(mp **)v[1].buf;
372   mp *b = *(mp **)v[2].buf;
373   mp *r = *(mp **)v[3].buf;
374   mp *mr;
375   int ok = 1;
376
377   mpmont mm;
378   mpmont_create(&mm, m);
379
380   mr = mpmont_exp(&mm, a, b);
381
382   if (MP_CMP(mr, !=, r)) {
383     fputs("\n*** montgomery modexp failed", stderr);
384     fputs("\n m = ", stderr); mp_writefile(m, stderr, 10);
385     fputs("\n a = ", stderr); mp_writefile(a, stderr, 10);
386     fputs("\n e = ", stderr); mp_writefile(b, stderr, 10);
387     fputs("\n r = ", stderr); mp_writefile(r, stderr, 10);
388     fputs("\nmr = ", stderr); mp_writefile(mr, stderr, 10);
389     fputc('\n', stderr);
390     ok = 0;
391   }
392
393   MP_DROP(m);
394   MP_DROP(a);
395   MP_DROP(b);
396   MP_DROP(r);
397   MP_DROP(mr);
398   mpmont_destroy(&mm);
399   return ok;
400 }
401
402
403 static test_chunk tests[] = {
404   { "create", tcreate, { &type_mp, &type_mp, &type_mp, &type_mp } },
405   { "mul", tmul, { &type_mp, &type_mp, &type_mp, &type_mp } },
406   { "exp", texp, { &type_mp, &type_mp, &type_mp, &type_mp } },
407   { 0, 0, { 0 } },
408 };
409
410 int main(int argc, char *argv[])
411 {
412   sub_init();
413   test_run(argc, argv, tests, SRCDIR "/tests/mpmont");
414   return (0);
415 }
416
417 #endif
418
419 /*----- That's all, folks -------------------------------------------------*/