]> granicus.if.org Git - python/commitdiff
Add tests to verify combinatoric relationships.
authorRaymond Hettinger <python@rcn.com>
Tue, 27 Jan 2009 09:35:21 +0000 (09:35 +0000)
committerRaymond Hettinger <python@rcn.com>
Tue, 27 Jan 2009 09:35:21 +0000 (09:35 +0000)
Lib/test/test_itertools.py

index 7c858eb97e09735935c651d0bb48af55c4a0c204..affad7f3b8c755d4cea510d23ffcdc5824197dde 100644 (file)
@@ -110,6 +110,14 @@ class TestBasicOps(unittest.TestCase):
                 if sorted(indices) == list(indices):
                     yield tuple(pool[i] for i in indices)
 
+        def combinations3(iterable, r):
+            'Pure python version from cwr()'
+            pool = tuple(iterable)
+            n = len(pool)
+            for indices in combinations_with_replacement(range(n), r):
+                if len(set(indices)) == r:
+                    yield tuple(pool[i] for i in indices)
+
         for n in range(7):
             values = [5*x-12 for x in range(n)]
             for r in range(n+2):
@@ -126,6 +134,7 @@ class TestBasicOps(unittest.TestCase):
                                      [e for e in values if e in c])      # comb is a subsequence of the input iterable
                 self.assertEqual(result, list(combinations1(values, r))) # matches first pure python version
                 self.assertEqual(result, list(combinations2(values, r))) # matches second pure python version
+                self.assertEqual(result, list(combinations3(values, r))) # matches second pure python version
 
         # Test implementation detail:  tuple re-use
         self.assertEqual(len(set(map(id, combinations('abcde', 3)))), 1)
@@ -265,6 +274,23 @@ class TestBasicOps(unittest.TestCase):
         self.assertEqual(len(set(map(id, permutations('abcde', 3)))), 1)
         self.assertNotEqual(len(set(map(id, list(permutations('abcde', 3))))), 1)
 
+    def test_combinatorics(self):
+        # Test relationships between product(), permutations(),
+        # combinations() and combinations_with_replacement().
+
+        s = 'ABCDE'
+        for r in range(8):
+            prod = list(product(s, repeat=r))
+            cwr = list(combinations_with_replacement(s, r))
+            perm = list(permutations(s, r))
+            comb = list(combinations(s, r))
+
+            self.assertEquals(cwr, [t for t in prod if sorted(t)==list(t)]) # cwr: prods which are sorted
+            self.assertEquals(perm, [t for t in prod if len(set(t))==r])    # perm: prods with no dups
+            self.assertEqual(comb, [t for t in perm if sorted(t)==list(t)]) # comb: perms that are sorted
+            self.assertEqual(comb, [t for t in cwr if len(set(t))==r])      # comb: cwrs without dups
+            self.assertEqual(set(comb), set(cwr) & set(perm))               # comb: both a cwr and a perm
+
     def test_compress(self):
         self.assertEqual(list(compress('ABCDEF', [1,0,1,0,1,1])), list('ACEF'))
         self.assertEqual(list(compress('ABCDEF', [0,0,0,0,0,0])), list(''))