Skip to content

Commit

Permalink
Reconfigure ESM to work with the fine-tuning implementation in biobert
Browse files Browse the repository at this point in the history
  • Loading branch information
jstjohn committed Aug 22, 2024
1 parent 0b17cbb commit 2aeab63
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 126 deletions.
117 changes: 13 additions & 104 deletions sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,12 @@


import math
import os
import tarfile
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Literal, Optional, Sequence, Type

import torch
import torch.distributed
from megatron.core import parallel_state, tensor_parallel
from megatron.core import tensor_parallel
from megatron.core.models.bert.bert_lm_head import BertLMHead
from megatron.core.models.bert.pooler import Pooler
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
Expand All @@ -33,18 +30,14 @@
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import get_linear_layer
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.lightning import get_vocab_size
from nemo.lightning.megatron_parallel import MegatronLossReduction
from nemo.lightning import io
from torch import Tensor
from torch.optim import Optimizer

from bionemo.core.model.config import BionemoModelConfig
from bionemo.esm2.model.attention import ESM2DotProductAttention
from bionemo.esm2.model.embedding import ESM2Embedding
from bionemo.llm.model.biobert.model import MegatronBioBertModel
from bionemo.llm.model.biobert.transformer_specs import BiobertSpecOption, get_biobert_spec
from bionemo.llm.model.loss import BERTMLMLossWithReduction
from bionemo.llm.utils.weight_utils import nemo1_to_nemo2_biobert_key_mapping
from bionemo.llm.model.biobert.model import BioBertGenericConfig, MegatronBioBertModel
from bionemo.llm.model.biobert.transformer_specs import BiobertSpecOption


