]> granicus.if.org Git - python/commitdiff
Issue #18244: Adopt C3-based linearization in functools.singledispatch for improved...
authorŁukasz Langa <lukasz@langa.pl>
Mon, 1 Jul 2013 14:00:38 +0000 (16:00 +0200)
committerŁukasz Langa <lukasz@langa.pl>
Mon, 1 Jul 2013 14:00:38 +0000 (16:00 +0200)
Lib/functools.py
Lib/test/test_functools.py
Misc/ACKS

index 9403e8e84320b646171672d2b084bd6babe58377..95c1a414e5040eab7e853d38c9f2bcab45acb1e8 100644 (file)
@@ -365,46 +365,138 @@ def lru_cache(maxsize=128, typed=False):
 ### singledispatch() - single-dispatch generic function decorator
 ################################################################################
 
-def _compose_mro(cls, haystack):
-    """Calculates the MRO for a given class `cls`, including relevant abstract
-    base classes from `haystack`.
+def _c3_merge(sequences):
+    """Merges MROs in *sequences* to a single MRO using the C3 algorithm.
+
+    Adapted from http://www.python.org/download/releases/2.3/mro/.
 
     """
-    bases = set(cls.__mro__)
-    mro = list(cls.__mro__)
-    for needle in haystack:
-        if (needle in bases or not hasattr(needle, '__mro__')
-                            or not issubclass(cls, needle)):
-            continue   # either present in the __mro__ already or unrelated
-        for index, base in enumerate(mro):
-            if not issubclass(base, needle):
+    result = []
+    while True:
+        sequences = [s for s in sequences if s]   # purge empty sequences
+        if not sequences:
+            return result
+        for s1 in sequences:   # find merge candidates among seq heads
+            candidate = s1[0]
+            for s2 in sequences:
+                if candidate in s2[1:]:
+                    candidate = None
+                    break      # reject the current head, it appears later
+            else:
                 break
-        if base in bases and not issubclass(needle, base):
-            # Conflict resolution: put classes present in __mro__ and their
-            # subclasses first. See test_mro_conflicts() in test_functools.py
-            # for examples.
-            index += 1
-        mro.insert(index, needle)
-    return mro
+        if not candidate:
+            raise RuntimeError("Inconsistent hierarchy")
+        result.append(candidate)
+        # remove the chosen candidate
+        for seq in sequences:
+            if seq[0] == candidate:
+                del seq[0]
+
+def _c3_mro(cls, abcs=None):
+    """Computes the method resolution order using extended C3 linearization.
+
+    If no *abcs* are given, the algorithm works exactly like the built-in C3
+    linearization used for method resolution.
+
+    If given, *abcs* is a list of abstract base classes that should be inserted
+    into the resulting MRO. Unrelated ABCs are ignored and don't end up in the
+    result. The algorithm inserts ABCs where their functionality is introduced,
+    i.e. issubclass(cls, abc) returns True for the class itself but returns
+    False for all its direct base classes. Implicit ABCs for a given class
+    (either registered or inferred from the presence of a special method like
+    __len__) are inserted directly after the last ABC explicitly listed in the
+    MRO of said class. If two implicit ABCs end up next to each other in the
+    resulting MRO, their ordering depends on the order of types in *abcs*.
+
+    """
+    for i, base in enumerate(reversed(cls.__bases__)):
+        if hasattr(base, '__abstractmethods__'):
+            boundary = len(cls.__bases__) - i
+            break   # Bases up to the last explicit ABC are considered first.
+    else:
+        boundary = 0
+    abcs = list(abcs) if abcs else []
+    explicit_bases = list(cls.__bases__[:boundary])
+    abstract_bases = []
+    other_bases = list(cls.__bases__[boundary:])
+    for base in abcs:
+        if issubclass(cls, base) and not any(
+                issubclass(b, base) for b in cls.__bases__
+            ):
+            # If *cls* is the class that introduces behaviour described by
+            # an ABC *base*, insert said ABC to its MRO.
+            abstract_bases.append(base)
+    for base in abstract_bases:
+        abcs.remove(base)
+    explicit_c3_mros = [_c3_mro(base, abcs=abcs) for base in explicit_bases]
+    abstract_c3_mros = [_c3_mro(base, abcs=abcs) for base in abstract_bases]
+    other_c3_mros = [_c3_mro(base, abcs=abcs) for base in other_bases]
+    return _c3_merge(
+        [[cls]] +
+        explicit_c3_mros + abstract_c3_mros + other_c3_mros +
+        [explicit_bases] + [abstract_bases] + [other_bases]
+    )
+
+def _compose_mro(cls, types):
+    """Calculates the method resolution order for a given class *cls*.
+
+    Includes relevant abstract base classes (with their respective bases) from
+    the *types* iterable. Uses a modified C3 linearization algorithm.
+
+    """
+    bases = set(cls.__mro__)
+    # Remove entries which are already present in the __mro__ or unrelated.
+    def is_related(typ):
+        return (typ not in bases and hasattr(typ, '__mro__')
+                                 and issubclass(cls, typ))
+    types = [n for n in types if is_related(n)]
+    # Remove entries which are strict bases of other entries (they will end up
+    # in the MRO anyway.
+    def is_strict_base(typ):
+        for other in types:
+            if typ != other and typ in other.__mro__:
+                return True
+        return False
+    types = [n for n in types if not is_strict_base(n)]
+    # Subclasses of the ABCs in *types* which are also implemented by
+    # *cls* can be used to stabilize ABC ordering.
+    type_set = set(types)
+    mro = []
+    for typ in types:
+        found = []
+        for sub in typ.__subclasses__():
+            if sub not in bases and issubclass(cls, sub):
+                found.append([s for s in sub.__mro__ if s in type_set])
+        if not found:
+            mro.append(typ)
+            continue
+        # Favor subclasses with the biggest number of useful bases
+        found.sort(key=len, reverse=True)
+        for sub in found:
+            for subcls in sub:
+                if subcls not in mro:
+                    mro.append(subcls)
+    return _c3_mro(cls, abcs=mro)
 
 def _find_impl(cls, registry):
-    """Returns the best matching implementation for the given class `cls` in
-    `registry`. Where there is no registered implementation for a specific
-    type, its method resolution order is used to find a more generic
-    implementation.
+    """Returns the best matching implementation from *registry* for type *cls*.
+
+    Where there is no registered implementation for a specific type, its method
+    resolution order is used to find a more generic implementation.
 
-    Note: if `registry` does not contain an implementation for the base
-    `object` type, this function may return None.
+    Note: if *registry* does not contain an implementation for the base
+    *object* type, this function may return None.
 
     """
     mro = _compose_mro(cls, registry.keys())
     match = None
     for t in mro:
         if match is not None:
