]> granicus.if.org Git - python/commitdiff
The methods always return Decimal classes, even if they're
authorFacundo Batista <facundobatista@gmail.com>
Mon, 17 Sep 2007 17:30:13 +0000 (17:30 +0000)
committerFacundo Batista <facundobatista@gmail.com>
Mon, 17 Sep 2007 17:30:13 +0000 (17:30 +0000)
executed through a subclass (thanks Mark Dickinson).
Added a bit of testing for this.

Lib/decimal.py
Lib/test/test_decimal.py

index b2bfc567a4e083acf8d9eb99be9d521e51600f98..6481d6c3f293d948ba67858944d5b052a534218b 100644 (file)
@@ -1473,7 +1473,7 @@ class Decimal(object):
                 pos += 1
             payload = payload[pos:]
             return Decimal((self._sign, payload, self._exp))
-        return self
+        return Decimal(self)
 
     def _fix(self, context):
         """Round if it is necessary to keep self within prec precision.
@@ -1494,7 +1494,7 @@ class Decimal(object):
                 return self._fix_nan(context)
             else:
                 # self is +/-Infinity; return unaltered
-                return self
+                return Decimal(self)
 
         # if self is zero then exponent should be between Etiny and
         # Emax if _clamp==0, and between Etiny and Etop if _clamp==1.
@@ -1507,7 +1507,7 @@ class Decimal(object):
                 context._raise_error(Clamped)
                 return Decimal((self._sign, (0,), new_exp))
             else:
-                return self
+                return Decimal(self)
 
         # exp_min is the smallest allowable exponent of the result,
         # equal to max(self.adjusted()-context.prec+1, Etiny)
@@ -1551,7 +1551,7 @@ class Decimal(object):
             return Decimal((self._sign, self_padded, Etop))
 
         # here self was representable to begin with; return unchanged
-        return self
+        return Decimal(self)
 
     _pick_rounding_function = {}
 
@@ -1678,10 +1678,10 @@ class Decimal(object):
                 return context._raise_error(InvalidOperation, 'sNaN',
                                         1, modulo)
             if self_is_nan:
-                return self
+                return self._fix_nan(context)
             if other_is_nan:
-                return other
-            return modulo
+                return other._fix_nan(context)
+            return modulo._fix_nan(context)
 
         # check inputs: we apply same restrictions as Python's pow()
         if not (self._isinteger() and
@@ -2179,7 +2179,7 @@ class Decimal(object):
 
             if exp._isinfinity() or self._isinfinity():
                 if exp._isinfinity() and self._isinfinity():
-                    return self  # if both are inf, it is OK
+                    return Decimal(self)  # if both are inf, it is OK
                 return context._raise_error(InvalidOperation,
                                         'quantize with one INF')
 
@@ -2254,7 +2254,7 @@ class Decimal(object):
         rounding = rounding mode
         """
         if self._is_special:
-            return self
+            return Decimal(self)
         if not self:
             return Decimal((self._sign, (0,), exp))
 
@@ -2285,9 +2285,9 @@ class Decimal(object):
             ans = self._check_nans(context=context)
             if ans:
                 return ans
-            return self
+            return Decimal(self)
         if self._exp >= 0:
-            return self
+            return Decimal(self)
         if not self:
             return Decimal((self._sign, (0,), 0))
         if context is None:
@@ -2310,9 +2310,9 @@ class Decimal(object):
             ans = self._check_nans(context=context)
             if ans:
                 return ans
-            return self
+            return Decimal(self)
         if self._exp >= 0:
-            return self
+            return Decimal(self)
         else:
             return self._rescale(0, rounding)
 
@@ -2426,6 +2426,9 @@ class Decimal(object):
         """
         other = _convert_other(other, raiseit=True)
 
+        if context is None:
+            context = getcontext()
+
         if self._is_special or other._is_special:
             # If one operand is a quiet NaN and the other is number, then the
             # number is always returned
@@ -2433,9 +2436,9 @@ class Decimal(object):
             on = other._isnan()
             if sn or on:
                 if on == 1 and sn != 2:
-                    return self
+                    return self._fix_nan(context)
                 if sn == 1 and on != 2:
-                    return other
+                    return other._fix_nan(context)
                 return self._check_nans(other, context)
 
         c = self.__cmp__(other)
@@ -2455,8 +2458,6 @@ class Decimal(object):
         else:
             ans = self
 
-        if context is None:
-            context = getcontext()
         if context._rounding_decision == ALWAYS_ROUND:
             return ans._fix(context)
         return ans
@@ -2469,6 +2470,9 @@ class Decimal(object):
         """
         other = _convert_other(other, raiseit=True)
 
+        if context is None:
+            context = getcontext()
+
         if self._is_special or other._is_special:
             # If one operand is a quiet NaN and the other is number, then the
             # number is always returned
