]> granicus.if.org Git - python/commitdiff
bpo-37058: PEP 544: Add Protocol to typing module (GH-13585)
authorIvan Levkivskyi <levkivskyi@gmail.com>
Tue, 28 May 2019 07:40:15 +0000 (08:40 +0100)
committerGitHub <noreply@github.com>
Tue, 28 May 2019 07:40:15 +0000 (08:40 +0100)
I tried to get rid of the `_ProtocolMeta`, but unfortunately it didn'y work. My idea to return a generic alias from `@runtime_checkable` made runtime protocols unpickleable. I am not sure what is worse (a custom metaclass or having some classes unpickleable), so I decided to stick with the status quo (since there were no complains so far). So essentially this is a copy of the implementation in `typing_extensions` with two modifications:
* Rename `@runtime` to `@runtime_checkable` (plus corresponding updates).
* Allow protocols that extend `collections.abc.Iterable` etc.

Doc/library/typing.rst
Lib/test/test_typing.py
Lib/typing.py
Misc/NEWS.d/next/Library/2019-05-26-19-05-24.bpo-37058.jmRu_g.rst [new file with mode: 0644]

index 27787fc2cb86b34dd4c2b1412ec2f6495fc3a181..709580ad2159df2397f01109b56ba6eb57d27379 100644 (file)
@@ -17,7 +17,8 @@
 
 --------------
 
-This module supports type hints as specified by :pep:`484` and :pep:`526`.
+This module provides runtime support for type hints as specified by
+:pep:`484`, :pep:`526`, :pep:`544`, :pep:`586`, :pep:`589`, and :pep:`591`.
 The most fundamental support consists of the types :data:`Any`, :data:`Union`,
 :data:`Tuple`, :data:`Callable`, :class:`TypeVar`, and
 :class:`Generic`.  For full specification please see :pep:`484`.  For
@@ -392,6 +393,48 @@ it as a return value) of a more specialized type is a type error. For example::
 Use :class:`object` to indicate that a value could be any type in a typesafe
 manner. Use :data:`Any` to indicate that a value is dynamically typed.
 
+
+Nominal vs structural subtyping
+-------------------------------
+
+Initially :pep:`484` defined Python static type system as using
+*nominal subtyping*. This means that a class ``A`` is allowed where
+a class ``B`` is expected if and only if ``A`` is a subclass of ``B``.
+
+This requirement previously also applied to abstract base classes, such as
+:class:`Iterable`. The problem with this approach is that a class had
+to be explicitly marked to support them, which is unpythonic and unlike
+what one would normally do in idiomatic dynamically typed Python code.
+For example, this conforms to the :pep:`484`::
+
+   from typing import Sized, Iterable, Iterator
+
+   class Bucket(Sized, Iterable[int]):
+       ...
+       def __len__(self) -> int: ...
+       def __iter__(self) -> Iterator[int]: ...
+
+:pep:`544` allows to solve this problem by allowing users to write
+the above code without explicit base classes in the class definition,
+allowing ``Bucket`` to be implicitly considered a subtype of both ``Sized``
+and ``Iterable[int]`` by static type checkers. This is known as
+*structural subtyping* (or static duck-typing)::
+
+   from typing import Iterator, Iterable
+
+   class Bucket:  # Note: no base classes
+       ...
+       def __len__(self) -> int: ...
+       def __iter__(self) -> Iterator[int]: ...
+
+   def collect(items: Iterable[int]) -> int: ...
+   result = collect(Bucket())  # Passes type check
+
+Moreover, by subclassing a special class :class:`Protocol`, a user
+can define new custom protocols to fully enjoy structural subtyping
+(see examples below).
+
+
 Classes, functions, and decorators
 ----------------------------------
 
@@ -459,6 +502,39 @@ The module defines the following classes, functions and decorators:
           except KeyError:
               return default
 
+.. class:: Protocol(Generic)
+
+   Base class for protocol classes. Protocol classes are defined like this::
+
+      class Proto(Protocol):
+          def meth(self) -> int:
+              ...
+
+   Such classes are primarily used with static type checkers that recognize
+   structural subtyping (static duck-typing), for example::
+
+      class C:
+          def meth(self) -> int:
+              return 0
+
+      def func(x: Proto) -> int:
+          return x.meth()
+
+      func(C())  # Passes static type check
+
+   See :pep:`544` for details. Protocol classes decorated with
+   :func:`runtime_checkable` (described later) act as simple-minded runtime
+   protocols that check only the presence of given attributes, ignoring their
+   type signatures.
+
+   Protocol classes can be generic, for example::
+
+      class GenProto(Protocol[T]):
+          def meth(self) -> T:
+              ...
+
+   .. versionadded:: 3.8
+
 .. class:: Type(Generic[CT_co])
 
    A variable annotated with ``C`` may accept a value of type ``C``. In
@@ -1033,6 +1109,26 @@ The module defines the following classes, functions and decorators:
    Note that returning instances of private classes is not recommended.
    It is usually preferable to make such classes public.
 
+.. decorator:: runtime_checkable
+
+   Mark a protocol class as a runtime protocol.
+
+   Such a protocol can be used with :func:`isinstance` and :func:`issubclass`.
+   This raises :exc:`TypeError` when applied to a non-protocol class.  This
+   allows a simple-minded structural check, very similar to "one trick ponies"
+   in :mod:`collections.abc` such as :class:`Iterable`.  For example::
+
+      @runtime_checkable
+      class Closable(Protocol):
+          def close(self): ...
+
+      assert isinstance(open('/some/file'), Closable)
+
+   **Warning:** this will check only the presence of the required methods,
+   not their type signatures!
+
+   .. versionadded:: 3.8
+
 .. data:: Any
 
    Special type indicating an unconstrained type.
index 46b7621182d6fd6ad0fcb447e7a4c70ddb9a7fd4..2b4b934d69f2b5bc6b9d7476cdb01856492daddd 100644 (file)
@@ -12,8 +12,8 @@ from typing import T, KT, VT  # Not in __all__.
 from typing import Union, Optional, Literal
 from typing import Tuple, List, MutableMapping
 from typing import Callable
-from typing import Generic, ClassVar, Final, final
-from typing import cast
+from typing import Generic, ClassVar, Final, final, Protocol
+from typing import cast, runtime_checkable
 from typing import get_type_hints
 from typing import no_type_check, no_type_check_decorator
 from typing import Type
@@ -24,6 +24,7 @@ from typing import Pattern, Match
 import abc
 import typing
 import weakref
+import types
 
 from test import mod_generics_cache
 
@@ -585,7 +586,710 @@ class MySimpleMapping(SimpleMapping[XK, XV]):
             return default
 
 
+class Coordinate(Protocol):
+    x: int
+    y: int
+
+@runtime_checkable
+class Point(Coordinate, Protocol):
+    label: str
+
+class MyPoint:
+    x: int
+    y: int
+    label: str
+
+class XAxis(Protocol):
+    x: int
+
+class YAxis(Protocol):
+    y: int
+
+@runtime_checkable
+class Position(XAxis, YAxis, Protocol):
+    pass
+
+@runtime_checkable
+class Proto(Protocol):
+    attr: int
+    def meth(self, arg: str) -> int:
+        ...
+
+class Concrete(Proto):
+    pass
+
+class Other:
+    attr: int = 1
+    def meth(self, arg: str) -> int:
+        if arg == 'this':
+            return 1
+        return 0
+
+class NT(NamedTuple):
+    x: int
+    y: int
+
+@runtime_checkable
+class HasCallProtocol(Protocol):
+    __call__: typing.Callable
+
+
 class ProtocolTests(BaseTestCase):
