]> granicus.if.org Git - python/commitdiff
* Fix decimal's handling of foreign types. Now returns NotImplemented
authorRaymond Hettinger <python@rcn.com>
Sun, 27 Mar 2005 10:47:39 +0000 (10:47 +0000)
committerRaymond Hettinger <python@rcn.com>
Sun, 27 Mar 2005 10:47:39 +0000 (10:47 +0000)
  instead of raising a TypeError.  Allows other types to successfully
  implement __radd__() style methods.
* Remove future division import from test suite.
* Remove test suite's shadowing of __builtin__.dir().

Lib/decimal.py
Lib/test/test_decimal.py
Misc/NEWS

index fb11e8f3e4fee5c08be6794ef6dc54ba3340eaf7..e3e7fd5c11c081b1efd11f748669ce38cb12cf14 100644 (file)
@@ -645,6 +645,8 @@ class Decimal(object):
 
     def __cmp__(self, other, context=None):
         other = _convert_other(other)
+        if other is NotImplemented:
+            return other
 
         if self._is_special or other._is_special:
             ans = self._check_nans(other, context)
@@ -696,12 +698,12 @@ class Decimal(object):
 
     def __eq__(self, other):
         if not isinstance(other, (Decimal, int, long)):
-            return False
+            return NotImplemented
         return self.__cmp__(other) == 0
 
     def __ne__(self, other):
         if not isinstance(other, (Decimal, int, long)):
-            return True
+            return NotImplemented
         return self.__cmp__(other) != 0
 
     def compare(self, other, context=None):
@@ -714,6 +716,8 @@ class Decimal(object):
         Like __cmp__, but returns Decimal instances.
         """
         other = _convert_other(other)
+        if other is NotImplemented:
+            return other
 
         #compare(NaN, NaN) = NaN
         if (self._is_special or other and other._is_special):
@@ -919,6 +923,8 @@ class Decimal(object):
         -INF + INF (or the reverse) cause InvalidOperation errors.
         """
         other = _convert_other(other)
+        if other is NotImplemented:
+            return other
 
         if context is None:
             context = getcontext()
@@ -1006,6 +1012,8 @@ class Decimal(object):
     def __sub__(self, other, context=None):
         """Return self + (-other)"""
         other = _convert_other(other)
+        if other is NotImplemented:
+            return other
 
         if self._is_special or other._is_special:
             ans = self._check_nans(other, context=context)
@@ -1023,6 +1031,8 @@ class Decimal(object):
     def __rsub__(self, other, context=None):
         """Return other + (-self)"""
         other = _convert_other(other)
+        if other is NotImplemented:
+            return other
 
         tmp = Decimal(self)
         tmp._sign = 1 - tmp._sign
@@ -1068,6 +1078,8 @@ class Decimal(object):
         (+-) INF * 0 (or its reverse) raise InvalidOperation.
         """
         other = _convert_other(other)
+        if other is NotImplemented:
+            return other
 
         if context is None:
             context = getcontext()
@@ -1140,6 +1152,10 @@ class Decimal(object):
         computing the other value are not raised.
         """
         other = _convert_other(other)
+        if other is NotImplemented:
+            if divmod in (0, 1):
+                return NotImplemented
+            return (NotImplemented, NotImplemented)
 
         if context is None:
             context = getcontext()
@@ -1292,6 +1308,8 @@ class Decimal(object):
     def __rdiv__(self, other, context=None):
         """Swaps self/other and returns __div__."""
         other = _convert_other(other)
+        if other is NotImplemented:
+            return other
         return other.__div__(self, context=context)
     __rtruediv__ = __rdiv__
 
@@ -1304,6 +1322,8 @@ class Decimal(object):
     def __rdivmod__(self, other, context=None):
         """Swaps self/other and returns __divmod__."""
         other = _convert_other(other)
