Skip to content

Commit

Permalink
Point to composer.callback.Generate (#631)
Browse files Browse the repository at this point in the history
* Point to composer.callback.Generate

* small fixes

* Add builder test for generate callback

* " -> '

* doc formatting

* Add test, add deprecation warning

* call assert

* Update llmfoundry/callbacks/generate_callback.py

Co-authored-by: Daniel King <[email protected]>

* add key test

* use mpt_causal_lm

* Update llmfoundry/callbacks/generate_callback.py

Co-authored-by: Daniel King <[email protected]>

* test updates

* mock inside

* formatting

* fix style

---------

Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
aspfohl and dakinggg authored Oct 12, 2023
1 parent 8e4c30a commit db2233e
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 113 deletions.
121 changes: 16 additions & 105 deletions llmfoundry/callbacks/generate_callback.py
Original file line number Diff line number Diff line change
@@ -1,119 +1,30 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

"""Periodically log generations to wandb from a set of prompts."""
from typing import Any, List, Union, cast
"""Deprecated Generate callback.
import torch
import wandb
from composer.core import Callback, State, get_precision_context
from composer.loggers import Logger, WandBLogger
from composer.utils import dist, ensure_tuple
Please use composer.callbacks.Generate instead.
"""
import warnings
from typing import Any, List, Union

from composer.callbacks import Generate as ComposerGenerate
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]


class Generate(Callback):
class Generate(ComposerGenerate):

def __init__(self, prompts: List[str], batch_log_interval: int,
**kwargs: Any):
"""Periodically log generations to wandb from a set of prompts.
In the main view for a run, there will be a table that will show the _last_ logged generations.
To compare previous iterations of the generations, you need to
1. Click on the run
2. Click on "artifacts" in the menu on the left side of the screen
3. Click on one of the artifacts called "predictions"
4. Click on the "files" tab
5. Click on "predictions.table.json"
6. On the left hand side, there are different versions of the table produced throughout training. Select one of these.
7. Now, when you hover over other versions, there will be a "compare" button, which will allow you to compare the currently
selected version to the version you add via compare.
Args:
prompts (List[str]): The list of prompts you would like to produce generations for
batch_log_interval (int): The interval (in batches) at which this callback runs
kwargs: All kwargs well be passed along to the call to generate. This is for things like `do_sample`, `top_p`, etc
"""
self.prompts = prompts
self.batch_log_interval = batch_log_interval
self.generate_kwargs = kwargs
self.wandb_logger = None

def init(self, state: State, logger: Logger):
if dist.get_global_rank() == 0:
for destination in ensure_tuple(logger.destinations):
if isinstance(destination, WandBLogger):
self.wandb_logger = destination

def batch_checkpoint(self, state: State, logger: Logger) -> None:
if (state.timestamp.batch.value % self.batch_log_interval) == 0:
self.generate(state, logger)

def generate(self, state: State, logger: Logger) -> None:
model = state.model
original_mode = model.training
model.eval()
tokenizer = cast(Tokenizer, state.model.tokenizer)
device = state.device

if not hasattr(model.model, 'generate'):
raise ValueError(
f'Cannot generate from model {model.model.__class__.__name__} because it does not have a `generate` method'
)

# stash the original original value of padding_side because generation requires left padding
original_padding_side = tokenizer.padding_side
tokenizer.padding_side = 'left'
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenized_input = tokenizer(self.prompts,
return_tensors='pt',
padding=True)

for k, v in tokenized_input.items():
tokenized_input[k] = device.tensor_to_device(v)

# dummy forward call needed for FSDP to work consistently
dummy_input = torch.tensor([[0]], dtype=torch.long)
dummy_input = device.tensor_to_device(dummy_input)
with get_precision_context(state.precision):
with torch.no_grad():
assert isinstance(model.model, torch.nn.Module)
_ = model.model(input_ids=dummy_input)

output_token_ids = model.model.generate( # type: ignore
input_ids=tokenized_input['input_ids'],
attention_mask=tokenized_input['attention_mask'],
synced_gpus=True,
**self.generate_kwargs,
)

if dist.get_global_rank() == 0:
if self.wandb_logger is not None:
assert wandb.run is not None, 'wandb should have started run'

