Skip to content

Commit

Permalink
perf: massive performance improvements for project / sources manager (A…
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Jul 24, 2024
1 parent 489e059 commit 09cb802
Show file tree
Hide file tree
Showing 10 changed files with 128 additions and 62 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ jobs:
- name: Run Integration Tests
run: pytest tests/integration -m "not fuzzing" -s --cov=src --cov-append -n auto --dist loadgroup

- name: Run Performance Tests
run: pytest tests/performance -s

fuzzing:
runs-on: ubuntu-latest

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"pytest-xdist>=3.6.1,<4", # Multi-process runner
"pytest-cov>=4.0.0,<5", # Coverage analyzer plugin
"pytest-mock", # For creating mocks
"pytest-benchmark", # For performance tests
"pytest-timeout>=2.2.0,<3", # For avoiding timing out during tests
"hypothesis>=6.2.0,<7.0", # Strategy-based fuzzer
"hypothesis-jsonschema==0.19.0", # JSON Schema fuzzer extension
Expand Down
89 changes: 65 additions & 24 deletions src/ape/managers/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ class SourceManager(BaseManager):

_path_cache: Optional[list[Path]] = None

# perf: calculating paths from source Ids can be expensive.
_path_to_source_id: dict[Path, str] = {}

def __init__(
self,
root_path: Path,
Expand Down Expand Up @@ -113,7 +116,7 @@ def get(self, source_id: str) -> Optional[Source]:
text: Union[str, dict]
if path.is_file():
try:
text = path.read_text()
text = path.read_text(encoding="utf8")
except Exception:
continue

Expand Down Expand Up @@ -158,7 +161,7 @@ def __contains_path(self, source_path: Path) -> bool:

return False

@property
@cached_property
def _all_files(self) -> list[Path]:
try:
contracts_folder = self.get_contracts_path()
Expand Down Expand Up @@ -306,8 +309,23 @@ def find_in_dir(dir_path: Path, path: Path) -> Optional[Path]:
relative_path = input_path.relative_to(input_path.anchor)
return find_in_dir(self.root_path, relative_path)

def refresh(self):
"""
Reset file-caches to handle session-changes.
(Typically not needed to be called by users).
"""
(self.__dict__ or {}).pop("_all_files", None)
self._path_to_source_id = {}
self._path_cache = None

def _get_source_id(self, path: Path) -> str:
return _path_to_source_id(path, self.root_path)
if src_id := self._path_to_source_id.get(path):
return src_id

# Cache because this can be expensive.
src_id = _path_to_source_id(path, self.root_path)
self._path_to_source_id[path] = src_id
return src_id

def _get_path(self, source_id: str) -> Path:
return self.root_path / source_id
Expand Down Expand Up @@ -350,7 +368,9 @@ def __iter__(self) -> Iterator[str]:

yield ct.name

def get(self, name: str, compile_missing: bool = True) -> Optional[ContractContainer]:
def get(
self, name: str, compile_missing: bool = True, check_for_changes: bool = True
) -> Optional[ContractContainer]:
"""
Get a contract by name.
Expand All @@ -359,23 +379,34 @@ def get(self, name: str, compile_missing: bool = True) -> Optional[ContractConta
compile_missing (bool): Set to ``False`` to not attempt compiling
if the contract can't be found. Note: modified sources are
re-compiled regardless of this flag.
check_for_changes (bool): Set to ``False`` if avoiding checking
for changes.
Returns:
ContractContainer | None
"""

existing_types = self.project.manifest.contract_types or {}
if contract_type := existing_types.get(name):
source_id = contract_type.source_id or ""
ext = get_full_extension(source_id)
contract_type = existing_types.get(name)

if not contract_type:
if compile_missing:
self._compile_missing_contracts(self.sources.paths)
return self.get(name, compile_missing=False)

return None

source_id = contract_type.source_id or ""
source_found = source_id in self.sources

if not check_for_changes and source_found:
return ContractContainer(contract_type)

# Allow us to still get previously-compiled contracts if don't
# have the compiler plugin installed at this time.
if ext not in self.compiler_manager.registered_compilers:
return ContractContainer(contract_type)
ext = get_full_extension(source_id)
if ext not in self.compiler_manager.registered_compilers:
return ContractContainer(contract_type)

elif source_id in self.sources and self._detect_change(source_id):
# Previous cache is outdated.
if source_found:
if check_for_changes and self._detect_change(source_id):
compiled = {
ct.name: ct
for ct in self.compiler_manager.compile(source_id, project=self.project)
Expand All @@ -386,12 +417,9 @@ def get(self, name: str, compile_missing: bool = True) -> Optional[ContractConta
if name in compiled:
return ContractContainer(compiled[name])

elif source_id in self.sources:
# Cached and already compiled.
return ContractContainer(contract_type)
return ContractContainer(contract_type)

if compile_missing:
# Try again after compiling all missing.
self._compile_missing_contracts(self.sources.paths)
return self.get(name, compile_missing=False)

Expand Down Expand Up @@ -2025,8 +2053,9 @@ def __init__(
manifest_path: Optional[Path] = None,
config_override: Optional[dict] = None,
) -> None:
# A local project uses a special manifest.
self._session_source_change_check: set[str] = set()
self.path = Path(path).resolve()
# A local project uses a special manifest.
self.manifest_path = manifest_path or self.path / ".build" / "__local__.json"
manifest = self.load_manifest()

Expand Down Expand Up @@ -2375,14 +2404,17 @@ def load_manifest(self) -> PackageManifest:
return manifest

def get_contract(self, name: str) -> Any:
if name in dir(self):
return self.__getattribute__(name)
if name in self._session_source_change_check:
check_for_changes = False
else:
check_for_changes = True
self._session_source_change_check.add(name)

elif contract := self.contracts.get(name):
contract = self.contracts.get(name, check_for_changes=check_for_changes)
if contract:
contract.base_path = self.path
return contract

return None
return contract

def update_manifest(self, **kwargs):
# Update the manifest in memory.
Expand Down Expand Up @@ -2461,6 +2493,15 @@ def reload_config(self):
self._clear_cached_config()
_ = self.config

def refresh_sources(self):
"""
Check for file-changes. Typically, you don't need to call this method.
This method exists for when changing files mid-session, you can "refresh"
and Ape will know about the changes.
"""
self._session_source_change_check = set()
self.sources.refresh()

def _clear_cached_config(self):
if "config" in self.__dict__:
del self.__dict__["config"]
Expand Down
59 changes: 29 additions & 30 deletions src/ape/utils/os.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from typing import Any, Optional, Union


# TODO: This method is no longer needed since the dropping of 3.9
# Delete in Ape 0.9 release.
def is_relative_to(path: Path, target: Path) -> bool:
"""
Search a path and determine its relevancy.
Expand All @@ -22,15 +24,7 @@ def is_relative_to(path: Path, target: Path) -> bool:
Returns:
bool: ``True`` if the path is relative to the target path or ``False``.
"""
if hasattr(path, "is_relative_to"):
# NOTE: Only available ``>=3.9``
return target.is_relative_to(path) # type: ignore

else:
try:
return target.relative_to(path) is not None
except ValueError:
return False
return target.is_relative_to(path)


def get_relative_path(target: Path, anchor: Path) -> Path:
Expand All @@ -52,15 +46,17 @@ def get_relative_path(target: Path, anchor: Path) -> Path:
if not anchor.is_absolute():
raise ValueError("'anchor' must be an absolute path.")

anchor_copy = Path(str(anchor))
levels_deep = 0
while not is_relative_to(anchor_copy, target):
levels_deep += 1
anchor_copy = anchor_copy.parent
# Calculate common prefix length
common_parts = 0
for target_part, anchor_part in zip(target.parts, anchor.parts):
if target_part == anchor_part:
common_parts += 1
else:
break

return Path("/".join(".." for _ in range(levels_deep))).joinpath(
str(target.relative_to(anchor_copy))
)
# Calculate the relative path
relative_parts = [".."] * (len(anchor.parts) - common_parts) + list(target.parts[common_parts:])
return Path(*relative_parts)


def get_all_files_in_directory(
Expand Down Expand Up @@ -98,13 +94,15 @@ def get_all_files_in_directory(
elif pattern is not None:
pattern_obj = pattern

# is dir
result: list[Path] = []
for file in (p for p in path.rglob("*.*") if p.is_file()):
if (max_files is None or max_files is not None and len(result) < max_files) and (
pattern_obj is None or pattern_obj.match(file.name)
):
result.append(file)
append_result = result.append # Local variable for faster access
for file in path.rglob("*.*"):
if not file.is_file() or (pattern_obj is not None and not pattern_obj.match(file.name)):
continue

append_result(file)
if max_files is not None and len(result) >= max_files:
break

return result

Expand Down Expand Up @@ -170,17 +168,18 @@ def get_full_extension(path: Union[Path, str]) -> str:
return ""

path = Path(path)
if path.is_dir():
if path.is_dir() or path.suffix == "":
return ""

parts = path.name.split(".")
start_idx = 2 if path.name.startswith(".") else 1
name = path.name
parts = name.split(".")

# NOTE: Handles when given just `.hiddenFile` since slice indices
# may exceed their bounds.
suffix = ".".join(parts[start_idx:])
if len(parts) > 2 and name.startswith("."):
return "." + ".".join(parts[2:])
elif len(parts) > 1:
return "." + ".".join(parts[1:])

return f".{suffix}" if suffix and f".{suffix}" != f"{path.name}" else ""
return ""


@contextmanager
Expand Down
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ape.contracts import ContractContainer
from ape.exceptions import APINotImplementedError, ProviderNotConnectedError, UnknownSnapshotError
from ape.logging import LogLevel, logger
from ape.managers.project import Project
from ape.pytest.config import ConfigWrapper
from ape.pytest.gas import GasTracker
from ape.types import AddressType, CurrencyValue
Expand Down Expand Up @@ -684,3 +685,8 @@ def vyper_contract_container(vyper_contract_type) -> ContractContainer:
@pytest.fixture(scope="session")
def shared_contracts_folder():
return SHARED_CONTRACTS_FOLDER


@pytest.fixture
def project_with_contracts(with_dependencies_project_path):
return Project(with_dependencies_project_path)
5 changes: 5 additions & 0 deletions tests/functional/test_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,11 @@ def test_compile(self, project):
'[{"name":"foo","type":"fallback", "stateMutability":"nonpayable"}]',
encoding="utf8",
)

# Since we are adding a file mid-session, we have to refresh so
# it's picked up. Users typically don't have to do this.
dependency.project.refresh_sources()

result = dependency.compile()
assert len(result) == 1
assert result["CCC"].name == "CCC"
Expand Down
11 changes: 6 additions & 5 deletions tests/functional/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,6 @@
from tests.conftest import skip_if_plugin_installed


@pytest.fixture
def project_with_contracts(with_dependencies_project_path):
return Project(with_dependencies_project_path)


@pytest.fixture
def tmp_project(with_dependencies_project_path):
real_project = Project(with_dependencies_project_path)
Expand Down Expand Up @@ -383,6 +378,10 @@ def test_load_contracts_after_deleting_same_named_contract(tmp_project, compiler
# Goodbye.
init_contract.unlink()

# Since we are changing files mid-session, we need to refresh the project.
# Typically, users don't have to do this.
tmp_project.refresh_sources()

result = tmp_project.load_contracts()
assert "foo" not in result # Was deleted.
# Also ensure it is gone from paths.
Expand All @@ -391,6 +390,8 @@ def test_load_contracts_after_deleting_same_named_contract(tmp_project, compiler
# Create a new contract with the same name.
new_contract = tmp_project.contracts_folder / "bar.__mock__"
new_contract.write_text("BAZ", encoding="utf8")
tmp_project.refresh_sources()

mock_compiler.overrides = {"contractName": "foo"}
result = tmp_project.load_contracts()
assert "foo" in result
Expand Down
6 changes: 3 additions & 3 deletions tests/integration/cli/test_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def test_test(setup_pytester, integ_project, pytester, eth_tester_provider):
from ape.logging import logger

logger.set_level("DEBUG")
result = pytester.runpytest_subprocess()
result = pytester.runpytest_subprocess(timeout=30)
try:
result.assert_outcomes(passed=passed, failed=failed), "\n".join(result.outlines)
except ValueError:
Expand All @@ -189,7 +189,7 @@ def test_test(setup_pytester, integ_project, pytester, eth_tester_provider):
def test_uncaught_txn_err(setup_pytester, integ_project, pytester, eth_tester_provider):
_ = eth_tester_provider # Ensure using EthTester for this test.
setup_pytester(integ_project)
result = pytester.runpytest_subprocess()
result = pytester.runpytest_subprocess(timeout=30)
expected = """
contract_in_test.setNumber(5, sender=owner)
E ape.exceptions.ContractLogicError: Transaction failed.
Expand Down Expand Up @@ -280,7 +280,7 @@ def test_gas_when_estimating(geth_provider, setup_pytester, integ_project, pytes
geth_account.transfer(geth_account, "1 wei") # Force a clean block.
with integ_project.temp_config(**cfg):
passed, failed = setup_pytester(integ_project)
result = pytester.runpytest_subprocess()
result = pytester.runpytest_subprocess(timeout=30)
run_gas_test(result, passed, failed)


Expand Down
Empty file added tests/performance/__init__.py
Empty file.
10 changes: 10 additions & 0 deletions tests/performance/test_project.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
def test_get_contract(benchmark, project_with_contracts):
benchmark.pedantic(
lambda *args, **kwargs: project_with_contracts.get_contract(*args, **kwargs),
args=(("Other",),),
rounds=5,
warmup_rounds=1, # It's always slower the first time, a little bit.
)
stats = benchmark.stats
median = stats.get("median")
assert median < 0.0002

0 comments on commit 09cb802

Please sign in to comment.