Skip to content

Commit

Permalink
Merge branch 'main' into fused-rope
Browse files Browse the repository at this point in the history
  • Loading branch information
Quentin-Anthony authored Dec 23, 2023
2 parents 7f3ae33 + 1148a0f commit a3572c7
Show file tree
Hide file tree
Showing 12 changed files with 272 additions and 152 deletions.
2 changes: 1 addition & 1 deletion .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
@@ -1 +1 @@
* @EleutherAI/pm-gptneo
* @Quentin-Anthony
39 changes: 24 additions & 15 deletions .github/workflows/coverity_scan.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,38 @@ jobs:

steps:
- uses: actions/checkout@v2
with:
path: gpt-neox

- name: Install utils
run: |
apt update -y && apt upgrade -y
apt install curl jq wget -y
sudo apt update -y && sudo apt upgrade -y
sudo apt install curl jq wget -y
- name: Coverity Download
run: |
wget https://scan.coverity.com/download/linux64 --post-data "token=$COVERITY_TOKEN&project=EleutherAI%2Fgpt-neox" -O coverity_tool.tgz
$GITHUB_WORKSPACE/bin/cov-configure --python
$GITHUB_WORKSPACE/bin/cov-configure --gcc
wget https://scan.coverity.com/download/linux64 --post-data "token=$COVERITY_TOKEN&project=$COVERITY_PROJECT" -O coverity_tool.tgz --no-verbose
mkdir $GITHUB_WORKSPACE/coverity && tar xvf coverity_tool.tgz -C $GITHUB_WORKSPACE/coverity --strip-components=1
$GITHUB_WORKSPACE/coverity/bin/cov-configure --python
$GITHUB_WORKSPACE/coverity/bin/cov-configure --gcc
- name: Coverity Scan
- name: Coverity Scan and Upload
run: |
set -x
$GITHUB_WORKSPACE/bin/cov-build --dir cov-int --no-command --fs-capture-search $GITHUB_WORKSPACE
- name: Coverity Upload
run: |
pushd $GITHUB_WORKSPACE
cd $GITHUB_WORKSPACE/gpt-neox
$GITHUB_WORKSPACE/coverity/bin/cov-build --dir $GITHUB_WORKSPACE/cov-int --no-command --fs-capture-search ./
popd
tar caf build-results.bz2 cov-int
curl --form token=$COV_PASSPHRASE \
curl --form token=$COVERITY_TOKEN \
--form email=$COV_USER \
--form file=@GITHUB_WORKSPACE/build-results.bz2 \
--form version="Version" \
--form description="Build" \
https://scan.coverity.com/builds?project=EleutherAI%2Fgpt-neox
--form [email protected] \
--form version="${{ inputs.build_version }}" \
--form description="${{ inputs.build_description }}" \
https://scan.coverity.com/builds?project=$COVERITY_PROJECT
- name: Upload Scan Build as Artifact
uses: actions/upload-artifact@v3
with:
name: coverity-build-${{ github.sha }}
path: build-results.bz2
190 changes: 105 additions & 85 deletions README.md

Large diffs are not rendered by default.

10 changes: 9 additions & 1 deletion configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = 589323d
Default = 79befef

current git hash of repository

Expand Down Expand Up @@ -261,6 +261,14 @@ Model Arguments



- **use_qk_layernorm**: bool

Default = False

Use QK Normalization



- **layernorm_epsilon**: float

Default = 1e-05
Expand Down
File renamed without changes.
143 changes: 103 additions & 40 deletions eval_tasks/eval_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,8 @@
# limitations under the License.

from megatron.utils import is_local_main, print_rank_0
import best_download

# patch best_download (eval harness downloader) to only happen on the first local rank
fn = best_download.download_file


def _download_file(*args, **kwargs):
if is_local_main():
fn(*args, **kwargs)


best_download.download_file = _download_file

import copy
import os
import sys
import dataclasses
Expand All @@ -38,13 +27,13 @@ def _download_file(*args, **kwargs):
import torch
import torch.nn.functional as F

from lm_eval.models.gpt2 import GPT2LM
from lm_eval import tasks, evaluator, utils, base
from lm_eval.models.huggingface import HFLM
from lm_eval import tasks, evaluator, utils, api
from megatron.text_generation_utils import generate_samples_from_prompt
from megatron import mpu


