From 560098c2b55a312b60884a0b8dfac97f6e8139d8 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 16 Dec 2024 13:06:12 +0100 Subject: [PATCH 1/3] Message interface is hidden inside function call --- src/plumpy/process_comms.py | 28 +++++++++------------ src/plumpy/process_states.py | 11 ++++++--- src/plumpy/processes.py | 43 +++++++++++++++++++-------------- tests/rmq/test_process_comms.py | 4 +-- tests/test_processes.py | 8 +++--- 5 files changed, 49 insertions(+), 45 deletions(-) diff --git a/src/plumpy/process_comms.py b/src/plumpy/process_comms.py index e615ee4a..3b1556fb 100644 --- a/src/plumpy/process_comms.py +++ b/src/plumpy/process_comms.py @@ -200,7 +200,7 @@ async def get_status(self, pid: 'PID_TYPE') -> 'ProcessStatus': result = await asyncio.wrap_future(future) return result - async def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'ProcessResult': + async def pause_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> 'ProcessResult': """ Pause the process @@ -208,7 +208,7 @@ async def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'Pr :param msg: optional pause message :return: True if paused, False otherwise """ - msg = MessageBuilder.pause(text=msg) + msg = MessageBuilder.pause(text=msg_text) pause_future = self._communicator.rpc_send(pid, msg) # rpc_send return a thread future from communicator @@ -229,7 +229,7 @@ async def play_process(self, pid: 'PID_TYPE') -> 'ProcessResult': result = await asyncio.wrap_future(future) return result - async def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> 'ProcessResult': + async def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> 'ProcessResult': """ Kill the process @@ -237,8 +237,7 @@ async def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) :param msg: optional kill message :return: True if killed, False otherwise """ - if msg is None: - msg = MessageBuilder.kill() + msg = MessageBuilder.kill(text=msg_text) # Wait for the communication to go through kill_future = self._communicator.rpc_send(pid, msg) @@ -364,7 +363,7 @@ def get_status(self, pid: 'PID_TYPE') -> kiwipy.Future: """ return self._communicator.rpc_send(pid, MessageBuilder.status()) - def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Future: + def pause_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> kiwipy.Future: """ Pause the process @@ -373,16 +372,17 @@ def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Fu :return: a response future from the process to be paused """ - msg = MessageBuilder.pause(text=msg) + msg = MessageBuilder.pause(text=msg_text) return self._communicator.rpc_send(pid, msg) - def pause_all(self, msg: Any) -> None: + def pause_all(self, msg_text: Optional[str]) -> None: """ Pause all processes that are subscribed to the same communicator :param msg: an optional pause message """ + msg = MessageBuilder.pause(text=msg_text) self._communicator.broadcast_send(msg, subject=Intent.PAUSE) def play_process(self, pid: 'PID_TYPE') -> kiwipy.Future: @@ -401,28 +401,24 @@ def play_all(self) -> None: """ self._communicator.broadcast_send(None, subject=Intent.PLAY) - def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> kiwipy.Future: + def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None, force_kill: bool = False) -> kiwipy.Future: """ Kill the process :param pid: the pid of the process to kill :param msg: optional kill message :return: a response future from the process to be killed - """ - if msg is None: - msg = MessageBuilder.kill() - + msg = MessageBuilder.kill(text=msg_text, force_kill=force_kill) return self._communicator.rpc_send(pid, msg) - def kill_all(self, msg: Optional[MessageType]) -> None: + def kill_all(self, msg_text: Optional[str]) -> None: """ Kill all processes that are subscribed to the same communicator :param msg: an optional pause message """ - if msg is None: - msg = MessageBuilder.kill() + msg = MessageBuilder.kill(msg_text) self._communicator.broadcast_send(msg, subject=Intent.KILL) diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index d369a1e9..931dbc5e 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -52,16 +52,19 @@ class Interruption(Exception): # noqa: N818 class KillInterruption(Interruption): - def __init__(self, msg: MessageType | None): + def __init__(self, msg_text: str | None): super().__init__() - if msg is None: - msg = MessageBuilder.kill() + msg = MessageBuilder.kill(text=msg_text) self.msg: MessageType = msg class PauseInterruption(Interruption): - pass + def __init__(self, msg_text: str | None): + super().__init__() + msg = MessageBuilder.pause(text=msg_text) + + self.msg: MessageType = msg # region Commands diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 0866ee41..d984e171 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -344,8 +344,7 @@ def init(self) -> None: def try_killing(future: futures.Future) -> None: if future.cancelled(): - msg = MessageBuilder.kill(text='Killed by future being cancelled') - if not self.kill(msg): + if not self.kill('Killed by future being cancelled'): self.logger.warning( 'Process<%s>: Failed to kill process on future cancel', self.pid, @@ -944,7 +943,7 @@ def _fire_event(self, evt: Callable[..., Any], *args: Any, **kwargs: Any) -> Non # region Communication - def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> Any: + def message_receive(self, _comm: kiwipy.Communicator, msg: MessageType) -> Any: """ Coroutine called when the process receives a message from the communicator @@ -964,9 +963,9 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An if intent == process_comms.Intent.PLAY: return self._schedule_rpc(self.play) if intent == process_comms.Intent.PAUSE: - return self._schedule_rpc(self.pause, msg=msg.get(process_comms.MESSAGE_KEY, None)) + return self._schedule_rpc(self.pause, msg_text=msg.get(process_comms.MESSAGE_KEY, None)) if intent == process_comms.Intent.KILL: - return self._schedule_rpc(self.kill, msg=msg) + return self._schedule_rpc(self.kill, msg_text=msg.get(process_comms.MESSAGE_KEY, None)) if intent == process_comms.Intent.STATUS: status_info: Dict[str, Any] = {} self.get_status_info(status_info) @@ -976,7 +975,7 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An raise RuntimeError('Unknown intent') def broadcast_receive( - self, _comm: kiwipy.Communicator, body: Any, sender: Any, subject: Any, correlation_id: Any + self, _comm: kiwipy.Communicator, msg: MessageType, sender: Any, subject: Any, correlation_id: Any ) -> Optional[kiwipy.Future]: """ Coroutine called when the process receives a message from the communicator @@ -990,16 +989,16 @@ def broadcast_receive( self.pid, subject, _comm, - body, + msg, ) # If we get a message we recognise then action it, otherwise ignore if subject == process_comms.Intent.PLAY: return self._schedule_rpc(self.play) if subject == process_comms.Intent.PAUSE: - return self._schedule_rpc(self.pause, msg=body) + return self._schedule_rpc(self.pause, msg_text=msg.get(process_comms.MESSAGE_KEY, None)) if subject == process_comms.Intent.KILL: - return self._schedule_rpc(self.kill, msg=body) + return self._schedule_rpc(self.kill, msg_text=msg.get(process_comms.MESSAGE_KEY, None)) return None def _schedule_rpc(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> kiwipy.Future: @@ -1071,7 +1070,7 @@ def transition_failed( ) self.transition_to(new_state) - def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]: + def pause(self, msg_text: Optional[str] = None) -> Union[bool, futures.CancellableAction]: """Pause the process. :param msg: an optional message to set as the status. The current status will be saved in the private @@ -1095,22 +1094,29 @@ def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.Cancellable if self._stepping: # Ask the step function to pause by setting this flag and giving the # caller back a future - interrupt_exception = process_states.PauseInterruption(msg) + interrupt_exception = process_states.PauseInterruption(msg_text) self._set_interrupt_action_from_exception(interrupt_exception) self._pausing = self._interrupt_action # Try to interrupt the state self._state.interrupt(interrupt_exception) return cast(futures.CancellableAction, self._interrupt_action) - return self._do_pause(msg) + msg = MessageBuilder.pause(msg_text) + return self._do_pause(state_msg=msg) - def _do_pause(self, state_msg: Optional[str], next_state: Optional[process_states.State] = None) -> bool: + def _do_pause(self, state_msg: Optional[MessageType], next_state: Optional[process_states.State] = None) -> bool: """Carry out the pause procedure, optionally transitioning to the next state first""" try: if next_state is not None: self.transition_to(next_state) - call_with_super_check(self.on_pausing, state_msg) - call_with_super_check(self.on_paused, state_msg) + + if state_msg is None: + msg_text = '' + else: + msg_text = state_msg[MESSAGE_KEY] + + call_with_super_check(self.on_pausing, msg_text) + call_with_super_check(self.on_paused, msg_text) finally: self._pausing = None @@ -1125,7 +1131,7 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu """ if isinstance(exception, process_states.PauseInterruption): - do_pause = functools.partial(self._do_pause, str(exception)) + do_pause = functools.partial(self._do_pause, exception.msg) return futures.CancellableAction(do_pause, cookie=exception) if isinstance(exception, process_states.KillInterruption): @@ -1190,7 +1196,7 @@ def fail(self, exception: Optional[BaseException], trace_back: Optional[Tracebac ) self.transition_to(new_state) - def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]: + def kill(self, msg_text: Optional[str] = None) -> Union[bool, asyncio.Future]: """ Kill the process :param msg: An optional kill message @@ -1210,12 +1216,13 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future] if self._stepping: # Ask the step function to pause by setting this flag and giving the # caller back a future - interrupt_exception = process_states.KillInterruption(msg) + interrupt_exception = process_states.KillInterruption(msg_text) self._set_interrupt_action_from_exception(interrupt_exception) self._killing = self._interrupt_action self._state.interrupt(interrupt_exception) return cast(futures.CancellableAction, self._interrupt_action) + msg = MessageBuilder.kill(msg_text) new_state = self._create_state_instance(process_states.ProcessState.KILLED, msg=msg) self.transition_to(new_state) return True diff --git a/tests/rmq/test_process_comms.py b/tests/rmq/test_process_comms.py index a6249d10..7a03fac4 100644 --- a/tests/rmq/test_process_comms.py +++ b/tests/rmq/test_process_comms.py @@ -195,9 +195,7 @@ async def test_kill_all(self, thread_communicator, sync_controller): for _ in range(10): procs.append(utils.WaitForSignalProcess(communicator=thread_communicator)) - msg = process_comms.MessageBuilder.kill(text='bang bang, I shot you down') - - sync_controller.kill_all(msg) + sync_controller.kill_all(msg_text='bang bang, I shot you down') await utils.wait_util(lambda: all([proc.killed() for proc in procs])) assert all([proc.state == plumpy.ProcessState.KILLED for proc in procs]) diff --git a/tests/test_processes.py b/tests/test_processes.py index 7b21c463..bba80739 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -10,7 +10,7 @@ import plumpy from plumpy import BundleKeys, Process, ProcessState -from plumpy.process_comms import MessageBuilder +from plumpy.process_comms import MESSAGE_KEY, MessageBuilder from plumpy.utils import AttributesFrozendict from tests import utils @@ -322,10 +322,10 @@ def run(self, **kwargs): def test_kill(self): proc: Process = utils.DummyProcess() - msg = MessageBuilder.kill(text='Farewell!') - proc.kill(msg) + msg_text = 'Farewell!' + proc.kill(msg_text=msg_text) self.assertTrue(proc.killed()) - self.assertEqual(proc.killed_msg(), msg) + self.assertEqual(proc.killed_msg()[MESSAGE_KEY], msg_text) self.assertEqual(proc.state, ProcessState.KILLED) def test_wait_continue(self): From a84b93c20913d3e7f3c20b41df5b275e9415fd1c Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Tue, 17 Dec 2024 11:32:03 +0100 Subject: [PATCH 2/3] rali --- src/plumpy/process_comms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/plumpy/process_comms.py b/src/plumpy/process_comms.py index 3b1556fb..5db12d58 100644 --- a/src/plumpy/process_comms.py +++ b/src/plumpy/process_comms.py @@ -401,7 +401,7 @@ def play_all(self) -> None: """ self._communicator.broadcast_send(None, subject=Intent.PLAY) - def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None, force_kill: bool = False) -> kiwipy.Future: + def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> kiwipy.Future: """ Kill the process @@ -409,7 +409,7 @@ def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None, force_ki :param msg: optional kill message :return: a response future from the process to be killed """ - msg = MessageBuilder.kill(text=msg_text, force_kill=force_kill) + msg = MessageBuilder.kill(text=msg_text) return self._communicator.rpc_send(pid, msg) def kill_all(self, msg_text: Optional[str]) -> None: From 05c36fa19514be2851f7e6886f3a327f0b711ed9 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Tue, 17 Dec 2024 14:52:28 +0100 Subject: [PATCH 3/3] Rename MESSAGE_KEY to MESSAGE_TEXT_KEY --- src/plumpy/process_comms.py | 10 +++++----- src/plumpy/processes.py | 14 +++++++------- tests/test_processes.py | 4 ++-- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/plumpy/process_comms.py b/src/plumpy/process_comms.py index 5db12d58..2d6b3bf4 100644 --- a/src/plumpy/process_comms.py +++ b/src/plumpy/process_comms.py @@ -28,7 +28,7 @@ ProcessStatus = Any INTENT_KEY = 'intent' -MESSAGE_KEY = 'message' +MESSAGE_TEXT_KEY = 'message' FORCE_KILL_KEY = 'force_kill' @@ -52,7 +52,7 @@ def play(cls, text: str | None = None) -> MessageType: """The play message send over communicator.""" return { INTENT_KEY: Intent.PLAY, - MESSAGE_KEY: text, + MESSAGE_TEXT_KEY: text, } @classmethod @@ -60,7 +60,7 @@ def pause(cls, text: str | None = None) -> MessageType: """The pause message send over communicator.""" return { INTENT_KEY: Intent.PAUSE, - MESSAGE_KEY: text, + MESSAGE_TEXT_KEY: text, } @classmethod @@ -68,7 +68,7 @@ def kill(cls, text: str | None = None, force_kill: bool = False) -> MessageType: """The kill message send over communicator.""" return { INTENT_KEY: Intent.KILL, - MESSAGE_KEY: text, + MESSAGE_TEXT_KEY: text, FORCE_KILL_KEY: force_kill, } @@ -77,7 +77,7 @@ def status(cls, text: str | None = None) -> MessageType: """The status message send over communicator.""" return { INTENT_KEY: Intent.STATUS, - MESSAGE_KEY: text, + MESSAGE_TEXT_KEY: text, } diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index d984e171..c12e185e 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -54,7 +54,7 @@ from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event from .base.utils import call_with_super_check, super_check from .event_helper import EventHelper -from .process_comms import MESSAGE_KEY, MessageBuilder, MessageType +from .process_comms import MESSAGE_TEXT_KEY, MessageBuilder, MessageType from .process_listener import ProcessListener from .process_spec import ProcessSpec from .utils import PID_TYPE, SAVED_STATE_TYPE, protected @@ -902,7 +902,7 @@ def on_kill(self, msg: Optional[MessageType]) -> None: if msg is None: msg_txt = '' else: - msg_txt = msg[MESSAGE_KEY] or '' + msg_txt = msg[MESSAGE_TEXT_KEY] or '' self.set_status(msg_txt) self.future().set_exception(exceptions.KilledError(msg_txt)) @@ -963,9 +963,9 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: MessageType) -> Any: if intent == process_comms.Intent.PLAY: return self._schedule_rpc(self.play) if intent == process_comms.Intent.PAUSE: - return self._schedule_rpc(self.pause, msg_text=msg.get(process_comms.MESSAGE_KEY, None)) + return self._schedule_rpc(self.pause, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None)) if intent == process_comms.Intent.KILL: - return self._schedule_rpc(self.kill, msg_text=msg.get(process_comms.MESSAGE_KEY, None)) + return self._schedule_rpc(self.kill, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None)) if intent == process_comms.Intent.STATUS: status_info: Dict[str, Any] = {} self.get_status_info(status_info) @@ -996,9 +996,9 @@ def broadcast_receive( if subject == process_comms.Intent.PLAY: return self._schedule_rpc(self.play) if subject == process_comms.Intent.PAUSE: - return self._schedule_rpc(self.pause, msg_text=msg.get(process_comms.MESSAGE_KEY, None)) + return self._schedule_rpc(self.pause, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None)) if subject == process_comms.Intent.KILL: - return self._schedule_rpc(self.kill, msg_text=msg.get(process_comms.MESSAGE_KEY, None)) + return self._schedule_rpc(self.kill, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None)) return None def _schedule_rpc(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> kiwipy.Future: @@ -1113,7 +1113,7 @@ def _do_pause(self, state_msg: Optional[MessageType], next_state: Optional[proce if state_msg is None: msg_text = '' else: - msg_text = state_msg[MESSAGE_KEY] + msg_text = state_msg[MESSAGE_TEXT_KEY] call_with_super_check(self.on_pausing, msg_text) call_with_super_check(self.on_paused, msg_text) diff --git a/tests/test_processes.py b/tests/test_processes.py index bba80739..5d3184f2 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -10,7 +10,7 @@ import plumpy from plumpy import BundleKeys, Process, ProcessState -from plumpy.process_comms import MESSAGE_KEY, MessageBuilder +from plumpy.process_comms import MESSAGE_TEXT_KEY, MessageBuilder from plumpy.utils import AttributesFrozendict from tests import utils @@ -325,7 +325,7 @@ def test_kill(self): msg_text = 'Farewell!' proc.kill(msg_text=msg_text) self.assertTrue(proc.killed()) - self.assertEqual(proc.killed_msg()[MESSAGE_KEY], msg_text) + self.assertEqual(proc.killed_msg()[MESSAGE_TEXT_KEY], msg_text) self.assertEqual(proc.state, ProcessState.KILLED) def test_wait_continue(self):