Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[T170073014] Rewrite distributed examples for Tensor Parallel, Sequence Parallel, 2D (FSDP + TP) #1201

Merged
merged 20 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
21a5fcf
update requirements.txt
lessw2020 Nov 15, 2023
f962b60
add torchrun support, move to init_device_mesh
lessw2020 Nov 15, 2023
bc3c1dd
update twod fully working
lessw2020 Nov 16, 2023
11a3bb2
ensure proper dp group seeding for synth data
lessw2020 Nov 16, 2023
9cebdf0
swiglu model added
lessw2020 Nov 16, 2023
2447883
sequential running of custom, auto, seq parallel models
lessw2020 Nov 16, 2023
a388c20
streamline to 2D TP only for two_d_parallel example
lessw2020 Nov 17, 2023
842c3f0
sequence parallel working...needs init_device_mesh update
lessw2020 Nov 18, 2023
3aa1c53
seq parallel now using init_device_mesh
lessw2020 Nov 21, 2023
b54e2ec
tp and sp examples all working and updated
lessw2020 Nov 21, 2023
4889e3b
updates from code review
lessw2020 Nov 21, 2023
b215178
remove utils.py. Sample models created in example files
lessw2020 Nov 22, 2023
242c328
remove originals.py, leftover imports, various updates from code revi…
lessw2020 Nov 22, 2023
2f4a083
code linting via ruff
lessw2020 Nov 22, 2023
742966b
code formatting via ruff
lessw2020 Nov 22, 2023
7da71bc
move rank_log to utils.py, update example files
lessw2020 Nov 22, 2023
836f798
move logging imports and config to log_utils, update examples with ne…
lessw2020 Nov 22, 2023
2de0144
add gpu verification, update run_python_examples.sh
lessw2020 Nov 22, 2023
77fe3d8
update min gpu = 4 for fsdp+tp
lessw2020 Nov 22, 2023
5f4a5d3
move gpu check to top of examples, but before import init_device_mesh…
lessw2020 Nov 22, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions distributed/tensor_parallelism/fsdp_tp_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import sys
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.tensor.parallel import (
parallelize_module,
ColwiseParallel,
RowwiseParallel,
)

import os
from log_utils import rank_log, get_logger, verify_min_gpu_count


# ---- GPU check ------------
_min_gpu_count = 4

if not verify_min_gpu_count(min_gpus=_min_gpu_count):
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
sys.exit()
# ---------------------------

from torch.distributed._tensor.device_mesh import init_device_mesh


"""
This is the script to test 2D Parallel which combines Tensor/Sequence
parallel with Fully Sharded Data Parallel (TP/SP + FSDP) on a toy model
in the SPMD style. We show an E2E working flow from forward, backward
and optimization.

We enabled Fully Sharded Data Parallel + Tensor Parallel in
separate parallel dimensions:
Data Parallel ("dp") across hosts
Tensor Parallel ("tp") within each host

We use a simple diagram to illustrate below:

======================================================================
------------ ------------ ------------ ------------
| Host 1 | | Host 2 | | | | Host N |
| 8 GPUs | | 8 GPUs | | | | 8 GPUs |
| | | | | ... | | |
| (TP) | | (TP) | | | | (TP) |
|[0,1,..,7]| |[8,9..,15]| | | |[8N-8,8N-7|
| | | | | | | .., 8N-1]|
| | | | | | | |
------------ ------------ ------------ ------------
FSDP:
[0, 8, ..., 8N-8], [1, 9, ..., 8N-7], ..., [7, 15, ..., 8N-1]
======================================================================

More details can be seen in the slide:
https://docs.google.com/presentation/d/17g6WqrO00rP3MsxbRENsPpjrlSkwiA_QB4r93_eB5is/
"""


def find_multiple(n: int, k: int) -> int:
"""function to find resizing multiple for SwiGLU MLP"""
if n % k == 0:
return n
return n + k - (n % k)


class MLP_swiglu(nn.Module):
"""SwiGLU to showcase a Llama style MLP model"""

def __init__(self, mlp_dim: int = 1024) -> None:
super().__init__()
hidden_dim = 4 * mlp_dim
scaled_hidden = int(2 * hidden_dim / 3)
rounded_hidden = find_multiple(scaled_hidden, 256)