-            # If `match` is an ABC but there is another unrelated, equally
-            # matching ABC. Refuse the temptation to guess.
-            if (t in registry and not issubclass(match, t)
-                              and match not in cls.__mro__):
+            # If *match* is an implicit ABC but there is another unrelated,
+            # equally matching implicit ABC, refuse the temptation to guess.
+            if (t in registry and t not in cls.__mro__
+                              and match not in cls.__mro__
+                              and not issubclass(match, t)):
                 raise RuntimeError("Ambiguous dispatch: {} or {}".format(
                     match, t))
             break
@@ -418,19 +510,19 @@ def singledispatch(func):
     Transforms a function into a generic function, which can have different
     behaviours depending upon the type of its first argument. The decorated
     function acts as the default implementation, and additional
-    implementations can be registered using the 'register()' attribute of
-    the generic function.
+    implementations can be registered using the register() attribute of the
+    generic function.
 
     """
     registry = {}
     dispatch_cache = WeakKeyDictionary()
     cache_token = None
 
-    def dispatch(typ):
-        """generic_func.dispatch(type) -> <function implementation>
+    def dispatch(cls):
+        """generic_func.dispatch(cls) -> <function implementation>
 
         Runs the dispatch algorithm to return the best available implementation
-        for the given `type` registered on `generic_func`.
+        for the given *cls* registered on *generic_func*.
 
         """
         nonlocal cache_token
@@ -440,26 +532,26 @@ def singledispatch(func):
                 dispatch_cache.clear()
                 cache_token = current_token
         try:
-            impl = dispatch_cache[typ]
+            impl = dispatch_cache[cls]
         except KeyError:
             try:
-                impl = registry[typ]
+                impl = registry[cls]
             except KeyError:
-                impl = _find_impl(typ, registry)
-            dispatch_cache[typ] = impl
+                impl = _find_impl(cls, registry)
+            dispatch_cache[cls] = impl
         return impl
 
-    def register(typ, func=None):
-        """generic_func.register(type, func) -> func
+    def register(cls, func=None):
+        """generic_func.register(cls, func) -> func
 
