Skip to content

Commit

Permalink
610 atol and rtol cannot be specified in solver options (#612)
Browse files Browse the repository at this point in the history
* added rtol, atol, to interface

* added tests for optimize and calibration

* passes linter

* atol and rtol moved to inside solver_options dict, updated tests

* formatting and linting passing
  • Loading branch information
augeorge authored Sep 18, 2024
1 parent 221c424 commit 08cf9fc
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 18 deletions.
45 changes: 36 additions & 9 deletions pyciemss/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def ensemble_sample(
- If performance is incredibly slow, we suggest using `euler` to debug.
If using `euler` results in faster simulation, the issue is likely that the model is stiff.
solver_options: Dict[str, Any]
- Options to pass to the solver. See torchdiffeq' `odeint` method for more details.
- Options to pass to the solver (including atol and rtol). See torchdiffeq' `odeint` method for more details.
start_time: float
- The start time of the model. This is used to align the `start_state` from the
AMR model with the simulation timepoints.
Expand Down Expand Up @@ -121,6 +121,10 @@ def ensemble_sample(
"""
check_solver(solver_method, solver_options)

# Get tolerances for solver
rtol = solver_options.pop("rtol", 1e-7) # default = 1e-7
atol = solver_options.pop("atol", 1e-9) # default = 1e-9

with torch.no_grad():
if dirichlet_alpha is None:
dirichlet_alpha = torch.ones(len(model_paths_or_jsons))
Expand All @@ -138,7 +142,9 @@ def ensemble_sample(
raise ValueError("num_samples must be a positive integer")

def wrapped_model():
with TorchDiffEq(method=solver_method, options=solver_options):
with TorchDiffEq(
rtol=rtol, atol=atol, method=solver_method, options=solver_options
):
solution = model(
torch.as_tensor(start_time),
torch.as_tensor(end_time),
Expand Down Expand Up @@ -233,7 +239,7 @@ def ensemble_calibrate(
- If performance is incredibly slow, we suggest using `euler` to debug.
If using `euler` results in faster simulation, the issue is likely that the model is stiff.
solver_options: Dict[str, Any]
- Options to pass to the solver. See torchdiffeq' `odeint` method for more details.
- Options to pass to the solver (including atol and rtol). See torchdiffeq' `odeint` method for more details.
start_time: float
- The start time of the model. This is used to align the `start_state` from the
AMR model with the simulation timepoints.
Expand Down Expand Up @@ -280,6 +286,10 @@ def ensemble_calibrate(
if not (isinstance(num_iterations, int) and num_iterations > 0):
raise ValueError("num_iterations must be a positive integer")

# Get tolerances for solver
rtol = solver_options.pop("rtol", 1e-7) # default = 1e-7
atol = solver_options.pop("atol", 1e-9) # default = 1e-9

def autoguide(model):
guide = pyro.infer.autoguide.AutoGuideList(model)
guide.append(
Expand Down Expand Up @@ -314,7 +324,9 @@ def autoguide(model):
def wrapped_model():
obs = condition(data=_data)(_noise_model)

with TorchDiffEq(method=solver_method, options=solver_options):
with TorchDiffEq(
rtol=rtol, atol=atol, method=solver_method, options=solver_options
):
solution = model(
torch.as_tensor(start_time),
torch.as_tensor(data_timepoints[-1]),
Expand Down Expand Up @@ -384,7 +396,8 @@ def sample(
- If performance is incredibly slow, we suggest using `euler` to debug.
If using `euler` results in faster simulation, the issue is likely that the model is stiff.
solver_options: Dict[str, Any]
- Options to pass to the solver. See torchdiffeq' `odeint` method for more details.
- Options to pass to the solver (including atol and rtol).
See torchdiffeq' `odeint` method for more details.
start_time: float
- The start time of the model. This is used to align the `start_state` from the
AMR model with the simulation timepoints.
Expand Down Expand Up @@ -449,6 +462,10 @@ def sample(

check_solver(solver_method, solver_options)

# Get tolerances for solver
rtol = solver_options.pop("rtol", 1e-7) # default = 1e-7
atol = solver_options.pop("atol", 1e-9) # default = 1e-9

with torch.no_grad():
model = CompiledDynamics.load(model_path_or_json)

Expand Down Expand Up @@ -492,7 +509,9 @@ def sample(

def wrapped_model():
with ParameterInterventionTracer():
with TorchDiffEq(method=solver_method, options=solver_options):
with TorchDiffEq(
rtol=rtol, atol=atol, method=solver_method, options=solver_options
):
with contextlib.ExitStack() as stack:
for handler in intervention_handlers:
stack.enter_context(handler)
Expand Down Expand Up @@ -602,7 +621,8 @@ def calibrate(
- If performance is incredibly slow, we suggest using `euler` to debug.
If using `euler` results in faster simulation, the issue is likely that the model is stiff.
- solver_options: Dict[str, Any]
- Options to pass to the solver. See torchdiffeq' `odeint` method for more details.
- Options to pass to the solver (including atol and rtol).
See torchdiffeq' `odeint` method for more details.
- start_time: float
- The start time of the model. This is used to align the `start_state` from the
AMR model with the simulation timepoints.
Expand Down Expand Up @@ -668,6 +688,10 @@ def calibrate(

check_solver(solver_method, solver_options)

# Get tolerances for solver
rtol = solver_options.pop("rtol", 1e-7) # default = 1e-7
atol = solver_options.pop("atol", 1e-9) # default = 1e-9

pyro.clear_param_store()

model = CompiledDynamics.load(model_path_or_json)
Expand Down Expand Up @@ -740,7 +764,9 @@ def wrapped_model():
obs = condition(data=_data)(_noise_model)

with StaticBatchObservation(data_timepoints, observation=obs):
with TorchDiffEq(method=solver_method, options=solver_options):
with TorchDiffEq(
rtol=rtol, atol=atol, method=solver_method, options=solver_options
):
with contextlib.ExitStack() as stack:
for handler in intervention_handlers:
stack.enter_context(handler)
Expand Down Expand Up @@ -834,7 +860,8 @@ def optimize(
- If performance is incredibly slow, we suggest using `euler` to debug.
If using `euler` results in faster simulation, the issue is likely that the model is stiff.
solver_options: Dict[str, Any]
- Options to pass to the solver. See torchdiffeq' `odeint` method for more details.
- Options to pass to the solver (including atol and rtol).
See torchdiffeq' `odeint` method for more details.
start_time: float
- The start time of the model. This is used to align the `start_state` from the
AMR model with the simulation timepoints.
Expand Down
7 changes: 6 additions & 1 deletion pyciemss/ouu/ouu.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def __init__(
self.u_bounds = u_bounds
self.risk_bound = risk_bound # used for defining penalty
warnings.simplefilter("always", UserWarning)
self.rtol = self.solver_options.pop("rtol", 1e-7) # default = 1e-7
self.atol = self.solver_options.pop("atol", 1e-9) # default = 1e-9

def __call__(self, x):
if np.any(x - self.u_bounds[0, :] < 0.0) or np.any(
Expand Down Expand Up @@ -144,7 +146,10 @@ def propagate_uncertainty(self, x):
def wrapped_model():
with ParameterInterventionTracer():
with TorchDiffEq(
method=self.solver_method, options=self.solver_options
rtol=self.rtol,
atol=self.atol,
method=self.solver_method,
options=self.solver_options,
):
with contextlib.ExitStack() as stack:
for handler in static_parameter_intervention_handlers:
Expand Down
51 changes: 43 additions & 8 deletions tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,29 +87,56 @@ def setup_calibrate(model_fixture, start_time, end_time, logging_step_size):
"num_iterations": 2,
}

RTOL = [1e-6, 1e-4]
ATOL = [1e-8, 1e-6]


@pytest.mark.parametrize("sample_method", SAMPLE_METHODS)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("start_time", START_TIMES)
@pytest.mark.parametrize("end_time", END_TIMES)
@pytest.mark.parametrize("logging_step_size", LOGGING_STEP_SIZES)
@pytest.mark.parametrize("num_samples", NUM_SAMPLES)
@pytest.mark.parametrize("rtol", RTOL)
@pytest.mark.parametrize("atol", ATOL)
def test_sample_no_interventions(
sample_method, model, start_time, end_time, logging_step_size, num_samples
sample_method,
model,
start_time,
end_time,
logging_step_size,
num_samples,
rtol,
atol,
):
model_url = model.url

with pyro.poutine.seed(rng_seed=0):
result1 = sample_method(
model_url, end_time, logging_step_size, num_samples, start_time=start_time
model_url,
end_time,
logging_step_size,
num_samples,
start_time=start_time,
solver_options={"rtol": rtol, "atol": atol},
)["unprocessed_result"]
with pyro.poutine.seed(rng_seed=0):
result2 = sample_method(
model_url, end_time, logging_step_size, num_samples, start_time=start_time
model_url,
end_time,
logging_step_size,
num_samples,
start_time=start_time,
solver_options={"rtol": rtol, "atol": atol},
)["unprocessed_result"]

result3 = sample_method(
model_url, end_time, logging_step_size, num_samples, start_time=start_time
model_url,
end_time,
logging_step_size,
num_samples,
start_time=start_time,
solver_options={"rtol": rtol, "atol": atol},
)["unprocessed_result"]

for result in [result1, result2, result3]:
Expand Down Expand Up @@ -364,8 +391,10 @@ def test_calibrate_no_kwargs(
@pytest.mark.parametrize("start_time", START_TIMES)
@pytest.mark.parametrize("end_time", END_TIMES)
@pytest.mark.parametrize("logging_step_size", LOGGING_STEP_SIZES)
@pytest.mark.parametrize("rtol", RTOL)
@pytest.mark.parametrize("atol", ATOL)
def test_calibrate_deterministic(
model_fixture, start_time, end_time, logging_step_size
model_fixture, start_time, end_time, logging_step_size, rtol, atol
):
model_url = model_fixture.url
(
Expand All @@ -381,6 +410,7 @@ def test_calibrate_deterministic(
"data_mapping": model_fixture.data_mapping,
"start_time": start_time,
"deterministic_learnable_parameters": deterministic_learnable_parameters,
"solver_options": {"rtol": rtol, "atol": atol},
**CALIBRATE_KWARGS,
}

Expand All @@ -400,7 +430,10 @@ def test_calibrate_deterministic(
assert torch.allclose(param_value, param_sample_2[param_name])

result = sample(
*sample_args, **sample_kwargs, inferred_parameters=inferred_parameters
*sample_args,
**sample_kwargs,
inferred_parameters=inferred_parameters,
solver_options={"rtol": rtol, "atol": atol},
)["unprocessed_result"]

check_result_sizes(result, start_time, end_time, logging_step_size, 1)
Expand Down Expand Up @@ -563,7 +596,9 @@ def test_output_format(
@pytest.mark.parametrize("start_time", START_TIMES)
@pytest.mark.parametrize("end_time", END_TIMES)
@pytest.mark.parametrize("num_samples", NUM_SAMPLES)
def test_optimize(model_fixture, start_time, end_time, num_samples):
@pytest.mark.parametrize("rtol", RTOL)
@pytest.mark.parametrize("atol", ATOL)
def test_optimize(model_fixture, start_time, end_time, num_samples, rtol, atol):
logging_step_size = 1.0
model_url = model_fixture.url

Expand All @@ -581,7 +616,7 @@ def __call__(self, x):
optimize_kwargs = {
**model_fixture.optimize_kwargs,
"solver_method": "euler",
"solver_options": {"step_size": 0.1},
"solver_options": {"step_size": 0.1, "rtol": rtol, "atol": atol},
"start_time": start_time,
"n_samples_ouu": int(2),
"maxiter": 1,
Expand Down

0 comments on commit 08cf9fc

Please sign in to comment.