diff --git a/.github/workflows/type.yml b/.github/workflows/type.yml new file mode 100644 index 000000000..cf637e25c --- /dev/null +++ b/.github/workflows/type.yml @@ -0,0 +1,33 @@ +name: Type check + +# Controls when the action will run. Triggers the workflow on push or pull request +# events but only for the main and develop branches. +on: + push: + branches: [ main, develop ] + pull_request: + branches: [ main, develop ] + +jobs: + mypy: + name: Run mypy + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + + - name: Set up latest Python + uses: actions/setup-python@v4 + with: + python-version: "3.10" + cache: 'pip' + + - name: Install dependencies + run: | + pip install --upgrade pip + pip install -r requirements.txt + + - name: mypy + run: | + make type + diff --git a/Makefile b/Makefile index 3440d1f78..ed22ea332 100644 --- a/Makefile +++ b/Makefile @@ -121,6 +121,14 @@ pylint: @echo "" @echo "pylint passed ⚙️" +.PHONY: type +type: + @echo "$(REVERSE)Running$(RESET) $(BOLD)mypy$(RESET)..." + @mypy --version + @mypy curvesim + @echo "" + @echo "Type checks passed 🏆" + .PHONY: coverage coverage: pytest --cov -n auto --cov-report= diff --git a/changelog.d/20230919_093613_chanhosuh_use_mypy.rst b/changelog.d/20230919_093613_chanhosuh_use_mypy.rst new file mode 100644 index 000000000..b569565c6 --- /dev/null +++ b/changelog.d/20230919_093613_chanhosuh_use_mypy.rst @@ -0,0 +1,7 @@ +Added +----- + +- Add type checks via mypy to the CI (#252). +- `StateLog` is derived from a `Log` interface class in-line + with other "template" classes. This also helps avoid + circular dependencies. diff --git a/curvesim/_order_book/__init__.py b/curvesim/_order_book/__init__.py index 9dd1cd564..45383c886 100644 --- a/curvesim/_order_book/__init__.py +++ b/curvesim/_order_book/__init__.py @@ -9,6 +9,7 @@ from curvesim.pool import CurveMetaPool +# pylint: disable=too-many-locals def order_book(pool, i, j, *, width=0.1, resolution=10**23, use_fee=True, show=True): """ Computes and optionally plots an orderbook representation of exchange rates diff --git a/curvesim/iterators/price_samplers/price_volume.py b/curvesim/iterators/price_samplers/price_volume.py index c3912d576..add02223c 100644 --- a/curvesim/iterators/price_samplers/price_volume.py +++ b/curvesim/iterators/price_samplers/price_volume.py @@ -1,3 +1,5 @@ +from typing import Iterator + from curvesim.logging import get_logger from curvesim.price_data import get from curvesim.templates.price_samplers import PriceSample, PriceSampler @@ -68,7 +70,7 @@ def __init__( self.volumes = volumes.set_axis(assets.symbol_pairs, axis="columns") @override - def __iter__(self) -> PriceVolumeSample: + def __iter__(self) -> Iterator[PriceVolumeSample]: """ Yields ------- @@ -86,7 +88,7 @@ def __iter__(self) -> PriceVolumeSample: prices = prices.to_dict() volumes = volumes.to_dict() - yield PriceVolumeSample(price_timestamp, prices, volumes) + yield PriceVolumeSample(price_timestamp, prices, volumes) # type:ignore def total_volumes(self): """ diff --git a/curvesim/logging.py b/curvesim/logging.py index 4808c0df0..986fe2826 100644 --- a/curvesim/logging.py +++ b/curvesim/logging.py @@ -9,6 +9,7 @@ import os from contextlib import contextmanager from logging.handlers import QueueHandler, QueueListener +from typing import Dict, List # -- convenient parameters to adjust for debugging -- # DEFAULT_LEVEL = "info" @@ -41,7 +42,7 @@ # FIXME: need a function to update the config after module initialization HANDLERS = ["console", "file"] if USE_LOG_FILE else ["console"] -CUSTOM_LOGGING_CONFIG = { +CUSTOM_LOGGING_CONFIG: dict = { "version": 1, "disable_existing_loggers": False, "formatters": { @@ -74,8 +75,8 @@ } # 3rd party loggers that we want to largely ignore -silenced_loggers = ["matplotlib", "asyncio", "rlp", "parso", "web3"] -configured_loggers = CUSTOM_LOGGING_CONFIG["loggers"] +silenced_loggers: List[str] = ["matplotlib", "asyncio", "rlp", "parso", "web3"] +configured_loggers: Dict[str, dict] = CUSTOM_LOGGING_CONFIG["loggers"] for name in silenced_loggers: configured_loggers[name] = { "handlers": HANDLERS, diff --git a/curvesim/metrics/state_log/log.py b/curvesim/metrics/state_log/log.py index f816846ca..614f89712 100644 --- a/curvesim/metrics/state_log/log.py +++ b/curvesim/metrics/state_log/log.py @@ -5,12 +5,14 @@ from pandas import DataFrame, concat from curvesim.metrics.base import PoolMetric +from curvesim.templates import Log +from curvesim.utils import override from .pool_parameters import get_pool_parameters from .pool_state import get_pool_state -class StateLog: +class StateLog(Log): """ Logger that records simulation/pool state throughout each simulation run and computes metrics at the end of each run. @@ -29,6 +31,7 @@ def __init__(self, pool, metrics): self.state_per_run = get_pool_parameters(pool) self.state_per_trade = [] + @override def update(self, **kwargs): """Records pool state and any keyword arguments provided.""" @@ -47,6 +50,7 @@ def get_logs(self): **state_per_trade, } + @override def compute_metrics(self): """Computes metrics from the accumulated log data.""" diff --git a/curvesim/pipelines/simple/strategy.py b/curvesim/pipelines/simple/strategy.py index c75138774..35aca1069 100644 --- a/curvesim/pipelines/simple/strategy.py +++ b/curvesim/pipelines/simple/strategy.py @@ -1,6 +1,8 @@ +from typing import Type + from curvesim.logging import get_logger from curvesim.metrics.state_log import StateLog -from curvesim.templates import Strategy +from curvesim.templates import Log, Strategy, Trader from .trader import SimpleArbitrageur @@ -17,8 +19,8 @@ class SimpleStrategy(Strategy): # pylint: disable=too-few-public-methods Class for creating state logger instances. """ - trader_class = SimpleArbitrageur - state_log_class = StateLog + trader_class: Type[Trader] = SimpleArbitrageur + log_class: Type[Log] = StateLog def _get_trader_inputs(self, sample): return (sample.prices,) diff --git a/curvesim/pipelines/vol_limited_arb/strategy.py b/curvesim/pipelines/vol_limited_arb/strategy.py index e8c127615..697c3d23b 100644 --- a/curvesim/pipelines/vol_limited_arb/strategy.py +++ b/curvesim/pipelines/vol_limited_arb/strategy.py @@ -1,7 +1,9 @@ +from typing import Type + from curvesim.logging import get_logger -from curvesim.metrics.state_log import StateLog +from curvesim.metrics.state_log.log import StateLog from curvesim.pipelines.vol_limited_arb.trader import VolumeLimitedArbitrageur -from curvesim.templates import Strategy +from curvesim.templates import Log, Strategy, Trader logger = get_logger(__name__) @@ -11,8 +13,8 @@ class VolumeLimitedStrategy(Strategy): Computes and executes volume-limited arbitrage trades at each timestep. """ - trader_class = VolumeLimitedArbitrageur - state_log_class = StateLog + trader_class: Type[Trader] = VolumeLimitedArbitrageur + log_class: Type[Log] = StateLog def __init__(self, metrics, vol_mult): """ diff --git a/curvesim/pool/base.py b/curvesim/pool/base.py index f72bf6e77..d6ce508fe 100644 --- a/curvesim/pool/base.py +++ b/curvesim/pool/base.py @@ -4,7 +4,9 @@ """ -from curvesim.pool.snapshot import SnapshotMixin +from typing import Optional, Type + +from curvesim.pool.snapshot import Snapshot, SnapshotMixin class Pool(SnapshotMixin): @@ -21,7 +23,7 @@ class Pool(SnapshotMixin): # need to configure in derived class otherwise the # snapshotting will not work - snapshot_class = None + snapshot_class: Optional[Type[Snapshot]] = None @property def name(self): diff --git a/curvesim/pool/cryptoswap/calcs/__init__.py b/curvesim/pool/cryptoswap/calcs/__init__.py index fcf3dab78..7c1efbc4c 100644 --- a/curvesim/pool/cryptoswap/calcs/__init__.py +++ b/curvesim/pool/cryptoswap/calcs/__init__.py @@ -71,7 +71,7 @@ def get_y(A: int, gamma: int, xp: List[int], D: int, j: int) -> List[int]: if n_coins == 2: y_out: List[int] = [factory_2_coin.newton_y(A, gamma, xp, D, j), 0] elif n_coins == 3: - y_out: List[int] = tricrypto_ng.get_y(A, gamma, xp, D, j) + y_out = tricrypto_ng.get_y(A, gamma, xp, D, j) else: raise CurvesimValueError("More than 3 coins is not supported.") @@ -93,7 +93,7 @@ def get_alpha( # # CAUTION: need to be wary of off-by-one errors from integer division. ma_half_time = ceil(ma_half_time * 1000 / 694) - alpha: int = tricrypto_ng.wad_exp( + alpha = tricrypto_ng.wad_exp( -1 * ((block_timestamp - last_prices_timestamp) * 10**18 // ma_half_time) ) else: diff --git a/curvesim/pool/cryptoswap/calcs/factory_2_coin.py b/curvesim/pool/cryptoswap/calcs/factory_2_coin.py index e64fc2532..1d5b52841 100644 --- a/curvesim/pool/cryptoswap/calcs/factory_2_coin.py +++ b/curvesim/pool/cryptoswap/calcs/factory_2_coin.py @@ -79,9 +79,8 @@ def lp_price(virtual_price: int, price_oracle: List[int]) -> int: int Liquidity redeemable per LP token in units of token 0. """ - price_oracle: int = price_oracle[0] - - return 2 * virtual_price * _sqrt_int(price_oracle) // 10**18 + p: int = price_oracle[0] + return 2 * virtual_price * _sqrt_int(p) // 10**18 # pylint: disable-next=too-many-locals @@ -109,10 +108,6 @@ def newton_y( # noqa: complexity: 11 frac: int = x[k] * 10**18 // D assert 10**16 <= frac <= 10**20 - y: int = D // n_coins - K0_i: int = 10**18 - S_i: int = 0 - x_sorted: List[int] = x.copy() x_sorted[i] = 0 x_sorted = sorted(x_sorted, reverse=True) # From high to low @@ -164,7 +159,7 @@ def newton_y( # noqa: complexity: 11 diff: int = abs(y - y_prev) if diff < max(convergence_limit, y // 10**14): - frac: int = y * 10**18 // D + frac = y * 10**18 // D assert 10**16 <= frac <= 10**20 # dev: unsafe value for y return int(y) @@ -176,7 +171,7 @@ def newton_D( # noqa: complexity: 13 ANN: int, gamma: int, x_unsorted: List[int], -) -> List[int]: +) -> int: """ Finding the `D` invariant using Newton's method. @@ -239,7 +234,7 @@ def newton_D( # noqa: complexity: 13 if diff * 10**14 < max(10**16, D): # Test that we are safe with the next newton_y for _x in x: - frac: int = _x * 10**18 // D + frac = _x * 10**18 // D if frac < 10**16 or frac > 10**20: raise CalculationError("Unsafe value for x[i]") return int(D) diff --git a/curvesim/pool/cryptoswap/calcs/tricrypto_ng.py b/curvesim/pool/cryptoswap/calcs/tricrypto_ng.py index bb47a6937..df6002eb7 100644 --- a/curvesim/pool/cryptoswap/calcs/tricrypto_ng.py +++ b/curvesim/pool/cryptoswap/calcs/tricrypto_ng.py @@ -51,11 +51,11 @@ def get_y( # noqa: complexity: 18 assert 10**17 <= D <= 10**15 * 10**18 frac: int = 0 - for k in range(3): - if k != i: - frac = x[k] * 10**18 // D + for _k in range(3): + if _k != i: + frac = x[_k] * 10**18 // D assert 10**16 <= frac <= 10**20, "Unsafe values x[i]" - # if above conditions are met, x[k] > 0 + # if above conditions are met, x[_k] > 0 j: int = 0 k: int = 0 @@ -99,7 +99,7 @@ def get_y( # noqa: complexity: 18 _c: int = gamma2 * (-1 * _c_neg) // D * ANN // 27 // A_MULTIPLIER _c *= -1 else: - _c: int = gamma2 * _c_neg // D * ANN // 27 // A_MULTIPLIER + _c = gamma2 * _c_neg // D * ANN // 27 // A_MULTIPLIER c += _c # (10**18 + gamma)**2/27 @@ -151,19 +151,19 @@ def get_y( # noqa: complexity: 18 if b_is_neg: delta0: int = -(_3ac // -b) - b else: - delta0: int = _3ac // b - b + delta0 = _3ac // b - b # 9*a*c/b - 2*b - 27*a**2/b*d/b if b_is_neg: delta1: int = -(3 * _3ac // -b) - 2 * b - 27 * a**2 // -b * d // -b else: - delta1: int = 3 * _3ac // b - 2 * b - 27 * a**2 // b * d // b + delta1 = 3 * _3ac // b - 2 * b - 27 * a**2 // b * d // b # delta1**2 + 4*delta0**2/b*delta0 if b_is_neg: sqrt_arg: int = delta1**2 - (4 * delta0**2 // -b * delta0) else: - sqrt_arg: int = delta1**2 + 4 * delta0**2 // b * delta0 + sqrt_arg = delta1**2 + 4 * delta0**2 // b * delta0 sqrt_val: int = 0 if sqrt_arg > 0: @@ -187,13 +187,13 @@ def get_y( # noqa: complexity: 18 if second_cbrt < 0: C1: int = -(b_cbrt * b_cbrt // 10**18 * -second_cbrt // 10**18) else: - C1: int = b_cbrt * b_cbrt // 10**18 * second_cbrt // 10**18 + C1 = b_cbrt * b_cbrt // 10**18 * second_cbrt // 10**18 # (b + b*delta0/C1 - C1)/3 if sign(b * delta0) != sign(C1): root_K0: int = (b + -(b * delta0 // -C1) - C1) // 3 else: - root_K0: int = (b + b * delta0 // C1 - C1) // 3 + root_K0 = (b + b * delta0 // C1 - C1) // 3 # D*D/27/x_k*D/x_j*root_K0/a root: int = D * D // 27 // x_k * D // x_j * root_K0 // a @@ -293,7 +293,7 @@ def _newton_y( # noqa: complexity: 11 # pylint: disable=duplicate-code,too-man diff: int = abs(y - y_prev) if diff < max(convergence_limit, y // 10**14): - frac: int = y * 10**18 // D + frac = y * 10**18 // D assert 10**16 <= frac <= 10**20 # dev: unsafe value for y return int(y) @@ -447,9 +447,6 @@ def newton_D( # pylint: disable=too-many-locals D_minus: int = 0 D_prev: int = 0 - diff: int = 0 - frac: int = 0 - D = mpz(D) for _ in range(255): @@ -466,7 +463,7 @@ def newton_D( # pylint: disable=too-many-locals _g1k0 = gamma + 10**18 # <--------- safe to do unsafe_add. # The following operations can safely be unsafe. - _g1k0: int = abs(_g1k0 - K0) + 1 + _g1k0 = abs(_g1k0 - K0) + 1 # D / (A * N**N) * _g1k0**2 / gamma**2 # mul1 = 10**18 * D / gamma * _g1k0 / gamma * _g1k0 * A_MULTIPLIER / ANN @@ -516,7 +513,7 @@ def newton_D( # pylint: disable=too-many-locals else: D = (D_minus - D_plus) // 2 - diff = abs(D - D_prev) + diff: int = abs(D - D_prev) # Could reduce precision for gas efficiency here: if diff * 10**14 < max(10**16, D): # Test that we are safe with the next get_y diff --git a/curvesim/pool/cryptoswap/pool.py b/curvesim/pool/cryptoswap/pool.py index 2f6cdcf81..8f3310e89 100644 --- a/curvesim/pool/cryptoswap/pool.py +++ b/curvesim/pool/cryptoswap/pool.py @@ -3,12 +3,12 @@ """ import time from math import isqrt, prod -from typing import List +from typing import List, Tuple, Type from curvesim.exceptions import CalculationError, CryptoPoolError, CurvesimValueError from curvesim.logging import get_logger from curvesim.pool.base import Pool -from curvesim.pool.snapshot import CurveCryptoPoolBalanceSnapshot +from curvesim.pool.snapshot import CurveCryptoPoolBalanceSnapshot, Snapshot from .calcs import ( factory_2_coin, @@ -30,7 +30,7 @@ class CurveCryptoPool(Pool): # pylint: disable=too-many-instance-attributes """Cryptoswap implementation in Python.""" - snapshot_class = CurveCryptoPoolBalanceSnapshot + snapshot_class: Type[Snapshot] = CurveCryptoPoolBalanceSnapshot __slots__ = ( "A", @@ -58,7 +58,8 @@ class CurveCryptoPool(Pool): # pylint: disable=too-many-instance-attributes "not_adjusted", ) - def __init__( # pylint: disable=too-many-locals,too-many-arguments + # pylint: disable-next=too-many-locals,too-many-arguments,too-many-branches + def __init__( self, A: int, gamma: int, @@ -81,7 +82,7 @@ def __init__( # pylint: disable=too-many-locals,too-many-arguments xcp_profit=10**18, xcp_profit_a=10**18, virtual_price=None, - ): + ) -> None: """ Parameters ---------- @@ -279,7 +280,7 @@ def _get_xcp(self, D: int) -> int: ] return geometric_mean(x) - def _increment_timestamp(self, blocks=1, timestamp=None): + def _increment_timestamp(self, blocks=1, timestamp=None) -> None: """Update the internal clock used to mimic the block timestamp.""" if timestamp: self._block_timestamp = timestamp @@ -297,7 +298,7 @@ def _tweak_price( # noqa: complexity: 12 p_i: int, new_D: int, K0_prev: int = 0, - ): + ) -> None: """ Applies several kinds of updates: - EMA price update: price_oracle @@ -444,7 +445,7 @@ def _tweak_price( # noqa: complexity: 12 self.D = D_unadjusted self.virtual_price = virtual_price - def _claim_admin_fees(self): + def _claim_admin_fees(self) -> None: # no gulping logic needed for the python code xcp_profit: int = self.xcp_profit xcp_profit_a: int = self.xcp_profit_a @@ -518,7 +519,7 @@ def get_dy(self, i: int, j: int, dx: int) -> int: return dy - def get_y(self, i, j, x, xp): + def get_y(self, i: int, j: int, x: int, xp: List[int]) -> int: r""" Calculate x[j] if one makes x[i] = x. @@ -578,7 +579,7 @@ def _fee(self, xp: List[int]) -> int: K: int = 10**18 for _x in xp: K = K * n_coins * _x // _sum_xp - f: int = fee_gamma * 10**18 // (fee_gamma + 10**18 - K) + f = fee_gamma * 10**18 // (fee_gamma + 10**18 - K) return (self.mid_fee * f + self.out_fee * (10**18 - f)) // 10**18 # pylint: disable-next=too-many-locals @@ -588,7 +589,7 @@ def _exchange( j: int, dx: int, min_dy: int, - ) -> int: + ) -> Tuple[int, int]: assert i != j, "Indices must be different" assert i < self.n, "Index out of bounds" assert j < self.n, "Index out of bounds" @@ -657,7 +658,7 @@ def exchange( j: int, dx: int, min_dy: int = 0, - ) -> int: + ) -> Tuple[int, int]: """ Swap `dx` amount of the `i`-th coin for the `j`-th coin. @@ -689,7 +690,7 @@ def exchange_underlying( j: int, dx: int, min_dy: int = 0, - ) -> int: + ) -> Tuple[int, int]: """ In the vyper contract, this exchanges using ETH instead of WETH. In Curvesim, this is the same as `exchange`. @@ -889,7 +890,7 @@ def _calc_withdraw_one_coin( i: int, update_D: bool, calc_price: bool, - ) -> (int, int, int, List[int]): + ) -> Tuple[int, int, int, List[int]]: token_supply: int = self.tokens assert token_amount <= token_supply # dev: token amount more than supply assert i < self.n # dev: coin out of range @@ -915,9 +916,7 @@ def _calc_withdraw_one_coin( if i == 0: dy: int = (xp[i] - y) // precisions[i] else: - dy: int = ( - (xp[i] - y) * PRECISION // (precisions[i] * self.price_scale[i - 1]) - ) + dy = (xp[i] - y) * PRECISION // (precisions[i] * self.price_scale[i - 1]) xp[i] = y # FIXME: update for n coins @@ -982,9 +981,9 @@ def lp_price(self) -> int: price: int = factory_2_coin.lp_price(virtual_price, price_oracle) elif self.n == 3: # 3-coin vyper contract uses cached packed oracle prices instead of internal_price_oracle() - virtual_price: int = self.virtual_price - price_oracle: List[int] = self._price_oracle - price: int = tricrypto_ng.lp_price(virtual_price, price_oracle) + virtual_price = self.virtual_price + price_oracle = self._price_oracle + price = tricrypto_ng.lp_price(virtual_price, price_oracle) else: raise CalculationError("LP price calc doesn't support more than 3 coins") @@ -1004,7 +1003,7 @@ def internal_price_oracle(self) -> List[int]: last_prices: List[int] = self.last_prices elif self.n == 3: # 3-coin vyper contract caps every "last price" that gets fed to the EMA - last_prices: List[int] = self.last_prices.copy() + last_prices = self.last_prices.copy() price_scale: List[int] = self.price_scale last_prices = [ min(last_p, 2 * storage_p) diff --git a/curvesim/pool/snapshot.py b/curvesim/pool/snapshot.py index 1cb1acd90..afe60538d 100644 --- a/curvesim/pool/snapshot.py +++ b/curvesim/pool/snapshot.py @@ -1,9 +1,47 @@ from abc import ABC, abstractmethod from contextlib import contextmanager +from typing import Optional, Type from curvesim.exceptions import SnapshotError +class Snapshot(ABC): + """ + This class allows customization of snapshot logic, i.e. + controls how partial states are produced and restored. + """ + + @classmethod + @abstractmethod + def create(cls, pool): + """ + Create a snapshot of the pool's state. + + Parameters + ----------- + pool: object + The object whose state we are saving as a snapshot. + + Returns + ------- + Snapshot + The saved data from the pool state. + """ + raise NotImplementedError + + @abstractmethod + def restore(self, pool): + """ + Update the pool's state using the snapshot data. + + Parameters + ----------- + pool: object + The object whose state we have saved as a snapshot. + """ + raise NotImplementedError + + class SnapshotMixin: """ Pool mixin to allow ability to snapshot partial states @@ -13,7 +51,7 @@ class SnapshotMixin: implements the `Snapshot` interface. """ - snapshot_class = None + snapshot_class: Optional[Type[Snapshot]] = None def get_snapshot(self): """Saves the pool's partial state.""" @@ -58,43 +96,6 @@ def use_snapshot_context(self): self.revert_to_snapshot(snapshot) -class Snapshot(ABC): - """ - This class allows customization of snapshot logic, i.e. - controls how partial states are produced and restored. - """ - - @classmethod - @abstractmethod - def create(cls, pool): - """ - Create a snapshot of the pool's state. - - Parameters - ----------- - pool: object - The object whose state we are saving as a snapshot. - - Returns - ------- - Snapshot - The saved data from the pool state. - """ - raise NotImplementedError - - @abstractmethod - def restore(self, pool): - """ - Update the pool's state using the snapshot data. - - Parameters - ----------- - pool: object - The object whose state we have saved as a snapshot. - """ - raise NotImplementedError - - class CurvePoolBalanceSnapshot(Snapshot): """Snapshot that saves pool balances and admin balances.""" diff --git a/curvesim/pool/stableswap/metapool.py b/curvesim/pool/stableswap/metapool.py index d1ad819fe..a62478caf 100644 --- a/curvesim/pool/stableswap/metapool.py +++ b/curvesim/pool/stableswap/metapool.py @@ -2,11 +2,12 @@ Mainly a module to house the `MetaPool`, a metapool stableswap implementation in Python. """ from math import prod +from typing import Type from gmpy2 import mpz from curvesim.exceptions import CurvesimValueError -from curvesim.pool.snapshot import CurveMetaPoolBalanceSnapshot +from curvesim.pool.snapshot import CurveMetaPoolBalanceSnapshot, Snapshot from ..base import Pool @@ -16,7 +17,7 @@ class CurveMetaPool(Pool): # pylint: disable=too-many-instance-attributes Basic stableswap metapool implementation in Python. """ - snapshot_class = CurveMetaPoolBalanceSnapshot + snapshot_class: Type[Snapshot] = CurveMetaPoolBalanceSnapshot __slots__ = ( "A", diff --git a/curvesim/pool/stableswap/pool.py b/curvesim/pool/stableswap/pool.py index 18a0fcb14..33e31169e 100644 --- a/curvesim/pool/stableswap/pool.py +++ b/curvesim/pool/stableswap/pool.py @@ -2,11 +2,12 @@ Mainly a module to house the `Pool`, a basic stableswap implementation in Python. """ from math import prod +from typing import Type from gmpy2 import mpz from curvesim.exceptions import CurvesimValueError -from curvesim.pool.snapshot import CurvePoolBalanceSnapshot +from curvesim.pool.snapshot import CurvePoolBalanceSnapshot, Snapshot from ..base import Pool @@ -16,7 +17,7 @@ class CurvePool(Pool): # pylint: disable=too-many-instance-attributes Basic stableswap implementation in Python. """ - snapshot_class = CurvePoolBalanceSnapshot + snapshot_class: Type[Snapshot] = CurvePoolBalanceSnapshot __slots__ = ( "A", diff --git a/curvesim/templates/__init__.py b/curvesim/templates/__init__.py index 31b16de4d..15cb2ad44 100644 --- a/curvesim/templates/__init__.py +++ b/curvesim/templates/__init__.py @@ -3,6 +3,7 @@ """ __all__ = [ + "Log", "Trader", "Strategy", "SimAssets", @@ -14,6 +15,7 @@ "PriceSampler", ] +from .log import Log from .param_samplers import ParameterSampler from .price_samplers import PriceSample, PriceSampler from .sim_assets import SimAssets diff --git a/curvesim/templates/log.py b/curvesim/templates/log.py new file mode 100644 index 000000000..9cea44923 --- /dev/null +++ b/curvesim/templates/log.py @@ -0,0 +1,11 @@ +from abc import ABC, abstractmethod + + +class Log(ABC): + @abstractmethod + def update(self, **kwargs): + """Updates log data with event data.""" + + @abstractmethod + def compute_metrics(self): + """Computes metrics from the accumulated log data.""" diff --git a/curvesim/templates/strategy.py b/curvesim/templates/strategy.py index eb944af05..9677975dd 100644 --- a/curvesim/templates/strategy.py +++ b/curvesim/templates/strategy.py @@ -1,7 +1,11 @@ from abc import ABC, abstractmethod +from typing import Optional, Type from curvesim.logging import get_logger +from .log import Log +from .trader import Trader + logger = get_logger(__name__) @@ -9,14 +13,14 @@ class Strategy(ABC): """ A Strategy defines the trading approach used during each step of a simulation. It executes the trades using an injected `Trader` class and then logs the - changes using the injected `StateLog` class. + changes using the injected `Log` class. Class Attributes ---------------- - trader_class : :class:`~curvesim.pipelines.templates.Trader` + trader_class : :class:`~curvesim.templates.Trader` Class for creating trader instances. - state_log_class : :class:`~curvesim.metrics.StateLog` - Class for creating state logger instances. + log_class : :class:`~curvesim.templates.Log` + Class for creating log instances. Attributes ---------- @@ -26,8 +30,8 @@ class Strategy(ABC): # These classes should be injected in child classes # to create the desired behavior. - trader_class = None - state_log_class = None + trader_class: Optional[Type[Trader]] = None + log_class: Optional[Type[Log]] = None def __init__(self, metrics): """ @@ -63,7 +67,7 @@ def __call__(self, pool, parameters, price_sampler): """ # pylint: disable=not-callable trader = self.trader_class(pool) - state_log = self.state_log_class(pool, self.metrics) + log = self.log_class(pool, self.metrics) parameters = parameters or "no parameter changes" logger.info("[%s] Simulating with %s", pool.symbol, parameters) @@ -74,9 +78,9 @@ def __call__(self, pool, parameters, price_sampler): pool.prepare_for_trades(sample.timestamp) trader_args = self._get_trader_inputs(sample) trade_data = trader.process_time_sample(*trader_args) - state_log.update(price_sample=sample, trade_data=trade_data) + log.update(price_sample=sample, trade_data=trade_data) - return state_log.compute_metrics() + return log.compute_metrics() @abstractmethod def _get_trader_inputs(self, sample): diff --git a/requirements.txt b/requirements.txt index 020cf166f..43bd29783 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,16 +2,16 @@ # "Core" dependencies used for the package. # These bounds should match the ones in `setup.cfg` -numpy==1.25.1 -pandas==2.0.3 -scipy==1.11.1 +numpy==1.25.2 +pandas==2.1.0 +scipy==1.11.2 gmpy2==2.1.5 -matplotlib==3.7.2 -web3==6.6.1 +matplotlib==3.7.3 +web3==6.9.0 requests==2.31.0 -tenacity==8.2.2 +tenacity==8.2.3 python-dotenv==1.0.0 -altair==5.0.1 +altair==5.1.1 # Dependencies used for setting up a venv for development. # These should be pinned to specific versions. black==22.6.0 @@ -28,38 +28,39 @@ coverage==7.0.5 titanoboa==0.1.7 Sphinx==5.3.0 Pallets-Sphinx-Themes==2.0.3 +mypy==1.5.1 ## The following requirements were added by pip freeze: aiohttp==3.8.5 aiosignal==1.3.1 alabaster==0.7.13 -asttokens==2.2.1 -async-timeout==4.0.2 +asttokens==2.4.0 +async-timeout==4.0.3 attrs==23.1.0 Babel==2.12.1 -bitarray==2.7.6 +bitarray==2.8.1 cached-property==1.5.2 -certifi==2023.5.7 +certifi==2023.7.22 charset-normalizer==3.2.0 -click==8.1.6 +click==8.1.7 click-log==0.4.0 contourpy==1.1.0 cycler==0.11.0 -cytoolz==0.12.1 -dill==0.3.6 +cytoolz==0.12.2 +dill==0.3.7 docutils==0.19 -eth-abi==4.1.0 +eth-abi==4.2.1 eth-account==0.9.0 eth-bloom==2.0.0 eth-hash==0.5.2 eth-keyfile==0.6.1 eth-keys==0.4.0 eth-rlp==0.3.0 -eth-stdlib==0.2.6 +eth-stdlib==0.2.7 eth-typing==3.4.0 -eth-utils==2.2.0 -exceptiongroup==1.1.2 +eth-utils==2.2.1 +exceptiongroup==1.1.3 execnet==2.0.2 -fonttools==4.41.0 +fonttools==4.42.1 frozenlist==1.4.0 hexbytes==0.3.1 idna==3.4 @@ -68,9 +69,9 @@ importlib-metadata==6.8.0 iniconfig==2.0.0 isort==5.12.0 Jinja2==3.1.2 -jsonschema==4.18.4 +jsonschema==4.19.0 jsonschema-specifications==2023.7.1 -kiwisolver==1.4.4 +kiwisolver==1.4.5 lazy-object-proxy==1.9.0 lru-dict==1.2.0 markdown-it-py==3.0.0 @@ -78,44 +79,43 @@ MarkupSafe==2.1.3 mccabe==0.7.0 mdurl==0.1.2 multidict==6.0.4 -mypy-extensions==0.4.4 +mypy-extensions==1.0.0 packaging==23.1 parsimonious==0.9.0 -pathspec==0.11.1 +pathspec==0.11.2 Pillow==10.0.0 -platformdirs==3.9.1 -pluggy==1.2.0 -protobuf==4.23.4 +platformdirs==3.10.0 +pluggy==1.3.0 +protobuf==4.24.3 py==1.11.0 py-ecc==6.0.0 -py-evm==0.7.0a3 +py-evm==0.7.0a4 pycodestyle==2.9.1 pycryptodome==3.18.0 pyethash==0.1.27 pyflakes==2.5.0 -Pygments==2.15.1 -pyparsing==3.0.9 +Pygments==2.16.1 +pyparsing==3.1.1 python-dateutil==2.8.2 -pytz==2023.3 +pytz==2023.3.post1 pyunormalize==15.0.0 -referencing==0.30.0 -regex==2023.6.3 -rich==13.4.2 +referencing==0.30.2 +regex==2023.8.8 +rich==13.5.2 rlp==3.0.0 -rpds-py==0.9.2 -safe-pysha3==1.0.4 +rpds-py==0.10.3 semantic-version==2.10.0 six==1.16.0 snowballstemmer==2.2.0 sortedcontainers==2.4.0 -sphinxcontrib-applehelp==1.0.4 -sphinxcontrib-devhelp==1.0.2 -sphinxcontrib-htmlhelp==2.0.1 +sphinxcontrib-applehelp==1.0.7 +sphinxcontrib-devhelp==1.0.5 +sphinxcontrib-htmlhelp==2.0.4 sphinxcontrib-jsmath==1.0.1 -sphinxcontrib-qthelp==1.0.3 -sphinxcontrib-serializinghtml==1.1.5 +sphinxcontrib-qthelp==1.0.6 +sphinxcontrib-serializinghtml==1.1.9 tomli==2.0.1 -tomlkit==0.11.8 +tomlkit==0.12.1 toolz==0.12.0 trie==2.1.1 typing_extensions==4.7.1 diff --git a/requirements_dev.txt b/requirements_dev.txt index 3551d4a42..eed86f369 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -20,3 +20,5 @@ titanoboa==0.1.7 # 0.1.6 for python versions before 3.10 sphinx==5.3.0 Pallets-Sphinx-Themes==2.0.3 + +mypy==1.5.1 diff --git a/setup.cfg b/setup.cfg index 1784caa91..3b44297a2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,6 +39,24 @@ filterwarnings = ignore::DeprecationWarning:altair.*: ignore::UserWarning:vyper.*: +[mypy] +warn_unused_configs = True + +[mypy-altair] +ignore_missing_imports = True +[mypy-pandas] +ignore_missing_imports = True +[mypy-gmpy2] +ignore_missing_imports = True +[mypy-matplotlib] +ignore_missing_imports = True +[mypy-matplotlib.pyplot] +ignore_missing_imports = True +[mypy-scipy] +ignore_missing_imports = True +[mypy-scipy.optimize] +ignore_missing_imports = True + [bumpversion:file:curvesim/version.py] [metadata]