]> granicus.if.org Git - postgresql/blob - contrib/pgcrypto/imath.c
Fix most -Wundef warnings
[postgresql] / contrib / pgcrypto / imath.c
1 /*-------------------------------------------------------------------------
2  *
3  * imath.c
4  *
5  * Last synchronized from https://github.com/creachadair/imath/tree/v1.29,
6  * using the following procedure:
7  *
8  * 1. Download imath.c and imath.h of the last synchronized version.  Remove
9  *    "#ifdef __cplusplus" blocks, which upset pgindent.  Run pgindent on the
10  *    two files.  Filter the two files through "unexpand -t4 --first-only".
11  *    Diff the result against the PostgreSQL versions.  As of the last
12  *    synchronization, changes were as follows:
13  *
14  *    - replace malloc(), realloc() and free() with px_ versions
15  *    - redirect assert() to Assert()
16  *    - #undef MIN, #undef MAX before defining them
17  *    - remove includes covered by c.h
18  *    - rename DEBUG to IMATH_DEBUG
19  *    - replace stdint.h usage with c.h equivalents
20  *    - suppress MSVC warning 4146
21  *    - add required PG_USED_FOR_ASSERTS_ONLY
22  *
23  * 2. Download a newer imath.c and imath.h.  Transform them like in step 1.
24  *    Apply to these files the diff you saved in step 1.  Look for new lines
25  *    requiring the same kind of change, such as new malloc() calls.
26  *
27  * 3. Configure PostgreSQL using --without-openssl.  Run "make -C
28  *    contrib/pgcrypto check".
29  *
30  * 4. Update this header comment.
31  *
32  * Portions Copyright (c) 1996-2019, PostgreSQL Global Development Group
33  *
34  * IDENTIFICATION
35  *        contrib/pgcrypto/imath.c
36  *
37  * Upstream copyright terms follow.
38  *-------------------------------------------------------------------------
39  */
40
41 /*
42   Name:         imath.c
43   Purpose:      Arbitrary precision integer arithmetic routines.
44   Author:   M. J. Fromberger
45
46   Copyright (C) 2002-2007 Michael J. Fromberger, All Rights Reserved.
47
48   Permission is hereby granted, free of charge, to any person obtaining a copy
49   of this software and associated documentation files (the "Software"), to deal
50   in the Software without restriction, including without limitation the rights
51   to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
52   copies of the Software, and to permit persons to whom the Software is
53   furnished to do so, subject to the following conditions:
54
55   The above copyright notice and this permission notice shall be included in
56   all copies or substantial portions of the Software.
57
58   THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
59   IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
60   FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
61   AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
62   LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
63   OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
64   SOFTWARE.
65  */
66
67 #include "postgres.h"
68
69 #include "imath.h"
70 #include "px.h"
71
72 #undef assert
73 #define assert(TEST) Assert(TEST)
74
75 const mp_result MP_OK = 0;              /* no error, all is well  */
76 const mp_result MP_FALSE = 0;   /* boolean false          */
77 const mp_result MP_TRUE = -1;   /* boolean true           */
78 const mp_result MP_MEMORY = -2; /* out of memory          */
79 const mp_result MP_RANGE = -3;  /* argument out of range  */
80 const mp_result MP_UNDEF = -4;  /* result undefined       */
81 const mp_result MP_TRUNC = -5;  /* output truncated       */
82 const mp_result MP_BADARG = -6; /* invalid null argument  */
83 const mp_result MP_MINERR = -6;
84
85 const mp_sign MP_NEG = 1;               /* value is strictly negative */
86 const mp_sign MP_ZPOS = 0;              /* value is non-negative      */
87
88 static const char *s_unknown_err = "unknown result code";
89 static const char *s_error_msg[] = {"error code 0", "boolean true",
90         "out of memory", "argument out of range",
91         "result undefined", "output truncated",
92 "invalid argument", NULL};
93
94 /* The ith entry of this table gives the value of log_i(2).
95
96    An integer value n requires ceil(log_i(n)) digits to be represented
97    in base i.  Since it is easy to compute lg(n), by counting bits, we
98    can compute log_i(n) = lg(n) * log_i(2).
99
100    The use of this table eliminates a dependency upon linkage against
101    the standard math libraries.
102
103    If MP_MAX_RADIX is increased, this table should be expanded too.
104  */
105 static const double s_log2[] = {
106         0.000000000, 0.000000000, 1.000000000, 0.630929754, /* (D)(D) 2  3 */
107         0.500000000, 0.430676558, 0.386852807, 0.356207187, /* 4  5  6  7 */
108         0.333333333, 0.315464877, 0.301029996, 0.289064826, /* 8  9 10 11 */
109         0.278942946, 0.270238154, 0.262649535, 0.255958025, /* 12 13 14 15 */
110         0.250000000, 0.244650542, 0.239812467, 0.235408913, /* 16 17 18 19 */
111         0.231378213, 0.227670249, 0.224243824, 0.221064729, /* 20 21 22 23 */
112         0.218104292, 0.215338279, 0.212746054, 0.210309918, /* 24 25 26 27 */
113         0.208014598, 0.205846832, 0.203795047, 0.201849087, /* 28 29 30 31 */
114         0.200000000, 0.198239863, 0.196561632, 0.194959022, /* 32 33 34 35 */
115         0.193426404,                            /* 36          */
116 };
117
118 /* Return the number of digits needed to represent a static value */
119 #define MP_VALUE_DIGITS(V) \
120   ((sizeof(V) + (sizeof(mp_digit) - 1)) / sizeof(mp_digit))
121
122 /* Round precision P to nearest word boundary */
123 static inline mp_size
124 s_round_prec(mp_size P)
125 {
126         return 2 * ((P + 1) / 2);
127 }
128
129 /* Set array P of S digits to zero */
130 static inline void
131 ZERO(mp_digit *P, mp_size S)
132 {
133         mp_size         i__ = S * sizeof(mp_digit);
134         mp_digit   *p__ = P;
135
136         memset(p__, 0, i__);
137 }
138
139 /* Copy S digits from array P to array Q */
140 static inline void
141 COPY(mp_digit *P, mp_digit *Q, mp_size S)
142 {
143         mp_size         i__ = S * sizeof(mp_digit);
144         mp_digit   *p__ = P;
145         mp_digit   *q__ = Q;
146
147         memcpy(q__, p__, i__);
148 }
149
150 /* Reverse N elements of unsigned char in A. */
151 static inline void
152 REV(unsigned char *A, int N)
153 {
154         unsigned char *u_ = A;
155         unsigned char *v_ = u_ + N - 1;
156
157         while (u_ < v_)
158         {
159                 unsigned char xch = *u_;
160
161                 *u_++ = *v_;
162                 *v_-- = xch;
163         }
164 }
165
166 /* Strip leading zeroes from z_ in-place. */
167 static inline void
168 CLAMP(mp_int z_)
169 {
170         mp_size         uz_ = MP_USED(z_);
171         mp_digit   *dz_ = MP_DIGITS(z_) + uz_ - 1;
172
173         while (uz_ > 1 && (*dz_-- == 0))
174                 --uz_;
175         z_->used = uz_;
176 }
177
178 /* Select min/max. */
179 #undef MIN
180 #undef MAX
181 static inline int
182 MIN(int A, int B)
183 {
184         return (B < A ? B : A);
185 }
186 static inline mp_size
187 MAX(mp_size A, mp_size B)
188 {
189         return (B > A ? B : A);
190 }
191
192 /* Exchange lvalues A and B of type T, e.g.
193    SWAP(int, x, y) where x and y are variables of type int. */
194 #define SWAP(T, A, B) \
195   do {                \
196         T t_ = (A);       \
197         A = (B);          \
198         B = t_;           \
199   } while (0)
200
201 /* Declare a block of N temporary mpz_t values.
202    These values are initialized to zero.
203    You must add CLEANUP_TEMP() at the end of the function.
204    Use TEMP(i) to access a pointer to the ith value.
205  */
206 #define DECLARE_TEMP(N)                   \
207   struct {                                \
208         mpz_t value[(N)];                     \
209         int len;                              \
210         mp_result err;                        \
211   } temp_ = {                             \
212           .len = (N),                         \
213           .err = MP_OK,                       \
214   };                                      \
215   do {                                    \
216         for (int i = 0; i < temp_.len; i++) { \
217           mp_int_init(TEMP(i));               \
218         }                                     \
219   } while (0)
220
221 /* Clear all allocated temp values. */
222 #define CLEANUP_TEMP()                    \
223   CLEANUP:                                \
224   do {                                    \
225         for (int i = 0; i < temp_.len; i++) { \
226           mp_int_clear(TEMP(i));              \
227         }                                     \
228         if (temp_.err != MP_OK) {             \
229           return temp_.err;                   \
230         }                                     \
231   } while (0)
232
233 /* A pointer to the kth temp value. */
234 #define TEMP(K) (temp_.value + (K))
235
236 /* Evaluate E, an expression of type mp_result expected to return MP_OK.  If
237    the value is not MP_OK, the error is cached and control resumes at the
238    cleanup handler, which returns it.
239 */
240 #define REQUIRE(E)                        \
241   do {                                    \
242         temp_.err = (E);                      \
243         if (temp_.err != MP_OK) goto CLEANUP; \
244   } while (0)
245
246 /* Compare value to zero. */
247 static inline int
248 CMPZ(mp_int Z)
249 {
250         if (Z->used == 1 && Z->digits[0] == 0)
251                 return 0;
252         return (Z->sign == MP_NEG) ? -1 : 1;
253 }
254
255 static inline mp_word
256 UPPER_HALF(mp_word W)
257 {
258         return (W >> MP_DIGIT_BIT);
259 }
260 static inline mp_digit
261 LOWER_HALF(mp_word W)
262 {
263         return (mp_digit) (W);
264 }
265
266 /* Report whether the highest-order bit of W is 1. */
267 static inline bool
268 HIGH_BIT_SET(mp_word W)
269 {
270         return (W >> (MP_WORD_BIT - 1)) != 0;
271 }
272
273 /* Report whether adding W + V will carry out. */
274 static inline bool
275 ADD_WILL_OVERFLOW(mp_word W, mp_word V)
276 {
277         return ((MP_WORD_MAX - V) < W);
278 }
279
280 /* Default number of digits allocated to a new mp_int */
281 static mp_size default_precision = 8;
282
283 void
284 mp_int_default_precision(mp_size size)
285 {
286         assert(size > 0);
287         default_precision = size;
288 }
289
290 /* Minimum number of digits to invoke recursive multiply */
291 static mp_size multiply_threshold = 32;
292
293 void
294 mp_int_multiply_threshold(mp_size thresh)
295 {
296         assert(thresh >= sizeof(mp_word));
297         multiply_threshold = thresh;
298 }
299
300 /* Allocate a buffer of (at least) num digits, or return
301    NULL if that couldn't be done.  */
302 static mp_digit *s_alloc(mp_size num);
303
304 /* Release a buffer of digits allocated by s_alloc(). */
305 static void s_free(void *ptr);
306
307 /* Insure that z has at least min digits allocated, resizing if
308    necessary.  Returns true if successful, false if out of memory. */
309 static bool s_pad(mp_int z, mp_size min);
310
311 /* Ensure Z has at least N digits allocated. */
312 static inline mp_result
313 GROW(mp_int Z, mp_size N)
314 {
315         return s_pad(Z, N) ? MP_OK : MP_MEMORY;
316 }
317
318 /* Fill in a "fake" mp_int on the stack with a given value */
319 static void s_fake(mp_int z, mp_small value, mp_digit vbuf[]);
320 static void s_ufake(mp_int z, mp_usmall value, mp_digit vbuf[]);
321
322 /* Compare two runs of digits of given length, returns <0, 0, >0 */
323 static int      s_cdig(mp_digit *da, mp_digit *db, mp_size len);
324
325 /* Pack the unsigned digits of v into array t */
326 static int      s_uvpack(mp_usmall v, mp_digit t[]);
327
328 /* Compare magnitudes of a and b, returns <0, 0, >0 */
329 static int      s_ucmp(mp_int a, mp_int b);
330
331 /* Compare magnitudes of a and v, returns <0, 0, >0 */
332 static int      s_vcmp(mp_int a, mp_small v);
333 static int      s_uvcmp(mp_int a, mp_usmall uv);
334
335 /* Unsigned magnitude addition; assumes dc is big enough.
336    Carry out is returned (no memory allocated). */
337 static mp_digit s_uadd(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a,
338                                            mp_size size_b);
339
340 /* Unsigned magnitude subtraction.  Assumes dc is big enough. */
341 static void s_usub(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a,
342                                    mp_size size_b);
343
344 /* Unsigned recursive multiplication.  Assumes dc is big enough. */
345 static int      s_kmul(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a,
346                                    mp_size size_b);
347
348 /* Unsigned magnitude multiplication.  Assumes dc is big enough. */
349 static void s_umul(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a,
350                                    mp_size size_b);
351
352 /* Unsigned recursive squaring.  Assumes dc is big enough. */
353 static int      s_ksqr(mp_digit *da, mp_digit *dc, mp_size size_a);
354
355 /* Unsigned magnitude squaring.  Assumes dc is big enough. */
356 static void s_usqr(mp_digit *da, mp_digit *dc, mp_size size_a);
357
358 /* Single digit addition.  Assumes a is big enough. */
359 static void s_dadd(mp_int a, mp_digit b);
360
361 /* Single digit multiplication.  Assumes a is big enough. */
362 static void s_dmul(mp_int a, mp_digit b);
363
364 /* Single digit multiplication on buffers; assumes dc is big enough. */
365 static void s_dbmul(mp_digit *da, mp_digit b, mp_digit *dc, mp_size size_a);
366
367 /* Single digit division.  Replaces a with the quotient,
368    returns the remainder.  */
369 static mp_digit s_ddiv(mp_int a, mp_digit b);
370
371 /* Quick division by a power of 2, replaces z (no allocation) */
372 static void s_qdiv(mp_int z, mp_size p2);
373
374 /* Quick remainder by a power of 2, replaces z (no allocation) */
375 static void s_qmod(mp_int z, mp_size p2);
376
377 /* Quick multiplication by a power of 2, replaces z.
378    Allocates if necessary; returns false in case this fails. */
379 static int      s_qmul(mp_int z, mp_size p2);
380
381 /* Quick subtraction from a power of 2, replaces z.
382    Allocates if necessary; returns false in case this fails. */
383 static int      s_qsub(mp_int z, mp_size p2);
384
385 /* Return maximum k such that 2^k divides z. */
386 static int      s_dp2k(mp_int z);
387
388 /* Return k >= 0 such that z = 2^k, or -1 if there is no such k. */
389 static int      s_isp2(mp_int z);
390
391 /* Set z to 2^k.  May allocate; returns false in case this fails. */
392 static int      s_2expt(mp_int z, mp_small k);
393
394 /* Normalize a and b for division, returns normalization constant */
395 static int      s_norm(mp_int a, mp_int b);
396
397 /* Compute constant mu for Barrett reduction, given modulus m, result
398    replaces z, m is untouched. */
399 static mp_result s_brmu(mp_int z, mp_int m);
400
401 /* Reduce a modulo m, using Barrett's algorithm. */
402 static int      s_reduce(mp_int x, mp_int m, mp_int mu, mp_int q1, mp_int q2);
403
404 /* Modular exponentiation, using Barrett reduction */
405 static mp_result s_embar(mp_int a, mp_int b, mp_int m, mp_int mu, mp_int c);
406
407 /* Unsigned magnitude division.  Assumes |a| > |b|.  Allocates temporaries;
408    overwrites a with quotient, b with remainder. */
409 static mp_result s_udiv_knuth(mp_int a, mp_int b);
410
411 /* Compute the number of digits in radix r required to represent the given
412    value.  Does not account for sign flags, terminators, etc. */
413 static int      s_outlen(mp_int z, mp_size r);
414
415 /* Guess how many digits of precision will be needed to represent a radix r
416    value of the specified number of digits.  Returns a value guaranteed to be
417    no smaller than the actual number required. */
418 static mp_size s_inlen(int len, mp_size r);
419
420 /* Convert a character to a digit value in radix r, or
421    -1 if out of range */
422 static int      s_ch2val(char c, int r);
423
424 /* Convert a digit value to a character */
425 static char s_val2ch(int v, int caps);
426
427 /* Take 2's complement of a buffer in place */
428 static void s_2comp(unsigned char *buf, int len);
429
430 /* Convert a value to binary, ignoring sign.  On input, *limpos is the bound on
431    how many bytes should be written to buf; on output, *limpos is set to the
432    number of bytes actually written. */
433 static mp_result s_tobin(mp_int z, unsigned char *buf, int *limpos, int pad);
434
435 /* Multiply X by Y into Z, ignoring signs.  Requires that Z have enough storage
436    preallocated to hold the result. */
437 static inline void
438 UMUL(mp_int X, mp_int Y, mp_int Z)
439 {
440         mp_size         ua_ = MP_USED(X);
441         mp_size         ub_ = MP_USED(Y);
442         mp_size         o_ = ua_ + ub_;
443
444         ZERO(MP_DIGITS(Z), o_);
445         (void) s_kmul(MP_DIGITS(X), MP_DIGITS(Y), MP_DIGITS(Z), ua_, ub_);
446         Z->used = o_;
447         CLAMP(Z);
448 }
449
450 /* Square X into Z.  Requires that Z have enough storage to hold the result. */
451 static inline void
452 USQR(mp_int X, mp_int Z)
453 {
454         mp_size         ua_ = MP_USED(X);
455         mp_size         o_ = ua_ + ua_;
456
457         ZERO(MP_DIGITS(Z), o_);
458         (void) s_ksqr(MP_DIGITS(X), MP_DIGITS(Z), ua_);
459         Z->used = o_;
460         CLAMP(Z);
461 }
462
463 mp_result
464 mp_int_init(mp_int z)
465 {
466         if (z == NULL)
467                 return MP_BADARG;
468
469         z->single = 0;
470         z->digits = &(z->single);
471         z->alloc = 1;
472         z->used = 1;
473         z->sign = MP_ZPOS;
474
475         return MP_OK;
476 }
477
478 mp_int
479 mp_int_alloc(void)
480 {
481         mp_int          out = px_alloc(sizeof(mpz_t));
482
483         if (out != NULL)
484                 mp_int_init(out);
485
486         return out;
487 }
488
489 mp_result
490 mp_int_init_size(mp_int z, mp_size prec)
491 {
492         assert(z != NULL);
493
494         if (prec == 0)
495         {
496                 prec = default_precision;
497         }
498         else if (prec == 1)
499         {
500                 return mp_int_init(z);
501         }
502         else
503         {
504                 prec = s_round_prec(prec);
505         }
506
507         z->digits = s_alloc(prec);
508         if (MP_DIGITS(z) == NULL)
509                 return MP_MEMORY;
510
511         z->digits[0] = 0;
512         z->used = 1;
513         z->alloc = prec;
514         z->sign = MP_ZPOS;
515
516         return MP_OK;
517 }
518
519 mp_result
520 mp_int_init_copy(mp_int z, mp_int old)
521 {
522         assert(z != NULL && old != NULL);
523
524         mp_size         uold = MP_USED(old);
525
526         if (uold == 1)
527         {
528                 mp_int_init(z);
529         }
530         else
531         {
532                 mp_size         target = MAX(uold, default_precision);
533                 mp_result       res = mp_int_init_size(z, target);
534
535                 if (res != MP_OK)
536                         return res;
537         }
538
539         z->used = uold;
540         z->sign = old->sign;
541         COPY(MP_DIGITS(old), MP_DIGITS(z), uold);
542
543         return MP_OK;
544 }
545
546 mp_result
547 mp_int_init_value(mp_int z, mp_small value)
548 {
549         mpz_t           vtmp;
550         mp_digit        vbuf[MP_VALUE_DIGITS(value)];
551
552         s_fake(&vtmp, value, vbuf);
553         return mp_int_init_copy(z, &vtmp);
554 }
555
556 mp_result
557 mp_int_init_uvalue(mp_int z, mp_usmall uvalue)
558 {
559         mpz_t           vtmp;
560         mp_digit        vbuf[MP_VALUE_DIGITS(uvalue)];
561
562         s_ufake(&vtmp, uvalue, vbuf);
563         return mp_int_init_copy(z, &vtmp);
564 }
565
566 mp_result
567 mp_int_set_value(mp_int z, mp_small value)
568 {
569         mpz_t           vtmp;
570         mp_digit        vbuf[MP_VALUE_DIGITS(value)];
571
572         s_fake(&vtmp, value, vbuf);
573         return mp_int_copy(&vtmp, z);
574 }
575
576 mp_result
577 mp_int_set_uvalue(mp_int z, mp_usmall uvalue)
578 {
579         mpz_t           vtmp;
580         mp_digit        vbuf[MP_VALUE_DIGITS(uvalue)];
581
582         s_ufake(&vtmp, uvalue, vbuf);
583         return mp_int_copy(&vtmp, z);
584 }
585
586 void
587 mp_int_clear(mp_int z)
588 {
589         if (z == NULL)
590                 return;
591
592         if (MP_DIGITS(z) != NULL)
593         {
594                 if (MP_DIGITS(z) != &(z->single))
595                         s_free(MP_DIGITS(z));
596
597                 z->digits = NULL;
598         }
599 }
600
601 void
602 mp_int_free(mp_int z)
603 {
604         assert(z != NULL);
605
606         mp_int_clear(z);
607         px_free(z);                                     /* note: NOT s_free() */
608 }
609
610 mp_result
611 mp_int_copy(mp_int a, mp_int c)
612 {
613         assert(a != NULL && c != NULL);
614
615         if (a != c)
616         {
617                 mp_size         ua = MP_USED(a);
618                 mp_digit   *da,
619                                    *dc;
620
621                 if (!s_pad(c, ua))
622                         return MP_MEMORY;
623
624                 da = MP_DIGITS(a);
625                 dc = MP_DIGITS(c);
626                 COPY(da, dc, ua);
627
628                 c->used = ua;
629                 c->sign = a->sign;
630         }
631
632         return MP_OK;
633 }
634
635 void
636 mp_int_swap(mp_int a, mp_int c)
637 {
638         if (a != c)
639         {
640                 mpz_t           tmp = *a;
641
642                 *a = *c;
643                 *c = tmp;
644
645                 if (MP_DIGITS(a) == &(c->single))
646                         a->digits = &(a->single);
647                 if (MP_DIGITS(c) == &(a->single))
648                         c->digits = &(c->single);
649         }
650 }
651
652 void
653 mp_int_zero(mp_int z)
654 {
655         assert(z != NULL);
656
657         z->digits[0] = 0;
658         z->used = 1;
659         z->sign = MP_ZPOS;
660 }
661
662 mp_result
663 mp_int_abs(mp_int a, mp_int c)
664 {
665         assert(a != NULL && c != NULL);
666
667         mp_result       res;
668
669         if ((res = mp_int_copy(a, c)) != MP_OK)
670                 return res;
671
672         c->sign = MP_ZPOS;
673         return MP_OK;
674 }
675
676 mp_result
677 mp_int_neg(mp_int a, mp_int c)
678 {
679         assert(a != NULL && c != NULL);
680
681         mp_result       res;
682
683         if ((res = mp_int_copy(a, c)) != MP_OK)
684                 return res;
685
686         if (CMPZ(c) != 0)
687                 c->sign = 1 - MP_SIGN(a);
688
689         return MP_OK;
690 }
691
692 mp_result
693 mp_int_add(mp_int a, mp_int b, mp_int c)
694 {
695         assert(a != NULL && b != NULL && c != NULL);
696
697         mp_size         ua = MP_USED(a);
698         mp_size         ub = MP_USED(b);
699         mp_size         max = MAX(ua, ub);
700
701         if (MP_SIGN(a) == MP_SIGN(b))
702         {
703                 /* Same sign -- add magnitudes, preserve sign of addends */
704                 if (!s_pad(c, max))
705                         return MP_MEMORY;
706
707                 mp_digit        carry = s_uadd(MP_DIGITS(a), MP_DIGITS(b), MP_DIGITS(c), ua, ub);
708                 mp_size         uc = max;
709
710                 if (carry)
711                 {
712                         if (!s_pad(c, max + 1))
713                                 return MP_MEMORY;
714
715                         c->digits[max] = carry;
716                         ++uc;
717                 }
718
719                 c->used = uc;
720                 c->sign = a->sign;
721
722         }
723         else
724         {
725                 /* Different signs -- subtract magnitudes, preserve sign of greater */
726                 int                     cmp = s_ucmp(a, b); /* magnitude comparision, sign ignored */
727
728                 /*
729                  * Set x to max(a, b), y to min(a, b) to simplify later code. A
730                  * special case yields zero for equal magnitudes.
731                  */
732                 mp_int          x,
733                                         y;
734
735                 if (cmp == 0)
736                 {
737                         mp_int_zero(c);
738                         return MP_OK;
739                 }
740                 else if (cmp < 0)
741                 {
742                         x = b;
743                         y = a;
744                 }
745                 else
746                 {
747                         x = a;
748                         y = b;
749                 }
750
751                 if (!s_pad(c, MP_USED(x)))
752                         return MP_MEMORY;
753
754                 /* Subtract smaller from larger */
755                 s_usub(MP_DIGITS(x), MP_DIGITS(y), MP_DIGITS(c), MP_USED(x), MP_USED(y));
756                 c->used = x->used;
757                 CLAMP(c);
758
759                 /* Give result the sign of the larger */
760                 c->sign = x->sign;
761         }
762
763         return MP_OK;
764 }
765
766 mp_result
767 mp_int_add_value(mp_int a, mp_small value, mp_int c)
768 {
769         mpz_t           vtmp;
770         mp_digit        vbuf[MP_VALUE_DIGITS(value)];
771
772         s_fake(&vtmp, value, vbuf);
773
774         return mp_int_add(a, &vtmp, c);
775 }
776
777 mp_result
778 mp_int_sub(mp_int a, mp_int b, mp_int c)
779 {
780         assert(a != NULL && b != NULL && c != NULL);
781
782         mp_size         ua = MP_USED(a);
783         mp_size         ub = MP_USED(b);
784         mp_size         max = MAX(ua, ub);
785
786         if (MP_SIGN(a) != MP_SIGN(b))
787         {
788                 /* Different signs -- add magnitudes and keep sign of a */
789                 if (!s_pad(c, max))
790                         return MP_MEMORY;
791
792                 mp_digit        carry = s_uadd(MP_DIGITS(a), MP_DIGITS(b), MP_DIGITS(c), ua, ub);
793                 mp_size         uc = max;
794
795                 if (carry)
796                 {
797                         if (!s_pad(c, max + 1))
798                                 return MP_MEMORY;
799
800                         c->digits[max] = carry;
801                         ++uc;
802                 }
803
804                 c->used = uc;
805                 c->sign = a->sign;
806
807         }
808         else
809         {
810                 /* Same signs -- subtract magnitudes */
811                 if (!s_pad(c, max))
812                         return MP_MEMORY;
813                 mp_int          x,
814                                         y;
815                 mp_sign         osign;
816
817                 int                     cmp = s_ucmp(a, b);
818
819                 if (cmp >= 0)
820                 {
821                         x = a;
822                         y = b;
823                         osign = MP_ZPOS;
824                 }
825                 else
826                 {
827                         x = b;
828                         y = a;
829                         osign = MP_NEG;
830                 }
831
832                 if (MP_SIGN(a) == MP_NEG && cmp != 0)
833                         osign = 1 - osign;
834
835                 s_usub(MP_DIGITS(x), MP_DIGITS(y), MP_DIGITS(c), MP_USED(x), MP_USED(y));
836                 c->used = x->used;
837                 CLAMP(c);
838
839                 c->sign = osign;
840         }
841
842         return MP_OK;
843 }
844
845 mp_result
846 mp_int_sub_value(mp_int a, mp_small value, mp_int c)
847 {
848         mpz_t           vtmp;
849         mp_digit        vbuf[MP_VALUE_DIGITS(value)];
850
851         s_fake(&vtmp, value, vbuf);
852
853         return mp_int_sub(a, &vtmp, c);
854 }
855
856 mp_result
857 mp_int_mul(mp_int a, mp_int b, mp_int c)
858 {
859         assert(a != NULL && b != NULL && c != NULL);
860
861         /* If either input is zero, we can shortcut multiplication */
862         if (mp_int_compare_zero(a) == 0 || mp_int_compare_zero(b) == 0)
863         {
864                 mp_int_zero(c);
865                 return MP_OK;
866         }
867
868         /* Output is positive if inputs have same sign, otherwise negative */
869         mp_sign         osign = (MP_SIGN(a) == MP_SIGN(b)) ? MP_ZPOS : MP_NEG;
870
871         /*
872          * If the output is not identical to any of the inputs, we'll write the
873          * results directly; otherwise, allocate a temporary space.
874          */
875         mp_size         ua = MP_USED(a);
876         mp_size         ub = MP_USED(b);
877         mp_size         osize = MAX(ua, ub);
878
879         osize = 4 * ((osize + 1) / 2);
880
881         mp_digit   *out;
882         mp_size         p = 0;
883
884         if (c == a || c == b)
885         {
886                 p = MAX(s_round_prec(osize), default_precision);
887
888                 if ((out = s_alloc(p)) == NULL)
889                         return MP_MEMORY;
890         }
891         else
892         {
893                 if (!s_pad(c, osize))
894                         return MP_MEMORY;
895
896                 out = MP_DIGITS(c);
897         }
898         ZERO(out, osize);
899
900         if (!s_kmul(MP_DIGITS(a), MP_DIGITS(b), out, ua, ub))
901                 return MP_MEMORY;
902
903         /*
904          * If we allocated a new buffer, get rid of whatever memory c was already
905          * using, and fix up its fields to reflect that.
906          */
907         if (out != MP_DIGITS(c))
908         {
909                 if ((void *) MP_DIGITS(c) != (void *) c)
910                         s_free(MP_DIGITS(c));
911                 c->digits = out;
912                 c->alloc = p;
913         }
914
915         c->used = osize;                        /* might not be true, but we'll fix it ... */
916         CLAMP(c);                                       /* ... right here */
917         c->sign = osign;
918
919         return MP_OK;
920 }
921
922 mp_result
923 mp_int_mul_value(mp_int a, mp_small value, mp_int c)
924 {
925         mpz_t           vtmp;
926         mp_digit        vbuf[MP_VALUE_DIGITS(value)];
927
928         s_fake(&vtmp, value, vbuf);
929
930         return mp_int_mul(a, &vtmp, c);
931 }
932
933 mp_result
934 mp_int_mul_pow2(mp_int a, mp_small p2, mp_int c)
935 {
936         assert(a != NULL && c != NULL && p2 >= 0);
937
938         mp_result       res = mp_int_copy(a, c);
939
940         if (res != MP_OK)
941                 return res;
942
943         if (s_qmul(c, (mp_size) p2))
944         {
945                 return MP_OK;
946         }
947         else
948         {
949                 return MP_MEMORY;
950         }
951 }
952
953 mp_result
954 mp_int_sqr(mp_int a, mp_int c)
955 {
956         assert(a != NULL && c != NULL);
957
958         /* Get a temporary buffer big enough to hold the result */
959         mp_size         osize = (mp_size) 4 * ((MP_USED(a) + 1) / 2);
960         mp_size         p = 0;
961         mp_digit   *out;
962
963         if (a == c)
964         {
965                 p = s_round_prec(osize);
966                 p = MAX(p, default_precision);
967
968                 if ((out = s_alloc(p)) == NULL)
969                         return MP_MEMORY;
970         }
971         else
972         {
973                 if (!s_pad(c, osize))
974                         return MP_MEMORY;
975
976                 out = MP_DIGITS(c);
977         }
978         ZERO(out, osize);
979
980         s_ksqr(MP_DIGITS(a), out, MP_USED(a));
981
982         /*
983          * Get rid of whatever memory c was already using, and fix up its fields
984          * to reflect the new digit array it's using
985          */
986         if (out != MP_DIGITS(c))
987         {
988                 if ((void *) MP_DIGITS(c) != (void *) c)
989                         s_free(MP_DIGITS(c));
990                 c->digits = out;
991                 c->alloc = p;
992         }
993
994         c->used = osize;                        /* might not be true, but we'll fix it ... */
995         CLAMP(c);                                       /* ... right here */
996         c->sign = MP_ZPOS;
997
998         return MP_OK;
999 }
1000
1001 mp_result
1002 mp_int_div(mp_int a, mp_int b, mp_int q, mp_int r)
1003 {
1004         assert(a != NULL && b != NULL && q != r);
1005
1006         int                     cmp;
1007         mp_result       res = MP_OK;
1008         mp_int          qout,
1009                                 rout;
1010         mp_sign         sa = MP_SIGN(a);
1011         mp_sign         sb = MP_SIGN(b);
1012
1013         if (CMPZ(b) == 0)
1014         {
1015                 return MP_UNDEF;
1016         }
1017         else if ((cmp = s_ucmp(a, b)) < 0)
1018         {
1019                 /*
1020                  * If |a| < |b|, no division is required: q = 0, r = a
1021                  */
1022                 if (r && (res = mp_int_copy(a, r)) != MP_OK)
1023                         return res;
1024
1025                 if (q)
1026                         mp_int_zero(q);
1027
1028                 return MP_OK;
1029         }
1030         else if (cmp == 0)
1031         {
1032                 /*
1033                  * If |a| = |b|, no division is required: q = 1 or -1, r = 0
1034                  */
1035                 if (r)
1036                         mp_int_zero(r);
1037
1038                 if (q)
1039                 {
1040                         mp_int_zero(q);
1041                         q->digits[0] = 1;
1042
1043                         if (sa != sb)
1044                                 q->sign = MP_NEG;
1045                 }
1046
1047                 return MP_OK;
1048         }
1049
1050         /*
1051          * When |a| > |b|, real division is required.  We need someplace to store
1052          * quotient and remainder, but q and r are allowed to be NULL or to
1053          * overlap with the inputs.
1054          */
1055         DECLARE_TEMP(2);
1056         int                     lg;
1057
1058         if ((lg = s_isp2(b)) < 0)
1059         {
1060                 if (q && b != q)
1061                 {
1062                         REQUIRE(mp_int_copy(a, q));
1063                         qout = q;
1064                 }
1065                 else
1066                 {
1067                         REQUIRE(mp_int_copy(a, TEMP(0)));
1068                         qout = TEMP(0);
1069                 }
1070
1071                 if (r && a != r)
1072                 {
1073                         REQUIRE(mp_int_copy(b, r));
1074                         rout = r;
1075                 }
1076                 else
1077                 {
1078                         REQUIRE(mp_int_copy(b, TEMP(1)));
1079                         rout = TEMP(1);
1080                 }
1081
1082                 REQUIRE(s_udiv_knuth(qout, rout));
1083         }
1084         else
1085         {
1086                 if (q)
1087                         REQUIRE(mp_int_copy(a, q));
1088                 if (r)
1089                         REQUIRE(mp_int_copy(a, r));
1090
1091                 if (q)
1092                         s_qdiv(q, (mp_size) lg);
1093                 qout = q;
1094                 if (r)
1095                         s_qmod(r, (mp_size) lg);
1096                 rout = r;
1097         }
1098
1099         /* Recompute signs for output */
1100         if (rout)
1101         {
1102                 rout->sign = sa;
1103                 if (CMPZ(rout) == 0)
1104                         rout->sign = MP_ZPOS;
1105         }
1106         if (qout)
1107         {
1108                 qout->sign = (sa == sb) ? MP_ZPOS : MP_NEG;
1109                 if (CMPZ(qout) == 0)
1110                         qout->sign = MP_ZPOS;
1111         }
1112
1113         if (q)
1114                 REQUIRE(mp_int_copy(qout, q));
1115         if (r)
1116                 REQUIRE(mp_int_copy(rout, r));
1117         CLEANUP_TEMP();
1118         return res;
1119 }
1120
1121 mp_result
1122 mp_int_mod(mp_int a, mp_int m, mp_int c)
1123 {
1124         DECLARE_TEMP(1);
1125         mp_int          out = (m == c) ? TEMP(0) : c;
1126
1127         REQUIRE(mp_int_div(a, m, NULL, out));
1128         if (CMPZ(out) < 0)
1129         {
1130                 REQUIRE(mp_int_add(out, m, c));
1131         }
1132         else
1133         {
1134                 REQUIRE(mp_int_copy(out, c));
1135         }
1136         CLEANUP_TEMP();
1137         return MP_OK;
1138 }
1139
1140 mp_result
1141 mp_int_div_value(mp_int a, mp_small value, mp_int q, mp_small *r)
1142 {
1143         mpz_t           vtmp;
1144         mp_digit        vbuf[MP_VALUE_DIGITS(value)];
1145
1146         s_fake(&vtmp, value, vbuf);
1147
1148         DECLARE_TEMP(1);
1149         REQUIRE(mp_int_div(a, &vtmp, q, TEMP(0)));
1150
1151         if (r)
1152                 (void) mp_int_to_int(TEMP(0), r);       /* can't fail */
1153
1154         CLEANUP_TEMP();
1155         return MP_OK;
1156 }
1157
1158 mp_result
1159 mp_int_div_pow2(mp_int a, mp_small p2, mp_int q, mp_int r)
1160 {
1161         assert(a != NULL && p2 >= 0 && q != r);
1162
1163         mp_result       res = MP_OK;
1164
1165         if (q != NULL && (res = mp_int_copy(a, q)) == MP_OK)
1166         {
1167                 s_qdiv(q, (mp_size) p2);
1168         }
1169
1170         if (res == MP_OK && r != NULL && (res = mp_int_copy(a, r)) == MP_OK)
1171         {
1172                 s_qmod(r, (mp_size) p2);
1173         }
1174
1175         return res;
1176 }
1177
1178 mp_result
1179 mp_int_expt(mp_int a, mp_small b, mp_int c)
1180 {
1181         assert(c != NULL);
1182         if (b < 0)
1183                 return MP_RANGE;
1184
1185         DECLARE_TEMP(1);
1186         REQUIRE(mp_int_copy(a, TEMP(0)));
1187
1188         (void) mp_int_set_value(c, 1);
1189         unsigned int v = labs(b);
1190
1191         while (v != 0)
1192         {
1193                 if (v & 1)
1194                 {
1195                         REQUIRE(mp_int_mul(c, TEMP(0), c));
1196                 }
1197
1198                 v >>= 1;
1199                 if (v == 0)
1200                         break;
1201
1202                 REQUIRE(mp_int_sqr(TEMP(0), TEMP(0)));
1203         }
1204
1205         CLEANUP_TEMP();
1206         return MP_OK;
1207 }
1208
1209 mp_result
1210 mp_int_expt_value(mp_small a, mp_small b, mp_int c)
1211 {
1212         assert(c != NULL);
1213         if (b < 0)
1214                 return MP_RANGE;
1215
1216         DECLARE_TEMP(1);
1217         REQUIRE(mp_int_set_value(TEMP(0), a));
1218
1219         (void) mp_int_set_value(c, 1);
1220         unsigned int v = labs(b);
1221
1222         while (v != 0)
1223         {
1224                 if (v & 1)
1225                 {
1226                         REQUIRE(mp_int_mul(c, TEMP(0), c));
1227                 }
1228
1229                 v >>= 1;
1230                 if (v == 0)
1231                         break;
1232
1233                 REQUIRE(mp_int_sqr(TEMP(0), TEMP(0)));
1234         }
1235
1236         CLEANUP_TEMP();
1237         return MP_OK;
1238 }
1239
1240 mp_result
1241 mp_int_expt_full(mp_int a, mp_int b, mp_int c)
1242 {
1243         assert(a != NULL && b != NULL && c != NULL);
1244         if (MP_SIGN(b) == MP_NEG)
1245                 return MP_RANGE;
1246
1247         DECLARE_TEMP(1);
1248         REQUIRE(mp_int_copy(a, TEMP(0)));
1249
1250         (void) mp_int_set_value(c, 1);
1251         for (unsigned ix = 0; ix < MP_USED(b); ++ix)
1252         {
1253                 mp_digit        d = b->digits[ix];
1254
1255                 for (unsigned jx = 0; jx < MP_DIGIT_BIT; ++jx)
1256                 {
1257                         if (d & 1)
1258                         {
1259                                 REQUIRE(mp_int_mul(c, TEMP(0), c));
1260                         }
1261
1262                         d >>= 1;
1263                         if (d == 0 && ix + 1 == MP_USED(b))
1264                                 break;
1265                         REQUIRE(mp_int_sqr(TEMP(0), TEMP(0)));
1266                 }
1267         }
1268
1269         CLEANUP_TEMP();
1270         return MP_OK;
1271 }
1272
1273 int
1274 mp_int_compare(mp_int a, mp_int b)
1275 {
1276         assert(a != NULL && b != NULL);
1277
1278         mp_sign         sa = MP_SIGN(a);
1279
1280         if (sa == MP_SIGN(b))
1281         {
1282                 int                     cmp = s_ucmp(a, b);
1283
1284                 /*
1285                  * If they're both zero or positive, the normal comparison applies; if
1286                  * both negative, the sense is reversed.
1287                  */
1288                 if (sa == MP_ZPOS)
1289                 {
1290                         return cmp;
1291                 }
1292                 else
1293                 {
1294                         return -cmp;
1295                 }
1296         }
1297         else if (sa == MP_ZPOS)
1298         {
1299                 return 1;
1300         }
1301         else
1302         {
1303                 return -1;
1304         }
1305 }
1306
1307 int
1308 mp_int_compare_unsigned(mp_int a, mp_int b)
1309 {
1310         assert(a != NULL && b != NULL);
1311
1312         return s_ucmp(a, b);
1313 }
1314
1315 int
1316 mp_int_compare_zero(mp_int z)
1317 {
1318         assert(z != NULL);
1319
1320         if (MP_USED(z) == 1 && z->digits[0] == 0)
1321         {
1322                 return 0;
1323         }
1324         else if (MP_SIGN(z) == MP_ZPOS)
1325         {
1326                 return 1;
1327         }
1328         else
1329         {
1330                 return -1;
1331         }
1332 }
1333
1334 int
1335 mp_int_compare_value(mp_int z, mp_small value)
1336 {
1337         assert(z != NULL);
1338
1339         mp_sign         vsign = (value < 0) ? MP_NEG : MP_ZPOS;
1340
1341         if (vsign == MP_SIGN(z))
1342         {
1343                 int                     cmp = s_vcmp(z, value);
1344
1345                 return (vsign == MP_ZPOS) ? cmp : -cmp;
1346         }
1347         else
1348         {
1349                 return (value < 0) ? 1 : -1;
1350         }
1351 }
1352
1353 int
1354 mp_int_compare_uvalue(mp_int z, mp_usmall uv)
1355 {
1356         assert(z != NULL);
1357
1358         if (MP_SIGN(z) == MP_NEG)
1359         {
1360                 return -1;
1361         }
1362         else
1363         {
1364                 return s_uvcmp(z, uv);
1365         }
1366 }
1367
1368 mp_result
1369 mp_int_exptmod(mp_int a, mp_int b, mp_int m, mp_int c)
1370 {
1371         assert(a != NULL && b != NULL && c != NULL && m != NULL);
1372
1373         /* Zero moduli and negative exponents are not considered. */
1374         if (CMPZ(m) == 0)
1375                 return MP_UNDEF;
1376         if (CMPZ(b) < 0)
1377                 return MP_RANGE;
1378
1379         mp_size         um = MP_USED(m);
1380
1381         DECLARE_TEMP(3);
1382         REQUIRE(GROW(TEMP(0), 2 * um));
1383         REQUIRE(GROW(TEMP(1), 2 * um));
1384
1385         mp_int          s;
1386
1387         if (c == b || c == m)
1388         {
1389                 REQUIRE(GROW(TEMP(2), 2 * um));
1390                 s = TEMP(2);
1391         }
1392         else
1393         {
1394                 s = c;
1395         }
1396
1397         REQUIRE(mp_int_mod(a, m, TEMP(0)));
1398         REQUIRE(s_brmu(TEMP(1), m));
1399         REQUIRE(s_embar(TEMP(0), b, m, TEMP(1), s));
1400         REQUIRE(mp_int_copy(s, c));
1401
1402         CLEANUP_TEMP();
1403         return MP_OK;
1404 }
1405
1406 mp_result
1407 mp_int_exptmod_evalue(mp_int a, mp_small value, mp_int m, mp_int c)
1408 {
1409         mpz_t           vtmp;
1410         mp_digit        vbuf[MP_VALUE_DIGITS(value)];
1411
1412         s_fake(&vtmp, value, vbuf);
1413
1414         return mp_int_exptmod(a, &vtmp, m, c);
1415 }
1416
1417 mp_result
1418 mp_int_exptmod_bvalue(mp_small value, mp_int b, mp_int m, mp_int c)
1419 {
1420         mpz_t           vtmp;
1421         mp_digit        vbuf[MP_VALUE_DIGITS(value)];
1422
1423         s_fake(&vtmp, value, vbuf);
1424
1425         return mp_int_exptmod(&vtmp, b, m, c);
1426 }
1427
1428 mp_result
1429 mp_int_exptmod_known(mp_int a, mp_int b, mp_int m, mp_int mu,
1430                                          mp_int c)
1431 {
1432         assert(a && b && m && c);
1433
1434         /* Zero moduli and negative exponents are not considered. */
1435         if (CMPZ(m) == 0)
1436                 return MP_UNDEF;
1437         if (CMPZ(b) < 0)
1438                 return MP_RANGE;
1439
1440         DECLARE_TEMP(2);
1441         mp_size         um = MP_USED(m);
1442
1443         REQUIRE(GROW(TEMP(0), 2 * um));
1444
1445         mp_int          s;
1446
1447         if (c == b || c == m)
1448         {
1449                 REQUIRE(GROW(TEMP(1), 2 * um));
1450                 s = TEMP(1);
1451         }
1452         else
1453         {
1454                 s = c;
1455         }
1456
1457         REQUIRE(mp_int_mod(a, m, TEMP(0)));
1458         REQUIRE(s_embar(TEMP(0), b, m, mu, s));
1459         REQUIRE(mp_int_copy(s, c));
1460
1461         CLEANUP_TEMP();
1462         return MP_OK;
1463 }
1464
1465 mp_result
1466 mp_int_redux_const(mp_int m, mp_int c)
1467 {
1468         assert(m != NULL && c != NULL && m != c);
1469
1470         return s_brmu(c, m);
1471 }
1472
1473 mp_result
1474 mp_int_invmod(mp_int a, mp_int m, mp_int c)
1475 {
1476         assert(a != NULL && m != NULL && c != NULL);
1477
1478         if (CMPZ(a) == 0 || CMPZ(m) <= 0)
1479                 return MP_RANGE;
1480
1481         DECLARE_TEMP(2);
1482
1483         REQUIRE(mp_int_egcd(a, m, TEMP(0), TEMP(1), NULL));
1484
1485         if (mp_int_compare_value(TEMP(0), 1) != 0)
1486         {
1487                 REQUIRE(MP_UNDEF);
1488         }
1489
1490         /* It is first necessary to constrain the value to the proper range */
1491         REQUIRE(mp_int_mod(TEMP(1), m, TEMP(1)));
1492
1493         /*
1494          * Now, if 'a' was originally negative, the value we have is actually the
1495          * magnitude of the negative representative; to get the positive value we
1496          * have to subtract from the modulus.  Otherwise, the value is okay as it
1497          * stands.
1498          */
1499         if (MP_SIGN(a) == MP_NEG)
1500         {
1501                 REQUIRE(mp_int_sub(m, TEMP(1), c));
1502         }
1503         else
1504         {
1505                 REQUIRE(mp_int_copy(TEMP(1), c));
1506         }
1507
1508         CLEANUP_TEMP();
1509         return MP_OK;
1510 }
1511
1512 /* Binary GCD algorithm due to Josef Stein, 1961 */
1513 mp_result
1514 mp_int_gcd(mp_int a, mp_int b, mp_int c)
1515 {
1516         assert(a != NULL && b != NULL && c != NULL);
1517
1518         int                     ca = CMPZ(a);
1519         int                     cb = CMPZ(b);
1520
1521         if (ca == 0 && cb == 0)
1522         {
1523                 return MP_UNDEF;
1524         }
1525         else if (ca == 0)
1526         {
1527                 return mp_int_abs(b, c);
1528         }
1529         else if (cb == 0)
1530         {
1531                 return mp_int_abs(a, c);
1532         }
1533
1534         DECLARE_TEMP(3);
1535         REQUIRE(mp_int_copy(a, TEMP(0)));
1536         REQUIRE(mp_int_copy(b, TEMP(1)));
1537
1538         TEMP(0)->sign = MP_ZPOS;
1539         TEMP(1)->sign = MP_ZPOS;
1540
1541         int                     k = 0;
1542
1543         {                                                       /* Divide out common factors of 2 from u and v */
1544                 int                     div2_u = s_dp2k(TEMP(0));
1545                 int                     div2_v = s_dp2k(TEMP(1));
1546
1547                 k = MIN(div2_u, div2_v);
1548                 s_qdiv(TEMP(0), (mp_size) k);
1549                 s_qdiv(TEMP(1), (mp_size) k);
1550         }
1551
1552         if (mp_int_is_odd(TEMP(0)))
1553         {
1554                 REQUIRE(mp_int_neg(TEMP(1), TEMP(2)));
1555         }
1556         else
1557         {
1558                 REQUIRE(mp_int_copy(TEMP(0), TEMP(2)));
1559         }
1560
1561         for (;;)
1562         {
1563                 s_qdiv(TEMP(2), s_dp2k(TEMP(2)));
1564
1565                 if (CMPZ(TEMP(2)) > 0)
1566                 {
1567                         REQUIRE(mp_int_copy(TEMP(2), TEMP(0)));
1568                 }
1569                 else
1570                 {
1571                         REQUIRE(mp_int_neg(TEMP(2), TEMP(1)));
1572                 }
1573
1574                 REQUIRE(mp_int_sub(TEMP(0), TEMP(1), TEMP(2)));
1575
1576                 if (CMPZ(TEMP(2)) == 0)
1577                         break;
1578         }
1579
1580         REQUIRE(mp_int_abs(TEMP(0), c));
1581         if (!s_qmul(c, (mp_size) k))
1582                 REQUIRE(MP_MEMORY);
1583
1584         CLEANUP_TEMP();
1585         return MP_OK;
1586 }
1587
1588 /* This is the binary GCD algorithm again, but this time we keep track of the
1589    elementary matrix operations as we go, so we can get values x and y
1590    satisfying c = ax + by.
1591  */
1592 mp_result
1593 mp_int_egcd(mp_int a, mp_int b, mp_int c, mp_int x, mp_int y)
1594 {
1595         assert(a != NULL && b != NULL && c != NULL && (x != NULL || y != NULL));
1596
1597         mp_result       res = MP_OK;
1598         int                     ca = CMPZ(a);
1599         int                     cb = CMPZ(b);
1600
1601         if (ca == 0 && cb == 0)
1602         {
1603                 return MP_UNDEF;
1604         }
1605         else if (ca == 0)
1606         {
1607                 if ((res = mp_int_abs(b, c)) != MP_OK)
1608                         return res;
1609                 mp_int_zero(x);
1610                 (void) mp_int_set_value(y, 1);
1611                 return MP_OK;
1612         }
1613         else if (cb == 0)
1614         {
1615                 if ((res = mp_int_abs(a, c)) != MP_OK)
1616                         return res;
1617                 (void) mp_int_set_value(x, 1);
1618                 mp_int_zero(y);
1619                 return MP_OK;
1620         }
1621
1622         /*
1623          * Initialize temporaries: A:0, B:1, C:2, D:3, u:4, v:5, ou:6, ov:7
1624          */
1625         DECLARE_TEMP(8);
1626         REQUIRE(mp_int_set_value(TEMP(0), 1));
1627         REQUIRE(mp_int_set_value(TEMP(3), 1));
1628         REQUIRE(mp_int_copy(a, TEMP(4)));
1629         REQUIRE(mp_int_copy(b, TEMP(5)));
1630
1631         /* We will work with absolute values here */
1632         TEMP(4)->sign = MP_ZPOS;
1633         TEMP(5)->sign = MP_ZPOS;
1634
1635         int                     k = 0;
1636
1637         {                                                       /* Divide out common factors of 2 from u and v */
1638                 int                     div2_u = s_dp2k(TEMP(4)),
1639                                         div2_v = s_dp2k(TEMP(5));
1640
1641                 k = MIN(div2_u, div2_v);
1642                 s_qdiv(TEMP(4), k);
1643                 s_qdiv(TEMP(5), k);
1644         }
1645
1646         REQUIRE(mp_int_copy(TEMP(4), TEMP(6)));
1647         REQUIRE(mp_int_copy(TEMP(5), TEMP(7)));
1648
1649         for (;;)
1650         {
1651                 while (mp_int_is_even(TEMP(4)))
1652                 {
1653                         s_qdiv(TEMP(4), 1);
1654
1655                         if (mp_int_is_odd(TEMP(0)) || mp_int_is_odd(TEMP(1)))
1656                         {
1657                                 REQUIRE(mp_int_add(TEMP(0), TEMP(7), TEMP(0)));
1658                                 REQUIRE(mp_int_sub(TEMP(1), TEMP(6), TEMP(1)));
1659                         }
1660
1661                         s_qdiv(TEMP(0), 1);
1662                         s_qdiv(TEMP(1), 1);
1663                 }
1664
1665                 while (mp_int_is_even(TEMP(5)))
1666                 {
1667                         s_qdiv(TEMP(5), 1);
1668
1669                         if (mp_int_is_odd(TEMP(2)) || mp_int_is_odd(TEMP(3)))
1670                         {
1671                                 REQUIRE(mp_int_add(TEMP(2), TEMP(7), TEMP(2)));
1672                                 REQUIRE(mp_int_sub(TEMP(3), TEMP(6), TEMP(3)));
1673                         }
1674
1675                         s_qdiv(TEMP(2), 1);
1676                         s_qdiv(TEMP(3), 1);
1677                 }
1678
1679                 if (mp_int_compare(TEMP(4), TEMP(5)) >= 0)
1680                 {
1681                         REQUIRE(mp_int_sub(TEMP(4), TEMP(5), TEMP(4)));
1682                         REQUIRE(mp_int_sub(TEMP(0), TEMP(2), TEMP(0)));
1683                         REQUIRE(mp_int_sub(TEMP(1), TEMP(3), TEMP(1)));
1684                 }
1685                 else
1686                 {
1687                         REQUIRE(mp_int_sub(TEMP(5), TEMP(4), TEMP(5)));
1688                         REQUIRE(mp_int_sub(TEMP(2), TEMP(0), TEMP(2)));
1689                         REQUIRE(mp_int_sub(TEMP(3), TEMP(1), TEMP(3)));
1690                 }
1691
1692                 if (CMPZ(TEMP(4)) == 0)
1693                 {
1694                         if (x)
1695                                 REQUIRE(mp_int_copy(TEMP(2), x));
1696                         if (y)
1697                                 REQUIRE(mp_int_copy(TEMP(3), y));
1698                         if (c)
1699                         {
1700                                 if (!s_qmul(TEMP(5), k))
1701                                 {
1702                                         REQUIRE(MP_MEMORY);
1703                                 }
1704                                 REQUIRE(mp_int_copy(TEMP(5), c));
1705                         }
1706
1707                         break;
1708                 }
1709         }
1710
1711         CLEANUP_TEMP();
1712         return MP_OK;
1713 }
1714
1715 mp_result
1716 mp_int_lcm(mp_int a, mp_int b, mp_int c)
1717 {
1718         assert(a != NULL && b != NULL && c != NULL);
1719
1720         /*
1721          * Since a * b = gcd(a, b) * lcm(a, b), we can compute lcm(a, b) = (a /
1722          * gcd(a, b)) * b.
1723          *
1724          * This formulation insures everything works even if the input variables
1725          * share space.
1726          */
1727         DECLARE_TEMP(1);
1728         REQUIRE(mp_int_gcd(a, b, TEMP(0)));
1729         REQUIRE(mp_int_div(a, TEMP(0), TEMP(0), NULL));
1730         REQUIRE(mp_int_mul(TEMP(0), b, TEMP(0)));
1731         REQUIRE(mp_int_copy(TEMP(0), c));
1732
1733         CLEANUP_TEMP();
1734         return MP_OK;
1735 }
1736
1737 bool
1738 mp_int_divisible_value(mp_int a, mp_small v)
1739 {
1740         mp_small        rem = 0;
1741
1742         if (mp_int_div_value(a, v, NULL, &rem) != MP_OK)
1743         {
1744                 return false;
1745         }
1746         return rem == 0;
1747 }
1748
1749 int
1750 mp_int_is_pow2(mp_int z)
1751 {
1752         assert(z != NULL);
1753
1754         return s_isp2(z);
1755 }
1756
1757 /* Implementation of Newton's root finding method, based loosely on a patch
1758    contributed by Hal Finkel <half@halssoftware.com>
1759    modified by M. J. Fromberger.
1760  */
1761 mp_result
1762 mp_int_root(mp_int a, mp_small b, mp_int c)
1763 {
1764         assert(a != NULL && c != NULL && b > 0);
1765
1766         if (b == 1)
1767         {
1768                 return mp_int_copy(a, c);
1769         }
1770         bool            flips = false;
1771
1772         if (MP_SIGN(a) == MP_NEG)
1773         {
1774                 if (b % 2 == 0)
1775                 {
1776                         return MP_UNDEF;        /* root does not exist for negative a with
1777                                                                  * even b */
1778                 }
1779                 else
1780                 {
1781                         flips = true;
1782                 }
1783         }
1784
1785         DECLARE_TEMP(5);
1786         REQUIRE(mp_int_copy(a, TEMP(0)));
1787         REQUIRE(mp_int_copy(a, TEMP(1)));
1788         TEMP(0)->sign = MP_ZPOS;
1789         TEMP(1)->sign = MP_ZPOS;
1790
1791         for (;;)
1792         {
1793                 REQUIRE(mp_int_expt(TEMP(1), b, TEMP(2)));
1794
1795                 if (mp_int_compare_unsigned(TEMP(2), TEMP(0)) <= 0)
1796                         break;
1797
1798                 REQUIRE(mp_int_sub(TEMP(2), TEMP(0), TEMP(2)));
1799                 REQUIRE(mp_int_expt(TEMP(1), b - 1, TEMP(3)));
1800                 REQUIRE(mp_int_mul_value(TEMP(3), b, TEMP(3)));
1801                 REQUIRE(mp_int_div(TEMP(2), TEMP(3), TEMP(4), NULL));
1802                 REQUIRE(mp_int_sub(TEMP(1), TEMP(4), TEMP(4)));
1803
1804                 if (mp_int_compare_unsigned(TEMP(1), TEMP(4)) == 0)
1805                 {
1806                         REQUIRE(mp_int_sub_value(TEMP(4), 1, TEMP(4)));
1807                 }
1808                 REQUIRE(mp_int_copy(TEMP(4), TEMP(1)));
1809         }
1810
1811         REQUIRE(mp_int_copy(TEMP(1), c));
1812
1813         /* If the original value of a was negative, flip the output sign. */
1814         if (flips)
1815                 (void) mp_int_neg(c, c);        /* cannot fail */
1816
1817         CLEANUP_TEMP();
1818         return MP_OK;
1819 }
1820
1821 mp_result
1822 mp_int_to_int(mp_int z, mp_small *out)
1823 {
1824         assert(z != NULL);
1825
1826         /* Make sure the value is representable as a small integer */
1827         mp_sign         sz = MP_SIGN(z);
1828
1829         if ((sz == MP_ZPOS && mp_int_compare_value(z, MP_SMALL_MAX) > 0) ||
1830                 mp_int_compare_value(z, MP_SMALL_MIN) < 0)
1831         {
1832                 return MP_RANGE;
1833         }
1834
1835         mp_usmall       uz = MP_USED(z);
1836         mp_digit   *dz = MP_DIGITS(z) + uz - 1;
1837         mp_small        uv = 0;
1838
1839         while (uz > 0)
1840         {
1841                 uv <<= MP_DIGIT_BIT / 2;
1842                 uv = (uv << (MP_DIGIT_BIT / 2)) | *dz--;
1843                 --uz;
1844         }
1845
1846         if (out)
1847                 *out = (mp_small) ((sz == MP_NEG) ? -uv : uv);
1848
1849         return MP_OK;
1850 }
1851
1852 mp_result
1853 mp_int_to_uint(mp_int z, mp_usmall *out)
1854 {
1855         assert(z != NULL);
1856
1857         /* Make sure the value is representable as an unsigned small integer */
1858         mp_size         sz = MP_SIGN(z);
1859
1860         if (sz == MP_NEG || mp_int_compare_uvalue(z, MP_USMALL_MAX) > 0)
1861         {
1862                 return MP_RANGE;
1863         }
1864
1865         mp_size         uz = MP_USED(z);
1866         mp_digit   *dz = MP_DIGITS(z) + uz - 1;
1867         mp_usmall       uv = 0;
1868
1869         while (uz > 0)
1870         {
1871                 uv <<= MP_DIGIT_BIT / 2;
1872                 uv = (uv << (MP_DIGIT_BIT / 2)) | *dz--;
1873                 --uz;
1874         }
1875
1876         if (out)
1877                 *out = uv;
1878
1879         return MP_OK;
1880 }
1881
1882 mp_result
1883 mp_int_to_string(mp_int z, mp_size radix, char *str, int limit)
1884 {
1885         assert(z != NULL && str != NULL && limit >= 2);
1886         assert(radix >= MP_MIN_RADIX && radix <= MP_MAX_RADIX);
1887
1888         int                     cmp = 0;
1889
1890         if (CMPZ(z) == 0)
1891         {
1892                 *str++ = s_val2ch(0, 1);
1893         }
1894         else
1895         {
1896                 mp_result       res;
1897                 mpz_t           tmp;
1898                 char       *h,
1899                                    *t;
1900
1901                 if ((res = mp_int_init_copy(&tmp, z)) != MP_OK)
1902                         return res;
1903
1904                 if (MP_SIGN(z) == MP_NEG)
1905                 {
1906                         *str++ = '-';
1907                         --limit;
1908                 }
1909                 h = str;
1910
1911                 /* Generate digits in reverse order until finished or limit reached */
1912                 for ( /* */ ; limit > 0; --limit)
1913                 {
1914                         mp_digit        d;
1915
1916                         if ((cmp = CMPZ(&tmp)) == 0)
1917                                 break;
1918
1919                         d = s_ddiv(&tmp, (mp_digit) radix);
1920                         *str++ = s_val2ch(d, 1);
1921                 }
1922                 t = str - 1;
1923
1924                 /* Put digits back in correct output order */
1925                 while (h < t)
1926                 {
1927                         char            tc = *h;
1928
1929                         *h++ = *t;
1930                         *t-- = tc;
1931                 }
1932
1933                 mp_int_clear(&tmp);
1934         }
1935
1936         *str = '\0';
1937         if (cmp == 0)
1938         {
1939                 return MP_OK;
1940         }
1941         else
1942         {
1943                 return MP_TRUNC;
1944         }
1945 }
1946
1947 mp_result
1948 mp_int_string_len(mp_int z, mp_size radix)
1949 {
1950         assert(z != NULL);
1951         assert(radix >= MP_MIN_RADIX && radix <= MP_MAX_RADIX);
1952
1953         int                     len = s_outlen(z, radix) + 1;   /* for terminator */
1954
1955         /* Allow for sign marker on negatives */
1956         if (MP_SIGN(z) == MP_NEG)
1957                 len += 1;
1958
1959         return len;
1960 }
1961
1962 /* Read zero-terminated string into z */
1963 mp_result
1964 mp_int_read_string(mp_int z, mp_size radix, const char *str)
1965 {
1966         return mp_int_read_cstring(z, radix, str, NULL);
1967 }
1968
1969 mp_result
1970 mp_int_read_cstring(mp_int z, mp_size radix, const char *str,
1971                                         char **end)
1972 {
1973         assert(z != NULL && str != NULL);
1974         assert(radix >= MP_MIN_RADIX && radix <= MP_MAX_RADIX);
1975
1976         /* Skip leading whitespace */
1977         while (isspace((unsigned char) *str))
1978                 ++str;
1979
1980         /* Handle leading sign tag (+/-, positive default) */
1981         switch (*str)
1982         {
1983                 case '-':
1984                         z->sign = MP_NEG;
1985                         ++str;
1986                         break;
1987                 case '+':
1988                         ++str;                          /* fallthrough */
1989                 default:
1990                         z->sign = MP_ZPOS;
1991                         break;
1992         }
1993
1994         /* Skip leading zeroes */
1995         int                     ch;
1996
1997         while ((ch = s_ch2val(*str, radix)) == 0)
1998                 ++str;
1999
2000         /* Make sure there is enough space for the value */
2001         if (!s_pad(z, s_inlen(strlen(str), radix)))
2002                 return MP_MEMORY;
2003
2004         z->used = 1;
2005         z->digits[0] = 0;
2006
2007         while (*str != '\0' && ((ch = s_ch2val(*str, radix)) >= 0))
2008         {
2009                 s_dmul(z, (mp_digit) radix);
2010                 s_dadd(z, (mp_digit) ch);
2011                 ++str;
2012         }
2013
2014         CLAMP(z);
2015
2016         /* Override sign for zero, even if negative specified. */
2017         if (CMPZ(z) == 0)
2018                 z->sign = MP_ZPOS;
2019
2020         if (end != NULL)
2021                 *end = unconstify(char *, str);
2022
2023         /*
2024          * Return a truncation error if the string has unprocessed characters
2025          * remaining, so the caller can tell if the whole string was done
2026          */
2027         if (*str != '\0')
2028         {
2029                 return MP_TRUNC;
2030         }
2031         else
2032         {
2033                 return MP_OK;
2034         }
2035 }
2036
2037 mp_result
2038 mp_int_count_bits(mp_int z)
2039 {
2040         assert(z != NULL);
2041
2042         mp_size         uz = MP_USED(z);
2043
2044         if (uz == 1 && z->digits[0] == 0)
2045                 return 1;
2046
2047         --uz;
2048         mp_size         nbits = uz * MP_DIGIT_BIT;
2049         mp_digit        d = z->digits[uz];
2050
2051         while (d != 0)
2052         {
2053                 d >>= 1;
2054                 ++nbits;
2055         }
2056
2057         return nbits;
2058 }
2059
2060 mp_result
2061 mp_int_to_binary(mp_int z, unsigned char *buf, int limit)
2062 {
2063         static const int PAD_FOR_2C = 1;
2064
2065         assert(z != NULL && buf != NULL);
2066
2067         int                     limpos = limit;
2068         mp_result       res = s_tobin(z, buf, &limpos, PAD_FOR_2C);
2069
2070         if (MP_SIGN(z) == MP_NEG)
2071                 s_2comp(buf, limpos);
2072
2073         return res;
2074 }
2075
2076 mp_result
2077 mp_int_read_binary(mp_int z, unsigned char *buf, int len)
2078 {
2079         assert(z != NULL && buf != NULL && len > 0);
2080
2081         /* Figure out how many digits are needed to represent this value */
2082         mp_size         need = ((len * CHAR_BIT) + (MP_DIGIT_BIT - 1)) / MP_DIGIT_BIT;
2083
2084         if (!s_pad(z, need))
2085                 return MP_MEMORY;
2086
2087         mp_int_zero(z);
2088
2089         /*
2090          * If the high-order bit is set, take the 2's complement before reading
2091          * the value (it will be restored afterward)
2092          */
2093         if (buf[0] >> (CHAR_BIT - 1))
2094         {
2095                 z->sign = MP_NEG;
2096                 s_2comp(buf, len);
2097         }
2098
2099         mp_digit   *dz = MP_DIGITS(z);
2100         unsigned char *tmp = buf;
2101
2102         for (int i = len; i > 0; --i, ++tmp)
2103         {
2104                 s_qmul(z, (mp_size) CHAR_BIT);
2105                 *dz |= *tmp;
2106         }
2107
2108         /* Restore 2's complement if we took it before */
2109         if (MP_SIGN(z) == MP_NEG)
2110                 s_2comp(buf, len);
2111
2112         return MP_OK;
2113 }
2114
2115 mp_result
2116 mp_int_binary_len(mp_int z)
2117 {
2118         mp_result       res = mp_int_count_bits(z);
2119
2120         if (res <= 0)
2121                 return res;
2122
2123         int                     bytes = mp_int_unsigned_len(z);
2124
2125         /*
2126          * If the highest-order bit falls exactly on a byte boundary, we need to
2127          * pad with an extra byte so that the sign will be read correctly when
2128          * reading it back in.
2129          */
2130         if (bytes * CHAR_BIT == res)
2131                 ++bytes;
2132
2133         return bytes;
2134 }
2135
2136 mp_result
2137 mp_int_to_unsigned(mp_int z, unsigned char *buf, int limit)
2138 {
2139         static const int NO_PADDING = 0;
2140
2141         assert(z != NULL && buf != NULL);
2142
2143         return s_tobin(z, buf, &limit, NO_PADDING);
2144 }
2145
2146 mp_result
2147 mp_int_read_unsigned(mp_int z, unsigned char *buf, int len)
2148 {
2149         assert(z != NULL && buf != NULL && len > 0);
2150
2151         /* Figure out how many digits are needed to represent this value */
2152         mp_size         need = ((len * CHAR_BIT) + (MP_DIGIT_BIT - 1)) / MP_DIGIT_BIT;
2153
2154         if (!s_pad(z, need))
2155                 return MP_MEMORY;
2156
2157         mp_int_zero(z);
2158
2159         unsigned char *tmp = buf;
2160
2161         for (int i = len; i > 0; --i, ++tmp)
2162         {
2163                 (void) s_qmul(z, CHAR_BIT);
2164                 *MP_DIGITS(z) |= *tmp;
2165         }
2166
2167         return MP_OK;
2168 }
2169
2170 mp_result
2171 mp_int_unsigned_len(mp_int z)
2172 {
2173         mp_result       res = mp_int_count_bits(z);
2174
2175         if (res <= 0)
2176                 return res;
2177
2178         int                     bytes = (res + (CHAR_BIT - 1)) / CHAR_BIT;
2179
2180         return bytes;
2181 }
2182
2183 const char *
2184 mp_error_string(mp_result res)
2185 {
2186         if (res > 0)
2187                 return s_unknown_err;
2188
2189         res = -res;
2190         int                     ix;
2191
2192         for (ix = 0; ix < res && s_error_msg[ix] != NULL; ++ix)
2193                 ;
2194
2195         if (s_error_msg[ix] != NULL)
2196         {
2197                 return s_error_msg[ix];
2198         }
2199         else
2200         {
2201                 return s_unknown_err;
2202         }
2203 }
2204
2205 /*------------------------------------------------------------------------*/
2206 /* Private functions for internal use.  These make assumptions.           */
2207
2208 #if IMATH_DEBUG
2209 static const mp_digit fill = (mp_digit) 0xdeadbeefabad1dea;
2210 #endif
2211
2212 static mp_digit *
2213 s_alloc(mp_size num)
2214 {
2215         mp_digit   *out = px_alloc(num * sizeof(mp_digit));
2216
2217         assert(out != NULL);
2218
2219 #if IMATH_DEBUG
2220         for (mp_size ix = 0; ix < num; ++ix)
2221                 out[ix] = fill;
2222 #endif
2223         return out;
2224 }
2225
2226 static mp_digit *
2227 s_realloc(mp_digit *old, mp_size osize, mp_size nsize)
2228 {
2229 #if IMATH_DEBUG
2230         mp_digit   *new = s_alloc(nsize);
2231
2232         assert(new != NULL);
2233
2234         for (mp_size ix = 0; ix < nsize; ++ix)
2235                 new[ix] = fill;
2236         memcpy(new, old, osize * sizeof(mp_digit));
2237 #else
2238         mp_digit   *new = px_realloc(old, nsize * sizeof(mp_digit));
2239
2240         assert(new != NULL);
2241 #endif
2242
2243         return new;
2244 }
2245
2246 static void
2247 s_free(void *ptr)
2248 {
2249         px_free(ptr);
2250 }
2251
2252 static bool
2253 s_pad(mp_int z, mp_size min)
2254 {
2255         if (MP_ALLOC(z) < min)
2256         {
2257                 mp_size         nsize = s_round_prec(min);
2258                 mp_digit   *tmp;
2259
2260                 if (z->digits == &(z->single))
2261                 {
2262                         if ((tmp = s_alloc(nsize)) == NULL)
2263                                 return false;
2264                         tmp[0] = z->single;
2265                 }
2266                 else if ((tmp = s_realloc(MP_DIGITS(z), MP_ALLOC(z), nsize)) == NULL)
2267                 {
2268                         return false;
2269                 }
2270
2271                 z->digits = tmp;
2272                 z->alloc = nsize;
2273         }
2274
2275         return true;
2276 }
2277
2278 /* Note: This will not work correctly when value == MP_SMALL_MIN */
2279 static void
2280 s_fake(mp_int z, mp_small value, mp_digit vbuf[])
2281 {
2282         mp_usmall       uv = (mp_usmall) (value < 0) ? -value : value;
2283
2284         s_ufake(z, uv, vbuf);
2285         if (value < 0)
2286                 z->sign = MP_NEG;
2287 }
2288
2289 static void
2290 s_ufake(mp_int z, mp_usmall value, mp_digit vbuf[])
2291 {
2292         mp_size         ndig = (mp_size) s_uvpack(value, vbuf);
2293
2294         z->used = ndig;
2295         z->alloc = MP_VALUE_DIGITS(value);
2296         z->sign = MP_ZPOS;
2297         z->digits = vbuf;
2298 }
2299
2300 static int
2301 s_cdig(mp_digit *da, mp_digit *db, mp_size len)
2302 {
2303         mp_digit   *dat = da + len - 1,
2304                            *dbt = db + len - 1;
2305
2306         for ( /* */ ; len != 0; --len, --dat, --dbt)
2307         {
2308                 if (*dat > *dbt)
2309                 {
2310                         return 1;
2311                 }
2312                 else if (*dat < *dbt)
2313                 {
2314                         return -1;
2315                 }
2316         }
2317
2318         return 0;
2319 }
2320
2321 static int
2322 s_uvpack(mp_usmall uv, mp_digit t[])
2323 {
2324         int                     ndig = 0;
2325
2326         if (uv == 0)
2327                 t[ndig++] = 0;
2328         else
2329         {
2330                 while (uv != 0)
2331                 {
2332                         t[ndig++] = (mp_digit) uv;
2333                         uv >>= MP_DIGIT_BIT / 2;
2334                         uv >>= MP_DIGIT_BIT / 2;
2335                 }
2336         }
2337
2338         return ndig;
2339 }
2340
2341 static int
2342 s_ucmp(mp_int a, mp_int b)
2343 {
2344         mp_size         ua = MP_USED(a),
2345                                 ub = MP_USED(b);
2346
2347         if (ua > ub)
2348         {
2349                 return 1;
2350         }
2351         else if (ub > ua)
2352         {
2353                 return -1;
2354         }
2355         else
2356         {
2357                 return s_cdig(MP_DIGITS(a), MP_DIGITS(b), ua);
2358         }
2359 }
2360
2361 static int
2362 s_vcmp(mp_int a, mp_small v)
2363 {
2364 #ifdef _MSC_VER
2365 #pragma warning(push)
2366 #pragma warning(disable: 4146)
2367 #endif
2368         mp_usmall       uv = (v < 0) ? -(mp_usmall) v : (mp_usmall) v;
2369 #ifdef _MSC_VER
2370 #pragma warning(pop)
2371 #endif
2372
2373         return s_uvcmp(a, uv);
2374 }
2375
2376 static int
2377 s_uvcmp(mp_int a, mp_usmall uv)
2378 {
2379         mpz_t           vtmp;
2380         mp_digit        vdig[MP_VALUE_DIGITS(uv)];
2381
2382         s_ufake(&vtmp, uv, vdig);
2383         return s_ucmp(a, &vtmp);
2384 }
2385
2386 static mp_digit
2387 s_uadd(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a,
2388            mp_size size_b)
2389 {
2390         mp_size         pos;
2391         mp_word         w = 0;
2392
2393         /* Insure that da is the longer of the two to simplify later code */
2394         if (size_b > size_a)
2395         {
2396                 SWAP(mp_digit *, da, db);
2397                 SWAP(mp_size, size_a, size_b);
2398         }
2399
2400         /* Add corresponding digits until the shorter number runs out */
2401         for (pos = 0; pos < size_b; ++pos, ++da, ++db, ++dc)
2402         {
2403                 w = w + (mp_word) *da + (mp_word) *db;
2404                 *dc = LOWER_HALF(w);
2405                 w = UPPER_HALF(w);
2406         }
2407
2408         /* Propagate carries as far as necessary */
2409         for ( /* */ ; pos < size_a; ++pos, ++da, ++dc)
2410         {
2411                 w = w + *da;
2412
2413                 *dc = LOWER_HALF(w);
2414                 w = UPPER_HALF(w);
2415         }
2416
2417         /* Return carry out */
2418         return (mp_digit) w;
2419 }
2420
2421 static void
2422 s_usub(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a,
2423            mp_size size_b)
2424 {
2425         mp_size         pos;
2426         mp_word         w = 0;
2427
2428         /* We assume that |a| >= |b| so this should definitely hold */
2429         assert(size_a >= size_b);
2430
2431         /* Subtract corresponding digits and propagate borrow */
2432         for (pos = 0; pos < size_b; ++pos, ++da, ++db, ++dc)
2433         {
2434                 w = ((mp_word) MP_DIGIT_MAX + 1 +       /* MP_RADIX */
2435                          (mp_word) *da) -
2436                         w - (mp_word) *db;
2437
2438                 *dc = LOWER_HALF(w);
2439                 w = (UPPER_HALF(w) == 0);
2440         }
2441
2442         /* Finish the subtraction for remaining upper digits of da */
2443         for ( /* */ ; pos < size_a; ++pos, ++da, ++dc)
2444         {
2445                 w = ((mp_word) MP_DIGIT_MAX + 1 +       /* MP_RADIX */
2446                          (mp_word) *da) -
2447                         w;
2448
2449                 *dc = LOWER_HALF(w);
2450                 w = (UPPER_HALF(w) == 0);
2451         }
2452
2453         /* If there is a borrow out at the end, it violates the precondition */
2454         assert(w == 0);
2455 }
2456
2457 static int
2458 s_kmul(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a,
2459            mp_size size_b)
2460 {
2461         mp_size         bot_size;
2462
2463         /* Make sure b is the smaller of the two input values */
2464         if (size_b > size_a)
2465         {
2466                 SWAP(mp_digit *, da, db);
2467                 SWAP(mp_size, size_a, size_b);
2468         }
2469
2470         /*
2471          * Insure that the bottom is the larger half in an odd-length split; the
2472          * code below relies on this being true.
2473          */
2474         bot_size = (size_a + 1) / 2;
2475
2476         /*
2477          * If the values are big enough to bother with recursion, use the
2478          * Karatsuba algorithm to compute the product; otherwise use the normal
2479          * multiplication algorithm
2480          */
2481         if (multiply_threshold && size_a >= multiply_threshold && size_b > bot_size)
2482         {
2483                 mp_digit   *t1,
2484                                    *t2,
2485                                    *t3,
2486                                         carry;
2487
2488                 mp_digit   *a_top = da + bot_size;
2489                 mp_digit   *b_top = db + bot_size;
2490
2491                 mp_size         at_size = size_a - bot_size;
2492                 mp_size         bt_size = size_b - bot_size;
2493                 mp_size         buf_size = 2 * bot_size;
2494
2495                 /*
2496                  * Do a single allocation for all three temporary buffers needed; each
2497                  * buffer must be big enough to hold the product of two bottom halves,
2498                  * and one buffer needs space for the completed product; twice the
2499                  * space is plenty.
2500                  */
2501                 if ((t1 = s_alloc(4 * buf_size)) == NULL)
2502                         return 0;
2503                 t2 = t1 + buf_size;
2504                 t3 = t2 + buf_size;
2505                 ZERO(t1, 4 * buf_size);
2506
2507                 /*
2508                  * t1 and t2 are initially used as temporaries to compute the inner
2509                  * product (a1 + a0)(b1 + b0) = a1b1 + a1b0 + a0b1 + a0b0
2510                  */
2511                 carry = s_uadd(da, a_top, t1, bot_size, at_size);       /* t1 = a1 + a0 */
2512                 t1[bot_size] = carry;
2513
2514                 carry = s_uadd(db, b_top, t2, bot_size, bt_size);       /* t2 = b1 + b0 */
2515                 t2[bot_size] = carry;
2516
2517                 (void) s_kmul(t1, t2, t3, bot_size + 1, bot_size + 1);  /* t3 = t1 * t2 */
2518
2519                 /*
2520                  * Now we'll get t1 = a0b0 and t2 = a1b1, and subtract them out so
2521                  * that we're left with only the pieces we want:  t3 = a1b0 + a0b1
2522                  */
2523                 ZERO(t1, buf_size);
2524                 ZERO(t2, buf_size);
2525                 (void) s_kmul(da, db, t1, bot_size, bot_size);  /* t1 = a0 * b0 */
2526                 (void) s_kmul(a_top, b_top, t2, at_size, bt_size);      /* t2 = a1 * b1 */
2527
2528                 /* Subtract out t1 and t2 to get the inner product */
2529                 s_usub(t3, t1, t3, buf_size + 2, buf_size);
2530                 s_usub(t3, t2, t3, buf_size + 2, buf_size);
2531
2532                 /* Assemble the output value */
2533                 COPY(t1, dc, buf_size);
2534                 carry = s_uadd(t3, dc + bot_size, dc + bot_size, buf_size + 1, buf_size);
2535                 assert(carry == 0);
2536
2537                 carry =
2538                         s_uadd(t2, dc + 2 * bot_size, dc + 2 * bot_size, buf_size, buf_size);
2539                 assert(carry == 0);
2540
2541                 s_free(t1);                             /* note t2 and t3 are just internal pointers
2542                                                                  * to t1 */
2543         }
2544         else
2545         {
2546                 s_umul(da, db, dc, size_a, size_b);
2547         }
2548
2549         return 1;
2550 }
2551
2552 static void
2553 s_umul(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a,
2554            mp_size size_b)
2555 {
2556         mp_size         a,
2557                                 b;
2558         mp_word         w;
2559
2560         for (a = 0; a < size_a; ++a, ++dc, ++da)
2561         {
2562                 mp_digit   *dct = dc;
2563                 mp_digit   *dbt = db;
2564
2565                 if (*da == 0)
2566                         continue;
2567
2568                 w = 0;
2569                 for (b = 0; b < size_b; ++b, ++dbt, ++dct)
2570                 {
2571                         w = (mp_word) *da * (mp_word) *dbt + w + (mp_word) *dct;
2572
2573                         *dct = LOWER_HALF(w);
2574                         w = UPPER_HALF(w);
2575                 }
2576
2577                 *dct = (mp_digit) w;
2578         }
2579 }
2580
2581 static int
2582 s_ksqr(mp_digit *da, mp_digit *dc, mp_size size_a)
2583 {
2584         if (multiply_threshold && size_a > multiply_threshold)
2585         {
2586                 mp_size         bot_size = (size_a + 1) / 2;
2587                 mp_digit   *a_top = da + bot_size;
2588                 mp_digit   *t1,
2589                                    *t2,
2590                                    *t3,
2591                                         carry PG_USED_FOR_ASSERTS_ONLY;
2592                 mp_size         at_size = size_a - bot_size;
2593                 mp_size         buf_size = 2 * bot_size;
2594
2595                 if ((t1 = s_alloc(4 * buf_size)) == NULL)
2596                         return 0;
2597                 t2 = t1 + buf_size;
2598                 t3 = t2 + buf_size;
2599                 ZERO(t1, 4 * buf_size);
2600
2601                 (void) s_ksqr(da, t1, bot_size);        /* t1 = a0 ^ 2 */
2602                 (void) s_ksqr(a_top, t2, at_size);      /* t2 = a1 ^ 2 */
2603
2604                 (void) s_kmul(da, a_top, t3, bot_size, at_size);        /* t3 = a0 * a1 */
2605
2606                 /* Quick multiply t3 by 2, shifting left (can't overflow) */
2607                 {
2608                         int                     i,
2609                                                 top = bot_size + at_size;
2610                         mp_word         w,
2611                                                 save = 0;
2612
2613                         for (i = 0; i < top; ++i)
2614                         {
2615                                 w = t3[i];
2616                                 w = (w << 1) | save;
2617                                 t3[i] = LOWER_HALF(w);
2618                                 save = UPPER_HALF(w);
2619                         }
2620                         t3[i] = LOWER_HALF(save);
2621                 }
2622
2623                 /* Assemble the output value */
2624                 COPY(t1, dc, 2 * bot_size);
2625                 carry = s_uadd(t3, dc + bot_size, dc + bot_size, buf_size + 1, buf_size);
2626                 assert(carry == 0);
2627
2628                 carry =
2629                         s_uadd(t2, dc + 2 * bot_size, dc + 2 * bot_size, buf_size, buf_size);
2630                 assert(carry == 0);
2631
2632                 s_free(t1);                             /* note that t2 and t2 are internal pointers
2633                                                                  * only */
2634
2635         }
2636         else
2637         {
2638                 s_usqr(da, dc, size_a);
2639         }
2640
2641         return 1;
2642 }
2643
2644 static void
2645 s_usqr(mp_digit *da, mp_digit *dc, mp_size size_a)
2646 {
2647         mp_size         i,
2648                                 j;
2649         mp_word         w;
2650
2651         for (i = 0; i < size_a; ++i, dc += 2, ++da)
2652         {
2653                 mp_digit   *dct = dc,
2654                                    *dat = da;
2655
2656                 if (*da == 0)
2657                         continue;
2658
2659                 /* Take care of the first digit, no rollover */
2660                 w = (mp_word) *dat * (mp_word) *dat + (mp_word) *dct;
2661                 *dct = LOWER_HALF(w);
2662                 w = UPPER_HALF(w);
2663                 ++dat;
2664                 ++dct;
2665
2666                 for (j = i + 1; j < size_a; ++j, ++dat, ++dct)
2667                 {
2668                         mp_word         t = (mp_word) *da * (mp_word) *dat;
2669                         mp_word         u = w + (mp_word) *dct,
2670                                                 ov = 0;
2671
2672                         /* Check if doubling t will overflow a word */
2673                         if (HIGH_BIT_SET(t))
2674                                 ov = 1;
2675
2676                         w = t + t;
2677
2678                         /* Check if adding u to w will overflow a word */
2679                         if (ADD_WILL_OVERFLOW(w, u))
2680                                 ov = 1;
2681
2682                         w += u;
2683
2684                         *dct = LOWER_HALF(w);
2685                         w = UPPER_HALF(w);
2686                         if (ov)
2687                         {
2688                                 w += MP_DIGIT_MAX;      /* MP_RADIX */
2689                                 ++w;
2690                         }
2691                 }
2692
2693                 w = w + *dct;
2694                 *dct = (mp_digit) w;
2695                 while ((w = UPPER_HALF(w)) != 0)
2696                 {
2697                         ++dct;
2698                         w = w + *dct;
2699                         *dct = LOWER_HALF(w);
2700                 }
2701
2702                 assert(w == 0);
2703         }
2704 }
2705
2706 static void
2707 s_dadd(mp_int a, mp_digit b)
2708 {
2709         mp_word         w = 0;
2710         mp_digit   *da = MP_DIGITS(a);
2711         mp_size         ua = MP_USED(a);
2712
2713         w = (mp_word) *da + b;
2714         *da++ = LOWER_HALF(w);
2715         w = UPPER_HALF(w);
2716
2717         for (ua -= 1; ua > 0; --ua, ++da)
2718         {
2719                 w = (mp_word) *da + w;
2720
2721                 *da = LOWER_HALF(w);
2722                 w = UPPER_HALF(w);
2723         }
2724
2725         if (w)
2726         {
2727                 *da = (mp_digit) w;
2728                 a->used += 1;
2729         }
2730 }
2731
2732 static void
2733 s_dmul(mp_int a, mp_digit b)
2734 {
2735         mp_word         w = 0;
2736         mp_digit   *da = MP_DIGITS(a);
2737         mp_size         ua = MP_USED(a);
2738
2739         while (ua > 0)
2740         {
2741                 w = (mp_word) *da * b + w;
2742                 *da++ = LOWER_HALF(w);
2743                 w = UPPER_HALF(w);
2744                 --ua;
2745         }
2746
2747         if (w)
2748         {
2749                 *da = (mp_digit) w;
2750                 a->used += 1;
2751         }
2752 }
2753
2754 static void
2755 s_dbmul(mp_digit *da, mp_digit b, mp_digit *dc, mp_size size_a)
2756 {
2757         mp_word         w = 0;
2758
2759         while (size_a > 0)
2760         {
2761                 w = (mp_word) *da++ * (mp_word) b + w;
2762
2763                 *dc++ = LOWER_HALF(w);
2764                 w = UPPER_HALF(w);
2765                 --size_a;
2766         }
2767
2768         if (w)
2769                 *dc = LOWER_HALF(w);
2770 }
2771
2772 static mp_digit
2773 s_ddiv(mp_int a, mp_digit b)
2774 {
2775         mp_word         w = 0,
2776                                 qdigit;
2777         mp_size         ua = MP_USED(a);
2778         mp_digit   *da = MP_DIGITS(a) + ua - 1;
2779
2780         for ( /* */ ; ua > 0; --ua, --da)
2781         {
2782                 w = (w << MP_DIGIT_BIT) | *da;
2783
2784                 if (w >= b)
2785                 {
2786                         qdigit = w / b;
2787                         w = w % b;
2788                 }
2789                 else
2790                 {
2791                         qdigit = 0;
2792                 }
2793
2794                 *da = (mp_digit) qdigit;
2795         }
2796
2797         CLAMP(a);
2798         return (mp_digit) w;
2799 }
2800
2801 static void
2802 s_qdiv(mp_int z, mp_size p2)
2803 {
2804         mp_size         ndig = p2 / MP_DIGIT_BIT,
2805                                 nbits = p2 % MP_DIGIT_BIT;
2806         mp_size         uz = MP_USED(z);
2807
2808         if (ndig)
2809         {
2810                 mp_size         mark;
2811                 mp_digit   *to,
2812                                    *from;
2813
2814                 if (ndig >= uz)
2815                 {
2816                         mp_int_zero(z);
2817                         return;
2818                 }
2819
2820                 to = MP_DIGITS(z);
2821                 from = to + ndig;
2822
2823                 for (mark = ndig; mark < uz; ++mark)
2824                 {
2825                         *to++ = *from++;
2826                 }
2827
2828                 z->used = uz - ndig;
2829         }
2830
2831         if (nbits)
2832         {
2833                 mp_digit        d = 0,
2834                                    *dz,
2835                                         save;
2836                 mp_size         up = MP_DIGIT_BIT - nbits;
2837
2838                 uz = MP_USED(z);
2839                 dz = MP_DIGITS(z) + uz - 1;
2840
2841                 for ( /* */ ; uz > 0; --uz, --dz)
2842                 {
2843                         save = *dz;
2844
2845                         *dz = (*dz >> nbits) | (d << up);
2846                         d = save;
2847                 }
2848
2849                 CLAMP(z);
2850         }
2851
2852         if (MP_USED(z) == 1 && z->digits[0] == 0)
2853                 z->sign = MP_ZPOS;
2854 }
2855
2856 static void
2857 s_qmod(mp_int z, mp_size p2)
2858 {
2859         mp_size         start = p2 / MP_DIGIT_BIT + 1,
2860                                 rest = p2 % MP_DIGIT_BIT;
2861         mp_size         uz = MP_USED(z);
2862         mp_digit        mask = (1u << rest) - 1;
2863
2864         if (start <= uz)
2865         {
2866                 z->used = start;
2867                 z->digits[start - 1] &= mask;
2868                 CLAMP(z);
2869         }
2870 }
2871
2872 static int
2873 s_qmul(mp_int z, mp_size p2)
2874 {
2875         mp_size         uz,
2876                                 need,
2877                                 rest,
2878                                 extra,
2879                                 i;
2880         mp_digit   *from,
2881                            *to,
2882                                 d;
2883
2884         if (p2 == 0)
2885                 return 1;
2886
2887         uz = MP_USED(z);
2888         need = p2 / MP_DIGIT_BIT;
2889         rest = p2 % MP_DIGIT_BIT;
2890
2891         /*
2892          * Figure out if we need an extra digit at the top end; this occurs if the
2893          * topmost `rest' bits of the high-order digit of z are not zero, meaning
2894          * they will be shifted off the end if not preserved
2895          */
2896         extra = 0;
2897         if (rest != 0)
2898         {
2899                 mp_digit   *dz = MP_DIGITS(z) + uz - 1;
2900
2901                 if ((*dz >> (MP_DIGIT_BIT - rest)) != 0)
2902                         extra = 1;
2903         }
2904
2905         if (!s_pad(z, uz + need + extra))
2906                 return 0;
2907
2908         /*
2909          * If we need to shift by whole digits, do that in one pass, then to back
2910          * and shift by partial digits.
2911          */
2912         if (need > 0)
2913         {
2914                 from = MP_DIGITS(z) + uz - 1;
2915                 to = from + need;
2916
2917                 for (i = 0; i < uz; ++i)
2918                         *to-- = *from--;
2919
2920                 ZERO(MP_DIGITS(z), need);
2921                 uz += need;
2922         }
2923
2924         if (rest)
2925         {
2926                 d = 0;
2927                 for (i = need, from = MP_DIGITS(z) + need; i < uz; ++i, ++from)
2928                 {
2929                         mp_digit        save = *from;
2930
2931                         *from = (*from << rest) | (d >> (MP_DIGIT_BIT - rest));
2932                         d = save;
2933                 }
2934
2935                 d >>= (MP_DIGIT_BIT - rest);
2936                 if (d != 0)
2937                 {
2938                         *from = d;
2939                         uz += extra;
2940                 }
2941         }
2942
2943         z->used = uz;
2944         CLAMP(z);
2945
2946         return 1;
2947 }
2948
2949 /* Compute z = 2^p2 - |z|; requires that 2^p2 >= |z|
2950    The sign of the result is always zero/positive.
2951  */
2952 static int
2953 s_qsub(mp_int z, mp_size p2)
2954 {
2955         mp_digit        hi = (1u << (p2 % MP_DIGIT_BIT)),
2956                            *zp;
2957         mp_size         tdig = (p2 / MP_DIGIT_BIT),
2958                                 pos;
2959         mp_word         w = 0;
2960
2961         if (!s_pad(z, tdig + 1))
2962                 return 0;
2963
2964         for (pos = 0, zp = MP_DIGITS(z); pos < tdig; ++pos, ++zp)
2965         {
2966                 w = ((mp_word) MP_DIGIT_MAX + 1) - w - (mp_word) *zp;
2967
2968                 *zp = LOWER_HALF(w);
2969                 w = UPPER_HALF(w) ? 0 : 1;
2970         }
2971
2972         w = ((mp_word) MP_DIGIT_MAX + 1 + hi) - w - (mp_word) *zp;
2973         *zp = LOWER_HALF(w);
2974
2975         assert(UPPER_HALF(w) != 0); /* no borrow out should be possible */
2976
2977         z->sign = MP_ZPOS;
2978         CLAMP(z);
2979
2980         return 1;
2981 }
2982
2983 static int
2984 s_dp2k(mp_int z)
2985 {
2986         int                     k = 0;
2987         mp_digit   *dp = MP_DIGITS(z),
2988                                 d;
2989
2990         if (MP_USED(z) == 1 && *dp == 0)
2991                 return 1;
2992
2993         while (*dp == 0)
2994         {
2995                 k += MP_DIGIT_BIT;
2996                 ++dp;
2997         }
2998
2999         d = *dp;
3000         while ((d & 1) == 0)
3001         {
3002                 d >>= 1;
3003                 ++k;
3004         }
3005
3006         return k;
3007 }
3008
3009 static int
3010 s_isp2(mp_int z)
3011 {
3012         mp_size         uz = MP_USED(z),
3013                                 k = 0;
3014         mp_digit   *dz = MP_DIGITS(z),
3015                                 d;
3016
3017         while (uz > 1)
3018         {
3019                 if (*dz++ != 0)
3020                         return -1;
3021                 k += MP_DIGIT_BIT;
3022                 --uz;
3023         }
3024
3025         d = *dz;
3026         while (d > 1)
3027         {
3028                 if (d & 1)
3029                         return -1;
3030                 ++k;
3031                 d >>= 1;
3032         }
3033
3034         return (int) k;
3035 }
3036
3037 static int
3038 s_2expt(mp_int z, mp_small k)
3039 {
3040         mp_size         ndig,
3041                                 rest;
3042         mp_digit   *dz;
3043
3044         ndig = (k + MP_DIGIT_BIT) / MP_DIGIT_BIT;
3045         rest = k % MP_DIGIT_BIT;
3046
3047         if (!s_pad(z, ndig))
3048                 return 0;
3049
3050         dz = MP_DIGITS(z);
3051         ZERO(dz, ndig);
3052         *(dz + ndig - 1) = (1u << rest);
3053         z->used = ndig;
3054
3055         return 1;
3056 }
3057
3058 static int
3059 s_norm(mp_int a, mp_int b)
3060 {
3061         mp_digit        d = b->digits[MP_USED(b) - 1];
3062         int                     k = 0;
3063
3064         while (d < (1u << (mp_digit) (MP_DIGIT_BIT - 1)))
3065         {                                                       /* d < (MP_RADIX / 2) */
3066                 d <<= 1;
3067                 ++k;
3068         }
3069
3070         /* These multiplications can't fail */
3071         if (k != 0)
3072         {
3073                 (void) s_qmul(a, (mp_size) k);
3074                 (void) s_qmul(b, (mp_size) k);
3075         }
3076
3077         return k;
3078 }
3079
3080 static mp_result
3081 s_brmu(mp_int z, mp_int m)
3082 {
3083         mp_size         um = MP_USED(m) * 2;
3084
3085         if (!s_pad(z, um))
3086                 return MP_MEMORY;
3087
3088         s_2expt(z, MP_DIGIT_BIT * um);
3089         return mp_int_div(z, m, z, NULL);
3090 }
3091
3092 static int
3093 s_reduce(mp_int x, mp_int m, mp_int mu, mp_int q1, mp_int q2)
3094 {
3095         mp_size         um = MP_USED(m),
3096                                 umb_p1,
3097                                 umb_m1;
3098
3099         umb_p1 = (um + 1) * MP_DIGIT_BIT;
3100         umb_m1 = (um - 1) * MP_DIGIT_BIT;
3101
3102         if (mp_int_copy(x, q1) != MP_OK)
3103                 return 0;
3104
3105         /* Compute q2 = floor((floor(x / b^(k-1)) * mu) / b^(k+1)) */
3106         s_qdiv(q1, umb_m1);
3107         UMUL(q1, mu, q2);
3108         s_qdiv(q2, umb_p1);
3109
3110         /* Set x = x mod b^(k+1) */
3111         s_qmod(x, umb_p1);
3112
3113         /*
3114          * Now, q is a guess for the quotient a / m. Compute x - q * m mod
3115          * b^(k+1), replacing x.  This may be off by a factor of 2m, but no more
3116          * than that.
3117          */
3118         UMUL(q2, m, q1);
3119         s_qmod(q1, umb_p1);
3120         (void) mp_int_sub(x, q1, x);    /* can't fail */
3121
3122         /*
3123          * The result may be < 0; if it is, add b^(k+1) to pin it in the proper
3124          * range.
3125          */
3126         if ((CMPZ(x) < 0) && !s_qsub(x, umb_p1))
3127                 return 0;
3128
3129         /*
3130          * If x > m, we need to back it off until it is in range.  This will be
3131          * required at most twice.
3132          */
3133         if (mp_int_compare(x, m) >= 0)
3134         {
3135                 (void) mp_int_sub(x, m, x);
3136                 if (mp_int_compare(x, m) >= 0)
3137                 {
3138                         (void) mp_int_sub(x, m, x);
3139                 }
3140         }
3141
3142         /* At this point, x has been properly reduced. */
3143         return 1;
3144 }
3145
3146 /* Perform modular exponentiation using Barrett's method, where mu is the
3147    reduction constant for m.  Assumes a < m, b > 0. */
3148 static mp_result
3149 s_embar(mp_int a, mp_int b, mp_int m, mp_int mu, mp_int c)
3150 {
3151         mp_digit        umu = MP_USED(mu);
3152         mp_digit   *db = MP_DIGITS(b);
3153         mp_digit   *dbt = db + MP_USED(b) - 1;
3154
3155         DECLARE_TEMP(3);
3156         REQUIRE(GROW(TEMP(0), 4 * umu));
3157         REQUIRE(GROW(TEMP(1), 4 * umu));
3158         REQUIRE(GROW(TEMP(2), 4 * umu));
3159         ZERO(TEMP(0)->digits, TEMP(0)->alloc);
3160         ZERO(TEMP(1)->digits, TEMP(1)->alloc);
3161         ZERO(TEMP(2)->digits, TEMP(2)->alloc);
3162
3163         (void) mp_int_set_value(c, 1);
3164
3165         /* Take care of low-order digits */
3166         while (db < dbt)
3167         {
3168                 mp_digit        d = *db;
3169
3170                 for (int i = MP_DIGIT_BIT; i > 0; --i, d >>= 1)
3171                 {
3172                         if (d & 1)
3173                         {
3174                                 /* The use of a second temporary avoids allocation */
3175                                 UMUL(c, a, TEMP(0));
3176                                 if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2)))
3177                                 {
3178                                         REQUIRE(MP_MEMORY);
3179                                 }
3180                                 mp_int_copy(TEMP(0), c);
3181                         }
3182
3183                         USQR(a, TEMP(0));
3184                         assert(MP_SIGN(TEMP(0)) == MP_ZPOS);
3185                         if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2)))
3186                         {
3187                                 REQUIRE(MP_MEMORY);
3188                         }
3189                         assert(MP_SIGN(TEMP(0)) == MP_ZPOS);
3190                         mp_int_copy(TEMP(0), a);
3191                 }
3192
3193                 ++db;
3194         }
3195
3196         /* Take care of highest-order digit */
3197         mp_digit        d = *dbt;
3198
3199         for (;;)
3200         {
3201                 if (d & 1)
3202                 {
3203                         UMUL(c, a, TEMP(0));
3204                         if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2)))
3205                         {
3206                                 REQUIRE(MP_MEMORY);
3207                         }
3208                         mp_int_copy(TEMP(0), c);
3209                 }
3210
3211                 d >>= 1;
3212                 if (!d)
3213                         break;
3214
3215                 USQR(a, TEMP(0));
3216                 if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2)))
3217                 {
3218                         REQUIRE(MP_MEMORY);
3219                 }
3220                 (void) mp_int_copy(TEMP(0), a);
3221         }
3222
3223         CLEANUP_TEMP();
3224         return MP_OK;
3225 }
3226
3227 /* Division of nonnegative integers
3228
3229    This function implements division algorithm for unsigned multi-precision
3230    integers. The algorithm is based on Algorithm D from Knuth's "The Art of
3231    Computer Programming", 3rd ed. 1998, pg 272-273.
3232
3233    We diverge from Knuth's algorithm in that we do not perform the subtraction
3234    from the remainder until we have determined that we have the correct
3235    quotient digit. This makes our algorithm less efficient that Knuth because
3236    we might have to perform multiple multiplication and comparison steps before
3237    the subtraction. The advantage is that it is easy to implement and ensure
3238    correctness without worrying about underflow from the subtraction.
3239
3240    inputs: u   a n+m digit integer in base b (b is 2^MP_DIGIT_BIT)
3241                    v   a n   digit integer in base b (b is 2^MP_DIGIT_BIT)
3242                    n >= 1
3243                    m >= 0
3244   outputs: u / v stored in u
3245                    u % v stored in v
3246  */
3247 static mp_result
3248 s_udiv_knuth(mp_int u, mp_int v)
3249 {
3250         /* Force signs to positive */
3251         u->sign = MP_ZPOS;
3252         v->sign = MP_ZPOS;
3253
3254         /* Use simple division algorithm when v is only one digit long */
3255         if (MP_USED(v) == 1)
3256         {
3257                 mp_digit        d,
3258                                         rem;
3259
3260                 d = v->digits[0];
3261                 rem = s_ddiv(u, d);
3262                 mp_int_set_value(v, rem);
3263                 return MP_OK;
3264         }
3265
3266         /*
3267          * Algorithm D
3268          *
3269          * The n and m variables are defined as used by Knuth. u is an n digit
3270          * number with digits u_{n-1}..u_0. v is an n+m digit number with digits
3271          * from v_{m+n-1}..v_0. We require that n > 1 and m >= 0
3272          */
3273         mp_size         n = MP_USED(v);
3274         mp_size         m = MP_USED(u) - n;
3275
3276         assert(n > 1);
3277         /* assert(m >= 0) follows because m is unsigned. */
3278
3279         /*
3280          * D1: Normalize. The normalization step provides the necessary condition
3281          * for Theorem B, which states that the quotient estimate for q_j, call it
3282          * qhat
3283          *
3284          * qhat = u_{j+n}u_{j+n-1} / v_{n-1}
3285          *
3286          * is bounded by
3287          *
3288          * qhat - 2 <= q_j <= qhat.
3289          *
3290          * That is, qhat is always greater than the actual quotient digit q, and
3291          * it is never more than two larger than the actual quotient digit.
3292          */
3293         int                     k = s_norm(u, v);
3294
3295         /*
3296          * Extend size of u by one if needed.
3297          *
3298          * The algorithm begins with a value of u that has one more digit of
3299          * input. The normalization step sets u_{m+n}..u_0 = 2^k * u_{m+n-1}..u_0.
3300          * If the multiplication did not increase the number of digits of u, we
3301          * need to add a leading zero here.
3302          */
3303         if (k == 0 || MP_USED(u) != m + n + 1)
3304         {
3305                 if (!s_pad(u, m + n + 1))
3306                         return MP_MEMORY;
3307                 u->digits[m + n] = 0;
3308                 u->used = m + n + 1;
3309         }
3310
3311         /*
3312          * Add a leading 0 to v.
3313          *
3314          * The multiplication in step D4 multiplies qhat * 0v_{n-1}..v_0.  We need
3315          * to add the leading zero to v here to ensure that the multiplication
3316          * will produce the full n+1 digit result.
3317          */
3318         if (!s_pad(v, n + 1))
3319                 return MP_MEMORY;
3320         v->digits[n] = 0;
3321
3322         /*
3323          * Initialize temporary variables q and t. q allocates space for m+1
3324          * digits to store the quotient digits t allocates space for n+1 digits to
3325          * hold the result of q_j*v
3326          */
3327         DECLARE_TEMP(2);
3328         REQUIRE(GROW(TEMP(0), m + 1));
3329         REQUIRE(GROW(TEMP(1), n + 1));
3330
3331         /* D2: Initialize j */
3332         int                     j = m;
3333         mpz_t           r;
3334
3335         r.digits = MP_DIGITS(u) + j;    /* The contents of r are shared with u */
3336         r.used = n + 1;
3337         r.sign = MP_ZPOS;
3338         r.alloc = MP_ALLOC(u);
3339         ZERO(TEMP(1)->digits, TEMP(1)->alloc);
3340
3341         /* Calculate the m+1 digits of the quotient result */
3342         for (; j >= 0; j--)
3343         {
3344                 /* D3: Calculate q' */
3345                 /* r->digits is aligned to position j of the number u */
3346                 mp_word         pfx,
3347                                         qhat;
3348
3349                 pfx = r.digits[n];
3350                 pfx <<= MP_DIGIT_BIT / 2;
3351                 pfx <<= MP_DIGIT_BIT / 2;
3352                 pfx |= r.digits[n - 1]; /* pfx = u_{j+n}{j+n-1} */
3353
3354                 qhat = pfx / v->digits[n - 1];
3355
3356                 /*
3357                  * Check to see if qhat > b, and decrease qhat if so. Theorem B
3358                  * guarantess that qhat is at most 2 larger than the actual value, so
3359                  * it is possible that qhat is greater than the maximum value that
3360                  * will fit in a digit
3361                  */
3362                 if (qhat > MP_DIGIT_MAX)
3363                         qhat = MP_DIGIT_MAX;
3364
3365                 /*
3366                  * D4,D5,D6: Multiply qhat * v and test for a correct value of q
3367                  *
3368                  * We proceed a bit different than the way described by Knuth. This
3369                  * way is simpler but less efficent. Instead of doing the multiply and
3370                  * subtract then checking for underflow, we first do the multiply of
3371                  * qhat * v and see if it is larger than the current remainder r. If
3372                  * it is larger, we decrease qhat by one and try again. We may need to
3373                  * decrease qhat one more time before we get a value that is smaller
3374                  * than r.
3375                  *
3376                  * This way is less efficent than Knuth becuase we do more multiplies,
3377                  * but we do not need to worry about underflow this way.
3378                  */
3379                 /* t = qhat * v */
3380                 s_dbmul(MP_DIGITS(v), (mp_digit) qhat, TEMP(1)->digits, n + 1);
3381                 TEMP(1)->used = n + 1;
3382                 CLAMP(TEMP(1));
3383
3384                 /* Clamp r for the comparison. Comparisons do not like leading zeros. */
3385                 CLAMP(&r);
3386                 if (s_ucmp(TEMP(1), &r) > 0)
3387                 {                                               /* would the remainder be negative? */
3388                         qhat -= 1;                      /* try a smaller q */
3389                         s_dbmul(MP_DIGITS(v), (mp_digit) qhat, TEMP(1)->digits, n + 1);
3390                         TEMP(1)->used = n + 1;
3391                         CLAMP(TEMP(1));
3392                         if (s_ucmp(TEMP(1), &r) > 0)
3393                         {                                       /* would the remainder be negative? */
3394                                 assert(qhat > 0);
3395                                 qhat -= 1;              /* try a smaller q */
3396                                 s_dbmul(MP_DIGITS(v), (mp_digit) qhat, TEMP(1)->digits, n + 1);
3397                                 TEMP(1)->used = n + 1;
3398                                 CLAMP(TEMP(1));
3399                         }
3400                         assert(s_ucmp(TEMP(1), &r) <= 0 && "The mathematics failed us.");
3401                 }
3402
3403                 /*
3404                  * Unclamp r. The D algorithm expects r = u_{j+n}..u_j to always be
3405                  * n+1 digits long.
3406                  */
3407                 r.used = n + 1;
3408
3409                 /*
3410                  * D4: Multiply and subtract
3411                  *
3412                  * Note: The multiply was completed above so we only need to subtract
3413                  * here.
3414                  */
3415                 s_usub(r.digits, TEMP(1)->digits, r.digits, r.used, TEMP(1)->used);
3416
3417                 /*
3418                  * D5: Test remainder
3419                  *
3420                  * Note: Not needed because we always check that qhat is the correct
3421                  * value before performing the subtract.  Value cast to mp_digit to
3422                  * prevent warning, qhat has been clamped to MP_DIGIT_MAX
3423                  */
3424                 TEMP(0)->digits[j] = (mp_digit) qhat;
3425
3426                 /*
3427                  * D6: Add back Note: Not needed because we always check that qhat is
3428                  * the correct value before performing the subtract.
3429                  */
3430
3431                 /* D7: Loop on j */
3432                 r.digits--;
3433                 ZERO(TEMP(1)->digits, TEMP(1)->alloc);
3434         }
3435
3436         /* Get rid of leading zeros in q */
3437         TEMP(0)->used = m + 1;
3438         CLAMP(TEMP(0));
3439
3440         /* Denormalize the remainder */
3441         CLAMP(u);                                       /* use u here because the r.digits pointer is
3442                                                                  * off-by-one */
3443         if (k != 0)
3444                 s_qdiv(u, k);
3445
3446         mp_int_copy(u, v);                      /* ok:  0 <= r < v */
3447         mp_int_copy(TEMP(0), u);        /* ok:  q <= u     */
3448
3449         CLEANUP_TEMP();
3450         return MP_OK;
3451 }
3452
3453 static int
3454 s_outlen(mp_int z, mp_size r)
3455 {
3456         assert(r >= MP_MIN_RADIX && r <= MP_MAX_RADIX);
3457
3458         mp_result       bits = mp_int_count_bits(z);
3459         double          raw = (double) bits * s_log2[r];
3460
3461         return (int) (raw + 0.999999);
3462 }
3463
3464 static mp_size
3465 s_inlen(int len, mp_size r)
3466 {
3467         double          raw = (double) len / s_log2[r];
3468         mp_size         bits = (mp_size) (raw + 0.5);
3469
3470         return (mp_size) ((bits + (MP_DIGIT_BIT - 1)) / MP_DIGIT_BIT) + 1;
3471 }
3472
3473 static int
3474 s_ch2val(char c, int r)
3475 {
3476         int                     out;
3477
3478         /*
3479          * In some locales, isalpha() accepts characters outside the range A-Z,
3480          * producing out<0 or out>=36.  The "out >= r" check will always catch
3481          * out>=36.  Though nothing explicitly catches out<0, our caller reacts
3482          * the same way to every negative return value.
3483          */
3484         if (isdigit((unsigned char) c))
3485                 out = c - '0';
3486         else if (r > 10 && isalpha((unsigned char) c))
3487                 out = toupper((unsigned char) c) - 'A' + 10;
3488         else
3489                 return -1;
3490
3491         return (out >= r) ? -1 : out;
3492 }
3493
3494 static char
3495 s_val2ch(int v, int caps)
3496 {
3497         assert(v >= 0);
3498
3499         if (v < 10)
3500         {
3501                 return v + '0';
3502         }
3503         else
3504         {
3505                 char            out = (v - 10) + 'a';
3506
3507                 if (caps)
3508                 {
3509                         return toupper((unsigned char) out);
3510                 }
3511                 else
3512                 {
3513                         return out;
3514                 }
3515         }
3516 }
3517
3518 static void
3519 s_2comp(unsigned char *buf, int len)
3520 {
3521         unsigned short s = 1;
3522
3523         for (int i = len - 1; i >= 0; --i)
3524         {
3525                 unsigned char c = ~buf[i];
3526
3527                 s = c + s;
3528                 c = s & UCHAR_MAX;
3529                 s >>= CHAR_BIT;
3530
3531                 buf[i] = c;
3532         }
3533
3534         /* last carry out is ignored */
3535 }
3536
3537 static mp_result
3538 s_tobin(mp_int z, unsigned char *buf, int *limpos, int pad)
3539 {
3540         int                     pos = 0,
3541                                 limit = *limpos;
3542         mp_size         uz = MP_USED(z);
3543         mp_digit   *dz = MP_DIGITS(z);
3544
3545         while (uz > 0 && pos < limit)
3546         {
3547                 mp_digit        d = *dz++;
3548                 int                     i;
3549
3550                 for (i = sizeof(mp_digit); i > 0 && pos < limit; --i)
3551                 {
3552                         buf[pos++] = (unsigned char) d;
3553                         d >>= CHAR_BIT;
3554
3555                         /* Don't write leading zeroes */
3556                         if (d == 0 && uz == 1)
3557                                 i = 0;                  /* exit loop without signaling truncation */
3558                 }
3559
3560                 /* Detect truncation (loop exited with pos >= limit) */
3561                 if (i > 0)
3562                         break;
3563
3564                 --uz;
3565         }
3566
3567         if (pad != 0 && (buf[pos - 1] >> (CHAR_BIT - 1)))
3568         {
3569                 if (pos < limit)
3570                 {
3571                         buf[pos++] = 0;
3572                 }
3573                 else
3574                 {
3575                         uz = 1;
3576                 }
3577         }
3578
3579         /* Digits are in reverse order, fix that */
3580         REV(buf, pos);
3581
3582         /* Return the number of bytes actually written */
3583         *limpos = pos;
3584
3585         return (uz == 0) ? MP_OK : MP_TRUNC;
3586 }
3587
3588 /* Here there be dragons */