Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

settings context manager #477

Merged
merged 4 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 36 additions & 15 deletions mrmustard/utils/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,14 @@ class Settings:
"""

def __new__(cls): # singleton
if not hasattr(cls, "instance"):
cls.instance = super(Settings, cls).__new__(cls)
return cls.instance
if not hasattr(cls, "_instance"):
cls._instance = super(Settings, cls).__new__(cls)
return cls._instance

def __init__(self):
self._hbar: float = 1.0
self._hbar_locked: bool = False
self._seed: int = np.random.randint(0, 2**31 - 1)
self._complex_warning: bool = False
self.rng = np.random.default_rng(self._seed)
self._precision_bits_hermite_poly: int = 128
self._complex_warning: bool = False
Expand All @@ -70,7 +69,7 @@ def __init__(self):
self.STABLE_FOCK_CONVERSION: bool = False
"Whether to use the ``vanilla_stable`` function when computing Fock amplitudes (more stable, but slower). Default is False."

self.DEBUG: bool = False
self.DEBUG: bool = False # TODO: remove in MM 1.0
"Whether or not to print the vector of means and the covariance matrix alongside the html representation of a state. Default is False."

self.AUTOSHAPE_PROBABILITY: float = 0.99999
Expand All @@ -94,25 +93,22 @@ def __init__(self):
self.DISCRETIZATION_METHOD: str = "clenshaw"
"The method used to discretize the Wigner function. Can be ``clenshaw`` (better, default) or ``iterative`` (worse, faster)."

self.EQ_TRANSFORMATION_CUTOFF: int = 3 # enough for a full step of rec rel
self.EQ_TRANSFORMATION_CUTOFF: int = 3 # TODO: remove in MM 1.0
"The cutoff used when comparing two transformations via the Choi–Jamiolkowski isomorphism. Default is 3."

self.EQ_TRANSFORMATION_RTOL_FOCK: float = 1e-3
self.EQ_TRANSFORMATION_RTOL_FOCK: float = 1e-3 # TODO: remove in MM 1.0
"The relative tolerance used when comparing two transformations via the Choi–Jamiolkowski isomorphism. Default is 1e-3."

self.EQ_TRANSFORMATION_RTOL_GAUSS: float = 1e-6
self.EQ_TRANSFORMATION_RTOL_GAUSS: float = 1e-6 # TODO: remove in MM 1.0
"The relative tolerance used when comparing two transformations on Gaussian states. Default is 1e-6."

self.PRN_INTERNAL_CUTOFF: int = 50
"The cutoff used when computing the output of a PNR detection. Default is 50."

self.HOMODYNE_SQUEEZING: float = 10.0
self.HOMODYNE_SQUEEZING: float = 10.0 # TODO: remove in MM 1.0
"The value of squeezing for homodyne measurements. Default is 10.0."

self.PROGRESSBAR: bool = True
"Whether or not to display the progress bar when performing training. Default is True."

self.PNR_INTERNAL_CUTOFF: int = 50
self.PNR_INTERNAL_CUTOFF: int = 50 # TODO: remove in MM 1.0
"The cutoff used when computing the output of a PNR detection. Default is 50."

self.BS_FOCK_METHOD: str = "vanilla" # can be 'vanilla' or 'schwinger'
Expand All @@ -121,6 +117,31 @@ def __init__(self):
self.ATOL: float = 1e-8
"The absolute tolerance when comparing two values or arrays. Default is 1e-8."

self._original_values = self.__dict__.copy()

def __call__(self, **kwargs):
"allows for setting multiple settings at once and saving the original values"
disallowed = {
"COMPLEX_WARNING",
"HBAR",
"SEED",
"PRECISION_BITS_HERMITE_POLY",
"CACHE_DIR",
} & kwargs.keys()
if disallowed:
raise ValueError(f"Cannot change the value of {disallowed} using a context manager.")
self._original_values = self.__dict__.copy()
self.__dict__.update(kwargs)
return self

def __enter__(self):
"context manager enter method"
return self

def __exit__(self, exc_type, exc_value, traceback):
"context manager exit method that resets the settings to their original values"
self.__dict__.update(self._original_values)

@property
def COMPLEX_WARNING(self):
r"""Whether tensorflow's ``ComplexWarning``s should be raised when a complex is cast to a float. Default is ``False``."""
Expand All @@ -136,7 +157,7 @@ def COMPLEX_WARNING(self, value: bool):

@property
def HBAR(self):
r"""The value of the Planck constant. Default is ``2``.
r"""The value of the Planck constant. Default is ``1``.

