diff --git a/src/anemoi/inference/checkpoint.py b/src/anemoi/inference/checkpoint.py index 3bfe70b..80f75da 100644 --- a/src/anemoi/inference/checkpoint.py +++ b/src/anemoi/inference/checkpoint.py @@ -77,6 +77,10 @@ def typed_variables(self): def diagnostic_variables(self): return self._metadata.diagnostic_variables + @property + def prognostic_variables(self): + return self._metadata.prognostic_variables + @property def prognostic_output_mask(self): return self._metadata.prognostic_output_mask diff --git a/src/anemoi/inference/metadata.py b/src/anemoi/inference/metadata.py index 838bcbb..f521825 100644 --- a/src/anemoi/inference/metadata.py +++ b/src/anemoi/inference/metadata.py @@ -36,6 +36,10 @@ def _remove_full_paths(x): class frozendict(dict): + def __deepcopy__(self, memo): + # As this is a frozendict, we can return the same object + return self + def __setitem__(self, key, value): raise TypeError("frozendict is immutable") diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index 2856763..e8a74b0 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -55,7 +55,7 @@ def __init__( checkpoint, *, accumulations=True, - device: str, + device: str = "cuda", precision: str = None, report_error=False, allow_nans=None, # can be True of False @@ -116,7 +116,7 @@ def run(self, *, input_state, lead_time): lead_time = to_timedelta(lead_time) - # This may be used but Ouput objects to compute the step + # This may be used but Output objects to compute the step self.lead_time = lead_time self.time_step = self.checkpoint.timestep