]> granicus.if.org Git - python/commitdiff
Issue 1979: Make Decimal comparisons (other than !=, ==) involving NaN
authorMark Dickinson <dickinsm@gmail.com>
Wed, 6 Feb 2008 22:10:50 +0000 (22:10 +0000)
committerMark Dickinson <dickinsm@gmail.com>
Wed, 6 Feb 2008 22:10:50 +0000 (22:10 +0000)
raise InvalidOperation (and return False if InvalidOperation is trapped).

Doc/library/decimal.rst
Lib/decimal.py
Lib/test/test_decimal.py
Misc/NEWS

index 7b8580685522a358eedb360bb8a2636fcba17bfe..ebb18bb2be5408e9fac0a915ea47907ada28e4c3 100644 (file)
@@ -1290,6 +1290,19 @@ A variant is :const:`sNaN` which signals rather than remaining quiet after every
 operation.  This is a useful return value when an invalid result needs to
 interrupt a calculation for special handling.
 
+The behavior of Python's comparison operators can be a little surprising where a
+:const:`NaN` is involved.  A test for equality where one of the operands is a
+quiet or signaling :const:`NaN` always returns :const:`False` (even when doing
+``Decimal('NaN')==Decimal('NaN')``), while a test for inequality always returns
+:const:`True`.  An attempt to compare two Decimals using any of the :const:'<',
+:const:'<=', :const:'>' or :const:'>=' operators will raise the
+:exc:`InvalidOperation` signal if either operand is a :const:`NaN`, and return
+:const:`False` if this signal is trapped.  Note that the General Decimal
+Arithmetic specification does not specify the behavior of direct comparisons;
+these rules for comparisons involving a :const:`NaN` were taken from the IEEE
+754 standard.  To ensure strict standards-compliance, use the :meth:`compare`
+and :meth:`compare-signal` methods instead.
+
 The signed zeros can result from calculations that underflow. They keep the sign
 that would have resulted if the calculation had been carried out to greater
 precision.  Since their magnitude is zero, both positive and negative zeros are
index eea9448f4317eba0e3b0c14b5b880fe73d0c6773..80340396658f116cdb34ad7d91cce51b51369780 100644 (file)
@@ -717,6 +717,39 @@ class Decimal(object):
             return other._fix_nan(context)
         return 0
 
+    def _compare_check_nans(self, other, context):
+        """Version of _check_nans used for the signaling comparisons
+        compare_signal, __le__, __lt__, __ge__, __gt__.
+
+        Signal InvalidOperation if either self or other is a (quiet
+        or signaling) NaN.  Signaling NaNs take precedence over quiet
+        NaNs.
+
+        Return 0 if neither operand is a NaN.
+
+        """
+        if context is None:
+            context = getcontext()
+
+        if self._is_special or other._is_special:
+            if self.is_snan():
+                return context._raise_error(InvalidOperation,
+                                            'comparison involving sNaN',
+                                            self)
+            elif other.is_snan():
+                return context._raise_error(InvalidOperation,
+                                            'comparison involving sNaN',
+                                            other)
+            elif self.is_qnan():
+                return context._raise_error(InvalidOperation,
+                                            'comparison involving NaN',
+                                            self)
+            elif other.is_qnan():
+                return context._raise_error(InvalidOperation,
+                                            'comparison involving NaN',
+                                            other)
+        return 0
+
     def __nonzero__(self):
         """Return True if self is nonzero; otherwise return False.
 
@@ -724,18 +757,13 @@ class Decimal(object):
         """
         return self._is_special or self._int != '0'
 
-    def __cmp__(self, other):
-        other = _convert_other(other)
-        if other is NotImplemented:
-            # Never return NotImplemented
-            return 1
+    def _cmp(self, other):
+        """Compare the two non-NaN decimal instances self and other.
 
-        if self._is_special or other._is_special:
-            # check for nans, without raising on a signaling nan
-            if self._isnan() or other._isnan():
-                return 1  # Comparison involving NaN's always reports self > other
+        Returns -1 if self < other, 0 if self == other and 1
+        if self > other.  This routine is for internal use only."""
 
