Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimental API Changing #80

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions src/scores/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""
Experimental features in `scores`

All api's in here are subject to change, and may be moved into the main `scores` namespace.

Can use .api to access scores, which then the context manager `source` can change.

Example:
>>> import scores.experimental
>>> from scores.experimental import api as scores_api
>>> with scores.experimental.source('NEWAPI'):
scores_api.api.continuous # Will be wrapper around said function
"""

from scores.experimental.wrapper import APIWrapper
from scores.experimental.context import APIChange as source

# Acts as standard api for scores
api = APIWrapper()

## As PyTorch connects the gradient with the array, issues arise with methods using xarray
## Not using at the moment
# try:
# from scores.experimental.pytorch import PyTorch
# # Pytorch api
# pytorch = PyTorch()
# except (ImportError, ModuleNotFoundError):
# pass
30 changes: 30 additions & 0 deletions src/scores/experimental/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""
Context managers for `scores`
"""
from scores import experimental

class APIChange:
def __init__(self, api_name: str):
"""Change the api of `scores` as defined

Args:
api_name (str):
Name of `api` to change to

Example:
>>> import scores.experimental
>>> with scores.experimental.source('NEWAPI'):
scores.experimental.api.continuous # Will be wrapper around said function
"""
if not hasattr(experimental, api_name):
raise AttributeError(f"`scores.experimental` has no attribute {api_name}.")
self.api_name = api_name
self._recorded_api = None

def __enter__(self):
"""Change api"""
self._recorded_api = experimental.api
setattr(experimental, 'api', getattr(experimental, self.api_name))

def __exit__(self, exc_type, exc_val, exc_tb):
setattr(experimental, 'api', self._recorded_api)
67 changes: 67 additions & 0 deletions src/scores/experimental/pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""
`scores` `torch` wrapper
"""


from typing import Any, Union
import torch
import xarray as xr

from scores.experimental.wrapper import APIWrapper

class PyTorch(APIWrapper):
"""
Specific wrapper to allow use of `scores` with `torch`.
"""
def __call__(self, prediction: torch.Tensor, target: torch.Tensor, dataset_select: Union[str, None] = None, **kwargs) -> torch.Tensor:
"""
Call underlying function, handling `torch.Tensor`

Due to `torch` implementation only operations on Tensor can be used.

Args:
prediction (torch.Tensor):
Predicted Tensor
target (torch.Tensor):
Target Tensor
dataset_select (Union[str, None], optional):
Variable to select, if score is an `xarray.Dataset`. Defaults to None.

Raises:
ValueError:
If score was a xr.Dataset and `dataset_select` not given.
ValueError:
If `dataset_select` given, and score not a xr.Dataset

Returns:
(torch.Tensor):
score as calculated
"""

## Note: Cannot detach tensors
## will not work with scores using xarray

# if hasattr(prediction, 'detach'):
# prediction = prediction.detach()
# target = target.detach()

# prediction = xr.DataArray(prediction.numpy())
# target = xr.DataArray(target.numpy())

score = super().__call__(prediction, target, **kwargs)

if isinstance(score, xr.Dataset):
if dataset_select is None:
raise ValueError(
f"Returned score is an 'xarray.Dataset' which cannot be parsed to a Tensor.\n"
"Set, `dataset_select` to a variable to select that variable."
)
score = score[dataset_select]

elif dataset_select:
raise ValueError(f"`dataset_select` was not None, but data was not a 'xarray.Dataset'.")

if isinstance(score, xr.DataArray):
score = score.values
return torch.Tensor(score)

90 changes: 90 additions & 0 deletions src/scores/experimental/wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""
API Wrapper around an underlying module, in this case `scores`.

Allows manipulation of data prior to calling base `scores` function.
"""


from typing import Any, Callable, Union
import scores

class APIWrapper:
"""
Base api wrapper of `scores` for use with other frameworks
"""
def __init__(self, function: Callable = scores):
"""Base wrapper for api control of `scores`

Args:
function (Callable, optional):
Function to wrap. Provides access to underlying attributes.
Defaults to `scores`.
"""
if function is None:
function = scores
self.function = function

def help(self):
"""Get help for underlying function"""
return help(self.function)

def __getattr__(self, key: str) -> Union[Callable, Any]:
"""Get underlying attribute from self.function

Args:
key (str):
Attribute name to find

Raises:
AttributeError:
If function has no attribute key

Returns:
(Union[Callable, Any]):
Underlying attribute
"""
if key == "function":
raise AttributeError(f"{self} has no attribute {key!r}")

if not hasattr(self.function, key):
raise AttributeError(f"{self.function} has no attribute {key!r}")

new_func = getattr(self.function, key)

return self.__class__(new_func)

def __dir__(self):
"""Get `__dir__` of underlying function"""
return self.function.__dir__()

def __repr__(self):
"""repr"""
return repr(self.function)

def __call__(self, *args, **kwargs) -> Any:
"""
Call underlying function

Raises:
TypeError:
If function cannot be called
"""
if not hasattr(self.function, '__call__'):
raise TypeError(f"{type(self.function)}: {self.function} can not be called.")
return self.function(*args, **kwargs)

def callback(self, **callback_kwargs):
"""
Create a callback function where any passed `callback_kwargs` are also passed.

Args:
**callback_kwargs (Any):
Extra kwargs to pass when callbacked
"""
callback_kwargs: dict = callback_kwargs

def callback_decorator(*args, **kwargs):
callback_kwargs.update(kwargs)
return self.__call__(*args, **callback_kwargs)

return callback_decorator
Loading