Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Message interface is hidden inside function call #301

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 17 additions & 21 deletions src/plumpy/process_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
ProcessStatus = Any

INTENT_KEY = 'intent'
MESSAGE_KEY = 'message'
MESSAGE_TEXT_KEY = 'message'
FORCE_KILL_KEY = 'force_kill'


Expand All @@ -52,23 +52,23 @@ def play(cls, text: str | None = None) -> MessageType:
"""The play message send over communicator."""
return {
INTENT_KEY: Intent.PLAY,
MESSAGE_KEY: text,
MESSAGE_TEXT_KEY: text,
}

@classmethod
def pause(cls, text: str | None = None) -> MessageType:
"""The pause message send over communicator."""
return {
INTENT_KEY: Intent.PAUSE,
MESSAGE_KEY: text,
MESSAGE_TEXT_KEY: text,
}

@classmethod
def kill(cls, text: str | None = None, force_kill: bool = False) -> MessageType:
"""The kill message send over communicator."""
return {
INTENT_KEY: Intent.KILL,
MESSAGE_KEY: text,
MESSAGE_TEXT_KEY: text,
FORCE_KILL_KEY: force_kill,
}

Expand All @@ -77,7 +77,7 @@ def status(cls, text: str | None = None) -> MessageType:
"""The status message send over communicator."""
return {
INTENT_KEY: Intent.STATUS,
MESSAGE_KEY: text,
MESSAGE_TEXT_KEY: text,
}


Expand Down Expand Up @@ -200,15 +200,15 @@ async def get_status(self, pid: 'PID_TYPE') -> 'ProcessStatus':
result = await asyncio.wrap_future(future)
return result

async def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'ProcessResult':
async def pause_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> 'ProcessResult':
"""
Pause the process

:param pid: the pid of the process to pause
:param msg: optional pause message
:return: True if paused, False otherwise
"""
msg = MessageBuilder.pause(text=msg)
msg = MessageBuilder.pause(text=msg_text)

pause_future = self._communicator.rpc_send(pid, msg)
# rpc_send return a thread future from communicator
Expand All @@ -229,16 +229,15 @@ async def play_process(self, pid: 'PID_TYPE') -> 'ProcessResult':
result = await asyncio.wrap_future(future)
return result

async def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> 'ProcessResult':
async def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> 'ProcessResult':
"""
Kill the process

:param pid: the pid of the process to kill
:param msg: optional kill message
:return: True if killed, False otherwise
"""
if msg is None:
msg = MessageBuilder.kill()
msg = MessageBuilder.kill(text=msg_text)

# Wait for the communication to go through
kill_future = self._communicator.rpc_send(pid, msg)
Expand Down Expand Up @@ -364,7 +363,7 @@ def get_status(self, pid: 'PID_TYPE') -> kiwipy.Future:
"""
return self._communicator.rpc_send(pid, MessageBuilder.status())

def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Future:
def pause_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> kiwipy.Future:
"""
Pause the process

Expand All @@ -373,16 +372,17 @@ def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Fu
:return: a response future from the process to be paused

"""
msg = MessageBuilder.pause(text=msg)
msg = MessageBuilder.pause(text=msg_text)

return self._communicator.rpc_send(pid, msg)

def pause_all(self, msg: Any) -> None:
def pause_all(self, msg_text: Optional[str]) -> None:
"""
Pause all processes that are subscribed to the same communicator

:param msg: an optional pause message
"""
msg = MessageBuilder.pause(text=msg_text)
self._communicator.broadcast_send(msg, subject=Intent.PAUSE)

def play_process(self, pid: 'PID_TYPE') -> kiwipy.Future:
Expand All @@ -401,28 +401,24 @@ def play_all(self) -> None:
"""
self._communicator.broadcast_send(None, subject=Intent.PLAY)

def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> kiwipy.Future:
def kill_process(self, pid: 'PID_TYPE', msg_text: Optional[str] = None) -> kiwipy.Future:
"""
Kill the process

:param pid: the pid of the process to kill
:param msg: optional kill message
:return: a response future from the process to be killed

"""
if msg is None:
msg = MessageBuilder.kill()

msg = MessageBuilder.kill(text=msg_text)
return self._communicator.rpc_send(pid, msg)

def kill_all(self, msg: Optional[MessageType]) -> None:
def kill_all(self, msg_text: Optional[str]) -> None:
"""
Kill all processes that are subscribed to the same communicator

:param msg: an optional pause message
"""
if msg is None:
msg = MessageBuilder.kill()
msg = MessageBuilder.kill(msg_text)

