Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use mypy #252

Merged
merged 14 commits into from
Sep 20, 2023
33 changes: 33 additions & 0 deletions .github/workflows/type.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: Type check

# Controls when the action will run. Triggers the workflow on push or pull request
# events but only for the main and develop branches.
on:
push:
branches: [ main, develop ]
pull_request:
branches: [ main, develop ]

jobs:
mypy:
name: Run mypy
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3

- name: Set up latest Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
cache: 'pip'

- name: Install dependencies
run: |
pip install --upgrade pip
pip install -r requirements.txt

- name: mypy
run: |
make type

8 changes: 8 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,14 @@ pylint:
@echo ""
@echo "pylint passed ⚙️"

.PHONY: type
type:
@echo "$(REVERSE)Running$(RESET) $(BOLD)mypy$(RESET)..."
@mypy --version
@mypy curvesim
@echo ""
@echo "Type checks passed 🏆"

.PHONY: coverage
coverage:
pytest --cov -n auto --cov-report=
Expand Down
7 changes: 7 additions & 0 deletions changelog.d/20230919_093613_chanhosuh_use_mypy.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Added
-----

- Add type checks via mypy to the CI (#252).
- `StateLog` is derived from a `Log` interface class in-line
with other "template" classes. This also helps avoid
circular dependencies.
1 change: 1 addition & 0 deletions curvesim/_order_book/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from curvesim.pool import CurveMetaPool


# pylint: disable=too-many-locals
def order_book(pool, i, j, *, width=0.1, resolution=10**23, use_fee=True, show=True):
"""
Computes and optionally plots an orderbook representation of exchange rates
Expand Down
6 changes: 4 additions & 2 deletions curvesim/iterators/price_samplers/price_volume.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Iterator

from curvesim.logging import get_logger
from curvesim.price_data import get
from curvesim.templates.price_samplers import PriceSample, PriceSampler
Expand Down Expand Up @@ -68,7 +70,7 @@ def __init__(
self.volumes = volumes.set_axis(assets.symbol_pairs, axis="columns")

@override
def __iter__(self) -> PriceVolumeSample:
def __iter__(self) -> Iterator[PriceVolumeSample]:
"""
Yields
-------
Expand All @@ -86,7 +88,7 @@ def __iter__(self) -> PriceVolumeSample:
prices = prices.to_dict()
volumes = volumes.to_dict()

yield PriceVolumeSample(price_timestamp, prices, volumes)
yield PriceVolumeSample(price_timestamp, prices, volumes) # type:ignore

def total_volumes(self):
"""
Expand Down
7 changes: 4 additions & 3 deletions curvesim/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os
from contextlib import contextmanager
from logging.handlers import QueueHandler, QueueListener
from typing import Dict, List

# -- convenient parameters to adjust for debugging -- #
DEFAULT_LEVEL = "info"
Expand Down Expand Up @@ -41,7 +42,7 @@
# FIXME: need a function to update the config after module initialization
HANDLERS = ["console", "file"] if USE_LOG_FILE else ["console"]

CUSTOM_LOGGING_CONFIG = {
CUSTOM_LOGGING_CONFIG: dict = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
Expand Down Expand Up @@ -74,8 +75,8 @@
}

# 3rd party loggers that we want to largely ignore
silenced_loggers = ["matplotlib", "asyncio", "rlp", "parso", "web3"]
configured_loggers = CUSTOM_LOGGING_CONFIG["loggers"]
silenced_loggers: List[str] = ["matplotlib", "asyncio", "rlp", "parso", "web3"]
configured_loggers: Dict[str, dict] = CUSTOM_LOGGING_CONFIG["loggers"]
for name in silenced_loggers:
configured_loggers[name] = {
"handlers": HANDLERS,
Expand Down
6 changes: 5 additions & 1 deletion curvesim/metrics/state_log/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from pandas import DataFrame, concat

from curvesim.metrics.base import PoolMetric
from curvesim.templates import Log
from curvesim.utils import override

from .pool_parameters import get_pool_parameters
from .pool_state import get_pool_state


class StateLog:
class StateLog(Log):
"""
Logger that records simulation/pool state throughout each simulation run and
computes metrics at the end of each run.
Expand All @@ -29,6 +31,7 @@ def __init__(self, pool, metrics):
self.state_per_run = get_pool_parameters(pool)
self.state_per_trade = []

@override
def update(self, **kwargs):
"""Records pool state and any keyword arguments provided."""

Expand All @@ -47,6 +50,7 @@ def get_logs(self):
**state_per_trade,
}

@override
def compute_metrics(self):
"""Computes metrics from the accumulated log data."""

Expand Down
8 changes: 5 additions & 3 deletions curvesim/pipelines/simple/strategy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Type