+    def test_basic_protocol(self):
+        @runtime_checkable
+        class P(Protocol):
+            def meth(self):
+                pass
+
+        class C: pass
+
+        class D:
+            def meth(self):
+                pass
+
+        def f():
+            pass
+
+        self.assertIsSubclass(D, P)
+        self.assertIsInstance(D(), P)
+        self.assertNotIsSubclass(C, P)
+        self.assertNotIsInstance(C(), P)
+        self.assertNotIsSubclass(types.FunctionType, P)
+        self.assertNotIsInstance(f, P)
+
+    def test_everything_implements_empty_protocol(self):
+        @runtime_checkable
+        class Empty(Protocol):
+            pass
+
+        class C:
+            pass
+
+        def f():
+            pass
+
+        for thing in (object, type, tuple, C, types.FunctionType):
+            self.assertIsSubclass(thing, Empty)
+        for thing in (object(), 1, (), typing, f):
+            self.assertIsInstance(thing, Empty)
+
+    def test_function_implements_protocol(self):
+        def f():
+            pass
+
+        self.assertIsInstance(f, HasCallProtocol)
+
+    def test_no_inheritance_from_nominal(self):
+        class C: pass
+
+        class BP(Protocol): pass
+
+        with self.assertRaises(TypeError):
+            class P(C, Protocol):
+                pass
+        with self.assertRaises(TypeError):
+            class P(Protocol, C):
+                pass
+        with self.assertRaises(TypeError):
+            class P(BP, C, Protocol):
+                pass
+
+        class D(BP, C): pass
+
+        class E(C, BP): pass
+
+        self.assertNotIsInstance(D(), E)
+        self.assertNotIsInstance(E(), D)
+
+    def test_no_instantiation(self):
+        class P(Protocol): pass
+
+        with self.assertRaises(TypeError):
+            P()
+
+        class C(P): pass
+
+        self.assertIsInstance(C(), C)
+        T = TypeVar('T')
+
+        class PG(Protocol[T]): pass
+
+        with self.assertRaises(TypeError):
+            PG()
+        with self.assertRaises(TypeError):
+            PG[int]()
+        with self.assertRaises(TypeError):
+            PG[T]()
+
+        class CG(PG[T]): pass
+
+        self.assertIsInstance(CG[int](), CG)
+
+    def test_cannot_instantiate_abstract(self):
+        @runtime_checkable
+        class P(Protocol):
+            @abc.abstractmethod
+            def ameth(self) -> int:
+                raise NotImplementedError
+
+        class B(P):
+            pass
+
+        class C(B):
+            def ameth(self) -> int:
+                return 26
+
+        with self.assertRaises(TypeError):
+            B()
+        self.assertIsInstance(C(), P)
+
+    def test_subprotocols_extending(self):
+        class P1(Protocol):
+            def meth1(self):
+                pass
+
+        @runtime_checkable
+        class P2(P1, Protocol):
+            def meth2(self):
+                pass
+
+        class C:
+            def meth1(self):
+                pass
+
+            def meth2(self):
+                pass
+
+        class C1:
+            def meth1(self):
+                pass
+
+        class C2:
+            def meth2(self):
+                pass
+
+        self.assertNotIsInstance(C1(), P2)
+        self.assertNotIsInstance(C2(), P2)
+        self.assertNotIsSubclass(C1, P2)
+        self.assertNotIsSubclass(C2, P2)
+        self.assertIsInstance(C(), P2)
+        self.assertIsSubclass(C, P2)
+
+    def test_subprotocols_merging(self):
+        class P1(Protocol):
+            def meth1(self):
+                pass
+
+        class P2(Protocol):
+            def meth2(self):
+                pass
+
+        @runtime_checkable
+        class P(P1, P2, Protocol):
+            pass
+
+        class C:
+            def meth1(self):
+                pass
+
+            def meth2(self):
+                pass
+
+        class C1:
+            def meth1(self):
+                pass
+
+        class C2:
+            def meth2(self):
+                pass
+
+        self.assertNotIsInstance(C1(), P)
+        self.assertNotIsInstance(C2(), P)
+        self.assertNotIsSubclass(C1, P)
+        self.assertNotIsSubclass(C2, P)
+        self.assertIsInstance(C(), P)
+        self.assertIsSubclass(C, P)
+
+    def test_protocols_issubclass(self):
+        T = TypeVar('T')
+
+        @runtime_checkable
+        class P(Protocol):
+            def x(self): ...
+
+        @runtime_checkable
+        class PG(Protocol[T]):
+            def x(self): ...
+
+        class BadP(Protocol):
+            def x(self): ...
+
+        class BadPG(Protocol[T]):
+            def x(self): ...
+
+        class C:
+            def x(self): ...
+
+        self.assertIsSubclass(C, P)
+        self.assertIsSubclass(C, PG)
+        self.assertIsSubclass(BadP, PG)
+
+        with self.assertRaises(TypeError):
+            issubclass(C, PG[T])
+        with self.assertRaises(TypeError):
+            issubclass(C, PG[C])
+        with self.assertRaises(TypeError):
+            issubclass(C, BadP)
+        with self.assertRaises(TypeError):
+            issubclass(C, BadPG)
+        with self.assertRaises(TypeError):
+            issubclass(P, PG[T])
+        with self.assertRaises(TypeError):
+            issubclass(PG, PG[int])
+
+    def test_protocols_issubclass_non_callable(self):
+        class C:
+            x = 1
+
+        @runtime_checkable
+        class PNonCall(Protocol):
+            x = 1
+
+        with self.assertRaises(TypeError):
+            issubclass(C, PNonCall)
+        self.assertIsInstance(C(), PNonCall)
+        PNonCall.register(C)
+        with self.assertRaises(TypeError):
+            issubclass(C, PNonCall)
+        self.assertIsInstance(C(), PNonCall)
+
+        # check that non-protocol subclasses are not affected
+        class D(PNonCall): ...
+
+        self.assertNotIsSubclass(C, D)
+        self.assertNotIsInstance(C(), D)
+        D.register(C)
+        self.assertIsSubclass(C, D)
+        self.assertIsInstance(C(), D)
+        with self.assertRaises(TypeError):
+            issubclass(D, PNonCall)
+
+    def test_protocols_isinstance(self):
+        T = TypeVar('T')
+
+        @runtime_checkable
+        class P(Protocol):
+            def meth(x): ...
+
+        @runtime_checkable
+        class PG(Protocol[T]):
+            def meth(x): ...
+
+        class BadP(Protocol):
+            def meth(x): ...
+
+        class BadPG(Protocol[T]):
+            def meth(x): ...
+
+        class C:
+            def meth(x): ...
+
+        self.assertIsInstance(C(), P)
+        self.assertIsInstance(C(), PG)
+        with self.assertRaises(TypeError):
+            isinstance(C(), PG[T])
+        with self.assertRaises(TypeError):
+            isinstance(C(), PG[C])
+        with self.assertRaises(TypeError):
+            isinstance(C(), BadP)
+        with self.assertRaises(TypeError):
+            isinstance(C(), BadPG)
+
+    def test_protocols_isinstance_py36(self):
+        class APoint:
+            def __init__(self, x, y, label):
+                self.x = x
+                self.y = y
+                self.label = label
+
+        class BPoint:
+            label = 'B'
+
+            def __init__(self, x, y):
+                self.x = x
+                self.y = y
+
+        class C:
+            def __init__(self, attr):
+                self.attr = attr
+
+            def meth(self, arg):
+                return 0
+
+        class Bad: pass
+
+        self.assertIsInstance(APoint(1, 2, 'A'), Point)
+        self.assertIsInstance(BPoint(1, 2), Point)
+        self.assertNotIsInstance(MyPoint(), Point)
+        self.assertIsInstance(BPoint(1, 2), Position)
+        self.assertIsInstance(Other(), Proto)
+        self.assertIsInstance(Concrete(), Proto)
+        self.assertIsInstance(C(42), Proto)
+        self.assertNotIsInstance(Bad(), Proto)
+        self.assertNotIsInstance(Bad(), Point)
+        self.assertNotIsInstance(Bad(), Position)
+        self.assertNotIsInstance(Bad(), Concrete)
+        self.assertNotIsInstance(Other(), Concrete)
+        self.assertIsInstance(NT(1, 2), Position)
+
+    def test_protocols_isinstance_init(self):
+        T = TypeVar('T')
+
+        @runtime_checkable
+        class P(Protocol):
+            x = 1
+
+        @runtime_checkable
+        class PG(Protocol[T]):
+            x = 1
+
+        class C:
+            def __init__(self, x):
+                self.x = x
+
+        self.assertIsInstance(C(1), P)
+        self.assertIsInstance(C(1), PG)
+
+    def test_protocols_support_register(self):
+        @runtime_checkable
+        class P(Protocol):
+            x = 1
+
+        class PM(Protocol):
+            def meth(self): pass
+
+        class D(PM): pass
+
+        class C: pass
+
+        D.register(C)
+        P.register(C)
+        self.assertIsInstance(C(), P)
+        self.assertIsInstance(C(), D)
+
+    def test_none_on_non_callable_doesnt_block_implementation(self):
+        @runtime_checkable
+        class P(Protocol):
+            x = 1
+
+        class A:
+            x = 1
+
+        class B(A):
+            x = None
+
+        class C:
+            def __init__(self):
+                self.x = None
+
+        self.assertIsInstance(B(), P)
+        self.assertIsInstance(C(), P)
+
+    def test_none_on_callable_blocks_implementation(self):
+        @runtime_checkable
+        class P(Protocol):
+            def x(self): ...
+
+        class A:
+            def x(self): ...
+
+        class B(A):
+            x = None
+
+        class C:
+            def __init__(self):
+                self.x = None
+
+        self.assertNotIsInstance(B(), P)
+        self.assertNotIsInstance(C(), P)
+
+    def test_non_protocol_subclasses(self):
+        class P(Protocol):
+            x = 1
+
+        @runtime_checkable
+        class PR(Protocol):
+            def meth(self): pass
+
+        class NonP(P):
+            x = 1
+
+        class NonPR(PR): pass
+
+        class C:
+            x = 1
+
+        class D:
+            def meth(self): pass
+
+        self.assertNotIsInstance(C(), NonP)
+        self.assertNotIsInstance(D(), NonPR)
+        self.assertNotIsSubclass(C, NonP)
+        self.assertNotIsSubclass(D, NonPR)
+        self.assertIsInstance(NonPR(), PR)
+        self.assertIsSubclass(NonPR, PR)
+
+    def test_custom_subclasshook(self):
+        class P(Protocol):
+            x = 1
+
+        class OKClass: pass
+
+        class BadClass:
+            x = 1
+
+        class C(P):
+            @classmethod
+            def __subclasshook__(cls, other):
+                return other.__name__.startswith("OK")
+
+        self.assertIsInstance(OKClass(), C)
+        self.assertNotIsInstance(BadClass(), C)
+        self.assertIsSubclass(OKClass, C)
+        self.assertNotIsSubclass(BadClass, C)
+
+    def test_issubclass_fails_correctly(self):
+        @runtime_checkable
+        class P(Protocol):
+            x = 1
+
+        class C: pass
+
+        with self.assertRaises(TypeError):
+            issubclass(C(), P)
+
+    def test_defining_generic_protocols(self):
+        T = TypeVar('T')
+        S = TypeVar('S')
+
+        @runtime_checkable
+        class PR(Protocol[T, S]):
+            def meth(self): pass
+
+        class P(PR[int, T], Protocol[T]):
+            y = 1
+
+        with self.assertRaises(TypeError):
+            PR[int]
+        with self.assertRaises(TypeError):
+            P[int, str]
+        with self.assertRaises(TypeError):
+            PR[int, 1]
+        with self.assertRaises(TypeError):
+            PR[int, ClassVar]
+
+        class C(PR[int, T]): pass
+
+        self.assertIsInstance(C[str](), C)
+
+    def test_defining_generic_protocols_old_style(self):
+        T = TypeVar('T')
+        S = TypeVar('S')
+
+        @runtime_checkable
+        class PR(Protocol, Generic[T, S]):
+            def meth(self): pass
+
+        class P(PR[int, str], Protocol):
+            y = 1
+
+        with self.assertRaises(TypeError):
+            issubclass(PR[int, str], PR)
+        self.assertIsSubclass(P, PR)
+        with self.assertRaises(TypeError):
+            PR[int]
+        with self.assertRaises(TypeError):
+            PR[int, 1]
+
+        class P1(Protocol, Generic[T]):
+            def bar(self, x: T) -> str: ...
+
+        class P2(Generic[T], Protocol):
+            def bar(self, x: T) -> str: ...
+
+        @runtime_checkable
+        class PSub(P1[str], Protocol):
+            x = 1
+
+        class Test:
+            x = 1
+
+            def bar(self, x: str) -> str:
+                return x
+
+        self.assertIsInstance(Test(), PSub)
+        with self.assertRaises(TypeError):
+            PR[int, ClassVar]
+
+    def test_init_called(self):
+        T = TypeVar('T')
+
+        class P(Protocol[T]): pass
+
+        class C(P[T]):
+            def __init__(self):
+                self.test = 'OK'
+
+        self.assertEqual(C[int]().test, 'OK')
+
+    def test_protocols_bad_subscripts(self):
+        T = TypeVar('T')
+        S = TypeVar('S')
+        with self.assertRaises(TypeError):
+            class P(Protocol[T, T]): pass
+        with self.assertRaises(TypeError):
+            class P(Protocol[int]): pass
+        with self.assertRaises(TypeError):
+            class P(Protocol[T], Protocol[S]): pass
+        with self.assertRaises(TypeError):
+            class P(typing.Mapping[T, S], Protocol[T]): pass
+
+    def test_generic_protocols_repr(self):
+        T = TypeVar('T')
+        S = TypeVar('S')
+
+        class P(Protocol[T, S]): pass
+
+        self.assertTrue(repr(P[T, S]).endswith('P[~T, ~S]'))
+        self.assertTrue(repr(P[int, str]).endswith('P[int, str]'))
+
+    def test_generic_protocols_eq(self):
+        T = TypeVar('T')
+        S = TypeVar('S')
+
+        class P(Protocol[T, S]): pass
+
+        self.assertEqual(P, P)
+        self.assertEqual(P[int, T], P[int, T])
+        self.assertEqual(P[T, T][Tuple[T, S]][int, str],
+                         P[Tuple[int, str], Tuple[int, str]])
+
+    def test_generic_protocols_special_from_generic(self):
+        T = TypeVar('T')
+
+        class P(Protocol[T]): pass
+
+        self.assertEqual(P.__parameters__, (T,))
+        self.assertEqual(P[int].__parameters__, ())
+        self.assertEqual(P[int].__args__, (int,))
+        self.assertIs(P[int].__origin__, P)
+
+    def test_generic_protocols_special_from_protocol(self):
+        @runtime_checkable
+        class PR(Protocol):
+            x = 1
+
+        class P(Protocol):
+            def meth(self):
+                pass
+
+        T = TypeVar('T')
+
+        class PG(Protocol[T]):
+            x = 1
+
+            def meth(self):
+                pass
+
+        self.assertTrue(P._is_protocol)
+        self.assertTrue(PR._is_protocol)
+        self.assertTrue(PG._is_protocol)
+        self.assertFalse(P._is_runtime_protocol)
+        self.assertTrue(PR._is_runtime_protocol)
+        self.assertTrue(PG[int]._is_protocol)
+        self.assertEqual(typing._get_protocol_attrs(P), {'meth'})
+        self.assertEqual(typing._get_protocol_attrs(PR), {'x'})
+        self.assertEqual(frozenset(typing._get_protocol_attrs(PG)),
+                         frozenset({'x', 'meth'}))
+
+    def test_no_runtime_deco_on_nominal(self):
+        with self.assertRaises(TypeError):
+            @runtime_checkable
+            class C: pass
+
+        class Proto(Protocol):
+            x = 1
+
+        with self.assertRaises(TypeError):
+            @runtime_checkable
+            class Concrete(Proto):
+                pass
+
+    def test_none_treated_correctly(self):
+        @runtime_checkable
+        class P(Protocol):
+            x = None  # type: int
+
+        class B(object): pass
+
+        self.assertNotIsInstance(B(), P)
+
+        class C:
+            x = 1
+
+        class D:
+            x = None
+
+        self.assertIsInstance(C(), P)
+        self.assertIsInstance(D(), P)
+
+        class CI:
+            def __init__(self):
+                self.x = 1
+
+        class DI:
+            def __init__(self):
+                self.x = None
+
+        self.assertIsInstance(C(), P)
+        self.assertIsInstance(D(), P)
+
+    def test_protocols_in_unions(self):
+        class P(Protocol):
+            x = None  # type: int
+
+        Alias = typing.Union[typing.Iterable, P]
+        Alias2 = typing.Union[P, typing.Iterable]
+        self.assertEqual(Alias, Alias2)
+
+    def test_protocols_pickleable(self):
+        global P, CP  # pickle wants to reference the class by name
+        T = TypeVar('T')
+
+        @runtime_checkable
+        class P(Protocol[T]):
+            x = 1
+
+        class CP(P[int]):
+            pass
+
+        c = CP()
+        c.foo = 42
+        c.bar = 'abc'
+        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+            z = pickle.dumps(c, proto)
+            x = pickle.loads(z)
+            self.assertEqual(x.foo, 42)
+            self.assertEqual(x.bar, 'abc')
+            self.assertEqual(x.x, 1)
+            self.assertEqual(x.__dict__, {'foo': 42, 'bar': 'abc'})
+            s = pickle.dumps(P)
+            D = pickle.loads(s)
+
+            class E:
+                x = 1
+
+            self.assertIsInstance(E(), D)
 
     def test_supports_int(self):
         self.assertIsSubclass(int, typing.SupportsInt)
