]> granicus.if.org Git - python/commitdiff
Tweak recipes and tests
authorRaymond Hettinger <python@rcn.com>
Fri, 7 Mar 2008 01:33:20 +0000 (01:33 +0000)
committerRaymond Hettinger <python@rcn.com>
Fri, 7 Mar 2008 01:33:20 +0000 (01:33 +0000)
Doc/library/itertools.rst
Lib/test/test_itertools.py

index 9ed0c54de1d7b362cf8eeea9c10a91c98bb8f1ae..f546fe16ee2bf55ac5c5f725891722f9bc41979a 100644 (file)
@@ -662,15 +662,15 @@ which incur interpreter overhead. ::
    def pairwise(iterable):
        "s -> (s0,s1), (s1,s2), (s2, s3), ..."
        a, b = tee(iterable)
-       try:
-           b.next()
-       except StopIteration:
-           pass
+       for elem in b:
+           break
        return izip(a, b)
 
-   def grouper(n, iterable, padvalue=None):
+   def grouper(n, iterable, fillvalue=None):
        "grouper(3, 'abcdefg', 'x') --> ('a','b','c'), ('d','e','f'), ('g','x','x')"
-       return izip(*[chain(iterable, repeat(padvalue, n-1))]*n)
+       args = [iter(iterable)] * n
+       kwds = dict(fillvalue=fillvalue)
+       return izip_longest(*args, **kwds)
 
    def roundrobin(*iterables):
        "roundrobin('abc', 'd', 'ef') --> 'a', 'd', 'e', 'b', 'f', 'c'"
index 3bd2255ab9a532cc58d482c61aca354400b50a22..696fdebf1ef028256f822011e7468f13b685ca8d 100644 (file)
@@ -410,6 +410,28 @@ class TestBasicOps(unittest.TestCase):
         self.assertEqual(len(list(product(*[range(7)]*6))), 7**6)
         self.assertRaises(TypeError, product, range(6), None)
 
+        def product1(*args, **kwds):
+            pools = map(tuple, args) * kwds.get('repeat', 1)
+            n = len(pools)
+            if n == 0:
+                yield ()
+                return
+            if any(len(pool) == 0 for pool in pools):
+                return
+            indices = [0] * n
+            yield tuple(pool[i] for pool, i in zip(pools, indices))
+            while 1:
+                for i in reversed(range(n)):  # right to left
+                    if indices[i] == len(pools[i]) - 1:
+                        continue
+                    indices[i] += 1
+                    for j in range(i+1, n):
+                        indices[j] = 0
+                    yield tuple(pool[i] for pool, i in zip(pools, indices))
+                    break
+                else:
+                    return
+
         def product2(*args, **kwds):
             'Pure python version used in docs'
             pools = map(tuple, args) * kwds.get('repeat', 1)
@@ -425,6 +447,7 @@ class TestBasicOps(unittest.TestCase):
             args = [random.choice(argtypes) for j in range(random.randrange(5))]
             expected_len = prod(map(len, args))
             self.assertEqual(len(list(product(*args))), expected_len)
+            self.assertEqual(list(product(*args)), list(product1(*args)))
             self.assertEqual(list(product(*args)), list(product2(*args)))
             args = map(iter, args)
             self.assertEqual(len(list(product(*args))), expected_len)
@@ -1213,7 +1236,7 @@ Samuele
 ...     return sum(imap(operator.mul, vec1, vec2))
 
 >>> def flatten(listOfLists):
-...     return list(chain(*listOfLists))
+...     return list(chain.from_iterable(listOfLists))
 
 >>> def repeatfunc(func, times=None, *args):
 ...     "Repeat calls to func with specified arguments."
@@ -1226,15 +1249,15 @@ Samuele
 >>> def pairwise(iterable):
 ...     "s -> (s0,s1), (s1,s2), (s2, s3), ..."
 ...     a, b = tee(iterable)
-...     try:
-...         b.next()
-...     except StopIteration:
-...         pass
+...     for elem in b:
+...         break
 ...     return izip(a, b)
 
->>> def grouper(n, iterable, padvalue=None):
+>>> def grouper(n, iterable, fillvalue=None):
 ...     "grouper(3, 'abcdefg', 'x') --> ('a','b','c'), ('d','e','f'), ('g','x','x')"
-...     return izip(*[chain(iterable, repeat(padvalue, n-1))]*n)
+...     args = [iter(iterable)] * n
+...     kwds = dict(fillvalue=fillvalue)
+...     return izip_longest(*args, **kwds)
 
 >>> def roundrobin(*iterables):
 ...     "roundrobin('abc', 'd', 'ef') --> 'a', 'd', 'e', 'b', 'f', 'c'"