From 35673d50c86a246316006bbb7c602d25b86b1d49 Mon Sep 17 00:00:00 2001 From: haileyschoelkopf Date: Fri, 22 Dec 2023 14:41:24 +0000 Subject: [PATCH] add lm-eval v0.4.0 --- eval_tasks/eval_adapter.py | 109 ++++++++++++++++++++-------------- requirements/requirements.txt | 2 +- 2 files changed, 65 insertions(+), 46 deletions(-) diff --git a/eval_tasks/eval_adapter.py b/eval_tasks/eval_adapter.py index e0a32797d..f7f66b794 100644 --- a/eval_tasks/eval_adapter.py +++ b/eval_tasks/eval_adapter.py @@ -13,19 +13,8 @@ # limitations under the License. from megatron.utils import is_local_main, print_rank_0 -import best_download - -# patch best_download (eval harness downloader) to only happen on the first local rank -fn = best_download.download_file - - -def _download_file(*args, **kwargs): - if is_local_main(): - fn(*args, **kwargs) - - -best_download.download_file = _download_file +import copy import os import sys import dataclasses @@ -38,13 +27,13 @@ def _download_file(*args, **kwargs): import torch import torch.nn.functional as F -from lm_eval.models.gpt2 import GPT2LM -from lm_eval import tasks, evaluator, utils, base +from lm_eval.models.huggingface import HFLM +from lm_eval import tasks, evaluator, utils, api from megatron.text_generation_utils import generate_samples_from_prompt from megatron import mpu -class EvalHarnessAdapter(GPT2LM): +class EvalHarnessAdapter(HFLM): """ An adapter to run NeoX models on LM Evaluation Harness (https://github.com/EleutherAI/lm-evaluation-harness) tasks. @@ -56,13 +45,13 @@ class EvalHarnessAdapter(GPT2LM): """ def __init__(self, model, forward_step_fn, neox_args, batch_size=None): - self.cache_hook = base.CacheHook(None) - self.model = model + self.cache_hook = api.model.CacheHook(None) + self._model = model self.neox_args = neox_args self.tokenizer = neox_args.tokenizer self._device = torch.device(f"cuda:{neox_args.local_rank}") self._eot_token_id = neox_args.tokenizer.eod_id - self._max_length = neox_args.max_position_embeddings // 2 + self._max_length = neox_args.max_position_embeddings self._max_gen_toks = 128 self._vocab_size = neox_args.padded_vocab_size @@ -94,8 +83,6 @@ def __init__(self, model, forward_step_fn, neox_args, batch_size=None): generate_samples_from_prompt, neox_args=neox_args, model=model, - maximum_tokens=self._max_gen_toks, - temperature=0.0, ) @property @@ -123,15 +110,23 @@ def batch_size(self): def device(self): return self._device - def tok_encode(self, string: str): + @property + def rank(self): + return 0 + + @property + def world_size(self): + return 1 + + def tok_encode(self, string: str, **kwargs): return self.tokenizer.encode(string) - def tok_decode(self, tokens): + def tok_decode(self, tokens, **kwargs): return self.tokenizer.decode(tokens) - def greedy_until(self, requests): + def generate_until(self, requests): """ - Greedy until is lm_eval harness' way to say "do greedy generation" - necessary for some tasks. + Generate until is lm_eval harness' way to say "do greedy generation" - necessary for some tasks. the eval harness dispatches requests to the model, and the model does argmax generation, the results of which are returned to the eval harness to evaluate. @@ -143,19 +138,46 @@ def greedy_until(self, requests): self.model.module.inference_mode(use_cache=True) # tell model to cache kv pairs res = [] + # get only the args from each Instance object + reqs = [req.args for req in requests] + def _collate(x): toks = self.tokenizer.encode(x[0]) return (len(toks), x[0]) - reord = utils.Reorderer(requests, _collate) - for context, until in tqdm(reord.get_reordered(), "Running greedy generation"): - if isinstance(until, str): - until = [until] + reord = utils.Reorderer(reqs, _collate) + for context, gen_kwargs in tqdm(reord.get_reordered(), "Running greedy generation"): + if isinstance(gen_kwargs, dict): + kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 + if "until" in kwargs.keys(): + until = kwargs.pop("until") + if isinstance(until, str): + until = [kwargs] + elif not isinstance(until, list): + raise ValueError( + f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}" + ) + else: + raise ValueError( + f"Expected `kwargs` to be of type `dict` but got {kwargs}" + ) + if not until: + until = [self.tok_decode(self.eot_token_id)] + if "max_gen_toks" in kwargs.keys(): + max_gen_toks = kwargs.pop("max_gen_toks") + else: + max_gen_toks = self.max_gen_toks + + if "do_sample" in kwargs.keys(): + kwargs.pop("do_sample") + stop_tokens = [self.tokenizer.encode(i) for i in until] cont = self.generate( text=context, stop_tokens=stop_tokens, recompute=self.neox_args.recompute, + maximum_tokens=max_gen_toks, + **kwargs, ) if cont: s = cont[0]["text"] or "" @@ -166,7 +188,7 @@ def _collate(x): s = s.split(term)[0] # partial caching - self.cache_hook.add_partial("greedy_until", (context, until), s) + self.cache_hook.add_partial("generate_until", (context, until), s) res.append(s) @@ -366,7 +388,6 @@ def run_eval( eval_tasks=None, num_fewshot=0, bootstrap_iters=2, - description_dict=None, use_cache=True, name="neox", limit=None, @@ -385,8 +406,12 @@ def run_eval( "winogrande", "mathqa", "pubmedqa", + "triviaqa", ] + # register all the default tasks bundled with lm-evaluation-harness repository + tasks.initialize_tasks() + # Returns a list containing all values of the task registry that # match at least one of the patterns import fnmatch @@ -413,31 +438,25 @@ def pattern_match(patterns, source_list): task_dict = tasks.get_task_dict(eval_tasks) lm = self - if use_cache: - # TODO(jon-tow): Append a subset of `neox_args` to the cache database - # name arg to distinguish model runs that use different configurations. - lm = base.CachingLM(lm, "lm_cache/" + name + ".db") - results = evaluator.evaluate( - lm=lm, - task_dict=tasks.get_task_dict(eval_tasks), - description_dict=description_dict, + results = evaluator.simple_evaluate( + model=lm, + tasks=eval_tasks, num_fewshot=num_fewshot, - limit=limit, + limit=10, #limit, bootstrap_iters=bootstrap_iters, + use_cache="lm_cache/" + name + ".db" if use_cache else None + # TODO: Append a subset of `neox_args` to the cache database + # name arg to distinguish model runs that use different configurations. ) - results["config"] = { + results["config"].update({ "model": name, "model_args": dataclasses.asdict(self.neox_args), "num_fewshot": num_fewshot, "batch_size": self.batch_size, "device": str(self.device), - "no_cache": not use_cache, - "limit": limit, - "bootstrap_iters": bootstrap_iters, - "description_dict": description_dict, - } + }) if was_training: self.model.train() diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 137da4d81..807a55974 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -3,7 +3,7 @@ git+https://github.com/EleutherAI/DeeperSpeed.git@b9260436e7da3e297fc6bedfd27d9e ftfy>=6.0.1 git+https://github.com/EleutherAI/lm_dataformat.git@4eec05349977071bf67fc072290b95e31c8dd836 huggingface_hub>=0.11.0 -lm_eval==0.3.0 +git+https://github.com/EleutherAI/lm-evaluation-harness.git@main#egg=lm_eval mpi4py>=3.0.3 numpy>=1.22.0 pybind11>=2.6.2