@@ -634,9 +1338,8 @@ class ProtocolTests(BaseTestCase):
         self.assertIsSubclass(int, typing.SupportsIndex)
         self.assertNotIsSubclass(str, typing.SupportsIndex)
 
-    def test_protocol_instance_type_error(self):
-        with self.assertRaises(TypeError):
-            isinstance(0, typing.SupportsAbs)
+    def test_bundled_protocol_instance_works(self):
+        self.assertIsInstance(0, typing.SupportsAbs)
         class C1(typing.SupportsInt):
             def __int__(self) -> int:
                 return 42
@@ -645,6 +1348,20 @@ class ProtocolTests(BaseTestCase):
         c = C2()
         self.assertIsInstance(c, C1)
 
+    def test_collections_protocols_allowed(self):
+        @runtime_checkable
+        class Custom(collections.abc.Iterable, Protocol):
+            def close(self): ...
+
+        class A: pass
+        class B:
+            def __iter__(self):
+                return []
+            def close(self):
+                return 0
+
+        self.assertIsSubclass(B, Custom)
+        self.assertNotIsSubclass(A, Custom)
 
 class GenericTests(BaseTestCase):
 
@@ -771,7 +1488,7 @@ class GenericTests(BaseTestCase):
     def test_new_repr_bare(self):
         T = TypeVar('T')
         self.assertEqual(repr(Generic[T]), 'typing.Generic[~T]')
