]> granicus.if.org Git - postgresql/blob - contrib/pgcrypto/imath.c
pgindent run for 8.2.
[postgresql] / contrib / pgcrypto / imath.c
1 /* imath version 1.3 */
2 /*
3   Name:         imath.c
4   Purpose:      Arbitrary precision integer arithmetic routines.
5   Author:       M. J. Fromberger <http://www.dartmouth.edu/~sting/>
6   Info:         Id: imath.c 21 2006-04-02 18:58:36Z sting
7
8   Copyright (C) 2002 Michael J. Fromberger, All Rights Reserved.
9
10   Permission is hereby granted, free of charge, to any person
11   obtaining a copy of this software and associated documentation files
12   (the "Software"), to deal in the Software without restriction,
13   including without limitation the rights to use, copy, modify, merge,
14   publish, distribute, sublicense, and/or sell copies of the Software,
15   and to permit persons to whom the Software is furnished to do so,
16   subject to the following conditions:
17
18   The above copyright notice and this permission notice shall be
19   included in all copies or substantial portions of the Software.
20
21   THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
22   EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
23   MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
24   NONINFRINGEMENT.      IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
25   BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
26   ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
27   CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28   SOFTWARE.
29  */
30 /* $PostgreSQL: pgsql/contrib/pgcrypto/imath.c,v 1.6 2006/10/04 00:29:46 momjian Exp $ */
31
32 #include "postgres.h"
33 #include "px.h"
34 #include "imath.h"
35
36 #undef assert
37 #define assert(TEST) Assert(TEST)
38 #define TRACEABLE_CLAMP 0
39 #define TRACEABLE_FREE 0
40
41 /* {{{ Constants */
42
43 const mp_result MP_OK = 0;              /* no error, all is well  */
44 const mp_result MP_FALSE = 0;   /* boolean false                  */
45 const mp_result MP_TRUE = -1;   /* boolean true                   */
46 const mp_result MP_MEMORY = -2; /* out of memory                  */
47 const mp_result MP_RANGE = -3;  /* argument out of range  */
48 const mp_result MP_UNDEF = -4;  /* result undefined               */
49 const mp_result MP_TRUNC = -5;  /* output truncated               */
50 const mp_result MP_BADARG = -6; /* invalid null argument  */
51
52 const mp_sign MP_NEG = 1;               /* value is strictly negative */
53 const mp_sign MP_ZPOS = 0;              /* value is non-negative          */
54
55 static const char *s_unknown_err = "unknown result code";
56 static const char *s_error_msg[] = {
57         "error code 0",
58         "boolean true",
59         "out of memory",
60         "argument out of range",
61         "result undefined",
62         "output truncated",
63         "invalid null argument",
64         NULL
65 };
66
67 /* }}} */
68
69 /* Optional library flags */
70 #define MP_CAP_DIGITS   1               /* flag bit to capitalize letter digits */
71
72 /* Argument checking macros
73    Use CHECK() where a return value is required; NRCHECK() elsewhere */
74 #define CHECK(TEST)   assert(TEST)
75 #define NRCHECK(TEST) assert(TEST)
76
77 /* {{{ Logarithm table for computing output sizes */
78
79 /* The ith entry of this table gives the value of log_i(2).
80
81    An integer value n requires ceil(log_i(n)) digits to be represented
82    in base i.  Since it is easy to compute lg(n), by counting bits, we
83    can compute log_i(n) = lg(n) * log_i(2).
84  */
85 static const double s_log2[] = {
86         0.000000000, 0.000000000, 1.000000000, 0.630929754, /* 0  1  2  3 */
87         0.500000000, 0.430676558, 0.386852807, 0.356207187, /* 4  5  6  7 */
88         0.333333333, 0.315464877, 0.301029996, 0.289064826, /* 8  9 10 11 */
89         0.278942946, 0.270238154, 0.262649535, 0.255958025, /* 12 13 14 15 */
90         0.250000000, 0.244650542, 0.239812467, 0.235408913, /* 16 17 18 19 */
91         0.231378213, 0.227670249, 0.224243824, 0.221064729, /* 20 21 22 23 */
92         0.218104292, 0.215338279, 0.212746054, 0.210309918, /* 24 25 26 27 */
93         0.208014598, 0.205846832, 0.203795047, 0.201849087, /* 28 29 30 31 */
94         0.200000000, 0.198239863, 0.196561632, 0.194959022, /* 32 33 34 35 */
95         0.193426404, 0.191958720, 0.190551412, 0.189200360, /* 36 37 38 39 */
96         0.187901825, 0.186652411, 0.185449023, 0.184288833, /* 40 41 42 43 */
97         0.183169251, 0.182087900, 0.181042597, 0.180031327, /* 44 45 46 47 */
98         0.179052232, 0.178103594, 0.177183820, 0.176291434, /* 48 49 50 51 */
99         0.175425064, 0.174583430, 0.173765343, 0.172969690, /* 52 53 54 55 */
100         0.172195434, 0.171441601, 0.170707280, 0.169991616, /* 56 57 58 59 */
101         0.169293808, 0.168613099, 0.167948779, 0.167300179, /* 60 61 62 63 */
102         0.166666667
103 };
104
105 /* }}} */
106 /* {{{ Various macros */
107
108 /* Return the number of digits needed to represent a static value */
109 #define MP_VALUE_DIGITS(V) \
110 ((sizeof(V)+(sizeof(mp_digit)-1))/sizeof(mp_digit))
111
112 /* Round precision P to nearest word boundary */
113 #define ROUND_PREC(P) ((mp_size)(2*(((P)+1)/2)))
114
115 /* Set array P of S digits to zero */
116 #define ZERO(P, S) \
117 do{mp_size i__=(S)*sizeof(mp_digit);mp_digit *p__=(P);memset(p__,0,i__);}while(0)
118
119 /* Copy S digits from array P to array Q */
120 #define COPY(P, Q, S) \
121 do{mp_size i__=(S)*sizeof(mp_digit);mp_digit *p__=(P),*q__=(Q);\
122 memcpy(q__,p__,i__);}while(0)
123
124 /* Reverse N elements of type T in array A */
125 #define REV(T, A, N) \
126 do{T *u_=(A),*v_=u_+(N)-1;while(u_<v_){T xch=*u_;*u_++=*v_;*v_--=xch;}}while(0)
127
128 #if TRACEABLE_CLAMP
129 #define CLAMP(Z) s_clamp(Z)
130 #else
131 #define CLAMP(Z) \
132 do{mp_int z_=(Z);mp_size uz_=MP_USED(z_);mp_digit *dz_=MP_DIGITS(z_)+uz_-1;\
133 while(uz_ > 1 && (*dz_-- == 0)) --uz_;MP_USED(z_)=uz_;}while(0)
134 #endif
135
136 #undef MIN
137 #undef MAX
138 #define MIN(A, B) ((B)<(A)?(B):(A))
139 #define MAX(A, B) ((B)>(A)?(B):(A))
140 #define SWAP(T, A, B) do{T t_=(A);A=(B);B=t_;}while(0)
141
142 #define TEMP(K) (temp + (K))
143 #define SETUP(E, C) \
144 do{if((res = (E)) != MP_OK) goto CLEANUP; ++(C);}while(0)
145
146 #define CMPZ(Z) \
147 (((Z)->used==1&&(Z)->digits[0]==0)?0:((Z)->sign==MP_NEG)?-1:1)
148
149 #define UMUL(X, Y, Z) \
150 do{mp_size ua_=MP_USED(X),ub_=MP_USED(Y);mp_size o_=ua_+ub_;\
151 ZERO(MP_DIGITS(Z),o_);\
152 (void) s_kmul(MP_DIGITS(X),MP_DIGITS(Y),MP_DIGITS(Z),ua_,ub_);\
153 MP_USED(Z)=o_;CLAMP(Z);}while(0)
154
155 #define USQR(X, Z) \
156 do{mp_size ua_=MP_USED(X),o_=ua_+ua_;ZERO(MP_DIGITS(Z),o_);\
157 (void) s_ksqr(MP_DIGITS(X),MP_DIGITS(Z),ua_);MP_USED(Z)=o_;CLAMP(Z);}while(0)
158
159 #define UPPER_HALF(W)                   ((mp_word)((W) >> MP_DIGIT_BIT))
160 #define LOWER_HALF(W)                   ((mp_digit)(W))
161 #define HIGH_BIT_SET(W)                 ((W) >> (MP_WORD_BIT - 1))
162 #define ADD_WILL_OVERFLOW(W, V) ((MP_WORD_MAX - (V)) < (W))
163
164 /* }}} */
165
166 /* Default number of digits allocated to a new mp_int */
167 static mp_size default_precision = 64;
168
169 /* Minimum number of digits to invoke recursive multiply */
170 static mp_size multiply_threshold = 32;
171
172 /* Default library configuration flags */
173 static mp_word mp_flags = MP_CAP_DIGITS;
174
175 /* Allocate a buffer of (at least) num digits, or return
176    NULL if that couldn't be done.  */
177 static mp_digit *s_alloc(mp_size num);
178
179 #if TRACEABLE_FREE
180 static void s_free(void *ptr);
181 #else
182 #define s_free(P) px_free(P)
183 #endif
184
185 /* Insure that z has at least min digits allocated, resizing if
186    necessary.  Returns true if successful, false if out of memory. */
187 static int      s_pad(mp_int z, mp_size min);
188
189 /* Normalize by removing leading zeroes (except when z = 0) */
190 #if TRACEABLE_CLAMP
191 static void s_clamp(mp_int z);
192 #endif
193
194 /* Fill in a "fake" mp_int on the stack with a given value */
195 static void s_fake(mp_int z, int value, mp_digit vbuf[]);
196
197 /* Compare two runs of digits of given length, returns <0, 0, >0 */
198 static int      s_cdig(mp_digit * da, mp_digit * db, mp_size len);
199
200 /* Pack the unsigned digits of v into array t */
201 static int      s_vpack(int v, mp_digit t[]);
202
203 /* Compare magnitudes of a and b, returns <0, 0, >0 */
204 static int      s_ucmp(mp_int a, mp_int b);
205
206 /* Compare magnitudes of a and v, returns <0, 0, >0 */
207 static int      s_vcmp(mp_int a, int v);
208
209 /* Unsigned magnitude addition; assumes dc is big enough.
210    Carry out is returned (no memory allocated). */
211 static mp_digit s_uadd(mp_digit * da, mp_digit * db, mp_digit * dc,
212            mp_size size_a, mp_size size_b);
213
214 /* Unsigned magnitude subtraction.      Assumes dc is big enough. */
215 static void s_usub(mp_digit * da, mp_digit * db, mp_digit * dc,
216            mp_size size_a, mp_size size_b);
217
218 /* Unsigned recursive multiplication.  Assumes dc is big enough. */
219 static int s_kmul(mp_digit * da, mp_digit * db, mp_digit * dc,
220            mp_size size_a, mp_size size_b);
221
222 /* Unsigned magnitude multiplication.  Assumes dc is big enough. */
223 static void s_umul(mp_digit * da, mp_digit * db, mp_digit * dc,
224            mp_size size_a, mp_size size_b);
225
226 /* Unsigned recursive squaring.  Assumes dc is big enough. */
227 static int      s_ksqr(mp_digit * da, mp_digit * dc, mp_size size_a);
228
229 /* Unsigned magnitude squaring.  Assumes dc is big enough. */
230 static void s_usqr(mp_digit * da, mp_digit * dc, mp_size size_a);
231
232 /* Single digit addition.  Assumes a is big enough. */
233 static void s_dadd(mp_int a, mp_digit b);
234
235 /* Single digit multiplication.  Assumes a is big enough. */
236 static void s_dmul(mp_int a, mp_digit b);
237
238 /* Single digit multiplication on buffers; assumes dc is big enough. */
239 static void s_dbmul(mp_digit * da, mp_digit b, mp_digit * dc,
240                 mp_size size_a);
241
242 /* Single digit division.  Replaces a with the quotient,
243    returns the remainder.  */
244 static mp_digit s_ddiv(mp_int a, mp_digit b);
245
246 /* Quick division by a power of 2, replaces z (no allocation) */
247 static void s_qdiv(mp_int z, mp_size p2);
248
249 /* Quick remainder by a power of 2, replaces z (no allocation) */
250 static void s_qmod(mp_int z, mp_size p2);
251
252 /* Quick multiplication by a power of 2, replaces z.
253    Allocates if necessary; returns false in case this fails. */
254 static int      s_qmul(mp_int z, mp_size p2);
255
256 /* Quick subtraction from a power of 2, replaces z.
257    Allocates if necessary; returns false in case this fails. */
258 static int      s_qsub(mp_int z, mp_size p2);
259
260 /* Return maximum k such that 2^k divides z. */
261 static int      s_dp2k(mp_int z);
262
263 /* Return k >= 0 such that z = 2^k, or -1 if there is no such k. */
264 static int      s_isp2(mp_int z);
265
266 /* Set z to 2^k.  May allocate; returns false in case this fails. */
267 static int      s_2expt(mp_int z, int k);
268
269 /* Normalize a and b for division, returns normalization constant */
270 static int      s_norm(mp_int a, mp_int b);
271
272 /* Compute constant mu for Barrett reduction, given modulus m, result
273    replaces z, m is untouched. */
274 static mp_result s_brmu(mp_int z, mp_int m);
275
276 /* Reduce a modulo m, using Barrett's algorithm. */
277 static int      s_reduce(mp_int x, mp_int m, mp_int mu, mp_int q1, mp_int q2);
278
279 /* Modular exponentiation, using Barrett reduction */
280 static mp_result s_embar(mp_int a, mp_int b, mp_int m, mp_int mu, mp_int c);
281
282 /* Unsigned magnitude division.  Assumes |a| > |b|.  Allocates
283    temporaries; overwrites a with quotient, b with remainder. */
284 static mp_result s_udiv(mp_int a, mp_int b);
285
286 /* Compute the number of digits in radix r required to represent the
287    given value.  Does not account for sign flags, terminators, etc. */
288 static int      s_outlen(mp_int z, mp_size r);
289
290 /* Guess how many digits of precision will be needed to represent a
291    radix r value of the specified number of digits.  Returns a value
292    guaranteed to be no smaller than the actual number required. */
293 static mp_size s_inlen(int len, mp_size r);
294
295 /* Convert a character to a digit value in radix r, or
296    -1 if out of range */
297 static int      s_ch2val(char c, int r);
298
299 /* Convert a digit value to a character */
300 static char s_val2ch(int v, int caps);
301
302 /* Take 2's complement of a buffer in place */
303 static void s_2comp(unsigned char *buf, int len);
304
305 /* Convert a value to binary, ignoring sign.  On input, *limpos is the
306    bound on how many bytes should be written to buf; on output, *limpos
307    is set to the number of bytes actually written. */
308 static mp_result s_tobin(mp_int z, unsigned char *buf, int *limpos, int pad);
309
310 #if 0
311 /* Dump a representation of the mp_int to standard output */
312 void            s_print(char *tag, mp_int z);
313 void            s_print_buf(char *tag, mp_digit * buf, mp_size num);
314 #endif
315
316 /* {{{ get_default_precision() */
317
318 mp_size
319 mp_get_default_precision(void)
320 {
321         return default_precision;
322 }
323
324 /* }}} */
325
326 /* {{{ mp_set_default_precision(s) */
327
328 void
329 mp_set_default_precision(mp_size s)
330 {
331         NRCHECK(s > 0);
332
333         default_precision = (mp_size) ROUND_PREC(s);
334 }
335
336 /* }}} */
337
338 /* {{{ mp_get_multiply_threshold() */
339
340 mp_size
341 mp_get_multiply_threshold(void)
342 {
343         return multiply_threshold;
344 }
345
346 /* }}} */
347
348 /* {{{ mp_set_multiply_threshold(s) */
349
350 void
351 mp_set_multiply_threshold(mp_size s)
352 {
353         multiply_threshold = s;
354 }
355
356 /* }}} */
357
358 /* {{{ mp_int_init(z) */
359
360 mp_result
361 mp_int_init(mp_int z)
362 {
363         return mp_int_init_size(z, default_precision);
364 }
365
366 /* }}} */
367
368 /* {{{ mp_int_alloc() */
369
370 mp_int
371 mp_int_alloc(void)
372 {
373         mp_int          out = px_alloc(sizeof(mpz_t));
374
375         assert(out != NULL);
376         out->digits = NULL;
377         out->used = 0;
378         out->alloc = 0;
379         out->sign = 0;
380
381         return out;
382 }
383
384 /* }}} */
385
386 /* {{{ mp_int_init_size(z, prec) */
387
388 mp_result
389 mp_int_init_size(mp_int z, mp_size prec)
390 {
391         CHECK(z != NULL);
392
393         prec = (mp_size) ROUND_PREC(prec);
394         prec = MAX(prec, default_precision);
395
396         if ((MP_DIGITS(z) = s_alloc(prec)) == NULL)
397                 return MP_MEMORY;
398
399         z->digits[0] = 0;
400         MP_USED(z) = 1;
401         MP_ALLOC(z) = prec;
402         MP_SIGN(z) = MP_ZPOS;
403
404         return MP_OK;
405 }
406
407 /* }}} */
408
409 /* {{{ mp_int_init_copy(z, old) */
410
411 mp_result
412 mp_int_init_copy(mp_int z, mp_int old)
413 {
414         mp_result       res;
415         mp_size         uold,
416                                 target;
417
418         CHECK(z != NULL && old != NULL);
419
420         uold = MP_USED(old);
421         target = MAX(uold, default_precision);
422
423         if ((res = mp_int_init_size(z, target)) != MP_OK)
424                 return res;
425
426         MP_USED(z) = uold;
427         MP_SIGN(z) = MP_SIGN(old);
428         COPY(MP_DIGITS(old), MP_DIGITS(z), uold);
429
430         return MP_OK;
431 }
432
433 /* }}} */
434
435 /* {{{ mp_int_init_value(z, value) */
436
437 mp_result
438 mp_int_init_value(mp_int z, int value)
439 {
440         mp_result       res;
441
442         CHECK(z != NULL);
443
444         if ((res = mp_int_init(z)) != MP_OK)
445                 return res;
446
447         return mp_int_set_value(z, value);
448 }
449
450 /* }}} */
451
452 /* {{{ mp_int_set_value(z, value) */
453
454 mp_result
455 mp_int_set_value(mp_int z, int value)
456 {
457         mp_size         ndig;
458
459         CHECK(z != NULL);
460
461         /* How many digits to copy */
462         ndig = (mp_size) MP_VALUE_DIGITS(value);
463
464         if (!s_pad(z, ndig))
465                 return MP_MEMORY;
466
467         MP_USED(z) = (mp_size) s_vpack(value, MP_DIGITS(z));
468         MP_SIGN(z) = (value < 0) ? MP_NEG : MP_ZPOS;
469
470         return MP_OK;
471 }
472
473 /* }}} */
474
475 /* {{{ mp_int_clear(z) */
476
477 void
478 mp_int_clear(mp_int z)
479 {
480         if (z == NULL)
481                 return;
482
483         if (MP_DIGITS(z) != NULL)
484         {
485                 s_free(MP_DIGITS(z));
486                 MP_DIGITS(z) = NULL;
487         }
488 }
489
490 /* }}} */
491
492 /* {{{ mp_int_free(z) */
493
494 void
495 mp_int_free(mp_int z)
496 {
497         NRCHECK(z != NULL);
498
499         if (z->digits != NULL)
500                 mp_int_clear(z);
501
502         px_free(z);
503 }
504
505 /* }}} */
506
507 /* {{{ mp_int_copy(a, c) */
508
509 mp_result
510 mp_int_copy(mp_int a, mp_int c)
511 {
512         CHECK(a != NULL && c != NULL);
513
514         if (a != c)
515         {
516                 mp_size         ua = MP_USED(a);
517                 mp_digit   *da,
518                                    *dc;
519
520                 if (!s_pad(c, ua))
521                         return MP_MEMORY;
522
523                 da = MP_DIGITS(a);
524                 dc = MP_DIGITS(c);
525                 COPY(da, dc, ua);
526
527                 MP_USED(c) = ua;
528                 MP_SIGN(c) = MP_SIGN(a);
529         }
530
531         return MP_OK;
532 }
533
534 /* }}} */
535
536 /* {{{ mp_int_swap(a, c) */
537
538 void
539 mp_int_swap(mp_int a, mp_int c)
540 {
541         if (a != c)
542         {
543                 mpz_t           tmp = *a;
544
545                 *a = *c;
546                 *c = tmp;
547         }
548 }
549
550 /* }}} */
551
552 /* {{{ mp_int_zero(z) */
553
554 void
555 mp_int_zero(mp_int z)
556 {
557         NRCHECK(z != NULL);
558
559         z->digits[0] = 0;
560         MP_USED(z) = 1;
561         MP_SIGN(z) = MP_ZPOS;
562 }
563
564 /* }}} */
565
566 /* {{{ mp_int_abs(a, c) */
567
568 mp_result
569 mp_int_abs(mp_int a, mp_int c)
570 {
571         mp_result       res;
572
573         CHECK(a != NULL && c != NULL);
574
575         if ((res = mp_int_copy(a, c)) != MP_OK)
576                 return res;
577
578         MP_SIGN(c) = MP_ZPOS;
579         return MP_OK;
580 }
581
582 /* }}} */
583
584 /* {{{ mp_int_neg(a, c) */
585
586 mp_result
587 mp_int_neg(mp_int a, mp_int c)
588 {
589         mp_result       res;
590
591         CHECK(a != NULL && c != NULL);
592
593         if ((res = mp_int_copy(a, c)) != MP_OK)
594                 return res;
595
596         if (CMPZ(c) != 0)
597                 MP_SIGN(c) = 1 - MP_SIGN(a);
598
599         return MP_OK;
600 }
601
602 /* }}} */
603
604 /* {{{ mp_int_add(a, b, c) */
605
606 mp_result
607 mp_int_add(mp_int a, mp_int b, mp_int c)
608 {
609         mp_size         ua,
610                                 ub,
611                                 uc,
612                                 max;
613
614         CHECK(a != NULL && b != NULL && c != NULL);
615
616         ua = MP_USED(a);
617         ub = MP_USED(b);
618         uc = MP_USED(c);
619         max = MAX(ua, ub);
620
621         if (MP_SIGN(a) == MP_SIGN(b))
622         {
623                 /* Same sign -- add magnitudes, preserve sign of addends */
624                 mp_digit        carry;
625
626                 if (!s_pad(c, max))
627                         return MP_MEMORY;
628
629                 carry = s_uadd(MP_DIGITS(a), MP_DIGITS(b), MP_DIGITS(c), ua, ub);
630                 uc = max;
631
632                 if (carry)
633                 {
634                         if (!s_pad(c, max + 1))
635                                 return MP_MEMORY;
636
637                         c->digits[max] = carry;
638                         ++uc;
639                 }
640
641                 MP_USED(c) = uc;
642                 MP_SIGN(c) = MP_SIGN(a);
643
644         }
645         else
646         {
647                 /* Different signs -- subtract magnitudes, preserve sign of greater */
648                 mp_int          x,
649                                         y;
650                 int                     cmp = s_ucmp(a, b); /* magnitude comparision, sign ignored */
651
652                 /* Set x to max(a, b), y to min(a, b) to simplify later code */
653                 if (cmp >= 0)
654                 {
655                         x = a;
656                         y = b;
657                 }
658                 else
659                 {
660                         x = b;
661                         y = a;
662                 }
663
664                 if (!s_pad(c, MP_USED(x)))
665                         return MP_MEMORY;
666
667                 /* Subtract smaller from larger */
668                 s_usub(MP_DIGITS(x), MP_DIGITS(y), MP_DIGITS(c), MP_USED(x), MP_USED(y));
669                 MP_USED(c) = MP_USED(x);
670                 CLAMP(c);
671
672                 /* Give result the sign of the larger */
673                 MP_SIGN(c) = MP_SIGN(x);
674         }
675
676         return MP_OK;
677 }
678
679 /* }}} */
680
681 /* {{{ mp_int_add_value(a, value, c) */
682
683 mp_result
684 mp_int_add_value(mp_int a, int value, mp_int c)
685 {
686         mpz_t           vtmp;
687         mp_digit        vbuf[MP_VALUE_DIGITS(value)];
688
689         s_fake(&vtmp, value, vbuf);
690
691         return mp_int_add(a, &vtmp, c);
692 }
693
694 /* }}} */
695
696 /* {{{ mp_int_sub(a, b, c) */
697
698 mp_result
699 mp_int_sub(mp_int a, mp_int b, mp_int c)
700 {
701         mp_size         ua,
702                                 ub,
703                                 uc,
704                                 max;
705
706         CHECK(a != NULL && b != NULL && c != NULL);
707
708         ua = MP_USED(a);
709         ub = MP_USED(b);
710         uc = MP_USED(c);
711         max = MAX(ua, ub);
712
713         if (MP_SIGN(a) != MP_SIGN(b))
714         {
715                 /* Different signs -- add magnitudes and keep sign of a */
716                 mp_digit        carry;
717
718                 if (!s_pad(c, max))
719                         return MP_MEMORY;
720
721                 carry = s_uadd(MP_DIGITS(a), MP_DIGITS(b), MP_DIGITS(c), ua, ub);
722                 uc = max;
723
724                 if (carry)
725                 {
726                         if (!s_pad(c, max + 1))
727                                 return MP_MEMORY;
728
729                         c->digits[max] = carry;
730                         ++uc;
731                 }
732
733                 MP_USED(c) = uc;
734                 MP_SIGN(c) = MP_SIGN(a);
735
736         }
737         else
738         {
739                 /* Same signs -- subtract magnitudes */
740                 mp_int          x,
741                                         y;
742                 mp_sign         osign;
743                 int                     cmp = s_ucmp(a, b);
744
745                 if (!s_pad(c, max))
746                         return MP_MEMORY;
747
748                 if (cmp >= 0)
749                 {
750                         x = a;
751                         y = b;
752                         osign = MP_ZPOS;
753                 }
754                 else
755                 {
756                         x = b;
757                         y = a;
758                         osign = MP_NEG;
759                 }
760
761                 if (MP_SIGN(a) == MP_NEG && cmp != 0)
762                         osign = 1 - osign;
763
764                 s_usub(MP_DIGITS(x), MP_DIGITS(y), MP_DIGITS(c), MP_USED(x), MP_USED(y));
765                 MP_USED(c) = MP_USED(x);
766                 CLAMP(c);
767
768                 MP_SIGN(c) = osign;
769         }
770
771         return MP_OK;
772 }
773
774 /* }}} */
775
776 /* {{{ mp_int_sub_value(a, value, c) */
777
778 mp_result
779 mp_int_sub_value(mp_int a, int value, mp_int c)
780 {
781         mpz_t           vtmp;
782         mp_digit        vbuf[MP_VALUE_DIGITS(value)];
783
784         s_fake(&vtmp, value, vbuf);
785
786         return mp_int_sub(a, &vtmp, c);
787 }
788
789 /* }}} */
790
791 /* {{{ mp_int_mul(a, b, c) */
792
793 mp_result
794 mp_int_mul(mp_int a, mp_int b, mp_int c)
795 {
796         mp_digit   *out;
797         mp_size         osize,
798                                 ua,
799                                 ub,
800                                 p = 0;
801         mp_sign         osign;
802
803         CHECK(a != NULL && b != NULL && c != NULL);
804
805         /* If either input is zero, we can shortcut multiplication */
806         if (mp_int_compare_zero(a) == 0 || mp_int_compare_zero(b) == 0)
807         {
808                 mp_int_zero(c);
809                 return MP_OK;
810         }
811
812         /* Output is positive if inputs have same sign, otherwise negative */
813         osign = (MP_SIGN(a) == MP_SIGN(b)) ? MP_ZPOS : MP_NEG;
814
815         /*
816          * If the output is not equal to any of the inputs, we'll write the
817          * results there directly; otherwise, allocate a temporary space.
818          */
819         ua = MP_USED(a);
820         ub = MP_USED(b);
821         osize = ua + ub;
822
823         if (c == a || c == b)
824         {
825                 p = ROUND_PREC(osize);
826                 p = MAX(p, default_precision);
827
828                 if ((out = s_alloc(p)) == NULL)
829                         return MP_MEMORY;
830         }
831         else
832         {
833                 if (!s_pad(c, osize))
834                         return MP_MEMORY;
835
836                 out = MP_DIGITS(c);
837         }
838         ZERO(out, osize);
839
840         if (!s_kmul(MP_DIGITS(a), MP_DIGITS(b), out, ua, ub))
841                 return MP_MEMORY;
842
843         /*
844          * If we allocated a new buffer, get rid of whatever memory c was already
845          * using, and fix up its fields to reflect that.
846          */
847         if (out != MP_DIGITS(c))
848         {
849                 s_free(MP_DIGITS(c));
850                 MP_DIGITS(c) = out;
851                 MP_ALLOC(c) = p;
852         }
853
854         MP_USED(c) = osize;                     /* might not be true, but we'll fix it ... */
855         CLAMP(c);                                       /* ... right here */
856         MP_SIGN(c) = osign;
857
858         return MP_OK;
859 }
860
861 /* }}} */
862
863 /* {{{ mp_int_mul_value(a, value, c) */
864
865 mp_result
866 mp_int_mul_value(mp_int a, int value, mp_int c)
867 {
868         mpz_t           vtmp;
869         mp_digit        vbuf[MP_VALUE_DIGITS(value)];
870
871         s_fake(&vtmp, value, vbuf);
872
873         return mp_int_mul(a, &vtmp, c);
874 }
875
876 /* }}} */
877
878 /* {{{ mp_int_mul_pow2(a, p2, c) */
879
880 mp_result
881 mp_int_mul_pow2(mp_int a, int p2, mp_int c)
882 {
883         mp_result       res;
884
885         CHECK(a != NULL && c != NULL && p2 >= 0);
886
887         if ((res = mp_int_copy(a, c)) != MP_OK)
888                 return res;
889
890         if (s_qmul(c, (mp_size) p2))
891                 return MP_OK;
892         else
893                 return MP_MEMORY;
894 }
895
896 /* }}} */
897
898 /* {{{ mp_int_sqr(a, c) */
899
900 mp_result
901 mp_int_sqr(mp_int a, mp_int c)
902 {
903         mp_digit   *out;
904         mp_size         osize,
905                                 p = 0;
906
907         CHECK(a != NULL && c != NULL);
908
909         /* Get a temporary buffer big enough to hold the result */
910         osize = (mp_size) 2 *MP_USED(a);
911
912         if (a == c)
913         {
914                 p = ROUND_PREC(osize);
915                 p = MAX(p, default_precision);
916
917                 if ((out = s_alloc(p)) == NULL)
918                         return MP_MEMORY;
919         }
920         else
921         {
922                 if (!s_pad(c, osize))
923                         return MP_MEMORY;
924
925                 out = MP_DIGITS(c);
926         }
927         ZERO(out, osize);
928
929         s_ksqr(MP_DIGITS(a), out, MP_USED(a));
930
931         /*
932          * Get rid of whatever memory c was already using, and fix up its fields
933          * to reflect the new digit array it's using
934          */
935         if (out != MP_DIGITS(c))
936         {
937                 s_free(MP_DIGITS(c));
938                 MP_DIGITS(c) = out;
939                 MP_ALLOC(c) = p;
940         }
941
942         MP_USED(c) = osize;                     /* might not be true, but we'll fix it ... */
943         CLAMP(c);                                       /* ... right here */
944         MP_SIGN(c) = MP_ZPOS;
945
946         return MP_OK;
947 }
948
949 /* }}} */
950
951 /* {{{ mp_int_div(a, b, q, r) */
952
953 mp_result
954 mp_int_div(mp_int a, mp_int b, mp_int q, mp_int r)
955 {
956         int                     cmp,
957                                 last = 0,
958                                 lg;
959         mp_result       res = MP_OK;
960         mpz_t           temp[2];
961         mp_int          qout,
962                                 rout;
963         mp_sign         sa = MP_SIGN(a),
964                                 sb = MP_SIGN(b);
965
966         CHECK(a != NULL && b != NULL && q != r);
967
968         if (CMPZ(b) == 0)
969                 return MP_UNDEF;
970         else if ((cmp = s_ucmp(a, b)) < 0)
971         {
972                 /*
973                  * If |a| < |b|, no division is required: q = 0, r = a
974                  */
975                 if (r && (res = mp_int_copy(a, r)) != MP_OK)
976                         return res;
977
978                 if (q)
979                         mp_int_zero(q);
980
981                 return MP_OK;
982         }
983         else if (cmp == 0)
984         {
985                 /*
986                  * If |a| = |b|, no division is required: q = 1 or -1, r = 0
987                  */
988                 if (r)
989                         mp_int_zero(r);
990
991                 if (q)
992                 {
993                         mp_int_zero(q);
994                         q->digits[0] = 1;
995
996                         if (sa != sb)
997                                 MP_SIGN(q) = MP_NEG;
998                 }
999
1000                 return MP_OK;
1001         }
1002
1003         /*
1004          * When |a| > |b|, real division is required.  We need someplace to store
1005          * quotient and remainder, but q and r are allowed to be NULL or to
1006          * overlap with the inputs.
1007          */
1008         if ((lg = s_isp2(b)) < 0)
1009         {
1010                 if (q && b != q && (res = mp_int_copy(a, q)) == MP_OK)
1011                 {
1012                         qout = q;
1013                 }
1014                 else
1015                 {
1016                         qout = TEMP(last);
1017                         SETUP(mp_int_init_copy(TEMP(last), a), last);
1018                 }
1019
1020                 if (r && a != r && (res = mp_int_copy(b, r)) == MP_OK)
1021                 {
1022                         rout = r;
1023                 }
1024                 else
1025                 {
1026                         rout = TEMP(last);
1027                         SETUP(mp_int_init_copy(TEMP(last), b), last);
1028                 }
1029
1030                 if ((res = s_udiv(qout, rout)) != MP_OK)
1031                         goto CLEANUP;
1032         }
1033         else
1034         {
1035                 if (q && (res = mp_int_copy(a, q)) != MP_OK)
1036                         goto CLEANUP;
1037                 if (r && (res = mp_int_copy(a, r)) != MP_OK)
1038                         goto CLEANUP;
1039
1040                 if (q)
1041                         s_qdiv(q, (mp_size) lg);
1042                 qout = q;
1043                 if (r)
1044                         s_qmod(r, (mp_size) lg);
1045                 rout = r;
1046         }
1047
1048         /* Recompute signs for output */
1049         if (rout)
1050         {
1051                 MP_SIGN(rout) = sa;
1052                 if (CMPZ(rout) == 0)
1053                         MP_SIGN(rout) = MP_ZPOS;
1054         }
1055         if (qout)
1056         {
1057                 MP_SIGN(qout) = (sa == sb) ? MP_ZPOS : MP_NEG;
1058                 if (CMPZ(qout) == 0)
1059                         MP_SIGN(qout) = MP_ZPOS;
1060         }
1061
1062         if (q && (res = mp_int_copy(qout, q)) != MP_OK)
1063                 goto CLEANUP;
1064         if (r && (res = mp_int_copy(rout, r)) != MP_OK)
1065                 goto CLEANUP;
1066
1067 CLEANUP:
1068         while (--last >= 0)
1069                 mp_int_clear(TEMP(last));
1070
1071         return res;
1072 }
1073
1074 /* }}} */
1075
1076 /* {{{ mp_int_mod(a, m, c) */
1077
1078 mp_result
1079 mp_int_mod(mp_int a, mp_int m, mp_int c)
1080 {
1081         mp_result       res;
1082         mpz_t           tmp;
1083         mp_int          out;
1084
1085         if (m == c)
1086         {
1087                 if ((res = mp_int_init(&tmp)) != MP_OK)
1088                         return res;
1089
1090                 out = &tmp;
1091         }
1092         else
1093         {
1094                 out = c;
1095         }
1096
1097         if ((res = mp_int_div(a, m, NULL, out)) != MP_OK)
1098                 goto CLEANUP;
1099
1100         if (CMPZ(out) < 0)
1101                 res = mp_int_add(out, m, c);
1102         else
1103                 res = mp_int_copy(out, c);
1104
1105 CLEANUP:
1106         if (out != c)
1107                 mp_int_clear(&tmp);
1108
1109         return res;
1110 }
1111
1112 /* }}} */
1113
1114
1115 /* {{{ mp_int_div_value(a, value, q, r) */
1116
1117 mp_result
1118 mp_int_div_value(mp_int a, int value, mp_int q, int *r)
1119 {
1120         mpz_t           vtmp,
1121                                 rtmp;
1122         mp_digit        vbuf[MP_VALUE_DIGITS(value)];
1123         mp_result       res;
1124
1125         if ((res = mp_int_init(&rtmp)) != MP_OK)
1126                 return res;
1127         s_fake(&vtmp, value, vbuf);
1128
1129         if ((res = mp_int_div(a, &vtmp, q, &rtmp)) != MP_OK)
1130                 goto CLEANUP;
1131
1132         if (r)
1133                 (void) mp_int_to_int(&rtmp, r); /* can't fail */
1134
1135 CLEANUP:
1136         mp_int_clear(&rtmp);
1137         return res;
1138 }
1139
1140 /* }}} */
1141
1142 /* {{{ mp_int_div_pow2(a, p2, q, r) */
1143
1144 mp_result
1145 mp_int_div_pow2(mp_int a, int p2, mp_int q, mp_int r)
1146 {
1147         mp_result       res = MP_OK;
1148
1149         CHECK(a != NULL && p2 >= 0 && q != r);
1150
1151         if (q != NULL && (res = mp_int_copy(a, q)) == MP_OK)
1152                 s_qdiv(q, (mp_size) p2);
1153
1154         if (res == MP_OK && r != NULL && (res = mp_int_copy(a, r)) == MP_OK)
1155                 s_qmod(r, (mp_size) p2);
1156
1157         return res;
1158 }
1159
1160 /* }}} */
1161
1162 /* {{{ mp_int_expt(a, b, c) */
1163
1164 mp_result
1165 mp_int_expt(mp_int a, int b, mp_int c)
1166 {
1167         mpz_t           t;
1168         mp_result       res;
1169         unsigned int v = abs(b);
1170
1171         CHECK(b >= 0 && c != NULL);
1172
1173         if ((res = mp_int_init_copy(&t, a)) != MP_OK)
1174                 return res;
1175
1176         (void) mp_int_set_value(c, 1);
1177         while (v != 0)
1178         {
1179                 if (v & 1)
1180                 {
1181                         if ((res = mp_int_mul(c, &t, c)) != MP_OK)
1182                                 goto CLEANUP;
1183                 }
1184
1185                 v >>= 1;
1186                 if (v == 0)
1187                         break;
1188
1189                 if ((res = mp_int_sqr(&t, &t)) != MP_OK)
1190                         goto CLEANUP;
1191         }
1192
1193 CLEANUP:
1194         mp_int_clear(&t);
1195         return res;
1196 }
1197
1198 /* }}} */
1199
1200 /* {{{ mp_int_expt_value(a, b, c) */
1201
1202 mp_result
1203 mp_int_expt_value(int a, int b, mp_int c)
1204 {
1205         mpz_t           t;
1206         mp_result       res;
1207         unsigned int v = abs(b);
1208
1209         CHECK(b >= 0 && c != NULL);
1210
1211         if ((res = mp_int_init_value(&t, a)) != MP_OK)
1212                 return res;
1213
1214         (void) mp_int_set_value(c, 1);
1215         while (v != 0)
1216         {
1217                 if (v & 1)
1218                 {
1219                         if ((res = mp_int_mul(c, &t, c)) != MP_OK)
1220                                 goto CLEANUP;
1221                 }
1222
1223                 v >>= 1;
1224                 if (v == 0)
1225                         break;
1226
1227                 if ((res = mp_int_sqr(&t, &t)) != MP_OK)
1228                         goto CLEANUP;
1229         }
1230
1231 CLEANUP:
1232         mp_int_clear(&t);
1233         return res;
1234 }
1235
1236 /* }}} */
1237
1238 /* {{{ mp_int_compare(a, b) */
1239
1240 int
1241 mp_int_compare(mp_int a, mp_int b)
1242 {
1243         mp_sign         sa;
1244
1245         CHECK(a != NULL && b != NULL);
1246
1247         sa = MP_SIGN(a);
1248         if (sa == MP_SIGN(b))
1249         {
1250                 int                     cmp = s_ucmp(a, b);
1251
1252                 /*
1253                  * If they're both zero or positive, the normal comparison applies; if
1254                  * both negative, the sense is reversed.
1255                  */
1256                 if (sa == MP_ZPOS)
1257                         return cmp;
1258                 else
1259                         return -cmp;
1260
1261         }
1262         else
1263         {
1264                 if (sa == MP_ZPOS)
1265                         return 1;
1266                 else
1267                         return -1;
1268         }
1269 }
1270
1271 /* }}} */
1272
1273 /* {{{ mp_int_compare_unsigned(a, b) */
1274
1275 int
1276 mp_int_compare_unsigned(mp_int a, mp_int b)
1277 {
1278         NRCHECK(a != NULL && b != NULL);
1279
1280         return s_ucmp(a, b);
1281 }
1282
1283 /* }}} */
1284
1285 /* {{{ mp_int_compare_zero(z) */
1286
1287 int
1288 mp_int_compare_zero(mp_int z)
1289 {
1290         NRCHECK(z != NULL);
1291
1292         if (MP_USED(z) == 1 && z->digits[0] == 0)
1293                 return 0;
1294         else if (MP_SIGN(z) == MP_ZPOS)
1295                 return 1;
1296         else
1297                 return -1;
1298 }
1299
1300 /* }}} */
1301
1302 /* {{{ mp_int_compare_value(z, value) */
1303
1304 int
1305 mp_int_compare_value(mp_int z, int value)
1306 {
1307         mp_sign         vsign = (value < 0) ? MP_NEG : MP_ZPOS;
1308         int                     cmp;
1309
1310         CHECK(z != NULL);
1311
1312         if (vsign == MP_SIGN(z))
1313         {
1314                 cmp = s_vcmp(z, value);
1315
1316                 if (vsign == MP_ZPOS)
1317                         return cmp;
1318                 else
1319                         return -cmp;
1320         }
1321         else
1322         {
1323                 if (value < 0)
1324                         return 1;
1325                 else
1326                         return -1;
1327         }
1328 }
1329
1330 /* }}} */
1331
1332 /* {{{ mp_int_exptmod(a, b, m, c) */
1333
1334 mp_result
1335 mp_int_exptmod(mp_int a, mp_int b, mp_int m, mp_int c)
1336 {
1337         mp_result       res;
1338         mp_size         um;
1339         mpz_t           temp[3];
1340         mp_int          s;
1341         int                     last = 0;
1342
1343         CHECK(a != NULL && b != NULL && c != NULL && m != NULL);
1344
1345         /* Zero moduli and negative exponents are not considered. */
1346         if (CMPZ(m) == 0)
1347                 return MP_UNDEF;
1348         if (CMPZ(b) < 0)
1349                 return MP_RANGE;
1350
1351         um = MP_USED(m);
1352         SETUP(mp_int_init_size(TEMP(0), 2 * um), last);
1353         SETUP(mp_int_init_size(TEMP(1), 2 * um), last);
1354
1355         if (c == b || c == m)
1356         {
1357                 SETUP(mp_int_init_size(TEMP(2), 2 * um), last);
1358                 s = TEMP(2);
1359         }
1360         else
1361         {
1362                 s = c;
1363         }
1364
1365         if ((res = mp_int_mod(a, m, TEMP(0))) != MP_OK)
1366                 goto CLEANUP;
1367
1368         if ((res = s_brmu(TEMP(1), m)) != MP_OK)
1369                 goto CLEANUP;
1370
1371         if ((res = s_embar(TEMP(0), b, m, TEMP(1), s)) != MP_OK)
1372                 goto CLEANUP;
1373
1374         res = mp_int_copy(s, c);
1375
1376 CLEANUP:
1377         while (--last >= 0)
1378                 mp_int_clear(TEMP(last));
1379
1380         return res;
1381 }
1382
1383 /* }}} */
1384
1385 /* {{{ mp_int_exptmod_evalue(a, value, m, c) */
1386
1387 mp_result
1388 mp_int_exptmod_evalue(mp_int a, int value, mp_int m, mp_int c)
1389 {
1390         mpz_t           vtmp;
1391         mp_digit        vbuf[MP_VALUE_DIGITS(value)];
1392
1393         s_fake(&vtmp, value, vbuf);
1394
1395         return mp_int_exptmod(a, &vtmp, m, c);
1396 }
1397
1398 /* }}} */
1399
1400 /* {{{ mp_int_exptmod_bvalue(v, b, m, c) */
1401
1402 mp_result
1403 mp_int_exptmod_bvalue(int value, mp_int b,
1404                                           mp_int m, mp_int c)
1405 {
1406         mpz_t           vtmp;
1407         mp_digit        vbuf[MP_VALUE_DIGITS(value)];
1408
1409         s_fake(&vtmp, value, vbuf);
1410
1411         return mp_int_exptmod(&vtmp, b, m, c);
1412 }
1413
1414 /* }}} */
1415
1416 /* {{{ mp_int_exptmod_known(a, b, m, mu, c) */
1417
1418 mp_result
1419 mp_int_exptmod_known(mp_int a, mp_int b, mp_int m, mp_int mu, mp_int c)
1420 {
1421         mp_result       res;
1422         mp_size         um;
1423         mpz_t           temp[2];
1424         mp_int          s;
1425         int                     last = 0;
1426
1427         CHECK(a && b && m && c);
1428
1429         /* Zero moduli and negative exponents are not considered. */
1430         if (CMPZ(m) == 0)
1431                 return MP_UNDEF;
1432         if (CMPZ(b) < 0)
1433                 return MP_RANGE;
1434
1435         um = MP_USED(m);
1436         SETUP(mp_int_init_size(TEMP(0), 2 * um), last);
1437
1438         if (c == b || c == m)
1439         {
1440                 SETUP(mp_int_init_size(TEMP(1), 2 * um), last);
1441                 s = TEMP(1);
1442         }
1443         else
1444         {
1445                 s = c;
1446         }
1447
1448         if ((res = mp_int_mod(a, m, TEMP(0))) != MP_OK)
1449                 goto CLEANUP;
1450
1451         if ((res = s_embar(TEMP(0), b, m, mu, s)) != MP_OK)
1452                 goto CLEANUP;
1453
1454         res = mp_int_copy(s, c);
1455
1456 CLEANUP:
1457         while (--last >= 0)
1458                 mp_int_clear(TEMP(last));
1459
1460         return res;
1461 }
1462
1463 /* }}} */
1464
1465 /* {{{ mp_int_redux_const(m, c) */
1466
1467 mp_result
1468 mp_int_redux_const(mp_int m, mp_int c)
1469 {
1470         CHECK(m != NULL && c != NULL && m != c);
1471
1472         return s_brmu(c, m);
1473 }
1474
1475 /* }}} */
1476
1477 /* {{{ mp_int_invmod(a, m, c) */
1478
1479 mp_result
1480 mp_int_invmod(mp_int a, mp_int m, mp_int c)
1481 {
1482         mp_result       res;
1483         mp_sign         sa;
1484         int                     last = 0;
1485         mpz_t           temp[2];
1486
1487         CHECK(a != NULL && m != NULL && c != NULL);
1488
1489         if (CMPZ(a) == 0 || CMPZ(m) <= 0)
1490                 return MP_RANGE;
1491
1492         sa = MP_SIGN(a);                        /* need this for the result later */
1493
1494         for (last = 0; last < 2; ++last)
1495                 if ((res = mp_int_init(TEMP(last))) != MP_OK)
1496                         goto CLEANUP;
1497
1498         if ((res = mp_int_egcd(a, m, TEMP(0), TEMP(1), NULL)) != MP_OK)
1499                 goto CLEANUP;
1500
1501         if (mp_int_compare_value(TEMP(0), 1) != 0)
1502         {
1503                 res = MP_UNDEF;
1504                 goto CLEANUP;
1505         }
1506
1507         /* It is first necessary to constrain the value to the proper range */
1508         if ((res = mp_int_mod(TEMP(1), m, TEMP(1))) != MP_OK)
1509                 goto CLEANUP;
1510
1511         /*
1512          * Now, if 'a' was originally negative, the value we have is actually the
1513          * magnitude of the negative representative; to get the positive value we
1514          * have to subtract from the modulus.  Otherwise, the value is okay as it
1515          * stands.
1516          */
1517         if (sa == MP_NEG)
1518                 res = mp_int_sub(m, TEMP(1), c);
1519         else
1520                 res = mp_int_copy(TEMP(1), c);
1521
1522 CLEANUP:
1523         while (--last >= 0)
1524                 mp_int_clear(TEMP(last));
1525
1526         return res;
1527 }
1528
1529 /* }}} */
1530
1531 /* {{{ mp_int_gcd(a, b, c) */
1532
1533 /* Binary GCD algorithm due to Josef Stein, 1961 */
1534 mp_result
1535 mp_int_gcd(mp_int a, mp_int b, mp_int c)
1536 {
1537         int                     ca,
1538                                 cb,
1539                                 k = 0;
1540         mpz_t           u,
1541                                 v,
1542                                 t;
1543         mp_result       res;
1544
1545         CHECK(a != NULL && b != NULL && c != NULL);
1546
1547         ca = CMPZ(a);
1548         cb = CMPZ(b);
1549         if (ca == 0 && cb == 0)
1550                 return MP_UNDEF;
1551         else if (ca == 0)
1552                 return mp_int_abs(b, c);
1553         else if (cb == 0)
1554                 return mp_int_abs(a, c);
1555
1556         if ((res = mp_int_init(&t)) != MP_OK)
1557                 return res;
1558         if ((res = mp_int_init_copy(&u, a)) != MP_OK)
1559                 goto U;
1560         if ((res = mp_int_init_copy(&v, b)) != MP_OK)
1561                 goto V;
1562
1563         MP_SIGN(&u) = MP_ZPOS;
1564         MP_SIGN(&v) = MP_ZPOS;
1565
1566         {                                                       /* Divide out common factors of 2 from u and v */
1567                 int                     div2_u = s_dp2k(&u),
1568                                         div2_v = s_dp2k(&v);
1569
1570                 k = MIN(div2_u, div2_v);
1571                 s_qdiv(&u, (mp_size) k);
1572                 s_qdiv(&v, (mp_size) k);
1573         }
1574
1575         if (mp_int_is_odd(&u))
1576         {
1577                 if ((res = mp_int_neg(&v, &t)) != MP_OK)
1578                         goto CLEANUP;
1579         }
1580         else
1581         {
1582                 if ((res = mp_int_copy(&u, &t)) != MP_OK)
1583                         goto CLEANUP;
1584         }
1585
1586         for (;;)
1587         {
1588                 s_qdiv(&t, s_dp2k(&t));
1589
1590                 if (CMPZ(&t) > 0)
1591                 {
1592                         if ((res = mp_int_copy(&t, &u)) != MP_OK)
1593                                 goto CLEANUP;
1594                 }
1595                 else
1596                 {
1597                         if ((res = mp_int_neg(&t, &v)) != MP_OK)
1598                                 goto CLEANUP;
1599                 }
1600
1601                 if ((res = mp_int_sub(&u, &v, &t)) != MP_OK)
1602                         goto CLEANUP;
1603
1604                 if (CMPZ(&t) == 0)
1605                         break;
1606         }
1607
1608         if ((res = mp_int_abs(&u, c)) != MP_OK)
1609                 goto CLEANUP;
1610         if (!s_qmul(c, (mp_size) k))
1611                 res = MP_MEMORY;
1612
1613 CLEANUP:
1614         mp_int_clear(&v);
1615 V: mp_int_clear(&u);
1616 U: mp_int_clear(&t);
1617
1618         return res;
1619 }
1620
1621 /* }}} */
1622
1623 /* {{{ mp_int_egcd(a, b, c, x, y) */
1624
1625 /* This is the binary GCD algorithm again, but this time we keep track
1626    of the elementary matrix operations as we go, so we can get values
1627    x and y satisfying c = ax + by.
1628  */
1629 mp_result
1630 mp_int_egcd(mp_int a, mp_int b, mp_int c,
1631                         mp_int x, mp_int y)
1632 {
1633         int                     k,
1634                                 last = 0,
1635                                 ca,
1636                                 cb;
1637         mpz_t           temp[8];
1638         mp_result       res;
1639
1640         CHECK(a != NULL && b != NULL && c != NULL &&
1641                   (x != NULL || y != NULL));
1642
1643         ca = CMPZ(a);
1644         cb = CMPZ(b);
1645         if (ca == 0 && cb == 0)
1646                 return MP_UNDEF;
1647         else if (ca == 0)
1648         {
1649                 if ((res = mp_int_abs(b, c)) != MP_OK)
1650                         return res;
1651                 mp_int_zero(x);
1652                 (void) mp_int_set_value(y, 1);
1653                 return MP_OK;
1654         }
1655         else if (cb == 0)
1656         {
1657                 if ((res = mp_int_abs(a, c)) != MP_OK)
1658                         return res;
1659                 (void) mp_int_set_value(x, 1);
1660                 mp_int_zero(y);
1661                 return MP_OK;
1662         }
1663
1664         /*
1665          * Initialize temporaries: A:0, B:1, C:2, D:3, u:4, v:5, ou:6, ov:7
1666          */
1667         for (last = 0; last < 4; ++last)
1668         {
1669                 if ((res = mp_int_init(TEMP(last))) != MP_OK)
1670                         goto CLEANUP;
1671         }
1672         TEMP(0)->digits[0] = 1;
1673         TEMP(3)->digits[0] = 1;
1674
1675         SETUP(mp_int_init_copy(TEMP(4), a), last);
1676         SETUP(mp_int_init_copy(TEMP(5), b), last);
1677
1678         /* We will work with absolute values here */
1679         MP_SIGN(TEMP(4)) = MP_ZPOS;
1680         MP_SIGN(TEMP(5)) = MP_ZPOS;
1681
1682         {                                                       /* Divide out common factors of 2 from u and v */
1683                 int                     div2_u = s_dp2k(TEMP(4)),
1684                                         div2_v = s_dp2k(TEMP(5));
1685
1686                 k = MIN(div2_u, div2_v);
1687                 s_qdiv(TEMP(4), k);
1688                 s_qdiv(TEMP(5), k);
1689         }
1690
1691         SETUP(mp_int_init_copy(TEMP(6), TEMP(4)), last);
1692         SETUP(mp_int_init_copy(TEMP(7), TEMP(5)), last);
1693
1694         for (;;)
1695         {
1696                 while (mp_int_is_even(TEMP(4)))
1697                 {
1698                         s_qdiv(TEMP(4), 1);
1699
1700                         if (mp_int_is_odd(TEMP(0)) || mp_int_is_odd(TEMP(1)))
1701                         {
1702                                 if ((res = mp_int_add(TEMP(0), TEMP(7), TEMP(0))) != MP_OK)
1703                                         goto CLEANUP;
1704                                 if ((res = mp_int_sub(TEMP(1), TEMP(6), TEMP(1))) != MP_OK)
1705                                         goto CLEANUP;
1706                         }
1707
1708                         s_qdiv(TEMP(0), 1);
1709                         s_qdiv(TEMP(1), 1);
1710                 }
1711
1712                 while (mp_int_is_even(TEMP(5)))
1713                 {
1714                         s_qdiv(TEMP(5), 1);
1715
1716                         if (mp_int_is_odd(TEMP(2)) || mp_int_is_odd(TEMP(3)))
1717                         {
1718                                 if ((res = mp_int_add(TEMP(2), TEMP(7), TEMP(2))) != MP_OK)
1719                                         goto CLEANUP;
1720                                 if ((res = mp_int_sub(TEMP(3), TEMP(6), TEMP(3))) != MP_OK)
1721                                         goto CLEANUP;
1722                         }
1723
1724                         s_qdiv(TEMP(2), 1);
1725                         s_qdiv(TEMP(3), 1);
1726                 }
1727
1728                 if (mp_int_compare(TEMP(4), TEMP(5)) >= 0)
1729                 {
1730                         if ((res = mp_int_sub(TEMP(4), TEMP(5), TEMP(4))) != MP_OK)
1731                                 goto CLEANUP;
1732                         if ((res = mp_int_sub(TEMP(0), TEMP(2), TEMP(0))) != MP_OK)
1733                                 goto CLEANUP;
1734                         if ((res = mp_int_sub(TEMP(1), TEMP(3), TEMP(1))) != MP_OK)
1735                                 goto CLEANUP;
1736                 }
1737                 else
1738                 {
1739                         if ((res = mp_int_sub(TEMP(5), TEMP(4), TEMP(5))) != MP_OK)
1740                                 goto CLEANUP;
1741                         if ((res = mp_int_sub(TEMP(2), TEMP(0), TEMP(2))) != MP_OK)
1742                                 goto CLEANUP;
1743                         if ((res = mp_int_sub(TEMP(3), TEMP(1), TEMP(3))) != MP_OK)
1744                                 goto CLEANUP;
1745                 }
1746
1747                 if (CMPZ(TEMP(4)) == 0)
1748                 {
1749                         if (x && (res = mp_int_copy(TEMP(2), x)) != MP_OK)
1750                                 goto CLEANUP;
1751                         if (y && (res = mp_int_copy(TEMP(3), y)) != MP_OK)
1752                                 goto CLEANUP;
1753                         if (c)
1754                         {
1755                                 if (!s_qmul(TEMP(5), k))
1756                                 {
1757                                         res = MP_MEMORY;
1758                                         goto CLEANUP;
1759                                 }
1760
1761                                 res = mp_int_copy(TEMP(5), c);
1762                         }
1763
1764                         break;
1765                 }
1766         }
1767
1768 CLEANUP:
1769         while (--last >= 0)
1770                 mp_int_clear(TEMP(last));
1771
1772         return res;
1773 }
1774
1775 /* }}} */
1776
1777 /* {{{ mp_int_divisible_value(a, v) */
1778
1779 int
1780 mp_int_divisible_value(mp_int a, int v)
1781 {
1782         int                     rem = 0;
1783
1784         if (mp_int_div_value(a, v, NULL, &rem) != MP_OK)
1785                 return 0;
1786
1787         return rem == 0;
1788 }
1789
1790 /* }}} */
1791
1792 /* {{{ mp_int_is_pow2(z) */
1793
1794 int
1795 mp_int_is_pow2(mp_int z)
1796 {
1797         CHECK(z != NULL);
1798
1799         return s_isp2(z);
1800 }
1801
1802 /* }}} */
1803
1804 /* {{{ mp_int_sqrt(a, c) */
1805
1806 mp_result
1807 mp_int_sqrt(mp_int a, mp_int c)
1808 {
1809         mp_result       res = MP_OK;
1810         mpz_t           temp[2];
1811         int                     last = 0;
1812
1813         CHECK(a != NULL && c != NULL);
1814
1815         /* The square root of a negative value does not exist in the integers. */
1816         if (MP_SIGN(a) == MP_NEG)
1817                 return MP_UNDEF;
1818
1819         SETUP(mp_int_init_copy(TEMP(last), a), last);
1820         SETUP(mp_int_init(TEMP(last)), last);
1821
1822         for (;;)
1823         {
1824                 if ((res = mp_int_sqr(TEMP(0), TEMP(1))) != MP_OK)
1825                         goto CLEANUP;
1826
1827                 if (mp_int_compare_unsigned(a, TEMP(1)) == 0)
1828                         break;
1829
1830                 if ((res = mp_int_copy(a, TEMP(1))) != MP_OK)
1831                         goto CLEANUP;
1832                 if ((res = mp_int_div(TEMP(1), TEMP(0), TEMP(1), NULL)) != MP_OK)
1833                         goto CLEANUP;
1834                 if ((res = mp_int_add(TEMP(0), TEMP(1), TEMP(1))) != MP_OK)
1835                         goto CLEANUP;
1836                 if ((res = mp_int_div_pow2(TEMP(1), 1, TEMP(1), NULL)) != MP_OK)
1837                         goto CLEANUP;
1838
1839                 if (mp_int_compare_unsigned(TEMP(0), TEMP(1)) == 0)
1840                         break;
1841                 if ((res = mp_int_sub_value(TEMP(0), 1, TEMP(0))) != MP_OK)
1842                         goto CLEANUP;
1843                 if (mp_int_compare_unsigned(TEMP(0), TEMP(1)) == 0)
1844                         break;
1845
1846                 if ((res = mp_int_copy(TEMP(1), TEMP(0))) != MP_OK)
1847                         goto CLEANUP;
1848         }
1849
1850         res = mp_int_copy(TEMP(0), c);
1851
1852 CLEANUP:
1853         while (--last >= 0)
1854                 mp_int_clear(TEMP(last));
1855
1856         return res;
1857 }
1858
1859 /* }}} */
1860
1861 /* {{{ mp_int_to_int(z, out) */
1862
1863 mp_result
1864 mp_int_to_int(mp_int z, int *out)
1865 {
1866         unsigned int uv = 0;
1867         mp_size         uz;
1868         mp_digit   *dz;
1869         mp_sign         sz;
1870
1871         CHECK(z != NULL);
1872
1873         /* Make sure the value is representable as an int */
1874         sz = MP_SIGN(z);
1875         if ((sz == MP_ZPOS && mp_int_compare_value(z, INT_MAX) > 0) ||
1876                 mp_int_compare_value(z, INT_MIN) < 0)
1877                 return MP_RANGE;
1878
1879         uz = MP_USED(z);
1880         dz = MP_DIGITS(z) + uz - 1;
1881
1882         while (uz > 0)
1883         {
1884                 uv <<= MP_DIGIT_BIT / 2;
1885                 uv = (uv << (MP_DIGIT_BIT / 2)) | *dz--;
1886                 --uz;
1887         }
1888
1889         if (out)
1890                 *out = (sz == MP_NEG) ? -(int) uv : (int) uv;
1891
1892         return MP_OK;
1893 }
1894
1895 /* }}} */
1896
1897 /* {{{ mp_int_to_string(z, radix, str, limit) */
1898
1899 mp_result
1900 mp_int_to_string(mp_int z, mp_size radix,
1901                                  char *str, int limit)
1902 {
1903         mp_result       res;
1904         int                     cmp = 0;
1905
1906         CHECK(z != NULL && str != NULL && limit >= 2);
1907
1908         if (radix < MP_MIN_RADIX || radix > MP_MAX_RADIX)
1909                 return MP_RANGE;
1910
1911         if (CMPZ(z) == 0)
1912         {
1913                 *str++ = s_val2ch(0, mp_flags & MP_CAP_DIGITS);
1914         }
1915         else
1916         {
1917                 mpz_t           tmp;
1918                 char       *h,
1919                                    *t;
1920
1921                 if ((res = mp_int_init_copy(&tmp, z)) != MP_OK)
1922                         return res;
1923
1924                 if (MP_SIGN(z) == MP_NEG)
1925                 {
1926                         *str++ = '-';
1927                         --limit;
1928                 }
1929                 h = str;
1930
1931                 /* Generate digits in reverse order until finished or limit reached */
1932                 for ( /* */ ; limit > 0; --limit)
1933                 {
1934                         mp_digit        d;
1935
1936                         if ((cmp = CMPZ(&tmp)) == 0)
1937                                 break;
1938
1939                         d = s_ddiv(&tmp, (mp_digit) radix);
1940                         *str++ = s_val2ch(d, mp_flags & MP_CAP_DIGITS);
1941                 }
1942                 t = str - 1;
1943
1944                 /* Put digits back in correct output order */
1945                 while (h < t)
1946                 {
1947                         char            tc = *h;
1948
1949                         *h++ = *t;
1950                         *t-- = tc;
1951                 }
1952
1953                 mp_int_clear(&tmp);
1954         }
1955
1956         *str = '\0';
1957         if (cmp == 0)
1958                 return MP_OK;
1959         else
1960                 return MP_TRUNC;
1961 }
1962
1963 /* }}} */
1964
1965 /* {{{ mp_int_string_len(z, radix) */
1966
1967 mp_result
1968 mp_int_string_len(mp_int z, mp_size radix)
1969 {
1970         int                     len;
1971
1972         CHECK(z != NULL);
1973
1974         if (radix < MP_MIN_RADIX || radix > MP_MAX_RADIX)
1975                 return MP_RANGE;
1976
1977         len = s_outlen(z, radix) + 1;           /* for terminator */
1978
1979         /* Allow for sign marker on negatives */
1980         if (MP_SIGN(z) == MP_NEG)
1981                 len += 1;
1982
1983         return len;
1984 }
1985
1986 /* }}} */
1987
1988 /* {{{ mp_int_read_string(z, radix, *str) */
1989
1990 /* Read zero-terminated string into z */
1991 mp_result
1992 mp_int_read_string(mp_int z, mp_size radix, const char *str)
1993 {
1994         return mp_int_read_cstring(z, radix, str, NULL);
1995
1996 }
1997
1998 /* }}} */
1999
2000 /* {{{ mp_int_read_cstring(z, radix, *str, **end) */
2001
2002 mp_result
2003 mp_int_read_cstring(mp_int z, mp_size radix, const char *str, char **end)
2004 {
2005         int                     ch;
2006
2007         CHECK(z != NULL && str != NULL);
2008
2009         if (radix < MP_MIN_RADIX || radix > MP_MAX_RADIX)
2010                 return MP_RANGE;
2011
2012         /* Skip leading whitespace */
2013         while (isspace((unsigned char) *str))
2014                 ++str;
2015
2016         /* Handle leading sign tag (+/-, positive default) */
2017         switch (*str)
2018         {
2019                 case '-':
2020                         MP_SIGN(z) = MP_NEG;
2021                         ++str;
2022                         break;
2023                 case '+':
2024                         ++str;                          /* fallthrough */
2025                 default:
2026                         MP_SIGN(z) = MP_ZPOS;
2027                         break;
2028         }
2029
2030         /* Skip leading zeroes */
2031         while ((ch = s_ch2val(*str, radix)) == 0)
2032                 ++str;
2033
2034         /* Make sure there is enough space for the value */
2035         if (!s_pad(z, s_inlen(strlen(str), radix)))
2036                 return MP_MEMORY;
2037
2038         MP_USED(z) = 1;
2039         z->digits[0] = 0;
2040
2041         while (*str != '\0' && ((ch = s_ch2val(*str, radix)) >= 0))
2042         {
2043                 s_dmul(z, (mp_digit) radix);
2044                 s_dadd(z, (mp_digit) ch);
2045                 ++str;
2046         }
2047
2048         CLAMP(z);
2049
2050         /* Override sign for zero, even if negative specified. */
2051         if (CMPZ(z) == 0)
2052                 MP_SIGN(z) = MP_ZPOS;
2053
2054         if (end != NULL)
2055                 *end = (char *) str;
2056
2057         /*
2058          * Return a truncation error if the string has unprocessed characters
2059          * remaining, so the caller can tell if the whole string was done
2060          */
2061         if (*str != '\0')
2062                 return MP_TRUNC;
2063         else
2064                 return MP_OK;
2065 }
2066
2067 /* }}} */
2068
2069 /* {{{ mp_int_count_bits(z) */
2070
2071 mp_result
2072 mp_int_count_bits(mp_int z)
2073 {
2074         mp_size         nbits = 0,
2075                                 uz;
2076         mp_digit        d;
2077
2078         CHECK(z != NULL);
2079
2080         uz = MP_USED(z);
2081         if (uz == 1 && z->digits[0] == 0)
2082                 return 1;
2083
2084         --uz;
2085         nbits = uz * MP_DIGIT_BIT;
2086         d = z->digits[uz];
2087
2088         while (d != 0)
2089         {
2090                 d >>= 1;
2091                 ++nbits;
2092         }
2093
2094         return nbits;
2095 }
2096
2097 /* }}} */
2098
2099 /* {{{ mp_int_to_binary(z, buf, limit) */
2100
2101 mp_result
2102 mp_int_to_binary(mp_int z, unsigned char *buf, int limit)
2103 {
2104         static const int PAD_FOR_2C = 1;
2105
2106         mp_result       res;
2107         int                     limpos = limit;
2108
2109         CHECK(z != NULL && buf != NULL);
2110
2111         res = s_tobin(z, buf, &limpos, PAD_FOR_2C);
2112
2113         if (MP_SIGN(z) == MP_NEG)
2114                 s_2comp(buf, limpos);
2115
2116         return res;
2117 }
2118
2119 /* }}} */
2120
2121 /* {{{ mp_int_read_binary(z, buf, len) */
2122
2123 mp_result
2124 mp_int_read_binary(mp_int z, unsigned char *buf, int len)
2125 {
2126         mp_size         need,
2127                                 i;
2128         unsigned char *tmp;
2129         mp_digit   *dz;
2130
2131         CHECK(z != NULL && buf != NULL && len > 0);
2132
2133         /* Figure out how many digits are needed to represent this value */
2134         need = ((len * CHAR_BIT) + (MP_DIGIT_BIT - 1)) / MP_DIGIT_BIT;
2135         if (!s_pad(z, need))
2136                 return MP_MEMORY;
2137
2138         mp_int_zero(z);
2139
2140         /*
2141          * If the high-order bit is set, take the 2's complement before reading
2142          * the value (it will be restored afterward)
2143          */
2144         if (buf[0] >> (CHAR_BIT - 1))
2145         {
2146                 MP_SIGN(z) = MP_NEG;
2147                 s_2comp(buf, len);
2148         }
2149
2150         dz = MP_DIGITS(z);
2151         for (tmp = buf, i = len; i > 0; --i, ++tmp)
2152         {
2153                 s_qmul(z, (mp_size) CHAR_BIT);
2154                 *dz |= *tmp;
2155         }
2156
2157         /* Restore 2's complement if we took it before */
2158         if (MP_SIGN(z) == MP_NEG)
2159                 s_2comp(buf, len);
2160
2161         return MP_OK;
2162 }
2163
2164 /* }}} */
2165
2166 /* {{{ mp_int_binary_len(z) */
2167
2168 mp_result
2169 mp_int_binary_len(mp_int z)
2170 {
2171         mp_result       res = mp_int_count_bits(z);
2172         int                     bytes = mp_int_unsigned_len(z);
2173
2174         if (res <= 0)
2175                 return res;
2176
2177         bytes = (res + (CHAR_BIT - 1)) / CHAR_BIT;
2178
2179         /*
2180          * If the highest-order bit falls exactly on a byte boundary, we need to
2181          * pad with an extra byte so that the sign will be read correctly when
2182          * reading it back in.
2183          */
2184         if (bytes * CHAR_BIT == res)
2185                 ++bytes;
2186
2187         return bytes;
2188 }
2189
2190 /* }}} */
2191
2192 /* {{{ mp_int_to_unsigned(z, buf, limit) */
2193
2194 mp_result
2195 mp_int_to_unsigned(mp_int z, unsigned char *buf, int limit)
2196 {
2197         static const int NO_PADDING = 0;
2198
2199         CHECK(z != NULL && buf != NULL);
2200
2201         return s_tobin(z, buf, &limit, NO_PADDING);
2202 }
2203
2204 /* }}} */
2205
2206 /* {{{ mp_int_read_unsigned(z, buf, len) */
2207
2208 mp_result
2209 mp_int_read_unsigned(mp_int z, unsigned char *buf, int len)
2210 {
2211         mp_size         need,
2212                                 i;
2213         unsigned char *tmp;
2214         mp_digit   *dz;
2215
2216         CHECK(z != NULL && buf != NULL && len > 0);
2217
2218         /* Figure out how many digits are needed to represent this value */
2219         need = ((len * CHAR_BIT) + (MP_DIGIT_BIT - 1)) / MP_DIGIT_BIT;
2220         if (!s_pad(z, need))
2221                 return MP_MEMORY;
2222
2223         mp_int_zero(z);
2224
2225         dz = MP_DIGITS(z);
2226         for (tmp = buf, i = len; i > 0; --i, ++tmp)
2227         {
2228                 (void) s_qmul(z, CHAR_BIT);
2229                 *dz |= *tmp;
2230         }
2231
2232         return MP_OK;
2233 }
2234
2235 /* }}} */
2236
2237 /* {{{ mp_int_unsigned_len(z) */
2238
2239 mp_result
2240 mp_int_unsigned_len(mp_int z)
2241 {
2242         mp_result       res = mp_int_count_bits(z);
2243         int                     bytes;
2244
2245         if (res <= 0)
2246                 return res;
2247
2248         bytes = (res + (CHAR_BIT - 1)) / CHAR_BIT;
2249
2250         return bytes;
2251 }
2252
2253 /* }}} */
2254
2255 /* {{{ mp_error_string(res) */
2256
2257 const char *
2258 mp_error_string(mp_result res)
2259 {
2260         int                     ix;
2261
2262         if (res > 0)
2263                 return s_unknown_err;
2264
2265         res = -res;
2266         for (ix = 0; ix < res && s_error_msg[ix] != NULL; ++ix)
2267                 ;
2268
2269         if (s_error_msg[ix] != NULL)
2270                 return s_error_msg[ix];
2271         else
2272                 return s_unknown_err;
2273 }
2274
2275 /* }}} */
2276
2277 /*------------------------------------------------------------------------*/
2278 /* Private functions for internal use.  These make assumptions.                   */
2279
2280 /* {{{ s_alloc(num) */
2281
2282 static mp_digit *
2283 s_alloc(mp_size num)
2284 {
2285         mp_digit   *out = px_alloc(num * sizeof(mp_digit));
2286
2287         assert(out != NULL);            /* for debugging */
2288
2289         return out;
2290 }
2291
2292 /* }}} */
2293
2294 /* {{{ s_realloc(old, num) */
2295
2296 static mp_digit *
2297 s_realloc(mp_digit * old, mp_size num)
2298 {
2299         mp_digit   *new = px_realloc(old, num * sizeof(mp_digit));
2300
2301         assert(new != NULL);            /* for debugging */
2302
2303         return new;
2304 }
2305
2306 /* }}} */
2307
2308 /* {{{ s_free(ptr) */
2309
2310 #if TRACEABLE_FREE
2311 static void
2312 s_free(void *ptr)
2313 {
2314         px_free(ptr);
2315 }
2316 #endif
2317
2318 /* }}} */
2319
2320 /* {{{ s_pad(z, min) */
2321
2322 static int
2323 s_pad(mp_int z, mp_size min)
2324 {
2325         if (MP_ALLOC(z) < min)
2326         {
2327                 mp_size         nsize = ROUND_PREC(min);
2328                 mp_digit   *tmp = s_realloc(MP_DIGITS(z), nsize);
2329
2330                 if (tmp == NULL)
2331                         return 0;
2332
2333                 MP_DIGITS(z) = tmp;
2334                 MP_ALLOC(z) = nsize;
2335         }
2336
2337         return 1;
2338 }
2339
2340 /* }}} */
2341
2342 /* {{{ s_clamp(z) */
2343
2344 #if TRACEABLE_CLAMP
2345 static void
2346 s_clamp(mp_int z)
2347 {
2348         mp_size         uz = MP_USED(z);
2349         mp_digit   *zd = MP_DIGITS(z) + uz - 1;
2350
2351         while (uz > 1 && (*zd-- == 0))
2352                 --uz;
2353
2354         MP_USED(z) = uz;
2355 }
2356 #endif
2357
2358 /* }}} */
2359
2360 /* {{{ s_fake(z, value, vbuf) */
2361
2362 static void
2363 s_fake(mp_int z, int value, mp_digit vbuf[])
2364 {
2365         mp_size         uv = (mp_size) s_vpack(value, vbuf);
2366
2367         z->used = uv;
2368         z->alloc = MP_VALUE_DIGITS(value);
2369         z->sign = (value < 0) ? MP_NEG : MP_ZPOS;
2370         z->digits = vbuf;
2371 }
2372
2373 /* }}} */
2374
2375 /* {{{ s_cdig(da, db, len) */
2376
2377 static int
2378 s_cdig(mp_digit * da, mp_digit * db, mp_size len)
2379 {
2380         mp_digit   *dat = da + len - 1,
2381                            *dbt = db + len - 1;
2382
2383         for ( /* */ ; len != 0; --len, --dat, --dbt)
2384         {
2385                 if (*dat > *dbt)
2386                         return 1;
2387                 else if (*dat < *dbt)
2388                         return -1;
2389         }
2390
2391         return 0;
2392 }
2393
2394 /* }}} */
2395
2396 /* {{{ s_vpack(v, t[]) */
2397
2398 static int
2399 s_vpack(int v, mp_digit t[])
2400 {
2401         unsigned int uv = (unsigned int) ((v < 0) ? -v : v);
2402         int                     ndig = 0;
2403
2404         if (uv == 0)
2405                 t[ndig++] = 0;
2406         else
2407         {
2408                 while (uv != 0)
2409                 {
2410                         t[ndig++] = (mp_digit) uv;
2411                         uv >>= MP_DIGIT_BIT / 2;
2412                         uv >>= MP_DIGIT_BIT / 2;
2413                 }
2414         }
2415
2416         return ndig;
2417 }
2418
2419 /* }}} */
2420
2421 /* {{{ s_ucmp(a, b) */
2422
2423 static int
2424 s_ucmp(mp_int a, mp_int b)
2425 {
2426         mp_size         ua = MP_USED(a),
2427                                 ub = MP_USED(b);
2428
2429         if (ua > ub)
2430                 return 1;
2431         else if (ub > ua)
2432                 return -1;
2433         else
2434                 return s_cdig(MP_DIGITS(a), MP_DIGITS(b), ua);
2435 }
2436
2437 /* }}} */
2438
2439 /* {{{ s_vcmp(a, v) */
2440
2441 static int
2442 s_vcmp(mp_int a, int v)
2443 {
2444         mp_digit        vdig[MP_VALUE_DIGITS(v)];
2445         int                     ndig = 0;
2446         mp_size         ua = MP_USED(a);
2447
2448         ndig = s_vpack(v, vdig);
2449
2450         if (ua > ndig)
2451                 return 1;
2452         else if (ua < ndig)
2453                 return -1;
2454         else
2455                 return s_cdig(MP_DIGITS(a), vdig, ndig);
2456 }
2457
2458 /* }}} */
2459
2460 /* {{{ s_uadd(da, db, dc, size_a, size_b) */
2461
2462 static mp_digit
2463 s_uadd(mp_digit * da, mp_digit * db, mp_digit * dc,
2464            mp_size size_a, mp_size size_b)
2465 {
2466         mp_size         pos;
2467         mp_word         w = 0;
2468
2469         /* Insure that da is the longer of the two to simplify later code */
2470         if (size_b > size_a)
2471         {
2472                 SWAP(mp_digit *, da, db);
2473                 SWAP(mp_size, size_a, size_b);
2474         }
2475
2476         /* Add corresponding digits until the shorter number runs out */
2477         for (pos = 0; pos < size_b; ++pos, ++da, ++db, ++dc)
2478         {
2479                 w = w + (mp_word) * da + (mp_word) * db;
2480                 *dc = LOWER_HALF(w);
2481                 w = UPPER_HALF(w);
2482         }
2483
2484         /* Propagate carries as far as necessary */
2485         for ( /* */ ; pos < size_a; ++pos, ++da, ++dc)
2486         {
2487                 w = w + *da;
2488
2489                 *dc = LOWER_HALF(w);
2490                 w = UPPER_HALF(w);
2491         }
2492
2493         /* Return carry out */
2494         return (mp_digit) w;
2495 }
2496
2497 /* }}} */
2498
2499 /* {{{ s_usub(da, db, dc, size_a, size_b) */
2500
2501 static void
2502 s_usub(mp_digit * da, mp_digit * db, mp_digit * dc,
2503            mp_size size_a, mp_size size_b)
2504 {
2505         mp_size         pos;
2506         mp_word         w = 0;
2507
2508         /* We assume that |a| >= |b| so this should definitely hold */
2509         assert(size_a >= size_b);
2510
2511         /* Subtract corresponding digits and propagate borrow */
2512         for (pos = 0; pos < size_b; ++pos, ++da, ++db, ++dc)
2513         {
2514                 w = ((mp_word) MP_DIGIT_MAX + 1 +               /* MP_RADIX */
2515                          (mp_word) * da) - w - (mp_word) * db;
2516
2517                 *dc = LOWER_HALF(w);
2518                 w = (UPPER_HALF(w) == 0);
2519         }
2520
2521         /* Finish the subtraction for remaining upper digits of da */
2522         for ( /* */ ; pos < size_a; ++pos, ++da, ++dc)
2523         {
2524                 w = ((mp_word) MP_DIGIT_MAX + 1 +               /* MP_RADIX */
2525                          (mp_word) * da) - w;
2526
2527                 *dc = LOWER_HALF(w);
2528                 w = (UPPER_HALF(w) == 0);
2529         }
2530
2531         /* If there is a borrow out at the end, it violates the precondition */
2532         assert(w == 0);
2533 }
2534
2535 /* }}} */
2536
2537 /* {{{ s_kmul(da, db, dc, size_a, size_b) */
2538
2539 static int
2540 s_kmul(mp_digit * da, mp_digit * db, mp_digit * dc,
2541            mp_size size_a, mp_size size_b)
2542 {
2543         mp_size         bot_size;
2544
2545         /* Make sure b is the smaller of the two input values */
2546         if (size_b > size_a)
2547         {
2548                 SWAP(mp_digit *, da, db);
2549                 SWAP(mp_size, size_a, size_b);
2550         }
2551
2552         /*
2553          * Insure that the bottom is the larger half in an odd-length split; the
2554          * code below relies on this being true.
2555          */
2556         bot_size = (size_a + 1) / 2;
2557
2558         /*
2559          * If the values are big enough to bother with recursion, use the
2560          * Karatsuba algorithm to compute the product; otherwise use the normal
2561          * multiplication algorithm
2562          */
2563         if (multiply_threshold &&
2564                 size_a >= multiply_threshold &&
2565                 size_b > bot_size)
2566         {
2567
2568                 mp_digit   *t1,
2569                                    *t2,
2570                                    *t3,
2571                                         carry;
2572
2573                 mp_digit   *a_top = da + bot_size;
2574                 mp_digit   *b_top = db + bot_size;
2575
2576                 mp_size         at_size = size_a - bot_size;
2577                 mp_size         bt_size = size_b - bot_size;
2578                 mp_size         buf_size = 2 * bot_size;
2579
2580                 /*
2581                  * Do a single allocation for all three temporary buffers needed; each
2582                  * buffer must be big enough to hold the product of two bottom halves,
2583                  * and one buffer needs space for the completed product; twice the
2584                  * space is plenty.
2585                  */
2586                 if ((t1 = s_alloc(4 * buf_size)) == NULL)
2587                         return 0;
2588                 t2 = t1 + buf_size;
2589                 t3 = t2 + buf_size;
2590                 ZERO(t1, 4 * buf_size);
2591
2592                 /*
2593                  * t1 and t2 are initially used as temporaries to compute the inner
2594                  * product (a1 + a0)(b1 + b0) = a1b1 + a1b0 + a0b1 + a0b0
2595                  */
2596                 carry = s_uadd(da, a_top, t1, bot_size, at_size);               /* t1 = a1 + a0 */
2597                 t1[bot_size] = carry;
2598
2599                 carry = s_uadd(db, b_top, t2, bot_size, bt_size);               /* t2 = b1 + b0 */
2600                 t2[bot_size] = carry;
2601
2602                 (void) s_kmul(t1, t2, t3, bot_size + 1, bot_size + 1);  /* t3 = t1 * t2 */
2603
2604                 /*
2605                  * Now we'll get t1 = a0b0 and t2 = a1b1, and subtract them out so
2606                  * that we're left with only the pieces we want:  t3 = a1b0 + a0b1
2607                  */
2608                 ZERO(t1, bot_size + 1);
2609                 ZERO(t2, bot_size + 1);
2610                 (void) s_kmul(da, db, t1, bot_size, bot_size);  /* t1 = a0 * b0 */
2611                 (void) s_kmul(a_top, b_top, t2, at_size, bt_size);              /* t2 = a1 * b1 */
2612
2613                 /* Subtract out t1 and t2 to get the inner product */
2614                 s_usub(t3, t1, t3, buf_size + 2, buf_size);
2615                 s_usub(t3, t2, t3, buf_size + 2, buf_size);
2616
2617                 /* Assemble the output value */
2618                 COPY(t1, dc, buf_size);
2619                 (void) s_uadd(t3, dc + bot_size, dc + bot_size,
2620                                           buf_size + 1, buf_size + 1);
2621
2622                 (void) s_uadd(t2, dc + 2 * bot_size, dc + 2 * bot_size,
2623                                           buf_size, buf_size);
2624
2625                 s_free(t1);                             /* note t2 and t3 are just internal pointers
2626                                                                  * to t1 */
2627         }
2628         else
2629         {
2630                 s_umul(da, db, dc, size_a, size_b);
2631         }
2632
2633         return 1;
2634 }
2635
2636 /* }}} */
2637
2638 /* {{{ s_umul(da, db, dc, size_a, size_b) */
2639
2640 static void
2641 s_umul(mp_digit * da, mp_digit * db, mp_digit * dc,
2642            mp_size size_a, mp_size size_b)
2643 {
2644         mp_size         a,
2645                                 b;
2646         mp_word         w;
2647
2648         for (a = 0; a < size_a; ++a, ++dc, ++da)
2649         {
2650                 mp_digit   *dct = dc;
2651                 mp_digit   *dbt = db;
2652
2653                 if (*da == 0)
2654                         continue;
2655
2656                 w = 0;
2657                 for (b = 0; b < size_b; ++b, ++dbt, ++dct)
2658                 {
2659                         w = (mp_word) * da * (mp_word) * dbt + w + (mp_word) * dct;
2660
2661                         *dct = LOWER_HALF(w);
2662                         w = UPPER_HALF(w);
2663                 }
2664
2665                 *dct = (mp_digit) w;
2666         }
2667 }
2668
2669 /* }}} */
2670
2671 /* {{{ s_ksqr(da, dc, size_a) */
2672
2673 static int
2674 s_ksqr(mp_digit * da, mp_digit * dc, mp_size size_a)
2675 {
2676         if (multiply_threshold && size_a > multiply_threshold)
2677         {
2678                 mp_size         bot_size = (size_a + 1) / 2;
2679                 mp_digit   *a_top = da + bot_size;
2680                 mp_digit   *t1,
2681                                    *t2,
2682                                    *t3;
2683                 mp_size         at_size = size_a - bot_size;
2684                 mp_size         buf_size = 2 * bot_size;
2685
2686                 if ((t1 = s_alloc(4 * buf_size)) == NULL)
2687                         return 0;
2688                 t2 = t1 + buf_size;
2689                 t3 = t2 + buf_size;
2690                 ZERO(t1, 4 * buf_size);
2691
2692                 (void) s_ksqr(da, t1, bot_size);                /* t1 = a0 ^ 2 */
2693                 (void) s_ksqr(a_top, t2, at_size);              /* t2 = a1 ^ 2 */
2694
2695                 (void) s_kmul(da, a_top, t3, bot_size, at_size);                /* t3 = a0 * a1 */
2696
2697                 /* Quick multiply t3 by 2, shifting left (can't overflow) */
2698                 {
2699                         int                     i,
2700                                                 top = bot_size + at_size;
2701                         mp_word         w,
2702                                                 save = 0;
2703
2704                         for (i = 0; i < top; ++i)
2705                         {
2706                                 w = t3[i];
2707                                 w = (w << 1) | save;
2708                                 t3[i] = LOWER_HALF(w);
2709                                 save = UPPER_HALF(w);
2710                         }
2711                         t3[i] = LOWER_HALF(save);
2712                 }
2713
2714                 /* Assemble the output value */
2715                 COPY(t1, dc, 2 * bot_size);
2716                 (void) s_uadd(t3, dc + bot_size, dc + bot_size,
2717                                           buf_size + 1, buf_size + 1);
2718
2719                 (void) s_uadd(t2, dc + 2 * bot_size, dc + 2 * bot_size,
2720                                           buf_size, buf_size);
2721
2722                 px_free(t1);                    /* note that t2 and t2 are internal pointers
2723                                                                  * only */
2724
2725         }
2726         else
2727         {
2728                 s_usqr(da, dc, size_a);
2729         }
2730
2731         return 1;
2732 }
2733
2734 /* }}} */
2735
2736 /* {{{ s_usqr(da, dc, size_a) */
2737
2738 static void
2739 s_usqr(mp_digit * da, mp_digit * dc, mp_size size_a)
2740 {
2741         mp_size         i,
2742                                 j;
2743         mp_word         w;
2744
2745         for (i = 0; i < size_a; ++i, dc += 2, ++da)
2746         {
2747                 mp_digit   *dct = dc,
2748                                    *dat = da;
2749
2750                 if (*da == 0)
2751                         continue;
2752
2753                 /* Take care of the first digit, no rollover */
2754                 w = (mp_word) * dat * (mp_word) * dat + (mp_word) * dct;
2755                 *dct = LOWER_HALF(w);
2756                 w = UPPER_HALF(w);
2757                 ++dat;
2758                 ++dct;
2759
2760                 for (j = i + 1; j < size_a; ++j, ++dat, ++dct)
2761                 {
2762                         mp_word         t = (mp_word) * da * (mp_word) * dat;
2763                         mp_word         u = w + (mp_word) * dct,
2764                                                 ov = 0;
2765
2766                         /* Check if doubling t will overflow a word */
2767                         if (HIGH_BIT_SET(t))
2768                                 ov = 1;
2769
2770                         w = t + t;
2771
2772                         /* Check if adding u to w will overflow a word */
2773                         if (ADD_WILL_OVERFLOW(w, u))
2774                                 ov = 1;
2775
2776                         w += u;
2777
2778                         *dct = LOWER_HALF(w);
2779                         w = UPPER_HALF(w);
2780                         if (ov)
2781                         {
2782                                 w += MP_DIGIT_MAX;              /* MP_RADIX */
2783                                 ++w;
2784                         }
2785                 }
2786
2787                 w = w + *dct;
2788                 *dct = (mp_digit) w;
2789                 while ((w = UPPER_HALF(w)) != 0)
2790                 {
2791                         ++dct;
2792                         w = w + *dct;
2793                         *dct = LOWER_HALF(w);
2794                 }
2795
2796                 assert(w == 0);
2797         }
2798 }
2799
2800 /* }}} */
2801
2802 /* {{{ s_dadd(a, b) */
2803
2804 static void
2805 s_dadd(mp_int a, mp_digit b)
2806 {
2807         mp_word         w = 0;
2808         mp_digit   *da = MP_DIGITS(a);
2809         mp_size         ua = MP_USED(a);
2810
2811         w = (mp_word) * da + b;
2812         *da++ = LOWER_HALF(w);
2813         w = UPPER_HALF(w);
2814
2815         for (ua -= 1; ua > 0; --ua, ++da)
2816         {
2817                 w = (mp_word) * da + w;
2818
2819                 *da = LOWER_HALF(w);
2820                 w = UPPER_HALF(w);
2821         }
2822
2823         if (w)
2824         {
2825                 *da = (mp_digit) w;
2826                 MP_USED(a) += 1;
2827         }
2828 }
2829
2830 /* }}} */
2831
2832 /* {{{ s_dmul(a, b) */
2833
2834 static void
2835 s_dmul(mp_int a, mp_digit b)
2836 {
2837         mp_word         w = 0;
2838         mp_digit   *da = MP_DIGITS(a);
2839         mp_size         ua = MP_USED(a);
2840
2841         while (ua > 0)
2842         {
2843                 w = (mp_word) * da * b + w;
2844                 *da++ = LOWER_HALF(w);
2845                 w = UPPER_HALF(w);
2846                 --ua;
2847         }
2848
2849         if (w)
2850         {
2851                 *da = (mp_digit) w;
2852                 MP_USED(a) += 1;
2853         }
2854 }
2855
2856 /* }}} */
2857
2858 /* {{{ s_dbmul(da, b, dc, size_a) */
2859
2860 static void
2861 s_dbmul(mp_digit * da, mp_digit b, mp_digit * dc, mp_size size_a)
2862 {
2863         mp_word         w = 0;
2864
2865         while (size_a > 0)
2866         {
2867                 w = (mp_word) * da++ * (mp_word) b + w;
2868
2869                 *dc++ = LOWER_HALF(w);
2870                 w = UPPER_HALF(w);
2871                 --size_a;
2872         }
2873
2874         if (w)
2875                 *dc = LOWER_HALF(w);
2876 }
2877
2878 /* }}} */
2879
2880 /* {{{ s_ddiv(da, d, dc, size_a) */
2881
2882 static mp_digit
2883 s_ddiv(mp_int a, mp_digit b)
2884 {
2885         mp_word         w = 0,
2886                                 qdigit;
2887         mp_size         ua = MP_USED(a);
2888         mp_digit   *da = MP_DIGITS(a) + ua - 1;
2889
2890         for ( /* */ ; ua > 0; --ua, --da)
2891         {
2892                 w = (w << MP_DIGIT_BIT) | *da;
2893
2894                 if (w >= b)
2895                 {
2896                         qdigit = w / b;
2897                         w = w % b;
2898                 }
2899                 else
2900                 {
2901                         qdigit = 0;
2902                 }
2903
2904                 *da = (mp_digit) qdigit;
2905         }
2906
2907         CLAMP(a);
2908         return (mp_digit) w;
2909 }
2910
2911 /* }}} */
2912
2913 /* {{{ s_qdiv(z, p2) */
2914
2915 static void
2916 s_qdiv(mp_int z, mp_size p2)
2917 {
2918         mp_size         ndig = p2 / MP_DIGIT_BIT,
2919                                 nbits = p2 % MP_DIGIT_BIT;
2920         mp_size         uz = MP_USED(z);
2921
2922         if (ndig)
2923         {
2924                 mp_size         mark;
2925                 mp_digit   *to,
2926                                    *from;
2927
2928                 if (ndig >= uz)
2929                 {
2930                         mp_int_zero(z);
2931                         return;
2932                 }
2933
2934                 to = MP_DIGITS(z);
2935                 from = to + ndig;
2936
2937                 for (mark = ndig; mark < uz; ++mark)
2938                         *to++ = *from++;
2939
2940                 MP_USED(z) = uz - ndig;
2941         }
2942
2943         if (nbits)
2944         {
2945                 mp_digit        d = 0,
2946                                    *dz,
2947                                         save;
2948                 mp_size         up = MP_DIGIT_BIT - nbits;
2949
2950                 uz = MP_USED(z);
2951                 dz = MP_DIGITS(z) + uz - 1;
2952
2953                 for ( /* */ ; uz > 0; --uz, --dz)
2954                 {
2955                         save = *dz;
2956
2957                         *dz = (*dz >> nbits) | (d << up);
2958                         d = save;
2959                 }
2960
2961                 CLAMP(z);
2962         }
2963
2964         if (MP_USED(z) == 1 && z->digits[0] == 0)
2965                 MP_SIGN(z) = MP_ZPOS;
2966 }
2967
2968 /* }}} */
2969
2970 /* {{{ s_qmod(z, p2) */
2971
2972 static void
2973 s_qmod(mp_int z, mp_size p2)
2974 {
2975         mp_size         start = p2 / MP_DIGIT_BIT + 1,
2976                                 rest = p2 % MP_DIGIT_BIT;
2977         mp_size         uz = MP_USED(z);
2978         mp_digit        mask = (1 << rest) - 1;
2979
2980         if (start <= uz)
2981         {
2982                 MP_USED(z) = start;
2983                 z->digits[start - 1] &= mask;
2984                 CLAMP(z);
2985         }
2986 }
2987
2988 /* }}} */
2989
2990 /* {{{ s_qmul(z, p2) */
2991
2992 static int
2993 s_qmul(mp_int z, mp_size p2)
2994 {
2995         mp_size         uz,
2996                                 need,
2997                                 rest,
2998                                 extra,
2999                                 i;
3000         mp_digit   *from,
3001                            *to,
3002                                 d;
3003
3004         if (p2 == 0)
3005                 return 1;
3006
3007         uz = MP_USED(z);
3008         need = p2 / MP_DIGIT_BIT;
3009         rest = p2 % MP_DIGIT_BIT;
3010
3011         /*
3012          * Figure out if we need an extra digit at the top end; this occurs if the
3013          * topmost `rest' bits of the high-order digit of z are not zero, meaning
3014          * they will be shifted off the end if not preserved
3015          */
3016         extra = 0;
3017         if (rest != 0)
3018         {
3019                 mp_digit   *dz = MP_DIGITS(z) + uz - 1;
3020
3021                 if ((*dz >> (MP_DIGIT_BIT - rest)) != 0)
3022                         extra = 1;
3023         }
3024
3025         if (!s_pad(z, uz + need + extra))
3026                 return 0;
3027
3028         /*
3029          * If we need to shift by whole digits, do that in one pass, then to back
3030          * and shift by partial digits.
3031          */
3032         if (need > 0)
3033         {
3034                 from = MP_DIGITS(z) + uz - 1;
3035                 to = from + need;
3036
3037                 for (i = 0; i < uz; ++i)
3038                         *to-- = *from--;
3039
3040                 ZERO(MP_DIGITS(z), need);
3041                 uz += need;
3042         }
3043
3044         if (rest)
3045         {
3046                 d = 0;
3047                 for (i = need, from = MP_DIGITS(z) + need; i < uz; ++i, ++from)
3048                 {
3049                         mp_digit        save = *from;
3050
3051                         *from = (*from << rest) | (d >> (MP_DIGIT_BIT - rest));
3052                         d = save;
3053                 }
3054
3055                 d >>= (MP_DIGIT_BIT - rest);
3056                 if (d != 0)
3057                 {
3058                         *from = d;
3059                         uz += extra;
3060                 }
3061         }
3062
3063         MP_USED(z) = uz;
3064         CLAMP(z);
3065
3066         return 1;
3067 }
3068
3069 /* }}} */
3070
3071 /* {{{ s_qsub(z, p2) */
3072
3073 /* Subtract |z| from 2^p2, assuming 2^p2 > |z|, and set z to be positive */
3074 static int
3075 s_qsub(mp_int z, mp_size p2)
3076 {
3077         mp_digit        hi = (1 << (p2 % MP_DIGIT_BIT)),
3078                            *zp;
3079         mp_size         tdig = (p2 / MP_DIGIT_BIT),
3080                                 pos;
3081         mp_word         w = 0;
3082
3083         if (!s_pad(z, tdig + 1))
3084                 return 0;
3085
3086         for (pos = 0, zp = MP_DIGITS(z); pos < tdig; ++pos, ++zp)
3087         {
3088                 w = ((mp_word) MP_DIGIT_MAX + 1) - w - (mp_word) * zp;
3089
3090                 *zp = LOWER_HALF(w);
3091                 w = UPPER_HALF(w) ? 0 : 1;
3092         }
3093
3094         w = ((mp_word) MP_DIGIT_MAX + 1 + hi) - w - (mp_word) * zp;
3095         *zp = LOWER_HALF(w);
3096
3097         assert(UPPER_HALF(w) != 0); /* no borrow out should be possible */
3098
3099         MP_SIGN(z) = MP_ZPOS;
3100         CLAMP(z);
3101
3102         return 1;
3103 }
3104
3105 /* }}} */
3106
3107 /* {{{ s_dp2k(z) */
3108
3109 static int
3110 s_dp2k(mp_int z)
3111 {
3112         int                     k = 0;
3113         mp_digit   *dp = MP_DIGITS(z),
3114                                 d;
3115
3116         if (MP_USED(z) == 1 && *dp == 0)
3117                 return 1;
3118
3119         while (*dp == 0)
3120         {
3121                 k += MP_DIGIT_BIT;
3122                 ++dp;
3123         }
3124
3125         d = *dp;
3126         while ((d & 1) == 0)
3127         {
3128                 d >>= 1;
3129                 ++k;
3130         }
3131
3132         return k;
3133 }
3134
3135 /* }}} */
3136
3137 /* {{{ s_isp2(z) */
3138
3139 static int
3140 s_isp2(mp_int z)
3141 {
3142         mp_size         uz = MP_USED(z),
3143                                 k = 0;
3144         mp_digit   *dz = MP_DIGITS(z),
3145                                 d;
3146
3147         while (uz > 1)
3148         {
3149                 if (*dz++ != 0)
3150                         return -1;
3151                 k += MP_DIGIT_BIT;
3152                 --uz;
3153         }
3154
3155         d = *dz;
3156         while (d > 1)
3157         {
3158                 if (d & 1)
3159                         return -1;
3160                 ++k;
3161                 d >>= 1;
3162         }
3163
3164         return (int) k;
3165 }
3166
3167 /* }}} */
3168
3169 /* {{{ s_2expt(z, k) */
3170
3171 static int
3172 s_2expt(mp_int z, int k)
3173 {
3174         mp_size         ndig,
3175                                 rest;
3176         mp_digit   *dz;
3177
3178         ndig = (k + MP_DIGIT_BIT) / MP_DIGIT_BIT;
3179         rest = k % MP_DIGIT_BIT;
3180
3181         if (!s_pad(z, ndig))
3182                 return 0;
3183
3184         dz = MP_DIGITS(z);
3185         ZERO(dz, ndig);
3186         *(dz + ndig - 1) = (1 << rest);
3187         MP_USED(z) = ndig;
3188
3189         return 1;
3190 }
3191
3192 /* }}} */
3193
3194 /* {{{ s_norm(a, b) */
3195
3196 static int
3197 s_norm(mp_int a, mp_int b)
3198 {
3199         mp_digit        d = b->digits[MP_USED(b) - 1];
3200         int                     k = 0;
3201
3202         while (d < (mp_digit) (1 << (MP_DIGIT_BIT - 1)))
3203         {                                                       /* d < (MP_RADIX / 2) */
3204                 d <<= 1;
3205                 ++k;
3206         }
3207
3208         /* These multiplications can't fail */
3209         if (k != 0)
3210         {
3211                 (void) s_qmul(a, (mp_size) k);
3212                 (void) s_qmul(b, (mp_size) k);
3213         }
3214
3215         return k;
3216 }
3217
3218 /* }}} */
3219
3220 /* {{{ s_brmu(z, m) */
3221
3222 static mp_result
3223 s_brmu(mp_int z, mp_int m)
3224 {
3225         mp_size         um = MP_USED(m) * 2;
3226
3227         if (!s_pad(z, um))
3228                 return MP_MEMORY;
3229
3230         s_2expt(z, MP_DIGIT_BIT * um);
3231         return mp_int_div(z, m, z, NULL);
3232 }
3233
3234 /* }}} */
3235
3236 /* {{{ s_reduce(x, m, mu, q1, q2) */
3237
3238 static int
3239 s_reduce(mp_int x, mp_int m, mp_int mu, mp_int q1, mp_int q2)
3240 {
3241         mp_size         um = MP_USED(m),
3242                                 umb_p1,
3243                                 umb_m1;
3244
3245         umb_p1 = (um + 1) * MP_DIGIT_BIT;
3246         umb_m1 = (um - 1) * MP_DIGIT_BIT;
3247
3248         if (mp_int_copy(x, q1) != MP_OK)
3249                 return 0;
3250
3251         /* Compute q2 = floor((floor(x / b^(k-1)) * mu) / b^(k+1)) */
3252         s_qdiv(q1, umb_m1);
3253         UMUL(q1, mu, q2);
3254         s_qdiv(q2, umb_p1);
3255
3256         /* Set x = x mod b^(k+1) */
3257         s_qmod(x, umb_p1);
3258
3259         /*
3260          * Now, q is a guess for the quotient a / m. Compute x - q * m mod
3261          * b^(k+1), replacing x.  This may be off by a factor of 2m, but no more
3262          * than that.
3263          */
3264         UMUL(q2, m, q1);
3265         s_qmod(q1, umb_p1);
3266         (void) mp_int_sub(x, q1, x);    /* can't fail */
3267
3268         /*
3269          * The result may be < 0; if it is, add b^(k+1) to pin it in the proper
3270          * range.
3271          */
3272         if ((CMPZ(x) < 0) && !s_qsub(x, umb_p1))
3273                 return 0;
3274
3275         /*
3276          * If x > m, we need to back it off until it is in range. This will be
3277          * required at most twice.
3278          */
3279         if (mp_int_compare(x, m) >= 0)
3280                 (void) mp_int_sub(x, m, x);
3281         if (mp_int_compare(x, m) >= 0)
3282                 (void) mp_int_sub(x, m, x);
3283
3284         /* At this point, x has been properly reduced. */
3285         return 1;
3286 }
3287
3288 /* }}} */
3289
3290 /* {{{ s_embar(a, b, m, mu, c) */
3291
3292 /* Perform modular exponentiation using Barrett's method, where mu is
3293    the reduction constant for m.  Assumes a < m, b > 0. */
3294 static mp_result
3295 s_embar(mp_int a, mp_int b, mp_int m, mp_int mu, mp_int c)
3296 {
3297         mp_digit   *db,
3298                            *dbt,
3299                                 umu,
3300                                 d;
3301         mpz_t           temp[3];
3302         mp_result       res;
3303         int                     last = 0;
3304
3305         umu = MP_USED(mu);
3306         db = MP_DIGITS(b);
3307         dbt = db + MP_USED(b) - 1;
3308
3309         while (last < 3)
3310                 SETUP(mp_int_init_size(TEMP(last), 2 * umu), last);
3311
3312         (void) mp_int_set_value(c, 1);
3313
3314         /* Take care of low-order digits */
3315         while (db < dbt)
3316         {
3317                 int                     i;
3318
3319                 for (d = *db, i = MP_DIGIT_BIT; i > 0; --i, d >>= 1)
3320                 {
3321                         if (d & 1)
3322                         {
3323                                 /* The use of a second temporary avoids allocation */
3324                                 UMUL(c, a, TEMP(0));
3325                                 if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2)))
3326                                 {
3327                                         res = MP_MEMORY;
3328                                         goto CLEANUP;
3329                                 }
3330                                 mp_int_copy(TEMP(0), c);
3331                         }
3332
3333
3334                         USQR(a, TEMP(0));
3335                         assert(MP_SIGN(TEMP(0)) == MP_ZPOS);
3336                         if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2)))
3337                         {
3338                                 res = MP_MEMORY;
3339                                 goto CLEANUP;
3340                         }
3341                         assert(MP_SIGN(TEMP(0)) == MP_ZPOS);
3342                         mp_int_copy(TEMP(0), a);
3343
3344
3345                 }
3346
3347                 ++db;
3348         }
3349
3350         /* Take care of highest-order digit */
3351         d = *dbt;
3352         for (;;)
3353         {
3354                 if (d & 1)
3355                 {
3356                         UMUL(c, a, TEMP(0));
3357                         if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2)))
3358                         {
3359                                 res = MP_MEMORY;
3360                                 goto CLEANUP;
3361                         }
3362                         mp_int_copy(TEMP(0), c);
3363                 }
3364
3365                 d >>= 1;
3366                 if (!d)
3367                         break;
3368
3369                 USQR(a, TEMP(0));
3370                 if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2)))
3371                 {
3372                         res = MP_MEMORY;
3373                         goto CLEANUP;
3374                 }
3375                 (void) mp_int_copy(TEMP(0), a);
3376         }
3377
3378 CLEANUP:
3379         while (--last >= 0)
3380                 mp_int_clear(TEMP(last));
3381
3382         return res;
3383 }
3384
3385 /* }}} */
3386
3387 /* {{{ s_udiv(a, b) */
3388
3389 /* Precondition:  a >= b and b > 0
3390    Postcondition: a' = a / b, b' = a % b
3391  */
3392 static mp_result
3393 s_udiv(mp_int a, mp_int b)
3394 {
3395         mpz_t           q,
3396                                 r,
3397                                 t;
3398         mp_size         ua,
3399                                 ub,
3400                                 qpos = 0;
3401         mp_digit   *da,
3402                                 btop;
3403         mp_result       res = MP_OK;
3404         int                     k,
3405                                 skip = 0;
3406
3407         /* Force signs to positive */
3408         MP_SIGN(a) = MP_ZPOS;
3409         MP_SIGN(b) = MP_ZPOS;
3410
3411         /* Normalize, per Knuth */
3412         k = s_norm(a, b);
3413
3414         ua = MP_USED(a);
3415         ub = MP_USED(b);
3416         btop = b->digits[ub - 1];
3417         if ((res = mp_int_init_size(&q, ua)) != MP_OK)
3418                 return res;
3419         if ((res = mp_int_init_size(&t, ua + 1)) != MP_OK)
3420                 goto CLEANUP;
3421
3422         da = MP_DIGITS(a);
3423         r.digits = da + ua - 1;         /* The contents of r are shared with a */
3424         r.used = 1;
3425         r.sign = MP_ZPOS;
3426         r.alloc = MP_ALLOC(a);
3427         ZERO(t.digits, t.alloc);
3428
3429         /* Solve for quotient digits, store in q.digits in reverse order */
3430         while (r.digits >= da)
3431         {
3432                 assert(qpos <= q.alloc);
3433
3434                 if (s_ucmp(b, &r) > 0)
3435                 {
3436                         r.digits -= 1;
3437                         r.used += 1;
3438
3439                         if (++skip > 1)
3440                                 q.digits[qpos++] = 0;
3441
3442                         CLAMP(&r);
3443                 }
3444                 else
3445                 {
3446                         mp_word         pfx = r.digits[r.used - 1];
3447                         mp_word         qdigit;
3448
3449                         if (r.used > 1 && (pfx < btop || r.digits[r.used - 2] == 0))
3450                         {
3451                                 pfx <<= MP_DIGIT_BIT / 2;
3452                                 pfx <<= MP_DIGIT_BIT / 2;
3453                                 pfx |= r.digits[r.used - 2];
3454                         }
3455
3456                         qdigit = pfx / btop;
3457                         if (qdigit > MP_DIGIT_MAX)
3458                                 qdigit = 1;
3459
3460                         s_dbmul(MP_DIGITS(b), (mp_digit) qdigit, t.digits, ub);
3461                         t.used = ub + 1;
3462                         CLAMP(&t);
3463                         while (s_ucmp(&t, &r) > 0)
3464                         {
3465                                 --qdigit;
3466                                 (void) mp_int_sub(&t, b, &t);   /* cannot fail */
3467                         }
3468
3469                         s_usub(r.digits, t.digits, r.digits, r.used, t.used);
3470                         CLAMP(&r);
3471
3472                         q.digits[qpos++] = (mp_digit) qdigit;
3473                         ZERO(t.digits, t.used);
3474                         skip = 0;
3475                 }
3476         }
3477
3478         /* Put quotient digits in the correct order, and discard extra zeroes */
3479         q.used = qpos;
3480         REV(mp_digit, q.digits, qpos);
3481         CLAMP(&q);
3482
3483         /* Denormalize the remainder */
3484         CLAMP(a);
3485         if (k != 0)
3486                 s_qdiv(a, k);
3487
3488         mp_int_copy(a, b);                      /* ok:  0 <= r < b */
3489         mp_int_copy(&q, a);                     /* ok:  q <= a     */
3490
3491         mp_int_clear(&t);
3492 CLEANUP:
3493         mp_int_clear(&q);
3494         return res;
3495 }
3496
3497 /* }}} */
3498
3499 /* {{{ s_outlen(z, r) */
3500
3501 /* Precondition:  2 <= r < 64 */
3502 static int
3503 s_outlen(mp_int z, mp_size r)
3504 {
3505         mp_result       bits;
3506         double          raw;
3507
3508         bits = mp_int_count_bits(z);
3509         raw = (double) bits *s_log2[r];
3510
3511         return (int) (raw + 0.999999);
3512 }
3513
3514 /* }}} */
3515
3516 /* {{{ s_inlen(len, r) */
3517
3518 static mp_size
3519 s_inlen(int len, mp_size r)
3520 {
3521         double          raw = (double) len / s_log2[r];
3522         mp_size         bits = (mp_size) (raw + 0.5);
3523
3524         return (mp_size) ((bits + (MP_DIGIT_BIT - 1)) / MP_DIGIT_BIT);
3525 }
3526
3527 /* }}} */
3528
3529 /* {{{ s_ch2val(c, r) */
3530
3531 static int
3532 s_ch2val(char c, int r)
3533 {
3534         int                     out;
3535
3536         if (isdigit((unsigned char) c))
3537                 out = c - '0';
3538         else if (r > 10 && isalpha((unsigned char) c))
3539                 out = toupper((unsigned char) c) - 'A' + 10;
3540         else
3541                 return -1;
3542
3543         return (out >= r) ? -1 : out;
3544 }
3545
3546 /* }}} */
3547
3548 /* {{{ s_val2ch(v, caps) */
3549
3550 static char
3551 s_val2ch(int v, int caps)
3552 {
3553         assert(v >= 0);
3554
3555         if (v < 10)
3556                 return v + '0';
3557         else
3558         {
3559                 char            out = (v - 10) + 'a';
3560
3561                 if (caps)
3562                         return toupper((unsigned char) out);
3563                 else
3564                         return out;
3565         }
3566 }
3567
3568 /* }}} */
3569
3570 /* {{{ s_2comp(buf, len) */
3571
3572 static void
3573 s_2comp(unsigned char *buf, int len)
3574 {
3575         int                     i;
3576         unsigned short s = 1;
3577
3578         for (i = len - 1; i >= 0; --i)
3579         {
3580                 unsigned char c = ~buf[i];
3581
3582                 s = c + s;
3583                 c = s & UCHAR_MAX;
3584                 s >>= CHAR_BIT;
3585
3586                 buf[i] = c;
3587         }
3588
3589         /* last carry out is ignored */
3590 }
3591
3592 /* }}} */
3593
3594 /* {{{ s_tobin(z, buf, *limpos) */
3595
3596 static mp_result
3597 s_tobin(mp_int z, unsigned char *buf, int *limpos, int pad)
3598 {
3599         mp_size         uz;
3600         mp_digit   *dz;
3601         int                     pos = 0,
3602                                 limit = *limpos;
3603
3604         uz = MP_USED(z);
3605         dz = MP_DIGITS(z);
3606         while (uz > 0 && pos < limit)
3607         {
3608                 mp_digit        d = *dz++;
3609                 int                     i;
3610
3611                 for (i = sizeof(mp_digit); i > 0 && pos < limit; --i)
3612                 {
3613                         buf[pos++] = (unsigned char) d;
3614                         d >>= CHAR_BIT;
3615
3616                         /* Don't write leading zeroes */
3617                         if (d == 0 && uz == 1)
3618                                 i = 0;                  /* exit loop without signaling truncation */
3619                 }
3620
3621                 /* Detect truncation (loop exited with pos >= limit) */
3622                 if (i > 0)
3623                         break;
3624
3625                 --uz;
3626         }
3627
3628         if (pad != 0 && (buf[pos - 1] >> (CHAR_BIT - 1)))
3629         {
3630                 if (pos < limit)
3631                         buf[pos++] = 0;
3632                 else
3633                         uz = 1;
3634         }
3635
3636         /* Digits are in reverse order, fix that */
3637         REV(unsigned char, buf, pos);
3638
3639         /* Return the number of bytes actually written */
3640         *limpos = pos;
3641
3642         return (uz == 0) ? MP_OK : MP_TRUNC;
3643 }
3644
3645 /* }}} */
3646
3647 /* {{{ s_print(tag, z) */
3648
3649 #if 0
3650 void
3651 s_print(char *tag, mp_int z)
3652 {
3653         int                     i;
3654
3655         fprintf(stderr, "%s: %c ", tag,
3656                         (MP_SIGN(z) == MP_NEG) ? '-' : '+');
3657
3658         for (i = MP_USED(z) - 1; i >= 0; --i)
3659                 fprintf(stderr, "%0*X", (int) (MP_DIGIT_BIT / 4), z->digits[i]);
3660
3661         fputc('\n', stderr);
3662
3663 }
3664
3665 void
3666 s_print_buf(char *tag, mp_digit * buf, mp_size num)
3667 {
3668         int                     i;
3669
3670         fprintf(stderr, "%s: ", tag);
3671
3672         for (i = num - 1; i >= 0; --i)
3673                 fprintf(stderr, "%0*X", (int) (MP_DIGIT_BIT / 4), buf[i]);
3674
3675         fputc('\n', stderr);
3676 }
3677 #endif
3678
3679 /* }}} */
3680
3681 /* HERE THERE BE DRAGONS */