From db256d58824553c315f109d56efb777c9cf78ad0 Mon Sep 17 00:00:00 2001 From: rafa-be Date: Thu, 3 Oct 2024 14:15:01 +0200 Subject: [PATCH] Fixes `Future.done()` behavior (issue #24). Signed-off-by: rafa-be --- scaler/about.py | 2 +- scaler/client/future.py | 69 ++++++++++++++++++++++++++++++++++++----- tests/test_future.py | 13 +++++++- 3 files changed, 74 insertions(+), 10 deletions(-) diff --git a/scaler/about.py b/scaler/about.py index 89c6ad8..17ecd62 100644 --- a/scaler/about.py +++ b/scaler/about.py @@ -1 +1 @@ -__version__ = "1.8.5" +__version__ = "1.8.6" diff --git a/scaler/client/future.py b/scaler/client/future.py index c5379ad..8697c98 100644 --- a/scaler/client/future.py +++ b/scaler/client/future.py @@ -23,6 +23,7 @@ def __init__(self, task: Task, is_delayed: bool, group_task_id: Optional[bytes], self._result_object_id: Optional[bytes] = None self._result_ready_event = threading.Event() self._result_request_sent = False + self._result_received = False self._profiling_info: Optional[ProfileResult] = None @@ -42,6 +43,8 @@ def set_result_ready(self, object_id: Optional[bytes], profile_result: Optional[ if self.done(): raise InvalidStateError(f"invalid future state: {self._state}") + self._state = "FINISHED" + if object_id is not None: self._result_object_id = object_id @@ -54,25 +57,74 @@ def set_result_ready(self, object_id: Optional[bytes], profile_result: Optional[ self._result_ready_event.set() - def set_exception(self, exception: Optional[BaseException], profile_result: Optional[ProfileResult] = None) -> None: + def _set_result_or_exception( + self, + result: Optional[Any] = None, + exception: Optional[BaseException] = None, + profiling_info: Optional[ProfileResult] = None + ) -> None: with self._condition: # type: ignore[attr-defined] - if profile_result is not None: - self._profiling_info = profile_result + if self.cancelled(): + raise InvalidStateError(f"invalid future state: {self._state}") + + if self._result_received: + raise InvalidStateError("future already received object data.") + + if profiling_info is not None: + if self._profiling_info is not None: + raise InvalidStateError("cannot set profiling info twice.") + + self._profiling_info = profiling_info + + self._state = "FINISHED" + self._result_received = True + + if exception is not None: + assert result is None + self._exception = exception + for waiter in self._waiters: + waiter.add_exception(self) + else: + self._result = result + for waiter in self._waiters: + waiter.add_result(self) self._result_ready_event.set() + self._condition.notify_all() - return super().set_exception(exception) + self._invoke_callbacks() # type: ignore[attr-defined] - def result(self, timeout=None): + def set_result(self, result: Any, profiling_info: Optional[ProfileResult] = None) -> None: + self._set_result_or_exception(result=result, profiling_info=profiling_info) + + def set_exception(self, exception: Optional[BaseException], profiling_info: Optional[ProfileResult] = None) -> None: + self._set_result_or_exception(exception=exception, profiling_info=profiling_info) + + def result(self, timeout: Optional[float] = None) -> Any: self._result_ready_event.wait(timeout) with self._condition: # type: ignore[attr-defined] - # if it's delayed future, get the result when future.result() get called + # if it's delayed future, get the result when future.result() gets called if self._is_delayed: self._request_result_object() - # wait for - return super().result(timeout) + if not self._result_received: + self._condition.wait(timeout) + + return super().result() + + def exception(self, timeout: Optional[float] = None) -> Optional[BaseException]: + self._result_ready_event.wait(timeout) + + with self._condition: # type: ignore[attr-defined] + # if it's delayed future, get the result when future.exception() gets called + if self._is_delayed: + self._request_result_object() + + if not self._result_received: + self._condition.wait(timeout) + + return super().exception() def cancel(self) -> bool: with self._condition: # type: ignore[attr-defined] @@ -88,6 +140,7 @@ def cancel(self) -> bool: self._connector.send(TaskCancel.new_msg(self._task_id)) self._state = "CANCELLED" + self._result_received = True self._result_ready_event.set() self._condition.notify_all() # type: ignore[attr-defined] diff --git a/tests/test_future.py b/tests/test_future.py index f7b7449..a6f6666 100644 --- a/tests/test_future.py +++ b/tests/test_future.py @@ -28,6 +28,7 @@ def test_callback(self): done_called_event = Event() def on_done_callback(fut): + self.assertTrue(fut.done()) self.assertAlmostEqual(fut.result(), 4.0) done_called_event.set() @@ -57,13 +58,21 @@ def test_state(self): def test_cancel(self): with Client(address=self.address) as client: fut = client.submit(math.sqrt, 100.0) - fut.cancel() + self.assertTrue(fut.cancel()) self.assertTrue(fut.cancelled()) + self.assertTrue(fut.done()) with self.assertRaises(CancelledError): fut.result() + fut = client.submit(math.sqrt, 16) + fut.result() + + # cancel() should fail on a completed future. + self.assertFalse(fut.cancel()) + self.assertFalse(fut.cancelled()) + def test_exception(self): with Client(address=self.address) as client: fut = client.submit(math.sqrt, "16") @@ -71,6 +80,8 @@ def test_exception(self): with self.assertRaises(TypeError): fut.result() + self.assertTrue(fut.done()) + self.assertIsInstance(fut.exception(), TypeError) def test_client_disconnected(self):