Skip to content

Commit

Permalink
Add eval loader to eval script (#742)
Browse files Browse the repository at this point in the history
* Add eval loader to eval script

* small input tests

* updates

* fix typing and formatting

* fixes, add tests

* remove circular dependency

* tests pass

* nits + small fixes

* add metrics at the end, refactor to put icl/gauntlet as helpers

* NOT

* metrics instead of models, add unit tests
  • Loading branch information
aspfohl authored Nov 30, 2023
1 parent 1191267 commit 3100859
Show file tree
Hide file tree
Showing 11 changed files with 469 additions and 165 deletions.
32 changes: 12 additions & 20 deletions llmfoundry/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader
from llmfoundry.data.text_data import build_text_dataloader

LOADER_NAME_TO_FUNCTION = {
'text': build_text_dataloader,
'text_denoising': build_text_denoising_dataloader,
'finetuning': build_finetuning_dataloader,
}


def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
device_batch_size: int) -> DataSpec:
Expand All @@ -22,23 +28,9 @@ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
device_batch_size (int): The size of the batches (number of examples)
that the dataloader will produce.
"""
if cfg.name == 'text':
return build_text_dataloader(
cfg,
tokenizer,
device_batch_size,
)
elif cfg.name == 'text_denoising':
return build_text_denoising_dataloader(
cfg,
tokenizer,
device_batch_size,
)
elif cfg.name == 'finetuning':
return build_finetuning_dataloader(
cfg,
tokenizer,
device_batch_size,
)
else:
raise ValueError(f'Not sure how to build dataloader with config: {cfg}')
if cfg.name not in LOADER_NAME_TO_FUNCTION:
allowed = ', '.join(LOADER_NAME_TO_FUNCTION.keys())
raise ValueError(f'Expected dataloader name to be one of {allowed}' +
f' but found name "{cfg.name}" in config: {cfg}')

return LOADER_NAME_TO_FUNCTION[cfg.name](cfg, tokenizer, device_batch_size)
81 changes: 81 additions & 0 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as om
from torch.optim.optimizer import Optimizer
from torchmetrics import Metric
from transformers import AutoTokenizer, PreTrainedTokenizerBase

from llmfoundry.callbacks import (EvalGauntlet, FDiffMetrics, GlobalLRScaling,
HuggingFaceCheckpointer, LayerFreezing,
MonolithicCheckpointSaver,
ScheduledGarbageCollector)
from llmfoundry.data.dataloader import build_dataloader
from llmfoundry.optim import (DecoupledAdaLRLion, DecoupledClipLion,
DecoupledLionW, DecoupledLionW_8bit)
from llmfoundry.optim.scheduler import InverseSquareRootWithWarmupScheduler
Expand All @@ -42,6 +44,85 @@
log = logging.getLogger(__name__)


def build_evaluators(
eval_loader_config: Optional[Union[DictConfig, ListConfig]],
icl_tasks_config: Optional[Union[str, ListConfig]],
eval_gauntlet_config: Optional[Union[str, DictConfig]],
*,
tokenizer: PreTrainedTokenizerBase,
device_eval_batch_size: int,
icl_seq_len: int,
icl_subset_num_batches: Optional[int],
) -> Tuple[List[Evaluator], List[str], Optional[EvalGauntlet]]:

evaluators = []
if eval_loader_config is not None:
evaluators = build_eval_loaders(
eval_loader_config,
tokenizer,
device_eval_batch_size,
)

logger_keys = []
eval_gauntlet_callback = None
if icl_tasks_config is not None:
icl_evaluators, logger_keys, eval_gauntlet_callback = build_icl_data_and_gauntlet(
icl_tasks_config,
eval_gauntlet_config,
tokenizer,
device_eval_batch_size,
icl_seq_len,
icl_subset_num_batches,
)
evaluators.extend(icl_evaluators)

return evaluators, logger_keys, eval_gauntlet_callback


def build_eval_loaders(
eval_loader_config: Union[DictConfig, ListConfig],
tokenizer: PreTrainedTokenizerBase,
device_eval_batch_size: int,
) -> List[Evaluator]:
evaluators: List[Evaluator] = []
if isinstance(eval_loader_config, ListConfig):
eval_configs: ListConfig = eval_loader_config
is_multi_eval = True
else:
eval_configs = ListConfig([eval_loader_config])
is_multi_eval = False

for eval_config in eval_configs:
eval_dataloader = build_dataloader(eval_config, tokenizer,
device_eval_batch_size)
eval_loader: Evaluator = Evaluator(
label=f'eval/{eval_config.label}' if is_multi_eval else 'eval',
dataloader=eval_dataloader,
# Load the eval data to fail fast. metrics will get added
# later in add_metrics_to_eval_loaders, after the model is loaded
metric_names=[],
)
evaluators.append(eval_loader)
return evaluators


def add_metrics_to_eval_loaders(
evaluators: List[Evaluator],
metrics: Dict[str, Metric],
) -> List[Evaluator]:
metric_names = list(metrics.keys())
eval_loaders, other_evaluators = [], []
for evaluator in evaluators:
if evaluator.metric_names == []:
evaluator.metric_names = metric_names
eval_loaders.append(evaluator)
else:
other_evaluators.append(evaluator)

# Put the base eval_loaders first
return eval_loaders + other_evaluators


def build_icl_data_and_gauntlet(
icl_tasks_config: Union[str, ListConfig],
eval_gauntlet_config: Optional[Union[str, DictConfig]],
Expand Down
53 changes: 39 additions & 14 deletions scripts/eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import sys
import time
import warnings
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import pandas as pd
import torch
Expand All @@ -21,13 +21,14 @@

from llmfoundry.models import MPTForCausalLM
from llmfoundry.models.model_registry import COMPOSER_MODEL_REGISTRY
from llmfoundry.utils.builders import (build_icl_data_and_gauntlet,
build_logger, build_tokenizer)
from llmfoundry.utils.builders import (add_metrics_to_eval_loaders,
build_evaluators, build_logger,
build_tokenizer)
from llmfoundry.utils.config_utils import pop_config, process_init_device


def load_peft_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
num_retries: int) -> Optional[ComposerModel]:
num_retries: int) -> ComposerModel:
try:
from peft import PeftModel
except ImportError as e:
Expand All @@ -43,7 +44,8 @@ def load_peft_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
}

retries = 0
while retries < num_retries:
composer_model_wrapper = None
while retries < num_retries and composer_model_wrapper is None:
try:
trust_remote_code = model_cfg.get('trust_remote_code', True)
use_auth_token = model_cfg.get('use_auth_token', False)
Expand All @@ -58,7 +60,6 @@ def load_peft_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,

composer_model_wrapper = COMPOSER_MODEL_REGISTRY[model_cfg.name](
peft_model, tokenizer)
return composer_model_wrapper
except Exception as e:
retries += 1
if retries >= num_retries:
Expand All @@ -68,19 +69,21 @@ def load_peft_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
f'Got exception {str(e)} while loading model {model_cfg.name}. {num_retries-retries} retries remaining'
)

assert composer_model_wrapper is not None
return composer_model_wrapper


def load_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
fsdp_config: Optional[Dict],
num_retries: int) -> Optional[ComposerModel]:
fsdp_config: Optional[Dict], num_retries: int) -> ComposerModel:
init_context = process_init_device(model_cfg, fsdp_config)

retries = 0
composer_model = None
with init_context:
while retries < num_retries:
while retries < num_retries and composer_model is None:
try:
composer_model = COMPOSER_MODEL_REGISTRY[model_cfg.name](
model_cfg, tokenizer)
return composer_model
except Exception as e:
retries += 1
if retries >= num_retries:
Expand All @@ -90,6 +93,9 @@ def load_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
f'Got exception {str(e)} while loading model {model_cfg.name}. {num_retries-retries} retries remaining'
)

assert composer_model is not None
return composer_model


def evaluate_model(
model_cfg: DictConfig,
Expand All @@ -100,6 +106,7 @@ def evaluate_model(
max_seq_len: int,
device_eval_batch_size: int,
eval_gauntlet_config: Optional[Union[str, DictConfig]],
eval_loader_config: Optional[Union[DictConfig, ListConfig]],
fsdp_config: Optional[Dict],
num_retries: int,
loggers_cfg: Dict[str, Any],
Expand All @@ -118,9 +125,15 @@ def evaluate_model(
tokenizer_kwargs = tokenizer_cfg.get('kwargs', {})
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)

evaluators, logger_keys, eval_gauntlet_callback = build_icl_data_and_gauntlet(
icl_tasks, eval_gauntlet_config, tokenizer, device_eval_batch_size,
max_seq_len, icl_subset_num_batches)
evaluators, logger_keys, eval_gauntlet_callback = build_evaluators(
eval_loader_config,
icl_tasks,
eval_gauntlet_config,
tokenizer=tokenizer,
device_eval_batch_size=device_eval_batch_size,
icl_seq_len=max_seq_len,
icl_subset_num_batches=icl_subset_num_batches,
)

callbacks = []
if eval_gauntlet_callback is not None:
Expand All @@ -143,6 +156,11 @@ def evaluate_model(
composer_model = load_model(model_cfg.model, tokenizer, fsdp_config,
num_retries)

# Now add the eval metrics
if eval_loader_config is not None:
train_metrics = composer_model.get_metrics(is_train=True)
evaluators = add_metrics_to_eval_loaders(evaluators, train_metrics)

if eval_gauntlet_df is None and eval_gauntlet_callback is not None:
eval_gauntlet_df = pd.DataFrame(
columns=['model_name'] +
Expand Down Expand Up @@ -186,7 +204,7 @@ def evaluate_model(
return (trainer, logger_keys, eval_gauntlet_callback, eval_gauntlet_df)


def main(cfg: DictConfig):
def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]:
om.resolve(cfg)
model_configs: ListConfig = pop_config(cfg, 'models', must_exist=True)
eval_gauntlet_config: Optional[Union[str, DictConfig]] = pop_config(
Expand Down Expand Up @@ -228,6 +246,8 @@ def main(cfg: DictConfig):
default_value='debug')

# Optional Evaluation Parameters with default values
eval_loader_config: Optional[Union[DictConfig, ListConfig]] = pop_config(
cfg, 'eval_loader', must_exist=False, default_value=None)
seed: int = pop_config(cfg, 'seed', must_exist=False, default_value=17)
dist_timeout: Union[float, int] = pop_config(cfg,
'dist_timeout',
Expand Down Expand Up @@ -274,6 +294,7 @@ def main(cfg: DictConfig):
eval_gauntlet_df = None
models_df = None
composite_scores = None
trainers = []
for model_cfg in model_configs:
(trainer, logger_keys, eval_gauntlet_callback,
eval_gauntlet_df) = evaluate_model(
Expand All @@ -285,13 +306,15 @@ def main(cfg: DictConfig):
max_seq_len=max_seq_len,
device_eval_batch_size=device_eval_batch_size,
eval_gauntlet_config=eval_gauntlet_config,
eval_loader_config=eval_loader_config,
fsdp_config=fsdp_config,
num_retries=num_retries,
loggers_cfg=loggers_cfg,
python_log_level=python_log_level,
precision=precision,
eval_gauntlet_df=eval_gauntlet_df,
icl_subset_num_batches=icl_subset_num_batches)
trainers.append(trainer)

if eval_gauntlet_callback is not None:
composite_scores = eval_gauntlet_callback.eval_after_all(
Expand Down Expand Up @@ -330,6 +353,8 @@ def main(cfg: DictConfig):
assert models_df is not None
print(models_df.to_markdown(index=False))

return trainers, eval_gauntlet_df


def calculate_markdown_results(logger_keys: List[str], trainer: Trainer,
benchmark_to_taxonomy: Dict[str, str],
Expand Down
52 changes: 17 additions & 35 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import torch
from composer import Trainer
from composer.core import Evaluator
from composer.core.callback import Callback
from composer.loggers import MosaicMLLogger
from composer.loggers.mosaicml_logger import (MOSAICML_ACCESS_TOKEN_ENV_VAR,
Expand All @@ -26,10 +25,11 @@
from llmfoundry import (COMPOSER_MODEL_REGISTRY, ComposerHFCausalLM,
MPTForCausalLM)
from llmfoundry.data.dataloader import build_dataloader
from llmfoundry.utils.builders import (build_algorithm, build_callback,
build_icl_data_and_gauntlet,
build_logger, build_optimizer,
build_scheduler, build_tokenizer)
from llmfoundry.utils.builders import (add_metrics_to_eval_loaders,
build_algorithm, build_callback,
build_evaluators, build_logger,
build_optimizer, build_scheduler,
build_tokenizer)
from llmfoundry.utils.config_utils import (log_config, pop_config,
process_init_device,
update_batch_size_info)
Expand Down Expand Up @@ -526,31 +526,16 @@ def main(cfg: DictConfig) -> Trainer:

## Evaluation
print('Building eval loader...')
evaluators = []
eval_loaders = []
if eval_loader_config is not None:
is_multi_eval = isinstance(eval_loader_config, ListConfig)
eval_configs = eval_loader_config if is_multi_eval else [
eval_loader_config
]
for eval_config in eval_configs:
eval_dataloader = build_dataloader(eval_config, tokenizer,
device_eval_batch_size)
eval_loader = Evaluator(
label=f'eval/{eval_config.label}' if is_multi_eval else 'eval',
dataloader=eval_dataloader,
metric_names=[], # we will add these after model is created
)
eval_loaders.append(eval_loader)

eval_gauntlet_callback = None

if icl_tasks_config is not None:
icl_evaluators, _, eval_gauntlet_callback = build_icl_data_and_gauntlet(
icl_tasks_config, eval_gauntlet_config, tokenizer,
device_eval_batch_size, icl_seq_len if icl_seq_len else max_seq_len,
icl_subset_num_batches)
evaluators.extend(icl_evaluators)
eval_icl_seq_len: int = icl_seq_len if icl_seq_len else max_seq_len
evaluators, _, eval_gauntlet_callback = build_evaluators(
eval_loader_config,
icl_tasks_config,
eval_gauntlet_config,
tokenizer=tokenizer,
device_eval_batch_size=device_eval_batch_size,
icl_seq_len=eval_icl_seq_len,
icl_subset_num_batches=icl_subset_num_batches,
)

if eval_gauntlet_callback is not None:
callbacks.append(eval_gauntlet_callback)
Expand Down Expand Up @@ -581,11 +566,8 @@ def main(cfg: DictConfig) -> Trainer:

# Now add the eval metrics
if eval_loader_config is not None:
assert model.train_metrics is not None
eval_metric_names = list(model.train_metrics.keys())
for eval_loader in eval_loaders:
eval_loader.metric_names = eval_metric_names
evaluators.insert(0, eval_loader) # Put the base eval_loaders first
train_metrics = model.get_metrics(is_train=True)
evaluators = add_metrics_to_eval_loaders(evaluators, train_metrics)

# Build the Trainer
print('Building trainer...')
Expand Down
Loading

0 comments on commit 3100859

Please sign in to comment.