]> granicus.if.org Git - python/commitdiff
SF patch 936813: fast modular exponentiation
authorTim Peters <tim.peters@gmail.com>
Mon, 30 Aug 2004 02:44:38 +0000 (02:44 +0000)
committerTim Peters <tim.peters@gmail.com>
Mon, 30 Aug 2004 02:44:38 +0000 (02:44 +0000)
This checkin is adapted from part 2 (of 3) of Trevor Perrin's patch set.

BACKWARD INCOMPATIBILITY:  SHIFT must now be divisible by 5.  AFAIK,
nobody will care.  long_pow() could be complicated to worm around that,
if necessary.

long_pow():
  - BUGFIX:  This leaked the base and power when the power was negative
    (and so the computation delegated to float pow).
  - Instead of doing right-to-left exponentiation, do left-to-right.  This
    is more efficient for small bases, which is the common case.
  - In addition, if the exponent is large (more than FIVEARY_CUTOFF
    digits), precompute [a**i % c for i in range(32)], and go left to
    right 5 bits at a time.
l_divmod():
  - The signature changed so that callers who don't want the quotient,
    or don't want the remainder, can pass NULL in the slot they don't
    want.  This saves them from having to declare a vrbl for unwanted
    stuff, and remembering to decref it.
long_mod(), long_div(), long_classic_div():
  - Adjust to new l_divmod() signature, and simplified as a result.

Include/longintrepr.h
Misc/NEWS
Objects/longobject.c

index 9ed1fe737b7e7ad32ca6438e302a0b40a560c2c4..254076e4d4dc5d8b3ee2d159a6bbca36827ad47c 100644 (file)
@@ -15,7 +15,8 @@ extern "C" {
    (at most (BASE-1)*(2*BASE+1) == MASK*(2*MASK+3)).
    Also, x_sub assumes that 'digit' is an unsigned type, and overflow
    is handled by taking the result mod 2**N for some N > SHIFT.
-   And, at some places it is assumed that MASK fits in an int, as well. */
+   And, at some places it is assumed that MASK fits in an int, as well.
+   long_pow() requires that SHIFT be divisible by 5. */
 
 typedef unsigned short digit;
 typedef unsigned int wdigit; /* digit widened to parameter size */
@@ -27,6 +28,10 @@ typedef BASE_TWODIGITS_TYPE stwodigits; /* signed variant of twodigits */
 #define BASE   ((digit)1 << SHIFT)
 #define MASK   ((int)(BASE - 1))
 
+#if SHIFT % 5 != 0
+#error "longobject.c requires that SHIFT be divisible by 5"
+#endif
+
 /* Long integer representation.
    The absolute value of a number is equal to
        SUM(for i=0 through abs(ob_size)-1) ob_digit[i] * 2**(SHIFT*i)
index 431b343aa06c7bf8664eed331ee16df709c4f93d..660c49fa755d08e8829eb458f98a276d2a823b98 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -20,7 +20,11 @@ Core and builtins
   to compute 17**1000000 dropped from about 14 seconds to 9 on my box due
   to this much.  The cutoff for Karatsuba multiplication was raised,
   since gradeschool multiplication got quicker, and the cutoff was
-  aggressively small regardless.
+  aggressively small regardless.  The exponentiation algorithm was switched
+  from right-to-left to left-to-right, which is more efficient for small
+  bases.  In addition, if the exponent is large, the algorithm now does
+  5 bits (instead of 1 bit) at a time.  That cut the time to compute
+  17**1000000 on my box in half again, down to about 4.5 seconds.
 
 - OverflowWarning is no longer generated.  PEP 237 scheduled this to
   occur in Python 2.3, but since OverflowWarning was disabled by default,
@@ -156,6 +160,14 @@ Tools/Demos
 Build
 -----
 
+- Backward incompatibility:  longintrepr.h now triggers a compile-time
+  error if SHIFT (the number of bits in a Python long "digit") isn't
+  divisible by 5.  This new requirement allows simple code for the new
+  5-bits-at-a-time long_pow() implementation.  If necessary, the
+  restriction could be removed (by complicating long_pow(), or by
+  falling back to the 1-bit-at-a-time algorithm), but there are no
+  plans to do so.
+
 - bug #991962: When building with --disable-toolbox-glue on Darwin no
   attempt to build Mac-specific modules occurs.
 
index 2f6d103bfec5fc74797669f97e4af4c197cdec9f..05c42ad47d747632b95292ed4055db062139eec7 100644 (file)
 #define KARATSUBA_CUTOFF 70
 #define KARATSUBA_SQUARE_CUTOFF (2 * KARATSUBA_CUTOFF)
 
+/* For exponentiation, use the binary left-to-right algorithm
+ * unless the exponent contains more than FIVEARY_CUTOFF digits.
+ * In that case, do 5 bits at a time.  The potential drawback is that
+ * a table of 2**5 intermediate results is computed.
+ */
+#define FIVEARY_CUTOFF 8
+
 #define ABS(x) ((x) < 0 ? -(x) : (x))
 
 #undef MIN
@@ -2136,6 +2143,12 @@ long_mul(PyLongObject *v, PyLongObject *w)
    have different signs.  We then subtract one from the 'div'
    part of the outcome to keep the invariant intact. */
 
+/* Compute
+ *     *pdiv, *pmod = divmod(v, w)
+ * NULL can be passed for pdiv or pmod, in which case that part of
+ * the result is simply thrown away.  The caller owns a reference to
+ * each of these it requests (does not pass NULL for).
+ */
 static int
 l_divmod(PyLongObject *v, PyLongObject *w,
         PyLongObject **pdiv, PyLongObject **pmod)
@@ -2167,44 +2180,43 @@ l_divmod(PyLongObject *v, PyLongObject *w,
                Py_DECREF(div);
                div = temp;
        }
-       *pdiv = div;
-       *pmod = mod;
+       if (pdiv != NULL)
+               *pdiv = div;
+       else
+               Py_DECREF(div);
+
+       if (pmod != NULL)
+               *pmod = mod;
+       else
+               Py_DECREF(mod);
+
        return 0;
 }
 
 static PyObject *
 long_div(PyObject *v, PyObject *w)
 {
-       PyLongObject *a, *b, *div, *mod;
+       PyLongObject *a, *b, *div;
 
        CONVERT_BINOP(v, w, &a, &b);
-
-       if (l_divmod(a, b, &div, &mod) < 0) {
-               Py_DECREF(a);
-               Py_DECREF(b);
-               return NULL;
-       }
+       if (l_divmod(a, b, &div, NULL) < 0)
+               div = NULL;
        Py_DECREF(a);
        Py_DECREF(b);
-       Py_DECREF(mod);
        return (PyObject *)div;
 }
 
 static PyObject *
 long_classic_div(PyObject *v, PyObject *w)
 {
-       PyLongObject *a, *b, *div, *mod;
+       PyLongObject *a, *b, *div;
 
        CONVERT_BINOP(v, w, &a, &b);
-
        if (Py_DivisionWarningFlag &&
            PyErr_Warn(PyExc_DeprecationWarning, "classic long division") < 0)
                div = NULL;
-       else if (l_divmod(a, b, &div, &mod) < 0)
+       else if (l_divmod(a, b, &div, NULL) < 0)
                div = NULL;
-       else
-               Py_DECREF(mod);
-
        Py_DECREF(a);
        Py_DECREF(b);
        return (PyObject *)div;
@@ -2255,18 +2267,14 @@ overflow:
 static PyObject *
 long_mod(PyObject *v, PyObject *w)
 {
-       PyLongObject *a, *b, *div, *mod;
+       PyLongObject *a, *b, *mod;
 
        CONVERT_BINOP(v, w, &a, &b);
 
-       if (l_divmod(a, b, &div, &mod) < 0) {
-               Py_DECREF(a);
-               Py_DECREF(b);
-               return NULL;
-       }
+       if (l_divmod(a, b, NULL, &mod) < 0)
+               mod = NULL;
        Py_DECREF(a);
        Py_DECREF(b);
-       Py_DECREF(div);
        return (PyObject *)mod;
 }
 
@@ -2297,22 +2305,33 @@ long_divmod(PyObject *v, PyObject *w)
        return z;
 }
 
+/* pow(v, w, x) */
 static PyObject *
 long_pow(PyObject *v, PyObject *w, PyObject *x)
 {
-       PyLongObject *a, *b;
-       PyObject *c;
-       PyLongObject *z, *div, *mod;
-       int size_b, i;
+       PyLongObject *a, *b, *c; /* a,b,c = v,w,x */
+       int negativeOutput = 0;  /* if x<0 return negative output */
+
+       PyLongObject *z = NULL;  /* accumulated result */
+       int i, j, k;             /* counters */
+       PyLongObject *temp = NULL;
 
+       /* 5-ary values.  If the exponent is large enough, table is
+        * precomputed so that table[i] == a**i % c for i in range(32).
+        */
+       PyLongObject *table[32] = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
+                                  0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
+
+       /* a, b, c = v, w, x */
        CONVERT_BINOP(v, w, &a, &b);
-       if (PyLong_Check(x) || Py_None == x) {
-               c = x;
+       if (PyLong_Check(x)) {
+               c = (PyLongObject *)x;
                Py_INCREF(x);
        }
-       else if (PyInt_Check(x)) {
-               c = PyLong_FromLong(PyInt_AS_LONG(x));
-       }
+       else if (PyInt_Check(x))
+               c = (PyLongObject *)PyLong_FromLong(PyInt_AS_LONG(x));
+       else if (x == Py_None)
+               c = NULL;
        else {
                Py_DECREF(a);
                Py_DECREF(b);
@@ -2320,95 +2339,154 @@ long_pow(PyObject *v, PyObject *w, PyObject *x)
                return Py_NotImplemented;
        }
 
-       if (c != Py_None && ((PyLongObject *)c)->ob_size == 0) {
-               PyErr_SetString(PyExc_ValueError,
-                               "pow() 3rd argument cannot be 0");
-               z = NULL;
-               goto error;
-       }
-
-       size_b = b->ob_size;
-       if (size_b < 0) {
-               Py_DECREF(a);
-               Py_DECREF(b);
-               Py_DECREF(c);
-               if (x != Py_None) {
+       if (b->ob_size < 0) {  /* if exponent is negative */
+               if (c) {
                        PyErr_SetString(PyExc_TypeError, "pow() 2nd argument "
-                            "cannot be negative when 3rd argument specified");
+                           "cannot be negative when 3rd argument specified");
                        return NULL;
                }
-               /* Return a float.  This works because we know that
-                  this calls float_pow() which converts its
-                  arguments to double. */
-               return PyFloat_Type.tp_as_number->nb_power(v, w, x);
+               else {
+                       /* else return a float.  This works because we know
+                          that this calls float_pow() which converts its
+                          arguments to double. */
+                       Py_DECREF(a);
+                       Py_DECREF(b);
+                       return PyFloat_Type.tp_as_number->nb_power(v, w, x);
+               }
        }
-       z = (PyLongObject *)PyLong_FromLong(1L);
-       for (i = 0; i < size_b; ++i) {
-               digit bi = b->ob_digit[i];
-               int j;
 
-               for (j = 0; j < SHIFT; ++j) {
-                       PyLongObject *temp;
+       if (c) {
+               /* if modulus == 0:
+                      raise ValueError() */
+               if (c->ob_size == 0) {
+                       PyErr_SetString(PyExc_ValueError,
+                                       "pow() 3rd argument cannot be 0");
+                       goto Done;
+               }
 
-                       if (bi & 1) {
-                               temp = (PyLongObject *)long_mul(z, a);
-                               Py_DECREF(z);
-                               if (c!=Py_None && temp!=NULL) {
-                                       if (l_divmod(temp,(PyLongObject *)c,
-                                                       &div,&mod) < 0) {
-                                               Py_DECREF(temp);
-                                               z = NULL;
-                                               goto error;
-                                       }
-                                       Py_XDECREF(div);
-                                       Py_DECREF(temp);
-                                       temp = mod;
-                               }
-                               z = temp;
-                               if (z == NULL)
-                                       break;
-                       }
-                       bi >>= 1;
-                       if (bi == 0 && i+1 == size_b)
-                               break;
-                       temp = (PyLongObject *)long_mul(a, a);
+               /* if modulus < 0:
+                      negativeOutput = True
+                      modulus = -modulus */
+               if (c->ob_size < 0) {
+                       negativeOutput = 1;
+                       temp = (PyLongObject *)_PyLong_Copy(c);
+                       if (temp == NULL)
+                               goto Error;
+                       Py_DECREF(c);
+                       c = temp;
+                       temp = NULL;
+                       c->ob_size = - c->ob_size;
+               }
+
+               /* if modulus == 1:
+                      return 0 */
+               if ((c->ob_size == 1) && (c->ob_digit[0] == 1)) {
+                       z = (PyLongObject *)PyLong_FromLong(0L);
+                       goto Done;
+               }
+
+               /* if base < 0:
+                      base = base % modulus
+                  Having the base positive just makes things easier. */
+               if (a->ob_size < 0) {
+                       if (l_divmod(a, c, NULL, &temp) < 0)
+                               goto Done;
                        Py_DECREF(a);
-                       if (c!=Py_None && temp!=NULL) {
-                               if (l_divmod(temp, (PyLongObject *)c, &div,
-                                                       &mod) < 0) {
-                                       Py_DECREF(temp);
-                                       z = NULL;
-                                       goto error;
-                               }
-                               Py_XDECREF(div);
-                               Py_DECREF(temp);
-                               temp = mod;
-                       }
                        a = temp;
-                       if (a == NULL) {
-                               Py_DECREF(z);
-                               z = NULL;
-                               break;
-                       }
+                       temp = NULL;
                }
-               if (a == NULL || z == NULL)
-                       break;
        }
-       if (c!=Py_None && z!=NULL) {
-               if (l_divmod(z, (PyLongObject *)c, &div, &mod) < 0) {
-                       Py_DECREF(z);
-                       z = NULL;
+
+       /* At this point a, b, and c are guaranteed non-negative UNLESS
+          c is NULL, in which case a may be negative. */
+
+       z = (PyLongObject *)PyLong_FromLong(1L);
+       if (z == NULL)
+               goto Error;
+
+       /* Perform a modular reduction, X = X % c, but leave X alone if c
+        * is NULL.
+        */
+#define REDUCE(X)                                      \
+       if (c != NULL) {                                \
+               if (l_divmod(X, c, NULL, &temp) < 0)    \
+                       goto Error;                     \
+               Py_XDECREF(X);                          \
+               X = temp;                               \
+               temp = NULL;                            \
+       }
+
+       /* Multiply two values, then reduce the result:
+          result = X*Y % c.  If c is NULL, skip the mod. */
+#define MULT(X, Y, result)                             \
+{                                                      \
+       temp = (PyLongObject *)long_mul(X, Y);          \
+       if (temp == NULL)                               \
+               goto Error;                             \
+       Py_XDECREF(result);                             \
+       result = temp;                                  \
+       temp = NULL;                                    \
+       REDUCE(result)                                  \
+}
+
+       if (b->ob_size <= FIVEARY_CUTOFF) {
+               /* Left-to-right binary exponentiation (HAC Algorithm 14.79) */
+               /* http://www.cacr.math.uwaterloo.ca/hac/about/chap14.pdf    */
+               for (i = b->ob_size - 1; i >= 0; --i) {
+                       digit bi = b->ob_digit[i];
+
+                       for (j = 1 << (SHIFT-1); j != 0; j >>= 1) {
+                               MULT(z, z, z)
+                               if (bi & j)
+                                       MULT(z, a, z)
+                       }
                }
-               else {
-                       Py_XDECREF(div);
-                       Py_DECREF(z);
-                       z = mod;
+       }
+       else {
+               /* Left-to-right 5-ary exponentiation (HAC Algorithm 14.82) */
+               Py_INCREF(z);   /* still holds 1L */
+               table[0] = z;
+               for (i = 1; i < 32; ++i)
+                       MULT(table[i-1], a, table[i])
+
+               for (i = b->ob_size - 1; i >= 0; --i) {
+                       const digit bi = b->ob_digit[i];
+
+                       for (j = SHIFT - 5; j >= 0; j -= 5) {
+                               const int index = (bi >> j) & 0x1f;
+                               for (k = 0; k < 5; ++k)
+                                       MULT(z, z, z)
+                               if (index)
+                                       MULT(z, table[index], z)
+                       }
                }
        }
-  error:
+
+       if (negativeOutput && (z->ob_size != 0)) {
+               temp = (PyLongObject *)long_sub(z, c);
+               if (temp == NULL)
+                       goto Error;
+               Py_DECREF(z);
+               z = temp;
+               temp = NULL;
+       }
+       goto Done;
+
+ Error:
+       if (z != NULL) {
+               Py_DECREF(z);
+               z = NULL;
+       }
+       /* fall through */
+ Done:
        Py_XDECREF(a);
-       Py_DECREF(b);
-       Py_DECREF(c);
+       Py_XDECREF(b);
+       Py_XDECREF(c);
+       Py_XDECREF(temp);
+       if (b->ob_size > FIVEARY_CUTOFF) {
+               for (i = 0; i < 32; ++i)
+                       Py_XDECREF(table[i]);
+       }
        return (PyObject *)z;
 }