Skip to content

Commit

Permalink
Merge branch 'main' into 1003-rope-casting
Browse files Browse the repository at this point in the history
  • Loading branch information
dashstander committed Sep 25, 2023
2 parents 89fc788 + e431ff5 commit 7db3cdf
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 5 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ GPT-NeoX leverages many of the same features and technologies as the popular Meg
* Easy connections with the open source ecosystem, including Hugging Face's [tokenizers](https://github.com/huggingface/tokenizers) and [transformers](https://github.com/huggingface/transformers/) libraries, logging via [WandB](https://wandb.ai/site), and evaluation via our [Language Model Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness).

## News
**[8/10/2023]** We now support checkpointing with AWS S3! Activate with the `s3_path` config option (for more detail, see [the PR](https://github.com/EleutherAI/gpt-neox/pull/1010))

**[9/20/2023]** As of https://github.com/EleutherAI/gpt-neox/pull/1035, we have deprecated Flash Attention 0.x and 1.x, and migrated support to Flash Attention 2.x. We don't believe this will cause problems, but if you have a specific use-case that requires old flash support using the latest GPT-NeoX, please raise an issue.

**[8/10/2023]** We have experimental support for LLaMA 2 and Flash Attention v2 supported in our [math-lm](https://github.com/EleutherAI/math-lm) project that will be upstreamed later this month.
Expand Down
18 changes: 17 additions & 1 deletion configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = 1d20559
Default = 1f832c1

current git hash of repository

Expand Down Expand Up @@ -1169,6 +1169,22 @@ Training Arguments



- **s3_path**: str

Default = None

Path to s3 bucket for saving checkpoints.



- **s3_chunk_size**: int

Default = 104857600

The number of bytes in each file chunk when uploading to s3. Defaults to 100MiB.



- **config_files**: dict

Default = None
Expand Down
128 changes: 126 additions & 2 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,23 @@
"""Input/output checkpointing."""

import json
import math
import os
import re
import shutil
import time
import random
import sys
import numpy as np

try:
import boto3
except ModuleNotFoundError:
print("For s3 checkpointing, please install boto3 either using requirements/requirements-s3.txt or https://github.com/boto/boto3")
try:
import hf_transfer
except ModuleNotFoundError:
print("For s3 checkpointing, please install hf_transfer either using requirements/requirements-s3.txt or https://github.com/huggingface/hf_transfer")
import torch
from glob import glob

Expand Down Expand Up @@ -137,6 +147,10 @@ def get_checkpoint_name(checkpoints_path, iteration, release=False, mp_rank=None
)


def get_checkpoint_tag(iteration: int) -> str:
return f"global_step{iteration}"


def delete_old_checkpoints(save_dir, n_to_keep):
if torch.distributed.get_rank() == 0:
ckpt_dir_regex = r"global_step[\d]*"
Expand Down Expand Up @@ -188,7 +202,7 @@ def save_ds_checkpoint(iteration, model, neox_args):
sd["checkpoint_validation_logits"] = logits

# checkpoint folder name
tag = f"global_step{iteration}"
tag = get_checkpoint_tag(iteration)

# save checkpoint
model.save_checkpoint(neox_args.save, tag=tag, client_state=sd)
Expand All @@ -203,6 +217,111 @@ def save_ds_checkpoint(iteration, model, neox_args):
f.write(config_data)
else:
json.dump(config_data, f)
def multiprocessing_starmap(func, args, num_processes=None):
"""Wrapper to allow for re-usable multiprocessing pools with `spawn` context handling
Args:
func (Callable): Function to call
args (Iterable): Iterable of arguments to pass to `func`
num_processes (int, optional): Number of processes to spawn. Defaults to `multiprocessing.cpu_count() - 1`
"""
import multiprocessing
num_processes = num_processes or (multiprocessing.cpu_count() - 1)
with multiprocessing.get_context("spawn").Pool(processes=num_processes) as process_pool:
process_pool.starmap(func, args)
process_pool.terminate()
process_pool.join()
del process_pool


def _upload(
file_path: str,
s3_key: str,
chunk_size: int = 104_857_600,
max_files: int = 64,
parallel_failures: int = 63,
max_retries: int = 5,
):
"""Upload local file to S3 using `hf_transfer` library
Args:
file_path (str): Local filename to upload
s3_key (str): S3 key to upload to. E.g. `s3://bucket-name/path/to/file`
chunk_size (int, optional): Chunk size to use for multipart upload.
Defaults to 100MiB = 104_857_600
max_files (int, optional): Number of open file handles, which determines
the maximum number of parallel downloads. Defaults to 64
parallel_failures (int, optional): Number of maximum failures of different
chunks in parallel (cannot exceed max_files). Defaults to 63
max_retries (int, optional): Number of retries for each chunk. Defaults to 5
"""
s3 = boto3.client('s3')
bucket = s3_key.split("s3://")[1].split("/")[0]
key = s3_key.split(bucket)[1].lstrip("/")

# 1. Init multipart upload and obtain unique upload identifier
upload = s3.create_multipart_upload(
ACL="bucket-owner-full-control",
Bucket=bucket,
Key=key,
)
upload_id = upload["UploadId"]

# 2. Generate presigned URLs for each part
file_size = os.stat(file_path).st_size
urls = []
nb_parts = math.ceil(file_size / chunk_size)
for part_number in range(1, nb_parts + 1):
params = {
"Bucket": bucket,
"Key": key,
"PartNumber": part_number,
"UploadId": upload_id,
}
urls.append(
s3.generate_presigned_url(
ClientMethod="upload_part", Params=params, ExpiresIn=86400
)
)

# 3. Upload parts in parallel
responses = hf_transfer.multipart_upload(
file_path=file_path,
parts_urls=urls,
chunk_size=chunk_size,
max_files=max_files,
parallel_failures=parallel_failures,
max_retries=max_retries,
)

# 4. Complete multipart upload request with ETag values
etag_with_parts = []
for part_number, header in enumerate(responses):
etag = header.get("etag")
etag_with_parts.append({"ETag": etag, "PartNumber": part_number + 1})
parts = {"Parts": etag_with_parts}
s3.complete_multipart_upload(
Bucket=bucket, Key=key, MultipartUpload=parts, UploadId=upload_id
)


def upload_checkpoint(iteration, neox_args):
local_checkpoint_path = os.path.join(os.path.abspath(neox_args.save), get_checkpoint_tag(iteration))
local_checkpoint_list = sorted(filter(
lambda x: os.path.isfile(x),
[str(p) for p in Path(local_checkpoint_path).rglob("*")],
))
remote_checkpoint_path = os.path.join(
neox_args.s3_path, os.path.basename(neox_args.save), get_checkpoint_tag(iteration))
remote_checkpoint_list = [
os.path.join(remote_checkpoint_path, os.path.relpath(local_checkpoint, local_checkpoint_path))
for local_checkpoint in local_checkpoint_list
]
inputs = zip(local_checkpoint_list, remote_checkpoint_list, [neox_args.s3_chunk_size] * len(local_checkpoint_list))

print_rank_0(f"[RANK {torch.distributed.get_rank()}] Uploading checkpoint `{local_checkpoint_path}` to `{remote_checkpoint_path}`...")
start = time.time()
multiprocessing_starmap(_upload, inputs)
total_time = time.time() - start
print_rank_0(f"[RANK {torch.distributed.get_rank()}] Uploaded checkpoint `{local_checkpoint_path}` to `{remote_checkpoint_path}` in {total_time:.2f}s")


def save_checkpoint(neox_args, iteration, model, optimizer, lr_scheduler):
Expand All @@ -213,6 +332,11 @@ def save_checkpoint(neox_args, iteration, model, optimizer, lr_scheduler):
else:
raise ValueError("Must be using deepspeed to use neox")

torch.distributed.barrier()
upload_to_s3 = torch.distributed.get_rank() == 0 and neox_args.s3_path is not None
if upload_to_s3:
upload_checkpoint(iteration, neox_args)

# Wait so everyone is done (necessary)
torch.distributed.barrier()
if neox_args.keep_last_n_checkpoints is not None:
Expand All @@ -233,7 +357,7 @@ def load_checkpoint(
if neox_args.finetune:
load_optim_and_scheduler = False
if iteration is not None:
tag = f"global_step{iteration}"
tag = get_checkpoint_tag(iteration)
else:
tag = None
checkpoint_name, state_dict = model.load_checkpoint(
Expand Down
6 changes: 4 additions & 2 deletions megatron/learning_rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def get_lr(self):

num_iters_ = num_iters_ - self.warmup_iter
if self.decay_style == "linear":
lr = self.start_lr * (self.end_iter - num_iters_) / self.end_iter
end_iter_ = self.end_iter - self.warmup_iter
lr = self.start_lr * (end_iter_ - num_iters_) / end_iter_
elif self.decay_style == "cosine":
end_iter_ = self.end_iter - self.warmup_iter
lr = self.min_lr + (
Expand All @@ -81,7 +82,8 @@ def get_lr(self):
)
elif self.decay_style == "exponential":
# exp(-0.693) = 1/2
lr = self.start_lr * math.exp(-0.693 * num_iters_ / self.end_iter)
end_iter = self.end_iter - self.warmup_iter
lr = self.start_lr * math.exp(-0.693 * num_iters_ / end_iter)
else:
lr = self.start_lr
return max(lr, self.min_lr)
Expand Down
10 changes: 10 additions & 0 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,16 @@ class NeoXArgsTraining(NeoXArgsTemplate):
Output directory to save checkpoints to.
"""

s3_path: str = None
"""
Path to s3 bucket for saving checkpoints.
"""

s3_chunk_size: int = 104_857_600
"""
The number of bytes in each file chunk when uploading to s3. Defaults to 100MiB.
"""

config_files: dict = None
"""
Store of original config files mapping config filename to file contents
Expand Down
2 changes: 2 additions & 0 deletions requirements/requirements-s3.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
hf-transfer>=0.1.3
boto3

0 comments on commit 7db3cdf

Please sign in to comment.