]> granicus.if.org Git - python/commitdiff
inspect.signature: Add (restore) support for builtin classes #20473
authorYury Selivanov <yselivanov@sprymix.com>
Mon, 3 Feb 2014 07:46:07 +0000 (02:46 -0500)
committerYury Selivanov <yselivanov@sprymix.com>
Mon, 3 Feb 2014 07:46:07 +0000 (02:46 -0500)
Lib/inspect.py
Lib/test/test_inspect.py

index 4e33a22e310b6a94d5ca64e6825cca07c09bf38b..7a2739f80422752a74e8219776283f85e0cf7945 100644 (file)
@@ -1648,6 +1648,139 @@ def _signature_get_bound_param(spec):
     return spec[2:pos]
 
 
+def _signature_fromstr(cls, obj, s):
+    # Internal helper to parse content of '__text_signature__'
+    # and return a Signature based on it
+    Parameter = cls._parameter_cls
+
+    if s.endswith("/)"):
+        kind = Parameter.POSITIONAL_ONLY
+        s = s[:-2] + ')'
+    else:
+        kind = Parameter.POSITIONAL_OR_KEYWORD
+
+    first_parameter_is_self = s.startswith("($")
+    if first_parameter_is_self:
+        s = '(' + s[2:]
+
+    s = "def foo" + s + ": pass"
+
+    try:
+        module = ast.parse(s)
+    except SyntaxError:
+        module = None
+
+    if not isinstance(module, ast.Module):
+        raise ValueError("{!r} builtin has invalid signature".format(obj))
+
+    f = module.body[0]
+
+    parameters = []
+    empty = Parameter.empty
+    invalid = object()
+
+    module = None
+    module_dict = {}
+    module_name = getattr(obj, '__module__', None)
+    if module_name:
+        module = sys.modules.get(module_name, None)
+        if module:
+            module_dict = module.__dict__
+    sys_module_dict = sys.modules
+
+    def parse_name(node):
+        assert isinstance(node, ast.arg)
+        if node.annotation != None:
+            raise ValueError("Annotations are not currently supported")
+        return node.arg
+
+    def wrap_value(s):
+        try:
+            value = eval(s, module_dict)
+        except NameError:
+            try:
+                value = eval(s, sys_module_dict)
+            except NameError:
+                raise RuntimeError()
+
+        if isinstance(value, str):
+            return ast.Str(value)
+        if isinstance(value, (int, float)):
+            return ast.Num(value)
+        if isinstance(value, bytes):
+            return ast.Bytes(value)
+        if value in (True, False, None):
+            return ast.NameConstant(value)
+        raise RuntimeError()
+
+    class RewriteSymbolics(ast.NodeTransformer):
+        def visit_Attribute(self, node):
+            a = []
+            n = node
+            while isinstance(n, ast.Attribute):
+                a.append(n.attr)
+                n = n.value
+            if not isinstance(n, ast.Name):
+                raise RuntimeError()
+            a.append(n.id)
+            value = ".".join(reversed(a))
+            return wrap_value(value)
+
+        def visit_Name(self, node):
+            if not isinstance(node.ctx, ast.Load):
+                raise ValueError()
+            return wrap_value(node.id)
+
+    def p(name_node, default_node, default=empty):
+        name = parse_name(name_node)
+        if name is invalid:
+            return None
+        if default_node and default_node is not _empty:
+            try:
+                default_node = RewriteSymbolics().visit(default_node)
+                o = ast.literal_eval(default_node)
+            except ValueError:
+                o = invalid
+            if o is invalid:
+                return None
+            default = o if o is not invalid else default
+        parameters.append(Parameter(name, kind, default=default, annotation=empty))
+
+    # non-keyword-only parameters
+    args = reversed(f.args.args)
+    defaults = reversed(f.args.defaults)
+    iter = itertools.zip_longest(args, defaults, fillvalue=None)
+    for name, default in reversed(list(iter)):
+        p(name, default)
+
+    # *args
+    if f.args.vararg:
+        kind = Parameter.VAR_POSITIONAL
+        p(f.args.vararg, empty)
+
+    # keyword-only arguments
+    kind = Parameter.KEYWORD_ONLY
+    for name, default in zip(f.args.kwonlyargs, f.args.kw_defaults):
+        p(name, default)
+
+    # **kwargs
+    if f.args.kwarg:
+        kind = Parameter.VAR_KEYWORD
+        p(f.args.kwarg, empty)
+
+    if first_parameter_is_self:
+        assert parameters
+        if getattr(obj, '__self__', None):
+            # strip off self, it's already been bound
+            parameters.pop(0)
+        else:
+            # for builtins, self parameter is always positional-only!
+            p = parameters[0].replace(kind=Parameter.POSITIONAL_ONLY)
+            parameters[0] = p
+
+    return cls(parameters, return_annotation=cls.empty)
+
+
 def signature(obj):
     '''Get a signature object for the passed callable.'''
 
