Skip to content

Commit

Permalink
Fix bug that caused import errors when flake8 was executed as a stand…
Browse files Browse the repository at this point in the history
…alone.
  • Loading branch information
atollk authored Aug 25, 2023
1 parent 6a9a698 commit 6154605
Show file tree
Hide file tree
Showing 16 changed files with 55 additions and 22 deletions.
43 changes: 24 additions & 19 deletions flake8_import_restrictions/imports_submodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,32 @@ def imports_submodule(
:return None, if an error occurs, e.g. the given file is not part of any directory in sys.path. Otherwise,
a bool is returned that is True if and only if the imported object is a module.
"""
if level > 0:
try:
filename = os.path.dirname(_rel_to_sys_path(filename))
except TypeError:
return None
package = ".".join(filename.split(os.path.sep))
else:
package = None

old_sys_path = sys.path
try:
parent = importlib.import_module("." * level + from_, package)
except (ImportError, TypeError):
return None
if not hasattr(parent, import_):
sys.path += [os.getcwd()]
if level > 0:
try:
filename = os.path.dirname(_rel_to_sys_path(filename))
except TypeError:
return None
package = ".".join(filename.split(os.path.sep))
else:
package = None

try:
importlib.import_module(
"." * level + from_ + "." + import_, package
)
except ImportError:
return False
return isinstance(getattr(parent, import_), types.ModuleType)
parent = importlib.import_module("." * level + from_, package)
except (ImportError, TypeError):
return None
if not hasattr(parent, import_):
try:
importlib.import_module(
"." * level + from_ + "." + import_, package
)
except ImportError:
return False
return isinstance(getattr(parent, import_), types.ModuleType)
finally:
sys.path = old_sys_path


def _rel_to_sys_path(path: str) -> str:
Expand Down
1 change: 1 addition & 0 deletions tests/test_imports_submodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def test_absolute_1():


def test_absolute_2():
assert imports_submodule(FILE1, 0, "tests.resources", "a")
assert imports_submodule(FILE1, 0, "tests.resources", "b")
assert not imports_submodule(FILE1, 0, "tests.resources.b", "B")
assert not imports_submodule(FILE1, 0, "tests.resources.b", "C")
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
14 changes: 14 additions & 0 deletions tests/test_i2041.py → tests/test_imr241.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,17 @@ def test_fail_2(self):
"""
result = self.run_flake8(code)
self.assert_error_at(result, "IMR241", 2, 1)

def test_issue_14(self):
main_code = """
from test.test2 import testmodule
from test import test2
"""
files = {
"main.py": main_code,
"test/__init__.py": "",
"test/test2/__init__.py": "",
"test/test2/testmodule.py": "",
}
result = self.run_flake8_multifile(files)
assert result == []
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
19 changes: 16 additions & 3 deletions tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import dataclasses
import re
import textwrap
from typing import List

from typing import List, Dict
import pytest_flake8_path
import pytest


Expand All @@ -27,9 +27,22 @@ def error_code(self) -> str:
raise NotImplementedError

@pytest.fixture(autouse=True)
def _flake8dir(self, flake8_path):
def _flake8dir(self, flake8_path: pytest_flake8_path.Flake8Path):
self.flake8_path = flake8_path

def run_flake8_multifile(self, files: Dict[str, str]):
for fname, code in files.items():
(self.flake8_path / fname).parent.mkdir(parents=True, exist_ok=True)
(self.flake8_path / fname).write_text(textwrap.dedent(code))
args = [f"--{self.error_code().lower()}_include=*", "--select=IMR"]
result = self.flake8_path.run_flake8(args)
reports = [
ReportedMessage.from_raw(report) for report in result.out_lines
]
return [
report for report in reports if report.code == self.error_code()
]

def run_flake8(self, code: str) -> List[ReportedMessage]:
(self.flake8_path / "example.py").write_text(textwrap.dedent(code))
args = [f"--{self.error_code().lower()}_include=*", "--select=IMR"]
Expand Down

0 comments on commit 6154605

Please sign in to comment.