]> granicus.if.org Git - python/commitdiff
Merged revisions 68546 via svnmerge from
authorBenjamin Peterson <benjamin@python.org>
Sun, 18 Jan 2009 22:46:33 +0000 (22:46 +0000)
committerBenjamin Peterson <benjamin@python.org>
Sun, 18 Jan 2009 22:46:33 +0000 (22:46 +0000)
svn+ssh://pythondev@svn.python.org/python/trunk

........
  r68546 | raymond.hettinger | 2009-01-12 04:37:32 -0600 (Mon, 12 Jan 2009) | 1 line

  Optimize heapq.nsmallest/nlargest for cases where n==1 or n>=size.
........

Lib/heapq.py

index 2d3404644aa4db308213f816c416f1fa95be1639..24997bf6fd2c12ba01520fb528cff55b1a1c2305 100644 (file)
@@ -129,7 +129,7 @@ From all times, sorting has always been a Great Art! :-)
 __all__ = ['heappush', 'heappop', 'heapify', 'heapreplace', 'merge',
            'nlargest', 'nsmallest', 'heappushpop']
 
-from itertools import islice, repeat, count, tee
+from itertools import islice, repeat, count, tee, chain
 from operator import itemgetter, neg
 import bisect
 
@@ -354,10 +354,32 @@ def nsmallest(n, iterable, key=None):
 
     Equivalent to:  sorted(iterable, key=key)[:n]
     """
+    # Short-cut for n==1 is to use min() when len(iterable)>0
+    if n == 1:
+        it = iter(iterable)
+        head = list(islice(it, 1))
+        if not head:
+            return []
+        if key is None:
+            return [min(chain(head, it))]
+        return [min(chain(head, it), key=key)]
+
+    # When n>=size, it's faster to use sort()
+    try:
+        size = len(iterable)
+    except (TypeError, AttributeError):
+        pass
+    else:
+        if n >= size:
+            return sorted(iterable, key=key)[:n]
+
+    # When key is none, use simpler decoration
     if key is None:
         it = zip(iterable, count())                         # decorate
         result = _nsmallest(n, it)
         return list(map(itemgetter(0), result))             # undecorate
+
+    # General case, slowest method
     in1, in2 = tee(iterable)
     it = zip(map(key, in1), count(), in2)                   # decorate
     result = _nsmallest(n, it)
@@ -369,10 +391,33 @@ def nlargest(n, iterable, key=None):
 
     Equivalent to:  sorted(iterable, key=key, reverse=True)[:n]
     """
+
+    # Short-cut for n==1 is to use max() when len(iterable)>0
+    if n == 1:
+        it = iter(iterable)
+        head = list(islice(it, 1))
+        if not head:
+            return []
+        if key is None:
+            return [max(chain(head, it))]
+        return [max(chain(head, it), key=key)]
+
+    # When n>=size, it's faster to use sort()
+    try:
+        size = len(iterable)
+    except (TypeError, AttributeError):
+        pass
+    else:
+        if n >= size:
+            return sorted(iterable, key=key, reverse=True)[:n]
+
+    # When key is none, use simpler decoration
     if key is None:
         it = zip(iterable, map(neg, count()))               # decorate
         result = _nlargest(n, it)
         return list(map(itemgetter(0), result))             # undecorate
+
+    # General case, slowest method
     in1, in2 = tee(iterable)
     it = zip(map(key, in1), map(neg, count()), in2)         # decorate
     result = _nlargest(n, it)