Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DTensor&DModule&DDP&Examples] feature updates and new examples #35

Merged
merged 1 commit into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
/**/*.tar.gz
/**/*.json.gz
/**/*.log
/**/.DS_Store
*_checkpoint_dir

# pre-commit config
./.pre-commit-config.yaml
Expand Down
32 changes: 32 additions & 0 deletions examples/llama2_4D_finetune/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Finetune a Llama2 3b model in 4D parallelism using veScale

## Overview

Finetune a pretrained llama2_3b model on a small Shakespeare dataset.
Dropout is set to 0 for this model, thus no randomness is involved during finetuning.
The reason for choosing llama2_3b instead of the 7b one is that it fits in 1 GPU so that we can check the correctness of veScale.

## Prerequisite

```
pip3 install sentencepiece
```

## Run

```
cd data/shakespeare/ && python3 prepare.py && cd ../..
torchrun --standalone --nproc_per_node={GPU_CNT} llama_train.py --dp={dp_size} --tp={tp_size} --max_iters={max_iters}
```

## Experiments

Like nanoGPT, we finetune the model with a constant learning rate `3e-5` and set `grad_clip = 1`.
The model state as well as the gradients and the optimizer states are in `bf16`.

![](./figures/llama2_3b_train_losses.jpg)


## Caveats

1. Currently, it does not works with `transformers==4.38.2`. The error happens when doing a backward step, the `aten._scaled_dot_product_efficient_attention` operator outputs the error message: `attn_bias: wrong shape (head dimension)`.
57 changes: 57 additions & 0 deletions examples/llama2_4D_finetune/data/shakespeare/prepare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
################################################################################
# Copyright (c) 2022 Andrej Karpathy

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
################################################################################
# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates.
################################################################################

import os
import requests
import numpy as np
from transformers import LlamaTokenizer

# download the tiny shakespeare dataset
input_file_path = os.path.join(os.path.dirname(__file__), "input.txt")
if not os.path.exists(input_file_path):
data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
with open(input_file_path, "w", encoding="utf-8") as f:
f.write(requests.get(data_url).text)

with open(input_file_path, encoding="utf-8") as f:
data = f.read()
n = len(data)
train_data = data[: int(n * 0.9)]
val_data = data[int(n * 0.9) :]

# tokenize with llama2 tokenizer
tokenizer = LlamaTokenizer.from_pretrained("openlm-research/open_llama_7b")
train_ids = tokenizer.encode(train_data)
val_ids = tokenizer.encode(val_data)
print(f"train has {len(train_ids):,} tokens")
print(f"val has {len(val_ids):,} tokens")

# export to bin files
train_ids = np.array(train_ids, dtype=np.uint16)
val_ids = np.array(val_ids, dtype=np.uint16)
train_ids.tofile(os.path.join(os.path.dirname(__file__), "train.bin"))
val_ids.tofile(os.path.join(os.path.dirname(__file__), "val.bin"))

# train.bin has 318,905 tokens
# val.bin has 37,782 tokens
9 changes: 9 additions & 0 deletions examples/llama2_4D_finetune/data/shakespeare/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@

# tiny shakespeare

Tiny shakespeare, of the good old char-rnn fame :)

After running `prepare.py`:

- train.bin has 318,905 tokens
- val.bin has 37,782 tokens
64 changes: 64 additions & 0 deletions examples/llama2_4D_finetune/data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################

import os
from typing import Optional

import numpy as np
import torch

from vescale.dtensor.device_mesh import DeviceMesh
from vescale import distribute_tensor
from vescale.dtensor.placement_types import Replicate
from vescale.dtensor import empty as d_empty


class DataLoader:
def __init__(self, dataset: str, seqlen: int, mesh: Optional[DeviceMesh] = None, dp_rank: int = 0):
self.data_dir = os.path.join("data", dataset)
self.seqlen = seqlen
self.mesh = mesh
self.dp_rank = dp_rank
if mesh is not None:
self.device_type = mesh.device_type
else:
self.device_type = "cuda"

