Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
#47 implement seperate post_processing objects for state and tendency
Browse files Browse the repository at this point in the history
  • Loading branch information
Rilwan-Adewoyin committed Sep 4, 2024
1 parent 6c306b3 commit a97b4d4
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 68 deletions.
78 changes: 53 additions & 25 deletions src/anemoi/training/config/data/zarr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,38 +27,66 @@ diagnostic:
- tp
- cp

normalizer:
default: "mean-std"
min-max:
max:
- "sdor"
- "slor"
- "z"
none:
- "cos_latitude"
- "cos_longitude"
- "sin_latitude"
- "sin_longitude"
- "cos_julian_day"
- "cos_local_time"
- "sin_julian_day"
- "sin_local_time"
- "insolation"
- "lsm"
normalizers:
state:
default: "mean-std"
min-max:
max:
- "sdor"
- "slor"
- "z"
none:
- "cos_latitude"
- "cos_longitude"
- "sin_latitude"
- "sin_longitude"
- "cos_julian_day"
- "cos_local_time"
- "sin_julian_day"
- "sin_local_time"
- "cos_solar_zenith_angle"
- "lsm"

tendency:
default: "mean-std"
min-max:
max:
- "sdor"
- "slor"
- "z"
none:
- "cos_latitude"
- "cos_longitude"
- "sin_latitude"
- "sin_longitude"
- "cos_julian_day"
- "cos_local_time"
- "sin_julian_day"
- "sin_local_time"
- "cos_solar_zenith_angle"
- "lsm"


imputer:
default: "none"

# processors including imputers and normalizers are applied in order of definition
processors:
# example_imputer:
# _target_: anemoi.models.preprocessing.imputer.InputImputer
# _convert_: all
# config: ${data.imputer}
normalizer:
_target_: anemoi.models.preprocessing.normalizer.InputNormalizer
_convert_: all
config: ${data.normalizer}
# _target_: anemoi.models.preprocessing.imputer.InputImputer
# _convert_: all
# config: ${data.imputer}
state:
normalizer:
_target_: anemoi.models.preprocessing.normalizer.InputNormalizer
_convert_: all
config: ${data.normalizers.state}

tendency:
normalizer:
_target_: anemoi.models.preprocessing.normalizer.InputNormalizer
_convert_: all
config: ${data.normalizers.tendency}

# Values set in the code
num_features: null # number of features in the forecast state
9 changes: 9 additions & 0 deletions src/anemoi/training/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@ def _check_resolution(self, resolution: str) -> None:
def statistics(self) -> dict:
return self.ds_train.statistics

@cached_property
def statistics_tendencies(self) -> dict:
# This is just a quick fix to work with datasets without stored tendency
# statistics. This should be caught in anemoi-datasets.
if self.config.training.tendency_mode:
return self.ds_train.statistics_tendencies
return None

