]> granicus.if.org Git - python/commitdiff
Improve the memory performance and speed of heapq.nsmallest() by using
authorRaymond Hettinger <python@rcn.com>
Sat, 12 Jun 2004 08:33:36 +0000 (08:33 +0000)
committerRaymond Hettinger <python@rcn.com>
Sat, 12 Jun 2004 08:33:36 +0000 (08:33 +0000)
an alternate algorithm when the number of selected items is small
relative to the full iterable.

Lib/heapq.py
Lib/test/test_heapq.py

index d1aad98a24d69b5e27cb35d1f2034416291f2b4f..65f415504201096aecf9c1d957096e064f167b30 100644 (file)
@@ -130,6 +130,7 @@ __all__ = ['heappush', 'heappop', 'heapify', 'heapreplace', 'nlargest',
            'nsmallest']
 
 from itertools import islice, repeat
+import bisect
 
 def heappush(heap, item):
     """Push item onto heap, maintaining the heap invariant."""
@@ -196,6 +197,28 @@ def nsmallest(iterable, n):
 
     Equivalent to:  sorted(iterable)[:n]
     """
+    if hasattr(iterable, '__len__') and n * 10 <= len(iterable):
+        # For smaller values of n, the bisect method is faster than a minheap.
+        # It is also memory efficient, consuming only n elements of space.
+        it = iter(iterable)
+        result = sorted(islice(it, 0, n))
+        if not result:
+            return result
+        insort = bisect.insort
+        pop = result.pop
+        los = result[-1]    # los --> Largest of the nsmallest
+        for elem in it:
+            if los <= elem:
+                continue
+            insort(result, elem)
+            pop()
+            los = result[-1]
+        return result
+    # An alternative approach manifests the whole iterable in memory but
+    # saves comparisons by heapifying all at once.  Also, saves time
+    # over bisect.insort() which has O(n) data movement time for every
+    # insertion.  Finding the n smallest of an m length iterable requires
+    #    O(m) + O(n log m) comparisons.
     h = list(iterable)
     heapify(h)
     return map(heappop, repeat(h, min(n, len(h))))
index 944b17dcc71a1d92fd81f8daaba27ec09176f17f..1cdaabe886154e23b6982596ebfd759abe35562d 100644 (file)
@@ -92,6 +92,7 @@ class TestHeap(unittest.TestCase):
     def test_nsmallest(self):
         data = [random.randrange(2000) for i in range(1000)]
         self.assertEqual(nsmallest(data, 400), sorted(data)[:400])
+        self.assertEqual(nsmallest(data, 50), sorted(data)[:50])
 
     def test_largest(self):
         data = [random.randrange(2000) for i in range(1000)]