]> granicus.if.org Git - python/commitdiff
Close #18545: now only executes member_type if no _value_ is assigned in __new__.
authorEthan Furman <ethan@stoneleaf.us>
Thu, 25 Jul 2013 20:50:45 +0000 (13:50 -0700)
committerEthan Furman <ethan@stoneleaf.us>
Thu, 25 Jul 2013 20:50:45 +0000 (13:50 -0700)
Lib/enum.py
Lib/test/test_enum.py

index 0def138e3e4fedfdfdcbf87d2faf67bde47114cb..33af04262e84c22d5af02a5a6f4e22c30460f597 100644 (file)
@@ -152,12 +152,12 @@ class EnumMeta(type):
                 args = (args, )     # wrap it one more time
             if not use_args:
                 enum_member = __new__(enum_class)
-                original_value = value
+                if not hasattr(enum_member, '_value_'):
+                    enum_member._value_ = value
             else:
                 enum_member = __new__(enum_class, *args)
-                original_value = member_type(*args)
-            if not hasattr(enum_member, '_value_'):
-                enum_member._value_ = original_value
+                if not hasattr(enum_member, '_value_'):
+                    enum_member._value_ = member_type(*args)
             value = enum_member._value_
             enum_member._member_type_ = member_type
             enum_member._name_ = member_name
index d0b4a1c089247e3a97c9a40cb2a6b06438d0cfa3..91c4b69ef5ed48e130e4000fb3ba3cea12e589fd 100644 (file)
@@ -934,6 +934,22 @@ class TestEnum(unittest.TestCase):
         self.assertEqual(ColorInAList.red.value, [1])
         self.assertEqual(ColorInAList([1]), ColorInAList.red)
 
+    def test_conflicting_types_resolved_in_new(self):
+        class LabelledIntEnum(int, Enum):
+            def __new__(cls, *args):
+                value, label = args
+                obj = int.__new__(cls, value)
+                obj.label = label
+                obj._value_ = value
+                return obj
+
+        class LabelledList(LabelledIntEnum):
+            unprocessed = (1, "Unprocessed")
+            payment_complete = (2, "Payment Complete")
+
+        self.assertEqual(list(LabelledList), [LabelledList.unprocessed, LabelledList.payment_complete])
+        self.assertEqual(LabelledList.unprocessed, 1)
+        self.assertEqual(LabelledList(1), LabelledList.unprocessed)
 
 
 class TestUnique(unittest.TestCase):