@@ -2476,9 +2480,9 @@ class Decimal(object):
             on = other._isnan()
             if sn or on:
                 if on == 1 and sn != 2:
-                    return self
+                    return self._fix_nan(context)
                 if sn == 1 and on != 2:
-                    return other
+                    return other._fix_nan(context)
                 return self._check_nans(other, context)
 
         c = self.__cmp__(other)
@@ -2490,8 +2494,6 @@ class Decimal(object):
         else:
             ans = other
 
-        if context is None:
-            context = getcontext()
         if context._rounding_decision == ALWAYS_ROUND:
             return ans._fix(context)
         return ans
@@ -3087,6 +3089,9 @@ class Decimal(object):
         """Compares the values numerically with their sign ignored."""
         other = _convert_other(other, raiseit=True)
 
+        if context is None:
+            context = getcontext()
+
         if self._is_special or other._is_special:
             # If one operand is a quiet NaN and the other is number, then the
             # number is always returned
@@ -3094,9 +3099,9 @@ class Decimal(object):
             on = other._isnan()
             if sn or on:
                 if on == 1 and sn != 2:
-                    return self
+                    return self._fix_nan(context)
                 if sn == 1 and on != 2:
-                    return other
+                    return other._fix_nan(context)
                 return self._check_nans(other, context)
 
         c = self.copy_abs().__cmp__(other.copy_abs())
@@ -3108,8 +3113,6 @@ class Decimal(object):
         else:
             ans = self
 
-        if context is None:
-            context = getcontext()
         if context._rounding_decision == ALWAYS_ROUND:
             return ans._fix(context)
         return ans
@@ -3118,6 +3121,9 @@ class Decimal(object):
         """Compares the values numerically with their sign ignored."""
         other = _convert_other(other, raiseit=True)
 
+        if context is None:
+            context = getcontext()
+
         if self._is_special or other._is_special:
             # If one operand is a quiet NaN and the other is number, then the
             # number is always returned
@@ -3125,9 +3131,9 @@ class Decimal(object):
             on = other._isnan()
             if sn or on:
                 if on == 1 and sn != 2:
-                    return self
+                    return self._fix_nan(context)
                 if sn == 1 and on != 2:
-                    return other
+                    return other._fix_nan(context)
                 return self._check_nans(other, context)
 
         c = self.copy_abs().__cmp__(other.copy_abs())
@@ -3139,8 +3145,6 @@ class Decimal(object):
         else:
             ans = other
 
-        if context is None:
-            context = getcontext()
         if context._rounding_decision == ALWAYS_ROUND:
             return ans._fix(context)
         return ans
@@ -3296,7 +3300,7 @@ class Decimal(object):
             return context._raise_error(InvalidOperation)
 
         if self._isinfinity():
-            return self
+            return Decimal(self)
 
         # get values, pad if necessary
         torot = int(other)
@@ -3334,7 +3338,7 @@ class Decimal(object):
             return context._raise_error(InvalidOperation)
 
         if self._isinfinity():
-            return self
+            return Decimal(self)
 
         d = Decimal((self._sign, self._int, self._exp + int(other)))
         d = d._fix(context)
@@ -3355,12 +3359,12 @@ class Decimal(object):
             return context._raise_error(InvalidOperation)
 
         if self._isinfinity():
-            return self
+            return Decimal(self)
 
         # get values, pad if necessary
         torot = int(other)
         if not torot:
-            return self
+            return Decimal(self)
         rotdig = self._int
         topad = context.prec - len(rotdig)
         if topad:
@@ -3751,7 +3755,7 @@ class Context(object):
         >>> ExtendedContext.copy_decimal(Decimal('-1.00'))
         Decimal("-1.00")
         """
-        return a
+        return Decimal(a)
 
     def copy_negate(self, a):
         """Returns a copy of the operand with the sign inverted.
index bc299ec5ecc4ce9af464908c7114508dbb80c094..2777b225c5fbd1e621762deb1b403a0b96208f37 100644 (file)
@@ -1072,6 +1072,21 @@ class DecimalUsabilityTest(unittest.TestCase):
         checkSameDec("to_eng_string")
         checkSameDec("to_integral")
 
+    def test_subclassing(self):
+        # Different behaviours when subclassing Decimal
+
+        class MyDecimal(Decimal):
+            pass
+
+        d1 = MyDecimal(1)
+        d2 = MyDecimal(2)
+        d = d1 + d2
+        self.assertTrue(type(d) is Decimal)
+
+        d = d1.max(d2)
+        self.assertTrue(type(d) is Decimal)
+
+
 class DecimalPythonAPItests(unittest.TestCase):
 
     def test_pickle(self):