@@ -1725,14 +1858,41 @@ def signature(obj):
                     sig = signature(init)
 
         if sig is None:
+            # At this point we know, that `obj` is a class, with no user-
+            # defined '__init__', '__new__', or class-level '__call__'
+
+            for base in obj.__mro__:
+                # Since '__text_signature__' is implemented as a
+                # descriptor that extracts text signature from the
+                # class docstring, if 'obj' is derived from a builtin
+                # class, its own '__text_signature__' may be 'None'.
+                # Therefore, we go through the MRO to find the first
+                # class with non-empty text signature.
+                try:
+                    text_sig = base.__text_signature__
+                except AttributeError:
+                    pass
+                else:
+                    if text_sig:
+                        # If 'obj' class has a __text_signature__ attribute:
+                        # return a signature based on it
+                        return _signature_fromstr(Signature, obj, text_sig)
+
+            # No '__text_signature__' was found for the 'obj' class.
+            # Last option is to check if its '__init__' is
+            # object.__init__ or type.__init__.
             if type in obj.__mro__:
                 # 'obj' is a metaclass without user-defined __init__
-                # or __new__. Return a signature of 'type' builtin.
-                return signature(type)
+                # or __new__.
+                if obj.__init__ is type.__init__:
+                    # Return a signature of 'type' builtin.
+                    return signature(type)
             else:
                 # We have a class (not metaclass), but no user-defined
                 # __init__ or __new__ for it
-                return signature(object)
+                if obj.__init__ is object.__init__:
+                    # Return a signature of 'object' builtin.
+                    return signature(object)
 
     elif not isinstance(obj, _NonUserDefinedCallables):
         # An object with __call__
@@ -2196,134 +2356,7 @@ class Signature:
         if not s:
             raise ValueError("no signature found for builtin {!r}".format(func))
 
