Skip to content

Commit

Permalink
Merge pull request #266 from curveresearch/refactor-utils
Browse files Browse the repository at this point in the history
Refactor utils
  • Loading branch information
chanhosuh authored Nov 6, 2023
2 parents bc63e51 + 0040675 commit 6ec7a60
Show file tree
Hide file tree
Showing 8 changed files with 261 additions and 78 deletions.
5 changes: 5 additions & 0 deletions changelog.d/20231031_152630_chanhosuh_refactor_utils.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Added
-----
- Created utilities for managing datetime, addresses, and enums.
- Created a constants module to house common constants / enums (currently only `Chain` and `Env`).

44 changes: 44 additions & 0 deletions curvesim/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""
Constants and Enum types used in Curvesim.
"""
from enum import Enum


class StrEnum(str, Enum):
"""
Custom string enum type since the builtin `StrEnum` is not available
until Python 3.11.
"""

def __str__(self):
"""
Regular Enum's __str__ is the name, rather than the value,
e.g.
>>> str(Chain.MAINNET)
'Chain.MAINNET'
so we need to explicitly use the value.
This behaves like the builtin `StrEnum` (available in 3.11).
"""
return str.__str__(self)


class Chain(StrEnum):
"""Identifiers for chains & layer 2s."""

MAINNET = "mainnet"
ARBITRUM = "arbitrum"
OPTIMISM = "optimism"
FANTOM = "fantom"
AVALANCHE = "avalanche"
MATIC = "matic"
XDAI = "xdai"


class Env(StrEnum):
"""Names for different API environments."""

PROD = "prod"
STAGING = "staging"
28 changes: 19 additions & 9 deletions curvesim/pool_data/queries/metadata.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""
Functions to get pool metadata for Curve pools.
"""
from typing import Optional, Union

from curvesim.constants import Chain, Env
from curvesim.network.subgraph import pool_snapshot_sync, symbol_address_sync
from curvesim.network.web3 import underlying_coin_info_sync
from curvesim.pool_data.metadata import PoolMetaData
from curvesim.utils import get_event_loop
from curvesim.utils import Address, get_event_loop, to_address


def from_address(address, chain, env="prod", end_ts=None):
Expand Down Expand Up @@ -56,28 +59,35 @@ def from_symbol(symbol, chain, env):


def get_metadata(
address,
chain="mainnet",
env="prod",
end_ts=None,
address: Union[str, Address],
chain: Union[str, Chain] = Chain.MAINNET,
env: Union[str, Env] = Env.PROD,
end_ts: Optional[int] = None,
):
"""
Pulls pool state and metadata from daily snapshot.
Parameters
----------
address : str
Pool address prefixed with “0x”.
address : str, Address
Pool address in proper checksum hexadecimal format.
chain : str
chain : str, Chain
Chain/layer2 identifier, e.g. “mainnet”, “arbitrum”, “optimism".
end_ts : int, optional
Datetime cutoff, given as Unix timestamp, to pull last snapshot before.
The default value is current datetime, which will pull the most recent snapshot.
Returns
-------
:class:`~curvesim.pool_data.metadata.PoolMetaDataInterface`
"""
# TODO: validate function arguments
address = to_address(address)
chain = Chain(chain)
env = Env(env)

metadata_dict = from_address(address, chain, env=env, end_ts=end_ts)
metadata = PoolMetaData(metadata_dict)

Expand Down
4 changes: 2 additions & 2 deletions curvesim/pool_data/queries/pool_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from curvesim.pool_data.metadata import PoolMetaDataInterface
from curvesim.utils import get_event_loop, get_pairs

from .metadata import get_metadata
from .metadata import Chain, get_metadata

logger = get_logger(__name__)

