Skip to content

Commit

Permalink
Add pybind for Communicator (#3686)
Browse files Browse the repository at this point in the history
This is for #3091 and #3092, limitations of mpi4py that are hard to get
rid of. This PR adds the bare minimum to get the sizes and the ranks.
The next PRs will expose more methods in Communicator.
  • Loading branch information
wujingyue authored Jan 10, 2025
1 parent 78db5e1 commit 5251964
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 5 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,7 @@ if(BUILD_PYTHON)
# nvfuser python API sources
set(NVFUSER_PYTHON_SRCS)
list(APPEND NVFUSER_PYTHON_SRCS
${NVFUSER_SRCS_DIR}/python_frontend/communicator_bindings.cpp
${NVFUSER_SRCS_DIR}/python_frontend/python_bindings.cpp
${NVFUSER_SRCS_DIR}/python_frontend/python_bindings_extension.cpp
${NVFUSER_SRCS_DIR}/python_frontend/schedule_bindings.cpp
Expand Down
42 changes: 42 additions & 0 deletions csrc/python_frontend/communicator_bindings.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include <python_frontend/python_bindings.h>

#include <multidevice/communicator.h>

namespace nvfuser::python_frontend {

void bindCommunicator(py::module& nvfuser) {
// py::nodelete is necessary because Communicator doesn't have a destructor:
// https://pybind11.readthedocs.io/en/stable/advanced/classes.html#non-public-destructors
py::class_<Communicator, std::unique_ptr<Communicator, py::nodelete>>
communicator(nvfuser, "Communicator");
communicator.def(
"instance",
&Communicator::getInstance,
"Returns the singleton communicator instance.",
py::return_value_policy::reference);
communicator.def(
"size",
&Communicator::size,
"Returns the number of processes in the communicator.");
communicator.def(
"rank",
&Communicator::deviceId,
"Returns the device ID associated with the current process.");
communicator.def(
"local_size",
&Communicator::local_size,
"Returns the number of processes within the node.");
communicator.def(
"local_rank",
&Communicator::local_rank,
"Returns the in-node rank associated with the current process.");
}

} // namespace nvfuser::python_frontend
2 changes: 2 additions & 0 deletions csrc/python_frontend/python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3616,6 +3616,8 @@ void initNvFuserPythonBindings(PyObject* module) {
py::return_value_policy::reference);

bindSchedule(fusion_def);

bindCommunicator(nvfuser);
}

void cleanup() {
Expand Down
4 changes: 4 additions & 0 deletions csrc/python_frontend/python_bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@
#include <visibility.h>

namespace nvfuser::python_frontend {

NVF_API void initNvFuserPythonBindings(PyObject* module);

void bindCommunicator(py::module& nvfuser);

void bindSchedule(py::class_<FusionDefinition>& fusion_def);

NVF_API void cleanup();

} // namespace nvfuser::python_frontend
11 changes: 6 additions & 5 deletions tests/python/test_multidevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@


@pytest.mark.mpi
def test_sizes_and_ranks(mpi_test):
def test_sizes_and_ranks():
comm = nvfuser.Communicator.instance()
size, rank, local_size, local_rank = (
mpi_test.size,
mpi_test.rank,
mpi_test.local_size,
mpi_test.local_rank,
comm.size(),
comm.rank(),
comm.local_size(),
comm.local_rank(),
)
assert size > 0
assert rank >= 0 and rank < size
Expand Down

0 comments on commit 5251964

Please sign in to comment.