-        Registers a new implementation for the given `type` on a `generic_func`.
+        Registers a new implementation for the given *cls* on a *generic_func*.
 
         """
         nonlocal cache_token
         if func is None:
-            return lambda f: register(typ, f)
-        registry[typ] = func
-        if cache_token is None and hasattr(typ, '__abstractmethods__'):
+            return lambda f: register(cls, f)
+        registry[cls] = func
+        if cache_token is None and hasattr(cls, '__abstractmethods__'):
             cache_token = get_cache_token()
         dispatch_cache.clear()
         return func
index 49c807d059241a47539e7c3fc3d70c165f9d3bea..99dccb096a14044181a123e59068dfcc92e88bfe 100644 (file)
@@ -929,22 +929,55 @@ class TestSingleDispatch(unittest.TestCase):
         self.assertEqual(g(rnd), ("Number got rounded",))
 
     def test_compose_mro(self):
+        # None of the examples in this test depend on haystack ordering.
         c = collections
         mro = functools._compose_mro
         bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
         for haystack in permutations(bases):
             m = mro(dict, haystack)
-            self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, object])
+            self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, c.Sized,
+                                 c.Iterable, c.Container, object])
         bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
         for haystack in permutations(bases):
             m = mro(c.ChainMap, haystack)
             self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
                                  c.Sized, c.Iterable, c.Container, object])
-        # Note: The MRO order below depends on haystack ordering.
-        m = mro(c.defaultdict, [c.Sized, c.Container, str])
-        self.assertEqual(m, [c.defaultdict, dict, c.Container, c.Sized, object])
-        m = mro(c.defaultdict, [c.Container, c.Sized, str])
-        self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container, object])
+
+        # If there's a generic function with implementations registered for
+        # both Sized and Container, passing a defaultdict to it results in an
+        # ambiguous dispatch which will cause a RuntimeError (see
+        # test_mro_conflicts).
+        bases = [c.Container, c.Sized, str]
+        for haystack in permutations(bases):
+            m = mro(c.defaultdict, [c.Sized, c.Container, str])
+            self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
+                                 object])
+
+        # MutableSequence below is registered directly on D. In other words, it
+        # preceeds MutableMapping which means single dispatch will always
+        # choose MutableSequence here.
+        class D(c.defaultdict):
+            pass
+        c.MutableSequence.register(D)
+        bases = [c.MutableSequence, c.MutableMapping]
+        for haystack in permutations(bases):
+            m = mro(D, bases)
+            self.assertEqual(m, [D, c.MutableSequence, c.Sequence,
+                                 c.defaultdict, dict, c.MutableMapping,
+                                 c.Mapping, c.Sized, c.Iterable, c.Container,
+                                 object])
+
+        # Container and Callable are registered on different base classes and
+        # a generic function supporting both should always pick the Callable
+        # implementation if a C instance is passed.
+        class C(c.defaultdict):
+            def __call__(self):
+                pass
+        bases = [c.Sized, c.Callable, c.Container, c.Mapping]
+        for haystack in permutations(bases):
+            m = mro(C, haystack)
+            self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
+                                 c.Sized, c.Iterable, c.Container, object])
 
     def test_register_abc(self):
         c = collections
@@ -1040,17 +1073,37 @@ class TestSingleDispatch(unittest.TestCase):
         self.assertEqual(g(f), "frozen-set")
         self.assertEqual(g(t), "tuple")
 
-    def test_mro_conflicts(self):
+    def test_c3_abc(self):
         c = collections
+        mro = functools._c3_mro
+        class A(object):
+            pass
+        class B(A):
+            def __len__(self):
+                return 0   # implies Sized
+        @c.Container.register
+        class C(object):
+            pass
+        class D(object):
+            pass   # unrelated
+        class X(D, C, B):
+            def __call__(self):
+                pass   # implies Callable
+        expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
+        for abcs in permutations([c.Sized, c.Callable, c.Container]):
+            self.assertEqual(mro(X, abcs=abcs), expected)
+        # unrelated ABCs don't appear in the resulting MRO
+        many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
+        self.assertEqual(mro(X, abcs=many_abcs), expected)
 
+    def test_mro_conflicts(self):
+        c = collections
         @functools.singledispatch
         def g(arg):
             return "base"
-
         class O(c.Sized):
             def __len__(self):
                 return 0
-
         o = O()
         self.assertEqual(g(o), "base")
         g.register(c.Iterable, lambda arg: "iterable")
@@ -1062,35 +1115,114 @@ class TestSingleDispatch(unittest.TestCase):
         self.assertEqual(g(o), "sized")   # because it's explicitly in __mro__
         c.Container.register(O)
         self.assertEqual(g(o), "sized")   # see above: Sized is in __mro__
-
+        c.Set.register(O)
+        self.assertEqual(g(o), "set")     # because c.Set is a subclass of
+                                          # c.Sized and c.Container
         class P:
             pass
-
         p = P()
         self.assertEqual(g(p), "base")
         c.Iterable.register(P)
         self.assertEqual(g(p), "iterable")
         c.Container.register(P)
-        with self.assertRaises(RuntimeError) as re:
+        with self.assertRaises(RuntimeError) as re_one:
             g(p)
-            self.assertEqual(
-                str(re),
-                ("Ambiguous dispatch: <class 'collections.abc.Container'> "
-                    "or <class 'collections.abc.Iterable'>"),
-            )
-
+        self.assertIn(
+            str(re_one.exception),
+            (("Ambiguous dispatch: <class 'collections.abc.Container'> "
+              "or <class 'collections.abc.Iterable'>"),
+             ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
+              "or <class 'collections.abc.Container'>")),
+        )
         class Q(c.Sized):
             def __len__(self):
                 return 0
-
         q = Q()
         self.assertEqual(g(q), "sized")
         c.Iterable.register(Q)
         self.assertEqual(g(q), "sized")   # because it's explicitly in __mro__
         c.Set.register(Q)
         self.assertEqual(g(q), "set")     # because c.Set is a subclass of
-                                          # c.Sized which is explicitly in
-                                          # __mro__
+                                          # c.Sized and c.Iterable
+        @functools.singledispatch
+        def h(arg):
+            return "base"
+        @h.register(c.Sized)
+        def _(arg):
+            return "sized"
+        @h.register(c.Container)
+        def _(arg):
+            return "container"
+        # Even though Sized and Container are explicit bases of MutableMapping,
+        # this ABC is implicitly registered on defaultdict which makes all of
+        # MutableMapping's bases implicit as well from defaultdict's
+        # perspective.
+        with self.assertRaises(RuntimeError) as re_two:
+            h(c.defaultdict(lambda: 0))
+        self.assertIn(
+            str(re_two.exception),
+            (("Ambiguous dispatch: <class 'collections.abc.Container'> "
+              "or <class 'collections.abc.Sized'>"),
+             ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
+              "or <class 'collections.abc.Container'>")),
+        )
+        class R(c.defaultdict):
+            pass
+        c.MutableSequence.register(R)
+        @functools.singledispatch
+        def i(arg):
+            return "base"
+        @i.register(c.MutableMapping)
+        def _(arg):
+            return "mapping"
+        @i.register(c.MutableSequence)
+        def _(arg):
+            return "sequence"
+        r = R()
+        self.assertEqual(i(r), "sequence")
+        class S:
+            pass
+        class T(S, c.Sized):
+            def __len__(self):
+                return 0
+        t = T()
+        self.assertEqual(h(t), "sized")
+        c.Container.register(T)
+        self.assertEqual(h(t), "sized")   # because it's explicitly in the MRO
+        class U:
+            def __len__(self):
+                return 0
+        u = U()
+        self.assertEqual(h(u), "sized")   # implicit Sized subclass inferred
+                                          # from the existence of __len__()
+        c.Container.register(U)
+        # There is no preference for registered versus inferred ABCs.
+        with self.assertRaises(RuntimeError) as re_three:
+            h(u)
+        self.assertIn(
+            str(re_three.exception),
+            (("Ambiguous dispatch: <class 'collections.abc.Container'> "
+              "or <class 'collections.abc.Sized'>"),
+             ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
+              "or <class 'collections.abc.Container'>")),
+        )
+        class V(c.Sized, S):
+            def __len__(self):
+                return 0
+        @functools.singledispatch
+        def j(arg):
+            return "base"
+        @j.register(S)
+        def _(arg):
+            return "s"
+        @j.register(c.Container)
+        def _(arg):
+            return "container"
+        v = V()
+        self.assertEqual(j(v), "s")
+        c.Container.register(V)
+        self.assertEqual(j(v), "container")   # because it ends up right after
+                                              # Sized in the MRO
 
     def test_cache_invalidation(self):
         from collections import UserDict
index ecf081e82ce91fe3ef22cfd27f337ef30eff5598..64f0b20e7197956c972cdff31b214c36d464f7aa 100644 (file)
--- a/Misc/ACKS
+++ b/Misc/ACKS
@@ -195,6 +195,7 @@ Brett Cannon
 Mike Carlton
 Pierre Carrier
 Terry Carroll
+Edward Catmur
 Lorenzo M. Catucci
 Donn Cave
 Charles Cazabon