+        if other is NotImplemented:
+            return other
         return other.__divmod__(self, context=context)
 
     def __mod__(self, other, context=None):
@@ -1311,6 +1331,8 @@ class Decimal(object):
         self % other
         """
         other = _convert_other(other)
+        if other is NotImplemented:
+            return other
 
         if self._is_special or other._is_special:
             ans = self._check_nans(other, context)
@@ -1325,6 +1347,8 @@ class Decimal(object):
     def __rmod__(self, other, context=None):
         """Swaps self/other and returns __mod__."""
         other = _convert_other(other)
+        if other is NotImplemented:
+            return other
         return other.__mod__(self, context=context)
 
     def remainder_near(self, other, context=None):
@@ -1332,6 +1356,8 @@ class Decimal(object):
         Remainder nearest to 0-  abs(remainder-near) <= other/2
         """
         other = _convert_other(other)
+        if other is NotImplemented:
+            return other
 
         if self._is_special or other._is_special:
             ans = self._check_nans(other, context)
@@ -1411,6 +1437,8 @@ class Decimal(object):
     def __rfloordiv__(self, other, context=None):
         """Swaps self/other and returns __floordiv__."""
         other = _convert_other(other)
+        if other is NotImplemented:
+            return other
         return other.__floordiv__(self, context=context)
 
     def __float__(self):
@@ -1661,6 +1689,8 @@ class Decimal(object):
         If modulo is None (default), don't take it mod modulo.
         """
         n = _convert_other(n)
+        if n is NotImplemented:
+            return n
 
         if context is None:
             context = getcontext()
@@ -1747,6 +1777,8 @@ class Decimal(object):
     def __rpow__(self, other, context=None):
         """Swaps self/other and returns __pow__."""
         other = _convert_other(other)
+        if other is NotImplemented:
+            return other
         return other.__pow__(self, context=context)
 
     def normalize(self, context=None):
@@ -2001,6 +2033,8 @@ class Decimal(object):
         NaN (and signals if one is sNaN).  Also rounds.
         """
         other = _convert_other(other)
+        if other is NotImplemented:
+            return other
 
         if self._is_special or other._is_special:
             # if one operand is a quiet NaN and the other is number, then the
@@ -2048,6 +2082,8 @@ class Decimal(object):
         NaN (and signals if one is sNaN).  Also rounds.
         """
         other = _convert_other(other)
+        if other is NotImplemented:
+            return other
 
         if self._is_special or other._is_special:
             # if one operand is a quiet NaN and the other is number, then the
@@ -2874,8 +2910,7 @@ def _convert_other(other):
         return other
     if isinstance(other, (int, long)):
         return Decimal(other)
-
-    raise TypeError, "You can interact Decimal only with int, long or Decimal data types."
+    return NotImplemented
 
 _infinity_map = {
     'inf' : 1,
index fc1e0482846521bc1a5fb901c73743ee4839f2ed..34f034b850bc3c8262c945329c4a62168da1db46 100644 (file)
@@ -24,8 +24,6 @@ you're working through IDLE, you can import this test module and call test_main(
 with the corresponding argument.
 """
 
-from __future__ import division
-
 import unittest
 import glob
 import os, sys
@@ -54,9 +52,9 @@ if __name__ == '__main__':
 else:
     file = __file__
 testdir = os.path.dirname(file) or os.curdir
-dir = testdir + os.sep + TESTDATADIR + os.sep
+directory = testdir + os.sep + TESTDATADIR + os.sep
 
-skip_expected = not os.path.isdir(dir)
+skip_expected = not os.path.isdir(directory)
 
 # Make sure it actually raises errors when not expected and caught in flags
 # Slower, since it runs some things several times.
