diff --git a/indica/workflows/bayes_workflow_example.py b/indica/workflows/bayes_workflow_example.py index a73b0cef..8e511f5e 100644 --- a/indica/workflows/bayes_workflow_example.py +++ b/indica/workflows/bayes_workflow_example.py @@ -2,6 +2,7 @@ import flatdict import numpy as np from scipy.stats import loguniform +from typing import Dict, Any from indica.bayesmodels import BayesModels from indica.bayesmodels import get_uniform @@ -196,7 +197,7 @@ def setup_plasma( self.save_phantom_profiles() def setup_models(self, diagnostics: list): - self.models = {} + self.models: Dict[str, Any] = {} for diag in diagnostics: if diag == "smmh1": los_transform = self.transforms[diag] @@ -216,9 +217,8 @@ def setup_models(self, diagnostics: list): # passes=2, # ) los_transform.set_equilibrium(self.equilibrium) - smmh = Interferometry(name=diag) - smmh.set_los_transform(los_transform) - self.models[diag] = smmh + model = Interferometry(name=diag) + model.set_los_transform(los_transform) elif diag == "xrcs": los_transform = self.transforms[diag] @@ -228,27 +228,24 @@ def setup_models(self, diagnostics: list): if diag in self.data.keys(): window = self.data[diag]["spectra"].wavelength.values - xrcs = HelikeSpectrometer( + model = HelikeSpectrometer( name="xrcs", window_masks=[slice(0.394, 0.396)], window=window, ) - xrcs.set_los_transform(los_transform) - # self.models[diag] = xrcs + model.set_los_transform(los_transform) elif diag == "efit": efit = EquilibriumReconstruction(name="efit") - self.models[diag] = efit elif diag == "cxff_pi": transform = self.transforms[diag] transform.set_equilibrium(self.equilibrium) - cxrs = ChargeExchange(name=diag, element="ar") - cxrs.set_transect_transform(transform) - self.models[diag] = cxrs - + model = ChargeExchange(name=diag, element="ar") + model.set_transect_transform(transform) else: raise ValueError(f"{diag} not found in setup_models") + self.models[diag] = model def setup_opt_data(self, phantoms=False, **kwargs): if not hasattr(self, "plasma"):