]> granicus.if.org Git - python/commitdiff
Issue #18844: Add random.weighted_choices()
authorRaymond Hettinger <python@rcn.com>
Wed, 7 Sep 2016 00:15:29 +0000 (17:15 -0700)
committerRaymond Hettinger <python@rcn.com>
Wed, 7 Sep 2016 00:15:29 +0000 (17:15 -0700)
Doc/library/random.rst
Lib/random.py
Lib/test/test_random.py
Misc/NEWS

index 6dc54d2877620c4a44ec5c1aca235318619329f4..330cce15b8cce4457ab2e0389565fbf519c2366c 100644 (file)
@@ -124,6 +124,27 @@ Functions for sequences:
    Return a random element from the non-empty sequence *seq*. If *seq* is empty,
    raises :exc:`IndexError`.
 
+.. function:: weighted_choices(k, population, weights=None, *, cum_weights=None)
+
+   Return a *k* sized list of elements chosen from the *population* with replacement.
+   If the *population* is empty, raises :exc:`IndexError`.
+
+   If a *weights* sequence is specified, selections are made according to the
+   relative weights.  Alternatively, if a *cum_weights* sequence is given, the
+   selections are made according to the cumulative weights.  For example, the
+   relative weights ``[10, 5, 30, 5]`` are equivalent to the cumulative
+   weights ``[10, 15, 45, 50]``.  Internally, the relative weights are
+   converted to cumulative weights before making selections, so supplying the
+   cumulative weights saves work.
+
+   If neither *weights* nor *cum_weights* are specified, selections are made
+   with equal probability.  If a weights sequence is supplied, it must be
+   the same length as the *population* sequence.  It is a :exc:`TypeError`
+   to specify both *weights* and *cum_weights*.
+
+   The *weights* or *cum_weights* can use any numeric type that interoperates
+   with the :class:`float` values returned by :func:`random` (that includes
+   integers, floats, and fractions but excludes decimals).
 
 .. function:: shuffle(x[, random])
 
index 82f6013b1fe1ed7fb8288d779673e74b4f62fcf7..136395e938f72d152114fc120002f6d52811a97f 100644 (file)
@@ -8,6 +8,7 @@
     ---------
            pick random element
            pick random sample
+           pick weighted random sample
            generate random permutation
 
     distributions on the real line:
@@ -43,12 +44,14 @@ from math import sqrt as _sqrt, acos as _acos, cos as _cos, sin as _sin
 from os import urandom as _urandom
 from _collections_abc import Set as _Set, Sequence as _Sequence
 from hashlib import sha512 as _sha512
+import itertools as _itertools
+import bisect as _bisect
 
 __all__ = ["Random","seed","random","uniform","randint","choice","sample",
            "randrange","shuffle","normalvariate","lognormvariate",
            "expovariate","vonmisesvariate","gammavariate","triangular",
            "gauss","betavariate","paretovariate","weibullvariate",
-           "getstate","setstate", "getrandbits",
+           "getstate","setstate", "getrandbits", "weighted_choices",
            "SystemRandom"]
 
 NV_MAGICCONST = 4 * _exp(-0.5)/_sqrt(2.0)
@@ -334,6 +337,28 @@ class Random(_random.Random):
                 result[i] = population[j]
         return result
 
+    def weighted_choices(self, k, population, weights=None, *, cum_weights=None):
+        """Return a k sized list of population elements chosen with replacement.
+
+        If the relative weights or cumulative weights are not specified,
+        the selections are made with equal probability.
+
+        """
+        if cum_weights is None:
+            if weights is None:
+                choice = self.choice
+                return [choice(population) for i in range(k)]
+            else:
+                cum_weights = list(_itertools.accumulate(weights))
+        elif weights is not None:
+            raise TypeError('Cannot specify both weights and cumulative_weights')
+        if len(cum_weights) != len(population):
+            raise ValueError('The number of weights does not match the population')
+        bisect = _bisect.bisect
+        random = self.random
+        total = cum_weights[-1]
+        return [population[bisect(cum_weights, random() * total)] for i in range(k)]
+
 ## -------------------- real-valued distributions  -------------------
 
 ## -------------------- uniform distribution -------------------
@@ -724,6 +749,7 @@ choice = _inst.choice
 randrange = _inst.randrange
 sample = _inst.sample
 shuffle = _inst.shuffle
+weighted_choices = _inst.weighted_choices
 normalvariate = _inst.normalvariate
 lognormvariate = _inst.lognormvariate
 expovariate = _inst.expovariate
index e80ed17a8cb6bc3d1972d60541a6a5de8fb638e6..b3741a8845ad1aacc95000bc67609512e7ce2315 100644 (file)
@@ -7,6 +7,7 @@ import warnings
 from functools import partial
 from math import log, exp, pi, fsum, sin
 from test import support