self.in_proj = nn.Linear(mlp_dim, rounded_hidden, bias=False)
self.gate_proj = nn.Linear(mlp_dim, rounded_hidden, bias=False)
self.out_proj = nn.Linear(rounded_hidden, mlp_dim, bias=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.silu(self.in_proj(x)) * self.gate_proj(x)
x = self.out_proj(x)
return x


"""
Main body of the demo of a basic version of tensor parallel by using
PyTorch native APIs.
"""
tp_size = 2
logger = get_logger()

# understand world topology
_rank = int(os.environ["RANK"])
_world_size = int(os.environ["WORLD_SIZE"])


print(f"Starting PyTorch 2D (FSDP + TP) example on rank {_rank}.")
assert (
_world_size % tp_size == 0
), f"World size {_world_size} needs to be divisible by TP size {tp_size}"


# create a sharding plan based on the given world_size.
dp_size = _world_size // tp_size

# Create a device mesh with 2 dimensions.
# First dim is the data parallel dimension
# Second dim is the tensor parallel dimension.
device_mesh = init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp"))

rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}")
tp_mesh = device_mesh["tp"]
dp_mesh = device_mesh["dp"]

# To support identical inputs for TP groups, we need the dp process group
dp_pg = device_mesh.get_dim_groups()[0]

# For TP, input needs to be same across all TP ranks.
# while for SP, input can be different across all ranks.
# We will use dp_rank for setting the random seed
# to mimic the behavior of the dataloader.
dp_rank = dist.get_rank(dp_pg)


# create model and move it to GPU with id rank
_mlp_dim = 1024
base_model_swiglu = MLP_swiglu(mlp_dim=_mlp_dim).to("cuda")


# Custom parallelization plan for the swiglu MLP model
custom_tp_model = parallelize_module(
module=base_model_swiglu,
device_mesh=tp_mesh,
parallelize_plan={
"in_proj": ColwiseParallel(),
"gate_proj": ColwiseParallel(),
"out_proj": RowwiseParallel(),
},
)

rank_log(_rank, logger, f"Model after parallelization {custom_tp_model=}\n")

# Init FSDP using the dp device mesh
sharded_model = FSDP(custom_tp_model, device_mesh=dp_mesh, use_orig_params=True)

# Create an optimizer for the parallelized and sharded model.
lr = 3e-3
rank_log(_rank, logger, f"Creating AdamW optimizer with learning rate {lr}")
optimizer = torch.optim.AdamW(sharded_model.parameters(), lr=lr, foreach=True)

# Training loop:
# Perform a num of iterations of forward/backward
# and optimizations for the sharded module.
rank_log(_rank, logger, "\nStarting 2D training...")
num_iterations = 10
batch_size = 2

for i in range(num_iterations):
# seeding with dp_rank to ensure identical inputs for TP groups
torch.manual_seed(i + dp_rank)
inp = torch.rand(batch_size, _mlp_dim, device="cuda")

output = sharded_model(inp)
output.sum().backward()
optimizer.step()
rank_log(_rank, logger, f"2D iter {i} complete")

rank_log(_rank, logger, "2D training successfully completed!")
22 changes: 22 additions & 0 deletions distributed/tensor_parallelism/log_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import logging
import torch

logging.basicConfig(
format="%(asctime)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p", level=logging.INFO
)

def get_logger():
return logging.getLogger(__name__)


def rank_log(_rank, logger, msg):
"""helper function to log only on global rank 0"""
if _rank == 0:
logger.info(f" {msg}")


def verify_min_gpu_count(min_gpus: int = 2) -> bool:
""" verification that we have at least 2 gpus to run dist examples """
has_cuda = torch.cuda.is_available()
gpu_count = torch.cuda.device_count()
return has_cuda and gpu_count >= min_gpus
6 changes: 3 additions & 3 deletions distributed/tensor_parallelism/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Python dependencies required for running the example

--pre
--extra-index-url https://download.pytorch.org/whl/nightly/cu113
--extra-index-url https://download.pytorch.org/whl/nightly/cu116
torch >= 1.14.0.dev0; sys_platform == "linux"
--extra-index-url https://download.pytorch.org/whl/nightly/cu118
--extra-index-url https://download.pytorch.org/whl/nightly/cu121
torch >= 2.2.0.dev0; sys_platform == "linux"
13 changes: 13 additions & 0 deletions distributed/tensor_parallelism/run_example.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@

# To run samples:
# bash run_example.sh {file_to_run.py} {num_gpus}
# where file_to_run = example to launch. Default = 'fsdp_tp_example.py'
# num_gpus = num local gpus to use (must be at least 2). Default = 4

# samples to run include:
# sequence_parallel_example.py
# tensor_parallel_example.py
# fsdp_tp_example.py