-        self.assertEqual(repr(typing._Protocol[T]), 'typing._Protocol[~T]')
+        self.assertEqual(repr(typing.Protocol[T]), 'typing.Protocol[~T]')
         class C(typing.Dict[Any, Any]): ...
         # this line should just work
         repr(C.__mro__)
@@ -1067,7 +1784,7 @@ class GenericTests(BaseTestCase):
         with self.assertRaises(TypeError):
             Tuple[Generic[T]]
         with self.assertRaises(TypeError):
-            List[typing._Protocol]
+            List[typing.Protocol]
 
     def test_type_erasure_special(self):
         T = TypeVar('T')
index d3e84cd64abe8c0c1363bbbf6d546b1c2f9f69f3..14bd06b2b745b9505393a0935115c0ada9b78800 100644 (file)
@@ -9,8 +9,7 @@ At large scale, the structure of the module is following:
 * The core of internal generics API: _GenericAlias and _VariadicGenericAlias, the latter is
   currently only used by Tuple and Callable. All subscripted types like X[int], Union[int, str],
   etc., are instances of either of these classes.
-* The public counterpart of the generics API consists of two classes: Generic and Protocol
-  (the latter is currently private, but will be made public after PEP 544 acceptance).
+* The public counterpart of the generics API consists of two classes: Generic and Protocol.
 * Public helper functions: get_type_hints, overload, cast, no_type_check,
   no_type_check_decorator.
 * Generic aliases for collections.abc ABCs and few additional protocols.
