Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix!: variable scaling, pressure level scalings only applied in specific circumstances #52

Open
wants to merge 37 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
511ed18
first version of refactor of variable scaling
sahahner Dec 27, 2024
7ddf6d6
config training changes
sahahner Dec 27, 2024
3ddeccc
avoid multiple scaling
sahahner Dec 27, 2024
be4602c
docstring and explain variable reference
sahahner Dec 31, 2024
195af07
fix to config for pressure level scaler
mc4117 Dec 31, 2024
2644c18
instantiating scalars as a list
mc4117 Dec 31, 2024
718fc57
preparing for tendency losses
mc4117 Dec 31, 2024
a34ac02
Merge branch '7-pressure-level-scalings-only-applied-in-specific-circ…
mc4117 Dec 31, 2024
b91af11
log the variable level scaling information as before
sahahner Jan 2, 2025
c22c50b
adding tendency scaler to additional scalers
pinnstorm Jan 8, 2025
1f4a532
reformatting
pinnstorm Jan 8, 2025
2843d98
updating description in configs
pinnstorm Jan 8, 2025
c978871
updating var-tendency-scaler spec
pinnstorm Jan 12, 2025
f56f9b2
updating training/default config
pinnstorm Jan 12, 2025
be90000
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 12, 2025
e474ae9
updating training/default.yaml
pinnstorm Jan 13, 2025
f005f84
updating training/default.yaml
pinnstorm Jan 13, 2025
7cdccc5
first try at tests
mc4117 Jan 17, 2025
61e7933
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 17, 2025
462bb34
variable name and level from mars metadata
sahahner Jan 17, 2025
960a602
Merge branch '7-pressure-level-scalings-only-applied-in-specific-circ…
sahahner Jan 17, 2025
af10173
get variable group and level in utils file
sahahner Jan 17, 2025
395cd6f
empty line
sahahner Jan 17, 2025
1f53a82
convert test for new strucutre. pressure level and general variable s…
sahahner Jan 17, 2025
3747959
more plausible check for availability of mars metadata
sahahner Jan 17, 2025
68cd6e3
update to tendency tests (still not working)
mc4117 Jan 17, 2025
d3a7c29
Merge branch '7-pressure-level-scalings-only-applied-in-specific-circ…
mc4117 Jan 17, 2025
d6e127a
tendency scaler tests now working
mc4117 Jan 20, 2025
fd29cbc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 20, 2025
8bff68b
change function into class, extracting variable group and name
sahahner Jan 22, 2025
4c7cbc1
Merge branch '7-pressure-level-scalings-only-applied-in-specific-circ…
sahahner Jan 22, 2025
7d8c76d
correct function call
sahahner Jan 22, 2025
d928b30
correct typo in test
sahahner Jan 22, 2025
bb054ce
incorporate comments
sahahner Jan 22, 2025
d0046fa
introduce base class for all loss scalings
sahahner Jan 22, 2025
a03d6ba
type checking check after all imports
sahahner Jan 22, 2025
aa7f558
comment: explanation about variable groups in config file
sahahner Jan 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 26 additions & 20 deletions training/src/anemoi/training/config/training/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ training_loss:
# Available scalars include:
# - 'variable': See `variable_loss_scaling` for more information
# - 'loss_weights_mask': Giving imputed NaNs a zero weight in the loss function
scalars: ['variable', 'loss_weights_mask']
scalars: ['variable', 'variable_pressure_level', 'loss_weights_mask']

ignore_nans: False