__all__: Sequence[str] = (
Expand Down Expand Up @@ -75,6 +68,7 @@ def __init__(
add_binary_head=True,
return_embeddings=False,
use_full_attention_mask=False,
include_hiddens: bool = False,
) -> None:
"""Initialize the ESM2 model.
Expand Down Expand Up @@ -120,6 +114,7 @@ def __init__(
self.position_embedding_type = position_embedding_type
self.add_binary_head = add_binary_head
self.return_embeddings = return_embeddings
self.include_hiddens = include_hiddens

# megatron core pipelining currently depends on model type
self.model_type = ModelType.encoder_or_decoder
Expand Down Expand Up @@ -219,8 +214,8 @@ def esm_gelu_func(x: Tensor) -> Tensor:
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))


@dataclass
class ESM2Config(BionemoModelConfig[ESM2Model], TransformerConfig):
@dataclass(kw_only=True)
class ESM2Config(BioBertGenericConfig[ESM2Model], io.IOMixin):
"""Configuration class for ESM2 model.
Attributes:
Expand Down Expand Up @@ -258,11 +253,12 @@ class ESM2Config(BionemoModelConfig[ESM2Model], TransformerConfig):
return_only_hidden_states: Whether to return only hidden states.
"""

model_cls = ESM2Model
num_layers: int = 33 # 650M
hidden_size: int = 1280 # 650M
num_attention_heads: int = 20
ffn_hidden_size: int = 4 * 1280 # Transformer FFN hidden size. Usually 4 * hidden_size.
hidden_dropout: int = 0 # ESM2 removes dropout from hidden layers and attention
hidden_dropout: float = 0 # ESM2 removes dropout from hidden layers and attention
attention_dropout: float = 0.0 # ESM2 does not use attention dropout
apply_residual_connection_post_layernorm: bool = False # TODO: farhadr False is new default, True was BERT pub.
layernorm_epsilon: float = 1.0e-5
Expand Down Expand Up @@ -291,7 +287,7 @@ class ESM2Config(BionemoModelConfig[ESM2Model], TransformerConfig):
rotary_percent: float = 1.0
seq_len_interpolation_factor: Optional[float] = None
seq_length: int = 1024
biobert_spec_option: BiobertSpecOption = BiobertSpecOption.esm2_bert_layer_local_spec.value
biobert_spec_option: BiobertSpecOption = BiobertSpecOption.esm2_bert_layer_local_spec

# TODO: Move this to better places?
get_attention_mask_from_fusion: bool = False
Expand All @@ -302,95 +298,8 @@ class ESM2Config(BionemoModelConfig[ESM2Model], TransformerConfig):
# TODO (@skothenhill,@jstjohn) come up with a nice way of doing fine-tuning checkpoint loading,
# where some acceptible layers (eg lm_head) may or may not be absent from the model, and others
# (like a new head) may be new and missing from the initial checkpoint.
nemo1_ckpt_path: Optional[Path] = None
nemo1_ckpt_path: str | None = None
# TODO (@jstjohn) come up with a cleaner way in the biobert module to return user requested
# things as part of the workflow for inference and fine-tuning.
return_only_hidden_states: bool = False # return logits

def configure_model(self, tokenizer) -> ESM2Model:
"""Configures the ESM2Model with the given tokenizer.
Args:
tokenizer: The tokenizer to be used.
Returns:
An instance of ESM2Model configured with the specified parameters.
"""
vp_size = self.virtual_pipeline_model_parallel_size
if vp_size:
p_size = self.pipeline_model_parallel_size
assert (
self.num_layers // p_size
) % vp_size == 0, "Make sure the number of model chunks is the same across all pipeline stages."

# The local specs all require the standard full attention mask. For transformer engine only the NVTE_FLASH_ATTN=0
# option requires this full attention mask.
use_full_attention_mask: bool = os.getenv("NVTE_FLASH_ATTN") == "0" or self.biobert_spec_option in {
BiobertSpecOption.bert_layer_local_spec,
BiobertSpecOption.bert_layer_local_spec_with_qk_ln,
BiobertSpecOption.esm2_bert_layer_local_spec,
}

do_next_sentence = False

model = ESM2Model(
self,
transformer_layer_spec=get_biobert_spec(
self.biobert_spec_option,
qk_layernorm=self.qk_layernorm,
core_attention=ESM2DotProductAttention,
),
num_tokentypes=2 if do_next_sentence else 0,
vocab_size=get_vocab_size(self, tokenizer.vocab_size, self.make_vocab_size_divisible_by),
max_sequence_length=self.seq_length,
tokenizer=tokenizer,
fp16_lm_cross_entropy=self.fp16_lm_cross_entropy,
parallel_output=self.parallel_output,
share_embeddings_and_output_weights=self.share_embeddings_and_output_weights,
position_embedding_type=self.position_embedding_type,
rotary_percent=self.rotary_percent,
seq_len_interpolation_factor=self.seq_len_interpolation_factor,
return_embeddings=False,
pre_process=parallel_state.is_pipeline_first_stage(),
post_process=parallel_state.is_pipeline_last_stage(), # set to False for inference
add_binary_head=do_next_sentence,
use_full_attention_mask=use_full_attention_mask,
)
# TODO (@skothenhill) this is a hack to load the old checkpoint.
# This should be removed once we have a proper checkpoint conversion
# see NeMo/nemo/collections/llm/gpt/model/mixtral.py for how we should do it.
# We should eventually have an adapter for nemo1 checkpoints, HF checkpoints (at least for ESM2 @georgea)
# and an adapter may also be the right way to handle expected missing/extra keys when importing
# a checkpoint for fine-tuning (eg ignore misisng lm_head, if not there in model, etc).
if self.nemo1_ckpt_path is not None:
te_mapping = self.biobert_spec_option in {
BiobertSpecOption.bert_layer_with_transformer_engine_spec,
BiobertSpecOption.bert_layer_with_transformer_engine_and_qk_ln_spec,
}
with tarfile.open(self.nemo1_ckpt_path, "r") as old_ckpt:
ckpt_file = old_ckpt.extractfile("./model_weights.ckpt")
old_weights = torch.load(ckpt_file)
new_state_dict_from_old = {}
for k, v in old_weights.items():
if "word_embeddings" in k:
print(k)
new_key = nemo1_to_nemo2_biobert_key_mapping(k, new_model_prefix="", te_mapping=te_mapping)
new_state_dict_from_old[new_key] = v
# TE adds non-null ._extra_state objects to layers, which store some kind of buffer bits
# so we need to allow those to pass through if we're loading from bionemo1 which did not
# use TE.
model.load_state_dict(new_state_dict_from_old, strict=not te_mapping)

# TODO (@jstjohn) come up with a cleaner way in the biobert module to return hidden states.
# maybe a suite of options like hugging face has so a user can ask for several or only one thing.
if self.return_only_hidden_states:
# this applies the final layernorm in the encoder to the hidden states which was
# the default in nemo1.
model.post_process = False
model.encoder.post_process = True
model.encoder.post_layer_norm = True
return model

def get_loss_reduction_class(self) -> Type[MegatronLossReduction]: # noqa: D102
# You could optionally return a different loss reduction class here based on the config settings.
return BERTMLMLossWithReduction
core_attention_override: Type[torch.nn.Module] | None = ESM2DotProductAttention
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_esm2_650m_checkpoint(esm2_model):

def test_esm2_golden_values(esm2_650M_config_w_ckpt, sample_data):
device = "cuda"

assert esm2_650M_config_w_ckpt.core_attention_override is not None
tokenizer = AutoTokenizer(pretrained_model_name="facebook/esm2_t33_650M_UR50D")
tokens = tokenizer.tokenizer([row[1] for row in sample_data], return_tensors="pt", padding=True).to(device)
input_ids = tokens["input_ids"]
Expand Down Expand Up @@ -183,7 +183,7 @@ def test_esm2_golden_values(esm2_650M_config_w_ckpt, sample_data):

# configure the model to return hiddens
esm2_650M_config_hiddens = deepcopy(esm2_650M_config_w_ckpt)
esm2_650M_config_hiddens.return_only_hidden_states = True
esm2_650M_config_hiddens.mutate_hparam("return_only_hidden_states", True)
model = esm2_650M_config_hiddens.configure_model(tokenizer).to(device)
model.eval()
hiddens = model(input_ids, attention_mask)
Expand Down
18 changes: 9 additions & 9 deletions sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
Any,
Callable,
Dict,
Generic,
List,
Literal,
Optional,
Expand Down Expand Up @@ -128,7 +127,7 @@ def __init__( # noqa: D107
transformer_layer_spec: spec_utils.ModuleSpec,
vocab_size: int,
max_sequence_length: int,
tokrnizer: Optional[AutoTokenizer] = None,
tokenizer: Optional[AutoTokenizer] = None,
pre_process: bool = True,
post_process: bool = True,
fp16_lm_cross_entropy: bool = False,
Expand Down Expand Up @@ -407,12 +406,9 @@ def override_mutate_possibly_extra_mutated_fiddle(
setattr(target_cfg, f, getattr(source_cfg, f))


@dataclass
@dataclass(kw_only=True)
class BioBertGenericConfig(
Generic[MegatronBioBertModelT],
MegatronBioNeMoTrainableModelConfig[MegatronBioBertModelT, MegatronLossReduction],
TransformerConfig,
# Do not add iomixin here
):
"""Config class for BioBert model, responsible for the partial configuration of Transformer models.
Expand All @@ -421,7 +417,6 @@ class BioBertGenericConfig(
`configure_model()` is ultimately called by the LightningModule using PTL lightning module hooks.
"""

model_cls: Type[MegatronBioBertModelT] = MegatronBioBertModel
# From megatron.core.models.gpt.bert_model.GPTModel
fp16_lm_cross_entropy: bool = False
parallel_output: bool = True
Expand Down Expand Up @@ -457,6 +452,7 @@ class BioBertGenericConfig(
override_parent_fields: List[str] = field(default_factory=lambda: _OVERRIDE_BIONEMO_CONFIG_DEFAULTS)
return_only_hidden_states: bool = False
include_hiddens: bool = False # Include hidden layers in the output of the model
core_attention_override: Type[torch.nn.Module] | None = field(default_factory=lambda: None)

def configure_model(self, tokenizer) -> MegatronBioBertModelT: # noqa: D102
vp_size = self.virtual_pipeline_model_parallel_size
Expand Down Expand Up @@ -496,11 +492,15 @@ def configure_model(self, tokenizer) -> MegatronBioBertModelT: # noqa: D102

model = self.model_cls(
self,
transformer_layer_spec=get_biobert_spec(self.biobert_spec_option, qk_layernorm=self.qk_layernorm),
transformer_layer_spec=get_biobert_spec(
self.biobert_spec_option,
qk_layernorm=self.qk_layernorm,
core_attention=self.core_attention_override,
),
num_tokentypes=2 if do_next_sentence else 0,
vocab_size=get_vocab_size(self, tokenizer.vocab_size, self.make_vocab_size_divisible_by),
max_sequence_length=self.seq_length,
tokrnizer=tokenizer,
tokenizer=tokenizer,
fp16_lm_cross_entropy=self.fp16_lm_cross_entropy,
parallel_output=self.parallel_output,
share_embeddings_and_output_weights=self.share_embeddings_and_output_weights,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


from enum import Enum
from typing import Optional, Sequence
from typing import Optional, Sequence, Type

from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
Expand Down Expand Up @@ -60,7 +60,7 @@ class BiobertSpecOption(str, Enum):
def get_biobert_spec( # noqa: D417
biobert_spec_option: BiobertSpecOption,
qk_layernorm: bool = False,
core_attention: Optional[Module] = None,
core_attention: Optional[Type[Module]] = None,
) -> spec_utils.ModuleSpec:
"""Get the spec for the Biobert model.
Expand Down
17 changes: 8 additions & 9 deletions sub-packages/bionemo-llm/src/bionemo/llm/model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Generic, Type
from dataclasses import dataclass, field
from typing import Type

from megatron.core.transformer import TransformerConfig
from nemo.lightning import io

from bionemo.core.model.config import BionemoModelConfig, BionemoTrainableModelConfig, Loss, Model


class MegatronBioNeMoModelConfig(Generic[Model], BionemoModelConfig[Model], TransformerConfig):
@dataclass(kw_only=True)
class MegatronBioNeMoModelConfig(BionemoModelConfig[Model], TransformerConfig, io.NeedsIOMixin):
"""A ModelConfig class for bionemo that supports usage with Megatron models, for example as NeMo2 requires."""

model_cls: Type[Model]
model_cls: Type[Model] = field(init=False)


class MegatronBioNeMoTrainableModelConfig(
Generic[Model, Loss], BionemoTrainableModelConfig[Model, Loss], MegatronBioNeMoModelConfig[Model], io.NeedsIOMixin
):
"""A ModelConfig class for bionemo that supports usage with Megatron models, for example as NeMo2 requires."""

model_cls: Type[Model]
@dataclass(kw_only=True)
class MegatronBioNeMoTrainableModelConfig(MegatronBioNeMoModelConfig[Model], BionemoTrainableModelConfig[Model, Loss]):
pass

0 comments on commit 2aeab63

Please sign in to comment.