Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Create model base class #79

Open
wants to merge 19 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ Keep it human-readable, your future self will thank you!

## [Unreleased](https://github.com/ecmwf/anemoi-models/compare/0.3.0...HEAD)

- Add synchronisation workflow

### Added

- New AnemoiModelEncProcDecHierarchical class available in models [#37](https://github.com/ecmwf/anemoi-models/pull/37)
Expand All @@ -26,6 +24,8 @@ Keep it human-readable, your future self will thank you!
- Mask NaN values in training loss function [#271](https://github.com/ecmwf-lab/aifs-mono/issues/271)
- New `NamedNodesAttributes` class to handle node attributes in a more flexible way [#64](https://github.com/ecmwf/anemoi-models/pull/64)
- Contributors file [#69](https://github.com/ecmwf/anemoi-models/pull/69)
- Add synchronisation workflow
- Refactor base functionality of `AnemoiEncProcDecModel` into abstract class. [#79](https://github.com/ecmwf/anemoi-models/pull/79/)

### Changed
- Bugfixes for CI
Expand Down
255 changes: 255 additions & 0 deletions src/anemoi/models/models/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
# (C) Copyright 2024 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#

import logging
from abc import ABC
from abc import abstractmethod
from typing import Optional

import einops
import torch
from anemoi.utils.config import DotDict
from hydra.utils import instantiate
from torch import Tensor
from torch import nn
from torch.distributed.distributed_c10d import ProcessGroup
from torch.utils.checkpoint import checkpoint
from torch_geometric.data import HeteroData

from anemoi.models.data_indices.collection import IndexCollection
from anemoi.models.distributed.shapes import get_shape_shards
from anemoi.models.layers.graph import NamedNodesAttributes

LOGGER = logging.getLogger(__name__)


class BaseAnemoiEncProcDecModel(nn.Module, ABC):
"""Message passing graph neural network."""

graph_data: HeteroData
_graph_name_data: str
_graph_name_hidden: str
multi_step: int
num_channels: int
node_attributes: NamedNodesAttributes
num_input_channels: int
num_output_channels: int
_internal_input_idx: list[int]
_internal_output_idx: list[int]

def __init__(
self,
*,
model_config: DotDict,
data_indices: IndexCollection,
graph_data: HeteroData,
) -> None:
"""Initializes the graph neural network.

Parameters
----------
model_config : DotDict
Model configuration
data_indices : IndexCollection
Data indices
graph_data : HeteroData
Graph definition
"""
super().__init__()

self.set_graph_parameters(graph_data, model_config)
self.set_model_parameters(model_config)

self._calculate_shapes_and_indices(data_indices)
self._assert_matching_indices(data_indices)

self.instantiate_encoder(model_config)
self.instantiate_processor(model_config)
self.instantiate_decoder(model_config)
self.instantiate_boundings(model_config, data_indices)

def set_graph_parameters(self, graph_data: HeteroData, model_config: DotDict) -> None:
"""Set the graph derived attributes inside the model."""
self._graph_data = graph_data
self._graph_name_data = model_config.graph.data
self._graph_name_hidden = model_config.graph.hidden
self.node_attributes = NamedNodesAttributes(model_config.model.trainable_parameters.hidden, self._graph_data)

def set_model_parameters(self, model_config: DotDict) -> None:
"""Set the model specific parameters based on the config file."""
self.multi_step = model_config.training.multistep_input
self.num_channels = model_config.model.num_channels

def _calculate_shapes_and_indices(self, data_indices: IndexCollection) -> None:
self.num_input_channels = len(data_indices.internal_model.input)
self.num_output_channels = len(data_indices.internal_model.output)
self._internal_input_idx = data_indices.internal_model.input.prognostic
self._internal_output_idx = data_indices.internal_model.output.prognostic

def _assert_matching_indices(self, data_indices: IndexCollection) -> None:
assert len(self._internal_output_idx) == len(data_indices.internal_model.output.full) - len(
data_indices.internal_model.output.diagnostic
), (
f"Mismatch between the internal data indices ({len(self._internal_output_idx)}) and "
f"the internal output indices excluding diagnostic variables "
f"({len(data_indices.internal_model.output.full) - len(data_indices.internal_model.output.diagnostic)})",
)
assert len(self._internal_input_idx) == len(
self._internal_output_idx,
), f"Internal model indices must match {self._internal_input_idx} != {self._internal_output_idx}"

@abstractmethod
def instantiate_encoder(self, model_config: DotDict) -> None:
pass

@abstractmethod
def instantiate_processor(self, model_config: DotDict) -> None:
pass

@abstractmethod
def instantiate_decoder(self, model_config: DotDict) -> None:
pass

def instantiate_boundings(self, model_config: DotDict, data_indices: IndexCollection) -> None:
# Instantiation of model output bounding functions (e.g., to ensure outputs like TP are positive definite)
self.boundings = nn.ModuleList(
[
instantiate(cfg, name_to_index=data_indices.internal_model.output.name_to_index)
for cfg in getattr(model_config.model, "bounding", [])
]
)

def _run_mapper(
self,
mapper: nn.Module,
data: tuple[Tensor],
batch_size: int,
shard_shapes: tuple[tuple[int, int], tuple[int, int]],
model_comm_group: Optional[ProcessGroup] = None,
use_reentrant: bool = False,
) -> Tensor:
"""Run mapper with activation checkpoint.

Parameters
----------
mapper : nn.Module
Which processor to use
data : tuple[Tensor]
tuple of data to pass in
batch_size: int,
Batch size
shard_shapes : tuple[tuple[int, int], tuple[int, int]]
Shard shapes for the data
model_comm_group : ProcessGroup
model communication group, specifies which GPUs work together
in one model instance
use_reentrant : bool, optional
Use reentrant, by default False

Returns
-------
Tensor
Mapped data
"""
return checkpoint(
mapper,
data,
batch_size=batch_size,
shard_shapes=shard_shapes,
model_comm_group=model_comm_group,
use_reentrant=use_reentrant,
)

@abstractmethod
def encode(
self,
x: tuple[Tensor, Tensor],
batch_size: int,
shard_shapes: tuple[int, int],
model_comm_group: Optional[ProcessGroup] = None,
) -> tuple[Tensor, Tensor]:
pass

@abstractmethod
def process(
self,
x: Tensor,
batch_size: int,
shard_shapes: tuple[int, int],
model_comm_group: Optional[ProcessGroup] = None,
) -> Tensor:
pass

@abstractmethod
def decode(self, x: tuple[Tensor, Tensor], batch_size: int, shard_shapes: tuple[int, int], model_comm_group):
pass

def bound_output(self, x: torch.Tensor) -> torch.Tensor:
for bounding in self.boundings:
# bounding performed in the order specified in the config file
x = bounding(x)

return x

def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> Tensor:
batch_size = x.shape[0]
ensemble_size = x.shape[2]

# add data positional info (lat/lon)
x_data = torch.cat(
(
einops.rearrange(x, "batch time ensemble grid vars -> (batch ensemble grid) (time vars)"),
self.node_attributes(self._graph_name_data, batch_size=batch_size),
),
dim=-1, # feature dimension
)

x_hidden = self.node_attributes(self._graph_name_hidden, batch_size=batch_size)

# get shard shapes
shard_shapes_data = get_shape_shards(x_data, 0, model_comm_group)
shard_shapes_hidden = get_shape_shards(x_hidden, 0, model_comm_group)

x_data_latent, x_hidden_latent = self.encode(
(x_data, x_hidden),
batch_size=batch_size,
shard_shapes=(shard_shapes_data, shard_shapes_hidden),
model_comm_group=model_comm_group,
)

x_hidden_latent_proc = self.process(x_hidden_latent, batch_size, shard_shapes_hidden, model_comm_group)

# add skip connection (hidden -> hidden)
x_hidden_latent_proc = x_hidden_latent_proc + x_hidden_latent

# Run decoder
x_out = self.decode(
(x_hidden_latent_proc, x_data_latent),
batch_size=batch_size,
shard_shapes=(shard_shapes_hidden, shard_shapes_data),
model_comm_group=model_comm_group,
)

x_out = (
einops.rearrange(
x_out,
"(batch ensemble grid) vars -> batch ensemble grid vars",
batch=batch_size,
ensemble=ensemble_size,
)
.to(dtype=x.dtype)
.clone()
)

# residual connection (just for the prognostic variables)
x_out[..., self._internal_output_idx] += x[:, -1, :, :, self._internal_input_idx]

x_out = self.bound_output(x_out, batch_size, ensemble_size)

return x_out
Loading
Loading