diff --git a/multiworld/world_communicator.py b/multiworld/world_communicator.py index 049db33..fb727b6 100644 --- a/multiworld/world_communicator.py +++ b/multiworld/world_communicator.py @@ -20,7 +20,7 @@ import asyncio import concurrent.futures import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable import torch.distributed as dist from torch import Tensor @@ -62,29 +62,38 @@ class WorldCommunicator: def __init__(self, world_manager: WorldManager): """Initialize a class instance.""" self._world_manager = world_manager + self._world_to_send_fn: dict[str, Callable] = {} + self._world_to_recv_fn: dict[str, Callable] = {} self._broken_world: dict[str, bool] = {} self._loop = asyncio.get_running_loop() def __del__(self): """Cleanup the class instance.""" + del self._world_to_send_fn + del self._world_to_recv_fn del self._broken_world - def add_world(self, world_name) -> None: + def add_world(self, world_name: str, backend: str) -> None: """Add a new world to the world communicator. This method shouldn't be called directly by a user program. WorldManager will use this method. """ + self._set_functions(world_name, backend) + self._broken_world[world_name] = False - def remove_world(self, world_name) -> None: + def remove_world(self, world_name: str) -> None: """Remove a world from the world communicator. This method shouldn't be called directly by a user program. WorldManager will use this method. """ logger.debug(f"remove world {world_name}") + + self._reset_functions(world_name) + try: self._broken_world[world_name] = True except KeyError: @@ -104,6 +113,38 @@ def is_broken(self, world_name: str) -> bool: logger.debug(f"check if world {world_name} is broken") return self._broken_world.get(world_name, True) + def _set_functions(self, world_name: str, backend: str) -> None: + if backend == "nccl": + self._world_to_send_fn[world_name] = dist.isend + self._world_to_recv_fn[world_name] = dist.irecv + else: + self._world_to_send_fn[world_name] = dist.send + self._world_to_recv_fn[world_name] = dist.recv + + def _reset_functions(self, world_name: str) -> None: + try: + del self._world_to_send_fn[world_name] + except KeyError: + pass + + try: + del self._world_to_recv_fn[world_name] + except KeyError: + pass + + def _get_fn(self, world_name: str, op: str) -> Callable: + try: + match op: + case "send": + return self._world_to_send_fn[world_name] + case "recv": + return self._world_to_recv_fn[world_name] + case _: + raise KeyError() + except KeyError: + err_msg = f"function for {op} not found" + raise BrokenWorldException(world_name, err_msg) + async def _wait_work(self, work: Work, world_name: str) -> None: """Do busy-waiting for work to be done. @@ -112,7 +153,7 @@ async def _wait_work(self, work: Work, world_name: str) -> None: """ while not work.is_completed(): if self._broken_world[world_name]: - raise BrokenWorldException(world_name, "watchdog raised the exception") + raise BrokenWorldException(world_name, "exception raised by watchdog") await asyncio.sleep(0) async def send( @@ -130,11 +171,12 @@ async def send( BrokenWorldException: An error that occurs when the world is broken due to worker, node or network failure. """ + fn = self._get_fn(world_name, "send") try: with concurrent.futures.ThreadPoolExecutor() as pool: work = await self._loop.run_in_executor( pool, - dist.isend, + fn, tensor, dst, None, @@ -144,7 +186,8 @@ async def send( except RuntimeError as e: self._handle_error(e, world_name) - await self._wait_work(work, world_name) + if isinstance(work, Work): + await self._wait_work(work, world_name) async def recv( self, tensor: Tensor, src: int, world_name: str = DEFAULT_WORLD_NAME @@ -161,11 +204,12 @@ async def recv( BrokenWorldException: An error that occurs when the world is broken due to worker, node or network failure. """ + fn = self._get_fn(world_name, "recv") try: with concurrent.futures.ThreadPoolExecutor() as pool: work = await self._loop.run_in_executor( pool, - dist.irecv, + fn, tensor, src, None, @@ -175,7 +219,8 @@ async def recv( except RuntimeError as e: self._handle_error(e, world_name) - await self._wait_work(work, world_name) + if isinstance(work, Work): + await self._wait_work(work, world_name) async def broadcast( self, tensor: Tensor, src: int, world_name: str = DEFAULT_WORLD_NAME diff --git a/multiworld/world_manager.py b/multiworld/world_manager.py index e22a195..e48e7b1 100644 --- a/multiworld/world_manager.py +++ b/multiworld/world_manager.py @@ -139,7 +139,7 @@ async def initialize_world( addr: host name or IP address. port: Port number. """ - self.add_world(world_name) + self.add_world(world_name, backend) loop = asyncio.get_running_loop() with concurrent.futures.ThreadPoolExecutor() as pool: @@ -158,19 +158,18 @@ async def initialize_world( store = self._worlds_stores[world_name] self._event_q.put((store, world_name, rank, world_size)) - def add_world(self, world_name, world=None): + def add_world(self, world_name: str, backend: str) -> None: """Add a new world to the world manager.""" if world_name in dist_c10d._worlds: raise ValueError(f"World {world_name} already exists.") - if world is None: - world = dist_c10d._World(world_name) + world = dist_c10d._World(world_name) dist_c10d._worlds[world_name] = world - self._communicator.add_world(world_name) + self._communicator.add_world(world_name, backend) - def remove_world(self, world_name): + def remove_world(self, world_name: str) -> None: """Remove a world from the world manager.""" if world_name not in dist_c10d._worlds: raise ValueError(f"World {world_name} does not exist.")