explicit = False # Is this ignored by refactor.py -f all?
run_order = 5 # Fixers will be sorted by run order before execution
# Lower numbers will be run first.
+ _accept_type = None # [Advanced and not public] This tells RefactoringTool
+ # which node type to accept when there's not a pattern.
# Shortcut for access to Python grammar symbols
syms = pygram.python_symbols
# Local imports
from .. import fixer_base
-from os.path import dirname, join, exists, pathsep
+from os.path import dirname, join, exists, sep
from ..fixer_util import FromImport, syms, token
# so can't be a relative import.
if not exists(join(dirname(base_path), '__init__.py')):
return False
- for ext in ['.py', pathsep, '.pyc', '.so', '.sl', '.pyd']:
+ for ext in ['.py', sep, '.pyc', '.so', '.sl', '.pyd']:
if exists(base_path + ext):
return True
return False
class FixImports(fixer_base.BaseFix):
- order = "pre" # Pre-order tree traversal
-
# This is overridden in fix_imports2.
mapping = MAPPING
"""
# Local imports
-from .. import fixer_base
-from ..fixer_util import Name, Number, is_probably_builtin
+from lib2to3 import fixer_base
+from lib2to3.fixer_util import is_probably_builtin
class FixLong(fixer_base.BaseFix):
PATTERN = "'long'"
- static_int = Name(u"int")
-
def transform(self, node, results):
if is_probably_builtin(node):
- new = self.static_int.clone()
- new.prefix = node.prefix
- return new
+ node.value = u"int"
+ node.changed()
class FixNe(fixer_base.BaseFix):
# This is so simple that we don't need the pattern compiler.
+ _accept_type = token.NOTEQUAL
+
def match(self, node):
# Override
- return node.type == token.NOTEQUAL and node.value == u"<>"
+ return node.value == u"<>"
def transform(self, node, results):
new = pytree.Leaf(token.NOTEQUAL, u"!=", prefix=node.prefix)
class FixNumliterals(fixer_base.BaseFix):
# This is so simple that we don't need the pattern compiler.
+ _accept_type = token.NUMBER
+
def match(self, node):
# Override
- return (node.type == token.NUMBER and
- (node.value.startswith(u"0") or node.value[-1] in u"Ll"))
+ return (node.value.startswith(u"0") or node.value[-1] in u"Ll")
def transform(self, node, results):
val = node.value
--- /dev/null
+"""Fixer for operator.{isCallable,sequenceIncludes}
+
+operator.isCallable(obj) -> hasattr(obj, '__call__')
+operator.sequenceIncludes(obj) -> operator.contains(obj)
+"""
+
+# Local imports
+from .. import fixer_base
+from ..fixer_util import Call, Name, String
+
+class FixOperator(fixer_base.BaseFix):
+
+ methods = "method=('isCallable'|'sequenceIncludes')"
+ func = "'(' func=any ')'"
+ PATTERN = """
+ power< module='operator'
+ trailer< '.' {methods} > trailer< {func} > >
+ |
+ power< {methods} trailer< {func} > >
+ """.format(methods=methods, func=func)
+
+ def transform(self, node, results):
+ method = results["method"][0]
+
+ if method.value == u"sequenceIncludes":
+ if "module" not in results:
+ # operator may not be in scope, so we can't make a change.
+ self.warning(node, "You should use operator.contains here.")
+ else:
+ method.value = u"contains"
+ method.changed()
+ elif method.value == u"isCallable":
+ if "module" not in results:
+ self.warning(node,
+ "You should use hasattr(%s, '__call__') here." %
+ results["func"].value)
+ else:
+ func = results["func"]
+ args = [func.clone(), String(u", "), String(u"'__call__'")]
+ return Call(Name(u"hasattr"), args, prefix=node.prefix)
)
-class FixPrint(fixer_base.ConditionalFix):
+class FixPrint(fixer_base.BaseFix):
PATTERN = """
simple_stmt< any* bare='print' any* > | print_stmt
"""
- skip_on = '__future__.print_function'
-
def transform(self, node, results):
assert results
- if self.should_skip(node):
- return
-
bare_print = results.get("bare")
if bare_print:
MAPPING = {'urllib': [
('urllib.request',
['URLOpener', 'FancyURLOpener', 'urlretrieve',
- '_urlopener', 'urlopen', 'urlcleanup']),
+ '_urlopener', 'urlopen', 'urlcleanup',
+ 'pathname2url', 'url2pathname']),
('urllib.parse',
['quote', 'quote_plus', 'unquote', 'unquote_plus',
- 'urlencode', 'pathname2url', 'url2pathname', 'splitattr',
- 'splithost', 'splitnport', 'splitpasswd', 'splitport',
- 'splitquery', 'splittag', 'splittype', 'splituser',
- 'splitvalue', ]),
+ 'urlencode', 'splitattr', 'splithost', 'splitnport',
+ 'splitpasswd', 'splitport', 'splitquery', 'splittag',
+ 'splittype', 'splituser', 'splitvalue', ]),
('urllib.error',
['ContentTooShortError'])],
'urllib2' : [
import sys
import os
+import difflib
import logging
import shutil
import optparse
from . import refactor
+
+def diff_texts(a, b, filename):
+ """Return a unified diff of two strings."""
+ a = a.splitlines()
+ b = b.splitlines()
+ return difflib.unified_diff(a, b, filename, filename,
+ "(original)", "(refactored)",
+ lineterm="")
+
+
class StdoutRefactoringTool(refactor.MultiprocessRefactoringTool):
"""
Prints output to stdout.
"""
- def __init__(self, fixers, options, explicit, nobackups):
+ def __init__(self, fixers, options, explicit, nobackups, show_diffs):
self.nobackups = nobackups
+ self.show_diffs = show_diffs
super(StdoutRefactoringTool, self).__init__(fixers, options, explicit)
def log_error(self, msg, *args, **kwargs):
if not self.nobackups:
shutil.copymode(backup, filename)
- def print_output(self, lines):
- for line in lines:
- print line
+ def print_output(self, old, new, filename, equal):
+ if equal:
+ self.log_message("No changes to %s", filename)
+ else:
+ self.log_message("Refactored %s", filename)
+ if self.show_diffs:
+ for line in diff_texts(old, new, filename):
+ print line
+
+
+def warn(msg):
+ print >> sys.stderr, "WARNING: %s" % (msg,)
def main(fixer_pkg, args=None):
parser.add_option("-l", "--list-fixes", action="store_true",
help="List available transformations (fixes/fix_*.py)")
parser.add_option("-p", "--print-function", action="store_true",
- help="Modify the grammar so that print() is a function")
+ help="DEPRECATED Modify the grammar so that print() is "
+ "a function")
parser.add_option("-v", "--verbose", action="store_true",
help="More verbose logging")
+ parser.add_option("--no-diffs", action="store_true",
+ help="Don't show diffs of the refactoring")
parser.add_option("-w", "--write", action="store_true",
help="Write back modified files")
parser.add_option("-n", "--nobackups", action="store_true", default=False,
# Parse command line arguments
refactor_stdin = False
options, args = parser.parse_args(args)
+ if not options.write and options.no_diffs:
+ warn("not writing files and not printing diffs; that's not very useful")
+ if options.print_function:
+ warn("-p is deprecated; "
+ "detection of from __future__ import print_function is automatic")
if not options.write and options.nobackups:
parser.error("Can't use -n without -w")
if options.list_fixes:
if not args:
return 0
if not args:
- print >>sys.stderr, "At least one file or directory argument required."
- print >>sys.stderr, "Use --help to show usage."
+ print >> sys.stderr, "At least one file or directory argument required."
+ print >> sys.stderr, "Use --help to show usage."
return 2
if "-" in args:
refactor_stdin = True
if options.write:
- print >>sys.stderr, "Can't write to stdin."
+ print >> sys.stderr, "Can't write to stdin."
return 2
# Set up logging handler
logging.basicConfig(format='%(name)s: %(message)s', level=level)
# Initialize the refactoring tool
- rt_opts = {"print_function" : options.print_function}
avail_fixes = set(refactor.get_fixers_from_package(fixer_pkg))
unwanted_fixes = set(fixer_pkg + ".fix_" + fix for fix in options.nofix)
explicit = set()
else:
requested = avail_fixes.union(explicit)
fixer_names = requested.difference(unwanted_fixes)
- rt = StdoutRefactoringTool(sorted(fixer_names), rt_opts, sorted(explicit),
- options.nobackups)
+ rt = StdoutRefactoringTool(sorted(fixer_names), None, sorted(explicit),
+ options.nobackups, not options.no_diffs)
# Refactor all files and directories passed as arguments
if not rt.errors:
import os
# Fairly local imports
-from .pgen2 import driver, literals, token, tokenize, parse
+from .pgen2 import driver, literals, token, tokenize, parse, grammar
# Really local imports
from . import pytree
node = nodes[0]
if node.type == token.STRING:
value = unicode(literals.evalString(node.value))
- return pytree.LeafPattern(content=value)
+ return pytree.LeafPattern(_type_of_literal(value), value)
elif node.type == token.NAME:
value = node.value
if value.isupper():
"TOKEN": None}
+def _type_of_literal(value):
+ if value[0].isalpha():
+ return token.NAME
+ elif value in grammar.opmap:
+ return grammar.opmap[value]
+ else:
+ return None
+
+
def pattern_convert(grammar, raw_node_info):
"""Converts raw node information to a Node or Leaf instance."""
type, value, context, children = raw_node_info
f.close()
self.__dict__.update(d)
+ def copy(self):
+ """
+ Copy the grammar.
+ """
+ new = self.__class__()
+ for dict_attr in ("symbol2number", "number2symbol", "dfas", "keywords",
+ "tokens", "symbol2label"):
+ setattr(new, dict_attr, getattr(self, dict_attr).copy())
+ new.labels = self.labels[:]
+ new.states = self.states[:]
+ new.start = self.start
+ return new
+
def report(self):
"""Dump the grammar tables to standard output, for debugging."""
from pprint import pprint
python_grammar = driver.load_grammar(_GRAMMAR_FILE)
+
python_symbols = Symbols(python_grammar)
+
+python_grammar_no_print_statement = python_grammar.copy()
+del python_grammar_no_print_statement.keywords["print"]
# Python imports
import os
import sys
-import difflib
import logging
import operator
-from collections import defaultdict
+import collections
+import StringIO
+import warnings
from itertools import chain
# Local imports
-from .pgen2 import driver, tokenize
+from .pgen2 import driver, tokenize, token
from . import pytree, pygram
fix_names.append(name[:-3])
return fix_names
-def get_head_types(pat):
+
+class _EveryNode(Exception):
+ pass
+
+
+def _get_head_types(pat):
""" Accepts a pytree Pattern Node and returns a set
of the pattern types which will match first. """
# NodePatters must either have no type and no content
# or a type and content -- so they don't get any farther
# Always return leafs
+ if pat.type is None:
+ raise _EveryNode
return set([pat.type])
if isinstance(pat, pytree.NegatedPattern):
if pat.content:
- return get_head_types(pat.content)
- return set([None]) # Negated Patterns don't have a type
+ return _get_head_types(pat.content)
+ raise _EveryNode # Negated Patterns don't have a type
if isinstance(pat, pytree.WildcardPattern):
# Recurse on each node in content
r = set()
for p in pat.content:
for x in p:
- r.update(get_head_types(x))
+ r.update(_get_head_types(x))
return r
raise Exception("Oh no! I don't understand pattern %s" %(pat))
-def get_headnode_dict(fixer_list):
+
+def _get_headnode_dict(fixer_list):
""" Accepts a list of fixers and returns a dictionary
of head node type --> fixer list. """
- head_nodes = defaultdict(list)
+ head_nodes = collections.defaultdict(list)
+ every = []
for fixer in fixer_list:
- if not fixer.pattern:
- head_nodes[None].append(fixer)
- continue
- for t in get_head_types(fixer.pattern):
- head_nodes[t].append(fixer)
- return head_nodes
+ if fixer.pattern:
+ try:
+ heads = _get_head_types(fixer.pattern)
+ except _EveryNode:
+ every.append(fixer)
+ else:
+ for node_type in heads:
+ head_nodes[node_type].append(fixer)
+ else:
+ if fixer._accept_type is not None:
+ head_nodes[fixer._accept_type].append(fixer)
+ else:
+ every.append(fixer)
+ for node_type in chain(pygram.python_grammar.symbol2number.itervalues(),
+ pygram.python_grammar.tokens):
+ head_nodes[node_type].extend(every)
+ return dict(head_nodes)
+
def get_fixers_from_package(pkg_name):
"""
_to_system_newlines = _identity
+def _detect_future_print(source):
+ have_docstring = False
+ gen = tokenize.generate_tokens(StringIO.StringIO(source).readline)
+ def advance():
+ tok = next(gen)
+ return tok[0], tok[1]
+ ignore = frozenset((token.NEWLINE, tokenize.NL, token.COMMENT))
+ try:
+ while True:
+ tp, value = advance()
+ if tp in ignore:
+ continue
+ elif tp == token.STRING:
+ if have_docstring:
+ break
+ have_docstring = True
+ elif tp == token.NAME:
+ if value == u"from":
+ tp, value = advance()
+ if tp != token.NAME and value != u"__future__":
+ break
+ tp, value = advance()
+ if tp != token.NAME and value != u"import":
+ break
+ tp, value = advance()
+ if tp == token.OP and value == u"(":
+ tp, value = advance()
+ while tp == token.NAME:
+ if value == u"print_function":
+ return True
+ tp, value = advance()
+ if tp != token.OP and value != u",":
+ break
+ tp, value = advance()
+ else:
+ break
+ else:
+ break
+ except StopIteration:
+ pass
+ return False
+
+
class FixerError(Exception):
"""A fixer could not be loaded."""
class RefactoringTool(object):
- _default_options = {"print_function": False}
+ _default_options = {}
CLASS_PREFIX = "Fix" # The prefix for fixer classes
FILE_PREFIX = "fix_" # The prefix for modules with a fixer within
self.explicit = explicit or []
self.options = self._default_options.copy()
if options is not None:
+ if "print_function" in options:
+ warnings.warn("the 'print_function' option is deprecated",
+ DeprecationWarning)
self.options.update(options)
self.errors = []
self.logger = logging.getLogger("RefactoringTool")
self.fixer_log = []
self.wrote = False
- if self.options["print_function"]:
- del pygram.python_grammar.keywords["print"]
self.driver = driver.Driver(pygram.python_grammar,
convert=pytree.convert,
logger=self.logger)
self.pre_order, self.post_order = self.get_fixers()
- self.pre_order_heads = get_headnode_dict(self.pre_order)
- self.post_order_heads = get_headnode_dict(self.post_order)
+ self.pre_order_heads = _get_headnode_dict(self.pre_order)
+ self.post_order_heads = _get_headnode_dict(self.post_order)
self.files = [] # List of files that were or should be modified
msg = msg % args
self.logger.debug(msg)
- def print_output(self, lines):
- """Called with lines of output to give to the user."""
+ def print_output(self, old_text, new_text, filename, equal):
+ """Called with the old version, new version, and filename of a
+ refactored file."""
pass
def refactor(self, items, write=False, doctests_only=False):
dirnames.sort()
filenames.sort()
for name in filenames:
- if not name.startswith(".") and name.endswith("py"):
+ if not name.startswith(".") and \
+ os.path.splitext(name)[1].endswith("py"):
fullname = os.path.join(dirpath, name)
self.refactor_file(fullname, write, doctests_only)
# Modify dirnames in-place to remove subdirs with leading dots
An AST corresponding to the refactored input stream; None if
there were errors during the parse.
"""
+ if _detect_future_print(data):
+ self.driver.grammar = pygram.python_grammar_no_print_statement
try:
tree = self.driver.parse_string(data)
except Exception, err:
self.log_error("Can't parse %s: %s: %s",
name, err.__class__.__name__, err)
return
+ finally:
+ self.driver.grammar = pygram.python_grammar
self.log_debug("Refactoring %s", name)
self.refactor_tree(tree, name)
return tree
else:
tree = self.refactor_string(input, "<stdin>")
if tree and tree.was_changed:
- self.processed_file(str(tree), "<stdin>", input)
+ self.processed_file(unicode(tree), "<stdin>", input)
else:
self.log_debug("No changes in stdin")
if not fixers:
return
for node in traversal:
- for fixer in fixers[node.type] + fixers[None]:
+ for fixer in fixers[node.type]:
results = fixer.match(node)
if results:
new = fixer.transform(node, results)
- if new is not None and (new != node or
- str(new) != str(node)):
+ if new is not None:
node.replace(new)
node = new
old_text = self._read_python_source(filename)[0]
if old_text is None:
return
- if old_text == new_text:
+ equal = old_text == new_text
+ self.print_output(old_text, new_text, filename, equal)
+ if equal:
self.log_debug("No changes to %s", filename)
return
- self.print_output(diff_texts(old_text, new_text, filename))
if write:
self.write_file(new_text, filename, old_text, encoding)
else:
filename, lineno, err.__class__.__name__, err)
return block
if self.refactor_tree(tree, filename):
- new = str(tree).splitlines(True)
+ new = unicode(tree).splitlines(True)
# Undo the adjustment of the line numbers in wrap_toks() below.
clipped, new = new[:lineno-1], new[lineno-1:]
assert clipped == [u"\n"] * (lineno-1), clipped
else:
return super(MultiprocessRefactoringTool, self).refactor_file(
*args, **kwargs)
-
-
-def diff_texts(a, b, filename):
- """Return a unified diff of two strings."""
- a = a.splitlines()
- b = b.splitlines()
- return difflib.unified_diff(a, b, filename, filename,
- "(original)", "(refactored)",
- lineterm="")
#!/usr/bin/env python
-# -*- coding: iso-8859-1 -*-
-print u'ßàáâãäåæçèéêëìíîïðñòóôõöøùúûüýþÿÀÁÂÃÄÅÆÇÈÉÊËÌÍÎÏÐÑÒÓÔÕÖØÙÚÛÜÝÞ'
+# -*- coding: utf-8 -*-
+print u'ßàáâãäåæçèéêëìíîïðñòóôõöøùúûüýþÿÀÁÂÃÄÅÆÇÈÉÊËÌÍÎÏÐÑÒÓÔÕÖØÙÚÛÜÝÞ'
+
+def f(x):
+ print '%s\t-> α(%2i):%s β(%s)'
def setUp(self, fix_list=None, fixer_pkg="lib2to3", options=None):
if fix_list is None:
fix_list = [self.fixer]
- if options is None:
- options = {"print_function" : False}
self.refactor = support.get_refactorer(fixer_pkg, fix_list, options)
self.fixer_log = []
self.filename = u"<string>"
def assert_runs_after(self, *names):
fixes = [self.fixer]
fixes.extend(names)
- options = {"print_function" : False}
- r = support.get_refactorer("lib2to3", fixes, options)
+ r = support.get_refactorer("lib2to3", fixes)
(pre, post) = r.get_fixers()
n = "fix_" + self.fixer
if post and post[-1].__class__.__module__.endswith(n):
self.unchanged(s)
def test_idempotency_print_as_function(self):
- print_stmt = pygram.python_grammar.keywords.pop("print")
- try:
- s = """print(1, 1+1, 1+1+1)"""
- self.unchanged(s)
+ self.refactor.driver.grammar = pygram.python_grammar_no_print_statement
+ s = """print(1, 1+1, 1+1+1)"""
+ self.unchanged(s)
- s = """print()"""
- self.unchanged(s)
+ s = """print()"""
+ self.unchanged(s)
- s = """print('')"""
- self.unchanged(s)
- finally:
- pygram.python_grammar.keywords["print"] = print_stmt
+ s = """print('')"""
+ self.unchanged(s)
def test_1(self):
b = """print 1, 1+1, 1+1+1"""
a = """print(file=sys.stderr)"""
self.check(b, a)
- # With from __future__ import print_function
def test_with_future_print_function(self):
- # XXX: These tests won't actually do anything until the parser
- # is fixed so it won't crash when it sees print(x=y).
- # When #2412 is fixed, the try/except block can be taken
- # out and the tests can be run like normal.
- # MvL: disable entirely for now, so that it doesn't print to stdout
- return
- try:
- s = "from __future__ import print_function\n"\
- "print('Hai!', end=' ')"
- self.unchanged(s)
+ s = "from __future__ import print_function\n" \
+ "print('Hai!', end=' ')"
+ self.unchanged(s)
- b = "print 'Hello, world!'"
- a = "print('Hello, world!')"
- self.check(b, a)
+ b = "print 'Hello, world!'"
+ a = "print('Hello, world!')"
+ self.check(b, a)
- s = "from __future__ import *\n"\
- "print('Hai!', end=' ')"
- self.unchanged(s)
- except:
- return
- else:
- self.assertFalse(True, "#2421 has been fixed -- printing tests "\
- "need to be updated!")
class Test_exec(FixerTestCase):
fixer = "exec"
for key in ('dbhash', 'dumbdbm', 'dbm', 'gdbm'):
self.modules[key] = mapping1[key]
+ def test_after_local_imports_refactoring(self):
+ for fix in ("imports", "imports2"):
+ self.fixer = fix
+ self.assert_runs_after("import")
+
class Test_urllib(FixerTestCase):
fixer = "urllib"
s = "from itertools import foo"
self.unchanged(s)
+
class Test_import(FixerTestCase):
fixer = "import"
self.always_exists = False
self.present_files = set(['__init__.py'])
- expected_extensions = ('.py', os.path.pathsep, '.pyc', '.so',
- '.sl', '.pyd')
+ expected_extensions = ('.py', os.path.sep, '.pyc', '.so', '.sl', '.pyd')
names_to_test = (p("/spam/eggs.py"), "ni.py", p("../../shrubbery.py"))
for name in names_to_test:
self.present_files = set(["__init__.py", "bar.py"])
self.check(b, a)
+ def test_import_from_package(self):
+ b = "import bar"
+ a = "from . import bar"
+ self.always_exists = False
+ self.present_files = set(["__init__.py", "bar/"])
+ self.check(b, a)
+
def test_comments_and_indent(self):
b = "import bar # Foo"
a = "from . import bar # Foo"
b = """os.getcwdu ( )"""
a = """os.getcwd ( )"""
self.check(b, a)
+
+
+class Test_operator(FixerTestCase):
+
+ fixer = "operator"
+
+ def test_operator_isCallable(self):
+ b = "operator.isCallable(x)"
+ a = "hasattr(x, '__call__')"
+ self.check(b, a)
+
+ def test_operator_sequenceIncludes(self):
+ b = "operator.sequenceIncludes(x, y)"
+ a = "operator.contains(x, y)"
+ self.check(b, a)
+
+ def test_bare_isCallable(self):
+ s = "isCallable(x)"
+ self.warns_unchanged(s, "You should use hasattr(x, '__call__') here.")
+
+ def test_bare_sequenceIncludes(self):
+ s = "sequenceIncludes(x, y)"
+ self.warns_unchanged(s, "You should use operator.contains here.")
import operator
import StringIO
import tempfile
+import shutil
import unittest
+import warnings
from lib2to3 import refactor, pygram, fixer_base
+from lib2to3.pgen2 import token
from . import support
return refactor.RefactoringTool(fixers, options, explicit)
def test_print_function_option(self):
- gram = pygram.python_grammar
- save = gram.keywords["print"]
- try:
- rt = self.rt({"print_function" : True})
- self.assertRaises(KeyError, operator.itemgetter("print"),
- gram.keywords)
- finally:
- gram.keywords["print"] = save
+ with warnings.catch_warnings(record=True) as w:
+ refactor.RefactoringTool(_DEFAULT_FIXERS, {"print_function" : True})
+ self.assertEqual(len(w), 1)
+ msg, = w
+ self.assertTrue(msg.category is DeprecationWarning)
def test_fixer_loading_helpers(self):
contents = ["explicit", "first", "last", "parrot", "preorder"]
self.assertEqual(full_names,
["myfixes.fix_" + name for name in contents])
+ def test_detect_future_print(self):
+ run = refactor._detect_future_print
+ self.assertFalse(run(""))
+ self.assertTrue(run("from __future__ import print_function"))
+ self.assertFalse(run("from __future__ import generators"))
+ self.assertFalse(run("from __future__ import generators, feature"))
+ input = "from __future__ import generators, print_function"
+ self.assertTrue(run(input))
+ input ="from __future__ import print_function, generators"
+ self.assertTrue(run(input))
+ input = "from __future__ import (print_function,)"
+ self.assertTrue(run(input))
+ input = "from __future__ import (generators, print_function)"
+ self.assertTrue(run(input))
+ input = "from __future__ import (generators, nested_scopes)"
+ self.assertFalse(run(input))
+ input = """from __future__ import generators
+from __future__ import print_function"""
+ self.assertTrue(run(input))
+ self.assertFalse(run("from"))
+ self.assertFalse(run("from 4"))
+ self.assertFalse(run("from x"))
+ self.assertFalse(run("from x 5"))
+ self.assertFalse(run("from x im"))
+ self.assertFalse(run("from x import"))
+ self.assertFalse(run("from x import 4"))
+ input = "'docstring'\nfrom __future__ import print_function"
+ self.assertTrue(run(input))
+ input = "'docstring'\n'somng'\nfrom __future__ import print_function"
+ self.assertFalse(run(input))
+ input = "# comment\nfrom __future__ import print_function"
+ self.assertTrue(run(input))
+ input = "# comment\n'doc'\nfrom __future__ import print_function"
+ self.assertTrue(run(input))
+ input = "class x: pass\nfrom __future__ import print_function"
+ self.assertFalse(run(input))
+
def test_get_headnode_dict(self):
class NoneFix(fixer_base.BaseFix):
- PATTERN = None
+ pass
class FileInputFix(fixer_base.BaseFix):
PATTERN = "file_input< any * >"
+ class SimpleFix(fixer_base.BaseFix):
+ PATTERN = "'name'"
+
no_head = NoneFix({}, [])
with_head = FileInputFix({}, [])
- d = refactor.get_headnode_dict([no_head, with_head])
- expected = {None: [no_head],
- pygram.python_symbols.file_input : [with_head]}
- self.assertEqual(d, expected)
+ simple = SimpleFix({}, [])
+ d = refactor._get_headnode_dict([no_head, with_head, simple])
+ top_fixes = d.pop(pygram.python_symbols.file_input)
+ self.assertEqual(top_fixes, [with_head, no_head])
+ name_fixes = d.pop(token.NAME)
+ self.assertEqual(name_fixes, [simple, no_head])
+ for fixes in d.itervalues():
+ self.assertEqual(fixes, [no_head])
def test_fixer_loading(self):
from myfixes.fix_first import FixFirst
class MyRT(refactor.RefactoringTool):
- def print_output(self, lines):
- diff_lines.extend(lines)
+ def print_output(self, old_text, new_text, filename, equal):
+ results.extend([old_text, new_text, filename, equal])
- diff_lines = []
+ results = []
rt = MyRT(_DEFAULT_FIXERS)
save = sys.stdin
sys.stdin = StringIO.StringIO("def parrot(): pass\n\n")
rt.refactor_stdin()
finally:
sys.stdin = save
- expected = """--- <stdin> (original)
-+++ <stdin> (refactored)
-@@ -1,2 +1,2 @@
--def parrot(): pass
-+def cheese(): pass""".splitlines()
- self.assertEqual(diff_lines[:-1], expected)
+ expected = ["def parrot(): pass\n\n",
+ "def cheese(): pass\n\n",
+ "<stdin>", False]
+ self.assertEqual(results, expected)
def check_file_refactoring(self, test_file, fixers=_2TO3_FIXERS):
def read_file():
test_file = os.path.join(FIXER_DIR, "parrot_example.py")
self.check_file_refactoring(test_file, _DEFAULT_FIXERS)
+ def test_refactor_dir(self):
+ def check(structure, expected):
+ def mock_refactor_file(self, f, *args):
+ got.append(f)
+ save_func = refactor.RefactoringTool.refactor_file
+ refactor.RefactoringTool.refactor_file = mock_refactor_file
+ rt = self.rt()
+ got = []
+ dir = tempfile.mkdtemp(prefix="2to3-test_refactor")
+ try:
+ os.mkdir(os.path.join(dir, "a_dir"))
+ for fn in structure:
+ open(os.path.join(dir, fn), "wb").close()
+ rt.refactor_dir(dir)
+ finally:
+ refactor.RefactoringTool.refactor_file = save_func
+ shutil.rmtree(dir)
+ self.assertEqual(got,
+ [os.path.join(dir, path) for path in expected])
+ check([], [])
+ tree = ["nothing",
+ "hi.py",
+ ".dumb",
+ ".after.py",
+ "sappy"]
+ expected = ["hi.py"]
+ check(tree, expected)
+ tree = ["hi.py",
+ "a_dir/stuff.py"]
+ check(tree, tree)
+
def test_file_encoding(self):
fn = os.path.join(TEST_DATA_DIR, "different_encoding.py")
self.check_file_refactoring(fn)
-""" Test suite for the code in fixes.util """
+""" Test suite for the code in fixer_util """
# Testing imports
from . import support
import os.path
# Local imports
-from .. import pytree
-from .. import fixer_util
-from ..fixer_util import Attr, Name
-
+from lib2to3.pytree import Node, Leaf
+from lib2to3 import fixer_util
+from lib2to3.fixer_util import Attr, Name, Call, Comma
+from lib2to3.pgen2 import token
def parse(code, strip_levels=0):
# The topmost node is file_input, which we don't care about.
class MacroTestCase(support.TestCase):
def assertStr(self, node, string):
if isinstance(node, (tuple, list)):
- node = pytree.Node(fixer_util.syms.simple_stmt, node)
+ node = Node(fixer_util.syms.simple_stmt, node)
self.assertEqual(str(node), string)
self.assertStr(Name("a", prefix="b"), "ba")
+class Test_Call(MacroTestCase):
+ def _Call(self, name, args=None, prefix=None):
+ """Help the next test"""
+ children = []
+ if isinstance(args, list):
+ for arg in args:
+ children.append(arg)
+ children.append(Comma())
+ children.pop()
+ return Call(Name(name), children, prefix)
+
+ def test(self):
+ kids = [None,
+ [Leaf(token.NUMBER, 1), Leaf(token.NUMBER, 2),
+ Leaf(token.NUMBER, 3)],
+ [Leaf(token.NUMBER, 1), Leaf(token.NUMBER, 3),
+ Leaf(token.NUMBER, 2), Leaf(token.NUMBER, 4)],
+ [Leaf(token.STRING, "b"), Leaf(token.STRING, "j", prefix=" ")]
+ ]
+ self.assertStr(self._Call("A"), "A()")
+ self.assertStr(self._Call("b", kids[1]), "b(1,2,3)")
+ self.assertStr(self._Call("a.b().c", kids[2]), "a.b().c(1,3,2,4)")
+ self.assertStr(self._Call("d", kids[3], prefix=" "), " d(b, j)")
+
+
class Test_does_tree_import(support.TestCase):
def _find_bind_rec(self, name, node):
# Search a tree for a binding -- used to find the starting