+from fractions import Fraction
 
 class TestBasicOps:
     # Superclass with tests common to all generators.
@@ -141,6 +142,73 @@ class TestBasicOps:
     def test_sample_on_dicts(self):
         self.assertRaises(TypeError, self.gen.sample, dict.fromkeys('abcdef'), 2)
 
+    def test_weighted_choices(self):
+        weighted_choices = self.gen.weighted_choices
+        data = ['red', 'green', 'blue', 'yellow']
+        str_data = 'abcd'
+        range_data = range(4)
+        set_data = set(range(4))
+
+        # basic functionality
+        for sample in [
+            weighted_choices(5, data),
+            weighted_choices(5, data, range(4)),
+            weighted_choices(k=5, population=data, weights=range(4)),
+            weighted_choices(k=5, population=data, cum_weights=range(4)),
+        ]:
+            self.assertEqual(len(sample), 5)
+            self.assertEqual(type(sample), list)
+            self.assertTrue(set(sample) <= set(data))
+
+        # test argument handling
+        with self.assertRaises(TypeError):                                        # missing arguments
+            weighted_choices(2)
+
+        self.assertEqual(weighted_choices(0, data), [])                           # k == 0
+        self.assertEqual(weighted_choices(-1, data), [])                          # negative k behaves like ``[0] * -1``
+        with self.assertRaises(TypeError):
+            weighted_choices(2.5, data)                                           # k is a float
+
+        self.assertTrue(set(weighted_choices(5, str_data)) <= set(str_data))      # population is a string sequence
+        self.assertTrue(set(weighted_choices(5, range_data)) <= set(range_data))  # population is a range
+        with self.assertRaises(TypeError):
+            weighted_choices(2.5, set_data)                                       # population is not a sequence
+
+        self.assertTrue(set(weighted_choices(5, data, None)) <= set(data))        # weights is None
+        self.assertTrue(set(weighted_choices(5, data, weights=None)) <= set(data))
+        with self.assertRaises(ValueError):
+            weighted_choices(5, data, [1,2])                                      # len(weights) != len(population)
+        with self.assertRaises(IndexError):
+            weighted_choices(5, data, [0]*4)                                      # weights sum to zero
+        with self.assertRaises(TypeError):
+            weighted_choices(5, data, 10)                                         # non-iterable weights
+        with self.assertRaises(TypeError):
+            weighted_choices(5, data, [None]*4)                                   # non-numeric weights
+        for weights in [
+                [15, 10, 25, 30],                                                 # integer weights
+                [15.1, 10.2, 25.2, 30.3],                                         # float weights
+                [Fraction(1, 3), Fraction(2, 6), Fraction(3, 6), Fraction(4, 6)], # fractional weights
+                [True, False, True, False]                                        # booleans (include / exclude)
+        ]:
+            self.assertTrue(set(weighted_choices(5, data, weights)) <= set(data))
+
+        with self.assertRaises(ValueError):
+            weighted_choices(5, data, cum_weights=[1,2])                          # len(weights) != len(population)
+        with self.assertRaises(IndexError):
+            weighted_choices(5, data, cum_weights=[0]*4)                          # cum_weights sum to zero
+        with self.assertRaises(TypeError):
+            weighted_choices(5, data, cum_weights=10)                             # non-iterable cum_weights
+        with self.assertRaises(TypeError):
+            weighted_choices(5, data, cum_weights=[None]*4)                       # non-numeric cum_weights
+        with self.assertRaises(TypeError):
+            weighted_choices(5, data, range(4), cum_weights=range(4))             # both weights and cum_weights
+        for weights in [
+                [15, 10, 25, 30],                                                 # integer cum_weights
+                [15.1, 10.2, 25.2, 30.3],                                         # float cum_weights
+                [Fraction(1, 3), Fraction(2, 6), Fraction(3, 6), Fraction(4, 6)], # fractional cum_weights
+        ]:
+            self.assertTrue(set(weighted_choices(5, data, cum_weights=weights)) <= set(data))
+
     def test_gauss(self):
         # Ensure that the seed() method initializes all the hidden state.  In
         # particular, through 2.2.1 it failed to reset a piece of state used
index e913ef81d28cb4867c75dba18e5935828b5710a1..fbf7b2b976470e7809c5161410b05ec2cc75e4b5 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -101,6 +101,8 @@ Library
 - Issue #27691: Fix ssl module's parsing of GEN_RID subject alternative name
   fields in X.509 certs.
 
+- Issue #18844: Add random.weighted_choices().
+
 - Issue #25761: Improved error reporting about truncated pickle data in
   C implementation of unpickler.  UnpicklingError is now raised instead of
   AttributeError and ValueError in some cases.