Skip to content

Commit

Permalink
Allow flash attention 2 and upgrade to transformers 4.34.1 (#672)
Browse files Browse the repository at this point in the history
* more special casing in tokenizer equivalence check
* fix addedtoken -> str
* add lazy load option
* add gc collect
* updates for the patch release
* add documentation for flash attention options
  • Loading branch information
dakinggg authored Oct 24, 2023
1 parent 091ddca commit d72902a
Show file tree
Hide file tree
Showing 9 changed files with 337 additions and 112 deletions.
11 changes: 9 additions & 2 deletions llmfoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
import torch

try:
# Before importing any transformers models, we need to disable transformers flash attention if
# we are in an environment with flash attention version <2. Transformers hard errors on a not properly
# gated import otherwise.
import transformers

from llmfoundry import optim, utils
from llmfoundry.data import (ConcatTokensDataset,
MixtureOfDenoisersCollator, NoConcatDataset,
Expand All @@ -14,8 +19,8 @@
ComposerHFT5)
from llmfoundry.models.layers.attention import (
MultiheadAttention, attn_bias_shape, build_alibi_bias, build_attn_bias,
flash_attn_fn, scaled_multihead_dot_product_attention,
triton_flash_attn_fn)
flash_attn_fn, is_flash_v1_installed,
scaled_multihead_dot_product_attention, triton_flash_attn_fn)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.ffn import (FFN_CLASS_REGISTRY, MPTMLP,
build_ffn)
Expand All @@ -24,6 +29,8 @@
MPTForCausalLM, MPTModel,
MPTPreTrainedModel)
from llmfoundry.tokenizers import TiktokenTokenizerWrapper
if is_flash_v1_installed():
transformers.utils.is_flash_attn_available = lambda: False

except ImportError as e:
try:
Expand Down
22 changes: 20 additions & 2 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@

from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
from llmfoundry.models.layers.llama_attention_monkeypatch import \
get_llama_attention_patch_fn
from llmfoundry.models.layers.attention import is_flash_v2_installed
from llmfoundry.models.utils import init_empty_weights

