From 21a5fcf8f9c883e84571ab6bc2c2fe57f7edbb87 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Wed, 15 Nov 2023 14:13:42 -0800 Subject: [PATCH 01/20] update requirements.txt --- distributed/tensor_parallelism/requirements.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/distributed/tensor_parallelism/requirements.txt b/distributed/tensor_parallelism/requirements.txt index f7b8148247..c6b283a441 100644 --- a/distributed/tensor_parallelism/requirements.txt +++ b/distributed/tensor_parallelism/requirements.txt @@ -1,6 +1,6 @@ # Python dependencies required for running the example --pre ---extra-index-url https://download.pytorch.org/whl/nightly/cu113 ---extra-index-url https://download.pytorch.org/whl/nightly/cu116 -torch >= 1.14.0.dev0; sys_platform == "linux" \ No newline at end of file +--extra-index-url https://download.pytorch.org/whl/nightly/cu118 +--extra-index-url https://download.pytorch.org/whl/nightly/cu121 +torch >= 2.2.0.dev0; sys_platform == "linux" From f962b605feac794ed2509621cbb138d1e2b1b3e1 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Wed, 15 Nov 2023 15:43:13 -0800 Subject: [PATCH 02/20] add torchrun support, move to init_device_mesh --- .../tensor_parallelism/run_twod_parallel.sh | 1 + .../two_d_parallel_example.py | 60 ++++++++++++++----- distributed/tensor_parallelism/utils.py | 3 + 3 files changed, 50 insertions(+), 14 deletions(-) create mode 100644 distributed/tensor_parallelism/run_twod_parallel.sh diff --git a/distributed/tensor_parallelism/run_twod_parallel.sh b/distributed/tensor_parallelism/run_twod_parallel.sh new file mode 100644 index 0000000000..7a0de15053 --- /dev/null +++ b/distributed/tensor_parallelism/run_twod_parallel.sh @@ -0,0 +1 @@ +torchrun --nnodes=1 --nproc_per_node=8 --rdzv_id=101 --rdzv_endpoint="localhost:5973" two_d_parallel_example.py diff --git a/distributed/tensor_parallelism/two_d_parallel_example.py b/distributed/tensor_parallelism/two_d_parallel_example.py index 5c28db5adf..7c98ef94ae 100644 --- a/distributed/tensor_parallelism/two_d_parallel_example.py +++ b/distributed/tensor_parallelism/two_d_parallel_example.py @@ -12,7 +12,13 @@ ) from torch.distributed.tensor.parallel.fsdp import enable_2d_with_fsdp -from utils import cleanup, setup, ToyModel +# updated imports +from torch.distributed._shard.sharded_tensor import ShardedTensor +from torch.distributed._tensor import DTensor, Replicate, sharding_prop +from torch.distributed._tensor.device_mesh import init_device_mesh +import os + +from utils import cleanup, torchrun_setup, ToyModel try: from torch.distributed.tensor.parallel import ( SequenceParallel @@ -54,24 +60,47 @@ """ -def demo_2d(rank, args): +def demo_2d(args): """ Main body of the demo of a basic version of tensor parallel by using PyTorch native APIs. """ - print(f"Running basic Megatron style TP example on rank {rank}.") - setup(rank, args.world_size) + torchrun_setup() + + + _rank = int(os.environ["RANK"]) + _local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(_local_rank) + _world_size = int(os.environ["WORLD_SIZE"]) + _local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) + + def rank_print(msg): + if _rank==0: + print(f"{msg}") + + print(f"Running basic Megatron style TP example on rank {_rank}.") + assert ( - args.world_size % args.tp_size == 0 - ), "World size needs to be divisible by TP size" + _world_size % args.tp_size == 0 + ), f"World size {_world_size} needs to be divisible by TP size {args.tp_size}" + device = f"cuda" # :{_local_rank}" # create a sharding plan based on the given world_size. - device_mesh = DeviceMesh( - "cuda", torch.arange(0, args.world_size).view(-1, args.tp_size) - ) + + + + dp_size = _world_size // args.tp_size + + device_mesh = init_device_mesh(device, (dp_size, args.tp_size)) + assert device_mesh is not None, "unable to create valid device mesh" + rank_print(f"Device Mesh created: {device_mesh=}") + + + + # create model and move it to GPU with id rank - model = ToyModel().cuda(rank) + model = ToyModel().cuda(_rank) # Create a optimizer for the parallelized module. LR = 0.25 optimizer = torch.optim.SGD(model.parameters(), lr=LR) @@ -83,8 +112,10 @@ def demo_2d(rank, args): assert ( enable_2d_with_fsdp() ), "FSDP 2D hook is not registered. Please use PyTorch with version >= 2.0" - dp_pg = device_mesh.get_dim_groups()[0] - model = FSDP(model, process_group=dp_pg) + # dp_pg = device_mesh.get_dim_groups()[0] + # rank_print(f"{dp_pg=}") + # dist.barrier() + model = FSDP(model, device_mesh = device_mesh) # Perform a num of iterations of forward/backward # and optimizations for the sharded module. @@ -123,5 +154,6 @@ def demo_2d(rank, args): "PyTorch doesn't have Sequence Parallelism available," " need nightly build." ) - else: - mp.spawn(demo_2d, args=(args,), nprocs=args.world_size, join=True) + #else: + #mp.spawn(demo_2d, args=(args,), nprocs=args.world_size, join=True) + demo_2d(args) diff --git a/distributed/tensor_parallelism/utils.py b/distributed/tensor_parallelism/utils.py index a55f85c026..8eab89af59 100644 --- a/distributed/tensor_parallelism/utils.py +++ b/distributed/tensor_parallelism/utils.py @@ -15,6 +15,9 @@ def setup(rank, world_size): dist.init_process_group("nccl", rank=rank, world_size=world_size) torch.cuda.set_device(rank) +def torchrun_setup(): + """we use torchrun for init so no params needed here""" + dist.init_process_group("nccl") def cleanup(): dist.destroy_process_group() From bc3c1dd73eafd2a3ed1c3c236818622350703408 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Wed, 15 Nov 2023 20:40:02 -0800 Subject: [PATCH 03/20] update twod fully working --- distributed/tensor_parallelism/original.py | 127 ++++++++++++++++++ .../two_d_parallel_example.py | 29 ++-- distributed/tensor_parallelism/utils.py | 2 +- 3 files changed, 141 insertions(+), 17 deletions(-) create mode 100644 distributed/tensor_parallelism/original.py diff --git a/distributed/tensor_parallelism/original.py b/distributed/tensor_parallelism/original.py new file mode 100644 index 0000000000..5c28db5adf --- /dev/null +++ b/distributed/tensor_parallelism/original.py @@ -0,0 +1,127 @@ +import argparse + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from torch.distributed._tensor import DeviceMesh +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.tensor.parallel import ( + PairwiseParallel, + parallelize_module, +) +from torch.distributed.tensor.parallel.fsdp import enable_2d_with_fsdp + +from utils import cleanup, setup, ToyModel +try: + from torch.distributed.tensor.parallel import ( + SequenceParallel + ) + SP_AVAILABLE = True +except BaseException as e: + pass + + +""" +This is the script to test 2D Parallel which combines Tensor/Sequence +parallel with Fully Sharded Data Parallel (TP/SP + FSDP) on a toy model +in the SPMD style. We show an E2E working flow from forward, backward +and optimization. + +We enabled Fully Sharded Data Parallel + Tensor Parallel in +separate parallel dimensions: + Data Parallel across hosts + Tensor Parallel within each host + + We use a simple diagram to illustrate below: + +====================================================================== +------------ ------------ ------------ ------------ +| Host 1 | | Host 2 | | | | Host N | +| 8 GPUs | | 8 GPUs | | | | 8 GPUs | +| | | | | ... | | | +| (TP) | | (TP) | | | | (TP) | +|[0,1,..,7]| |[8,9..,15]| | | |[8N-8,8N-7| +| | | | | | | .., 8N-1]| +| | | | | | | | +------------ ------------ ------------ ------------ +FSDP: +[0, 8, ..., 8N-8], [1, 9, ..., 8N-7], ..., [7, 15, ..., 8N-1] +====================================================================== + +More details can be seen in the slide: +https://docs.google.com/presentation/d/17g6WqrO00rP3MsxbRENsPpjrlSkwiA_QB4r93_eB5is/ +""" + + +def demo_2d(rank, args): + """ + Main body of the demo of a basic version of tensor parallel by using + PyTorch native APIs. + """ + print(f"Running basic Megatron style TP example on rank {rank}.") + setup(rank, args.world_size) + assert ( + args.world_size % args.tp_size == 0 + ), "World size needs to be divisible by TP size" + + # create a sharding plan based on the given world_size. + device_mesh = DeviceMesh( + "cuda", torch.arange(0, args.world_size).view(-1, args.tp_size) + ) + + # create model and move it to GPU with id rank + model = ToyModel().cuda(rank) + # Create a optimizer for the parallelized module. + LR = 0.25 + optimizer = torch.optim.SGD(model.parameters(), lr=LR) + # Parallelize the module based on the given Parallel Style. + parallel_style = SequenceParallel() if args.run_seq_parallel else PairwiseParallel() + model = parallelize_module(model, device_mesh, parallel_style, tp_mesh_dim=1) + + # We need to register hooks for TP + FSDP integration. + assert ( + enable_2d_with_fsdp() + ), "FSDP 2D hook is not registered. Please use PyTorch with version >= 2.0" + dp_pg = device_mesh.get_dim_groups()[0] + model = FSDP(model, process_group=dp_pg) + + # Perform a num of iterations of forward/backward + # and optimizations for the sharded module. + for i in range(args.iter_nums): + # For TP, input needs to be same across all TP ranks. + # while for SP, input can be different across all ranks. + # Setting the random seed is to mimic the behavior of dataloader. + dp_rank = ( + rank + if args.run_seq_parallel + else dist.get_rank(dp_pg) + ) + torch.manual_seed(i + dp_rank) + inp = torch.rand(20, 10).cuda(rank) + output = model(inp) + output.sum().backward() + optimizer.step() + + cleanup() + + +if __name__ == "__main__": + n_gpus = torch.cuda.device_count() + parser = argparse.ArgumentParser() + # This is passed in via cmd + parser.add_argument("--world_size", type=int, default=n_gpus) + parser.add_argument("--iter_nums", type=int, default=10) + parser.add_argument("--run_seq_parallel", type=bool, default=False) + parser.add_argument("--tp_size", type=int, default=2) + args = parser.parse_args() + # The main entry point is called directly without using subprocess + if n_gpus < 4: + print("Requires at least 4 GPUs to run.") + elif not SP_AVAILABLE: + print( + "PyTorch doesn't have Sequence Parallelism available," + " need nightly build." + ) + else: + mp.spawn(demo_2d, args=(args,), nprocs=args.world_size, join=True) diff --git a/distributed/tensor_parallelism/two_d_parallel_example.py b/distributed/tensor_parallelism/two_d_parallel_example.py index 7c98ef94ae..c12e03a44a 100644 --- a/distributed/tensor_parallelism/two_d_parallel_example.py +++ b/distributed/tensor_parallelism/two_d_parallel_example.py @@ -91,13 +91,12 @@ def rank_print(msg): dp_size = _world_size // args.tp_size - device_mesh = init_device_mesh(device, (dp_size, args.tp_size)) + device_mesh = init_device_mesh(device, (dp_size, args.tp_size), mesh_dim_names=("dp","tp")) assert device_mesh is not None, "unable to create valid device mesh" - rank_print(f"Device Mesh created: {device_mesh=}") - - - + rank_print(f"Device Mesh created: {device_mesh=}") + tp_mesh = device_mesh["tp"] + dp_mesh = device_mesh["dp"] # create model and move it to GPU with id rank model = ToyModel().cuda(_rank) @@ -106,16 +105,11 @@ def rank_print(msg): optimizer = torch.optim.SGD(model.parameters(), lr=LR) # Parallelize the module based on the given Parallel Style. parallel_style = SequenceParallel() if args.run_seq_parallel else PairwiseParallel() - model = parallelize_module(model, device_mesh, parallel_style, tp_mesh_dim=1) + model = parallelize_module(model, tp_mesh, parallel_style) - # We need to register hooks for TP + FSDP integration. - assert ( - enable_2d_with_fsdp() - ), "FSDP 2D hook is not registered. Please use PyTorch with version >= 2.0" - # dp_pg = device_mesh.get_dim_groups()[0] - # rank_print(f"{dp_pg=}") - # dist.barrier() - model = FSDP(model, device_mesh = device_mesh) + + + model = FSDP(model, device_mesh = dp_mesh) # Perform a num of iterations of forward/backward # and optimizations for the sharded module. @@ -123,16 +117,19 @@ def rank_print(msg): # For TP, input needs to be same across all TP ranks. # while for SP, input can be different across all ranks. # Setting the random seed is to mimic the behavior of dataloader. - dp_rank = ( + dp_rank = _rank + '''( rank if args.run_seq_parallel else dist.get_rank(dp_pg) ) + ''' torch.manual_seed(i + dp_rank) - inp = torch.rand(20, 10).cuda(rank) + inp = torch.rand(20, 10).cuda(_rank) output = model(inp) output.sum().backward() optimizer.step() + rank_print(f"tp iter {i}") cleanup() diff --git a/distributed/tensor_parallelism/utils.py b/distributed/tensor_parallelism/utils.py index 8eab89af59..4c2c96c444 100644 --- a/distributed/tensor_parallelism/utils.py +++ b/distributed/tensor_parallelism/utils.py @@ -9,7 +9,7 @@ def setup(rank, world_size): os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "12355" + os.environ["MASTER_PORT"] = "12359" # initialize the process group dist.init_process_group("nccl", rank=rank, world_size=world_size) From 11a3bb22dbadfb726d541988dfda97db557f6547 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Thu, 16 Nov 2023 10:19:08 -0800 Subject: [PATCH 04/20] ensure proper dp group seeding for synth data --- .../two_d_parallel_example.py | 55 +++++++++++-------- 1 file changed, 31 insertions(+), 24 deletions(-) diff --git a/distributed/tensor_parallelism/two_d_parallel_example.py b/distributed/tensor_parallelism/two_d_parallel_example.py index c12e03a44a..e949ca9095 100644 --- a/distributed/tensor_parallelism/two_d_parallel_example.py +++ b/distributed/tensor_parallelism/two_d_parallel_example.py @@ -2,7 +2,6 @@ import torch import torch.distributed as dist -import torch.multiprocessing as mp from torch.distributed._tensor import DeviceMesh from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -10,9 +9,8 @@ PairwiseParallel, parallelize_module, ) -from torch.distributed.tensor.parallel.fsdp import enable_2d_with_fsdp -# updated imports + from torch.distributed._shard.sharded_tensor import ShardedTensor from torch.distributed._tensor import DTensor, Replicate, sharding_prop from torch.distributed._tensor.device_mesh import init_device_mesh @@ -36,8 +34,8 @@ We enabled Fully Sharded Data Parallel + Tensor Parallel in separate parallel dimensions: - Data Parallel across hosts - Tensor Parallel within each host + Data Parallel ("dp") across hosts + Tensor Parallel ("tp") within each host We use a simple diagram to illustrate below: @@ -75,6 +73,7 @@ def demo_2d(args): _local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) def rank_print(msg): + """helper function to print only on rank 0""" if _rank==0: print(f"{msg}") @@ -84,13 +83,15 @@ def rank_print(msg): _world_size % args.tp_size == 0 ), f"World size {_world_size} needs to be divisible by TP size {args.tp_size}" - device = f"cuda" # :{_local_rank}" - # create a sharding plan based on the given world_size. - + device = f"cuda" + # create a sharding plan based on the given world_size. dp_size = _world_size // args.tp_size + # Create a device mesh with 2 dimensions. + # First dim is the data parallel dimension + # and second dim is the tensor parallel dimension. device_mesh = init_device_mesh(device, (dp_size, args.tp_size), mesh_dim_names=("dp","tp")) assert device_mesh is not None, "unable to create valid device mesh" @@ -98,39 +99,45 @@ def rank_print(msg): tp_mesh = device_mesh["tp"] dp_mesh = device_mesh["dp"] + # To support identical inputs for TP groups, we need the dp process group + dp_pg = device_mesh.get_dim_groups()[0] + + # For TP, input needs to be same across all TP ranks. + # while for SP, input can be different across all ranks. + # We will use dp_rank for setting the random seed + # to mimic the behavior of the dataloader. + dp_rank = _rank if args.run_seq_parallel else dist.get_rank(dp_pg) + + # create model and move it to GPU with id rank model = ToyModel().cuda(_rank) - # Create a optimizer for the parallelized module. - LR = 0.25 - optimizer = torch.optim.SGD(model.parameters(), lr=LR) + # Create an optimizer for the parallelized module. + lr = 3e-3 + rank_print(f"Creating AdamW optimizer with learning rate {lr}") + optimizer = torch.optim.AdamW(model.parameters(), lr=lr) + # Parallelize the module based on the given Parallel Style. parallel_style = SequenceParallel() if args.run_seq_parallel else PairwiseParallel() model = parallelize_module(model, tp_mesh, parallel_style) - - + # Init FSDP using the dp device mesh model = FSDP(model, device_mesh = dp_mesh) + + # Training loop: # Perform a num of iterations of forward/backward # and optimizations for the sharded module. for i in range(args.iter_nums): - # For TP, input needs to be same across all TP ranks. - # while for SP, input can be different across all ranks. - # Setting the random seed is to mimic the behavior of dataloader. - dp_rank = _rank - '''( - rank - if args.run_seq_parallel - else dist.get_rank(dp_pg) - ) - ''' + # seeding to ensure idential inputs for TP pairs (when running TP) torch.manual_seed(i + dp_rank) inp = torch.rand(20, 10).cuda(_rank) + output = model(inp) output.sum().backward() optimizer.step() - rank_print(f"tp iter {i}") + rank_print(f"2D iter {i} complete") + rank_print(f"2D training successfully completed!") cleanup() From 9cebdf0674dfcb518f68bf769443c1e62ebff4a9 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Thu, 16 Nov 2023 12:54:36 -0800 Subject: [PATCH 05/20] swiglu model added --- .../two_d_parallel_example.py | 31 ++++++++++++++++--- distributed/tensor_parallelism/utils.py | 23 +++++++++++++- 2 files changed, 48 insertions(+), 6 deletions(-) diff --git a/distributed/tensor_parallelism/two_d_parallel_example.py b/distributed/tensor_parallelism/two_d_parallel_example.py index e949ca9095..92d59ab11f 100644 --- a/distributed/tensor_parallelism/two_d_parallel_example.py +++ b/distributed/tensor_parallelism/two_d_parallel_example.py @@ -8,6 +8,8 @@ from torch.distributed.tensor.parallel import ( PairwiseParallel, parallelize_module, + ColwiseParallel, + RowwiseParallel, ) @@ -16,7 +18,7 @@ from torch.distributed._tensor.device_mesh import init_device_mesh import os -from utils import cleanup, torchrun_setup, ToyModel +from utils import cleanup, torchrun_setup, ToyModel, MLP_swiglu try: from torch.distributed.tensor.parallel import ( SequenceParallel @@ -111,6 +113,11 @@ def rank_print(msg): # create model and move it to GPU with id rank model = ToyModel().cuda(_rank) + + _mlp_dim = 1024 + mlp_model = MLP_swiglu(mlp_dim=_mlp_dim).cuda(_rank) + + rank_print(f"{mlp_model=}") # Create an optimizer for the parallelized module. lr = 3e-3 rank_print(f"Creating AdamW optimizer with learning rate {lr}") @@ -118,19 +125,33 @@ def rank_print(msg): # Parallelize the module based on the given Parallel Style. parallel_style = SequenceParallel() if args.run_seq_parallel else PairwiseParallel() - model = parallelize_module(model, tp_mesh, parallel_style) + auto_model = parallelize_module(model, tp_mesh, parallel_style) + + # custom parallelization for the swiglu MLP model + mlp_tp_model = parallelize_module(module = mlp_model, + device_mesh = tp_mesh, + parallelize_plan = { + "in_proj": ColwiseParallel(), + "gate_proj": ColwiseParallel(), + "out_proj": RowwiseParallel(), + }, + tp_mesh_dim=1, + ) - # Init FSDP using the dp device mesh - model = FSDP(model, device_mesh = dp_mesh) + rank_print(f" after parallelization {mlp_tp_model=}") + # Init FSDP using the dp device mesh + model = FSDP(mlp_tp_model, device_mesh = dp_mesh) # Training loop: # Perform a num of iterations of forward/backward # and optimizations for the sharded module. + rank_print(f"\nStarting 2D training...") for i in range(args.iter_nums): # seeding to ensure idential inputs for TP pairs (when running TP) torch.manual_seed(i + dp_rank) - inp = torch.rand(20, 10).cuda(_rank) + inp = torch.rand(2, _mlp_dim).cuda(_rank) + # inp = torch.rand(20,10).cuda(_rank) output = model(inp) output.sum().backward() diff --git a/distributed/tensor_parallelism/utils.py b/distributed/tensor_parallelism/utils.py index 4c2c96c444..35ae55740f 100644 --- a/distributed/tensor_parallelism/utils.py +++ b/distributed/tensor_parallelism/utils.py @@ -5,7 +5,7 @@ import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn - +import torch.nn.functional as F def setup(rank, world_size): os.environ["MASTER_ADDR"] = "localhost" @@ -32,3 +32,24 @@ def __init__(self): def forward(self, x): return self.net2(self.relu(self.net1(x))) + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + +class MLP_swiglu(nn.Module): + def __init__(self, mlp_dim: int= 1024) -> None: + super().__init__() + hidden_dim = 4 * mlp_dim + scaled_hidden = int(2 * hidden_dim / 3) + rounded_hidden = find_multiple(scaled_hidden, 256) + + self.in_proj = nn.Linear(mlp_dim, rounded_hidden, bias=False) + self.gate_proj = nn.Linear(mlp_dim, rounded_hidden, bias=False) + self.out_proj = nn.Linear(rounded_hidden, mlp_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.silu(self.in_proj(x)) * self.gate_proj(x) + x = self.out_proj(x) + return x From 244788359071c9d93bc6113d6351bfa2e6bc409c Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Thu, 16 Nov 2023 13:34:14 -0800 Subject: [PATCH 06/20] sequential running of custom, auto, seq parallel models --- .../two_d_parallel_example.py | 72 +++++++++++++++---- 1 file changed, 57 insertions(+), 15 deletions(-) diff --git a/distributed/tensor_parallelism/two_d_parallel_example.py b/distributed/tensor_parallelism/two_d_parallel_example.py index 92d59ab11f..4386d4b033 100644 --- a/distributed/tensor_parallelism/two_d_parallel_example.py +++ b/distributed/tensor_parallelism/two_d_parallel_example.py @@ -112,36 +112,40 @@ def rank_print(msg): # create model and move it to GPU with id rank - model = ToyModel().cuda(_rank) + base_model_tp = ToyModel().cuda(_rank) + base_model_sp = ToyModel().cuda(_rank) _mlp_dim = 1024 - mlp_model = MLP_swiglu(mlp_dim=_mlp_dim).cuda(_rank) + base_model_swiglu = MLP_swiglu(mlp_dim=_mlp_dim).cuda(_rank) + + - rank_print(f"{mlp_model=}") - # Create an optimizer for the parallelized module. - lr = 3e-3 - rank_print(f"Creating AdamW optimizer with learning rate {lr}") - optimizer = torch.optim.AdamW(model.parameters(), lr=lr) # Parallelize the module based on the given Parallel Style. - parallel_style = SequenceParallel() if args.run_seq_parallel else PairwiseParallel() - auto_model = parallelize_module(model, tp_mesh, parallel_style) + # parallel_style = SequenceParallel() if args.run_seq_parallel else PairwiseParallel() + auto_tp_model = parallelize_module(base_model_tp, tp_mesh, PairwiseParallel()) + + sequence_p_model = parallelize_module(base_model_sp, tp_mesh, SequenceParallel()) # custom parallelization for the swiglu MLP model - mlp_tp_model = parallelize_module(module = mlp_model, + custom_tp_model = parallelize_module(module = base_model_swiglu, device_mesh = tp_mesh, parallelize_plan = { "in_proj": ColwiseParallel(), "gate_proj": ColwiseParallel(), "out_proj": RowwiseParallel(), }, - tp_mesh_dim=1, ) - rank_print(f" after parallelization {mlp_tp_model=}") + rank_print(f"after parallelization {custom_tp_model=}") # Init FSDP using the dp device mesh - model = FSDP(mlp_tp_model, device_mesh = dp_mesh) + sharded_model = FSDP(custom_tp_model, device_mesh = dp_mesh, use_orig_params=True) + + # Create an optimizer for the parallelized module. + lr = 3e-3 + rank_print(f"Creating AdamW optimizer with learning rate {lr}") + optimizer = torch.optim.AdamW(sharded_model.parameters(), lr=lr) # Training loop: # Perform a num of iterations of forward/backward @@ -153,12 +157,50 @@ def rank_print(msg): inp = torch.rand(2, _mlp_dim).cuda(_rank) # inp = torch.rand(20,10).cuda(_rank) - output = model(inp) + output = sharded_model(inp) + output.sum().backward() + optimizer.step() + rank_print(f"2D iter {i} complete") + + rank_print(f"custom 2D training successfully completed!") + + rank_print(f"starting auto parallel example...") + + sharded_model = FSDP(auto_tp_model, device_mesh = dp_mesh, use_orig_params=True) + lr = 3e-3 + rank_print(f"Creating AdamW optimizer with learning rate {lr}") + optimizer = torch.optim.AdamW(sharded_model.parameters(), lr=lr) + + rank_print(f"\nStarting 2D training of auto-parallelized model...") + for i in range(args.iter_nums): + # seeding to ensure identical inputs for TP pairs (when running TP) + torch.manual_seed(i + dp_rank) + inp = torch.rand(20, 10).cuda(_rank) + # inp = torch.rand(20,10).cuda(_rank) + + output = sharded_model(inp) output.sum().backward() optimizer.step() rank_print(f"2D iter {i} complete") + rank_print(f"Pairwise Parallel training successfully completed!") + + sharded_model = FSDP(sequence_p_model, device_mesh = dp_mesh, use_orig_params=True) + lr = 3e-3 + rank_print(f"Creating AdamW optimizer for seq parallel model with learning rate {lr}") + optimizer = torch.optim.AdamW(sharded_model.parameters(), lr=lr) + + rank_print(f"\nStarting Sequence Parallel training...") + for i in range(args.iter_nums): + # seeding to ensure different inputs for sequence parallel + torch.manual_seed(i + _rank) + inp = torch.rand(20, 10).cuda(_rank) + + output = sharded_model(inp) + output.sum().backward() + optimizer.step() + rank_print(f"Sequence Parallel iter {i} complete") - rank_print(f"2D training successfully completed!") + rank_print(f"Sequence Parallel training successfully completed!") cleanup() From a388c204a9be30e201a4d017bf0dcc0bf126fc39 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Thu, 16 Nov 2023 18:50:09 -0800 Subject: [PATCH 07/20] streamline to 2D TP only for two_d_parallel example --- .../tensor_parallelism/run_twod_parallel.sh | 2 +- .../two_d_parallel_example.py | 99 ++++--------------- 2 files changed, 18 insertions(+), 83 deletions(-) diff --git a/distributed/tensor_parallelism/run_twod_parallel.sh b/distributed/tensor_parallelism/run_twod_parallel.sh index 7a0de15053..1c173c372c 100644 --- a/distributed/tensor_parallelism/run_twod_parallel.sh +++ b/distributed/tensor_parallelism/run_twod_parallel.sh @@ -1 +1 @@ -torchrun --nnodes=1 --nproc_per_node=8 --rdzv_id=101 --rdzv_endpoint="localhost:5973" two_d_parallel_example.py +torchrun --nnodes=1 --nproc_per_node=4 --rdzv_id=101 --rdzv_endpoint="localhost:5973" two_d_parallel_example.py diff --git a/distributed/tensor_parallelism/two_d_parallel_example.py b/distributed/tensor_parallelism/two_d_parallel_example.py index 4386d4b033..644967968e 100644 --- a/distributed/tensor_parallelism/two_d_parallel_example.py +++ b/distributed/tensor_parallelism/two_d_parallel_example.py @@ -18,14 +18,7 @@ from torch.distributed._tensor.device_mesh import init_device_mesh import os -from utils import cleanup, torchrun_setup, ToyModel, MLP_swiglu -try: - from torch.distributed.tensor.parallel import ( - SequenceParallel - ) - SP_AVAILABLE = True -except BaseException as e: - pass +from utils import cleanup, torchrun_setup, MLP_swiglu """ @@ -67,15 +60,16 @@ def demo_2d(args): """ torchrun_setup() - + # understand world topology _rank = int(os.environ["RANK"]) _local_rank = int(os.environ["LOCAL_RANK"]) - torch.cuda.set_device(_local_rank) _world_size = int(os.environ["WORLD_SIZE"]) _local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) + torch.cuda.set_device(_local_rank) + def rank_print(msg): - """helper function to print only on rank 0""" + """helper function to print only on global rank 0""" if _rank==0: print(f"{msg}") @@ -93,7 +87,7 @@ def rank_print(msg): # Create a device mesh with 2 dimensions. # First dim is the data parallel dimension - # and second dim is the tensor parallel dimension. + # Second dim is the tensor parallel dimension. device_mesh = init_device_mesh(device, (dp_size, args.tp_size), mesh_dim_names=("dp","tp")) assert device_mesh is not None, "unable to create valid device mesh" @@ -108,26 +102,16 @@ def rank_print(msg): # while for SP, input can be different across all ranks. # We will use dp_rank for setting the random seed # to mimic the behavior of the dataloader. - dp_rank = _rank if args.run_seq_parallel else dist.get_rank(dp_pg) + dp_rank = dist.get_rank(dp_pg) # create model and move it to GPU with id rank - base_model_tp = ToyModel().cuda(_rank) - base_model_sp = ToyModel().cuda(_rank) - _mlp_dim = 1024 - base_model_swiglu = MLP_swiglu(mlp_dim=_mlp_dim).cuda(_rank) - + base_model_swiglu = MLP_swiglu(mlp_dim=_mlp_dim).cuda(_local_rank) - # Parallelize the module based on the given Parallel Style. - # parallel_style = SequenceParallel() if args.run_seq_parallel else PairwiseParallel() - auto_tp_model = parallelize_module(base_model_tp, tp_mesh, PairwiseParallel()) - - sequence_p_model = parallelize_module(base_model_sp, tp_mesh, SequenceParallel()) - - # custom parallelization for the swiglu MLP model + # Custom parallelization plan for the swiglu MLP model custom_tp_model = parallelize_module(module = base_model_swiglu, device_mesh = tp_mesh, parallelize_plan = { @@ -137,12 +121,12 @@ def rank_print(msg): }, ) - rank_print(f"after parallelization {custom_tp_model=}") + rank_print(f"Model after parallelization {custom_tp_model=}\n") # Init FSDP using the dp device mesh sharded_model = FSDP(custom_tp_model, device_mesh = dp_mesh, use_orig_params=True) - # Create an optimizer for the parallelized module. + # Create an optimizer for the parallelized and sharded model. lr = 3e-3 rank_print(f"Creating AdamW optimizer with learning rate {lr}") optimizer = torch.optim.AdamW(sharded_model.parameters(), lr=lr) @@ -151,76 +135,27 @@ def rank_print(msg): # Perform a num of iterations of forward/backward # and optimizations for the sharded module. rank_print(f"\nStarting 2D training...") - for i in range(args.iter_nums): - # seeding to ensure idential inputs for TP pairs (when running TP) - torch.manual_seed(i + dp_rank) - inp = torch.rand(2, _mlp_dim).cuda(_rank) - # inp = torch.rand(20,10).cuda(_rank) - - output = sharded_model(inp) - output.sum().backward() - optimizer.step() - rank_print(f"2D iter {i} complete") - - rank_print(f"custom 2D training successfully completed!") - - rank_print(f"starting auto parallel example...") - - sharded_model = FSDP(auto_tp_model, device_mesh = dp_mesh, use_orig_params=True) - lr = 3e-3 - rank_print(f"Creating AdamW optimizer with learning rate {lr}") - optimizer = torch.optim.AdamW(sharded_model.parameters(), lr=lr) + num_iterations = 10 + batch_size = 2 - rank_print(f"\nStarting 2D training of auto-parallelized model...") - for i in range(args.iter_nums): - # seeding to ensure identical inputs for TP pairs (when running TP) + for i in range(num_iterations): + # seeding with dp_rank to ensure identical inputs for TP groups torch.manual_seed(i + dp_rank) - inp = torch.rand(20, 10).cuda(_rank) - # inp = torch.rand(20,10).cuda(_rank) + inp = torch.rand(batch_size, _mlp_dim).cuda(_rank) output = sharded_model(inp) output.sum().backward() optimizer.step() rank_print(f"2D iter {i} complete") - rank_print(f"Pairwise Parallel training successfully completed!") - sharded_model = FSDP(sequence_p_model, device_mesh = dp_mesh, use_orig_params=True) - lr = 3e-3 - rank_print(f"Creating AdamW optimizer for seq parallel model with learning rate {lr}") - optimizer = torch.optim.AdamW(sharded_model.parameters(), lr=lr) - - rank_print(f"\nStarting Sequence Parallel training...") - for i in range(args.iter_nums): - # seeding to ensure different inputs for sequence parallel - torch.manual_seed(i + _rank) - inp = torch.rand(20, 10).cuda(_rank) - - output = sharded_model(inp) - output.sum().backward() - optimizer.step() - rank_print(f"Sequence Parallel iter {i} complete") + rank_print(f"2D training successfully completed!") - rank_print(f"Sequence Parallel training successfully completed!") cleanup() if __name__ == "__main__": - n_gpus = torch.cuda.device_count() parser = argparse.ArgumentParser() # This is passed in via cmd - parser.add_argument("--world_size", type=int, default=n_gpus) - parser.add_argument("--iter_nums", type=int, default=10) - parser.add_argument("--run_seq_parallel", type=bool, default=False) parser.add_argument("--tp_size", type=int, default=2) args = parser.parse_args() - # The main entry point is called directly without using subprocess - if n_gpus < 4: - print("Requires at least 4 GPUs to run.") - elif not SP_AVAILABLE: - print( - "PyTorch doesn't have Sequence Parallelism available," - " need nightly build." - ) - #else: - #mp.spawn(demo_2d, args=(args,), nprocs=args.world_size, join=True) demo_2d(args) From 842c3f03878d6789b0cb745b4dc17466df8de332 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Fri, 17 Nov 2023 17:48:06 -0800 Subject: [PATCH 08/20] sequence parallel working...needs init_device_mesh update --- .../run_sequence_parallel.sh | 1 + .../tensor_parallelism/run_tensor_parallel.sh | 1 + .../sequence_parallel_example.py | 89 +++++++++++++------ 3 files changed, 62 insertions(+), 29 deletions(-) create mode 100644 distributed/tensor_parallelism/run_sequence_parallel.sh create mode 100644 distributed/tensor_parallelism/run_tensor_parallel.sh diff --git a/distributed/tensor_parallelism/run_sequence_parallel.sh b/distributed/tensor_parallelism/run_sequence_parallel.sh new file mode 100644 index 0000000000..1c35e7ddab --- /dev/null +++ b/distributed/tensor_parallelism/run_sequence_parallel.sh @@ -0,0 +1 @@ +torchrun --nnodes=1 --nproc_per_node=4 --rdzv_id=101 --rdzv_endpoint="localhost:5973" sequence_parallel_example.py diff --git a/distributed/tensor_parallelism/run_tensor_parallel.sh b/distributed/tensor_parallelism/run_tensor_parallel.sh new file mode 100644 index 0000000000..05c5e34880 --- /dev/null +++ b/distributed/tensor_parallelism/run_tensor_parallel.sh @@ -0,0 +1 @@ +torchrun --nnodes=1 --nproc_per_node=4 --rdzv_id=101 --rdzv_endpoint="localhost:5973" tensor_parallel_example.py diff --git a/distributed/tensor_parallelism/sequence_parallel_example.py b/distributed/tensor_parallelism/sequence_parallel_example.py index 666713295f..07c022b34f 100644 --- a/distributed/tensor_parallelism/sequence_parallel_example.py +++ b/distributed/tensor_parallelism/sequence_parallel_example.py @@ -1,11 +1,17 @@ import argparse - +import os import torch -import torch.multiprocessing as mp -from torch.distributed._tensor import DeviceMesh -from torch.distributed.tensor.parallel import parallelize_module -from utils import cleanup, setup, ToyModel +from torch.distributed._tensor.device_mesh import init_device_mesh +from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard + +from torch.distributed.tensor.parallel import ( + PairwiseParallel, + parallelize_module, + ColwiseParallel, + RowwiseParallel, +) +from utils import cleanup, ToyModel, torchrun_setup try: from torch.distributed.tensor.parallel import ( @@ -33,51 +39,76 @@ """ -def demo_sp(rank, args): +def demo_sp(args): """ Main body of the demo of a basic version of sequence parallel by using PyTorch native APIs. """ - print(f"Running SP example on rank {rank}.") - setup(rank, args.world_size) + torchrun_setup() + + # understand world topology + _rank = int(os.environ["RANK"]) + _local_rank = int(os.environ["LOCAL_RANK"]) + _world_size = int(os.environ["WORLD_SIZE"]) + _local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) + + torch.cuda.set_device(_local_rank) + + def rank_print(msg): + """helper function to print only on global rank 0""" + if _rank==0: + print(f"{msg}") + + print(f"Running basic Megatron style TP example on rank {_rank}.") # create a sharding plan based on the given world_size. - device_mesh = DeviceMesh("cuda", torch.arange(0, args.world_size)) + device_mesh = DeviceMesh("cuda", torch.arange(0, _world_size)) + + device = f"cuda" + #device_mesh = init_device_mesh(device, torch.arange(0,_world_size)) # , mesh_dim_names=("sp",)) + assert device_mesh is not None, "unable to create valid device mesh" + + rank_print(f"Device Mesh created: {device_mesh=}") + # create model and move it to GPU with id rank - model = ToyModel().cuda(rank) + model = ToyModel().cuda(_rank) + + # Custom parallelization plan for the model + sp_model = parallelize_module(module = model, + device_mesh = device_mesh, + parallelize_plan = { + "net1": ColwiseParallel(input_layouts=Shard(0)), + "net1": RowwiseParallel(input_layouts=Shard(0)), + }, + ) + + # Create a optimizer for the parallelized module. - LR = 0.25 - optimizer = torch.optim.SGD(model.parameters(), lr=LR) - # Parallelize the module based on the given Parallel Style. - model = parallelize_module(model, device_mesh, SequenceParallel()) + lr = 0.25 + optimizer = torch.optim.AdamW(sp_model.parameters(), lr=lr) + # Perform a num of iterations of forward/backward # and optimizations for the sharded module. - for _ in range(args.iter_nums): + num_iters = 10 + rank_print(f"Sequence Parallel training starting...") + + for i in range(num_iters): # For SP, input can be different across all ranks. - inp = torch.rand(20, 10).cuda(rank) - output = model(inp) + inp = torch.rand(20, 10).cuda(_rank) + output = sp_model(inp) output.sum().backward() optimizer.step() + rank_print(f"Sequence Parallel iter {i} completed") + rank_print(f"Sequence Parallel training completed!") cleanup() if __name__ == "__main__": - n_gpus = torch.cuda.device_count() parser = argparse.ArgumentParser() # This is passed in via cmd - parser.add_argument("--world_size", type=int, default=n_gpus) - parser.add_argument("--iter_nums", type=int, default=10) args = parser.parse_args() # The main entry point is called directly without using subprocess - if n_gpus < 2: - print("Requires at least 2 GPUs to run.") - elif not SP_AVAILABLE: - print( - "PyTorch doesn't have Sequence Parallelism available," - " need nightly build." - ) - else: - mp.spawn(demo_sp, args=(args,), nprocs=args.world_size, join=True) + demo_sp(args,) From 3aa1c5344b59bdbe58a26fca56a23a126230425e Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Mon, 20 Nov 2023 20:15:53 -0800 Subject: [PATCH 09/20] seq parallel now using init_device_mesh --- distributed/tensor_parallelism/sequence_parallel_example.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/distributed/tensor_parallelism/sequence_parallel_example.py b/distributed/tensor_parallelism/sequence_parallel_example.py index 07c022b34f..6b47329a43 100644 --- a/distributed/tensor_parallelism/sequence_parallel_example.py +++ b/distributed/tensor_parallelism/sequence_parallel_example.py @@ -61,11 +61,10 @@ def rank_print(msg): print(f"Running basic Megatron style TP example on rank {_rank}.") - # create a sharding plan based on the given world_size. - device_mesh = DeviceMesh("cuda", torch.arange(0, _world_size)) + # create a mesh based on the given world_size. device = f"cuda" - #device_mesh = init_device_mesh(device, torch.arange(0,_world_size)) # , mesh_dim_names=("sp",)) + device_mesh = init_device_mesh(device_type = device,mesh_shape = (_world_size,)) assert device_mesh is not None, "unable to create valid device mesh" rank_print(f"Device Mesh created: {device_mesh=}") From b54e2ec98286ed8604bdce7d1dc2028d6726344e Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Mon, 20 Nov 2023 20:48:33 -0800 Subject: [PATCH 10/20] tp and sp examples all working and updated --- .../sequence_parallel_example.py | 6 +- .../tensor_parallel_example.py | 88 +++++++++++++------ 2 files changed, 63 insertions(+), 31 deletions(-) diff --git a/distributed/tensor_parallelism/sequence_parallel_example.py b/distributed/tensor_parallelism/sequence_parallel_example.py index 6b47329a43..be6379c8a4 100644 --- a/distributed/tensor_parallelism/sequence_parallel_example.py +++ b/distributed/tensor_parallelism/sequence_parallel_example.py @@ -59,9 +59,9 @@ def rank_print(msg): if _rank==0: print(f"{msg}") - print(f"Running basic Megatron style TP example on rank {_rank}.") + print(f"Running basic Megatron style Sequence Parallel example on rank {_rank}.") - # create a mesh based on the given world_size. + # create a device mesh based on the given world_size. device = f"cuda" device_mesh = init_device_mesh(device_type = device,mesh_shape = (_world_size,)) @@ -78,7 +78,7 @@ def rank_print(msg): device_mesh = device_mesh, parallelize_plan = { "net1": ColwiseParallel(input_layouts=Shard(0)), - "net1": RowwiseParallel(input_layouts=Shard(0)), + "net2": RowwiseParallel(output_layouts=Shard(0)), }, ) diff --git a/distributed/tensor_parallelism/tensor_parallel_example.py b/distributed/tensor_parallelism/tensor_parallel_example.py index 18133d8eea..15b5f0abe2 100755 --- a/distributed/tensor_parallelism/tensor_parallel_example.py +++ b/distributed/tensor_parallelism/tensor_parallel_example.py @@ -1,11 +1,17 @@ import argparse - +import os import torch -import torch.multiprocessing as mp -from torch.distributed._tensor import DeviceMesh -from torch.distributed.tensor.parallel import PairwiseParallel, parallelize_module -from utils import cleanup, setup, ToyModel +from torch.distributed._tensor.device_mesh import init_device_mesh +from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard + +from torch.distributed.tensor.parallel import ( + PairwiseParallel, + parallelize_module, + ColwiseParallel, + RowwiseParallel, +) +from utils import cleanup, ToyModel, torchrun_setup """ @@ -40,48 +46,74 @@ """ -def demo_tp(rank, args): +def demo_tp(args): """ Main body of the demo of a basic version of tensor parallel by using PyTorch native APIs. """ - print(f"Running basic Megatron style TP example on rank {rank}.") - setup(rank, args.world_size) + torchrun_setup() - # create a sharding plan based on the given world_size. - device_mesh = DeviceMesh("cuda", torch.arange(0, args.world_size)) + # understand world topology + _rank = int(os.environ["RANK"]) + _local_rank = int(os.environ["LOCAL_RANK"]) + _world_size = int(os.environ["WORLD_SIZE"]) + _local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) + + torch.cuda.set_device(_local_rank) + + def rank_print(msg): + """helper function to print only on global rank 0""" + if _rank==0: + print(f"{msg}") + + print(f"Running basic Megatron style TP example on rank {_rank}.") - # create model and move it to GPU with id rank - model = ToyModel().cuda(rank) - # Create a optimizer for the parallelized module. - LR = 0.25 - optimizer = torch.optim.SGD(model.parameters(), lr=LR) - # Parallelize the module based on the given Parallel Style. - model = parallelize_module(model, device_mesh, PairwiseParallel()) + # create a device mesh based on the given world_size. + + device = f"cuda" + device_mesh = init_device_mesh(device_type = device,mesh_shape = (_world_size,)) + assert device_mesh is not None, "unable to create valid device mesh" + + rank_print(f"Device Mesh created: {device_mesh=}") + + # create model and move it to GPU with id rank + tp_model = ToyModel().cuda(_rank) + + # Create an optimizer for the parallelized module. + lr = 0.25 + optimizer = torch.optim.AdamW(tp_model.parameters(), lr=lr) + + # Custom parallelization plan for the model + tp_model = parallelize_module(module = tp_model, + device_mesh = device_mesh, + parallelize_plan = { + "net1": ColwiseParallel(), + "net2": RowwiseParallel(), + }, + ) # Perform a num of iterations of forward/backward # and optimizations for the sharded module. - for i in range(args.iter_nums): + num_iters = 10 + rank_print(f"Tensor Parallel training starting...") + + for i in range(num_iters): # For TP, input needs to be same across all TP ranks. # Setting the random seed is to mimic the behavior of dataloader. torch.manual_seed(i) - inp = torch.rand(20, 10).cuda(rank) - output = model(inp) + inp = torch.rand(20, 10).cuda(_rank) + output = tp_model(inp) output.sum().backward() optimizer.step() + rank_print(f"Tensor Parallel iter {i} completed") + + rank_print(f"Tensor Parallel training completed!") cleanup() if __name__ == "__main__": - n_gpus = torch.cuda.device_count() parser = argparse.ArgumentParser() # This is passed in via cmd - parser.add_argument("--world_size", type=int, default=n_gpus) - parser.add_argument("--iter_nums", type=int, default=10) args = parser.parse_args() - # The main entry point is called directly without using subprocess - if n_gpus < 2: - print("Requires at least 2 GPUs to run.") - else: - mp.spawn(demo_tp, args=(args,), nprocs=args.world_size, join=True) + demo_tp(args) From 4889e3bc86e078b972266295ee4d54cbdc373186 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Tue, 21 Nov 2023 13:28:11 -0800 Subject: [PATCH 11/20] updates from code review --- distributed/tensor_parallelism/run_example.sh | 13 ++ .../run_sequence_parallel.sh | 1 - .../tensor_parallelism/run_tensor_parallel.sh | 1 - .../tensor_parallelism/run_twod_parallel.sh | 1 - .../sequence_parallel_example.py | 99 +++++------- .../tensor_parallel_example.py | 111 ++++++------- .../two_d_parallel_example.py | 153 ++++++++---------- distributed/tensor_parallelism/utils.py | 5 +- 8 files changed, 177 insertions(+), 207 deletions(-) create mode 100644 distributed/tensor_parallelism/run_example.sh delete mode 100644 distributed/tensor_parallelism/run_sequence_parallel.sh delete mode 100644 distributed/tensor_parallelism/run_tensor_parallel.sh delete mode 100644 distributed/tensor_parallelism/run_twod_parallel.sh diff --git a/distributed/tensor_parallelism/run_example.sh b/distributed/tensor_parallelism/run_example.sh new file mode 100644 index 0000000000..aebe77d985 --- /dev/null +++ b/distributed/tensor_parallelism/run_example.sh @@ -0,0 +1,13 @@ + +# To run samples: +# bash run_example.sh file_to_run.py num_gpus +# where file_to_run = example to launch. Default = 'two_d_parallel_example.py' +# num_gpus = num local gpus to use (must be at least 2). Default =4 + +# samples to run include: +# sequence_parallel_example.py +# tensor_parallel_example.py +# two_d_parallel_example.py + +echo "Launching ${1:-two_d_parallel_example.py} with ${2:-4} gpus" +torchrun --nnodes=1 --nproc_per_node=${2:-4} --rdzv_id=101 --rdzv_endpoint="localhost:5972" ${1:-two_d_parallel_example.py} diff --git a/distributed/tensor_parallelism/run_sequence_parallel.sh b/distributed/tensor_parallelism/run_sequence_parallel.sh deleted file mode 100644 index 1c35e7ddab..0000000000 --- a/distributed/tensor_parallelism/run_sequence_parallel.sh +++ /dev/null @@ -1 +0,0 @@ -torchrun --nnodes=1 --nproc_per_node=4 --rdzv_id=101 --rdzv_endpoint="localhost:5973" sequence_parallel_example.py diff --git a/distributed/tensor_parallelism/run_tensor_parallel.sh b/distributed/tensor_parallelism/run_tensor_parallel.sh deleted file mode 100644 index 05c5e34880..0000000000 --- a/distributed/tensor_parallelism/run_tensor_parallel.sh +++ /dev/null @@ -1 +0,0 @@ -torchrun --nnodes=1 --nproc_per_node=4 --rdzv_id=101 --rdzv_endpoint="localhost:5973" tensor_parallel_example.py diff --git a/distributed/tensor_parallelism/run_twod_parallel.sh b/distributed/tensor_parallelism/run_twod_parallel.sh deleted file mode 100644 index 1c173c372c..0000000000 --- a/distributed/tensor_parallelism/run_twod_parallel.sh +++ /dev/null @@ -1 +0,0 @@ -torchrun --nnodes=1 --nproc_per_node=4 --rdzv_id=101 --rdzv_endpoint="localhost:5973" two_d_parallel_example.py diff --git a/distributed/tensor_parallelism/sequence_parallel_example.py b/distributed/tensor_parallelism/sequence_parallel_example.py index be6379c8a4..bf9185443a 100644 --- a/distributed/tensor_parallelism/sequence_parallel_example.py +++ b/distributed/tensor_parallelism/sequence_parallel_example.py @@ -1,4 +1,3 @@ -import argparse import os import torch @@ -6,12 +5,11 @@ from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard from torch.distributed.tensor.parallel import ( - PairwiseParallel, parallelize_module, ColwiseParallel, RowwiseParallel, ) -from utils import cleanup, ToyModel, torchrun_setup +from utils import ToyModel try: from torch.distributed.tensor.parallel import ( @@ -39,75 +37,58 @@ """ -def demo_sp(args): - """ - Main body of the demo of a basic version of sequence parallel by using - PyTorch native APIs. - """ - torchrun_setup() - # understand world topology - _rank = int(os.environ["RANK"]) - _local_rank = int(os.environ["LOCAL_RANK"]) - _world_size = int(os.environ["WORLD_SIZE"]) - _local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) - - torch.cuda.set_device(_local_rank) - - def rank_print(msg): - """helper function to print only on global rank 0""" - if _rank==0: - print(f"{msg}") +""" +Main body of the demo of a basic version of sequence parallel by using +PyTorch native APIs. +""" - print(f"Running basic Megatron style Sequence Parallel example on rank {_rank}.") +_rank = int(os.environ["RANK"]) - # create a device mesh based on the given world_size. - device = f"cuda" - device_mesh = init_device_mesh(device_type = device,mesh_shape = (_world_size,)) - assert device_mesh is not None, "unable to create valid device mesh" +def rank_print(msg): + """helper function to print only on global rank 0""" + if _rank==0: + print(f"{msg}") - rank_print(f"Device Mesh created: {device_mesh=}") +print(f"Running basic Megatron style Sequence Parallel example on rank {_rank}.") +# create a device mesh based on the given world_size. +_device = f"cuda" +device_mesh = init_device_mesh(device_type = _device,mesh_shape = (int(os.environ["WORLD_SIZE"]),)) - # create model and move it to GPU with id rank - model = ToyModel().cuda(_rank) +rank_print(f"Device Mesh created: {device_mesh=}") - # Custom parallelization plan for the model - sp_model = parallelize_module(module = model, - device_mesh = device_mesh, - parallelize_plan = { - "net1": ColwiseParallel(input_layouts=Shard(0)), - "net2": RowwiseParallel(output_layouts=Shard(0)), - }, - ) +# create model and move it to GPU. Init_device_mesh has already assigned gpu ids... +model = ToyModel().to(_device) - # Create a optimizer for the parallelized module. - lr = 0.25 - optimizer = torch.optim.AdamW(sp_model.parameters(), lr=lr) +# Custom parallelization plan for the model +sp_model = parallelize_module(module = model, + device_mesh = device_mesh, + parallelize_plan = { + "net1": ColwiseParallel(input_layouts=Shard(0)), + "net2": RowwiseParallel(output_layouts=Shard(0)), + }, +) - # Perform a num of iterations of forward/backward - # and optimizations for the sharded module. - num_iters = 10 - rank_print(f"Sequence Parallel training starting...") +# Create a optimizer for the parallelized module. +lr = 0.25 +optimizer = torch.optim.AdamW(sp_model.parameters(), lr=lr, foreach=True) - for i in range(num_iters): - # For SP, input can be different across all ranks. - inp = torch.rand(20, 10).cuda(_rank) - output = sp_model(inp) - output.sum().backward() - optimizer.step() - rank_print(f"Sequence Parallel iter {i} completed") - rank_print(f"Sequence Parallel training completed!") - cleanup() +# Perform a num of iterations of forward/backward +# and optimizations for the sharded module. +num_iters = 10 +rank_print(f"Sequence Parallel training starting...") +for i in range(num_iters): + # For SP, input can be different across all ranks. + inp = torch.rand(20, 10,device=_device) + output = sp_model(inp) + output.sum().backward() + optimizer.step() + rank_print(f"Sequence Parallel iter {i} completed") -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # This is passed in via cmd - args = parser.parse_args() - # The main entry point is called directly without using subprocess - demo_sp(args,) +rank_print(f"Sequence Parallel training completed!") diff --git a/distributed/tensor_parallelism/tensor_parallel_example.py b/distributed/tensor_parallelism/tensor_parallel_example.py index 15b5f0abe2..c19f35c217 100755 --- a/distributed/tensor_parallelism/tensor_parallel_example.py +++ b/distributed/tensor_parallelism/tensor_parallel_example.py @@ -1,4 +1,4 @@ -import argparse + import os import torch @@ -6,12 +6,11 @@ from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard from torch.distributed.tensor.parallel import ( - PairwiseParallel, parallelize_module, ColwiseParallel, RowwiseParallel, ) -from utils import cleanup, ToyModel, torchrun_setup +from utils import ToyModel """ @@ -46,74 +45,66 @@ """ -def demo_tp(args): - """ - Main body of the demo of a basic version of tensor parallel by using - PyTorch native APIs. - """ - torchrun_setup() - - # understand world topology - _rank = int(os.environ["RANK"]) - _local_rank = int(os.environ["LOCAL_RANK"]) - _world_size = int(os.environ["WORLD_SIZE"]) - _local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) - - torch.cuda.set_device(_local_rank) - def rank_print(msg): - """helper function to print only on global rank 0""" - if _rank==0: - print(f"{msg}") +""" +Main body of the demo of a basic version of tensor parallel by using +PyTorch native APIs. +""" - print(f"Running basic Megatron style TP example on rank {_rank}.") +# understand world topology +_rank = int(os.environ["RANK"]) +_local_rank = int(os.environ["LOCAL_RANK"]) +_world_size = int(os.environ["WORLD_SIZE"]) +#_local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) - # create a device mesh based on the given world_size. - device = f"cuda" - device_mesh = init_device_mesh(device_type = device,mesh_shape = (_world_size,)) - assert device_mesh is not None, "unable to create valid device mesh" +def rank_print(msg): + """helper function to print only on global rank 0""" + if _rank==0: + print(f"{msg}") - rank_print(f"Device Mesh created: {device_mesh=}") +print(f"Running basic Megatron style TP example on rank {_rank}.") +assert _world_size % 2 == 0, f"TP examples require even number of GPUs, but got {_world_size} gpus" - # create model and move it to GPU with id rank - tp_model = ToyModel().cuda(_rank) - # Create an optimizer for the parallelized module. - lr = 0.25 - optimizer = torch.optim.AdamW(tp_model.parameters(), lr=lr) - # Custom parallelization plan for the model - tp_model = parallelize_module(module = tp_model, - device_mesh = device_mesh, - parallelize_plan = { - "net1": ColwiseParallel(), - "net2": RowwiseParallel(), - }, - ) - # Perform a num of iterations of forward/backward - # and optimizations for the sharded module. - num_iters = 10 - rank_print(f"Tensor Parallel training starting...") +# create a device mesh based on the given world_size. - for i in range(num_iters): - # For TP, input needs to be same across all TP ranks. - # Setting the random seed is to mimic the behavior of dataloader. - torch.manual_seed(i) - inp = torch.rand(20, 10).cuda(_rank) - output = tp_model(inp) - output.sum().backward() - optimizer.step() - rank_print(f"Tensor Parallel iter {i} completed") +_device = f"cuda" +device_mesh = init_device_mesh(device_type = _device,mesh_shape = (_world_size,)) +assert device_mesh is not None, "unable to create valid device mesh" - rank_print(f"Tensor Parallel training completed!") +rank_print(f"Device Mesh created: {device_mesh=}") - cleanup() +# create model and move it to GPU - init_device_mesh has already mapped GPU ids. +tp_model = ToyModel().to(_device) +# Create an optimizer for the parallelized module. +lr = 0.25 +optimizer = torch.optim.AdamW(tp_model.parameters(), lr=lr, foreach=True) -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # This is passed in via cmd - args = parser.parse_args() - demo_tp(args) +# Custom parallelization plan for the model +tp_model = parallelize_module(module = tp_model, + device_mesh = device_mesh, + parallelize_plan = { + "net1": ColwiseParallel(), + "net2": RowwiseParallel(), + }, +) +# Perform a num of iterations of forward/backward +# and optimizations for the sharded module. +num_iters = 10 +rank_print(f"Tensor Parallel training starting...") + +for i in range(num_iters): + # For TP, input needs to be same across all TP ranks. + # Setting the random seed is to mimic the behavior of dataloader. + torch.manual_seed(i) + inp = torch.rand(20, 10, device=_device) + output = tp_model(inp) + output.sum().backward() + optimizer.step() + rank_print(f"Tensor Parallel iter {i} completed") + +rank_print(f"Tensor Parallel training completed!") diff --git a/distributed/tensor_parallelism/two_d_parallel_example.py b/distributed/tensor_parallelism/two_d_parallel_example.py index 644967968e..2a053b77a6 100644 --- a/distributed/tensor_parallelism/two_d_parallel_example.py +++ b/distributed/tensor_parallelism/two_d_parallel_example.py @@ -1,4 +1,3 @@ -import argparse import torch import torch.distributed as dist @@ -6,7 +5,6 @@ from torch.distributed._tensor import DeviceMesh from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.tensor.parallel import ( - PairwiseParallel, parallelize_module, ColwiseParallel, RowwiseParallel, @@ -18,7 +16,7 @@ from torch.distributed._tensor.device_mesh import init_device_mesh import os -from utils import cleanup, torchrun_setup, MLP_swiglu +from utils import MLP_swiglu """ @@ -53,109 +51,98 @@ """ -def demo_2d(args): - """ - Main body of the demo of a basic version of tensor parallel by using - PyTorch native APIs. - """ - torchrun_setup() - # understand world topology - _rank = int(os.environ["RANK"]) - _local_rank = int(os.environ["LOCAL_RANK"]) - _world_size = int(os.environ["WORLD_SIZE"]) - _local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) - - torch.cuda.set_device(_local_rank) +""" +Main body of the demo of a basic version of tensor parallel by using +PyTorch native APIs. +""" +tp_size = 2 - def rank_print(msg): - """helper function to print only on global rank 0""" - if _rank==0: - print(f"{msg}") - print(f"Running basic Megatron style TP example on rank {_rank}.") +# understand world topology +_rank = int(os.environ["RANK"]) +_local_rank = int(os.environ["LOCAL_RANK"]) +_world_size = int(os.environ["WORLD_SIZE"]) +_local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) - assert ( - _world_size % args.tp_size == 0 - ), f"World size {_world_size} needs to be divisible by TP size {args.tp_size}" - device = f"cuda" +def rank_print(msg): + """helper function to print only on global rank 0""" + if _rank==0: + print(f"{msg}") - # create a sharding plan based on the given world_size. +print(f"Running basic Megatron style TP example on rank {_rank}.") +assert ( + _world_size % tp_size == 0 +), f"World size {_world_size} needs to be divisible by TP size {tp_size}" - dp_size = _world_size // args.tp_size - # Create a device mesh with 2 dimensions. - # First dim is the data parallel dimension - # Second dim is the tensor parallel dimension. - device_mesh = init_device_mesh(device, (dp_size, args.tp_size), mesh_dim_names=("dp","tp")) - assert device_mesh is not None, "unable to create valid device mesh" +_device = f"cuda" - rank_print(f"Device Mesh created: {device_mesh=}") - tp_mesh = device_mesh["tp"] - dp_mesh = device_mesh["dp"] +# create a sharding plan based on the given world_size. - # To support identical inputs for TP groups, we need the dp process group - dp_pg = device_mesh.get_dim_groups()[0] +dp_size = _world_size // tp_size - # For TP, input needs to be same across all TP ranks. - # while for SP, input can be different across all ranks. - # We will use dp_rank for setting the random seed - # to mimic the behavior of the dataloader. - dp_rank = dist.get_rank(dp_pg) +# Create a device mesh with 2 dimensions. +# First dim is the data parallel dimension +# Second dim is the tensor parallel dimension. +device_mesh = init_device_mesh(_device, (dp_size, tp_size), mesh_dim_names=("dp","tp")) +rank_print(f"Device Mesh created: {device_mesh=}") +tp_mesh = device_mesh["tp"] +dp_mesh = device_mesh["dp"] - # create model and move it to GPU with id rank - _mlp_dim = 1024 - base_model_swiglu = MLP_swiglu(mlp_dim=_mlp_dim).cuda(_local_rank) +# To support identical inputs for TP groups, we need the dp process group +dp_pg = device_mesh.get_dim_groups()[0] +# For TP, input needs to be same across all TP ranks. +# while for SP, input can be different across all ranks. +# We will use dp_rank for setting the random seed +# to mimic the behavior of the dataloader. +dp_rank = dist.get_rank(dp_pg) - # Custom parallelization plan for the swiglu MLP model - custom_tp_model = parallelize_module(module = base_model_swiglu, - device_mesh = tp_mesh, - parallelize_plan = { - "in_proj": ColwiseParallel(), - "gate_proj": ColwiseParallel(), - "out_proj": RowwiseParallel(), - }, - ) +# create model and move it to GPU with id rank +_mlp_dim = 1024 +base_model_swiglu = MLP_swiglu(mlp_dim=_mlp_dim).to(_device) - rank_print(f"Model after parallelization {custom_tp_model=}\n") - # Init FSDP using the dp device mesh - sharded_model = FSDP(custom_tp_model, device_mesh = dp_mesh, use_orig_params=True) - # Create an optimizer for the parallelized and sharded model. - lr = 3e-3 - rank_print(f"Creating AdamW optimizer with learning rate {lr}") - optimizer = torch.optim.AdamW(sharded_model.parameters(), lr=lr) +# Custom parallelization plan for the swiglu MLP model +custom_tp_model = parallelize_module(module = base_model_swiglu, + device_mesh = tp_mesh, + parallelize_plan = { + "in_proj": ColwiseParallel(), + "gate_proj": ColwiseParallel(), + "out_proj": RowwiseParallel(), + }, +) - # Training loop: - # Perform a num of iterations of forward/backward - # and optimizations for the sharded module. - rank_print(f"\nStarting 2D training...") - num_iterations = 10 - batch_size = 2 +rank_print(f"Model after parallelization {custom_tp_model=}\n") - for i in range(num_iterations): - # seeding with dp_rank to ensure identical inputs for TP groups - torch.manual_seed(i + dp_rank) - inp = torch.rand(batch_size, _mlp_dim).cuda(_rank) +# Init FSDP using the dp device mesh +sharded_model = FSDP(custom_tp_model, device_mesh = dp_mesh, use_orig_params=True) - output = sharded_model(inp) - output.sum().backward() - optimizer.step() - rank_print(f"2D iter {i} complete") +# Create an optimizer for the parallelized and sharded model. +lr = 3e-3 +rank_print(f"Creating AdamW optimizer with learning rate {lr}") +optimizer = torch.optim.AdamW(sharded_model.parameters(), lr=lr, foreach=True) - rank_print(f"2D training successfully completed!") +# Training loop: +# Perform a num of iterations of forward/backward +# and optimizations for the sharded module. +rank_print(f"\nStarting 2D training...") +num_iterations = 10 +batch_size = 2 - cleanup() +for i in range(num_iterations): + # seeding with dp_rank to ensure identical inputs for TP groups + torch.manual_seed(i + dp_rank) + inp = torch.rand(batch_size, _mlp_dim, device=_device) + output = sharded_model(inp) + output.sum().backward() + optimizer.step() + rank_print(f"2D iter {i} complete") -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # This is passed in via cmd - parser.add_argument("--tp_size", type=int, default=2) - args = parser.parse_args() - demo_2d(args) +rank_print(f"2D training successfully completed!") diff --git a/distributed/tensor_parallelism/utils.py b/distributed/tensor_parallelism/utils.py index 35ae55740f..3923eccd69 100644 --- a/distributed/tensor_parallelism/utils.py +++ b/distributed/tensor_parallelism/utils.py @@ -17,10 +17,11 @@ def setup(rank, world_size): def torchrun_setup(): """we use torchrun for init so no params needed here""" - dist.init_process_group("nccl") + #dist.init_process_group("nccl") def cleanup(): - dist.destroy_process_group() + #dist.destroy_process_group() + pass class ToyModel(nn.Module): From b2151787b1bfdc1f71fa661088da4d994e9ee637 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Tue, 21 Nov 2023 18:21:24 -0800 Subject: [PATCH 12/20] remove utils.py. Sample models created in example files --- .../sequence_parallel_example.py | 19 ++++++- .../tensor_parallel_example.py | 19 ++++++- .../two_d_parallel_example.py | 26 ++++++++- distributed/tensor_parallelism/utils.py | 56 ------------------- 4 files changed, 56 insertions(+), 64 deletions(-) delete mode 100644 distributed/tensor_parallelism/utils.py diff --git a/distributed/tensor_parallelism/sequence_parallel_example.py b/distributed/tensor_parallelism/sequence_parallel_example.py index bf9185443a..eda3231a93 100644 --- a/distributed/tensor_parallelism/sequence_parallel_example.py +++ b/distributed/tensor_parallelism/sequence_parallel_example.py @@ -1,5 +1,7 @@ import os import torch +import torch.nn as nn +import torch.nn.functional as F from torch.distributed._tensor.device_mesh import init_device_mesh from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard @@ -9,7 +11,7 @@ ColwiseParallel, RowwiseParallel, ) -from utils import ToyModel + try: from torch.distributed.tensor.parallel import ( @@ -37,6 +39,17 @@ """ +class ToyModel(nn.Module): + """ MLP based model """ + def __init__(self): + super().__init__() + self.in_proj = nn.Linear(10, 32) + self.relu = nn.ReLU() + self.out_proj = nn.Linear(32, 5) + + def forward(self, x): + return self.out_proj(self.relu(self.in_proj(x))) + """ Main body of the demo of a basic version of sequence parallel by using @@ -67,8 +80,8 @@ def rank_print(msg): sp_model = parallelize_module(module = model, device_mesh = device_mesh, parallelize_plan = { - "net1": ColwiseParallel(input_layouts=Shard(0)), - "net2": RowwiseParallel(output_layouts=Shard(0)), + "in_proj": ColwiseParallel(input_layouts=Shard(0)), + "out_proj": RowwiseParallel(output_layouts=Shard(0)), }, ) diff --git a/distributed/tensor_parallelism/tensor_parallel_example.py b/distributed/tensor_parallelism/tensor_parallel_example.py index c19f35c217..99a0014782 100755 --- a/distributed/tensor_parallelism/tensor_parallel_example.py +++ b/distributed/tensor_parallelism/tensor_parallel_example.py @@ -1,6 +1,8 @@ import os import torch +import torch.nn as nn +import torch.nn.functional as F from torch.distributed._tensor.device_mesh import init_device_mesh from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard @@ -10,7 +12,8 @@ ColwiseParallel, RowwiseParallel, ) -from utils import ToyModel + + """ @@ -45,6 +48,16 @@ """ +class ToyModel(nn.Module): + """ MLP based model """ + def __init__(self): + super(ToyModel, self).__init__() + self.in_proj = nn.Linear(10, 32) + self.relu = nn.ReLU() + self.out_proj = nn.Linear(32, 5) + + def forward(self, x): + return self.out_proj(self.relu(self.in_proj(x))) """ Main body of the demo of a basic version of tensor parallel by using @@ -88,8 +101,8 @@ def rank_print(msg): tp_model = parallelize_module(module = tp_model, device_mesh = device_mesh, parallelize_plan = { - "net1": ColwiseParallel(), - "net2": RowwiseParallel(), + "in_proj": ColwiseParallel(), + "out_proj": RowwiseParallel(), }, ) # Perform a num of iterations of forward/backward diff --git a/distributed/tensor_parallelism/two_d_parallel_example.py b/distributed/tensor_parallelism/two_d_parallel_example.py index 2a053b77a6..8615d44e27 100644 --- a/distributed/tensor_parallelism/two_d_parallel_example.py +++ b/distributed/tensor_parallelism/two_d_parallel_example.py @@ -1,6 +1,8 @@ import torch import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F from torch.distributed._tensor import DeviceMesh from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -16,7 +18,6 @@ from torch.distributed._tensor.device_mesh import init_device_mesh import os -from utils import MLP_swiglu """ @@ -49,8 +50,30 @@ More details can be seen in the slide: https://docs.google.com/presentation/d/17g6WqrO00rP3MsxbRENsPpjrlSkwiA_QB4r93_eB5is/ """ +def find_multiple(n: int, k: int) -> int: + """ function to find resizing multiple for SwiGLU MLP """ + if n % k == 0: + return n + return n + k - (n % k) +class MLP_swiglu(nn.Module): + """ SwiGLU to showcase a Llama style MLP model """ + + def __init__(self, mlp_dim: int= 1024) -> None: + super().__init__() + hidden_dim = 4 * mlp_dim + scaled_hidden = int(2 * hidden_dim / 3) + rounded_hidden = find_multiple(scaled_hidden, 256) + + self.in_proj = nn.Linear(mlp_dim, rounded_hidden, bias=False) + self.gate_proj = nn.Linear(mlp_dim, rounded_hidden, bias=False) + self.out_proj = nn.Linear(rounded_hidden, mlp_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.silu(self.in_proj(x)) * self.gate_proj(x) + x = self.out_proj(x) + return x """ Main body of the demo of a basic version of tensor parallel by using @@ -107,7 +130,6 @@ def rank_print(msg): base_model_swiglu = MLP_swiglu(mlp_dim=_mlp_dim).to(_device) - # Custom parallelization plan for the swiglu MLP model custom_tp_model = parallelize_module(module = base_model_swiglu, device_mesh = tp_mesh, diff --git a/distributed/tensor_parallelism/utils.py b/distributed/tensor_parallelism/utils.py deleted file mode 100644 index 3923eccd69..0000000000 --- a/distributed/tensor_parallelism/utils.py +++ /dev/null @@ -1,56 +0,0 @@ -import argparse -import os - -import torch -import torch.distributed as dist -import torch.multiprocessing as mp -import torch.nn as nn -import torch.nn.functional as F - -def setup(rank, world_size): - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "12359" - - # initialize the process group - dist.init_process_group("nccl", rank=rank, world_size=world_size) - torch.cuda.set_device(rank) - -def torchrun_setup(): - """we use torchrun for init so no params needed here""" - #dist.init_process_group("nccl") - -def cleanup(): - #dist.destroy_process_group() - pass - - -class ToyModel(nn.Module): - def __init__(self): - super(ToyModel, self).__init__() - self.net1 = nn.Linear(10, 32) - self.relu = nn.ReLU() - self.net2 = nn.Linear(32, 5) - - def forward(self, x): - return self.net2(self.relu(self.net1(x))) - -def find_multiple(n: int, k: int) -> int: - if n % k == 0: - return n - return n + k - (n % k) - -class MLP_swiglu(nn.Module): - def __init__(self, mlp_dim: int= 1024) -> None: - super().__init__() - hidden_dim = 4 * mlp_dim - scaled_hidden = int(2 * hidden_dim / 3) - rounded_hidden = find_multiple(scaled_hidden, 256) - - self.in_proj = nn.Linear(mlp_dim, rounded_hidden, bias=False) - self.gate_proj = nn.Linear(mlp_dim, rounded_hidden, bias=False) - self.out_proj = nn.Linear(rounded_hidden, mlp_dim, bias=False) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = F.silu(self.in_proj(x)) * self.gate_proj(x) - x = self.out_proj(x) - return x From 242c3280910f095a8275c8a6ccb85fc7b2467bb0 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Tue, 21 Nov 2023 19:44:57 -0800 Subject: [PATCH 13/20] remove originals.py, leftover imports, various updates from code review feedback. --- ...parallel_example.py => fsdp_tp_example.py} | 37 +++-- distributed/tensor_parallelism/original.py | 127 ------------------ distributed/tensor_parallelism/run_example.sh | 12 +- .../sequence_parallel_example.py | 42 +++--- .../tensor_parallel_example.py | 42 +++--- 5 files changed, 57 insertions(+), 203 deletions(-) rename distributed/tensor_parallelism/{two_d_parallel_example.py => fsdp_tp_example.py} (85%) delete mode 100644 distributed/tensor_parallelism/original.py diff --git a/distributed/tensor_parallelism/two_d_parallel_example.py b/distributed/tensor_parallelism/fsdp_tp_example.py similarity index 85% rename from distributed/tensor_parallelism/two_d_parallel_example.py rename to distributed/tensor_parallelism/fsdp_tp_example.py index 8615d44e27..e185fb08db 100644 --- a/distributed/tensor_parallelism/two_d_parallel_example.py +++ b/distributed/tensor_parallelism/fsdp_tp_example.py @@ -14,10 +14,9 @@ from torch.distributed._shard.sharded_tensor import ShardedTensor -from torch.distributed._tensor import DTensor, Replicate, sharding_prop from torch.distributed._tensor.device_mesh import init_device_mesh import os - +import logging """ @@ -81,37 +80,35 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ tp_size = 2 +logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p', level=logging.INFO) +logger = logging.getLogger(__name__) + # understand world topology _rank = int(os.environ["RANK"]) -_local_rank = int(os.environ["LOCAL_RANK"]) _world_size = int(os.environ["WORLD_SIZE"]) -_local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) - -def rank_print(msg): +# +def rank_log(msg): """helper function to print only on global rank 0""" if _rank==0: - print(f"{msg}") + logger.info(f" {msg}") -print(f"Running basic Megatron style TP example on rank {_rank}.") +print(f"Starting PyTorch 2D (FSDP + TP) example on rank {_rank}.") assert ( _world_size % tp_size == 0 ), f"World size {_world_size} needs to be divisible by TP size {tp_size}" -_device = f"cuda" - # create a sharding plan based on the given world_size. - dp_size = _world_size // tp_size # Create a device mesh with 2 dimensions. # First dim is the data parallel dimension # Second dim is the tensor parallel dimension. -device_mesh = init_device_mesh(_device, (dp_size, tp_size), mesh_dim_names=("dp","tp")) +device_mesh = init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("dp","tp")) -rank_print(f"Device Mesh created: {device_mesh=}") +rank_log(f"Device Mesh created: {device_mesh=}") tp_mesh = device_mesh["tp"] dp_mesh = device_mesh["dp"] @@ -127,7 +124,7 @@ def rank_print(msg): # create model and move it to GPU with id rank _mlp_dim = 1024 -base_model_swiglu = MLP_swiglu(mlp_dim=_mlp_dim).to(_device) +base_model_swiglu = MLP_swiglu(mlp_dim=_mlp_dim).to("cuda") # Custom parallelization plan for the swiglu MLP model @@ -140,31 +137,31 @@ def rank_print(msg): }, ) -rank_print(f"Model after parallelization {custom_tp_model=}\n") +rank_log(f"Model after parallelization {custom_tp_model=}\n") # Init FSDP using the dp device mesh sharded_model = FSDP(custom_tp_model, device_mesh = dp_mesh, use_orig_params=True) # Create an optimizer for the parallelized and sharded model. lr = 3e-3 -rank_print(f"Creating AdamW optimizer with learning rate {lr}") +rank_log(f"Creating AdamW optimizer with learning rate {lr}") optimizer = torch.optim.AdamW(sharded_model.parameters(), lr=lr, foreach=True) # Training loop: # Perform a num of iterations of forward/backward # and optimizations for the sharded module. -rank_print(f"\nStarting 2D training...") +rank_log(f"\nStarting 2D training...") num_iterations = 10 batch_size = 2 for i in range(num_iterations): # seeding with dp_rank to ensure identical inputs for TP groups torch.manual_seed(i + dp_rank) - inp = torch.rand(batch_size, _mlp_dim, device=_device) + inp = torch.rand(batch_size, _mlp_dim, device="cuda") output = sharded_model(inp) output.sum().backward() optimizer.step() - rank_print(f"2D iter {i} complete") + rank_log(f"2D iter {i} complete") -rank_print(f"2D training successfully completed!") +rank_log(f"2D training successfully completed!") diff --git a/distributed/tensor_parallelism/original.py b/distributed/tensor_parallelism/original.py deleted file mode 100644 index 5c28db5adf..0000000000 --- a/distributed/tensor_parallelism/original.py +++ /dev/null @@ -1,127 +0,0 @@ -import argparse - -import torch -import torch.distributed as dist -import torch.multiprocessing as mp - -from torch.distributed._tensor import DeviceMesh -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.tensor.parallel import ( - PairwiseParallel, - parallelize_module, -) -from torch.distributed.tensor.parallel.fsdp import enable_2d_with_fsdp - -from utils import cleanup, setup, ToyModel -try: - from torch.distributed.tensor.parallel import ( - SequenceParallel - ) - SP_AVAILABLE = True -except BaseException as e: - pass - - -""" -This is the script to test 2D Parallel which combines Tensor/Sequence -parallel with Fully Sharded Data Parallel (TP/SP + FSDP) on a toy model -in the SPMD style. We show an E2E working flow from forward, backward -and optimization. - -We enabled Fully Sharded Data Parallel + Tensor Parallel in -separate parallel dimensions: - Data Parallel across hosts - Tensor Parallel within each host - - We use a simple diagram to illustrate below: - -====================================================================== ------------- ------------ ------------ ------------ -| Host 1 | | Host 2 | | | | Host N | -| 8 GPUs | | 8 GPUs | | | | 8 GPUs | -| | | | | ... | | | -| (TP) | | (TP) | | | | (TP) | -|[0,1,..,7]| |[8,9..,15]| | | |[8N-8,8N-7| -| | | | | | | .., 8N-1]| -| | | | | | | | ------------- ------------ ------------ ------------ -FSDP: -[0, 8, ..., 8N-8], [1, 9, ..., 8N-7], ..., [7, 15, ..., 8N-1] -====================================================================== - -More details can be seen in the slide: -https://docs.google.com/presentation/d/17g6WqrO00rP3MsxbRENsPpjrlSkwiA_QB4r93_eB5is/ -""" - - -def demo_2d(rank, args): - """ - Main body of the demo of a basic version of tensor parallel by using - PyTorch native APIs. - """ - print(f"Running basic Megatron style TP example on rank {rank}.") - setup(rank, args.world_size) - assert ( - args.world_size % args.tp_size == 0 - ), "World size needs to be divisible by TP size" - - # create a sharding plan based on the given world_size. - device_mesh = DeviceMesh( - "cuda", torch.arange(0, args.world_size).view(-1, args.tp_size) - ) - - # create model and move it to GPU with id rank - model = ToyModel().cuda(rank) - # Create a optimizer for the parallelized module. - LR = 0.25 - optimizer = torch.optim.SGD(model.parameters(), lr=LR) - # Parallelize the module based on the given Parallel Style. - parallel_style = SequenceParallel() if args.run_seq_parallel else PairwiseParallel() - model = parallelize_module(model, device_mesh, parallel_style, tp_mesh_dim=1) - - # We need to register hooks for TP + FSDP integration. - assert ( - enable_2d_with_fsdp() - ), "FSDP 2D hook is not registered. Please use PyTorch with version >= 2.0" - dp_pg = device_mesh.get_dim_groups()[0] - model = FSDP(model, process_group=dp_pg) - - # Perform a num of iterations of forward/backward - # and optimizations for the sharded module. - for i in range(args.iter_nums): - # For TP, input needs to be same across all TP ranks. - # while for SP, input can be different across all ranks. - # Setting the random seed is to mimic the behavior of dataloader. - dp_rank = ( - rank - if args.run_seq_parallel - else dist.get_rank(dp_pg) - ) - torch.manual_seed(i + dp_rank) - inp = torch.rand(20, 10).cuda(rank) - output = model(inp) - output.sum().backward() - optimizer.step() - - cleanup() - - -if __name__ == "__main__": - n_gpus = torch.cuda.device_count() - parser = argparse.ArgumentParser() - # This is passed in via cmd - parser.add_argument("--world_size", type=int, default=n_gpus) - parser.add_argument("--iter_nums", type=int, default=10) - parser.add_argument("--run_seq_parallel", type=bool, default=False) - parser.add_argument("--tp_size", type=int, default=2) - args = parser.parse_args() - # The main entry point is called directly without using subprocess - if n_gpus < 4: - print("Requires at least 4 GPUs to run.") - elif not SP_AVAILABLE: - print( - "PyTorch doesn't have Sequence Parallelism available," - " need nightly build." - ) - else: - mp.spawn(demo_2d, args=(args,), nprocs=args.world_size, join=True) diff --git a/distributed/tensor_parallelism/run_example.sh b/distributed/tensor_parallelism/run_example.sh index aebe77d985..c8d431505b 100644 --- a/distributed/tensor_parallelism/run_example.sh +++ b/distributed/tensor_parallelism/run_example.sh @@ -1,13 +1,13 @@ # To run samples: -# bash run_example.sh file_to_run.py num_gpus -# where file_to_run = example to launch. Default = 'two_d_parallel_example.py' -# num_gpus = num local gpus to use (must be at least 2). Default =4 +# bash run_example.sh {file_to_run.py} {num_gpus} +# where file_to_run = example to launch. Default = 'fsdp_tp_example.py' +# num_gpus = num local gpus to use (must be at least 2). Default = 4 # samples to run include: # sequence_parallel_example.py # tensor_parallel_example.py -# two_d_parallel_example.py +# fsdp_tp_example.py -echo "Launching ${1:-two_d_parallel_example.py} with ${2:-4} gpus" -torchrun --nnodes=1 --nproc_per_node=${2:-4} --rdzv_id=101 --rdzv_endpoint="localhost:5972" ${1:-two_d_parallel_example.py} +echo "Launching ${1:-fsdp_tp_example.py} with ${2:-4} gpus" +torchrun --nnodes=1 --nproc_per_node=${2:-4} --rdzv_id=101 --rdzv_endpoint="localhost:5972" ${1:-fsdp_tp_example.py} diff --git a/distributed/tensor_parallelism/sequence_parallel_example.py b/distributed/tensor_parallelism/sequence_parallel_example.py index eda3231a93..d90d36ee2c 100644 --- a/distributed/tensor_parallelism/sequence_parallel_example.py +++ b/distributed/tensor_parallelism/sequence_parallel_example.py @@ -2,9 +2,10 @@ import torch import torch.nn as nn import torch.nn.functional as F +import logging from torch.distributed._tensor.device_mesh import init_device_mesh -from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard +from torch.distributed._tensor import DeviceMesh, Shard from torch.distributed.tensor.parallel import ( parallelize_module, @@ -13,13 +14,8 @@ ) -try: - from torch.distributed.tensor.parallel import ( - SequenceParallel - ) - SP_AVAILABLE = True -except BaseException as e: - pass +logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p', level=logging.INFO) +logger = logging.getLogger(__name__) """ @@ -38,7 +34,6 @@ in the end of the second linear layer. """ - class ToyModel(nn.Module): """ MLP based model """ def __init__(self): @@ -56,25 +51,22 @@ def forward(self, x): PyTorch native APIs. """ -_rank = int(os.environ["RANK"]) +# create a device mesh based on the given world_size. +device_mesh = init_device_mesh(device_type = "cuda",mesh_shape = (int(os.environ["WORLD_SIZE"]),)) +_rank = device_mesh.get_rank() -def rank_print(msg): - """helper function to print only on global rank 0""" +def rank_log(msg): + """helper function to log only on global rank 0""" if _rank==0: - print(f"{msg}") - -print(f"Running basic Megatron style Sequence Parallel example on rank {_rank}.") - -# create a device mesh based on the given world_size. -_device = f"cuda" -device_mesh = init_device_mesh(device_type = _device,mesh_shape = (int(os.environ["WORLD_SIZE"]),)) + logger.info(f" {msg}") -rank_print(f"Device Mesh created: {device_mesh=}") +print(f"Starting PyTorch Sequence Parallel example on rank {_rank}.") +rank_log(f"Device Mesh created: {device_mesh=}") # create model and move it to GPU. Init_device_mesh has already assigned gpu ids... -model = ToyModel().to(_device) +model = ToyModel().to("cuda") # Custom parallelization plan for the model sp_model = parallelize_module(module = model, @@ -94,14 +86,14 @@ def rank_print(msg): # Perform a num of iterations of forward/backward # and optimizations for the sharded module. num_iters = 10 -rank_print(f"Sequence Parallel training starting...") +rank_log(f"Sequence Parallel training starting...") for i in range(num_iters): # For SP, input can be different across all ranks. - inp = torch.rand(20, 10,device=_device) + inp = torch.rand(20, 10,device="cuda") output = sp_model(inp) output.sum().backward() optimizer.step() - rank_print(f"Sequence Parallel iter {i} completed") + rank_log(f"Sequence Parallel iter {i} completed") -rank_print(f"Sequence Parallel training completed!") +rank_log(f"Sequence Parallel training completed!") diff --git a/distributed/tensor_parallelism/tensor_parallel_example.py b/distributed/tensor_parallelism/tensor_parallel_example.py index 99a0014782..c57b3bbfc5 100755 --- a/distributed/tensor_parallelism/tensor_parallel_example.py +++ b/distributed/tensor_parallelism/tensor_parallel_example.py @@ -1,11 +1,10 @@ - import os import torch import torch.nn as nn import torch.nn.functional as F from torch.distributed._tensor.device_mesh import init_device_mesh -from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard +from torch.distributed._tensor import DeviceMesh from torch.distributed.tensor.parallel import ( parallelize_module, @@ -13,7 +12,10 @@ RowwiseParallel, ) +import logging +logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p', level=logging.INFO) +logger = logging.getLogger(__name__) """ @@ -64,34 +66,24 @@ def forward(self, x): PyTorch native APIs. """ -# understand world topology -_rank = int(os.environ["RANK"]) -_local_rank = int(os.environ["LOCAL_RANK"]) +# create a device mesh based on the given world_size. _world_size = int(os.environ["WORLD_SIZE"]) -#_local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) +device_mesh = init_device_mesh(device_type = "cuda",mesh_shape = (_world_size,)) +_rank = device_mesh.get_rank() - -def rank_print(msg): +def rank_log(msg): """helper function to print only on global rank 0""" if _rank==0: - print(f"{msg}") + logger.info(f" {msg}") -print(f"Running basic Megatron style TP example on rank {_rank}.") +print(f"Starting PyTorch TP example on rank {_rank}.") assert _world_size % 2 == 0, f"TP examples require even number of GPUs, but got {_world_size} gpus" +rank_log(f"Device Mesh created: {device_mesh=}") - -# create a device mesh based on the given world_size. - -_device = f"cuda" -device_mesh = init_device_mesh(device_type = _device,mesh_shape = (_world_size,)) -assert device_mesh is not None, "unable to create valid device mesh" - -rank_print(f"Device Mesh created: {device_mesh=}") - -# create model and move it to GPU - init_device_mesh has already mapped GPU ids. -tp_model = ToyModel().to(_device) +# create model and move it to GPU - init"cuda"_mesh has already mapped GPU ids. +tp_model = ToyModel().to("cuda") # Create an optimizer for the parallelized module. lr = 0.25 @@ -108,16 +100,16 @@ def rank_print(msg): # Perform a num of iterations of forward/backward # and optimizations for the sharded module. num_iters = 10 -rank_print(f"Tensor Parallel training starting...") +rank_log(f"Tensor Parallel training starting...") for i in range(num_iters): # For TP, input needs to be same across all TP ranks. # Setting the random seed is to mimic the behavior of dataloader. torch.manual_seed(i) - inp = torch.rand(20, 10, device=_device) + inp = torch.rand(20, 10, device="cuda") output = tp_model(inp) output.sum().backward() optimizer.step() - rank_print(f"Tensor Parallel iter {i} completed") + rank_log(f"Tensor Parallel iter {i} completed") -rank_print(f"Tensor Parallel training completed!") +rank_log(f"Tensor Parallel training completed!") From 2f4a08307c687edd7e0cc218b3a378095a9407ec Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Tue, 21 Nov 2023 20:01:36 -0800 Subject: [PATCH 14/20] code linting via ruff --- distributed/tensor_parallelism/fsdp_tp_example.py | 6 ++---- .../tensor_parallelism/sequence_parallel_example.py | 7 +++---- distributed/tensor_parallelism/tensor_parallel_example.py | 6 ++---- 3 files changed, 7 insertions(+), 12 deletions(-) diff --git a/distributed/tensor_parallelism/fsdp_tp_example.py b/distributed/tensor_parallelism/fsdp_tp_example.py index e185fb08db..1017a7051e 100644 --- a/distributed/tensor_parallelism/fsdp_tp_example.py +++ b/distributed/tensor_parallelism/fsdp_tp_example.py @@ -4,7 +4,6 @@ import torch.nn as nn import torch.nn.functional as F -from torch.distributed._tensor import DeviceMesh from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.tensor.parallel import ( parallelize_module, @@ -13,7 +12,6 @@ ) -from torch.distributed._shard.sharded_tensor import ShardedTensor from torch.distributed._tensor.device_mesh import init_device_mesh import os import logging @@ -150,7 +148,7 @@ def rank_log(msg): # Training loop: # Perform a num of iterations of forward/backward # and optimizations for the sharded module. -rank_log(f"\nStarting 2D training...") +rank_log("\nStarting 2D training...") num_iterations = 10 batch_size = 2 @@ -164,4 +162,4 @@ def rank_log(msg): optimizer.step() rank_log(f"2D iter {i} complete") -rank_log(f"2D training successfully completed!") +rank_log("2D training successfully completed!") diff --git a/distributed/tensor_parallelism/sequence_parallel_example.py b/distributed/tensor_parallelism/sequence_parallel_example.py index d90d36ee2c..203d2afa05 100644 --- a/distributed/tensor_parallelism/sequence_parallel_example.py +++ b/distributed/tensor_parallelism/sequence_parallel_example.py @@ -1,11 +1,10 @@ import os import torch import torch.nn as nn -import torch.nn.functional as F import logging from torch.distributed._tensor.device_mesh import init_device_mesh -from torch.distributed._tensor import DeviceMesh, Shard +from torch.distributed._tensor import Shard from torch.distributed.tensor.parallel import ( parallelize_module, @@ -86,7 +85,7 @@ def rank_log(msg): # Perform a num of iterations of forward/backward # and optimizations for the sharded module. num_iters = 10 -rank_log(f"Sequence Parallel training starting...") +rank_log("Sequence Parallel training starting...") for i in range(num_iters): # For SP, input can be different across all ranks. @@ -96,4 +95,4 @@ def rank_log(msg): optimizer.step() rank_log(f"Sequence Parallel iter {i} completed") -rank_log(f"Sequence Parallel training completed!") +rank_log("Sequence Parallel training completed!") diff --git a/distributed/tensor_parallelism/tensor_parallel_example.py b/distributed/tensor_parallelism/tensor_parallel_example.py index c57b3bbfc5..1b2bc073e0 100755 --- a/distributed/tensor_parallelism/tensor_parallel_example.py +++ b/distributed/tensor_parallelism/tensor_parallel_example.py @@ -1,10 +1,8 @@ import os import torch import torch.nn as nn -import torch.nn.functional as F from torch.distributed._tensor.device_mesh import init_device_mesh -from torch.distributed._tensor import DeviceMesh from torch.distributed.tensor.parallel import ( parallelize_module, @@ -100,7 +98,7 @@ def rank_log(msg): # Perform a num of iterations of forward/backward # and optimizations for the sharded module. num_iters = 10 -rank_log(f"Tensor Parallel training starting...") +rank_log("Tensor Parallel training starting...") for i in range(num_iters): # For TP, input needs to be same across all TP ranks. @@ -112,4 +110,4 @@ def rank_log(msg): optimizer.step() rank_log(f"Tensor Parallel iter {i} completed") -rank_log(f"Tensor Parallel training completed!") +rank_log("Tensor Parallel training completed!") From 742966ba1cbec46081f913f1697d24d376c3248e Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Tue, 21 Nov 2023 20:03:58 -0800 Subject: [PATCH 15/20] code formatting via ruff --- .../tensor_parallelism/fsdp_tp_example.py | 37 +++++++++++-------- .../sequence_parallel_example.py | 31 ++++++++++------ .../tensor_parallel_example.py | 31 ++++++++++------ 3 files changed, 62 insertions(+), 37 deletions(-) diff --git a/distributed/tensor_parallelism/fsdp_tp_example.py b/distributed/tensor_parallelism/fsdp_tp_example.py index 1017a7051e..adb134dd06 100644 --- a/distributed/tensor_parallelism/fsdp_tp_example.py +++ b/distributed/tensor_parallelism/fsdp_tp_example.py @@ -1,4 +1,3 @@ - import torch import torch.distributed as dist import torch.nn as nn @@ -47,17 +46,19 @@ More details can be seen in the slide: https://docs.google.com/presentation/d/17g6WqrO00rP3MsxbRENsPpjrlSkwiA_QB4r93_eB5is/ """ + + def find_multiple(n: int, k: int) -> int: - """ function to find resizing multiple for SwiGLU MLP """ + """function to find resizing multiple for SwiGLU MLP""" if n % k == 0: return n return n + k - (n % k) class MLP_swiglu(nn.Module): - """ SwiGLU to showcase a Llama style MLP model """ + """SwiGLU to showcase a Llama style MLP model""" - def __init__(self, mlp_dim: int= 1024) -> None: + def __init__(self, mlp_dim: int = 1024) -> None: super().__init__() hidden_dim = 4 * mlp_dim scaled_hidden = int(2 * hidden_dim / 3) @@ -72,13 +73,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.out_proj(x) return x + """ Main body of the demo of a basic version of tensor parallel by using PyTorch native APIs. """ tp_size = 2 -logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p', level=logging.INFO) +logging.basicConfig( + format="%(asctime)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p", level=logging.INFO +) logger = logging.getLogger(__name__) @@ -86,12 +90,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: _rank = int(os.environ["RANK"]) _world_size = int(os.environ["WORLD_SIZE"]) + # def rank_log(msg): """helper function to print only on global rank 0""" - if _rank==0: + if _rank == 0: logger.info(f" {msg}") + print(f"Starting PyTorch 2D (FSDP + TP) example on rank {_rank}.") assert ( _world_size % tp_size == 0 @@ -104,7 +110,7 @@ def rank_log(msg): # Create a device mesh with 2 dimensions. # First dim is the data parallel dimension # Second dim is the tensor parallel dimension. -device_mesh = init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("dp","tp")) +device_mesh = init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp")) rank_log(f"Device Mesh created: {device_mesh=}") tp_mesh = device_mesh["tp"] @@ -126,19 +132,20 @@ def rank_log(msg): # Custom parallelization plan for the swiglu MLP model -custom_tp_model = parallelize_module(module = base_model_swiglu, - device_mesh = tp_mesh, - parallelize_plan = { - "in_proj": ColwiseParallel(), - "gate_proj": ColwiseParallel(), - "out_proj": RowwiseParallel(), - }, +custom_tp_model = parallelize_module( + module=base_model_swiglu, + device_mesh=tp_mesh, + parallelize_plan={ + "in_proj": ColwiseParallel(), + "gate_proj": ColwiseParallel(), + "out_proj": RowwiseParallel(), + }, ) rank_log(f"Model after parallelization {custom_tp_model=}\n") # Init FSDP using the dp device mesh -sharded_model = FSDP(custom_tp_model, device_mesh = dp_mesh, use_orig_params=True) +sharded_model = FSDP(custom_tp_model, device_mesh=dp_mesh, use_orig_params=True) # Create an optimizer for the parallelized and sharded model. lr = 3e-3 diff --git a/distributed/tensor_parallelism/sequence_parallel_example.py b/distributed/tensor_parallelism/sequence_parallel_example.py index 203d2afa05..069e981b74 100644 --- a/distributed/tensor_parallelism/sequence_parallel_example.py +++ b/distributed/tensor_parallelism/sequence_parallel_example.py @@ -13,7 +13,9 @@ ) -logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p', level=logging.INFO) +logging.basicConfig( + format="%(asctime)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p", level=logging.INFO +) logger = logging.getLogger(__name__) @@ -33,8 +35,10 @@ in the end of the second linear layer. """ + class ToyModel(nn.Module): - """ MLP based model """ + """MLP based model""" + def __init__(self): super().__init__() self.in_proj = nn.Linear(10, 32) @@ -51,15 +55,19 @@ def forward(self, x): """ # create a device mesh based on the given world_size. -device_mesh = init_device_mesh(device_type = "cuda",mesh_shape = (int(os.environ["WORLD_SIZE"]),)) +device_mesh = init_device_mesh( + device_type="cuda", mesh_shape=(int(os.environ["WORLD_SIZE"]),) +) _rank = device_mesh.get_rank() + def rank_log(msg): """helper function to log only on global rank 0""" - if _rank==0: + if _rank == 0: logger.info(f" {msg}") + print(f"Starting PyTorch Sequence Parallel example on rank {_rank}.") rank_log(f"Device Mesh created: {device_mesh=}") @@ -68,12 +76,13 @@ def rank_log(msg): model = ToyModel().to("cuda") # Custom parallelization plan for the model -sp_model = parallelize_module(module = model, - device_mesh = device_mesh, - parallelize_plan = { - "in_proj": ColwiseParallel(input_layouts=Shard(0)), - "out_proj": RowwiseParallel(output_layouts=Shard(0)), - }, +sp_model = parallelize_module( + module=model, + device_mesh=device_mesh, + parallelize_plan={ + "in_proj": ColwiseParallel(input_layouts=Shard(0)), + "out_proj": RowwiseParallel(output_layouts=Shard(0)), + }, ) @@ -89,7 +98,7 @@ def rank_log(msg): for i in range(num_iters): # For SP, input can be different across all ranks. - inp = torch.rand(20, 10,device="cuda") + inp = torch.rand(20, 10, device="cuda") output = sp_model(inp) output.sum().backward() optimizer.step() diff --git a/distributed/tensor_parallelism/tensor_parallel_example.py b/distributed/tensor_parallelism/tensor_parallel_example.py index 1b2bc073e0..cdf70f3799 100755 --- a/distributed/tensor_parallelism/tensor_parallel_example.py +++ b/distributed/tensor_parallelism/tensor_parallel_example.py @@ -12,7 +12,9 @@ import logging -logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p', level=logging.INFO) +logging.basicConfig( + format="%(asctime)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p", level=logging.INFO +) logger = logging.getLogger(__name__) @@ -49,7 +51,8 @@ class ToyModel(nn.Module): - """ MLP based model """ + """MLP based model""" + def __init__(self): super(ToyModel, self).__init__() self.in_proj = nn.Linear(10, 32) @@ -59,6 +62,7 @@ def __init__(self): def forward(self, x): return self.out_proj(self.relu(self.in_proj(x))) + """ Main body of the demo of a basic version of tensor parallel by using PyTorch native APIs. @@ -67,16 +71,20 @@ def forward(self, x): # create a device mesh based on the given world_size. _world_size = int(os.environ["WORLD_SIZE"]) -device_mesh = init_device_mesh(device_type = "cuda",mesh_shape = (_world_size,)) +device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,)) _rank = device_mesh.get_rank() + def rank_log(msg): """helper function to print only on global rank 0""" - if _rank==0: + if _rank == 0: logger.info(f" {msg}") + print(f"Starting PyTorch TP example on rank {_rank}.") -assert _world_size % 2 == 0, f"TP examples require even number of GPUs, but got {_world_size} gpus" +assert ( + _world_size % 2 == 0 +), f"TP examples require even number of GPUs, but got {_world_size} gpus" rank_log(f"Device Mesh created: {device_mesh=}") @@ -88,12 +96,13 @@ def rank_log(msg): optimizer = torch.optim.AdamW(tp_model.parameters(), lr=lr, foreach=True) # Custom parallelization plan for the model -tp_model = parallelize_module(module = tp_model, - device_mesh = device_mesh, - parallelize_plan = { - "in_proj": ColwiseParallel(), - "out_proj": RowwiseParallel(), - }, +tp_model = parallelize_module( + module=tp_model, + device_mesh=device_mesh, + parallelize_plan={ + "in_proj": ColwiseParallel(), + "out_proj": RowwiseParallel(), + }, ) # Perform a num of iterations of forward/backward # and optimizations for the sharded module. From 7da71bcdbc1adba0d54535a4efd6c90f0567b746 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Tue, 21 Nov 2023 20:31:25 -0800 Subject: [PATCH 16/20] move rank_log to utils.py, update example files --- .../tensor_parallelism/fsdp_tp_example.py | 21 +++++++------------ .../sequence_parallel_example.py | 16 +++++--------- .../tensor_parallel_example.py | 15 +++++-------- distributed/tensor_parallelism/utils.py | 6 ++++++ 4 files changed, 23 insertions(+), 35 deletions(-) create mode 100644 distributed/tensor_parallelism/utils.py diff --git a/distributed/tensor_parallelism/fsdp_tp_example.py b/distributed/tensor_parallelism/fsdp_tp_example.py index adb134dd06..4d0f6f0ab1 100644 --- a/distributed/tensor_parallelism/fsdp_tp_example.py +++ b/distributed/tensor_parallelism/fsdp_tp_example.py @@ -14,7 +14,7 @@ from torch.distributed._tensor.device_mesh import init_device_mesh import os import logging - +from utils import rank_log """ This is the script to test 2D Parallel which combines Tensor/Sequence @@ -91,13 +91,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: _world_size = int(os.environ["WORLD_SIZE"]) -# -def rank_log(msg): - """helper function to print only on global rank 0""" - if _rank == 0: - logger.info(f" {msg}") - - print(f"Starting PyTorch 2D (FSDP + TP) example on rank {_rank}.") assert ( _world_size % tp_size == 0 @@ -112,7 +105,7 @@ def rank_log(msg): # Second dim is the tensor parallel dimension. device_mesh = init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp")) -rank_log(f"Device Mesh created: {device_mesh=}") +rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}") tp_mesh = device_mesh["tp"] dp_mesh = device_mesh["dp"] @@ -142,20 +135,20 @@ def rank_log(msg): }, ) -rank_log(f"Model after parallelization {custom_tp_model=}\n") +rank_log(_rank, logger, f"Model after parallelization {custom_tp_model=}\n") # Init FSDP using the dp device mesh sharded_model = FSDP(custom_tp_model, device_mesh=dp_mesh, use_orig_params=True) # Create an optimizer for the parallelized and sharded model. lr = 3e-3 -rank_log(f"Creating AdamW optimizer with learning rate {lr}") +rank_log(_rank, logger, f"Creating AdamW optimizer with learning rate {lr}") optimizer = torch.optim.AdamW(sharded_model.parameters(), lr=lr, foreach=True) # Training loop: # Perform a num of iterations of forward/backward # and optimizations for the sharded module. -rank_log("\nStarting 2D training...") +rank_log(_rank, logger, "\nStarting 2D training...") num_iterations = 10 batch_size = 2 @@ -167,6 +160,6 @@ def rank_log(msg): output = sharded_model(inp) output.sum().backward() optimizer.step() - rank_log(f"2D iter {i} complete") + rank_log(_rank, logger, f"2D iter {i} complete") -rank_log("2D training successfully completed!") +rank_log(_rank, logger, "2D training successfully completed!") diff --git a/distributed/tensor_parallelism/sequence_parallel_example.py b/distributed/tensor_parallelism/sequence_parallel_example.py index 069e981b74..309729f83f 100644 --- a/distributed/tensor_parallelism/sequence_parallel_example.py +++ b/distributed/tensor_parallelism/sequence_parallel_example.py @@ -12,6 +12,7 @@ RowwiseParallel, ) +from utils import rank_log logging.basicConfig( format="%(asctime)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p", level=logging.INFO @@ -61,16 +62,9 @@ def forward(self, x): _rank = device_mesh.get_rank() - -def rank_log(msg): - """helper function to log only on global rank 0""" - if _rank == 0: - logger.info(f" {msg}") - - print(f"Starting PyTorch Sequence Parallel example on rank {_rank}.") -rank_log(f"Device Mesh created: {device_mesh=}") +rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}") # create model and move it to GPU. Init_device_mesh has already assigned gpu ids... model = ToyModel().to("cuda") @@ -94,7 +88,7 @@ def rank_log(msg): # Perform a num of iterations of forward/backward # and optimizations for the sharded module. num_iters = 10 -rank_log("Sequence Parallel training starting...") +rank_log(_rank, logger, "Sequence Parallel training starting...") for i in range(num_iters): # For SP, input can be different across all ranks. @@ -102,6 +96,6 @@ def rank_log(msg): output = sp_model(inp) output.sum().backward() optimizer.step() - rank_log(f"Sequence Parallel iter {i} completed") + rank_log(_rank, logger, f"Sequence Parallel iter {i} completed") -rank_log("Sequence Parallel training completed!") +rank_log(_rank, logger, "Sequence Parallel training completed!") diff --git a/distributed/tensor_parallelism/tensor_parallel_example.py b/distributed/tensor_parallelism/tensor_parallel_example.py index cdf70f3799..d3b0bb6448 100755 --- a/distributed/tensor_parallelism/tensor_parallel_example.py +++ b/distributed/tensor_parallelism/tensor_parallel_example.py @@ -11,6 +11,7 @@ ) import logging +from utils import rank_log logging.basicConfig( format="%(asctime)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p", level=logging.INFO @@ -75,18 +76,12 @@ def forward(self, x): _rank = device_mesh.get_rank() -def rank_log(msg): - """helper function to print only on global rank 0""" - if _rank == 0: - logger.info(f" {msg}") - - print(f"Starting PyTorch TP example on rank {_rank}.") assert ( _world_size % 2 == 0 ), f"TP examples require even number of GPUs, but got {_world_size} gpus" -rank_log(f"Device Mesh created: {device_mesh=}") +rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}") # create model and move it to GPU - init"cuda"_mesh has already mapped GPU ids. tp_model = ToyModel().to("cuda") @@ -107,7 +102,7 @@ def rank_log(msg): # Perform a num of iterations of forward/backward # and optimizations for the sharded module. num_iters = 10 -rank_log("Tensor Parallel training starting...") +rank_log(_rank, logger, "Tensor Parallel training starting...") for i in range(num_iters): # For TP, input needs to be same across all TP ranks. @@ -117,6 +112,6 @@ def rank_log(msg): output = tp_model(inp) output.sum().backward() optimizer.step() - rank_log(f"Tensor Parallel iter {i} completed") + rank_log(_rank, logger, f"Tensor Parallel iter {i} completed") -rank_log("Tensor Parallel training completed!") +rank_log(_rank, logger, "Tensor Parallel training completed!") diff --git a/distributed/tensor_parallelism/utils.py b/distributed/tensor_parallelism/utils.py new file mode 100644 index 0000000000..3a25aa11c5 --- /dev/null +++ b/distributed/tensor_parallelism/utils.py @@ -0,0 +1,6 @@ + + +def rank_log(_rank, logger, msg): + """helper function to log only on global rank 0""" + if _rank == 0: + logger.info(f" {msg}") From 836f7986526c41c0b3bd466e626f29aca6b47ce5 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Wed, 22 Nov 2023 12:18:35 -0800 Subject: [PATCH 17/20] move logging imports and config to log_utils, update examples with new import --- distributed/tensor_parallelism/fsdp_tp_example.py | 10 ++-------- distributed/tensor_parallelism/log_utils.py | 14 ++++++++++++++ .../sequence_parallel_example.py | 9 ++------- .../tensor_parallelism/tensor_parallel_example.py | 8 ++------ distributed/tensor_parallelism/utils.py | 6 ------ 5 files changed, 20 insertions(+), 27 deletions(-) create mode 100644 distributed/tensor_parallelism/log_utils.py delete mode 100644 distributed/tensor_parallelism/utils.py diff --git a/distributed/tensor_parallelism/fsdp_tp_example.py b/distributed/tensor_parallelism/fsdp_tp_example.py index 4d0f6f0ab1..92828df5cf 100644 --- a/distributed/tensor_parallelism/fsdp_tp_example.py +++ b/distributed/tensor_parallelism/fsdp_tp_example.py @@ -13,8 +13,7 @@ from torch.distributed._tensor.device_mesh import init_device_mesh import os -import logging -from utils import rank_log +from log_utils import rank_log, get_logger """ This is the script to test 2D Parallel which combines Tensor/Sequence @@ -79,12 +78,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: PyTorch native APIs. """ tp_size = 2 - -logging.basicConfig( - format="%(asctime)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p", level=logging.INFO -) -logger = logging.getLogger(__name__) - +logger = get_logger() # understand world topology _rank = int(os.environ["RANK"]) diff --git a/distributed/tensor_parallelism/log_utils.py b/distributed/tensor_parallelism/log_utils.py new file mode 100644 index 0000000000..611b25a412 --- /dev/null +++ b/distributed/tensor_parallelism/log_utils.py @@ -0,0 +1,14 @@ +import logging + +logging.basicConfig( + format="%(asctime)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p", level=logging.INFO +) + +def get_logger(): + return logging.getLogger(__name__) + + +def rank_log(_rank, logger, msg): + """helper function to log only on global rank 0""" + if _rank == 0: + logger.info(f" {msg}") diff --git a/distributed/tensor_parallelism/sequence_parallel_example.py b/distributed/tensor_parallelism/sequence_parallel_example.py index 309729f83f..aa943a7304 100644 --- a/distributed/tensor_parallelism/sequence_parallel_example.py +++ b/distributed/tensor_parallelism/sequence_parallel_example.py @@ -1,7 +1,6 @@ import os import torch import torch.nn as nn -import logging from torch.distributed._tensor.device_mesh import init_device_mesh from torch.distributed._tensor import Shard @@ -12,12 +11,7 @@ RowwiseParallel, ) -from utils import rank_log - -logging.basicConfig( - format="%(asctime)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p", level=logging.INFO -) -logger = logging.getLogger(__name__) +from log_utils import rank_log, get_logger """ @@ -54,6 +48,7 @@ def forward(self, x): Main body of the demo of a basic version of sequence parallel by using PyTorch native APIs. """ +logger = get_logger() # create a device mesh based on the given world_size. device_mesh = init_device_mesh( diff --git a/distributed/tensor_parallelism/tensor_parallel_example.py b/distributed/tensor_parallelism/tensor_parallel_example.py index d3b0bb6448..f25252b998 100755 --- a/distributed/tensor_parallelism/tensor_parallel_example.py +++ b/distributed/tensor_parallelism/tensor_parallel_example.py @@ -10,13 +10,8 @@ RowwiseParallel, ) -import logging -from utils import rank_log -logging.basicConfig( - format="%(asctime)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p", level=logging.INFO -) -logger = logging.getLogger(__name__) +from log_utils import rank_log, get_logger """ @@ -68,6 +63,7 @@ def forward(self, x): Main body of the demo of a basic version of tensor parallel by using PyTorch native APIs. """ +logger = get_logger() # create a device mesh based on the given world_size. _world_size = int(os.environ["WORLD_SIZE"]) diff --git a/distributed/tensor_parallelism/utils.py b/distributed/tensor_parallelism/utils.py deleted file mode 100644 index 3a25aa11c5..0000000000 --- a/distributed/tensor_parallelism/utils.py +++ /dev/null @@ -1,6 +0,0 @@ - - -def rank_log(_rank, logger, msg): - """helper function to log only on global rank 0""" - if _rank == 0: - logger.info(f" {msg}") From 2de01443d6cc7fabece0c4354d1a0ebebd8cc81d Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Wed, 22 Nov 2023 13:43:10 -0800 Subject: [PATCH 18/20] add gpu verification, update run_python_examples.sh --- distributed/tensor_parallelism/fsdp_tp_example.py | 9 ++++++++- distributed/tensor_parallelism/log_utils.py | 8 ++++++++ .../tensor_parallelism/sequence_parallel_example.py | 8 +++++++- .../tensor_parallelism/tensor_parallel_example.py | 10 +++++++++- run_python_examples.sh | 9 ++++----- 5 files changed, 36 insertions(+), 8 deletions(-) diff --git a/distributed/tensor_parallelism/fsdp_tp_example.py b/distributed/tensor_parallelism/fsdp_tp_example.py index 92828df5cf..9c32d1038b 100644 --- a/distributed/tensor_parallelism/fsdp_tp_example.py +++ b/distributed/tensor_parallelism/fsdp_tp_example.py @@ -1,3 +1,4 @@ +import sys import torch import torch.distributed as dist import torch.nn as nn @@ -13,7 +14,7 @@ from torch.distributed._tensor.device_mesh import init_device_mesh import os -from log_utils import rank_log, get_logger +from log_utils import rank_log, get_logger, verify_min_gpu_count """ This is the script to test 2D Parallel which combines Tensor/Sequence @@ -46,6 +47,12 @@ https://docs.google.com/presentation/d/17g6WqrO00rP3MsxbRENsPpjrlSkwiA_QB4r93_eB5is/ """ +_min_gpu_count = 2 + +if not verify_min_gpu_count(min_gpus=_min_gpu_count): + print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.") + sys.exit(0) + def find_multiple(n: int, k: int) -> int: """function to find resizing multiple for SwiGLU MLP""" diff --git a/distributed/tensor_parallelism/log_utils.py b/distributed/tensor_parallelism/log_utils.py index 611b25a412..f16d46526d 100644 --- a/distributed/tensor_parallelism/log_utils.py +++ b/distributed/tensor_parallelism/log_utils.py @@ -1,4 +1,5 @@ import logging +import torch logging.basicConfig( format="%(asctime)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p", level=logging.INFO @@ -12,3 +13,10 @@ def rank_log(_rank, logger, msg): """helper function to log only on global rank 0""" if _rank == 0: logger.info(f" {msg}") + + +def verify_min_gpu_count(min_gpus: int = 2) -> bool: + """ verification that we have at least 2 gpus to run dist examples """ + has_cuda = torch.cuda.is_available() + gpu_count = torch.cuda.device_count() + return has_cuda and gpu_count >= min_gpus diff --git a/distributed/tensor_parallelism/sequence_parallel_example.py b/distributed/tensor_parallelism/sequence_parallel_example.py index aa943a7304..6a9de413bb 100644 --- a/distributed/tensor_parallelism/sequence_parallel_example.py +++ b/distributed/tensor_parallelism/sequence_parallel_example.py @@ -1,4 +1,5 @@ import os +import sys import torch import torch.nn as nn @@ -11,7 +12,7 @@ RowwiseParallel, ) -from log_utils import rank_log, get_logger +from log_utils import rank_log, get_logger, verify_min_gpu_count """ @@ -29,6 +30,11 @@ now is different so that we need one all-gather for input and one reduce-scatter in the end of the second linear layer. """ +_min_gpu_count = 2 + +if not verify_min_gpu_count(min_gpus=_min_gpu_count): + print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.") + sys.exit(0) class ToyModel(nn.Module): diff --git a/distributed/tensor_parallelism/tensor_parallel_example.py b/distributed/tensor_parallelism/tensor_parallel_example.py index f25252b998..bc8325d5d7 100755 --- a/distributed/tensor_parallelism/tensor_parallel_example.py +++ b/distributed/tensor_parallelism/tensor_parallel_example.py @@ -1,4 +1,5 @@ import os +import sys import torch import torch.nn as nn @@ -11,7 +12,9 @@ ) -from log_utils import rank_log, get_logger +from log_utils import rank_log, get_logger, verify_min_gpu_count + + """ @@ -45,6 +48,11 @@ Parallelism APIs in this example to show users how to use them. """ +_min_gpu_count = 2 + +if not verify_min_gpu_count(min_gpus=_min_gpu_count): + print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.") + sys.exit(0) class ToyModel(nn.Module): """MLP based model""" diff --git a/run_python_examples.sh b/run_python_examples.sh index 1b45a281cf..c933cf8f65 100755 --- a/run_python_examples.sh +++ b/run_python_examples.sh @@ -63,8 +63,8 @@ function distributed() { start python tensor_parallelism/tensor_parallel_example.py || error "tensor parallel example failed" python tensor_parallelism/sequence_parallel_example.py || error "sequence parallel example failed" - python tensor_parallelism/two_d_parallel_example.py || error "2D parallel example failed" - python ddp/main.py || error "ddp example failed" + python tensor_parallelism/fsdp_tp_parallel_example.py || error "2D parallel example failed" + python ddp/main.py || error "ddp example failed" } function fast_neural_style() { @@ -96,7 +96,7 @@ function mnist() { python main.py --epochs 1 --dry-run || error "mnist example failed" } function mnist_forward_forward() { - start + start python main.py --epochs 1 --no_mps --no_cuda || error "mnist forward forward failed" } @@ -212,9 +212,8 @@ function clean() { function run_all() { # cpp dcgan - # distributed - fast_neural_style distributed + fast_neural_style imagenet mnist mnist_forward_forward From 77fe3d8ae1f463dd46d9af35d112011ca787f60b Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Wed, 22 Nov 2023 13:56:35 -0800 Subject: [PATCH 19/20] update min gpu = 4 for fsdp+tp --- distributed/tensor_parallelism/fsdp_tp_example.py | 4 ++-- distributed/tensor_parallelism/sequence_parallel_example.py | 2 +- distributed/tensor_parallelism/tensor_parallel_example.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/distributed/tensor_parallelism/fsdp_tp_example.py b/distributed/tensor_parallelism/fsdp_tp_example.py index 9c32d1038b..0c52aaabbb 100644 --- a/distributed/tensor_parallelism/fsdp_tp_example.py +++ b/distributed/tensor_parallelism/fsdp_tp_example.py @@ -47,11 +47,11 @@ https://docs.google.com/presentation/d/17g6WqrO00rP3MsxbRENsPpjrlSkwiA_QB4r93_eB5is/ """ -_min_gpu_count = 2 +_min_gpu_count = 4 if not verify_min_gpu_count(min_gpus=_min_gpu_count): print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.") - sys.exit(0) + sys.exit() def find_multiple(n: int, k: int) -> int: diff --git a/distributed/tensor_parallelism/sequence_parallel_example.py b/distributed/tensor_parallelism/sequence_parallel_example.py index 6a9de413bb..ea1c76bf41 100644 --- a/distributed/tensor_parallelism/sequence_parallel_example.py +++ b/distributed/tensor_parallelism/sequence_parallel_example.py @@ -34,7 +34,7 @@ if not verify_min_gpu_count(min_gpus=_min_gpu_count): print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.") - sys.exit(0) + sys.exit() class ToyModel(nn.Module): diff --git a/distributed/tensor_parallelism/tensor_parallel_example.py b/distributed/tensor_parallelism/tensor_parallel_example.py index bc8325d5d7..91f0625a69 100755 --- a/distributed/tensor_parallelism/tensor_parallel_example.py +++ b/distributed/tensor_parallelism/tensor_parallel_example.py @@ -52,7 +52,7 @@ if not verify_min_gpu_count(min_gpus=_min_gpu_count): print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.") - sys.exit(0) + sys.exit() class ToyModel(nn.Module): """MLP based model""" From 5f4a5d33b5f98a6daf2905f221fee057537edf94 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Wed, 22 Nov 2023 14:20:53 -0800 Subject: [PATCH 20/20] move gpu check to top of examples, but before import init_device_mesh to clear CI --- .../tensor_parallelism/fsdp_tp_example.py | 20 +++++++++++-------- .../sequence_parallel_example.py | 19 ++++++++++++------ .../tensor_parallel_example.py | 19 +++++++++--------- run_python_examples.sh | 2 +- 4 files changed, 36 insertions(+), 24 deletions(-) diff --git a/distributed/tensor_parallelism/fsdp_tp_example.py b/distributed/tensor_parallelism/fsdp_tp_example.py index 0c52aaabbb..bccd811d82 100644 --- a/distributed/tensor_parallelism/fsdp_tp_example.py +++ b/distributed/tensor_parallelism/fsdp_tp_example.py @@ -11,11 +11,21 @@ RowwiseParallel, ) - -from torch.distributed._tensor.device_mesh import init_device_mesh import os from log_utils import rank_log, get_logger, verify_min_gpu_count + +# ---- GPU check ------------ +_min_gpu_count = 4 + +if not verify_min_gpu_count(min_gpus=_min_gpu_count): + print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.") + sys.exit() +# --------------------------- + +from torch.distributed._tensor.device_mesh import init_device_mesh + + """ This is the script to test 2D Parallel which combines Tensor/Sequence parallel with Fully Sharded Data Parallel (TP/SP + FSDP) on a toy model @@ -47,12 +57,6 @@ https://docs.google.com/presentation/d/17g6WqrO00rP3MsxbRENsPpjrlSkwiA_QB4r93_eB5is/ """ -_min_gpu_count = 4 - -if not verify_min_gpu_count(min_gpus=_min_gpu_count): - print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.") - sys.exit() - def find_multiple(n: int, k: int) -> int: """function to find resizing multiple for SwiGLU MLP""" diff --git a/distributed/tensor_parallelism/sequence_parallel_example.py b/distributed/tensor_parallelism/sequence_parallel_example.py index ea1c76bf41..3324d28d4a 100644 --- a/distributed/tensor_parallelism/sequence_parallel_example.py +++ b/distributed/tensor_parallelism/sequence_parallel_example.py @@ -3,7 +3,6 @@ import torch import torch.nn as nn -from torch.distributed._tensor.device_mesh import init_device_mesh from torch.distributed._tensor import Shard from torch.distributed.tensor.parallel import ( @@ -15,6 +14,19 @@ from log_utils import rank_log, get_logger, verify_min_gpu_count +# ---- GPU check ------------ +_min_gpu_count = 2 + +if not verify_min_gpu_count(min_gpus=_min_gpu_count): + print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.") + sys.exit() +# --------------------------- + + +from torch.distributed._tensor.device_mesh import init_device_mesh + + + """ This is the script to test Sequence Parallel(SP) on a toy model in a Megetron-LM SPMD style. We show an E2E working flow from forward, @@ -30,11 +42,6 @@ now is different so that we need one all-gather for input and one reduce-scatter in the end of the second linear layer. """ -_min_gpu_count = 2 - -if not verify_min_gpu_count(min_gpus=_min_gpu_count): - print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.") - sys.exit() class ToyModel(nn.Module): diff --git a/distributed/tensor_parallelism/tensor_parallel_example.py b/distributed/tensor_parallelism/tensor_parallel_example.py index 91f0625a69..2731e8046b 100755 --- a/distributed/tensor_parallelism/tensor_parallel_example.py +++ b/distributed/tensor_parallelism/tensor_parallel_example.py @@ -3,17 +3,24 @@ import torch import torch.nn as nn -from torch.distributed._tensor.device_mesh import init_device_mesh - from torch.distributed.tensor.parallel import ( parallelize_module, ColwiseParallel, RowwiseParallel, ) - from log_utils import rank_log, get_logger, verify_min_gpu_count +# ---- GPU check ------------ +_min_gpu_count = 2 + +if not verify_min_gpu_count(min_gpus=_min_gpu_count): + print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.") + sys.exit() +# --------------------------- + +from torch.distributed._tensor.device_mesh import init_device_mesh + @@ -48,12 +55,6 @@ Parallelism APIs in this example to show users how to use them. """ -_min_gpu_count = 2 - -if not verify_min_gpu_count(min_gpus=_min_gpu_count): - print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.") - sys.exit() - class ToyModel(nn.Module): """MLP based model""" diff --git a/run_python_examples.sh b/run_python_examples.sh index c933cf8f65..a9ff393e80 100755 --- a/run_python_examples.sh +++ b/run_python_examples.sh @@ -63,7 +63,7 @@ function distributed() { start python tensor_parallelism/tensor_parallel_example.py || error "tensor parallel example failed" python tensor_parallelism/sequence_parallel_example.py || error "sequence parallel example failed" - python tensor_parallelism/fsdp_tp_parallel_example.py || error "2D parallel example failed" + python tensor_parallelism/fsdp_tp_example.py || error "2D parallel example failed" python ddp/main.py || error "ddp example failed" }