]> granicus.if.org Git - python/commitdiff
k_mul: Rearranged computation for better cache use. Ignored overflow
authorTim Peters <tim.peters@gmail.com>
Mon, 12 Aug 2002 15:08:20 +0000 (15:08 +0000)
committerTim Peters <tim.peters@gmail.com>
Mon, 12 Aug 2002 15:08:20 +0000 (15:08 +0000)
(it's possible, but should be harmless -- this requires more thought,
and allocating enough space in advance to prevent it requires exactly
as much thought, to know exactly how much that is -- the end result
certainly fits in the allocated space -- hmm, but that's really all
the thought it needs!  borrows/carries out of the high digits really
are harmless).

Objects/longobject.c

index 6dedd389f6eeae16213d020eda16e8af80b969eb..bf82d732785626e4b77be8909bf3c1b728b5d1e5 100644 (file)
@@ -1598,20 +1598,17 @@ kmul_split(PyLongObject *n, int size, PyLongObject **high, PyLongObject **low)
 static PyLongObject *
 k_mul(PyLongObject *a, PyLongObject *b)
 {
+       int asize = ABS(a->ob_size);
+       int bsize = ABS(b->ob_size);
        PyLongObject *ah = NULL;
        PyLongObject *al = NULL;
        PyLongObject *bh = NULL;
        PyLongObject *bl = NULL;
-       PyLongObject *albl = NULL;
-       PyLongObject *ahbh = NULL;
-       PyLongObject *k = NULL;
        PyLongObject *ret = NULL;
-       PyLongObject *t1, *t2;
+       PyLongObject *t1, *t2, *t3;
        int shift;      /* the number of digits we split off */
        int i;
-#ifdef Py_DEBUG
-       digit d;
-#endif
+
        /* (ah*X+al)(bh*X+bl) = ah*bh*X*X + (ah*bl + al*bh)*X + al*bl
         * Let k = (ah+al)*(bh+bl) = ah*bl + al*bh  + ah*bh + al*bl
         * Then the original product is
@@ -1623,59 +1620,75 @@ k_mul(PyLongObject *a, PyLongObject *b)
        /* We want to split based on the larger number; fiddle so that b
         * is largest.
         */
-       if (ABS(a->ob_size) > ABS(b->ob_size)) {
+       if (asize > bsize) {
                t1 = a;
                a = b;
                b = t1;
+
+               i = asize;
+               asize = bsize;
+               bsize = i;
        }
 
        /* Use gradeschool math when either number is too small. */
-       if (ABS(a->ob_size) <= KARATSUBA_CUTOFF) {
+       if (asize <= KARATSUBA_CUTOFF) {
                /* 0 is inevitable if one kmul arg has more than twice
                 * the digits of another, so it's worth special-casing.
                 */
-               if (a->ob_size == 0)
+               if (asize == 0)
                        return _PyLong_New(0);
                else
                        return x_mul(a, b);
        }
 
-       shift = ABS(b->ob_size) >> 1;
+       shift = bsize >> 1;
        if (kmul_split(a, shift, &ah, &al) < 0) goto fail;
        if (kmul_split(b, shift, &bh, &bl) < 0) goto fail;
 
-       if ((ahbh = k_mul(ah, bh)) == NULL) goto fail;
-       assert(ahbh->ob_size >= 0);
-
-       /* Allocate result space, and copy ahbh into the high digits. */
-       ret = _PyLong_New(ABS(a->ob_size) + ABS(b->ob_size));
+       /* Allocate result space. */
+       ret = _PyLong_New(asize + bsize);
        if (ret == NULL) goto fail;
 #ifdef Py_DEBUG
        /* Fill with trash, to catch reference to uninitialized digits. */
        memset(ret->ob_digit, 0xDF, ret->ob_size * sizeof(digit));
 #endif
-       assert(2*shift + ahbh->ob_size <= ret->ob_size);
-       memcpy(ret->ob_digit + 2*shift, ahbh->ob_digit,
-              ahbh->ob_size * sizeof(digit));
 
-       /* Zero-out the digits higher than the ahbh copy. */
-       i = ret->ob_size - 2*shift - ahbh->ob_size;
+       /* t1 <- ah*bh, and copy into high digits of result. */
+       if ((t1 = k_mul(ah, bh)) == NULL) goto fail;
+       assert(t1->ob_size >= 0);
+       assert(2*shift + t1->ob_size <= ret->ob_size);
+       memcpy(ret->ob_digit + 2*shift, t1->ob_digit,
+              t1->ob_size * sizeof(digit));
+
+       /* Zero-out the digits higher than the ah*bh copy. */
+       i = ret->ob_size - 2*shift - t1->ob_size;
        if (i)
-               memset(ret->ob_digit + 2*shift + ahbh->ob_size, 0,
+               memset(ret->ob_digit + 2*shift + t1->ob_size, 0,
                       i * sizeof(digit));
 
-       /* Compute al*bl, and copy into the low digits. */
-       if ((albl = k_mul(al, bl)) == NULL) goto fail;
-       assert(albl->ob_size >= 0);
-       assert(albl->ob_size <= 2*shift); /* no overlap with high digits */
-       memcpy(ret->ob_digit, albl->ob_digit, albl->ob_size * sizeof(digit));
+       /* t2 <- al*bl, and copy into the low digits. */
+       if ((t2 = k_mul(al, bl)) == NULL) {
+               Py_DECREF(t1);
+               goto fail;
+       }
+       assert(t2->ob_size >= 0);
+       assert(t2->ob_size <= 2*shift); /* no overlap with high digits */
+       memcpy(ret->ob_digit, t2->ob_digit, t2->ob_size * sizeof(digit));
 
        /* Zero out remaining digits. */
-       i = 2*shift - albl->ob_size;    /* number of uninitialized digits */
+       i = 2*shift - t2->ob_size;      /* number of uninitialized digits */
        if (i)
-               memset(ret->ob_digit + albl->ob_size, 0, i * sizeof(digit));
+               memset(ret->ob_digit + t2->ob_size, 0, i * sizeof(digit));
+
+       /* Subtract ah*bh (t1) and al*bl (t2) from "the middle" digits. */
+       i = ret->ob_size - shift;  /* # digits after shift */
+       v_isub(ret->ob_digit + shift, i, t2->ob_digit, t2->ob_size);
+       Py_DECREF(t2);
 
-       /* k = (ah+al)(bh+bl) */
+       v_isub(ret->ob_digit + shift, i, t1->ob_digit, t1->ob_size);
+       Py_DECREF(t1);
+
+       /* t3 <- (ah+al)(bh+bl) */
        if ((t1 = x_add(ah, al)) == NULL) goto fail;
        Py_DECREF(ah);
        Py_DECREF(al);
@@ -1689,36 +1702,16 @@ k_mul(PyLongObject *a, PyLongObject *b)
        Py_DECREF(bl);
        bh = bl = NULL;
 
-       k = k_mul(t1, t2);
+       t3 = k_mul(t1, t2);
+       assert(t3->ob_size >= 0);
        Py_DECREF(t1);
        Py_DECREF(t2);
-       if (k == NULL) goto fail;
-
-       /* Add k into the result, starting at the shift'th LSD. */
-       i = ret->ob_size - shift;  /* # digits after shift */
-#ifdef Py_DEBUG
-       d =
-#endif
-       v_iadd(ret->ob_digit + shift, i, k->ob_digit, k->ob_size);
-       assert(d == 0);
-       Py_DECREF(k);
+       if (t3 == NULL) goto fail;
 
-       /* Subtract ahbh and albl from the result.  Note that this can't
-        * become negative, since k = ahbh + albl + other stuff.
-        */
-#ifdef Py_DEBUG
-       d =
-#endif
-       v_isub(ret->ob_digit + shift, i, ahbh->ob_digit, ahbh->ob_size);
-       assert(d == 0);
-       Py_DECREF(ahbh);
-
-#ifdef Py_DEBUG
-       d =
-#endif
-       v_isub(ret->ob_digit + shift, i, albl->ob_digit, albl->ob_size);
-       assert(d == 0);
-       Py_DECREF(albl);
+       /* Add t3. */
+       v_iadd(ret->ob_digit + shift, ret->ob_size - shift,
+              t3->ob_digit, t3->ob_size);
+       Py_DECREF(t3);
 
        return long_normalize(ret);
 
@@ -1728,9 +1721,6 @@ k_mul(PyLongObject *a, PyLongObject *b)
        Py_XDECREF(al);
        Py_XDECREF(bh);
        Py_XDECREF(bl);
-       Py_XDECREF(ahbh);
-       Py_XDECREF(albl);
-       Py_XDECREF(k);
        return NULL;
 }