-            # INF = INF
+        if self._is_special or other._is_special:
             return cmp(self._isinfinity(), other._isinfinity())
 
         # check for zeros;  note that cmp(0, -0) should return 0
@@ -764,15 +792,71 @@ class Decimal(object):
         else: # self_adjusted < other_adjusted
             return -((-1)**self._sign)
 
+    # Note: The Decimal standard doesn't cover rich comparisons for
+    # Decimals.  In particular, the specification is silent on the
+    # subject of what should happen for a comparison involving a NaN.
+    # We take the following approach:
+    #
+    #   == comparisons involving a NaN always return False
+    #   != comparisons involving a NaN always return True
+    #   <, >, <= and >= comparisons involving a (quiet or signaling)
+    #      NaN signal InvalidOperation, and return False if the
+    #      InvalidOperation is trapped.
+    #
+    # This behavior is designed to conform as closely as possible to
+    # that specified by IEEE 754.
+
     def __eq__(self, other):
-        if not isinstance(other, (Decimal, int, long)):
-            return NotImplemented
-        return self.__cmp__(other) == 0
+        other = _convert_other(other)
+        if other is NotImplemented:
+            return other
+        if self.is_nan() or other.is_nan():
+            return False
+        return self._cmp(other) == 0
 
     def __ne__(self, other):
-        if not isinstance(other, (Decimal, int, long)):
-            return NotImplemented
-        return self.__cmp__(other) != 0
+        other = _convert_other(other)
+        if other is NotImplemented:
+            return other
+        if self.is_nan() or other.is_nan():
+            return True
+        return self._cmp(other) != 0
+
+    def __lt__(self, other, context=None):
+        other = _convert_other(other)
+        if other is NotImplemented:
+            return other
+        ans = self._compare_check_nans(other, context)
+        if ans:
+            return False
+        return self._cmp(other) < 0
+
+    def __le__(self, other, context=None):
+        other = _convert_other(other)
+        if other is NotImplemented:
+            return other
+        ans = self._compare_check_nans(other, context)
+        if ans:
+            return False
+        return self._cmp(other) <= 0
+
+    def __gt__(self, other, context=None):
+        other = _convert_other(other)
+        if other is NotImplemented:
+            return other
+        ans = self._compare_check_nans(other, context)
+        if ans:
+            return False
+        return self._cmp(other) > 0
+
+    def __ge__(self, other, context=None):
+        other = _convert_other(other)
+        if other is NotImplemented:
+            return other
+        ans = self._compare_check_nans(other, context)
+        if ans:
+            return False
+        return self._cmp(other) >= 0
 
     def compare(self, other, context=None):
         """Compares one to another.
@@ -791,7 +875,7 @@ class Decimal(object):
             if ans:
                 return ans
 
-        return Decimal(self.__cmp__(other))
+        return Decimal(self._cmp(other))
 
     def __hash__(self):
         """x.__hash__() <==> hash(x)"""
@@ -2452,7 +2536,7 @@ class Decimal(object):
                     return other._fix_nan(context)
                 return self._check_nans(other, context)
 
-        c = self.__cmp__(other)
+        c = self._cmp(other)
         if c == 0:
             # If both operands are finite and equal in numerical value
             # then an ordering is applied:
@@ -2494,7 +2578,7 @@ class Decimal(object):
                     return other._fix_nan(context)
                 return self._check_nans(other, context)
 
-        c = self.__cmp__(other)
+        c = self._cmp(other)
         if c == 0:
             c = self.compare_total(other)
 
@@ -2542,23 +2626,10 @@ class Decimal(object):
         It's pretty much like compare(), but all NaNs signal, with signaling
         NaNs taking precedence over quiet NaNs.
         """
