diff --git a/nessai/reparameterisations/__init__.py b/nessai/reparameterisations/__init__.py index f14bdfe0..d71b1c70 100644 --- a/nessai/reparameterisations/__init__.py +++ b/nessai/reparameterisations/__init__.py @@ -8,71 +8,21 @@ """ import logging -from dataclasses import dataclass, field -from typing import Any -from ..utils.entry_points import get_entry_points from .angle import Angle, AnglePair, ToCartesian from .base import Reparameterisation from .combined import CombinedReparameterisation from .discrete import Dequantise from .null import NullReparameterisation from .rescale import Rescale, RescaleToBounds, ScaleAndShift -from .utils import get_reparameterisation +from .utils import ( + ReparameterisationDict, + get_reparameterisation, +) logger = logging.getLogger(__name__) -@dataclass(frozen=True) -class KnownReparameterisation: - """Dataclass to store the reparameterisation class and keyword arguments""" - - name: str - class_fn: Reparameterisation - keyword_arguments: dict[str:Any] = field(default_factory=dict) - - -class ReparameterisationDict(dict): - """Dictionary of reparameterisations - - This dictionary is used to store the known reparameterisations and - provides a method to add new reparameterisations. - """ - - def add_reparameterisation(self, name, class_fn, keyword_arguments=None): - """Add a new reparameterisation to the dictionary - - Parameters - ---------- - name : str - Name of the reparameterisation. - class_fn : Reparameterisation - Reparameterisation class. - keyword_arguments : dict, optional - Keyword arguments for the reparameterisation. - """ - if keyword_arguments is None: - keyword_arguments = {} - if name in self: - raise ValueError(f"Reparameterisation {name} already exists") - self[name] = KnownReparameterisation(name, class_fn, keyword_arguments) - - def add_external_reparameterisations(self, group): - entry_points = get_entry_points(group) - for ep in entry_points.values(): - reparam = ep.load() - if not isinstance(reparam, KnownReparameterisation): - raise RuntimeError( - f"Invalid external reparameterisation: {reparam}" - ) - elif reparam.name in self: - raise ValueError( - f"Reparameterisation {reparam.name} already exists" - ) - logger.debug(f"Adding external reparameterisation: {reparam}") - self[reparam.name] = reparam - - default_reparameterisations = ReparameterisationDict() default_reparameterisations.add_reparameterisation("default", RescaleToBounds) default_reparameterisations.add_reparameterisation( diff --git a/nessai/reparameterisations/utils.py b/nessai/reparameterisations/utils.py index 990ee01b..472ff6b3 100644 --- a/nessai/reparameterisations/utils.py +++ b/nessai/reparameterisations/utils.py @@ -4,10 +4,66 @@ """ import copy +import logging +from dataclasses import dataclass, field +from typing import Any +from ..utils.entry_points import get_entry_points from .base import Reparameterisation from .null import NullReparameterisation +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class KnownReparameterisation: + """Dataclass to store the reparameterisation class and keyword arguments""" + + name: str + class_fn: Reparameterisation + keyword_arguments: dict[str:Any] = field(default_factory=dict) + + +class ReparameterisationDict(dict): + """Dictionary of reparameterisations + + This dictionary is used to store the known reparameterisations and + provides a method to add new reparameterisations. + """ + + def add_reparameterisation(self, name, class_fn, keyword_arguments=None): + """Add a new reparameterisation to the dictionary + + Parameters + ---------- + name : str + Name of the reparameterisation. + class_fn : Reparameterisation + Reparameterisation class. + keyword_arguments : dict, optional + Keyword arguments for the reparameterisation. + """ + if keyword_arguments is None: + keyword_arguments = {} + if name in self: + raise ValueError(f"Reparameterisation {name} already exists") + self[name] = KnownReparameterisation(name, class_fn, keyword_arguments) + + def add_external_reparameterisations(self, group): + entry_points = get_entry_points(group) + for ep in entry_points.values(): + reparam = ep.load() + if not isinstance(reparam, KnownReparameterisation): + raise RuntimeError( + f"Invalid external reparameterisation: {reparam}" + ) + elif reparam.name in self: + raise ValueError( + f"Reparameterisation {reparam.name} already exists" + ) + logger.debug(f"Adding external reparameterisation: {reparam}") + self[reparam.name] = reparam + def get_reparameterisation(reparameterisation, defaults=None): """Function to get a reparameterisation class from a name