from curvesim.logging import get_logger
from curvesim.metrics.state_log import StateLog
from curvesim.templates import Strategy
from curvesim.templates import Log, Strategy, Trader

from .trader import SimpleArbitrageur

Expand All @@ -17,8 +19,8 @@ class SimpleStrategy(Strategy): # pylint: disable=too-few-public-methods
Class for creating state logger instances.
"""

trader_class = SimpleArbitrageur
state_log_class = StateLog
trader_class: Type[Trader] = SimpleArbitrageur
log_class: Type[Log] = StateLog

def _get_trader_inputs(self, sample):
return (sample.prices,)
10 changes: 6 additions & 4 deletions curvesim/pipelines/vol_limited_arb/strategy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Type

from curvesim.logging import get_logger
from curvesim.metrics.state_log import StateLog
from curvesim.metrics.state_log.log import StateLog
from curvesim.pipelines.vol_limited_arb.trader import VolumeLimitedArbitrageur
from curvesim.templates import Strategy
from curvesim.templates import Log, Strategy, Trader

logger = get_logger(__name__)

Expand All @@ -11,8 +13,8 @@ class VolumeLimitedStrategy(Strategy):
Computes and executes volume-limited arbitrage trades at each timestep.
"""

trader_class = VolumeLimitedArbitrageur
state_log_class = StateLog
trader_class: Type[Trader] = VolumeLimitedArbitrageur
log_class: Type[Log] = StateLog

def __init__(self, metrics, vol_mult):
"""
Expand Down
6 changes: 4 additions & 2 deletions curvesim/pool/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
"""


from curvesim.pool.snapshot import SnapshotMixin
from typing import Optional, Type

from curvesim.pool.snapshot import Snapshot, SnapshotMixin


class Pool(SnapshotMixin):
Expand All @@ -21,7 +23,7 @@ class Pool(SnapshotMixin):

# need to configure in derived class otherwise the
# snapshotting will not work
snapshot_class = None
snapshot_class: Optional[Type[Snapshot]] = None

@property
def name(self):
Expand Down
4 changes: 2 additions & 2 deletions curvesim/pool/cryptoswap/calcs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def get_y(A: int, gamma: int, xp: List[int], D: int, j: int) -> List[int]:
if n_coins == 2:
y_out: List[int] = [factory_2_coin.newton_y(A, gamma, xp, D, j), 0]
elif n_coins == 3:
y_out: List[int] = tricrypto_ng.get_y(A, gamma, xp, D, j)
y_out = tricrypto_ng.get_y(A, gamma, xp, D, j)
else:
raise CurvesimValueError("More than 3 coins is not supported.")

