From 82431b25a32c240d873e0860ab99d39b880d32f6 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 29 Nov 2024 16:22:25 +0000 Subject: [PATCH] Better support for notebooks --- src/anemoi/inference/metadata.py | 23 +++++++++++------------ src/anemoi/inference/runner.py | 4 +--- src/anemoi/inference/runners/simple.py | 11 +++++++++++ 3 files changed, 23 insertions(+), 15 deletions(-) diff --git a/src/anemoi/inference/metadata.py b/src/anemoi/inference/metadata.py index f521825..967adab 100644 --- a/src/anemoi/inference/metadata.py +++ b/src/anemoi/inference/metadata.py @@ -629,12 +629,11 @@ def constant_forcings_inputs(self, context, input_state): remaining_mask = [i for i, _ in remaining] remaining = [name for _, name in remaining] - result.append( - context.create_constant_coupled_forcings( - remaining, - remaining_mask, - ) - ) + forcing = context.create_constant_coupled_forcings(remaining, remaining_mask) + + if forcing is not None: + # SimpleRunner does not support dynamic forcings + result.append(forcing) return result @@ -676,12 +675,12 @@ def dynamic_forcings_inputs(self, context, input_state): remaining_mask = [i for i, _ in remaining] remaining = [name for _, name in remaining] - result.append( - context.create_dynamic_coupled_forcings( - remaining, - remaining_mask, - ) - ) + forcing = context.create_dynamic_coupled_forcings(remaining, remaining_mask) + + if forcing is not None: + # SimpleRunner does not support dynamic forcings + result.append(forcing) + return result def boundary_forcings_inputs(self, context, input_state): diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index b1c6b95..b5cba45 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -145,9 +145,6 @@ def add_initial_forcings_to_input_state(self, input_state): # TODO: Check for user provided forcings for source in self.constant_forcings_inputs: - if source is None: - # When the constants are already in the input state - continue LOG.info("Constant forcings input: %s %s (%s)", source, source.variables, dates) arrays = source.load_forcings(input_state, dates) for name, forcing in zip(source.variables, arrays): @@ -362,6 +359,7 @@ def add_dynamic_forcings_to_input_tensor(self, input_tensor_torch, state, date, # batch is always 1 for source in self.dynamic_forcings_inputs: + forcings = source.load_forcings(state, [date]) # shape: (variables, dates, values) forcings = np.squeeze(forcings, axis=1) # Drop the dates dimension diff --git a/src/anemoi/inference/runners/simple.py b/src/anemoi/inference/runners/simple.py index f0dd5c9..1907d8d 100644 --- a/src/anemoi/inference/runners/simple.py +++ b/src/anemoi/inference/runners/simple.py @@ -52,4 +52,15 @@ def create_dynamic_computed_forcings(self, variables, mask): return result def create_constant_coupled_forcings(self, variables, mask): + # This runner does not support coupled forcings + # there are supposed to be already in the state dictionary + # of managed by the user. + LOG.warning("Coupled forcings are not supported by this runner: %s", variables) + return None + + def create_dynamic_coupled_forcings(self, variables, mask): + # This runner does not support coupled forcings + # there are supposed to be already in the state dictionary + # of managed by the user. + LOG.warning("Coupled forcings are not supported by this runner: %s", variables) return None