Expand Down Expand Up @@ -109,33 +109,39 @@ lr:
# Variable loss scaling
# 'variable' must be included in `scalars` in the losses for this to be applied.
variable_loss_scaling:
variable_groups:
default: sfc
pl: [q, t, u, v, w, z]
HCookie marked this conversation as resolved.
Show resolved Hide resolved
default: 1
pl:
q: 0.6 #1
t: 6 #1
u: 0.8 #0.5
v: 0.5 #0.33
w: 0.001
z: 12 #1
sfc:
sp: 10
10u: 0.1
10v: 0.1
2d: 0.5
tp: 0.025
cp: 0.0025
q: 0.6 #1
t: 6 #1
u: 0.8 #0.5
v: 0.5 #0.33
w: 0.001
z: 12 #1
sp: 10
10u: 0.1
10v: 0.1
2d: 0.5
tp: 0.025
cp: 0.0025
additional_scalars:
# pressure level scalar
- _target_: anemoi.training.train.scaling.ReluVariableLevelScaler
group: pl
y_intercept: 0.2
slope: 0.001
scale_dim: -1 # dimension on which scaling applied
name: "variable_pressure_level"
# norm tendency scalar (scaling loss function by the normalised tendency values)
#- _target_: anemoi.training.data.scaling.NormTendencyScaler

metrics:
- z_500
- t_850
- u_850
- v_850

pressure_level_scaler:
_target_: anemoi.training.data.scaling.ReluPressureLevelScaler
minimum: 0.2
slope: 0.001

node_loss_weights:
_target_: anemoi.training.losses.nodeweights.GraphNodeAttribute
target_nodes: ${graph.data}
Expand Down
5 changes: 5 additions & 0 deletions training/src/anemoi/training/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ def __init__(self, config: DictConfig, graph_data: HeteroData) -> None:
def statistics(self) -> dict:
return self.ds_train.statistics

@cached_property
def statistics_tendencies(self) -> dict:
return self.ds_train.statistics_tendencies

