diff --git a/src/scores/experimental/__init__.py b/src/scores/experimental/__init__.py new file mode 100644 index 00000000..aae8cb74 --- /dev/null +++ b/src/scores/experimental/__init__.py @@ -0,0 +1,28 @@ +""" +Experimental features in `scores` + +All api's in here are subject to change, and may be moved into the main `scores` namespace. + +Can use .api to access scores, which then the context manager `source` can change. + +Example: + >>> import scores.experimental + >>> from scores.experimental import api as scores_api + >>> with scores.experimental.source('NEWAPI'): + scores_api.api.continuous # Will be wrapper around said function +""" + +from scores.experimental.wrapper import APIWrapper +from scores.experimental.context import APIChange as source + +# Acts as standard api for scores +api = APIWrapper() + +## As PyTorch connects the gradient with the array, issues arise with methods using xarray +## Not using at the moment +# try: +# from scores.experimental.pytorch import PyTorch +# # Pytorch api +# pytorch = PyTorch() +# except (ImportError, ModuleNotFoundError): +# pass diff --git a/src/scores/experimental/context.py b/src/scores/experimental/context.py new file mode 100644 index 00000000..4a340f4a --- /dev/null +++ b/src/scores/experimental/context.py @@ -0,0 +1,30 @@ +""" +Context managers for `scores` +""" +from scores import experimental + +class APIChange: + def __init__(self, api_name: str): + """Change the api of `scores` as defined + + Args: + api_name (str): + Name of `api` to change to + + Example: + >>> import scores.experimental + >>> with scores.experimental.source('NEWAPI'): + scores.experimental.api.continuous # Will be wrapper around said function + """ + if not hasattr(experimental, api_name): + raise AttributeError(f"`scores.experimental` has no attribute {api_name}.") + self.api_name = api_name + self._recorded_api = None + + def __enter__(self): + """Change api""" + self._recorded_api = experimental.api + setattr(experimental, 'api', getattr(experimental, self.api_name)) + + def __exit__(self, exc_type, exc_val, exc_tb): + setattr(experimental, 'api', self._recorded_api) diff --git a/src/scores/experimental/pytorch.py b/src/scores/experimental/pytorch.py new file mode 100644 index 00000000..ba16b68d --- /dev/null +++ b/src/scores/experimental/pytorch.py @@ -0,0 +1,67 @@ +""" +`scores` `torch` wrapper +""" + + +from typing import Any, Union +import torch +import xarray as xr + +from scores.experimental.wrapper import APIWrapper + +class PyTorch(APIWrapper): + """ + Specific wrapper to allow use of `scores` with `torch`. + """ + def __call__(self, prediction: torch.Tensor, target: torch.Tensor, dataset_select: Union[str, None] = None, **kwargs) -> torch.Tensor: + """ + Call underlying function, handling `torch.Tensor` + + Due to `torch` implementation only operations on Tensor can be used. + + Args: + prediction (torch.Tensor): + Predicted Tensor + target (torch.Tensor): + Target Tensor + dataset_select (Union[str, None], optional): + Variable to select, if score is an `xarray.Dataset`. Defaults to None. + + Raises: + ValueError: + If score was a xr.Dataset and `dataset_select` not given. + ValueError: + If `dataset_select` given, and score not a xr.Dataset + + Returns: + (torch.Tensor): + score as calculated + """ + + ## Note: Cannot detach tensors + ## will not work with scores using xarray + + # if hasattr(prediction, 'detach'): + # prediction = prediction.detach() + # target = target.detach() + + # prediction = xr.DataArray(prediction.numpy()) + # target = xr.DataArray(target.numpy()) + + score = super().__call__(prediction, target, **kwargs) + + if isinstance(score, xr.Dataset): + if dataset_select is None: + raise ValueError( + f"Returned score is an 'xarray.Dataset' which cannot be parsed to a Tensor.\n" + "Set, `dataset_select` to a variable to select that variable." + ) + score = score[dataset_select] + + elif dataset_select: + raise ValueError(f"`dataset_select` was not None, but data was not a 'xarray.Dataset'.") + + if isinstance(score, xr.DataArray): + score = score.values + return torch.Tensor(score) + \ No newline at end of file diff --git a/src/scores/experimental/wrapper.py b/src/scores/experimental/wrapper.py new file mode 100644 index 00000000..d8cd09b2 --- /dev/null +++ b/src/scores/experimental/wrapper.py @@ -0,0 +1,90 @@ +""" +API Wrapper around an underlying module, in this case `scores`. + +Allows manipulation of data prior to calling base `scores` function. +""" + + +from typing import Any, Callable, Union +import scores + +class APIWrapper: + """ + Base api wrapper of `scores` for use with other frameworks + """ + def __init__(self, function: Callable = scores): + """Base wrapper for api control of `scores` + + Args: + function (Callable, optional): + Function to wrap. Provides access to underlying attributes. + Defaults to `scores`. + """ + if function is None: + function = scores + self.function = function + + def help(self): + """Get help for underlying function""" + return help(self.function) + + def __getattr__(self, key: str) -> Union[Callable, Any]: + """Get underlying attribute from self.function + + Args: + key (str): + Attribute name to find + + Raises: + AttributeError: + If function has no attribute key + + Returns: + (Union[Callable, Any]): + Underlying attribute + """ + if key == "function": + raise AttributeError(f"{self} has no attribute {key!r}") + + if not hasattr(self.function, key): + raise AttributeError(f"{self.function} has no attribute {key!r}") + + new_func = getattr(self.function, key) + + return self.__class__(new_func) + + def __dir__(self): + """Get `__dir__` of underlying function""" + return self.function.__dir__() + + def __repr__(self): + """repr""" + return repr(self.function) + + def __call__(self, *args, **kwargs) -> Any: + """ + Call underlying function + + Raises: + TypeError: + If function cannot be called + """ + if not hasattr(self.function, '__call__'): + raise TypeError(f"{type(self.function)}: {self.function} can not be called.") + return self.function(*args, **kwargs) + + def callback(self, **callback_kwargs): + """ + Create a callback function where any passed `callback_kwargs` are also passed. + + Args: + **callback_kwargs (Any): + Extra kwargs to pass when callbacked + """ + callback_kwargs: dict = callback_kwargs + + def callback_decorator(*args, **kwargs): + callback_kwargs.update(kwargs) + return self.__call__(*args, **callback_kwargs) + + return callback_decorator