Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: issue when switching providers when disconnecting [APE-1460] #1700

Merged
merged 19 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: check-yaml

Expand All @@ -10,7 +10,7 @@ repos:
- id: isort

- repo: https://github.com/psf/black
rev: 23.9.1
rev: 23.10.1
hooks:
- id: black
name: black
Expand All @@ -21,7 +21,7 @@ repos:
- id: flake8

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.5.1
rev: v1.6.1
hooks:
- id: mypy
additional_dependencies: [
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
"hypothesis-jsonschema==0.19.0", # JSON Schema fuzzer extension
],
"lint": [
"black>=23.9.1,<24", # Auto-formatter and linter
"mypy>=1.5.1,<2", # Static type analyzer
"black>=23.10.1,<24", # Auto-formatter and linter
"mypy>=1.6.1,<2", # Static type analyzer
"types-PyYAML", # Needed due to mypy typeshed
"types-requests", # Needed due to mypy typeshed
"types-setuptools", # Needed due to mypy typeshed
Expand Down
68 changes: 39 additions & 29 deletions src/ape/api/networks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
from functools import partial
from pathlib import Path
from tempfile import mkdtemp
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, Type, Union
from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
Iterator,
List,
Optional,
Tuple,
Type,
Union,
)

from eth_account import Account as EthAccount
from eth_account._utils.legacy_transactions import (
Expand Down Expand Up @@ -430,7 +441,9 @@ def get_network(self, network_name: str) -> "NetworkAPI":

raise NetworkNotFoundError(network_name, ecosystem=self.name, options=self.networks)

def get_network_data(self, network_name: str) -> Dict:
def get_network_data(
self, network_name: str, provider_filter: Optional[Collection[str]] = None
) -> Dict:
"""
Get a dictionary of data about providers in the network.

Expand All @@ -439,6 +452,8 @@ def get_network_data(self, network_name: str) -> Dict:

Args:
network_name (str): The name of the network to get provider data from.
provider_filter (Optional[Collection[str]]): Optionally filter the providers
by name.

Returns:
dict: A dictionary containing the providers in a network.
Expand All @@ -456,6 +471,9 @@ def get_network_data(self, network_name: str) -> Dict:
data["explorer"] = str(network.explorer.name)

for provider_name in network.providers:
if provider_filter and provider_name not in provider_filter:
continue

provider_data: Dict = {"name": str(provider_name)}

# Only add isDefault key when True
Expand Down Expand Up @@ -573,22 +591,22 @@ def push_provider(self):
if must_connect:
self._provider.connect()

provider_id = self.get_provider_id(self._provider)
if provider_id is None:
connection_id = self._provider.connection_id
if connection_id is None:
raise ProviderNotConnectedError()

self.provider_stack.append(provider_id)
self.disconnect_map[provider_id] = self._disconnect_after
if provider_id in self.connected_providers:
self.provider_stack.append(connection_id)
self.disconnect_map[connection_id] = self._disconnect_after
if connection_id in self.connected_providers:
# Using already connected instance
if must_connect:
# Disconnect if had to connect to check chain ID
self._provider.disconnect()

self._provider = self.connected_providers[provider_id]
self._provider = self.connected_providers[connection_id]
else:
# Adding provider for the first time. Retain connection.
self.connected_providers[provider_id] = self._provider
self.connected_providers[connection_id] = self._provider

self.network_manager.active_provider = self._provider
return self._provider
Expand All @@ -598,24 +616,28 @@ def pop_provider(self):
return

# Clear last provider
exiting_provider_id = self.provider_stack.pop()
current_id = self.provider_stack.pop()

# Disconnect the provider in same cases.
if self.disconnect_map[exiting_provider_id]:
if self.disconnect_map[current_id]:
if provider := self.network_manager.active_provider:
provider.disconnect()

del self.disconnect_map[current_id]
if current_id in self.connected_providers:
del self.connected_providers[current_id]

if not self.provider_stack:
self.network_manager.active_provider = None
return

# Reset the original active provider
previous_provider_id = self.provider_stack[-1]
if previous_provider_id == exiting_provider_id:
prior_id = self.provider_stack[-1]
if prior_id == current_id:
# Active provider is not changing
return

if previous_provider := self.connected_providers[previous_provider_id]:
if previous_provider := self.connected_providers[prior_id]:
self.network_manager.active_provider = previous_provider

def disconnect_all(self):
Expand All @@ -628,13 +650,6 @@ def disconnect_all(self):
self.network_manager.active_provider = None
self.connected_providers = {}

@classmethod
def get_provider_id(cls, provider: "ProviderAPI") -> Optional[str]:
if not provider.is_connected:
return None

return f"{provider.network_choice}:-{provider.chain_id}"


class NetworkAPI(BaseInterfaceModel):
"""
Expand Down Expand Up @@ -865,14 +880,9 @@ def get_provider(

if provider_name in self.providers:
provider = self.providers[provider_name](provider_settings=provider_settings)

provider_id = ProviderContextManager.get_provider_id(provider)
if not provider_id:
# Provider not yet connected
return provider

if provider_id in ProviderContextManager.connected_providers:
return ProviderContextManager.connected_providers[provider_id]
if provider.connection_id in ProviderContextManager.connected_providers:
# Likely multi-chain testing or utilizing multiple on-going connections.
return ProviderContextManager.connected_providers[provider.connection_id]

return provider

Expand Down
41 changes: 39 additions & 2 deletions src/ape/api/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class ProviderAPI(BaseInterfaceModel):
network: NetworkAPI
"""A reference to the network this provider provides."""

provider_settings: dict
provider_settings: Dict = {}
"""The settings for the provider, as overrides to the configuration."""

data_folder: Path
Expand Down Expand Up @@ -184,6 +184,37 @@ def ws_uri(self) -> Optional[str]:
"""
return None

@property
def settings(self) -> PluginConfig:
"""
The combination of settings from ``ape-config.yaml`` and ``.provider_settings``.
"""
CustomConfig = self.config.__class__
data = {**self.config.dict(), **self.provider_settings}
return CustomConfig.parse_obj(data)

@property
def connection_id(self) -> Optional[str]:
"""
A connection ID to uniquely identify and manage multiple
connections to providers, especially when working with multiple
providers of the same type, like multiple Geth --dev nodes.
"""

try:
chain_id = self.chain_id
except Exception:
if chain_id := self.settings.get("chain_id"):
pass

else:
# A connection is required to obtain a chain ID for this provider.
return None

# NOTE: If other provider settings are different, ``.update_settings()``
# should be called.
return f"{self.network_choice}:{chain_id}"

@abstractmethod
def update_settings(self, new_settings: dict):
"""
Expand Down Expand Up @@ -1536,7 +1567,8 @@ def _create_trace_frame(self, evm_frame: EvmTraceFrame) -> TraceFrame:
raw=evm_frame.dict(),
)

def _make_request(self, endpoint: str, parameters: List) -> Any:
def _make_request(self, endpoint: str, parameters: Optional[List] = None) -> Any:
parameters = parameters or []
coroutine = self.web3.provider.make_request(RPCEndpoint(endpoint), parameters)
result = run_until_complete(coroutine)

Expand Down Expand Up @@ -1712,6 +1744,11 @@ def _stdout_logger(self) -> Logger:
def _stderr_logger(self) -> Logger:
return self._get_process_output_logger("stderr", self.stderr_logs_path)

@property
def connection_id(self) -> Optional[str]:
cmd_id = ",".join(self.build_command())
return f"{self.network_choice}:{cmd_id}"

def _get_process_output_logger(self, name: str, path: Path):
logger = getLogger(f"{self.name}_{name}_subprocessProviderLogger")
path.parent.mkdir(parents=True, exist_ok=True)
Expand Down
31 changes: 27 additions & 4 deletions src/ape/managers/networks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from functools import cached_property
from typing import Dict, Iterator, List, Optional, Set, Union
from typing import Collection, Dict, Iterator, List, Optional, Set, Union

import yaml

Expand Down Expand Up @@ -501,15 +501,33 @@ def network_data(self) -> Dict:
Returns:
dict
"""
return self.get_network_data()

def get_network_data(
self,
ecosystem_filter: Optional[Collection[str]] = None,
network_filter: Optional[Collection[str]] = None,
provider_filter: Optional[Collection[str]] = None,
):
data: Dict = {"ecosystems": []}

for ecosystem_name in self:
ecosystem_data = self._get_ecosystem_data(ecosystem_name)
if ecosystem_filter and ecosystem_name not in ecosystem_filter:
continue

ecosystem_data = self._get_ecosystem_data(
ecosystem_name, network_filter=network_filter, provider_filter=provider_filter
)
data["ecosystems"].append(ecosystem_data)

return data

def _get_ecosystem_data(self, ecosystem_name: str) -> Dict:
def _get_ecosystem_data(
self,
ecosystem_name: str,
network_filter: Optional[Collection[str]] = None,
provider_filter: Optional[Collection[str]] = None,
) -> Dict:
ecosystem = self[ecosystem_name]
ecosystem_data: Dict = {"name": str(ecosystem_name)}

Expand All @@ -520,16 +538,21 @@ def _get_ecosystem_data(self, ecosystem_name: str) -> Dict:
ecosystem_data["networks"] = []

for network_name in getattr(self, ecosystem_name).networks:
network_data = ecosystem.get_network_data(network_name)
if network_filter and network_name not in network_filter:
continue

network_data = ecosystem.get_network_data(network_name, provider_filter=provider_filter)
ecosystem_data["networks"].append(network_data)

return ecosystem_data

@property
# TODO: Remove in 0.7
def networks_yaml(self) -> str:
"""
Get a ``yaml`` ``str`` representing all the networks
in all the ecosystems.
**NOTE**: Deprecated.

View the result via CLI command ``ape networks list --format yaml``.

Expand Down
4 changes: 2 additions & 2 deletions src/ape/utils/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,10 @@ def __getattr__(self, name: str) -> Any:

# The error message mentions the alternative mappings,
# such as a contract-type map.
message = f"No attribute with name '{name}'."
message = f"'{repr(self)}' has no attribute '{name}'"
if extras_checked:
extras_str = ", ".join(extras_checked)
message = f"{message} Also checked '{extras_str}'."
message = f"{message}. Also checked '{extras_str}'"

raise ApeAttributeError(message)

Expand Down
Loading
Loading