echo "Launching ${1:-fsdp_tp_example.py} with ${2:-4} gpus"
torchrun --nnodes=1 --nproc_per_node=${2:-4} --rdzv_id=101 --rdzv_endpoint="localhost:5972" ${1:-fsdp_tp_example.py}
148 changes: 87 additions & 61 deletions distributed/tensor_parallelism/sequence_parallel_example.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,30 @@
import argparse

import os
import sys
import torch
import torch.multiprocessing as mp
import torch.nn as nn

from torch.distributed._tensor import Shard

from torch.distributed.tensor.parallel import (
parallelize_module,
ColwiseParallel,
RowwiseParallel,
)

from log_utils import rank_log, get_logger, verify_min_gpu_count


# ---- GPU check ------------
_min_gpu_count = 2

if not verify_min_gpu_count(min_gpus=_min_gpu_count):
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
sys.exit()
# ---------------------------

from torch.distributed._tensor import DeviceMesh
from torch.distributed.tensor.parallel import parallelize_module
from utils import cleanup, setup, ToyModel

try:
from torch.distributed.tensor.parallel import (
SequenceParallel
)
SP_AVAILABLE = True
except BaseException as e:
pass
from torch.distributed._tensor.device_mesh import init_device_mesh



"""
Expand All @@ -33,51 +44,66 @@
"""


def demo_sp(rank, args):
"""
Main body of the demo of a basic version of sequence parallel by using
PyTorch native APIs.
"""
print(f"Running SP example on rank {rank}.")
setup(rank, args.world_size)

# create a sharding plan based on the given world_size.
device_mesh = DeviceMesh("cuda", torch.arange(0, args.world_size))

# create model and move it to GPU with id rank
model = ToyModel().cuda(rank)
# Create a optimizer for the parallelized module.
LR = 0.25
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
# Parallelize the module based on the given Parallel Style.
model = parallelize_module(model, device_mesh, SequenceParallel())

# Perform a num of iterations of forward/backward
# and optimizations for the sharded module.
for _ in range(args.iter_nums):
# For SP, input can be different across all ranks.
inp = torch.rand(20, 10).cuda(rank)
output = model(inp)
output.sum().backward()
optimizer.step()

cleanup()


if __name__ == "__main__":
n_gpus = torch.cuda.device_count()
parser = argparse.ArgumentParser()
# This is passed in via cmd
parser.add_argument("--world_size", type=int, default=n_gpus)
parser.add_argument("--iter_nums", type=int, default=10)
args = parser.parse_args()
# The main entry point is called directly without using subprocess
if n_gpus < 2:
print("Requires at least 2 GPUs to run.")
elif not SP_AVAILABLE:
print(
"PyTorch doesn't have Sequence Parallelism available,"
" need nightly build."
)
else:
mp.spawn(demo_sp, args=(args,), nprocs=args.world_size, join=True)
class ToyModel(nn.Module):
"""MLP based model"""

def __init__(self):
super().__init__()
self.in_proj = nn.Linear(10, 32)
self.relu = nn.ReLU()
self.out_proj = nn.Linear(32, 5)

def forward(self, x):
return self.out_proj(self.relu(self.in_proj(x)))


"""
Main body of the demo of a basic version of sequence parallel by using
PyTorch native APIs.
"""
logger = get_logger()

# create a device mesh based on the given world_size.
device_mesh = init_device_mesh(
device_type="cuda", mesh_shape=(int(os.environ["WORLD_SIZE"]),)
)

_rank = device_mesh.get_rank()

print(f"Starting PyTorch Sequence Parallel example on rank {_rank}.")

rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}")

# create model and move it to GPU. Init_device_mesh has already assigned gpu ids...
model = ToyModel().to("cuda")

# Custom parallelization plan for the model
sp_model = parallelize_module(
module=model,
device_mesh=device_mesh,
parallelize_plan={
"in_proj": ColwiseParallel(input_layouts=Shard(0)),
"out_proj": RowwiseParallel(output_layouts=Shard(0)),
},
)


# Create a optimizer for the parallelized module.
lr = 0.25
optimizer = torch.optim.AdamW(sp_model.parameters(), lr=lr, foreach=True)


# Perform a num of iterations of forward/backward
# and optimizations for the sharded module.
num_iters = 10
rank_log(_rank, logger, "Sequence Parallel training starting...")

for i in range(num_iters):
# For SP, input can be different across all ranks.
inp = torch.rand(20, 10, device="cuda")
output = sp_model(inp)
output.sum().backward()
optimizer.step()
rank_log(_rank, logger, f"Sequence Parallel iter {i} completed")

rank_log(_rank, logger, "Sequence Parallel training completed!")
Loading
Loading