diff --git a/src/ape/managers/project.py b/src/ape/managers/project.py index 9759f07487..e4ea3b0b85 100644 --- a/src/ape/managers/project.py +++ b/src/ape/managers/project.py @@ -882,6 +882,69 @@ def remove(self, package_id: str, version: str): manifest_file.unlink(missing_ok=True) +def _version_to_options(version: str) -> tuple[str, ...]: + if version.startswith("v"): + # with the v, without + return (version, version[1:]) + + elif version and version[0].isnumeric(): + # without the v, and with. + return (version, f"v{version}") + + return (version,) + + +class DependencyVersionMap(dict[str, "ProjectManager"]): + """ + A mapping of versions to dependencies. + This class exists to allow both v-prefixed versions + as well none v-prefixed versions. + """ + + def __init__(self, name: str): + self._name = name + + @log_instead_of_fail(default="") + def __repr__(self) -> str: + keys = ",".join(list(self.keys())) + return f"<{self._name} versions='{keys}'>" + + def __contains__(self, version: Any) -> bool: + if not isinstance(version, str): + return False + + options = _version_to_options(version) + return any(dict.__contains__(self, v) for v in options) # type: ignore + + def __getitem__(self, version: str) -> "ProjectManager": + options = _version_to_options(version) + for vers in options: + if not dict.__contains__(self, vers): # type: ignore + continue + + # Found. + return dict.__getitem__(self, vers) # type: ignore + + raise KeyError(version) + + def get( # type: ignore + self, version: str, default: Optional["ProjectManager"] = None + ) -> Optional["ProjectManager"]: + options = _version_to_options(version) + for vers in options: + if not dict.__contains__(self, vers): # type: ignore + continue + + # Found. + return dict.get(self, vers) # type: ignore + + return default + + def extend(self, data: dict): + for key, val in data.items(): + self[key] = val + + class DependencyManager(BaseManager): """ Manage dependencies for an Ape project. @@ -912,22 +975,23 @@ def __len__(self) -> int: # NOTE: Using the config value keeps use lazy and fast. return len(self.project.config.dependencies) - def __getitem__(self, name: str) -> dict[str, "ProjectManager"]: - result: dict[str, "ProjectManager"] = {} + def __getitem__(self, name: str) -> DependencyVersionMap: + result = DependencyVersionMap(name) + + # Always ensure the specified are included, even if not yet installed. + if versions := {d.version: d.project for d in self._get_specified(name=name)}: + result.extend(versions) + + # Add other dependencies of the same package (different versions) + # that are also installed. for dependency in self.installed: if dependency.name != name: continue - result[dependency.version] = dependency.project - - if result: - return result - - # Try installing specified. - if versions := {d.version: d.project for d in self._get_specified(name=name)}: - return versions + if dependency.version not in result: + result[dependency.version] = dependency.project - return {} + return result def __contains__(self, name: str) -> bool: for dependency in self.installed: @@ -1164,12 +1228,7 @@ def get_dependency( Returns: class:`~ape.managers.project.Dependency` """ - version_options = [version] - if version.startswith("v"): - version_options.append(version[1:]) - elif version and version[0].isnumeric(): - # All try a v-prefix if using a numeric-like version. - version_options.append(f"v{version}") + version_options = _version_to_options(version) # Also try the lower of the name # so ``OpenZeppelin`` would give you ``openzeppelin``. diff --git a/tests/functional/test_dependencies.py b/tests/functional/test_dependencies.py index ad7f69e384..7e305495c2 100644 --- a/tests/functional/test_dependencies.py +++ b/tests/functional/test_dependencies.py @@ -163,6 +163,18 @@ def test_get_versions(project_with_downloaded_dependencies): assert len(actual) == 2 +def test_getitem_and_contains_and_get(project_with_downloaded_dependencies): + dm = project_with_downloaded_dependencies.dependencies + name = "openzeppelin" + versions = dm[name] + assert "3.1.0" in versions + assert "v3.1.0" in versions # Also allows v-prefix. + assert ( + versions["3.1.0"] == versions["v3.1.0"] == versions.get("3.1.0") == versions.get("v3.1.0") + ) + assert isinstance(versions["3.1.0"], LocalProject) + + def test_add(project): with project.isolate_in_tempdir() as tmp_project: contracts_path = tmp_project.path / "src" diff --git a/tests/functional/test_ecosystem.py b/tests/functional/test_ecosystem.py index b423141b12..998804a478 100644 --- a/tests/functional/test_ecosystem.py +++ b/tests/functional/test_ecosystem.py @@ -974,6 +974,7 @@ def test_default_network_name_when_not_set_and_no_local_uses_only( with project.temp_config(networks={"custom": [net]}): ecosystem = project.network_manager.get_ecosystem(ecosystem_name) + ecosystem._default_network = None actual = ecosystem.default_network_name if actual == LOCAL_NETWORK_NAME: