chiark / gitweb /
Use sliding-window exponentiation.
[catacomb] / mpbarrett.c
1 /* -*-c-*-
2  *
3  * $Id: mpbarrett.c,v 1.7 2001/04/19 18:25:26 mdw Exp $
4  *
5  * Barrett modular reduction
6  *
7  * (c) 1999 Straylight/Edgeware
8  */
9
10 /*----- Licensing notice --------------------------------------------------* 
11  *
12  * This file is part of Catacomb.
13  *
14  * Catacomb is free software; you can redistribute it and/or modify
15  * it under the terms of the GNU Library General Public License as
16  * published by the Free Software Foundation; either version 2 of the
17  * License, or (at your option) any later version.
18  * 
19  * Catacomb is distributed in the hope that it will be useful,
20  * but WITHOUT ANY WARRANTY; without even the implied warranty of
21  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
22  * GNU Library General Public License for more details.
23  * 
24  * You should have received a copy of the GNU Library General Public
25  * License along with Catacomb; if not, write to the Free
26  * Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
27  * MA 02111-1307, USA.
28  */
29
30 /*----- Revision history --------------------------------------------------* 
31  *
32  * $Log: mpbarrett.c,v $
33  * Revision 1.7  2001/04/19 18:25:26  mdw
34  * Use sliding-window exponentiation.
35  *
36  * Revision 1.6  2000/10/08 12:03:44  mdw
37  * (mpbarrett_reduce): Cope with negative numbers.
38  *
39  * Revision 1.5  2000/07/29 17:04:33  mdw
40  * Change to use left-to-right bitwise exponentiation.  This will improve
41  * performance when the base is small.
42  *
43  * Revision 1.4  2000/06/17 11:45:09  mdw
44  * Major memory management overhaul.  Added arena support.  Use the secure
45  * arena for secret integers.  Replace and improve the MP management macros
46  * (e.g., replace MP_MODIFY by MP_DEST).
47  *
48  * Revision 1.3  1999/12/12 15:08:52  mdw
49  * Don't bother shifting %$q$% in @mpbarrett_reduce@, just skip the least
50  * significant digits.
51  *
52  * Revision 1.2  1999/12/11 01:50:56  mdw
53  * Improve initialization slightly.
54  *
55  * Revision 1.1  1999/12/10 23:21:59  mdw
56  * Barrett reduction support: works with even moduli.
57  *
58  */
59
60 /*----- Header files ------------------------------------------------------*/
61
62 #include "mp.h"
63 #include "mpbarrett.h"
64
65 /*----- Main code ---------------------------------------------------------*/
66
67 /* --- @mpbarrett_create@ --- *
68  *
69  * Arguments:   @mpbarrett *mb@ = pointer to Barrett reduction context
70  *              @mp *m@ = modulus to work to
71  *
72  *
73  * Returns:     ---
74  *
75  * Use:         Initializes a Barrett reduction context ready for use.
76  */
77
78 void mpbarrett_create(mpbarrett *mb, mp *m)
79 {
80   mp *b;
81
82   /* --- Validate the arguments --- */
83
84   assert(((void)"Barrett modulus must be positive", (m->f & MP_NEG) == 0));
85
86   /* --- Compute %$\mu$% --- */
87
88   mp_shrink(m);
89   mb->k = MP_LEN(m);
90   mb->m = MP_COPY(m);
91   b = mp_new(2 * mb->k + 1, 0);
92   MPX_ZERO(b->v, b->vl - 1);
93   b->vl[-1] = 1;
94   mp_div(&b, 0, b, m);
95   mb->mu = b;
96 }
97
98 /* --- @mpbarrett_destroy@ --- *
99  *
100  * Arguments:   @mpbarrett *mb@ = pointer to Barrett reduction context
101  *
102  * Returns:     ---
103  *
104  * Use:         Destroys a Barrett reduction context releasing any resources
105  *              claimed.
106  */
107
108 void mpbarrett_destroy(mpbarrett *mb)
109 {
110   mp_drop(mb->m);
111   mp_drop(mb->mu);
112 }
113
114 /* --- @mpbarrett_reduce@ --- *
115  *
116  * Arguments:   @mpbarrett *mb@ = pointer to Barrett reduction context
117  *              @mp *d@ = destination for result
118  *              @mp *m@ = number to reduce
119  *
120  * Returns:     The residue of @m@ modulo the number in the reduction
121  *              context.
122  *
123  * Use:         Performs an efficient modular reduction.
124  */
125
126 mp *mpbarrett_reduce(mpbarrett *mb, mp *d, mp *m)
127 {
128   mp *q;
129   size_t k = mb->k;
130
131   /* --- Special case if @m@ is too small --- */
132
133   if (MP_LEN(m) < k) {
134     m = MP_COPY(m);
135     if (d)
136       MP_DROP(d);
137     return (m);
138   }
139
140   /* --- First stage --- */
141
142   {
143     mp qq;
144     mp_build(&qq, m->v + (k - 1), m->vl);
145     q = mp_mul(MP_NEW, &qq, mb->mu);
146     if (MP_LEN(q) <= k) {
147       m = MP_COPY(m);
148       if (d)
149         MP_DROP(d);
150       return (m);
151     }
152   }
153
154   /* --- Second stage --- */
155
156   {
157     mp *r;
158     mpw *mvl;
159
160     MP_COPY(m);
161     if (MP_LEN(m) <= k + 1)
162       mvl = m->vl;
163     else
164       mvl = m->v + k + 1;
165     r = mp_new(k + 1, (q->f | mb->m->f) & MP_BURN);
166     mpx_umul(r->v, r->vl, q->v + k + 1, q->vl, mb->m->v, mb->m->vl);
167     MP_DEST(d, k + 1, r->f);
168     mpx_usub(d->v, d->vl, m->v, mvl, r->v, r->vl);
169     d->f = (m->f | r->f) & (MP_BURN | MP_NEG);
170     MP_DROP(r);
171     MP_DROP(q);
172     MP_DROP(m);
173   }
174
175   /* --- Final stage --- */
176
177   MP_SHRINK(d);
178   while (MPX_UCMP(d->v, d->vl, >=, mb->m->v, mb->m->vl))
179     mpx_usub(d->v, d->vl, d->v, d->vl, mb->m->v, mb->m->vl);
180
181   /* --- Fix up the sign --- */
182
183   if (d->f & MP_NEG) {
184     mpx_usub(d->v, d->vl, mb->m->v, mb->m->vl, d->v, d->vl);
185     d->f &= ~MP_NEG;
186   }
187
188   MP_SHRINK(d);
189   return (d);
190 }
191
192 /* --- @mpbarrett_exp@ --- *
193  *
194  * Arguments:   @mpbarrett *mb@ = pointer to Barrett reduction context
195  *              @mp *d@ = fake destination
196  *              @mp *a@ = base
197  *              @mp *e@ = exponent
198  *
199  * Returns:     Result, %$a^e \bmod m$%.
200  */
201
202 #define WINSZ 5
203 #define TABSZ (1 << (WINSZ - 1))
204
205 #define THRESH (((MPW_BITS / WINSZ) << 2) + 1)
206
207 static mp *exp_simple(mpbarrett *mb, mp *d, mp *a, mp *e)
208 {
209   mpscan sc;
210   mp *x = MP_ONE;
211   mp *spare = (e->f & MP_BURN) ? MP_NEWSEC : MP_NEW;
212   unsigned sq = 0;
213
214   a = MP_COPY(a);
215   mp_rscan(&sc, e);
216   if (!MP_RSTEP(&sc))
217     goto exit;
218   while (!MP_RBIT(&sc))
219     MP_RSTEP(&sc);
220
221   /* --- Do the main body of the work --- */
222
223   for (;;) {
224     sq++;
225     while (sq) {
226       mp *y;
227       y = mp_sqr(spare, x);
228       y = mpbarrett_reduce(mb, y, y);
229       spare = x; x = y;
230       sq--;
231     }
232     {
233       mp *y = mp_mul(spare, x, a);
234       y = mpbarrett_reduce(mb, y, y);
235       spare = x; x = y;
236     }
237     for (;;) {
238       if (!MP_RSTEP(&sc))
239         goto done;
240       if (MP_RBIT(&sc))
241         break;
242       sq++;
243     }
244   }
245
246   /* --- Do a final round of squaring --- */
247
248 done:
249   while (sq) {
250     mp *y;
251     y = mp_sqr(spare, x);
252     y = mpbarrett_reduce(mb, y, y);
253     spare = x; x = y;
254     sq--;
255   }  
256
257 exit:
258   MP_DROP(a);
259   if (spare != MP_NEW)
260     MP_DROP(spare);
261   if (d != MP_NEW)
262     MP_DROP(d);
263   return (x);
264 }
265
266 mp *mpbarrett_exp(mpbarrett *mb, mp *d, mp *a, mp *e)
267 {
268   mp **tab;
269   mp *a2;
270   mp *spare = (e->f & MP_BURN) ? MP_NEWSEC : MP_NEW;
271   mp *x = MP_ONE;
272   unsigned i, sq = 0;
273   mpscan sc;
274
275   /* --- Do we bother? --- */
276
277   MP_SHRINK(e);
278   if (MP_LEN(e) == 0)
279     goto exit;
280   if (MP_LEN(e) < THRESH) {
281     x->ref--;
282     return (exp_simple(mb, d, a, e));
283   }
284
285   /* --- Do the precomputation --- */
286
287   a2 = mp_sqr(MP_NEW, a);
288   a2 = mpbarrett_reduce(mb, a2, a2);
289   tab = xmalloc(TABSZ * sizeof(mp *));
290   tab[0] = MP_COPY(a);
291   for (i = 1; i < TABSZ; i++) {
292     mp *x = mp_mul(MP_NEW, tab[i - 1], a2);
293     tab[i] = mpbarrett_reduce(mb, x, x);
294   }
295   mp_drop(a2);
296   mp_rscan(&sc, e);
297   
298   /* --- Skip top-end zero bits --- *
299    *
300    * If the initial step worked, there must be a set bit somewhere, so keep
301    * stepping until I find it.
302    */
303
304   MP_RSTEP(&sc);
305   while (!MP_RBIT(&sc))
306     MP_RSTEP(&sc);
307
308   /* --- Now for the main work --- */
309
310   for (;;) {
311     unsigned l = 0;
312     unsigned z = 0;
313
314     /* --- The next bit is set, so read a window index --- *
315      *
316      * Reset @i@ to zero and increment @sq@.  Then, until either I read
317      * @WINSZ@ bits or I run out of bits, scan in a bit: if it's clear, bump
318      * the @z@ counter; if it's set, push a set bit into @i@, shift it over
319      * by @z@ bits, bump @sq@ by @z + 1@ and clear @z@.  By the end of this
320      * palaver, @i@ is an index to the precomputed value in @tab@.
321      */
322
323     i = 0;
324     sq++;
325     for (;;) {
326       l++;
327       if (l >= WINSZ || !MP_RSTEP(&sc))
328         break;
329       if (!MP_RBIT(&sc))
330         z++;
331       else {
332         i = ((i << 1) | 1) << z;
333         sq += z + 1;
334         z = 0;
335       }
336     }
337
338     /* --- Do the squaring --- *
339      *
340      * Remember that @sq@ carries over from the zero-skipping stuff below.
341      */
342
343     while (sq) {
344       mp *y;
345       y = mp_sqr(spare, x);
346       y = mpbarrett_reduce(mb, y, y);
347       spare = x; x = y;
348       sq--;
349     }
350
351     /* --- Do the multiply --- */
352
353     { mp *y = mp_mul(spare, x, tab[i]); spare = x;
354       x = mpbarrett_reduce(mb, y, y); }
355
356     /* --- Now grind along through the rest of the bits --- */
357
358     sq = z;
359     for (;;) {
360       if (!MP_RSTEP(&sc))
361         goto done;
362       if (MP_RBIT(&sc))
363         break;
364       sq++;
365     }
366   }
367
368   /* --- Do a final round of squaring --- */
369
370 done:
371   while (sq) {
372     mp *y;
373     y = mp_sqr(spare, x);
374     y = mpbarrett_reduce(mb, y, y);
375     spare = x; x = y;
376     sq--;
377   }
378
379   /* --- Done --- */
380
381   for (i = 0; i < TABSZ; i++)
382     mp_drop(tab[i]);
383   xfree(tab);
384 exit:
385   mp_drop(d);
386   mp_drop(spare);
387   return (x);
388 }
389
390 /*----- Test rig ----------------------------------------------------------*/
391
392 #ifdef TEST_RIG
393
394 static int vmod(dstr *v)
395 {
396   mp *x = *(mp **)v[0].buf;
397   mp *n = *(mp **)v[1].buf;
398   mp *r = *(mp **)v[2].buf;
399   mp *s;
400   mpbarrett mb;
401   int ok = 1;
402
403   mpbarrett_create(&mb, n);
404   s = mpbarrett_reduce(&mb, MP_NEW, x);
405
406   if (!MP_EQ(s, r)) {
407     fputs("\n*** barrett reduction failure\n", stderr);
408     fputs("x = ", stderr); mp_writefile(x, stderr, 10); fputc('\n', stderr);
409     fputs("n = ", stderr); mp_writefile(n, stderr, 10); fputc('\n', stderr);
410     fputs("r = ", stderr); mp_writefile(r, stderr, 10); fputc('\n', stderr);
411     fputs("s = ", stderr); mp_writefile(s, stderr, 10); fputc('\n', stderr);
412     ok = 0;
413   }
414
415   mpbarrett_destroy(&mb);
416   mp_drop(x);
417   mp_drop(n);
418   mp_drop(r);
419   mp_drop(s);
420   assert(mparena_count(MPARENA_GLOBAL) == 0);
421   return (ok);
422 }
423
424 static int vexp(dstr *v)
425 {
426   mp *m = *(mp **)v[0].buf;
427   mp *a = *(mp **)v[1].buf;
428   mp *b = *(mp **)v[2].buf;
429   mp *r = *(mp **)v[3].buf;
430   mp *mr;
431   int ok = 1;
432
433   mpbarrett mb;
434   mpbarrett_create(&mb, m);
435
436   mr = mpbarrett_exp(&mb, MP_NEW, a, b);
437
438   if (!MP_EQ(mr, r)) {
439     fputs("\n*** barrett modexp failed", stderr);
440     fputs("\n m = ", stderr); mp_writefile(m, stderr, 10);
441     fputs("\n a = ", stderr); mp_writefile(a, stderr, 10);
442     fputs("\n e = ", stderr); mp_writefile(b, stderr, 10);
443     fputs("\n r = ", stderr); mp_writefile(r, stderr, 10);
444     fputs("\nmr = ", stderr); mp_writefile(mr, stderr, 10);
445     fputc('\n', stderr);
446     ok = 0;
447   }
448
449   mp_drop(m);
450   mp_drop(a);
451   mp_drop(b);
452   mp_drop(r);
453   mp_drop(mr);
454   mpbarrett_destroy(&mb);
455   assert(mparena_count(MPARENA_GLOBAL) == 0);
456   return ok;
457 }
458
459 static test_chunk tests[] = {
460   { "mpbarrett-reduce", vmod, { &type_mp, &type_mp, &type_mp, 0 } },
461   { "mpbarrett-exp", vexp, { &type_mp, &type_mp, &type_mp, &type_mp, 0 } },
462   { 0, 0, { 0 } }
463 };
464
465 int main(int argc, char *argv[])
466 {
467   sub_init();
468   test_run(argc, argv, tests, SRCDIR "/tests/mpbarrett");
469   return (0);
470 }
471
472 #endif
473
474 /*----- That's all, folks -------------------------------------------------*/