diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py index 794a318768..3bb328f636 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py @@ -15,15 +15,12 @@ import math -import os -import tarfile from dataclasses import dataclass -from pathlib import Path from typing import Callable, Literal, Optional, Sequence, Type import torch import torch.distributed -from megatron.core import parallel_state, tensor_parallel +from megatron.core import tensor_parallel from megatron.core.models.bert.bert_lm_head import BertLMHead from megatron.core.models.bert.pooler import Pooler from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding @@ -33,18 +30,14 @@ from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.utils import get_linear_layer from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer -from nemo.lightning import get_vocab_size -from nemo.lightning.megatron_parallel import MegatronLossReduction +from nemo.lightning import io from torch import Tensor from torch.optim import Optimizer -from bionemo.core.model.config import BionemoModelConfig from bionemo.esm2.model.attention import ESM2DotProductAttention from bionemo.esm2.model.embedding import ESM2Embedding -from bionemo.llm.model.biobert.model import MegatronBioBertModel -from bionemo.llm.model.biobert.transformer_specs import BiobertSpecOption, get_biobert_spec -from bionemo.llm.model.loss import BERTMLMLossWithReduction -from bionemo.llm.utils.weight_utils import nemo1_to_nemo2_biobert_key_mapping +from bionemo.llm.model.biobert.model import BioBertGenericConfig, MegatronBioBertModel +from bionemo.llm.model.biobert.transformer_specs import BiobertSpecOption __all__: Sequence[str] = ( @@ -75,6 +68,7 @@ def __init__( add_binary_head=True, return_embeddings=False, use_full_attention_mask=False, + include_hiddens: bool = False, ) -> None: """Initialize the ESM2 model. @@ -120,6 +114,7 @@ def __init__( self.position_embedding_type = position_embedding_type self.add_binary_head = add_binary_head self.return_embeddings = return_embeddings + self.include_hiddens = include_hiddens # megatron core pipelining currently depends on model type self.model_type = ModelType.encoder_or_decoder @@ -219,8 +214,8 @@ def esm_gelu_func(x: Tensor) -> Tensor: return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) -@dataclass -class ESM2Config(BionemoModelConfig[ESM2Model], TransformerConfig): +@dataclass(kw_only=True) +class ESM2Config(BioBertGenericConfig[ESM2Model], io.IOMixin): """Configuration class for ESM2 model. Attributes: @@ -258,11 +253,12 @@ class ESM2Config(BionemoModelConfig[ESM2Model], TransformerConfig): return_only_hidden_states: Whether to return only hidden states. """ + model_cls = ESM2Model num_layers: int = 33 # 650M hidden_size: int = 1280 # 650M num_attention_heads: int = 20 ffn_hidden_size: int = 4 * 1280 # Transformer FFN hidden size. Usually 4 * hidden_size. - hidden_dropout: int = 0 # ESM2 removes dropout from hidden layers and attention + hidden_dropout: float = 0 # ESM2 removes dropout from hidden layers and attention attention_dropout: float = 0.0 # ESM2 does not use attention dropout apply_residual_connection_post_layernorm: bool = False # TODO: farhadr False is new default, True was BERT pub. layernorm_epsilon: float = 1.0e-5 @@ -291,7 +287,7 @@ class ESM2Config(BionemoModelConfig[ESM2Model], TransformerConfig): rotary_percent: float = 1.0 seq_len_interpolation_factor: Optional[float] = None seq_length: int = 1024 - biobert_spec_option: BiobertSpecOption = BiobertSpecOption.esm2_bert_layer_local_spec.value + biobert_spec_option: BiobertSpecOption = BiobertSpecOption.esm2_bert_layer_local_spec # TODO: Move this to better places? get_attention_mask_from_fusion: bool = False @@ -302,95 +298,8 @@ class ESM2Config(BionemoModelConfig[ESM2Model], TransformerConfig): # TODO (@skothenhill,@jstjohn) come up with a nice way of doing fine-tuning checkpoint loading, # where some acceptible layers (eg lm_head) may or may not be absent from the model, and others # (like a new head) may be new and missing from the initial checkpoint. - nemo1_ckpt_path: Optional[Path] = None + nemo1_ckpt_path: str | None = None # TODO (@jstjohn) come up with a cleaner way in the biobert module to return user requested # things as part of the workflow for inference and fine-tuning. return_only_hidden_states: bool = False # return logits - - def configure_model(self, tokenizer) -> ESM2Model: - """Configures the ESM2Model with the given tokenizer. - - Args: - tokenizer: The tokenizer to be used. - - Returns: - An instance of ESM2Model configured with the specified parameters. - """ - vp_size = self.virtual_pipeline_model_parallel_size - if vp_size: - p_size = self.pipeline_model_parallel_size - assert ( - self.num_layers // p_size - ) % vp_size == 0, "Make sure the number of model chunks is the same across all pipeline stages." - - # The local specs all require the standard full attention mask. For transformer engine only the NVTE_FLASH_ATTN=0 - # option requires this full attention mask. - use_full_attention_mask: bool = os.getenv("NVTE_FLASH_ATTN") == "0" or self.biobert_spec_option in { - BiobertSpecOption.bert_layer_local_spec, - BiobertSpecOption.bert_layer_local_spec_with_qk_ln, - BiobertSpecOption.esm2_bert_layer_local_spec, - } - - do_next_sentence = False - - model = ESM2Model( - self, - transformer_layer_spec=get_biobert_spec( - self.biobert_spec_option, - qk_layernorm=self.qk_layernorm, - core_attention=ESM2DotProductAttention, - ), - num_tokentypes=2 if do_next_sentence else 0, - vocab_size=get_vocab_size(self, tokenizer.vocab_size, self.make_vocab_size_divisible_by), - max_sequence_length=self.seq_length, - tokenizer=tokenizer, - fp16_lm_cross_entropy=self.fp16_lm_cross_entropy, - parallel_output=self.parallel_output, - share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, - position_embedding_type=self.position_embedding_type, - rotary_percent=self.rotary_percent, - seq_len_interpolation_factor=self.seq_len_interpolation_factor, - return_embeddings=False, - pre_process=parallel_state.is_pipeline_first_stage(), - post_process=parallel_state.is_pipeline_last_stage(), # set to False for inference - add_binary_head=do_next_sentence, - use_full_attention_mask=use_full_attention_mask, - ) - # TODO (@skothenhill) this is a hack to load the old checkpoint. - # This should be removed once we have a proper checkpoint conversion - # see NeMo/nemo/collections/llm/gpt/model/mixtral.py for how we should do it. - # We should eventually have an adapter for nemo1 checkpoints, HF checkpoints (at least for ESM2 @georgea) - # and an adapter may also be the right way to handle expected missing/extra keys when importing - # a checkpoint for fine-tuning (eg ignore misisng lm_head, if not there in model, etc). - if self.nemo1_ckpt_path is not None: - te_mapping = self.biobert_spec_option in { - BiobertSpecOption.bert_layer_with_transformer_engine_spec, - BiobertSpecOption.bert_layer_with_transformer_engine_and_qk_ln_spec, - } - with tarfile.open(self.nemo1_ckpt_path, "r") as old_ckpt: - ckpt_file = old_ckpt.extractfile("./model_weights.ckpt") - old_weights = torch.load(ckpt_file) - new_state_dict_from_old = {} - for k, v in old_weights.items(): - if "word_embeddings" in k: - print(k) - new_key = nemo1_to_nemo2_biobert_key_mapping(k, new_model_prefix="", te_mapping=te_mapping) - new_state_dict_from_old[new_key] = v - # TE adds non-null ._extra_state objects to layers, which store some kind of buffer bits - # so we need to allow those to pass through if we're loading from bionemo1 which did not - # use TE. - model.load_state_dict(new_state_dict_from_old, strict=not te_mapping) - - # TODO (@jstjohn) come up with a cleaner way in the biobert module to return hidden states. - # maybe a suite of options like hugging face has so a user can ask for several or only one thing. - if self.return_only_hidden_states: - # this applies the final layernorm in the encoder to the hidden states which was - # the default in nemo1. - model.post_process = False - model.encoder.post_process = True - model.encoder.post_layer_norm = True - return model - - def get_loss_reduction_class(self) -> Type[MegatronLossReduction]: # noqa: D102 - # You could optionally return a different loss reduction class here based on the config settings. - return BERTMLMLossWithReduction + core_attention_override: Type[torch.nn.Module] | None = ESM2DotProductAttention diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py index 089a4caf0f..fabffea69d 100644 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py @@ -148,7 +148,7 @@ def test_esm2_650m_checkpoint(esm2_model): def test_esm2_golden_values(esm2_650M_config_w_ckpt, sample_data): device = "cuda" - + assert esm2_650M_config_w_ckpt.core_attention_override is not None tokenizer = AutoTokenizer(pretrained_model_name="facebook/esm2_t33_650M_UR50D") tokens = tokenizer.tokenizer([row[1] for row in sample_data], return_tensors="pt", padding=True).to(device) input_ids = tokens["input_ids"] @@ -183,7 +183,7 @@ def test_esm2_golden_values(esm2_650M_config_w_ckpt, sample_data): # configure the model to return hiddens esm2_650M_config_hiddens = deepcopy(esm2_650M_config_w_ckpt) - esm2_650M_config_hiddens.return_only_hidden_states = True + esm2_650M_config_hiddens.mutate_hparam("return_only_hidden_states", True) model = esm2_650M_config_hiddens.configure_model(tokenizer).to(device) model.eval() hiddens = model(input_ids, attention_mask) diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py b/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py index 0ef6656ccf..af41ce13cc 100644 --- a/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py +++ b/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py @@ -24,7 +24,6 @@ Any, Callable, Dict, - Generic, List, Literal, Optional, @@ -128,7 +127,7 @@ def __init__( # noqa: D107 transformer_layer_spec: spec_utils.ModuleSpec, vocab_size: int, max_sequence_length: int, - tokrnizer: Optional[AutoTokenizer] = None, + tokenizer: Optional[AutoTokenizer] = None, pre_process: bool = True, post_process: bool = True, fp16_lm_cross_entropy: bool = False, @@ -407,12 +406,9 @@ def override_mutate_possibly_extra_mutated_fiddle( setattr(target_cfg, f, getattr(source_cfg, f)) -@dataclass +@dataclass(kw_only=True) class BioBertGenericConfig( - Generic[MegatronBioBertModelT], MegatronBioNeMoTrainableModelConfig[MegatronBioBertModelT, MegatronLossReduction], - TransformerConfig, - # Do not add iomixin here ): """Config class for BioBert model, responsible for the partial configuration of Transformer models. @@ -421,7 +417,6 @@ class BioBertGenericConfig( `configure_model()` is ultimately called by the LightningModule using PTL lightning module hooks. """ - model_cls: Type[MegatronBioBertModelT] = MegatronBioBertModel # From megatron.core.models.gpt.bert_model.GPTModel fp16_lm_cross_entropy: bool = False parallel_output: bool = True @@ -457,6 +452,7 @@ class BioBertGenericConfig( override_parent_fields: List[str] = field(default_factory=lambda: _OVERRIDE_BIONEMO_CONFIG_DEFAULTS) return_only_hidden_states: bool = False include_hiddens: bool = False # Include hidden layers in the output of the model + core_attention_override: Type[torch.nn.Module] | None = field(default_factory=lambda: None) def configure_model(self, tokenizer) -> MegatronBioBertModelT: # noqa: D102 vp_size = self.virtual_pipeline_model_parallel_size @@ -496,11 +492,15 @@ def configure_model(self, tokenizer) -> MegatronBioBertModelT: # noqa: D102 model = self.model_cls( self, - transformer_layer_spec=get_biobert_spec(self.biobert_spec_option, qk_layernorm=self.qk_layernorm), + transformer_layer_spec=get_biobert_spec( + self.biobert_spec_option, + qk_layernorm=self.qk_layernorm, + core_attention=self.core_attention_override, + ), num_tokentypes=2 if do_next_sentence else 0, vocab_size=get_vocab_size(self, tokenizer.vocab_size, self.make_vocab_size_divisible_by), max_sequence_length=self.seq_length, - tokrnizer=tokenizer, + tokenizer=tokenizer, fp16_lm_cross_entropy=self.fp16_lm_cross_entropy, parallel_output=self.parallel_output, share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/transformer_specs.py b/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/transformer_specs.py index 0be173ff5f..0288114505 100644 --- a/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/transformer_specs.py +++ b/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/transformer_specs.py @@ -15,7 +15,7 @@ from enum import Enum -from typing import Optional, Sequence +from typing import Optional, Sequence, Type from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add from megatron.core.fusions.fused_layer_norm import FusedLayerNorm @@ -60,7 +60,7 @@ class BiobertSpecOption(str, Enum): def get_biobert_spec( # noqa: D417 biobert_spec_option: BiobertSpecOption, qk_layernorm: bool = False, - core_attention: Optional[Module] = None, + core_attention: Optional[Type[Module]] = None, ) -> spec_utils.ModuleSpec: """Get the spec for the Biobert model. diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/model/config.py b/sub-packages/bionemo-llm/src/bionemo/llm/model/config.py index 0b79fcd8d0..6ed115b9e9 100644 --- a/sub-packages/bionemo-llm/src/bionemo/llm/model/config.py +++ b/sub-packages/bionemo-llm/src/bionemo/llm/model/config.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Generic, Type +from dataclasses import dataclass, field +from typing import Type from megatron.core.transformer import TransformerConfig from nemo.lightning import io @@ -21,15 +22,13 @@ from bionemo.core.model.config import BionemoModelConfig, BionemoTrainableModelConfig, Loss, Model -class MegatronBioNeMoModelConfig(Generic[Model], BionemoModelConfig[Model], TransformerConfig): +@dataclass(kw_only=True) +class MegatronBioNeMoModelConfig(BionemoModelConfig[Model], TransformerConfig, io.NeedsIOMixin): """A ModelConfig class for bionemo that supports usage with Megatron models, for example as NeMo2 requires.""" - model_cls: Type[Model] + model_cls: Type[Model] = field(init=False) -class MegatronBioNeMoTrainableModelConfig( - Generic[Model, Loss], BionemoTrainableModelConfig[Model, Loss], MegatronBioNeMoModelConfig[Model], io.NeedsIOMixin -): - """A ModelConfig class for bionemo that supports usage with Megatron models, for example as NeMo2 requires.""" - - model_cls: Type[Model] +@dataclass(kw_only=True) +class MegatronBioNeMoTrainableModelConfig(MegatronBioNeMoModelConfig[Model], BionemoTrainableModelConfig[Model, Loss]): + pass