Skip to content

Commit

Permalink
fix upstream time collisions in ChiRho (#400)
Browse files Browse the repository at this point in the history
  • Loading branch information
SamWitty authored Oct 30, 2023
1 parent 5e67d92 commit b3e5a4e
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ install_requires =
jupyter
torch >= 1.8.0
mira @ git+https://github.com/indralab/[email protected]
chirho @ git+https://github.com/BasisResearch/chirho@f44731d416c20cbf1615147f5e0ba6ef3afed78d
chirho @ git+https://github.com/BasisResearch/chirho@f3019d4b22f4e49261efbf8da90a30095af2afbc
sympytorch
torchdiffeq

Expand Down
5 changes: 3 additions & 2 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
START_TIMES = [0.0, 1.0, 2.0]
END_TIMES = [3.0, 4.0, 5.0]

LOGGING_STEP_SIZES = [0.09]
LOGGING_STEP_SIZES = [0.1]

NUM_SAMPLES = [2]

Expand Down Expand Up @@ -68,7 +68,8 @@ def check_result_sizes(
if k[:5] == "state":
assert v.shape == (
num_samples,
len(torch.arange(start_time, end_time, logging_step_size)),
len(torch.arange(start_time, end_time, logging_step_size))
- 1, # Does not include start_time
)
else:
assert v.shape == (num_samples,)
Expand Down
7 changes: 3 additions & 4 deletions tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def test_simulate_no_interventions(
):
result = simulate(url, start_time, end_time, logging_step_size, num_samples)
assert isinstance(result, dict)

check_result_sizes(result, start_time, end_time, logging_step_size, num_samples)


Expand All @@ -43,7 +42,7 @@ def test_simulate_with_static_interventions(
intervened_state_1 = {k: v + 1 for k, v in initial_state.items()}
intervened_state_2 = {k: v + 2 for k, v in initial_state.items()}

intervention_time_1 = (end_time + start_time) / 2.000001 # Midpoint
intervention_time_1 = (end_time + start_time) / 2 # Midpoint
intervention_time_2 = (end_time + intervention_time_1) / 2 # 3/4 point
static_interventions = {
intervention_time_1: intervened_state_1,
Expand Down Expand Up @@ -82,7 +81,7 @@ def test_simulate_with_dynamic_interventions(
intervened_state_1 = {k: v + 1 for k, v in initial_state.items()}
intervened_state_2 = {k: v + 2 for k, v in initial_state.items()}

intervention_time_1 = (end_time + start_time) / 2.000001 # Midpoint
intervention_time_1 = (end_time + start_time) / 2 # Midpoint
intervention_time_2 = (end_time + intervention_time_1) / 2 # 3/4 point

def intervention_event_fn_1(time: torch.Tensor, *args, **kwargs):
Expand Down Expand Up @@ -128,7 +127,7 @@ def test_simulate_with_static_and_dynamic_interventions(
intervened_state_1 = {k: v + 1 for k, v in initial_state.items()}
intervened_state_2 = {k: v + 2 for k, v in initial_state.items()}

intervention_time_1 = (end_time + start_time) / 2.000001 # Midpoint
intervention_time_1 = (end_time + start_time) / 2 # Midpoint
intervention_time_2 = (end_time + intervention_time_1) / 2 # 3/4 point

def intervention_event_fn_1(time: torch.Tensor, *args, **kwargs):
Expand Down

0 comments on commit b3e5a4e

Please sign in to comment.