Skip to content

Commit

Permalink
Attempt to DID-parallelize a loop split in Python.
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue committed Dec 2, 2024
1 parent 39f2809 commit 50ccaac
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 6 deletions.
9 changes: 9 additions & 0 deletions csrc/python_frontend/fusion_definition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,15 @@ void FusionDefinition::finalizeSchedule(
// Users can access schedule objects after scheduling the fusion.
}

void FusionDefinition::setupMultideviceSchedule() {
prev_fusion_ = FusionGuard::getCurFusion();
FusionGuard::setCurFusion(preschedFusion());
}

void FusionDefinition::finalizeMultideviceSchedule() {
FusionGuard::setCurFusion(prev_fusion_);
}

void FusionDefinition::print(std::ostream& os) const {
if (id().has_value()) {
os << "\ndef nvfuser_fusion_id" << id().value();
Expand Down
2 changes: 2 additions & 0 deletions csrc/python_frontend/fusion_definition.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ class NVF_API FusionDefinition : public FusionState {
//! Finalized use scheduling of a fusion
//! resets FusionGuard, lowers IR to a kernel, compiles kernel
NVF_API void finalizeSchedule(const at::ArrayRef<c10::IValue>& inputs);
NVF_API void setupMultideviceSchedule();
NVF_API void finalizeMultideviceSchedule();
//! Prints a python function representing the definition
NVF_API void print(std::ostream& os) const;
//! Executes a fusion if a valid definition or cache lookup occurred prior
Expand Down
19 changes: 18 additions & 1 deletion csrc/python_frontend/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1091,6 +1091,12 @@ void initNvFuserPythonBindings(PyObject* module) {
// Mark the end of a schedule
inst::Trace::instance()->endEvent(nullptr);
})
.def(
"_setup_multidevice_schedule",
[](FusionDefinition& self) { self.setupMultideviceSchedule(); })
.def(
"_finalize_multidevice_schedule",
[](FusionDefinition& self) { self.finalizeMultideviceSchedule(); })
.def("inputs", [](FusionDefinition& self) { return self.inputs(); })
.def("outputs", [](FusionDefinition& self) { return self.outputs(); })
.def("extents", [](FusionDefinition& self) { return self.extents(); })
Expand Down Expand Up @@ -3596,7 +3602,6 @@ void initNvFuserPythonBindings(PyObject* module) {
},
py::arg("tensor"),
py::arg("mesh"));
//! experimental API for multidevice support
nvf_sched.def(
"parallelize",
[](FusionDefinition::SchedOperators& self,
Expand Down Expand Up @@ -3683,6 +3688,18 @@ void initNvFuserPythonBindings(PyObject* module) {
py::arg("dim"),
py::arg("factor"),
py::arg("inner_split") = true);
nvf_sched.def(
"set_allocation_as_loop",
[](FusionDefinition::SchedOperators& self, Tensor arg) {
FUSER_PERF_SCOPE("SchedOperators.set_allocation_as_loop");
NVF_CHECK(
self.validUse(),
"Attempting to use a SchedOperators Op prior to definition!");
FusionDefinition* fd = self.fusion_definition;
auto* tv = fd->getFusionState(arg.index)->template as<TensorView>();
tv->setAllocationDomain(tv->getLoopDomain(), true);
},
py::arg("arg"));
nvf_sched.def(
"cache_after",
[](FusionDefinition::SchedOperators& self,
Expand Down
2 changes: 2 additions & 0 deletions nvfuser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,9 @@ def execute(
#
# Note: there's a plan to embed multidevice schedules into FusionDefinition
# as annotating nodes. This may eventually replace `multidevice_schedule`.
self._setup_multidevice_schedule()
self.multidevice_schedule()
self._finalize_multidevice_schedule()

# If schedule is defined by child class and schedule is not defined for
# inputs, make a schedule.
Expand Down
4 changes: 4 additions & 0 deletions tests/python/mpi_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ def local_rank(self):
def barrier(self):
self._communicator.barrier()

def shard_tensor(self, t: torch.Tensor, dim: int) -> torch.Tensor:
assert t.is_cpu
return t.tensor_split(self.size, dim)[self.rank].cuda(self.local_rank)


@pytest.fixture(scope="session")
def mpi_test():
Expand Down
46 changes: 46 additions & 0 deletions tests/python/test_communication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

import pytest
import torch

import mpi_fixtures
import nvfuser
from nvfuser import DataType, FusionDefinition


mpi_test = mpi_fixtures.mpi_test


@pytest.mark.mpi
def test_allgather(mpi_test):
num_devices = mpi_test.size
rank = mpi_test.rank

unsharded = torch.randn(num_devices * 4)
sharded = mpi_test.shard_tensor(unsharded, 0)

class Model(FusionDefinition):
def definition(self):
self.inp = self.define_tensor(
(num_devices * 4,), contiguity=True, dtype=DataType.Float
)
self.out = self.ops.set(self.inp)
self.add_output(self.out)

def multidevice_schedule(self):
mesh = self.sched._create_device_mesh(range(num_devices))
self.sched._set_device_mesh(self.inp, mesh)
self.sched._set_device_mesh(self.out, mesh)

self.sched.split(self.inp, 0, num_devices, False)
self.sched.parallelize(self.inp, 0, nvfuser.ParallelType.mesh_x)
self.sched.set_allocation_as_loop(self.inp)

self.sched.split(self.out, 0, num_devices, False)
self.sched.set_allocation_as_loop(self.out)

fd = Model()
outputs = fd.execute([sharded])
torch.testing.assert_close(outputs[0].cpu(), unsharded)
8 changes: 3 additions & 5 deletions tests/python/test_multidevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,8 @@ def test_pointwise(mpi_test):
num_devices = mpi_test.size
rank = mpi_test.rank

torch.cuda.set_device(mpi_test.local_rank)

unsharded_input = torch.randn(num_devices, 4, device="cuda")
sharded_input = unsharded_input[rank : rank + 1]
unsharded_input = torch.randn(num_devices, 4)
sharded_input = mpi_test.shard_tensor(unsharded_input, 0)

class Model(FusionDefinition):
def definition(self):
Expand All @@ -58,7 +56,7 @@ def multidevice_schedule(self):

fd = Model()
outputs = fd.execute([sharded_input])
torch.testing.assert_close(outputs[0], unsharded_input.relu() * 2)
torch.testing.assert_close(outputs[0].cpu(), unsharded_input.relu() * 2)


@pytest.mark.mpi
Expand Down

0 comments on commit 50ccaac

Please sign in to comment.