diff --git a/python/example/mixtral_4D_benchmark/README.md b/python/example/mixtral_4D_benchmark/README.md new file mode 100644 index 0000000..55e239e --- /dev/null +++ b/python/example/mixtral_4D_benchmark/README.md @@ -0,0 +1,34 @@ +# veScale Mixtral Example + +## Overview + +In this directory, we provides an 4D parallelism example of using veScale to run +a [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) that is directly imported +from HuggingFace without any model code modifications. + + +## Run + +### Single Machine 8 cards +``` +torchrun --nproc-per-node=8 --nnodes=1 --master-port=42516 -- python/example/mixtral_4D_benchmark/mixtral_train.py --num_hidden_layers=16 +``` +This will start a 8-cards MFU benchmark for Mixtral with veScale with dp=1 and tp=8. + +### Distributed Environment (2 Machine 16 cards example) +``` +# You may need to pull up a suitable distributed cluster environment +torchrun --nproc-per-node=8 --nnodes=1 python/example/mixtral_4D_benchmark/mixtral_train.py --tp 8 --dp 2 +``` +This will start a 16 cards MFU benchmark for Mixtral with veScale with dp=2 and tp=8. + +### Options +1. `--bsz`: the total number of batch size for one iteration. The default is 16. +2. `--seqlen`: the sequence lengtht of the input. The default value is 256. +3. `--dp`: the amount of data parallelism (DDP). The default is 1. +4. `--tp`: the amount of tensor parallelism. The default is 8. + + +## Caveats +1. The scripts are purely for demonstration propose and mfu calculation. You need to write your own training script + it in order to fine-tune Mixtral with your data. diff --git a/python/example/mixtral_4D_benchmark/mixtral_train.py b/python/example/mixtral_4D_benchmark/mixtral_train.py new file mode 100644 index 0000000..91e94eb --- /dev/null +++ b/python/example/mixtral_4D_benchmark/mixtral_train.py @@ -0,0 +1,150 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +import argparse +import os + +import torch +import torch.distributed as dist + +from vescale.dtensor.device_mesh import DeviceMesh +from vescale.dmodule import parallelize_module +from vescale.ddp.distributed_data_parallel import DistributedDataParallel as DDP +from vescale.optim.distributed_optimizer import DistributedOptimizer +from vescale.initialize.deferred_init import deferred_init, is_deferred + +from transformers.models.mixtral.modeling_mixtral import MixtralModel +from transformers.models.mixtral.configuration_mixtral import MixtralConfig +from sharding_plan import mixtral_plan + +local_rank = int(os.environ["LOCAL_RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +rank = int(os.environ["RANK"]) + + +def estimate_mixtral(config, bsz, sqence_length): + embed = 4 * bsz * sqence_length * config.hidden_size + # MixtralMoE consists of 3 linear layers. + ff = 3 * 2 * config.num_experts_per_tok * config.hidden_size * config.intermediate_size * bsz * sqence_length + attn_qkv = 2 * bsz * sqence_length * config.hidden_size * 3 * config.hidden_size + attn_mask = 2 * sqence_length * config.hidden_size + attn_proj = 2 * config.hidden_size * config.intermediate_size * bsz * sqence_length + attn = attn_qkv + attn_mask + attn_proj + return embed + (ff + attn) * config.num_hidden_layers + + +def run_mixtral(args): + torch.random.manual_seed(777) + device_list = [ + list(range(i * args.tp, min((i + 1) * args.tp, world_size))) for i in range(max(world_size // args.tp, 1)) + ] + device_mesh = DeviceMesh("cuda", device_list, mesh_dim_names=("DP", "TP")) + torch.cuda.set_device(local_rank) + + mixtral_config = MixtralConfig( + vocab_size=args.vocab_size, + hidden_size=args.hidden_size, + intermediate_size=args.intermediate_size, + num_hidden_layers=args.num_hidden_layers, + num_attention_heads=args.num_attention_heads, + num_key_value_heads=args.num_key_value_heads, + ) + + model_deferred = deferred_init(MixtralModel, mixtral_config) + + mixtral_model = parallelize_module( + model_deferred, + device_mesh["TP"], + mixtral_plan, + factory=True, + ) + + assert not is_deferred(mixtral_model) + + ddp_mixtral_model = DDP( + mixtral_model, + device_mesh["DP"], + accumulate_allreduce_grads_in_fp32=True, + overlap_grad_reduce=False, + use_distributed_optimizer=True, + ) + + doptim = DistributedOptimizer( + torch.optim.Adam(mixtral_model.parameters(), lr=0.01), + models=[ddp_mixtral_model], + overlap_param_gather=True, + ) + + dataloader = [] + for iter in range(args.iter): + data = torch.randint(0, args.vocab_size, (args.bsz, args.seqlen)).cuda() + dist.all_reduce(data, op=dist.ReduceOp.MAX) + dataloader.append(data) + + # =----- warmup -----= # + for _ in range(args.warmup): + data = torch.randint(0, args.vocab_size, (args.bsz, args.seqlen)).cuda() + doptim.zero_grad() + ddp_mixtral_model(data).last_hidden_state.to_local().sum().backward() + ddp_mixtral_model.finish_grad_sync() + doptim.step() + + # =----- training ----= # + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for iter in range(args.iter): + doptim.zero_grad() + x = dataloader[iter] + ddp_mixtral_model(x).last_hidden_state.to_local().sum().backward() + ddp_mixtral_model.finish_grad_sync() + doptim.step() + end.record() + torch.cuda.synchronize() + exec_t = start.elapsed_time(end) / 1000 / args.iter + # masure mfu + if local_rank == 0: + # Note we are using FP32. The peak FLOPs of H100 is 59 TFLOPs. + total_flops = 59 * (10**12) * device_mesh.ndevice + print(f"1 iter time: {exec_t}") + mixtral_flops = estimate_mixtral(mixtral_config, args.bsz, args.seqlen) + print(f"fwd mixtral flops: {mixtral_flops}") + # bwd ~= fwd * 2 + print("mfu:", mixtral_flops * 3 * args.dp * 100 / exec_t / total_flops) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--warmup", type=int, default=2) + parser.add_argument("--iter", type=int, default=10) + parser.add_argument("--vocab_size", type=int, default=32000) + parser.add_argument("--hidden_size", type=int, default=4096) + parser.add_argument("--intermediate_size", type=int, default=14336) + parser.add_argument("--num_hidden_layers", type=int, default=16) + parser.add_argument("--num_attention_heads", type=int, default=32) + parser.add_argument("--num_key_value_heads", type=int, default=8) + parser.add_argument("--bsz", type=int, default=16) + parser.add_argument("--seqlen", type=int, default=256) + parser.add_argument("--dp", type=int, default=1) + parser.add_argument("--tp", type=int, default=8) + return parser + + +if __name__ == "__main__": + parser = parse_args() + args = parser.parse_args() + run_mixtral(args) diff --git a/python/example/mixtral_4D_benchmark/sharding_plan.py b/python/example/mixtral_4D_benchmark/sharding_plan.py new file mode 100644 index 0000000..b8ae79e --- /dev/null +++ b/python/example/mixtral_4D_benchmark/sharding_plan.py @@ -0,0 +1,69 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +"""This file contain TP/SP sharding plans for Mixtral example code.""" + +from vescale.dtensor.placement_types import Replicate, Shard + + +param_sharding_plan = { + "embed_tokens.weight": [Replicate()], + r"layers.\d+.input_layernorm.weight": [Replicate()], # MixtralRMSNorm + r"layers.\d+.self_attn.q_proj.weight": [Shard(0)], + r"layers.\d+.self_attn.k_proj.weight": [Shard(0)], + r"layers.\d+.self_attn.v_proj.weight": [Shard(0)], + # TODO: buggy, cos_cached or sin_cached can be updated or recreated if seqlen exceeds the max seqlen. + r"layers.\d+.self_attn.rotary_emb.layers.\d+.cos_cached": [Replicate()], + r"layers.\d+.self_attn.rotary_emb.layers.\d+.sin_cached": [Replicate()], + r"layers.\d+.self_attn.o_proj.weight": [Shard(1)], + r"layers.\d+.post_attention_layernorm.weight": [Replicate()], + r"layers.\d+.block_sparse_moe.gate.weight": [Replicate()], + r"layers.\d+.block_sparse_moe.experts.\d+.w1.weight": [Shard(0)], + r"layers.\d+.block_sparse_moe.experts.\d+.w3.weight": [Shard(0)], + r"layers.\d+.block_sparse_moe.experts.\d+.w2.weight": [Shard(1)], + "norm.weight": [Replicate()], +} + +fwd_resharding_plan = { + # TODO: buggy: attn mask is torch.Tensor, in training, it's a None + r".input": {"input_ids": [Replicate()], "attention_mask": [Replicate()]}, + "embed_tokens.input": [[Replicate()]], + # No SP + # r"layers.\d+.input_layernorm.input": [[Replicate()]], + # r"layers.\d+.input_layernorm.output": [[Replicate()]], + # SP + r"layers.\d+.input_layernorm.input": [[Shard(1)]], + r"layers.\d+.input_layernorm.output": [[Shard(1)]], + r"layers.\d+.self_attn.input": [[Replicate()]], + r"layers.\d+.self_attn.output": {"attn_output": [Replicate()], "attn_weights": None, "past_key_value": None}, + r"layers.\d+.self_attn.o_proj.output": [[Replicate()]], + # No SP + # r"layers.\d+.post_attention_layernorm.input": [[Replicate()]], + # r"layers.\d+.post_attention_layernorm.output": [[Replicate()]], + # SP + r"layers.\d+.post_attention_layernorm.input": [[Shard(1)]], + r"layers.\d+.post_attention_layernorm.output": [[Shard(1)]], + r"layers.\d+.block_sparse_moe.input": [[Replicate()]], + r"layers.\d+.block_sparse_moe.gate.output": [[Replicate()]], + r"layers.\d+.block_sparse_moe.output": {"final_hidden_states": [Replicate()], "router_logits": [Replicate()]}, + r"layers.\d+.block_sparse_moe.experts.\d+.w1.input": [[Replicate()]], + r"layers.\d+.block_sparse_moe.experts.\d+.w3.input": [[Replicate()]], + r"layers.\d+.block_sparse_moe.experts.\d+.w2.output": [[Replicate()]], + "norm.input": [[Replicate()]], +} + +mixtral_plan = {"parameter": param_sharding_plan, "forward": fwd_resharding_plan} diff --git a/python/example/nanogpt_4D_finetune/base_train.py b/python/example/nanogpt_4D_finetune/base_train.py index b3b97f1..70980ad 100644 --- a/python/example/nanogpt_4D_finetune/base_train.py +++ b/python/example/nanogpt_4D_finetune/base_train.py @@ -44,7 +44,6 @@ import time import math import pickle -from contextlib import nullcontext import numpy as np import torch @@ -138,7 +137,6 @@ device_type = "cuda" if "cuda" in device else "cpu" # for later use in torch.autocast # note: float16 data type will automatically use a GradScaler ptdtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[dtype] -ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=ptdtype) # poor man's data loader data_dir = os.path.join("data", dataset) @@ -227,9 +225,7 @@ def get_batch(split, bsz=batch_size, lbsz=local_batch_size): model.crop_block_size(block_size) model_args["block_size"] = block_size # so that the checkpoint will have the right value model.to(device) - -# initialize a GradScaler. If enabled=False scaler is a no-op -scaler = torch.cuda.amp.GradScaler(enabled=(dtype == "float16")) +model.to(ptdtype) # optimizer optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type) @@ -258,8 +254,7 @@ def estimate_loss(): losses = torch.zeros(eval_iters // factor).to(device) for k in range(eval_iters // factor): X, Y = get_batch(split, batch_size * factor, local_batch_size * factor) - with ctx: - logits, loss = model(X, Y) + logits, loss = model(X, Y) losses[k] = loss.item() / ddp_world_size if ddp: all_reduce(losses) @@ -345,20 +340,17 @@ def get_lr(it): # I really dislike that this bloats the code and forces us to repeat code # looking at the source of that context manager, it just toggles this variable model.require_backward_grad_sync = micro_step == gradient_accumulation_steps - 1 - with ctx: - logits, loss = model(X, Y) - loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation + logits, loss = model(X, Y) + loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation # immediately async prefetch next batch while model is doing the forward pass on the GPU X, Y = get_batch("train") # backward pass, with gradient scaling if training in fp16 - scaler.scale(loss).backward() + loss.backward() # clip the gradient if grad_clip != 0.0: - scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) - # step the optimizer and scaler if training in fp16 - scaler.step(optimizer) - scaler.update() + # step the optimizer + optimizer.step() # flush the gradients as soon as we can, no need for this memory anymore optimizer.zero_grad(set_to_none=True) diff --git a/python/example/nanogpt_4D_finetune/config/finetune_shakespeare.py b/python/example/nanogpt_4D_finetune/config/finetune_shakespeare.py index 5e89e87..295e2d9 100644 --- a/python/example/nanogpt_4D_finetune/config/finetune_shakespeare.py +++ b/python/example/nanogpt_4D_finetune/config/finetune_shakespeare.py @@ -29,7 +29,7 @@ dataset = "shakespeare" init_from = "gpt2" # this is the smallest GPT-2 model -wandb_log = True # feel free to turn on +wandb_log = False # feel free to turn on wandb_project = f"{init_from}_finetune_{dataset}" # only save checkpoints if the validation loss improves diff --git a/python/example/nanogpt_4D_finetune/exp.py b/python/example/nanogpt_4D_finetune/exp.py index b373217..a4f66e0 100644 --- a/python/example/nanogpt_4D_finetune/exp.py +++ b/python/example/nanogpt_4D_finetune/exp.py @@ -47,42 +47,50 @@ def parse(log_fn, name=None): print(f'"{name}": {val_losses},') +GPU_CNT = 4 +DP_SIZES = [4, 2, 1] +SINGLE_GPU_RUN = "python3" +MULTI_GPU_RUN = "torchrun --standalone --nproc_per_node=4" +CONFIG = "config/finetune_shakespeare.py" +LOG_PREFIX = "" + + def run_exps(max_iters, dtypes, run=True): os.makedirs("logs", exist_ok=True) if run: for dtype in dtypes: dt = "bfloat16" if dtype == "bf16" else "float32" - cmd = f"python3 base_train.py config/finetune_shakespeare.py --compile=False --wandb_log=False --max_iters={max_iters} --dtype='{dt}'" - log_fn = f"logs/1gpu_{dtype}_max_iters_{max_iters}.log" + cmd = f"{SINGLE_GPU_RUN} base_train.py {CONFIG} --compile=False --max_iters={max_iters} --dtype='{dt}'" + log_fn = f"logs/{LOG_PREFIX}_1gpu_{dtype}_max_iters_{max_iters}.log" print(f"run {cmd} > {log_fn} 2> {log_fn}.err") os.system(f"{cmd} > {log_fn} 2> {log_fn}.err") - for dp_size in [1, 2, 4]: - tp_size = 4 // dp_size + for dp_size in DP_SIZES: + tp_size = GPU_CNT // dp_size for dtype in dtypes: dt = "bfloat16" if dtype == "bf16" else "float32" - cmd = f"torchrun --standalone --nproc_per_node=4 finetune_4D.py config/finetune_shakespeare.py --compile=False --use_DO=True --wandb_log=False --dp_size={dp_size} --tp_size={tp_size} --max_iters={max_iters} --dtype='{dt}'" - log_fn = f"logs/4gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log" + cmd = f"{MULTI_GPU_RUN} finetune_4D.py {CONFIG} --compile=False --DDP_grads_in_fp32=False --dp_size={dp_size} --tp_size={tp_size} --max_iters={max_iters} --dtype='{dt}'" + log_fn = f"logs/{LOG_PREFIX}_{GPU_CNT}gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log" print(f"run {cmd} > {log_fn} 2> {log_fn}.err") os.system(f"{cmd} > {log_fn} 2> {log_fn}.err") print("train_loss = {") for dtype in dtypes: - parse_train_loss(f"logs/1gpu_{dtype}_max_iters_{max_iters}.log", f"1GPU_{dtype}") - for dp_size in [1, 2, 4]: - tp_size = 4 // dp_size - log_fn = f"logs/4gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log" + parse_train_loss(f"logs/{LOG_PREFIX}_1gpu_{dtype}_max_iters_{max_iters}.log", f"1GPU_{dtype}") + for dp_size in DP_SIZES: + tp_size = GPU_CNT // dp_size + log_fn = f"logs/{LOG_PREFIX}_{GPU_CNT}gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log" parse_train_loss(log_fn, f"4GPU_DP{dp_size}_TP{tp_size}_{dtype}") print("}") print("val_loss = {") for dtype in dtypes: - parse(f"logs/1gpu_{dtype}_max_iters_{max_iters}.log", f"1GPU_{dtype}") - for dp_size in [1, 2, 4]: - tp_size = 4 // dp_size - log_fn = f"logs/4gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log" + parse(f"logs/{LOG_PREFIX}_1gpu_{dtype}_max_iters_{max_iters}.log", f"1GPU_{dtype}") + for dp_size in DP_SIZES: + tp_size = GPU_CNT // dp_size + log_fn = f"logs/{LOG_PREFIX}_{GPU_CNT}gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log" parse(log_fn, f"4GPU_DP{dp_size}_TP{tp_size}_{dtype}") print("}") if __name__ == "__main__": - run_exps(200, ["bf16", "fp32"], run=True) + run_exps(200, ["bf16"], run=True) diff --git a/python/example/nanogpt_4D_finetune/finetune_4D.py b/python/example/nanogpt_4D_finetune/finetune_4D.py index bb852a5..4201a22 100644 --- a/python/example/nanogpt_4D_finetune/finetune_4D.py +++ b/python/example/nanogpt_4D_finetune/finetune_4D.py @@ -271,6 +271,7 @@ def configure_optimizers(model, weight_decay, learning_rate, betas): models=[model], clip_grad=grad_clip, overlap_param_gather=False, + grad_to_fp32=DDP_grads_in_fp32, ) elif ddp: optimizer = BasicOptimizer(base_optimizer, models=model, grad_hook=GradOptimizerHookBase) diff --git a/python/vescale/__init__.py b/python/vescale/__init__.py index f76f345..242c534 100644 --- a/python/vescale/__init__.py +++ b/python/vescale/__init__.py @@ -48,15 +48,18 @@ "DeviceMesh", "init_device_mesh", "normalize_placements", + "from_local", + "to_local", "distribute_tensor", "redistribute_dtensor", + "vescale_all_gather", + "vescale_all_reduce", + "vescale_reduce_scatter", "Placement", "Partial", "Replicate", "Shard", "InterleavedShard", - "from_local", - "to_local", "deferred_init", "is_deferred", "materialize_dtensor", @@ -81,7 +84,7 @@ def wrapper(*args, **kwargs): torch.jit.script = deprecated_function -# dynamo utils +# dynamo utils # TODO: move this out of __init__ def switch_dtensor_for_torch_export(ep: torch.export.ExportedProgram): print(ep.graph_signature.parameters) if not isinstance(ep, torch.export.ExportedProgram): diff --git a/python/vescale/ddp/distributed_data_parallel.py b/python/vescale/ddp/distributed_data_parallel.py index b466f0c..c8b5325 100644 --- a/python/vescale/ddp/distributed_data_parallel.py +++ b/python/vescale/ddp/distributed_data_parallel.py @@ -116,7 +116,6 @@ def __init__( param_to_name = {} for name, param in self.module.named_parameters(): if param.requires_grad: - param.grad_added_to_main_grad = False param_to_name[param] = name dtype = torch.float if accumulate_allreduce_grads_in_fp32 else param.dtype @@ -145,7 +144,6 @@ def __init__( # NOTE: maybe we shoule handle these code later when we need MOE parallel. for param in self.module.parameters(): if param.requires_grad and not getattr(param, "allreduce", True): - param.grad_added_to_main_grad = False dtype = torch.float if accumulate_allreduce_grads_in_fp32 else param.dtype param.main_grad = torch.zeros( param.data.shape, @@ -187,12 +185,8 @@ def param_hook(*unused): if param.requires_grad: if self.overlap_grad_reduce: assert param.grad is not None, "param.grad being None is not safe when overlap_grad_reduce is True" - # NOTE: it seems that there are some place where grad_added_to_main_grad is True. - # what will happen then? - - # TODO: remove grad_added_to_main_grad attribute. model_parallel_device_mesh, placements = None, None - if param.grad is not None and not param.grad_added_to_main_grad: + if param.grad is not None: if isinstance(param.data, DTensor): param.main_grad.add_(param.grad._local_tensor.data) # add DTensor's data model_parallel_device_mesh = param.grad._spec.mesh @@ -248,9 +242,6 @@ def zero_grad_buffer(self, zero_buffer: bool = True): When zero_buffer is set to True, the underlying grad buffer is zeroed out. """ - for param in self.module.parameters(): - if param.requires_grad: - param.grad_added_to_main_grad = False for grad_buffer in self.grad_buffers.values(): grad_buffer.reset(zero_buffer) for expert_grad in self.expert_grads: diff --git a/python/vescale/debug/__init__.py b/python/vescale/debug/__init__.py new file mode 100644 index 0000000..cb3ed3b --- /dev/null +++ b/python/vescale/debug/__init__.py @@ -0,0 +1,3 @@ +from .debug_log import DebugLogger + +__all__ = ["DebugLogger", "pdb"] diff --git a/python/vescale/debug/debug_log.py b/python/vescale/debug/debug_log.py new file mode 100644 index 0000000..5911e37 --- /dev/null +++ b/python/vescale/debug/debug_log.py @@ -0,0 +1,361 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +import functools +import inspect +import os +from typing import Optional +import sys +from types import FrameType +from typing import Dict, Tuple, Union, Sequence + +import torch +from torch import distributed as dist +from typing import TYPE_CHECKING +from logging import Logger + +if TYPE_CHECKING: + from vescale.dtensor.op_schema import OpSchema, OpInfo + + +__all__ = [ + "DebugLogger", +] + + +class DebugLogger: + """ + Provides a centralized logging utility designed to support debugging vescale model in distributed computing environments. + It allows for selective logging based on rank and supports dynamic adjustment of debugging verbosity through environment variables. + + Attributes: + IS_DEBUG_MODE (Optional[bool]): Flag indicating if debug mode is enabled. Defaults to False. + _device_mesh: Placeholder for a device mesh API. Currently None and marked for future replacement. + _already_init (bool): Indicates whether initial setup has been completed to avoid redundant operations. + rank (Optional[int]): The rank of the current process within the distributed setup. + local_rank (Optional[int]): The local rank of the current process. Not currently used but reserved for future. + world_size (Optional[int]): The total number of processes in the distributed environment. To be set externally. + _rank_to_print (Tuple[int,...]): Specifies the ranks for which logging is enabled. Defaults to (-1, ), indicating none. + _loggeer (Optional[Logger]): The logger object used for debug output. If None, falls back to printing. + + Static Methods: + log(*args, **kwargs): Logs a message either to the console or through a specified logger if debug mode is on. + update_vescale_debug_mode_from_env(): Updates the IS_DEBUG_MODE flag based on the VESCALE_DEBUG_MODE environment variable. + set_vescale_debug_mode(on=True, *, rank_to_print=None, logger=None): Configures the debug mode, including which ranks should log messages. + log_communication(func, *args, **kwargs): Logs communication operations. + log_op(op_info: 'OpInfo'): Logs operations execution. + _init_values_(): Initializes necessary values for the logger, such as ranks and world size, if not already done. + + Usage: + Option 1: Define VESCALE_DEBUG_MODE as an environment variable at the beginning of the program. + ` For performance reasons, VESCALE_DEBUG_MODE should be at least set before calling vescale.parallelize_module. + + Option 2 (perferred way): Using set_vescale_debug_mode at any point of your program. + set_vescale_debug_mode also allows you to pass in a Python logging.Logger for each rank to distinguish logs from different ranks. + """ + + IS_DEBUG_MODE: Optional[bool] = False + # TODO replace by new devicemesh api + _device_mesh = None + _already_init: bool = False + rank: Optional[int] = None + local_rank: Optional[int] = None + world_size: Optional[int] = None # dist.get_world_size() + _rank_to_print: Tuple[int, ...] = (-1,) + _loggeer: Optional[Logger] = None + + @staticmethod + def log(*arg, **kwargs): + if DebugLogger._loggeer is None: + print(*arg, **kwargs) + else: + DebugLogger._loggeer.debug(*arg, **kwargs) + + @staticmethod + def update_vescale_debug_mode_from_env(): + DebugLogger.IS_DEBUG_MODE = os.getenv("VESCALE_DEBUG_MODE", "") == "1" + + @staticmethod + def set_vescale_debug_mode( + on: bool = True, *, rank_to_print: Optional[Union[int, Sequence[int]]] = None, logger=None + ): + if DebugLogger.IS_DEBUG_MODE != on: + os.environ["VESCALE_DEBUG_MODE"] = str(int(on)) + DebugLogger.IS_DEBUG_MODE = on + DebugLogger._rank_to_print = None + DebugLogger._loggeer = None + if not DebugLogger.IS_DEBUG_MODE: + DebugLogger.log("vescale debug mode is off") + return + DebugLogger.log("vescale debug mode is on") + if rank_to_print is None: + DebugLogger.log("rank_to_print is not set, using rank 0") + DebugLogger._rank_to_print = (0,) + return + elif isinstance(rank_to_print, int): + DebugLogger._rank_to_print = (rank_to_print,) + elif isinstance(rank_to_print, Sequence) and all(isinstance(i, int) for i in rank_to_print): + DebugLogger._rank_to_print = rank_to_print + else: + raise TypeError( + "expect rank_to_print to be either int or tuple/list of int" f"but get {type(rank_to_print)}" + ) + DebugLogger._loggeer = logger + + @staticmethod + def log_communication(func, *args, **kwargs) -> None: + DebugLogger._init_values_() + _CommunicationLogger.log_communication(func, *args, **kwargs) + + @staticmethod + def log_op(op_info: "OpInfo") -> None: + DebugLogger._init_values_() + _OperatorLogger.print_ops_execution(op_info) + + @staticmethod + def _init_values_(): + if DebugLogger._already_init: + return + DebugLogger._already_init = True + # TODO replace by new devicemesh api + DebugLogger.rank = dist.get_rank() + DebugLogger.world_size = dist.get_world_size() + if DebugLogger._rank_to_print == (-1,): + DebugLogger._rank_to_print = tuple(range(DebugLogger.world_size)) + + +class _CommunicationLogger: + _file_to_recoder = { + "/dtensor/_utils.py", + "/dtensor/redistribute.py", + "/dtensor/api.py", + "/dtensor/dtensor.py", + "/dmodule/_hook.py", + } + _func_to_exclude = {"", ""} + + @staticmethod + def _trace_to_coll_inject_point(): + result = [] + for frame_record in inspect.stack(): + frame = frame_record.frame + code = frame.f_code + co_name = code.co_name + if co_name in _CommunicationLogger._func_to_exclude: + continue + + co_filename = code.co_filename + for f in _CommunicationLogger._file_to_recoder: + if co_filename.endswith(f): + result.append(f"{f}::{co_name}") + break + return ", ".join(result) + + @staticmethod + def log_communication(func, *args, **kwargs): + DebugLogger._init_values_() + rank = DebugLogger.rank + if rank in DebugLogger._rank_to_print: + inject_point = _CommunicationLogger._trace_to_coll_inject_point() + sig = "" + bound_arguments = inspect.signature(func).bind(*args, **kwargs) + bound_arguments.apply_defaults() + for param_name, value in bound_arguments.arguments.items(): + if isinstance(value, torch.Tensor): + sig += f"\t{param_name}: {value.shape}\n" + elif "scatter_list" in param_name: + sig += f"\t{param_name}: [" + for i, item in enumerate(value): + if isinstance(item, torch.Tensor): + sig += f"{item.shape}, " + sig += "]\n" + else: + sig += f"\t{param_name}: {value}\n" + DebugLogger.log(f"[rank{rank}] {func.__name__} with stack: {inject_point}") + DebugLogger.log(f"\t{sig[1:]}") + + @staticmethod + def log_communication_decorator(): + """ + + print_coll_comm_signature prints out the collective communication, including: + collective_commucation_type, function signature, and the po + + Args: + rank_to_print: the rank that is going to DebugLogger.log out the debug info. -1 means prints on all ranks + + Example:: + + usage: used as decorator: + @print_coll_comm_signature(0) + def mesh_all_gather() + + output: + [rank0] mesh_all_gather at _reshard_to_replicate_with_pad_one_dim + tensor: torch.Size([40, 11]) + global_size: torch.Size([40, 88]) + mesh: DeviceMesh:([0, 1, 2, 3, 4, 5, 6, 7]) + scatter_dim: 1 + mesh_dim: 0 + + """ + + def decorator(func): + if not DebugLogger.IS_DEBUG_MODE: # NOTE: put here for performance if no debug mode + return func + + @functools.wraps(func) + def wrapper(*args, **kwargs): + _CommunicationLogger.log_communication(func, *args, **kwargs) + return func(*args, **kwargs) + + return wrapper + + return decorator + + +class _OperatorLogger: + op_map: Dict = dict() + + @staticmethod + def trace_to_forward() -> FrameType: + current_frame = sys._getframe() + while current_frame: + if current_frame.f_code.co_name == "forward": + break + current_frame = current_frame.f_back + return current_frame + + @staticmethod + def get_module_from_frame(frame) -> any: + return frame.f_locals.get("self", None) + + @staticmethod + def ops_info_printer(rank: int, frame: FrameType, module: object, op_info: "OpInfo"): + DebugLogger.log( + f"[rank{rank}] {module.__class__.__name__} forward() at {frame.f_code.co_filename}:{frame.f_lineno}" + ) + + # input + _OperatorLogger._print_input(op_info.schema.args_schema) + + # op + _OperatorLogger._print_op(op_info.schema) + + # output + _OperatorLogger._print_output(op_info.output_sharding) + + DebugLogger.log("\n") + + @staticmethod + def _print_op(op_schema: "OpSchema"): + DebugLogger.log(f"\t{op_schema}") + + @staticmethod + def _print_input(args_schema): + from vescale.dtensor.placement_types import DTensorSpec + + DebugLogger.log("\tinput: [") + for item in args_schema: + if isinstance(item, DTensorSpec): + _OperatorLogger.dt_spec_debug_formatter(item) + DebugLogger.log("\t]") + + @staticmethod + def _print_output(output_sharding): + from vescale.dtensor.placement_types import DTensorSpec + + output_spec = output_sharding.output_spec + DebugLogger.log("\toutput: [") + if isinstance(output_spec, DTensorSpec): + _OperatorLogger.dt_spec_debug_formatter(output_spec) + elif isinstance(output_spec, (list, tuple)): + for item in output_spec: + if isinstance(item, DTensorSpec): + _OperatorLogger.dt_spec_debug_formatter(item) + else: + DebugLogger.log(output_spec) + DebugLogger.log("\t]") + + @staticmethod + def dt_spec_debug_formatter(dt_spec): + """ + dt_spec_debug_formatter() pretty DebugLogger.log dtensor with TensorMeta, Placements, and DeviceMesh + + Args: + DTensorSpec + + Example:: + + DTensor( + TensorMeta(shape=torch.Size([104]), stride=(1,), dtype=torch.float32) + Placements:(Partial(reduce_op=RedOpType.SUM),) + DeviceMesh:([0, 1, 2, 3, 4, 5, 6, 7]) + ) + """ + DebugLogger.log("\t\tDTensor(") + DebugLogger.log(f"\t\t\t{dt_spec.tensor_meta}") + DebugLogger.log(f"\t\t\tPlacements:{dt_spec.placements}") + DebugLogger.log(f"\t\t\t{dt_spec.mesh}") + DebugLogger.log("\t\t)") + + @staticmethod + def print_ops_execution(op_info: "OpInfo") -> None: + """ + print_ops_execution() prints out the executed ops during __torch_dispatch__, it prints out the metadata including: + DModule name, propagation stage, line# in source code, input/output, operators. + + Args: + OpInfo, the operator that is going to dispatch + + Example:: + + [rank0] VeConv1D forward() at /vescale/python/vescale/model/audio/gpt2_audio.py:54 + input: [ + DTensor( + TensorMeta(shape=torch.Size([104]), stride=(1,), dtype=torch.float32) + Placements:(Partial(reduce_op=RedOpType.SUM),) + DeviceMesh:([0, 1, 2, 3, 4, 5, 6, 7]) + ) + DTensor( + TensorMeta(shape=torch.Size([40, 88]), stride=(88, 1), dtype=torch.float32) + Placements:(Shard(dim=1),) + DeviceMesh:([0, 1, 2, 3, 4, 5, 6, 7]) + ) + DTensor( + TensorMeta(shape=torch.Size([88, 104]), stride=(104, 1), dtype=torch.float32) + Placements:(Shard(dim=0),) + DeviceMesh:([0, 1, 2, 3, 4, 5, 6, 7]) + ) + ] + Op(op=aten.addmm.default, args_sharding=Spec(P on (104,)), Spec(S(1) on (40, 88)), Spec(S(0) on (88, 104))@ mesh: (8,)) + output: [ + DTensor( + TensorMeta(shape=torch.Size([40, 104]), stride=(104, 1), dtype=torch.float32) + Placements:(Partial(reduce_op=RedOpType.SUM),) + DeviceMesh:([0, 1, 2, 3, 4, 5, 6, 7]) + ) + ] + """ + frame = _OperatorLogger.trace_to_forward() + if not frame: + return + module = _OperatorLogger.get_module_from_frame(frame) + rank = DebugLogger.rank + # -1 means DebugLogger.log on all ranks + if rank in DebugLogger._rank_to_print: + _OperatorLogger.ops_info_printer(rank, frame, module, op_info) diff --git a/python/vescale/dmodule/api.py b/python/vescale/dmodule/api.py index 14a1964..a2e1530 100644 --- a/python/vescale/dmodule/api.py +++ b/python/vescale/dmodule/api.py @@ -25,6 +25,7 @@ from vescale.dtensor.device_mesh import DeviceMesh, mesh_resources from vescale.dmodule._dmodule import DModule from vescale.dmodule.placements_interface import PlacementsInterface +from vescale.debug import DebugLogger __all__ = ["parallelize_module", "is_dmodule", "PlacementsInterface"] @@ -240,6 +241,9 @@ def forward(self, x): """ + # for performance, update debug env once here + DebugLogger.update_vescale_debug_mode_from_env() + if DModule.is_dmodule(module): warnings.warn(f"{module} is already parallelized `DModule`. Skip `parallelize_module`", UserWarning) return module diff --git a/python/vescale/dtensor/_collective_utils.py b/python/vescale/dtensor/_collective_utils.py index 472b675..ec3ccab 100644 --- a/python/vescale/dtensor/_collective_utils.py +++ b/python/vescale/dtensor/_collective_utils.py @@ -29,6 +29,7 @@ from vescale.dtensor.device_mesh import DeviceMesh, mesh_resources from vescale.dtensor.placement_types import DTensorSpec +from vescale.debug import DebugLogger logger = logging.getLogger(__name__) @@ -63,6 +64,9 @@ def mesh_scatter( Returns: A :class:`Work` object """ + if DebugLogger.IS_DEBUG_MODE: + DebugLogger.log_communication(mesh_scatter, output, scatter_list, mesh_dim, async_op) + # if rank is not part of mesh, simply return output if mesh.get_coordinate() is None: return output @@ -112,6 +116,9 @@ def mesh_all_to_all( mesh_dim: int = 0, async_op: bool = False, ) -> Optional[Work]: + if DebugLogger.IS_DEBUG_MODE: + DebugLogger.log_communication(mesh_all_to_all, output_tensor_list, input_tensor_list, mesh, mesh_dim, async_op) + # if rank is not part of mesh, simply return None if mesh.get_coordinate() is None: return None @@ -170,6 +177,9 @@ def mesh_broadcast( Returns: A :class:`Tensor` object """ + if DebugLogger.IS_DEBUG_MODE: + DebugLogger.log_communication(mesh_broadcast, tensor, mesh, mesh_dim, async_op) + # if rank is not part of mesh, simply return tensor, which should be an empty tensor if mesh.get_coordinate() is None: return tensor @@ -203,6 +213,10 @@ def mesh_reduce_scatter( First peform all_reduce on the tensor, then split the tensor at scatter_dim and scatter them to a device mesh dimension. """ + + if DebugLogger.IS_DEBUG_MODE: + DebugLogger.log_communication(mesh_reduce_scatter, tensor, mesh, reduce_op, scatter_dim, mesh_dim) + # if rank is not part of mesh, simply return tensor, which should be an empty tensor if mesh.get_coordinate() is None: return tensor @@ -230,6 +244,9 @@ def mesh_all_gather( all_gather all shards and return a tensor that is replicated on the previously sharded mesh dimension """ + if DebugLogger.IS_DEBUG_MODE: + DebugLogger.log_communication(mesh_all_gather, tensor, global_size, mesh, scatter_dim, mesh_dim) + # if rank is not part of mesh, simply return tensor, which should be an empty tensor if mesh.get_coordinate() is None: return tensor diff --git a/python/vescale/dtensor/_utils.py b/python/vescale/dtensor/_utils.py index 22dada9..71b167c 100644 --- a/python/vescale/dtensor/_utils.py +++ b/python/vescale/dtensor/_utils.py @@ -16,7 +16,6 @@ from torch._prims_common import ShapeType from vescale.dtensor.device_mesh import DeviceMesh -from vescale.dtensor.dtensor import DTensor from vescale.dtensor.placement_types import InterleavedShard, Partial, Placement, Replicate, Shard from vescale.dtensor._collective_utils import mesh_all_gather @@ -320,8 +319,8 @@ def is_zero_out_local_shard(mesh: DeviceMesh, placements: Sequence[Placement]) - return False -def _equal_meta_data(dt1: DTensor, dt2: DTensor, exact_device: bool) -> bool: - if type(dt1) is not DTensor or type(dt2) is not DTensor: +def _equal_meta_data(dt1, dt2, exact_device: bool) -> bool: + if type(dt1).__name__ != "DTensor" or type(dt2).__name__ != "DTensor": return False # check itself if exact_device and (dt1.device.type != dt2.device.type): @@ -368,7 +367,7 @@ def _equal_meta_data(dt1: DTensor, dt2: DTensor, exact_device: bool) -> bool: return True -def equal(dt1: DTensor, dt2: DTensor, exact_device: bool = True) -> bool: +def equal(dt1, dt2, exact_device: bool = True) -> bool: """ check if two DTensors are 'exactly' equal """ @@ -383,8 +382,8 @@ def equal(dt1: DTensor, dt2: DTensor, exact_device: bool = True) -> bool: def allclose( - dt1: DTensor, - dt2: DTensor, + dt1, + dt2, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False, @@ -409,7 +408,7 @@ def allclose( def compute_local_offset(global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement]) -> Tuple[int, ...]: """ - Compute the offsets of a local shard of the given DTensor on its current + Compute the offsets of a local shard of the given "DTensor" on its current global rank. This is mostly used by distributed checkpointing to know the exact offsets of the local shard. """ diff --git a/python/vescale/dtensor/api.py b/python/vescale/dtensor/api.py index 34bf200..eabf812 100644 --- a/python/vescale/dtensor/api.py +++ b/python/vescale/dtensor/api.py @@ -8,238 +8,33 @@ # Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. ################################################################################ -import os -import warnings from typing import List, Optional, Sequence, Tuple, Union, cast import torch -import torch.distributed._functional_collectives as funcol import vescale.dtensor.random as random -from vescale.dtensor._collective_utils import mesh_broadcast, mesh_scatter -from vescale.dtensor._utils import compute_global_tensor_info, gather_local_tensor_shape +from vescale.dtensor._collective_utils import mesh_scatter from vescale.dtensor.device_mesh import DeviceMesh, mesh_resources -from vescale.dtensor.dtensor import DTensor +from vescale.dtensor.dtensor import DTensor, normalize_placements from vescale.dtensor.ops.utils import normalize_dims -from vescale.dtensor.placement_types import DTensorSpec, InterleavedShard, Placement, Replicate, Shard +from vescale.dtensor.placement_types import Placement, Replicate, Shard, InterleavedShard from vescale.dtensor.random import OffsetBasedRNGTracker, is_rng_supported_mesh from vescale.dtensor.redistribute import ( - Redistribute, _replicate_tensor, _scatter_tensor_by_shard, - redistribute_local_tensor, ) - __all__ = [ - "distribute_tensor", - "to_local", + "normalize_placements", "from_local", + "to_local", + "distribute_tensor", "redistribute_dtensor", - "normalize_placements", + "vescale_all_gather", + "vescale_all_reduce", + "vescale_reduce_scatter", ] -VESCALE_DISABLE_RUN_CHECK = os.environ.get("VESCALE_DISABLE_RUN_CHECK", "0") == "1" - - -def normalize_placements( - placements: Optional[Sequence[Placement]], mesh_ndim: int, *, tensor_ndim: int = 0, none_as_replicate: bool = False -) -> Optional[Tuple[Placement]]: - """ - normalize a placements to be valid. - """ - if placements is None: - return tuple(Replicate() for _ in range(mesh_ndim)) if none_as_replicate else None - - if len(placements) > mesh_ndim: - raise ValueError(f"`placements` (len={len(placements)}) have larger length than `mesh_ndim` ({mesh_ndim})!") - - if len(placements) < mesh_ndim: - warnings.warn( - "`placements` have less elements than `mesh_ndim`!. We will postpend Replicate placement to the end.", - UserWarning, - ) - placements = list(placements) + [Replicate()] * (mesh_ndim - len(placements)) - - for p in placements: - if not isinstance(p, Placement): - raise ValueError(f"Unsupported placements = {placements}!") - if isinstance(p, (Shard, InterleavedShard)) and p.dim < 0: - # normalize shard dim to be positive - p.dim += tensor_ndim - - return tuple(placements) - - -# NOTE [Autograd interaction between torch.Tensor] -# -# The autograd functions defined below are being used by the public -# facing APIs (i.e. from_local, to_local) to ensure our DTensor -# works together with torch.Tensor within autograd engine. This -# allows DistributedTensor to exist on part of the module hierarchy -# and still able to calculate gradients across the torch.Tensor and -# DistributedTensor boundary. -# As an example, we have the a module that consists of submodules -# A, B, and C, the execution flow would be like: -# input(torch.Tensor) -> Module A -> Module B -> Module C -> output (torch.Tensor) -# -# Suppose I only want to make Module B be a sharded module with -# DistributedTensor params, we would need to make the following -# flow to work: -# -# input(torch.Tensor) -> Module A -# -> DTensor input -> Sharded Module B -> DTensor output -# -> output (torch.Tensor) -> Module C -> output (torch.Tensor) -# -# We need the conversion from Module A to DTensor input, which is -# `from_local`, and conversion from DTensor output to output, which -# is `to_local`, thus these two functions must be Autograd functions. -# -class _ToTorchTensor(torch.autograd.Function): - @staticmethod - def forward(ctx, input: "DTensor", grad_placements: Optional[Sequence[Placement]], async_output: bool): - ctx.dtensor_spec = input._spec - ctx.grad_placements = grad_placements - local_tensor = input._local_tensor - if not async_output and type(local_tensor) is funcol.AsyncCollectiveTensor: - # synchronously wait for any pending collectives to get the result tensor - local_tensor = local_tensor.trigger_wait() - if hasattr(local_tensor, "elem"): - local_tensor = local_tensor.elem # type: ignore[attr-defined] - # We need to return a fresh Tensor object there as autograd metadata - # will be inplaced into it. So we don't want to pollute the Tensor - # object stored in the _local_tensor of this DTensor. - return local_tensor.view_as(local_tensor) - - @staticmethod - def backward(ctx, grad_output: torch.Tensor): - dtensor_spec = ctx.dtensor_spec - mesh = dtensor_spec.mesh - grad_placements = ctx.grad_placements - dtensor_meta = dtensor_spec.tensor_meta - - if grad_placements is not None: - grad_spec = DTensorSpec(mesh, grad_placements) - grad_output = redistribute_local_tensor(grad_output, grad_spec, dtensor_spec) - - _, tensor_stride = compute_global_tensor_info(grad_output, mesh, dtensor_spec.placements) - - return ( - DTensor( - grad_output, - mesh, - tuple(dtensor_spec.placements), - shape=dtensor_meta.shape, - dtype=dtensor_meta.dtype, - requires_grad=grad_output.requires_grad, - stride=tuple(tensor_stride), - ), - None, - None, - ) - - -class _FromTorchTensor(torch.autograd.Function): - @staticmethod - def forward( - ctx, - input: torch.Tensor, - device_mesh: DeviceMesh, - placements: Tuple[Placement, ...], - run_check: bool, - shape: Optional[torch.Size] = None, - stride: Optional[Tuple[int, ...]] = None, - support_uneven: bool = True, - async_input: bool = True, - ) -> "DTensor": - ctx.previous_placement = placements - ctx.previous_device_mesh = device_mesh - ctx.async_input = async_input - - # infer global shape and stride - if (shape is None) != (stride is None): - raise ValueError( - f"Found shape:{shape}, stride:{stride}.", - "Please pass both shape and stride at the same time!", - ) - elif shape and stride: # use given global shape and stride - tensor_shape, tensor_stride = torch.Size(shape), tuple(stride) - elif all( - p.is_replicate() or p.is_partial() for p in placements - ): # for all replicate/partial tensor, infer from local tensor - tensor_shape, tensor_stride = input.shape, input.stride() - else: # infer sharded global shape and stride - if support_uneven: # support uneven shard - meshdim_localtensor_shape = gather_local_tensor_shape(input, device_mesh, placements, shard_only=True) - assert meshdim_localtensor_shape is not None, "Out-of-mesh is impossible to support uneven sharding!" - global_shape, global_stride = compute_global_tensor_info( - input, device_mesh, placements, meshdim_localtensor_shape - ) - else: # assume even shard - global_shape, global_stride = compute_global_tensor_info(input, device_mesh, placements) - tensor_shape, tensor_stride = torch.Size(global_shape), tuple(global_stride) - - # if global rank is not participating in the device mesh, we simply: - # - set the local tensor to an empty tensor - # - set global shape/stride as the global tensor - if device_mesh.get_coordinate() is None: - input = input.new_empty(0, requires_grad=input.requires_grad) - # runtime checking for in-mesh ranks - elif run_check: - # per placement check - for idx, placement in enumerate(placements): - if placement.is_replicate(): - # broadcast rank 0 tensor to all ranks - input = mesh_broadcast(input.contiguous(), device_mesh, mesh_dim=idx) - elif placement.is_interleaved_shard(): - if input.shape[placement.dim] % placement.interleaved_size != 0: - raise ValueError( - f"Tensor size at dim {placement.dim} is not divisible by {placement.interleaved_size}" - ) - # [conservative] global tensor_shape/tensor_stride should be the same across ranks - # meshdim_localtensor_shape = gather_local_tensor_shape( - # tensor_shape, device_mesh, placements, shard_only=False - # ) - # for stacked_local_shape in meshdim_localtensor_shape.values(): - # assert stacked_local_shape.count(stacked_local_shape[0]) == len( - # stacked_local_shape - # ), "The global tensor shape must be the same across ranks!" - - # We want a fresh Tensor object that shares memory with the input tensor - return DTensor( - input.view_as(input), - device_mesh, - placements, - shape=tensor_shape, - dtype=input.dtype, - requires_grad=input.requires_grad, - stride=tensor_stride, - ) - - @staticmethod - # type: ignore[override] - def backward(ctx, grad_output: "DTensor"): - previous_placement = ctx.previous_placement - previous_device_mesh = ctx.previous_device_mesh - async_input = ctx.async_input - - # reshard to the placement when creating DistributedTensor - # so that the gradient layout matches, and we could return - # local gradients directly - if grad_output.placements != previous_placement: - grad_output = Redistribute.apply(grad_output, previous_device_mesh, previous_placement, False) - - local_tensor = grad_output._local_tensor - if not async_input and type(local_tensor) is funcol.AsyncCollectiveTensor: - # synchronously wait for any pending collectives to get the result tensor - local_tensor = local_tensor.trigger_wait() - if hasattr(local_tensor, "elem"): - local_tensor = local_tensor.elem # type: ignore[attr-defined] - - # TODO: backward is also differentiable now, add a test - # to test higher level gradients. - return local_tensor.view_as(local_tensor), None, None, None, None, None, None, None - def from_local( local_tensor: torch.Tensor, @@ -251,7 +46,7 @@ def from_local( stride: Optional[Tuple[int, ...]] = None, support_uneven: bool = True, async_input: bool = True, -) -> "DTensor": +) -> DTensor: """ Create a :class:`DTensor` from a local torch.Tensor on each rank according to the `device_mesh` and `placements` specified. @@ -311,62 +106,20 @@ def from_local( - `from_local` is differentiable - the `requires_grad` of the created `DTensor` object will depend on if `local_tensor` requires_grad or not. """ - assert type(local_tensor) is not DTensor - assert type(getattr(local_tensor, "data", None)) is not DTensor - - # if same shape/dtype, no need to run_check, if not, must allgather - # the metadatas to check the size/dtype across ranks - # There should be no data communication unless there's replication - # strategy, where we broadcast the replication from the first rank - # in the mesh dimension - device_mesh = device_mesh or mesh_resources.get_current_mesh() - device_type = device_mesh.device_type - - # convert the local tensor to desired device base on device mesh's device_type - if device_type != local_tensor.device.type and not local_tensor.is_meta: - local_tensor = local_tensor.to(device_type) - - # validate placements - placements: Tuple[Placement] = normalize_placements( - placements, device_mesh.ndim, tensor_ndim=local_tensor.ndim, none_as_replicate=True - ) - - # TODO: fix later - # if any(p.is_partial() for p in placements if p is not None): - # warnings.warn( - # "DTensor.from_local(.., [Partial]) has no zero-out feature yet! Use Partial with caution.", UserWarning - # ) - - # `from_local` is differentiable, and the gradient of the dist tensor this function - # created should flow back the gradients to the local_tensor, so we call an autograd - # function to construct the dist tensor instead. - - if VESCALE_DISABLE_RUN_CHECK: - run_check = False - - if device_mesh.get_coordinate() is None and support_uneven: - warnings.warn( - "Out-of-mesh rank uses `DTensor.from_local` under uneven sharding support, which is impossible!" - " We set `support_uneven` as `False`!" - " If uneven sharding does happen, out-of-mesh rank can only assume even sharding, which disgrees with in-mesh ranks!", - UserWarning, - ) - support_uneven = False - - return _FromTorchTensor.apply( # pyre-ignore[16]: autograd func + return DTensor.from_local( local_tensor, device_mesh, placements, - run_check, - shape, - stride, - support_uneven, - async_input, + run_check=run_check, + shape=shape, + stride=stride, + support_uneven=support_uneven, + async_input=async_input, ) def to_local( - dtensor: "DTensor", + dtensor: DTensor, *, grad_placements: Optional[Sequence[Placement]] = None, async_output: bool = True, @@ -395,16 +148,14 @@ def to_local( .. note:: `to_local` is differentiable, the `requires_grad` of the local tensor returned will depend on if the `DTensor` requires_grad or not. """ - if grad_placements is not None: - grad_placements = normalize_placements(grad_placements, dtensor.mesh.ndim, tensor_ndim=dtensor.ndim) - return _ToTorchTensor.apply(dtensor, grad_placements, async_output) + return dtensor.to_local(grad_placements=grad_placements, async_output=async_output) def distribute_tensor( tensor: torch.Tensor, device_mesh: Optional[DeviceMesh] = None, placements: Optional[Sequence[Placement]] = None, -) -> "DTensor": +) -> DTensor: """ Distribute a global `torch.Tensor` to the `device_mesh` according to the `placements` specified. The rank of `device_mesh` and `placements` must be the same. @@ -528,11 +279,11 @@ def distribute_tensor( def redistribute_dtensor( - dtensor: "DTensor", + dtensor: DTensor, device_mesh: Optional[DeviceMesh] = None, placements: Optional[Sequence[Placement]] = None, async_op: bool = True, -) -> "DTensor": +) -> DTensor: """ `redistribute_dtensor` performs necessary collective operations that redistribute the current DTensor from its current placements to a new placements, or from is current DeviceMesh @@ -557,24 +308,14 @@ def redistribute_dtensor( - `redistribute_dtensor` is differentiable (i.e., redistribute happen for both forward and backward) - This redistribute API currently only supports out of place redistribution, i.e. it always create a new DTensor object and leave the original one unchanged. """ - - # if device_mesh is not specified, use the current device_mesh - device_mesh = device_mesh or dtensor.device_mesh - - # check new placements for not specified - if placements is None: - raise RuntimeError("placements is needed for redistribute!") - - placements: Tuple[Placement] = normalize_placements(placements, device_mesh.ndim, tensor_ndim=dtensor.ndim) - - return Redistribute.apply(dtensor, device_mesh, placements, async_op) + return dtensor.redistribute(device_mesh, placements, async_op) def vescale_all_gather( - d_tensor: "DTensor", + d_tensor: DTensor, mesh_dims: Optional[Union[int, List[int]]] = None, async_op: bool = True, -) -> "DTensor": +) -> DTensor: """ all gather the DTensor along specified dimensions. Args: @@ -607,14 +348,14 @@ def vescale_all_gather( raise ValueError(f"DeviceDim {mesh_dim} is not sharded, cannot use it for all gather") dst_placements[mesh_dim] = Replicate() - return redistribute_dtensor(d_tensor, device_mesh, tuple(dst_placements), async_op) + return d_tensor.redistribute(device_mesh, tuple(dst_placements), async_op) def vescale_all_reduce( - d_tensor: "DTensor", + d_tensor: DTensor, mesh_dims: Optional[Union[int, List[int]]] = None, async_op: bool = True, -) -> "DTensor": +) -> DTensor: """ all reduce dtensor along given dimensions. @@ -640,16 +381,17 @@ def vescale_all_reduce( if mesh_dim not in dtensor_spec.sums: raise ValueError(f"MeshDim {mesh_dim} is not a reduction dimension, cannot use it for all reduce") dst_placements[mesh_dim] = Replicate() - return redistribute_dtensor(d_tensor, device_mesh, tuple(dst_placements), async_op) + + return d_tensor.redistribute(device_mesh, tuple(dst_placements), async_op) def vescale_reduce_scatter( - d_tensor: "DTensor", + d_tensor: DTensor, reduce_mesh_dims: Optional[Union[int, List[int]]] = None, scatter_dims: Union[int, List[int]] = None, mesh_dims: Union[int, List[int]] = None, async_op: bool = True, -) -> "DTensor": +) -> DTensor: """ reduce scatter a DTensor on a specified device mesh dimension. @@ -690,4 +432,4 @@ def vescale_reduce_scatter( for scatter_dim, mesh_dim in zip(scatter_dims, mesh_dims): dst_placements[mesh_dim] = Shard(scatter_dim) - return redistribute_dtensor(d_tensor, device_mesh, tuple(dst_placements), async_op) + return d_tensor.redistribute(device_mesh, tuple(dst_placements), async_op) diff --git a/python/vescale/dtensor/device_mesh.py b/python/vescale/dtensor/device_mesh.py index 51d2971..9e03727 100644 --- a/python/vescale/dtensor/device_mesh.py +++ b/python/vescale/dtensor/device_mesh.py @@ -29,6 +29,8 @@ new_group, ) +from vescale.debug import DebugLogger + logger = logging.getLogger(__name__) # only import numpy typing when type checking @@ -229,6 +231,8 @@ def __init__( _validate_mesh: bool = True, _init_process_groups: bool = True, ) -> None: + # for performance, update debug env once here + DebugLogger.update_vescale_debug_mode_from_env() # check args if mesh is None and pg is None: raise ValueError("Either `mesh` or `pg` must be provided!") diff --git a/python/vescale/dtensor/dtensor.py b/python/vescale/dtensor/dtensor.py index 054b8f9..2f813a8 100644 --- a/python/vescale/dtensor/dtensor.py +++ b/python/vescale/dtensor/dtensor.py @@ -8,24 +8,237 @@ # Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. ################################################################################ +import os import warnings from typing import Optional, Sequence, Tuple, List, Union from numbers import Number import torch +import torch.distributed._functional_collectives as funcol import vescale.dtensor.dispatch as op_dispatch -from vescale.dtensor.device_mesh import DeviceMesh -from vescale.dtensor.placement_types import DTensorSpec, Placement, Replicate, TensorMeta +from vescale.dtensor.device_mesh import DeviceMesh, mesh_resources +from vescale.dtensor.placement_types import DTensorSpec, TensorMeta, Placement, Replicate, Shard, InterleavedShard from vescale.dtensor.sharding_prop import ShardingPropagator +from vescale.dtensor.redistribute import ( + Redistribute, + redistribute_local_tensor, +) +from vescale.dtensor._utils import compute_global_tensor_info, gather_local_tensor_shape +from vescale.dtensor._collective_utils import mesh_broadcast __all__ = ["DTensor"] + aten = torch.ops.aten +VESCALE_DISABLE_RUN_CHECK = os.environ.get("VESCALE_DISABLE_RUN_CHECK", "0") == "1" _OK_TO_USE_DATA_PTR = True +def normalize_placements( + placements: Optional[Sequence[Placement]], mesh_ndim: int, *, tensor_ndim: int = 0, none_as_replicate: bool = False +) -> Optional[Tuple[Placement]]: + """ + normalize a placements to be valid. + """ + if placements is None: + return tuple(Replicate() for _ in range(mesh_ndim)) if none_as_replicate else None + + if len(placements) > mesh_ndim: + raise ValueError(f"`placements` (len={len(placements)}) have larger length than `mesh_ndim` ({mesh_ndim})!") + + if len(placements) < mesh_ndim: + warnings.warn( + "`placements` have less elements than `mesh_ndim`!. We will postpend Replicate placement to the end.", + UserWarning, + ) + placements = list(placements) + [Replicate()] * (mesh_ndim - len(placements)) + + for p in placements: + if not isinstance(p, Placement): + raise ValueError(f"Unsupported placements = {placements}!") + if isinstance(p, (Shard, InterleavedShard)) and p.dim < 0: + # normalize shard dim to be positive + p.dim += tensor_ndim + + return tuple(placements) + + +# NOTE [Autograd interaction between torch.Tensor] +# +# The autograd functions defined below are being used by the public +# facing APIs (i.e. from_local, to_local) to ensure our DTensor +# works together with torch.Tensor within autograd engine. This +# allows DistributedTensor to exist on part of the module hierarchy +# and still able to calculate gradients across the torch.Tensor and +# DistributedTensor boundary. +# As an example, we have the a module that consists of submodules +# A, B, and C, the execution flow would be like: +# input(torch.Tensor) -> Module A -> Module B -> Module C -> output (torch.Tensor) +# +# Suppose I only want to make Module B be a sharded module with +# DistributedTensor params, we would need to make the following +# flow to work: +# +# input(torch.Tensor) -> Module A +# -> DTensor input -> Sharded Module B -> DTensor output +# -> output (torch.Tensor) -> Module C -> output (torch.Tensor) +# +# We need the conversion from Module A to DTensor input, which is +# `from_local`, and conversion from DTensor output to output, which +# is `to_local`, thus these two functions must be Autograd functions. + + +class _FromTorchTensor(torch.autograd.Function): + @staticmethod + def forward( + ctx, + input: torch.Tensor, + device_mesh: DeviceMesh, + placements: Tuple[Placement, ...], + run_check: bool, + shape: Optional[torch.Size] = None, + stride: Optional[Tuple[int, ...]] = None, + support_uneven: bool = True, + async_input: bool = True, + ) -> "DTensor": + ctx.previous_placement = placements + ctx.previous_device_mesh = device_mesh + ctx.async_input = async_input + + # infer global shape and stride + if (shape is None) != (stride is None): + raise ValueError( + f"Found shape:{shape}, stride:{stride}.", + "Please pass both shape and stride at the same time!", + ) + elif shape and stride: # use given global shape and stride + tensor_shape, tensor_stride = torch.Size(shape), tuple(stride) + elif all( + p.is_replicate() or p.is_partial() for p in placements + ): # for all replicate/partial tensor, infer from local tensor + tensor_shape, tensor_stride = input.shape, input.stride() + else: # infer sharded global shape and stride + if support_uneven: # support uneven shard + meshdim_localtensor_shape = gather_local_tensor_shape(input, device_mesh, placements, shard_only=True) + assert meshdim_localtensor_shape is not None, "Out-of-mesh is impossible to support uneven sharding!" + global_shape, global_stride = compute_global_tensor_info( + input, device_mesh, placements, meshdim_localtensor_shape + ) + else: # assume even shard + global_shape, global_stride = compute_global_tensor_info(input, device_mesh, placements) + tensor_shape, tensor_stride = torch.Size(global_shape), tuple(global_stride) + + # if global rank is not participating in the device mesh, we simply: + # - set the local tensor to an empty tensor + # - set global shape/stride as the global tensor + if device_mesh.get_coordinate() is None: + input = input.new_empty(0, requires_grad=input.requires_grad) + # runtime checking for in-mesh ranks + elif run_check: + # per placement check + for idx, placement in enumerate(placements): + if placement.is_replicate(): + # broadcast rank 0 tensor to all ranks + input = mesh_broadcast(input.contiguous(), device_mesh, mesh_dim=idx) + elif placement.is_interleaved_shard(): + if input.shape[placement.dim] % placement.interleaved_size != 0: + raise ValueError( + f"Tensor size at dim {placement.dim} is not divisible by {placement.interleaved_size}" + ) + # [conservative] global tensor_shape/tensor_stride should be the same across ranks + # meshdim_localtensor_shape = gather_local_tensor_shape( + # tensor_shape, device_mesh, placements, shard_only=False + # ) + # for stacked_local_shape in meshdim_localtensor_shape.values(): + # assert stacked_local_shape.count(stacked_local_shape[0]) == len( + # stacked_local_shape + # ), "The global tensor shape must be the same across ranks!" + + # We want a fresh Tensor object that shares memory with the input tensor + return DTensor( + input.view_as(input), + device_mesh, + placements, + shape=tensor_shape, + dtype=input.dtype, + requires_grad=input.requires_grad, + stride=tensor_stride, + ) + + @staticmethod + # type: ignore[override] + def backward(ctx, grad_output: "DTensor"): + previous_placement = ctx.previous_placement + previous_device_mesh = ctx.previous_device_mesh + async_input = ctx.async_input + + # reshard to the placement when creating DistributedTensor + # so that the gradient layout matches, and we could return + # local gradients directly + if grad_output.placements != previous_placement: + grad_output = Redistribute.apply(grad_output, previous_device_mesh, previous_placement, False) + + local_tensor = grad_output._local_tensor + if not async_input and type(local_tensor) is funcol.AsyncCollectiveTensor: + # synchronously wait for any pending collectives to get the result tensor + local_tensor = local_tensor.trigger_wait() + if hasattr(local_tensor, "elem"): + local_tensor = local_tensor.elem # type: ignore[attr-defined] + + # TODO: backward is also differentiable now, add a test + # to test higher level gradients. + return local_tensor.view_as(local_tensor), None, None, None, None, None, None, None + + +class _ToTorchTensor(torch.autograd.Function): + @staticmethod + def forward(ctx, input: "DTensor", grad_placements: Optional[Sequence[Placement]], async_output: bool): + ctx.dtensor_spec = input._spec + ctx.grad_placements = grad_placements + local_tensor = input._local_tensor + if not async_output and type(local_tensor) is funcol.AsyncCollectiveTensor: + # synchronously wait for any pending collectives to get the result tensor + local_tensor = local_tensor.trigger_wait() + if hasattr(local_tensor, "elem"): + local_tensor = local_tensor.elem # type: ignore[attr-defined] + # We need to return a fresh Tensor object there as autograd metadata + # will be inplaced into it. So we don't want to pollute the Tensor + # object stored in the _local_tensor of this DTensor. + return local_tensor.view_as(local_tensor) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + dtensor_spec = ctx.dtensor_spec + mesh = dtensor_spec.mesh + grad_placements = ctx.grad_placements + dtensor_meta = dtensor_spec.tensor_meta + + if grad_placements is not None: + grad_spec = DTensorSpec(mesh, grad_placements) + grad_output = redistribute_local_tensor(grad_output, grad_spec, dtensor_spec) + + _, tensor_stride = compute_global_tensor_info(grad_output, mesh, dtensor_spec.placements) + + return ( + DTensor( + grad_output, + mesh, + tuple(dtensor_spec.placements), + shape=dtensor_meta.shape, + dtype=dtensor_meta.dtype, + requires_grad=grad_output.requires_grad, + stride=tuple(tensor_stride), + ), + None, + None, + ) + + +################ DTensor below ################ + + def _dispatch_torch_make_wrapper_subclass(*args, data_ptr, **kwargs): global _OK_TO_USE_DATA_PTR @@ -191,19 +404,59 @@ def from_local( support_uneven: bool = True, async_input: bool = True, ) -> "DTensor": - # we have to do this to avoid circle import. - from vescale.dtensor.api import from_local - # TODO: moving impl code here for performance, as here is on the critial path but api function is less used - return from_local( + + assert type(local_tensor) is not DTensor + assert type(getattr(local_tensor, "data", None)) is not DTensor + + # if same shape/dtype, no need to run_check, if not, must allgather + # the metadatas to check the size/dtype across ranks + # There should be no data communication unless there's replication + # strategy, where we broadcast the replication from the first rank + # in the mesh dimension + device_mesh = device_mesh or mesh_resources.get_current_mesh() + device_type = device_mesh.device_type + + # convert the local tensor to desired device base on device mesh's device_type + if device_type != local_tensor.device.type and not local_tensor.is_meta: + local_tensor = local_tensor.to(device_type) + + # validate placements + placements: Tuple[Placement] = normalize_placements( + placements, device_mesh.ndim, tensor_ndim=local_tensor.ndim, none_as_replicate=True + ) + + # TODO: fix later + # if any(p.is_partial() for p in placements if p is not None): + # warnings.warn( + # "DTensor.from_local(.., [Partial]) has no zero-out feature yet! Use Partial with caution.", UserWarning + # ) + + # `from_local` is differentiable, and the gradient of the dist tensor this function + # created should flow back the gradients to the local_tensor, so we call an autograd + # function to construct the dist tensor instead. + + if VESCALE_DISABLE_RUN_CHECK: + run_check = False + + if device_mesh.get_coordinate() is None and support_uneven: + warnings.warn( + "Out-of-mesh rank uses `DTensor.from_local` under uneven sharding support, which is impossible!" + " We set `support_uneven` as `False`!" + " If uneven sharding does happen, out-of-mesh rank can only assume even sharding, which disgrees with in-mesh ranks!", + UserWarning, + ) + support_uneven = False + + return _FromTorchTensor.apply( # pyre-ignore[16]: autograd func local_tensor, device_mesh, placements, - run_check=run_check, - shape=shape, - stride=stride, - support_uneven=support_uneven, - async_input=async_input, + run_check, + shape, + stride, + support_uneven, + async_input, ) def to_local( @@ -212,21 +465,32 @@ def to_local( grad_placements: Optional[Sequence[Placement]] = None, async_output: bool = True, ) -> torch.Tensor: - from vescale.dtensor.api import to_local - # TODO: moving impl code here for performance, as here is on the critial path but api function is NEVER used - return to_local(self, grad_placements=grad_placements, async_output=async_output) + # NOTE: moving impl code here for performance, as here is on the critial path but api function is NEVER used + if grad_placements is not None: + grad_placements: Tuple[Placement] = normalize_placements( + grad_placements, self._spec.mesh.ndim, tensor_ndim=self.ndim + ) + + return _ToTorchTensor.apply(self, grad_placements, async_output) def redistribute( self, device_mesh: Optional[DeviceMesh] = None, placements: Optional[Sequence[Placement]] = None, async_op: bool = True, ) -> "DTensor": - from vescale.dtensor.api import redistribute_dtensor + # NOTE: moving impl code here for performance, as here is on the critial path but api function is rarely used + + # if device_mesh is not specified, use the current device_mesh + device_mesh = device_mesh or self._spec.mesh + + # check new placements for not specified + if placements is None: + raise RuntimeError("placements is needed for redistribute!") + placements: Tuple[Placement] = normalize_placements(placements, device_mesh.ndim, tensor_ndim=self.ndim) - # TODO: moving impl code here for performance, as here is on the critial path but api function is rarely used - return redistribute_dtensor(self, device_mesh=device_mesh, placements=placements, async_op=async_op) + return Redistribute.apply(self, device_mesh, placements, async_op) def requires_grad_(self, mode=True): self._local_tensor.requires_grad_(mode) diff --git a/python/vescale/dtensor/ops/math_ops.py b/python/vescale/dtensor/ops/math_ops.py index 78a4ff4..09a2e61 100644 --- a/python/vescale/dtensor/ops/math_ops.py +++ b/python/vescale/dtensor/ops/math_ops.py @@ -14,7 +14,14 @@ import torch.distributed.distributed_c10d as c10d from vescale.dtensor import DeviceMesh -from vescale.dtensor.op_schema import OpSchema, OutputSharding, RuntimeSchemaInfo, OpStrategy, PlacementStrategy +from vescale.dtensor.op_schema import ( + OpSchema, + OutputSharding, + RuntimeSchemaInfo, + OpStrategy, + PlacementStrategy, + TupleStrategy, +) from vescale.dtensor.ops.utils import ( as_list, generate_redistribute_costs, @@ -257,38 +264,25 @@ def var_reduction_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: @register_op_strategy([aten.topk.default]) -def topk(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: +def topk(mesh: DeviceMesh, op_schema: OpSchema) -> TupleStrategy: input_strategy = op_schema.args_schema[0] dim = op_schema.args_schema[2] if len(op_schema.args_schema) > 2 else -1 input_strategy = cast(OpStrategy, input_strategy) dim = cast(int, dim) dim = normalize_dim(dim, input_strategy.output_ndim) + input_placement_strategy = input_strategy.strategies[0] + input_src_spec = input_placement_strategy.output_spec - output_strategy = OpStrategy([]) - for input_placement_strategy in input_strategy.strategies: - redistribute_costs = [] - input_src_spec = input_placement_strategy.output_spec + output_target_spec = DTensorSpec(mesh=mesh, placements=input_src_spec.placements) + value_out_strategy = PlacementStrategy( + output_spec=output_target_spec, + ) - # make sure input is replicated along the sort dim - input_target_spec = DTensorSpec( - mesh=mesh, - placements=replicate_reduction_dims(input_src_spec.placements, [dim]), - tensor_meta=input_src_spec.tensor_meta, - ) - # TODO: change to vescale stype redistribution - redistribute_costs.append(generate_redistribute_costs(input_strategy, input_target_spec)) - output_target_spec = DTensorSpec( - mesh=mesh, - placements=input_target_spec.placements, - ) - output_strategy.strategies.append( - PlacementStrategy( - output_spec=output_target_spec, - input_specs=[input_target_spec], - redistribute_cost=redistribute_costs, - ) - ) + index_out_strategy = PlacementStrategy( + output_spec=output_target_spec, + ) + output_strategy = TupleStrategy(childs=[OpStrategy([value_out_strategy]), OpStrategy([index_out_strategy])]) return output_strategy diff --git a/python/vescale/dtensor/ops/tensor_ops.py b/python/vescale/dtensor/ops/tensor_ops.py index 5022a79..fb7d52c 100644 --- a/python/vescale/dtensor/ops/tensor_ops.py +++ b/python/vescale/dtensor/ops/tensor_ops.py @@ -10,6 +10,7 @@ from typing import List, Optional, Sequence, Tuple, cast +import copy import torch from torch.utils._python_dispatch import _get_current_dispatch_mode @@ -796,7 +797,7 @@ def unbind_rule(op_schema: OpSchema) -> OutputSharding: output_spec_list: List[DTensorSpec] = [] input_spec = cast(DTensorSpec, op_schema.args_schema[0]) ndim = input_spec.ndim - dim = cast(int, op_schema.args_schema[1]) + dim = cast(int, op_schema.args_schema[1] if len(op_schema.args_schema) > 1 else 0) dim = normalize_dim(dim, ndim) # TODO: tensor to unbind cannot have Partial @@ -887,3 +888,10 @@ def index_add_rule(op_schema: OpSchema) -> OutputSharding: def _prop_aten_alias(op_schema: OpSchema) -> OutputSharding: output_spec = cast(DTensorSpec, op_schema.args_schema[0]) return OutputSharding(output_spec=output_spec) + + +@register_prop_rule(aten.nonzero.default) +def _nonzero_prop(op_schema: OpSchema): + output_spec = cast(DTensorSpec, copy.deepcopy(op_schema.args_schema[0])) + output_spec.tensor_meta = None + return OutputSharding(output_spec=output_spec) diff --git a/python/vescale/dtensor/redistribute.py b/python/vescale/dtensor/redistribute.py index d7f4711..ec75515 100644 --- a/python/vescale/dtensor/redistribute.py +++ b/python/vescale/dtensor/redistribute.py @@ -440,7 +440,7 @@ def forward( # type: ignore[override] ctx, input: "dtensor.DTensor", device_mesh: DeviceMesh, - placements: List[Placement], + placements: Tuple[Placement], async_op: bool = True, ): current_spec = input._spec @@ -451,7 +451,7 @@ def forward( # type: ignore[override] if input._spec.placements == placements: return input - target_spec = DTensorSpec(device_mesh, tuple(placements), tensor_meta=input._spec.tensor_meta) + target_spec = DTensorSpec(device_mesh, placements, tensor_meta=input._spec.tensor_meta) local_tensor = input._local_tensor output = redistribute_local_tensor(local_tensor, current_spec, target_spec, async_op) diff --git a/python/vescale/dtensor/sharding_prop.py b/python/vescale/dtensor/sharding_prop.py index ccd5f2e..6367d6d 100644 --- a/python/vescale/dtensor/sharding_prop.py +++ b/python/vescale/dtensor/sharding_prop.py @@ -12,6 +12,7 @@ from typing import Callable, Dict, Optional, Sequence, Union, cast, List import torch +import copy from torch._ops import OpOverload from torch._subclasses import FakeTensorMode @@ -39,6 +40,7 @@ aten.native_dropout.default, # aten.native_layer_norm.default, aten.nll_loss_forward.default, + aten.topk.default, ] @@ -259,7 +261,7 @@ def spec_to_strategy(spec: object) -> object: for strategy in op_strategy.childs: assert isinstance(strategy, OpStrategy) output_strategy = self._select_strategy(strategy) - out_spec_list.append(output_strategy.output_spec) + out_spec_list.append(copy.deepcopy(output_strategy.output_spec)) if output_strategy.output_spec is None: fallback_prop = True diff --git a/python/vescale/optim/base_optimizer.py b/python/vescale/optim/base_optimizer.py index 634ea35..a50c1bb 100644 --- a/python/vescale/optim/base_optimizer.py +++ b/python/vescale/optim/base_optimizer.py @@ -164,7 +164,7 @@ def __init__( self.models = [self.models] if any(getattr(x, "use_distributed_optimizer", False) for x in self.models): - raise RuntimeError( + raise ValueError( "detected DDP with use_distributed_optimizer on, please consider use a distributed optimizer" ) diff --git a/python/vescale/optim/clip_grads.py b/python/vescale/optim/clip_grads.py index 77c511f..e1beb82 100644 --- a/python/vescale/optim/clip_grads.py +++ b/python/vescale/optim/clip_grads.py @@ -55,10 +55,11 @@ def clip_grad_norm_fp32( # Grads. grads = [] + grad_dtype = None for param in parameters: if param.grad is not None: - assert param.grad.type() == "torch.cuda.FloatTensor" grads.append(param.grad.detach()) + grad_dtype = grads[-1].dtype # Norm parameters. max_norm = float(max_norm) @@ -116,5 +117,7 @@ def clip_grad_norm_fp32( g.data.mul_(clip_coeff) else: multi_tensor_applier(amp_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff) + for g in grads: + g.to(grad_dtype) return total_norm diff --git a/python/vescale/optim/distributed_optimizer.py b/python/vescale/optim/distributed_optimizer.py index f9ca030..01e01c1 100644 --- a/python/vescale/optim/distributed_optimizer.py +++ b/python/vescale/optim/distributed_optimizer.py @@ -139,6 +139,8 @@ class DistributedOptimizer(OptimizerBase): that clipping is ignored if clip_grad == 0. overlap_param_gather: whether overlaping parameter all gathering with forward. By default, False. + grad_to_fp32: whether casting both the gradients and the optimizer + states to fp32. By default, True. optimizer_kwargs: used to initialize base optimizer instance when class is provided for `optimizer` argument. By default, None. @@ -179,6 +181,7 @@ def __init__( models: Sequence[DDP], clip_grad: float = 0.0, overlap_param_gather: bool = False, + grad_to_fp32: bool = True, optimizer_kwargs: Dict[str, Any] = None, **kwargs, ): @@ -214,9 +217,7 @@ def __init__( elif self.data_parallel_group != m.data_parallel_group: raise RuntimeError("Detect model chunks of warious data-parallel process groups") if not all(x.use_distributed_optimizer for x in models): - print( - "You are using a distributed optimizer, it's suggested to set use_distributed_optimizer on for better performance" - ) + raise ValueError("Please open `use_distributed_optimizer` in DDP initialization for better performance.") param_dtype_cnt = {} main_param_dtype_cnt = 0 @@ -241,6 +242,7 @@ def __init__( self.clip_grad = clip_grad self.overlap_param_gather = overlap_param_gather + self.grad_to_fp32 = grad_to_fp32 # Model parameter sharding info for omnistore checkpointing self.param_to_name = {} @@ -270,7 +272,7 @@ def __init__( self.param_across_dp_ranks_info = {} # Mapping fp32 master weights and original fp16 weights # for mix percision training - self.param_to_origin_param_for_shard_fp32_from_float16_groups = {} + self.param_to_origin_param_for_shard_casted_float16_groups = {} for model_chunk in self.models: self.per_bucket_numel.append( @@ -294,7 +296,7 @@ def __init__( self.model_fp32_groups, self.shard_float16_groups, self.shard_fp32_groups, - self.shard_fp32_from_float16_groups, + self.shard_casted_float16_groups, ) = self.build_model_and_main_param_groups( self.model_gbuf_ranges, self.model_param_gbuf_map, self.opt_group_ranges ) @@ -375,6 +377,9 @@ def __init__( self.optimizer.param_groups = [g["orig_group"] for g in self.opt_group_ranges] self.optimizer.load_state_dict(self.optimizer.state_dict()) + # A flag indicates whether the `step` has been called. It will be reset after invoking `zero_grad`. + self.step_issued = False + def build_param_sharding_info_for_checkpoint(self, model: DDP, dtype, gbuf_world_all_ranges): param_world_index_map = model.grad_buffer_param_index_map[dtype] for param, param_world_indexes in param_world_index_map.items(): @@ -610,12 +615,12 @@ def build_model_and_main_param_groups(self, model_gbuf_ranges, param_gbuf_map, o # model_fp32_groups: original fp32 parameters # shard_float16_groups: shards of original float16 parameters # shard_fp32_groups: shards of original fp32 parameters - # shard_fp32_from_float16_groups: fp32 copy of float16 parameters + # shard_casted_float16_groups: fp32 copy of float16 parameters if cast_shard_param_to_fp32, otherwise, they are just copies model_float16_groups = [] model_fp32_groups = [] shard_float16_groups = [] shard_fp32_groups = [] - shard_fp32_from_float16_groups = [] + shard_casted_float16_groups = [] # Allocate (or slice) each group's param shard. for group_index, group_range in enumerate(opt_group_ranges): @@ -624,12 +629,12 @@ def build_model_and_main_param_groups(self, model_gbuf_ranges, param_gbuf_map, o model_fp32_params_this_group = [] shard_float16_params_this_group = [] shard_fp32_params_this_group = [] - shard_fp32_from_float16_params_this_group = [] + shard_casted_float16_params_this_group = [] model_float16_groups.append(model_float16_params_this_group) model_fp32_groups.append(model_fp32_params_this_group) shard_float16_groups.append(shard_float16_params_this_group) shard_fp32_groups.append(shard_fp32_params_this_group) - shard_fp32_from_float16_groups.append(shard_fp32_from_float16_params_this_group) + shard_casted_float16_groups.append(shard_casted_float16_params_this_group) for model_param in group_range["params"]: assert model_param.requires_grad @@ -648,7 +653,10 @@ def build_model_and_main_param_groups(self, model_gbuf_ranges, param_gbuf_map, o model_param if not isinstance(model_param, DTensor) else model_param._local_tensor ) shard_model_param = model_param_tensor.detach().view(-1)[param_range.start : param_range.end] - shard_main_param = shard_model_param.clone().float() + if self.grad_to_fp32: + shard_main_param = shard_model_param.clone().float() + else: + shard_main_param = shard_model_param.clone() # copy sharded info from DTensor shard_model_param._spec = None if not isinstance(model_param, DTensor) else model_param._spec if hasattr(model_param, "shared"): @@ -662,8 +670,8 @@ def build_model_and_main_param_groups(self, model_gbuf_ranges, param_gbuf_map, o # Add to group. model_float16_params_this_group.append(model_param) shard_float16_params_this_group.append(shard_model_param) - shard_fp32_from_float16_params_this_group.append(shard_main_param) - self.param_to_origin_param_for_shard_fp32_from_float16_groups[shard_main_param] = model_param + shard_casted_float16_params_this_group.append(shard_main_param) + self.param_to_origin_param_for_shard_casted_float16_groups[shard_main_param] = model_param # fp32 params. elif model_param.type() == "torch.cuda.FloatTensor": model_param_tensor = ( @@ -692,7 +700,7 @@ def build_model_and_main_param_groups(self, model_gbuf_ranges, param_gbuf_map, o # changing group_range will implicitly change self.optimzer.param_groups. group_range["orig_group"]["params"] = [ *shard_fp32_params_this_group, - *shard_fp32_from_float16_params_this_group, + *shard_casted_float16_params_this_group, ] return ( @@ -700,7 +708,7 @@ def build_model_and_main_param_groups(self, model_gbuf_ranges, param_gbuf_map, o model_fp32_groups, shard_float16_groups, shard_fp32_groups, - shard_fp32_from_float16_groups, + shard_casted_float16_groups, ) def get_model_param_range_map(self, param): @@ -768,12 +776,13 @@ def state_dict(self): self.param_across_dp_ranks_info.get(param), ) # If it is mix percision training, we should save master fp32 weights - if not all(not group for group in self.shard_fp32_from_float16_groups): - for group in self.shard_fp32_from_float16_groups: + if not all(not group for group in self.shard_casted_float16_groups): + for group in self.shard_casted_float16_groups: for param in group: - original_param = self.param_to_origin_param_for_shard_fp32_from_float16_groups[param] + original_param = self.param_to_origin_param_for_shard_casted_float16_groups[param] name = self.param_to_name[original_param] - distributed_state[torch.float32][name]["shard_fp32_from_float16_groups"] = OptimizerStateSpec( + dtype = torch.float32 if self.grad_to_fp32 else param.dtype + distributed_state[dtype][name]["shard_casted_float16_groups"] = OptimizerStateSpec( self.param_global_shape_info[original_param], self.param_local_shape_info[original_param], self.param_global_offset_info[original_param], @@ -825,14 +834,15 @@ def load_state_dict(self, state_dict): self.optimizer.load_state_dict(optimizer_state) - if not all(not group for group in self.shard_fp32_from_float16_groups): - for group in self.shard_fp32_from_float16_groups: + if not all(not group for group in self.shard_casted_float16_groups): + for group in self.shard_casted_float16_groups: for param in group: - original_param = self.param_to_origin_param_for_shard_fp32_from_float16_groups[param] + original_param = self.param_to_origin_param_for_shard_casted_float16_groups[param] name = self.param_to_name[original_param] # The weights have been flatten into 1D and get range based on current rank (if necessary) # in the "resume optimizer state loop - param.copy_(state_dict[torch.float32][name]["shard_fp32_from_float16_groups"]) + dtype = torch.float32 if self.grad_to_fp32 else param.dtype + param.copy_(state_dict[dtype][name]["shard_casted_float16_groups"]) # state_dict['shard_fp32_from_float16_groups'] # optimizer_state['shard_fp32_from_float16_groups'] # TODO: Copy data for the main params. @@ -860,7 +870,7 @@ def zero_grad(self, set_to_none=True): self.model_fp32_groups, self.shard_float16_groups, # grad empty/unused here? self.shard_fp32_groups, # throws grad-access warning - self.shard_fp32_from_float16_groups, + self.shard_casted_float16_groups, ): for group in groups: _zero_grad_group_helper(group, set_to_none) @@ -877,9 +887,12 @@ def zero_grad(self, set_to_none=True): # pre-hook when this all-gather finishes (to ensure that the communication # kernels don't head-of-line block the compute kernels since we run with # CUDA_DEVICE_MAX_CONNECTIONS=1 to support sequence parallelism). - if self.overlap_param_gather: + # NOTE: we shouldn't issue param all-gather if runned before any `optim.step`. + if self.overlap_param_gather and self.step_issued: self._dispatch_gather_model_params(all_gather_handle_index=0) + self.step_issued = False + def shard_buffer_among_dp_world(self, buffer: torch.Tensor): assert buffer.numel() % self.data_parallel_world_size == 0 shard_size = buffer.numel() // self.data_parallel_world_size @@ -1096,7 +1109,7 @@ def _get_model_and_main_params_data_float16(self): """ model_data = [] main_data = [] - for model_group, main_group in zip(self.shard_float16_groups, self.shard_fp32_from_float16_groups): + for model_group, main_group in zip(self.shard_float16_groups, self.shard_casted_float16_groups): for model_param, main_param in zip(model_group, main_group): model_data.append(model_param.data) main_data.append(main_param.data) @@ -1121,10 +1134,13 @@ def copy_group_grads(model_groups, shard_main_groups): model_grad = model_param.main_grad shard_model_grad = model_grad.view(-1)[param_range.start : param_range.end] - shard_main_param.grad = shard_model_grad.float() + if self.grad_to_fp32: + shard_main_param.grad = shard_model_grad.float() + else: + shard_main_param.grad = shard_model_grad # Copy model groups to shard groups. - copy_group_grads(self.model_float16_groups, self.shard_fp32_from_float16_groups) + copy_group_grads(self.model_float16_groups, self.shard_casted_float16_groups) copy_group_grads(self.model_fp32_groups, self.shard_fp32_groups) def _copy_main_params_to_model_params(self): @@ -1153,7 +1169,7 @@ def copy_group_params(shard_main_groups, model_groups): shard_model_param.data.copy_(shard_main_param) # Copy shard groups to model groups. - copy_group_params(self.shard_fp32_from_float16_groups, self.model_float16_groups) + copy_group_params(self.shard_casted_float16_groups, self.model_float16_groups) copy_group_params(self.shard_fp32_groups, self.model_fp32_groups) def _copy_model_params_to_main_params(self): @@ -1177,7 +1193,7 @@ def copy_group_params(model_groups, shard_main_groups): shard_main_param.data.copy_(shard_model_param) # Copy model groups to shard groups. - copy_group_params(self.model_float16_groups, self.shard_fp32_from_float16_groups) + copy_group_params(self.model_float16_groups, self.shard_casted_float16_groups) copy_group_params(self.model_fp32_groups, self.shard_fp32_groups) def clip_grad_norm(self, clip_grad): @@ -1217,6 +1233,7 @@ def step(self): for all_gather_handle_index in range(self.num_all_gather_handles): self._dispatch_gather_model_params(all_gather_handle_index) + self.step_issued = True return grad_norm def get_parameters(self): diff --git a/scripts/run_test.sh b/scripts/run_test.sh index dfa65b7..8fd7ec1 100755 --- a/scripts/run_test.sh +++ b/scripts/run_test.sh @@ -17,7 +17,7 @@ export PYTHONPATH # run test while IFS= read -r -d '' file -do +do pkill -9 python3 || true # ok if nothing to kill pytest -s "${file}" pkill -9 python3 || true diff --git a/test/common_dtensor.py b/test/common_dtensor.py index 3be3b90..f68d7d9 100644 --- a/test/common_dtensor.py +++ b/test/common_dtensor.py @@ -28,9 +28,10 @@ import torch.testing._internal.distributed.fake_pg as fake_pg from torch.utils._pytree import TreeSpec, tree_flatten, tree_unflatten -import vescale -from vescale import DeviceMesh, Shard, Replicate, distribute_tensor, DTensor -from vescale.dtensor.placement_types import Placement, DTensorSpec +from vescale.dtensor.device_mesh import DeviceMesh +from vescale.dtensor.api import distribute_tensor +from vescale.dtensor import DTensor +from vescale.dtensor.placement_types import Placement, Shard, Replicate, DTensorSpec # add new skipped test exit code TEST_SKIPS["torch-version-2.2"] = TestSkip(90, "Need torch version bigger than 2.2") @@ -94,6 +95,8 @@ class RedistributeProfile: @contextmanager def redistribute_profiler() -> Generator[RedistributeProfile, None, None]: + import vescale + orig_redistribute_local_tensor = vescale.dtensor.redistribute.redistribute_local_tensor profile: RedistributeProfile = RedistributeProfile(num_calls=0) diff --git a/test/debug/model.py b/test/debug/model.py new file mode 100644 index 0000000..56fb7cc --- /dev/null +++ b/test/debug/model.py @@ -0,0 +1,59 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +import torch +from torch import nn +from vescale.dtensor.placement_types import Replicate, Shard + + +class MLP(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(4, 16) + self.gelu = torch.nn.GELU() + self.fc2 = nn.Linear(16, 4) + + def forward(self, x): + x = self.fc1(x) + x = self.gelu(x) + x = self.fc2(x) + return x + + +class Block(nn.Module): + def __init__(self): + super().__init__() + self.ln = nn.LayerNorm(4, bias=False) + self.mlp = MLP() + + def forward(self, x): + return self.mlp(self.ln(x)) + + +param_sharding_plan = { + "fc1.weight": [Shard(0)], + "fc1.bias": [Shard(0)], + "fc2.weight": [Shard(1)], + "fc2.bias": [Replicate()], +} + +fwd_resharding_plan = { + "fc1.input": [[Replicate()]], + "fc2.output": [[Replicate()]], +} + +sharding_plan = {"parameter": param_sharding_plan, "forward": fwd_resharding_plan} diff --git a/test/debug/test_log_with_dtensor_test_base.py b/test/debug/test_log_with_dtensor_test_base.py new file mode 100644 index 0000000..80b24b2 --- /dev/null +++ b/test/debug/test_log_with_dtensor_test_base.py @@ -0,0 +1,123 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +import torch +from torch.testing._internal.common_utils import run_tests + +import logging +import io +import os +from vescale.dmodule.api import parallelize_module +from vescale.dtensor.device_mesh import DeviceMesh +from vescale.debug import DebugLogger +from model import MLP, sharding_plan +from contextlib import redirect_stdout +from common_dtensor import DTensorTestBase, with_comms + + +class DModuleTestDebugLog(DTensorTestBase): + @property + def world_size(self): + return 4 + + @with_comms + def test_simple_std_out(self): + DebugLogger.set_vescale_debug_mode(rank_to_print=(0, 1, 2, 3)) + device_mesh = DeviceMesh("cuda", list(range(self.world_size))) + + with io.StringIO() as buf, redirect_stdout(buf): + model = MLP() + dmodule = parallelize_module(model, device_mesh, sharding_plan) + input = torch.ones((4, 4, 4)) + output = dmodule(input).to_local() + output.sum().backward() + out = buf.getvalue() + self.assertGreater(len("".join(out.split())), 100) + + @with_comms + def test_simple_std_out_without_set0(self): + os.environ["VESCALE_DEBUG_MODE"] = "1" + device_mesh = DeviceMesh("cuda", list(range(self.world_size))) + + with io.StringIO() as buf, redirect_stdout(buf): + model = MLP() + dmodule = parallelize_module(model, device_mesh, sharding_plan) + input = torch.ones((4, 4, 4)) + output = dmodule(input).to_local() + output.sum().backward() + out = buf.getvalue() + self.assertGreater(len("".join(out.split())), 100) + + @with_comms + def test_simple_std_out_without_set1(self): + device_mesh = DeviceMesh("cuda", list(range(self.world_size))) + + with io.StringIO() as buf, redirect_stdout(buf): + model = MLP() + os.environ["VESCALE_DEBUG_MODE"] = "1" + dmodule = parallelize_module(model, device_mesh, sharding_plan) + input = torch.ones((4, 4, 4)) + output = dmodule(input).to_local() + output.sum().backward() + out = buf.getvalue() + self.assertGreater(len("".join(out.split())), 100) + + @with_comms + def test_simple_only_rank1(self): + DebugLogger.set_vescale_debug_mode(rank_to_print=(1)) + device_mesh = DeviceMesh("cuda", list(range(self.world_size))) + + with io.StringIO() as buf, redirect_stdout(buf): + model = MLP() + dmodule = parallelize_module(model, device_mesh, sharding_plan) + input = torch.ones((4, 4, 4)) + output = dmodule(input).to_local() + output.sum().backward() + out = buf.getvalue() + if self.rank == 1: + self.assertGreater(len("".join(out.split())), 100) + else: + self.assertEqual("".join(out.split()), "") + + @with_comms + def test_simple_logging(self): + logger = logging.getLogger("test_simple_logging") + logger.setLevel(logging.DEBUG) + + log_filename = f"logging_sample_rank{self.rank}.log" + + fh = logging.FileHandler(log_filename, mode="w") + fh.setLevel(logging.DEBUG) + logger.addHandler(fh) + + DebugLogger.set_vescale_debug_mode(rank_to_print=(0, 1, 2, 3), logger=logger) + device_mesh = DeviceMesh("cuda", list(range(self.world_size))) + + model = MLP() + dmodule = parallelize_module(model, device_mesh, sharding_plan) + input = torch.ones((4, 4, 4)) + output = dmodule(input).to_local() + output.sum().backward() + + with open(log_filename) as file: + out = file.read() + out = "".join(out.split()) + self.assertGreater(len(out), 100) + + +if __name__ == "__main__": + run_tests() diff --git a/test/debug/test_log_with_torch_run.py b/test/debug/test_log_with_torch_run.py new file mode 100644 index 0000000..07b33a9 --- /dev/null +++ b/test/debug/test_log_with_torch_run.py @@ -0,0 +1,51 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ +import os +import unittest +import subprocess + +dir_name = "torchrun_scripts" + +target_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), dir_name) + + +def run_script(fname, env=None): + current_dir = os.getcwd() + os.chdir(target_dir) + result = subprocess.run(["torchrun", "--standalone", "--nnodes=1", "--nproc-per-node=4", fname], env=env) + os.chdir(current_dir) + return result.returncode + + +class DebugTorchrunTestSuite(unittest.TestCase): + def test_simple_std_out(self): + self.assertEqual(run_script("simple_std_out.py"), 0) + + def test_simple_only_rank1(self): + self.assertEqual(run_script("simple_only_rank1.py"), 0) + + def test_simple_logging(self): + self.assertEqual(run_script("simple_logging.py"), 0) + + def test_simple_set_env(self): + my_env = os.environ + my_env["VESCALE_DEBUG_MODE"] = "1" + self.assertEqual(run_script("simple_std_out.py", my_env), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/debug/torchrun_scripts/simple_logging.py b/test/debug/torchrun_scripts/simple_logging.py new file mode 100644 index 0000000..f5f07e4 --- /dev/null +++ b/test/debug/torchrun_scripts/simple_logging.py @@ -0,0 +1,56 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +import logging +import torch +import os +import sys +from vescale.dmodule.api import parallelize_module +from vescale.dtensor.device_mesh import DeviceMesh +from vescale.debug import DebugLogger + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from model import MLP, sharding_plan + +world_size = int(os.environ["WORLD_SIZE"]) +rank = int(os.environ["RANK"]) + +logger = logging.getLogger("test_simple_logging") +logger.setLevel(logging.DEBUG) + +log_filename = f"logging_sample_rank{rank}.log" + +fh = logging.FileHandler(log_filename, mode="w") +fh.setLevel(logging.DEBUG) +logger.addHandler(fh) + +DebugLogger.set_vescale_debug_mode(rank_to_print=(0, 1, 2, 3), logger=logger) +device_mesh = DeviceMesh("cuda", list(range(world_size))) + + +model = MLP() +dmodule = parallelize_module(model, device_mesh, sharding_plan) +input = torch.ones((4, 4, 4)) +output = dmodule(input).to_local() +output.sum().backward() + +with open(log_filename) as file: + out = file.read() + +out = "".join(out.split()) + +assert len(out) > 100 diff --git a/test/debug/torchrun_scripts/simple_only_rank1.py b/test/debug/torchrun_scripts/simple_only_rank1.py new file mode 100644 index 0000000..926eaf1 --- /dev/null +++ b/test/debug/torchrun_scripts/simple_only_rank1.py @@ -0,0 +1,49 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +import torch +import io +import os +import sys +from vescale.dmodule.api import parallelize_module +from vescale.dtensor.device_mesh import DeviceMesh +from vescale.debug import DebugLogger +from contextlib import redirect_stdout + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from model import MLP, sharding_plan + +world_size = int(os.environ["WORLD_SIZE"]) +rank = int(os.environ["RANK"]) + +DebugLogger.set_vescale_debug_mode(rank_to_print=(1,)) +device_mesh = DeviceMesh("cuda", list(range(world_size))) + + +with io.StringIO() as buf, redirect_stdout(buf): + model = MLP() + dmodule = parallelize_module(model, device_mesh, sharding_plan) + input = torch.ones((4, 4, 4)) + output = dmodule(input).to_local() + output.sum().backward() + out = buf.getvalue() + +out = "".join(out.split()) +if rank == 1: + assert len(out) > 100 +else: + assert len(out) == 0 diff --git a/test/debug/torchrun_scripts/simple_std_out.py b/test/debug/torchrun_scripts/simple_std_out.py new file mode 100644 index 0000000..b6eb537 --- /dev/null +++ b/test/debug/torchrun_scripts/simple_std_out.py @@ -0,0 +1,47 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +import torch +import io +import os +import sys +from vescale.dmodule.api import parallelize_module +from vescale.dtensor.device_mesh import DeviceMesh +from vescale.debug import DebugLogger +from contextlib import redirect_stdout + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from model import MLP, sharding_plan + +world_size = int(os.environ["WORLD_SIZE"]) +rank = int(os.environ["RANK"]) + +DebugLogger.set_vescale_debug_mode(rank_to_print=(0, 1, 2, 3)) +device_mesh = DeviceMesh("cuda", list(range(world_size))) + + +with io.StringIO() as buf, redirect_stdout(buf): + model = MLP() + dmodule = parallelize_module(model, device_mesh, sharding_plan) + input = torch.ones((4, 4, 4)) + output = dmodule(input).to_local() + output.sum().backward() + out = buf.getvalue() + +out = "".join(out.split()) + +assert len(out) > 100 diff --git a/test/dtensor/general/test_dtensor.py b/test/dtensor/general/test_dtensor.py index e05e143..74b3f22 100644 --- a/test/dtensor/general/test_dtensor.py +++ b/test/dtensor/general/test_dtensor.py @@ -13,7 +13,6 @@ import torch import torch.distributed as dist -import torch.nn.functional as F from numpy.testing import assert_array_equal from torch.distributed._functional_collectives import AsyncCollectiveTensor from torch.testing._internal.common_utils import run_tests @@ -23,24 +22,6 @@ from vescale.dtensor.placement_types import Partial, Replicate, Shard -class DummyMLP(torch.nn.Module): - def __init__(self, device): - super().__init__() - self.net1 = torch.nn.Linear(5, 1024, device=device) - self.relu = torch.nn.ReLU() - self.net2 = torch.nn.Linear(1024, 4, device=device) - - def forward(self, x): - return self.net2(F.relu(self.net1(x))) - - def reset_parameters(self, *args, **kwargs): - with torch.no_grad(): - self.net1.weight.fill_(0.5) - self.net2.weight.fill_(1) - self.net1.bias.fill_(1.5) - self.net2.bias.fill_(1.2) - - class DTensorTest(DTensorTestBase): @property def world_size(self) -> int: diff --git a/test/dtensor/ops/test_math_ops.py b/test/dtensor/ops/test_math_ops.py index e2236be..a160c27 100644 --- a/test/dtensor/ops/test_math_ops.py +++ b/test/dtensor/ops/test_math_ops.py @@ -196,6 +196,14 @@ def test_topk(self): self.assertTrue(d_result.values.placements[0].is_shard(dim=shard_dim)) self.assertEqual(d_result.values.full_tensor(), local_result.values) + @with_comms + def test_topk_no_dim(self): + device_mesh = self.build_device_mesh() + tensor = torch.randn(8, 8) + dtensor = distribute_tensor(tensor, device_mesh, [Shard(1)]) + topk_no_dim = torch.topk(tensor, 2) + dtopk_no_dim = torch.topk(dtensor, 2) + @with_comms def test_topk_backward(self): device_mesh = self.build_device_mesh() diff --git a/test/dtensor/ops/test_tensor_ops.py b/test/dtensor/ops/test_tensor_ops.py index 62c156d..c34572b 100644 --- a/test/dtensor/ops/test_tensor_ops.py +++ b/test/dtensor/ops/test_tensor_ops.py @@ -478,6 +478,29 @@ def test_stack(self): dx = torch.stack(dx) # torch.autograd.backward(dout, torch.ones_like(dout)) + @with_comms + def test_nonzero(self): + device_mesh = self.build_device_mesh() + x = torch.randint(0, 1, (4, 5, 6)) + out = torch.nonzero(x) + + d_x = distribute_tensor(x, device_mesh, [Replicate()]) + d_out = torch.nonzero(d_x) + + self.assertEqual(d_out.to_local(), out) + self.assertEqual(d_out.size(), d_out._local_tensor.size()) + + @with_comms + def test_unbind(self): + device_mesh = self.build_device_mesh() + x = torch.randint(0, 1, (4, 5, 6)) + d_x = distribute_tensor(x, device_mesh, [Replicate()]) + for dim in range(3): + out = torch.unbind(x, dim) + d_out = torch.unbind(d_x, dim) + for d_r, r in zip(d_out, out): + self.assertEqual(d_r.to_local(), r) + if __name__ == "__main__": run_tests() diff --git a/test/model/mixtral/README.md b/test/model/mixtral/README.md new file mode 100644 index 0000000..f9e1291 --- /dev/null +++ b/test/model/mixtral/README.md @@ -0,0 +1,7 @@ +This directory stores the tests for various components of the Mixtral 8x7B model. Specifically, there are: +1. MixtralAttentionBlock: `test/model/mixtral/test_mixtral_attention.py` +2. MixtralSparseMoeBlock: `test/model/mixtral/test_mixtral_sparse_moe.py` +3. MixtralRMSNorm: Same as Llama's RMSNorm. See `test/model/open_llama/test_rms_norm.py` instead. +4. MixtralDecoderLayer: `test/model/mixtral/test_mixtral_decoder_layer.py` + +More over, we also add an E2E test of a `small` Mixtral 8x7B model. You can see how to combine multiple parallel strategies (including TP/SP, DP and ZeRO 2+) to train a simple Mixtral network, refer to `test/model/mixtral/test_mixtral.py` for detail. \ No newline at end of file diff --git a/test/model/mixtral/sharding_plan.py b/test/model/mixtral/sharding_plan.py new file mode 100644 index 0000000..b9d7769 --- /dev/null +++ b/test/model/mixtral/sharding_plan.py @@ -0,0 +1,69 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +"""This file contain TP/SP sharding plans for all test cases.""" + +from vescale.dtensor.placement_types import Replicate, Shard + + +param_sharding_plan = { + "embed_tokens.weight": [Replicate()], + r"layers.\d+.input_layernorm.weight": [Replicate()], # MixtralRMSNorm + r"layers.\d+.self_attn.q_proj.weight": [Shard(0)], + r"layers.\d+.self_attn.k_proj.weight": [Shard(0)], + r"layers.\d+.self_attn.v_proj.weight": [Shard(0)], + # TODO: buggy, cos_cached or sin_cached can be updated or recreated if seqlen exceeds the max seqlen. + r"layers.\d+.self_attn.rotary_emb.layers.\d+.cos_cached": [Replicate()], + r"layers.\d+.self_attn.rotary_emb.layers.\d+.sin_cached": [Replicate()], + r"layers.\d+.self_attn.o_proj.weight": [Shard(1)], + r"layers.\d+.post_attention_layernorm.weight": [Replicate()], + r"layers.\d+.block_sparse_moe.gate.weight": [Replicate()], + r"layers.\d+.block_sparse_moe.experts.\d+.w1.weight": [Shard(0)], + r"layers.\d+.block_sparse_moe.experts.\d+.w3.weight": [Shard(0)], + r"layers.\d+.block_sparse_moe.experts.\d+.w2.weight": [Shard(1)], + "norm.weight": [Replicate()], +} + +fwd_resharding_plan = { + # TODO: buggy: attn mask is torch.Tensor, in training, it's a None + r".input": {"input_ids": [Replicate()], "attention_mask": [Replicate()]}, + "embed_tokens.input": [[Replicate()]], + # No SP + # r"layers.\d+.input_layernorm.input": [[Replicate()]], + # r"layers.\d+.input_layernorm.output": [[Replicate()]], + # SP + r"layers.\d+.input_layernorm.input": [[Shard(1)]], + r"layers.\d+.input_layernorm.output": [[Shard(1)]], + r"layers.\d+.self_attn.input": [[Replicate()]], + r"layers.\d+.self_attn.output": {"attn_output": [Replicate()], "attn_weights": None, "past_key_value": None}, + r"layers.\d+.self_attn.o_proj.output": [[Replicate()]], + # No SP + # r"layers.\d+.post_attention_layernorm.input": [[Replicate()]], + # r"layers.\d+.post_attention_layernorm.output": [[Replicate()]], + # SP + r"layers.\d+.post_attention_layernorm.input": [[Shard(1)]], + r"layers.\d+.post_attention_layernorm.output": [[Shard(1)]], + r"layers.\d+.block_sparse_moe.input": [[Replicate()]], + r"layers.\d+.block_sparse_moe.gate.output": [[Replicate()]], + r"layers.\d+.block_sparse_moe.output": {"final_hidden_states": [Replicate()], "router_logits": [Replicate()]}, + r"layers.\d+.block_sparse_moe.experts.\d+.w1.input": [[Replicate()]], + r"layers.\d+.block_sparse_moe.experts.\d+.w3.input": [[Replicate()]], + r"layers.\d+.block_sparse_moe.experts.\d+.w2.output": [[Replicate()]], + "norm.input": [[Replicate()]], +} + +mixtral_plan = {"parameter": param_sharding_plan, "forward": fwd_resharding_plan} diff --git a/test/model/mixtral/test_mixtral.py b/test/model/mixtral/test_mixtral.py new file mode 100644 index 0000000..2c14fb5 --- /dev/null +++ b/test/model/mixtral/test_mixtral.py @@ -0,0 +1,327 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +import copy + +import torch +from torch.testing._internal.common_utils import ( + run_tests, +) +import torch.distributed as dist +from common_dtensor import DTensorTestBase, skip_unless_torch_gpu, with_comms + +from vescale.dtensor.dtensor import DTensor +from vescale.dtensor.device_mesh import DeviceMesh +from vescale.dtensor.placement_types import Replicate +from vescale.dmodule.api import parallelize_module +from vescale.ddp.distributed_data_parallel import DistributedDataParallel as DDP +from vescale.optim.base_optimizer import BasicOptimizer +from vescale.optim.distributed_optimizer import DistributedOptimizer +from vescale.initialize.deferred_init import deferred_init + +from transformers.models.mixtral.configuration_mixtral import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralModel +from sharding_plan import mixtral_plan + + +torch.manual_seed(9999) + +vocab_size = 30 # default 32000 +hidden_size = 64 # default 4096 +# TODO: if changed to use default intermediate_size, accuracy error: 0.016 +intermediate_size = 128 # default 14336 +num_hidden_layers = 2 # default 32 +num_attention_heads = 16 # default 32 +num_key_value_heads = 8 # default 8 +attn_implementation = "eager" # options are ["eager", "sdpa", "flash_attention_2"] +bsz = 7 +seqlen = 9 + +mixtral_config = MixtralConfig( + vocab_size=vocab_size, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, +) + + +class MixtralTPSPTest(DTensorTestBase): + @property + def world_size(self): + return 2 + + def gen_golden(self, mixtral_model, x): + outs = mixtral_model(x) + hidden_states = outs.last_hidden_state + hidden_states.sum().backward() + + def compare_model_weights_and_grads(self, base_model, model): + for name, base_param in base_model.named_parameters(): + param = model.get_parameter(name) + base_grad = base_param.grad.data + grad = param.grad + if isinstance(param, DTensor): + param = param.redistribute(param.device_mesh, [Replicate()], async_op=False)._local_tensor + + torch.testing.assert_close(param, base_param) + if isinstance(grad.data, DTensor): + grad = grad.data.redistribute(grad.data.device_mesh, [Replicate()], async_op=False)._local_tensor + torch.testing.assert_close(base_grad, grad, atol=1e-4, rtol=1e-4) + + @skip_unless_torch_gpu + @with_comms + def test_tp_sp(self): + device_mesh = DeviceMesh("cuda", range(self.world_size)) + mixtral_model = MixtralModel(mixtral_config).cuda() + base_mixtral_model = copy.deepcopy(mixtral_model) + + mixtral_model = parallelize_module( + mixtral_model, + device_mesh, + mixtral_plan, + factory=True, + ) + + token_ids = torch.randint(0, vocab_size, (bsz, seqlen)).cuda() + dist.all_reduce(token_ids, op=dist.ReduceOp.MAX) + base_token_ids = copy.deepcopy(token_ids) + outs = mixtral_model(token_ids) + + hidden_states = outs.last_hidden_state + hidden_states.to_local().sum().backward() + + mixtral_model.finish_grad_sync() + + self.gen_golden(base_mixtral_model, base_token_ids) + self.compare_model_weights_and_grads(base_mixtral_model, mixtral_model) + + @skip_unless_torch_gpu + @with_comms + def test_tp_sp_deferred(self): + device_mesh = DeviceMesh("cuda", range(self.world_size)) + mixtral_model = deferred_init(MixtralModel, mixtral_config) + + mixtral_model = parallelize_module( + mixtral_model, + device_mesh, + mixtral_plan, + factory=True, + ) + + token_ids = torch.randint(0, vocab_size, (bsz, seqlen)).cuda() + dist.all_reduce(token_ids, op=dist.ReduceOp.MAX) + outs = mixtral_model(token_ids) + + hidden_states = outs.last_hidden_state + hidden_states.to_local().sum().backward() + + mixtral_model.finish_grad_sync() + + +class Mixtral4DTest(DTensorTestBase): + @property + def world_size(self): + return 4 + + def gen_golden( + self, + mixtral_model, + token_ids_batch_1_epoch_1, + token_ids_batch_2_epoch_1, + token_ids_batch_1_epoch_2, + token_ids_batch_2_epoch_2, + ): + optim = torch.optim.Adam(mixtral_model.parameters(), lr=0.01) + + # epoch 1 + optim.zero_grad() + outs = mixtral_model(token_ids_batch_1_epoch_1) + outs.last_hidden_state.sum().backward() + outs = mixtral_model(token_ids_batch_2_epoch_1) + outs.last_hidden_state.sum().backward() + + # manually reduce mean the grad + for param in mixtral_model.parameters(): + if param.grad is not None: + param.grad /= 2 + optim.step() + + # epoch 2 + optim.zero_grad() + outs = mixtral_model(token_ids_batch_1_epoch_2) + outs.last_hidden_state.sum().backward() + outs = mixtral_model(token_ids_batch_2_epoch_2) + outs.last_hidden_state.sum().backward() + + # manually reduce mean the grad + for param in mixtral_model.parameters(): + if param.grad is not None: + param.grad /= 2 + optim.step() + + def compare_model_weights(self, base_model, model): + for name, base_param in base_model.named_parameters(): + param = model.get_parameter(name) + if base_param.grad is None: + continue + if isinstance(param, DTensor): + param = param.redistribute(param.device_mesh, [Replicate()], async_op=False)._local_tensor + torch.testing.assert_close(param, base_param, atol=2e-4, rtol=2e-4) + + @skip_unless_torch_gpu + @with_comms + def test_tp_sp_ddp(self): + device_mesh = DeviceMesh("cuda", [[0, 1], [2, 3]], mesh_dim_names=("DP", "TP")) + + mixtral_model = MixtralModel(mixtral_config).cuda() + base_mixtral_model = copy.deepcopy(mixtral_model) + + mixtral_model = parallelize_module( + mixtral_model, + device_mesh["TP"], + mixtral_plan, + factory=True, + ) + + ddp_mixtral_model = DDP( + mixtral_model, + device_mesh["DP"], + accumulate_allreduce_grads_in_fp32=True, + overlap_grad_reduce=True, + use_distributed_optimizer=False, + ) + + optim = BasicOptimizer(torch.optim.Adam(mixtral_model.parameters(), lr=0.01), models=[ddp_mixtral_model]) + + token_ids_batch_1_epoch_1 = torch.randint(0, vocab_size, (bsz, seqlen)).cuda() + token_ids_batch_2_epoch_1 = torch.randint(0, vocab_size, (bsz, seqlen)).cuda() + token_ids_batch_1_epoch_2 = torch.randint(0, vocab_size, (bsz, seqlen)).cuda() + token_ids_batch_2_epoch_2 = torch.randint(0, vocab_size, (bsz, seqlen)).cuda() + dist.all_reduce(token_ids_batch_1_epoch_1, op=dist.ReduceOp.MAX) + dist.all_reduce(token_ids_batch_2_epoch_1, op=dist.ReduceOp.MAX) + dist.all_reduce(token_ids_batch_1_epoch_2, op=dist.ReduceOp.MAX) + dist.all_reduce(token_ids_batch_2_epoch_2, op=dist.ReduceOp.MAX) + base_token_ids_batch_1_epoch_1 = copy.deepcopy(token_ids_batch_1_epoch_1) + base_token_ids_batch_2_epoch_1 = copy.deepcopy(token_ids_batch_2_epoch_1) + base_token_ids_batch_1_epoch_2 = copy.deepcopy(token_ids_batch_1_epoch_2) + base_token_ids_batch_2_epoch_2 = copy.deepcopy(token_ids_batch_2_epoch_2) + + # epoch 1 + optim.zero_grad() + if self.rank in [0, 1]: + x = token_ids_batch_1_epoch_1 + else: + x = token_ids_batch_2_epoch_1 + ddp_mixtral_model(x).last_hidden_state.to_local().sum().backward() + ddp_mixtral_model.finish_grad_sync() + optim.step() + + # epoch 2 + optim.zero_grad() + if self.rank in [0, 1]: + x = token_ids_batch_1_epoch_2 + else: + x = token_ids_batch_2_epoch_2 + ddp_mixtral_model(x).last_hidden_state.to_local().sum().backward() + ddp_mixtral_model.finish_grad_sync() + optim.step() + + self.gen_golden( + base_mixtral_model, + base_token_ids_batch_1_epoch_1, + base_token_ids_batch_2_epoch_1, + base_token_ids_batch_1_epoch_2, + base_token_ids_batch_2_epoch_2, + ) + self.compare_model_weights(base_mixtral_model, mixtral_model) + + @skip_unless_torch_gpu + @with_comms + def test_tp_sp_ddp_doptim(self): + device_mesh = DeviceMesh("cuda", [[0, 1], [2, 3]], mesh_dim_names=("DP", "TP")) + + mixtral_model = MixtralModel(mixtral_config).cuda() + base_mixtral_model = copy.deepcopy(mixtral_model) + + mixtral_model = parallelize_module( + mixtral_model, + device_mesh["TP"], + mixtral_plan, + factory=True, + ) + + ddp_mixtral_model = DDP( + mixtral_model, + device_mesh["DP"], + accumulate_allreduce_grads_in_fp32=True, + overlap_grad_reduce=True, + use_distributed_optimizer=True, + ) + + doptim = DistributedOptimizer( + torch.optim.Adam(mixtral_model.parameters(), lr=0.01), + models=[ddp_mixtral_model], + overlap_param_gather=False, + ) + + token_ids_batch_1_epoch_1 = torch.randint(0, vocab_size, (bsz, seqlen)).cuda() + token_ids_batch_2_epoch_1 = torch.randint(0, vocab_size, (bsz, seqlen)).cuda() + token_ids_batch_1_epoch_2 = torch.randint(0, vocab_size, (bsz, seqlen)).cuda() + token_ids_batch_2_epoch_2 = torch.randint(0, vocab_size, (bsz, seqlen)).cuda() + dist.all_reduce(token_ids_batch_1_epoch_1, op=dist.ReduceOp.MAX) + dist.all_reduce(token_ids_batch_2_epoch_1, op=dist.ReduceOp.MAX) + dist.all_reduce(token_ids_batch_1_epoch_2, op=dist.ReduceOp.MAX) + dist.all_reduce(token_ids_batch_2_epoch_2, op=dist.ReduceOp.MAX) + base_token_ids_batch_1_epoch_1 = copy.deepcopy(token_ids_batch_1_epoch_1) + base_token_ids_batch_2_epoch_1 = copy.deepcopy(token_ids_batch_2_epoch_1) + base_token_ids_batch_1_epoch_2 = copy.deepcopy(token_ids_batch_1_epoch_2) + base_token_ids_batch_2_epoch_2 = copy.deepcopy(token_ids_batch_2_epoch_2) + + # epoch 1 + doptim.zero_grad() + if self.rank in [0, 1]: + x = token_ids_batch_1_epoch_1 + else: + x = token_ids_batch_2_epoch_1 + ddp_mixtral_model(x).last_hidden_state.to_local().sum().backward() + ddp_mixtral_model.finish_grad_sync() + doptim.step() + + # epoch 2 + doptim.zero_grad() + if self.rank in [0, 1]: + x = token_ids_batch_1_epoch_2 + else: + x = token_ids_batch_2_epoch_2 + ddp_mixtral_model(x).last_hidden_state.to_local().sum().backward() + ddp_mixtral_model.finish_grad_sync() + doptim.step() + + self.gen_golden( + base_mixtral_model, + base_token_ids_batch_1_epoch_1, + base_token_ids_batch_2_epoch_1, + base_token_ids_batch_1_epoch_2, + base_token_ids_batch_2_epoch_2, + ) + self.compare_model_weights(base_mixtral_model, mixtral_model) + + +if __name__ == "__main__": + run_tests() diff --git a/test/model/mixtral/test_mixtral_attention.py b/test/model/mixtral/test_mixtral_attention.py new file mode 100644 index 0000000..effef51 --- /dev/null +++ b/test/model/mixtral/test_mixtral_attention.py @@ -0,0 +1,100 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +import copy +import torch +from torch.testing._internal.common_utils import run_tests + +from vescale.dtensor.device_mesh import DeviceMesh +from vescale.dtensor.placement_types import Replicate, Shard +from vescale.dmodule.api import parallelize_module + +from transformers.models.mixtral.configuration_mixtral import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralAttention + +from common_dtensor import DTensorTestBase, with_comms, skip_unless_torch_gpu + +torch.manual_seed(9999) + + +class MixtralAttentionBlockTest(DTensorTestBase): + @property + def world_size(self): + return 4 + + @skip_unless_torch_gpu + @with_comms + def test_tp_mixtral_attn( + self, + ): + bsz = 6 + seqlen = 18 + config = MixtralConfig() + hidden_size = config.hidden_size + + device_mesh = DeviceMesh(self.device_type, range(self.world_size)) + base_attn = MixtralAttention(config, 0).cuda() + attn = copy.deepcopy(base_attn) + + base_input = torch.rand(bsz, seqlen, hidden_size).cuda() + input = copy.deepcopy(base_input) + + # =---------------- baseline ----------------= # + base_output, _, _ = base_attn(base_input) + base_loss = base_output.mean() + base_loss.backward() + + # =---------------- vescale ----------------= # + param_sharding_plan = { + r"q_proj.weight": [Shard(0)], + r"k_proj.weight": [Shard(0)], + r"v_proj.weight": [Shard(0)], + # TODO: buggy, cos_cached or sin_cached can be updated or recreated if seqlen exceeds the max seqlen. + r"rotary_emb.cos_cached": [Replicate()], + r"rotary_emb.sin_cached": [Replicate()], + r"o_proj.weight": [Shard(1)], + } + fwd_resharding_plan = { + r"input": [[Replicate()]], + r"output": {"attn_output": [Replicate()], "attn_weights": None, "past_key_value": None}, + r"o_proj.output": [[Replicate()]], + } + + attn = parallelize_module( + attn, + device_mesh=device_mesh, + sharding_plan={"parameter": param_sharding_plan, "forward": fwd_resharding_plan}, + factory=True, + ) + output, _, _ = attn(input) + loss = output.mean() + loss.backward() + + torch.testing.assert_close(base_output, output._local_tensor) + torch.testing.assert_close(base_loss, loss._local_tensor) + for fc_name in ["q_proj", "k_proj", "v_proj", "o_proj"]: + base_param_grad = base_attn.get_parameter(f"{fc_name}.weight").grad + param_grad = ( + attn.get_parameter(f"{fc_name}.weight") + .grad.redistribute(device_mesh, [Replicate()], async_op=False) + ._local_tensor + ) + torch.testing.assert_close(base_param_grad, param_grad) + + +if __name__ == "__main__": + run_tests() diff --git a/test/model/mixtral/test_mixtral_decoder_layer.py b/test/model/mixtral/test_mixtral_decoder_layer.py new file mode 100644 index 0000000..63adc6f --- /dev/null +++ b/test/model/mixtral/test_mixtral_decoder_layer.py @@ -0,0 +1,134 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +import copy +import torch +from torch.testing._internal.common_utils import run_tests + +from vescale.dtensor.device_mesh import DeviceMesh +from vescale.dtensor.placement_types import Replicate, Shard +from vescale.dmodule.api import parallelize_module + +from transformers.models.mixtral.configuration_mixtral import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer + +from common_dtensor import DTensorTestBase, with_comms, skip_unless_torch_gpu + +torch.manual_seed(9999) + + +class MixtralDecoderLayerTest(DTensorTestBase): + @property + def world_size(self): + return 4 + + @skip_unless_torch_gpu + @with_comms + def test_tp_mixtral_decoder( + self, + ): + bsz = 6 + seqlen = 18 + config = MixtralConfig() + hidden_size = config.hidden_size + + device_mesh = DeviceMesh(self.device_type, range(self.world_size)) + base_decoder = MixtralDecoderLayer(config, 0).cuda() + decoder = copy.deepcopy(base_decoder) + + base_input = torch.rand(bsz, seqlen, hidden_size).cuda() + input = copy.deepcopy(base_input) + + # =---------------- baseline ----------------= # + base_output = base_decoder(base_input)[0] + base_loss = base_output.mean() + base_loss.backward() + + # =---------------- vescale ----------------= # + param_sharding_plan = { + r"input_layernorm.weight": [Replicate()], # MixtralRMSNorm + r"self_attn.q_proj.weight": [Shard(0)], + r"self_attn.k_proj.weight": [Shard(0)], + r"self_attn.v_proj.weight": [Shard(0)], + # TODO: buggy, cos_cached or sin_cached can be updated or recreated if seqlen exceeds the max seqlen. + r"self_attn.rotary_emb.cos_cached": [Replicate()], + r"self_attn.rotary_emb.sin_cached": [Replicate()], + r"self_attn.o_proj.weight": [Shard(1)], + r"post_attention_layernorm.weight": [Replicate()], + r"block_sparse_moe.gate.weight": [Replicate()], + r"block_sparse_moe.experts.\d+.w1.weight": [Shard(0)], + r"block_sparse_moe.experts.\d+.w3.weight": [Shard(0)], + r"block_sparse_moe.experts.\d+.w2.weight": [Shard(1)], + } + + fwd_resharding_plan = { + r"input": [[Replicate()]], + # No SP + # r"input_layernorm.input": [[Replicate()]], + # r"input_layernorm.output": [[Replicate()]], + # SP + r"input_layernorm.input": [[Shard(1)]], + r"input_layernorm.output": [[Shard(1)]], + # TODO: buggy: attn mask is torch.Tensor, in training, it's a None + r"self_attn.input": [[Replicate()]], + r"self_attn.output": { + "attn_output": [Replicate()], + "attn_weights": None, + "past_key_value": None, + }, + r"self_attn.o_proj.output": [[Replicate()]], + # No SP + # r"post_attention_layernorm.input": [[Replicate()]], + # r"post_attention_layernorm.output": [[Replicate()]], + # SP + r"post_attention_layernorm.input": [[Shard(1)]], + r"post_attention_layernorm.output": [[Shard(1)]], + r"block_sparse_moe.input": [[Replicate()]], + r"block_sparse_moe.gate.output": [[Replicate()]], + r"block_sparse_moe.output": { + "final_hidden_states": [Replicate()], + "router_logits": [Replicate()], + }, + r"block_sparse_moe.experts.\d+.w1.input": [[Replicate()]], + r"block_sparse_moe.experts.\d+.w3.input": [[Replicate()]], + r"block_sparse_moe.experts.\d+.w2.output": [[Replicate()]], + r"output": [[Replicate()]], + } + + decoder = parallelize_module( + decoder, + device_mesh=device_mesh, + sharding_plan={"parameter": param_sharding_plan, "forward": fwd_resharding_plan}, + factory=True, + ) + output = decoder(input)[0] + loss = output.mean() + loss.backward() + + torch.testing.assert_close(base_output, output._local_tensor) + torch.testing.assert_close(base_loss, loss._local_tensor) + for name, base_param in base_decoder.named_parameters(): + param = decoder.get_parameter(name) + if base_param.grad is None or param.grad is None: + continue + base_param_grad = base_param.grad + param_grad = param.grad.redistribute(device_mesh, [Replicate()], async_op=False)._local_tensor + torch.testing.assert_close(base_param_grad, param_grad) + + +if __name__ == "__main__": + run_tests() diff --git a/test/model/mixtral/test_mixtral_sparse_moe.py b/test/model/mixtral/test_mixtral_sparse_moe.py new file mode 100644 index 0000000..9f7bf33 --- /dev/null +++ b/test/model/mixtral/test_mixtral_sparse_moe.py @@ -0,0 +1,104 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +import copy +import torch +from torch.testing._internal.common_utils import run_tests + +from vescale.dtensor.device_mesh import DeviceMesh +from vescale.dtensor.placement_types import Replicate, Shard +from vescale.dmodule.api import parallelize_module + +from transformers.models.mixtral.configuration_mixtral import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + +from common_dtensor import DTensorTestBase, with_comms, skip_unless_torch_gpu + +torch.manual_seed(9999) + + +class MixtralSparseMoeBlockTest(DTensorTestBase): + @property + def world_size(self): + return 4 + + @skip_unless_torch_gpu + @with_comms + def test_tp_moe( + self, + ): + bsz = 6 + seqlen = 18 + config = MixtralConfig() + hidden_size = config.hidden_size + + device_mesh = DeviceMesh(self.device_type, range(self.world_size)) + base_moe = MixtralSparseMoeBlock(config).cuda() + moe = copy.deepcopy(base_moe) + + base_input = torch.rand(bsz, seqlen, hidden_size).cuda() + input = copy.deepcopy(base_input) + + # =---------------- baseline ----------------= # + base_output, _ = base_moe(base_input) + base_loss = base_output.mean() + base_loss.backward() + + # =---------------- vescale ----------------= # + param_sharding_plan = { + r"gate.weight": [Replicate()], + r"experts.\d+.w1.weight": [Shard(0)], + r"experts.\d+.w3.weight": [Shard(0)], + r"experts.\d+.w2.weight": [Shard(1)], + } + fwd_resharding_plan = { + r"input": [[Replicate()]], + r"gate.output": [[Replicate()]], + r"output": {"final_hidden_states": [Replicate()], "router_logits": [Replicate()]}, + r"experts.\d+.w1.input": [[Replicate()]], + r"experts.\d+.w3.input": [[Replicate()]], + r"experts.\d+.w2.output": [[Replicate()]], + } + + moe = parallelize_module( + moe, + device_mesh=device_mesh, + sharding_plan={"parameter": param_sharding_plan, "forward": fwd_resharding_plan}, + factory=True, + ) + output, _ = moe(input) + loss = output.mean() + loss.backward() + + torch.testing.assert_close(base_output, output._local_tensor) + torch.testing.assert_close(base_loss, loss._local_tensor) + for i in range(config.num_local_experts): + for fc_name in ["w1", "w2", "w3"]: + base_param = base_moe.get_parameter(f"experts.{i}.{fc_name}.weight") + param = moe.get_parameter(f"experts.{i}.{fc_name}.weight") + if param.grad is None or base_param.grad is None: + continue + base_param_grad = base_param.grad + param_grad = param.grad.redistribute(device_mesh, [Replicate()], async_op=False)._local_tensor + torch.testing.assert_close(base_param_grad, param_grad) + base_gate_grad = base_moe.get_parameter("gate.weight").grad + gate_grad = moe.get_parameter("gate.weight").grad._local_tensor + torch.testing.assert_close(base_gate_grad, gate_grad) + + +if __name__ == "__main__": + run_tests() diff --git a/test/parallel/ddp_optim/test_doptimizer.py b/test/parallel/ddp_optim/test_doptimizer.py index abd6c6a..1b4c262 100644 --- a/test/parallel/ddp_optim/test_doptimizer.py +++ b/test/parallel/ddp_optim/test_doptimizer.py @@ -85,7 +85,7 @@ def gen_golden_output(self, params_and_inputs): @with_comms @parametrize("overlap_grad_reduce", [True, False]) - @parametrize("use_distributed_optimizer", [True, False]) + @parametrize("use_distributed_optimizer", [True]) @parametrize("overlap_param_gather", [True, False]) @parametrize("use_optimizer_class", [True, False]) def test_distributed_optimizer( @@ -139,11 +139,7 @@ def test_distributed_optimizer( ) # epoch 1 - # NOTE: we can't invoke optimizer.zero_grad here. Because if overlap_param_gather is True, - # DOptimizer will try to all gather parameters. - for m in ve_optimizer.models: - m.zero_grad_buffer() - + ve_optimizer.zero_grad() x = params_and_inputs["batch1_epoch1"] if dist.get_rank() == 2 or dist.get_rank() == 3: x = params_and_inputs["batch2_epoch1"]