Skip to content

Commit

Permalink
Add in_place option for Allreduce
Browse files Browse the repository at this point in the history
  • Loading branch information
FlorianDeconinck committed Dec 30, 2024
1 parent 312b492 commit 760578c
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 3 deletions.
3 changes: 3 additions & 0 deletions ndsl/comm/comm_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,6 @@ def allreduce(self, sendobj: T, op: Optional[ReductionOperator] = None) -> T:
@abc.abstractmethod
def Allreduce(self, sendobj: T, recvobj: T, op: ReductionOperator) -> T:
...

def Allreduce_inplace(self, obj: T, op: ReductionOperator) -> T:
...
5 changes: 5 additions & 0 deletions ndsl/comm/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ def all_reduce_per_element(
):
self.comm.Allreduce(input_quantity.data, output_quantity.data, op)

def all_reduce_per_element_in_place(
self, quantity: Quantity, op: ReductionOperator
):
self.comm.Allreduce_inplace(quantity.data, op)

def _Scatter(self, numpy_module, sendbuf, recvbuf, **kwargs):
with send_buffer(numpy_module.zeros, sendbuf) as send:
with recv_buffer(numpy_module.zeros, recvbuf) as recv:
Expand Down
14 changes: 11 additions & 3 deletions ndsl/comm/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,16 @@ def allreduce(self, sendobj: T, op: Optional[ReductionOperator] = None) -> T:
)
return self._comm.allreduce(sendobj, self._op_mapping[op])

def Allreduce(self, sendobj: T, recvobj: T, op: ReductionOperator) -> T:
def Allreduce(self, sendobj_or_inplace: T, recvobj: T, op: ReductionOperator) -> T:
ndsl_log.debug(
"allreduce on rank %s with operator %s", self._comm.Get_rank(), op
"Allreduce on rank %s with operator %s", self._comm.Get_rank(), op
)
return self._comm.Allreduce(sendobj_or_inplace, recvobj, self._op_mapping[op])

def Allreduce_inplace(self, recvobj: T, op: ReductionOperator) -> T:
ndsl_log.debug(
"Allreduce (in place) on rank %s with operator %s",
self._comm.Get_rank(),
op,
)
return self._comm.Allreduce(sendobj, recvobj, self._op_mapping[op])
return self._comm.Allreduce(mpi4py.MPI.IN_PLACE, recvobj, self._op_mapping[op])

0 comments on commit 760578c

Please sign in to comment.