Skip to content

Commit

Permalink
Typing
Browse files Browse the repository at this point in the history
  • Loading branch information
anirban-chaudhuri committed Jul 31, 2024
1 parent 4077928 commit 0f91a32
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 23 deletions.
32 changes: 16 additions & 16 deletions pyciemss/integration_utils/intervention_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def param_value_objective(
param_name: List[str],
start_time: List[torch.Tensor],
param_value: List[Intervention] = [None],
) -> Callable[[torch.Tensor], Dict[torch.Tensor, Dict[str, Intervention]]]:
) -> Callable[[torch.Tensor], Dict[float, Dict[str, Intervention]]]:
param_size = len(param_name)
if len(param_value) < param_size and param_value[0] is None:
param_value = [None for _ in param_name]
Expand All @@ -19,13 +19,13 @@ def param_value_objective(

def intervention_generator(
x: torch.Tensor,
) -> Dict[torch.Tensor, Dict[str, Intervention]]:
) -> Dict[float, Dict[str, Intervention]]:
x = torch.atleast_1d(x)
assert x.size()[0] == param_size, (
f"Size mismatch between input size ('{x.size()[0]}') and param_name size ('{param_size}'): "
"check size for initial_guess_interventions and/or bounds_interventions."
)
static_parameter_interventions: Dict[torch.Tensor, Dict[str, Intervention]] = {}
static_parameter_interventions: Dict[float, Dict[str, Intervention]] = {}
for count in range(param_size):
if start_time[count].item() in static_parameter_interventions:
static_parameter_interventions[start_time[count].item()].update(
Expand All @@ -47,18 +47,18 @@ def intervention_generator(
def start_time_objective(
param_name: List[str],
param_value: List[Intervention],
) -> Callable[[torch.Tensor], Dict[torch.Tensor, Dict[str, Intervention]]]:
) -> Callable[[torch.Tensor], Dict[float, Dict[str, Intervention]]]:
param_size = len(param_name)

def intervention_generator(
x: torch.Tensor,
) -> Dict[torch.Tensor, Dict[str, Intervention]]:
) -> Dict[float, Dict[str, Intervention]]:
x = torch.atleast_1d(x)
assert x.size()[0] == param_size, (
f"Size mismatch between input size ('{x.size()[0]}') and param_name size ('{param_size}'): "
"check size for initial_guess_interventions and/or bounds_interventions."
)
static_parameter_interventions: Dict[torch.Tensor, Dict[str, Intervention]] = {}
static_parameter_interventions: Dict[float, Dict[str, Intervention]] = {}
for count in range(param_size):
if x[count].item() in static_parameter_interventions:
static_parameter_interventions[x[count].item()].update(
Expand All @@ -76,7 +76,7 @@ def intervention_generator(
def start_time_param_value_objective(
param_name: List[str],
param_value: List[Intervention] = [None],
) -> Callable[[torch.Tensor], Dict[torch.Tensor, Dict[str, Intervention]]]:
) -> Callable[[torch.Tensor], Dict[float, Dict[str, Intervention]]]:
param_size = len(param_name)
if len(param_value) < param_size and param_value[0] is None:
param_value = [None for _ in param_name]
Expand All @@ -87,13 +87,13 @@ def start_time_param_value_objective(

def intervention_generator(
x: torch.Tensor,
) -> Dict[torch.Tensor, Dict[str, Intervention]]:
) -> Dict[float, Dict[str, Intervention]]:
x = torch.atleast_1d(x)
assert x.size()[0] == param_size * 2, (
f"Size mismatch between input size ('{x.size()[0]}') and param_name size ('{param_size * 2}'): "
"check size for initial_guess_interventions and/or bounds_interventions."
)
static_parameter_interventions: Dict[torch.Tensor, Dict[str, Intervention]] = {}
static_parameter_interventions: Dict[float, Dict[str, Intervention]] = {}
for count in range(param_size):
if x[count * 2].item() in static_parameter_interventions:
static_parameter_interventions[x[count * 2].item()].update(
Expand All @@ -116,21 +116,21 @@ def intervention_generator(

def intervention_func_combinator(
intervention_funcs: List[
Callable[[torch.Tensor], Dict[torch.Tensor, Dict[str, Intervention]]]
Callable[[torch.Tensor], Dict[float, Dict[str, Intervention]]]
],
intervention_func_lengths: List[int],
) -> Callable[[torch.Tensor], Dict[torch.Tensor, Dict[str, Intervention]]]:
) -> Callable[[torch.Tensor], Dict[float, Dict[str, Intervention]]]:
assert len(intervention_funcs) == len(intervention_func_lengths)

total_length = sum(intervention_func_lengths)

# Note: This only works for combining static parameter interventions.
def intervention_generator(
x: torch.Tensor,
) -> Dict[torch.Tensor, Dict[str, Intervention]]:
) -> Dict[float, Dict[str, Intervention]]:
x = torch.atleast_1d(x)
assert x.size()[0] == total_length
interventions: List[Dict[torch.Tensor, Dict[str, Intervention]]] = [None] * len(
interventions: List[Dict[float, Dict[str, Intervention]]] = [None] * len(
intervention_funcs
)
i = 0
Expand All @@ -145,9 +145,9 @@ def intervention_generator(


def combine_static_parameter_interventions(
interventions: List[Dict[torch.Tensor, Dict[str, Intervention]]]
) -> Dict[torch.Tensor, Dict[str, Intervention]]:
static_parameter_interventions: Dict[torch.Tensor, Dict[str, Intervention]] = {}
interventions: List[Dict[float, Dict[str, Intervention]]]
) -> Dict[float, Dict[str, Intervention]]:
static_parameter_interventions: Dict[float, Dict[str, Intervention]] = {}
for intervention in interventions:
for key, value in intervention.items():
if key in static_parameter_interventions:
Expand Down
6 changes: 2 additions & 4 deletions pyciemss/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@ def optimize(
qoi: Callable,
risk_bound: float,
static_parameter_interventions: Callable[
[torch.Tensor], Dict[torch.Tensor, Dict[str, Intervention]]
[torch.Tensor], Dict[float, Dict[str, Intervention]]
],
objfun: Callable,
initial_guess_interventions: List[float],
Expand All @@ -783,9 +783,7 @@ def optimize(
solver_options: Dict[str, Any] = {},
start_time: float = 0.0,
inferred_parameters: Optional[pyro.nn.PyroModule] = None,
fixed_static_parameter_interventions: Dict[
torch.Tensor, Dict[str, Intervention]
] = {},
fixed_static_parameter_interventions: Dict[float, Dict[str, Intervention]] = {},
n_samples_ouu: int = int(1e3),
maxiter: int = 5,
maxfeval: int = 25,
Expand Down
4 changes: 1 addition & 3 deletions pyciemss/ouu/ouu.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,7 @@ class computeRisk:
def __init__(
self,
model: Callable,
interventions: Callable[
[torch.Tensor], Dict[torch.Tensor, Dict[str, Intervention]]
],
interventions: Callable[[torch.Tensor], Dict[float, Dict[str, Intervention]]],
qoi: Callable,
end_time: float,
logging_step_size: float,
Expand Down

0 comments on commit 0f91a32

Please sign in to comment.