]> granicus.if.org Git - python/commitdiff
Issue #8692: Improve performance of math.factorial:
authorMark Dickinson <dickinsm@gmail.com>
Sat, 15 May 2010 17:02:38 +0000 (17:02 +0000)
committerMark Dickinson <dickinsm@gmail.com>
Sat, 15 May 2010 17:02:38 +0000 (17:02 +0000)
(1) use a different algorithm that roughly halves the total number of
    multiplications required and results in more balanced multiplications
(2) use a lookup table for small arguments
(3) fast accumulation of products in C integer arithmetic rather than
    PyLong arithmetic when possible.

Typical speedup, from unscientific testing on a 64-bit laptop, is 4.5x
to 6.5x for arguments in the range 100 - 10000.

Patch by Daniel Stutzbach; extensive reviews by Alexander Belopolsky.

Lib/test/test_math.py
Misc/NEWS
Modules/mathmodule.c

index 48d9b1abbdc88c17072d7d92ec5634e4bc2f1258..6c4443540f3d873c0043ca269e5841e030a74f09 100644 (file)
@@ -60,6 +60,56 @@ def ulps_check(expected, got, ulps=20):
     return "error = {} ulps; permitted error = {} ulps".format(ulps_error,
                                                                ulps)
 
+# Here's a pure Python version of the math.factorial algorithm, for
+# documentation and comparison purposes.
+#
+# Formula:
+#
+#   factorial(n) = factorial_odd_part(n) << (n - count_set_bits(n))
+#
+# where
+#
+#   factorial_odd_part(n) = product_{i >= 0} product_{0 < j <= n >> i; j odd} j
+#
+# The outer product above is an infinite product, but once i >= n.bit_length,
+# (n >> i) < 1 and the corresponding term of the product is empty.  So only the
+# finitely many terms for 0 <= i < n.bit_length() contribute anything.
+#
+# We iterate downwards from i == n.bit_length() - 1 to i == 0.  The inner
+# product in the formula above starts at 1 for i == n.bit_length(); for each i
+# < n.bit_length() we get the inner product for i from that for i + 1 by
+# multiplying by all j in {n >> i+1 < j <= n >> i; j odd}.  In Python terms,
+# this set is range((n >> i+1) + 1 | 1, (n >> i) + 1 | 1, 2).
+
+def count_set_bits(n):
+    """Number of '1' bits in binary expansion of a nonnnegative integer."""
+    return 1 + count_set_bits(n & n - 1) if n else 0
+
+def partial_product(start, stop):
+    """Product of integers in range(start, stop, 2), computed recursively.
+    start and stop should both be odd, with start <= stop.
+
+    """
+    numfactors = (stop - start) >> 1
+    if not numfactors:
+        return 1
+    elif numfactors == 1:
+        return start
+    else:
+        mid = (start + numfactors) | 1
+        return partial_product(start, mid) * partial_product(mid, stop)
+
+def py_factorial(n):
+    """Factorial of nonnegative integer n, via "Binary Split Factorial Formula"
+    described at http://www.luschny.de/math/factorial/binarysplitfact.html
+
+    """
+    inner = outer = 1
+    for i in reversed(range(n.bit_length())):
+        inner *= partial_product((n >> i + 1) + 1 | 1, (n >> i) + 1 | 1)
+        outer *= inner
+    return outer << (n - count_set_bits(n))
+
 def acc_check(expected, got, rel_err=2e-15, abs_err = 5e-323):
     """Determine whether non-NaN floats a and b are equal to within a
     (small) rounding error.  The default values for rel_err and
@@ -365,18 +415,19 @@ class MathTests(unittest.TestCase):
         self.ftest('fabs(1)', math.fabs(1), 1)
 
     def testFactorial(self):
-        def fact(n):
-            result = 1
-            for i in range(1, int(n)+1):
-                result *= i
-            return result
-        values = list(range(10)) + [50, 100, 500]
-        random.shuffle(values)
-        for x in values:
-            for cast in (int, float):
-                self.assertEqual(math.factorial(cast(x)), fact(x), (x, fact(x), math.factorial(x)))
+        self.assertEqual(math.factorial(0), 1)
+        self.assertEqual(math.factorial(0.0), 1)
+        total = 1
+        for i in range(1, 1000):
+            total *= i
+            self.assertEqual(math.factorial(i), total)
+            self.assertEqual(math.factorial(float(i)), total)
+            self.assertEqual(math.factorial(i), py_factorial(i))
         self.assertRaises(ValueError, math.factorial, -1)
+        self.assertRaises(ValueError, math.factorial, -1.0)
         self.assertRaises(ValueError, math.factorial, math.pi)
+        self.assertRaises(OverflowError, math.factorial, sys.maxsize+1)
+        self.assertRaises(OverflowError, math.factorial, 10e100)
 
     def testFloor(self):
         self.assertRaises(TypeError, math.floor)
index 3da54ab4f73640f9a83db2f2e2c24009928ea568..a4a69388cf0145d577509f43d0840e99b3baf8dc 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -1132,6 +1132,12 @@ Library
 Extension Modules
 -----------------
 
+- Issue #8692: Optimize math.factorial: replace the previous naive
+  algorithm with an improved 'binary-split' algorithm that uses fewer
+  multiplications and allows many of the multiplications to be
+  performed using plain C integer arithmetic instead of PyLong
+  arithmetic.  Also uses a lookup table for small arguments.
+
 - Issue #8674: Fixed a number of incorrect or undefined-behaviour-inducing
   overflow checks in the audioop module.
 
index 76d7906a19cf64a90e3a26ce8bcfd43d978460ab..d57ad90b107a016560b089b6c616b6be5a47dacc 100644 (file)
@@ -1129,18 +1129,239 @@ PyDoc_STRVAR(math_fsum_doc,
 Return an accurate floating point sum of values in the iterable.\n\
 Assumes IEEE-754 floating point arithmetic.");
 
+/* Return the smallest integer k such that n < 2**k, or 0 if n == 0.
+ * Equivalent to floor(lg(x))+1.  Also equivalent to: bitwidth_of_type -
+ * count_leading_zero_bits(x)
+ */
+
+/* XXX: This routine does more or less the same thing as
+ * bits_in_digit() in Objects/longobject.c.  Someday it would be nice to
+ * consolidate them.  On BSD, there's a library function called fls()
+ * that we could use, and GCC provides __builtin_clz().
+ */
+
+static unsigned long
+bit_length(unsigned long n)
+{
+    unsigned long len = 0;
+    while (n != 0) {
+        ++len;
+        n >>= 1;
+    }
+    return len;
+}
+
+static unsigned long
+count_set_bits(unsigned long n)
+{
+    unsigned long count = 0;
+    while (n != 0) {
+        ++count;
+        n &= n - 1; /* clear least significant bit */
+    }
+    return count;
+}
+
+/* Divide-and-conquer factorial algorithm
+ *
+ * Based on the formula and psuedo-code provided at:
+ * http://www.luschny.de/math/factorial/binarysplitfact.html
+ *
+ * Faster algorithms exist, but they're more complicated and depend on
+ * a fast prime factoriazation algorithm.
+ *
+ * Notes on the algorithm
+ * ----------------------
+ *
+ * factorial(n) is written in the form 2**k * m, with m odd.  k and m are
+ * computed separately, and then combined using a left shift.
+ *
+ * The function factorial_odd_part computes the odd part m (i.e., the greatest
+ * odd divisor) of factorial(n), using the formula:
+ *
+ *   factorial_odd_part(n) =
+ *
+ *        product_{i >= 0} product_{0 < j <= n / 2**i, j odd} j
+ *
+ * Example: factorial_odd_part(20) =
+ *
+ *        (1) *
+ *        (1) *
+ *        (1 * 3 * 5) *
+ *        (1 * 3 * 5 * 7 * 9)
+ *        (1 * 3 * 5 * 7 * 9 * 11 * 13 * 15 * 17 * 19)
+ *
+ * Here i goes from large to small: the first term corresponds to i=4 (any
+ * larger i gives an empty product), and the last term corresponds to i=0.
+ * Each term can be computed from the last by multiplying by the extra odd
+ * numbers required: e.g., to get from the penultimate term to the last one,
+ * we multiply by (11 * 13 * 15 * 17 * 19).
+ *
+ * To see a hint of why this formula works, here are the same numbers as above
+ * but with the even parts (i.e., the appropriate powers of 2) included.  For
+ * each subterm in the product for i, we multiply that subterm by 2**i:
+ *
+ *   factorial(20) =
+ *
+ *        (16) *
+ *        (8) *
+ *        (4 * 12 * 20) *
+ *        (2 * 6 * 10 * 14 * 18) *
+ *        (1 * 3 * 5 * 7 * 9 * 11 * 13 * 15 * 17 * 19)
+ *
+ * The factorial_partial_product function computes the product of all odd j in
+ * range(start, stop) for given start and stop.  It's used to compute the
+ * partial products like (11 * 13 * 15 * 17 * 19) in the example above.  It
+ * operates recursively, repeatedly splitting the range into two roughly equal
+ * pieces until the subranges are small enough to be computed using only C
+ * integer arithmetic.
+ *
+ * The two-valuation k (i.e., the exponent of the largest power of 2 dividing
+ * the factorial) is computed independently in the main math_factorial
+ * function.  By standard results, its value is:
+ *
+ *    two_valuation = n//2 + n//4 + n//8 + ....
+ *
+ * It can be shown (e.g., by complete induction on n) that two_valuation is
+ * equal to n - count_set_bits(n), where count_set_bits(n) gives the number of
+ * '1'-bits in the binary expansion of n.
+ */
+
+/* factorial_partial_product: Compute product(range(start, stop, 2)) using
+ * divide and conquer.  Assumes start and stop are odd and stop > start.
+ * max_bits must be >= bit_length(stop - 2). */
+
+static PyObject *
+factorial_partial_product(unsigned long start, unsigned long stop,
+                          unsigned long max_bits)
+{
+    unsigned long midpoint, num_operands;
+    PyObject *left = NULL, *right = NULL, *result = NULL;
+
+    /* If the return value will fit an unsigned long, then we can
+     * multiply in a tight, fast loop where each multiply is O(1).
+     * Compute an upper bound on the number of bits required to store
+     * the answer.
+     *
+     * Storing some integer z requires floor(lg(z))+1 bits, which is
+     * conveniently the value returned by bit_length(z).  The
+     * product x*y will require at most
+     * bit_length(x) + bit_length(y) bits to store, based
+     * on the idea that lg product = lg x + lg y.
+     *
+     * We know that stop - 2 is the largest number to be multiplied.  From
+     * there, we have: bit_length(answer) <= num_operands *
+     * bit_length(stop - 2)
+     */
+
+    num_operands = (stop - start) / 2;
+    /* The "num_operands <= 8 * SIZEOF_LONG" check guards against the
+     * unlikely case of an overflow in num_operands * max_bits. */
+    if (num_operands <= 8 * SIZEOF_LONG &&
+        num_operands * max_bits <= 8 * SIZEOF_LONG) {
+        unsigned long j, total;
+        for (total = start, j = start + 2; j < stop; j += 2)
+            total *= j;
+        return PyLong_FromUnsignedLong(total);
+    }
+
+    /* find midpoint of range(start, stop), rounded up to next odd number. */
+    midpoint = (start + num_operands) | 1;
+    left = factorial_partial_product(start, midpoint,
+                                     bit_length(midpoint - 2));
+    if (left == NULL)
+        goto error;
+    right = factorial_partial_product(midpoint, stop, max_bits);
+    if (right == NULL)
+        goto error;
+    result = PyNumber_Multiply(left, right);
+
+  error:
+    Py_XDECREF(left);
+    Py_XDECREF(right);
+    return result;
+}
+
+/* factorial_odd_part:  compute the odd part of factorial(n). */
+
+static PyObject *
+factorial_odd_part(unsigned long n)
+{
+    long i;
+    unsigned long v, lower, upper;
+    PyObject *partial, *tmp, *inner, *outer;
+
+    inner = PyLong_FromLong(1);
+    if (inner == NULL)
+        return NULL;
+    outer = inner;
+    Py_INCREF(outer);
+
+    upper = 3;
+    for (i = bit_length(n) - 2; i >= 0; i--) {
+        v = n >> i;
+        if (v <= 2)
+            continue;
+        lower = upper;
+        /* (v + 1) | 1 = least odd integer strictly larger than n / 2**i */
+        upper = (v + 1) | 1;
+        /* Here inner is the product of all odd integers j in the range (0,
+           n/2**(i+1)].  The factorial_partial_product call below gives the
+           product of all odd integers j in the range (n/2**(i+1), n/2**i]. */
+        partial = factorial_partial_product(lower, upper, bit_length(upper-2));
+        /* inner *= partial */
+        if (partial == NULL)
+            goto error;
+        tmp = PyNumber_Multiply(inner, partial);
+        Py_DECREF(partial);
+        if (tmp == NULL)
+            goto error;
+        Py_DECREF(inner);
+        inner = tmp;
+        /* Now inner is the product of all odd integers j in the range (0,
+           n/2**i], giving the inner product in the formula above. */
+
+        /* outer *= inner; */
+        tmp = PyNumber_Multiply(outer, inner);
+        if (tmp == NULL)
+            goto error;
+        Py_DECREF(outer);
+        outer = tmp;
+    }
+
+    goto done;
+
+  error:
+    Py_DECREF(outer);
+  done:
+    Py_DECREF(inner);
+    return outer;
+}
+
+/* Lookup table for small factorial values */
+
+static const unsigned long SmallFactorials[] = {
+    1, 1, 2, 6, 24, 120, 720, 5040, 40320,
+    362880, 3628800, 39916800, 479001600,
+#if SIZEOF_LONG >= 8
+    6227020800, 87178291200, 1307674368000,
+    20922789888000, 355687428096000, 6402373705728000,
+    121645100408832000, 2432902008176640000
+#endif
+};
+
 static PyObject *
 math_factorial(PyObject *self, PyObject *arg)
 {
-    long i, x;
-    PyObject *result, *iobj, *newresult;
+    long x;
+    PyObject *result, *odd_part, *two_valuation;
 
     if (PyFloat_Check(arg)) {
         PyObject *lx;
         double dx = PyFloat_AS_DOUBLE((PyFloatObject *)arg);
         if (!(Py_IS_FINITE(dx) && dx == floor(dx))) {
             PyErr_SetString(PyExc_ValueError,
-                "factorial() only accepts integral values");
+                            "factorial() only accepts integral values");
             return NULL;
         }
         lx = PyLong_FromDouble(dx);
@@ -1156,29 +1377,28 @@ math_factorial(PyObject *self, PyObject *arg)
         return NULL;
     if (x < 0) {
         PyErr_SetString(PyExc_ValueError,
-            "factorial() not defined for negative values");
+                        "factorial() not defined for negative values");
         return NULL;
     }
 
-    result = (PyObject *)PyLong_FromLong(1);
-    if (result == NULL)
+    /* use lookup table if x is small */
+    if (x < (long)(sizeof(SmallFactorials)/sizeof(SmallFactorials[0])))
+        return PyLong_FromUnsignedLong(SmallFactorials[x]);
+
+    /* else express in the form odd_part * 2**two_valuation, and compute as
+       odd_part << two_valuation. */
+    odd_part = factorial_odd_part(x);
+    if (odd_part == NULL)
+        return NULL;
+    two_valuation = PyLong_FromLong(x - count_set_bits(x));
+    if (two_valuation == NULL) {
+        Py_DECREF(odd_part);
         return NULL;
-    for (i=1 ; i<=x ; i++) {
-        iobj = (PyObject *)PyLong_FromLong(i);
-        if (iobj == NULL)
-            goto error;
-        newresult = PyNumber_Multiply(result, iobj);
-        Py_DECREF(iobj);
-        if (newresult == NULL)
-            goto error;
-        Py_DECREF(result);
-        result = newresult;
     }
+    result = PyNumber_Lshift(odd_part, two_valuation);
+    Py_DECREF(two_valuation);
+    Py_DECREF(odd_part);
     return result;
-
-error:
-    Py_DECREF(result);
-    return NULL;
 }
 
 PyDoc_STRVAR(math_factorial_doc,