-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[emulator] feat: veScale correctness emulator (#45)
This pull request contains **veScale Correctness Emulator** that emulates the results from multiple devices execution on a single device. ## Why veScale Correctness Emulator? - Modern Frameworks promise **Single-Device Abstraction** for **nD Parallelism**. But it is still missing a critical component that can verify the ***correctness*** of **Single-Device Abstraction of nD Parallelism**. For example, there are differences between the loss curve of single device training and loss curves of 3D parallelism training. - How do we know the difference is *correct*? To what extent is it *correct*? - "Correct" differences come from nD Parallelism - Communication difference (e.g., ring allreduce) - Compute difference (e.g., matmul) - Hardware difference (e.g. FP16) - "Incorrect" differences come from bugs in - User configuration - User model code - System implementation code - Data loader - Model checkpoint - Random seed and offset ## What is veScale Correctness Emulator? - **veScale Correctness Emulator** verifies nD prarllelism correctness by emulating nD parallel training on a single device, - **veScale Correctness Emulator** isolates correctness at different layers and seperates differences come from nD parallelism with differences come from bugs. - **veScale Correctness Emulator** achieves bitwise correctness in three levels: NCCL collectives, mesh collectives, and DTensor. ### NCCL Emulation We are using the NCCL version 2.19.3 code as a reference for our emulation implementation. The code can be found at [NVIDIA/nccl](https://github.com/NVIDIA/nccl/tree/v2.19.3-1). **veScale Correctness Emulator** can perfectly emulate NCCL collective APIs' results. This is achieved by implementing the same NCCL collective algorithms and modeling NCCL's computation order via calculating the correct chunk size. ### Collective APIs Emulation These are standalone collective APIs which emulate the results from collective APIs of NCCL on a single device. Supported APIs: - `all_reduce` - `all_gather` - `reduce_scatter` - `all_to_all` ### Mesh Collective APIs Emulation These are standalone mesh collective APIs which emulate the results from mesh collective APIs of PyTorch on a single device. Supported APIs: - `mesh_all_reduce` - `mesh_all_gather` - `mesh_reduce_scatter` - `mesh_all_to_all` - `mesh_broadcast` - `mesh_scatter` ### DTensor Redistribution Function Emulation These are standalone DTensor redistribution functions which emulate the results from DTensor redistribution functions of PyTorch on a single device. - `R2R` - `R2S` - `S2R` - `P2R` Comming soon: A full list of emulator DTensor redistribution functions will be added to support nD parallelisms including DP, TP, SP, PP, EP, and OP. ## How does veScale Correctness Emulator work? **veScale Correctness Emulator** achieves bitwise correctness in emulating NCCL collectives APIs results. This is done by implementing the same NCCL collective algorithms and modeling NCCL's algorithm and protocol selection function and chunk size calculation process to ensure the same computation order as NCCL. Based on the emulation functions for NCCL collectives, **veScale Correctness Emulator** implements a global-view emulator `ProcessGroup` and `DeviceMesh` that contain all the process groups in the enviroment, while PyTorch's `ProcessGroup` and `DeviceMesh` only view process groups related to the current ranks. Aided by the global-view emulator `ProcessGroup` and `DeviceMesh`, **veScale Correctness Emulator** can emulate the results of collective APIs, mesh collective APIs, and DTensor redistribution functions on a single device.
- Loading branch information
1 parent
70db7e7
commit e439aa9
Showing
38 changed files
with
5,698 additions
and
0 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
################################################################################ | ||
# | ||
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
################################################################################ | ||
|
||
from typing import Callable, Tuple, Dict, Any | ||
from functools import wraps | ||
|
||
TestFunc = Callable[[object], object] | ||
|
||
|
||
# wrapper to initialize comms (process group) within emulator | ||
def with_comms_emulator(func: TestFunc) -> TestFunc: | ||
assert func is not None | ||
|
||
@wraps(func) # pyre-ignore[6] | ||
def wrapper(self, *args: Tuple[object], **kwargs: Dict[str, Any]) -> None: # type: ignore[misc] | ||
# launch | ||
self.init_emulator_pg() | ||
func(self, *args, **kwargs) # type: ignore[misc] | ||
self.destroy_emulator_pg() | ||
|
||
return wrapper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
################################################################################ | ||
# | ||
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
################################################################################ | ||
|
||
|
||
import os | ||
import torch | ||
import torch.distributed as dist | ||
from torch.testing._internal.common_utils import ( | ||
instantiate_parametrized_tests, | ||
parametrize, | ||
run_tests, | ||
) | ||
|
||
import vescale | ||
from vescale.emulator.distributed import ProcessGroup, dump_nccl_graph_for_pg | ||
from vescale.emulator.reduce_kernel import ReduceOp | ||
|
||
from vescale.emulator.all_gather import expand_tensor_list | ||
from vescale.emulator.reduce_scatter import contract_tensor_list | ||
from common_dtensor import DTensorTestBase, with_comms | ||
from emulator.common_emulator import with_comms_emulator | ||
from vescale.emulator.utils import emulator_reduce_op_to_torch | ||
|
||
|
||
class TestDistributed(DTensorTestBase): | ||
@property | ||
def world_size(self) -> int: | ||
return 4 | ||
|
||
def init_emulator_pg(self): | ||
torch.manual_seed(0) | ||
backend = "nccl" | ||
world_size = self.world_size | ||
|
||
vescale.emulator.distributed.init_process_group(backend=backend, world_size=world_size, rank=0) | ||
vescale.emulator.distributed.set_rank(0) | ||
self.pg: ProcessGroup = vescale.emulator.distributed._world.default_pg | ||
self.torch_pg = torch.distributed.distributed_c10d._get_default_group() | ||
dump_nccl_graph_for_pg(self.pg, self.torch_pg, self.rank) | ||
|
||
def destroy_emulator_pg(self): | ||
vescale.emulator.distributed.destroy_process_group() | ||
|
||
@with_comms | ||
@with_comms_emulator | ||
def test_process_group(self): | ||
ground_truth_pg_group_ranks = [{0: 0, 1: 1, 2: 2, 3: 3}, {0: 0, 2: 1}, {1: 0, 3: 1}, {0: 0, 1: 1}, {2: 0, 3: 1}] | ||
for count, value in enumerate(vescale.emulator.distributed._world.pg_group_ranks.values()): | ||
self.assertEqual(value, ground_truth_pg_group_ranks[count]) | ||
|
||
@with_comms | ||
@with_comms_emulator | ||
# @parametrize("reduce_op", [ReduceOp.SUM, ReduceOp.PRODUCT, ReduceOp.MAX, ReduceOp.MIN]) | ||
@parametrize("reduce_op", [ReduceOp.SUM]) | ||
@parametrize("nelement", [1, 1024, 1024 * 1024]) | ||
def test_all_reduce(self, nelement, reduce_op): | ||
nranks = self.pg.size() | ||
tree_structure = [[0, 1], [2, 3]] | ||
torch_rank = self.rank | ||
device = f"cuda:{torch_rank}" | ||
|
||
input_file = "input_distributed.pt" | ||
if self.rank == 0: | ||
# To ensure all ranks have the same input | ||
input_list = [] | ||
for i in range(nranks): | ||
input_list.append(torch.randn((nelement,), device="cuda")) | ||
torch.save(input_list, input_file) | ||
dist.barrier() | ||
|
||
data_list = torch.load(input_file) | ||
data_list = [data.to(device) for data in data_list] | ||
ground_truth = [data_list[rank].clone().to(device) if rank == torch_rank else [] for rank in range(nranks)] | ||
torch_reduce_op = emulator_reduce_op_to_torch(reduce_op) | ||
|
||
torch.distributed.all_reduce(ground_truth[torch_rank], torch_reduce_op) | ||
self.pg.all_reduce(data_list, op=reduce_op, tree_structure=tree_structure) | ||
|
||
self.assertTrue(torch.equal(data_list[torch_rank], ground_truth[torch_rank])) | ||
|
||
if self.rank == 0: | ||
if os.path.exists(input_file): | ||
os.remove(input_file) | ||
|
||
@with_comms | ||
@with_comms_emulator | ||
@parametrize("nelement", [1, 1024, 1024 * 1024]) | ||
def test_all_gather(self, nelement): | ||
nranks = self.pg.size() | ||
torch_rank = self.rank | ||
device = f"cuda:{torch_rank}" | ||
|
||
input_file = "input_distributed.pt" | ||
if self.rank == 0: | ||
# To ensure all ranks have the same input | ||
input_list = [] | ||
for i in range(nranks): | ||
input_list.append(torch.randn((nelement,), device="cuda")) | ||
torch.save(input_list, input_file) | ||
dist.barrier() | ||
|
||
data_list = torch.load(input_file) | ||
data_list = [data.to(device) for data in data_list] | ||
ground_truth_list = [torch.zeros(nelement).to(device) for _ in range(nranks)] | ||
output_list = expand_tensor_list(data_list) | ||
|
||
torch.distributed.all_gather(ground_truth_list, data_list[torch_rank]) | ||
self.pg.all_gather(output_list, data_list) | ||
|
||
for gt, data in zip(ground_truth_list, data_list): | ||
self.assertTrue(torch.equal(gt, data)) | ||
|
||
if self.rank == 0: | ||
if os.path.exists(input_file): | ||
os.remove(input_file) | ||
|
||
@with_comms | ||
@with_comms_emulator | ||
# @parametrize("reduce_op", [ReduceOp.SUM, ReduceOp.PRODUCT, ReduceOp.MAX, ReduceOp.MIN]) | ||
@parametrize("reduce_op", [ReduceOp.SUM]) | ||
@parametrize("nelement", [1, 1024, 1024 * 1024]) | ||
def test_reduce_scatter(self, nelement, reduce_op): | ||
nranks = self.pg.size() | ||
torch_rank = self.rank | ||
device = f"cuda:{torch_rank}" | ||
|
||
input_file = "input_distributed.pt" | ||
if self.rank == 0: | ||
# To ensure all ranks have the same input | ||
input_list = [] | ||
for i in range(nranks): | ||
input_list.append([]) | ||
for j in range(nranks): | ||
input_list[i].append(torch.randn((nelement,), device="cuda")) | ||
torch.save(input_list, input_file) | ||
dist.barrier() | ||
|
||
data_list = torch.load(input_file) | ||
data_list = [[elem.to(device) for elem in data] for data in data_list] | ||
ground_truth = torch.zeros(nelement).to(device) | ||
outputs = contract_tensor_list(data_list) | ||
torch_reduce_op = emulator_reduce_op_to_torch(reduce_op) | ||
|
||
torch.distributed.reduce_scatter(ground_truth, data_list[torch_rank], torch_reduce_op) | ||
|
||
self.pg.reduce_scatter(outputs, data_list, op=reduce_op) | ||
|
||
result = outputs[torch_rank] | ||
self.assertTrue(torch.equal(result, ground_truth)) | ||
|
||
if self.rank == 0: | ||
if os.path.exists(input_file): | ||
os.remove(input_file) | ||
|
||
@with_comms | ||
@with_comms_emulator | ||
@parametrize("nelement", [1, 1024, 1024 * 1024]) | ||
def test_all_to_all(self, nelement): | ||
nranks = self.pg.size() | ||
torch_rank = self.rank | ||
device = f"cuda:{torch_rank}" | ||
|
||
input_file = "input_distributed.pt" | ||
if self.rank == 0: | ||
# To ensure all ranks have the same input | ||
input_list = [] | ||
for i in range(nranks): | ||
input_list.append([]) | ||
for j in range(nranks): | ||
input_list[i].append(torch.randn((nelement,), device="cuda")) | ||
torch.save(input_list, input_file) | ||
dist.barrier() | ||
|
||
data_list = torch.load(input_file) | ||
outputs_list = [] | ||
ground_truth_list = [] | ||
for i in range(nranks): | ||
outputs_list.append([]) | ||
for j in range(nranks): | ||
data_list[i][j] = data_list[i][j].to(device) | ||
outputs_list[i].append((torch.zeros(nelement)).to(device)) | ||
ground_truth_list.append((torch.zeros(nelement)).to(device)) | ||
|
||
torch.distributed.all_to_all(ground_truth_list, data_list[torch_rank]) | ||
self.pg.all_to_all(outputs_list, data_list) | ||
|
||
for gt, output in zip(ground_truth_list, outputs_list[torch_rank]): | ||
self.assertTrue(torch.equal(gt, output)) | ||
|
||
if self.rank == 0: | ||
if os.path.exists(input_file): | ||
os.remove(input_file) | ||
|
||
|
||
instantiate_parametrized_tests(TestDistributed) | ||
|
||
if __name__ == "__main__": | ||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
################################################################################ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
################################################################################ | ||
# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. | ||
################################################################################ | ||
|
||
import os | ||
|
||
import numpy as np | ||
from common_dtensor import ( | ||
DTensorTestBase, # skip_unless_torch_gpu, | ||
with_comms, | ||
) | ||
from typing import List, cast | ||
|
||
import torch | ||
import torch.distributed as dist | ||
import torch.distributed._functional_collectives as funcol | ||
from torch.testing._internal.common_utils import run_tests | ||
|
||
import vescale | ||
from vescale.dtensor.dtensor import DTensor | ||
from vescale.dtensor.placement_types import Placement, Replicate, Shard | ||
|
||
from vescale.emulator.device_mesh import dump_nccl_graph_for_mesh | ||
from vescale.emulator.distributed import ProcessGroup, dump_nccl_graph_for_pg | ||
from vescale.emulator.comm_api import distribute_tensor, redistribute_dtensor | ||
from vescale.emulator.device_mesh import DeviceMesh | ||
from vescale.emulator.emulator_instrumentation import EmulatorInstrumentation | ||
from emulator.common_emulator import with_comms_emulator | ||
|
||
|
||
class DistMatrixOpsTest(DTensorTestBase): | ||
@property | ||
def world_size(self) -> int: | ||
return 4 | ||
|
||
def init_emulator_pg(self): | ||
torch.manual_seed(0) | ||
backend = "nccl" | ||
world_size = self.world_size | ||
|
||
vescale.emulator.distributed.init_process_group(backend=backend, world_size=world_size, rank=0) | ||
vescale.emulator.distributed.set_rank(0) | ||
# dump default process group | ||
self.pg: ProcessGroup = vescale.emulator.distributed._world.default_pg | ||
self.torch_pg = torch.distributed.distributed_c10d._get_default_group() | ||
dump_nccl_graph_for_pg(self.pg, self.torch_pg, self.rank) | ||
|
||
# dump for other process groups | ||
mesh_tensor = list(range(world_size)) | ||
self.vescale_mesh = vescale.dtensor.device_mesh.DeviceMesh(self.device_type, mesh_tensor) | ||
self.mesh = DeviceMesh(self.device_type, mesh_tensor) | ||
dump_nccl_graph_for_mesh(self.mesh, self.vescale_mesh) | ||
|
||
def destroy_emulator_pg(self): | ||
vescale.emulator.distributed.destroy_process_group() | ||
|
||
@with_comms | ||
@with_comms_emulator | ||
def test_mm(self): | ||
device_mesh = self.mesh | ||
vescale_device_mesh = vescale.dtensor.device_mesh.DeviceMesh(self.device_type, list(range(self.world_size))) | ||
device = f"cuda:{self.rank}" | ||
replica_spec = Replicate() | ||
|
||
input_file = "input_dtensors.pt" | ||
if self.rank == 0: | ||
t1 = torch.randn(12, 8, requires_grad=True).cuda() | ||
t2 = torch.randn(8, 12, requires_grad=True).cuda() | ||
torch.save((t1, t2), input_file) | ||
dist.barrier() | ||
|
||
t1, t2 = torch.load(input_file) | ||
t1 = t1.to(device) | ||
t2 = t2.to(device) | ||
t1_list = [t1.clone().detach().requires_grad_() for _ in range(self.world_size)] | ||
t2_list = [t2.clone().detach().requires_grad_() for _ in range(self.world_size)] | ||
|
||
def test_placement_comb(placements1: List[Placement], placements2: List[Placement]) -> None: | ||
dt1_list = distribute_tensor(t1_list, device_mesh, placements1) | ||
dt2_list = distribute_tensor(t2_list, device_mesh, placements2) | ||
|
||
# Emulator replace the given pytorch function to accpet lists of tensors as input | ||
func_list = ["mm"] | ||
indices = [(0, 1)] | ||
with EmulatorInstrumentation(torch, func_list, indices): | ||
dist_res_list = torch.mm(dt1_list, dt2_list) | ||
dist_res_list = redistribute_dtensor(dist_res_list, device_mesh, [replica_spec]) | ||
|
||
dt1 = vescale.distribute_tensor(t1.clone().detach().requires_grad_(), vescale_device_mesh, placements1) | ||
dt2 = vescale.distribute_tensor(t2.clone().detach().requires_grad_(), vescale_device_mesh, placements2) | ||
dist_res: DTensor = cast(DTensor, torch.mm(dt1, dt2)).redistribute(vescale_device_mesh, [replica_spec]) | ||
|
||
for dist_res_emu in dist_res_list: | ||
self.assertTrue(torch.equal(dist_res.to_local(), dist_res_emu.to_local())) | ||
|
||
shard_specs_comb = [ | ||
(Shard(dim=0), Replicate()), | ||
(Shard(dim=1), Shard(dim=0)), | ||
(Replicate(), Shard(dim=1)), | ||
(Replicate(), Replicate()), | ||
] | ||
|
||
for spec in shard_specs_comb: | ||
test_placement_comb([spec[0]], [spec[1]]) | ||
|
||
if self.rank == 0: | ||
if os.path.exists(input_file): | ||
os.remove(input_file) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
Oops, something went wrong.