def get_batch(self, split, bsz, lbsz):
# We recreate np.memmap every batch to avoid a memory leak, as per
# https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
if split == "train":
data = np.memmap(os.path.join(self.data_dir, "train.bin"), dtype=np.uint16, mode="r")
else:
data = np.memmap(os.path.join(self.data_dir, "val.bin"), dtype=np.uint16, mode="r")
if self.mesh is not None:
ix = d_empty((bsz,), device_mesh=self.mesh, placements=[Replicate()])
else:
ix = torch.empty((bsz,), device="cuda")
ix = torch.randint_like(ix, len(data) - self.seqlen, dtype=torch.int64)
if self.mesh is not None:
ix = ix.to_local()
if self.mesh is None or self.mesh.get_rank() == 0:
print(f"sum(ix) {sum(ix)}")
ix = torch.split(ix, lbsz)[self.dp_rank]
x = torch.stack([torch.from_numpy((data[i : i + self.seqlen]).astype(np.int64)) for i in ix])
y = torch.stack([torch.from_numpy((data[i + 1 : i + 1 + self.seqlen]).astype(np.int64)) for i in ix])
x, y = x.to(self.device_type), y.to(self.device_type)
if self.mesh is not None:
x = distribute_tensor(x, self.mesh["TP"], [Replicate()])
y = distribute_tensor(y, self.mesh["TP"], [Replicate()])
return x, y
98 changes: 98 additions & 0 deletions examples/llama2_4D_finetune/exp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################

import os
import re


def parse_train_loss(log_fn, name=None):
lines = open(log_fn).readlines()
train_losses = []
for line in lines:
if "loss" in line and "iter" in line:
token = line.split()[line.split().index("loss") + 1]
train_loss = float(token)
train_losses.append(train_loss)
if name is None:
name = log_fn
print(f'"{name}": {train_losses},')


def parse(log_fn, name=None):
lines = open(log_fn).readlines()
val_losses = []
for line in lines:
if "val_loss" in line:
token = line.split()[line.split().index("val_loss:") + 1]
val_loss = float(token)
val_losses.append(val_loss)
if name is None:
name = log_fn
print(f'"{name}": {val_losses},')


GPU_CNT = 4
DP_SIZES = [1, 2, 4]
# DP_SIZES = [4]
SINGLE_GPU_RUN = "python3"
MULTI_GPU_RUN = f"torchrun --standalone --nproc_per_node={GPU_CNT}"
CODE = "llama_train.py"
LOG_PREFIX = "llama2"
TRAIN_BIN_PATH = "data/shakespeare/train.bin"


def run_exps(max_iters, dtypes, run=True):
if not os.path.isfile(TRAIN_BIN_PATH):
os.system(f"cd data/shakespeare/ && python3 prepare.py && cd ../..")
os.makedirs("logs", exist_ok=True)
if run:
for dtype in dtypes:
dt = "bfloat16" if dtype == "bf16" else "float32"
cmd = f"{SINGLE_GPU_RUN} {CODE} --dp=1 --tp=1 --max_iters={max_iters} --dtype='{dt}'"
log_fn = f"logs/{LOG_PREFIX}_1gpu_{dtype}_max_iters_{max_iters}.log"
# print(f"run {cmd} > {log_fn} 2> {log_fn}.err")
# os.system(f"{cmd} > {log_fn} 2> {log_fn}.err")
for dp_size in DP_SIZES:
tp_size = GPU_CNT // dp_size
dt = "bfloat16" if dtype == "bf16" else "float32"
cmd = f"{MULTI_GPU_RUN} {CODE} --dp={dp_size} --tp={tp_size} --max_iters={max_iters} --dtype='{dt}'"
log_fn = f"logs/{LOG_PREFIX}_{GPU_CNT}gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log"
print(f"run {cmd} > {log_fn} 2> {log_fn}.err")
os.system(f"{cmd} > {log_fn} 2> {log_fn}.err")

print("train_loss = {")
for dtype in dtypes:
parse_train_loss(f"logs/{LOG_PREFIX}_1gpu_{dtype}_max_iters_{max_iters}.log", f"1GPU_{dtype}")
for dp_size in DP_SIZES:
tp_size = GPU_CNT // dp_size
log_fn = f"logs/{LOG_PREFIX}_{GPU_CNT}gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log"
parse_train_loss(log_fn, f"{GPU_CNT}GPU_DP{dp_size}_TP{tp_size}_{dtype}")
print("}")

# print("val_loss = {")
# for dtype in dtypes:
# # parse(f"logs/{LOG_PREFIX}_1gpu_{dtype}_max_iters_{max_iters}.log", f"1GPU_{dtype}")
# for dp_size in DP_SIZES:
# tp_size = GPU_CNT // dp_size
# log_fn = f"logs/{LOG_PREFIX}_{GPU_CNT}gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log"
# parse(log_fn, f"{GPU_CNT}GPU_DP{dp_size}_TP{tp_size}_{dtype}")
# print("}")


if __name__ == "__main__":
run_exps(100000, ["bf16"], run=True)
# run_exps(10, ["bf16"], run=False)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading