]> granicus.if.org Git - python/commitdiff
Merged revisions 70542 via svnmerge from
authorMark Dickinson <dickinsm@gmail.com>
Mon, 23 Mar 2009 18:44:57 +0000 (18:44 +0000)
committerMark Dickinson <dickinsm@gmail.com>
Mon, 23 Mar 2009 18:44:57 +0000 (18:44 +0000)
svn+ssh://pythondev@svn.python.org/python/trunk

........
  r70542 | mark.dickinson | 2009-03-23 18:25:13 +0000 (Mon, 23 Mar 2009) | 14 lines

  Issue #5512: speed up the long division algorithm for Python longs.
  The basic algorithm remains the same; the most significant speedups
  come from the following three changes:

    (1) normalize by shifting instead of multiplying and dividing
    (2) the old algorithm usually did an unnecessary extra iteration of
        the outer loop; remove this.  As a special case, this means that
        long divisions with a single-digit result run twice as fast as
        before.
    (3) make inner loop much tighter.

  Various benchmarks show speedups of between 50% and 150% for long
  integer divisions and modulo operations.
........

Misc/NEWS
Objects/longobject.c

index 460a36edf7ff398c7b4729ad6bb4b516faf4f138..8c752a7a7d17ba231a5d067bc0af2a7d93478c25 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -12,6 +12,10 @@ What's New in Python 3.1 alpha 2?
 Core and Builtins
 -----------------
 
+- Issue #5512: Rewrite PyLong long division algorithm (x_divrem) to
+  improve its performance.  Long divisions and remainder operations
+  are now between 50% and 150% faster.
+
 - Issue #4258: Make it possible to use base 2**30 instead of base
   2**15 for the internal representation of integers, for performance
   reasons.  Base 2**30 is enabled by default on 64-bit machines.  Add
index e1df9d9ef775d143dec76e0f9ccaf2c596e82eef..e5e23dfdb92862ce294b0dc3173e2d1b53ae8ad3 100644 (file)
@@ -1363,6 +1363,26 @@ PyLong_AsUnsignedLongLongMask(register PyObject *op)
                return Py_NotImplemented; \
        }
 
