chiark / gitweb /
Add some simple bitwise operations so that Perl can use them.
[catacomb] / mpx.c
1 /* -*-c-*-
2  *
3  * $Id: mpx.c,v 1.11 2001/04/03 19:36:05 mdw Exp $
4  *
5  * Low-level multiprecision arithmetic
6  *
7  * (c) 1999 Straylight/Edgeware
8  */
9
10 /*----- Licensing notice --------------------------------------------------* 
11  *
12  * This file is part of Catacomb.
13  *
14  * Catacomb is free software; you can redistribute it and/or modify
15  * it under the terms of the GNU Library General Public License as
16  * published by the Free Software Foundation; either version 2 of the
17  * License, or (at your option) any later version.
18  * 
19  * Catacomb is distributed in the hope that it will be useful,
20  * but WITHOUT ANY WARRANTY; without even the implied warranty of
21  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
22  * GNU Library General Public License for more details.
23  * 
24  * You should have received a copy of the GNU Library General Public
25  * License along with Catacomb; if not, write to the Free
26  * Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
27  * MA 02111-1307, USA.
28  */
29
30 /*----- Revision history --------------------------------------------------* 
31  *
32  * $Log: mpx.c,v $
33  * Revision 1.11  2001/04/03 19:36:05  mdw
34  * Add some simple bitwise operations so that Perl can use them.
35  *
36  * Revision 1.10  2000/10/08 12:06:12  mdw
37  * Provide @mpx_ueq@ for rapidly testing equality of two integers.
38  *
39  * Revision 1.9  2000/06/26 07:52:50  mdw
40  * Portability fix for the bug fix.
41  *
42  * Revision 1.8  2000/06/25 12:59:02  mdw
43  * (mpx_udiv): Fix bug in quotient digit estimation.
44  *
45  * Revision 1.7  1999/12/22 15:49:07  mdw
46  * New function for division by a small integer.
47  *
48  * Revision 1.6  1999/11/20 22:43:44  mdw
49  * Integrate testing for MPX routines.
50  *
51  * Revision 1.5  1999/11/20 22:23:27  mdw
52  * Add function versions of some low-level macros with wider use.
53  *
54  * Revision 1.4  1999/11/17 18:04:09  mdw
55  * Add two's-complement functionality.  Improve mpx_udiv a little by
56  * performing the multiplication of the divisor by q with the subtraction
57  * from r.
58  *
59  * Revision 1.3  1999/11/13 01:57:31  mdw
60  * Remove stray debugging code.
61  *
62  * Revision 1.2  1999/11/13 01:50:59  mdw
63  * Multiprecision routines finished and tested.
64  *
65  * Revision 1.1  1999/09/03 08:41:12  mdw
66  * Initial import.
67  *
68  */
69
70 /*----- Header files ------------------------------------------------------*/
71
72 #include <assert.h>
73 #include <stdio.h>
74 #include <stdlib.h>
75 #include <string.h>
76
77 #include <mLib/bits.h>
78
79 #include "mptypes.h"
80 #include "mpx.h"
81
82 /*----- Loading and storing -----------------------------------------------*/
83
84 /* --- @mpx_storel@ --- *
85  *
86  * Arguments:   @const mpw *v, *vl@ = base and limit of source vector
87  *              @void *pp@ = pointer to octet array
88  *              @size_t sz@ = size of octet array
89  *
90  * Returns:     ---
91  *
92  * Use:         Stores an MP in an octet array, least significant octet
93  *              first.  High-end octets are silently discarded if there
94  *              isn't enough space for them.
95  */
96
97 void mpx_storel(const mpw *v, const mpw *vl, void *pp, size_t sz)
98 {
99   mpw n, w = 0;
100   octet *p = pp, *q = p + sz;
101   unsigned bits = 0;
102
103   while (p < q) {
104     if (bits < 8) {
105       if (v >= vl) {
106         *p++ = U8(w);
107         break;
108       }
109       n = *v++;
110       *p++ = U8(w | n << bits);
111       w = n >> (8 - bits);
112       bits += MPW_BITS - 8;
113     } else {
114       *p++ = U8(w);
115       w >>= 8;
116       bits -= 8;
117     }
118   }
119   memset(p, 0, q - p);
120 }
121
122 /* --- @mpx_loadl@ --- *
123  *
124  * Arguments:   @mpw *v, *vl@ = base and limit of destination vector
125  *              @const void *pp@ = pointer to octet array
126  *              @size_t sz@ = size of octet array
127  *
128  * Returns:     ---
129  *
130  * Use:         Loads an MP in an octet array, least significant octet
131  *              first.  High-end octets are ignored if there isn't enough
132  *              space for them.
133  */
134
135 void mpx_loadl(mpw *v, mpw *vl, const void *pp, size_t sz)
136 {
137   unsigned n;
138   mpw w = 0;
139   const octet *p = pp, *q = p + sz;
140   unsigned bits = 0;
141
142   if (v >= vl)
143     return;
144   while (p < q) {
145     n = U8(*p++);
146     w |= n << bits;
147     bits += 8;
148     if (bits >= MPW_BITS) {
149       *v++ = MPW(w);
150       w = n >> (MPW_BITS - bits + 8);
151       bits -= MPW_BITS;
152       if (v >= vl)
153         return;
154     }
155   }
156   *v++ = w;
157   MPX_ZERO(v, vl);
158 }
159
160 /* --- @mpx_storeb@ --- *
161  *
162  * Arguments:   @const mpw *v, *vl@ = base and limit of source vector
163  *              @void *pp@ = pointer to octet array
164  *              @size_t sz@ = size of octet array
165  *
166  * Returns:     ---
167  *
168  * Use:         Stores an MP in an octet array, most significant octet
169  *              first.  High-end octets are silently discarded if there
170  *              isn't enough space for them.
171  */
172
173 void mpx_storeb(const mpw *v, const mpw *vl, void *pp, size_t sz)
174 {
175   mpw n, w = 0;
176   octet *p = pp, *q = p + sz;
177   unsigned bits = 0;
178
179   while (q > p) {
180     if (bits < 8) {
181       if (v >= vl) {
182         *--q = U8(w);
183         break;
184       }
185       n = *v++;
186       *--q = U8(w | n << bits);
187       w = n >> (8 - bits);
188       bits += MPW_BITS - 8;
189     } else {
190       *--q = U8(w);
191       w >>= 8;
192       bits -= 8;
193     }
194   }
195   memset(p, 0, q - p);
196 }
197
198 /* --- @mpx_loadb@ --- *
199  *
200  * Arguments:   @mpw *v, *vl@ = base and limit of destination vector
201  *              @const void *pp@ = pointer to octet array
202  *              @size_t sz@ = size of octet array
203  *
204  * Returns:     ---
205  *
206  * Use:         Loads an MP in an octet array, most significant octet
207  *              first.  High-end octets are ignored if there isn't enough
208  *              space for them.
209  */
210
211 void mpx_loadb(mpw *v, mpw *vl, const void *pp, size_t sz)
212 {
213   unsigned n;
214   mpw w = 0;
215   const octet *p = pp, *q = p + sz;
216   unsigned bits = 0;
217
218   if (v >= vl)
219     return;
220   while (q > p) {
221     n = U8(*--q);
222     w |= n << bits;
223     bits += 8;
224     if (bits >= MPW_BITS) {
225       *v++ = MPW(w);
226       w = n >> (MPW_BITS - bits + 8);
227       bits -= MPW_BITS;
228       if (v >= vl)
229         return;
230     }
231   }
232   *v++ = w;
233   MPX_ZERO(v, vl);
234 }
235
236 /*----- Logical shifting --------------------------------------------------*/
237
238 /* --- @mpx_lsl@ --- *
239  *
240  * Arguments:   @mpw *dv, *dvl@ = destination vector base and limit
241  *              @const mpw *av, *avl@ = source vector base and limit
242  *              @size_t n@ = number of bit positions to shift by
243  *
244  * Returns:     ---
245  *
246  * Use:         Performs a logical shift left operation on an integer.
247  */
248
249 void mpx_lsl(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl, size_t n)
250 {
251   size_t nw;
252   unsigned nb;
253
254   /* --- Trivial special case --- */
255
256   if (n == 0)
257     MPX_COPY(dv, dvl, av, avl);
258
259   /* --- Single bit shifting --- */
260
261   else if (n == 1) {
262     mpw w = 0;
263     while (av < avl) {
264       mpw t;
265       if (dv >= dvl)
266         goto done;
267       t = *av++;
268       *dv++ = MPW((t << 1) | w);
269       w = t >> (MPW_BITS - 1);
270     }
271     if (dv >= dvl)
272       goto done;
273     *dv++ = MPW(w);
274     MPX_ZERO(dv, dvl);
275     goto done;
276   }
277
278   /* --- Break out word and bit shifts for more sophisticated work --- */
279         
280   nw = n / MPW_BITS;
281   nb = n % MPW_BITS;
282
283   /* --- Handle a shift by a multiple of the word size --- */
284
285   if (nb == 0) {
286     MPX_COPY(dv + nw, dvl, av, avl);
287     memset(dv, 0, MPWS(nw));
288   }
289
290   /* --- And finally the difficult case --- *
291    *
292    * This is a little convoluted, because I have to start from the end and
293    * work backwards to avoid overwriting the source, if they're both the same
294    * block of memory.
295    */
296
297   else {
298     mpw w;
299     size_t nr = MPW_BITS - nb;
300     size_t dvn = dvl - dv;
301     size_t avn = avl - av;
302
303     if (dvn <= nw) {
304       MPX_ZERO(dv, dvl);
305       goto done;
306     }
307
308     if (dvn > avn + nw) {
309       size_t off = avn + nw + 1;
310       MPX_ZERO(dv + off, dvl);
311       dvl = dv + off;
312       w = 0;
313     } else {
314       avl = av + dvn - nw;
315       w = *--avl << nb;
316     }
317
318     while (avl > av) {
319       mpw t = *--avl;
320       *--dvl = (t >> nr) | w;
321       w = t << nb;
322     }
323
324     *--dvl = w;
325     MPX_ZERO(dv, dvl);
326   }
327
328 done:;
329 }
330
331 /* --- @mpx_lsr@ --- *
332  *
333  * Arguments:   @mpw *dv, *dvl@ = destination vector base and limit
334  *              @const mpw *av, *avl@ = source vector base and limit
335  *              @size_t n@ = number of bit positions to shift by
336  *
337  * Returns:     ---
338  *
339  * Use:         Performs a logical shift right operation on an integer.
340  */
341
342 void mpx_lsr(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl, size_t n)
343 {
344   size_t nw;
345   unsigned nb;
346
347   /* --- Trivial special case --- */
348
349   if (n == 0)
350     MPX_COPY(dv, dvl, av, avl);
351
352   /* --- Single bit shifting --- */
353
354   else if (n == 1) {
355     mpw w = *av++ >> 1;
356     while (av < avl) {
357       mpw t;
358       if (dv >= dvl)
359         goto done;
360       t = *av++;
361       *dv++ = MPW((t << (MPW_BITS - 1)) | w);
362       w = t >> 1;
363     }
364     if (dv >= dvl)
365       goto done;
366     *dv++ = MPW(w);
367     MPX_ZERO(dv, dvl);
368     goto done;
369   }
370
371   /* --- Break out word and bit shifts for more sophisticated work --- */
372
373   nw = n / MPW_BITS;
374   nb = n % MPW_BITS;
375
376   /* --- Handle a shift by a multiple of the word size --- */
377
378   if (nb == 0)
379     MPX_COPY(dv, dvl, av + nw, avl);
380
381   /* --- And finally the difficult case --- */
382
383   else {
384     mpw w;
385     size_t nr = MPW_BITS - nb;
386
387     av += nw;
388     w = *av++;
389     while (av < avl) {
390       mpw t;
391       if (dv >= dvl)
392         goto done;
393       t = *av++;
394       *dv++ = MPW((w >> nb) | (t << nr));
395       w = t;
396     }
397     if (dv < dvl) {
398       *dv++ = MPW(w >> nb);
399       MPX_ZERO(dv, dvl);
400     }
401   }
402
403 done:;
404 }
405
406 /*----- Bitwise operations ------------------------------------------------*/
407
408 /* --- @mpx_and@, @mpx_or@, @mpx_xor@, @mpx_not@ --- *
409  *
410  * Arguments:   @mpw *dv, *dvl@ = destination vector
411  *              @const mpw *av, *avl@ = first source vector
412  *              @const mpw *bv, *bvl@ = second source vector
413  *
414  * Returns:     ---
415  *
416  * Use;         Does the obvious bitwise operations.
417  */
418
419 #define MPX_BITBINOP(name, op)                                          \
420                                                                         \
421 void mpx_##name(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl,       \
422                 const mpw *bv, const mpw *bvl)                          \
423 {                                                                       \
424   MPX_SHRINK(av, avl);                                                  \
425   MPX_SHRINK(bv, bvl);                                                  \
426                                                                         \
427   while (dv < dvl) {                                                    \
428     mpw a, b;                                                           \
429     a = (av < avl) ? *av++ : 0;                                         \
430     b = (bv < bvl) ? *bv++ : 0;                                         \
431     *dv++ = a op b;                                                     \
432   }                                                                     \
433 }
434
435 MPX_BITBINOP(and, &)
436 MPX_BITBINOP(or, |)
437 MPX_BITBINOP(xor, ^)
438
439 void mpx_not(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl)
440 {
441   MPX_SHRINK(av, avl);
442
443   while (dv < dvl) {
444     mpw a;
445     a = (av < avl) ? *av++ : 0;
446     *dv++ = ~a;
447   }
448 }
449
450 /*----- Unsigned arithmetic -----------------------------------------------*/
451
452 /* --- @mpx_2c@ --- *
453  *
454  * Arguments:   @mpw *dv, *dvl@ = destination vector
455  *              @const mpw *v, *vl@ = source vector
456  *
457  * Returns:     ---
458  *
459  * Use:         Calculates the two's complement of @v@.
460  */
461
462 void mpx_2c(mpw *dv, mpw *dvl, const mpw *v, const mpw *vl)
463 {
464   mpw c = 0;
465   while (dv < dvl && v < vl)
466     *dv++ = c = MPW(~*v++);
467   if (dv < dvl) {
468     if (c > MPW_MAX / 2)
469       c = MPW(~0);
470     while (dv < dvl)
471       *dv++ = c;
472   }
473   MPX_UADDN(dv, dvl, 1);
474 }
475
476 /* --- @mpx_ueq@ --- *
477  *
478  * Arguments:   @const mpw *av, *avl@ = first argument vector base and limit
479  *              @const mpw *bv, *bvl@ = second argument vector base and limit
480  *
481  * Returns:     Nonzero if the two vectors are equal.
482  *
483  * Use:         Performs an unsigned integer test for equality.
484  */
485
486 int mpx_ueq(const mpw *av, const mpw *avl, const mpw *bv, const mpw *bvl)
487 {
488   MPX_SHRINK(av, avl);
489   MPX_SHRINK(bv, bvl);
490   if (avl - av != bvl - bv)
491     return (0);
492   while (av < avl) {
493     if (*av++ != *bv++)
494       return (0);
495   }
496   return (1);
497 }
498
499 /* --- @mpx_ucmp@ --- *
500  *
501  * Arguments:   @const mpw *av, *avl@ = first argument vector base and limit
502  *              @const mpw *bv, *bvl@ = second argument vector base and limit
503  *
504  * Returns:     Less than, equal to, or greater than zero depending on
505  *              whether @a@ is less than, equal to or greater than @b@,
506  *              respectively.
507  *
508  * Use:         Performs an unsigned integer comparison.
509  */
510
511 int mpx_ucmp(const mpw *av, const mpw *avl, const mpw *bv, const mpw *bvl)
512 {
513   MPX_SHRINK(av, avl);
514   MPX_SHRINK(bv, bvl);
515
516   if (avl - av > bvl - bv)
517     return (+1);
518   else if (avl - av < bvl - bv)
519     return (-1);
520   else while (avl > av) {
521     mpw a = *--avl, b = *--bvl;
522     if (a > b)
523       return (+1);
524     else if (a < b)
525       return (-1);
526   }
527   return (0);
528 }
529
530 /* --- @mpx_uadd@ --- *
531  *
532  * Arguments:   @mpw *dv, *dvl@ = destination vector base and limit
533  *              @const mpw *av, *avl@ = first addend vector base and limit
534  *              @const mpw *bv, *bvl@ = second addend vector base and limit
535  *
536  * Returns:     ---
537  *
538  * Use:         Performs unsigned integer addition.  If the result overflows
539  *              the destination vector, high-order bits are discarded.  This
540  *              means that two's complement addition happens more or less for
541  *              free, although that's more a side-effect than anything else.
542  *              The result vector may be equal to either or both source
543  *              vectors, but may not otherwise overlap them.
544  */
545
546 void mpx_uadd(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl,
547               const mpw *bv, const mpw *bvl)
548 {
549   mpw c = 0;
550
551   while (av < avl || bv < bvl) {
552     mpw a, b;
553     mpd x;
554     if (dv >= dvl)
555       return;
556     a = (av < avl) ? *av++ : 0;
557     b = (bv < bvl) ? *bv++ : 0;
558     x = (mpd)a + (mpd)b + c;
559     *dv++ = MPW(x);
560     c = x >> MPW_BITS;
561   }
562   if (dv < dvl) {
563     *dv++ = c;
564     MPX_ZERO(dv, dvl);
565   }
566 }
567
568 /* --- @mpx_uaddn@ --- *
569  *
570  * Arguments:   @mpw *dv, *dvl@ = source and destination base and limit
571  *              @mpw n@ = other addend
572  *
573  * Returns:     ---
574  *
575  * Use:         Adds a small integer to a multiprecision number.
576  */
577
578 void mpx_uaddn(mpw *dv, mpw *dvl, mpw n) { MPX_UADDN(dv, dvl, n); }
579
580 /* --- @mpx_usub@ --- *
581  *
582  * Arguments:   @mpw *dv, *dvl@ = destination vector base and limit
583  *              @const mpw *av, *avl@ = first argument vector base and limit
584  *              @const mpw *bv, *bvl@ = second argument vector base and limit
585  *
586  * Returns:     ---
587  *
588  * Use:         Performs unsigned integer subtraction.  If the result
589  *              overflows the destination vector, high-order bits are
590  *              discarded.  This means that two's complement subtraction
591  *              happens more or less for free, althuogh that's more a side-
592  *              effect than anything else.  The result vector may be equal to
593  *              either or both source vectors, but may not otherwise overlap
594  *              them.
595  */
596
597 void mpx_usub(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl,
598               const mpw *bv, const mpw *bvl)
599 {
600   mpw c = 0;
601
602   while (av < avl || bv < bvl) {
603     mpw a, b;
604     mpd x;
605     if (dv >= dvl)
606       return;
607     a = (av < avl) ? *av++ : 0;
608     b = (bv < bvl) ? *bv++ : 0;
609     x = (mpd)a - (mpd)b - c;
610     *dv++ = MPW(x);
611     if (x >> MPW_BITS)
612       c = 1;
613     else
614       c = 0;
615   }
616   if (c)
617     c = MPW_MAX;
618   while (dv < dvl)
619     *dv++ = c;
620 }
621
622 /* --- @mpx_usubn@ --- *
623  *
624  * Arguments:   @mpw *dv, *dvl@ = source and destination base and limit
625  *              @n@ = subtrahend
626  *
627  * Returns:     ---
628  *
629  * Use:         Subtracts a small integer from a multiprecision number.
630  */
631
632 void mpx_usubn(mpw *dv, mpw *dvl, mpw n) { MPX_USUBN(dv, dvl, n); }
633
634 /* --- @mpx_umul@ --- *
635  *
636  * Arguments:   @mpw *dv, *dvl@ = destination vector base and limit
637  *              @const mpw *av, *avl@ = multiplicand vector base and limit
638  *              @const mpw *bv, *bvl@ = multiplier vector base and limit
639  *
640  * Returns:     ---
641  *
642  * Use:         Performs unsigned integer multiplication.  If the result
643  *              overflows the desination vector, high-order bits are
644  *              discarded.  The result vector may not overlap the argument
645  *              vectors in any way.
646  */
647
648 void mpx_umul(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl,
649               const mpw *bv, const mpw *bvl)
650 {
651   /* --- This is probably worthwhile on a multiply --- */
652
653   MPX_SHRINK(av, avl);
654   MPX_SHRINK(bv, bvl);
655
656   /* --- Deal with a multiply by zero --- */
657   
658   if (bv == bvl) {
659     MPX_ZERO(dv, dvl);
660     return;
661   }
662
663   /* --- Do the initial multiply and initialize the accumulator --- */
664
665   MPX_UMULN(dv, dvl, av, avl, *bv++);
666
667   /* --- Do the remaining multiply/accumulates --- */
668
669   while (dv < dvl && bv < bvl) {
670     mpw m = *bv++;
671     mpw c = 0;
672     const mpw *avv = av;
673     mpw *dvv = ++dv;
674
675     while (avv < avl) {
676       mpd x;
677       if (dvv >= dvl)
678         goto next;
679       x = (mpd)*dvv + (mpd)m * (mpd)*avv++ + c;
680       *dvv++ = MPW(x);
681       c = x >> MPW_BITS;
682     }
683     MPX_UADDN(dvv, dvl, c);
684   next:;
685   }
686 }
687
688 /* --- @mpx_umuln@ --- *
689  *
690  * Arguments:   @mpw *dv, *dvl@ = destination vector base and limit
691  *              @const mpw *av, *avl@ = multiplicand vector base and limit
692  *              @mpw m@ = multiplier
693  *
694  * Returns:     ---
695  *
696  * Use:         Multiplies a multiprecision integer by a single-word value.
697  *              The destination and source may be equal.  The destination
698  *              is completely cleared after use.
699  */
700
701 void mpx_umuln(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl, mpw m)
702 {
703   MPX_UMULN(dv, dvl, av, avl, m);
704 }
705
706 /* --- @mpx_umlan@ --- *
707  *
708  * Arguments:   @mpw *dv, *dvl@ = destination/accumulator base and limit
709  *              @const mpw *av, *avl@ = multiplicand vector base and limit
710  *              @mpw m@ = multiplier
711  *
712  * Returns:     ---
713  *
714  * Use:         Multiplies a multiprecision integer by a single-word value
715  *              and adds the result to an accumulator.
716  */
717
718 void mpx_umlan(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl, mpw m)
719 {
720   MPX_UMLAN(dv, dvl, av, avl, m);
721 }
722
723 /* --- @mpx_usqr@ --- *
724  *
725  * Arguments:   @mpw *dv, *dvl@ = destination vector base and limit
726  *              @const mpw *av, *av@ = source vector base and limit
727  *
728  * Returns:     ---
729  *
730  * Use:         Performs unsigned integer squaring.  The result vector must
731  *              not overlap the source vector in any way.
732  */
733
734 void mpx_usqr(mpw *dv, mpw *dvl, const mpw *av, const mpw *avl)
735 {
736   MPX_ZERO(dv, dvl);
737
738   /* --- Main loop --- */
739
740   while (av < avl) {
741     const mpw *avv = av;
742     mpw *dvv = dv;
743     mpw a = *av;
744     mpd c;
745
746     /* --- Stop if I've run out of destination --- */
747
748     if (dvv >= dvl)
749       break;
750
751     /* --- Work out the square at this point in the proceedings --- */
752
753     {
754       mpd x = (mpd)a * (mpd)a + *dvv;
755       *dvv++ = MPW(x);
756       c = MPW(x >> MPW_BITS);
757     }
758
759     /* --- Now fix up the rest of the vector upwards --- */
760
761     avv++;
762     while (dvv < dvl && avv < avl) {
763       mpd x = (mpd)a * (mpd)*avv++;
764       mpd y = ((x << 1) & MPW_MAX) + c + *dvv;
765       c = (x >> (MPW_BITS - 1)) + (y >> MPW_BITS);
766       *dvv++ = MPW(y);
767     }
768     while (dvv < dvl && c) {
769       mpd x = c + *dvv;
770       *dvv++ = MPW(x);
771       c = x >> MPW_BITS;
772     }
773
774     /* --- Get ready for the next round --- */
775
776     av++;
777     dv += 2;
778   }
779 }
780
781 /* --- @mpx_udiv@ --- *
782  *
783  * Arguments:   @mpw *qv, *qvl@ = quotient vector base and limit
784  *              @mpw *rv, *rvl@ = dividend/remainder vector base and limit
785  *              @const mpw *dv, *dvl@ = divisor vector base and limit
786  *              @mpw *sv, *svl@ = scratch workspace
787  *
788  * Returns:     ---
789  *
790  * Use:         Performs unsigned integer division.  If the result overflows
791  *              the quotient vector, high-order bits are discarded.  (Clearly
792  *              the remainder vector can't overflow.)  The various vectors
793  *              may not overlap in any way.  Yes, I know it's a bit odd
794  *              requiring the dividend to be in the result position but it
795  *              does make some sense really.  The remainder must have
796  *              headroom for at least two extra words.  The scratch space
797  *              must be at least one word larger than the divisor.
798  */
799
800 void mpx_udiv(mpw *qv, mpw *qvl, mpw *rv, mpw *rvl,
801               const mpw *dv, const mpw *dvl,
802               mpw *sv, mpw *svl)
803 {
804   unsigned norm = 0;
805   size_t scale;
806   mpw d, dd;
807
808   /* --- Initialize the quotient --- */
809
810   MPX_ZERO(qv, qvl);
811
812   /* --- Perform some sanity checks --- */
813
814   MPX_SHRINK(dv, dvl);
815   assert(((void)"division by zero in mpx_udiv", dv < dvl));
816
817   /* --- Normalize the divisor --- *
818    *
819    * The algorithm requires that the divisor be at least two digits long.
820    * This is easy to fix.
821    */
822
823   {
824     unsigned b;
825
826     d = dvl[-1];
827     for (b = MPW_BITS / 2; b; b >>= 1) {
828       if (d < (MPW_MAX >> b)) {
829         d <<= b;
830         norm += b;
831       }
832     }
833     if (dv + 1 == dvl)
834       norm += MPW_BITS;
835   }
836
837   /* --- Normalize the dividend/remainder to match --- */
838
839   if (norm) {
840     mpx_lsl(rv, rvl, rv, rvl, norm);
841     mpx_lsl(sv, svl, dv, dvl, norm);
842     dv = sv;
843     dvl = svl;
844     MPX_SHRINK(dv, dvl);
845   }
846
847   MPX_SHRINK(rv, rvl);
848   d = dvl[-1];
849   dd = dvl[-2];
850
851   /* --- Work out the relative scales --- */
852
853   {
854     size_t rvn = rvl - rv;
855     size_t dvn = dvl - dv;
856
857     /* --- If the divisor is clearly larger, notice this --- */
858
859     if (dvn > rvn) {
860       mpx_lsr(rv, rvl, rv, rvl, norm);
861       return;
862     }
863
864     scale = rvn - dvn;
865   }
866
867   /* --- Calculate the most significant quotient digit --- *
868    *
869    * Because the divisor has its top bit set, this can only happen once.  The
870    * pointer arithmetic is a little contorted, to make sure that the
871    * behaviour is defined.
872    */
873
874   if (MPX_UCMP(rv + scale, rvl, >=, dv, dvl)) {
875     mpx_usub(rv + scale, rvl, rv + scale, rvl, dv, dvl);
876     if (qvl - qv > scale)
877       qv[scale] = 1;
878   }
879
880   /* --- Now for the main loop --- */
881
882   {
883     mpw *rvv = rvl - 2;
884
885     while (scale) {
886       mpw q;
887       mpd rh;
888
889       /* --- Get an estimate for the next quotient digit --- */
890
891       mpw r = rvv[1];
892       mpw rr = rvv[0];
893       mpw rrr = *--rvv;
894
895       scale--;
896       rh = ((mpd)r << MPW_BITS) | rr;
897       if (r == d)
898         q = MPW_MAX;
899       else
900         q = MPW(rh / d);
901
902       /* --- Refine the estimate --- */
903
904       {
905         mpd yh = (mpd)d * q;
906         mpd yy = (mpd)dd * q;
907         mpw yl;
908
909         if (yy > MPW_MAX)
910           yh += yy >> MPW_BITS;
911         yl = MPW(yy);
912
913         while (yh > rh || (yh == rh && yl > rrr)) {
914           q--;
915           yh -= d;
916           if (yl < dd)
917             yh--;
918           yl = MPW(yl - dd);
919         }
920       }
921
922       /* --- Remove a chunk from the dividend --- */
923
924       {
925         mpw *svv;
926         const mpw *dvv;
927         mpw mc = 0, sc = 0;
928
929         /* --- Calculate the size of the chunk --- *
930          *
931          * This does the whole job of calculating @r >> scale - qd@.
932          */
933
934         for (svv = rv + scale, dvv = dv;
935              dvv < dvl && svv < rvl;
936              svv++, dvv++) {
937           mpd x = (mpd)*dvv * (mpd)q + mc;
938           mc = x >> MPW_BITS;
939           x = (mpd)*svv - MPW(x) - sc;
940           *svv = MPW(x);
941           if (x >> MPW_BITS)
942             sc = 1;
943           else
944             sc = 0;
945         }
946
947         if (svv < rvl) {
948           mpd x = (mpd)*svv - mc - sc;
949           *svv++ = MPW(x);
950           if (x >> MPW_BITS)
951             sc = MPW_MAX;
952           else
953             sc = 0;
954           while (svv < rvl)
955             *svv++ = sc;
956         }
957
958         /* --- Fix if the quotient was too large --- *
959          *
960          * This doesn't seem to happen very often.
961          */
962
963         if (rvl[-1] > MPW_MAX / 2) {
964           mpx_uadd(rv + scale, rvl, rv + scale, rvl, dv, dvl);
965           q--;
966         }
967       }
968
969       /* --- Done for another iteration --- */
970
971       if (qvl - qv > scale)
972         qv[scale] = q;
973       r = rr;
974       rr = rrr;
975     }
976   }
977
978   /* --- Now fiddle with unnormalizing and things --- */
979
980   mpx_lsr(rv, rvl, rv, rvl, norm);
981 }
982
983 /* --- @mpx_udivn@ --- *
984  *
985  * Arguments:   @mpw *qv, *qvl@ = storage for the quotient (may overlap
986  *                      dividend)
987  *              @const mpw *rv, *rvl@ = dividend
988  *              @mpw d@ = single-precision divisor
989  *
990  * Returns:     Remainder after divison.
991  *
992  * Use:         Performs a single-precision division operation.
993  */
994
995 mpw mpx_udivn(mpw *qv, mpw *qvl, const mpw *rv, const mpw *rvl, mpw d)
996 {
997   size_t i;
998   size_t ql = qvl - qv;
999   mpd r = 0;
1000
1001   i = rvl - rv;
1002   while (i > 0) {
1003     i--;
1004     r = (r << MPW_BITS) | rv[i];
1005     if (i < ql)
1006       qv[i] = r / d;
1007     r %= d;
1008   }
1009   return (MPW(r));
1010 }
1011
1012 /*----- Test rig ----------------------------------------------------------*/
1013
1014 #ifdef TEST_RIG
1015
1016 #include <mLib/alloc.h>
1017 #include <mLib/dstr.h>
1018 #include <mLib/quis.h>
1019 #include <mLib/testrig.h>
1020
1021 #include "mpscan.h"
1022
1023 #define ALLOC(v, vl, sz) do {                                           \
1024   size_t _sz = (sz);                                                    \
1025   mpw *_vv = xmalloc(MPWS(_sz));                                        \
1026   mpw *_vvl = _vv + _sz;                                                \
1027   (v) = _vv;                                                            \
1028   (vl) = _vvl;                                                          \
1029 } while (0)
1030
1031 #define LOAD(v, vl, d) do {                                             \
1032   const dstr *_d = (d);                                                 \
1033   mpw *_v, *_vl;                                                        \
1034   ALLOC(_v, _vl, MPW_RQ(_d->len));                                      \
1035   mpx_loadb(_v, _vl, _d->buf, _d->len);                                 \
1036   (v) = _v;                                                             \
1037   (vl) = _vl;                                                           \
1038 } while (0)
1039
1040 #define MAX(x, y) ((x) > (y) ? (x) : (y))
1041   
1042 static void dumpbits(const char *msg, const void *pp, size_t sz)
1043 {
1044   const octet *p = pp;
1045   fputs(msg, stderr);
1046   for (; sz; sz--)
1047     fprintf(stderr, " %02x", *p++);
1048   fputc('\n', stderr);
1049 }
1050
1051 static void dumpmp(const char *msg, const mpw *v, const mpw *vl)
1052 {
1053   fputs(msg, stderr);
1054   MPX_SHRINK(v, vl);
1055   while (v < vl)
1056     fprintf(stderr, " %08lx", (unsigned long)*--vl);
1057   fputc('\n', stderr);
1058 }
1059
1060 static int chkscan(const mpw *v, const mpw *vl,
1061                    const void *pp, size_t sz, int step)
1062 {
1063   mpscan mps;
1064   const octet *p = pp;
1065   unsigned bit = 0;
1066   int ok = 1;
1067
1068   mpscan_initx(&mps, v, vl);
1069   while (sz) {
1070     unsigned x = *p;
1071     int i;
1072     p += step;
1073     for (i = 0; i < 8 && MPSCAN_STEP(&mps); i++) {
1074       if (MPSCAN_BIT(&mps) != (x & 1)) {
1075         fprintf(stderr,
1076                 "\n*** error, step %i, bit %u, expected %u, found %u\n",
1077                 step, bit, x & 1, MPSCAN_BIT(&mps));
1078         ok = 0;
1079       }
1080       x >>= 1;
1081       bit++;
1082     }
1083     sz--;
1084   }
1085
1086   return (ok);
1087 }
1088
1089 static int loadstore(dstr *v)
1090 {
1091   dstr d = DSTR_INIT;
1092   size_t sz = MPW_RQ(v->len) * 2, diff;
1093   mpw *m, *ml;
1094   int ok = 1;
1095
1096   dstr_ensure(&d, v->len);
1097   m = xmalloc(MPWS(sz));
1098
1099   for (diff = 0; diff < sz; diff += 5) {
1100     size_t oct;
1101
1102     ml = m + sz - diff;
1103
1104     mpx_loadl(m, ml, v->buf, v->len);
1105     if (!chkscan(m, ml, v->buf, v->len, +1))
1106       ok = 0;
1107     MPX_OCTETS(oct, m, ml);
1108     mpx_storel(m, ml, d.buf, d.sz);
1109     if (memcmp(d.buf, v->buf, oct) != 0) {
1110       dumpbits("\n*** storel failed", d.buf, d.sz);
1111       ok = 0;
1112     }
1113
1114     mpx_loadb(m, ml, v->buf, v->len);
1115     if (!chkscan(m, ml, v->buf + v->len - 1, v->len, -1))
1116       ok = 0;
1117     MPX_OCTETS(oct, m, ml);
1118     mpx_storeb(m, ml, d.buf, d.sz);
1119     if (memcmp(d.buf + d.sz - oct, v->buf + v->len - oct, oct) != 0) {
1120       dumpbits("\n*** storeb failed", d.buf, d.sz);
1121       ok = 0;
1122     }
1123   }
1124
1125   if (!ok)
1126     dumpbits("input data", v->buf, v->len);
1127
1128   free(m);
1129   dstr_destroy(&d);
1130   return (ok);
1131 }
1132
1133 static int lsl(dstr *v)
1134 {
1135   mpw *a, *al;
1136   int n = *(int *)v[1].buf;
1137   mpw *c, *cl;
1138   mpw *d, *dl;
1139   int ok = 1;
1140
1141   LOAD(a, al, &v[0]);
1142   LOAD(c, cl, &v[2]);
1143   ALLOC(d, dl, al - a + (n + MPW_BITS - 1) / MPW_BITS);
1144
1145   mpx_lsl(d, dl, a, al, n);
1146   if (!mpx_ueq(d, dl, c, cl)) {
1147     fprintf(stderr, "\n*** lsl(%i) failed\n", n);
1148     dumpmp("       a", a, al);
1149     dumpmp("expected", c, cl);
1150     dumpmp("  result", d, dl);
1151     ok = 0;
1152   }
1153
1154   free(a); free(c); free(d);
1155   return (ok);
1156 }
1157
1158 static int lsr(dstr *v)
1159 {
1160   mpw *a, *al;
1161   int n = *(int *)v[1].buf;
1162   mpw *c, *cl;
1163   mpw *d, *dl;
1164   int ok = 1;
1165
1166   LOAD(a, al, &v[0]);
1167   LOAD(c, cl, &v[2]);
1168   ALLOC(d, dl, al - a + (n + MPW_BITS - 1) / MPW_BITS + 1);
1169
1170   mpx_lsr(d, dl, a, al, n);
1171   if (!mpx_ueq(d, dl, c, cl)) {
1172     fprintf(stderr, "\n*** lsr(%i) failed\n", n);
1173     dumpmp("       a", a, al);
1174     dumpmp("expected", c, cl);
1175     dumpmp("  result", d, dl);
1176     ok = 0;
1177   }
1178
1179   free(a); free(c); free(d);
1180   return (ok);
1181 }
1182
1183 static int uadd(dstr *v)
1184 {
1185   mpw *a, *al;
1186   mpw *b, *bl;
1187   mpw *c, *cl;
1188   mpw *d, *dl;
1189   int ok = 1;
1190
1191   LOAD(a, al, &v[0]);
1192   LOAD(b, bl, &v[1]);
1193   LOAD(c, cl, &v[2]);
1194   ALLOC(d, dl, MAX(al - a, bl - b) + 1);
1195
1196   mpx_uadd(d, dl, a, al, b, bl);
1197   if (!mpx_ueq(d, dl, c, cl)) {
1198     fprintf(stderr, "\n*** uadd failed\n");
1199     dumpmp("       a", a, al);
1200     dumpmp("       b", b, bl);
1201     dumpmp("expected", c, cl);
1202     dumpmp("  result", d, dl);
1203     ok = 0;
1204   }
1205
1206   free(a); free(b); free(c); free(d);
1207   return (ok);
1208 }
1209
1210 static int usub(dstr *v)
1211 {
1212   mpw *a, *al;
1213   mpw *b, *bl;
1214   mpw *c, *cl;
1215   mpw *d, *dl;
1216   int ok = 1;
1217
1218   LOAD(a, al, &v[0]);
1219   LOAD(b, bl, &v[1]);
1220   LOAD(c, cl, &v[2]);
1221   ALLOC(d, dl, al - a);
1222
1223   mpx_usub(d, dl, a, al, b, bl);
1224   if (!mpx_ueq(d, dl, c, cl)) {
1225     fprintf(stderr, "\n*** usub failed\n");
1226     dumpmp("       a", a, al);
1227     dumpmp("       b", b, bl);
1228     dumpmp("expected", c, cl);
1229     dumpmp("  result", d, dl);
1230     ok = 0;
1231   }
1232
1233   free(a); free(b); free(c); free(d);
1234   return (ok);
1235 }
1236
1237 static int umul(dstr *v)
1238 {
1239   mpw *a, *al;
1240   mpw *b, *bl;
1241   mpw *c, *cl;
1242   mpw *d, *dl;
1243   int ok = 1;
1244
1245   LOAD(a, al, &v[0]);
1246   LOAD(b, bl, &v[1]);
1247   LOAD(c, cl, &v[2]);
1248   ALLOC(d, dl, (al - a) + (bl - b));
1249
1250   mpx_umul(d, dl, a, al, b, bl);
1251   if (!mpx_ueq(d, dl, c, cl)) {
1252     fprintf(stderr, "\n*** umul failed\n");
1253     dumpmp("       a", a, al);
1254     dumpmp("       b", b, bl);
1255     dumpmp("expected", c, cl);
1256     dumpmp("  result", d, dl);
1257     ok = 0;
1258   }
1259
1260   free(a); free(b); free(c); free(d);
1261   return (ok);
1262 }
1263
1264 static int usqr(dstr *v)
1265 {
1266   mpw *a, *al;
1267   mpw *c, *cl;
1268   mpw *d, *dl;
1269   int ok = 1;
1270
1271   LOAD(a, al, &v[0]);
1272   LOAD(c, cl, &v[1]);
1273   ALLOC(d, dl, 2 * (al - a));
1274
1275   mpx_usqr(d, dl, a, al);
1276   if (!mpx_ueq(d, dl, c, cl)) {
1277     fprintf(stderr, "\n*** usqr failed\n");
1278     dumpmp("       a", a, al);
1279     dumpmp("expected", c, cl);
1280     dumpmp("  result", d, dl);
1281     ok = 0;
1282   }
1283
1284   free(a); free(c); free(d);
1285   return (ok);
1286 }
1287
1288 static int udiv(dstr *v)
1289 {
1290   mpw *a, *al;
1291   mpw *b, *bl;
1292   mpw *q, *ql;
1293   mpw *r, *rl;
1294   mpw *qq, *qql;
1295   mpw *s, *sl;
1296   int ok = 1;
1297
1298   ALLOC(a, al, MPW_RQ(v[0].len) + 2); mpx_loadb(a, al, v[0].buf, v[0].len);
1299   LOAD(b, bl, &v[1]);
1300   LOAD(q, ql, &v[2]);
1301   LOAD(r, rl, &v[3]);
1302   ALLOC(qq, qql, al - a);
1303   ALLOC(s, sl, (bl - b) + 1);
1304
1305   mpx_udiv(qq, qql, a, al, b, bl, s, sl);
1306   if (!mpx_ueq(qq, qql, q, ql) ||
1307       !mpx_ueq(a, al, r, rl)) {
1308     fprintf(stderr, "\n*** udiv failed\n");
1309     dumpmp(" divisor", b, bl);
1310     dumpmp("expect r", r, rl);
1311     dumpmp("result r", a, al);
1312     dumpmp("expect q", q, ql);
1313     dumpmp("result q", qq, qql);
1314     ok = 0;
1315   }
1316
1317   free(a); free(b); free(r); free(q); free(s); free(qq);
1318   return (ok);
1319 }
1320
1321 static test_chunk defs[] = {
1322   { "load-store", loadstore, { &type_hex, 0 } },
1323   { "lsl", lsl, { &type_hex, &type_int, &type_hex, 0 } },
1324   { "lsr", lsr, { &type_hex, &type_int, &type_hex, 0 } },
1325   { "uadd", uadd, { &type_hex, &type_hex, &type_hex, 0 } },
1326   { "usub", usub, { &type_hex, &type_hex, &type_hex, 0 } },
1327   { "umul", umul, { &type_hex, &type_hex, &type_hex, 0 } },
1328   { "usqr", usqr, { &type_hex, &type_hex, 0 } },
1329   { "udiv", udiv, { &type_hex, &type_hex, &type_hex, &type_hex, 0 } },
1330   { 0, 0, { 0 } }
1331 };
1332
1333 int main(int argc, char *argv[])
1334 {
1335   test_run(argc, argv, defs, SRCDIR"/tests/mpx");
1336   return (0);
1337 }
1338
1339 #endif
1340
1341 /*----- That's all, folks -------------------------------------------------*/