Skip to content

Commit

Permalink
MAINT: move KnownReparameterisation and ReparameterisationDict
Browse files Browse the repository at this point in the history
  • Loading branch information
mj-will committed Oct 26, 2024
1 parent 99cd120 commit 26ac71d
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 54 deletions.
58 changes: 4 additions & 54 deletions nessai/reparameterisations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,71 +8,21 @@
"""

import logging
from dataclasses import dataclass, field
from typing import Any

from ..utils.entry_points import get_entry_points
from .angle import Angle, AnglePair, ToCartesian
from .base import Reparameterisation
from .combined import CombinedReparameterisation
from .discrete import Dequantise
from .null import NullReparameterisation
from .rescale import Rescale, RescaleToBounds, ScaleAndShift
from .utils import get_reparameterisation
from .utils import (
ReparameterisationDict,
get_reparameterisation,
)

logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class KnownReparameterisation:
"""Dataclass to store the reparameterisation class and keyword arguments"""

name: str
class_fn: Reparameterisation
keyword_arguments: dict[str:Any] = field(default_factory=dict)


class ReparameterisationDict(dict):
"""Dictionary of reparameterisations
This dictionary is used to store the known reparameterisations and
provides a method to add new reparameterisations.
"""

def add_reparameterisation(self, name, class_fn, keyword_arguments=None):
"""Add a new reparameterisation to the dictionary
Parameters
----------
name : str
Name of the reparameterisation.
class_fn : Reparameterisation
Reparameterisation class.
keyword_arguments : dict, optional
Keyword arguments for the reparameterisation.
"""
if keyword_arguments is None:
keyword_arguments = {}
if name in self:
raise ValueError(f"Reparameterisation {name} already exists")
self[name] = KnownReparameterisation(name, class_fn, keyword_arguments)

def add_external_reparameterisations(self, group):
entry_points = get_entry_points(group)
for ep in entry_points.values():
reparam = ep.load()
if not isinstance(reparam, KnownReparameterisation):
raise RuntimeError(
f"Invalid external reparameterisation: {reparam}"
)
elif reparam.name in self:
raise ValueError(
f"Reparameterisation {reparam.name} already exists"
)
logger.debug(f"Adding external reparameterisation: {reparam}")
self[reparam.name] = reparam


default_reparameterisations = ReparameterisationDict()
default_reparameterisations.add_reparameterisation("default", RescaleToBounds)
default_reparameterisations.add_reparameterisation(
Expand Down
56 changes: 56 additions & 0 deletions nessai/reparameterisations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,66 @@
"""

import copy
import logging
from dataclasses import dataclass, field
from typing import Any

from ..utils.entry_points import get_entry_points
from .base import Reparameterisation
from .null import NullReparameterisation

logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class KnownReparameterisation:
"""Dataclass to store the reparameterisation class and keyword arguments"""

name: str
class_fn: Reparameterisation
keyword_arguments: dict[str:Any] = field(default_factory=dict)


class ReparameterisationDict(dict):
"""Dictionary of reparameterisations
This dictionary is used to store the known reparameterisations and
provides a method to add new reparameterisations.
"""

def add_reparameterisation(self, name, class_fn, keyword_arguments=None):
"""Add a new reparameterisation to the dictionary
Parameters
----------
name : str
Name of the reparameterisation.
class_fn : Reparameterisation
Reparameterisation class.
keyword_arguments : dict, optional
Keyword arguments for the reparameterisation.
"""
if keyword_arguments is None:
keyword_arguments = {}
if name in self:
raise ValueError(f"Reparameterisation {name} already exists")
self[name] = KnownReparameterisation(name, class_fn, keyword_arguments)

def add_external_reparameterisations(self, group):
entry_points = get_entry_points(group)
for ep in entry_points.values():
reparam = ep.load()
if not isinstance(reparam, KnownReparameterisation):
raise RuntimeError(
f"Invalid external reparameterisation: {reparam}"
)
elif reparam.name in self:
raise ValueError(
f"Reparameterisation {reparam.name} already exists"
)
logger.debug(f"Adding external reparameterisation: {reparam}")
self[reparam.name] = reparam


def get_reparameterisation(reparameterisation, defaults=None):
"""Function to get a reparameterisation class from a name
Expand Down

0 comments on commit 26ac71d

Please sign in to comment.