artifact = wandb.Artifact('generate_samples_' +
str(wandb.run.id),
type='predictions')

rows = []
for i in range(len(self.prompts)):
prompt = self.prompts[i]
output_tokens = output_token_ids[i][
tokenized_input['input_ids'].shape[1]:]
output_text = tokenizer.decode(output_tokens,
skip_special_tokens=True)

rows.append([prompt, output_text])

text_table = wandb.Table(data=rows,
columns=['prompt', 'generation'])
artifact.add(text_table, 'predictions')
wandb.log_artifact(artifact)
wandb.log({'generations': text_table},
step=state.timestamp.batch.value)
warnings.warn(
('Accessing llmfoundry.callbacks.generate_callback.Generate '
'is deprecated and will be removed in a future release. '
'Please use composer.callbacks.Generate instead.'),
DeprecationWarning,
)

tokenizer.padding_side = original_padding_side
model.train(mode=original_mode)
interval = f'{batch_log_interval}ba'
super().__init__(prompts=prompts, interval=interval, **kwargs)
29 changes: 22 additions & 7 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@

import logging
import os
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from composer import algorithms
from composer.callbacks import (EarlyStopper, LRMonitor, MemoryMonitor,
OptimizerMonitor, RuntimeEstimator,
SpeedMonitor)
from composer.callbacks import (EarlyStopper, Generate, LRMonitor,
MemoryMonitor, OptimizerMonitor,
RuntimeEstimator, SpeedMonitor)
from composer.core import Algorithm, Callback, Evaluator
from composer.datasets.in_context_learning_evaluation import \
get_icl_task_dataloader
Expand All @@ -26,9 +27,9 @@
from torch.optim.optimizer import Optimizer
from transformers import AutoTokenizer, PreTrainedTokenizerBase

from llmfoundry.callbacks import (EvalGauntlet, FDiffMetrics, Generate,
GlobalLRScaling, HuggingFaceCheckpointer,
LayerFreezing, MonolithicCheckpointSaver,
from llmfoundry.callbacks import (EvalGauntlet, FDiffMetrics, GlobalLRScaling,
HuggingFaceCheckpointer, LayerFreezing,
MonolithicCheckpointSaver,
ScheduledGarbageCollector)
from llmfoundry.optim import (DecoupledAdaLRLion, DecoupledClipLion,
DecoupledLionW, DecoupledLionW_8bit)
Expand Down Expand Up @@ -90,7 +91,21 @@ def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback:
'log_optimizer_metrics', True),)
elif name == 'generate_callback':
prompts = kwargs.pop('prompts')
return Generate(prompts=list(prompts), **kwargs)
interval = kwargs.pop('interval', None)
# Generate callback used to be batch_log_interval, so this is for backwards compatibility
if interval is None:
batch_log_interval: str = kwargs.pop('batch_log_interval', '')
if batch_log_interval:
interval = f'{batch_log_interval}ba'
warnings.warn(
('generate_callback.batch_log_interval is deprecated and will be removed in a future release.'
f'Please use interval: {interval}'),
DeprecationWarning,
)
else:
raise KeyError(
'"interval" must be specified with generate callback')
return Generate(prompts=list(prompts), interval=interval, **kwargs)
elif name == 'global_lr_scaling':
return GlobalLRScaling(**kwargs)
elif name == 'layer_freezing':
Expand Down
51 changes: 50 additions & 1 deletion tests/test_builders.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import unittest.mock as mock
from typing import Union

import pytest
from composer.callbacks import Generate
from transformers import PreTrainedTokenizerBase

from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper
from llmfoundry.utils.builders import build_tokenizer
from llmfoundry.utils.builders import build_callback, build_tokenizer


