]> granicus.if.org Git - python/commitdiff
bpo-32227: functools.singledispatch supports registering via type annotations (#4733)
authorŁukasz Langa <lukasz@langa.pl>
Mon, 11 Dec 2017 21:56:31 +0000 (13:56 -0800)
committerGitHub <noreply@github.com>
Mon, 11 Dec 2017 21:56:31 +0000 (13:56 -0800)
Doc/library/functools.rst
Lib/functools.py
Lib/test/test_functools.py
Misc/NEWS.d/next/Library/2017-12-05-13-25-15.bpo-32227.3vnWFS.rst [new file with mode: 0644]

index 28062c11890ed10304a4ede90a0a1452180dda11..a81e819103adf3f59e01373b50958963a1314d9a 100644 (file)
@@ -281,23 +281,34 @@ The :mod:`functools` module defines the following functions:
      ...     print(arg)
 
    To add overloaded implementations to the function, use the :func:`register`
-   attribute of the generic function.  It is a decorator, taking a type
-   parameter and decorating a function implementing the operation for that
-   type::
+   attribute of the generic function.  It is a decorator.  For functions
+   annotated with types, the decorator will infer the type of the first
+   argument automatically::
 
-     >>> @fun.register(int)
-     ... def _(arg, verbose=False):
+     >>> @fun.register
+     ... def _(arg: int, verbose=False):
      ...     if verbose:
      ...         print("Strength in numbers, eh?", end=" ")
      ...     print(arg)
      ...
-     >>> @fun.register(list)
-     ... def _(arg, verbose=False):
+     >>> @fun.register
+     ... def _(arg: list, verbose=False):
      ...     if verbose:
      ...         print("Enumerate this:")
      ...     for i, elem in enumerate(arg):
      ...         print(i, elem)
 
+   For code which doesn't use type annotations, the appropriate type
+   argument can be passed explicitly to the decorator itself::
+
+     >>> @fun.register(complex)
+     ... def _(arg, verbose=False):
+     ...     if verbose:
+     ...         print("Better than complicated.", end=" ")
+     ...     print(arg.real, arg.imag)
+     ...
+
+
    To enable registering lambdas and pre-existing functions, the
    :func:`register` attribute can be used in a functional form::
 
@@ -368,6 +379,9 @@ The :mod:`functools` module defines the following functions:
 
    .. versionadded:: 3.4
 
+   .. versionchanged:: 3.7
+      The :func:`register` attribute supports using type annotations.
+
 
 .. function:: update_wrapper(wrapper, wrapped, assigned=WRAPPER_ASSIGNMENTS, updated=WRAPPER_UPDATES)
 
index a51dddf878584b89fa28ee6f0a85688d010a7217..c8b79c2a7c2bd19e656d749bd2f7b68341cb91bc 100644 (file)
@@ -793,7 +793,23 @@ def singledispatch(func):
         """
         nonlocal cache_token
         if func is None:
-            return lambda f: register(cls, f)
+            if isinstance(cls, type):
+                return lambda f: register(cls, f)
+            ann = getattr(cls, '__annotations__', {})
+            if not ann:
+                raise TypeError(
+                    f"Invalid first argument to `register()`: {cls!r}. "
+                    f"Use either `@register(some_class)` or plain `@register` "
+                    f"on an annotated function."
+                )
+            func = cls
+
+            # only import typing if annotation parsing is necessary
+            from typing import get_type_hints
+            argname, cls = next(iter(get_type_hints(func).items()))
+            assert isinstance(cls, type), (
+                f"Invalid annotation for {argname!r}. {cls!r} is not a class."
+            )
         registry[cls] = func
         if cache_token is None and hasattr(cls, '__abstractmethods__'):
             cache_token = get_cache_token()
index 68e94e7ae17430a580c975ed0c319a700c4375d1..35ec2e2f481f815941942ab76ad088dff86da7ea 100644 (file)
@@ -10,6 +10,7 @@ import sys
 from test import support
 import threading
 import time
+import typing
 import unittest
 import unittest.mock
 from weakref import proxy
@@ -2119,6 +2120,73 @@ class TestSingleDispatch(unittest.TestCase):
             g._clear_cache()
             self.assertEqual(len(td), 0)
 
+    def test_annotations(self):
+        @functools.singledispatch
+        def i(arg):
+            return "base"
+        @i.register
+        def _(arg: collections.abc.Mapping):
+            return "mapping"
+        @i.register
+        def _(arg: "collections.abc.Sequence"):
+            return "sequence"
+        self.assertEqual(i(None), "base")
+        self.assertEqual(i({"a": 1}), "mapping")
+        self.assertEqual(i([1, 2, 3]), "sequence")
+        self.assertEqual(i((1, 2, 3)), "sequence")
+        self.assertEqual(i("str"), "sequence")
+
+        # Registering classes as callables doesn't work with annotations,
+        # you need to pass the type explicitly.
+        @i.register(str)
+        class _:
+            def __init__(self, arg):
+                self.arg = arg
+
+            def __eq__(self, other):
+                return self.arg == other
+        self.assertEqual(i("str"), "str")
+
+    def test_invalid_registrations(self):
+        msg_prefix = "Invalid first argument to `register()`: "
+        msg_suffix = (
+            ". Use either `@register(some_class)` or plain `@register` on an "
+            "annotated function."
+        )
+        @functools.singledispatch
+        def i(arg):
+            return "base"
+        with self.assertRaises(TypeError) as exc:
+            @i.register(42)
+            def _(arg):
+                return "I annotated with a non-type"
+        self.assertTrue(str(exc.exception).startswith(msg_prefix + "42"))
+        self.assertTrue(str(exc.exception).endswith(msg_suffix))
+        with self.assertRaises(TypeError) as exc:
+            @i.register
+            def _(arg):
+                return "I forgot to annotate"
+        self.assertTrue(str(exc.exception).startswith(msg_prefix +
+            "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
+        ))
+        self.assertTrue(str(exc.exception).endswith(msg_suffix))
+
+        # FIXME: The following will only work after PEP 560 is implemented.
+        return
+
+        with self.assertRaises(TypeError) as exc:
+            @i.register
+            def _(arg: typing.Iterable[str]):
+                # At runtime, dispatching on generics is impossible.
+                # When registering implementations with singledispatch, avoid
+                # types from `typing`. Instead, annotate with regular types
+                # or ABCs.
+                return "I annotated with a generic collection"
+        self.assertTrue(str(exc.exception).startswith(msg_prefix +
+            "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
+        ))
+        self.assertTrue(str(exc.exception).endswith(msg_suffix))
+
 
 if __name__ == '__main__':
     unittest.main()
diff --git a/Misc/NEWS.d/next/Library/2017-12-05-13-25-15.bpo-32227.3vnWFS.rst b/Misc/NEWS.d/next/Library/2017-12-05-13-25-15.bpo-32227.3vnWFS.rst
new file mode 100644 (file)
index 0000000..4dbc7ba
--- /dev/null
@@ -0,0 +1,2 @@
+``functools.singledispatch`` now supports registering implementations using
+type annotations.