Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mamba update #254

Open
wants to merge 71 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
32c596d
allow inputs_embeds in model input args
jmercat Apr 10, 2024
9e8abdf
linted
jmercat Apr 11, 2024
7aa33c3
few fixes
jmercat Apr 15, 2024
7739486
update Mamba compatibility
jmercat Apr 18, 2024
d468181
MambaParams default values
jmercat Apr 18, 2024
b14ef50
add mamba dataclass args
jmercat Apr 18, 2024
35e2488
small fix in hf_config
jmercat Apr 18, 2024
762f0a9
set the correct kind of model in OpenLMModel
jmercat Apr 18, 2024
7fb747f
get_input_embeddings for mamba
jmercat Apr 23, 2024
aebb9d5
[wip] add inputs_embeddings in mamba class
jmercat Apr 23, 2024
3ca3986
[wip] add inputs_embeddings in mamba class
jmercat Apr 23, 2024
c21129e
[wip] add inputs_embeddings in mamba class
jmercat Apr 23, 2024
f5ae4d4
[wip] Mamba kwargs
jmercat Apr 23, 2024
f89e36c
[wip] Mamba optional inputs
jmercat Apr 23, 2024
a8fda98
[wip] Mamba inputs propagated
jmercat Apr 23, 2024
8a88e0d
[wip] Mamba output fix
jmercat Apr 23, 2024
813cf72
Filter keys keywords in load model (remove some layers from the state…
jmercat Apr 24, 2024
81b6aac
Mamba import optional and its dependencies too
jmercat Apr 24, 2024
8b5d641
Mamba import optional and its dependencies too
jmercat Apr 25, 2024
9525c5c
check for state_dict in checkpoint when no epoch
jmercat Apr 25, 2024
7c62e96
rotary inv_freq not a buffer anymore
jmercat Apr 25, 2024
bc64b0f
rotary inv_freq not a buffer anymore -> remove it from state dicts wh…
jmercat Apr 25, 2024
d2d08b6
loosen panda dep
jmercat Apr 25, 2024
3e0ce49
avoid masking if the attention_mask is masking only the right padding
jmercat Apr 30, 2024
741d717
handle none mask in hf_model
jmercat Apr 30, 2024
5186928
Filter keys keywords in load model (remove some layers from the state…
jmercat May 1, 2024
9bac984
Fix hf_model loss
jmercat May 1, 2024
b5b1cfb
Prevent wrong config to fail
jmercat May 1, 2024
f301fb6
fix mask shift
jmercat May 4, 2024
b84ddbe
Update boto requirement in requirements.txt (#284)
sedrick-keh-tri Jul 1, 2024
a14133f
remove xformers from requirements
sedrick-keh-tri Jul 11, 2024
3a3998e
full purge
sedrick-keh-tri Jul 11, 2024
e29256f
Merge pull request #1 from jmercat/purge-xformers
sedrick-keh-tri Jul 11, 2024
87bbb85
mbm configs
sedrick-keh-tri Jul 14, 2024
69c9235
Update ci.yml (#293)
Vaishaal Jul 17, 2024
c0f1319
HF integration (#291)
sedrick-keh-tri Jul 17, 2024
11d2028
Create open_lm_1b_swiglutorch.json
sedrick-keh-tri Jul 19, 2024
9bb92ef
use llm-foundry for ICL metrics (#287)
ysharma1126 Jul 23, 2024
d908188
allow inputs_embeds in model input args
jmercat Apr 10, 2024
ad34fd3
linted
jmercat Apr 11, 2024
93a8117
few fixes
jmercat Apr 15, 2024
bef9e8d
update Mamba compatibility
jmercat Apr 18, 2024
a8be17c
MambaParams default values
jmercat Apr 18, 2024
a286021
add mamba dataclass args
jmercat Apr 18, 2024
b7fdd10
small fix in hf_config
jmercat Apr 18, 2024
47105c8
set the correct kind of model in OpenLMModel
jmercat Apr 18, 2024
a940f32
get_input_embeddings for mamba
jmercat Apr 23, 2024
d99bcec
[wip] add inputs_embeddings in mamba class
jmercat Apr 23, 2024
fb4613d
[wip] add inputs_embeddings in mamba class
jmercat Apr 23, 2024
046134d
[wip] add inputs_embeddings in mamba class
jmercat Apr 23, 2024
7b9ad58
[wip] Mamba kwargs
jmercat Apr 23, 2024
a3f86a5
[wip] Mamba optional inputs
jmercat Apr 23, 2024
e59c73e
[wip] Mamba inputs propagated
jmercat Apr 23, 2024
75b983f
[wip] Mamba output fix
jmercat Apr 23, 2024
9ceeaa2
Filter keys keywords in load model (remove some layers from the state…
jmercat Apr 24, 2024
9c41b76
Mamba import optional and its dependencies too
jmercat Apr 24, 2024
1f46f1f
Mamba import optional and its dependencies too
jmercat Apr 25, 2024
ed27e51
check for state_dict in checkpoint when no epoch
jmercat Apr 25, 2024
5fd16fd
rotary inv_freq not a buffer anymore
jmercat Apr 25, 2024
eba0181
rotary inv_freq not a buffer anymore -> remove it from state dicts wh…
jmercat Apr 25, 2024
e79644c
loosen panda dep
jmercat Apr 25, 2024
f1226d9
avoid masking if the attention_mask is masking only the right padding
jmercat Apr 30, 2024
4f6d45f
handle none mask in hf_model
jmercat Apr 30, 2024
80199f7
Filter keys keywords in load model (remove some layers from the state…
jmercat May 1, 2024
412896b
Fix hf_model loss
jmercat May 1, 2024
1ce3442
Prevent wrong config to fail
jmercat May 1, 2024
b401e75
fix mask shift
jmercat May 4, 2024
a1d6394
black formatting and rebase on main
jmercat Jul 24, 2024
dd97111
Merge remote-tracking branch 'origin/mamba_update' into mamba_update
jmercat Jul 24, 2024
1d6efab
change model_norm to norm_type
jmercat Jul 24, 2024
97cd5d4
more flexible OpenLM model config input
jmercat Jul 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,18 @@ venv
wandb
logs
checkpoints
data
results
preproc_data
logs
wandb
*.pt
.pytest_cache
.vscode
.git
tmp
attention_logs
not_val_shard
attention_logs
training/eval_data

2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
- name: Install
run: |
mkdir $HOME/.aws/
wget https://gist.githubusercontent.com/Vaishaal/f109bfab6a194a93040ae2a19b6be251/raw/7d8026ae234d77ba1ca29b1f9d114c6780308ae4/dummy_creds -O $HOME/.aws/credentials
wget https://gist.githubusercontent.com/Vaishaal/f109bfab6a194a93040ae2a19b6be251/raw/d6caf98a52f6889981d5fbd1707edd815b834161/dummy_creds -O $HOME/.aws/credentials
sudo apt-get update
python3 -m venv .env
source .env/bin/activate
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,5 @@ tests/assets/source_*/*
secrets.env
checkpoints/
experiments/
external/
preproc_data/
38 changes: 2 additions & 36 deletions open_lm/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import torch
from torch.nn import functional as F
import xformers.ops as xops


def get_rectangular_causal_mask(shape, q_seq_len, k_seq_len, device, dtype):
Expand Down Expand Up @@ -63,31 +62,6 @@ def apply_attention_mask_(bias, attention_mask, queries_dtype):
bias.mul_(~torch.all(bias == min_dtype, dim=-1, keepdim=True))


def xformers_attn(queries, keys, values, is_causal, attention_mask=None):
# xformers assumes q, k, v are [batch, seq_len, heads, embed_dim]
# We assume that queries match the last part of the key / value sequences
# see (https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.fmha.attn_bias.LowerTriangularFromBottomRightMask)
# we would like to replace the mask generation with: mask = xops.fmha.attn_bias.LowerTriangularFromBottomRightMask()
# sadly we cannot us this because it needs xformers>=0.0.23 and this is not compatible with torch<2.1.1 while llm-foundry requires torch<2.1.1

# If queries have shape [batch, 1, heads, dim] it means there is only one query in the sequence.
# In this case, there is no notion of causal masking, so we can just set the mask to None.
# This is actually needed to get the desired behavior with seq_len=1.
bias = None
if is_causal and queries.shape[1] == keys.shape[1] and attention_mask is None:
bias = xops.LowerTriangularMask()
elif is_causal and (queries.shape[1] > 1 or attention_mask is not None):
# Build causal mask that assumes queries are in the end of the sequence.
batch, q_seq_len, heads, _ = queries.shape
k_seq_len = keys.shape[1]
bias = get_rectangular_causal_mask((batch, heads), q_seq_len, k_seq_len, queries.device, queries.dtype)
if attention_mask is not None:
apply_attention_mask_(bias, attention_mask, queries_dtype=queries.dtype)
elif not is_causal and attention_mask is not None:
raise NotImplementedError("attention_mask with is_causal=False is not yet implemented.")
return xops.memory_efficient_attention(queries, keys, values, attn_bias=bias)


def torch_attn(queries, keys, values, is_causal, attention_mask=None):
# Need to call contiguous in torch >=2.1, otherwise later calls to .view() fail.
# Possibly related: https://github.com/pytorch/pytorch/issues/110213 - behavior of scaled_dot_product_attention
Expand All @@ -111,7 +85,7 @@ def torch_attn(queries, keys, values, is_causal, attention_mask=None):
if attention_mask is None:
bias = None
# If we only have one query, assume we don't need to be in causal mode (can attend to all keys).
if queries.shape == 1:
if queries.shape[1] == 1:
is_causal = False
else:
if not is_causal:
Expand Down Expand Up @@ -196,15 +170,7 @@ def get_attn_func(
alpha=None,
):
if attn_name == "auto":
return xformers_attn if torch.cuda.is_available() else torch_attn
elif attn_name == "xformers_attn":
return xformers_attn
elif attn_name == "xformers_attn_variable_length":
# Upon changing the input sequence length, xformers attention changes
# the stride dimension of the output tensor. This makes future calls to
# .view() that collapses last two dimensions fail. One thus needs to
# call .contiguous() on the output tensor. [#188]
return lambda *args, **kwargs: xformers_attn(*args, **kwargs).contiguous()
return torch_attn
elif attn_name == "torch_attn":
return torch_attn
elif attn_name == "custom_attn":
Expand Down
3 changes: 3 additions & 0 deletions open_lm/hf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .configuration_openlm import OpenLMConfig
from .modeling_openlm import OpenLMForCausalLM
from .tokenization_openlm import OpenLMTokenizerFast
24 changes: 24 additions & 0 deletions open_lm/hf/configuration_openlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Follows OLMo's HF template

"""
OpenLM configuration
"""

from transformers import AutoConfig, PretrainedConfig
from transformers.utils import logging

from open_lm.model import Params

logger = logging.get_logger(__name__)


class OpenLMConfig(PretrainedConfig):
model_type = "openlm"

def __init__(self, **kwargs):
kwargs["architectures"] = ["OpenLMForCausalLM"]
super().__init__(**kwargs)


# Register the config class so that it is available for transformer pipelines, auto-loading etc.
AutoConfig.register("openlm", OpenLMConfig)
194 changes: 194 additions & 0 deletions open_lm/hf/modeling_openlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# Follows OLMo's HF template

import logging
from dataclasses import fields
from typing import List, Optional, Tuple, Union

import torch
from transformers import PreTrainedModel
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.auto import AutoModelForCausalLM

from open_lm.model import Params, Transformer
from open_lm.norms import get_norm_class
from open_lm.attention import get_attn_func

from .configuration_openlm import OpenLMConfig

log = logging.getLogger(__name__)


def create_model_config_from_pretrained_config(config: OpenLMConfig):
"""
Utility function
"""

kwargs = {}
for field in fields(Params):
if hasattr(config, field.name):
kwargs[field.name] = getattr(config, field.name)

model_config = Params(**kwargs)

if hasattr(config, "norm_type"):
model_config.norm_type = get_norm_class(config.norm_type)

if hasattr(config, "attn_name"):
model_config.attn_func = get_attn_func(config.attn_name)

return model_config


class OpenLMForCausalLM(PreTrainedModel):
"""
Extremely barebones HF model wrapper.
"""

config_class = OpenLMConfig
base_model_prefix = "model"

def __init__(self, config: OpenLMConfig, model: Optional[Transformer] = None):
super().__init__(config)

if not model:
self.model_config = create_model_config_from_pretrained_config(config)
# Initialize model (always on CPU to start with so we don't run out of GPU memory).
self.model_config.init_device = "cpu"
self.model = Transformer(self.model_config)

else:
self.model = model

def forward(
self,
input_ids: torch.LongTensor = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
attention_bias: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[
Cache
] = None, # This is a hack mitigation of an issue in transformers `4.39.x` https://github.com/huggingface/transformers/issues/29426
) -> Union[Tuple, CausalLMOutputWithPast]:
if inputs_embeds is not None:
log.warning("inputs_embeds is set but OpenLM does not support it yet")
if attention_bias is not None:
log.warning("attention_bias is et but OpenLM does not support it yet")
if use_cache is None:
use_cache = True
if output_attentions:
raise ValueError("output_attentions is not yet supported in OpenLM")
if output_hidden_states:
raise ValueError("output_hidden_states is not yet supported in OpenLM")

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
# print("outer past_key_values: ", type(past_key_values))
# if past_key_values is not None:
# print(len(past_key_values), type(past_key_values[0]))
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
)

logits = outputs[0]
past_key_values = outputs[2]
hidden_states = None

loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = torch.nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.model_config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)

return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=past_key_values,
hidden_states=hidden_states,
)

def can_generate(self) -> bool:
return True

def prepare_inputs_for_generation(
self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
):
if past_key_values is not None:
if isinstance(past_key_values[0][1], int):
# This assumes that the second item of past key values is the length of the past (this is the case for linear attention)
past_length = past_key_values[0][1]
else:
# This assumes that the first item of past key values is a list of all the past keys, thus the
# shape 1 is the length of the past (this is the case for attention without window)
past_length = past_key_values[0][0].shape[1]

# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1

input_ids = input_ids[:, remove_prefix_length:]

model_inputs = {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.pop("use_cache", True),
}
return model_inputs

def get_input_embeddings(self) -> torch.nn.Module:
return self.model.tok_embeddings

def set_input_embeddings(self, value: torch.nn.Module):
self.model.tok_embeddings = value

def get_output_embeddings(self):
if self.model_config.weight_tying:
return self.model.tok_embeddings
else:
return self.model.output

def set_output_embeddings(self, value: torch.nn.Module):
if self.model_config.weight_tying:
self.model.tok_embeddings = value
else:
self.model.output = value

def tie_weights(self):
"""
Copied from OLMo (description below). I removed it and the results just became garbage, so this pass is needed.
This function is intentionally left as a no-op.
Weight tying is handled as follows:
- When the model is initialized, the `ff_out` layer is conditionally defined based on the `weight_tying` configuration.
See: `if not config.weight_tying: self.transformer.update(...)` in `olmo/model.py`.
- When computing logits, the `wte` weights are used directly if `weight_tying` is enabled.
See: `if self.config.weight_tying: logits = F.linear(x, self.transformer.wte.weight, None)` in the `forward` method.
Therefore, there is no need to explicitly tie the weights in this function.
"""
pass

def resize_token_embeddings(
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
) -> torch.nn.Embedding:
raise NotImplementedError


# Register the model so that it is available for transformer pipelines, auto-loading, etc.
AutoModelForCausalLM.register(OpenLMConfig, OpenLMForCausalLM)
18 changes: 18 additions & 0 deletions open_lm/hf/tokenization_openlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Follows OLMo's HF template

from transformers import AutoTokenizer, PreTrainedTokenizerFast

from open_lm.hf.configuration_openlm import OpenLMConfig


class OpenLMTokenizerFast(PreTrainedTokenizerFast):
# Note: OpenLM's tokenizer is already a wrapper around huggingface. This is potentially unnecessary.
pass

# def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
# # This is required to make the implementation complete.
# pass


# Register the tokenizer class so that it is available for transformer pipelines, auto-loading etc.
AutoTokenizer.register(OpenLMConfig, fast_tokenizer_class=OpenLMTokenizerFast)
9 changes: 8 additions & 1 deletion open_lm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def get_state_dict(name):
return sd


def load_model(args, model, different_seed=False):
def load_model(args, model, different_seed=False, filter_keys=None):
checkpoint = pt_load(args.resume, map_location="cpu")
if "epoch" in checkpoint:
if not different_seed and "shard_shuffle_seed" in checkpoint:
Expand All @@ -121,11 +121,15 @@ def load_model(args, model, different_seed=False):
# resuming a train checkpoint w/ epoch and optimizer state
start_epoch = checkpoint["epoch"]
sd = checkpoint["state_dict"]
# remove inv_freq from the state dict if it exists
sd = {k: v for k, v in sd.items() if "inv_freq" not in k}
global_step = checkpoint.get("step", None)
if next(iter(sd.items()))[0].startswith("module"):
sd = {k[len("module.") :]: v for k, v in sd.items()}
if "_orig_mod" in next(iter(sd.items()))[0]:
sd = {k.replace("_orig_mod.", ""): v for k, v in sd.items()}
if filter_keys is not None:
sd = {k: v for k, v in sd.items() if not any([x in k for x in filter_keys])}
if args.fsdp:
model.load_state_dict(sd)
elif args.distributed:
Expand All @@ -137,6 +141,9 @@ def load_model(args, model, different_seed=False):
# loading a bare (model only) checkpoint for fine-tune or evaluation
start_epoch, global_step = 0, 0
pretrained_seed = None
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
checkpoint = {k: v for k, v in checkpoint.items() if "inv_freq" not in k}
model.load_state_dict(checkpoint)
logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})")
return start_epoch, global_step, pretrained_seed
Expand Down
Loading
Loading