Skip to content

Commit

Permalink
fix: issue where v prefix mattered in dependency mapping (#2111)
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Jun 4, 2024
1 parent 4b69ce3 commit c37ca14
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 17 deletions.
93 changes: 76 additions & 17 deletions src/ape/managers/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,69 @@ def remove(self, package_id: str, version: str):
manifest_file.unlink(missing_ok=True)


def _version_to_options(version: str) -> tuple[str, ...]:
if version.startswith("v"):
# with the v, without
return (version, version[1:])

elif version and version[0].isnumeric():
# without the v, and with.
return (version, f"v{version}")

return (version,)


class DependencyVersionMap(dict[str, "ProjectManager"]):
"""
A mapping of versions to dependencies.
This class exists to allow both v-prefixed versions
as well none v-prefixed versions.
"""

def __init__(self, name: str):
self._name = name

@log_instead_of_fail(default="<DependencyVersionMap>")
def __repr__(self) -> str:
keys = ",".join(list(self.keys()))
return f"<{self._name} versions='{keys}'>"

def __contains__(self, version: Any) -> bool:
if not isinstance(version, str):
return False

options = _version_to_options(version)
return any(dict.__contains__(self, v) for v in options) # type: ignore

def __getitem__(self, version: str) -> "ProjectManager":
options = _version_to_options(version)
for vers in options:
if not dict.__contains__(self, vers): # type: ignore
continue

# Found.
return dict.__getitem__(self, vers) # type: ignore

raise KeyError(version)

def get( # type: ignore
self, version: str, default: Optional["ProjectManager"] = None
) -> Optional["ProjectManager"]:
options = _version_to_options(version)
for vers in options:
if not dict.__contains__(self, vers): # type: ignore
continue

# Found.
return dict.get(self, vers) # type: ignore

return default

def extend(self, data: dict):
for key, val in data.items():
self[key] = val


class DependencyManager(BaseManager):
"""
Manage dependencies for an Ape project.
Expand Down Expand Up @@ -912,22 +975,23 @@ def __len__(self) -> int:
# NOTE: Using the config value keeps use lazy and fast.
return len(self.project.config.dependencies)

def __getitem__(self, name: str) -> dict[str, "ProjectManager"]:
result: dict[str, "ProjectManager"] = {}
def __getitem__(self, name: str) -> DependencyVersionMap:
result = DependencyVersionMap(name)

# Always ensure the specified are included, even if not yet installed.
if versions := {d.version: d.project for d in self._get_specified(name=name)}:
result.extend(versions)

# Add other dependencies of the same package (different versions)
# that are also installed.
for dependency in self.installed:
if dependency.name != name:
continue

result[dependency.version] = dependency.project

if result:
return result

# Try installing specified.
if versions := {d.version: d.project for d in self._get_specified(name=name)}:
return versions
if dependency.version not in result:
result[dependency.version] = dependency.project

return {}
return result

def __contains__(self, name: str) -> bool:
for dependency in self.installed:
Expand Down Expand Up @@ -1164,12 +1228,7 @@ def get_dependency(
Returns:
class:`~ape.managers.project.Dependency`
"""
version_options = [version]
if version.startswith("v"):
version_options.append(version[1:])
elif version and version[0].isnumeric():
# All try a v-prefix if using a numeric-like version.
version_options.append(f"v{version}")
version_options = _version_to_options(version)

# Also try the lower of the name
# so ``OpenZeppelin`` would give you ``openzeppelin``.
Expand Down
12 changes: 12 additions & 0 deletions tests/functional/test_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,18 @@ def test_get_versions(project_with_downloaded_dependencies):
assert len(actual) == 2


def test_getitem_and_contains_and_get(project_with_downloaded_dependencies):
dm = project_with_downloaded_dependencies.dependencies
name = "openzeppelin"
versions = dm[name]
assert "3.1.0" in versions
assert "v3.1.0" in versions # Also allows v-prefix.
assert (
versions["3.1.0"] == versions["v3.1.0"] == versions.get("3.1.0") == versions.get("v3.1.0")
)
assert isinstance(versions["3.1.0"], LocalProject)


def test_add(project):
with project.isolate_in_tempdir() as tmp_project:
contracts_path = tmp_project.path / "src"
Expand Down
1 change: 1 addition & 0 deletions tests/functional/test_ecosystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,7 @@ def test_default_network_name_when_not_set_and_no_local_uses_only(

with project.temp_config(networks={"custom": [net]}):
ecosystem = project.network_manager.get_ecosystem(ecosystem_name)
ecosystem._default_network = None
actual = ecosystem.default_network_name

if actual == LOCAL_NETWORK_NAME:
Expand Down

0 comments on commit c37ca14

Please sign in to comment.