diff --git a/flake8_future_import.py b/flake8_future_import.py index c2a22ef..481d809 100755 --- a/flake8_future_import.py +++ b/flake8_future_import.py @@ -13,29 +13,61 @@ except ImportError as e: argparse = e -from ast import NodeVisitor, Str, Module, parse +import ast __version__ = '0.4.3' -class FutureImportVisitor(NodeVisitor): +class FutureImportVisitor(ast.NodeVisitor): def __init__(self): super(FutureImportVisitor, self).__init__() self.future_imports = [] + self._uses_code = False + self._uses_print = False + self._uses_division = False + self._uses_import = False + self._uses_str_literals = False + self._uses_generators = False + self._uses_with = False + + def _is_print(self, node): + # python 2 + if hasattr(ast, 'Print') and isinstance(node, ast.Print): + return True + + # python 3 + if isinstance(node, ast.Call) and \ + isinstance(node.func, ast.Name) and \ + node.func.id == 'print': + return True + + return False def visit_ImportFrom(self, node): if node.module == '__future__': self.future_imports += [node] - - def visit_Expr(self, node): - if not isinstance(node.value, Str) or node.value.col_offset != 0: - self._uses_code = True + else: + self._uses_import = True def generic_visit(self, node): - if not isinstance(node, Module): + if not isinstance(node, ast.Module): self._uses_code = True + + if isinstance(node, ast.Str): + self._uses_str_literals = True + elif self._is_print(node): + self._uses_print = True + elif isinstance(node, ast.Div): + self._uses_division = True + elif isinstance(node, ast.Import): + self._uses_import = True + elif isinstance(node, ast.With): + self._uses_with = True + elif isinstance(node, ast.Yield): + self._uses_generators = True + super(FutureImportVisitor, self).generic_visit(node) @property @@ -94,6 +126,7 @@ class FutureImportChecker(Flake8Argparse): name = 'flake8-future-import' require_code = True min_version = False + require_used = False def __init__(self, tree, filename): self.tree = tree @@ -106,6 +139,8 @@ def add_arguments(cls, parser): parser.add_argument('--min-version', default=False, help='The minimum version supported so that it can ' 'ignore mandatory and non-existent features') + parser.add_argument('--require-used', action='store_true', + help='Only alert when relevant features are used') @classmethod def parse_options(cls, options): @@ -122,6 +157,7 @@ def parse_options(cls, options): 'like "A.B.C"'.format(options.min_version)) min_version += (0, ) * (max(3 - len(min_version), 0)) cls.min_version = min_version + cls.require_used = options.require_used def _generate_error(self, future_import, lineno, present): feature = FEATURES.get(future_import) @@ -156,10 +192,31 @@ def run(self): yield err present.add(alias.name) for name in FEATURES: - if name not in present: - err = self._generate_error(name, 1, False) - if err: - yield err + if name in present: + continue + + if self.require_used: + if name == 'print_function' and not visitor._uses_print: + continue + + if name == 'division' and not visitor._uses_division: + continue + + if name == 'absolute_import' and not visitor._uses_import: + continue + + if name == 'unicode_literals' and not visitor._uses_str_literals: + continue + + if name == 'generators' and not visitor._uses_generators: + continue + + if name == 'with_statement' and not visitor._uses_with: + continue + + err = self._generate_error(name, 1, False) + if err: + yield err def main(args): @@ -199,7 +256,7 @@ def main(args): has_errors = False for filename in args.files: with open(filename, 'rb') as f: - tree = parse(f.read(), filename=filename, mode='exec') + tree = ast.parse(f.read(), filename=filename, mode='exec') for line, char, msg, checker in FutureImportChecker(tree, filename).run(): if msg[:4] not in ignored: diff --git a/test_flake8_future_import.py b/test_flake8_future_import.py index 64aed86..3da9c11 100644 --- a/test_flake8_future_import.py +++ b/test_flake8_future_import.py @@ -349,5 +349,64 @@ class TestFeatures(TestCaseBase): """Verify that the features are up to date.""" +class FeatureDetectionTestCase(TestCaseBase): + + ALWAYS_MISSING = frozenset(('generator_stop', 'nested_scopes')) + + def check_code(self, code): + tree = ast.parse(code) + checker = flake8_future_import.FutureImportChecker(tree, 'fn') + checker.require_used = True + iterator = self.iterator(checker) + return self.check_result(iterator) + + def assert_errors(self, code, missing=None, forbidden=None): + missing = missing or set() + forbidden = forbidden or set() + + found_missing, found_forbidden, _ = self.check_code(code) + + self.assertEqual(missing, found_missing) + self.assertEqual(forbidden, found_forbidden) + + def test_no_code(self): + self.assert_errors('') + self.assert_errors('# comment only') + + def test_simple_statement(self): + self.assert_errors('1+1', missing=self.ALWAYS_MISSING) + + def test_print_function(self): + self.assert_errors('print(foo)', self.ALWAYS_MISSING | set(['print_function'])) + + def test_unicode_literals(self): + expected_missing = self.ALWAYS_MISSING | set(['unicode_literals']) + self.assert_errors('"foo"', expected_missing) + self.assert_errors('u"foo"', expected_missing) + self.assert_errors('r"foo"', expected_missing) + self.assert_errors('fn("foo")', expected_missing) + + def test_division(self): + # not division + self.assert_errors('a % b', self.ALWAYS_MISSING) + + expected_missing = self.ALWAYS_MISSING | set(['division']) + self.assert_errors('1 / 0', expected_missing) + self.assert_errors('1 / 2 / 1', expected_missing) + self.assert_errors('a /= b', expected_missing) + self.assert_errors('fn(3 / 2)', expected_missing) + + def test_absolute_import(self): + expected_missing = self.ALWAYS_MISSING | set(['absolute_import']) + self.assert_errors('import foo\npass', expected_missing) + self.assert_errors('from foo import bar\npass', expected_missing) + + def test_with_statement(self): + self.assert_errors('with foo: foo()', self.ALWAYS_MISSING | set(['with_statement'])) + + def test_generators(self): + self.assert_errors('def foo(): yield', self.ALWAYS_MISSING | set(['generators'])) + + if __name__ == '__main__': unittest.main()