class EvalHarnessAdapter(GPT2LM):
class EvalHarnessAdapter(HFLM):
"""
An adapter to run NeoX models on LM Evaluation Harness (https://github.com/EleutherAI/lm-evaluation-harness) tasks.
Expand All @@ -56,13 +45,13 @@ class EvalHarnessAdapter(GPT2LM):
"""

def __init__(self, model, forward_step_fn, neox_args, batch_size=None):
self.cache_hook = base.CacheHook(None)
self.model = model
self.cache_hook = api.model.CacheHook(None)
self._model = model
self.neox_args = neox_args
self.tokenizer = neox_args.tokenizer
self._device = torch.device(f"cuda:{neox_args.local_rank}")
self._eot_token_id = neox_args.tokenizer.eod_id
self._max_length = neox_args.max_position_embeddings // 2
self._max_length = neox_args.max_position_embeddings
self._max_gen_toks = 128
self._vocab_size = neox_args.padded_vocab_size

Expand Down Expand Up @@ -94,8 +83,6 @@ def __init__(self, model, forward_step_fn, neox_args, batch_size=None):
generate_samples_from_prompt,
neox_args=neox_args,
model=model,
maximum_tokens=self._max_gen_toks,
temperature=0.0,
)

@property
Expand Down Expand Up @@ -123,15 +110,23 @@ def batch_size(self):
def device(self):
return self._device

def tok_encode(self, string: str):
@property
def rank(self):
return 0

@property
def world_size(self):
return 1

def tok_encode(self, string: str, **kwargs):
return self.tokenizer.encode(string)

def tok_decode(self, tokens):
def tok_decode(self, tokens, **kwargs):
return self.tokenizer.decode(tokens)

