]> granicus.if.org Git - python/commitdiff
k_mul() and long_mul(): I'm confident that the Karatsuba algorithm is
authorTim Peters <tim.peters@gmail.com>
Mon, 12 Aug 2002 17:36:03 +0000 (17:36 +0000)
committerTim Peters <tim.peters@gmail.com>
Mon, 12 Aug 2002 17:36:03 +0000 (17:36 +0000)
correct now, so added some final comments, did some cleanup, and enabled
it for all long-int multiplies.  The KARAT envar no longer matters,
although I left some #if 0'ed code in there for my own use (temporary).
k_mul() is still much slower than x_mul() if the inputs have very
differenent sizes, and that still needs to be addressed.

Misc/NEWS
Objects/longobject.c

index 9d278b646a16851d8dcb05a7d2559beee775e134..efeb3acc66d0632592f5c11586ea71934a3e1d8f 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -57,9 +57,16 @@ Type/class unification and new-style classes
 
 Core and builtins
 
-- XXX Karatsuba multiplication.  This is currently used if and only
-  if envar KARAT exists.  It needs more correctness and speed testing,
-  the latter especially with unbalanced bit lengths.
+- When multiplying very large integers, a version of the so-called
+  Karatsuba algorithm is now used.  This is most effective if the
+  inputs have roughly the same size.  If they both have about N digits,
+  Karatsuba multiplication has O(N**1.58) runtime (the exponent is
+  log_base_2(3)) instead of the previous O(N**2).  Measured results may
+  be better or worse than that, depending on platform quirks.  Note that
+  this is a simple implementation, and there's no intent here to compete
+  with, e.g., gmp.  It simply gives a very nice speedup when it applies.
+  XXX Karatsuba multiplication can be slower when the inputs have very
+  XXX different sizes.
 
 - u'%c' will now raise a ValueError in case the argument is an
   integer outside the valid range of Unicode code point ordinals.
index bf82d732785626e4b77be8909bf3c1b728b5d1e5..0eefc90b1ea3df21a4c84b2d9200eec6d18a8a04 100644 (file)
@@ -1645,7 +1645,23 @@ k_mul(PyLongObject *a, PyLongObject *b)
        if (kmul_split(a, shift, &ah, &al) < 0) goto fail;
        if (kmul_split(b, shift, &bh, &bl) < 0) goto fail;
 
-       /* Allocate result space. */
+       /* The plan:
+        * 1. Allocate result space (asize + bsize digits:  that's always
+        *    enough).
+        * 2. Compute ah*bh, and copy into result at 2*shift.
+        * 3. Compute al*bl, and copy into result at 0.  Note that this
+        *    can't overlap with #2.
+        * 4. Subtract al*bl from the result, starting at shift.  This may
+        *    underflow (borrow out of the high digit), but we don't care:
+        *    we're effectively doing unsigned arithmetic mod
+        *    BASE**(sizea + sizeb), and so long as the *final* result fits,
+        *    borrows and carries out of the high digit can be ignored.
+        * 5. Subtract ah*bh from the result, starting at shift.
+        * 6. Compute (ah+al)*(bh+bl), and add it into the result starting
+        *    at shift.
+        */
+
+       /* 1. Allocate result space. */
        ret = _PyLong_New(asize + bsize);
        if (ret == NULL) goto fail;
 #ifdef Py_DEBUG
@@ -1653,7 +1669,7 @@ k_mul(PyLongObject *a, PyLongObject *b)
        memset(ret->ob_digit, 0xDF, ret->ob_size * sizeof(digit));
 #endif
 
-       /* t1 <- ah*bh, and copy into high digits of result. */
+       /* 2. t1 <- ah*bh, and copy into high digits of result. */
        if ((t1 = k_mul(ah, bh)) == NULL) goto fail;
        assert(t1->ob_size >= 0);
        assert(2*shift + t1->ob_size <= ret->ob_size);
@@ -1666,7 +1682,7 @@ k_mul(PyLongObject *a, PyLongObject *b)
                memset(ret->ob_digit + 2*shift + t1->ob_size, 0,
                       i * sizeof(digit));
 
-       /* t2 <- al*bl, and copy into the low digits. */
+       /* 3. t2 <- al*bl, and copy into the low digits. */
        if ((t2 = k_mul(al, bl)) == NULL) {
                Py_DECREF(t1);
                goto fail;
@@ -1680,15 +1696,17 @@ k_mul(PyLongObject *a, PyLongObject *b)
        if (i)
                memset(ret->ob_digit + t2->ob_size, 0, i * sizeof(digit));
 
-       /* Subtract ah*bh (t1) and al*bl (t2) from "the middle" digits. */
+       /* 4 & 5. Subtract ah*bh (t1) and al*bl (t2).  We do al*bl first
+        * because it's fresher in cache.
+        */
        i = ret->ob_size - shift;  /* # digits after shift */
-       v_isub(ret->ob_digit + shift, i, t2->ob_digit, t2->ob_size);
+       (void)v_isub(ret->ob_digit + shift, i, t2->ob_digit, t2->ob_size);
        Py_DECREF(t2);
 
-       v_isub(ret->ob_digit + shift, i, t1->ob_digit, t1->ob_size);
+       (void)v_isub(ret->ob_digit + shift, i, t1->ob_digit, t1->ob_size);
        Py_DECREF(t1);
 
-       /* t3 <- (ah+al)(bh+bl) */
+       /* 6. t3 <- (ah+al)(bh+bl), and add into result. */
        if ((t1 = x_add(ah, al)) == NULL) goto fail;
        Py_DECREF(ah);
        Py_DECREF(al);
@@ -1709,8 +1727,7 @@ k_mul(PyLongObject *a, PyLongObject *b)
        if (t3 == NULL) goto fail;
 
        /* Add t3. */
-       v_iadd(ret->ob_digit + shift, ret->ob_size - shift,
-              t3->ob_digit, t3->ob_size);
+       (void)v_iadd(ret->ob_digit + shift, i, t3->ob_digit, t3->ob_size);
        Py_DECREF(t3);
 
        return long_normalize(ret);
@@ -1743,10 +1760,14 @@ long_mul(PyLongObject *v, PyLongObject *w)
                return Py_NotImplemented;
        }
 
+#if 0
        if (Py_GETENV("KARAT") != NULL)
                z = k_mul(a, b);
        else
                z = x_mul(a, b);
+#else
+       z = k_mul(a, b);
+#endif
        if(z == NULL) {
                Py_DECREF(a);
                Py_DECREF(b);