Skip to content

Commit

Permalink
Cleanup utils interface, create subpackage
Browse files Browse the repository at this point in the history
  • Loading branch information
chanhosuh committed Oct 18, 2023
1 parent 6d83be5 commit 373d28e
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 67 deletions.
80 changes: 13 additions & 67 deletions curvesim/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
"""Utlity functions for general usage in Curvesim."""
__all__ = [
"get_env_var",
"get_pairs",
"dataclass",
"get_event_loop",
"cache",
"override",
"datetime",
]

import asyncio
import functools
import inspect
import os
import re
import sys
from dataclasses import dataclass as _dataclass
from itertools import combinations

from dotenv import load_dotenv

from curvesim.exceptions import CurvesimException, MissingEnvVarError
from curvesim.exceptions import MissingEnvVarError

from .decorators import cache, override

load_dotenv()

Expand Down Expand Up @@ -58,69 +67,6 @@ def get_env_var(var_name, default=_NOT_VALUE):
return var_value


def cache(user_function, /):
"""
Simple lightweight unbounded cache. Sometimes called "memoize".
Returns the same as lru_cache(maxsize=None), creating a thin wrapper
around a dictionary lookup for the function arguments. Because it
never needs to evict old values, this is smaller and faster than
lru_cache() with a size limit.
The cache is threadsafe so the wrapped function can be used in
multiple threads.
----
This isn't in functools until python 3.9, so we copy over the
implementation as in:
https://github.com/python/cpython/blob/3.11/Lib/functools.py#L648
"""
return functools.lru_cache(maxsize=None)(user_function)


def override(method):
"""
Method decorator to signify and check a method overrides a method
in a super class.
Implementation taken from https://stackoverflow.com/a/14631397/1175053
"""
stack = inspect.stack()
base_classes = re.search(r"class.+\((.+)\)\s*\:", stack[2][4][0]).group(1)

# handle multiple inheritance
base_classes = [s.strip() for s in base_classes.split(",")]
if not base_classes:
raise CurvesimException("override decorator: unable to determine base class")

# stack[0]=overrides, stack[1]=inside class def'n, stack[2]=outside class def'n
derived_class_locals = stack[2][0].f_locals

# replace each class name in base_classes with the actual class type
for i, base_class in enumerate(base_classes):

if "." not in base_class:
base_classes[i] = derived_class_locals[base_class]

else:
components = base_class.split(".")

# obj is either a module or a class
obj = derived_class_locals[components[0]]

for c in components[1:]:
assert inspect.ismodule(obj) or inspect.isclass(obj)
obj = getattr(obj, c)

base_classes[i] = obj

if not any(hasattr(cls, method.__name__) for cls in base_classes):
raise CurvesimException(
f'Overridden method "{method.__name__}" was not found in any super class.'
)
return method


def get_pairs(arg):
"""
Get sorted pairwise combinations of an iterable.
Expand Down
68 changes: 68 additions & 0 deletions curvesim/utils/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import functools
import inspect
import re

from curvesim.exceptions import CurvesimException


def cache(user_function, /):
"""
Simple lightweight unbounded cache. Sometimes called "memoize".
Returns the same as lru_cache(maxsize=None), creating a thin wrapper
around a dictionary lookup for the function arguments. Because it
never needs to evict old values, this is smaller and faster than
lru_cache() with a size limit.
The cache is threadsafe so the wrapped function can be used in
multiple threads.
----
This isn't in functools until python 3.9, so we copy over the
implementation as in:
https://github.com/python/cpython/blob/3.11/Lib/functools.py#L648
"""
return functools.lru_cache(maxsize=None)(user_function)


def override(method):
"""
Method decorator to signify and check a method overrides a method
in a super class.
Implementation taken from https://stackoverflow.com/a/14631397/1175053
"""
stack = inspect.stack()
base_classes = re.search(r"class.+\((.+)\)\s*\:", stack[2][4][0]).group(1)

# handle multiple inheritance
base_classes = [s.strip() for s in base_classes.split(",")]
if not base_classes:
raise CurvesimException("override decorator: unable to determine base class")

# stack[0]=overrides, stack[1]=inside class def'n, stack[2]=outside class def'n
derived_class_locals = stack[2][0].f_locals

# replace each class name in base_classes with the actual class type
for i, base_class in enumerate(base_classes):

if "." not in base_class:
base_classes[i] = derived_class_locals[base_class]

else:
components = base_class.split(".")

# obj is either a module or a class
obj = derived_class_locals[components[0]]

for c in components[1:]:
assert inspect.ismodule(obj) or inspect.isclass(obj)
obj = getattr(obj, c)

base_classes[i] = obj

if not any(hasattr(cls, method.__name__) for cls in base_classes):
raise CurvesimException(
f'Overridden method "{method.__name__}" was not found in any super class.'
)
return method

0 comments on commit 373d28e

Please sign in to comment.