Skip to content

Commit

Permalink
settings context manager (#477)
Browse files Browse the repository at this point in the history
  • Loading branch information
ziofil authored Sep 5, 2024
1 parent d0ba778 commit 1a798b3
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 63 deletions.
51 changes: 36 additions & 15 deletions mrmustard/utils/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,14 @@ class Settings:
"""

def __new__(cls): # singleton
if not hasattr(cls, "instance"):
cls.instance = super(Settings, cls).__new__(cls)
return cls.instance
if not hasattr(cls, "_instance"):
cls._instance = super(Settings, cls).__new__(cls)
return cls._instance

def __init__(self):
self._hbar: float = 1.0
self._hbar_locked: bool = False
self._seed: int = np.random.randint(0, 2**31 - 1)
self._complex_warning: bool = False
self.rng = np.random.default_rng(self._seed)
self._precision_bits_hermite_poly: int = 128
self._complex_warning: bool = False
Expand All @@ -70,7 +69,7 @@ def __init__(self):
self.STABLE_FOCK_CONVERSION: bool = False
"Whether to use the ``vanilla_stable`` function when computing Fock amplitudes (more stable, but slower). Default is False."

self.DEBUG: bool = False
self.DEBUG: bool = False # TODO: remove in MM 1.0
"Whether or not to print the vector of means and the covariance matrix alongside the html representation of a state. Default is False."

self.AUTOSHAPE_PROBABILITY: float = 0.99999
Expand All @@ -94,25 +93,22 @@ def __init__(self):
self.DISCRETIZATION_METHOD: str = "clenshaw"
"The method used to discretize the Wigner function. Can be ``clenshaw`` (better, default) or ``iterative`` (worse, faster)."

self.EQ_TRANSFORMATION_CUTOFF: int = 3 # enough for a full step of rec rel
self.EQ_TRANSFORMATION_CUTOFF: int = 3 # TODO: remove in MM 1.0
"The cutoff used when comparing two transformations via the Choi–Jamiolkowski isomorphism. Default is 3."

self.EQ_TRANSFORMATION_RTOL_FOCK: float = 1e-3
self.EQ_TRANSFORMATION_RTOL_FOCK: float = 1e-3 # TODO: remove in MM 1.0
"The relative tolerance used when comparing two transformations via the Choi–Jamiolkowski isomorphism. Default is 1e-3."

self.EQ_TRANSFORMATION_RTOL_GAUSS: float = 1e-6
self.EQ_TRANSFORMATION_RTOL_GAUSS: float = 1e-6 # TODO: remove in MM 1.0
"The relative tolerance used when comparing two transformations on Gaussian states. Default is 1e-6."

self.PRN_INTERNAL_CUTOFF: int = 50
"The cutoff used when computing the output of a PNR detection. Default is 50."

self.HOMODYNE_SQUEEZING: float = 10.0
self.HOMODYNE_SQUEEZING: float = 10.0 # TODO: remove in MM 1.0
"The value of squeezing for homodyne measurements. Default is 10.0."

self.PROGRESSBAR: bool = True
"Whether or not to display the progress bar when performing training. Default is True."

self.PNR_INTERNAL_CUTOFF: int = 50
self.PNR_INTERNAL_CUTOFF: int = 50 # TODO: remove in MM 1.0
"The cutoff used when computing the output of a PNR detection. Default is 50."

self.BS_FOCK_METHOD: str = "vanilla" # can be 'vanilla' or 'schwinger'
Expand All @@ -121,6 +117,31 @@ def __init__(self):
self.ATOL: float = 1e-8
"The absolute tolerance when comparing two values or arrays. Default is 1e-8."

self._original_values = self.__dict__.copy()

def __call__(self, **kwargs):
"allows for setting multiple settings at once and saving the original values"
disallowed = {
"COMPLEX_WARNING",
"HBAR",
"SEED",
"PRECISION_BITS_HERMITE_POLY",
"CACHE_DIR",
} & kwargs.keys()
if disallowed:
raise ValueError(f"Cannot change the value of {disallowed} using a context manager.")
self._original_values = self.__dict__.copy()
self.__dict__.update(kwargs)
return self

def __enter__(self):
"context manager enter method"
return self

def __exit__(self, exc_type, exc_value, traceback):
"context manager exit method that resets the settings to their original values"
self.__dict__.update(self._original_values)

@property
def COMPLEX_WARNING(self):
r"""Whether tensorflow's ``ComplexWarning``s should be raised when a complex is cast to a float. Default is ``False``."""
Expand All @@ -136,7 +157,7 @@ def COMPLEX_WARNING(self, value: bool):

@property
def HBAR(self):
r"""The value of the Planck constant. Default is ``2``.
r"""The value of the Planck constant. Default is ``1``.
Cannot be changed after its value is queried for the first time.
"""
Expand All @@ -146,7 +167,7 @@ def HBAR(self):
@HBAR.setter
def HBAR(self, value: float):
if value != self._hbar and self._hbar_locked:
raise ValueError("Cannot change the value of `settings.HBAR`.")
raise ValueError("Cannot change the value of `settings.HBAR` in the current session.")
self._hbar = value

@property
Expand Down
75 changes: 27 additions & 48 deletions tests/test_utils/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ def test_init(self):

assert settings.HBAR == 1.0
assert settings.DEBUG is False
assert (
settings.AUTOSHAPE_PROBABILITY == 0.99999
) # capture at least 99.9% of the probability
assert settings.AUTOSHAPE_PROBABILITY == 0.99999
assert settings.AUTOCUTOFF_MAX_CUTOFF == 100
assert settings.AUTOCUTOFF_MIN_CUTOFF == 1
assert settings.CIRCUIT_DECIMALS == 3
Expand All @@ -46,60 +44,25 @@ def test_init(self):
assert settings.HOMODYNE_SQUEEZING == 10.0
assert settings.PRECISION_BITS_HERMITE_POLY == 128
assert settings.PROGRESSBAR is True
assert settings.BS_FOCK_METHOD == "vanilla" # can be 'vanilla' or 'schwinger'
assert settings.BS_FOCK_METHOD == "vanilla"

def test_setters(self):
settings = Settings()

ap0 = settings.AUTOSHAPE_PROBABILITY
settings.AUTOSHAPE_PROBABILITY = 0.1
assert settings.AUTOSHAPE_PROBABILITY == 0.1
settings.AUTOSHAPE_PROBABILITY = ap0

db0 = settings.DEBUG
settings.DEBUG = True
assert settings.DEBUG is True
settings.DEBUG = db0

dbsm0 = settings.BS_FOCK_METHOD
settings.BS_FOCK_METHOD = "schwinger"
assert settings.BS_FOCK_METHOD == "schwinger"
settings.BS_FOCK_METHOD = dbsm0

eqtc0 = settings.EQ_TRANSFORMATION_CUTOFF
settings.EQ_TRANSFORMATION_CUTOFF = 2
assert settings.EQ_TRANSFORMATION_CUTOFF == 2
settings.EQ_TRANSFORMATION_CUTOFF = eqtc0

pnr0 = settings.PNR_INTERNAL_CUTOFF
settings.PNR_INTERNAL_CUTOFF = False
assert settings.PNR_INTERNAL_CUTOFF is False
settings.PNR_INTERNAL_CUTOFF = pnr0

pb0 = settings.PROGRESSBAR
settings.PROGRESSBAR = False
assert settings.PROGRESSBAR is False
settings.PROGRESSBAR = pb0
cw = settings.COMPLEX_WARNING
settings.COMPLEX_WARNING = not cw
assert settings.COMPLEX_WARNING == (not cw)
settings.COMPLEX_WARNING = cw

s0 = settings.SEED
settings.SEED = None
assert settings.SEED is not None
settings.SEED = s0

hs0 = settings.HOMODYNE_SQUEEZING
settings.HOMODYNE_SQUEEZING = 20.1
assert settings.HOMODYNE_SQUEEZING == 20.1
settings.HOMODYNE_SQUEEZING = hs0

fock_rtol = settings.EQ_TRANSFORMATION_RTOL_FOCK
settings.EQ_TRANSFORMATION_RTOL_FOCK = 0.02
assert settings.EQ_TRANSFORMATION_RTOL_FOCK == 0.02
settings.EQ_TRANSFORMATION_RTOL_FOCK = fock_rtol

gauss_rtol = settings.EQ_TRANSFORMATION_RTOL_GAUSS
settings.EQ_TRANSFORMATION_RTOL_GAUSS = 0.02
assert settings.EQ_TRANSFORMATION_RTOL_GAUSS == 0.02
settings.EQ_TRANSFORMATION_RTOL_GAUSS = gauss_rtol
p0 = settings.PRECISION_BITS_HERMITE_POLY
settings.PRECISION_BITS_HERMITE_POLY = 256
assert settings.PRECISION_BITS_HERMITE_POLY == 256
settings.PRECISION_BITS_HERMITE_POLY = p0

assert settings.HBAR == 1.0
with pytest.raises(ValueError, match="Cannot change"):
Expand All @@ -112,7 +75,7 @@ def test_settings_seed_randomness_at_init(self):
"""Test that the random seed is set randomly as MM is initialized."""
settings = Settings()
seed0 = settings.SEED
del Settings.instance
del Settings._instance
settings = Settings()
seed1 = settings.SEED
assert seed0 != seed1
Expand Down Expand Up @@ -144,3 +107,19 @@ def test_complex_warnings(self, caplog):
settings.COMPLEX_WARNING = False
math.cast(1 + 1j, math.float64)
assert len(caplog.records) == 1

def test_context_manager(self):
"""Test that the context manager works correctly."""
settings = Settings()

with settings(AUTOSHAPE_PROBABILITY=0.1):
assert settings.AUTOSHAPE_PROBABILITY == 0.1
assert settings.AUTOSHAPE_PROBABILITY == 0.99999

def test_context_manager_disallowed(self):
"""Test that the context manager disallows changing some settings."""
settings = Settings()

with pytest.raises(ValueError, match="Cannot change"):
with settings(HBAR=0.5):
pass

0 comments on commit 1a798b3

Please sign in to comment.