-        Parameter = cls._parameter_cls
-
-        if s.endswith("/)"):
-            kind = Parameter.POSITIONAL_ONLY
-            s = s[:-2] + ')'
-        else:
-            kind = Parameter.POSITIONAL_OR_KEYWORD
-
-        first_parameter_is_self = s.startswith("($")
-        if first_parameter_is_self:
-            s = '(' + s[2:]
-
-        s = "def foo" + s + ": pass"
-
-        try:
-            module = ast.parse(s)
-        except SyntaxError:
-            module = None
-
-        if not isinstance(module, ast.Module):
-            raise ValueError("{!r} builtin has invalid signature".format(func))
-
-        f = module.body[0]
-
-        parameters = []
-        empty = Parameter.empty
-        invalid = object()
-
-        module = None
-        module_dict = {}
-        module_name = getattr(func, '__module__', None)
-        if module_name:
-            module = sys.modules.get(module_name, None)
-            if module:
-                module_dict = module.__dict__
-        sys_module_dict = sys.modules
-
-        def parse_name(node):
-            assert isinstance(node, ast.arg)
-            if node.annotation != None:
-                raise ValueError("Annotations are not currently supported")
-            return node.arg
-
-        def wrap_value(s):
-            try:
-                value = eval(s, module_dict)
-            except NameError:
-                try:
-                    value = eval(s, sys_module_dict)
-                except NameError:
-                    raise RuntimeError()
-
-            if isinstance(value, str):
-                return ast.Str(value)
-            if isinstance(value, (int, float)):
-                return ast.Num(value)
-            if isinstance(value, bytes):
-                return ast.Bytes(value)
-            if value in (True, False, None):
-                return ast.NameConstant(value)
-            raise RuntimeError()
-
-        class RewriteSymbolics(ast.NodeTransformer):
-            def visit_Attribute(self, node):
-                a = []
-                n = node
-                while isinstance(n, ast.Attribute):
-                    a.append(n.attr)
-                    n = n.value
-                if not isinstance(n, ast.Name):
-                    raise RuntimeError()
-                a.append(n.id)
-                value = ".".join(reversed(a))
-                return wrap_value(value)
-
-            def visit_Name(self, node):
-                if not isinstance(node.ctx, ast.Load):
-                    raise ValueError()
-                return wrap_value(node.id)
-
-        def p(name_node, default_node, default=empty):
-            name = parse_name(name_node)
-            if name is invalid:
-                return None
-            if default_node and default_node is not _empty:
-                try:
-                    default_node = RewriteSymbolics().visit(default_node)
-                    o = ast.literal_eval(default_node)
-                except ValueError:
-                    o = invalid
-                if o is invalid:
-                    return None
-                default = o if o is not invalid else default
-            parameters.append(Parameter(name, kind, default=default, annotation=empty))
-
-        # non-keyword-only parameters
-        args = reversed(f.args.args)
-        defaults = reversed(f.args.defaults)
-        iter = itertools.zip_longest(args, defaults, fillvalue=None)
-        for name, default in reversed(list(iter)):
-            p(name, default)
-
-        # *args
-        if f.args.vararg:
-            kind = Parameter.VAR_POSITIONAL
-            p(f.args.vararg, empty)
-
-        # keyword-only arguments
-        kind = Parameter.KEYWORD_ONLY
-        for name, default in zip(f.args.kwonlyargs, f.args.kw_defaults):
-            p(name, default)
-
-        # **kwargs
-        if f.args.kwarg:
-            kind = Parameter.VAR_KEYWORD
-            p(f.args.kwarg, empty)
-
-        if first_parameter_is_self:
-            assert parameters
-            if getattr(func, '__self__', None):
-                # strip off self, it's already been bound
-                parameters.pop(0)
-            else:
-                # for builtins, self parameter is always positional-only!
-                p = parameters[0].replace(kind=Parameter.POSITIONAL_ONLY)
-                parameters[0] = p
-
-        return cls(parameters, return_annotation=cls.empty)
+        return _signature_fromstr(cls, func, s)
 
     @property
     def parameters(self):
index efed714a26c34dd23b27f9fbced170cfba709c45..12a315ebb15a558eb618c458453272211601ec1b 100644 (file)
@@ -2210,6 +2210,32 @@ class TestSignatureObject(unittest.TestCase):
         self.assertEqual(str(inspect.signature(D)),
                          '(object_or_name, bases, dict)')
 
+    @unittest.skipIf(MISSING_C_DOCSTRINGS,
+                     "Signature information for builtins requires docstrings")
+    def test_signature_on_builtin_class(self):
+        self.assertEqual(str(inspect.signature(_pickle.Pickler)),
+                         '(file, protocol=None, fix_imports=True)')
+
+        class P(_pickle.Pickler): pass
+        class EmptyTrait: pass
+        class P2(EmptyTrait, P): pass
+        self.assertEqual(str(inspect.signature(P)),
+                         '(file, protocol=None, fix_imports=True)')
+        self.assertEqual(str(inspect.signature(P2)),
+                         '(file, protocol=None, fix_imports=True)')
+
+        class P3(P2):
+            def __init__(self, spam):
+                pass
+        self.assertEqual(str(inspect.signature(P3)), '(spam)')
+
+        class MetaP(type):
+            def __call__(cls, foo, bar):
+                pass
+        class P4(P2, metaclass=MetaP):
+            pass
+        self.assertEqual(str(inspect.signature(P4)), '(foo, bar)')
+
     def test_signature_on_callable_objects(self):
         class Foo:
             def __call__(self, a):