Skip to content

Commit

Permalink
fixed how callable list is read in 'ouu.py'
Browse files Browse the repository at this point in the history
  • Loading branch information
anirban-chaudhuri committed Jul 16, 2024
1 parent d5ef867 commit 646f1a3
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions pyciemss/ouu/ouu.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,7 @@ def __init__(
risk_bound: float = 0.0,
):
self.model = model
self.interventions = (
[interventions] if not isinstance(interventions, list) else interventions
)
self.interventions = interventions
self.qoi = qoi
self.risk_measure = risk_measure
self.num_samples = num_samples
Expand Down Expand Up @@ -139,7 +137,11 @@ def propagate_uncertainty(self, x):
intervention_list = [
deepcopy(self.fixed_static_parameter_interventions)
]
intervention_list.extend(self.interventions(torch.from_numpy(x)))
intervention_list.extend(
[self.interventions(torch.from_numpy(x))]
if not isinstance(self.interventions(torch.from_numpy(x)), list)
else self.interventions(torch.from_numpy(x))
)
static_parameter_interventions = combine_static_parameter_interventions(
intervention_list
)
Expand Down

0 comments on commit 646f1a3

Please sign in to comment.