+/* bits_in_digit(d) returns the unique integer k such that 2**(k-1) <= d <
+   2**k if d is nonzero, else 0. */
+
+static const unsigned char BitLengthTable[32] = {
+       0, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4,
+       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5
+};
+
+static int
+bits_in_digit(digit d)
+{
+       int d_bits = 0;
+       while (d >= 32) {
+               d_bits += 6;
+               d >>= 6;
+       }
+       d_bits += (int)BitLengthTable[d];
+       return d_bits;
+}
+
 /* x[0:m] and y[0:n] are digit vectors, LSD first, m >= n required.  x[0:n]
  * is modified in place, by adding y to it.  Carries are propagated as far as
  * x[m-1], and the remaining carry (0 or 1) is returned.
@@ -1415,25 +1435,41 @@ v_isub(digit *x, Py_ssize_t m, digit *y, Py_ssize_t n)
        return borrow;
 }
 
-/* Multiply by a single digit, ignoring the sign. */
+/* Shift digit vector a[0:m] d bits left, with 0 <= d < PyLong_SHIFT.  Put
+ * result in z[0:m], and return the d bits shifted out of the top.
+ */
+static digit
+v_lshift(digit *z, digit *a, Py_ssize_t m, int d)
+{
+       Py_ssize_t i;
+       digit carry = 0;
 
-static PyLongObject *
-mul1(PyLongObject *a, digit n)
+       assert(0 <= d && d < PyLong_SHIFT);
+       for (i=0; i < m; i++) {
+               twodigits acc = (twodigits)a[i] << d | carry;
+               z[i] = (digit)acc & PyLong_MASK;
+               carry = (digit)(acc >> PyLong_SHIFT);
+       }
+       return carry;
+}
+
+/* Shift digit vector a[0:m] d bits right, with 0 <= d < PyLong_SHIFT.  Put
+ * result in z[0:m], and return the d bits shifted out of the bottom.
+ */
+static digit
+v_rshift(digit *z, digit *a, Py_ssize_t m, int d)
 {
-       Py_ssize_t size_a = ABS(Py_SIZE(a));
-       PyLongObject *z = _PyLong_New(size_a+1);
-       twodigits carry = 0;
        Py_ssize_t i;
+       digit carry = 0;
+       digit mask = ((digit)1 << d) - 1U;
 
-       if (z == NULL)
-               return NULL;
-       for (i = 0; i < size_a; ++i) {
-               carry += (twodigits)a->ob_digit[i] * n;
-               z->ob_digit[i] = (digit) (carry & PyLong_MASK);
-               carry >>= PyLong_SHIFT;
+       assert(0 <= d && d < PyLong_SHIFT);
+       for (i=m; i-- > 0;) {
+               twodigits acc = (twodigits)carry << PyLong_SHIFT | a[i];
+               carry = (digit)acc & mask;
+               z[i] = (digit)(acc >> d);
        }
-       z->ob_digit[i] = (digit) carry;
-       return long_normalize(z);
+       return carry;
 }
 
 /* Divide long pin, w/ size digits, by non-zero digit n, storing quotient
@@ -2089,104 +2125,131 @@ long_divrem(PyLongObject *a, PyLongObject *b,
        return 0;
 }
 
-/* Unsigned long division with remainder -- the algorithm */
+/* Unsigned long division with remainder -- the algorithm.  The arguments v1
+   and w1 should satisfy 2 <= ABS(Py_SIZE(w1)) <= ABS(Py_SIZE(v1)). */
 
 static PyLongObject *
 x_divrem(PyLongObject *v1, PyLongObject *w1, PyLongObject **prem)
 {
-       Py_ssize_t size_v = ABS(Py_SIZE(v1)), size_w = ABS(Py_SIZE(w1));
-       digit d = (digit) ((twodigits)PyLong_BASE / (w1->ob_digit[size_w-1] + 1));
-       PyLongObject *v = mul1(v1, d);
-       PyLongObject *w = mul1(w1, d);
-       PyLongObject *a;
-       Py_ssize_t j, k;
-
-       if (v == NULL || w == NULL) {
-               Py_XDECREF(v);
-               Py_XDECREF(w);
+       PyLongObject *v, *w, *a;
+       Py_ssize_t i, k, size_v, size_w;
+       int d;
+       digit wm1, wm2, carry, q, r, vtop, *v0, *vk, *w0, *ak;
+       twodigits vv;
+       sdigit zhi;
+       stwodigits z;
+
+       /* We follow Knuth [The Art of Computer Programming, Vol. 2 (3rd
+          edn.), section 4.3.1, Algorithm D], except that we don't explicitly
+          handle the special case when the initial estimate q for a quotient
+          digit is >= PyLong_BASE: the max value for q is PyLong_BASE+1, and
+          that won't overflow a digit. */
+
+       /* allocate space; w will also be used to hold the final remainder */
+       size_v = ABS(Py_SIZE(v1));
+       size_w = ABS(Py_SIZE(w1));
+       assert(size_v >= size_w && size_w >= 2); /* Assert checks by div() */
+       v = _PyLong_New(size_v+1);
+       if (v == NULL) {
+               *prem = NULL;
+               return NULL;
+       }
+       w = _PyLong_New(size_w);
+       if (w == NULL) {
+               Py_DECREF(v);
+               *prem = NULL;
                return NULL;
        }
 
-       assert(size_v >= size_w && size_w > 1); /* Assert checks by div() */
-       assert(Py_REFCNT(v) == 1); /* Since v will be used as accumulator! */
-       assert(size_w == ABS(Py_SIZE(w))); /* That's how d was calculated */
+       /* normalize: shift w1 left so that its top digit is >= PyLong_BASE/2.
+          shift v1 left by the same amount.  Results go into w and v. */
+       d = PyLong_SHIFT - bits_in_digit(w1->ob_digit[size_w-1]);
+       carry = v_lshift(w->ob_digit, w1->ob_digit, size_w, d);
+       assert(carry == 0);
+       carry = v_lshift(v->ob_digit, v1->ob_digit, size_v, d);
+       if (carry != 0 || v->ob_digit[size_v-1] >= w->ob_digit[size_w-1]) {
+               v->ob_digit[size_v] = carry;
+               size_v++;
+       }
 
-       size_v = ABS(Py_SIZE(v));
+       /* Now v->ob_digit[size_v-1] < w->ob_digit[size_w-1], so quotient has
+          at most (and usually exactly) k = size_v - size_w digits. */
        k = size_v - size_w;
-       a = _PyLong_New(k + 1);
-
-       for (j = size_v; a != NULL && k >= 0; --j, --k) {
-               digit vj = (j >= size_v) ? 0 : v->ob_digit[j];
-               twodigits q;
-               stwodigits carry = 0;
-               Py_ssize_t i;
+       assert(k >= 0);
+       a = _PyLong_New(k);
+       if (a == NULL) {
+               Py_DECREF(w);
+               Py_DECREF(v);
+               *prem = NULL;
+               return NULL;
+       }
+       v0 = v->ob_digit;
+       w0 = w->ob_digit;
+       wm1 = w0[size_w-1];
+       wm2 = w0[size_w-2];
+       for (vk = v0+k, ak = a->ob_digit + k; vk-- > v0;) {
+               /* inner loop: divide vk[0:size_w+1] by w0[0:size_w], giving
+                  single-digit quotient q, remainder in vk[0:size_w]. */
 
                SIGCHECK({
                        Py_DECREF(a);
-                       a = NULL;
-                       break;
+                       Py_DECREF(w);
+                       Py_DECREF(v);
+                       *prem = NULL;
+                       return NULL;
                })
-               if (vj == w->ob_digit[size_w-1])
-                       q = PyLong_MASK;
-               else
-                       q = (((twodigits)vj << PyLong_SHIFT) + v->ob_digit[j-1]) /
-                               w->ob_digit[size_w-1];
-
-               while (w->ob_digit[size_w-2]*q >
-                               ((
-                                       ((twodigits)vj << PyLong_SHIFT)
-                                       + v->ob_digit[j-1]
-                                       - q*w->ob_digit[size_w-1]
-                                                               ) << PyLong_SHIFT)
-                               + v->ob_digit[j-2])
-                       --q;
 
-               for (i = 0; i < size_w && i+k < size_v; ++i) {
-                       twodigits z = w->ob_digit[i] * q;
-                       digit zz = (digit) (z >> PyLong_SHIFT);
-                       carry += v->ob_digit[i+k] - z
-                               + ((twodigits)zz << PyLong_SHIFT);
-                       v->ob_digit[i+k] = (digit)(carry & PyLong_MASK);
-                       carry = Py_ARITHMETIC_RIGHT_SHIFT(stwodigits,
-                                                         carry, PyLong_SHIFT);
-                       carry -= zz;
+               /* estimate quotient digit q; may overestimate by 1 (rare) */
+               vtop = vk[size_w];
+               assert(vtop <= wm1);
+               vv = ((twodigits)vtop << PyLong_SHIFT) | vk[size_w-1];
+               q = (digit)(vv / wm1);
+               r = (digit)(vv - (twodigits)wm1 * q); /* r = vv % wm1 */
+               while ((twodigits)wm2 * q > (((twodigits)r << PyLong_SHIFT)
+                                            | vk[size_w-2])) {
+                       --q;
+                       r += wm1;
+                       if (r >= PyLong_BASE)
+                               break;
                }
-
-               if (i+k < size_v) {
-                       carry += v->ob_digit[i+k];
-                       v->ob_digit[i+k] = 0;
+               assert(q <= PyLong_BASE);
+
+               /* subtract q*w0[0:size_w] from vk[0:size_w+1] */
+               zhi = 0;
+               for (i = 0; i < size_w; ++i) {
+                       /* invariants: -PyLong_BASE <= -q <= zhi <= 0;
+                          -PyLong_BASE * q <= z < PyLong_BASE */
+                       z = (sdigit)vk[i] + zhi -
+                               (stwodigits)q * (stwodigits)w0[i];
+                       vk[i] = (digit)z & PyLong_MASK;
+                       zhi = (sdigit)Py_ARITHMETIC_RIGHT_SHIFT(stwodigits,
+                                                       z, PyLong_SHIFT);
                }
 
-               if (carry == 0)
-                       a->ob_digit[k] = (digit) q;
-               else {
-                       assert(carry == -1);
-                       a->ob_digit[k] = (digit) q-1;
+               /* add w back if q was too large (this branch taken rarely) */
+               assert((sdigit)vtop + zhi == -1 || (sdigit)vtop + zhi == 0);
+               if ((sdigit)vtop + zhi < 0) {
                        carry = 0;
-                       for (i = 0; i < size_w && i+k < size_v; ++i) {
-                               carry += v->ob_digit[i+k] + w->ob_digit[i];
-                               v->ob_digit[i+k] = (digit)(carry & PyLong_MASK);
-                               carry = Py_ARITHMETIC_RIGHT_SHIFT(
-                                               stwodigits,
-                                               carry, PyLong_SHIFT);
+                       for (i = 0; i < size_w; ++i) {
+                               carry += vk[i] + w0[i];
+                               vk[i] = carry & PyLong_MASK;
+                               carry >>= PyLong_SHIFT;
                        }
+                       --q;
                }
-       } /* for j, k */
 
-       if (a == NULL)
-               *prem = NULL;
-       else {
-               a = long_normalize(a);
-               *prem = divrem1(v, d, &d);
-               /* d receives the (unused) remainder */
-               if (*prem == NULL) {
-                       Py_DECREF(a);
-                       a = NULL;
-               }
+               /* store quotient digit */
+               assert(q < PyLong_BASE);
+               *--ak = q;
        }
+
+       /* unshift remainder; we reuse w to store the result */
+       carry = v_rshift(w0, v0, size_w, d);
+       assert(carry==0);
        Py_DECREF(v);
-       Py_DECREF(w);
-       return a;
+
+       *prem = long_normalize(w);
+       return long_normalize(a);
 }
 
 /* Methods */
@@ -3793,11 +3856,6 @@ long_sizeof(PyLongObject *v)
        return PyLong_FromSsize_t(res);
 }
 
-static const unsigned char BitLengthTable[32] = {
-       0, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4,
-       5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5
-};
-
 static PyObject *
 long_bit_length(PyLongObject *v)
 {