diff --git a/tests/test_interfaces.py b/tests/test_interfaces.py index dbf3978d3..4ea1d6957 100644 --- a/tests/test_interfaces.py +++ b/tests/test_interfaces.py @@ -23,7 +23,9 @@ def test_sample_no_interventions( url, start_time, end_time, logging_step_size, num_samples ): - result = sample(url, start_time, end_time, logging_step_size, num_samples) + result = sample( + url, end_time, logging_step_size, num_samples, start_time=start_time + ) assert isinstance(result, dict) check_result_sizes(result, start_time, end_time, logging_step_size, num_samples) @@ -51,14 +53,16 @@ def test_sample_with_static_interventions( intervened_result = sample( url, - start_time, end_time, logging_step_size, num_samples, + start_time=start_time, static_interventions=static_interventions, ) - result = sample(url, start_time, end_time, logging_step_size, num_samples) + result = sample( + url, end_time, logging_step_size, num_samples, start_time=start_time + ) check_states_match_in_all_but_values(result, intervened_result) check_result_sizes(result, start_time, end_time, logging_step_size, num_samples) @@ -97,14 +101,16 @@ def intervention_event_fn_2(time: torch.Tensor, *args, **kwargs): intervened_result = sample( url, - start_time, end_time, logging_step_size, num_samples, + start_time=start_time, dynamic_interventions=dynamic_interventions, ) - result = sample(url, start_time, end_time, logging_step_size, num_samples) + result = sample( + url, end_time, logging_step_size, num_samples, start_time=start_time + ) check_states_match_in_all_but_values(result, intervened_result) check_result_sizes(result, start_time, end_time, logging_step_size, num_samples) @@ -139,15 +145,17 @@ def intervention_event_fn_1(time: torch.Tensor, *args, **kwargs): intervened_result = sample( url, - start_time, end_time, logging_step_size, num_samples, + start_time=start_time, static_interventions=static_interventions, dynamic_interventions=dynamic_interventions, ) - result = sample(url, start_time, end_time, logging_step_size, num_samples) + result = sample( + url, end_time, logging_step_size, num_samples, start_time=start_time + ) check_states_match_in_all_but_values(result, intervened_result) check_result_sizes(result, start_time, end_time, logging_step_size, num_samples)