@cached_property
def metadata(self) -> dict:
return self.ds_train.metadata
Expand Down Expand Up @@ -165,6 +173,7 @@ def _get_dataset(
rollout=r,
multistep=self.config.training.multistep_input,
timeincrement=self.timeincrement,
timestep=self.config.data.timestep,
model_comm_group_rank=self.model_comm_group_rank,
model_comm_group_id=self.model_comm_group_id,
model_comm_num_groups=self.model_comm_num_groups,
Expand Down
55 changes: 12 additions & 43 deletions src/anemoi/training/diagnostics/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def __init__(self, config: OmegaConf) -> None:
self.config = config
self.save_basedir = config.hardware.paths.plots
self.plot_frequency = config.diagnostics.plot.frequency
self.post_processors = None
self.pre_processors = None
self.post_processors_state = None
self.pre_processors_state = None
self.latlons = None
init_plot_settings()

Expand Down Expand Up @@ -195,40 +195,9 @@ def _eval(
pl_module: pl.LightningModule,
batch: torch.Tensor,
) -> None:
loss = torch.zeros(1, dtype=batch.dtype, device=pl_module.device, requires_grad=False)
# NB! the batch is already normalized in-place - see pl_model.validation_step()
metrics = {}

# start rollout
x = batch[
:,
0 : pl_module.multi_step,
...,
pl_module.data_indices.data.input.full,
] # (bs, multi_step, latlon, nvar)
assert (
batch.shape[1] >= self.rollout + pl_module.multi_step
), "Batch length not sufficient for requested rollout length!"

with torch.no_grad():
for rollout_step in range(self.rollout):
y_pred = pl_module(x) # prediction at rollout step rollout_step, shape = (bs, latlon, nvar)
y = batch[
:,
pl_module.multi_step + rollout_step,
...,
pl_module.data_indices.data.output.full,
] # target, shape = (bs, latlon, nvar)
# y includes the auxiliary variables, so we must leave those out when computing the loss
loss += pl_module.loss(y_pred, y)

x = pl_module.advance_input(x, y_pred, batch, rollout_step)

metrics_next, _ = pl_module.calculate_val_metrics(y_pred, y, rollout_step)
metrics.update(metrics_next)

# scale loss
loss *= 1.0 / self.rollout
loss, metrics, _ = pl_module._step(batch, validation_mode=True, in_place_proc=False)

self._log(pl_module, loss, metrics, batch.shape[0])

def _log(self, pl_module: pl.LightningModule, loss: torch.Tensor, metrics: dict, bs: int) -> None:
Expand Down Expand Up @@ -533,9 +502,9 @@ def _plot(
# When running in Async mode, it might happen that in the last epoch these tensors
# have been moved to the cpu (and then the denormalising would fail as the 'input_tensor' would be on CUDA
# but internal ones would be on the cpu), The lines below allow to address this problem
if self.post_processors is None:
if self.post_processors_state is None:
# Copy to be used across all the training cycle
self.post_processors = copy.deepcopy(pl_module.model.post_processors).cpu()
self.post_processors_state = copy.deepcopy(pl_module.model.post_processors_state).cpu()
if self.latlons is None:
self.latlons = np.rad2deg(pl_module.latlons_data.clone().cpu().numpy())
local_rank = pl_module.local_rank
Expand All @@ -546,9 +515,9 @@ def _plot(
...,
pl_module.data_indices.data.output.full,
].cpu()
data = self.post_processors(input_tensor).numpy()
data = self.post_processors_state(input_tensor).numpy()

output_tensor = self.post_processors(
output_tensor = self.post_processors_state(
torch.cat(tuple(x[self.sample_idx : self.sample_idx + 1, ...].cpu() for x in outputs[1])),
in_place=False,
).numpy()
Expand Down Expand Up @@ -624,9 +593,9 @@ def _plot(
if self.pre_processors is None:
# Copy to be used across all the training cycle
self.pre_processors = copy.deepcopy(pl_module.model.pre_processors).cpu()
if self.post_processors is None:
if self.post_processors_state is None:
# Copy to be used across all the training cycle
self.post_processors = copy.deepcopy(pl_module.model.post_processors).cpu()
self.post_processors_state = copy.deepcopy(pl_module.model.post_processors_state).cpu()
if self.latlons is None:
self.latlons = np.rad2deg(pl_module.latlons_data.clone().cpu().numpy())
local_rank = pl_module.local_rank
Expand All @@ -637,8 +606,8 @@ def _plot(
...,
pl_module.data_indices.data.output.full,
].cpu()
data = self.post_processors(input_tensor).numpy()
output_tensor = self.post_processors(
data = self.post_processors_state(input_tensor).numpy()
output_tensor = self.post_processors_state(
torch.cat(tuple(x[self.sample_idx : self.sample_idx + 1, ...].cpu() for x in outputs[1])),
in_place=False,
).numpy()
Expand Down

0 comments on commit a97b4d4

Please sign in to comment.