chiark / gitweb /
Remove stray debugging code.
[catacomb] / mpx.c
1 /* -*-c-*-
2  *
3  * $Id: mpx.c,v 1.3 1999/11/13 01:57:31 mdw Exp $
4  *
5  * Low-level multiprecision arithmetic
6  *
7  * (c) 1999 Straylight/Edgeware
8  */
9
10 /*----- Licensing notice --------------------------------------------------* 
11  *
12  * This file is part of Catacomb.
13  *
14  * Catacomb is free software; you can redistribute it and/or modify
15  * it under the terms of the GNU Library General Public License as
16  * published by the Free Software Foundation; either version 2 of the
17  * License, or (at your option) any later version.
18  * 
19  * Catacomb is distributed in the hope that it will be useful,
20  * but WITHOUT ANY WARRANTY; without even the implied warranty of
21  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
22  * GNU Library General Public License for more details.
23  * 
24  * You should have received a copy of the GNU Library General Public
25  * License along with Catacomb; if not, write to the Free
26  * Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
27  * MA 02111-1307, USA.
28  */
29
30 /*----- Revision history --------------------------------------------------* 
31  *
32  * $Log: mpx.c,v $
33  * Revision 1.3  1999/11/13 01:57:31  mdw
34  * Remove stray debugging code.
35  *
36  * Revision 1.2  1999/11/13 01:50:59  mdw
37  * Multiprecision routines finished and tested.
38  *
39  * Revision 1.1  1999/09/03 08:41:12  mdw
40  * Initial import.
41  *
42  */
43
44 /*----- Header files ------------------------------------------------------*/
45
46 #include <assert.h>
47 #include <stdio.h>
48 #include <stdlib.h>
49 #include <string.h>
50
51 #include <mLib/bits.h>
52
53 #include "mptypes.h"
54 #include "mpx.h"
55
56 /*----- Loading and storing -----------------------------------------------*/
57
58 /* --- @mpx_storel@ --- *
59  *
60  * Arguments:   @const mpw *v, *vl@ = base and limit of source vector
61  *              @void *pp@ = pointer to octet array
62  *              @size_t sz@ = size of octet array
63  *
64  * Returns:     ---
65  *
66  * Use:         Stores an MP in an octet array, least significant octet
67  *              first.  High-end octets are silently discarded if there
68  *              isn't enough space for them.
69  */
70
71 void mpx_storel(const mpw *v, const mpw *vl, void *pp, size_t sz)
72 {
73   mpw n, w = 0;
74   octet *p = pp, *q = p + sz;
75   unsigned bits = 0;
76
77   while (p < q) {
78     if (bits < 8) {
79       if (v >= vl) {
80         *p++ = U8(w);
81         break;
82       }
83       n = *v++;
84       *p++ = U8(w | n << bits);
85       w = n >> (8 - bits);
86       bits += MPW_BITS - 8;
87     } else {
88       *p++ = U8(w);
89       w >>= 8;
90       bits -= 8;
91     }
92   }
93   memset(p, 0, q - p);
94 }
95
96 /* --- @mpx_loadl@ --- *
97  *
98  * Arguments:   @mpw *v, *vl@ = base and limit of destination vector
99  *              @const void *pp@ = pointer to octet array
100  *              @size_t sz@ = size of octet array
101  *
102  * Returns:     ---
103  *
104  * Use:         Loads an MP in an octet array, least significant octet
105  *              first.  High-end octets are ignored if there isn't enough
106  *              space for them.
107  */
108
109 void mpx_loadl(mpw *v, mpw *vl, const void *pp, size_t sz)
110 {
111   unsigned n;
112   mpw w = 0;
113   const octet *p = pp, *q = p + sz;
114   unsigned bits = 0;
115
116   if (v >= vl)
117     return;
118   while (p < q) {
119     n = U8(*p++);
120     w |= n << bits;
121     bits += 8;
122     if (bits >= MPW_BITS) {
123       *v++ = MPW(w);
124       w = n >> (MPW_BITS - bits + 8);
125       bits -= MPW_BITS;
126       if (v >= vl)
127         return;
128     }
129   }
130   *v++ = w;
131   MPX_ZERO(v, vl);
132 }
133
134 /* --- @mpx_storeb@ --- *
135  *
136  * Arguments:   @const mpw *v, *vl@ = base and limit of source vector
137  *              @void *pp@ = pointer to octet array
138  *              @size_t sz@ = size of octet array
139  *
140  * Returns:     ---
141  *
142  * Use:         Stores an MP in an octet array, most significant octet
143  *              first.  High-end octets are silently discarded if there
144  *              isn't enough space for them.
145  */
146
147 void mpx_storeb(const mpw *v, const mpw *vl, void *pp, size_t sz)
148 {
149   mpw n, w = 0;
150   octet *p = pp, *q = p + sz;
151   unsigned bits = 0;
152
153   while (q > p) {
154     if (bits < 8) {
155       if (v >= vl) {
156         *--q = U8(w);
157         break;
158       }
159       n = *v++;
160       *--q = U8(w | n << bits);
161       w = n >> (8 - bits);
162       bits += MPW_BITS - 8;
163     } else {
164       *--q = U8(w);
165       w >>= 8;
166       bits -= 8;
167     }
168   }
169   memset(p, 0, q - p);
170 }
171
172 /* --- @mpx_loadb@ --- *
173  *
174  * Arguments:   @mpw *v, *vl@ = base and limit of destination vector
175  *              @const void *pp@ = pointer to octet array
176  *              @size_t sz@ = size of octet array
177  *
178  * Returns:     ---
179  *
180  * Use:         Loads an MP in an octet array, most significant octet
181  *              first.  High-end octets are ignored if there isn't enough
182  *              space for them.
183  */
184
185 void mpx_loadb(mpw *v, mpw *vl, const void *pp, size_t sz)
186 {
187   unsigned n;
188   mpw w = 0;
189   const octet *p = pp, *q = p + sz;
190   unsigned bits = 0;
191
192   if (v >= vl)
193     return;
194   while (q > p) {
195     n = U8(*--q);
196     w |= n << bits;
197     bits += 8;
198     if (bits >= MPW_BITS) {
199       *v++ = MPW(w);
200       w = n >> (MPW_BITS - bits + 8);
201       bits -= MPW_BITS;
202       if (v >= vl)
203         return;
204     }
205   }
206   *v++ = w;
207   MPX_ZERO(v, vl);
208 }
209
210 /*----- Logical shifting --------------------------------------------------*/
211
212 /* --- @mpx_lsl@ --- *
213  *
214  * Arguments:   @mpw *dv, *dvl@ = destination vector base and limit
215  *              @const mpw *av, *avl@ = source vector base and limit
216  *              @size_t n@ = number of bit positions to shift by
217  *
218  * Returns:     ---
219  *
220  * Use:         Performs a logical shift left operation on an integer.
221  */
222
223 void mpx_lsl(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl, size_t n)
224 {
225   size_t nw;
226   unsigned nb;
227
228   /* --- Trivial special case --- */
229
230   if (n == 0)
231     MPX_COPY(dv, dvl, av, avl);
232
233   /* --- Single bit shifting --- */
234
235   else if (n == 1) {
236     mpw w = 0;
237     while (av < avl) {
238       mpw t;
239       if (dv >= dvl)
240         goto done;
241       t = *av++;
242       *dv++ = MPW((t << 1) | w);
243       w = t >> (MPW_BITS - 1);
244     }
245     if (dv >= dvl)
246       goto done;
247     *dv++ = MPW(w);
248     MPX_ZERO(dv, dvl);
249     goto done;
250   }
251
252   /* --- Break out word and bit shifts for more sophisticated work --- */
253         
254   nw = n / MPW_BITS;
255   nb = n % MPW_BITS;
256
257   /* --- Handle a shift by a multiple of the word size --- */
258
259   if (nb == 0) {
260     MPX_COPY(dv + nw, dvl, av, avl);
261     memset(dv, 0, MPWS(nw));
262   }
263
264   /* --- And finally the difficult case --- *
265    *
266    * This is a little convoluted, because I have to start from the end and
267    * work backwards to avoid overwriting the source, if they're both the same
268    * block of memory.
269    */
270
271   else {
272     mpw w;
273     size_t nr = MPW_BITS - nb;
274     size_t dvn = dvl - dv;
275     size_t avn = avl - av;
276
277     if (dvn <= nw) {
278       MPX_ZERO(dv, dvl);
279       goto done;
280     }
281
282     if (dvn > avn + nw) {
283       size_t off = avn + nw + 1;
284       MPX_ZERO(dv + off, dvl);
285       dvl = dv + off;
286       w = 0;
287     } else {
288       avl = av + dvn - nw;
289       w = *--avl << nb;
290     }
291
292     while (avl > av) {
293       mpw t = *--avl;
294       *--dvl = (t >> nr) | w;
295       w = t << nb;
296     }
297
298     *--dvl = w;
299     MPX_ZERO(dv, dvl);
300   }
301
302 done:;
303 }
304
305 /* --- @mpx_lsr@ --- *
306  *
307  * Arguments:   @mpw *dv, *dvl@ = destination vector base and limit
308  *              @const mpw *av, *avl@ = source vector base and limit
309  *              @size_t n@ = number of bit positions to shift by
310  *
311  * Returns:     ---
312  *
313  * Use:         Performs a logical shift right operation on an integer.
314  */
315
316 void mpx_lsr(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl, size_t n)
317 {
318   size_t nw;
319   unsigned nb;
320
321   /* --- Trivial special case --- */
322
323   if (n == 0)
324     MPX_COPY(dv, dvl, av, avl);
325
326   /* --- Single bit shifting --- */
327
328   else if (n == 1) {
329     mpw w = *av++ >> 1;
330     while (av < avl) {
331       mpw t;
332       if (dv >= dvl)
333         goto done;
334       t = *av++;
335       *dv++ = MPW((t << (MPW_BITS - 1)) | w);
336       w = t >> 1;
337     }
338     if (dv >= dvl)
339       goto done;
340     *dv++ = MPW(w);
341     MPX_ZERO(dv, dvl);
342     goto done;
343   }
344
345   /* --- Break out word and bit shifts for more sophisticated work --- */
346
347   nw = n / MPW_BITS;
348   nb = n % MPW_BITS;
349
350   /* --- Handle a shift by a multiple of the word size --- */
351
352   if (nb == 0)
353     MPX_COPY(dv, dvl, av + nw, avl);
354
355   /* --- And finally the difficult case --- */
356
357   else {
358     mpw w;
359     size_t nr = MPW_BITS - nb;
360
361     av += nw;
362     w = *av++;
363     while (av < avl) {
364       mpw t;
365       if (dv >= dvl)
366         goto done;
367       t = *av++;
368       *dv++ = MPW((w >> nb) | (t << nr));
369       w = t;
370     }
371     if (dv < dvl) {
372       *dv++ = MPW(w >> nb);
373       MPX_ZERO(dv, dvl);
374     }
375   }
376
377 done:;
378 }
379
380 /*----- Unsigned arithmetic -----------------------------------------------*/
381
382 /* --- @mpx_ucmp@ --- *
383  *
384  * Arguments:   @const mpw *av, *avl@ = first argument vector base and limit
385  *              @const mpw *bv, *bvl@ = second argument vector base and limit
386  *
387  * Returns:     Less than, equal to, or greater than zero depending on
388  *              whether @a@ is less than, equal to or greater than @b@,
389  *              respectively.
390  *
391  * Use:         Performs an unsigned integer comparison.
392  */
393
394 int mpx_ucmp(const mpw *av, const mpw *avl, const mpw *bv, const mpw *bvl)
395 {
396   MPX_SHRINK(av, avl);
397   MPX_SHRINK(bv, bvl);
398
399   if (avl - av > bvl - bv)
400     return (+1);
401   else if (avl - av < bvl - bv)
402     return (-1);
403   else while (avl > av) {
404     mpw a = *--avl, b = *--bvl;
405     if (a > b)
406       return (+1);
407     else if (a < b)
408       return (-1);
409   }
410   return (0);
411 }
412   
413 /* --- @mpx_uadd@ --- *
414  *
415  * Arguments:   @mpw *dv, *dvl@ = destination vector base and limit
416  *              @const mpw *av, *avl@ = first addend vector base and limit
417  *              @const mpw *bv, *bvl@ = second addend vector base and limit
418  *
419  * Returns:     ---
420  *
421  * Use:         Performs unsigned integer addition.  If the result overflows
422  *              the destination vector, high-order bits are discarded.  This
423  *              means that two's complement addition happens more or less for
424  *              free, although that's more a side-effect than anything else.
425  *              The result vector may be equal to either or both source
426  *              vectors, but may not otherwise overlap them.
427  */
428
429 void mpx_uadd(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl,
430               const mpw *bv, const mpw *bvl)
431 {
432   mpw c = 0;
433
434   while (av < avl || bv < bvl) {
435     mpw a, b;
436     mpd x;
437     if (dv >= dvl)
438       return;
439     a = (av < avl) ? *av++ : 0;
440     b = (bv < bvl) ? *bv++ : 0;
441     x = (mpd)a + (mpd)b + c;
442     *dv++ = MPW(x);
443     c = x >> MPW_BITS;
444   }
445   if (dv < dvl) {
446     *dv++ = c;
447     MPX_ZERO(dv, dvl);
448   }
449 }
450
451 /* --- @mpx_usub@ --- *
452  *
453  * Arguments:   @mpw *dv, *dvl@ = destination vector base and limit
454  *              @const mpw *av, *avl@ = first argument vector base and limit
455  *              @const mpw *bv, *bvl@ = second argument vector base and limit
456  *
457  * Returns:     ---
458  *
459  * Use:         Performs unsigned integer subtraction.  If the result
460  *              overflows the destination vector, high-order bits are
461  *              discarded.  This means that two's complement subtraction
462  *              happens more or less for free, althuogh that's more a side-
463  *              effect than anything else.  The result vector may be equal to
464  *              either or both source vectors, but may not otherwise overlap
465  *              them.
466  */
467
468 void mpx_usub(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl,
469               const mpw *bv, const mpw *bvl)
470 {
471   mpw c = 0;
472
473   while (av < avl || bv < bvl) {
474     mpw a, b;
475     mpd x;
476     if (dv >= dvl)
477       return;
478     a = (av < avl) ? *av++ : 0;
479     b = (bv < bvl) ? *bv++ : 0;
480     x = (mpd)a - (mpd)b - c;
481     *dv++ = MPW(x);
482     if (x >> MPW_BITS)
483       c = 1;
484     else
485       c = 0;
486   }
487   if (c)
488     c = MPW_MAX;
489   while (dv < dvl)
490     *dv++ = c;
491 }
492
493 /* --- @mpx_umul@ --- *
494  *
495  * Arguments:   @mpw *dv, *dvl@ = destination vector base and limit
496  *              @const mpw *av, *avl@ = multiplicand vector base and limit
497  *              @const mpw *bv, *bvl@ = multiplier vector base and limit
498  *
499  * Returns:     ---
500  *
501  * Use:         Performs unsigned integer multiplication.  If the result
502  *              overflows the desination vector, high-order bits are
503  *              discarded.  The result vector may not overlap the argument
504  *              vectors in any way.
505  */
506
507 void mpx_umul(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl,
508               const mpw *bv, const mpw *bvl)
509 {
510   /* --- This is probably worthwhile on a multiply --- */
511
512   MPX_SHRINK(av, avl);
513   MPX_SHRINK(bv, bvl);
514
515   /* --- Deal with a multiply by zero --- */
516   
517   if (bv == bvl) {
518     MPX_ZERO(dv, dvl);
519     return;
520   }
521
522   /* --- Do the initial multiply and initialize the accumulator --- */
523
524   MPX_UMULN(dv, dvl, av, avl, *bv++);
525
526   /* --- Do the remaining multiply/accumulates --- */
527
528   while (dv < dvl && bv < bvl) {
529     mpw m = *bv++;
530     mpw c = 0;
531     const mpw *avv = av;
532     mpw *dvv = ++dv;
533
534     while (avv < avl) {
535       mpd x;
536       if (dvv >= dvl)
537         goto next;
538       x = (mpd)*dvv + (mpd)m * (mpd)*avv++ + c;
539       *dvv++ = MPW(x);
540       c = x >> MPW_BITS;
541     }
542     MPX_UADDN(dvv, dvl, c);
543   next:;
544   }
545 }
546
547 /* --- @mpx_usqr@ --- *
548  *
549  * Arguments:   @mpw *dv, *dvl@ = destination vector base and limit
550  *              @const mpw *av, *av@ = source vector base and limit
551  *
552  * Returns:     ---
553  *
554  * Use:         Performs unsigned integer squaring.  The result vector must
555  *              not overlap the source vector in any way.
556  */
557
558 void mpx_usqr(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl)
559 {
560   MPX_ZERO(dv, dvl);
561
562   /* --- Main loop --- */
563
564   while (av < avl) {
565     const mpw *avv = av;
566     mpw *dvv = dv;
567     mpw a = *av;
568     mpd c;
569
570     /* --- Stop if I've run out of destination --- */
571
572     if (dvv >= dvl)
573       break;
574
575     /* --- Work out the square at this point in the proceedings --- */
576
577     {
578       mpd x = (mpd)a * (mpd)a + *dvv;
579       *dvv++ = MPW(x);
580       c = MPW(x >> MPW_BITS);
581     }
582
583     /* --- Now fix up the rest of the vector upwards --- */
584
585     avv++;
586     while (dvv < dvl && avv < avl) {
587       mpd x = (mpd)a * (mpd)*avv++;
588       mpd y = ((x << 1) & MPW_MAX) + c + *dvv;
589       c = (x >> (MPW_BITS - 1)) + (y >> MPW_BITS);
590       *dvv++ = MPW(y);
591     }
592     while (dvv < dvl && c) {
593       mpd x = c + *dvv;
594       *dvv++ = MPW(x);
595       c = x >> MPW_BITS;
596     }
597
598     /* --- Get ready for the next round --- */
599
600     av++;
601     dv += 2;
602   }
603 }
604
605 /* --- @mpx_udiv@ --- *
606  *
607  * Arguments:   @mpw *qv, *qvl@ = quotient vector base and limit
608  *              @mpw *rv, *rvl@ = dividend/remainder vector base and limit
609  *              @const mpw *dv, *dvl@ = divisor vector base and limit
610  *              @mpw *sv, *svl@ = scratch workspace
611  *
612  * Returns:     ---
613  *
614  * Use:         Performs unsigned integer division.  If the result overflows
615  *              the quotient vector, high-order bits are discarded.  (Clearly
616  *              the remainder vector can't overflow.)  The various vectors
617  *              may not overlap in any way.  Yes, I know it's a bit odd
618  *              requiring the dividend to be in the result position but it
619  *              does make some sense really.  The remainder must have
620  *              headroom for at least two extra words.  The scratch space
621  *              must be at least two words larger than twice the size of the
622  *              divisor.
623  */
624
625 void mpx_udiv(mpw *qv, mpw *qvl, mpw *rv, mpw *rvl,
626               const mpw *dv, const mpw *dvl,
627               mpw *sv, mpw *svl)
628 {
629   unsigned norm = 0;
630   size_t scale;
631   mpw d, dd;
632
633   /* --- Initialize the quotient --- */
634
635   MPX_ZERO(qv, qvl);
636
637   /* --- Perform some sanity checks --- */
638
639   MPX_SHRINK(dv, dvl);
640   assert(((void)"division by zero in mpx_udiv", dv < dvl));
641
642   /* --- Normalize the divisor --- *
643    *
644    * The algorithm requires that the divisor be at least two digits long.
645    * This is easy to fix.
646    */
647
648   {
649     unsigned b;
650
651     d = dvl[-1];
652     for (b = MPW_BITS / 2; b; b >>= 1) {
653       if (d < (MPW_MAX >> b)) {
654         d <<= b;
655         norm += b;
656       }
657     }
658     if (dv + 1 == dvl)
659       norm += MPW_BITS;
660   }
661
662   /* --- Normalize the dividend/remainder to match --- */
663
664   if (norm) {
665     mpw *svvl = sv + (dvl - dv) + 1;
666     mpx_lsl(rv, rvl, rv, rvl, norm);
667     mpx_lsl(sv, svvl, dv, dvl, norm);
668     dv = sv;
669     sv = svvl;
670     dvl = svvl;
671     MPX_SHRINK(dv, dvl);
672   }
673
674   MPX_SHRINK(rv, rvl);
675   d = dvl[-1];
676   dd = dvl[-2];
677
678   /* --- Work out the relative scales --- */
679
680   {
681     size_t rvn = rvl - rv;
682     size_t dvn = dvl - dv;
683
684     /* --- If the divisor is clearly larger, notice this --- */
685
686     if (dvn > rvn) {
687       mpx_lsr(rv, rvl, rv, rvl, norm);
688       return;
689     }
690
691     scale = rvn - dvn;
692   }
693
694   /* --- Calculate the most significant quotient digit --- *
695    *
696    * Because the divisor has its top bit set, this can only happen once.  The
697    * pointer arithmetic is a little contorted, to make sure that the
698    * behaviour is defined.
699    */
700
701   if (MPX_UCMP(rv + scale, rvl, >=, dv, dvl)) {
702     mpx_usub(rv + scale, rvl, rv + scale, rvl, dv, dvl);
703     if (qvl - qv > scale)
704       qv[scale] = 1;
705   }
706
707   /* --- Now for the main loop --- */
708
709   {
710     mpw *rvv = rvl - 2;
711
712     while (scale) {
713       mpw q;
714       mpd rh;
715
716       /* --- Get an estimate for the next quotient digit --- */
717
718       mpw r = rvv[1];
719       mpw rr = rvv[0];
720       mpw rrr = *--rvv;
721
722       scale--;
723       rh = ((mpd)r << MPW_BITS) | rr;
724       if (r == d)
725         q = MPW_MAX;
726       else
727         q = MPW(rh / d);
728
729       /* --- Refine the estimate --- */
730
731       {
732         mpd yh = (mpd)d * q;
733         mpd yl = (mpd)dd * q;
734
735         if (yl > MPW_MAX) {
736           yh += yl >> MPW_BITS;
737           yl &= MPW_MAX;
738         }
739
740         while (yh > rh || (yh == rh && yl > rrr)) {
741           q--;
742           yh -= d;
743           if (yl < dd) {
744             yh++;
745             yl += MPW_MAX;
746           }
747           yl -= dd;
748         }
749       }
750
751       /* --- Remove a chunk from the dividend --- */
752
753       {
754         mpw *svv;
755         const mpw *dvv;
756         mpw c = 0;
757
758         /* --- Calculate the size of the chunk --- */
759
760         for (svv = sv, dvv = dv; dvv < dvl; svv++, dvv++) {
761           mpd x = (mpd)*dvv * (mpd)q + c;
762           *svv = MPW(x);
763           c = x >> MPW_BITS;
764         }
765         if (c)
766           *svv++ = c;
767
768         /* --- Now make sure that we can cope with the difference --- *
769          *
770          * Take advantage of the fact that subtraction works two's-
771          * complement.
772          */
773
774         mpx_usub(rv + scale, rvl, rv + scale, rvl, sv, svv);
775         if (rvl[-1] > MPW_MAX / 2) {
776           mpx_uadd(rv + scale, rvl, rv + scale, rvl, dv, dvl);
777           q--;
778         }
779       }
780
781       /* --- Done for another iteration --- */
782
783       if (qvl - qv > scale)
784         qv[scale] = q;
785       r = rr;
786       rr = rrr;
787     }
788   }
789
790   /* --- Now fiddle with unnormalizing and things --- */
791
792   mpx_lsr(rv, rvl, rv, rvl, norm);
793 }
794
795 /*----- That's all, folks -------------------------------------------------*/