Skip to content

Commit

Permalink
intelligently complain about python 3 imports
Browse files Browse the repository at this point in the history
Don't complain all the time about print, division, etc. Instead, just
complain when the module in question actually uses a language feature
that would be affected.

travis rebuild
  • Loading branch information
terite committed Mar 28, 2017
1 parent 43eac27 commit 597e20e
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 12 deletions.
81 changes: 69 additions & 12 deletions flake8_future_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,61 @@
except ImportError as e:
argparse = e

from ast import NodeVisitor, Str, Module, parse
import ast

__version__ = '0.4.3'


class FutureImportVisitor(NodeVisitor):
class FutureImportVisitor(ast.NodeVisitor):

def __init__(self):
super(FutureImportVisitor, self).__init__()
self.future_imports = []

self._uses_code = False
self._uses_print = False
self._uses_division = False
self._uses_import = False
self._uses_str_literals = False
self._uses_generators = False
self._uses_with = False

def _is_print(self, node):
# python 2
if hasattr(ast, 'Print') and isinstance(node, ast.Print):
return True

# python 3
if isinstance(node, ast.Call) and \
isinstance(node.func, ast.Name) and \
node.func.id == 'print':
return True

return False

def visit_ImportFrom(self, node):
if node.module == '__future__':
self.future_imports += [node]

def visit_Expr(self, node):
if not isinstance(node.value, Str) or node.value.col_offset != 0:
self._uses_code = True
else:
self._uses_import = True

def generic_visit(self, node):
if not isinstance(node, Module):
if not isinstance(node, ast.Module):
self._uses_code = True

if isinstance(node, ast.Str):
self._uses_str_literals = True
elif self._is_print(node):
self._uses_print = True
elif isinstance(node, ast.Div):
self._uses_division = True
elif isinstance(node, ast.Import):
self._uses_import = True
elif isinstance(node, ast.With):
self._uses_with = True
elif isinstance(node, ast.Yield):
self._uses_generators = True

super(FutureImportVisitor, self).generic_visit(node)

@property
Expand Down Expand Up @@ -94,6 +126,7 @@ class FutureImportChecker(Flake8Argparse):
name = 'flake8-future-import'
require_code = True
min_version = False
require_used = False

def __init__(self, tree, filename):
self.tree = tree
Expand All @@ -106,6 +139,8 @@ def add_arguments(cls, parser):
parser.add_argument('--min-version', default=False,
help='The minimum version supported so that it can '
'ignore mandatory and non-existent features')
parser.add_argument('--require-used', action='store_true',
help='Only alert when relevant features are used')

@classmethod
def parse_options(cls, options):
Expand All @@ -122,6 +157,7 @@ def parse_options(cls, options):
'like "A.B.C"'.format(options.min_version))
min_version += (0, ) * (max(3 - len(min_version), 0))
cls.min_version = min_version
cls.require_used = options.require_used

def _generate_error(self, future_import, lineno, present):
feature = FEATURES.get(future_import)
Expand Down Expand Up @@ -156,10 +192,31 @@ def run(self):
yield err
present.add(alias.name)
for name in FEATURES:
if name not in present:
err = self._generate_error(name, 1, False)
if err:
yield err
if name in present:
continue

if self.require_used:
if name == 'print_function' and not visitor._uses_print:
continue

if name == 'division' and not visitor._uses_division:
continue

if name == 'absolute_import' and not visitor._uses_import:
continue

if name == 'unicode_literals' and not visitor._uses_str_literals:
continue

if name == 'generators' and not visitor._uses_generators:
continue

if name == 'with_statement' and not visitor._uses_with:
continue

err = self._generate_error(name, 1, False)
if err:
yield err


def main(args):
Expand Down Expand Up @@ -199,7 +256,7 @@ def main(args):
has_errors = False
for filename in args.files:
with open(filename, 'rb') as f:
tree = parse(f.read(), filename=filename, mode='exec')
tree = ast.parse(f.read(), filename=filename, mode='exec')
for line, char, msg, checker in FutureImportChecker(tree,
filename).run():
if msg[:4] not in ignored:
Expand Down
59 changes: 59 additions & 0 deletions test_flake8_future_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,5 +349,64 @@ class TestFeatures(TestCaseBase):
"""Verify that the features are up to date."""


class FeatureDetectionTestCase(TestCaseBase):

ALWAYS_MISSING = frozenset(('generator_stop', 'nested_scopes'))

def check_code(self, code):
tree = ast.parse(code)
checker = flake8_future_import.FutureImportChecker(tree, 'fn')
checker.require_used = True
iterator = self.iterator(checker)
return self.check_result(iterator)

def assert_errors(self, code, missing=None, forbidden=None):
missing = missing or set()
forbidden = forbidden or set()

found_missing, found_forbidden, _ = self.check_code(code)

self.assertEqual(missing, found_missing)
self.assertEqual(forbidden, found_forbidden)

def test_no_code(self):
self.assert_errors('')
self.assert_errors('# comment only')

def test_simple_statement(self):
self.assert_errors('1+1', missing=self.ALWAYS_MISSING)

def test_print_function(self):
self.assert_errors('print(foo)', self.ALWAYS_MISSING | set(['print_function']))

def test_unicode_literals(self):
expected_missing = self.ALWAYS_MISSING | set(['unicode_literals'])
self.assert_errors('"foo"', expected_missing)
self.assert_errors('u"foo"', expected_missing)
self.assert_errors('r"foo"', expected_missing)
self.assert_errors('fn("foo")', expected_missing)

def test_division(self):
# not division
self.assert_errors('a % b', self.ALWAYS_MISSING)

expected_missing = self.ALWAYS_MISSING | set(['division'])
self.assert_errors('1 / 0', expected_missing)
self.assert_errors('1 / 2 / 1', expected_missing)
self.assert_errors('a /= b', expected_missing)
self.assert_errors('fn(3 / 2)', expected_missing)

def test_absolute_import(self):
expected_missing = self.ALWAYS_MISSING | set(['absolute_import'])
self.assert_errors('import foo\npass', expected_missing)
self.assert_errors('from foo import bar\npass', expected_missing)

def test_with_statement(self):
self.assert_errors('with foo: foo()', self.ALWAYS_MISSING | set(['with_statement']))

def test_generators(self):
self.assert_errors('def foo(): yield', self.ALWAYS_MISSING | set(['generators']))


if __name__ == '__main__':
unittest.main()

0 comments on commit 597e20e

Please sign in to comment.