Expand All @@ -22,7 +22,7 @@ def get_pool_volume(
metadata_or_address: Union[PoolMetaDataInterface, str],
days: int = 60,
end: Optional[int] = None,
chain: Optional[str] = "mainnet",
chain: Union[str, Chain] = "mainnet",
) -> DataFrame:
"""
Gets historical daily volume for each pair of coins traded in a Curve pool.
Expand Down
84 changes: 17 additions & 67 deletions curvesim/utils.py → curvesim/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,29 @@
"""Utlity functions for general usage in Curvesim."""
__all__ = [
"get_env_var",
"get_pairs",
"dataclass",
"get_event_loop",
"cache",
"override",
"datetime",
"is_address",
"to_address",
"Address",
]

import asyncio
import functools
import inspect
import os
import re
import sys
from dataclasses import dataclass as _dataclass
from itertools import combinations

from dotenv import load_dotenv

from curvesim.exceptions import CurvesimException, MissingEnvVarError
from curvesim.exceptions import MissingEnvVarError

from .address import Address, is_address, to_address
from .decorators import cache, override

load_dotenv()

Expand Down Expand Up @@ -58,69 +71,6 @@ def get_env_var(var_name, default=_NOT_VALUE):
return var_value


def cache(user_function, /):
"""
Simple lightweight unbounded cache. Sometimes called "memoize".
Returns the same as lru_cache(maxsize=None), creating a thin wrapper
around a dictionary lookup for the function arguments. Because it
never needs to evict old values, this is smaller and faster than
lru_cache() with a size limit.
The cache is threadsafe so the wrapped function can be used in
multiple threads.
----
This isn't in functools until python 3.9, so we copy over the
implementation as in:
https://github.com/python/cpython/blob/3.11/Lib/functools.py#L648
"""
return functools.lru_cache(maxsize=None)(user_function)


def override(method):
"""
Method decorator to signify and check a method overrides a method
in a super class.
Implementation taken from https://stackoverflow.com/a/14631397/1175053
"""
stack = inspect.stack()
base_classes = re.search(r"class.+\((.+)\)\s*\:", stack[2][4][0]).group(1)

# handle multiple inheritance
base_classes = [s.strip() for s in base_classes.split(",")]
if not base_classes:
raise CurvesimException("override decorator: unable to determine base class")

# stack[0]=overrides, stack[1]=inside class def'n, stack[2]=outside class def'n
derived_class_locals = stack[2][0].f_locals

# replace each class name in base_classes with the actual class type
for i, base_class in enumerate(base_classes):

if "." not in base_class:
base_classes[i] = derived_class_locals[base_class]

else:
components = base_class.split(".")

# obj is either a module or a class
obj = derived_class_locals[components[0]]

for c in components[1:]:
assert inspect.ismodule(obj) or inspect.isclass(obj)
obj = getattr(obj, c)

base_classes[i] = obj

if not any(hasattr(cls, method.__name__) for cls in base_classes):
raise CurvesimException(
f'Overridden method "{method.__name__}" was not found in any super class.'
)
return method


def get_pairs(arg):
"""
Get sorted pairwise combinations of an iterable.
Expand Down
50 changes: 50 additions & 0 deletions curvesim/utils/address.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""
Utility functions for working with EVM-like account addresses.
This module provides custom versions of common address operations to ensure
standard handling of EVM addresses.
"""
__all__ = ["to_address", "is_address", "Address"]

from typing import NewType

from eth_typing import ChecksumAddress
from eth_utils import is_checksum_address, is_normalized_address, to_checksum_address

Address = NewType("Address", str)


def to_address(address_string: str) -> Address:
"""
Convert given string to an EVM-like checksum address.
Parameters
----------
address_string: str
string to convert to checksum address.
Raises
------
ValueError
string cannot be coerced to a checksum address
"""
checksum_address: ChecksumAddress = to_checksum_address(address_string)
return Address(checksum_address)


def is_address(address_string: str, checksum=True) -> bool:
"""
Check if given string is properly formatted as an
EVM-like address, optionally as a checksum address.
Parameters
----------
address_string: str
string to check for proper address formatting
checksum: bool, default=True
verify it is properly checksummed
"""
if checksum: # pylint: disable=no-else-return
return is_checksum_address(address_string)
else:
return is_normalized_address(address_string)
56 changes: 56 additions & 0 deletions curvesim/utils/datetime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""
Utility functions for working with datetimes in UTC format.
This module provides custom versions of common datetime operations to ensure
returned datetimes are always in the UTC timezone and are not naive.
The main goal is to prevent potential issues related to timezone-aware and
timezone-naive datetime objects.
Recommended style for importing is:
from curvesim.utils import datetime
datetime.now()
This syntactically substitutes the `datetime.datetime` class methods with
this module's functions.
"""
from datetime import datetime, timezone


def now() -> datetime:
"""
Customized version of `datetime.datetime.now` to ensure
UTC timezone and format.
Returns
-------
datetime
Current UTC datetime with hour, minute, second, and microsecond set to 0.
This datetime is timezone-aware and is not naive.
"""
utc_dt = datetime.now(timezone.utc).replace(
hour=0, minute=0, second=0, microsecond=0
)
return utc_dt


def fromtimestamp(timestamp: int) -> datetime:
"""
Customized version of `datetime.datetime.fromtimestamp` to ensure
UTC timezone and format.
Parameters
----------
timestamp : int
Unix timestamp to be converted to datetime.
Returns
-------
datetime
The UTC datetime representation of the given timestamp.
This datetime is timezone-aware and is not naive.
"""
utc_timestamp = datetime.fromtimestamp(timestamp, tz=timezone.utc)
return utc_timestamp
Loading

0 comments on commit 6ec7a60

Please sign in to comment.