]> granicus.if.org Git - python/commitdiff
Issue #13121: Support in-place math operators for collections.Counter().
authorRaymond Hettinger <python@rcn.com>
Wed, 19 Oct 2011 20:40:37 +0000 (13:40 -0700)
committerRaymond Hettinger <python@rcn.com>
Wed, 19 Oct 2011 20:40:37 +0000 (13:40 -0700)
Doc/library/collections.rst
Lib/collections/__init__.py
Lib/test/test_collections.py
Misc/NEWS

index f4edced6164c4928a291324aabf96eec249aab15..c9e386402271da9808efbad81cada1650418f9f7 100644 (file)
@@ -293,7 +293,7 @@ or subtracting from an empty counter.
     Counter({'b': 4})
 
 .. versionadded:: 3.3
-   Added support for unary plus and unary minus.
+   Added support for unary plus, unary minus, and in-place multiset operations.
 
 .. note::
 
index 3e864b6f7ac42110ed962b4e478360f4374530b3..68b63a87ee8165332562a073074fea5ad87a9213 100644 (file)
@@ -683,6 +683,69 @@ class Counter(dict):
         '''
         return Counter() - self
 
+    def _keep_positive(self):
+        '''Internal method to strip elements with a negative or zero count'''
+        nonpositive = [elem for elem, count in self.items() if not count > 0]
+        for elem in nonpositive:
+            del self[elem]
+        return self
+
+    def __iadd__(self, other):
+        '''Inplace add from another counter, keeping only positive counts.
+
+        >>> c = Counter('abbb')
+        >>> c += Counter('bcc')
+        >>> c
+        Counter({'b': 4, 'c': 2, 'a': 1})
+
+        '''
+        for elem, count in other.items():
+            self[elem] += count
+        return self._keep_positive()
+
+    def __isub__(self, other):
+        '''Inplace subtract counter, but keep only results with positive counts.
+
+        >>> c = Counter('abbbc')
+        >>> c -= Counter('bccd')
+        >>> c
+        Counter({'b': 2, 'a': 1})
+
+        '''
+        for elem, count in other.items():
+            self[elem] -= count
+        return self._keep_positive()
+
+    def __ior__(self, other):
+        '''Inplace union is the maximum of value from either counter.
+
+        >>> c = Counter('abbb')
+        >>> c |= Counter('bcc')
+        >>> c
+        Counter({'b': 3, 'c': 2, 'a': 1})
+
+        '''
+        for elem, other_count in other.items():
+            count = self[elem]
+            if other_count > count:
+                self[elem] = other_count
+        return self._keep_positive()
+
+    def __iand__(self, other):
+        '''Inplace intersection is the minimum of corresponding counts.
+
+        >>> c = Counter('abbb')
+        >>> c &= Counter('bcc')
+        >>> c
+        Counter({'b': 1})
+
+        '''
+        for elem, count in self.items():
+            other_count = other[elem]
+            if other_count < count:
+                self[elem] = other_count
+        return self._keep_positive()
+
 
 ########################################################################
 ###  ChainMap (helper for configparser and string.Template)
index 04c4d97c5577fa6ca5df5e92761199a929a4a24f..ec2093891791b238c706719a59bfa66eb45e1d70 100644 (file)
@@ -932,6 +932,27 @@ class TestCounter(unittest.TestCase):
                 set_result = setop(set(p.elements()), set(q.elements()))
                 self.assertEqual(counter_result, dict.fromkeys(set_result, 1))
 
+    def test_inplace_operations(self):
+        elements = 'abcd'
+        for i in range(1000):
+            # test random pairs of multisets
+            p = Counter(dict((elem, randrange(-2,4)) for elem in elements))
+            p.update(e=1, f=-1, g=0)
+            q = Counter(dict((elem, randrange(-2,4)) for elem in elements))
+            q.update(h=1, i=-1, j=0)
+            for inplace_op, regular_op in [
+                (Counter.__iadd__, Counter.__add__),
+                (Counter.__isub__, Counter.__sub__),
+                (Counter.__ior__, Counter.__or__),
+                (Counter.__iand__, Counter.__and__),
+            ]:
+                c = p.copy()
+                c_id = id(c)
+                regular_result = regular_op(c, q)
+                inplace_result = inplace_op(c, q)
+                self.assertEqual(inplace_result, regular_result)
+                self.assertEqual(id(inplace_result), c_id)
+
     def test_subtract(self):
         c = Counter(a=-5, b=0, c=5, d=10, e=15,g=40)
         c.subtract(a=1, b=2, c=-3, d=10, e=20, f=30, h=-50)
index 87743343024c4aabb71831fa2f91b435bde57ae8..cd6747e1539bafae8ec6f4c01101b0059c9598ec 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -488,6 +488,7 @@ Library
   in os.kill().
 
 - Add support for unary plus and unary minus to collections.Counter().
+  Issue #13121: Also an support for inplace math operators.
 
 - Issue #12683: urlparse updated to include svn as schemes that uses relative
   paths. (svn from 1.5 onwards support relative path).