@@ -109,7 +107,6 @@ class DecimalTest(unittest.TestCase):
     Changed for unittest.
     """
     def setUp(self):
-        global dir
         self.context = Context()
         for key in DefaultContext.traps.keys():
             DefaultContext.traps[key] = 1
@@ -302,11 +299,11 @@ class DecimalTest(unittest.TestCase):
 # Dynamically build custom test definition for each file in the test
 # directory and add the definitions to the DecimalTest class.  This
 # procedure insures that new files do not get skipped.
-for filename in os.listdir(dir):
+for filename in os.listdir(directory):
     if '.decTest' not in filename:
         continue
     head, tail = filename.split('.')
-    tester = lambda self, f=filename: self.eval_file(dir + f)
+    tester = lambda self, f=filename: self.eval_file(directory + f)
     setattr(DecimalTest, 'test_' + head, tester)
     del filename, head, tail, tester
 
@@ -476,6 +473,52 @@ class DecimalImplicitConstructionTest(unittest.TestCase):
     def test_implicit_from_Decimal(self):
         self.assertEqual(Decimal(5) + Decimal(45), Decimal(50))
 
+    def test_rop(self):
+        # Allow other classes to be trained to interact with Decimals
+        class E:
+            def __divmod__(self, other):
+                return 'divmod ' + str(other)
+            def __rdivmod__(self, other):
+                return str(other) + ' rdivmod'
+            def __lt__(self, other):
+                return 'lt ' + str(other)
+            def __gt__(self, other):
+                return 'gt ' + str(other)
+            def __le__(self, other):
+                return 'le ' + str(other)
+            def __ge__(self, other):
+                return 'ge ' + str(other)
+            def __eq__(self, other):
+                return 'eq ' + str(other)
+            def __ne__(self, other):
+                return 'ne ' + str(other)
+
+        self.assertEqual(divmod(E(), Decimal(10)), 'divmod 10')
+        self.assertEqual(divmod(Decimal(10), E()), '10 rdivmod')
+        self.assertEqual(eval('Decimal(10) < E()'), 'gt 10')
+        self.assertEqual(eval('Decimal(10) > E()'), 'lt 10')
+        self.assertEqual(eval('Decimal(10) <= E()'), 'ge 10')
+        self.assertEqual(eval('Decimal(10) >= E()'), 'le 10')
+        self.assertEqual(eval('Decimal(10) == E()'), 'eq 10')
+        self.assertEqual(eval('Decimal(10) != E()'), 'ne 10')
+
+        # insert operator methods and then exercise them
+        for sym, lop, rop in (
+                ('+', '__add__', '__radd__'),
+                ('-', '__sub__', '__rsub__'),
+                ('*', '__mul__', '__rmul__'),
+                ('/', '__div__', '__rdiv__'),
+                ('%', '__mod__', '__rmod__'),
+                ('//', '__floordiv__', '__rfloordiv__'),
+                ('**', '__pow__', '__rpow__'),
+            ):
+
+            setattr(E, lop, lambda self, other: 'str' + lop + str(other))
+            setattr(E, rop, lambda self, other: str(other) + rop + 'str')
+            self.assertEqual(eval('E()' + sym + 'Decimal(10)'),
+                             'str' + lop + '10')
+            self.assertEqual(eval('Decimal(10)' + sym + 'E()'),
+                             '10' + rop + 'str')
 
 class DecimalArithmeticOperatorsTest(unittest.TestCase):
     '''Unit tests for all arithmetic operators, binary and unary.'''
index 1706874c24da2a94b5a381039d10a1e92ce06df9..9a63f2103812dccc1989dc27f4f5fa3217a8d5e0 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -91,6 +91,11 @@ Library
 - distutils.commands.upload was added to support uploading distribution
   files to PyPI.
 
+- decimal operator and comparison methods now return NotImplemented
+  instead of raising a TypeError when interacting with other types.  This
+  allows other classes to implement __radd__ style methods and have them
+  work as expected.
+
 - Bug #1163325:  Decimal infinities failed to hash.  Attempting to
   hash a NaN raised an InvalidOperation instead of a TypeError.