From 11b689515e501d1da26e9aa8344da195dccf0881 Mon Sep 17 00:00:00 2001
From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com>
Date: Thu, 9 Jan 2025 10:09:04 -0800
Subject: [PATCH] Bump 3rdparty/Megatron-LM from `2da43ef` to `65720c8` (#579)
Bumps [3rdparty/Megatron-LM](https://github.com/NVIDIA/Megatron-LM) from
`2da43ef` to `65720c8`.
Commits
65720c8
Merge branch 'ko3n1g/chore/fix-local-generator-script' into 'main'
c8d12e6
ADLR/megatron-lm!2519 - chore: Fix local generator script
ab171c5
Merge branch 'ko3n1g/ci/use-torchrun' into 'main'
6e09dd4
ADLR/megatron-lm!2507 - ci: Use torchrun
df28200
Merge branch 'generate_fix' into 'main'
342e359
ADLR/megatron-lm!2370 - Make generate function only return results for
newly ...
15517f6
Merge branch 'ko3n1g/ci/update-nightlies' into 'main'
c383fe9
ADLR/megatron-lm!2511 - ci: Update golden values of nightlies
86e5481
Merge branch 'video_training' into 'main'
82a6dfd
ADLR/megatron-lm!2500 - Video training
- Additional commits viewable in compare
view
Dependabot will resolve any conflicts with this PR as long as you don't
alter it yourself. You can also trigger a rebase manually by commenting
`@dependabot rebase`.
[//]: # (dependabot-automerge-start)
[//]: # (dependabot-automerge-end)
---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR:
- `@dependabot rebase` will rebase this PR
- `@dependabot recreate` will recreate this PR, overwriting any edits
that have been made to it
- `@dependabot merge` will merge this PR after your CI passes on it
- `@dependabot squash and merge` will squash and merge this PR after
your CI passes on it
- `@dependabot cancel merge` will cancel a previously requested merge
and block automerging
- `@dependabot reopen` will reopen this PR if it is closed
- `@dependabot close` will close this PR and stop Dependabot recreating
it. You can achieve the same result by closing it manually
- `@dependabot show ignore conditions` will show all
of the ignore conditions of the specified dependency
- `@dependabot ignore this major version` will close this PR and stop
Dependabot creating any more for this major version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this minor version` will close this PR and stop
Dependabot creating any more for this minor version (unless you reopen
the PR or upgrade to it yourself)
- `@dependabot ignore this dependency` will close this PR and stop
Dependabot creating any more for this dependency (unless you reopen the
PR or upgrade to it yourself)
---------
Signed-off-by: dependabot[bot]
Signed-off-by: Peter St. John
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Peter St. John
---
3rdparty/Megatron-LM | 2 +-
.../src/bionemo/esm2/model/attention.py | 365 ------------------
.../src/bionemo/esm2/model/model.py | 4 +-
.../src/bionemo/esm2/run/config_models.py | 3 -
.../bionemo/esm2/model/test_attention.py | 120 ------
5 files changed, 2 insertions(+), 492 deletions(-)
delete mode 100644 sub-packages/bionemo-esm2/src/bionemo/esm2/model/attention.py
delete mode 100644 sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_attention.py
diff --git a/3rdparty/Megatron-LM b/3rdparty/Megatron-LM
index 2da43ef4c1..65720c87ba 160000
--- a/3rdparty/Megatron-LM
+++ b/3rdparty/Megatron-LM
@@ -1 +1 @@
-Subproject commit 2da43ef4c1b9e76f03b7567360cf7390e877f1b6
+Subproject commit 65720c87ba9c9d0ae8c90b1ffdbdccd2d51b1bc1
diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/attention.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/attention.py
deleted file mode 100644
index 63d93448e1..0000000000
--- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/attention.py
+++ /dev/null
@@ -1,365 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: LicenseRef-Apache2
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import os
-from typing import Callable, Optional, Sequence, Union
-
-import torch
-from megatron.core import parallel_state, tensor_parallel
-from megatron.core.extensions.transformer_engine import TEDotProductAttention
-from megatron.core.packed_seq_params import PackedSeqParams
-from megatron.core.parallel_state import (
- get_context_parallel_global_ranks,
- get_context_parallel_group,
- get_tensor_model_parallel_group,
-)
-from megatron.core.tensor_parallel import get_cuda_rng_tracker
-from megatron.core.transformer.dot_product_attention import DotProductAttention
-from megatron.core.transformer.enums import AttnMaskType
-from megatron.core.transformer.transformer_config import TransformerConfig
-from megatron.core.utils import get_te_version, is_te_min_version
-from torch import Tensor
-
-
-__all__: Sequence[str] = ("ESM2DotProductAttention", "ESM2TEDotProductAttention")
-
-
-class ESM2TEDotProductAttention(TEDotProductAttention):
- """ESM2-Specific transformer engine core attention.
-
- Override the softmax_scale to 1.0 to match the ESM2 implementation while keeping the rest from the original TEDotProductAttention.
- """
-
- def __init__(
- self,
- config: TransformerConfig,
- layer_number: int,
- attn_mask_type: AttnMaskType,
- attention_type: str,
- attention_dropout: float | None = None,
- softmax_scale: float = 1.0,
- k_channels: int | None = None,
- v_channels: int | None = None,
- cp_comm_type: str = "p2p",
- ):
- """Initialize ESM2TEDotProductAttention."""
- self.config = config
- self.te_forward_mask_type = False
- self.qkv_format: str = "sbhd"
-
- if self.config.apply_query_key_layer_scaling != bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))):
- raise ValueError(
- f"apply_query_key_layer_scaling is {self.config.apply_query_key_layer_scaling} "
- f"but environment variable NVTE_APPLY_QK_LAYER_SCALING is "
- f"{os.getenv('NVTE_APPLY_QK_LAYER_SCALING')}. Transformer Engine does not support "
- f"setting query key layer scaling via argument, so these two must match."
- )
-
- extra_kwargs = {}
- if is_te_min_version("0.11.0"):
- extra_kwargs["num_gqa_groups"] = self.config.num_query_groups
- elif self.config.num_query_groups != self.config.num_attention_heads:
- raise ValueError(
- f"Transformer Engine v{get_te_version()} does not support Grouped Query Attention, "
- f"use a newer version of Transformer Engine. "
- f"(num_query_groups ({self.config.num_query_groups}) != "
- f"num_attention_heads ({self.config.num_attention_heads}))"
- )
-
- if is_te_min_version("0.10.0"):
- extra_kwargs["attention_type"] = attention_type
- # older version don't need attention_type
-
- if is_te_min_version("0.12.0", check_equality=False):
- self.te_forward_mask_type = True
-
- # Only Transformer-Engine version >= 1.0.0 supports context parallelism
- if is_te_min_version("1.0.0"):
- if getattr(TEDotProductAttention, "cp_stream") is None:
- TEDotProductAttention.cp_stream = torch.cuda.Stream()
- extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False)
- extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks(check_initialized=False)
- extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream
- if is_te_min_version("1.10.0"):
- if cp_comm_type is None:
- extra_kwargs["cp_comm_type"] = "p2p"
- else:
- extra_kwargs["cp_comm_type"] = cp_comm_type
- else:
- assert (
- self.config.context_parallel_size == 1
- ), "Only Transformer-Engine version >= 1.0.0 supports context parallelism!"
-
- if self.config.deterministic_mode:
- if int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")) != 0:
- raise RuntimeError(
- "deterministic_mode is on and we are using DotProductAttention from "
- "Transformer Engine, but NVTE_ALLOW_NONDETERMINISTIC_ALGO is not 0. "
- f"Currently set to: {os.getenv('NVTE_ALLOW_NONDETERMINISTIC_ALGO', 'not set')}."
- )
-
- if config.window_size is not None:
- # Check version
- assert is_te_min_version("1.2.0"), (
- f"Transformer-Engine v{get_te_version()} must be >= 1.2.0 to support" "sliding window attention."
- )
- extra_kwargs["window_size"] = config.window_size
-
- if is_te_min_version("1.10.0"):
- # TE 1.10.0 introduces the ability to set the different k and v channels
- kv_channels = (
- (k_channels, v_channels)
- if k_channels is not None and v_channels is not None
- else self.config.kv_channels
- )
- else:
- kv_channels = self.config.kv_channels
-
- extra_kwargs["softmax_scale"] = softmax_scale
-
- super(TEDotProductAttention, self).__init__(
- num_attention_heads=self.config.num_attention_heads,
- kv_channels=kv_channels,
- attention_dropout=(self.config.attention_dropout if attention_dropout is None else attention_dropout),
- attn_mask_type=attn_mask_type.name,
- sequence_parallel=self.config.sequence_parallel,
- tp_size=self.config.tensor_model_parallel_size,
- get_rng_state_tracker=(get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None),
- tp_group=get_tensor_model_parallel_group(check_initialized=False),
- layer_number=layer_number,
- **extra_kwargs,
- )
-
-
-class ESM2DotProductAttention(DotProductAttention):
- """ESM2-Specific core attention.
-
- Region where selective activation recomputation is applied.
- This region is memory intensive but less compute intensive which
- makes activation checkpointing more efficient for LLMs (20B+).
- See Reducing Activation Recomputation in Large Transformer Models:
- https://arxiv.org/abs/2205.05198 for more details.
-
- We use the following notation:
- h: hidden size
- n: number of attention heads
- p: number of tensor model parallel partitions
- b: batch size
- s: sequence length
- """
-
- def __init__(
- self,
- config: TransformerConfig,
- layer_number: int,
- attn_mask_type: AttnMaskType,
- attention_type: str,
- attention_dropout: Optional[float] = None,
- ) -> None:
- """Initializes the Attention class.
-
- Args:
- config: The configuration object for the transformer.
- layer_number: The layer number of the attention module.
- attn_mask_type: The type of attention mask to be used.
- attention_type: The type of attention mechanism.
- attention_dropout: The dropout rate for attention weights. Defaults to None.
- """
- super().__init__(
- config=config,
- layer_number=layer_number,
- attn_mask_type=attn_mask_type,
- attention_type=attention_type,
- attention_dropout=attention_dropout,
- )
-
- def forward(
- self,
- query: Tensor,
- key: Tensor,
- value: Tensor,
- attention_mask: Tensor,
- attn_mask_type: Optional[AttnMaskType] = None,
- packed_seq_params: Optional[PackedSeqParams] = None,
- ):
- """Forward pass of the ESM2DotProductAttention module.
-
- Args:
- query: The query tensor of shape [sq, b, np, hn].
- key: The key tensor of shape [sk, b, ng, hn].
- value: The value tensor of shape [sk, b, ng, hn].
- attention_mask: The attention mask tensor of shape [b, np, sq, sk].
- attn_mask_type: The attention mask type, currently unused. Defaults to None.
- packed_seq_params: The packed sequence parameters. These are used for context parallelism so will be needed
- to be implemented if we want to support this. Defaults to None.
-
- Returns:
- Tensor: The context tensor of shape [sq, b, hp].
- """
- if packed_seq_params is not None:
- raise ValueError(
- "Packed sequence is not supported by DotProductAttention. " "Please use TEDotProductAttention instead."
- )
-
- # ===================================
- # Raw attention scores. [b, n/p, s, s]
- # ===================================
-
- # expand the key and value [sk, b, ng, hn] -> [sk, b, np, hn]
- # This is a noop for normal attention where ng == np. When using group query attention this
- # creates a view that has the keys and values virtually repeated along their dimension to
- # match the number of queries.
-
- # attn_mask_type is not used.
- if (np_ng := self.num_attention_heads_per_partition // self.num_query_groups_per_partition) > 1:
- key = key.repeat_interleave(np_ng, dim=2)
- value = value.repeat_interleave(np_ng, dim=2)
-
- # [b, np, sq, sk]
- b, np, sq, sk = query.size(1), query.size(2), query.size(0), key.size(0)
-
- # [sq, b, np, hn] -> [sq, b * np, hn]
- # This will be a simple view when doing normal attention, but in group query attention
- # the key and value tensors are repeated to match the queries so you can't use simple strides
- # to extract the queries.
- query = query.reshape(sq, b * np, -1)
- # [sk, b, np, hn] -> [sk, b * np, hn]
- key = key.view(sk, b * np, -1)
-
- # preallocting input tensor: [b * np, sq, sk]
- matmul_input_buffer = parallel_state.get_global_memory_buffer().get_tensor(
- (b * np, sq, sk),
- query.dtype,
- "mpu",
- )
-
- # Raw attention scores. [b * np, sq, sk]
- matmul_result = torch.baddbmm(
- matmul_input_buffer,
- query.transpose(0, 1), # [b * np, sq, hn]
- key.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
- beta=0.0,
- alpha=(1.0 / self.norm_factor) if self.config.normalize_attention_scores else 1.0,
- )
-
- # change view to [b, np, sq, sk]
- attention_scores = matmul_result.view(b, np, sq, sk)
-
- # ===========================
- # Attention probs and dropout
- # ===========================
-
- # attention scores and attention mask [b, np, sq, sk]
- # ESM2 Customization
- if self.config.use_esm_attention:
- # NOTE: the slicing here is to make the attention_mask the same shape as the extended
- # attention mask in ESM2. The multiplication by -3.4028e+38 (float32 min_val) is
- # similarly motivated by ESM2's masking approach, which forces softmax of attention scores
- # for masked entries to be close to 0. This number is replaced with min_val of the precision
- # using min_val instead of -inf is stable in an special case where all sequence is masked
- min_val = torch.finfo(attention_scores.dtype).min
-
- attention_probs: Tensor = self.esm2_scale_mask_softmax(
- attention_scores.masked_fill(attention_mask[:, :, 0:1, :].to(bool), min_val)
- )
- # END ESM2 Customization
- else:
- attention_probs: Tensor = self.scale_mask_softmax(attention_scores, attention_mask)
-
- # This is actually dropping out entire tokens to attend to, which might
- # seem a bit unusual, but is taken from the original Transformer paper.
-
- if not self.config.sequence_parallel:
- with tensor_parallel.get_cuda_rng_tracker().fork():
- attention_probs = self.attention_dropout(attention_probs)
- else:
- attention_probs = self.attention_dropout(attention_probs)
-
- # =========================
- # Context layer. [sq, b, hp]
- # =========================
-
- # value -> context layer.
- # [sk, b, np, hn] --> [b, np, sq, hn]
-
- # context layer shape: [b, np, sq, hn]
- b, np, sq, hn = value.size(1), value.size(2), query.size(0), value.size(3)
-
- # change view [sk, b * np, hn]
- value = value.view(value.size(0), b * np, -1)
-
- # change view [b * np, sq, sk]
- attention_probs = attention_probs.view(b * np, sq, -1)
-
- # matmul: [b * np, sq, hn]
- context = torch.bmm(attention_probs, value.transpose(0, 1))
-
- # change view [b, np, sq, hn]
- context = context.view(b, np, sq, hn)
-
- # [b, np, sq, hn] --> [sq, b, np, hn]
- context = context.permute(2, 0, 1, 3).contiguous()
-
- # [sq, b, np, hn] --> [sq, b, hp]
- context = context.view(sq, b, self.hidden_size_per_partition)
-
- return context
-
- def esm2_scale_mask_softmax(
- self,
- input: Tensor,
- mask: Optional[Tensor] = None,
- scale: Optional[Union[float, int]] = None,
- mask_func: Optional[Callable] = None,
- ) -> Tensor:
- """Scale Mask Softmax function.
-
- Args:
- input: Tensor of shape (Batch, NP, SK, SQ). The input may or may not have already
- had a mask applied to it.
- mask: If a mask is to be applied, it will go here.
- scale: A scale factor that will be applied before the softmax.
- mask_func: An optional function to apply to the mask. If None, it is assumed that
- the input already had the mask applied to it.
-
- Returns:
- probs: Tensor of normalized probabilities after the softmax has been applied,
- of shape (Batch, NP, SK, SQ).
- """
- if self.attn_mask_type.name != "padding":
- raise ValueError(
- f"self.attn_mask_type: {self.attn_mask_type} is not 'padding'. "
- "Only 'padding' type is supported currently."
- )
-
- original_dtype = input.dtype # Store original dtype
- if (
- original_dtype == torch.float16 or original_dtype == torch.bfloat16
- ) and self.config.attention_softmax_in_fp32:
- input = input.float() # Convert to float32 for softmax
-
- if scale is not None:
- input = input * scale # Apply scaling
-
- if mask is not None and mask_func is not None:
- input = mask_func(input, mask) # Apply mask function if provided
-
- probs = torch.nn.functional.softmax(input, dim=-1) # Apply softmax
-
- if self.config.attention_softmax_in_fp32 and original_dtype in (torch.float16, torch.bfloat16):
- probs = probs.to(original_dtype) # Convert back to original dtype if necessary
-
- return probs
diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py
index d0999c2773..b9c82ed258 100644
--- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py
+++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py
@@ -35,7 +35,6 @@
from torch.optim import Optimizer
from bionemo.esm2.data.tokenizer import BioNeMoESMTokenizer
-from bionemo.esm2.model.attention import ESM2DotProductAttention, ESM2TEDotProductAttention
from bionemo.esm2.model.embedding import ESM2Embedding
from bionemo.llm.api import MegatronLossType
from bionemo.llm.model.biobert.model import BioBertConfig, MegatronBioBertModel, PositionEmbeddingKinds
@@ -294,6 +293,7 @@ class ESM2GenericConfig(BioBertConfig[ESM2ModelT, MegatronLossType]):
bias_activation_fusion: bool = True # True degrades accuracy slightly, but is faster.
activation_func: Callable = F.gelu # esm_gelu_func # ESM2 MLP
init_method_std: float = 0.02
+ softmax_scale: float = 1.0
# embedding
token_dropout: bool = True
@@ -346,13 +346,11 @@ def __post_init__(self):
super().__post_init__()
if self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec:
self.apply_query_key_layer_scaling = False
- self.core_attention_override = ESM2TEDotProductAttention
elif self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_local_spec:
logging.warning(
"BiobertSpecOption.esm2_bert_layer_local_spec is depreciated. Use BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec instead."
)
self.apply_query_key_layer_scaling = True
- self.core_attention_override = ESM2DotProductAttention
else:
raise ValueError(f"Unknown biobert_spec_option: {self.biobert_spec_option}")
diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/run/config_models.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/run/config_models.py
index 5ba6739164..ac21820875 100644
--- a/sub-packages/bionemo-esm2/src/bionemo/esm2/run/config_models.py
+++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/run/config_models.py
@@ -25,7 +25,6 @@
from bionemo.esm2.data.datamodule import ESMDataModule
from bionemo.esm2.data.dataset import RandomMaskStrategy
from bionemo.esm2.data.tokenizer import get_tokenizer
-from bionemo.esm2.model.attention import ESM2DotProductAttention, ESM2TEDotProductAttention
from bionemo.esm2.model.model import ESM2Config
from bionemo.llm.model.biobert.model import BiobertSpecOption
from bionemo.llm.run.config_models import (
@@ -188,14 +187,12 @@ def validate_and_set_attention_and_scaling(self):
)
if self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec:
self.apply_query_key_layer_scaling = False
- self.core_attention_override = ESM2TEDotProductAttention
elif self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_local_spec:
logging.warning(
"BiobertSpecOption.esm2_bert_layer_local_spec is deprecated. "
"Use BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec instead."
)
self.apply_query_key_layer_scaling = True
- self.core_attention_override = ESM2DotProductAttention
return self
def model_validator(self, global_cfg: MainConfig) -> MainConfig:
diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_attention.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_attention.py
deleted file mode 100644
index 6383a04b65..0000000000
--- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_attention.py
+++ /dev/null
@@ -1,120 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: LicenseRef-Apache2
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import math
-
-import pytest
-import torch
-from megatron.core.transformer.enums import AttnMaskType
-
-from bionemo.esm2.api import ESM2Config
-from bionemo.esm2.model.attention import ESM2DotProductAttention, ESM2TEDotProductAttention
-from bionemo.testing import megatron_parallel_state_utils
-
-
-@pytest.fixture(scope="module")
-def config():
- with megatron_parallel_state_utils.distributed_model_parallel_state():
- yield ESM2Config(
- seq_length=20,
- hidden_size=4,
- num_attention_heads=4,
- attention_dropout=0.1,
- use_esm_attention=True,
- )
-
-
-@pytest.fixture(scope="module")
-def local_attention_layer(config: ESM2Config) -> ESM2DotProductAttention:
- return ESM2DotProductAttention(
- config=config,
- layer_number=0,
- attn_mask_type=AttnMaskType.padding,
- attention_type="normal",
- ).eval()
-
-
-@pytest.fixture(scope="module")
-def attention_layer(config: ESM2Config) -> ESM2TEDotProductAttention:
- return ESM2TEDotProductAttention(
- config=config,
- layer_number=0,
- attn_mask_type=AttnMaskType.padding,
- attention_type="self",
- ).eval()
-
-
-def test_init(attention_layer, config):
- assert attention_layer.config.use_esm_attention
- assert attention_layer.config == config
-
-
-@pytest.mark.skip(reason="Not implemented yet for transformer engine")
-def test_forward(attention_layer, config):
- batch_size = 2
- sequence_length = config.seq_length
- hidden_size = config.hidden_size
- device = torch.device("cuda")
-
- query = torch.randn(sequence_length, batch_size, 1, hidden_size, device=device)
- key = torch.randn(sequence_length, batch_size, 1, hidden_size, device=device)
- value = torch.randn(sequence_length, batch_size, 1, hidden_size, device=device)
- random_ints = torch.randint(0, 2, (batch_size, 1, sequence_length, sequence_length), device=device)
- attention_mask = ((random_ints + torch.transpose(random_ints, dim0=2, dim1=3)) / 2).to(
- dtype=torch.bool
- ) # symmetric mask tensor
-
- if isinstance(attention_layer, ESM2TEDotProductAttention):
- raise NotImplementedError("TE requires reshaped input and is not implemented yet")
- else:
- output = attention_layer(query, key, value, attention_mask)
- assert output.shape == (sequence_length, batch_size, hidden_size)
-
-
-@pytest.mark.skip(reason="Not implemented yet for transformer engine")
-@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.half])
-def test_attention_with_mask(attention_layer, dtype):
- sequence_length_val = 3
- sequence_length_query = 1
- batch_size = 2
- emb_dim = 4
- device = torch.device("cuda")
-
- # query and key such that the dot prod is an all-ones tensor
- query = torch.ones(batch_size, sequence_length_query, 1, emb_dim, device=device, dtype=dtype) / math.sqrt(emb_dim)
- key = torch.ones(batch_size, sequence_length_val, 1, emb_dim, device=device, dtype=dtype) / math.sqrt(emb_dim)
-
- query = query.transpose(0, 1)
- key = key.transpose(0, 1)
-
- attention_mask = torch.zeros(batch_size, 1, 1, sequence_length_val, device=device, dtype=dtype)
- attention_mask[0, :, :, 2:] = 1 # average first two tensors in val
- attention_mask[1, :, :, 1:] = 1 # select first item from val
-
- values = torch.stack([torch.arange(sequence_length_val)] * batch_size).to(device=device, dtype=dtype) + 1.0
- values = torch.stack([values] * emb_dim, dim=2).unsqueeze(2).transpose(0, 1)
-
- assert values.shape == (sequence_length_val, batch_size, 1, emb_dim)
-
- # softmax will make the the avg first 2 tensors in vals (ones + twos)/2 and second row is just ones
- if isinstance(attention_layer, ESM2TEDotProductAttention):
- raise NotImplementedError("TE requires reshaped input and is not implemented yet")
- else:
- output = attention_layer(query, key, values, attention_mask)
- expected_output = torch.tensor(
- [[[1.5000, 1.5000, 1.5000, 1.5000], [1.0000, 1.0000, 1.0000, 1.0000]]], device=device, dtype=dtype
- )
- assert torch.equal(output, expected_output)