]> granicus.if.org Git - python/commitdiff
Change the binary operators |, &, ^, - to return NotImplemented rather
authorGuido van Rossum <guido@python.org>
Thu, 22 Aug 2002 17:23:33 +0000 (17:23 +0000)
committerGuido van Rossum <guido@python.org>
Thu, 22 Aug 2002 17:23:33 +0000 (17:23 +0000)
than raising TypeError when the other argument is not a BaseSet.  This
made it necessary to separate the implementation of e.g. __or__ from
the union method; the latter should not return NotImplemented but
raise TypeError.  This is accomplished by making union(self, other)
return self|other, etc.; Python's binary operator machinery will raise
TypeError.

The idea behind this change is to allow other set implementations with
an incompatible internal structure; these can provide union (etc.) with
standard sets by implementing __ror__ etc.

I wish I could do this for comparisons too, but the default comparison
implementation allows comparing anything to anything else (returning
false); we don't want that (at least the test suite makes sure
e.g. Set()==42 raises TypeError).  That's probably fine; otherwise
other set implementations would be constrained to implementing a hash
that's compatible with ours.

Lib/sets.py

index eeef0e8087febfb12d5181193c79650bb5677753..fee06d76be0dc717af24cf33a6fba0d2aa616314 100644 (file)
@@ -53,7 +53,7 @@ what's tested is actually `z in y'.
 #   and cleaned up the docstrings.
 #
 # - Raymond Hettinger added a number of speedups and other
-#   bugs^H^H^H^Himprovements.
+#   improvements.
 
 
 __all__ = ['BaseSet', 'Set', 'ImmutableSet']
@@ -155,26 +155,35 @@ class BaseSet(object):
             data[deepcopy(elt, memo)] = value
         return result
 
-    # Standard set operations: union, intersection, both differences
+    # Standard set operations: union, intersection, both differences.
+    # Each has an operator version (e.g. __or__, invoked with |) and a
+    # method version (e.g. union).
 
-    def union(self, other):
+    def __or__(self, other):
         """Return the union of two sets as a new set.
 
         (I.e. all elements that are in either set.)
         """
-        self._binary_sanity_check(other)
+        if not isinstance(other, BaseSet):
+            return NotImplemented
         result = self.__class__(self._data)
         result._data.update(other._data)
         return result
 
-    __or__ = union
+    def union(self, other):
+        """Return the union of two sets as a new set.
 
-    def intersection(self, other):
+        (I.e. all elements that are in either set.)
+        """
+        return self | other
+
+    def __and__(self, other):
         """Return the intersection of two sets as a new set.
 
         (I.e. all elements that are in both sets.)
         """
-        self._binary_sanity_check(other)
+        if not isinstance(other, BaseSet):
+            return NotImplemented
         if len(self) <= len(other):
             little, big = self, other
         else:
@@ -187,14 +196,20 @@ class BaseSet(object):
                 data[elt] = value
         return result
 
-    __and__ = intersection
+    def intersection(self, other):
+        """Return the intersection of two sets as a new set.
 
-    def symmetric_difference(self, other):
+        (I.e. all elements that are in both sets.)
+        """
+        return self & other
+
+    def __xor__(self, other):
         """Return the symmetric difference of two sets as a new set.
 
         (I.e. all elements that are in exactly one of the sets.)
         """
-        self._binary_sanity_check(other)
+        if not isinstance(other, BaseSet):
+            return NotImplemented
         result = self.__class__([])
         data = result._data
         value = True
@@ -206,14 +221,20 @@ class BaseSet(object):
                 data[elt] = value
         return result
 
-    __xor__ = symmetric_difference
+    def symmetric_difference(self, other):
+        """Return the symmetric difference of two sets as a new set.
 
-    def difference(self, other):
+        (I.e. all elements that are in exactly one of the sets.)
+        """
+        return self ^ other
+
+    def  __sub__(self, other):
         """Return the difference of two sets as a new Set.
 
         (I.e. all elements that are in this set and not in the other.)
         """
-        self._binary_sanity_check(other)
+        if not isinstance(other, BaseSet):
+            return NotImplemented
         result = self.__class__([])
         data = result._data
         value = True
@@ -222,7 +243,12 @@ class BaseSet(object):
                 data[elt] = value
         return result
 
-    __sub__ = difference
+    def difference(self, other):
+        """Return the difference of two sets as a new Set.
+
+        (I.e. all elements that are in this set and not in the other.)
+        """
+        return self - other
 
     # Membership test