@@ -18,7 +17,7 @@ At large scale, the structure of the module is following:
 * Wrapper submodules for re and io related types.
 """
 
-from abc import abstractmethod, abstractproperty
+from abc import abstractmethod, abstractproperty, ABCMeta
 import collections
 import collections.abc
 import contextlib
@@ -39,6 +38,7 @@ __all__ = [
     'Generic',
     'Literal',
     'Optional',
+    'Protocol',
     'Tuple',
     'Type',
     'TypeVar',
@@ -102,6 +102,7 @@ __all__ = [
     'no_type_check_decorator',
     'NoReturn',
     'overload',
+    'runtime_checkable',
     'Text',
     'TYPE_CHECKING',
 ]
@@ -123,7 +124,7 @@ def _type_check(arg, msg, is_argument=True):
 
     We append the repr() of the actual value (truncated to 100 chars).
     """
-    invalid_generic_forms = (Generic, _Protocol)
+    invalid_generic_forms = (Generic, Protocol)
     if is_argument:
         invalid_generic_forms = invalid_generic_forms + (ClassVar, Final)
 
@@ -135,7 +136,7 @@ def _type_check(arg, msg, is_argument=True):
             arg.__origin__ in invalid_generic_forms):
         raise TypeError(f"{arg} is not valid as type argument")
     if (isinstance(arg, _SpecialForm) and arg not in (Any, NoReturn) or
-            arg in (Generic, _Protocol)):
+            arg in (Generic, Protocol)):
         raise TypeError(f"Plain {arg} is not valid as type argument")
     if isinstance(arg, (type, TypeVar, ForwardRef)):
         return arg
@@ -665,8 +666,8 @@ class _GenericAlias(_Final, _root=True):
 
     @_tp_cache
     def __getitem__(self, params):
-        if self.__origin__ in (Generic, _Protocol):
-            # Can't subscript Generic[...] or _Protocol[...].
+        if self.__origin__ in (Generic, Protocol):
+            # Can't subscript Generic[...] or Protocol[...].
             raise TypeError(f"Cannot subscript already-subscripted {self}")
         if not isinstance(params, tuple):
             params = (params,)
@@ -733,6 +734,8 @@ class _GenericAlias(_Final, _root=True):
                 res.append(Generic)
             return tuple(res)
         if self.__origin__ is Generic:
+            if Protocol in bases:
+                return ()
             i = bases.index(self)
             for b in bases[i+1:]:
                 if isinstance(b, _GenericAlias) and b is not self:
@@ -850,10 +853,11 @@ class Generic:
               return default
     """
     __slots__ = ()
+    _is_protocol = False
 
     def __new__(cls, *args, **kwds):
-        if cls is Generic:
-            raise TypeError("Type Generic cannot be instantiated; "
+        if cls in (Generic, Protocol):
+            raise TypeError(f"Type {cls.__name__} cannot be instantiated; "
                             "it can be used only as a base class")
         if super().__new__ is object.__new__ and cls.__init__ is not object.__init__:
             obj = super().__new__(cls)
@@ -870,17 +874,14 @@ class Generic:
                 f"Parameter list to {cls.__qualname__}[...] cannot be empty")
         msg = "Parameters to generic types must be types."
         params = tuple(_type_check(p, msg) for p in params)
-        if cls is Generic:
-            # Generic can only be subscripted with unique type variables.
+        if cls in (Generic, Protocol):
+            # Generic and Protocol can only be subscripted with unique type variables.
             if not all(isinstance(p, TypeVar) for p in params):
                 raise TypeError(
-                    "Parameters to Generic[...] must all be type variables")
+                    f"Parameters to {cls.__name__}[...] must all be type variables")
             if len(set(params)) != len(params):
                 raise TypeError(
-                    "Parameters to Generic[...] must all be unique")
-        elif cls is _Protocol:
-            # _Protocol is internal at the moment, just skip the check
-            pass
+                    f"Parameters to {cls.__name__}[...] must all be unique")
         else:
             # Subscripting a regular Generic subclass.
             _check_generic(cls, params)
@@ -892,7 +893,7 @@ class Generic:
         if '__orig_bases__' in cls.__dict__:
             error = Generic in cls.__orig_bases__
         else:
-            error = Generic in cls.__bases__ and cls.__name__ != '_Protocol'
+            error = Generic in cls.__bases__ and cls.__name__ != 'Protocol'
         if error:
             raise TypeError("Cannot inherit from plain Generic")
         if '__orig_bases__' in cls.__dict__:
@@ -910,9 +911,7 @@ class Generic:
                         raise TypeError(
                             "Cannot inherit from Generic[...] multiple types.")
                     gvars = base.__parameters__
-            if gvars is None:
-                gvars = tvars
-            else:
+            if gvars is not None:
                 tvarset = set(tvars)
                 gvarset = set(gvars)
                 if not tvarset <= gvarset:
@@ -935,6 +934,204 @@ class _TypingEllipsis:
     """Internal placeholder for ... (ellipsis)."""
 
 
+_TYPING_INTERNALS = ['__parameters__', '__orig_bases__',  '__orig_class__',
+                     '_is_protocol', '_is_runtime_protocol']
+
+_SPECIAL_NAMES = ['__abstractmethods__', '__annotations__', '__dict__', '__doc__',
+                  '__init__', '__module__', '__new__', '__slots__',
+                  '__subclasshook__', '__weakref__']
+
+# These special attributes will be not collected as protocol members.
+EXCLUDED_ATTRIBUTES = _TYPING_INTERNALS + _SPECIAL_NAMES + ['_MutableMapping__marker']
+
+
+def _get_protocol_attrs(cls):
+    """Collect protocol members from a protocol class objects.
+
+    This includes names actually defined in the class dictionary, as well
+    as names that appear in annotations. Special names (above) are skipped.
+    """
+    attrs = set()
+    for base in cls.__mro__[:-1]:  # without object
+        if base.__name__ in ('Protocol', 'Generic'):
+            continue
+        annotations = getattr(base, '__annotations__', {})
+        for attr in list(base.__dict__.keys()) + list(annotations.keys()):
+            if not attr.startswith('_abc_') and attr not in EXCLUDED_ATTRIBUTES:
+                attrs.add(attr)
+    return attrs
+
+
+def _is_callable_members_only(cls):
+    # PEP 544 prohibits using issubclass() with protocols that have non-method members.
+    return all(callable(getattr(cls, attr, None)) for attr in _get_protocol_attrs(cls))
+
+
+def _no_init(self, *args, **kwargs):
+    if type(self)._is_protocol:
+        raise TypeError('Protocols cannot be instantiated')
+
+
+def _allow_reckless_class_cheks():
+    """Allow instnance and class checks for special stdlib modules.
+
+    The abc and functools modules indiscriminately call isinstance() and
+    issubclass() on the whole MRO of a user class, which may contain protocols.
+    """
+    try:
+        return sys._getframe(3).f_globals['__name__'] in ['abc', 'functools']
+    except (AttributeError, ValueError):  # For platforms without _getframe().
+        return True
+
+
+_PROTO_WHITELIST = ['Callable', 'Awaitable',
+                    'Iterable', 'Iterator', 'AsyncIterable', 'AsyncIterator',
+                    'Hashable', 'Sized', 'Container', 'Collection', 'Reversible',
+                    'ContextManager', 'AsyncContextManager']
+
+
+class _ProtocolMeta(ABCMeta):
+    # This metaclass is really unfortunate and exists only because of
+    # the lack of __instancehook__.
+    def __instancecheck__(cls, instance):
+        # We need this method for situations where attributes are
+        # assigned in __init__.
+        if ((not getattr(cls, '_is_protocol', False) or
+                _is_callable_members_only(cls)) and
+                issubclass(instance.__class__, cls)):
+            return True
+        if cls._is_protocol:
+            if all(hasattr(instance, attr) and
+                    # All *methods* can be blocked by setting them to None.
+                    (not callable(getattr(cls, attr, None)) or
+                     getattr(instance, attr) is not None)
+                    for attr in _get_protocol_attrs(cls)):
+                return True
+        return super().__instancecheck__(instance)
+
+
+class Protocol(Generic, metaclass=_ProtocolMeta):
+    """Base class for protocol classes.
+
+    Protocol classes are defined as::
+
+        class Proto(Protocol):
+            def meth(self) -> int:
+                ...
+
+    Such classes are primarily used with static type checkers that recognize
+    structural subtyping (static duck-typing), for example::
+
+        class C:
+            def meth(self) -> int:
+                return 0
+
+        def func(x: Proto) -> int:
+            return x.meth()
+
+        func(C())  # Passes static type check
+
+    See PEP 544 for details. Protocol classes decorated with
+    @typing.runtime_checkable act as simple-minded runtime protocols that check
+    only the presence of given attributes, ignoring their type signatures.
+    Protocol classes can be generic, they are defined as::
+
+        class GenProto(Protocol[T]):
+            def meth(self) -> T:
+                ...
+    """
+    __slots__ = ()
+    _is_protocol = True
+    _is_runtime_protocol = False
+
+    def __init_subclass__(cls, *args, **kwargs):
+        super().__init_subclass__(*args, **kwargs)
+
+        # Determine if this is a protocol or a concrete subclass.
+        if not cls.__dict__.get('_is_protocol', False):
+            cls._is_protocol = any(b is Protocol for b in cls.__bases__)
+
+        # Set (or override) the protocol subclass hook.
+        def _proto_hook(other):
+            if not cls.__dict__.get('_is_protocol', False):
+                return NotImplemented
+
+            # First, perform various sanity checks.
+            if not getattr(cls, '_is_runtime_protocol', False):
+                if _allow_reckless_class_cheks():
+                    return NotImplemented
+                raise TypeError("Instance and class checks can only be used with"
+                                " @runtime_checkable protocols")
+            if not _is_callable_members_only(cls):
+                if _allow_reckless_class_cheks():
+                    return NotImplemented
+                raise TypeError("Protocols with non-method members"
+                                " don't support issubclass()")
+            if not isinstance(other, type):
+                # Same error message as for issubclass(1, int).
+                raise TypeError('issubclass() arg 1 must be a class')
+
+            # Second, perform the actual structural compatibility check.
+            for attr in _get_protocol_attrs(cls):
+                for base in other.__mro__:
+                    # Check if the members appears in the class dictionary...
+                    if attr in base.__dict__:
+                        if base.__dict__[attr] is None:
+                            return NotImplemented
+                        break
+
+                    # ...or in annotations, if it is a sub-protocol.
+                    annotations = getattr(base, '__annotations__', {})
+                    if (isinstance(annotations, collections.abc.Mapping) and
+                            attr in annotations and
+                            issubclass(other, Generic) and other._is_protocol):
+                        break
+                else:
+                    return NotImplemented
+            return True
+
+        if '__subclasshook__' not in cls.__dict__:
+            cls.__subclasshook__ = _proto_hook
+
+        # We have nothing more to do for non-protocols...
+        if not cls._is_protocol:
+            return
+
+        # ... otherwise check consistency of bases, and prohibit instantiation.
+        for base in cls.__bases__:
+            if not (base in (object, Generic) or
+                    base.__module__ == 'collections.abc' and base.__name__ in _PROTO_WHITELIST or
+                    issubclass(base, Generic) and base._is_protocol):
+                raise TypeError('Protocols can only inherit from other'
+                                ' protocols, got %r' % base)
+        cls.__init__ = _no_init
+
+
+def runtime_checkable(cls):
+    """Mark a protocol class as a runtime protocol.
+
+    Such protocol can be used with isinstance() and issubclass().
+    Raise TypeError if applied to a non-protocol class.
+    This allows a simple-minded structural check very similar to
+    one trick ponies in collections.abc such as Iterable.
+    For example::
+
+        @runtime_checkable
+        class Closable(Protocol):
+            def close(self): ...
+
+        assert isinstance(open('/some/file'), Closable)
+
+    Warning: this will check only the presence of the required methods,
+    not their type signatures!
+    """
+    if not issubclass(cls, Generic) or not cls._is_protocol:
+        raise TypeError('@runtime_checkable can be only applied to protocol classes,'
+                        ' got %r' % cls)
+    cls._is_runtime_protocol = True
+    return cls
+
+
 def cast(typ, val):
     """Cast a value to a type.
 
