]> granicus.if.org Git - python/commitdiff
* Add __eq__ and __ne__ so that things like list.index() work properly
authorRaymond Hettinger <python@rcn.com>
Mon, 5 Jul 2004 22:53:03 +0000 (22:53 +0000)
committerRaymond Hettinger <python@rcn.com>
Mon, 5 Jul 2004 22:53:03 +0000 (22:53 +0000)
  for lists of mixed types.
* Test that sort works.

Lib/decimal.py
Lib/test/test_decimal.py

index 500ba07a101489569e3d41fbdac26c0ac7bb98ca..1d13767746635d36c4d36e63ea6062907827bb7c 100644 (file)
@@ -8,10 +8,6 @@
 #    and Tim Peters
 
 
-# Todo:
-#    Add rich comparisons for equality testing with other types
-
-
 """
 This is a Py2.3 implementation of decimal floating point arithmetic based on
 the General Decimal Arithmetic Specification:
@@ -644,6 +640,16 @@ class Decimal(object):
             return -1
         return 1
 
+    def __eq__(self, other):
+        if not isinstance(other, (Decimal, int, long)):
+            return False
+        return self.__cmp__(other) == 0
+
+    def __ne__(self, other):
+        if not isinstance(other, (Decimal, int, long)):
+            return True
+        return self.__cmp__(other) != 0
+
     def compare(self, other, context=None):
         """Compares one to another.
 
index 51b352802352ff2effa1bb5c8b0f5a430f11c448..a03b7840720dfdbf0ab267072271b98537ace64a 100644 (file)
@@ -33,6 +33,7 @@ import pickle, copy
 from decimal import *
 from test.test_support import TestSkipped, run_unittest, run_doctest, is_resource_enabled
 import threading
+import random
 
 # Tests are built around these assumed context defaults
 DefaultContext.prec=9
@@ -841,17 +842,17 @@ class DecimalUsabilityTest(unittest.TestCase):
         self.assertEqual(cmp(dc,45), 0)
 
         #a Decimal and uncomparable
-        try: da == 'ugly'
-        except TypeError: pass
-        else: self.fail('Did not raised an error!')
-
-        try: da == '32.7'
-        except TypeError: pass
-        else: self.fail('Did not raised an error!')
-
-        try: da == object
-        except TypeError: pass
-        else: self.fail('Did not raised an error!')
+        self.assertNotEqual(da, 'ugly')
+        self.assertNotEqual(da, 32.7)
+        self.assertNotEqual(da, object())
+        self.assertNotEqual(da, object)
+
+        # sortable
+        a = map(Decimal, xrange(100))
+        b =  a[:]
+        random.shuffle(a)
+        a.sort()
+        self.assertEqual(a, b)
 
     def test_copy_and_deepcopy_methods(self):
         d = Decimal('43.24')
@@ -1078,6 +1079,10 @@ class ContextAPItests(unittest.TestCase):
             v2 = vars(e)[k]
             self.assertEqual(v1, v2)
 
+    def test_equality_with_other_types(self):
+        self.assert_(Decimal(10) in ['a', 1.0, Decimal(10), (1,2), {}])
+        self.assert_(Decimal(10) not in ['a', 1.0, (1,2), {}])
+
 def test_main(arith=False, verbose=None):
     """ Execute the tests.