]> granicus.if.org Git - python/commitdiff
issue23673
authorEthan Furman <ethan@stoneleaf.us>
Thu, 19 Mar 2015 00:27:57 +0000 (17:27 -0700)
committerEthan Furman <ethan@stoneleaf.us>
Thu, 19 Mar 2015 00:27:57 +0000 (17:27 -0700)
add private method to enum to support replacing global constants with Enum members:
- search for candidate constants via supplied filter
- create new enum class and members
- insert enum class and replace constants with members via supplied module name
- replace __reduce_ex__ with function that returns member name, so previous Python versions can unpickle
modify IntEnum classes to use new method

Lib/enum.py
Lib/signal.py
Lib/socket.py
Lib/ssl.py
Lib/test/test_enum.py
Lib/test/test_socket.py

index 5be13d5c71ab1e7e22aea75c42dc0dade0d0e27d..1d9ebf0f94136c41846fef66b7a94d30ffe62387 100644 (file)
@@ -519,11 +519,37 @@ class Enum(metaclass=EnumMeta):
         """The value of the Enum member."""
         return self._value_
 
+    @classmethod
+    def _convert(cls, name, module, filter, source=None):
+        """
+        Create a new Enum subclass that replaces a collection of global constants
+        """
+        # convert all constants from source (or module) that pass filter() to
+        # a new Enum called name, and export the enum and its members back to
+        # module;
+        # also, replace the __reduce_ex__ method so unpickling works in
+        # previous Python versions
+        module_globals = vars(sys.modules[module])
+        if source:
+            source = vars(source)
+        else:
+            source = module_globals
+        members = {name: value for name, value in source.items()
+                if filter(name)}
+        cls = cls(name, members, module=module)
+        cls.__reduce_ex__ = _reduce_ex_by_name
+        module_globals.update(cls.__members__)
+        module_globals[name] = cls
+        return cls
+
 
 class IntEnum(int, Enum):
     """Enum where members are also (and must be) ints"""
 
 
+def _reduce_ex_by_name(self, proto):
+    return self.name
+
 def unique(enumeration):
     """Class decorator for enumerations ensuring unique member values."""
     duplicates = []
index 0db3df8bd6188b2ecdb0e3a6166f6a27caee86a0..371d7128f85183f2fe5b6927f332b8b9c55a58a1 100644 (file)
@@ -5,27 +5,21 @@ from enum import IntEnum as _IntEnum
 
 _globals = globals()
 
-Signals = _IntEnum(
-    'Signals',
-    {name: value for name, value in _globals.items()
-     if name.isupper()
-        and (name.startswith('SIG') and not name.startswith('SIG_'))
-        or name.startswith('CTRL_')})
+_IntEnum._convert(
+        'Signals', __name__,
+        lambda name:
+            name.isupper()
+            and (name.startswith('SIG') and not name.startswith('SIG_'))
+            or name.startswith('CTRL_'))
 
-class Handlers(_IntEnum):
-    SIG_DFL = _signal.SIG_DFL
-    SIG_IGN = _signal.SIG_IGN
-
-_globals.update(Signals.__members__)
-_globals.update(Handlers.__members__)
+_IntEnum._convert(
+        'Handlers', __name__,
+        lambda name: name in ('SIG_DFL', 'SIG_IGN'))
 
 if 'pthread_sigmask' in _globals:
-    class Sigmasks(_IntEnum):
-        SIG_BLOCK = _signal.SIG_BLOCK
-        SIG_UNBLOCK = _signal.SIG_UNBLOCK
-        SIG_SETMASK = _signal.SIG_SETMASK
-
-    _globals.update(Sigmasks.__members__)
+    _IntEnum._convert(
+            'Sigmasks', __name__,
+            lambda name: name in ('SIG_BLOCK', 'SIG_UNBLOCK', 'SIG_SETMASK'))
 
 
 def _int_to_enum(value, enum_klass):
index 9c39c69eaced11d445333289956e1518bc6094f4..db34ab37ee676cae8479260fb377379c249ccf0d 100644 (file)
@@ -69,16 +69,16 @@ __all__.extend(os._get_exports_list(_socket))
 # Note that _socket only knows about the integer values. The public interface
 # in this module understands the enums and translates them back from integers
 # where needed (e.g. .family property of a socket object).
-AddressFamily = IntEnum('AddressFamily',
-                        {name: value for name, value in globals().items()
-                         if name.isupper() and name.startswith('AF_')})
-globals().update(AddressFamily.__members__)
 
-SocketKind = IntEnum('SocketKind',
-                     {name: value for name, value in globals().items()
-                      if name.isupper() and name.startswith('SOCK_')})
-globals().update(SocketKind.__members__)
+IntEnum._convert(
+        'AddressFamily',
+        __name__,
+        lambda C: C.isupper() and C.startswith('AF_'))
 
+IntEnum._convert(
+        'SocketKind',
+        __name__,
+        lambda C: C.isupper() and C.startswith('SOCK_'))
 
 _LOCALHOST    = '127.0.0.1'
 _LOCALHOST_V6 = '::1'
index 18730cb2e989de2623a29dd317e8e5f1861c37d9..ab7a49b5763f85b1e28f47da56016fbe4d01eebd 100644 (file)
@@ -126,10 +126,10 @@ from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN, HAS_ALPN
 
 from _ssl import _OPENSSL_API_VERSION
 
-_SSLMethod = _IntEnum('_SSLMethod',
-                      {name: value for name, value in vars(_ssl).items()
-                       if name.startswith('PROTOCOL_')})
-globals().update(_SSLMethod.__members__)
+_IntEnum._convert(
+        '_SSLMethod', __name__,
+        lambda name: name.startswith('PROTOCOL_'),
+        source=_ssl)
 
 _PROTOCOL_NAMES = {value: name for name, value in _SSLMethod.__members__.items()}
 
index 7d172c86e4f5a02c9478b105392837f7c5cfeea1..51d9deb2e69473fc2250230e66b04d89d25cede8 100644 (file)
@@ -581,6 +581,14 @@ class TestEnum(unittest.TestCase):
         test_pickle_dump_load(self.assertIs, self.NestedEnum.twigs,
                 protocol=(4, HIGHEST_PROTOCOL))
 
+    def test_pickle_by_name(self):
+        class ReplaceGlobalInt(IntEnum):
+            ONE = 1
+            TWO = 2
+        ReplaceGlobalInt.__reduce_ex__ = enum._reduce_ex_by_name
+        for proto in range(HIGHEST_PROTOCOL):
+            self.assertEqual(ReplaceGlobalInt.TWO.__reduce_ex__(proto), 'TWO')
+
     def test_exploding_pickle(self):
         BadPickle = Enum(
                 'BadPickle', 'dill sweet bread-n-butter', module=__name__)
index cf45b7346ca154d25b7dfd2120ca1ab2c683d836..d43e56d45608c0ea6588fc7297092b537a895171 100644 (file)
@@ -1377,6 +1377,11 @@ class GeneralModuleTests(unittest.TestCase):
         with sock:
             for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
                 self.assertRaises(TypeError, pickle.dumps, sock, protocol)
+        for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
+            family = pickle.loads(pickle.dumps(socket.AF_INET, protocol))
+            self.assertEqual(family, socket.AF_INET)
+            type = pickle.loads(pickle.dumps(socket.SOCK_STREAM, protocol))
+            self.assertEqual(type, socket.SOCK_STREAM)
 
     def test_listen_backlog(self):
         for backlog in 0, -1: