Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Feature/trmapper #78

Draft
wants to merge 2 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 62 additions & 9 deletions src/anemoi/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,16 @@
from typing import Optional

import einops
import torch
from torch import Tensor
from torch import nn
from torch.distributed.distributed_c10d import ProcessGroup
from torch_geometric.typing import PairTensor

try:
from flash_attn import flash_attn_func as attn_func
except ImportError:
from flash_attn.layers.rotary import RotaryEmbedding
from torch.nn.functional import scaled_dot_product_attention as attn_func

_FLASH_ATTENTION_AVAILABLE = False
Expand All @@ -27,6 +30,7 @@

from anemoi.models.distributed.transformer import shard_heads
from anemoi.models.distributed.transformer import shard_sequence
from anemoi.models.layers.utils import AutocastLayerNorm

LOGGER = logging.getLogger(__name__)

Expand All @@ -42,6 +46,8 @@ def __init__(
is_causal: bool = False,
window_size: Optional[int] = None,
dropout_p: float = 0.0,
qk_norm: bool = False,
rotary_embeddings: bool = False,
):
super().__init__()

Expand All @@ -55,20 +61,36 @@ def __init__(
self.window_size = (window_size, window_size) # flash attention
self.dropout_p = dropout_p
self.is_causal = is_causal
self.qk_norm = qk_norm
self.rotary_embeddings = rotary_embeddings

self.lin_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
self.lin_q = nn.Linear(embed_dim, embed_dim, bias=bias)
self.lin_k = nn.Linear(embed_dim, embed_dim, bias=bias)
self.lin_v = nn.Linear(embed_dim, embed_dim, bias=bias)
self.attention = attn_func

if not _FLASH_ATTENTION_AVAILABLE:
LOGGER.warning("Flash attention not available, falling back to pytorch scaled_dot_product_attention")

self.projection = nn.Linear(embed_dim, embed_dim, bias=True)

def forward(
self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None
) -> Tensor:
query, key, value = self.lin_qkv(x).chunk(3, -1)
if self.qk_norm:
self.q_norm = AutocastLayerNorm(self.head_dim, bias=False)
self.k_norm = AutocastLayerNorm(self.head_dim, bias=False)

if self.rotary_embeddings: # find alternative implementation
assert _FLASH_ATTENTION_AVAILABLE, "Rotary embeddings require flash attention"
self.rotary_emb = RotaryEmbedding(dim=self.head_dim)

def attention_computation(
self,
query: Tensor,
key: Tensor,
value: Tensor,
shapes: list,
batch_size: int,
model_comm_group: Optional[ProcessGroup] = None,
) -> Tensor:
if model_comm_group:
assert (
model_comm_group.size() == 1 or batch_size == 1
Expand All @@ -83,16 +105,28 @@ def forward(
)
for t in (query, key, value)
)

query = shard_heads(query, shapes=shapes, mgroup=model_comm_group)
key = shard_heads(key, shapes=shapes, mgroup=model_comm_group)
value = shard_heads(value, shapes=shapes, mgroup=model_comm_group)
dropout_p = self.dropout_p if self.training else 0.0

if self.qk_norm:
query = self.q_norm(query)
key = self.k_norm(key)

if _FLASH_ATTENTION_AVAILABLE:
query, key, value = (
einops.rearrange(t, "batch heads grid vars -> batch grid heads vars") for t in (query, key, value)
)
if self.rotary_embeddings: # can this be done in a better way?
key = key.unsqueeze(-3)
value = value.unsqueeze(-3)
keyvalue = torch.cat((key, value), dim=-3)
query, keyvalue = self.rotary_emb(
query, keyvalue, max_seqlen=max(keyvalue.shape[1], query.shape[1])
) # assumption seq const
key = keyvalue[:, :, 0, ...]
value = keyvalue[:, :, 1, ...]
out = self.attention(query, key, value, causal=False, window_size=self.window_size, dropout_p=dropout_p)
out = einops.rearrange(out, "batch grid heads vars -> batch heads grid vars")
else:
Expand All @@ -103,10 +137,29 @@ def forward(
is_causal=False,
dropout_p=dropout_p,
) # expects (batch heads grid variable) format

out = shard_sequence(out, shapes=shapes, mgroup=model_comm_group)
out = einops.rearrange(out, "batch heads grid vars -> (batch grid) (heads vars)")
return self.projection(out)

def forward(
self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None
) -> Tensor:
query = self.lin_q(x)
key = self.lin_k(x)
value = self.lin_v(x)
return self.attention_computation(query, key, value, shapes, batch_size, model_comm_group)

out = self.projection(out)

return out
class MultiHeadCrossAttention(MultiHeadSelfAttention):
"""Multi Head Cross Attention Pytorch Layer."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def forward(
self, x: PairTensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None
) -> Tensor:
query = self.lin_q(x[1])
key = self.lin_k(x[0])
value = self.lin_v(x[0])
return self.attention_computation(query, key, value, shapes, batch_size, model_comm_group)
50 changes: 49 additions & 1 deletion src/anemoi/models/layers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from anemoi.models.distributed.khop_edges import sort_edges_1hop_chunks
from anemoi.models.distributed.transformer import shard_heads
from anemoi.models.distributed.transformer import shard_sequence
from anemoi.models.layers.attention import MultiHeadCrossAttention
from anemoi.models.layers.attention import MultiHeadSelfAttention
from anemoi.models.layers.conv import GraphConv
from anemoi.models.layers.conv import GraphTransformerConv
Expand Down Expand Up @@ -105,6 +106,53 @@ def forward(
return x


class TransformerMapperBlock(TransformerProcessorBlock):
"""Transformer mapper block with MultiHeadCrossAttention and MLPs."""

def __init__(
self,
num_channels: int,
hidden_dim: int,
num_heads: int,
activation: str,
window_size: int,
dropout_p: float = 0.0,
):
super().__init__(
num_channels=num_channels,
hidden_dim=hidden_dim,
num_heads=num_heads,
activation=activation,
window_size=window_size,
dropout_p=dropout_p,
)

self.attention = MultiHeadCrossAttention(
num_heads=num_heads,
embed_dim=num_channels,
window_size=window_size,
bias=False,
is_causal=False,
dropout_p=dropout_p,
)

self.layer_norm_src = nn.LayerNorm(num_channels)

def forward(
self,
x: OptPairTensor,
shapes: list,
batch_size: int,
model_comm_group: Optional[ProcessGroup] = None,
) -> Tensor:
# Need to be out of place for gradient propagation
x_src = self.layer_norm_src(x[0])
x_dst = self.layer_norm1(x[1])
x_dst = x_dst + self.attention((x_src, x_dst), shapes, batch_size, model_comm_group=model_comm_group)
x_dst = x_dst + self.mlp(self.layer_norm2(x_dst))
return (x_src, x_dst), None # logic expects return of edge_attr


class GraphConvBaseBlock(BaseBlock):
"""Message passing block with MLPs for node embeddings."""

Expand Down Expand Up @@ -180,7 +228,7 @@ def __ini__(
**kwargs,
):
super().__init__(
self,
self, # is this correct?
in_channels=in_channels,
out_channels=out_channels,
mlp_extra_layers=mlp_extra_layers,
Expand Down
Loading
Loading