self._communicator.broadcast_send(msg, subject=Intent.KILL)

Expand Down
11 changes: 7 additions & 4 deletions src/plumpy/process_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,19 @@ class Interruption(Exception): # noqa: N818


class KillInterruption(Interruption):
def __init__(self, msg: MessageType | None):
def __init__(self, msg_text: str | None):
super().__init__()
if msg is None:
msg = MessageBuilder.kill()
msg = MessageBuilder.kill(text=msg_text)

self.msg: MessageType = msg


class PauseInterruption(Interruption):
pass
def __init__(self, msg_text: str | None):
super().__init__()
msg = MessageBuilder.pause(text=msg_text)

self.msg: MessageType = msg


# region Commands
Expand Down
47 changes: 27 additions & 20 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event
from .base.utils import call_with_super_check, super_check
from .event_helper import EventHelper
from .process_comms import MESSAGE_KEY, MessageBuilder, MessageType
from .process_comms import MESSAGE_TEXT_KEY, MessageBuilder, MessageType
from .process_listener import ProcessListener
from .process_spec import ProcessSpec
from .utils import PID_TYPE, SAVED_STATE_TYPE, protected
Expand Down Expand Up @@ -344,8 +344,7 @@ def init(self) -> None:

def try_killing(future: futures.Future) -> None:
if future.cancelled():
msg = MessageBuilder.kill(text='Killed by future being cancelled')
if not self.kill(msg):
if not self.kill('Killed by future being cancelled'):
self.logger.warning(
'Process<%s>: Failed to kill process on future cancel',
self.pid,
Expand Down Expand Up @@ -903,7 +902,7 @@ def on_kill(self, msg: Optional[MessageType]) -> None:
if msg is None:
msg_txt = ''
else:
msg_txt = msg[MESSAGE_KEY] or ''
msg_txt = msg[MESSAGE_TEXT_KEY] or ''

self.set_status(msg_txt)
self.future().set_exception(exceptions.KilledError(msg_txt))
Expand Down Expand Up @@ -944,7 +943,7 @@ def _fire_event(self, evt: Callable[..., Any], *args: Any, **kwargs: Any) -> Non

# region Communication

def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> Any:
def message_receive(self, _comm: kiwipy.Communicator, msg: MessageType) -> Any:
"""
Coroutine called when the process receives a message from the communicator

Expand All @@ -964,9 +963,9 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An
if intent == process_comms.Intent.PLAY:
return self._schedule_rpc(self.play)
if intent == process_comms.Intent.PAUSE:
return self._schedule_rpc(self.pause, msg=msg.get(process_comms.MESSAGE_KEY, None))
return self._schedule_rpc(self.pause, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None))
if intent == process_comms.Intent.KILL:
return self._schedule_rpc(self.kill, msg=msg)
return self._schedule_rpc(self.kill, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None))
if intent == process_comms.Intent.STATUS:
status_info: Dict[str, Any] = {}
self.get_status_info(status_info)
Expand All @@ -976,7 +975,7 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An
raise RuntimeError('Unknown intent')

def broadcast_receive(
self, _comm: kiwipy.Communicator, body: Any, sender: Any, subject: Any, correlation_id: Any
self, _comm: kiwipy.Communicator, msg: MessageType, sender: Any, subject: Any, correlation_id: Any
) -> Optional[kiwipy.Future]:
"""
Coroutine called when the process receives a message from the communicator
Expand All @@ -990,16 +989,16 @@ def broadcast_receive(
self.pid,
subject,
_comm,
body,
msg,
)

# If we get a message we recognise then action it, otherwise ignore
if subject == process_comms.Intent.PLAY:
return self._schedule_rpc(self.play)
if subject == process_comms.Intent.PAUSE:
return self._schedule_rpc(self.pause, msg=body)
return self._schedule_rpc(self.pause, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None))
if subject == process_comms.Intent.KILL:
return self._schedule_rpc(self.kill, msg=body)
return self._schedule_rpc(self.kill, msg_text=msg.get(process_comms.MESSAGE_TEXT_KEY, None))
return None