@cached_property
def metadata(self) -> dict:
return self.ds_train.metadata
Expand Down Expand Up @@ -183,6 +187,7 @@ def _get_dataset(
rollout=r,
multistep=self.config.training.multistep_input,
timeincrement=self.timeincrement,
timestep=self.config.data.timestep,
shuffle=shuffle,
grid_indices=self.grid_indices,
label=label,
Expand Down
13 changes: 13 additions & 0 deletions training/src/anemoi/training/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
rollout: int = 1,
multistep: int = 1,
timeincrement: int = 1,
timestep: str = "6h",
shuffle: bool = True,
label: str = "generic",
effective_bs: int = 1,
Expand All @@ -57,6 +58,8 @@ def __init__(
length of rollout window, by default 12
timeincrement : int, optional
time increment between samples, by default 1
timestep : int, optional
the time frequency of the samples, by default '6h'
multistep : int, optional
collate (t-1, ... t - multistep) into the input state vector, by default 1
shuffle : bool, optional
Expand All @@ -73,6 +76,7 @@ def __init__(

self.rollout = rollout
self.timeincrement = timeincrement
self.timestep = timestep
self.grid_indices = grid_indices

# lazy init
Expand Down Expand Up @@ -104,6 +108,15 @@ def statistics(self) -> dict:
"""Return dataset statistics."""
return self.data.statistics

@cached_property
def statistics_tendencies(self) -> dict:
"""Return dataset tendency statistics."""
# The statistics_tendencies are lazily loaded
self.data.statistics_tendencies = (
self.data.statistics_tendencies(self.timestep) if callable(self.data.statistics_tendencies) else None
)
return self.data.statistics_tendencies

@cached_property
def metadata(self) -> dict:
"""Return dataset metadata."""
Expand Down
79 changes: 0 additions & 79 deletions training/src/anemoi/training/data/scaling.py

This file was deleted.

64 changes: 22 additions & 42 deletions training/src/anemoi/training/train/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from typing import Optional
from typing import Union

import numpy as np
import pytorch_lightning as pl
import torch
from hydra.utils import instantiate
Expand All @@ -31,6 +30,7 @@
from anemoi.models.interface import AnemoiModelInterface
from anemoi.training.losses.utils import grad_scaler
from anemoi.training.losses.weightedloss import BaseWeightedLoss
from anemoi.training.train.scaling import GeneralVariableLossScaler
from anemoi.training.utils.jsonify import map_config_to_primitives
from anemoi.training.utils.masks import Boolean1DMask
from anemoi.training.utils.masks import NoOutputMask
Expand All @@ -48,6 +48,7 @@ def __init__(
config: DictConfig,
graph_data: HeteroData,
statistics: dict,
statistics_tendencies: dict,
data_indices: IndexCollection,
metadata: dict,
supporting_arrays: dict,
Expand Down Expand Up @@ -95,10 +96,25 @@ def __init__(
self.latlons_data = graph_data[config.graph.data].x
self.node_weights = self.get_node_weights(config, graph_data)
self.node_weights = self.output_mask.apply(self.node_weights, dim=0, fill_value=0.0)
self.statistics_tendencies = statistics_tendencies

self.logger_enabled = config.diagnostics.log.wandb.enabled or config.diagnostics.log.mlflow.enabled

variable_scaling = self.get_variable_scaling(config, data_indices)
variable_scaling = GeneralVariableLossScaler(
config.training.variable_loss_scaling,
data_indices,
).get_variable_scaling()

# Instantiate the pressure level scaling class with the training configuration
sahahner marked this conversation as resolved.
Show resolved Hide resolved
config_container = OmegaConf.to_container(config.training.additional_scalars, resolve=False)
if isinstance(config_container, list):
scalar = [instantiate(
scalar_config,
scaling_config=config.training.variable_loss_scaling,
data_indices=data_indices,
)
for scalar_config in config_container
]

self.internal_metric_ranges, self.val_metric_ranges = self.get_val_metric_ranges(config, data_indices)

Expand All @@ -118,8 +134,11 @@ def __init__(
self.scalars = {
"variable": (-1, variable_scaling),
"loss_weights_mask": ((-2, -1), torch.ones((1, 1))),
"limited_area_mask": (2, limited_area_mask),
"limited_area_mask": (2, limited_area_mask)
}
# add addtional user-defined scalars
[self.scalars.update({scale.name: (scale.scale_dim, scale.get_variable_scaling())}) for scale in scalar]

self.updated_loss_mask = False

self.loss = self.get_loss_function(config.training.training_loss, scalars=self.scalars, **loss_kwargs)
Expand Down Expand Up @@ -299,45 +318,6 @@ def get_val_metric_ranges(config: DictConfig, data_indices: IndexCollection) ->

return metric_ranges, metric_ranges_validation

@staticmethod
def get_variable_scaling(
config: DictConfig,
data_indices: IndexCollection,
) -> torch.Tensor:
variable_loss_scaling = (
np.ones((len(data_indices.internal_data.output.full),), dtype=np.float32)
* config.training.variable_loss_scaling.default
)
pressure_level = instantiate(config.training.pressure_level_scaler)

LOGGER.info(
"Pressure level scaling: use scaler %s with slope %.4f and minimum %.2f",
type(pressure_level).__name__,
pressure_level.slope,
pressure_level.minimum,
)

for key, idx in data_indices.internal_model.output.name_to_index.items():
split = key.split("_")
if len(split) > 1 and split[-1].isdigit():
# Apply pressure level scaling
if split[0] in config.training.variable_loss_scaling.pl:
variable_loss_scaling[idx] = config.training.variable_loss_scaling.pl[
split[0]
] * pressure_level.scaler(
int(split[-1]),
)
else:
LOGGER.debug("Parameter %s was not scaled.", key)
else:
# Apply surface variable scaling
if key in config.training.variable_loss_scaling.sfc:
variable_loss_scaling[idx] = config.training.variable_loss_scaling.sfc[key]
else:
LOGGER.debug("Parameter %s was not scaled.", key)

return torch.from_numpy(variable_loss_scaling)

@staticmethod
def get_node_weights(config: DictConfig, graph_data: HeteroData) -> torch.Tensor:
node_weighting = instantiate(config.training.node_loss_weights)
Expand Down
Loading
Loading