def greedy_until(self, requests):
def generate_until(self, requests):
"""
Greedy until is lm_eval harness' way to say "do greedy generation" - necessary for some tasks.
Generate until is lm_eval harness' way to say "do greedy generation" - necessary for some tasks.
the eval harness dispatches requests to the model, and the model does argmax generation, the results of which
are returned to the eval harness to evaluate.
Expand All @@ -143,19 +138,46 @@ def greedy_until(self, requests):
self.model.module.inference_mode(use_cache=True) # tell model to cache kv pairs
res = []

# get only the args from each Instance object
reqs = [req.args for req in requests]

def _collate(x):
toks = self.tokenizer.encode(x[0])
return (len(toks), x[0])

reord = utils.Reorderer(requests, _collate)
for context, until in tqdm(reord.get_reordered(), "Running greedy generation"):
if isinstance(until, str):
until = [until]
reord = utils.Reorderer(reqs, _collate)
for context, gen_kwargs in tqdm(reord.get_reordered(), "Running greedy generation"):
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys():
until = kwargs.pop("until")
if isinstance(until, str):
until = [kwargs]
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {kwargs}"
)
if not until:
until = [self.tok_decode(self.eot_token_id)]
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks

if "do_sample" in kwargs.keys():
kwargs.pop("do_sample")

stop_tokens = [self.tokenizer.encode(i) for i in until]
cont = self.generate(
text=context,
stop_tokens=stop_tokens,
recompute=self.neox_args.recompute,
maximum_tokens=max_gen_toks,
**kwargs,
)
if cont:
s = cont[0]["text"] or ""
Expand All @@ -166,7 +188,7 @@ def _collate(x):
s = s.split(term)[0]

# partial caching
self.cache_hook.add_partial("greedy_until", (context, until), s)
self.cache_hook.add_partial("generate_until", (context, until), s)

res.append(s)

Expand Down Expand Up @@ -366,7 +388,6 @@ def run_eval(
eval_tasks=None,
num_fewshot=0,
bootstrap_iters=2,
description_dict=None,
use_cache=True,
name="neox",
limit=None,
Expand All @@ -385,8 +406,12 @@ def run_eval(
"winogrande",
"mathqa",
"pubmedqa",
"triviaqa"
]

# register all the default tasks bundled with lm-evaluation-harness repository
tasks.initialize_tasks()

# Returns a list containing all values of the task registry that
# match at least one of the patterns
import fnmatch
Expand All @@ -401,6 +426,8 @@ def pattern_match(patterns, source_list):
eval_tasks = pattern_match(eval_tasks, tasks.ALL_TASKS)
print(f"Found tasks: {eval_tasks}")

assert len(eval_tasks) > 0, "Must run at least one task"

# **HACK INCOMING**:
# first get task dict on local main rank
# the tasks are downloaded *as they are initialized*, and the downloads don't like multithreading.
Expand All @@ -413,31 +440,67 @@ def pattern_match(patterns, source_list):
task_dict = tasks.get_task_dict(eval_tasks)

lm = self

if use_cache:
# TODO(jon-tow): Append a subset of `neox_args` to the cache database
# name arg to distinguish model runs that use different configurations.
lm = base.CachingLM(lm, "lm_cache/" + name + ".db")
use_cache = 'lm_cache/neox' + '_dp_rank' + str(self._dp_rank) + '_dp_group' + str(self._dp_group) + '.db'
print(f"Using cache at {use_cache}...")
lm = lm_eval.api.model.CachingLM(
lm,
use_cache
# each rank receives a different cache db.
# necessary to avoid multiple writes to cache at once
# TODO: Append a subset of `neox_args` to the cache database
# name arg to distinguish model runs that use different configurations.
)

# from simple_evaluate:
# override fewshot values for all tasks we can
for task_name in task_dict.keys():
task_obj = task_dict[task_name]
if type(task_obj) == tuple:
group, task_obj = task_obj
if task_obj is None:
continue

config = task_obj._config

if num_fewshot is not None:
if config["num_fewshot"] == 0:
utils.eval_logger.info(
f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
)
else:
default_num_fewshot = config["num_fewshot"]
if not default_num_fewshot:
utils.eval_logger.warning(
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
)

task_obj._config["num_fewshot"] = num_fewshot

results = evaluator.evaluate(
lm=lm,
task_dict=tasks.get_task_dict(eval_tasks),
description_dict=description_dict,
num_fewshot=num_fewshot,
limit=limit,
task_dict=task_dict,
limit=10, #limit,
bootstrap_iters=bootstrap_iters,
log_samples=False,
)

results["config"] = {
"model": name,
"model_args": dataclasses.asdict(self.neox_args),
"num_fewshot": num_fewshot,
"batch_size": self.batch_size,
"device": str(self.device),
"no_cache": not use_cache,
"use_cache": use_cache,
"limit": limit,
"bootstrap_iters": bootstrap_iters,
"description_dict": description_dict,
}
results["git_hash"] = utils.get_git_commit_hash()

print(results.keys())
for task_name in task_dict.keys():
if "alias" in results["results"][task_name]:
results["results"][task_name].pop("alias")

if was_training:
self.model.train()
Expand Down
6 changes: 2 additions & 4 deletions megatron/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,17 +92,15 @@ def get_flops(neox_args, iter_time_s) -> float:
hidden_size = neox_args.hidden_size
num_layers = neox_args.num_layers
ckpt_activations_factor = 4 if neox_args.checkpoint_activations else 3
flops_calc1 = (
flops_per_iteration = (
24
* ckpt_activations_factor
* batch_size
* seq_len
* num_layers
* (hidden_size**2)
* (1.0 + (seq_len / (6.0 * hidden_size)))
* (1.0 + (seq_len / (6.0 * hidden_size)) + (vocab_size / (16.0 * num_layers * hidden_size)))
)
flops_calc2 = vocab_size / (16.0 * num_layers * hidden_size)
flops_per_iteration = flops_calc1 + flops_calc2
return flops_per_iteration / (iter_time_s * world_size)


Expand Down
15 changes: 15 additions & 0 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,16 @@ def __init__(
neox_args.num_attention_heads, world_size
)
self.pos_emb = neox_args.pos_emb
self.use_qk_layernorm = neox_args.use_qk_layernorm
if self.use_qk_layernorm:
norm, eps = get_norm(neox_args)
self.qk_layernorm = norm(
[
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
],
eps=eps,
)

# Strided linear layer.
self.query_key_value = mpu.ColumnParallelLinear(
Expand Down Expand Up @@ -644,6 +654,11 @@ def forward(self, hidden_states, attention_mask, layer_past=None):
mixed_x_layer, 3
)

# QK Normalization https://arxiv.org/abs/2302.05442
if self.use_qk_layernorm:
query_layer = self.qk_layernorm(query_layer)
key_layer = self.qk_layernorm(key_layer)

if exists(self.rotary_emb):
if exists(self.rotary_ndims):
# partial rotary
Expand Down
Loading

0 comments on commit a3572c7

Please sign in to comment.