def _schedule_rpc(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> kiwipy.Future:
Expand Down Expand Up @@ -1071,7 +1070,7 @@ def transition_failed(
)
self.transition_to(new_state)

def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]:
def pause(self, msg_text: Optional[str] = None) -> Union[bool, futures.CancellableAction]:
"""Pause the process.

:param msg: an optional message to set as the status. The current status will be saved in the private
Expand All @@ -1095,22 +1094,29 @@ def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.Cancellable
if self._stepping:
# Ask the step function to pause by setting this flag and giving the
# caller back a future
interrupt_exception = process_states.PauseInterruption(msg)
interrupt_exception = process_states.PauseInterruption(msg_text)
self._set_interrupt_action_from_exception(interrupt_exception)
self._pausing = self._interrupt_action
# Try to interrupt the state
self._state.interrupt(interrupt_exception)
return cast(futures.CancellableAction, self._interrupt_action)

return self._do_pause(msg)
msg = MessageBuilder.pause(msg_text)
return self._do_pause(state_msg=msg)

def _do_pause(self, state_msg: Optional[str], next_state: Optional[process_states.State] = None) -> bool:
def _do_pause(self, state_msg: Optional[MessageType], next_state: Optional[process_states.State] = None) -> bool:
"""Carry out the pause procedure, optionally transitioning to the next state first"""
try:
if next_state is not None:
self.transition_to(next_state)
call_with_super_check(self.on_pausing, state_msg)
call_with_super_check(self.on_paused, state_msg)

if state_msg is None:
msg_text = ''
else:
msg_text = state_msg[MESSAGE_TEXT_KEY]

call_with_super_check(self.on_pausing, msg_text)
call_with_super_check(self.on_paused, msg_text)
finally:
self._pausing = None

Expand All @@ -1125,7 +1131,7 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu

"""
if isinstance(exception, process_states.PauseInterruption):
do_pause = functools.partial(self._do_pause, str(exception))
do_pause = functools.partial(self._do_pause, exception.msg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to add a comment that exception.msg is an attribute of PauseInterruption of type MessageType

return futures.CancellableAction(do_pause, cookie=exception)

if isinstance(exception, process_states.KillInterruption):
Expand Down Expand Up @@ -1190,7 +1196,7 @@ def fail(self, exception: Optional[BaseException], trace_back: Optional[Tracebac
)
self.transition_to(new_state)

def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]:
def kill(self, msg_text: Optional[str] = None) -> Union[bool, asyncio.Future]:
"""
Kill the process
:param msg: An optional kill message
Expand All @@ -1210,12 +1216,13 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]
if self._stepping:
# Ask the step function to pause by setting this flag and giving the
# caller back a future
interrupt_exception = process_states.KillInterruption(msg)
interrupt_exception = process_states.KillInterruption(msg_text)
self._set_interrupt_action_from_exception(interrupt_exception)
self._killing = self._interrupt_action
self._state.interrupt(interrupt_exception)
return cast(futures.CancellableAction, self._interrupt_action)

msg = MessageBuilder.kill(msg_text)
new_state = self._create_state_instance(process_states.ProcessState.KILLED, msg=msg)
self.transition_to(new_state)
return True
Expand Down
4 changes: 1 addition & 3 deletions tests/rmq/test_process_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,7 @@ async def test_kill_all(self, thread_communicator, sync_controller):
for _ in range(10):
procs.append(utils.WaitForSignalProcess(communicator=thread_communicator))

msg = process_comms.MessageBuilder.kill(text='bang bang, I shot you down')

sync_controller.kill_all(msg)
sync_controller.kill_all(msg_text='bang bang, I shot you down')
await utils.wait_util(lambda: all([proc.killed() for proc in procs]))
assert all([proc.state == plumpy.ProcessState.KILLED for proc in procs])

Expand Down
8 changes: 4 additions & 4 deletions tests/test_processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import plumpy
from plumpy import BundleKeys, Process, ProcessState
from plumpy.process_comms import MessageBuilder
from plumpy.process_comms import MESSAGE_TEXT_KEY, MessageBuilder
from plumpy.utils import AttributesFrozendict
from tests import utils

Expand Down Expand Up @@ -322,10 +322,10 @@ def run(self, **kwargs):
def test_kill(self):
proc: Process = utils.DummyProcess()

msg = MessageBuilder.kill(text='Farewell!')
proc.kill(msg)
msg_text = 'Farewell!'
proc.kill(msg_text=msg_text)
self.assertTrue(proc.killed())
self.assertEqual(proc.killed_msg(), msg)
self.assertEqual(proc.killed_msg()[MESSAGE_TEXT_KEY], msg_text)
self.assertEqual(proc.state, ProcessState.KILLED)

def test_wait_continue(self):
Expand Down
Loading