chiark / gitweb /
Simplify and improve.
[catacomb] / mpx-ksqr.c
1 /* -*-c-*-
2  *
3  * $Id: mpx-ksqr.c,v 1.2 1999/12/13 15:35:01 mdw Exp $
4  *
5  * Karatsuba-based squaring algorithm
6  *
7  * (c) 1999 Straylight/Edgeware
8  */
9
10 /*----- Licensing notice --------------------------------------------------* 
11  *
12  * This file is part of Catacomb.
13  *
14  * Catacomb is free software; you can redistribute it and/or modify
15  * it under the terms of the GNU Library General Public License as
16  * published by the Free Software Foundation; either version 2 of the
17  * License, or (at your option) any later version.
18  * 
19  * Catacomb is distributed in the hope that it will be useful,
20  * but WITHOUT ANY WARRANTY; without even the implied warranty of
21  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
22  * GNU Library General Public License for more details.
23  * 
24  * You should have received a copy of the GNU Library General Public
25  * License along with Catacomb; if not, write to the Free
26  * Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
27  * MA 02111-1307, USA.
28  */
29
30 /*----- Revision history --------------------------------------------------* 
31  *
32  * $Log: mpx-ksqr.c,v $
33  * Revision 1.2  1999/12/13 15:35:01  mdw
34  * Simplify and improve.
35  *
36  * Revision 1.1  1999/12/11 10:57:43  mdw
37  * Karatsuba squaring algorithm.
38  *
39  */
40
41 /*----- Header files ------------------------------------------------------*/
42
43 #include <assert.h>
44 #include <stdio.h>
45
46 #include "mpx.h"
47
48 /*----- Tweakables --------------------------------------------------------*/
49
50 #ifdef TEST_RIG
51 #  undef KARATSUBA_CUTOFF
52 #  define KARATSUBA_CUTOFF 2
53 #endif
54
55 /*----- Addition macros ---------------------------------------------------*/
56
57 #define ULSL1(dv, av, avl) do {                                         \
58   mpw *_dv = (dv);                                                      \
59   const mpw *_av = (av), *_avl = (avl);                                 \
60   mpw _c = 0;                                                           \
61                                                                         \
62   while (_av < _avl) {                                                  \
63     mpw _x = *_av++;                                                    \
64     *_dv++ = MPW(_c | (_x << 1));                                       \
65     _c = MPW(_x >> (MPW_BITS - 1));                                     \
66   }                                                                     \
67   *_dv++ = _c;                                                          \
68 } while (0)
69
70 #define UADD(dv, av, avl) do {                                          \
71   mpw *_dv = (dv);                                                      \
72   const mpw *_av = (av), *_avl = (avl);                                 \
73   mpw _c = 0;                                                           \
74                                                                         \
75   while (_av < _avl) {                                                  \
76     mpw _a, _b;                                                         \
77     mpd _x;                                                             \
78     _a = *_av++;                                                        \
79     _b = *_dv;                                                          \
80     _x = (mpd)_a + (mpd)_b + _c;                                        \
81     *_dv++ = MPW(_x);                                                   \
82     _c = _x >> MPW_BITS;                                                \
83   }                                                                     \
84   while (_c) {                                                          \
85     mpd _x = (mpd)*_dv + (mpd)_c;                                       \
86     *_dv++ = MPW(_x);                                                   \
87     _c = _x >> MPW_BITS;                                                \
88   }                                                                     \
89 } while (0)
90
91 /*----- Main code ---------------------------------------------------------*/
92
93 /* --- @mpx_ksqr@ --- *
94  *
95  * Arguments:   @mpw *dv, *dvl@ = pointer to destination buffer
96  *              @const mpw *av, *avl@ = pointer to first argument
97  *              @mpw *sv, *svl@ = pointer to scratch workspace
98  *
99  * Returns:     ---
100  *
101  * Use:         Squares a multiprecision integers using something similar to
102  *              Karatsuba's multiplication algorithm.  This is rather faster
103  *              than traditional long multiplication (e.g., @mpx_umul@) on
104  *              large numbers, although more expensive on small ones, and
105  *              rather simpler than full-blown Karatsuba multiplication.
106  *
107  *              The destination must be twice as large as the argument.  The
108  *              scratch space must be twice as large as the argument, plus
109  *              the magic number @KARATSUBA_SLOP@.
110  */
111
112 void mpx_ksqr(mpw *dv, mpw *dvl,
113               const mpw *av, const mpw *avl,
114               mpw *sv, mpw *svl)
115 {
116   const mpw *avm;
117   size_t m;
118
119   /* --- Dispose of easy cases to @mpx_usqr@ --- *
120    *
121    * Karatsuba is only a win on large numbers, because of all the
122    * recursiveness and bookkeeping.  The recursive calls make a quick check
123    * to see whether to bottom out to @mpx_usqr@ which should help quite a
124    * lot, but sometimes the only way to know is to make sure...
125    */
126
127   MPX_SHRINK(av, avl);
128
129   if (avl - av <= KARATSUBA_CUTOFF) {
130     mpx_usqr(dv, dvl, av, avl);
131     return;
132   }
133
134   /* --- How the algorithm works --- *
135    *
136    * Unlike Karatsuba's identity for multiplication which isn't particularly
137    * obvious, the identity for multiplication is known to all schoolchildren.
138    * Let %$A = xb + y$%.  Then %$A^2 = x^2 b^x + 2 x y b + y^2$%.  So now I
139    * have three multiplications, each four times easier, and that's a win.
140    */
141
142   /* --- First things --- *
143    *
144    * Sort out where to break the factor in half.
145    */
146
147   m = (avl - av + 1) >> 1;
148   avm = av + m;
149
150   assert(((void)"Destination too small for Karatsuba square",
151           dvl - dv >= 4 * m));
152   assert(((void)"Not enough workspace for Karatsuba square",
153           svl - sv >= 4 * m));
154
155   /* --- Sort out everything --- */
156
157   {
158     mpw *svm = sv + m, *svn = svm + m, *ssv = svn + 4;
159     mpw *tdv = dv + m;
160     mpw *rdv = tdv + m;
161
162     /* --- The cross term in the middle needs a multiply --- *
163      *
164      * This isn't actually true, since %$x y = ((x + y)^2 - (x - y)^2)/4%.
165      * But that's two squarings, versus one multiplication.
166      */
167
168     if (m > KARATSUBA_CUTOFF)
169       mpx_kmul(sv, ssv, av, avm, avm, avl, ssv, svl);
170     else
171       mpx_umul(sv, ssv, av, avm, avm, avl);
172     ULSL1(tdv, sv, svn);
173
174     if (m > KARATSUBA_CUTOFF)
175       mpx_ksqr(sv, ssv, avm, avl, ssv, svl);
176     else
177       mpx_usqr(sv, ssv, avm, avl);
178     MPX_COPY(rdv + m + 1, dvl, svm + 1, svn);
179     UADD(rdv, sv, svm + 1);
180     
181     if (m > KARATSUBA_CUTOFF)
182       mpx_ksqr(sv, ssv, av, avm, ssv, svl);
183     else
184       mpx_usqr(sv, ssv, av, avm);
185     MPX_COPY(dv, tdv, sv, svm);
186     UADD(tdv, svm, svn);
187   }
188 }
189
190 /*----- Test rig ----------------------------------------------------------*/
191
192 #ifdef TEST_RIG
193
194 #include <mLib/alloc.h>
195 #include <mLib/testrig.h>
196
197 #include "mpscan.h"
198
199 #define ALLOC(v, vl, sz) do {                                           \
200   size_t _sz = (sz);                                                    \
201   mpw *_vv = xmalloc(MPWS(_sz));                                        \
202   mpw *_vvl = _vv + _sz;                                                \
203   (v) = _vv;                                                            \
204   (vl) = _vvl;                                                          \
205 } while (0)
206
207 #define LOAD(v, vl, d) do {                                             \
208   const dstr *_d = (d);                                                 \
209   mpw *_v, *_vl;                                                        \
210   ALLOC(_v, _vl, MPW_RQ(_d->len));                                      \
211   mpx_loadb(_v, _vl, _d->buf, _d->len);                                 \
212   (v) = _v;                                                             \
213   (vl) = _vl;                                                           \
214 } while (0)
215
216 #define MAX(x, y) ((x) > (y) ? (x) : (y))
217
218 static void dumpmp(const char *msg, const mpw *v, const mpw *vl)
219 {
220   fputs(msg, stderr);
221   MPX_SHRINK(v, vl);
222   while (v < vl)
223     fprintf(stderr, " %08lx", (unsigned long)*--vl);
224   fputc('\n', stderr);
225 }
226
227 static int usqr(dstr *v)
228 {
229   mpw *a, *al;
230   mpw *c, *cl;
231   mpw *d, *dl;
232   mpw *s, *sl;
233   size_t m;
234   int ok = 1;
235
236   LOAD(a, al, &v[0]);
237   LOAD(c, cl, &v[1]);
238   m = al - a + 1;
239   ALLOC(d, dl, 2 * m);
240   ALLOC(s, sl, 2 * m + 32);
241
242   mpx_ksqr(d, dl, a, al, s, sl);
243   if (MPX_UCMP(d, dl, !=, c, cl)) {
244     fprintf(stderr, "\n*** usqr failed\n");
245     dumpmp("       a", a, al);
246     dumpmp("expected", c, cl);
247     dumpmp("  result", d, dl);
248     ok = 0;
249   }
250
251   free(a); free(c); free(d); free(s);
252   return (ok);
253 }
254
255 static test_chunk defs[] = {
256   { "usqr", usqr, { &type_hex, &type_hex, 0 } },
257   { 0, 0, { 0 } }
258 };
259
260 int main(int argc, char *argv[])
261 {
262   test_run(argc, argv, defs, SRCDIR"/tests/mpx");
263   return (0);
264 }
265
266 #endif
267
268 /*----- That's all, folks -------------------------------------------------*/