From 5665aba555c798c48a5cd230facc20aa3727f573 Mon Sep 17 00:00:00 2001 From: antazoey Date: Fri, 31 May 2024 20:31:34 -0500 Subject: [PATCH] fix: issue where call coverage didn't increment (#2105) --- src/ape/pytest/coverage.py | 47 ++++++++++------ tests/functional/conftest.py | 8 ++- tests/functional/test_coverage.py | 89 ++++++++++++++++++++++++++++--- 3 files changed, 122 insertions(+), 22 deletions(-) diff --git a/src/ape/pytest/coverage.py b/src/ape/pytest/coverage.py index 39fc300418..d21a928377 100644 --- a/src/ape/pytest/coverage.py +++ b/src/ape/pytest/coverage.py @@ -7,6 +7,7 @@ from ethpm_types.source import ContractSource from ape.logging import logger +from ape.managers import ProjectManager from ape.pytest.config import ConfigWrapper from ape.types import ( ContractFunctionPath, @@ -22,8 +23,8 @@ class CoverageData(ManagerAccessMixin): - def __init__(self, base_path: Path, sources: Iterable[ContractSource]): - self.base_path = base_path + def __init__(self, project: ProjectManager, sources: Iterable[ContractSource]): + self.project = project self.sources = list(sources) self._report: Optional[CoverageReport] = None self._init_coverage_profile() # Inits self._report. @@ -43,7 +44,7 @@ def _init_coverage_profile( self, ) -> CoverageReport: # source_id -> pc(s) -> times hit - project_coverage = CoverageProject(name=self.config_manager.name or "__local__") + project_coverage = CoverageProject(name=self.project.name or "__local__") for src in self.sources: source_cov = project_coverage.include(src) @@ -60,11 +61,11 @@ def _init_coverage_profile( timestamp = get_current_timestamp_ms() report = CoverageReport( projects=[project_coverage], - source_folders=[self.local_project.contracts_folder], + source_folders=[self.project.contracts_folder], timestamp=timestamp, ) - # Remove emptys. + # Remove empties. for project in report.projects: project.sources = [x for x in project.sources if len(x.statements) > 0] @@ -74,7 +75,11 @@ def _init_coverage_profile( def cover( self, src_path: Path, pcs: Iterable[int], inc_fn_hits: bool = True ) -> tuple[set[int], list[str]]: - source_id = str(get_relative_path(src_path.absolute(), self.base_path)) + if hasattr(self.project, "path"): + source_id = str(get_relative_path(src_path.absolute(), self.project.path)) + else: + source_id = str(src_path) + if source_id not in self.report.sources: # The source is not tracked for coverage. return set(), [] @@ -120,14 +125,27 @@ def cover( class CoverageTracker(ManagerAccessMixin): - def __init__(self, config_wrapper: ConfigWrapper): + def __init__( + self, + config_wrapper: ConfigWrapper, + project: Optional[ProjectManager] = None, + output_path: Optional[Path] = None, + ): self.config_wrapper = config_wrapper - sources = self.local_project._contract_sources + self._project = project or self.local_project + + if output_path: + self._output_path = output_path + elif hasattr(self._project, "manifest_path"): + # Local project. + self._output_path = self._project.manifest_path.parent + else: + self._output_path = Path.cwd() + + sources = self._project._contract_sources self.data: Optional[CoverageData] = ( - CoverageData(self.local_project.path, sources) - if self.config_wrapper.track_coverage - else None + CoverageData(self._project, sources) if self.config_wrapper.track_coverage else None ) @property @@ -180,7 +198,7 @@ def cover( for src in project.sources: # NOTE: We will allow this check to skip if there is no source is the # traceback. This helps increment methods that are missing from the source map. - path = self.local_project.contracts_folder / src.source_id + path = self._project.path / src.source_id if source_path is not None and path != source_path: continue @@ -279,7 +297,6 @@ def show_session_coverage(self) -> bool: # Reports are set in ape-config.yaml. reports = self.config_wrapper.ape_test_config.coverage.reports - out_folder = self.local_project.manifest_path.parent if reports.terminal: verbose = ( reports.terminal.get("verbose", False) @@ -308,9 +325,9 @@ def show_session_coverage(self) -> bool: click.echo() if self.config_wrapper.xml_coverage: - self.data.report.write_xml(out_folder) + self.data.report.write_xml(self._output_path) if value := self.config_wrapper.html_coverage: verbose = value.get("verbose", False) if isinstance(value, dict) else False - self.data.report.write_html(out_folder, verbose=verbose) + self.data.report.write_html(self._output_path, verbose=verbose) return True diff --git a/tests/functional/conftest.py b/tests/functional/conftest.py index 19a0fc2ed8..47785204d2 100644 --- a/tests/functional/conftest.py +++ b/tests/functional/conftest.py @@ -680,6 +680,8 @@ def mock_compiler(mocker): mock.name = "mock" mock.ext = ".__mock__" mock.tracked_settings = [] + mock.ast = None + mock.pcmap = None def mock_compile(paths, project=None, settings=None): settings = settings or {} @@ -691,10 +693,14 @@ def mock_compile(paths, project=None, settings=None): code = HexBytes(123).hex() data = { "contractName": name, - "abi": [], + "abi": mock.abi, "deploymentBytecode": code, "sourceId": f"{project.contracts_folder.name}/{path.name}", } + if ast := mock.ast: + data["ast"] = ast + if pcmap := mock.pcmap: + data["pcmap"] = pcmap # Check for mocked overrides overrides = mock.overrides diff --git a/tests/functional/test_coverage.py b/tests/functional/test_coverage.py index 2bfaf191da..2b9eb5a68c 100644 --- a/tests/functional/test_coverage.py +++ b/tests/functional/test_coverage.py @@ -1,10 +1,13 @@ from pathlib import Path import pytest +from ethpm_types import MethodABI from ethpm_types.source import ContractSource, Source +import ape from ape.pytest.config import ConfigWrapper from ape.pytest.coverage import CoverageData, CoverageTracker +from ape.types import SourceTraceback from ape.types.coverage import ( ContractCoverage, ContractSourceCoverage, @@ -172,7 +175,7 @@ def contract_source(self, vyper_contract_type, src): @pytest.fixture(scope="class") def coverage_data(self, project, contract_source): - return CoverageData(project.path, (contract_source,)) + return CoverageData(project, (contract_source,)) def test_report(self, coverage_data): actual = coverage_data.report @@ -184,13 +187,87 @@ class TestCoverageTracker: def pytest_config(self, mocker): return mocker.MagicMock() - @pytest.fixture(scope="class") + @pytest.fixture def config_wrapper(self, pytest_config): return ConfigWrapper(pytest_config) - def test_data(self, pytest_config): - tracker = CoverageTracker(pytest_config) + @pytest.fixture + def tracker(self, pytest_config): + return CoverageTracker(pytest_config) + + def test_data(self, tracker): assert tracker.data is not None - actual = tracker.data.base_path - expected = tracker.local_project.path + actual = tracker.data.project + expected = tracker.local_project assert actual == expected + + def test_cover(self, mocker, pytest_config, compilers, mock_compiler): + """ + Ensure coverage of a call works. + """ + filestem = "atest" + filename = f"{filestem}.__mock__" + fn_name = "_a_method" + + # Set up the mock compiler. + mock_compiler.abi = [MethodABI(name=fn_name)] + mock_compiler.ast = { + "src": "0:112:0", + "name": filename, + "end_lineno": 7, + "lineno": 1, + "ast_type": "Module", + } + mock_compiler.pcmap = {"0": {"location": (1, 7, 1, 7)}} + mock_contract = mocker.MagicMock() + mock_contract.name = filename + mock_statement = mocker.MagicMock() + mock_statement.pcs = {20} + mock_statement.hit_count = 0 + mock_function = mocker.MagicMock() + mock_function.name = fn_name + mock_function.statements = [mock_statement] + mock_contract.functions = [mock_function] + mock_contract.statements = [mock_statement] + + def init_profile(source_cov, src): + source_cov.contracts = [mock_contract] + + mock_compiler.init_coverage_profile.side_effect = init_profile + + stmt = {"type": "dev: Cannot send ether to non-payable function", "pcs": [20]} + fn_name = "_a_method" + tb_data = { + "statements": [stmt], + "closure": {"name": fn_name, "full_name": f"{fn_name}()"}, + "depth": 0, + } + + with ape.Project.create_temporary_project() as tmp: + # Create a source file. + file = tmp.path / "contracts" / filename + file.parent.mkdir(exist_ok=True, parents=True) + file.write_text("testing") + + # Ensure the TB refers to this source. + tb_data["source_path"] = f"{tmp.path}/contracts/{filename}" + call_tb = SourceTraceback.model_validate([tb_data]) + + try: + # Hack in our mock compiler. + _ = compilers.registered_compilers # Ensure cache is exists. + compilers.__dict__["registered_compilers"][mock_compiler.ext] = mock_compiler + + # Ensure our coverage tracker is using our new tmp project w/ the new src + # as well is set _after_ our new compiler plugin is added. + tracker = CoverageTracker(pytest_config, project=tmp) + + tracker.cover(call_tb, contract=filestem, function=f"{fn_name}()") + assert mock_statement.hit_count > 0 + + finally: + if ( + "registered_compilers" in compilers.__dict__ + and mock_compiler.ext in compilers.__dict__["registered_compilers"] + ): + del compilers.__dict__["registered_compilers"][mock_compiler.ext]