Skip to content

Commit

Permalink
refactor: update collective operations' signatures (#31)
Browse files Browse the repository at this point in the history
To make collective operations' signatures consistent with torch's,
our signatures are restructured.
  • Loading branch information
myungjin authored Jul 16, 2024
1 parent 3fbada0 commit eb72720
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 16 deletions.
4 changes: 2 additions & 2 deletions examples/multiworld_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ async def send_data(world_name, rank, size, backend):
tensor = _prepare_tensor(rank, backend)

try:
await world_communicator.send(tensor, world_name, rank_to_send)
await world_communicator.send(tensor, rank_to_send, world_name)
except Exception as e:
print(f"caught an exception: {e}")
print("terminate sending")
Expand Down Expand Up @@ -114,7 +114,7 @@ async def receive_data(world_communicator, backend, worlds):
tensor = _prepare_tensor(0, backend)

try:
await world_communicator.recv(tensor, world, 1)
await world_communicator.recv(tensor, 1, world)
except Exception as e:
print(f"caught an exception: {e}")
worlds.remove(world)
Expand Down
4 changes: 2 additions & 2 deletions examples/resnet_multiworld.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ async def run_leader(world_communicator, world_size, backend):
# Send the image to the worker
try:
await world_communicator.send(
image_tensor, f"world{worker_idx}", WORKER_RANK
image_tensor, WORKER_RANK, f"world{worker_idx}"
)
except Exception as e:
print(f"Caught an except while sending image: {e}")
Expand All @@ -298,7 +298,7 @@ async def run_leader(world_communicator, world_size, backend):
)
try:
await world_communicator.recv(
predicted_class_tensor, f"world{worker_idx}", WORKER_RANK
predicted_class_tensor, WORKER_RANK, f"world{worker_idx}"
)
except Exception as e:
print(f"Caught an except while receiving predicted class: {e}")
Expand Down
44 changes: 32 additions & 12 deletions multiworld/world_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ async def _wait_work(self, work: Work, world_name: str) -> None:
raise BrokenWorldException(f"{world_name}")
await asyncio.sleep(0)

async def send(self, tensor: Tensor, world_name: str, dst: int) -> None:
async def send(
self, tensor: Tensor, dst: int, world_name: str = dist.DEFAULT_WORLD_NAME
) -> None:
"""Send a tensor to a destination in a world."""
try:
work = dist.isend(tensor, dst=dst, name=world_name)
Expand All @@ -104,7 +106,9 @@ async def send(self, tensor: Tensor, world_name: str, dst: int) -> None:

await self._wait_work(work, world_name)

async def recv(self, tensor: Tensor, world_name: str, src: int) -> None:
async def recv(
self, tensor: Tensor, src: int, world_name: str = dist.DEFAULT_WORLD_NAME
) -> None:
"""Receive a tensor from a specific rank in a world."""
try:
work = dist.irecv(tensor, src=src, name=world_name)
Expand All @@ -113,7 +117,9 @@ async def recv(self, tensor: Tensor, world_name: str, src: int) -> None:

await self._wait_work(work, world_name)

async def broadcast(self, tensor: Tensor, world_name: str, src: int) -> None:
async def broadcast(
self, tensor: Tensor, src: int, world_name: str = dist.DEFAULT_WORLD_NAME
) -> None:
"""Broadcast a tensor to the world from a source (src)."""
try:
work = dist.broadcast(tensor, src, async_op=True, name=world_name)
Expand All @@ -122,29 +128,43 @@ async def broadcast(self, tensor: Tensor, world_name: str, src: int) -> None:

await self._wait_work(work, world_name)

async def all_reduce(self, tensor: Tensor, world_name: str) -> None:
async def all_reduce(
self,
tensor: Tensor,
op: dist.ReduceOp = dist.ReduceOp.SUM,
world_name: str = dist.DEFAULT_WORLD_NAME,
) -> None:
"""Do all-reduce for a given tensor in a world."""
try:
work = dist.all_reduce(tensor, async_op=True, name=world_name)
work = dist.all_reduce(tensor, op, async_op=True, name=world_name)
except RuntimeError as e:
self._handle_error(e, world_name)

await self._wait_work(work, world_name)

async def reduce(self, tensor: Tensor, world_name: str, dst: int) -> None:
async def reduce(
self,
tensor: Tensor,
dst: int,
op: dist.ReduceOp = dist.ReduceOp.SUM,
world_name: str = dist.DEFAULT_WORLD_NAME,
) -> None:
"""Do reduce for a given tensor in a world.
The rank is a receiver of the final result.
"""
try:
work = dist.reduce(tensor, dst, async_op=True, name=world_name)
work = dist.reduce(tensor, dst, op, async_op=True, name=world_name)
except RuntimeError as e:
self._handle_error(e, world_name)

await self._wait_work(work, world_name)

async def all_gather(
self, tensors: list[Tensor], tensor: Tensor, world_name: str
self,
tensors: list[Tensor],
tensor: Tensor,
world_name: str = dist.DEFAULT_WORLD_NAME,
) -> None:
"""Do all-gather for a given tensor in a world."""
try:
Expand All @@ -157,9 +177,9 @@ async def all_gather(
async def gather(
self,
tensor: Tensor,
world_name: str,
dst: int,
gather_list: list[Tensor] = None,
dst: int = 0,
world_name: str = dist.DEFAULT_WORLD_NAME,
) -> None:
"""Do gather for a list of tensors in a world."""
try:
Expand All @@ -178,9 +198,9 @@ async def gather(
async def scatter(
self,
tensor: Tensor,
world_name: str,
src: int,
scatter_list: list[Tensor] = None,
src: int = 0,
world_name: str = dist.DEFAULT_WORLD_NAME,
) -> None:
"""Do scatter for a list of tensors from a source (src) in a world."""
try:
Expand Down

0 comments on commit eb72720

Please sign in to comment.