@pytest.mark.parametrize('tokenizer_name,tokenizer_kwargs', [
Expand All @@ -29,3 +33,48 @@ def test_tokenizer_builder(tokenizer_name: str, tokenizer_kwargs: dict):
assert tokenizer.model_max_length == tokenizer_kwargs[
'model_max_length']
assert isinstance(tokenizer, PreTrainedTokenizerBase)


def test_build_callback_fails():
with pytest.raises(ValueError):
build_callback('nonexistent_callback', {})


@pytest.mark.parametrize(
'interval_key,interval_value',
[('interval', '10ba'), ('batch_log_interval', 10)],
)
def test_build_generate_callback(
interval_key: str,
interval_value: Union[str, int],
):

with mock.patch.object(Generate, '__init__',
autospec=True) as mock_generate:
mock_generate.return_value = None
build_callback(
'generate_callback', {
'prompts': ['hello'],
interval_key: interval_value,
'foo': 'bar',
'something': 'else',
})

assert mock_generate.call_count == 1
_, _, kwargs = mock_generate.mock_calls[0]
assert kwargs['prompts'] == ['hello']
assert kwargs['interval'] == '10ba'
assert kwargs['something'] == 'else'
assert kwargs['foo'] == 'bar'


def test_build_generate_callback_unspecified_interval():
with pytest.raises(KeyError):
with mock.patch.object(Generate, '__init__',
autospec=True) as mock_generate:
mock_generate.return_value = None
build_callback('generate_callback', {
'prompts': ['hello'],
'foo': 'bar',
'something': 'else',
})
93 changes: 93 additions & 0 deletions tests/test_hf_mpt_gen.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from pathlib import Path
from typing import Any, Dict
from unittest.mock import Mock

import pytest
from composer.callbacks import Generate as ComposerGenerate
from composer.core.precision import get_precision_context
from composer.trainer import Trainer
from composer.utils import get_device, reproducibility
from omegaconf import DictConfig
from omegaconf import OmegaConf as om

from llmfoundry import COMPOSER_MODEL_REGISTRY
from llmfoundry.data.finetuning import build_finetuning_dataloader
from llmfoundry.utils import build_tokenizer
from tests.data_utils import make_tiny_ft_dataset


@pytest.mark.gpu
Expand Down Expand Up @@ -72,3 +78,90 @@ def test_init_hfhub_mpt(device: str, attn_impl: str):

def test_init_hfhub_mpt_cpu():
test_init_hfhub_mpt(device='cpu', attn_impl='torch')


@pytest.mark.gpu
def test_mpt_generate_callback(tmpdir: Path):
composer_device = get_device('gpu')
reproducibility.seed_all(42)
max_seq_len = 128

# testing dataset and dataloader
dataset_size = 5

tiny_dataset_path = tmpdir / 'test-ift-data-small'
tiny_dataset_path.mkdir()
tiny_dataset_file = tiny_dataset_path / 'train.jsonl'
make_tiny_ft_dataset(path=str(tiny_dataset_file), size=dataset_size)

dataloader_cfg = DictConfig({
'name': 'finetuning',
'dataset': {
'hf_name': str(tiny_dataset_path),
'split': 'train',
'max_seq_len': max_seq_len,
'decoder_only_format': True,
'allow_pad_trimming': False,
'packing_ratio': None,
'shuffle': True,
},
'drop_last': False,
'num_workers': 4,
'pin_memory': False,
'prefetch_factor': 2,
'persistent_workers': False,
'timeout': 0
})

# build tokenizer
tokenizer = build_tokenizer('EleutherAI/gpt-neox-20b', {})

# build mpt model
model_config = DictConfig({
'name': 'mpt_causal_lm',
'config_overrides': {
'd_model': 128,
'n_heads': 4,
'n_layers': 2,
'expansion_ratio': 2,
},
})
model = COMPOSER_MODEL_REGISTRY[model_config.name](model_config, tokenizer)
model = composer_device.module_to_device(model)

# generate callback
prompts = [
'The best banana bread recipe is',
'2+2=',
'how much wood could a woodchuck chuck',
]
gen_interval = 1
generate = ComposerGenerate(
prompts,
interval=f'{gen_interval}ba',
max_new_tokens=5,
batch_size=len(prompts),
use_cache=True,
)
generate.generate = Mock(wraps=generate.generate, autospec=True)

# build trainer
device_batch_size = 1
train_dataloader = build_finetuning_dataloader(
dataloader_cfg,
tokenizer,
device_batch_size,
)

trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
device=composer_device,
max_duration=f'{gen_interval}ba',
callbacks=[generate],
)
trainer.logger.log_table = Mock()
trainer.fit()

generate.generate.assert_called_once()
trainer.logger.log_table.assert_called_once()

0 comments on commit db2233e

Please sign in to comment.