Skip to content

Commit

Permalink
Fix exception in MPFuture.__del__() (#555)
Browse files Browse the repository at this point in the history
This PR addresses the bug reported in #552 - or, at least, it should, since we cannot reproduce the problem locally.
  • Loading branch information
justheuristic authored Feb 14, 2023
1 parent 7d1bb7d commit 8c98caa
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions hivemind/utils/mpfuture.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class MPFuture(base.Future, Generic[ResultType]):
_active_pid: Optional[PID] = None # pid of currently active process; used to handle forks natively

def __init__(self, *, use_lock: bool = True):
self._maybe_initialize_mpfuture_backend()

self._origin_pid, self._uid = os.getpid(), uuid.uuid4().int
self._shared_state_code = SharedBytes.next()
self._state_cache: Dict[State, State] = {}
Expand All @@ -105,11 +107,6 @@ def __init__(self, *, use_lock: bool = True):
self._state, self._result, self._exception = base.PENDING, None, None
self._use_lock = use_lock

if self._origin_pid != MPFuture._active_pid:
with MPFuture._initialization_lock:
if self._origin_pid != MPFuture._active_pid:
# note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking
self._initialize_mpfuture_backend()
assert self._uid not in MPFuture._active_futures
MPFuture._active_futures[self._uid] = ref(self)
self._sender_pipe = MPFuture._global_sender_pipe
Expand Down Expand Up @@ -151,16 +148,23 @@ async def _event_setter():
self._loop.run_until_complete(_event_setter())

@classmethod
def _initialize_mpfuture_backend(cls):
def _maybe_initialize_mpfuture_backend(cls):
pid = os.getpid()
logger.debug(f"Initializing MPFuture backend for pid {pid}")

receiver_pipe, cls._global_sender_pipe = mp.Pipe(duplex=False)
cls._active_pid, cls._active_futures = pid, {}
cls._pipe_waiter_thread = threading.Thread(
target=cls._process_updates_in_background, args=[receiver_pipe], name=f"{__name__}.BACKEND", daemon=True
)
cls._pipe_waiter_thread.start()
if pid != MPFuture._active_pid:
with MPFuture._initialization_lock:
if pid != MPFuture._active_pid:
# note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking
logger.debug(f"Initializing MPFuture backend for pid {pid}")

receiver_pipe, cls._global_sender_pipe = mp.Pipe(duplex=False)
cls._active_pid, cls._active_futures = pid, {}
cls._pipe_waiter_thread = threading.Thread(
target=cls._process_updates_in_background,
args=[receiver_pipe],
name=f"{__name__}.BACKEND",
daemon=True,
)
cls._pipe_waiter_thread.start()

@staticmethod
def reset_backend():
Expand Down Expand Up @@ -296,7 +300,7 @@ def __await__(self):
raise asyncio.CancelledError()

def __del__(self):
if getattr(self, "_origin_pid", None) == os.getpid():
if getattr(self, "_origin_pid", None) == os.getpid() and MPFuture._active_futures is not None:
MPFuture._active_futures.pop(self._uid, None)
if getattr(self, "_aio_event", None):
self._aio_event.set()
Expand Down

0 comments on commit 8c98caa

Please sign in to comment.