Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add s3 checkpoint syncing #1010

Merged
merged 10 commits into from
Sep 23, 2023
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ GPT-NeoX leverages many of the same features and technologies as the popular Meg
* Easy connections with the open source ecosystem, including Hugging Face's [tokenizers](https://github.com/huggingface/tokenizers) and [transformers](https://github.com/huggingface/transformers/) libraries, logging via [WandB](https://wandb.ai/site), and evaluation via our [Language Model Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness).

## News
**[8/10/2023]** We now support checkpointing with AWS S3! Activate with the `s3_path` config option (for more detail, see [the PR](https://github.com/EleutherAI/gpt-neox/pull/1010))

**[9/20/2023]** As of https://github.com/EleutherAI/gpt-neox/pull/1035, we have deprecated Flash Attention 0.x and 1.x, and migrated support to Flash Attention 2.x. We don't believe this will cause problems, but if you have a specific use-case that requires old flash support using the latest GPT-NeoX, please raise an issue.

**[8/10/2023]** We have experimental support for LLaMA 2 and Flash Attention v2 supported in our [math-lm](https://github.com/EleutherAI/math-lm) project that will be upstreamed later this month.
Expand Down
18 changes: 17 additions & 1 deletion configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = 1d20559
Default = aa6c176

current git hash of repository

Expand Down Expand Up @@ -1169,6 +1169,22 @@ Training Arguments



- **s3_path**: str

Default = None

Path to s3 bucket for saving checkpoints.



- **s3_chunk_size**: int

Default = 104857600

The number of bytes in each file chunk when uploading to s3. Defaults to 100MiB.



- **config_files**: dict

Default = None
Expand Down
128 changes: 126 additions & 2 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,23 @@
"""Input/output checkpointing."""

import json
import math
import os
import re
import shutil
import time
import random
import sys
import numpy as np

try:
import boto3
except ModuleNotFoundError:
print("For s3 checkpointing, please install boto3 either using requirements/requirements-s3.txt or https://github.com/boto/boto3")
try:
import hf_transfer
except ModuleNotFoundError:
print("For s3 checkpointing, please install hf_transfer either using requirements/requirements-s3.txt or https://github.com/huggingface/hf_transfer")
import torch
from glob import glob

Expand Down Expand Up @@ -137,6 +147,10 @@ def get_checkpoint_name(checkpoints_path, iteration, release=False, mp_rank=None
)


def get_checkpoint_tag(iteration: int) -> str:
return f"global_step{iteration}"


def delete_old_checkpoints(save_dir, n_to_keep):
if torch.distributed.get_rank() == 0:
ckpt_dir_regex = r"global_step[\d]*"
Expand Down Expand Up @@ -188,7 +202,7 @@ def save_ds_checkpoint(iteration, model, neox_args):
sd["checkpoint_validation_logits"] = logits

# checkpoint folder name
tag = f"global_step{iteration}"
tag = get_checkpoint_tag(iteration)

# save checkpoint
model.save_checkpoint(neox_args.save, tag=tag, client_state=sd)
Expand All @@ -203,6 +217,111 @@ def save_ds_checkpoint(iteration, model, neox_args):
f.write(config_data)
else:
json.dump(config_data, f)
def multiprocessing_starmap(func, args, num_processes=None):
"""Wrapper to allow for re-usable multiprocessing pools with `spawn` context handling
Args:
func (Callable): Function to call
args (Iterable): Iterable of arguments to pass to `func`
num_processes (int, optional): Number of processes to spawn. Defaults to `multiprocessing.cpu_count() - 1`
"""
import multiprocessing
num_processes = num_processes or (multiprocessing.cpu_count() - 1)
with multiprocessing.get_context("spawn").Pool(processes=num_processes) as process_pool:
process_pool.starmap(func, args)
process_pool.terminate()
process_pool.join()
del process_pool


def _upload(
file_path: str,
s3_key: str,
chunk_size: int = 104_857_600,
max_files: int = 64,
parallel_failures: int = 63,
max_retries: int = 5,
):
"""Upload local file to S3 using `hf_transfer` library
Args:
file_path (str): Local filename to upload
s3_key (str): S3 key to upload to. E.g. `s3://bucket-name/path/to/file`
chunk_size (int, optional): Chunk size to use for multipart upload.
Defaults to 100MiB = 104_857_600
max_files (int, optional): Number of open file handles, which determines
the maximum number of parallel downloads. Defaults to 64
parallel_failures (int, optional): Number of maximum failures of different
chunks in parallel (cannot exceed max_files). Defaults to 63
max_retries (int, optional): Number of retries for each chunk. Defaults to 5
"""
s3 = boto3.client('s3')
bucket = s3_key.split("s3://")[1].split("/")[0]
key = s3_key.split(bucket)[1].lstrip("/")

# 1. Init multipart upload and obtain unique upload identifier
upload = s3.create_multipart_upload(
ACL="bucket-owner-full-control",
Bucket=bucket,
Key=key,
)
upload_id = upload["UploadId"]

# 2. Generate presigned URLs for each part
file_size = os.stat(file_path).st_size
urls = []
nb_parts = math.ceil(file_size / chunk_size)
for part_number in range(1, nb_parts + 1):
params = {
"Bucket": bucket,
"Key": key,
"PartNumber": part_number,
"UploadId": upload_id,
}
urls.append(
s3.generate_presigned_url(
ClientMethod="upload_part", Params=params, ExpiresIn=86400
)
)

# 3. Upload parts in parallel
responses = hf_transfer.multipart_upload(
file_path=file_path,
parts_urls=urls,
chunk_size=chunk_size,
max_files=max_files,
parallel_failures=parallel_failures,
max_retries=max_retries,
)

# 4. Complete multipart upload request with ETag values
etag_with_parts = []
for part_number, header in enumerate(responses):
etag = header.get("etag")
etag_with_parts.append({"ETag": etag, "PartNumber": part_number + 1})
parts = {"Parts": etag_with_parts}
s3.complete_multipart_upload(
Bucket=bucket, Key=key, MultipartUpload=parts, UploadId=upload_id
)


def upload_checkpoint(iteration, neox_args):
local_checkpoint_path = os.path.join(os.path.abspath(neox_args.save), get_checkpoint_tag(iteration))
local_checkpoint_list = sorted(filter(
lambda x: os.path.isfile(x),
[str(p) for p in Path(local_checkpoint_path).rglob("*")],
))
remote_checkpoint_path = os.path.join(
neox_args.s3_path, os.path.basename(neox_args.save), get_checkpoint_tag(iteration))
remote_checkpoint_list = [
os.path.join(remote_checkpoint_path, os.path.relpath(local_checkpoint, local_checkpoint_path))
for local_checkpoint in local_checkpoint_list
]
inputs = zip(local_checkpoint_list, remote_checkpoint_list, [neox_args.s3_chunk_size] * len(local_checkpoint_list))

print_rank_0(f"[RANK {torch.distributed.get_rank()}] Uploading checkpoint `{local_checkpoint_path}` to `{remote_checkpoint_path}`...")
start = time.time()
multiprocessing_starmap(_upload, inputs)
total_time = time.time() - start
print_rank_0(f"[RANK {torch.distributed.get_rank()}] Uploaded checkpoint `{local_checkpoint_path}` to `{remote_checkpoint_path}` in {total_time:.2f}s")


def save_checkpoint(neox_args, iteration, model, optimizer, lr_scheduler):
Expand All @@ -213,6 +332,11 @@ def save_checkpoint(neox_args, iteration, model, optimizer, lr_scheduler):
else:
raise ValueError("Must be using deepspeed to use neox")

torch.distributed.barrier()
upload_to_s3 = torch.distributed.get_rank() == 0 and neox_args.s3_path is not None
if upload_to_s3:
upload_checkpoint(iteration, neox_args)

# Wait so everyone is done (necessary)
torch.distributed.barrier()
if neox_args.keep_last_n_checkpoints is not None:
Expand All @@ -233,7 +357,7 @@ def load_checkpoint(
if neox_args.finetune:
load_optim_and_scheduler = False
if iteration is not None:
tag = f"global_step{iteration}"
tag = get_checkpoint_tag(iteration)
else:
tag = None
checkpoint_name, state_dict = model.load_checkpoint(
Expand Down
10 changes: 10 additions & 0 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,16 @@ class NeoXArgsTraining(NeoXArgsTemplate):
Output directory to save checkpoints to.
"""

s3_path: str = None
"""
Path to s3 bucket for saving checkpoints.
"""

s3_chunk_size: int = 104_857_600
"""
The number of bytes in each file chunk when uploading to s3. Defaults to 100MiB.
"""

config_files: dict = None
"""
Store of original config files mapping config filename to file contents
Expand Down
2 changes: 2 additions & 0 deletions requirements/requirements-s3.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
hf-transfer>=0.1.3
boto3