Skip to content

Commit

Permalink
add lm-eval v0.4.0
Browse files Browse the repository at this point in the history
  • Loading branch information
haileyschoelkopf committed Dec 22, 2023
1 parent 9283eff commit 35673d5
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 46 deletions.
109 changes: 64 additions & 45 deletions eval_tasks/eval_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,8 @@
# limitations under the License.

from megatron.utils import is_local_main, print_rank_0
import best_download

# patch best_download (eval harness downloader) to only happen on the first local rank
fn = best_download.download_file


def _download_file(*args, **kwargs):
if is_local_main():
fn(*args, **kwargs)


best_download.download_file = _download_file

import copy
import os
import sys
import dataclasses
Expand All @@ -38,13 +27,13 @@ def _download_file(*args, **kwargs):
import torch
import torch.nn.functional as F

from lm_eval.models.gpt2 import GPT2LM
from lm_eval import tasks, evaluator, utils, base
from lm_eval.models.huggingface import HFLM
from lm_eval import tasks, evaluator, utils, api
from megatron.text_generation_utils import generate_samples_from_prompt
from megatron import mpu


class EvalHarnessAdapter(GPT2LM):
class EvalHarnessAdapter(HFLM):
"""
An adapter to run NeoX models on LM Evaluation Harness (https://github.com/EleutherAI/lm-evaluation-harness) tasks.
Expand All @@ -56,13 +45,13 @@ class EvalHarnessAdapter(GPT2LM):
"""

def __init__(self, model, forward_step_fn, neox_args, batch_size=None):
self.cache_hook = base.CacheHook(None)
self.model = model
self.cache_hook = api.model.CacheHook(None)
self._model = model
self.neox_args = neox_args
self.tokenizer = neox_args.tokenizer
self._device = torch.device(f"cuda:{neox_args.local_rank}")
self._eot_token_id = neox_args.tokenizer.eod_id
self._max_length = neox_args.max_position_embeddings // 2
self._max_length = neox_args.max_position_embeddings
self._max_gen_toks = 128
self._vocab_size = neox_args.padded_vocab_size

Expand Down Expand Up @@ -94,8 +83,6 @@ def __init__(self, model, forward_step_fn, neox_args, batch_size=None):
generate_samples_from_prompt,
neox_args=neox_args,
model=model,
maximum_tokens=self._max_gen_toks,
temperature=0.0,
)

@property
Expand Down Expand Up @@ -123,15 +110,23 @@ def batch_size(self):
def device(self):
return self._device

def tok_encode(self, string: str):
@property
def rank(self):
return 0

@property
def world_size(self):
return 1

def tok_encode(self, string: str, **kwargs):
return self.tokenizer.encode(string)

def tok_decode(self, tokens):
def tok_decode(self, tokens, **kwargs):
return self.tokenizer.decode(tokens)

def greedy_until(self, requests):
def generate_until(self, requests):
"""
Greedy until is lm_eval harness' way to say "do greedy generation" - necessary for some tasks.
Generate until is lm_eval harness' way to say "do greedy generation" - necessary for some tasks.
the eval harness dispatches requests to the model, and the model does argmax generation, the results of which
are returned to the eval harness to evaluate.
Expand All @@ -143,19 +138,46 @@ def greedy_until(self, requests):
self.model.module.inference_mode(use_cache=True) # tell model to cache kv pairs
res = []

# get only the args from each Instance object
reqs = [req.args for req in requests]

def _collate(x):
toks = self.tokenizer.encode(x[0])
return (len(toks), x[0])

reord = utils.Reorderer(requests, _collate)
for context, until in tqdm(reord.get_reordered(), "Running greedy generation"):
if isinstance(until, str):
until = [until]
reord = utils.Reorderer(reqs, _collate)
for context, gen_kwargs in tqdm(reord.get_reordered(), "Running greedy generation"):
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys():
until = kwargs.pop("until")
if isinstance(until, str):
until = [kwargs]
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {kwargs}"
)
if not until:
until = [self.tok_decode(self.eot_token_id)]
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks

if "do_sample" in kwargs.keys():
kwargs.pop("do_sample")

stop_tokens = [self.tokenizer.encode(i) for i in until]
cont = self.generate(
text=context,
stop_tokens=stop_tokens,
recompute=self.neox_args.recompute,
maximum_tokens=max_gen_toks,
**kwargs,
)
if cont:
s = cont[0]["text"] or ""
Expand All @@ -166,7 +188,7 @@ def _collate(x):
s = s.split(term)[0]

# partial caching
self.cache_hook.add_partial("greedy_until", (context, until), s)
self.cache_hook.add_partial("generate_until", (context, until), s)

res.append(s)

Expand Down Expand Up @@ -366,7 +388,6 @@ def run_eval(
eval_tasks=None,
num_fewshot=0,
bootstrap_iters=2,
description_dict=None,
use_cache=True,
name="neox",
limit=None,
Expand All @@ -385,8 +406,12 @@ def run_eval(
"winogrande",
"mathqa",
"pubmedqa",
"triviaqa",
]

# register all the default tasks bundled with lm-evaluation-harness repository
tasks.initialize_tasks()

# Returns a list containing all values of the task registry that
# match at least one of the patterns
import fnmatch
Expand All @@ -413,31 +438,25 @@ def pattern_match(patterns, source_list):
task_dict = tasks.get_task_dict(eval_tasks)

lm = self
if use_cache:
# TODO(jon-tow): Append a subset of `neox_args` to the cache database
# name arg to distinguish model runs that use different configurations.
lm = base.CachingLM(lm, "lm_cache/" + name + ".db")

results = evaluator.evaluate(
lm=lm,
task_dict=tasks.get_task_dict(eval_tasks),
description_dict=description_dict,
results = evaluator.simple_evaluate(
model=lm,
tasks=eval_tasks,
num_fewshot=num_fewshot,
limit=limit,
limit=10, #limit,
bootstrap_iters=bootstrap_iters,
use_cache="lm_cache/" + name + ".db" if use_cache else None
# TODO: Append a subset of `neox_args` to the cache database
# name arg to distinguish model runs that use different configurations.
)

results["config"] = {
results["config"].update({
"model": name,
"model_args": dataclasses.asdict(self.neox_args),
"num_fewshot": num_fewshot,
"batch_size": self.batch_size,
"device": str(self.device),
"no_cache": not use_cache,
"limit": limit,
"bootstrap_iters": bootstrap_iters,
"description_dict": description_dict,
}
})

if was_training:
self.model.train()
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ git+https://github.com/EleutherAI/DeeperSpeed.git@b9260436e7da3e297fc6bedfd27d9e
ftfy>=6.0.1
git+https://github.com/EleutherAI/lm_dataformat.git@4eec05349977071bf67fc072290b95e31c8dd836
huggingface_hub>=0.11.0
lm_eval==0.3.0
git+https://github.com/EleutherAI/lm-evaluation-harness.git@main#egg=lm_eval
mpi4py>=3.0.3
numpy>=1.22.0
pybind11>=2.6.2
Expand Down

0 comments on commit 35673d5

Please sign in to comment.