@@ -1159,90 +1356,6 @@ def final(f):
     return f
 
 
-class _ProtocolMeta(type):
-    """Internal metaclass for _Protocol.
-
-    This exists so _Protocol classes can be generic without deriving
-    from Generic.
-    """
-
-    def __instancecheck__(self, obj):
-        if _Protocol not in self.__bases__:
-            return super().__instancecheck__(obj)
-        raise TypeError("Protocols cannot be used with isinstance().")
-
-    def __subclasscheck__(self, cls):
-        if not self._is_protocol:
-            # No structural checks since this isn't a protocol.
-            return NotImplemented
-
-        if self is _Protocol:
-            # Every class is a subclass of the empty protocol.
-            return True
-
-        # Find all attributes defined in the protocol.
-        attrs = self._get_protocol_attrs()
-
-        for attr in attrs:
-            if not any(attr in d.__dict__ for d in cls.__mro__):
-                return False
-        return True
-
-    def _get_protocol_attrs(self):
-        # Get all Protocol base classes.
-        protocol_bases = []
-        for c in self.__mro__:
-            if getattr(c, '_is_protocol', False) and c.__name__ != '_Protocol':
-                protocol_bases.append(c)
-
-        # Get attributes included in protocol.
-        attrs = set()
-        for base in protocol_bases:
-            for attr in base.__dict__.keys():
-                # Include attributes not defined in any non-protocol bases.
-                for c in self.__mro__:
-                    if (c is not base and attr in c.__dict__ and
-                            not getattr(c, '_is_protocol', False)):
-                        break
-                else:
-                    if (not attr.startswith('_abc_') and
-                            attr != '__abstractmethods__' and
-                            attr != '__annotations__' and
-                            attr != '__weakref__' and
-                            attr != '_is_protocol' and
-                            attr != '_gorg' and
-                            attr != '__dict__' and
-                            attr != '__args__' and
-                            attr != '__slots__' and
-                            attr != '_get_protocol_attrs' and
-                            attr != '__next_in_mro__' and
-                            attr != '__parameters__' and
-                            attr != '__origin__' and
-                            attr != '__orig_bases__' and
-                            attr != '__extra__' and
-                            attr != '__tree_hash__' and
-                            attr != '__module__'):
-                        attrs.add(attr)
-
-        return attrs
-
-
-class _Protocol(Generic, metaclass=_ProtocolMeta):
-    """Internal base class for protocol classes.
-
-    This implements a simple-minded structural issubclass check
-    (similar but more general than the one-offs in collections.abc
-    such as Hashable).
-    """
-
-    __slots__ = ()
-
-    _is_protocol = True
-
-    def __class_getitem__(cls, params):
-        return super().__class_getitem__(params)
-
-
 # Some unconstrained type variables.  These are used by the container types.
 # (These are not for export.)
 T = TypeVar('T')  # Any type.
@@ -1347,7 +1460,8 @@ Type.__doc__ = \
     """
 
 
-class SupportsInt(_Protocol):
+@runtime_checkable
+class SupportsInt(Protocol):
     __slots__ = ()
 
     @abstractmethod
@@ -1355,7 +1469,8 @@ class SupportsInt(_Protocol):
         pass
 
 
-class SupportsFloat(_Protocol):
+@runtime_checkable
+class SupportsFloat(Protocol):
     __slots__ = ()
 
     @abstractmethod
@@ -1363,7 +1478,8 @@ class SupportsFloat(_Protocol):
         pass
 
 
-class SupportsComplex(_Protocol):
+@runtime_checkable
+class SupportsComplex(Protocol):
     __slots__ = ()
 
     @abstractmethod
@@ -1371,7 +1487,8 @@ class SupportsComplex(_Protocol):
         pass
 
 
-class SupportsBytes(_Protocol):
+@runtime_checkable
+class SupportsBytes(Protocol):
     __slots__ = ()
 
     @abstractmethod
@@ -1379,7 +1496,8 @@ class SupportsBytes(_Protocol):
         pass
 
 
-class SupportsIndex(_Protocol):
+@runtime_checkable
+class SupportsIndex(Protocol):
     __slots__ = ()
 
     @abstractmethod
@@ -1387,7 +1505,8 @@ class SupportsIndex(_Protocol):
         pass
 
 
-class SupportsAbs(_Protocol[T_co]):
+@runtime_checkable
+class SupportsAbs(Protocol[T_co]):
     __slots__ = ()
 
     @abstractmethod
@@ -1395,7 +1514,8 @@ class SupportsAbs(_Protocol[T_co]):
         pass
 
 
-class SupportsRound(_Protocol[T_co]):
+@runtime_checkable
+class SupportsRound(Protocol[T_co]):
     __slots__ = ()
 
     @abstractmethod
diff --git a/Misc/NEWS.d/next/Library/2019-05-26-19-05-24.bpo-37058.jmRu_g.rst b/Misc/NEWS.d/next/Library/2019-05-26-19-05-24.bpo-37058.jmRu_g.rst
new file mode 100644 (file)
index 0000000..329b82c
--- /dev/null
@@ -0,0 +1 @@
+PEP 544: Add ``Protocol`` and ``@runtime_checkable`` to the ``typing`` module.
\ No newline at end of file