Cannot be changed after its value is queried for the first time.
"""
Expand All @@ -146,7 +167,7 @@ def HBAR(self):
@HBAR.setter
def HBAR(self, value: float):
if value != self._hbar and self._hbar_locked:
raise ValueError("Cannot change the value of `settings.HBAR`.")
raise ValueError("Cannot change the value of `settings.HBAR` in the current session.")
self._hbar = value

@property
Expand Down
75 changes: 27 additions & 48 deletions tests/test_utils/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ def test_init(self):

assert settings.HBAR == 1.0
assert settings.DEBUG is False
assert (
settings.AUTOSHAPE_PROBABILITY == 0.99999
) # capture at least 99.9% of the probability
assert settings.AUTOSHAPE_PROBABILITY == 0.99999
assert settings.AUTOCUTOFF_MAX_CUTOFF == 100
assert settings.AUTOCUTOFF_MIN_CUTOFF == 1
assert settings.CIRCUIT_DECIMALS == 3
Expand All @@ -46,60 +44,25 @@ def test_init(self):
assert settings.HOMODYNE_SQUEEZING == 10.0
assert settings.PRECISION_BITS_HERMITE_POLY == 128
assert settings.PROGRESSBAR is True
assert settings.BS_FOCK_METHOD == "vanilla" # can be 'vanilla' or 'schwinger'
assert settings.BS_FOCK_METHOD == "vanilla"

def test_setters(self):
settings = Settings()

ap0 = settings.AUTOSHAPE_PROBABILITY
settings.AUTOSHAPE_PROBABILITY = 0.1
assert settings.AUTOSHAPE_PROBABILITY == 0.1
settings.AUTOSHAPE_PROBABILITY = ap0

db0 = settings.DEBUG
settings.DEBUG = True
assert settings.DEBUG is True
settings.DEBUG = db0

dbsm0 = settings.BS_FOCK_METHOD
settings.BS_FOCK_METHOD = "schwinger"
assert settings.BS_FOCK_METHOD == "schwinger"
settings.BS_FOCK_METHOD = dbsm0

eqtc0 = settings.EQ_TRANSFORMATION_CUTOFF
settings.EQ_TRANSFORMATION_CUTOFF = 2
assert settings.EQ_TRANSFORMATION_CUTOFF == 2
settings.EQ_TRANSFORMATION_CUTOFF = eqtc0

pnr0 = settings.PNR_INTERNAL_CUTOFF
settings.PNR_INTERNAL_CUTOFF = False
assert settings.PNR_INTERNAL_CUTOFF is False
settings.PNR_INTERNAL_CUTOFF = pnr0

pb0 = settings.PROGRESSBAR
settings.PROGRESSBAR = False
assert settings.PROGRESSBAR is False
settings.PROGRESSBAR = pb0
cw = settings.COMPLEX_WARNING
settings.COMPLEX_WARNING = not cw
assert settings.COMPLEX_WARNING == (not cw)
settings.COMPLEX_WARNING = cw

s0 = settings.SEED
settings.SEED = None
assert settings.SEED is not None
settings.SEED = s0

hs0 = settings.HOMODYNE_SQUEEZING
settings.HOMODYNE_SQUEEZING = 20.1
assert settings.HOMODYNE_SQUEEZING == 20.1
settings.HOMODYNE_SQUEEZING = hs0

fock_rtol = settings.EQ_TRANSFORMATION_RTOL_FOCK
settings.EQ_TRANSFORMATION_RTOL_FOCK = 0.02
assert settings.EQ_TRANSFORMATION_RTOL_FOCK == 0.02
settings.EQ_TRANSFORMATION_RTOL_FOCK = fock_rtol

gauss_rtol = settings.EQ_TRANSFORMATION_RTOL_GAUSS
settings.EQ_TRANSFORMATION_RTOL_GAUSS = 0.02
assert settings.EQ_TRANSFORMATION_RTOL_GAUSS == 0.02
settings.EQ_TRANSFORMATION_RTOL_GAUSS = gauss_rtol
p0 = settings.PRECISION_BITS_HERMITE_POLY
settings.PRECISION_BITS_HERMITE_POLY = 256
assert settings.PRECISION_BITS_HERMITE_POLY == 256
settings.PRECISION_BITS_HERMITE_POLY = p0

assert settings.HBAR == 1.0
with pytest.raises(ValueError, match="Cannot change"):
Expand All @@ -112,7 +75,7 @@ def test_settings_seed_randomness_at_init(self):
"""Test that the random seed is set randomly as MM is initialized."""
settings = Settings()
seed0 = settings.SEED
del Settings.instance
del Settings._instance
settings = Settings()
seed1 = settings.SEED
assert seed0 != seed1
Expand Down Expand Up @@ -144,3 +107,19 @@ def test_complex_warnings(self, caplog):
settings.COMPLEX_WARNING = False
math.cast(1 + 1j, math.float64)
assert len(caplog.records) == 1

def test_context_manager(self):
"""Test that the context manager works correctly."""
settings = Settings()

with settings(AUTOSHAPE_PROBABILITY=0.1):
assert settings.AUTOSHAPE_PROBABILITY == 0.1
assert settings.AUTOSHAPE_PROBABILITY == 0.99999

def test_context_manager_disallowed(self):
"""Test that the context manager disallows changing some settings."""
settings = Settings()

with pytest.raises(ValueError, match="Cannot change"):
with settings(HBAR=0.5):
pass
Loading