chiark / gitweb /
Multiprecision routines finished and tested.
[catacomb] / mpx.c
1 /* -*-c-*-
2  *
3  * $Id: mpx.c,v 1.2 1999/11/13 01:50:59 mdw Exp $
4  *
5  * Low-level multiprecision arithmetic
6  *
7  * (c) 1999 Straylight/Edgeware
8  */
9
10 /*----- Licensing notice --------------------------------------------------* 
11  *
12  * This file is part of Catacomb.
13  *
14  * Catacomb is free software; you can redistribute it and/or modify
15  * it under the terms of the GNU Library General Public License as
16  * published by the Free Software Foundation; either version 2 of the
17  * License, or (at your option) any later version.
18  * 
19  * Catacomb is distributed in the hope that it will be useful,
20  * but WITHOUT ANY WARRANTY; without even the implied warranty of
21  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
22  * GNU Library General Public License for more details.
23  * 
24  * You should have received a copy of the GNU Library General Public
25  * License along with Catacomb; if not, write to the Free
26  * Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
27  * MA 02111-1307, USA.
28  */
29
30 /*----- Revision history --------------------------------------------------* 
31  *
32  * $Log: mpx.c,v $
33  * Revision 1.2  1999/11/13 01:50:59  mdw
34  * Multiprecision routines finished and tested.
35  *
36  * Revision 1.1  1999/09/03 08:41:12  mdw
37  * Initial import.
38  *
39  */
40
41 /*----- Header files ------------------------------------------------------*/
42
43 #include <assert.h>
44 #include <stdio.h>
45 #include <stdlib.h>
46 #include <string.h>
47
48 #include <mLib/bits.h>
49
50 #include "mptypes.h"
51 #include "mpx.h"
52
53 /*----- Loading and storing -----------------------------------------------*/
54
55 /* --- @mpx_storel@ --- *
56  *
57  * Arguments:   @const mpw *v, *vl@ = base and limit of source vector
58  *              @void *pp@ = pointer to octet array
59  *              @size_t sz@ = size of octet array
60  *
61  * Returns:     ---
62  *
63  * Use:         Stores an MP in an octet array, least significant octet
64  *              first.  High-end octets are silently discarded if there
65  *              isn't enough space for them.
66  */
67
68 void mpx_storel(const mpw *v, const mpw *vl, void *pp, size_t sz)
69 {
70   mpw n, w = 0;
71   octet *p = pp, *q = p + sz;
72   unsigned bits = 0;
73
74   while (p < q) {
75     if (bits < 8) {
76       if (v >= vl) {
77         *p++ = U8(w);
78         break;
79       }
80       n = *v++;
81       *p++ = U8(w | n << bits);
82       w = n >> (8 - bits);
83       bits += MPW_BITS - 8;
84     } else {
85       *p++ = U8(w);
86       w >>= 8;
87       bits -= 8;
88     }
89   }
90   memset(p, 0, q - p);
91 }
92
93 /* --- @mpx_loadl@ --- *
94  *
95  * Arguments:   @mpw *v, *vl@ = base and limit of destination vector
96  *              @const void *pp@ = pointer to octet array
97  *              @size_t sz@ = size of octet array
98  *
99  * Returns:     ---
100  *
101  * Use:         Loads an MP in an octet array, least significant octet
102  *              first.  High-end octets are ignored if there isn't enough
103  *              space for them.
104  */
105
106 void mpx_loadl(mpw *v, mpw *vl, const void *pp, size_t sz)
107 {
108   unsigned n;
109   mpw w = 0;
110   const octet *p = pp, *q = p + sz;
111   unsigned bits = 0;
112
113   if (v >= vl)
114     return;
115   while (p < q) {
116     n = U8(*p++);
117     w |= n << bits;
118     bits += 8;
119     if (bits >= MPW_BITS) {
120       *v++ = MPW(w);
121       w = n >> (MPW_BITS - bits + 8);
122       bits -= MPW_BITS;
123       if (v >= vl)
124         return;
125     }
126   }
127   *v++ = w;
128   MPX_ZERO(v, vl);
129 }
130
131 /* --- @mpx_storeb@ --- *
132  *
133  * Arguments:   @const mpw *v, *vl@ = base and limit of source vector
134  *              @void *pp@ = pointer to octet array
135  *              @size_t sz@ = size of octet array
136  *
137  * Returns:     ---
138  *
139  * Use:         Stores an MP in an octet array, most significant octet
140  *              first.  High-end octets are silently discarded if there
141  *              isn't enough space for them.
142  */
143
144 void mpx_storeb(const mpw *v, const mpw *vl, void *pp, size_t sz)
145 {
146   mpw n, w = 0;
147   octet *p = pp, *q = p + sz;
148   unsigned bits = 0;
149
150   while (q > p) {
151     if (bits < 8) {
152       if (v >= vl) {
153         *--q = U8(w);
154         break;
155       }
156       n = *v++;
157       *--q = U8(w | n << bits);
158       w = n >> (8 - bits);
159       bits += MPW_BITS - 8;
160     } else {
161       *--q = U8(w);
162       w >>= 8;
163       bits -= 8;
164     }
165   }
166   memset(p, 0, q - p);
167 }
168
169 /* --- @mpx_loadb@ --- *
170  *
171  * Arguments:   @mpw *v, *vl@ = base and limit of destination vector
172  *              @const void *pp@ = pointer to octet array
173  *              @size_t sz@ = size of octet array
174  *
175  * Returns:     ---
176  *
177  * Use:         Loads an MP in an octet array, most significant octet
178  *              first.  High-end octets are ignored if there isn't enough
179  *              space for them.
180  */
181
182 void mpx_loadb(mpw *v, mpw *vl, const void *pp, size_t sz)
183 {
184   unsigned n;
185   mpw w = 0;
186   const octet *p = pp, *q = p + sz;
187   unsigned bits = 0;
188
189   if (v >= vl)
190     return;
191   while (q > p) {
192     n = U8(*--q);
193     w |= n << bits;
194     bits += 8;
195     if (bits >= MPW_BITS) {
196       *v++ = MPW(w);
197       w = n >> (MPW_BITS - bits + 8);
198       bits -= MPW_BITS;
199       if (v >= vl)
200         return;
201     }
202   }
203   *v++ = w;
204   MPX_ZERO(v, vl);
205 }
206
207 /*----- Logical shifting --------------------------------------------------*/
208
209 /* --- @mpx_lsl@ --- *
210  *
211  * Arguments:   @mpw *dv, *dvl@ = destination vector base and limit
212  *              @const mpw *av, *avl@ = source vector base and limit
213  *              @size_t n@ = number of bit positions to shift by
214  *
215  * Returns:     ---
216  *
217  * Use:         Performs a logical shift left operation on an integer.
218  */
219
220 void mpx_lsl(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl, size_t n)
221 {
222   size_t nw;
223   unsigned nb;
224
225   /* --- Trivial special case --- */
226
227   if (n == 0)
228     MPX_COPY(dv, dvl, av, avl);
229
230   /* --- Single bit shifting --- */
231
232   else if (n == 1) {
233     mpw w = 0;
234     while (av < avl) {
235       mpw t;
236       if (dv >= dvl)
237         goto done;
238       t = *av++;
239       *dv++ = MPW((t << 1) | w);
240       w = t >> (MPW_BITS - 1);
241     }
242     if (dv >= dvl)
243       goto done;
244     *dv++ = MPW(w);
245     MPX_ZERO(dv, dvl);
246     goto done;
247   }
248
249   /* --- Break out word and bit shifts for more sophisticated work --- */
250         
251   nw = n / MPW_BITS;
252   nb = n % MPW_BITS;
253
254   /* --- Handle a shift by a multiple of the word size --- */
255
256   if (nb == 0) {
257     MPX_COPY(dv + nw, dvl, av, avl);
258     memset(dv, 0, MPWS(nw));
259   }
260
261   /* --- And finally the difficult case --- *
262    *
263    * This is a little convoluted, because I have to start from the end and
264    * work backwards to avoid overwriting the source, if they're both the same
265    * block of memory.
266    */
267
268   else {
269     mpw w;
270     size_t nr = MPW_BITS - nb;
271     size_t dvn = dvl - dv;
272     size_t avn = avl - av;
273
274     if (dvn <= nw) {
275       MPX_ZERO(dv, dvl);
276       goto done;
277     }
278
279     if (dvn > avn + nw) {
280       size_t off = avn + nw + 1;
281       MPX_ZERO(dv + off, dvl);
282       dvl = dv + off;
283       w = 0;
284     } else {
285       avl = av + dvn - nw;
286       w = *--avl << nb;
287     }
288
289     while (avl > av) {
290       mpw t = *--avl;
291       *--dvl = (t >> nr) | w;
292       w = t << nb;
293     }
294
295     *--dvl = w;
296     MPX_ZERO(dv, dvl);
297   }
298
299 done:;
300 }
301
302 /* --- @mpx_lsr@ --- *
303  *
304  * Arguments:   @mpw *dv, *dvl@ = destination vector base and limit
305  *              @const mpw *av, *avl@ = source vector base and limit
306  *              @size_t n@ = number of bit positions to shift by
307  *
308  * Returns:     ---
309  *
310  * Use:         Performs a logical shift right operation on an integer.
311  */
312
313 void mpx_lsr(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl, size_t n)
314 {
315   size_t nw;
316   unsigned nb;
317
318   /* --- Trivial special case --- */
319
320   if (n == 0)
321     MPX_COPY(dv, dvl, av, avl);
322
323   /* --- Single bit shifting --- */
324
325   else if (n == 1) {
326     mpw w = *av++ >> 1;
327     while (av < avl) {
328       mpw t;
329       if (dv >= dvl)
330         goto done;
331       t = *av++;
332       *dv++ = MPW((t << (MPW_BITS - 1)) | w);
333       w = t >> 1;
334     }
335     if (dv >= dvl)
336       goto done;
337     *dv++ = MPW(w);
338     MPX_ZERO(dv, dvl);
339     goto done;
340   }
341
342   /* --- Break out word and bit shifts for more sophisticated work --- */
343
344   nw = n / MPW_BITS;
345   nb = n % MPW_BITS;
346
347   /* --- Handle a shift by a multiple of the word size --- */
348
349   if (nb == 0)
350     MPX_COPY(dv, dvl, av + nw, avl);
351
352   /* --- And finally the difficult case --- */
353
354   else {
355     mpw w;
356     size_t nr = MPW_BITS - nb;
357
358     av += nw;
359     w = *av++;
360     while (av < avl) {
361       mpw t;
362       if (dv >= dvl)
363         goto done;
364       t = *av++;
365       *dv++ = MPW((w >> nb) | (t << nr));
366       w = t;
367     }
368     if (dv < dvl) {
369       *dv++ = MPW(w >> nb);
370       MPX_ZERO(dv, dvl);
371     }
372   }
373
374 done:;
375 }
376
377 /*----- Unsigned arithmetic -----------------------------------------------*/
378
379 /* --- @mpx_ucmp@ --- *
380  *
381  * Arguments:   @const mpw *av, *avl@ = first argument vector base and limit
382  *              @const mpw *bv, *bvl@ = second argument vector base and limit
383  *
384  * Returns:     Less than, equal to, or greater than zero depending on
385  *              whether @a@ is less than, equal to or greater than @b@,
386  *              respectively.
387  *
388  * Use:         Performs an unsigned integer comparison.
389  */
390
391 int mpx_ucmp(const mpw *av, const mpw *avl, const mpw *bv, const mpw *bvl)
392 {
393   MPX_SHRINK(av, avl);
394   MPX_SHRINK(bv, bvl);
395
396   if (avl - av > bvl - bv)
397     return (+1);
398   else if (avl - av < bvl - bv)
399     return (-1);
400   else while (avl > av) {
401     mpw a = *--avl, b = *--bvl;
402     if (a > b)
403       return (+1);
404     else if (a < b)
405       return (-1);
406   }
407   return (0);
408 }
409   
410 /* --- @mpx_uadd@ --- *
411  *
412  * Arguments:   @mpw *dv, *dvl@ = destination vector base and limit
413  *              @const mpw *av, *avl@ = first addend vector base and limit
414  *              @const mpw *bv, *bvl@ = second addend vector base and limit
415  *
416  * Returns:     ---
417  *
418  * Use:         Performs unsigned integer addition.  If the result overflows
419  *              the destination vector, high-order bits are discarded.  This
420  *              means that two's complement addition happens more or less for
421  *              free, although that's more a side-effect than anything else.
422  *              The result vector may be equal to either or both source
423  *              vectors, but may not otherwise overlap them.
424  */
425
426 void mpx_uadd(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl,
427               const mpw *bv, const mpw *bvl)
428 {
429   mpw c = 0;
430
431   while (av < avl || bv < bvl) {
432     mpw a, b;
433     mpd x;
434     if (dv >= dvl)
435       return;
436     a = (av < avl) ? *av++ : 0;
437     b = (bv < bvl) ? *bv++ : 0;
438     x = (mpd)a + (mpd)b + c;
439     *dv++ = MPW(x);
440     c = x >> MPW_BITS;
441   }
442   if (dv < dvl) {
443     *dv++ = c;
444     MPX_ZERO(dv, dvl);
445   }
446 }
447
448 /* --- @mpx_usub@ --- *
449  *
450  * Arguments:   @mpw *dv, *dvl@ = destination vector base and limit
451  *              @const mpw *av, *avl@ = first argument vector base and limit
452  *              @const mpw *bv, *bvl@ = second argument vector base and limit
453  *
454  * Returns:     ---
455  *
456  * Use:         Performs unsigned integer subtraction.  If the result
457  *              overflows the destination vector, high-order bits are
458  *              discarded.  This means that two's complement subtraction
459  *              happens more or less for free, althuogh that's more a side-
460  *              effect than anything else.  The result vector may be equal to
461  *              either or both source vectors, but may not otherwise overlap
462  *              them.
463  */
464
465 void mpx_usub(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl,
466               const mpw *bv, const mpw *bvl)
467 {
468   mpw c = 0;
469
470   while (av < avl || bv < bvl) {
471     mpw a, b;
472     mpd x;
473     if (dv >= dvl)
474       return;
475     a = (av < avl) ? *av++ : 0;
476     b = (bv < bvl) ? *bv++ : 0;
477     x = (mpd)a - (mpd)b - c;
478     *dv++ = MPW(x);
479     if (x >> MPW_BITS)
480       c = 1;
481     else
482       c = 0;
483   }
484   if (c)
485     c = MPW_MAX;
486   while (dv < dvl)
487     *dv++ = c;
488 }
489
490 /* --- @mpx_umul@ --- *
491  *
492  * Arguments:   @mpw *dv, *dvl@ = destination vector base and limit
493  *              @const mpw *av, *avl@ = multiplicand vector base and limit
494  *              @const mpw *bv, *bvl@ = multiplier vector base and limit
495  *
496  * Returns:     ---
497  *
498  * Use:         Performs unsigned integer multiplication.  If the result
499  *              overflows the desination vector, high-order bits are
500  *              discarded.  The result vector may not overlap the argument
501  *              vectors in any way.
502  */
503
504 void mpx_umul(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl,
505               const mpw *bv, const mpw *bvl)
506 {
507   /* --- This is probably worthwhile on a multiply --- */
508
509   MPX_SHRINK(av, avl);
510   MPX_SHRINK(bv, bvl);
511
512   /* --- Deal with a multiply by zero --- */
513   
514   if (bv == bvl) {
515     MPX_ZERO(dv, dvl);
516     return;
517   }
518
519   /* --- Do the initial multiply and initialize the accumulator --- */
520
521   MPX_UMULN(dv, dvl, av, avl, *bv++);
522
523   /* --- Do the remaining multiply/accumulates --- */
524
525   while (dv < dvl && bv < bvl) {
526     mpw m = *bv++;
527     mpw c = 0;
528     const mpw *avv = av;
529     mpw *dvv = ++dv;
530
531     while (avv < avl) {
532       mpd x;
533       if (dvv >= dvl)
534         goto next;
535       x = (mpd)*dvv + (mpd)m * (mpd)*avv++ + c;
536       *dvv++ = MPW(x);
537       c = x >> MPW_BITS;
538     }
539     MPX_UADDN(dvv, dvl, c);
540   next:;
541   }
542 }
543
544 /* --- @mpx_usqr@ --- *
545  *
546  * Arguments:   @mpw *dv, *dvl@ = destination vector base and limit
547  *              @const mpw *av, *av@ = source vector base and limit
548  *
549  * Returns:     ---
550  *
551  * Use:         Performs unsigned integer squaring.  The result vector must
552  *              not overlap the source vector in any way.
553  */
554
555 void mpx_usqr(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl)
556 {
557   MPX_ZERO(dv, dvl);
558
559   /* --- Main loop --- */
560
561   while (av < avl) {
562     const mpw *avv = av;
563     mpw *dvv = dv;
564     mpw a = *av;
565     mpd c;
566
567     /* --- Stop if I've run out of destination --- */
568
569     if (dvv >= dvl)
570       break;
571
572     /* --- Work out the square at this point in the proceedings --- */
573
574     {
575       mpw d = *dvv;
576       mpd x = (mpd)a * (mpd)a + *dvv;
577       *dvv++ = MPW(x);
578       c = MPW(x >> MPW_BITS);
579     }
580
581     /* --- Now fix up the rest of the vector upwards --- */
582
583     avv++;
584     while (dvv < dvl && avv < avl) {
585       mpw aa = *avv;
586       mpd x = (mpd)a * (mpd)*avv++;
587       mpd y = ((x << 1) & MPW_MAX) + c + *dvv;
588       c = (x >> (MPW_BITS - 1)) + (y >> MPW_BITS);
589       *dvv++ = MPW(y);
590     }
591     while (dvv < dvl && c) {
592       mpd x = c + *dvv;
593       *dvv++ = MPW(x);
594       c = x >> MPW_BITS;
595     }
596
597     /* --- Get ready for the next round --- */
598
599     av++;
600     dv += 2;
601   }
602 }
603
604 /* --- @mpx_udiv@ --- *
605  *
606  * Arguments:   @mpw *qv, *qvl@ = quotient vector base and limit
607  *              @mpw *rv, *rvl@ = dividend/remainder vector base and limit
608  *              @const mpw *dv, *dvl@ = divisor vector base and limit
609  *              @mpw *sv, *svl@ = scratch workspace
610  *
611  * Returns:     ---
612  *
613  * Use:         Performs unsigned integer division.  If the result overflows
614  *              the quotient vector, high-order bits are discarded.  (Clearly
615  *              the remainder vector can't overflow.)  The various vectors
616  *              may not overlap in any way.  Yes, I know it's a bit odd
617  *              requiring the dividend to be in the result position but it
618  *              does make some sense really.  The remainder must have
619  *              headroom for at least two extra words.  The scratch space
620  *              must be at least two words larger than twice the size of the
621  *              divisor.
622  */
623
624 void mpx_udiv(mpw *qv, mpw *qvl, mpw *rv, mpw *rvl,
625               const mpw *dv, const mpw *dvl,
626               mpw *sv, mpw *svl)
627 {
628   unsigned norm = 0;
629   size_t scale;
630   mpw d, dd;
631
632   /* --- Initialize the quotient --- */
633
634   MPX_ZERO(qv, qvl);
635
636   /* --- Perform some sanity checks --- */
637
638   MPX_SHRINK(dv, dvl);
639   assert(((void)"division by zero in mpx_udiv", dv < dvl));
640
641   /* --- Normalize the divisor --- *
642    *
643    * The algorithm requires that the divisor be at least two digits long.
644    * This is easy to fix.
645    */
646
647   {
648     unsigned b;
649
650     d = dvl[-1];
651     for (b = MPW_BITS / 2; b; b >>= 1) {
652       if (d < (MPW_MAX >> b)) {
653         d <<= b;
654         norm += b;
655       }
656     }
657     if (dv + 1 == dvl)
658       norm += MPW_BITS;
659   }
660
661   /* --- Normalize the dividend/remainder to match --- */
662
663   if (norm) {
664     mpw *svvl = sv + (dvl - dv) + 1;
665     mpx_lsl(rv, rvl, rv, rvl, norm);
666     mpx_lsl(sv, svvl, dv, dvl, norm);
667     dv = sv;
668     sv = svvl;
669     dvl = svvl;
670     MPX_SHRINK(dv, dvl);
671   }
672
673   MPX_SHRINK(rv, rvl);
674   d = dvl[-1];
675   dd = dvl[-2];
676
677   /* --- Work out the relative scales --- */
678
679   {
680     size_t rvn = rvl - rv;
681     size_t dvn = dvl - dv;
682
683     /* --- If the divisor is clearly larger, notice this --- */
684
685     if (dvn > rvn) {
686       mpx_lsr(rv, rvl, rv, rvl, norm);
687       return;
688     }
689
690     scale = rvn - dvn;
691   }
692
693   /* --- Calculate the most significant quotient digit --- *
694    *
695    * Because the divisor has its top bit set, this can only happen once.  The
696    * pointer arithmetic is a little contorted, to make sure that the
697    * behaviour is defined.
698    */
699
700   if (MPX_UCMP(rv + scale, rvl, >=, dv, dvl)) {
701     mpx_usub(rv + scale, rvl, rv + scale, rvl, dv, dvl);
702     if (qvl - qv > scale)
703       qv[scale] = 1;
704   }
705
706   /* --- Now for the main loop --- */
707
708   {
709     mpw *rvv = rvl - 2;
710
711     while (scale) {
712       mpw q;
713       mpd rh;
714
715       /* --- Get an estimate for the next quotient digit --- */
716
717       mpw r = rvv[1];
718       mpw rr = rvv[0];
719       mpw rrr = *--rvv;
720
721       scale--;
722       rh = ((mpd)r << MPW_BITS) | rr;
723       if (r == d)
724         q = MPW_MAX;
725       else
726         q = MPW(rh / d);
727
728       /* --- Refine the estimate --- */
729
730       {
731         mpd yh = (mpd)d * q;
732         mpd yl = (mpd)dd * q;
733
734         if (yl > MPW_MAX) {
735           yh += yl >> MPW_BITS;
736           yl &= MPW_MAX;
737         }
738
739         while (yh > rh || (yh == rh && yl > rrr)) {
740           q--;
741           yh -= d;
742           if (yl < dd) {
743             yh++;
744             yl += MPW_MAX;
745           }
746           yl -= dd;
747         }
748       }
749
750       /* --- Remove a chunk from the dividend --- */
751
752       {
753         mpw *svv;
754         const mpw *dvv;
755         mpw c = 0;
756
757         /* --- Calculate the size of the chunk --- */
758
759         for (svv = sv, dvv = dv; dvv < dvl; svv++, dvv++) {
760           mpd x = (mpd)*dvv * (mpd)q + c;
761           *svv = MPW(x);
762           c = x >> MPW_BITS;
763         }
764         if (c)
765           *svv++ = c;
766
767         /* --- Now make sure that we can cope with the difference --- *
768          *
769          * Take advantage of the fact that subtraction works two's-
770          * complement.
771          */
772
773         mpx_usub(rv + scale, rvl, rv + scale, rvl, sv, svv);
774         if (rvl[-1] > MPW_MAX / 2) {
775           mpx_uadd(rv + scale, rvl, rv + scale, rvl, dv, dvl);
776           q--;
777         }
778       }
779
780       /* --- Done for another iteration --- */
781
782       if (qvl - qv > scale)
783         qv[scale] = q;
784       r = rr;
785       rr = rrr;
786     }
787   }
788
789   /* --- Now fiddle with unnormalizing and things --- */
790
791   mpx_lsr(rv, rvl, rv, rvl, norm);
792 }
793
794 /*----- That's all, folks -------------------------------------------------*/