Expand All @@ -93,7 +93,7 @@ def get_alpha(
#
# CAUTION: need to be wary of off-by-one errors from integer division.
ma_half_time = ceil(ma_half_time * 1000 / 694)
alpha: int = tricrypto_ng.wad_exp(
alpha = tricrypto_ng.wad_exp(
-1 * ((block_timestamp - last_prices_timestamp) * 10**18 // ma_half_time)
)
else:
Expand Down
15 changes: 5 additions & 10 deletions curvesim/pool/cryptoswap/calcs/factory_2_coin.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,8 @@ def lp_price(virtual_price: int, price_oracle: List[int]) -> int:
int
Liquidity redeemable per LP token in units of token 0.
"""
price_oracle: int = price_oracle[0]

return 2 * virtual_price * _sqrt_int(price_oracle) // 10**18
p: int = price_oracle[0]
return 2 * virtual_price * _sqrt_int(p) // 10**18


# pylint: disable-next=too-many-locals
Expand Down Expand Up @@ -109,10 +108,6 @@ def newton_y( # noqa: complexity: 11
frac: int = x[k] * 10**18 // D
assert 10**16 <= frac <= 10**20

y: int = D // n_coins
K0_i: int = 10**18
S_i: int = 0

x_sorted: List[int] = x.copy()
x_sorted[i] = 0
x_sorted = sorted(x_sorted, reverse=True) # From high to low
Expand Down Expand Up @@ -164,7 +159,7 @@ def newton_y( # noqa: complexity: 11

diff: int = abs(y - y_prev)
if diff < max(convergence_limit, y // 10**14):
frac: int = y * 10**18 // D
frac = y * 10**18 // D
assert 10**16 <= frac <= 10**20 # dev: unsafe value for y
return int(y)

Expand All @@ -176,7 +171,7 @@ def newton_D( # noqa: complexity: 13
ANN: int,
gamma: int,
x_unsorted: List[int],
) -> List[int]:
) -> int:
"""
Finding the `D` invariant using Newton's method.

Expand Down Expand Up @@ -239,7 +234,7 @@ def newton_D( # noqa: complexity: 13
if diff * 10**14 < max(10**16, D):
# Test that we are safe with the next newton_y
for _x in x:
frac: int = _x * 10**18 // D
frac = _x * 10**18 // D
if frac < 10**16 or frac > 10**20:
raise CalculationError("Unsafe value for x[i]")
return int(D)
Expand Down
29 changes: 13 additions & 16 deletions curvesim/pool/cryptoswap/calcs/tricrypto_ng.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ def get_y( # noqa: complexity: 18
assert 10**17 <= D <= 10**15 * 10**18

frac: int = 0
for k in range(3):
if k != i:
frac = x[k] * 10**18 // D
for _k in range(3):
if _k != i:
frac = x[_k] * 10**18 // D
assert 10**16 <= frac <= 10**20, "Unsafe values x[i]"
# if above conditions are met, x[k] > 0
# if above conditions are met, x[_k] > 0

j: int = 0
k: int = 0
Expand Down Expand Up @@ -99,7 +99,7 @@ def get_y( # noqa: complexity: 18
_c: int = gamma2 * (-1 * _c_neg) // D * ANN // 27 // A_MULTIPLIER
_c *= -1
else:
_c: int = gamma2 * _c_neg // D * ANN // 27 // A_MULTIPLIER
_c = gamma2 * _c_neg // D * ANN // 27 // A_MULTIPLIER
c += _c

# (10**18 + gamma)**2/27
Expand Down Expand Up @@ -151,19 +151,19 @@ def get_y( # noqa: complexity: 18
if b_is_neg:
delta0: int = -(_3ac // -b) - b
else:
delta0: int = _3ac // b - b
delta0 = _3ac // b - b

# 9*a*c/b - 2*b - 27*a**2/b*d/b
if b_is_neg:
delta1: int = -(3 * _3ac // -b) - 2 * b - 27 * a**2 // -b * d // -b
else:
delta1: int = 3 * _3ac // b - 2 * b - 27 * a**2 // b * d // b
delta1 = 3 * _3ac // b - 2 * b - 27 * a**2 // b * d // b

# delta1**2 + 4*delta0**2/b*delta0
if b_is_neg:
sqrt_arg: int = delta1**2 - (4 * delta0**2 // -b * delta0)
else:
sqrt_arg: int = delta1**2 + 4 * delta0**2 // b * delta0
sqrt_arg = delta1**2 + 4 * delta0**2 // b * delta0

sqrt_val: int = 0
if sqrt_arg > 0:
Expand All @@ -187,13 +187,13 @@ def get_y( # noqa: complexity: 18
if second_cbrt < 0:
C1: int = -(b_cbrt * b_cbrt // 10**18 * -second_cbrt // 10**18)
else:
C1: int = b_cbrt * b_cbrt // 10**18 * second_cbrt // 10**18
C1 = b_cbrt * b_cbrt // 10**18 * second_cbrt // 10**18

# (b + b*delta0/C1 - C1)/3
if sign(b * delta0) != sign(C1):
root_K0: int = (b + -(b * delta0 // -C1) - C1) // 3
else:
root_K0: int = (b + b * delta0 // C1 - C1) // 3
root_K0 = (b + b * delta0 // C1 - C1) // 3

# D*D/27/x_k*D/x_j*root_K0/a
root: int = D * D // 27 // x_k * D // x_j * root_K0 // a
Expand Down Expand Up @@ -293,7 +293,7 @@ def _newton_y( # noqa: complexity: 11 # pylint: disable=duplicate-code,too-man

diff: int = abs(y - y_prev)
if diff < max(convergence_limit, y // 10**14):
frac: int = y * 10**18 // D
frac = y * 10**18 // D
assert 10**16 <= frac <= 10**20 # dev: unsafe value for y
return int(y)

Expand Down Expand Up @@ -447,9 +447,6 @@ def newton_D( # pylint: disable=too-many-locals
D_minus: int = 0
D_prev: int = 0

diff: int = 0
frac: int = 0

D = mpz(D)

for _ in range(255):
Expand All @@ -466,7 +463,7 @@ def newton_D( # pylint: disable=too-many-locals
_g1k0 = gamma + 10**18 # <--------- safe to do unsafe_add.

# The following operations can safely be unsafe.
_g1k0: int = abs(_g1k0 - K0) + 1
_g1k0 = abs(_g1k0 - K0) + 1

# D / (A * N**N) * _g1k0**2 / gamma**2
# mul1 = 10**18 * D / gamma * _g1k0 / gamma * _g1k0 * A_MULTIPLIER / ANN
Expand Down Expand Up @@ -516,7 +513,7 @@ def newton_D( # pylint: disable=too-many-locals
else:
D = (D_minus - D_plus) // 2

diff = abs(D - D_prev)
diff: int = abs(D - D_prev)
# Could reduce precision for gas efficiency here:
if diff * 10**14 < max(10**16, D):
# Test that we are safe with the next get_y
Expand Down
Loading
Loading