try:
Expand Down Expand Up @@ -95,12 +94,28 @@ def __init__(self, om_model_config: Union[DictConfig,
# load the model config
trust_remote_code = om_model_config.get('trust_remote_code', True)
use_auth_token = om_model_config.get('use_auth_token', False)
use_flash_attention_2 = om_model_config.get('use_flash_attention_2',
False)
if use_flash_attention_2 and not is_flash_v2_installed():
raise ValueError(
'use_flash_attention_2 is set to True, but flash-attention 2 is not installed. '
+ 'Please install flash_attn==2.3.2`.')

config = AutoConfig.from_pretrained(
om_model_config.pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
)

# This is not how you are supposed to set this, but transformers currently only
# supports enabling flash attention 2 when using the from_pretrained API.
# We need to support it for both from_pretrained and from_config, so we have to
# set the private attribute here. This will just skip all of transformers'
# validation logic that it is ok to use flash attention 2, so we check
# whether it is installed above, and whether the chosen config supports it here.
# https://github.com/huggingface/transformers/issues/26878
config._flash_attn_2_enabled = use_flash_attention_2

# set config overrides
for k, v in om_model_config.get('config_overrides', {}).items():
if not hasattr(config, k):
Expand Down Expand Up @@ -200,6 +215,9 @@ def __init__(self, om_model_config: Union[DictConfig,
)
from transformers.models.llama.modeling_llama import \
LlamaAttention

from llmfoundry.models.layers.llama_attention_monkeypatch import \
get_llama_attention_patch_fn
LlamaAttention.forward = get_llama_attention_patch_fn(
attention_patch_type)
model.config.use_cache = False
Expand Down
4 changes: 2 additions & 2 deletions llmfoundry/tokenizers/tiktoken.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def convert_ids_to_tokens(
"""
if isinstance(ids, int):
if ids in self.added_tokens_decoder:
return self.added_tokens_decoder[ids]
return str(self.added_tokens_decoder[ids])

return self._convert_id_to_token(ids)

Expand All @@ -171,7 +171,7 @@ def convert_ids_to_tokens(
if index in self.added_tokens_decoder:
tokens.append(self.encoding.decode(current_stream))
current_stream = []
tokens.append(self.added_tokens_decoder[index])
tokens.append(str(self.added_tokens_decoder[index]))
else:
current_stream.append(index)

Expand Down
62 changes: 55 additions & 7 deletions scripts/train/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@ This README walks through pretraining and finetuning a large language model usin
#### Table of Contents
1. [Part 1: LLM Pretraining](#llmpretraining)
1. [Installation](#installation)
2. [Dataset Preparation](#datasetpreparation)
3. [How to start single and multi-node pretraining](#howtostartpretraining)
2. [Part 2: LLM Finetuning](#llmfinetuning)
1. [Dataset Preparation](#datasetpreparation)
1. [How to start single and multi-node pretraining](#howtostartpretraining)
1. [Part 2: LLM Finetuning](#llmfinetuning)
1. [Using a dataset on the HuggingFace Hub](#hfdataset)
2. [Using a local dataset](#localdataset)
3. [Using a StreamingDataset (MDS) formatted dataset locally or in an object store](#mdsdataset)
3. [FAQ: How many GPUs do I need to train a LLM?](#howmandygpus)
4. [FAQ: Optimizing Performance](#optimizingperformance)
1. [Using a local dataset](#localdataset)
1. [Using a StreamingDataset (MDS) formatted dataset locally or in an object store](#mdsdataset)
1. [Using Flash Attention](#flashattention)
1. [FAQ: How many GPUs do I need to train a LLM?](#howmandygpus)
1. [FAQ: Optimizing Performance](#optimizingperformance)

# Part 1: LLM Pretraining <a name="llmpretraining"></a>

Expand Down Expand Up @@ -332,6 +333,53 @@ train_loader:
...
```
# Using Flash Attention <a name="flashattention"></a>

Flash Attention is an optimized implementation of the attention mechanism, first introduced by [Dao et al.](https://github.com/Dao-AILab/flash-attention). There are three versions of Flash Attention that can be used with LLM Foundry: Flash Attention V1, Flash Attention V2, and a Triton implementation of Flash Attention. To start, we recommend using one of our [provided Docker images](../../README.md#mosaicml-docker-images) corresponding to the Flash Attention version you would like to use. The Triton implementation can be used with either Flash Attention V1 or V2. Next, how you specify to use Flash Attention depends on which model you are using.

For MPT, you can specify Flash Attention in your YAML like so:
```yaml
model:
name: mpt_causal_lm
...
attn_config:
# Will use either V1 or V2 depending on what is installed
# "triton" will use the Triton implementation
attn_impl: flash
...
```

If loading MPT from the HuggingFace Hub, you can specify Flash Attention in your YAML like so:
```yaml
model:
name: hf_causal_lm
pretrained_model_name_or_path: mosaicml/mpt-7b
...
config_overrides:
# Will use either V1 or V2 depending on what is installed
# "triton" will use the Triton implementation
attn_config:
attn_impl: flash
...
```

For any HuggingFace model that supports Flash Attention (e.g. Llama and Mistral), you can specify Flash Attention in your YAML like so:
```yaml
model:
name: hf_causal_lm
use_flash_attention_2: True # Will be automatically set to True if Flash Attention V2 is installed and the model supports it
...
```
HuggingFace models currently only support Flash Attention V2.

For Llama specifically, we have another option if you would like to use the Triton implementation of Flash Attention. You can specify this in your YAML like so:
```yaml
model:
name: hf_causal_lm
pretrained_model_name_or_path: meta-llama/Llama-2-7b-hf
attention_patch_type: triton
...
```

# FAQ: How many GPUs do I need to train a LLM? <a name="howmanygpus"></a>
This is a complicated question in general, but if we assume that you are using FSDP with `FULL_SHARD`,
Expand Down
8 changes: 8 additions & 0 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
import copy
import gc
import logging
import os
import sys
Expand Down Expand Up @@ -216,6 +217,12 @@ def main(cfg: DictConfig) -> Trainer:
os.environ[
'PYTORCH_CUDA_ALLOC_CONF'] = f'max_split_size_mb:{max_split_size_mb}'

# Set CUDA lazy loading
# This can save a bit of memory if not all modules are needed
cuda_load_lazy: bool = cfg.pop('cuda_load_lazy', True)
if cuda_load_lazy:
os.environ['CUDA_MODULE_LOADING'] = 'LAZY'

# Set seed first
seed: int = pop_config(cfg, 'seed', must_exist=True)
reproducibility.seed_all(seed)
Expand Down Expand Up @@ -634,6 +641,7 @@ def main(cfg: DictConfig) -> Trainer:
print('Logging config')
log_config(logged_cfg)
torch.cuda.empty_cache()
gc.collect()

# Eval first if requested
if eval_first and trainer.state.timestamp.batch.value == 0:
Expand Down
9 changes: 5 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
install_requires = [
'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.16.4,<0.17',
'accelerate>=0.20,<0.21', # for HF inference `device_map`
'transformers>=4.33,<4.34',
'transformers>=4.34.1,<4.35',
'mosaicml-streaming>=0.6,<0.7',
'torch>=1.13.1,<2.1.1',
'datasets>=2.14.5,<2.15',
Expand Down Expand Up @@ -114,9 +114,10 @@
extra_deps['all-cpu'] = set(
dep for key, deps in extra_deps.items() for dep in deps if 'gpu' not in key)
extra_deps['all'] = set(dep for key, deps in extra_deps.items() for dep in deps
if key != 'gpu-flash2')
extra_deps['all-flash2'] = set(
dep for key, deps in extra_deps.items() for dep in deps if key != 'gpu')
if key not in {'gpu-flash2', 'all-cpu'})
extra_deps['all-flash2'] = set(dep for key, deps in extra_deps.items()
for dep in deps
if key not in {'gpu', 'all', 'all-cpu'})

setup(
name=_PACKAGE_NAME,
Expand Down
43 changes: 43 additions & 0 deletions tests/test_hf_conversion_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,49 @@ def check_hf_tokenizer_equivalence(tokenizer1: PreTrainedTokenizerBase,
tokenizer1.__dict__['init_kwargs'].pop('auto_map', None)
tokenizer2.__dict__['init_kwargs'].pop('auto_map', None)

# Additional special tokens do not match between original tokenizer and loaded tokenizer due to transformers
# constructor differences
additional_special_tokens_1 = {
t if isinstance(t, str) else t.content
for t in tokenizer1.__dict__.pop('_additional_special_tokens', [])
}
additional_special_tokens_2 = {
t if isinstance(t, str) else t.content
for t in tokenizer2.__dict__.pop('_additional_special_tokens', [])
}
# Also pop it out of init_kwargs
tokenizer1.__dict__['init_kwargs'].pop('additional_special_tokens', None)
tokenizer2.__dict__['init_kwargs'].pop('additional_special_tokens', None)
tokenizer1.__dict__['init_kwargs'].pop('added_tokens_decoder', None)
tokenizer2.__dict__['init_kwargs'].pop('added_tokens_decoder', None)
# If the additional special tokens are the same (or a subset of each other), or if one of them is empty, then we are good
assert additional_special_tokens_1.issubset(
additional_special_tokens_2) or additional_special_tokens_2.issubset(
additional_special_tokens_1)

# The special token attributes may be strings or they may be AddedToken objects, so we just check string values
# First check that they have the same attrs
assert tokenizer1.SPECIAL_TOKENS_ATTRIBUTES == tokenizer2.SPECIAL_TOKENS_ATTRIBUTES
# Then check that the values are the same
for special_token_attr in tokenizer1.SPECIAL_TOKENS_ATTRIBUTES:
# Skip additional_special_tokens because we already checked it above
if special_token_attr == 'additional_special_tokens':
continue

# The init_kwargs can change between the original tokenizer and the loaded tokenizer,
# so we just pop them
tokenizer1.__dict__['init_kwargs'].pop(special_token_attr, None)
tokenizer2.__dict__['init_kwargs'].pop(special_token_attr, None)

attr1 = tokenizer1.__dict__.pop('_' + special_token_attr, None)
attr2 = tokenizer2.__dict__.pop('_' + special_token_attr, None)
if attr1 is None and attr2 is None:
continue

attr_value1 = attr1 if isinstance(attr1, str) else attr1.content
attr_value2 = attr2 if isinstance(attr2, str) else attr2.content
assert attr_value1 == attr_value2

assert tokenizer1.__dict__ == tokenizer2.__dict__


Expand Down
Loading

0 comments on commit d72902a

Please sign in to comment.