-        if context is None:
-            context = getcontext()
-
-        self_is_nan = self._isnan()
-        other_is_nan = other._isnan()
-        if self_is_nan == 2:
-            return context._raise_error(InvalidOperation, 'sNaN',
-                                        self)
-        if other_is_nan == 2:
-            return context._raise_error(InvalidOperation, 'sNaN',
-                                        other)
-        if self_is_nan:
-            return context._raise_error(InvalidOperation, 'NaN in compare_signal',
-                                        self)
-        if other_is_nan:
-            return context._raise_error(InvalidOperation, 'NaN in compare_signal',
-                                        other)
+        other = _convert_other(other, raiseit = True)
+        ans = self._compare_check_nans(other, context)
+        if ans:
+            return ans
         return self.compare(other, context=context)
 
     def compare_total(self, other):
@@ -3065,7 +3136,7 @@ class Decimal(object):
                     return other._fix_nan(context)
                 return self._check_nans(other, context)
 
-        c = self.copy_abs().__cmp__(other.copy_abs())
+        c = self.copy_abs()._cmp(other.copy_abs())
         if c == 0:
             c = self.compare_total(other)
 
@@ -3095,7 +3166,7 @@ class Decimal(object):
                     return other._fix_nan(context)
                 return self._check_nans(other, context)
 
-        c = self.copy_abs().__cmp__(other.copy_abs())
+        c = self.copy_abs()._cmp(other.copy_abs())
         if c == 0:
             c = self.compare_total(other)
 
@@ -3170,7 +3241,7 @@ class Decimal(object):
         if ans:
             return ans
 
-        comparison = self.__cmp__(other)
+        comparison = self._cmp(other)
         if comparison == 0:
             return self.copy_sign(other)
 
index 0c8852c4001f6d1bf2b38d891fc4c9c1fa4e45ab..3b15df76c6f918c9f9ccf32c61dc536a8cefeefc 100644 (file)
@@ -838,6 +838,19 @@ class DecimalArithmeticOperatorsTest(unittest.TestCase):
         self.assertEqual(-Decimal(45), Decimal(-45))           #  -
         self.assertEqual(abs(Decimal(45)), abs(Decimal(-45)))  # abs
 
+    def test_nan_comparisons(self):
+        n = Decimal('NaN')
+        s = Decimal('sNaN')
+        i = Decimal('Inf')
+        f = Decimal('2')
+        for x, y in [(n, n), (n, i), (i, n), (n, f), (f, n),
+                     (s, n), (n, s), (s, i), (i, s), (s, f), (f, s), (s, s)]:
+            self.assert_(x != y)
+            self.assert_(not (x == y))
+            self.assert_(not (x < y))
+            self.assert_(not (x <= y))
+            self.assert_(not (x > y))
+            self.assert_(not (x >= y))
 
 # The following are two functions used to test threading in the next class
 
@@ -1147,7 +1160,12 @@ class DecimalUsabilityTest(unittest.TestCase):
         checkSameDec("__add__", True)
         checkSameDec("__div__", True)
         checkSameDec("__divmod__", True)
-        checkSameDec("__cmp__", True)
+        checkSameDec("__eq__", True)
+        checkSameDec("__ne__", True)
+        checkSameDec("__le__", True)
+        checkSameDec("__lt__", True)
+        checkSameDec("__ge__", True)
+        checkSameDec("__gt__", True)
         checkSameDec("__float__")
         checkSameDec("__floordiv__", True)
         checkSameDec("__hash__")
index 0c5cdf2523ecde8696779ecea584950b12386d5c..172678e4b91dc56b8eb050bdeef3ea4ce2add50d 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -391,6 +391,9 @@ Core and builtins
 Library
 -------
 
+- #1979: Add rich comparisons to Decimal, and make Decimal comparisons
+  involving a NaN follow the IEEE 754 standard.
+
 - #2004: tarfile.py: Use mode 0700 for temporary directories and default
   permissions for missing directories.