Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
#47 implement different step functions for learning tendencies and l…
Browse files Browse the repository at this point in the history
…earning state

Co-authored-by: Jakob Schloer <[email protected]>
  • Loading branch information
Rilwan-Adewoyin committed Sep 4, 2024
1 parent 758ec18 commit 8359ad1
Showing 1 changed file with 104 additions and 24 deletions.
128 changes: 104 additions & 24 deletions src/anemoi/training/train/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
config: DictConfig,
graph_data: HeteroData,
statistics: dict,
statistics_tendencies: dict,
data_indices: IndexCollection,
metadata: dict,
) -> None:
Expand All @@ -57,6 +58,8 @@ def __init__(
Graph object
statistics : dict
Statistics of the training data
statistics_tendencies : dict
Statistics of the training data tendencies
data_indices : IndexCollection
Indices of the training data,
metadata : dict
Expand All @@ -69,12 +72,21 @@ def __init__(

self.model = AnemoiModelInterface(
statistics=statistics,
statistics_tendencies=statistics_tendencies,
data_indices=data_indices,
metadata=metadata,
graph_data=graph_data,
config=DotDict(map_config_to_primitives(OmegaConf.to_container(config, resolve=True))),
)

# Flexible stepping function definition
self.step_functions = {
"residual": self._step_residual,
"tendency": self._step_tendency,
}
self.prediction_mode = "tendency" if self.model.tendency_mode else "residual"
LOGGER.info("Using stepping mode: %s", self.prediction_mode)

self.data_indices = data_indices

self.save_hyperparameters()
Expand All @@ -84,8 +96,16 @@ def __init__(

self.logger_enabled = config.diagnostics.log.wandb.enabled or config.diagnostics.log.mlflow.enabled

# TODO (rilwan-ade): restructure this so that as the feature weighting - it can be configurable loaded in from a "get_loss_scaling" function
# use method in other branch
tendency_variance = (
torch.from_numpy(self.model.statistics_tendencies["stdev"][self.data_indices.data.output.full])
if self.model.tendency_mode
else None
)

self.metric_ranges, loss_scaling = self.metrics_loss_scaling(config, data_indices)
self.loss = WeightedMSELoss(node_weights=self.loss_weights, data_variances=loss_scaling)
self.loss = WeightedMSELoss(node_weights=self.loss_weights, data_variances=loss_scaling, tendency_variances=tendency_variance)
self.metrics = WeightedMSELoss(node_weights=self.loss_weights, ignore_nans=True)

if config.training.loss_gradient_scaling:
Expand Down Expand Up @@ -187,8 +207,7 @@ def advance_input(
x[:, -1, :, :, self.data_indices.model.input.forcing] = batch[
:,
self.multi_step + rollout_step,
:,
:,
...,
self.data_indices.data.input.forcing,
]
return x
Expand All @@ -198,10 +217,19 @@ def _step(
batch: torch.Tensor,
batch_idx: int,
validation_mode: bool = False,
in_place_proc: bool = True,
) -> tuple[torch.Tensor, Mapping[str, torch.Tensor]]:
return self.step_functions[self.prediction_mode](batch, batch_idx, validation_mode, in_place_proc)

def _step_residual(
self,
batch: torch.Tensor,
batch_idx: int,
validation_mode: bool = False,
in_place_proc: bool = True,
) -> tuple[torch.Tensor, Mapping[str, torch.Tensor]]:
del batch_idx
loss = torch.zeros(1, dtype=batch.dtype, device=self.device, requires_grad=False)
batch = self.model.pre_processors(batch) # normalized in-place
batch = self.model.pre_processors_state(batch, in_place=in_place_proc) # normalized in-place
metrics = {}

# start rollout
Expand All @@ -210,6 +238,7 @@ def _step(
y_preds = []
for rollout_step in range(self.rollout):
# prediction at rollout step rollout_step, shape = (bs, latlon, nvar)
# if rollout_step > 0: torch.cuda.empty_cache() # uncomment if rollout fails with OOM
y_pred = self(x)

y = batch[:, self.multi_step + rollout_step, ..., self.data_indices.data.output.full]
Expand All @@ -219,38 +248,89 @@ def _step(
x = self.advance_input(x, y_pred, batch, rollout_step)

if validation_mode:
metrics_next, y_preds_next = self.calculate_val_metrics(
y_pred,
y,
rollout_step,
enable_plot=self.enable_plot,
)
metrics_next, y_preds_next = self.calculate_val_metrics(y_pred, y, rollout_step, enable_plot=self.enable_plot)
metrics.update(metrics_next)
y_preds.extend(y_preds_next)

# scale loss
loss *= 1.0 / self.rollout

return loss, metrics, y_preds

def calculate_val_metrics(
def _step_tendency(
self,
y_pred: torch.Tensor,
y: torch.Tensor,
rollout_step: int,
enable_plot: bool = False,
) -> tuple[dict, list]:
batch: torch.Tensor,
batch_idx: int,
validation_mode: bool = False,
in_place_proc: bool = True,
) -> tuple[torch.Tensor, Mapping[str, torch.Tensor]]:
loss = torch.zeros(1, dtype=batch.dtype, device=self.device, requires_grad=False)
metrics = {}

# x ( non-processed)
x = batch[:, 0 : self.multi_step, ..., self.data_indices.data.input.full] # (bs, multi_step, latlon, nvar)

y_preds = []
y_postprocessed = self.model.post_processors(y, in_place=False)
y_pred_postprocessed = self.model.post_processors(y_pred, in_place=False)
for mkey, indices in self.metric_ranges.items():
metrics[f"{mkey}_{rollout_step + 1}"] = self.metrics(
y_pred_postprocessed[..., indices],
y_postprocessed[..., indices],
for rollout_step in range(self.rollout):

# normalise inputs
x_in = self.model.pre_processors_state(x, in_place=False, data_index=self.data_indices.data.input.full)

# prediction (normalized tendency)
tendency_pred = self(x_in)

# re-construct non-processed predicted state
y_pred = self.model.add_tendency_to_state(x[:, -1, ...], tendency_pred)

# Target is full state
y_target = batch[:, self.multi_step + rollout_step, ..., self.data_indices.data.output.full]

# calculate loss
loss += checkpoint(
self.loss,
self.model.pre_processors_state(y_pred, in_place=False, data_index=self.data_indices.data.output.full),
self.model.pre_processors_state(y_target, in_place=False, data_index=self.data_indices.data.output.full),
use_reentrant=False,
)
# TODO: We should try that too
# loss += checkpoint(self.loss, y_pred, y_target, use_reentrant=False)

# advance input using non-processed x, y_pred and batch
x = self.advance_input(x, y_pred, batch, rollout_step)

if validation_mode:
# calculate_val_metrics requires processed inputs
metrics_next, _ = self.calculate_val_metrics(
None,
None,
rollout_step,
self.enable_plot,
y_pred_postprocessed=y_pred,
y_postprocessed=y_target,
)

metrics.update(metrics_next)

y_preds.extend(y_pred)

# scale loss
loss *= 1.0 / self.rollout

return loss, metrics, y_preds

def calculate_val_metrics(self, y_pred, y, rollout_step, enable_plot=False, y_pred_postprocessed=None, y_postprocessed=None):
metrics = {}
y_preds = []
if y_postprocessed is None:
y_postprocessed = self.model.post_processors_state(y, in_place=False)
if y_pred_postprocessed is None:
y_pred_postprocessed = self.model.post_processors_state(y_pred, in_place=False)

for mkey, indices in self.metric_ranges.items():
metrics[f"{mkey}_{rollout_step + 1}"] = self.metrics(y_pred_postprocessed[..., indices], y_postprocessed[..., indices])

if enable_plot:
y_preds.append(y_pred)
y_preds.append(y_pred_postprocessed)
return metrics, y_preds

def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
Expand Down

0 comments on commit 8359ad1

Please sign in to comment.