Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support configuring test account initial balance #2173

Merged
merged 3 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/userguides/testing.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,13 @@ def test_my_method(owner, receiver):
...
```

You can configure your accounts by changing the `mnemonic` or `number_of_accounts` settings in the `test` section of your `ape-config.yaml` file:
You can configure your accounts by changing the `mnemonic`, `number_of_accounts`, and `balance` in the `test` section of your `ape-config.yaml` file:

```yaml
test:
mnemonic: test test test test test test test test test test test junk
number_of_accounts: 5
balance: 100_000 ETH
```

If you are running tests against `anvil`, your generated test accounts may not correspond to the `anvil`'s default generated accounts despite using the same mnemonic. In such a case, you are able to specify a custom derivation path in `ape-config.yaml`:
Expand Down
2 changes: 2 additions & 0 deletions src/ape/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from ape.utils.process import JoinableQueue, spawn
from ape.utils.testing import (
DEFAULT_NUMBER_OF_TEST_ACCOUNTS,
DEFAULT_TEST_ACCOUNT_BALANCE,
DEFAULT_TEST_CHAIN_ID,
DEFAULT_TEST_HD_PATH,
DEFAULT_TEST_MNEMONIC,
Expand All @@ -82,6 +83,7 @@
"DEFAULT_LIVE_NETWORK_BASE_FEE_MULTIPLIER",
"DEFAULT_LOCAL_TRANSACTION_ACCEPTANCE_TIMEOUT",
"DEFAULT_NUMBER_OF_TEST_ACCOUNTS",
"DEFAULT_TEST_ACCOUNT_BALANCE",
"DEFAULT_TEST_CHAIN_ID",
"DEFAULT_TEST_MNEMONIC",
"DEFAULT_TEST_HD_PATH",
Expand Down
1 change: 1 addition & 0 deletions src/ape/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
DEFAULT_TEST_MNEMONIC = "test test test test test test test test test test test junk"
DEFAULT_TEST_HD_PATH = "m/44'/60'/0'/0"
DEFAULT_TEST_CHAIN_ID = 1337
DEFAULT_TEST_ACCOUNT_BALANCE = int(10e21) # 10,000 Ether (in Wei)
GeneratedDevAccount = namedtuple("GeneratedDevAccount", ("address", "private_key"))
"""
An account key-pair generated from the test mnemonic. Set the test mnemonic
Expand Down
77 changes: 41 additions & 36 deletions src/ape_node/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from eth_pydantic_types import HexBytes
from eth_typing import HexStr
from eth_utils import add_0x_prefix, to_hex, to_wei
from eth_utils import add_0x_prefix, to_hex
from evmchains import get_random_rpc
from geth.chain import initialize_chain
from geth.process import BaseGethProcess
Expand All @@ -21,17 +21,15 @@
from ape.api import PluginConfig, SubprocessProvider, TestProviderAPI
from ape.logging import LogLevel, logger
from ape.types import SnapshotID
from ape.utils import (
from ape.utils.misc import ZERO_ADDRESS, log_instead_of_fail, raises_not_implemented
from ape.utils.process import JoinableQueue, spawn
from ape.utils.testing import (
DEFAULT_NUMBER_OF_TEST_ACCOUNTS,
DEFAULT_TEST_ACCOUNT_BALANCE,
DEFAULT_TEST_CHAIN_ID,
DEFAULT_TEST_HD_PATH,
DEFAULT_TEST_MNEMONIC,
ZERO_ADDRESS,
JoinableQueue,
generate_dev_accounts,
log_instead_of_fail,
raises_not_implemented,
spawn,
)
from ape_ethereum.provider import (
DEFAULT_HOSTNAME,
Expand Down Expand Up @@ -98,7 +96,7 @@ def __init__(
mnemonic: str = DEFAULT_TEST_MNEMONIC,
number_of_accounts: int = DEFAULT_NUMBER_OF_TEST_ACCOUNTS,
chain_id: int = DEFAULT_TEST_CHAIN_ID,
initial_balance: Union[str, int] = to_wei(10000, "ether"),
initial_balance: Union[str, int] = DEFAULT_TEST_ACCOUNT_BALANCE,
executable: Optional[str] = None,
auto_disconnect: bool = True,
extra_funded_accounts: Optional[list[str]] = None,
Expand Down Expand Up @@ -131,8 +129,9 @@ def __init__(
self._clean()

geth_kwargs["dev_mode"] = True
hd_path = hd_path or DEFAULT_TEST_HD_PATH
accounts = generate_dev_accounts(
mnemonic, number_of_accounts=number_of_accounts, hd_path=hd_path or DEFAULT_TEST_HD_PATH
mnemonic, number_of_accounts=number_of_accounts, hd_path=hd_path
)
addresses = [a.address for a in accounts]
addresses.extend(extra_funded_accounts or [])
Expand All @@ -152,20 +151,22 @@ def from_uri(cls, uri: str, data_folder: Path, **kwargs):
port = parsed_uri.port if parsed_uri.port is not None else DEFAULT_PORT
mnemonic = kwargs.get("mnemonic", DEFAULT_TEST_MNEMONIC)
number_of_accounts = kwargs.get("number_of_accounts", DEFAULT_NUMBER_OF_TEST_ACCOUNTS)
balance = kwargs.get("initial_balance", DEFAULT_TEST_ACCOUNT_BALANCE)
extra_accounts = [
HexBytes(a).hex().lower() for a in kwargs.get("extra_funded_accounts", [])
]

return cls(
data_folder,
hostname=parsed_uri.host,
port=port,
mnemonic=mnemonic,
number_of_accounts=number_of_accounts,
executable=kwargs.get("executable"),
auto_disconnect=kwargs.get("auto_disconnect", True),
executable=kwargs.get("executable"),
extra_funded_accounts=extra_accounts,
hd_path=kwargs.get("hd_path", DEFAULT_TEST_HD_PATH),
hostname=parsed_uri.host,
initial_balance=balance,
mnemonic=mnemonic,
number_of_accounts=number_of_accounts,
port=port,
)

@property
Expand Down Expand Up @@ -325,33 +326,15 @@ def connect(self):
self.start()

def start(self, timeout: int = 20):
# NOTE: Using JSON mode to ensure types can be passed as CLI args.
test_config = self.config_manager.get_config("test").model_dump(mode="json")

# Allow configuring a custom executable besides your $PATH geth.
if self.settings.executable is not None:
test_config["executable"] = self.settings.executable

test_config["ipc_path"] = self.ipc_path
test_config["auto_disconnect"] = self._test_runner is None or test_config.get(
"disconnect_providers_after", True
)

# Include extra accounts to allocated funds to at genesis.
extra_accounts = self.settings.ethereum.local.get("extra_funded_accounts", [])
extra_accounts.extend(self.provider_settings.get("extra_funded_accounts", []))
extra_accounts = list({HexBytes(a).hex().lower() for a in extra_accounts})
test_config["extra_funded_accounts"] = extra_accounts

process = GethDevProcess.from_uri(self.uri, self.data_dir, **test_config)
process.connect(timeout=timeout)
geth_dev = self._create_process()
geth_dev.connect(timeout=timeout)
if not self.web3.is_connected():
process.disconnect()
geth_dev.disconnect()
raise ConnectionError("Unable to connect to locally running geth.")
else:
self.web3.middleware_onion.inject(geth_poa_middleware, layer=0)

self._process = process
self._process = geth_dev

# For subprocess-provider
if self._process is not None and (process := self._process.proc):
Expand All @@ -366,6 +349,28 @@ def start(self, timeout: int = 20):
spawn(self.consume_stdout_queue)
spawn(self.consume_stderr_queue)

def _create_process(self) -> GethDevProcess:
# NOTE: Using JSON mode to ensure types can be passed as CLI args.
test_config = self.config_manager.get_config("test").model_dump(mode="json")

# Allow configuring a custom executable besides your $PATH geth.
if self.settings.executable is not None:
test_config["executable"] = self.settings.executable

test_config["ipc_path"] = self.ipc_path
test_config["auto_disconnect"] = self._test_runner is None or test_config.get(
"disconnect_providers_after", True
)

# Include extra accounts to allocated funds to at genesis.
extra_accounts = self.settings.ethereum.local.get("extra_funded_accounts", [])
extra_accounts.extend(self.provider_settings.get("extra_funded_accounts", []))
extra_accounts = list({HexBytes(a).hex().lower() for a in extra_accounts})
test_config["extra_funded_accounts"] = extra_accounts
test_config["initial_balance"] = self.test_config.balance

return GethDevProcess.from_uri(self.uri, self.data_dir, **test_config)

def disconnect(self):
# Must disconnect process first.
if self._process is not None:
Expand Down
42 changes: 31 additions & 11 deletions src/ape_test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
from ape import plugins
from ape.api import PluginConfig
from ape.api.networks import LOCAL_NETWORK_NAME
from ape.utils import DEFAULT_NUMBER_OF_TEST_ACCOUNTS, DEFAULT_TEST_HD_PATH, DEFAULT_TEST_MNEMONIC
from ape.utils.basemodel import ManagerAccessMixin
from ape.utils.testing import (
DEFAULT_NUMBER_OF_TEST_ACCOUNTS,
DEFAULT_TEST_ACCOUNT_BALANCE,
DEFAULT_TEST_HD_PATH,
DEFAULT_TEST_MNEMONIC,
)
from ape_test.accounts import TestAccount, TestAccountContainer
from ape_test.provider import EthTesterProviderConfig, LocalProvider

Expand Down Expand Up @@ -110,41 +116,55 @@ class CoverageConfig(PluginConfig):


class ApeTestConfig(PluginConfig):
mnemonic: str = DEFAULT_TEST_MNEMONIC
balance: int = DEFAULT_TEST_ACCOUNT_BALANCE
"""
The mnemonic to use when generating the test accounts.
The starting-balance of every test account in Wei (NOT Ether).
"""

number_of_accounts: NonNegativeInt = DEFAULT_NUMBER_OF_TEST_ACCOUNTS
coverage: CoverageConfig = CoverageConfig()
"""
The number of test accounts to generate in the provider.
Configuration related to coverage reporting.
"""

disconnect_providers_after: bool = True
"""
Set to ``False`` to keep providers connected at the end of the test run.
"""

gas: GasConfig = GasConfig()
"""
Configuration related to gas reporting.
"""

coverage: CoverageConfig = CoverageConfig()
hd_path: str = DEFAULT_TEST_HD_PATH
"""
Configuration related to coverage reporting.
The hd_path to use when generating the test accounts.
"""

disconnect_providers_after: bool = True
mnemonic: str = DEFAULT_TEST_MNEMONIC
"""
Set to ``False`` to keep providers connected at the end of the test run.
The mnemonic to use when generating the test accounts.
"""

hd_path: str = DEFAULT_TEST_HD_PATH
number_of_accounts: NonNegativeInt = DEFAULT_NUMBER_OF_TEST_ACCOUNTS
"""
The hd_path to use when generating the test accounts.
The number of test accounts to generate in the provider.
"""

provider: EthTesterProviderConfig = EthTesterProviderConfig()
"""
Settings for the provider.
"""

@field_validator("balance", mode="before")
@classmethod
def validate_balance(cls, value):
return (
value
if isinstance(value, int)
else ManagerAccessMixin.conversion_manager.convert(value, int)
)


@plugins.register(plugins.Config)
def config_class():
Expand Down
4 changes: 3 additions & 1 deletion src/ape_test/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,12 @@ def tester(self):
return

hd_path = (self.config.hd_path or DEFAULT_TEST_HD_PATH).rstrip("/")
state_overrides = {"balance": self.test_config.balance}
self._evm_backend = PyEVMBackend.from_mnemonic(
genesis_state_overrides=state_overrides,
hd_path=hd_path,
mnemonic=self.config.mnemonic,
num_accounts=self.config.number_of_accounts,
hd_path=hd_path,
)
endpoints = {**API_ENDPOINTS}
endpoints["eth"] = merge(endpoints["eth"], {"chainId": static_return(chain_id)})
Expand Down
14 changes: 14 additions & 0 deletions tests/functional/geth/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,3 +602,17 @@ def test_trace_approach_config(project):
with project.temp_config(node=node_cfg):
provider = project.network_manager.ethereum.local.get_provider("node")
assert provider.call_trace_approach is TraceApproach.GETH_STRUCT_LOG_PARSE


def test_start(mocker, convert, project, geth_provider):
amount = convert("100_000 ETH", int)
spy = mocker.spy(GethDevProcess, "from_uri")

with project.temp_config(test={"balance": amount}):
try:
geth_provider.start()
except Exception:
pass # Exceptions are fine here.

actual = spy.call_args[1]["balance"]
assert actual == amount
22 changes: 20 additions & 2 deletions tests/functional/test_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from requests import HTTPError
from web3.exceptions import ContractPanicError

from ape import convert
from ape.exceptions import (
APINotImplementedError,
BlockNotFoundError,
Expand All @@ -19,9 +20,10 @@
TransactionNotFoundError,
)
from ape.types import LogFilter
from ape.utils import DEFAULT_TEST_CHAIN_ID
from ape.utils.testing import DEFAULT_TEST_ACCOUNT_BALANCE, DEFAULT_TEST_CHAIN_ID
from ape_ethereum.provider import WEB3_PROVIDER_URI_ENV_VAR_NAME, Web3Provider, _sanitize_web3_url
from ape_ethereum.transactions import TransactionStatusEnum, TransactionType
from ape_test import LocalProvider


def test_uri(eth_tester_provider):
Expand Down Expand Up @@ -200,7 +202,7 @@ def test_provider_get_balance(project, networks, accounts):
balance = networks.provider.get_balance(accounts.test_accounts[0].address)

assert type(balance) is int
assert balance == 1000000000000000000000000
assert balance == DEFAULT_TEST_ACCOUNT_BALANCE


def test_set_timestamp(ethereum):
Expand Down Expand Up @@ -475,3 +477,19 @@ def disconnect(self):
finally:
if WEB3_PROVIDER_URI_ENV_VAR_NAME in os.environ:
del os.environ[WEB3_PROVIDER_URI_ENV_VAR_NAME]


def test_account_balance_state(project, eth_tester_provider, owner):
amount = convert("100_000 ETH", int)

with project.temp_config(test={"balance": amount}):
# NOTE: Purposely using a different instance of the provider
# for better testing isolation.
provider = LocalProvider(
name="test",
network=eth_tester_provider.network,
request_header=eth_tester_provider.request_header,
)
provider.connect()
bal = provider.get_balance(owner.address)
assert bal == amount
11 changes: 11 additions & 0 deletions tests/functional/test_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from ape_test import ApeTestConfig


class TestApeTestConfig:
def test_balance_set_from_currency_str(self):
curr_val = "10 Eth"
data = {"balance": curr_val}
cfg = ApeTestConfig.model_validate(data)
actual = cfg.balance
expected = 10_000_000_000_000_000_000 # 10 ETH in WEI
assert actual == expected
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest

INITIAL_BALANCE = 1_000_1 * 10**18


@pytest.fixture(scope="session")
def alice(accounts):
Expand Down Expand Up @@ -27,10 +29,10 @@ def start_block_number(chain):

def test_isolation_first(alice, bob, chain, start_block_number):
assert chain.provider.get_block("latest").number == start_block_number
assert bob.balance == 1_000_001 * 10**18
assert bob.balance == INITIAL_BALANCE
alice.transfer(bob, "1 ether")


def test_isolation_second(bob, chain, start_block_number):
assert chain.provider.get_block("latest").number == start_block_number
assert bob.balance == 1_000_001 * 10**18
assert bob.balance == INITIAL_BALANCE
Loading