From 1117eeb07aa93c8cf32a92fb7d27a74205155fea Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Wed, 27 Nov 2024 14:08:14 +0100 Subject: [PATCH 01/29] amend from rebase --- tests/test_communications.py | 2 ++ tests/test_expose.py | 36 ++++++++++++++++++++++++++++++++++++ tests/test_processes.py | 3 +++ 3 files changed, 41 insertions(+) diff --git a/tests/test_communications.py b/tests/test_communications.py index f7e04255..37177d6e 100644 --- a/tests/test_communications.py +++ b/tests/test_communications.py @@ -4,6 +4,8 @@ import pytest from kiwipy import CommunicatorHelper +import pytest +from kiwipy import CommunicatorHelper from plumpy.communications import LoopCommunicator diff --git a/tests/test_expose.py b/tests/test_expose.py index 0f6f8087..f48ce32e 100644 --- a/tests/test_expose.py +++ b/tests/test_expose.py @@ -8,6 +8,42 @@ from plumpy.processes import Process +def validator_function(input, port): + pass + + +class BaseNamespaceProcess(NewLoopProcess): + @classmethod + def define(cls, spec): + super().define(spec) + spec.input('top') + spec.input('namespace.sub_one') + spec.input('namespace.sub_two') + spec.inputs['namespace'].valid_type = (int, float) + spec.inputs['namespace'].validator = validator_function + + +class BaseProcess(NewLoopProcess): + @classmethod + def define(cls, spec): + super().define(spec) + spec.input('a', valid_type=str, default='a') + spec.input('b', valid_type=str, default='b') + spec.inputs.dynamic = True + spec.inputs.valid_type = str + + +class ExposeProcess(NewLoopProcess): + @classmethod + def define(cls, spec): + super().define(spec) + spec.expose_inputs(BaseProcess, namespace='base.name.space') + spec.input('c', valid_type=int, default=1) + spec.input('d', valid_type=int, default=2) + spec.inputs.dynamic = True + spec.inputs.valid_type = int + + def validator_function(input, port): pass diff --git a/tests/test_processes.py b/tests/test_processes.py index faea9eae..bc500688 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -10,6 +10,9 @@ import pytest from tests import utils +import plumpy +import pytest + import plumpy from plumpy import BundleKeys, Process, ProcessState from plumpy.process_comms import KILL_MSG, MESSAGE_KEY From b82791dd436c86af1b1cb20dd5131fe931133ff0 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Sat, 30 Nov 2024 00:33:29 +0100 Subject: [PATCH 02/29] Add default MESSAGE_KEY to None value and FORCE_KILL_KEY --- src/plumpy/process_comms.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/plumpy/process_comms.py b/src/plumpy/process_comms.py index 293c680b..13aa5fb3 100644 --- a/src/plumpy/process_comms.py +++ b/src/plumpy/process_comms.py @@ -31,7 +31,7 @@ INTENT_KEY = 'intent' MESSAGE_KEY = 'message' - +FORCE_KILL_KEY = 'force_kill' class Intent: """Intent constants for a process message""" @@ -42,10 +42,10 @@ class Intent: STATUS: str = 'status' -PAUSE_MSG = {INTENT_KEY: Intent.PAUSE} -PLAY_MSG = {INTENT_KEY: Intent.PLAY} -KILL_MSG = {INTENT_KEY: Intent.KILL} -STATUS_MSG = {INTENT_KEY: Intent.STATUS} +PAUSE_MSG = {INTENT_KEY: Intent.PAUSE, MESSAGE_KEY: None} +PLAY_MSG = {INTENT_KEY: Intent.PLAY, MESSAGE_KEY: None} +KILL_MSG = {INTENT_KEY: Intent.KILL, MESSAGE_KEY: None, FORCE_KILL_KEY: False} +STATUS_MSG = {INTENT_KEY: Intent.STATUS, MESSAGE_KEY: None} TASK_KEY = 'task' TASK_ARGS = 'args' From d4c0489b96c55ffa39a5be11f26487d6165f2bc4 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Sat, 30 Nov 2024 23:45:12 +0100 Subject: [PATCH 03/29] Alias MessageType for message passing --- src/plumpy/process_comms.py | 27 +++++++++++++-------------- src/plumpy/process_states.py | 4 +++- src/plumpy/processes.py | 23 ++++++++++++++++------- tests/test_processes.py | 2 ++ 4 files changed, 34 insertions(+), 22 deletions(-) diff --git a/src/plumpy/process_comms.py b/src/plumpy/process_comms.py index 13aa5fb3..1d280334 100644 --- a/src/plumpy/process_comms.py +++ b/src/plumpy/process_comms.py @@ -41,11 +41,12 @@ class Intent: KILL: str = 'kill' STATUS: str = 'status' +MessageType = dict[str, Any] -PAUSE_MSG = {INTENT_KEY: Intent.PAUSE, MESSAGE_KEY: None} -PLAY_MSG = {INTENT_KEY: Intent.PLAY, MESSAGE_KEY: None} -KILL_MSG = {INTENT_KEY: Intent.KILL, MESSAGE_KEY: None, FORCE_KILL_KEY: False} -STATUS_MSG = {INTENT_KEY: Intent.STATUS, MESSAGE_KEY: None} +PAUSE_MSG: MessageType= {INTENT_KEY: Intent.PAUSE, MESSAGE_KEY: None} +PLAY_MSG: MessageType = {INTENT_KEY: Intent.PLAY, MESSAGE_KEY: None} +KILL_MSG: MessageType = {INTENT_KEY: Intent.KILL, MESSAGE_KEY: None, FORCE_KILL_KEY: False} +STATUS_MSG: MessageType = {INTENT_KEY: Intent.STATUS, MESSAGE_KEY: None} TASK_KEY = 'task' TASK_ARGS = 'args' @@ -197,7 +198,7 @@ async def play_process(self, pid: 'PID_TYPE') -> 'ProcessResult': result = await asyncio.wrap_future(future) return result - async def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'ProcessResult': + async def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> 'ProcessResult': """ Kill the process @@ -205,12 +206,11 @@ async def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'Pro :param msg: optional kill message :return: True if killed, False otherwise """ - message = copy.copy(KILL_MSG) - if msg is not None: - message[MESSAGE_KEY] = msg + if msg is None: + msg = copy.copy(KILL_MSG) # Wait for the communication to go through - kill_future = self._communicator.rpc_send(pid, message) + kill_future = self._communicator.rpc_send(pid, msg) future = await asyncio.wrap_future(kill_future) # Now wait for the kill to be enacted result = await asyncio.wrap_future(future) @@ -372,7 +372,7 @@ def play_all(self) -> None: """ self._communicator.broadcast_send(None, subject=Intent.PLAY) - def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Future: + def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> kiwipy.Future: """ Kill the process @@ -381,11 +381,10 @@ def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Fut :return: a response future from the process to be killed """ - message = copy.copy(KILL_MSG) - if msg is not None: - message[MESSAGE_KEY] = msg + if msg is None: + msg = copy.copy(KILL_MSG) - return self._communicator.rpc_send(pid, message) + return self._communicator.rpc_send(pid, msg) def kill_all(self, msg: Optional[Any]) -> None: """ diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 7ae6e9bd..10ebfdab 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -8,6 +8,8 @@ import yaml from yaml.loader import Loader +from plumpy.process_comms import MessageType + try: import tblib @@ -402,7 +404,7 @@ def __init__(self, process: 'Process', result: Any, successful: bool) -> None: class Killed(State): LABEL = ProcessState.KILLED - def __init__(self, process: 'Process', msg: Optional[str]): + def __init__(self, process: 'Process', msg: Optional[MessageType]): """ :param process: The associated process :param msg: Optional kill message diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index ba7967d3..25a8f78e 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -47,6 +47,7 @@ from .process_listener import ProcessListener from .process_spec import ProcessSpec from .utils import PID_TYPE, SAVED_STATE_TYPE, protected +from .process_comms import KILL_MSG, MESSAGE_KEY, MessageType __all__ = ['BundleKeys', 'Process', 'ProcessSpec', 'TransitionFailed'] @@ -320,7 +321,9 @@ def init(self) -> None: def try_killing(future: futures.Future) -> None: if future.cancelled(): - if not self.kill('Killed by future being cancelled'): + msg = copy.copy(KILL_MSG) + msg[MESSAGE_KEY] = 'Killed by future being cancelled' + if not self.kill(msg): self.logger.warning('Process<%s>: Failed to kill process on future cancel', self.pid) self._future.add_done_callback(try_killing) @@ -857,10 +860,15 @@ def on_excepted(self) -> None: self._fire_event(ProcessListener.on_process_excepted, str(self.future().exception())) @super_check - def on_kill(self, msg: Optional[str]) -> None: + def on_kill(self, msg: Optional[MessageType]) -> None: """Entering the KILLED state.""" - self.set_status(msg) - self.future().set_exception(exceptions.KilledError(msg)) + if msg is None: + msg_txt = '' + else: + msg_txt = msg[MESSAGE_KEY] or '' + + self.set_status(msg_txt) + self.future().set_exception(exceptions.KilledError(msg_txt)) @super_check def on_killed(self) -> None: @@ -915,7 +923,7 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An if intent == process_comms.Intent.PAUSE: return self._schedule_rpc(self.pause, msg=msg.get(process_comms.MESSAGE_KEY, None)) if intent == process_comms.Intent.KILL: - return self._schedule_rpc(self.kill, msg=msg.get(process_comms.MESSAGE_KEY, None)) + return self._schedule_rpc(self.kill, msg=msg) if intent == process_comms.Intent.STATUS: status_info: Dict[str, Any] = {} self.get_status_info(status_info) @@ -1071,7 +1079,8 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu def do_kill(_next_state: process_states.State) -> Any: try: # Ignore the next state - self.transition_to(process_states.ProcessState.KILLED, str(exception)) + __import__('ipdb').set_trace() + self.transition_to(process_states.ProcessState.KILLED, exception) return True finally: self._killing = None @@ -1125,7 +1134,7 @@ def fail(self, exception: Optional[BaseException], trace_back: Optional[Tracebac """ self.transition_to(process_states.ProcessState.EXCEPTED, exception, trace_back) - def kill(self, msg: Union[str, None] = None) -> Union[bool, asyncio.Future]: + def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]: """ Kill the process :param msg: An optional kill message diff --git a/tests/test_processes.py b/tests/test_processes.py index bc500688..6481273e 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -4,6 +4,8 @@ import asyncio import copy import enum +from plumpy.process_comms import KILL_MSG, MESSAGE_KEY +from test import utils import unittest import kiwipy From c5a195c9d66d69aed6d09883c63bad71a4a178b2 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Sun, 1 Dec 2024 00:10:00 +0100 Subject: [PATCH 04/29] Simplify _create_state_instance so it only need to do real create --- src/plumpy/base/state_machine.py | 120 ++++++++++++++++++++----------- src/plumpy/processes.py | 5 +- 2 files changed, 80 insertions(+), 45 deletions(-) diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index d99d0705..556760c0 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -8,7 +8,20 @@ import os import sys from types import TracebackType -from typing import Any, Callable, Dict, Hashable, Iterable, List, Optional, Sequence, Set, Type, Union, cast +from typing import ( + Any, + Callable, + Dict, + Hashable, + Iterable, + List, + Optional, + Sequence, + Set, + Type, + Union, + cast, +) from plumpy.futures import Future @@ -60,10 +73,10 @@ def __init__( super().__init__(self._format_msg()) def _format_msg(self) -> str: - msg = [f'{self.initial_state} -> {self.final_state}'] + msg = [f"{self.initial_state} -> {self.final_state}"] if self.traceback_str is not None: msg.append(self.traceback_str) - return '\n'.join(msg) + return "\n".join(msg) def event( @@ -71,16 +84,16 @@ def event( to_states: Union[str, Type['State'], Iterable[Type['State']]] = '*', ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """A decorator to check for correct transitions, raising ``EventError`` on invalid transitions.""" - if from_states != '*': + if from_states != "*": if inspect.isclass(from_states): from_states = (from_states,) if not all(issubclass(state, State) for state in from_states): # type: ignore - raise TypeError(f'from_states: {from_states}') - if to_states != '*': + raise TypeError(f"from_states: {from_states}") + if to_states != "*": if inspect.isclass(to_states): to_states = (to_states,) if not all(issubclass(state, State) for state in to_states): # type: ignore - raise TypeError(f'to_states: {to_states}') + raise TypeError(f"to_states: {to_states}") def wrapper(wrapped: Callable[..., Any]) -> Callable[..., Any]: evt_label = wrapped.__name__ @@ -89,14 +102,20 @@ def wrapper(wrapped: Callable[..., Any]) -> Callable[..., Any]: def transition(self: Any, *a: Any, **kw: Any) -> Any: initial = self._state - if from_states != '*' and not any(isinstance(self._state, state) for state in from_states): # type: ignore - raise EventError(evt_label, f'Event {evt_label} invalid in state {initial.LABEL}') + if from_states != "*" and not any( + isinstance(self._state, state) for state in from_states + ): # type: ignore + raise EventError( + evt_label, f"Event {evt_label} invalid in state {initial.LABEL}" + ) result = wrapped(self, *a, **kw) if not (result is False or isinstance(result, Future)): - if to_states != '*' and not any(isinstance(self._state, state) for state in to_states): # type: ignore + if to_states != "*" and not any( + isinstance(self._state, state) for state in to_states + ): # type: ignore if self._state == initial: - raise EventError(evt_label, 'Machine did not transition') + raise EventError(evt_label, "Machine did not transition") raise EventError( evt_label, @@ -142,7 +161,7 @@ def label(self) -> LABEL_TYPE: def enter(self) -> None: """Entering the state""" - def execute(self) -> Optional['State']: + def execute(self) -> Optional["State"]: """ Execute the state, performing the actions that this state is responsible for. :returns: a state to transition to or None if finished. @@ -152,9 +171,9 @@ def execute(self) -> Optional['State']: def exit(self) -> None: """Exiting the state""" if self.is_terminal(): - raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + raise InvalidStateError(f"Cannot exit a terminal state {self.LABEL}") - def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'State': + def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> "State": return self.state_machine.create_state(state_label, *args, **kwargs) def do_enter(self) -> None: @@ -211,7 +230,7 @@ def get_states(cls) -> Sequence[Type[State]]: if cls.STATES is not None: return cls.STATES - raise RuntimeError('States not defined') + raise RuntimeError("States not defined") @classmethod def initial_state_label(cls) -> LABEL_TYPE: @@ -229,7 +248,7 @@ def get_state_class(cls, label: LABEL_TYPE) -> Type[State]: def __ensure_built(cls) -> None: try: # Check if it's already been built (and therefore sealed) - if cls.__getattribute__(cls, 'sealed'): + if cls.__getattribute__(cls, "sealed"): return except AttributeError: pass @@ -253,7 +272,9 @@ def __init__(self) -> None: self.__ensure_built() self._state: Optional[State] = None self._exception_handler = None # Note this appears to never be used - self.set_debug((not sys.flags.ignore_environment and bool(os.environ.get('PYTHONSMDEBUG')))) + self.set_debug( + (not sys.flags.ignore_environment and bool(os.environ.get("PYTHONSMDEBUG"))) + ) self._transitioning = False self._event_callbacks: Dict[Hashable, List[EVENT_CALLBACK_TYPE]] = {} @@ -262,7 +283,7 @@ def init(self) -> None: """Called after entering initial state in `__call__` method of `StateMachineMeta`""" def __str__(self) -> str: - return f'<{self.__class__.__name__}> ({self.state})' + return f"<{self.__class__.__name__}> ({self.state})" def create_initial_state(self) -> State: return self.get_state_class(self.initial_state_label())(self) @@ -273,7 +294,9 @@ def state(self) -> Optional[LABEL_TYPE]: return None return self._state.LABEL - def add_state_event_callback(self, hook: Hashable, callback: EVENT_CALLBACK_TYPE) -> None: + def add_state_event_callback( + self, hook: Hashable, callback: EVENT_CALLBACK_TYPE + ) -> None: """ Add a callback to be called on a particular state event hook. The callback should have form fn(state_machine, hook, state) @@ -283,8 +306,10 @@ def add_state_event_callback(self, hook: Hashable, callback: EVENT_CALLBACK_TYPE """ self._event_callbacks.setdefault(hook, []).append(callback) - def remove_state_event_callback(self, hook: Hashable, callback: EVENT_CALLBACK_TYPE) -> None: - if getattr(self, '_closed', False): + def remove_state_event_callback( + self, hook: Hashable, callback: EVENT_CALLBACK_TYPE + ) -> None: + if getattr(self, "_closed", False): # if the process is closed, then all callbacks have already been removed return None try: @@ -308,8 +333,10 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A try: self._transitioning = True - # Make sure we have a state instance - new_state = self._create_state_instance(new_state, *args, **kwargs) + if not isinstance(new_state, State): + # Make sure we have a state instance + new_state = self._create_state_instance(new_state, *args, **kwargs) + label = new_state.LABEL # If the previous transition failed, do not try to exit it but go straight to next state @@ -320,7 +347,10 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A self._enter_next_state(new_state) except StateEntryFailed as exception: # Make sure we have a state instance - new_state = self._create_state_instance(exception.state, *exception.args, **exception.kwargs) + if not isinstance(exception.state, State): + new_state = self._create_state_instance( + exception.state, *exception.args, **exception.kwargs + ) label = new_state.LABEL self._exit_current_state(new_state) self._enter_next_state(new_state) @@ -338,7 +368,11 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A self._transitioning = False def transition_failed( - self, initial_state: Hashable, final_state: Hashable, exception: Exception, trace: TracebackType + self, + initial_state: Hashable, + final_state: Hashable, + exception: Exception, + trace: TracebackType, ) -> None: """Called when a state transitions fails. @@ -358,7 +392,7 @@ def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> Stat try: return self.get_states_map()[state_label](self, *args, **kwargs) except KeyError: - raise ValueError(f'{state_label} is not a valid state') + raise ValueError(f"{state_label} is not a valid state") def _exit_current_state(self, next_state: State) -> None: """Exit the given state""" @@ -367,11 +401,15 @@ def _exit_current_state(self, next_state: State) -> None: # in which case check the new state is the initial state if self._state is None: if next_state.label != self.initial_state_label(): - raise RuntimeError(f"Cannot enter state '{next_state}' as the initial state") + raise RuntimeError( + f"Cannot enter state '{next_state}' as the initial state" + ) return # Nothing to exit if next_state.LABEL not in self._state.ALLOWED: - raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.label}') + raise RuntimeError( + f"Cannot transition from {self._state.LABEL} to {next_state.label}" + ) self._fire_state_event(StateEventHook.EXITING_STATE, next_state) self._state.do_exit() @@ -383,20 +421,16 @@ def _enter_next_state(self, next_state: State) -> None: self._state = next_state self._fire_state_event(StateEventHook.ENTERED_STATE, last_state) - def _create_state_instance(self, state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any) -> State: - if isinstance(state, State): - # It's already a state instance - return state - - # OK, have to create it - state_cls = self._ensure_state_class(state) - return state_cls(self, *args, **kwargs) - - def _ensure_state_class(self, state: Union[Hashable, Type[State]]) -> Type[State]: + def _create_state_instance( + self, state: Union[Hashable, Type[State]], *args: Any, **kwargs: Any + ) -> State: + # build from state class if inspect.isclass(state) and issubclass(state, State): - return state + state_cls = state + else: + try: + state_cls = self.get_states_map()[cast(Hashable, state)] # pylint: disable=unsubscriptable-object + except KeyError: + raise ValueError(f"{state} is not a valid state") - try: - return self.get_states_map()[cast(Hashable, state)] - except KeyError: - raise ValueError(f'{state} is not a valid state') + return state_cls(self, *args, **kwargs) diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 25a8f78e..6eef55af 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -865,7 +865,8 @@ def on_kill(self, msg: Optional[MessageType]) -> None: if msg is None: msg_txt = '' else: - msg_txt = msg[MESSAGE_KEY] or '' + # msg_txt = msg[MESSAGE_KEY] or '' + msg_txt = msg self.set_status(msg_txt) self.future().set_exception(exceptions.KilledError(msg_txt)) @@ -1079,7 +1080,7 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu def do_kill(_next_state: process_states.State) -> Any: try: # Ignore the next state - __import__('ipdb').set_trace() + # __import__('ipdb').set_trace() self.transition_to(process_states.ProcessState.KILLED, exception) return True finally: From 8db66756bd1e96456e5bd4d6f74655139f577a93 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Sun, 1 Dec 2024 00:34:41 +0100 Subject: [PATCH 05/29] Furthur simplipy _create_state_instant only create state from class --- src/plumpy/base/state_machine.py | 31 +++++++++++++++++++------------ src/plumpy/processes.py | 10 +++++----- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index 556760c0..3397c40d 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -20,7 +20,6 @@ Set, Type, Union, - cast, ) from plumpy.futures import Future @@ -325,8 +324,18 @@ def _fire_state_event(self, hook: Hashable, state: Optional[State]) -> None: def on_terminated(self) -> None: """Called when a terminal state is entered""" - def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any) -> None: - assert not self._transitioning, 'Cannot call transition_to when already transitioning state' + def transition_to( + self, new_state: Union[State, Type[State]], *args: Any, **kwargs: Any + ) -> None: + """Transite to the new state. + + The new target state will be create lazily when the state + is not yet instantiated, which will happened for states not in the expect path such as + pause and kill. + """ + assert ( + not self._transitioning + ), "Cannot call transition_to when already transitioning state" initial_state_label = self._state.LABEL if self._state is not None else None label = None @@ -389,6 +398,10 @@ def set_debug(self, enabled: bool) -> None: self._debug: bool = enabled def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> State: + # XXX: this method create state from label, which is duplicate as _create_state_instance and less generic + # because the label is defined after the state and required to be know before calling this function. + # This method should be replaced by `_create_state_instance`. + # aiida-core using this method for its Waiting state override. try: return self.get_states_map()[state_label](self, *args, **kwargs) except KeyError: @@ -422,15 +435,9 @@ def _enter_next_state(self, next_state: State) -> None: self._fire_state_event(StateEventHook.ENTERED_STATE, last_state) def _create_state_instance( - self, state: Union[Hashable, Type[State]], *args: Any, **kwargs: Any + self, state_cls: type[State], *args: Any, **kwargs: Any ) -> State: - # build from state class - if inspect.isclass(state) and issubclass(state, State): - state_cls = state - else: - try: - state_cls = self.get_states_map()[cast(Hashable, state)] # pylint: disable=unsubscriptable-object - except KeyError: - raise ValueError(f"{state} is not a valid state") + if state_cls.LABEL not in self.get_states_map(): + raise ValueError(f"{state_cls.LABEL} is not a valid state") return state_cls(self, *args, **kwargs) diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 6eef55af..a4b3b017 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -831,7 +831,7 @@ def on_finish(self, result: Any, successful: bool) -> None: if successful: validation_error = self.spec().outputs.validate(self.outputs) if validation_error: - raise StateEntryFailed(process_states.ProcessState.FINISHED, result, False) + raise StateEntryFailed(process_states.Finished, result, False) self.future().set_result(self.outputs) @@ -1016,7 +1016,7 @@ def transition_failed( if final_state == process_states.ProcessState.CREATED: raise exception.with_traceback(trace) - self.transition_to(process_states.ProcessState.EXCEPTED, exception, trace) + self.transition_to(process_states.Excepted, exception, trace) def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]: """Pause the process. @@ -1081,7 +1081,7 @@ def do_kill(_next_state: process_states.State) -> Any: try: # Ignore the next state # __import__('ipdb').set_trace() - self.transition_to(process_states.ProcessState.KILLED, exception) + self.transition_to(process_states.Killed, exception) return True finally: self._killing = None @@ -1133,7 +1133,7 @@ def fail(self, exception: Optional[BaseException], trace_back: Optional[Tracebac :param exception: The exception that caused the failure :param trace_back: Optional exception traceback """ - self.transition_to(process_states.ProcessState.EXCEPTED, exception, trace_back) + self.transition_to(process_states.Excepted, exception, trace_back) def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]: """ @@ -1161,7 +1161,7 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future] self._state.interrupt(interrupt_exception) return cast(futures.CancellableAction, self._interrupt_action) - self.transition_to(process_states.ProcessState.KILLED, msg) + self.transition_to(process_states.Killed, msg) return True @property From 74d048dd7c259e631d9a14ef5e869b91c485086b Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Sun, 1 Dec 2024 02:29:46 +0100 Subject: [PATCH 06/29] Killed state all through passing msg --- src/plumpy/base/state_machine.py | 19 +-- src/plumpy/process_comms.py | 5 +- src/plumpy/process_states.py | 186 +++++++++++++++------- src/plumpy/processes.py | 264 ++++++++++++++++++++++--------- 4 files changed, 336 insertions(+), 138 deletions(-) diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index 3397c40d..a371084a 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -205,7 +205,7 @@ def __call__(cls, *args: Any, **kwargs: Any) -> 'StateMachine': :param kwargs: Any keyword arguments to be passed to the constructor :return: An instance of the state machine """ - inst = super().__call__(*args, **kwargs) + inst: StateMachine = super().__call__(*args, **kwargs) inst.transition_to(inst.create_initial_state()) call_with_super_check(inst.init) return inst @@ -325,13 +325,14 @@ def on_terminated(self) -> None: """Called when a terminal state is entered""" def transition_to( - self, new_state: Union[State, Type[State]], *args: Any, **kwargs: Any + self, new_state: Union[State, Type[State]], **kwargs: Any ) -> None: """Transite to the new state. - The new target state will be create lazily when the state - is not yet instantiated, which will happened for states not in the expect path such as - pause and kill. + The new target state will be create lazily when the state is not yet instantiated, + which will happened for states not in the expect path such as pause and kill. + The arguments are passed to the state class to create state instance. + (process arg does not need to pass since it will always call with 'self' as process) """ assert ( not self._transitioning @@ -344,7 +345,7 @@ def transition_to( if not isinstance(new_state, State): # Make sure we have a state instance - new_state = self._create_state_instance(new_state, *args, **kwargs) + new_state = self._create_state_instance(new_state, **kwargs) label = new_state.LABEL @@ -358,7 +359,7 @@ def transition_to( # Make sure we have a state instance if not isinstance(exception.state, State): new_state = self._create_state_instance( - exception.state, *exception.args, **exception.kwargs + exception.state, **exception.kwargs ) label = new_state.LABEL self._exit_current_state(new_state) @@ -435,9 +436,9 @@ def _enter_next_state(self, next_state: State) -> None: self._fire_state_event(StateEventHook.ENTERED_STATE, last_state) def _create_state_instance( - self, state_cls: type[State], *args: Any, **kwargs: Any + self, state_cls: type[State], **kwargs: Any ) -> State: if state_cls.LABEL not in self.get_states_map(): raise ValueError(f"{state_cls.LABEL} is not a valid state") - return state_cls(self, *args, **kwargs) + return state_cls(self, **kwargs) diff --git a/src/plumpy/process_comms.py b/src/plumpy/process_comms.py index 1d280334..9e1e4110 100644 --- a/src/plumpy/process_comms.py +++ b/src/plumpy/process_comms.py @@ -386,12 +386,15 @@ def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> ki return self._communicator.rpc_send(pid, msg) - def kill_all(self, msg: Optional[Any]) -> None: + def kill_all(self, msg: Optional[MessageType]) -> None: """ Kill all processes that are subscribed to the same communicator :param msg: an optional pause message """ + if msg is None: + msg = copy.copy(KILL_MSG) + self._communicator.broadcast_send(msg, subject=Intent.KILL) def continue_process( diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 10ebfdab..46f29d8f 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- import sys +import copy import traceback from enum import Enum from types import TracebackType @@ -8,7 +9,7 @@ import yaml from yaml.loader import Loader -from plumpy.process_comms import MessageType +from plumpy.process_comms import KILL_MSG, MessageType try: import tblib @@ -50,7 +51,12 @@ class Interruption(Exception): # noqa: N818 class KillInterruption(Interruption): - pass + def __init__(self, msg: MessageType | None): + super().__init__() + if msg is None: + msg = copy.copy(KILL_MSG) + + self.msg: MessageType = msg class PauseInterruption(Interruption): @@ -64,9 +70,9 @@ class Command(persistence.Savable): pass -@auto_persist('msg') +@auto_persist("msg") class Kill(Command): - def __init__(self, msg: Optional[Any] = None): + def __init__(self, msg: Optional[MessageType] = None): super().__init__() self.msg = msg @@ -75,10 +81,13 @@ class Pause(Command): pass -@auto_persist('msg', 'data') +@auto_persist("msg", "data") class Wait(Command): def __init__( - self, continue_fn: Optional[Callable[..., Any]] = None, msg: Optional[Any] = None, data: Optional[Any] = None + self, + continue_fn: Optional[Callable[..., Any]] = None, + msg: Optional[Any] = None, + data: Optional[Any] = None, ): super().__init__() self.continue_fn = continue_fn @@ -86,7 +95,7 @@ def __init__( self.data = data -@auto_persist('result') +@auto_persist("result") class Stop(Command): def __init__(self, result: Any, successful: bool) -> None: super().__init__() @@ -94,9 +103,9 @@ def __init__(self, result: Any, successful: bool) -> None: self.successful = successful -@auto_persist('args', 'kwargs') +@auto_persist("args", "kwargs") class Continue(Command): - CONTINUE_FN = 'continue_fn' + CONTINUE_FN = "continue_fn" def __init__(self, continue_fn: Callable[..., Any], *args: Any, **kwargs: Any): super().__init__() @@ -104,11 +113,15 @@ def __init__(self, continue_fn: Callable[..., Any], *args: Any, **kwargs: Any): self.args = args self.kwargs = kwargs - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: + def save_instance_state( + self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext + ) -> None: super().save_instance_state(out_state, save_context) out_state[self.CONTINUE_FN] = self.continue_fn.__name__ - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: + def load_instance_state( + self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext + ) -> None: super().load_instance_state(saved_state, load_context) try: self.continue_fn = utils.load_function(saved_state[self.CONTINUE_FN]) @@ -135,7 +148,7 @@ class ProcessState(Enum): KILLED: str = 'killed' -@auto_persist('in_state') +@auto_persist("in_state") class State(state_machine.State, persistence.Savable): @property def process(self) -> state_machine.StateMachine: @@ -144,7 +157,9 @@ def process(self) -> state_machine.StateMachine: """ return self.state_machine - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: + def load_instance_state( + self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext + ) -> None: super().load_instance_state(saved_state, load_context) self.state_machine = load_context.process @@ -152,33 +167,41 @@ def interrupt(self, reason: Any) -> None: pass -@auto_persist('args', 'kwargs') +@auto_persist("args", "kwargs") class Created(State): LABEL = ProcessState.CREATED ALLOWED = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED} - RUN_FN = 'run_fn' + RUN_FN = "run_fn" - def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: + def __init__( + self, process: "Process", run_fn: Callable[..., Any], *args: Any, **kwargs: Any + ) -> None: super().__init__(process) assert run_fn is not None self.run_fn = run_fn self.args = args self.kwargs = kwargs - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: + def save_instance_state( + self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext + ) -> None: super().save_instance_state(out_state, save_context) out_state[self.RUN_FN] = self.run_fn.__name__ - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: + def load_instance_state( + self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext + ) -> None: super().load_instance_state(saved_state, load_context) self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) def execute(self) -> state_machine.State: - return self.create_state(ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs) + return self.create_state( + ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs + ) -@auto_persist('args', 'kwargs') +@auto_persist("args", "kwargs") class Running(State): LABEL = ProcessState.RUNNING ALLOWED = { @@ -189,15 +212,17 @@ class Running(State): ProcessState.EXCEPTED, } - RUN_FN = 'run_fn' # The key used to store the function to run - COMMAND = 'command' # The key used to store an upcoming command + RUN_FN = "run_fn" # The key used to store the function to run + COMMAND = "command" # The key used to store an upcoming command # Class level defaults _command: Union[None, Kill, Stop, Wait, Continue] = None _running: bool = False _run_handle = None - def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: + def __init__( + self, process: "Process", run_fn: Callable[..., Any], *args: Any, **kwargs: Any + ) -> None: super().__init__(process) assert run_fn is not None self.run_fn = run_fn @@ -205,17 +230,23 @@ def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, * self.kwargs = kwargs self._run_handle = None - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: + def save_instance_state( + self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext + ) -> None: super().save_instance_state(out_state, save_context) out_state[self.RUN_FN] = self.run_fn.__name__ if self._command is not None: out_state[self.COMMAND] = self._command.save() - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: + def load_instance_state( + self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext + ) -> None: super().load_instance_state(saved_state, load_context) self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) if self.COMMAND in saved_state: - self._command = persistence.Savable.load(saved_state[self.COMMAND], load_context) # type: ignore + self._command = persistence.Savable.load( + saved_state[self.COMMAND], load_context + ) # type: ignore def interrupt(self, reason: Any) -> None: pass @@ -255,18 +286,24 @@ def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> State: # elif isinstance(command, Pause): # self.pause() elif isinstance(command, Stop): - state = self.create_state(ProcessState.FINISHED, command.result, command.successful) + state = self.create_state( + ProcessState.FINISHED, command.result, command.successful + ) elif isinstance(command, Wait): - state = self.create_state(ProcessState.WAITING, command.continue_fn, command.msg, command.data) + state = self.create_state( + ProcessState.WAITING, command.continue_fn, command.msg, command.data + ) elif isinstance(command, Continue): - state = self.create_state(ProcessState.RUNNING, command.continue_fn, *command.args) + state = self.create_state( + ProcessState.RUNNING, command.continue_fn, *command.args + ) else: - raise ValueError('Unrecognised command') + raise ValueError("Unrecognised command") return cast(State, state) # casting from base.State to process.State -@auto_persist('msg', 'data') +@auto_persist("msg", "data") class Waiting(State): LABEL = ProcessState.WAITING ALLOWED = { @@ -277,19 +314,19 @@ class Waiting(State): ProcessState.FINISHED, } - DONE_CALLBACK = 'DONE_CALLBACK' + DONE_CALLBACK = "DONE_CALLBACK" _interruption = None def __str__(self) -> str: state_info = super().__str__() if self.msg is not None: - state_info += f' ({self.msg})' + state_info += f" ({self.msg})" return state_info def __init__( self, - process: 'Process', + process: "Process", done_callback: Optional[Callable[..., Any]], msg: Optional[str] = None, data: Optional[Any] = None, @@ -300,12 +337,16 @@ def __init__( self.data = data self._waiting_future: futures.Future = futures.Future() - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: + def save_instance_state( + self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext + ) -> None: super().save_instance_state(out_state, save_context) if self.done_callback is not None: out_state[self.DONE_CALLBACK] = self.done_callback.__name__ - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: + def load_instance_state( + self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext + ) -> None: super().load_instance_state(saved_state, load_context) callback_name = saved_state.get(self.DONE_CALLBACK, None) if callback_name is not None: @@ -331,12 +372,14 @@ async def execute(self) -> State: # type: ignore if result == NULL: next_state = self.create_state(ProcessState.RUNNING, self.done_callback) else: - next_state = self.create_state(ProcessState.RUNNING, self.done_callback, result) + next_state = self.create_state( + ProcessState.RUNNING, self.done_callback, result + ) return cast(State, next_state) # casting from base.State to process.State def resume(self, value: Any = NULL) -> None: - assert self._waiting_future is not None, 'Not yet waiting' + assert self._waiting_future is not None, "Not yet waiting" if self._waiting_future.done(): return @@ -345,13 +388,23 @@ def resume(self, value: Any = NULL) -> None: class Excepted(State): + """ + Excepted state, can optionally provide exception and trace_back + + :param exception: The exception instance + :param trace_back: An optional exception traceback + """ + LABEL = ProcessState.EXCEPTED - EXC_VALUE = 'ex_value' - TRACEBACK = 'traceback' + EXC_VALUE = "ex_value" + TRACEBACK = "traceback" def __init__( - self, process: 'Process', exception: Optional[BaseException], trace_back: Optional[TracebackType] = None + self, + process: "Process", + exception: Optional[BaseException], + trace_back: Optional[TracebackType] = None, ): """ :param process: The associated process @@ -363,16 +416,22 @@ def __init__( self.traceback = trace_back def __str__(self) -> str: - exception = traceback.format_exception_only(type(self.exception) if self.exception else None, self.exception)[0] - return super().__str__() + f'({exception})' + exception = traceback.format_exception_only( + type(self.exception) if self.exception else None, self.exception + )[0] + return super().__str__() + f"({exception})" - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: + def save_instance_state( + self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext + ) -> None: super().save_instance_state(out_state, save_context) out_state[self.EXC_VALUE] = yaml.dump(self.exception) if self.traceback is not None: - out_state[self.TRACEBACK] = ''.join(traceback.format_tb(self.traceback)) + out_state[self.TRACEBACK] = "".join(traceback.format_tb(self.traceback)) - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: + def load_instance_state( + self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext + ) -> None: super().load_instance_state(saved_state, load_context) self.exception = yaml.load(saved_state[self.EXC_VALUE], Loader=Loader) if _HAS_TBLIB: @@ -383,32 +442,53 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi else: self.traceback = None - def get_exc_info(self) -> Tuple[Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType]]: + def get_exc_info( + self, + ) -> Tuple[ + Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType] + ]: """ Recreate the exc_info tuple and return it """ - return type(self.exception) if self.exception else None, self.exception, self.traceback + return ( + type(self.exception) if self.exception else None, + self.exception, + self.traceback, + ) -@auto_persist('result', 'successful') +@auto_persist("result", "successful") class Finished(State): + """State for process is finished. + + :param result: The result of process + :param successful: Boolean for the exit code is ``0`` the process is successful. + """ LABEL = ProcessState.FINISHED - def __init__(self, process: 'Process', result: Any, successful: bool) -> None: + def __init__(self, process: "Process", result: Any, successful: bool) -> None: super().__init__(process) self.result = result self.successful = successful -@auto_persist('msg') +@auto_persist("msg") class Killed(State): + """ + Represents a state where a process has been killed. + + This state is used to indicate that a process has been terminated and can optionally + include a message providing details about the termination. + + :param msg: An optional message explaining the reason for the process termination. + """ + LABEL = ProcessState.KILLED - def __init__(self, process: 'Process', msg: Optional[MessageType]): + def __init__(self, process: "Process", msg: Optional[MessageType]): """ :param process: The associated process :param msg: Optional kill message - """ super().__init__(process) self.msg = msg diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index a4b3b017..07e2d20c 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -39,7 +39,16 @@ import yaml from aio_pika.exceptions import ChannelInvalidStateError, ConnectionClosed -from . import events, exceptions, futures, persistence, ports, process_comms, process_states, utils +from . import ( + events, + exceptions, + futures, + persistence, + ports, + process_comms, + process_states, + utils, +) from .base import state_machine from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event from .base.utils import call_with_super_check, super_check @@ -52,7 +61,7 @@ __all__ = ['BundleKeys', 'Process', 'ProcessSpec', 'TransitionFailed'] _LOGGER = logging.getLogger(__name__) -PROCESS_STACK = ContextVar('process stack', default=[]) +PROCESS_STACK = ContextVar("process stack", default=[]) class BundleKeys: @@ -85,14 +94,20 @@ def ensure_not_closed(func: Callable[..., Any]) -> Callable[..., Any]: @functools.wraps(func) def func_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: if self._closed: - raise exceptions.ClosedError('Process is closed') + raise exceptions.ClosedError("Process is closed") return func(self, *args, **kwargs) return func_wrapper @persistence.auto_persist( - '_pid', '_creation_time', '_future', '_paused', '_status', '_pre_paused_status', '_event_helper' + "_pid", + "_creation_time", + "_future", + "_paused", + "_status", + "_pre_paused_status", + "_event_helper", ) class Process(StateMachine, persistence.Savable, metaclass=ProcessStateMachineMeta): """ @@ -146,7 +161,7 @@ class Process(StateMachine, persistence.Savable, metaclass=ProcessStateMachineMe __called: bool = False @classmethod - def current(cls) -> Optional['Process']: + def current(cls) -> Optional["Process"]: """ Get the currently running process i.e. the one at the top of the stack @@ -182,15 +197,15 @@ def get_state_classes(cls) -> Dict[Hashable, Type[process_states.State]]: @classmethod def spec(cls) -> ProcessSpec: try: - return cls.__getattribute__(cls, '_spec') + return cls.__getattribute__(cls, "_spec") except AttributeError: try: cls._spec: ProcessSpec = cls._spec_class() # type: ignore cls.__called: bool = False # type: ignore cls.define(cls._spec) # type: ignore assert cls.__called, ( - f'Process.define() was not called by {cls}\nHint: Did you forget to call the superclass method in ' - 'your define? Try: super().define(spec)' + f"Process.define() was not called by {cls}\nHint: Did you forget to call the superclass method in " + "your define? Try: super().define(spec)" ) return cls._spec # type: ignore except Exception: @@ -222,18 +237,20 @@ def get_description(cls) -> Dict[str, Any]: description: Dict[str, Any] = {} if cls.__doc__: - description['description'] = cls.__doc__.strip() + description["description"] = cls.__doc__.strip() spec_description = cls.spec().get_description() if spec_description: - description['spec'] = spec_description + description["spec"] = spec_description return description @classmethod def recreate_from( - cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext] = None - ) -> 'Process': + cls, + saved_state: SAVED_STATE_TYPE, + load_context: Optional[persistence.LoadSaveContext] = None, + ) -> "Process": """ Recreate a process from a saved state, passing any positional and keyword arguments on to load_instance_state @@ -281,7 +298,9 @@ def __init__( self._paused = None # Input/output - self._raw_inputs = None if inputs is None else utils.AttributesFrozendict(inputs) + self._raw_inputs = ( + None if inputs is None else utils.AttributesFrozendict(inputs) + ) self._pid = pid self._parsed_inputs: Optional[utils.AttributesFrozendict] = None self._outputs: Dict[str, Any] = {} @@ -304,27 +323,49 @@ def init(self) -> None: if self._communicator is not None: try: - identifier = self._communicator.add_rpc_subscriber(self.message_receive, identifier=str(self.pid)) - self.add_cleanup(functools.partial(self._communicator.remove_rpc_subscriber, identifier)) + identifier = self._communicator.add_rpc_subscriber( + self.message_receive, identifier=str(self.pid) + ) + self.add_cleanup( + functools.partial( + self._communicator.remove_rpc_subscriber, identifier + ) + ) except kiwipy.TimeoutError: - self.logger.exception('Process<%s>: failed to register as an RPC subscriber', self.pid) + self.logger.exception( + "Process<%s>: failed to register as an RPC subscriber", self.pid + ) try: # filter out state change broadcasts - subscriber = kiwipy.BroadcastFilter(self.broadcast_receive, subject=re.compile(r'^(?!state_changed).*')) - identifier = self._communicator.add_broadcast_subscriber(subscriber, identifier=str(self.pid)) - self.add_cleanup(functools.partial(self._communicator.remove_broadcast_subscriber, identifier)) + subscriber = kiwipy.BroadcastFilter( + self.broadcast_receive, subject=re.compile(r"^(?!state_changed).*") + ) + identifier = self._communicator.add_broadcast_subscriber( + subscriber, identifier=str(self.pid) + ) + self.add_cleanup( + functools.partial( + self._communicator.remove_broadcast_subscriber, identifier + ) + ) except kiwipy.TimeoutError: - self.logger.exception('Process<%s>: failed to register as a broadcast subscriber', self.pid) + self.logger.exception( + "Process<%s>: failed to register as a broadcast subscriber", + self.pid, + ) if not self._future.done(): def try_killing(future: futures.Future) -> None: if future.cancelled(): msg = copy.copy(KILL_MSG) - msg[MESSAGE_KEY] = 'Killed by future being cancelled' + msg[MESSAGE_KEY] = "Killed by future being cancelled" if not self.kill(msg): - self.logger.warning('Process<%s>: Failed to kill process on future cancel', self.pid) + self.logger.warning( + "Process<%s>: Failed to kill process on future cancel", + self.pid, + ) self._future.add_done_callback(try_killing) @@ -419,7 +460,7 @@ def future(self) -> persistence.SavableFuture: @ensure_not_closed def launch( self, - process_class: Type['Process'], + process_class: Type["Process"], inputs: Optional[dict] = None, pid: Optional[PID_TYPE] = None, logger: Optional[logging.Logger] = None, @@ -428,7 +469,13 @@ def launch( The process is started asynchronously, without blocking other task in the event loop. """ - process = process_class(inputs=inputs, pid=pid, logger=logger, loop=self.loop, communicator=self._communicator) + process = process_class( + inputs=inputs, + pid=pid, + logger=logger, + loop=self.loop, + communicator=self._communicator, + ) self.loop.create_task(process.step_until_terminated()) return process @@ -451,7 +498,7 @@ def result(self) -> Any: if isinstance(self._state, process_states.Killed): raise exceptions.KilledError(self._state.msg) if isinstance(self._state, process_states.Excepted): - raise (self._state.exception or Exception('process excepted')) + raise (self._state.exception or Exception("process excepted")) raise exceptions.InvalidStateError @@ -463,7 +510,9 @@ def successful(self) -> bool: try: return self._state.successful # type: ignore except AttributeError as exception: - raise exceptions.InvalidStateError('process is not in the finished state') from exception + raise exceptions.InvalidStateError( + "process is not in the finished state" + ) from exception @property def is_successful(self) -> bool: @@ -480,12 +529,12 @@ def killed(self) -> bool: """Return whether the process is killed.""" return self.state == process_states.ProcessState.KILLED - def killed_msg(self) -> Optional[str]: + def killed_msg(self) -> Optional[MessageType]: """Return the killed message.""" if isinstance(self._state, process_states.Killed): return self._state.msg - raise exceptions.InvalidStateError('Has not been killed') + raise exceptions.InvalidStateError("Has not been killed") def exception(self) -> Optional[BaseException]: """Return exception, if the process is terminated in excepted state.""" @@ -520,7 +569,9 @@ def loop(self) -> asyncio.AbstractEventLoop: """Return the event loop of the process.""" return self._loop - def call_soon(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> events.ProcessCallback: + def call_soon( + self, callback: Callable[..., Any], *args: Any, **kwargs: Any + ) -> events.ProcessCallback: """ Schedule a callback to what is considered an internal process function (this needn't be a method). @@ -532,7 +583,10 @@ def call_soon(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> return handle def callback_excepted( - self, _callback: Callable[..., Any], exception: Optional[BaseException], trace: Optional[TracebackType] + self, + _callback: Callable[..., Any], + exception: Optional[BaseException], + trace: Optional[TracebackType], ) -> None: if self.state != process_states.ProcessState.EXCEPTED: self.fail(exception, trace) @@ -551,14 +605,16 @@ def _process_scope(self) -> Generator[None, None, None]: yield None finally: assert Process.current() is self, ( - 'Somehow, the process at the top of the stack is not me, but another process! ' - f'({self} != {Process.current()})' + "Somehow, the process at the top of the stack is not me, but another process! " + f"({self} != {Process.current()})" ) stack_copy = PROCESS_STACK.get().copy() stack_copy.pop() PROCESS_STACK.set(stack_copy) - async def _run_task(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: + async def _run_task( + self, callback: Callable[..., Any], *args: Any, **kwargs: Any + ) -> Any: """ This method should be used to run all Process related functions and coroutines. If there is an exception the process will enter the EXCEPTED state. @@ -579,7 +635,9 @@ async def _run_task(self, callback: Callable[..., Any], *args: Any, **kwargs: An # region Persistence def save_instance_state( - self, out_state: SAVED_STATE_TYPE, save_context: Optional[persistence.LoadSaveContext] + self, + out_state: SAVED_STATE_TYPE, + save_context: Optional[persistence.LoadSaveContext], ) -> None: """ Ask the process to save its current instance state. @@ -589,7 +647,7 @@ def save_instance_state( """ super().save_instance_state(out_state, save_context) - out_state['_state'] = self._state.save() + out_state["_state"] = self._state.save() # Inputs/outputs if self.raw_inputs is not None: @@ -602,7 +660,9 @@ def save_instance_state( out_state[BundleKeys.OUTPUTS] = self.encode_input_args(self.outputs) @protected - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: + def load_instance_state( + self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext + ) -> None: """Load the process from its saved instance state. :param saved_state: A bundle to load the state from @@ -620,17 +680,17 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi self._logger = None self._communicator = None - if 'loop' in load_context: + if "loop" in load_context: self._loop = load_context.loop else: self._loop = asyncio.get_event_loop() - self._state: process_states.State = self.recreate_state(saved_state['_state']) + self._state: process_states.State = self.recreate_state(saved_state["_state"]) - if 'communicator' in load_context: + if "communicator" in load_context: self._communicator = load_context.communicator - if 'logger' in load_context: + if "logger" in load_context: self._logger = load_context.logger # Need to call this here as things downstream may rely on us having the runtime variable above @@ -679,7 +739,7 @@ def set_logger(self, logger: logging.Logger) -> None: @protected def log_with_pid(self, level: int, msg: str) -> None: """Log the message with the process pid.""" - self.logger.log(level, '%s: %s', self.pid, msg) + self.logger.log(level, "%s: %s", self.pid, msg) # region Events @@ -714,16 +774,24 @@ def on_entered(self, from_state: Optional[process_states.State]) -> None: call_with_super_check(self.on_killed) if self._communicator and isinstance(self.state, enum.Enum): - from_label = cast(enum.Enum, from_state.LABEL).value if from_state is not None else None - subject = f'state_changed.{from_label}.{self.state.value}' - self.logger.info('Process<%s>: Broadcasting state change: %s', self.pid, subject) + from_label = ( + cast(enum.Enum, from_state.LABEL).value + if from_state is not None + else None + ) + subject = f"state_changed.{from_label}.{self.state.value}" + self.logger.info( + "Process<%s>: Broadcasting state change: %s", self.pid, subject + ) try: - self._communicator.broadcast_send(body=None, sender=self.pid, subject=subject) + self._communicator.broadcast_send( + body=None, sender=self.pid, subject=subject + ) except (ConnectionClosed, ChannelInvalidStateError): - message = 'Process<%s>: no connection available to broadcast state change from %s to %s' + message = "Process<%s>: no connection available to broadcast state change from %s to %s" self.logger.warning(message, self.pid, from_label, self.state.value) except kiwipy.TimeoutError: - message = 'Process<%s>: sending broadcast of state change from %s to %s timed out' + message = "Process<%s>: sending broadcast of state change from %s to %s timed out" self.logger.warning(message, self.pid, from_label, self.state.value) def on_exiting(self) -> None: @@ -741,7 +809,10 @@ def on_create(self) -> None: def recursively_copy_dictionaries(value: Any) -> Any: """Recursively copy the mapping but only create copies of the dictionaries not the values.""" if isinstance(value, dict): - return {key: recursively_copy_dictionaries(subvalue) for key, subvalue in value.items()} + return { + key: recursively_copy_dictionaries(subvalue) + for key, subvalue in value.items() + } return value # This will parse the inputs with respect to the input portnamespace of the spec and validate them. The @@ -749,7 +820,11 @@ def recursively_copy_dictionaries(value: Any) -> Any: # ``_raw_inputs`` should not be modified, we pass a clone of it. Note that we only need a clone of the nested # dictionaries, so we don't use ``copy.deepcopy`` (which might seem like the obvious choice) as that will also # create a clone of the values, which we don't want. - raw_inputs = recursively_copy_dictionaries(dict(self._raw_inputs)) if self._raw_inputs else {} + raw_inputs = ( + recursively_copy_dictionaries(dict(self._raw_inputs)) + if self._raw_inputs + else {} + ) self._parsed_inputs = self.spec().inputs.pre_process(raw_inputs) result = self.spec().inputs.validate(self._parsed_inputs) @@ -782,7 +857,9 @@ def on_output_emitting(self, output_port: str, value: Any) -> None: """Output is about to be emitted.""" def on_output_emitted(self, output_port: str, value: Any, dynamic: bool) -> None: - self._event_helper.fire_event(ProcessListener.on_output_emitted, self, output_port, value, dynamic) + self._event_helper.fire_event( + ProcessListener.on_output_emitted, self, output_port, value, dynamic + ) @super_check def on_wait(self, awaitables: Sequence[Awaitable]) -> None: @@ -831,7 +908,9 @@ def on_finish(self, result: Any, successful: bool) -> None: if successful: validation_error = self.spec().outputs.validate(self.outputs) if validation_error: - raise StateEntryFailed(process_states.Finished, result, False) + raise StateEntryFailed( + process_states.Finished, result=result, successful=False + ) self.future().set_result(self.outputs) @@ -857,16 +936,17 @@ def on_except(self, exc_info: Tuple[Any, Exception, TracebackType]) -> None: @super_check def on_excepted(self) -> None: """Entered the EXCEPTED state.""" - self._fire_event(ProcessListener.on_process_excepted, str(self.future().exception())) + self._fire_event( + ProcessListener.on_process_excepted, str(self.future().exception()) + ) @super_check def on_kill(self, msg: Optional[MessageType]) -> None: """Entering the KILLED state.""" if msg is None: - msg_txt = '' + msg_txt = "" else: - # msg_txt = msg[MESSAGE_KEY] or '' - msg_txt = msg + msg_txt = msg[MESSAGE_KEY] or "" self.set_status(msg_txt) self.future().set_exception(exceptions.KilledError(msg_txt)) @@ -915,14 +995,21 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An :param msg: the message :return: the outcome of processing the message, the return value will be sent back as a response to the sender """ - self.logger.debug("Process<%s>: received RPC message with communicator '%s': %r", self.pid, _comm, msg) + self.logger.debug( + "Process<%s>: received RPC message with communicator '%s': %r", + self.pid, + _comm, + msg, + ) intent = msg[process_comms.INTENT_KEY] if intent == process_comms.Intent.PLAY: return self._schedule_rpc(self.play) if intent == process_comms.Intent.PAUSE: - return self._schedule_rpc(self.pause, msg=msg.get(process_comms.MESSAGE_KEY, None)) + return self._schedule_rpc( + self.pause, msg=msg.get(process_comms.MESSAGE_KEY, None) + ) if intent == process_comms.Intent.KILL: return self._schedule_rpc(self.kill, msg=msg) if intent == process_comms.Intent.STATUS: @@ -931,7 +1018,7 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An return status_info # Didn't match any known intents - raise RuntimeError('Unknown intent') + raise RuntimeError("Unknown intent") def broadcast_receive( self, _comm: kiwipy.Communicator, body: Any, sender: Any, subject: Any, correlation_id: Any @@ -944,7 +1031,11 @@ def broadcast_receive( """ self.logger.debug( - "Process<%s>: received broadcast message '%s' with communicator '%s': %r", self.pid, subject, _comm, body + "Process<%s>: received broadcast message '%s' with communicator '%s': %r", + self.pid, + subject, + _comm, + body, ) # If we get a message we recognise then action it, otherwise ignore @@ -956,7 +1047,9 @@ def broadcast_receive( return self._schedule_rpc(self.kill, msg=body) return None - def _schedule_rpc(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> kiwipy.Future: + def _schedule_rpc( + self, callback: Callable[..., Any], *args: Any, **kwargs: Any + ) -> kiwipy.Future: """ Schedule a call to a callback as a result of an RPC communication call, this will return a future that resolves to the final result (even after one or more layer of futures being @@ -1010,15 +1103,23 @@ def close(self) -> None: # region State related methods def transition_failed( - self, initial_state: Hashable, final_state: Hashable, exception: Exception, trace: TracebackType + self, + initial_state: Hashable, + final_state: Hashable, + exception: Exception, + trace: TracebackType, ) -> None: # If we are creating, then reraise instead of failing. if final_state == process_states.ProcessState.CREATED: raise exception.with_traceback(trace) - self.transition_to(process_states.Excepted, exception, trace) + self.transition_to( + process_states.Excepted, exception=exception, trace_back=trace + ) - def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]: + def pause( + self, msg: Union[str, None] = None + ) -> Union[bool, futures.CancellableAction]: """Pause the process. :param msg: an optional message to set as the status. The current status will be saved in the private @@ -1063,7 +1164,9 @@ def _do_pause(self, state_msg: Optional[str], next_state: Optional[process_state return True - def _create_interrupt_action(self, exception: process_states.Interruption) -> futures.CancellableAction: + def _create_interrupt_action( + self, exception: process_states.Interruption + ) -> futures.CancellableAction: """ Create an interrupt action from the corresponding interrupt exception @@ -1079,9 +1182,7 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu def do_kill(_next_state: process_states.State) -> Any: try: - # Ignore the next state - # __import__('ipdb').set_trace() - self.transition_to(process_states.Killed, exception) + self.transition_to(process_states.Killed, msg=exception.msg) return True finally: self._killing = None @@ -1090,7 +1191,9 @@ def do_kill(_next_state: process_states.State) -> Any: raise ValueError(f"Got unknown interruption type '{type(exception)}'") - def _set_interrupt_action(self, new_action: Optional[futures.CancellableAction]) -> None: + def _set_interrupt_action( + self, new_action: Optional[futures.CancellableAction] + ) -> None: """ Set the interrupt action cancelling the current one if it exists :param new_action: The new interrupt action to set @@ -1127,13 +1230,17 @@ def resume(self, *args: Any) -> None: return self._state.resume(*args) # type: ignore @event(to_states=process_states.Excepted) - def fail(self, exception: Optional[BaseException], trace_back: Optional[TracebackType]) -> None: + def fail( + self, exception: Optional[BaseException], trace_back: Optional[TracebackType] + ) -> None: """ Fail the process in response to an exception :param exception: The exception that caused the failure :param trace_back: Optional exception traceback """ - self.transition_to(process_states.Excepted, exception, trace_back) + self.transition_to( + process_states.Excepted, exception=exception, trace_back=trace_back + ) def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]: """ @@ -1161,7 +1268,7 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future] self._state.interrupt(interrupt_exception) return cast(futures.CancellableAction, self._interrupt_action) - self.transition_to(process_states.Killed, msg) + self.transition_to(process_states.Killed, msg=msg) return True @property @@ -1178,7 +1285,10 @@ def create_initial_state(self) -> process_states.State: :return: A Created state """ - return cast(process_states.State, self.get_state_class(process_states.ProcessState.CREATED)(self, self.run)) + return cast( + process_states.State, + self.get_state_class(process_states.ProcessState.CREATED)(self, self.run), + ) def recreate_state(self, saved_state: persistence.Bundle) -> process_states.State: """ @@ -1188,7 +1298,9 @@ def recreate_state(self, saved_state: persistence.Bundle) -> process_states.Stat :return: An instance of the object with its state loaded from the save state. """ load_context = persistence.LoadSaveContext(process=self) - return cast(process_states.State, persistence.Savable.load(saved_state, load_context)) + return cast( + process_states.State, persistence.Savable.load(saved_state, load_context) + ) # endregion @@ -1221,7 +1333,7 @@ async def step(self) -> None: The execute function running in this method is dependent on the state of the process. """ - assert not self.has_terminated(), 'Cannot step, already terminated' + assert not self.has_terminated(), "Cannot step, already terminated" if self.paused and self._paused is not None: await self._paused @@ -1246,7 +1358,9 @@ async def step(self) -> None: raise except Exception: # Overwrite the next state to go to excepted directly - next_state = self.create_state(process_states.ProcessState.EXCEPTED, *sys.exc_info()[1:]) + next_state = self.create_state( + process_states.ProcessState.EXCEPTED, *sys.exc_info()[1:] + ) self._set_interrupt_action(None) if self._interrupt_action: From 667af7a5d42dc86ff30903ca74f75da1737ae824 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Sun, 1 Dec 2024 17:35:40 +0100 Subject: [PATCH 07/29] Amend --- src/plumpy/base/state_machine.py | 84 +++++-------- src/plumpy/process_comms.py | 4 +- src/plumpy/process_states.py | 139 ++++++++------------- src/plumpy/processes.py | 205 +++++++++++-------------------- tests/test_processes.py | 1 - 5 files changed, 153 insertions(+), 280 deletions(-) diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index a371084a..499612e0 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -43,7 +43,7 @@ class StateEntryFailed(Exception): # noqa: N818 Failed to enter a state, can provide the next state to go to via this exception """ - def __init__(self, state: Hashable = None, *args: Any, **kwargs: Any) -> None: + def __init__(self, state: type['State'], *args: Any, **kwargs: Any) -> None: super().__init__('failed to enter state') self.state = state self.args = args @@ -72,10 +72,10 @@ def __init__( super().__init__(self._format_msg()) def _format_msg(self) -> str: - msg = [f"{self.initial_state} -> {self.final_state}"] + msg = [f'{self.initial_state} -> {self.final_state}'] if self.traceback_str is not None: msg.append(self.traceback_str) - return "\n".join(msg) + return '\n'.join(msg) def event( @@ -83,16 +83,16 @@ def event( to_states: Union[str, Type['State'], Iterable[Type['State']]] = '*', ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """A decorator to check for correct transitions, raising ``EventError`` on invalid transitions.""" - if from_states != "*": + if from_states != '*': if inspect.isclass(from_states): from_states = (from_states,) if not all(issubclass(state, State) for state in from_states): # type: ignore - raise TypeError(f"from_states: {from_states}") - if to_states != "*": + raise TypeError(f'from_states: {from_states}') + if to_states != '*': if inspect.isclass(to_states): to_states = (to_states,) if not all(issubclass(state, State) for state in to_states): # type: ignore - raise TypeError(f"to_states: {to_states}") + raise TypeError(f'to_states: {to_states}') def wrapper(wrapped: Callable[..., Any]) -> Callable[..., Any]: evt_label = wrapped.__name__ @@ -101,20 +101,14 @@ def wrapper(wrapped: Callable[..., Any]) -> Callable[..., Any]: def transition(self: Any, *a: Any, **kw: Any) -> Any: initial = self._state - if from_states != "*" and not any( - isinstance(self._state, state) for state in from_states - ): # type: ignore - raise EventError( - evt_label, f"Event {evt_label} invalid in state {initial.LABEL}" - ) + if from_states != '*' and not any(isinstance(self._state, state) for state in from_states): # type: ignore + raise EventError(evt_label, f'Event {evt_label} invalid in state {initial.LABEL}') result = wrapped(self, *a, **kw) if not (result is False or isinstance(result, Future)): - if to_states != "*" and not any( - isinstance(self._state, state) for state in to_states - ): # type: ignore + if to_states != '*' and not any(isinstance(self._state, state) for state in to_states): # type: ignore if self._state == initial: - raise EventError(evt_label, "Machine did not transition") + raise EventError(evt_label, 'Machine did not transition') raise EventError( evt_label, @@ -160,7 +154,7 @@ def label(self) -> LABEL_TYPE: def enter(self) -> None: """Entering the state""" - def execute(self) -> Optional["State"]: + def execute(self) -> Optional['State']: """ Execute the state, performing the actions that this state is responsible for. :returns: a state to transition to or None if finished. @@ -170,9 +164,9 @@ def execute(self) -> Optional["State"]: def exit(self) -> None: """Exiting the state""" if self.is_terminal(): - raise InvalidStateError(f"Cannot exit a terminal state {self.LABEL}") + raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') - def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> "State": + def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'State': return self.state_machine.create_state(state_label, *args, **kwargs) def do_enter(self) -> None: @@ -229,7 +223,7 @@ def get_states(cls) -> Sequence[Type[State]]: if cls.STATES is not None: return cls.STATES - raise RuntimeError("States not defined") + raise RuntimeError('States not defined') @classmethod def initial_state_label(cls) -> LABEL_TYPE: @@ -247,7 +241,7 @@ def get_state_class(cls, label: LABEL_TYPE) -> Type[State]: def __ensure_built(cls) -> None: try: # Check if it's already been built (and therefore sealed) - if cls.__getattribute__(cls, "sealed"): + if cls.__getattribute__(cls, 'sealed'): return except AttributeError: pass @@ -271,9 +265,7 @@ def __init__(self) -> None: self.__ensure_built() self._state: Optional[State] = None self._exception_handler = None # Note this appears to never be used - self.set_debug( - (not sys.flags.ignore_environment and bool(os.environ.get("PYTHONSMDEBUG"))) - ) + self.set_debug((not sys.flags.ignore_environment and bool(os.environ.get('PYTHONSMDEBUG')))) self._transitioning = False self._event_callbacks: Dict[Hashable, List[EVENT_CALLBACK_TYPE]] = {} @@ -282,7 +274,7 @@ def init(self) -> None: """Called after entering initial state in `__call__` method of `StateMachineMeta`""" def __str__(self) -> str: - return f"<{self.__class__.__name__}> ({self.state})" + return f'<{self.__class__.__name__}> ({self.state})' def create_initial_state(self) -> State: return self.get_state_class(self.initial_state_label())(self) @@ -293,9 +285,7 @@ def state(self) -> Optional[LABEL_TYPE]: return None return self._state.LABEL - def add_state_event_callback( - self, hook: Hashable, callback: EVENT_CALLBACK_TYPE - ) -> None: + def add_state_event_callback(self, hook: Hashable, callback: EVENT_CALLBACK_TYPE) -> None: """ Add a callback to be called on a particular state event hook. The callback should have form fn(state_machine, hook, state) @@ -305,10 +295,8 @@ def add_state_event_callback( """ self._event_callbacks.setdefault(hook, []).append(callback) - def remove_state_event_callback( - self, hook: Hashable, callback: EVENT_CALLBACK_TYPE - ) -> None: - if getattr(self, "_closed", False): + def remove_state_event_callback(self, hook: Hashable, callback: EVENT_CALLBACK_TYPE) -> None: + if getattr(self, '_closed', False): # if the process is closed, then all callbacks have already been removed return None try: @@ -324,19 +312,15 @@ def _fire_state_event(self, hook: Hashable, state: Optional[State]) -> None: def on_terminated(self) -> None: """Called when a terminal state is entered""" - def transition_to( - self, new_state: Union[State, Type[State]], **kwargs: Any - ) -> None: + def transition_to(self, new_state: Union[State, Type[State]], **kwargs: Any) -> None: """Transite to the new state. - The new target state will be create lazily when the state is not yet instantiated, + The new target state will be create lazily when the state is not yet instantiated, which will happened for states not in the expect path such as pause and kill. The arguments are passed to the state class to create state instance. (process arg does not need to pass since it will always call with 'self' as process) """ - assert ( - not self._transitioning - ), "Cannot call transition_to when already transitioning state" + assert not self._transitioning, 'Cannot call transition_to when already transitioning state' initial_state_label = self._state.LABEL if self._state is not None else None label = None @@ -358,9 +342,7 @@ def transition_to( except StateEntryFailed as exception: # Make sure we have a state instance if not isinstance(exception.state, State): - new_state = self._create_state_instance( - exception.state, **exception.kwargs - ) + new_state = self._create_state_instance(exception.state, **exception.kwargs) label = new_state.LABEL self._exit_current_state(new_state) self._enter_next_state(new_state) @@ -406,7 +388,7 @@ def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> Stat try: return self.get_states_map()[state_label](self, *args, **kwargs) except KeyError: - raise ValueError(f"{state_label} is not a valid state") + raise ValueError(f'{state_label} is not a valid state') def _exit_current_state(self, next_state: State) -> None: """Exit the given state""" @@ -415,15 +397,11 @@ def _exit_current_state(self, next_state: State) -> None: # in which case check the new state is the initial state if self._state is None: if next_state.label != self.initial_state_label(): - raise RuntimeError( - f"Cannot enter state '{next_state}' as the initial state" - ) + raise RuntimeError(f"Cannot enter state '{next_state}' as the initial state") return # Nothing to exit if next_state.LABEL not in self._state.ALLOWED: - raise RuntimeError( - f"Cannot transition from {self._state.LABEL} to {next_state.label}" - ) + raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.label}') self._fire_state_event(StateEventHook.EXITING_STATE, next_state) self._state.do_exit() @@ -435,10 +413,8 @@ def _enter_next_state(self, next_state: State) -> None: self._state = next_state self._fire_state_event(StateEventHook.ENTERED_STATE, last_state) - def _create_state_instance( - self, state_cls: type[State], **kwargs: Any - ) -> State: + def _create_state_instance(self, state_cls: type[State], **kwargs: Any) -> State: if state_cls.LABEL not in self.get_states_map(): - raise ValueError(f"{state_cls.LABEL} is not a valid state") + raise ValueError(f'{state_cls.LABEL} is not a valid state') return state_cls(self, **kwargs) diff --git a/src/plumpy/process_comms.py b/src/plumpy/process_comms.py index 9e1e4110..773a9742 100644 --- a/src/plumpy/process_comms.py +++ b/src/plumpy/process_comms.py @@ -33,6 +33,7 @@ MESSAGE_KEY = 'message' FORCE_KILL_KEY = 'force_kill' + class Intent: """Intent constants for a process message""" @@ -41,9 +42,10 @@ class Intent: KILL: str = 'kill' STATUS: str = 'status' + MessageType = dict[str, Any] -PAUSE_MSG: MessageType= {INTENT_KEY: Intent.PAUSE, MESSAGE_KEY: None} +PAUSE_MSG: MessageType = {INTENT_KEY: Intent.PAUSE, MESSAGE_KEY: None} PLAY_MSG: MessageType = {INTENT_KEY: Intent.PLAY, MESSAGE_KEY: None} KILL_MSG: MessageType = {INTENT_KEY: Intent.KILL, MESSAGE_KEY: None, FORCE_KILL_KEY: False} STATUS_MSG: MessageType = {INTENT_KEY: Intent.STATUS, MESSAGE_KEY: None} diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 46f29d8f..ede846e4 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -import sys import copy +import sys import traceback from enum import Enum from types import TracebackType @@ -70,7 +70,7 @@ class Command(persistence.Savable): pass -@auto_persist("msg") +@auto_persist('msg') class Kill(Command): def __init__(self, msg: Optional[MessageType] = None): super().__init__() @@ -81,7 +81,7 @@ class Pause(Command): pass -@auto_persist("msg", "data") +@auto_persist('msg', 'data') class Wait(Command): def __init__( self, @@ -95,7 +95,7 @@ def __init__( self.data = data -@auto_persist("result") +@auto_persist('result') class Stop(Command): def __init__(self, result: Any, successful: bool) -> None: super().__init__() @@ -103,9 +103,9 @@ def __init__(self, result: Any, successful: bool) -> None: self.successful = successful -@auto_persist("args", "kwargs") +@auto_persist('args', 'kwargs') class Continue(Command): - CONTINUE_FN = "continue_fn" + CONTINUE_FN = 'continue_fn' def __init__(self, continue_fn: Callable[..., Any], *args: Any, **kwargs: Any): super().__init__() @@ -113,15 +113,11 @@ def __init__(self, continue_fn: Callable[..., Any], *args: Any, **kwargs: Any): self.args = args self.kwargs = kwargs - def save_instance_state( - self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext - ) -> None: + def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) out_state[self.CONTINUE_FN] = self.continue_fn.__name__ - def load_instance_state( - self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext - ) -> None: + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) try: self.continue_fn = utils.load_function(saved_state[self.CONTINUE_FN]) @@ -148,7 +144,7 @@ class ProcessState(Enum): KILLED: str = 'killed' -@auto_persist("in_state") +@auto_persist('in_state') class State(state_machine.State, persistence.Savable): @property def process(self) -> state_machine.StateMachine: @@ -157,9 +153,7 @@ def process(self) -> state_machine.StateMachine: """ return self.state_machine - def load_instance_state( - self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext - ) -> None: + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self.state_machine = load_context.process @@ -167,41 +161,33 @@ def interrupt(self, reason: Any) -> None: pass -@auto_persist("args", "kwargs") +@auto_persist('args', 'kwargs') class Created(State): LABEL = ProcessState.CREATED ALLOWED = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED} - RUN_FN = "run_fn" + RUN_FN = 'run_fn' - def __init__( - self, process: "Process", run_fn: Callable[..., Any], *args: Any, **kwargs: Any - ) -> None: + def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: super().__init__(process) assert run_fn is not None self.run_fn = run_fn self.args = args self.kwargs = kwargs - def save_instance_state( - self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext - ) -> None: + def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) out_state[self.RUN_FN] = self.run_fn.__name__ - def load_instance_state( - self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext - ) -> None: + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) def execute(self) -> state_machine.State: - return self.create_state( - ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs - ) + return self.create_state(ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs) -@auto_persist("args", "kwargs") +@auto_persist('args', 'kwargs') class Running(State): LABEL = ProcessState.RUNNING ALLOWED = { @@ -212,17 +198,15 @@ class Running(State): ProcessState.EXCEPTED, } - RUN_FN = "run_fn" # The key used to store the function to run - COMMAND = "command" # The key used to store an upcoming command + RUN_FN = 'run_fn' # The key used to store the function to run + COMMAND = 'command' # The key used to store an upcoming command # Class level defaults _command: Union[None, Kill, Stop, Wait, Continue] = None _running: bool = False _run_handle = None - def __init__( - self, process: "Process", run_fn: Callable[..., Any], *args: Any, **kwargs: Any - ) -> None: + def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: super().__init__(process) assert run_fn is not None self.run_fn = run_fn @@ -230,23 +214,17 @@ def __init__( self.kwargs = kwargs self._run_handle = None - def save_instance_state( - self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext - ) -> None: + def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) out_state[self.RUN_FN] = self.run_fn.__name__ if self._command is not None: out_state[self.COMMAND] = self._command.save() - def load_instance_state( - self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext - ) -> None: + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) if self.COMMAND in saved_state: - self._command = persistence.Savable.load( - saved_state[self.COMMAND], load_context - ) # type: ignore + self._command = persistence.Savable.load(saved_state[self.COMMAND], load_context) # type: ignore def interrupt(self, reason: Any) -> None: pass @@ -286,24 +264,18 @@ def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> State: # elif isinstance(command, Pause): # self.pause() elif isinstance(command, Stop): - state = self.create_state( - ProcessState.FINISHED, command.result, command.successful - ) + state = self.create_state(ProcessState.FINISHED, command.result, command.successful) elif isinstance(command, Wait): - state = self.create_state( - ProcessState.WAITING, command.continue_fn, command.msg, command.data - ) + state = self.create_state(ProcessState.WAITING, command.continue_fn, command.msg, command.data) elif isinstance(command, Continue): - state = self.create_state( - ProcessState.RUNNING, command.continue_fn, *command.args - ) + state = self.create_state(ProcessState.RUNNING, command.continue_fn, *command.args) else: - raise ValueError("Unrecognised command") + raise ValueError('Unrecognised command') return cast(State, state) # casting from base.State to process.State -@auto_persist("msg", "data") +@auto_persist('msg', 'data') class Waiting(State): LABEL = ProcessState.WAITING ALLOWED = { @@ -314,19 +286,19 @@ class Waiting(State): ProcessState.FINISHED, } - DONE_CALLBACK = "DONE_CALLBACK" + DONE_CALLBACK = 'DONE_CALLBACK' _interruption = None def __str__(self) -> str: state_info = super().__str__() if self.msg is not None: - state_info += f" ({self.msg})" + state_info += f' ({self.msg})' return state_info def __init__( self, - process: "Process", + process: 'Process', done_callback: Optional[Callable[..., Any]], msg: Optional[str] = None, data: Optional[Any] = None, @@ -337,16 +309,12 @@ def __init__( self.data = data self._waiting_future: futures.Future = futures.Future() - def save_instance_state( - self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext - ) -> None: + def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) if self.done_callback is not None: out_state[self.DONE_CALLBACK] = self.done_callback.__name__ - def load_instance_state( - self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext - ) -> None: + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) callback_name = saved_state.get(self.DONE_CALLBACK, None) if callback_name is not None: @@ -372,14 +340,12 @@ async def execute(self) -> State: # type: ignore if result == NULL: next_state = self.create_state(ProcessState.RUNNING, self.done_callback) else: - next_state = self.create_state( - ProcessState.RUNNING, self.done_callback, result - ) + next_state = self.create_state(ProcessState.RUNNING, self.done_callback, result) return cast(State, next_state) # casting from base.State to process.State def resume(self, value: Any = NULL) -> None: - assert self._waiting_future is not None, "Not yet waiting" + assert self._waiting_future is not None, 'Not yet waiting' if self._waiting_future.done(): return @@ -397,12 +363,12 @@ class Excepted(State): LABEL = ProcessState.EXCEPTED - EXC_VALUE = "ex_value" - TRACEBACK = "traceback" + EXC_VALUE = 'ex_value' + TRACEBACK = 'traceback' def __init__( self, - process: "Process", + process: 'Process', exception: Optional[BaseException], trace_back: Optional[TracebackType] = None, ): @@ -416,22 +382,16 @@ def __init__( self.traceback = trace_back def __str__(self) -> str: - exception = traceback.format_exception_only( - type(self.exception) if self.exception else None, self.exception - )[0] - return super().__str__() + f"({exception})" + exception = traceback.format_exception_only(type(self.exception) if self.exception else None, self.exception)[0] + return super().__str__() + f'({exception})' - def save_instance_state( - self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext - ) -> None: + def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) out_state[self.EXC_VALUE] = yaml.dump(self.exception) if self.traceback is not None: - out_state[self.TRACEBACK] = "".join(traceback.format_tb(self.traceback)) + out_state[self.TRACEBACK] = ''.join(traceback.format_tb(self.traceback)) - def load_instance_state( - self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext - ) -> None: + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self.exception = yaml.load(saved_state[self.EXC_VALUE], Loader=Loader) if _HAS_TBLIB: @@ -444,9 +404,7 @@ def load_instance_state( def get_exc_info( self, - ) -> Tuple[ - Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType] - ]: + ) -> Tuple[Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType]]: """ Recreate the exc_info tuple and return it """ @@ -457,22 +415,23 @@ def get_exc_info( ) -@auto_persist("result", "successful") +@auto_persist('result', 'successful') class Finished(State): """State for process is finished. :param result: The result of process :param successful: Boolean for the exit code is ``0`` the process is successful. """ + LABEL = ProcessState.FINISHED - def __init__(self, process: "Process", result: Any, successful: bool) -> None: + def __init__(self, process: 'Process', result: Any, successful: bool) -> None: super().__init__(process) self.result = result self.successful = successful -@auto_persist("msg") +@auto_persist('msg') class Killed(State): """ Represents a state where a process has been killed. @@ -485,7 +444,7 @@ class Killed(State): LABEL = ProcessState.KILLED - def __init__(self, process: "Process", msg: Optional[MessageType]): + def __init__(self, process: 'Process', msg: Optional[MessageType]): """ :param process: The associated process :param msg: Optional kill message diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 07e2d20c..9358d927 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -15,6 +15,7 @@ import warnings from types import TracebackType from typing import ( + TYPE_CHECKING, Any, Awaitable, Callable, @@ -53,15 +54,18 @@ from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event from .base.utils import call_with_super_check, super_check from .event_helper import EventHelper +from .process_comms import KILL_MSG, MESSAGE_KEY, MessageType from .process_listener import ProcessListener from .process_spec import ProcessSpec from .utils import PID_TYPE, SAVED_STATE_TYPE, protected -from .process_comms import KILL_MSG, MESSAGE_KEY, MessageType + +if TYPE_CHECKING: + from .process_states import State __all__ = ['BundleKeys', 'Process', 'ProcessSpec', 'TransitionFailed'] _LOGGER = logging.getLogger(__name__) -PROCESS_STACK = ContextVar("process stack", default=[]) +PROCESS_STACK = ContextVar('process stack', default=[]) class BundleKeys: @@ -94,20 +98,20 @@ def ensure_not_closed(func: Callable[..., Any]) -> Callable[..., Any]: @functools.wraps(func) def func_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: if self._closed: - raise exceptions.ClosedError("Process is closed") + raise exceptions.ClosedError('Process is closed') return func(self, *args, **kwargs) return func_wrapper @persistence.auto_persist( - "_pid", - "_creation_time", - "_future", - "_paused", - "_status", - "_pre_paused_status", - "_event_helper", + '_pid', + '_creation_time', + '_future', + '_paused', + '_status', + '_pre_paused_status', + '_event_helper', ) class Process(StateMachine, persistence.Savable, metaclass=ProcessStateMachineMeta): """ @@ -161,7 +165,7 @@ class Process(StateMachine, persistence.Savable, metaclass=ProcessStateMachineMe __called: bool = False @classmethod - def current(cls) -> Optional["Process"]: + def current(cls) -> Optional['Process']: """ Get the currently running process i.e. the one at the top of the stack @@ -197,15 +201,15 @@ def get_state_classes(cls) -> Dict[Hashable, Type[process_states.State]]: @classmethod def spec(cls) -> ProcessSpec: try: - return cls.__getattribute__(cls, "_spec") + return cls.__getattribute__(cls, '_spec') except AttributeError: try: cls._spec: ProcessSpec = cls._spec_class() # type: ignore cls.__called: bool = False # type: ignore cls.define(cls._spec) # type: ignore assert cls.__called, ( - f"Process.define() was not called by {cls}\nHint: Did you forget to call the superclass method in " - "your define? Try: super().define(spec)" + f'Process.define() was not called by {cls}\nHint: Did you forget to call the superclass method in ' + 'your define? Try: super().define(spec)' ) return cls._spec # type: ignore except Exception: @@ -237,11 +241,11 @@ def get_description(cls) -> Dict[str, Any]: description: Dict[str, Any] = {} if cls.__doc__: - description["description"] = cls.__doc__.strip() + description['description'] = cls.__doc__.strip() spec_description = cls.spec().get_description() if spec_description: - description["spec"] = spec_description + description['spec'] = spec_description return description @@ -250,7 +254,7 @@ def recreate_from( cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext] = None, - ) -> "Process": + ) -> 'Process': """ Recreate a process from a saved state, passing any positional and keyword arguments on to load_instance_state @@ -298,9 +302,7 @@ def __init__( self._paused = None # Input/output - self._raw_inputs = ( - None if inputs is None else utils.AttributesFrozendict(inputs) - ) + self._raw_inputs = None if inputs is None else utils.AttributesFrozendict(inputs) self._pid = pid self._parsed_inputs: Optional[utils.AttributesFrozendict] = None self._outputs: Dict[str, Any] = {} @@ -323,35 +325,19 @@ def init(self) -> None: if self._communicator is not None: try: - identifier = self._communicator.add_rpc_subscriber( - self.message_receive, identifier=str(self.pid) - ) - self.add_cleanup( - functools.partial( - self._communicator.remove_rpc_subscriber, identifier - ) - ) + identifier = self._communicator.add_rpc_subscriber(self.message_receive, identifier=str(self.pid)) + self.add_cleanup(functools.partial(self._communicator.remove_rpc_subscriber, identifier)) except kiwipy.TimeoutError: - self.logger.exception( - "Process<%s>: failed to register as an RPC subscriber", self.pid - ) + self.logger.exception('Process<%s>: failed to register as an RPC subscriber', self.pid) try: # filter out state change broadcasts - subscriber = kiwipy.BroadcastFilter( - self.broadcast_receive, subject=re.compile(r"^(?!state_changed).*") - ) - identifier = self._communicator.add_broadcast_subscriber( - subscriber, identifier=str(self.pid) - ) - self.add_cleanup( - functools.partial( - self._communicator.remove_broadcast_subscriber, identifier - ) - ) + subscriber = kiwipy.BroadcastFilter(self.broadcast_receive, subject=re.compile(r'^(?!state_changed).*')) + identifier = self._communicator.add_broadcast_subscriber(subscriber, identifier=str(self.pid)) + self.add_cleanup(functools.partial(self._communicator.remove_broadcast_subscriber, identifier)) except kiwipy.TimeoutError: self.logger.exception( - "Process<%s>: failed to register as a broadcast subscriber", + 'Process<%s>: failed to register as a broadcast subscriber', self.pid, ) @@ -360,10 +346,10 @@ def init(self) -> None: def try_killing(future: futures.Future) -> None: if future.cancelled(): msg = copy.copy(KILL_MSG) - msg[MESSAGE_KEY] = "Killed by future being cancelled" + msg[MESSAGE_KEY] = 'Killed by future being cancelled' if not self.kill(msg): self.logger.warning( - "Process<%s>: Failed to kill process on future cancel", + 'Process<%s>: Failed to kill process on future cancel', self.pid, ) @@ -460,7 +446,7 @@ def future(self) -> persistence.SavableFuture: @ensure_not_closed def launch( self, - process_class: Type["Process"], + process_class: Type['Process'], inputs: Optional[dict] = None, pid: Optional[PID_TYPE] = None, logger: Optional[logging.Logger] = None, @@ -498,7 +484,7 @@ def result(self) -> Any: if isinstance(self._state, process_states.Killed): raise exceptions.KilledError(self._state.msg) if isinstance(self._state, process_states.Excepted): - raise (self._state.exception or Exception("process excepted")) + raise (self._state.exception or Exception('process excepted')) raise exceptions.InvalidStateError @@ -510,9 +496,7 @@ def successful(self) -> bool: try: return self._state.successful # type: ignore except AttributeError as exception: - raise exceptions.InvalidStateError( - "process is not in the finished state" - ) from exception + raise exceptions.InvalidStateError('process is not in the finished state') from exception @property def is_successful(self) -> bool: @@ -534,7 +518,7 @@ def killed_msg(self) -> Optional[MessageType]: if isinstance(self._state, process_states.Killed): return self._state.msg - raise exceptions.InvalidStateError("Has not been killed") + raise exceptions.InvalidStateError('Has not been killed') def exception(self) -> Optional[BaseException]: """Return exception, if the process is terminated in excepted state.""" @@ -569,9 +553,7 @@ def loop(self) -> asyncio.AbstractEventLoop: """Return the event loop of the process.""" return self._loop - def call_soon( - self, callback: Callable[..., Any], *args: Any, **kwargs: Any - ) -> events.ProcessCallback: + def call_soon(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> events.ProcessCallback: """ Schedule a callback to what is considered an internal process function (this needn't be a method). @@ -605,16 +587,14 @@ def _process_scope(self) -> Generator[None, None, None]: yield None finally: assert Process.current() is self, ( - "Somehow, the process at the top of the stack is not me, but another process! " - f"({self} != {Process.current()})" + 'Somehow, the process at the top of the stack is not me, but another process! ' + f'({self} != {Process.current()})' ) stack_copy = PROCESS_STACK.get().copy() stack_copy.pop() PROCESS_STACK.set(stack_copy) - async def _run_task( - self, callback: Callable[..., Any], *args: Any, **kwargs: Any - ) -> Any: + async def _run_task(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: """ This method should be used to run all Process related functions and coroutines. If there is an exception the process will enter the EXCEPTED state. @@ -647,7 +627,7 @@ def save_instance_state( """ super().save_instance_state(out_state, save_context) - out_state["_state"] = self._state.save() + out_state['_state'] = self._state.save() # Inputs/outputs if self.raw_inputs is not None: @@ -660,9 +640,7 @@ def save_instance_state( out_state[BundleKeys.OUTPUTS] = self.encode_input_args(self.outputs) @protected - def load_instance_state( - self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext - ) -> None: + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: """Load the process from its saved instance state. :param saved_state: A bundle to load the state from @@ -680,17 +658,17 @@ def load_instance_state( self._logger = None self._communicator = None - if "loop" in load_context: + if 'loop' in load_context: self._loop = load_context.loop else: self._loop = asyncio.get_event_loop() - self._state: process_states.State = self.recreate_state(saved_state["_state"]) + self._state: process_states.State = self.recreate_state(saved_state['_state']) - if "communicator" in load_context: + if 'communicator' in load_context: self._communicator = load_context.communicator - if "logger" in load_context: + if 'logger' in load_context: self._logger = load_context.logger # Need to call this here as things downstream may rely on us having the runtime variable above @@ -739,7 +717,7 @@ def set_logger(self, logger: logging.Logger) -> None: @protected def log_with_pid(self, level: int, msg: str) -> None: """Log the message with the process pid.""" - self.logger.log(level, "%s: %s", self.pid, msg) + self.logger.log(level, '%s: %s', self.pid, msg) # region Events @@ -774,24 +752,16 @@ def on_entered(self, from_state: Optional[process_states.State]) -> None: call_with_super_check(self.on_killed) if self._communicator and isinstance(self.state, enum.Enum): - from_label = ( - cast(enum.Enum, from_state.LABEL).value - if from_state is not None - else None - ) - subject = f"state_changed.{from_label}.{self.state.value}" - self.logger.info( - "Process<%s>: Broadcasting state change: %s", self.pid, subject - ) + from_label = cast(enum.Enum, from_state.LABEL).value if from_state is not None else None + subject = f'state_changed.{from_label}.{self.state.value}' + self.logger.info('Process<%s>: Broadcasting state change: %s', self.pid, subject) try: - self._communicator.broadcast_send( - body=None, sender=self.pid, subject=subject - ) + self._communicator.broadcast_send(body=None, sender=self.pid, subject=subject) except (ConnectionClosed, ChannelInvalidStateError): - message = "Process<%s>: no connection available to broadcast state change from %s to %s" + message = 'Process<%s>: no connection available to broadcast state change from %s to %s' self.logger.warning(message, self.pid, from_label, self.state.value) except kiwipy.TimeoutError: - message = "Process<%s>: sending broadcast of state change from %s to %s timed out" + message = 'Process<%s>: sending broadcast of state change from %s to %s timed out' self.logger.warning(message, self.pid, from_label, self.state.value) def on_exiting(self) -> None: @@ -809,10 +779,7 @@ def on_create(self) -> None: def recursively_copy_dictionaries(value: Any) -> Any: """Recursively copy the mapping but only create copies of the dictionaries not the values.""" if isinstance(value, dict): - return { - key: recursively_copy_dictionaries(subvalue) - for key, subvalue in value.items() - } + return {key: recursively_copy_dictionaries(subvalue) for key, subvalue in value.items()} return value # This will parse the inputs with respect to the input portnamespace of the spec and validate them. The @@ -820,11 +787,7 @@ def recursively_copy_dictionaries(value: Any) -> Any: # ``_raw_inputs`` should not be modified, we pass a clone of it. Note that we only need a clone of the nested # dictionaries, so we don't use ``copy.deepcopy`` (which might seem like the obvious choice) as that will also # create a clone of the values, which we don't want. - raw_inputs = ( - recursively_copy_dictionaries(dict(self._raw_inputs)) - if self._raw_inputs - else {} - ) + raw_inputs = recursively_copy_dictionaries(dict(self._raw_inputs)) if self._raw_inputs else {} self._parsed_inputs = self.spec().inputs.pre_process(raw_inputs) result = self.spec().inputs.validate(self._parsed_inputs) @@ -857,9 +820,7 @@ def on_output_emitting(self, output_port: str, value: Any) -> None: """Output is about to be emitted.""" def on_output_emitted(self, output_port: str, value: Any, dynamic: bool) -> None: - self._event_helper.fire_event( - ProcessListener.on_output_emitted, self, output_port, value, dynamic - ) + self._event_helper.fire_event(ProcessListener.on_output_emitted, self, output_port, value, dynamic) @super_check def on_wait(self, awaitables: Sequence[Awaitable]) -> None: @@ -908,9 +869,7 @@ def on_finish(self, result: Any, successful: bool) -> None: if successful: validation_error = self.spec().outputs.validate(self.outputs) if validation_error: - raise StateEntryFailed( - process_states.Finished, result=result, successful=False - ) + raise StateEntryFailed(process_states.Finished, result=result, successful=False) self.future().set_result(self.outputs) @@ -936,17 +895,15 @@ def on_except(self, exc_info: Tuple[Any, Exception, TracebackType]) -> None: @super_check def on_excepted(self) -> None: """Entered the EXCEPTED state.""" - self._fire_event( - ProcessListener.on_process_excepted, str(self.future().exception()) - ) + self._fire_event(ProcessListener.on_process_excepted, str(self.future().exception())) @super_check def on_kill(self, msg: Optional[MessageType]) -> None: """Entering the KILLED state.""" if msg is None: - msg_txt = "" + msg_txt = '' else: - msg_txt = msg[MESSAGE_KEY] or "" + msg_txt = msg[MESSAGE_KEY] or '' self.set_status(msg_txt) self.future().set_exception(exceptions.KilledError(msg_txt)) @@ -1007,9 +964,7 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An if intent == process_comms.Intent.PLAY: return self._schedule_rpc(self.play) if intent == process_comms.Intent.PAUSE: - return self._schedule_rpc( - self.pause, msg=msg.get(process_comms.MESSAGE_KEY, None) - ) + return self._schedule_rpc(self.pause, msg=msg.get(process_comms.MESSAGE_KEY, None)) if intent == process_comms.Intent.KILL: return self._schedule_rpc(self.kill, msg=msg) if intent == process_comms.Intent.STATUS: @@ -1018,7 +973,7 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An return status_info # Didn't match any known intents - raise RuntimeError("Unknown intent") + raise RuntimeError('Unknown intent') def broadcast_receive( self, _comm: kiwipy.Communicator, body: Any, sender: Any, subject: Any, correlation_id: Any @@ -1047,9 +1002,7 @@ def broadcast_receive( return self._schedule_rpc(self.kill, msg=body) return None - def _schedule_rpc( - self, callback: Callable[..., Any], *args: Any, **kwargs: Any - ) -> kiwipy.Future: + def _schedule_rpc(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> kiwipy.Future: """ Schedule a call to a callback as a result of an RPC communication call, this will return a future that resolves to the final result (even after one or more layer of futures being @@ -1113,13 +1066,9 @@ def transition_failed( if final_state == process_states.ProcessState.CREATED: raise exception.with_traceback(trace) - self.transition_to( - process_states.Excepted, exception=exception, trace_back=trace - ) + self.transition_to(process_states.Excepted, exception=exception, trace_back=trace) - def pause( - self, msg: Union[str, None] = None - ) -> Union[bool, futures.CancellableAction]: + def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]: """Pause the process. :param msg: an optional message to set as the status. The current status will be saved in the private @@ -1164,9 +1113,7 @@ def _do_pause(self, state_msg: Optional[str], next_state: Optional[process_state return True - def _create_interrupt_action( - self, exception: process_states.Interruption - ) -> futures.CancellableAction: + def _create_interrupt_action(self, exception: process_states.Interruption) -> futures.CancellableAction: """ Create an interrupt action from the corresponding interrupt exception @@ -1191,9 +1138,7 @@ def do_kill(_next_state: process_states.State) -> Any: raise ValueError(f"Got unknown interruption type '{type(exception)}'") - def _set_interrupt_action( - self, new_action: Optional[futures.CancellableAction] - ) -> None: + def _set_interrupt_action(self, new_action: Optional[futures.CancellableAction]) -> None: """ Set the interrupt action cancelling the current one if it exists :param new_action: The new interrupt action to set @@ -1230,17 +1175,13 @@ def resume(self, *args: Any) -> None: return self._state.resume(*args) # type: ignore @event(to_states=process_states.Excepted) - def fail( - self, exception: Optional[BaseException], trace_back: Optional[TracebackType] - ) -> None: + def fail(self, exception: Optional[BaseException], trace_back: Optional[TracebackType]) -> None: """ Fail the process in response to an exception :param exception: The exception that caused the failure :param trace_back: Optional exception traceback """ - self.transition_to( - process_states.Excepted, exception=exception, trace_back=trace_back - ) + self.transition_to(process_states.Excepted, exception=exception, trace_back=trace_back) def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]: """ @@ -1298,9 +1239,7 @@ def recreate_state(self, saved_state: persistence.Bundle) -> process_states.Stat :return: An instance of the object with its state loaded from the save state. """ load_context = persistence.LoadSaveContext(process=self) - return cast( - process_states.State, persistence.Savable.load(saved_state, load_context) - ) + return cast(process_states.State, persistence.Savable.load(saved_state, load_context)) # endregion @@ -1333,7 +1272,7 @@ async def step(self) -> None: The execute function running in this method is dependent on the state of the process. """ - assert not self.has_terminated(), "Cannot step, already terminated" + assert not self.has_terminated(), 'Cannot step, already terminated' if self.paused and self._paused is not None: await self._paused @@ -1358,9 +1297,7 @@ async def step(self) -> None: raise except Exception: # Overwrite the next state to go to excepted directly - next_state = self.create_state( - process_states.ProcessState.EXCEPTED, *sys.exc_info()[1:] - ) + next_state = self.create_state(process_states.ProcessState.EXCEPTED, *sys.exc_info()[1:]) self._set_interrupt_action(None) if self._interrupt_action: diff --git a/tests/test_processes.py b/tests/test_processes.py index 6481273e..47085c90 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -5,7 +5,6 @@ import copy import enum from plumpy.process_comms import KILL_MSG, MESSAGE_KEY -from test import utils import unittest import kiwipy From 4be6931d45fc339e86fb8de07813492581fc5bd5 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 2 Dec 2024 00:29:20 +0100 Subject: [PATCH 08/29] If transition_to None do noting --- src/plumpy/base/state_machine.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index 499612e0..853ca668 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -312,7 +312,7 @@ def _fire_state_event(self, hook: Hashable, state: Optional[State]) -> None: def on_terminated(self) -> None: """Called when a terminal state is entered""" - def transition_to(self, new_state: Union[State, Type[State]], **kwargs: Any) -> None: + def transition_to(self, new_state: State | type[State] | None, **kwargs: Any) -> None: """Transite to the new state. The new target state will be create lazily when the state is not yet instantiated, @@ -322,6 +322,9 @@ def transition_to(self, new_state: Union[State, Type[State]], **kwargs: Any) -> """ assert not self._transitioning, 'Cannot call transition_to when already transitioning state' + if new_state is None: + return None + initial_state_label = self._state.LABEL if self._state is not None else None label = None try: From 88259d67c243fe55f551c6ed20beb83fa65f2b00 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 2 Dec 2024 00:29:40 +0100 Subject: [PATCH 09/29] KillMessage build msg from parameters --- src/plumpy/process_comms.py | 21 ++++++++++++++++----- src/plumpy/process_states.py | 5 ++--- src/plumpy/processes.py | 12 +++++------- tests/rmq/test_process_comms.py | 3 +-- tests/test_processes.py | 10 +++------- tests/utils.py | 6 ++---- 6 files changed, 29 insertions(+), 28 deletions(-) diff --git a/src/plumpy/process_comms.py b/src/plumpy/process_comms.py index 773a9742..bc2fa125 100644 --- a/src/plumpy/process_comms.py +++ b/src/plumpy/process_comms.py @@ -12,10 +12,10 @@ from .utils import PID_TYPE __all__ = [ - 'KILL_MSG', 'PAUSE_MSG', 'PLAY_MSG', 'STATUS_MSG', + 'KillMessage', 'ProcessLauncher', 'RemoteProcessController', 'RemoteProcessThreadController', @@ -47,9 +47,20 @@ class Intent: PAUSE_MSG: MessageType = {INTENT_KEY: Intent.PAUSE, MESSAGE_KEY: None} PLAY_MSG: MessageType = {INTENT_KEY: Intent.PLAY, MESSAGE_KEY: None} -KILL_MSG: MessageType = {INTENT_KEY: Intent.KILL, MESSAGE_KEY: None, FORCE_KILL_KEY: False} +# KILL_MSG: MessageType = {INTENT_KEY: Intent.KILL, MESSAGE_KEY: None, FORCE_KILL_KEY: False} STATUS_MSG: MessageType = {INTENT_KEY: Intent.STATUS, MESSAGE_KEY: None} + +class KillMessage: + @classmethod + def build(cls, message: str | None = None, force: bool = False) -> MessageType: + return { + INTENT_KEY: Intent.KILL, + MESSAGE_KEY: message, + FORCE_KILL_KEY: force, + } + + TASK_KEY = 'task' TASK_ARGS = 'args' PERSIST_KEY = 'persist' @@ -209,7 +220,7 @@ async def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) :return: True if killed, False otherwise """ if msg is None: - msg = copy.copy(KILL_MSG) + msg = KillMessage.build() # Wait for the communication to go through kill_future = self._communicator.rpc_send(pid, msg) @@ -384,7 +395,7 @@ def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> ki """ if msg is None: - msg = copy.copy(KILL_MSG) + msg = KillMessage.build() return self._communicator.rpc_send(pid, msg) @@ -395,7 +406,7 @@ def kill_all(self, msg: Optional[MessageType]) -> None: :param msg: an optional pause message """ if msg is None: - msg = copy.copy(KILL_MSG) + msg = KillMessage.build() self._communicator.broadcast_send(msg, subject=Intent.KILL) diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index ede846e4..45178b42 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- -import copy import sys import traceback from enum import Enum @@ -9,7 +8,7 @@ import yaml from yaml.loader import Loader -from plumpy.process_comms import KILL_MSG, MessageType +from plumpy.process_comms import KillMessage, MessageType try: import tblib @@ -54,7 +53,7 @@ class KillInterruption(Interruption): def __init__(self, msg: MessageType | None): super().__init__() if msg is None: - msg = copy.copy(KILL_MSG) + msg = KillMessage.build() self.msg: MessageType = msg diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 9358d927..ef558fa1 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -15,7 +15,6 @@ import warnings from types import TracebackType from typing import ( - TYPE_CHECKING, Any, Awaitable, Callable, @@ -27,6 +26,7 @@ Sequence, Tuple, Type, + TypeVar, Union, cast, ) @@ -54,13 +54,12 @@ from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event from .base.utils import call_with_super_check, super_check from .event_helper import EventHelper -from .process_comms import KILL_MSG, MESSAGE_KEY, MessageType +from .process_comms import MESSAGE_KEY, KillMessage, MessageType from .process_listener import ProcessListener from .process_spec import ProcessSpec from .utils import PID_TYPE, SAVED_STATE_TYPE, protected -if TYPE_CHECKING: - from .process_states import State +T = TypeVar('T') __all__ = ['BundleKeys', 'Process', 'ProcessSpec', 'TransitionFailed'] @@ -345,8 +344,7 @@ def init(self) -> None: def try_killing(future: futures.Future) -> None: if future.cancelled(): - msg = copy.copy(KILL_MSG) - msg[MESSAGE_KEY] = 'Killed by future being cancelled' + msg = KillMessage.build(message='Killed by future being cancelled') if not self.kill(msg): self.logger.warning( 'Process<%s>: Failed to kill process on future cancel', @@ -594,7 +592,7 @@ def _process_scope(self) -> Generator[None, None, None]: stack_copy.pop() PROCESS_STACK.set(stack_copy) - async def _run_task(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: + async def _run_task(self, callback: Callable[..., T], *args: Any, **kwargs: Any) -> T: """ This method should be used to run all Process related functions and coroutines. If there is an exception the process will enter the EXCEPTED state. diff --git a/tests/rmq/test_process_comms.py b/tests/rmq/test_process_comms.py index 7223b888..4c7a4f1a 100644 --- a/tests/rmq/test_process_comms.py +++ b/tests/rmq/test_process_comms.py @@ -196,8 +196,7 @@ async def test_kill_all(self, thread_communicator, sync_controller): for _ in range(10): procs.append(utils.WaitForSignalProcess(communicator=thread_communicator)) - msg = copy.copy(process_comms.KILL_MSG) - msg[process_comms.MESSAGE_KEY] = 'bang bang, I shot you down' + msg = process_comms.KillMessage.build(message='bang bang, I shot you down') sync_controller.kill_all(msg) await utils.wait_util(lambda: all([proc.killed() for proc in procs])) diff --git a/tests/test_processes.py b/tests/test_processes.py index 47085c90..cec20c51 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -2,9 +2,8 @@ """Process tests""" import asyncio -import copy import enum -from plumpy.process_comms import KILL_MSG, MESSAGE_KEY +from plumpy.process_comms import KillMessage import unittest import kiwipy @@ -16,7 +15,6 @@ import plumpy from plumpy import BundleKeys, Process, ProcessState -from plumpy.process_comms import KILL_MSG, MESSAGE_KEY from plumpy.utils import AttributesFrozendict @@ -327,8 +325,7 @@ def run(self, **kwargs): def test_kill(self): proc: Process = utils.DummyProcess() - msg = copy.copy(KILL_MSG) - msg[MESSAGE_KEY] = 'Farewell!' + msg = KillMessage.build(message='Farewell!') proc.kill(msg) self.assertTrue(proc.killed()) self.assertEqual(proc.killed_msg(), msg) @@ -434,8 +431,7 @@ class KillProcess(Process): after_kill = False def run(self, **kwargs): - msg = copy.copy(KILL_MSG) - msg[MESSAGE_KEY] = 'killed' + msg = KillMessage.build(message='killed') self.kill(msg) # The following line should be executed because kill will not # interrupt execution of a method call in the RUNNING state diff --git a/tests/utils.py b/tests/utils.py index f2a58dfc..88638e01 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,13 +3,12 @@ import asyncio import collections -import copy import unittest from collections.abc import Mapping import plumpy from plumpy import persistence, process_states, processes, utils -from plumpy.process_comms import KILL_MSG, MESSAGE_KEY +from plumpy.process_comms import KillMessage Snapshot = collections.namedtuple('Snapshot', ['state', 'bundle', 'outputs']) @@ -86,8 +85,7 @@ def last_step(self): class KillProcess(processes.Process): @utils.override def run(self): - msg = copy.copy(KILL_MSG) - msg[MESSAGE_KEY] = 'killed' + msg = KillMessage.build(message='killed') return process_states.Kill(msg=msg) From c3c9db40230a7a60b25e58ad3a6541d4f2e74032 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 2 Dec 2024 00:56:52 +0100 Subject: [PATCH 10/29] Pause/Play/Status all using message builder --- docs/source/tutorial.ipynb | 2 +- src/plumpy/process_comms.py | 61 +++++++++++++++++++++++++------------ 2 files changed, 42 insertions(+), 21 deletions(-) diff --git a/docs/source/tutorial.ipynb b/docs/source/tutorial.ipynb index c1fdb3b2..fe25892d 100644 --- a/docs/source/tutorial.ipynb +++ b/docs/source/tutorial.ipynb @@ -1118,7 +1118,7 @@ "\n", "process = SimpleProcess(communicator=communicator)\n", "\n", - "pprint(communicator.rpc_send(str(process.pid), plumpy.STATUS_MSG).result())" + "pprint(communicator.rpc_send(str(process.pid), plumpy.StatusMessage.build()).result())" ] }, { diff --git a/src/plumpy/process_comms.py b/src/plumpy/process_comms.py index bc2fa125..39b70d4f 100644 --- a/src/plumpy/process_comms.py +++ b/src/plumpy/process_comms.py @@ -2,7 +2,6 @@ """Module for process level communication functions and classes""" import asyncio -import copy import logging from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union, cast @@ -12,13 +11,13 @@ from .utils import PID_TYPE __all__ = [ - 'PAUSE_MSG', - 'PLAY_MSG', - 'STATUS_MSG', 'KillMessage', + 'PauseMessage', + 'PlayMessage', 'ProcessLauncher', 'RemoteProcessController', 'RemoteProcessThreadController', + 'StatusMessage', 'create_continue_body', 'create_launch_body', ] @@ -45,10 +44,27 @@ class Intent: MessageType = dict[str, Any] -PAUSE_MSG: MessageType = {INTENT_KEY: Intent.PAUSE, MESSAGE_KEY: None} -PLAY_MSG: MessageType = {INTENT_KEY: Intent.PLAY, MESSAGE_KEY: None} -# KILL_MSG: MessageType = {INTENT_KEY: Intent.KILL, MESSAGE_KEY: None, FORCE_KILL_KEY: False} -STATUS_MSG: MessageType = {INTENT_KEY: Intent.STATUS, MESSAGE_KEY: None} +# PAUSE_MSG: MessageType = {INTENT_KEY: Intent.PAUSE, MESSAGE_KEY: None} +# PLAY_MSG: MessageType = {INTENT_KEY: Intent.PLAY, MESSAGE_KEY: None} +# STATUS_MSG: MessageType = {INTENT_KEY: Intent.STATUS, MESSAGE_KEY: None} + + +class PlayMessage: + @classmethod + def build(cls, message: str | None = None) -> MessageType: + return { + INTENT_KEY: Intent.PLAY, + MESSAGE_KEY: message, + } + + +class PauseMessage: + @classmethod + def build(cls, message: str | None = None) -> MessageType: + return { + INTENT_KEY: Intent.PAUSE, + MESSAGE_KEY: message, + } class KillMessage: @@ -61,6 +77,15 @@ def build(cls, message: str | None = None, force: bool = False) -> MessageType: } +class StatusMessage: + @classmethod + def build(cls, message: str | None = None) -> MessageType: + return { + INTENT_KEY: Intent.STATUS, + MESSAGE_KEY: message, + } + + TASK_KEY = 'task' TASK_ARGS = 'args' PERSIST_KEY = 'persist' @@ -176,7 +201,7 @@ async def get_status(self, pid: 'PID_TYPE') -> 'ProcessStatus': :param pid: the process id :return: the status response from the process """ - future = self._communicator.rpc_send(pid, STATUS_MSG) + future = self._communicator.rpc_send(pid, StatusMessage.build()) result = await asyncio.wrap_future(future) return result @@ -188,11 +213,9 @@ async def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'Pr :param msg: optional pause message :return: True if paused, False otherwise """ - message = copy.copy(PAUSE_MSG) - if msg is not None: - message[MESSAGE_KEY] = msg + msg = PauseMessage.build(message=msg) - pause_future = self._communicator.rpc_send(pid, message) + pause_future = self._communicator.rpc_send(pid, msg) # rpc_send return a thread future from communicator future = await asyncio.wrap_future(pause_future) # future is just returned from rpc call which return a kiwipy future @@ -206,7 +229,7 @@ async def play_process(self, pid: 'PID_TYPE') -> 'ProcessResult': :param pid: the pid of the process to play :return: True if played, False otherwise """ - play_future = self._communicator.rpc_send(pid, PLAY_MSG) + play_future = self._communicator.rpc_send(pid, PlayMessage.build()) future = await asyncio.wrap_future(play_future) result = await asyncio.wrap_future(future) return result @@ -344,7 +367,7 @@ def get_status(self, pid: 'PID_TYPE') -> kiwipy.Future: :param pid: the process id :return: the status response from the process """ - return self._communicator.rpc_send(pid, STATUS_MSG) + return self._communicator.rpc_send(pid, StatusMessage.build()) def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Future: """ @@ -355,11 +378,9 @@ def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Fu :return: a response future from the process to be paused """ - message = copy.copy(PAUSE_MSG) - if msg is not None: - message[MESSAGE_KEY] = msg + msg = PauseMessage.build(message=msg) - return self._communicator.rpc_send(pid, message) + return self._communicator.rpc_send(pid, msg) def pause_all(self, msg: Any) -> None: """ @@ -377,7 +398,7 @@ def play_process(self, pid: 'PID_TYPE') -> kiwipy.Future: :return: a response future from the process to be played """ - return self._communicator.rpc_send(pid, PLAY_MSG) + return self._communicator.rpc_send(pid, PlayMessage.build()) def play_all(self) -> None: """ From d0e4e73b810ad22d9fc9ef1130ae088053782fec Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 2 Dec 2024 09:38:29 +0100 Subject: [PATCH 11/29] Remove duplicate codes --- tests/persistence/test_inmemory.py | 4 +--- tests/persistence/test_pickle.py | 4 ++-- tests/rmq/test_process_comms.py | 1 - tests/test_communications.py | 2 -- tests/test_expose.py | 38 +----------------------------- tests/test_process_comms.py | 2 +- tests/test_processes.py | 7 ++---- 7 files changed, 7 insertions(+), 51 deletions(-) diff --git a/tests/persistence/test_inmemory.py b/tests/persistence/test_inmemory.py index b0db46e7..9e3141de 100644 --- a/tests/persistence/test_inmemory.py +++ b/tests/persistence/test_inmemory.py @@ -1,11 +1,9 @@ # -*- coding: utf-8 -*- import unittest -from ..utils import ProcessWithCheckpoint - import plumpy -import plumpy +from ..utils import ProcessWithCheckpoint class TestInMemoryPersister(unittest.TestCase): diff --git a/tests/persistence/test_pickle.py b/tests/persistence/test_pickle.py index dd68b4fd..da4ede51 100644 --- a/tests/persistence/test_pickle.py +++ b/tests/persistence/test_pickle.py @@ -5,10 +5,10 @@ if getattr(tempfile, 'TemporaryDirectory', None) is None: from backports import tempfile -from ..utils import ProcessWithCheckpoint - import plumpy +from ..utils import ProcessWithCheckpoint + class TestPicklePersister(unittest.TestCase): def test_save_load_roundtrip(self): diff --git a/tests/rmq/test_process_comms.py b/tests/rmq/test_process_comms.py index 4c7a4f1a..c6826a24 100644 --- a/tests/rmq/test_process_comms.py +++ b/tests/rmq/test_process_comms.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- import asyncio -import copy import kiwipy import pytest diff --git a/tests/test_communications.py b/tests/test_communications.py index 37177d6e..f7e04255 100644 --- a/tests/test_communications.py +++ b/tests/test_communications.py @@ -4,8 +4,6 @@ import pytest from kiwipy import CommunicatorHelper -import pytest -from kiwipy import CommunicatorHelper from plumpy.communications import LoopCommunicator diff --git a/tests/test_expose.py b/tests/test_expose.py index f48ce32e..c5e6014c 100644 --- a/tests/test_expose.py +++ b/tests/test_expose.py @@ -1,47 +1,11 @@ # -*- coding: utf-8 -*- import unittest -from .utils import NewLoopProcess - from plumpy.ports import PortNamespace from plumpy.process_spec import ProcessSpec from plumpy.processes import Process - -def validator_function(input, port): - pass - - -class BaseNamespaceProcess(NewLoopProcess): - @classmethod - def define(cls, spec): - super().define(spec) - spec.input('top') - spec.input('namespace.sub_one') - spec.input('namespace.sub_two') - spec.inputs['namespace'].valid_type = (int, float) - spec.inputs['namespace'].validator = validator_function - - -class BaseProcess(NewLoopProcess): - @classmethod - def define(cls, spec): - super().define(spec) - spec.input('a', valid_type=str, default='a') - spec.input('b', valid_type=str, default='b') - spec.inputs.dynamic = True - spec.inputs.valid_type = str - - -class ExposeProcess(NewLoopProcess): - @classmethod - def define(cls, spec): - super().define(spec) - spec.expose_inputs(BaseProcess, namespace='base.name.space') - spec.input('c', valid_type=int, default=1) - spec.input('d', valid_type=int, default=2) - spec.inputs.dynamic = True - spec.inputs.valid_type = int +from .utils import NewLoopProcess def validator_function(input, port): diff --git a/tests/test_process_comms.py b/tests/test_process_comms.py index c59737ac..44947230 100644 --- a/tests/test_process_comms.py +++ b/tests/test_process_comms.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- import pytest -from tests import utils import plumpy from plumpy import process_comms +from tests import utils class Process(plumpy.Process): diff --git a/tests/test_processes.py b/tests/test_processes.py index cec20c51..eb5bf599 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -3,19 +3,16 @@ import asyncio import enum -from plumpy.process_comms import KillMessage import unittest import kiwipy import pytest -from tests import utils - -import plumpy -import pytest import plumpy from plumpy import BundleKeys, Process, ProcessState +from plumpy.process_comms import KillMessage from plumpy.utils import AttributesFrozendict +from tests import utils class ForgetToCallParent(plumpy.Process): From e3c2ae806b76c4e777bfc82bd440338e37485c29 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 2 Dec 2024 09:45:05 +0100 Subject: [PATCH 12/29] Future type annotation --- src/plumpy/base/state_machine.py | 2 ++ src/plumpy/process_comms.py | 4 +++- src/plumpy/process_states.py | 2 ++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index 853ca668..2035d4ab 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- """The state machine for processes""" +from __future__ import annotations + import enum import functools import inspect diff --git a/src/plumpy/process_comms.py b/src/plumpy/process_comms.py index 39b70d4f..5727bdae 100644 --- a/src/plumpy/process_comms.py +++ b/src/plumpy/process_comms.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- """Module for process level communication functions and classes""" +from __future__ import annotations + import asyncio import logging from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union, cast @@ -42,7 +44,7 @@ class Intent: STATUS: str = 'status' -MessageType = dict[str, Any] +MessageType = Dict[str, Any] # PAUSE_MSG: MessageType = {INTENT_KEY: Intent.PAUSE, MESSAGE_KEY: None} # PLAY_MSG: MessageType = {INTENT_KEY: Intent.PLAY, MESSAGE_KEY: None} diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 45178b42..dbbb7bef 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -1,4 +1,6 @@ # -*- coding: utf-8 -*- +from __future__ import annotations + import sys import traceback from enum import Enum From e5c74ad7fe5ce52125e76d561c6c1a59b727bc18 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 2 Dec 2024 10:06:22 +0100 Subject: [PATCH 13/29] Fix doc --- docs/source/tutorial.ipynb | 2 +- src/plumpy/process_comms.py | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/docs/source/tutorial.ipynb b/docs/source/tutorial.ipynb index fe25892d..af1ed795 100644 --- a/docs/source/tutorial.ipynb +++ b/docs/source/tutorial.ipynb @@ -281,7 +281,7 @@ " def continue_fn(self):\n", " print('continuing')\n", " # message is stored in the process status\n", - " return plumpy.Kill('I was killed')\n", + " return plumpy.Kill(plumpy.KillMessage.build('I was killed'))\n", "\n", "\n", "process = ContinueProcess()\n", diff --git a/src/plumpy/process_comms.py b/src/plumpy/process_comms.py index 5727bdae..cd6e7238 100644 --- a/src/plumpy/process_comms.py +++ b/src/plumpy/process_comms.py @@ -46,10 +46,6 @@ class Intent: MessageType = Dict[str, Any] -# PAUSE_MSG: MessageType = {INTENT_KEY: Intent.PAUSE, MESSAGE_KEY: None} -# PLAY_MSG: MessageType = {INTENT_KEY: Intent.PLAY, MESSAGE_KEY: None} -# STATUS_MSG: MessageType = {INTENT_KEY: Intent.STATUS, MESSAGE_KEY: None} - class PlayMessage: @classmethod From 18eb56e38e3c30b663de0b15d018da42f1c5e12e Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Tue, 3 Dec 2024 02:25:27 +0100 Subject: [PATCH 14/29] Mapping states from state name --- src/plumpy/base/state_machine.py | 13 +++---------- src/plumpy/processes.py | 20 +++++++++++++++----- tests/base/test_statemachine.py | 9 +++++---- 3 files changed, 23 insertions(+), 19 deletions(-) diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index 2035d4ab..be27e0cd 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -45,7 +45,7 @@ class StateEntryFailed(Exception): # noqa: N818 Failed to enter a state, can provide the next state to go to via this exception """ - def __init__(self, state: type['State'], *args: Any, **kwargs: Any) -> None: + def __init__(self, state: State, *args: Any, **kwargs: Any) -> None: super().__init__('failed to enter state') self.state = state self.args = args @@ -314,7 +314,7 @@ def _fire_state_event(self, hook: Hashable, state: Optional[State]) -> None: def on_terminated(self) -> None: """Called when a terminal state is entered""" - def transition_to(self, new_state: State | type[State] | None, **kwargs: Any) -> None: + def transition_to(self, new_state: State | None, **kwargs: Any) -> None: """Transite to the new state. The new target state will be create lazily when the state is not yet instantiated, @@ -331,11 +331,6 @@ def transition_to(self, new_state: State | type[State] | None, **kwargs: Any) -> label = None try: self._transitioning = True - - if not isinstance(new_state, State): - # Make sure we have a state instance - new_state = self._create_state_instance(new_state, **kwargs) - label = new_state.LABEL # If the previous transition failed, do not try to exit it but go straight to next state @@ -345,9 +340,7 @@ def transition_to(self, new_state: State | type[State] | None, **kwargs: Any) -> try: self._enter_next_state(new_state) except StateEntryFailed as exception: - # Make sure we have a state instance - if not isinstance(exception.state, State): - new_state = self._create_state_instance(exception.state, **exception.kwargs) + new_state = exception.state label = new_state.LABEL self._exit_current_state(new_state) self._enter_next_state(new_state) diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index ef558fa1..5e2f4cbd 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -867,7 +867,9 @@ def on_finish(self, result: Any, successful: bool) -> None: if successful: validation_error = self.spec().outputs.validate(self.outputs) if validation_error: - raise StateEntryFailed(process_states.Finished, result=result, successful=False) + state_cls = self.get_states_map()[process_states.ProcessState.FINISHED] + finished_state = state_cls(self, result=result, successful=False) + raise StateEntryFailed(finished_state) self.future().set_result(self.outputs) @@ -1064,7 +1066,9 @@ def transition_failed( if final_state == process_states.ProcessState.CREATED: raise exception.with_traceback(trace) - self.transition_to(process_states.Excepted, exception=exception, trace_back=trace) + state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED] + new_state = self._create_state_instance(state_class, exception=exception, trace_back=trace) + self.transition_to(new_state) def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]: """Pause the process. @@ -1127,7 +1131,9 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu def do_kill(_next_state: process_states.State) -> Any: try: - self.transition_to(process_states.Killed, msg=exception.msg) + state_class = self.get_states_map()[process_states.ProcessState.KILLED] + new_state = self._create_state_instance(state_class, msg=exception.msg) + self.transition_to(new_state) return True finally: self._killing = None @@ -1179,7 +1185,9 @@ def fail(self, exception: Optional[BaseException], trace_back: Optional[Tracebac :param exception: The exception that caused the failure :param trace_back: Optional exception traceback """ - self.transition_to(process_states.Excepted, exception=exception, trace_back=trace_back) + state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED] + new_state = self._create_state_instance(state_class, exception=exception, trace_back=trace_back) + self.transition_to(new_state) def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]: """ @@ -1207,7 +1215,9 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future] self._state.interrupt(interrupt_exception) return cast(futures.CancellableAction, self._interrupt_action) - self.transition_to(process_states.Killed, msg=msg) + state_class = self.get_states_map()[process_states.ProcessState.KILLED] + new_state = self._create_state_instance(state_class, msg=msg) + self.transition_to(new_state) return True @property diff --git a/tests/base/test_statemachine.py b/tests/base/test_statemachine.py index 5b4b73d8..3a1621a2 100644 --- a/tests/base/test_statemachine.py +++ b/tests/base/test_statemachine.py @@ -57,6 +57,7 @@ class Paused(state_machine.State): def __init__(self, player, playing_state): assert isinstance(playing_state, Playing), 'Must provide the playing state to pause' super().__init__(player) + self._player = player self.playing_state = playing_state def __str__(self): @@ -64,7 +65,7 @@ def __str__(self): def play(self, track=None): if track is not None: - self.state_machine.transition_to(Playing, track=track) + self.state_machine.transition_to(Playing(player=self.state_machine, track=track)) else: self.state_machine.transition_to(self.playing_state) @@ -80,7 +81,7 @@ def __str__(self): return '[]' def play(self, track): - self.state_machine.transition_to(Playing, track=track) + self.state_machine.transition_to(Playing(self.state_machine, track=track)) class CdPlayer(state_machine.StateMachine): @@ -107,12 +108,12 @@ def play(self, track=None): @state_machine.event(from_states=Playing, to_states=Paused) def pause(self): - self.transition_to(Paused, playing_state=self._state) + self.transition_to(Paused(self, playing_state=self._state)) return True @state_machine.event(from_states=(Playing, Paused), to_states=Stopped) def stop(self): - self.transition_to(Stopped) + self.transition_to(Stopped(self)) class TestStateMachine(unittest.TestCase): From b5056284475cfa44c6c87d2d831e0e08e8e222e8 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 2 Dec 2024 15:53:13 +0100 Subject: [PATCH 15/29] Remove the middle layer of statemachine.State + Savable abstraction --- docs/source/nitpick-exceptions | 2 +- src/plumpy/process_states.py | 111 +++++++++++++++++++++++---------- src/plumpy/processes.py | 26 ++++---- src/plumpy/workchains.py | 4 +- 4 files changed, 94 insertions(+), 49 deletions(-) diff --git a/docs/source/nitpick-exceptions b/docs/source/nitpick-exceptions index 2f354987..e1d6d969 100644 --- a/docs/source/nitpick-exceptions +++ b/docs/source/nitpick-exceptions @@ -18,7 +18,7 @@ py:class kiwipy.communications.Communicator # unavailable forward references py:class plumpy.process_states.Command -py:class plumpy.process_states.State +py:class plumpy.state_machine.State py:class plumpy.base.state_machine.State py:class State py:class Process diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index dbbb7bef..44a916e9 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -120,6 +120,7 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) + self.state_machine = load_context.process try: self.continue_fn = utils.load_function(saved_state[self.CONTINUE_FN]) except ValueError: @@ -145,25 +146,8 @@ class ProcessState(Enum): KILLED: str = 'killed' -@auto_persist('in_state') -class State(state_machine.State, persistence.Savable): - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine - - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process - - def interrupt(self, reason: Any) -> None: - pass - - -@auto_persist('args', 'kwargs') -class Created(State): +@auto_persist('args', 'kwargs', 'in_state') +class Created(state_machine.State, persistence.Savable): LABEL = ProcessState.CREATED ALLOWED = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED} @@ -182,14 +166,23 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) + self.state_machine = load_context.process + self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) def execute(self) -> state_machine.State: return self.create_state(ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs) + @property + def process(self) -> state_machine.StateMachine: + """ + :return: The process + """ + return self.state_machine -@auto_persist('args', 'kwargs') -class Running(State): + +@auto_persist('args', 'kwargs', 'in_state') +class Running(state_machine.State, persistence.Savable): LABEL = ProcessState.RUNNING ALLOWED = { ProcessState.RUNNING, @@ -223,6 +216,8 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) + self.state_machine = load_context.process + self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) if self.COMMAND in saved_state: self._command = persistence.Savable.load(saved_state[self.COMMAND], load_context) # type: ignore @@ -230,7 +225,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi def interrupt(self, reason: Any) -> None: pass - async def execute(self) -> State: # type: ignore + async def execute(self) -> state_machine.State: # type: ignore if self._command is not None: command = self._command else: @@ -245,7 +240,7 @@ async def execute(self) -> State: # type: ignore raise except Exception: excepted = self.create_state(ProcessState.EXCEPTED, *sys.exc_info()[1:]) - return cast(State, excepted) + return cast(state_machine.State, excepted) else: if not isinstance(result, Command): if isinstance(result, exceptions.UnsuccessfulResult): @@ -259,7 +254,7 @@ async def execute(self) -> State: # type: ignore next_state = self._action_command(command) return next_state - def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> State: + def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> state_machine.State: if isinstance(command, Kill): state = self.create_state(ProcessState.KILLED, command.msg) # elif isinstance(command, Pause): @@ -273,11 +268,18 @@ def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> State: else: raise ValueError('Unrecognised command') - return cast(State, state) # casting from base.State to process.State + return cast(state_machine.State, state) # casting from base.State to process.State + @property + def process(self) -> state_machine.StateMachine: + """ + :return: The process + """ + return self.state_machine -@auto_persist('msg', 'data') -class Waiting(State): + +@auto_persist('msg', 'data', 'in_state') +class Waiting(state_machine.State, persistence.Savable): LABEL = ProcessState.WAITING ALLOWED = { ProcessState.RUNNING, @@ -317,6 +319,8 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) + self.state_machine = load_context.process + callback_name = saved_state.get(self.DONE_CALLBACK, None) if callback_name is not None: self.done_callback = getattr(self.process, callback_name) @@ -328,7 +332,7 @@ def interrupt(self, reason: Any) -> None: # This will cause the future in execute() to raise the exception self._waiting_future.set_exception(reason) - async def execute(self) -> State: # type: ignore + async def execute(self) -> state_machine.State: # type: ignore try: result = await self._waiting_future except Interruption: @@ -343,7 +347,7 @@ async def execute(self) -> State: # type: ignore else: next_state = self.create_state(ProcessState.RUNNING, self.done_callback, result) - return cast(State, next_state) # casting from base.State to process.State + return cast(state_machine.State, next_state) # casting from base.State to process.State def resume(self, value: Any = NULL) -> None: assert self._waiting_future is not None, 'Not yet waiting' @@ -353,8 +357,16 @@ def resume(self, value: Any = NULL) -> None: self._waiting_future.set_result(value) + @property + def process(self) -> state_machine.StateMachine: + """ + :return: The process + """ + return self.state_machine + -class Excepted(State): +@auto_persist('in_state') +class Excepted(state_machine.State, persistence.Savable): """ Excepted state, can optionally provide exception and trace_back @@ -394,6 +406,8 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) + self.state_machine = load_context.process + self.exception = yaml.load(saved_state[self.EXC_VALUE], Loader=Loader) if _HAS_TBLIB: try: @@ -415,9 +429,16 @@ def get_exc_info( self.traceback, ) + @property + def process(self) -> state_machine.StateMachine: + """ + :return: The process + """ + return self.state_machine + -@auto_persist('result', 'successful') -class Finished(State): +@auto_persist('result', 'successful', 'in_state') +class Finished(state_machine.State, persistence.Savable): """State for process is finished. :param result: The result of process @@ -431,9 +452,20 @@ def __init__(self, process: 'Process', result: Any, successful: bool) -> None: self.result = result self.successful = successful + @property + def process(self) -> state_machine.StateMachine: + """ + :return: The process + """ + return self.state_machine + + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: + super().load_instance_state(saved_state, load_context) + self.state_machine = load_context.process + -@auto_persist('msg') -class Killed(State): +@auto_persist('msg', 'in_state') +class Killed(state_machine.State, persistence.Savable): """ Represents a state where a process has been killed. @@ -453,5 +485,16 @@ def __init__(self, process: 'Process', msg: Optional[MessageType]): super().__init__(process) self.msg = msg + @property + def process(self) -> state_machine.StateMachine: + """ + :return: The process + """ + return self.state_machine + + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: + super().load_instance_state(saved_state, load_context) + self.state_machine = load_context.process + # endregion diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 5e2f4cbd..1fe05470 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -177,7 +177,7 @@ def current(cls) -> Optional['Process']: return None @classmethod - def get_states(cls) -> Sequence[Type[process_states.State]]: + def get_states(cls) -> Sequence[Type[state_machine.State]]: """Return all allowed states of the process.""" state_classes = cls.get_state_classes() return ( @@ -186,7 +186,7 @@ def get_states(cls) -> Sequence[Type[process_states.State]]: ) @classmethod - def get_state_classes(cls) -> Dict[Hashable, Type[process_states.State]]: + def get_state_classes(cls) -> Dict[Hashable, Type[state_machine.State]]: # A mapping of the State constants to the corresponding state class return { process_states.ProcessState.CREATED: process_states.Created, @@ -357,10 +357,10 @@ def _setup_event_hooks(self) -> None: """Set the event hooks to process, when it is created or loaded(recreated).""" event_hooks = { state_machine.StateEventHook.ENTERING_STATE: lambda _s, _h, state: self.on_entering( - cast(process_states.State, state) + cast(state_machine.State, state) ), state_machine.StateEventHook.ENTERED_STATE: lambda _s, _h, from_state: self.on_entered( - cast(Optional[process_states.State], from_state) + cast(Optional[state_machine.State], from_state) ), state_machine.StateEventHook.EXITING_STATE: lambda _s, _h, _state: self.on_exiting(), } @@ -661,7 +661,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi else: self._loop = asyncio.get_event_loop() - self._state: process_states.State = self.recreate_state(saved_state['_state']) + self._state: state_machine.State = self.recreate_state(saved_state['_state']) if 'communicator' in load_context: self._communicator = load_context.communicator @@ -719,7 +719,7 @@ def log_with_pid(self, level: int, msg: str) -> None: # region Events - def on_entering(self, state: process_states.State) -> None: + def on_entering(self, state: state_machine.State) -> None: # Map these onto direct functions that the subclass can implement state_label = state.LABEL if state_label == process_states.ProcessState.CREATED: @@ -735,7 +735,7 @@ def on_entering(self, state: process_states.State) -> None: elif state_label == process_states.ProcessState.EXCEPTED: call_with_super_check(self.on_except, state.get_exc_info()) # type: ignore - def on_entered(self, from_state: Optional[process_states.State]) -> None: + def on_entered(self, from_state: Optional[state_machine.State]) -> None: # Map these onto direct functions that the subclass can implement state_label = self._state.LABEL if state_label == process_states.ProcessState.RUNNING: @@ -1103,7 +1103,7 @@ def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.Cancellable return self._do_pause(msg) - def _do_pause(self, state_msg: Optional[str], next_state: Optional[process_states.State] = None) -> bool: + def _do_pause(self, state_msg: Optional[str], next_state: Optional[state_machine.State] = None) -> bool: """Carry out the pause procedure, optionally transitioning to the next state first""" try: if next_state is not None: @@ -1129,7 +1129,7 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu if isinstance(exception, process_states.KillInterruption): - def do_kill(_next_state: process_states.State) -> Any: + def do_kill(_next_state: state_machine.State) -> Any: try: state_class = self.get_states_map()[process_states.ProcessState.KILLED] new_state = self._create_state_instance(state_class, msg=exception.msg) @@ -1227,7 +1227,7 @@ def is_killing(self) -> bool: # endregion - def create_initial_state(self) -> process_states.State: + def create_initial_state(self) -> state_machine.State: """This method is here to override its superclass. Automatically enter the CREATED state when the process is created. @@ -1235,11 +1235,11 @@ def create_initial_state(self) -> process_states.State: :return: A Created state """ return cast( - process_states.State, + state_machine.State, self.get_state_class(process_states.ProcessState.CREATED)(self, self.run), ) - def recreate_state(self, saved_state: persistence.Bundle) -> process_states.State: + def recreate_state(self, saved_state: persistence.Bundle) -> state_machine.State: """ Create a state object from a saved state @@ -1247,7 +1247,7 @@ def recreate_state(self, saved_state: persistence.Bundle) -> process_states.Stat :return: An instance of the object with its state loaded from the save state. """ load_context = persistence.LoadSaveContext(process=self) - return cast(process_states.State, persistence.Savable.load(saved_state, load_context)) + return cast(state_machine.State, persistence.Savable.load(saved_state, load_context)) # endregion diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 748a44d7..9eafcb50 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -25,6 +25,8 @@ import kiwipy +from plumpy.base import state_machine + from . import lang, mixins, persistence, process_states, processes from .utils import PID_TYPE, SAVED_STATE_TYPE @@ -117,7 +119,7 @@ class WorkChain(mixins.ContextMixin, processes.Process): _CONTEXT = 'CONTEXT' @classmethod - def get_state_classes(cls) -> Dict[Hashable, Type[process_states.State]]: + def get_state_classes(cls) -> Dict[Hashable, Type[state_machine.State]]: states_map = super().get_state_classes() states_map[process_states.ProcessState.WAITING] = Waiting return states_map From 7f8a30e215a5760736e968953172412b297f21a6 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 2 Dec 2024 16:06:39 +0100 Subject: [PATCH 16/29] Move is_terminal as class attribute required --- src/plumpy/base/state_machine.py | 8 ++------ src/plumpy/process_states.py | 11 +++++++++++ src/plumpy/processes.py | 4 ++-- tests/base/test_statemachine.py | 6 ++++++ 4 files changed, 21 insertions(+), 8 deletions(-) diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index be27e0cd..380f4610 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -133,10 +133,6 @@ class State: # from this one ALLOWED: Set[LABEL_TYPE] = set() - @classmethod - def is_terminal(cls) -> bool: - return not cls.ALLOWED - def __init__(self, state_machine: 'StateMachine', *args: Any, **kwargs: Any): """ :param state_machine: The process this state belongs to @@ -165,7 +161,7 @@ def execute(self) -> Optional['State']: @super_check def exit(self) -> None: """Exiting the state""" - if self.is_terminal(): + if self.is_terminal: raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'State': @@ -345,7 +341,7 @@ def transition_to(self, new_state: State | None, **kwargs: Any) -> None: self._exit_current_state(new_state) self._enter_next_state(new_state) - if self._state is not None and self._state.is_terminal(): + if self._state is not None and self._state.is_terminal: call_with_super_check(self.on_terminated) except Exception: self._transitioning = False diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 44a916e9..91959c4d 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -152,6 +152,7 @@ class Created(state_machine.State, persistence.Savable): ALLOWED = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED} RUN_FN = 'run_fn' + is_terminal = False def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: super().__init__(process) @@ -200,6 +201,8 @@ class Running(state_machine.State, persistence.Savable): _running: bool = False _run_handle = None + is_terminal = False + def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: super().__init__(process) assert run_fn is not None @@ -293,6 +296,8 @@ class Waiting(state_machine.State, persistence.Savable): _interruption = None + is_terminal = False + def __str__(self) -> str: state_info = super().__str__() if self.msg is not None: @@ -379,6 +384,8 @@ class Excepted(state_machine.State, persistence.Savable): EXC_VALUE = 'ex_value' TRACEBACK = 'traceback' + is_terminal = True + def __init__( self, process: 'Process', @@ -447,6 +454,8 @@ class Finished(state_machine.State, persistence.Savable): LABEL = ProcessState.FINISHED + is_terminal = True + def __init__(self, process: 'Process', result: Any, successful: bool) -> None: super().__init__(process) self.result = result @@ -477,6 +486,8 @@ class Killed(state_machine.State, persistence.Savable): LABEL = ProcessState.KILLED + is_terminal = True + def __init__(self, process: 'Process', msg: Optional[MessageType]): """ :param process: The associated process diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 1fe05470..1e745437 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -467,7 +467,7 @@ def launch( def has_terminated(self) -> bool: """Return whether the process was terminated.""" - return self._state.is_terminal() + return self._state.is_terminal def result(self) -> Any: """ @@ -540,7 +540,7 @@ def done(self) -> bool: Use the `has_terminated` method instead """ warnings.warn('method is deprecated, use `has_terminated` instead', DeprecationWarning) - return self._state.is_terminal() + return self._state.is_terminal # endregion diff --git a/tests/base/test_statemachine.py b/tests/base/test_statemachine.py index 3a1621a2..b6d7e2d3 100644 --- a/tests/base/test_statemachine.py +++ b/tests/base/test_statemachine.py @@ -20,6 +20,8 @@ class Playing(state_machine.State): ALLOWED = {PAUSED, STOPPED} TRANSITIONS = {STOP: STOPPED} + is_terminal = False + def __init__(self, player, track): assert track is not None, 'Must provide a track name' super().__init__(player) @@ -54,6 +56,8 @@ class Paused(state_machine.State): ALLOWED = {PLAYING, STOPPED} TRANSITIONS = {STOP: STOPPED} + is_terminal = False + def __init__(self, player, playing_state): assert isinstance(playing_state, Playing), 'Must provide the playing state to pause' super().__init__(player) @@ -77,6 +81,8 @@ class Stopped(state_machine.State): } TRANSITIONS = {PLAY: PLAYING} + is_terminal = False + def __str__(self): return '[]' From e2078927aacce69e54b26e2a6af84c351d564350 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 2 Dec 2024 17:07:29 +0100 Subject: [PATCH 17/29] forming the enter/exit for State protocol --- src/plumpy/base/state_machine.py | 66 ++++---------- src/plumpy/process_states.py | 148 ++++++++++++++++++------------- src/plumpy/workchains.py | 25 +++--- tests/base/test_statemachine.py | 44 +++++++-- tests/test_processes.py | 2 +- 5 files changed, 157 insertions(+), 128 deletions(-) diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index 380f4610..2164737f 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -13,15 +13,17 @@ from typing import ( Any, Callable, + ClassVar, Dict, Hashable, Iterable, List, Optional, + Protocol, Sequence, - Set, Type, Union, + runtime_checkable, ) from plumpy.futures import Future @@ -88,12 +90,12 @@ def event( if from_states != '*': if inspect.isclass(from_states): from_states = (from_states,) - if not all(issubclass(state, State) for state in from_states): # type: ignore + if not all(isinstance(state, State) for state in from_states): # type: ignore raise TypeError(f'from_states: {from_states}') if to_states != '*': if inspect.isclass(to_states): to_states = (to_states,) - if not all(issubclass(state, State) for state in to_states): # type: ignore + if not all(isinstance(state, State) for state in to_states): # type: ignore raise TypeError(f'to_states: {to_states}') def wrapper(wrapped: Callable[..., Any]) -> Callable[..., Any]: @@ -127,53 +129,20 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any: return wrapper -class State: - LABEL: LABEL_TYPE = None - # A set containing the labels of states that can be entered - # from this one - ALLOWED: Set[LABEL_TYPE] = set() +@runtime_checkable +class State(Protocol): + LABEL: ClassVar[LABEL_TYPE] - def __init__(self, state_machine: 'StateMachine', *args: Any, **kwargs: Any): - """ - :param state_machine: The process this state belongs to - """ - self.state_machine = state_machine - self.in_state: bool = False - - def __str__(self) -> str: - return str(self.LABEL) - - @property - def label(self) -> LABEL_TYPE: - """Convenience property to get the state label""" - return self.LABEL - - @super_check - def enter(self) -> None: - """Entering the state""" - - def execute(self) -> Optional['State']: + async def execute(self) -> State | None: """ Execute the state, performing the actions that this state is responsible for. :returns: a state to transition to or None if finished. """ + ... - @super_check - def exit(self) -> None: - """Exiting the state""" - if self.is_terminal: - raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') - - def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'State': - return self.state_machine.create_state(state_label, *args, **kwargs) - - def do_enter(self) -> None: - call_with_super_check(self.enter) - self.in_state = True + def enter(self) -> None: ... - def do_exit(self) -> None: - call_with_super_check(self.exit) - self.in_state = False + def exit(self) -> None: ... class StateEventHook(enum.Enum): @@ -250,7 +219,7 @@ def __ensure_built(cls) -> None: # Build the states map cls._STATES_MAP = {} for state_cls in cls.STATES: - assert issubclass(state_cls, State) + assert isinstance(state_cls, State) label = state_cls.LABEL assert label not in cls._STATES_MAP, f"Duplicate label '{label}'" cls._STATES_MAP[label] = state_cls @@ -380,7 +349,8 @@ def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> Stat # This method should be replaced by `_create_state_instance`. # aiida-core using this method for its Waiting state override. try: - return self.get_states_map()[state_label](self, *args, **kwargs) + state_cls = self.get_states_map()[state_label] + return state_cls(self, *args, **kwargs) except KeyError: raise ValueError(f'{state_label} is not a valid state') @@ -390,20 +360,20 @@ def _exit_current_state(self, next_state: State) -> None: # If we're just being constructed we may not have a state yet to exit, # in which case check the new state is the initial state if self._state is None: - if next_state.label != self.initial_state_label(): + if next_state.LABEL != self.initial_state_label(): raise RuntimeError(f"Cannot enter state '{next_state}' as the initial state") return # Nothing to exit if next_state.LABEL not in self._state.ALLOWED: raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.label}') self._fire_state_event(StateEventHook.EXITING_STATE, next_state) - self._state.do_exit() + self._state.exit() def _enter_next_state(self, next_state: State) -> None: last_state = self._state self._fire_state_event(StateEventHook.ENTERING_STATE, next_state) # Enter the new state - next_state.do_enter() + next_state.enter() self._state = next_state self._fire_state_event(StateEventHook.ENTERED_STATE, last_state) diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 91959c4d..88cab660 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -5,7 +5,7 @@ import traceback from enum import Enum from types import TracebackType -from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union, cast, final import yaml from yaml.loader import Loader @@ -146,6 +146,7 @@ class ProcessState(Enum): KILLED: str = 'killed' +@final @auto_persist('args', 'kwargs', 'in_state') class Created(state_machine.State, persistence.Savable): LABEL = ProcessState.CREATED @@ -155,11 +156,12 @@ class Created(state_machine.State, persistence.Savable): is_terminal = False def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: - super().__init__(process) assert run_fn is not None + self.process = process self.run_fn = run_fn self.args = args self.kwargs = kwargs + self.in_state = True def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) @@ -167,21 +169,24 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process + self.process = load_context.process self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) - def execute(self) -> state_machine.State: - return self.create_state(ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs) + async def execute(self) -> state_machine.State: + return self.process.create_state(ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs) - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine + def enter(self) -> None: + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self.in_state = False +@final @auto_persist('args', 'kwargs', 'in_state') class Running(state_machine.State, persistence.Savable): LABEL = ProcessState.RUNNING @@ -204,12 +209,13 @@ class Running(state_machine.State, persistence.Savable): is_terminal = False def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: - super().__init__(process) assert run_fn is not None + self.process = process self.run_fn = run_fn self.args = args self.kwargs = kwargs self._run_handle = None + self.in_state = False def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) @@ -219,7 +225,7 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process + self.process = load_context.process self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) if self.COMMAND in saved_state: @@ -228,7 +234,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi def interrupt(self, reason: Any) -> None: pass - async def execute(self) -> state_machine.State: # type: ignore + async def execute(self) -> state_machine.State: if self._command is not None: command = self._command else: @@ -242,7 +248,7 @@ async def execute(self) -> state_machine.State: # type: ignore # Let this bubble up to the caller raise except Exception: - excepted = self.create_state(ProcessState.EXCEPTED, *sys.exc_info()[1:]) + excepted = self.process.create_state(ProcessState.EXCEPTED, *sys.exc_info()[1:]) return cast(state_machine.State, excepted) else: if not isinstance(result, Command): @@ -259,28 +265,30 @@ async def execute(self) -> state_machine.State: # type: ignore def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> state_machine.State: if isinstance(command, Kill): - state = self.create_state(ProcessState.KILLED, command.msg) + state = self.process.create_state(ProcessState.KILLED, command.msg) # elif isinstance(command, Pause): # self.pause() elif isinstance(command, Stop): - state = self.create_state(ProcessState.FINISHED, command.result, command.successful) + state = self.process.create_state(ProcessState.FINISHED, command.result, command.successful) elif isinstance(command, Wait): - state = self.create_state(ProcessState.WAITING, command.continue_fn, command.msg, command.data) + state = self.process.create_state(ProcessState.WAITING, command.continue_fn, command.msg, command.data) elif isinstance(command, Continue): - state = self.create_state(ProcessState.RUNNING, command.continue_fn, *command.args) + state = self.process.create_state(ProcessState.RUNNING, command.continue_fn, *command.args) else: raise ValueError('Unrecognised command') return cast(state_machine.State, state) # casting from base.State to process.State - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine + def enter(self) -> None: + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + self.in_state = False +@final @auto_persist('msg', 'data', 'in_state') class Waiting(state_machine.State, persistence.Savable): LABEL = ProcessState.WAITING @@ -311,11 +319,12 @@ def __init__( msg: Optional[str] = None, data: Optional[Any] = None, ) -> None: - super().__init__(process) + self.process = process self.done_callback = done_callback self.msg = msg self.data = data self._waiting_future: futures.Future = futures.Future() + self.in_state = False def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) @@ -324,7 +333,7 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process + self.process = load_context.process callback_name = saved_state.get(self.DONE_CALLBACK, None) if callback_name is not None: @@ -348,9 +357,9 @@ async def execute(self) -> state_machine.State: # type: ignore raise if result == NULL: - next_state = self.create_state(ProcessState.RUNNING, self.done_callback) + next_state = self.process.create_state(ProcessState.RUNNING, self.done_callback) else: - next_state = self.create_state(ProcessState.RUNNING, self.done_callback, result) + next_state = self.process.create_state(ProcessState.RUNNING, self.done_callback, result) return cast(state_machine.State, next_state) # casting from base.State to process.State @@ -362,12 +371,14 @@ def resume(self, value: Any = NULL) -> None: self._waiting_future.set_result(value) - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine + def enter(self) -> None: + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self.in_state = False @auto_persist('in_state') @@ -397,9 +408,10 @@ def __init__( :param exception: The exception instance :param trace_back: An optional exception traceback """ - super().__init__(process) + self.process = process self.exception = exception self.traceback = trace_back + self.in_state = False def __str__(self) -> str: exception = traceback.format_exception_only(type(self.exception) if self.exception else None, self.exception)[0] @@ -413,7 +425,7 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process + self.process = load_context.process self.exception = yaml.load(saved_state[self.EXC_VALUE], Loader=Loader) if _HAS_TBLIB: @@ -436,12 +448,17 @@ def get_exc_info( self.traceback, ) - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine + async def execute(self) -> state_machine.State: # type: ignore + ... + + def enter(self) -> None: + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self.in_state = False @auto_persist('result', 'successful', 'in_state') @@ -457,20 +474,26 @@ class Finished(state_machine.State, persistence.Savable): is_terminal = True def __init__(self, process: 'Process', result: Any, successful: bool) -> None: - super().__init__(process) + self.process = process self.result = result self.successful = successful - - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine + self.in_state = False def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process + self.process = load_context.process + + def enter(self) -> None: + self.in_state = True + + async def execute(self) -> state_machine.State: # type: ignore + ... + + def exit(self) -> None: + if self.is_terminal: + raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self.in_state = False @auto_persist('msg', 'in_state') @@ -493,19 +516,24 @@ def __init__(self, process: 'Process', msg: Optional[MessageType]): :param process: The associated process :param msg: Optional kill message """ - super().__init__(process) + self.process = process self.msg = msg - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine + async def execute(self) -> state_machine.State: # type: ignore + ... def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process + self.process = load_context.process + + def enter(self) -> None: + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self.in_state = False # endregion diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 9eafcb50..eefd57f1 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -26,6 +26,7 @@ import kiwipy from plumpy.base import state_machine +from plumpy.exceptions import InvalidStateError from . import lang, mixins, persistence, process_states, processes from .utils import PID_TYPE, SAVED_STATE_TYPE @@ -87,16 +88,6 @@ def __init__( resolved_awaitable = awaitable.future() if isinstance(awaitable, processes.Process) else awaitable self._awaiting[resolved_awaitable] = key - def enter(self) -> None: - super().enter() - for awaitable in self._awaiting: - awaitable.add_done_callback(self._awaitable_done) - - def exit(self) -> None: - super().exit() - for awaitable in self._awaiting: - awaitable.remove_done_callback(self._awaitable_done) - def _awaitable_done(self, awaitable: asyncio.Future) -> None: key = self._awaiting.pop(awaitable) try: @@ -107,6 +98,20 @@ def _awaitable_done(self, awaitable: asyncio.Future) -> None: if not self._awaiting: self._waiting_future.set_result(lang.NULL) + def enter(self) -> None: + for awaitable in self._awaiting: + awaitable.add_done_callback(self._awaitable_done) + + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + self.in_state = False + + for awaitable in self._awaiting: + awaitable.remove_done_callback(self._awaitable_done) + class WorkChain(mixins.ContextMixin, processes.Process): """ diff --git a/tests/base/test_statemachine.py b/tests/base/test_statemachine.py index b6d7e2d3..b6100614 100644 --- a/tests/base/test_statemachine.py +++ b/tests/base/test_statemachine.py @@ -1,8 +1,10 @@ # -*- coding: utf-8 -*- import time +from typing import final import unittest from plumpy.base import state_machine +from plumpy.exceptions import InvalidStateError # Events PLAY = 'Play' @@ -24,24 +26,16 @@ class Playing(state_machine.State): def __init__(self, player, track): assert track is not None, 'Must provide a track name' - super().__init__(player) self.track = track self._last_time = None self._played = 0.0 + self.in_state = False def __str__(self): if self.in_state: self._update_time() return f'> {self.track} ({self._played}s)' - def enter(self): - super().enter() - self._last_time = time.time() - - def exit(self): - super().exit() - self._update_time() - def play(self, track=None): # pylint: disable=no-self-use, unused-argument return False @@ -50,6 +44,17 @@ def _update_time(self): self._played += current_time - self._last_time self._last_time = current_time + def enter(self) -> None: + self._last_time = time.time() + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self._update_time() + self.in_state = False + class Paused(state_machine.State): LABEL = PAUSED @@ -73,6 +78,15 @@ def play(self, track=None): else: self.state_machine.transition_to(self.playing_state) + def enter(self) -> None: + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self.in_state = False + class Stopped(state_machine.State): LABEL = STOPPED @@ -83,12 +97,24 @@ class Stopped(state_machine.State): is_terminal = False + def __init__(self, player): + self.state_machine = player + def __str__(self): return '[]' def play(self, track): self.state_machine.transition_to(Playing(self.state_machine, track=track)) + def enter(self) -> None: + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self.in_state = False + class CdPlayer(state_machine.StateMachine): STATES = (Stopped, Playing, Paused) diff --git a/tests/test_processes.py b/tests/test_processes.py index eb5bf599..4b8cc606 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -653,7 +653,7 @@ def test_exception_during_on_entered(self): class RaisingProcess(Process): def on_entered(self, from_state): - if from_state is not None and from_state.label == ProcessState.RUNNING: + if from_state is not None and from_state.LABEL == ProcessState.RUNNING: raise RuntimeError('exception during on_entered') super().on_entered(from_state) From 080d0364ce160f82ed0477081a9f73c6bc5ec7de Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 2 Dec 2024 23:20:45 +0100 Subject: [PATCH 18/29] Forming Interruptable and Proceedable protocol --- src/plumpy/base/state_machine.py | 20 +++++++++++++++----- src/plumpy/process_states.py | 12 ++---------- src/plumpy/processes.py | 22 +++++++++++++++++++++- 3 files changed, 38 insertions(+), 16 deletions(-) diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index 2164737f..27b1e5f8 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -132,18 +132,28 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any: @runtime_checkable class State(Protocol): LABEL: ClassVar[LABEL_TYPE] + is_terminal: ClassVar[bool] - async def execute(self) -> State | None: + def enter(self) -> None: ... + + def exit(self) -> None: ... + + +@runtime_checkable +class Interruptable(Protocol): + def interrupt(self, reason: Exception) -> None: ... + + +@runtime_checkable +class Proceedable(Protocol): + + def execute(self) -> State | None: """ Execute the state, performing the actions that this state is responsible for. :returns: a state to transition to or None if finished. """ ... - def enter(self) -> None: ... - - def exit(self) -> None: ... - class StateEventHook(enum.Enum): """ diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 88cab660..cc9169c7 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -288,6 +288,7 @@ def exit(self) -> None: self.in_state = False + @final @auto_persist('msg', 'data', 'in_state') class Waiting(state_machine.State, persistence.Savable): @@ -342,7 +343,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi self.done_callback = None self._waiting_future = futures.Future() - def interrupt(self, reason: Any) -> None: + def interrupt(self, reason: Exception) -> None: # This will cause the future in execute() to raise the exception self._waiting_future.set_exception(reason) @@ -448,9 +449,6 @@ def get_exc_info( self.traceback, ) - async def execute(self) -> state_machine.State: # type: ignore - ... - def enter(self) -> None: self.in_state = True @@ -486,9 +484,6 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi def enter(self) -> None: self.in_state = True - async def execute(self) -> state_machine.State: # type: ignore - ... - def exit(self) -> None: if self.is_terminal: raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') @@ -519,9 +514,6 @@ def __init__(self, process: 'Process', msg: Optional[MessageType]): self.process = process self.msg = msg - async def execute(self) -> state_machine.State: # type: ignore - ... - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self.process = load_context.process diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 1e745437..74808291 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -51,7 +51,15 @@ utils, ) from .base import state_machine -from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event +from .base.state_machine import ( + Interruptable, + Proceedable, + StateEntryFailed, + StateMachine, + StateMachineError, + TransitionFailed, + event, +) from .base.utils import call_with_super_check, super_check from .event_helper import EventHelper from .process_comms import MESSAGE_KEY, KillMessage, MessageType @@ -1092,6 +1100,11 @@ def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.Cancellable return self._pausing if self._stepping: + if not isinstance(self._state, Interruptable): + raise exceptions.InvalidStateError( + f'cannot interrupt {self._state.__class__}, method `interrupt` not implement' + ) + # Ask the step function to pause by setting this flag and giving the # caller back a future interrupt_exception = process_states.PauseInterruption(msg) @@ -1103,6 +1116,10 @@ def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.Cancellable return self._do_pause(msg) + @staticmethod + def _interrupt(state: Interruptable, reason: Exception) -> None: + state.interrupt(reason) + def _do_pause(self, state_msg: Optional[str], next_state: Optional[state_machine.State] = None) -> bool: """Carry out the pause procedure, optionally transitioning to the next state first""" try: @@ -1285,6 +1302,9 @@ async def step(self) -> None: if self.paused and self._paused is not None: await self._paused + if not isinstance(self._state, Proceedable): + raise StateMachineError(f'cannot step from {self._state.__class__}, async method `execute` not implemented') + try: self._stepping = True next_state = None From 6bfb87df69889f0082e8ecbe7732d24ad69e64c2 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Tue, 3 Dec 2024 00:48:53 +0100 Subject: [PATCH 19/29] Refactoring create_state as static function initialize state from label create_state refact Hashable initialized + parameters passed to Hashable Fix pre-commit errors --- src/plumpy/base/state_machine.py | 45 +++--- src/plumpy/process_states.py | 235 ++++++++++++++++--------------- src/plumpy/processes.py | 42 +++--- src/plumpy/workchains.py | 10 +- tests/base/test_statemachine.py | 15 +- 5 files changed, 173 insertions(+), 174 deletions(-) diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index 27b1e5f8..fc926008 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -34,7 +34,6 @@ _LOGGER = logging.getLogger(__name__) -LABEL_TYPE = Union[None, enum.Enum, str] EVENT_CALLBACK_TYPE = Callable[['StateMachine', Hashable, Optional['State']], None] @@ -131,9 +130,12 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any: @runtime_checkable class State(Protocol): - LABEL: ClassVar[LABEL_TYPE] + LABEL: ClassVar[Any] + ALLOWED: ClassVar[set[Any]] is_terminal: ClassVar[bool] + def __init__(self, *args: Any, **kwargs: Any): ... + def enter(self) -> None: ... def exit(self) -> None: ... @@ -146,7 +148,6 @@ def interrupt(self, reason: Exception) -> None: ... @runtime_checkable class Proceedable(Protocol): - def execute(self) -> State | None: """ Execute the state, performing the actions that this state is responsible for. @@ -155,6 +156,14 @@ def execute(self) -> State | None: ... +def create_state(st: StateMachine, state_label: Hashable, *args: Any, **kwargs: Any) -> State: + if state_label not in st.get_states_map(): + raise ValueError(f'{state_label} is not a valid state') + + state_cls = st.get_states_map()[state_label] + return state_cls(*args, **kwargs) + + class StateEventHook(enum.Enum): """ Hooks that can be used to register callback at various points in the state transition @@ -203,13 +212,13 @@ def get_states(cls) -> Sequence[Type[State]]: raise RuntimeError('States not defined') @classmethod - def initial_state_label(cls) -> LABEL_TYPE: + def initial_state_label(cls) -> Any: cls.__ensure_built() assert cls.STATES is not None return cls.STATES[0].LABEL @classmethod - def get_state_class(cls, label: LABEL_TYPE) -> Type[State]: + def get_state_class(cls, label: Any) -> Type[State]: cls.__ensure_built() assert cls._STATES_MAP is not None return cls._STATES_MAP[label] @@ -253,11 +262,11 @@ def init(self) -> None: def __str__(self) -> str: return f'<{self.__class__.__name__}> ({self.state})' - def create_initial_state(self) -> State: - return self.get_state_class(self.initial_state_label())(self) + def create_initial_state(self, *args: Any, **kwargs: Any) -> State: + return self.get_state_class(self.initial_state_label())(self, *args, **kwargs) @property - def state(self) -> Optional[LABEL_TYPE]: + def state(self) -> Any: if self._state is None: return None return self._state.LABEL @@ -297,6 +306,7 @@ def transition_to(self, new_state: State | None, **kwargs: Any) -> None: The arguments are passed to the state class to create state instance. (process arg does not need to pass since it will always call with 'self' as process) """ + print(f'try: {self._state} -> {new_state}') assert not self._transitioning, 'Cannot call transition_to when already transitioning state' if new_state is None: @@ -353,17 +363,6 @@ def get_debug(self) -> bool: def set_debug(self, enabled: bool) -> None: self._debug: bool = enabled - def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> State: - # XXX: this method create state from label, which is duplicate as _create_state_instance and less generic - # because the label is defined after the state and required to be know before calling this function. - # This method should be replaced by `_create_state_instance`. - # aiida-core using this method for its Waiting state override. - try: - state_cls = self.get_states_map()[state_label] - return state_cls(self, *args, **kwargs) - except KeyError: - raise ValueError(f'{state_label} is not a valid state') - def _exit_current_state(self, next_state: State) -> None: """Exit the given state""" @@ -375,7 +374,7 @@ def _exit_current_state(self, next_state: State) -> None: return # Nothing to exit if next_state.LABEL not in self._state.ALLOWED: - raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.label}') + raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.LABEL}') self._fire_state_event(StateEventHook.EXITING_STATE, next_state) self._state.exit() @@ -386,9 +385,3 @@ def _enter_next_state(self, next_state: State) -> None: next_state.enter() self._state = next_state self._fire_state_event(StateEventHook.ENTERED_STATE, last_state) - - def _create_state_instance(self, state_cls: type[State], **kwargs: Any) -> State: - if state_cls.LABEL not in self.get_states_map(): - raise ValueError(f'{state_cls.LABEL} is not a valid state') - - return state_cls(self, **kwargs) diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index cc9169c7..5f3e8237 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -5,7 +5,20 @@ import traceback from enum import Enum from types import TracebackType -from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union, cast, final +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Optional, + Protocol, + Tuple, + Type, + Union, + cast, + final, + runtime_checkable, +) import yaml from yaml.loader import Loader @@ -20,9 +33,9 @@ _HAS_TBLIB = False from . import exceptions, futures, persistence, utils -from .base import state_machine +from .base import state_machine as st from .lang import NULL -from .persistence import auto_persist +from .persistence import LoadSaveContext, auto_persist from .utils import SAVED_STATE_TYPE __all__ = [ @@ -138,22 +151,28 @@ class ProcessState(Enum): The possible states that a :class:`~plumpy.processes.Process` can be in. """ - CREATED: str = 'created' - RUNNING: str = 'running' - WAITING: str = 'waiting' - FINISHED: str = 'finished' - EXCEPTED: str = 'excepted' - KILLED: str = 'killed' + # FIXME: see LSP error of return a exception, the type is Literal[str] which is invariant, tricky + CREATED = 'created' + RUNNING = 'running' + WAITING = 'waiting' + FINISHED = 'finished' + EXCEPTED = 'excepted' + KILLED = 'killed' + + +@runtime_checkable +class Savable(Protocol): + def save(self, save_context: LoadSaveContext | None = None) -> SAVED_STATE_TYPE: ... @final -@auto_persist('args', 'kwargs', 'in_state') -class Created(state_machine.State, persistence.Savable): - LABEL = ProcessState.CREATED - ALLOWED = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED} +@auto_persist('args', 'kwargs') +class Created(persistence.Savable): + LABEL: ClassVar = ProcessState.CREATED + ALLOWED: ClassVar = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED} RUN_FN = 'run_fn' - is_terminal = False + is_terminal: ClassVar[bool] = False def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: assert run_fn is not None @@ -161,7 +180,6 @@ def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, * self.run_fn = run_fn self.args = args self.kwargs = kwargs - self.in_state = True def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) @@ -173,24 +191,21 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) - async def execute(self) -> state_machine.State: - return self.process.create_state(ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs) - - def enter(self) -> None: - self.in_state = True + def execute(self) -> st.State: + return st.create_state( + self.process, ProcessState.RUNNING, process=self.process, run_fn=self.run_fn, *self.args, **self.kwargs + ) - def exit(self) -> None: - if self.is_terminal: - raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + def enter(self) -> None: ... - self.in_state = False + def exit(self) -> None: ... @final -@auto_persist('args', 'kwargs', 'in_state') -class Running(state_machine.State, persistence.Savable): - LABEL = ProcessState.RUNNING - ALLOWED = { +@auto_persist('args', 'kwargs') +class Running(persistence.Savable): + LABEL: ClassVar = ProcessState.RUNNING + ALLOWED: ClassVar = { ProcessState.RUNNING, ProcessState.WAITING, ProcessState.FINISHED, @@ -206,7 +221,7 @@ class Running(state_machine.State, persistence.Savable): _running: bool = False _run_handle = None - is_terminal = False + is_terminal: ClassVar[bool] = False def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: assert run_fn is not None @@ -215,7 +230,6 @@ def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, * self.args = args self.kwargs = kwargs self._run_handle = None - self.in_state = False def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) @@ -234,7 +248,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi def interrupt(self, reason: Any) -> None: pass - async def execute(self) -> state_machine.State: + def execute(self) -> st.State: if self._command is not None: command = self._command else: @@ -248,8 +262,10 @@ async def execute(self) -> state_machine.State: # Let this bubble up to the caller raise except Exception: - excepted = self.process.create_state(ProcessState.EXCEPTED, *sys.exc_info()[1:]) - return cast(state_machine.State, excepted) + _, exception, traceback = sys.exc_info() + # excepted = state_cls(exception=exception, traceback=traceback) + excepted = Excepted(exception=exception, traceback=traceback) + return excepted else: if not isinstance(result, Command): if isinstance(result, exceptions.UnsuccessfulResult): @@ -258,42 +274,52 @@ async def execute(self) -> state_machine.State: # Got passed a basic return type result = Stop(result, True) - command = result + command = cast(Stop, result) next_state = self._action_command(command) return next_state - def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> state_machine.State: + def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> st.State: if isinstance(command, Kill): - state = self.process.create_state(ProcessState.KILLED, command.msg) + state = st.create_state(self.process, ProcessState.KILLED, msg=command.msg) # elif isinstance(command, Pause): # self.pause() elif isinstance(command, Stop): - state = self.process.create_state(ProcessState.FINISHED, command.result, command.successful) + state = st.create_state( + self.process, ProcessState.FINISHED, result=command.result, successful=command.successful + ) elif isinstance(command, Wait): - state = self.process.create_state(ProcessState.WAITING, command.continue_fn, command.msg, command.data) + state = st.create_state( + self.process, + ProcessState.WAITING, + process=self.process, + done_callback=command.continue_fn, + msg=command.msg, + data=command.data, + ) elif isinstance(command, Continue): - state = self.process.create_state(ProcessState.RUNNING, command.continue_fn, *command.args) + state = st.create_state( + self.process, + ProcessState.RUNNING, + process=self.process, + run_fn=command.continue_fn, + *command.args, + **command.kwargs, + ) else: raise ValueError('Unrecognised command') - return cast(state_machine.State, state) # casting from base.State to process.State - - def enter(self) -> None: - self.in_state = True + return state - def exit(self) -> None: - if self.is_terminal: - raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + def enter(self) -> None: ... - self.in_state = False + def exit(self) -> None: ... -@final -@auto_persist('msg', 'data', 'in_state') -class Waiting(state_machine.State, persistence.Savable): - LABEL = ProcessState.WAITING - ALLOWED = { +@auto_persist('msg', 'data') +class Waiting(persistence.Savable): + LABEL: ClassVar = ProcessState.WAITING + ALLOWED: ClassVar = { ProcessState.RUNNING, ProcessState.WAITING, ProcessState.KILLED, @@ -305,7 +331,7 @@ class Waiting(state_machine.State, persistence.Savable): _interruption = None - is_terminal = False + is_terminal: ClassVar[bool] = False def __str__(self) -> str: state_info = super().__str__() @@ -325,7 +351,6 @@ def __init__( self.msg = msg self.data = data self._waiting_future: futures.Future = futures.Future() - self.in_state = False def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: super().save_instance_state(out_state, save_context) @@ -347,7 +372,7 @@ def interrupt(self, reason: Exception) -> None: # This will cause the future in execute() to raise the exception self._waiting_future.set_exception(reason) - async def execute(self) -> state_machine.State: # type: ignore + async def execute(self) -> st.State: try: result = await self._waiting_future except Interruption: @@ -358,11 +383,15 @@ async def execute(self) -> state_machine.State: # type: ignore raise if result == NULL: - next_state = self.process.create_state(ProcessState.RUNNING, self.done_callback) + next_state = st.create_state( + self.process, ProcessState.RUNNING, process=self.process, run_fn=self.done_callback + ) else: - next_state = self.process.create_state(ProcessState.RUNNING, self.done_callback, result) + next_state = st.create_state( + self.process, ProcessState.RUNNING, process=self.process, done_callback=self.done_callback, *result + ) - return cast(state_machine.State, next_state) # casting from base.State to process.State + return next_state def resume(self, value: Any = NULL) -> None: assert self._waiting_future is not None, 'Not yet waiting' @@ -372,47 +401,39 @@ def resume(self, value: Any = NULL) -> None: self._waiting_future.set_result(value) - def enter(self) -> None: - self.in_state = True + def enter(self) -> None: ... - def exit(self) -> None: - if self.is_terminal: - raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + def exit(self) -> None: ... - self.in_state = False - -@auto_persist('in_state') -class Excepted(state_machine.State, persistence.Savable): +@final +class Excepted(persistence.Savable): """ - Excepted state, can optionally provide exception and trace_back + Excepted state, can optionally provide exception and traceback :param exception: The exception instance - :param trace_back: An optional exception traceback + :param traceback: An optional exception traceback """ - LABEL = ProcessState.EXCEPTED + LABEL: ClassVar = ProcessState.EXCEPTED + ALLOWED: ClassVar[set[str]] = set() EXC_VALUE = 'ex_value' TRACEBACK = 'traceback' - is_terminal = True + is_terminal: ClassVar = True def __init__( self, - process: 'Process', exception: Optional[BaseException], - trace_back: Optional[TracebackType] = None, + traceback: Optional[TracebackType] = None, ): """ - :param process: The associated process :param exception: The exception instance - :param trace_back: An optional exception traceback + :param traceback: An optional exception traceback """ - self.process = process self.exception = exception - self.traceback = trace_back - self.in_state = False + self.traceback = traceback def __str__(self) -> str: exception = traceback.format_exception_only(type(self.exception) if self.exception else None, self.exception)[0] @@ -426,7 +447,6 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.process = load_context.process self.exception = yaml.load(saved_state[self.EXC_VALUE], Loader=Loader) if _HAS_TBLIB: @@ -449,50 +469,40 @@ def get_exc_info( self.traceback, ) - def enter(self) -> None: - self.in_state = True + def enter(self) -> None: ... - def exit(self) -> None: - if self.is_terminal: - raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + def exit(self) -> None: ... - self.in_state = False - -@auto_persist('result', 'successful', 'in_state') -class Finished(state_machine.State, persistence.Savable): +@final +@auto_persist('result', 'successful') +class Finished(persistence.Savable): """State for process is finished. :param result: The result of process :param successful: Boolean for the exit code is ``0`` the process is successful. """ - LABEL = ProcessState.FINISHED + LABEL: ClassVar = ProcessState.FINISHED + ALLOWED: ClassVar[set[str]] = set() - is_terminal = True + is_terminal: ClassVar[bool] = True - def __init__(self, process: 'Process', result: Any, successful: bool) -> None: - self.process = process + def __init__(self, result: Any, successful: bool) -> None: self.result = result self.successful = successful - self.in_state = False def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.process = load_context.process - def enter(self) -> None: - self.in_state = True + def enter(self) -> None: ... - def exit(self) -> None: - if self.is_terminal: - raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + def exit(self) -> None: ... - self.in_state = False - -@auto_persist('msg', 'in_state') -class Killed(state_machine.State, persistence.Savable): +@final +@auto_persist('msg') +class Killed(persistence.Savable): """ Represents a state where a process has been killed. @@ -502,30 +512,23 @@ class Killed(state_machine.State, persistence.Savable): :param msg: An optional message explaining the reason for the process termination. """ - LABEL = ProcessState.KILLED + LABEL: ClassVar = ProcessState.KILLED + ALLOWED: ClassVar[set[str]] = set() - is_terminal = True + is_terminal: ClassVar[bool] = True - def __init__(self, process: 'Process', msg: Optional[MessageType]): + def __init__(self, msg: Optional[MessageType]): """ - :param process: The associated process :param msg: Optional kill message """ - self.process = process self.msg = msg def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.process = load_context.process - - def enter(self) -> None: - self.in_state = True - def exit(self) -> None: - if self.is_terminal: - raise exceptions.InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + def enter(self) -> None: ... - self.in_state = False + def exit(self) -> None: ... # endregion diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 74808291..bae08dd4 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- """The main Process module""" +from __future__ import annotations + import abc import asyncio import contextlib @@ -58,6 +60,7 @@ StateMachine, StateMachineError, TransitionFailed, + create_state, event, ) from .base.utils import call_with_super_check, super_check @@ -194,7 +197,7 @@ def get_states(cls) -> Sequence[Type[state_machine.State]]: ) @classmethod - def get_state_classes(cls) -> Dict[Hashable, Type[state_machine.State]]: + def get_state_classes(cls) -> dict[process_states.ProcessState, Type[state_machine.State]]: # A mapping of the State constants to the corresponding state class return { process_states.ProcessState.CREATED: process_states.Created, @@ -633,7 +636,9 @@ def save_instance_state( """ super().save_instance_state(out_state, save_context) - out_state['_state'] = self._state.save() + # FIXME: the combined ProcessState protocol should cover the case + if isinstance(self._state, process_states.Savable): + out_state['_state'] = self._state.save() # Inputs/outputs if self.raw_inputs is not None: @@ -876,7 +881,7 @@ def on_finish(self, result: Any, successful: bool) -> None: validation_error = self.spec().outputs.validate(self.outputs) if validation_error: state_cls = self.get_states_map()[process_states.ProcessState.FINISHED] - finished_state = state_cls(self, result=result, successful=False) + finished_state = state_cls(result=result, successful=False) raise StateEntryFailed(finished_state) self.future().set_result(self.outputs) @@ -1074,8 +1079,8 @@ def transition_failed( if final_state == process_states.ProcessState.CREATED: raise exception.with_traceback(trace) - state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED] - new_state = self._create_state_instance(state_class, exception=exception, trace_back=trace) + # state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED] + new_state = create_state(self, process_states.ProcessState.EXCEPTED, exception=exception, traceback=trace) self.transition_to(new_state) def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]: @@ -1148,10 +1153,11 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu def do_kill(_next_state: state_machine.State) -> Any: try: - state_class = self.get_states_map()[process_states.ProcessState.KILLED] - new_state = self._create_state_instance(state_class, msg=exception.msg) + new_state = create_state(self, process_states.ProcessState.KILLED, msg=exception.msg) self.transition_to(new_state) return True + # FIXME: if try block except, will hit deadlock in event loop + # need to know how to debug it, and where to set a timeout. finally: self._killing = None @@ -1196,14 +1202,14 @@ def resume(self, *args: Any) -> None: return self._state.resume(*args) # type: ignore @event(to_states=process_states.Excepted) - def fail(self, exception: Optional[BaseException], trace_back: Optional[TracebackType]) -> None: + def fail(self, exception: Optional[BaseException], traceback: Optional[TracebackType]) -> None: """ Fail the process in response to an exception :param exception: The exception that caused the failure - :param trace_back: Optional exception traceback + :param traceback: Optional exception traceback """ - state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED] - new_state = self._create_state_instance(state_class, exception=exception, trace_back=trace_back) + # state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED] + new_state = create_state(self, process_states.ProcessState.EXCEPTED, exception=exception, traceback=traceback) self.transition_to(new_state) def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]: @@ -1223,7 +1229,7 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future] # Already killing return self._killing - if self._stepping: + if self._stepping and isinstance(self._state, Interruptable): # Ask the step function to pause by setting this flag and giving the # caller back a future interrupt_exception = process_states.KillInterruption(msg) @@ -1232,8 +1238,7 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future] self._state.interrupt(interrupt_exception) return cast(futures.CancellableAction, self._interrupt_action) - state_class = self.get_states_map()[process_states.ProcessState.KILLED] - new_state = self._create_state_instance(state_class, msg=msg) + new_state = create_state(self, process_states.ProcessState.KILLED, msg=msg) self.transition_to(new_state) return True @@ -1251,10 +1256,7 @@ def create_initial_state(self) -> state_machine.State: :return: A Created state """ - return cast( - state_machine.State, - self.get_state_class(process_states.ProcessState.CREATED)(self, self.run), - ) + return self.get_state_class(process_states.ProcessState.CREATED)(self, self.run) def recreate_state(self, saved_state: persistence.Bundle) -> state_machine.State: """ @@ -1325,7 +1327,9 @@ async def step(self) -> None: raise except Exception: # Overwrite the next state to go to excepted directly - next_state = self.create_state(process_states.ProcessState.EXCEPTED, *sys.exc_info()[1:]) + next_state = create_state( + self, process_states.ProcessState.EXCEPTED, exception=sys.exc_info()[1], traceback=sys.exc_info()[2] + ) self._set_interrupt_action(None) if self._interrupt_action: diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index eefd57f1..865a5b61 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -11,7 +11,6 @@ Any, Callable, Dict, - Hashable, List, Mapping, MutableSequence, @@ -71,6 +70,7 @@ def get_outline(self) -> Union['_Instruction', '_FunctionCall']: return self._outline +# FIXME: better use composition here @persistence.auto_persist('_awaiting') class Waiting(process_states.Waiting): """Overwrite the waiting state""" @@ -80,11 +80,11 @@ def __init__( process: 'WorkChain', done_callback: Optional[Callable[..., Any]], msg: Optional[str] = None, - awaiting: Optional[Dict[Union[asyncio.Future, processes.Process], str]] = None, + data: Optional[Dict[Union[asyncio.Future, processes.Process], str]] = None, ) -> None: - super().__init__(process, done_callback, msg, awaiting) + super().__init__(process, done_callback, msg, data) self._awaiting: Dict[asyncio.Future, str] = {} - for awaitable, key in (awaiting or {}).items(): + for awaitable, key in (data or {}).items(): resolved_awaitable = awaitable.future() if isinstance(awaitable, processes.Process) else awaitable self._awaiting[resolved_awaitable] = key @@ -124,7 +124,7 @@ class WorkChain(mixins.ContextMixin, processes.Process): _CONTEXT = 'CONTEXT' @classmethod - def get_state_classes(cls) -> Dict[Hashable, Type[state_machine.State]]: + def get_state_classes(cls) -> Dict[process_states.ProcessState, Type[state_machine.State]]: states_map = super().get_state_classes() states_map[process_states.ProcessState.WAITING] = Waiting return states_map diff --git a/tests/base/test_statemachine.py b/tests/base/test_statemachine.py index b6100614..6a61fe00 100644 --- a/tests/base/test_statemachine.py +++ b/tests/base/test_statemachine.py @@ -17,7 +17,7 @@ STOPPED = 'Stopped' -class Playing(state_machine.State): +class Playing: LABEL = PLAYING ALLOWED = {PAUSED, STOPPED} TRANSITIONS = {STOP: STOPPED} @@ -56,7 +56,7 @@ def exit(self) -> None: self.in_state = False -class Paused(state_machine.State): +class Paused: LABEL = PAUSED ALLOWED = {PLAYING, STOPPED} TRANSITIONS = {STOP: STOPPED} @@ -65,7 +65,6 @@ class Paused(state_machine.State): def __init__(self, player, playing_state): assert isinstance(playing_state, Playing), 'Must provide the playing state to pause' - super().__init__(player) self._player = player self.playing_state = playing_state @@ -74,9 +73,9 @@ def __str__(self): def play(self, track=None): if track is not None: - self.state_machine.transition_to(Playing(player=self.state_machine, track=track)) + self._player.transition_to(Playing(player=self.state_machine, track=track)) else: - self.state_machine.transition_to(self.playing_state) + self._player.transition_to(self.playing_state) def enter(self) -> None: self.in_state = True @@ -88,7 +87,7 @@ def exit(self) -> None: self.in_state = False -class Stopped(state_machine.State): +class Stopped: LABEL = STOPPED ALLOWED = { PLAYING, @@ -98,13 +97,13 @@ class Stopped(state_machine.State): is_terminal = False def __init__(self, player): - self.state_machine = player + self._player = player def __str__(self): return '[]' def play(self, track): - self.state_machine.transition_to(Playing(self.state_machine, track=track)) + self._player.transition_to(Playing(self._player, track=track)) def enter(self) -> None: self.in_state = True From ef964ed792c9e4b46b7c8e0dd4a21e12bc4b70a7 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Wed, 4 Dec 2024 16:28:43 +0100 Subject: [PATCH 20/29] To lenthy for rethinking --- src/plumpy/persistence.py | 82 ++++++++++++++++-------------------- src/plumpy/process_states.py | 1 - 2 files changed, 37 insertions(+), 46 deletions(-) diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index ba755bc5..a1d083cb 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -432,28 +432,6 @@ class Savable: _auto_persist: Optional[Set[str]] = None _persist_configured = False - @staticmethod - def load(saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': - """ - Load a `Savable` from a saved instance state. The load context is a way of passing - runtime data to the object being loaded. - - :param saved_state: The saved state - :param load_context: Additional runtime state that can be passed into when loading. - The type and content (if any) is completely user defined - :return: The loaded Savable instance - - """ - load_context = _ensure_object_loader(load_context, saved_state) - assert load_context.loader is not None # required for type checking - try: - class_name = Savable._get_class_name(saved_state) - load_cls = load_context.loader.load_object(class_name) - except KeyError: - raise ValueError('Class name not found in saved state') - else: - return load_cls.recreate_from(saved_state, load_context) - @classmethod def auto_persist(cls, *members: str) -> None: if cls._auto_persist is None: @@ -484,13 +462,48 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext]) -> None: self._ensure_persist_configured() if self._auto_persist is not None: - self.load_members(self._auto_persist, saved_state, load_context) + for member in self._auto_persist: + setattr(self, member, self._get_value(saved_state, member, load_context)) + + @staticmethod + def load(saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Load a `Savable` from a saved instance state. The load context is a way of passing + runtime data to the object being loaded. + + :param saved_state: The saved state + :param load_context: Additional runtime state that can be passed into when loading. + The type and content (if any) is completely user defined + :return: The loaded Savable instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + assert load_context.loader is not None # required for type checking + try: + class_name = Savable._get_class_name(saved_state) + load_cls = load_context.loader.load_object(class_name) + except KeyError: + raise ValueError('Class name not found in saved state') + else: + return load_cls.recreate_from(saved_state, load_context) @super_check def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: Optional[LoadSaveContext]) -> None: self._ensure_persist_configured() if self._auto_persist is not None: - self.save_members(self._auto_persist, out_state) + for member in self._auto_persist: + value = getattr(self, member) + if inspect.ismethod(value): + if value.__self__ is not self: + raise TypeError('Cannot persist methods of other classes') + Savable._set_meta_type(out_state, member, META__TYPE__METHOD) + value = value.__name__ + elif isinstance(value, Savable): + Savable._set_meta_type(out_state, member, META__TYPE__SAVABLE) + value = value.save() + else: + value = copy.deepcopy(value) + out_state[member] = value def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: out_state: SAVED_STATE_TYPE = {} @@ -513,27 +526,6 @@ def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TY call_with_super_check(self.save_instance_state, out_state, save_context) return out_state - def save_members(self, members: Iterable[str], out_state: SAVED_STATE_TYPE) -> None: - for member in members: - value = getattr(self, member) - if inspect.ismethod(value): - if value.__self__ is not self: - raise TypeError('Cannot persist methods of other classes') - Savable._set_meta_type(out_state, member, META__TYPE__METHOD) - value = value.__name__ - elif isinstance(value, Savable): - Savable._set_meta_type(out_state, member, META__TYPE__SAVABLE) - value = value.save() - else: - value = copy.deepcopy(value) - out_state[member] = value - - def load_members( - self, members: Iterable[str], saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None - ) -> None: - for member in members: - setattr(self, member, self._get_value(saved_state, member, load_context)) - def _ensure_persist_configured(self) -> None: if not self._persist_configured: self.persist() diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 5f3e8237..d2743d81 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -151,7 +151,6 @@ class ProcessState(Enum): The possible states that a :class:`~plumpy.processes.Process` can be in. """ - # FIXME: see LSP error of return a exception, the type is Literal[str] which is invariant, tricky CREATED = 'created' RUNNING = 'running' WAITING = 'waiting' From 937ad012f0f55e2ec874d4d5dbbf34b11c35d66f Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Wed, 4 Dec 2024 23:33:13 +0100 Subject: [PATCH 21/29] Move static method load outside --- src/plumpy/persistence.py | 102 +++++++++++++++++++------------------- src/plumpy/processes.py | 2 +- 2 files changed, 52 insertions(+), 52 deletions(-) diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index a1d083cb..6c3849a6 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -35,8 +35,33 @@ from .processes import Process +class LoadSaveContext: + def __init__(self, loader: Optional[loaders.ObjectLoader] = None, **kwargs: Any) -> None: + self._values = dict(**kwargs) + self.loader = loader + + def __getattr__(self, item: str) -> Any: + try: + return self._values[item] + except KeyError: + raise AttributeError(f"item '{item}' not found") + + def __iter__(self) -> Iterable[Any]: + return self._value.__iter__() + + def __contains__(self, item: Any) -> bool: + return self._values.__contains__(item) + + def copyextend(self, **kwargs: Any) -> 'LoadSaveContext': + """Add additional information to the context by making a copy with the new values""" + extended = self._values.copy() + extended.update(kwargs) + loader = extended.pop('loader', self.loader) + return LoadSaveContext(loader=loader, **extended) + + class Bundle(dict): - def __init__(self, savable: 'Savable', save_context: Optional['LoadSaveContext'] = None, dereference: bool = False): + def __init__(self, savable: 'Savable', save_context: LoadSaveContext | None = None, dereference: bool = False): """ Create a bundle from a savable. Optionally keep information about the class loader that can be used to load the classes in the bundle. @@ -52,7 +77,7 @@ class loader that can be used to load the classes in the bundle. else: self.update(savable.save(save_context)) - def unbundle(self, load_context: Optional['LoadSaveContext'] = None) -> 'Savable': + def unbundle(self, load_context: LoadSaveContext | None = None) -> 'Savable': """ This method loads the class of the object and calls its recreate_from method passing the positional and keyword arguments. @@ -61,7 +86,29 @@ def unbundle(self, load_context: Optional['LoadSaveContext'] = None) -> 'Savable :return: An instance of the Savable """ - return Savable.load(self, load_context) + return load(self, load_context) + + +def load(saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None = None) -> 'Savable': + """ + Load a `Savable` from a saved instance state. The load context is a way of passing + runtime data to the object being loaded. + + :param saved_state: The saved state + :param load_context: Additional runtime state that can be passed into when loading. + The type and content (if any) is completely user defined + :return: The loaded Savable instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + assert load_context.loader is not None # required for type checking + try: + class_name = Savable._get_class_name(saved_state) + load_cls = load_context.loader.load_object(class_name) + except KeyError: + raise ValueError('Class name not found in saved state') + else: + return load_cls.recreate_from(saved_state, load_context) _BUNDLE_TAG = '!plumpy:Bundle' @@ -392,31 +439,6 @@ def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAV return context.copyextend(loader=loader) -class LoadSaveContext: - def __init__(self, loader: Optional[loaders.ObjectLoader] = None, **kwargs: Any) -> None: - self._values = dict(**kwargs) - self.loader = loader - - def __getattr__(self, item: str) -> Any: - try: - return self._values[item] - except KeyError: - raise AttributeError(f"item '{item}' not found") - - def __iter__(self) -> Iterable[Any]: - return self._value.__iter__() - - def __contains__(self, item: Any) -> bool: - return self._values.__contains__(item) - - def copyextend(self, **kwargs: Any) -> 'LoadSaveContext': - """Add additional information to the context by making a copy with the new values""" - extended = self._values.copy() - extended.update(kwargs) - loader = extended.pop('loader', self.loader) - return LoadSaveContext(loader=loader, **extended) - - META: str = '!!meta' META__CLASS_NAME: str = 'class_name' META__OBJECT_LOADER: str = 'object_loader' @@ -465,28 +487,6 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: Optio for member in self._auto_persist: setattr(self, member, self._get_value(saved_state, member, load_context)) - @staticmethod - def load(saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': - """ - Load a `Savable` from a saved instance state. The load context is a way of passing - runtime data to the object being loaded. - - :param saved_state: The saved state - :param load_context: Additional runtime state that can be passed into when loading. - The type and content (if any) is completely user defined - :return: The loaded Savable instance - - """ - load_context = _ensure_object_loader(load_context, saved_state) - assert load_context.loader is not None # required for type checking - try: - class_name = Savable._get_class_name(saved_state) - load_cls = load_context.loader.load_object(class_name) - except KeyError: - raise ValueError('Class name not found in saved state') - else: - return load_cls.recreate_from(saved_state, load_context) - @super_check def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: Optional[LoadSaveContext]) -> None: self._ensure_persist_configured() @@ -580,7 +580,7 @@ def _get_value( if typ == META__TYPE__METHOD: value = getattr(self, value) elif typ == META__TYPE__SAVABLE: - value = Savable.load(value, load_context) + value = load(value, load_context) return value diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index bae08dd4..56fdc570 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -1266,7 +1266,7 @@ def recreate_state(self, saved_state: persistence.Bundle) -> state_machine.State :return: An instance of the object with its state loaded from the save state. """ load_context = persistence.LoadSaveContext(process=self) - return cast(state_machine.State, persistence.Savable.load(saved_state, load_context)) + return cast(state_machine.State, persistence.load(saved_state, load_context)) # endregion From 304f3bab15ad6c3cd12ac53a2ce13b0b40edf95b Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 9 Dec 2024 11:01:32 +0100 Subject: [PATCH 22/29] save_instance_state simplify to only has save interface For the auto_persist attributes, the fn auto_save will take care of save the state --- src/plumpy/mixins.py | 13 ------ src/plumpy/persistence.py | 82 ++++++++++++++++++--------------- src/plumpy/process_states.py | 45 ++++++++++++------ src/plumpy/processes.py | 15 +++--- src/plumpy/workchains.py | 89 +++++++++++++++++++++++++++++++----- tests/test_processes.py | 2 +- 6 files changed, 160 insertions(+), 86 deletions(-) diff --git a/src/plumpy/mixins.py b/src/plumpy/mixins.py index 10142eb7..9dfa7539 100644 --- a/src/plumpy/mixins.py +++ b/src/plumpy/mixins.py @@ -23,19 +23,6 @@ def __init__(self, *args: Any, **kwargs: Any): def ctx(self) -> Optional[AttributesDict]: return self._context - def save_instance_state( - self, out_state: SAVED_STATE_TYPE, save_context: Optional[persistence.LoadSaveContext] - ) -> None: - """Add the instance state to ``out_state``. - .. important:: - - The instance state will contain a pointer to the ``ctx``, - and so should be deep copied or serialised before persisting. - """ - super().save_instance_state(out_state, save_context) - if self._context is not None: - out_state[self.CONTEXT] = self._context.__dict__ - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) try: diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index 6c3849a6..ccdeef26 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -104,7 +104,7 @@ def load(saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None = N assert load_context.loader is not None # required for type checking try: class_name = Savable._get_class_name(saved_state) - load_cls = load_context.loader.load_object(class_name) + load_cls: Savable = load_context.loader.load_object(class_name) except KeyError: raise ValueError('Class name not found in saved state') else: @@ -487,43 +487,9 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: Optio for member in self._auto_persist: setattr(self, member, self._get_value(saved_state, member, load_context)) - @super_check - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: Optional[LoadSaveContext]) -> None: - self._ensure_persist_configured() - if self._auto_persist is not None: - for member in self._auto_persist: - value = getattr(self, member) - if inspect.ismethod(value): - if value.__self__ is not self: - raise TypeError('Cannot persist methods of other classes') - Savable._set_meta_type(out_state, member, META__TYPE__METHOD) - value = value.__name__ - elif isinstance(value, Savable): - Savable._set_meta_type(out_state, member, META__TYPE__SAVABLE) - value = value.save() - else: - value = copy.deepcopy(value) - out_state[member] = value - def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: - out_state: SAVED_STATE_TYPE = {} - - if save_context is None: - save_context = LoadSaveContext() + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) - utils.type_check(save_context, LoadSaveContext) - - default_loader = loaders.get_object_loader() - # If the user has specified a class loader, then save it in the saved state - if save_context.loader is not None: - loader_class = default_loader.identify_object(save_context.loader.__class__) - Savable.set_custom_meta(out_state, META__OBJECT_LOADER, loader_class) - loader = save_context.loader - else: - loader = default_loader - - Savable._set_class_name(out_state, loader.identify_object(self.__class__)) - call_with_super_check(self.save_instance_state, out_state, save_context) return out_state def _ensure_persist_configured(self) -> None: @@ -593,11 +559,13 @@ class SavableFuture(futures.Future, Savable): .. note: This does not save any assigned done callbacks. """ - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) if self.done() and self.exception() is not None: out_state['exception'] = self.exception() + return out_state + @classmethod def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': """ @@ -643,3 +611,41 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: LoadS # typing says asyncio.Future._callbacks needs to be called, but in the python 3.7 code it is a simple list for callback in self._callbacks: self.remove_done_callback(callback) # type: ignore[arg-type] + + +def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = {} + + if save_context is None: + save_context = LoadSaveContext() + + utils.type_check(save_context, LoadSaveContext) + + default_loader = loaders.get_object_loader() + # If the user has specified a class loader, then save it in the saved state + if save_context.loader is not None: + loader_class = default_loader.identify_object(save_context.loader.__class__) + Savable.set_custom_meta(out_state, META__OBJECT_LOADER, loader_class) + loader = save_context.loader + else: + loader = default_loader + + Savable._set_class_name(out_state, loader.identify_object(obj.__class__)) + + obj._ensure_persist_configured() + if obj._auto_persist is not None: + for member in obj._auto_persist: + value = getattr(obj, member) + if inspect.ismethod(value): + if value.__self__ is not obj: + raise TypeError('Cannot persist methods of other classes') + Savable._set_meta_type(out_state, member, META__TYPE__METHOD) + value = value.__name__ + elif isinstance(value, Savable): + Savable._set_meta_type(out_state, member, META__TYPE__SAVABLE) + value = value.save() + else: + value = copy.deepcopy(value) + out_state[member] = value + + return out_state diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index d2743d81..0f811cb6 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- from __future__ import annotations +import copy +import inspect import sys import traceback from enum import Enum @@ -23,6 +25,7 @@ import yaml from yaml.loader import Loader +from plumpy import loaders from plumpy.process_comms import KillMessage, MessageType try: @@ -35,7 +38,7 @@ from . import exceptions, futures, persistence, utils from .base import state_machine as st from .lang import NULL -from .persistence import LoadSaveContext, auto_persist +from .persistence import META__OBJECT_LOADER, META__TYPE__METHOD, META__TYPE__SAVABLE, LoadSaveContext, Savable, auto_persist, auto_save from .utils import SAVED_STATE_TYPE __all__ = [ @@ -127,10 +130,12 @@ def __init__(self, continue_fn: Callable[..., Any], *args: Any, **kwargs: Any): self.args = args self.kwargs = kwargs - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) out_state[self.CONTINUE_FN] = self.continue_fn.__name__ + return out_state + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self.state_machine = load_context.process @@ -159,10 +164,9 @@ class ProcessState(Enum): KILLED = 'killed' -@runtime_checkable -class Savable(Protocol): - def save(self, save_context: LoadSaveContext | None = None) -> SAVED_STATE_TYPE: ... - +# @runtime_checkable +# class Savable(Protocol): +# def save(self, save_context: LoadSaveContext | None = None) -> SAVED_STATE_TYPE: ... @final @auto_persist('args', 'kwargs') @@ -180,10 +184,12 @@ def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, * self.args = args self.kwargs = kwargs - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) out_state[self.RUN_FN] = self.run_fn.__name__ + return out_state + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self.process = load_context.process @@ -230,12 +236,15 @@ def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, * self.kwargs = kwargs self._run_handle = None - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + out_state[self.RUN_FN] = self.run_fn.__name__ if self._command is not None: out_state[self.COMMAND] = self._command.save() + return out_state + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self.process = load_context.process @@ -351,11 +360,14 @@ def __init__( self.data = data self._waiting_future: futures.Future = futures.Future() - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + if self.done_callback is not None: out_state[self.DONE_CALLBACK] = self.done_callback.__name__ + return out_state + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self.process = load_context.process @@ -438,12 +450,15 @@ def __str__(self) -> str: exception = traceback.format_exception_only(type(self.exception) if self.exception else None, self.exception)[0] return super().__str__() + f'({exception})' - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + out_state[self.EXC_VALUE] = yaml.dump(self.exception) if self.traceback is not None: out_state[self.TRACEBACK] = ''.join(traceback.format_tb(self.traceback)) + return out_state + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 56fdc570..61c2ff46 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -9,6 +9,7 @@ import copy import enum import functools +import inspect import logging import re import sys @@ -33,6 +34,8 @@ cast, ) +from plumpy import loaders + try: from aiocontextvars import ContextVar except ModuleNotFoundError: @@ -82,7 +85,7 @@ class BundleKeys: """ String keys used by the process to save its state in the state bundle. - See :meth:`plumpy.processes.Process.save_instance_state` and :meth:`plumpy.processes.Process.load_instance_state`. + See :meth:`plumpy.processes.Process.save` and :meth:`plumpy.processes.Process.load_instance_state`. """ @@ -623,18 +626,14 @@ async def _run_task(self, callback: Callable[..., T], *args: Any, **kwargs: Any) # region Persistence - def save_instance_state( - self, - out_state: SAVED_STATE_TYPE, - save_context: Optional[persistence.LoadSaveContext], - ) -> None: + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: """ Ask the process to save its current instance state. :param out_state: A bundle to save the state to :param save_context: The save context """ - super().save_instance_state(out_state, save_context) + out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) # FIXME: the combined ProcessState protocol should cover the case if isinstance(self._state, process_states.Savable): @@ -650,6 +649,8 @@ def save_instance_state( if self.outputs: out_state[BundleKeys.OUTPUTS] = self.encode_input_args(self.outputs) + return out_state + @protected def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: """Load the process from its saved instance state. diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 865a5b61..9741f7ed 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import annotations +import copy import abc import asyncio import collections @@ -29,6 +30,7 @@ from . import lang, mixins, persistence, process_states, processes from .utils import PID_TYPE, SAVED_STATE_TYPE +from plumpy import loaders, utils __all__ = ['ToContext', 'WorkChain', 'WorkChainSpec', 'if_', 'return_', 'while_'] @@ -149,15 +151,69 @@ def on_create(self) -> None: super().on_create() self._stepper = self.spec().get_outline().create_stepper(self) - def save_instance_state( - self, out_state: SAVED_STATE_TYPE, save_context: Optional[persistence.LoadSaveContext] - ) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: + """ + Ask the process to save its current instance state. + + :param out_state: A bundle to save the state to + :param save_context: The save context + """ + out_state: SAVED_STATE_TYPE = {} + + if save_context is None: + save_context = persistence.LoadSaveContext() + + utils.type_check(save_context, persistence.LoadSaveContext) + + default_loader = loaders.get_object_loader() + # If the user has specified a class loader, then save it in the saved state + if save_context.loader is not None: + loader_class = default_loader.identify_object(save_context.loader.__class__) + persistence.Savable.set_custom_meta(out_state, persistence.META__OBJECT_LOADER, loader_class) + loader = save_context.loader + else: + loader = default_loader + + persistence.Savable._set_class_name(out_state, loader.identify_object(self.__class__)) + + self._ensure_persist_configured() + if self._auto_persist is not None: + for member in self._auto_persist: + value = getattr(self, member) + if inspect.ismethod(value): + if value.__self__ is not self: + raise TypeError('Cannot persist methods of other classes') + persistence.Savable._set_meta_type(out_state, member, persistence.META__TYPE__METHOD) + value = value.__name__ + elif isinstance(value, persistence.Savable): + persistence.Savable._set_meta_type(out_state, member, persistence.META__TYPE__SAVABLE) + value = value.save() + else: + value = copy.deepcopy(value) + out_state[member] = value + + if isinstance(self._state, process_states.Savable): + out_state['_state'] = self._state.save() + + # Inputs/outputs + if self.raw_inputs is not None: + out_state[processes.BundleKeys.INPUTS_RAW] = self.encode_input_args(self.raw_inputs) + + if self.inputs is not None: + out_state[processes.BundleKeys.INPUTS_PARSED] = self.encode_input_args(self.inputs) + + if self.outputs: + out_state[processes.BundleKeys.OUTPUTS] = self.encode_input_args(self.outputs) # Ask the stepper to save itself if self._stepper is not None: out_state[self._STEPPER_STATE] = self._stepper.save() + if self._context is not None: + out_state[self.CONTEXT] = self._context.__dict__ + + return out_state + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) @@ -253,10 +309,12 @@ def __init__(self, workchain: 'WorkChain', fn: WC_COMMAND_TYPE): super().__init__(workchain) self._fn = fn - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) out_state['_fn'] = self._fn.__name__ + return out_state + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self._fn = getattr(self._workchain.__class__, saved_state['_fn']) @@ -326,11 +384,13 @@ def next_instruction(self) -> None: def finished(self) -> bool: return self._pos == len(self._block) - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) if self._child_stepper is not None: out_state[STEPPER_STATE] = self._child_stepper.save() + return out_state + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self._block = load_context.block_instruction @@ -464,11 +524,13 @@ def step(self) -> Tuple[bool, Any]: def finished(self) -> bool: return self._pos == len(self._if_instruction) - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) if self._child_stepper is not None: out_state[STEPPER_STATE] = self._child_stepper.save() + return out_state + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self._if_instruction = load_context.if_instruction @@ -558,11 +620,14 @@ def step(self) -> Tuple[bool, Any]: return False, result - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) + if self._child_stepper is not None: out_state[STEPPER_STATE] = self._child_stepper.save() + return out_state + def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) self._while_instruction = load_context.while_instruction diff --git a/tests/test_processes.py b/tests/test_processes.py index 4b8cc606..7fa33bb1 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -700,7 +700,7 @@ def step2(self): class TestProcessSaving(unittest.TestCase): maxDiff = None - def test_running_save_instance_state(self): + def test_running_save(self): loop = asyncio.get_event_loop() nsync_comeback = SavePauseProc() From ce6beae1a891a02370801f3151864a34dba4beaa Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 9 Dec 2024 17:07:58 +0100 Subject: [PATCH 23/29] WIP: load_instance_state deabstract simplify - stepper de-abstract - remove ContextMixin - Stepper all using recreate_from --- src/plumpy/__init__.py | 2 - src/plumpy/mixins.py | 31 ------ src/plumpy/persistence.py | 17 +-- src/plumpy/process_states.py | 175 ++++++++++++++++++++++++------ src/plumpy/processes.py | 16 +-- src/plumpy/workchains.py | 204 +++++++++++++++++++++++++++-------- 6 files changed, 320 insertions(+), 125 deletions(-) delete mode 100644 src/plumpy/mixins.py diff --git a/src/plumpy/__init__.py b/src/plumpy/__init__.py index 6f94b5bf..5aa23401 100644 --- a/src/plumpy/__init__.py +++ b/src/plumpy/__init__.py @@ -9,7 +9,6 @@ from .exceptions import * from .futures import * from .loaders import * -from .mixins import * from .persistence import * from .ports import * from .process_comms import * @@ -25,7 +24,6 @@ + processes.__all__ + utils.__all__ + futures.__all__ - + mixins.__all__ + persistence.__all__ + communications.__all__ + process_comms.__all__ diff --git a/src/plumpy/mixins.py b/src/plumpy/mixins.py deleted file mode 100644 index 9dfa7539..00000000 --- a/src/plumpy/mixins.py +++ /dev/null @@ -1,31 +0,0 @@ -# -*- coding: utf-8 -*- -from typing import Any, Optional - -from . import persistence -from .utils import SAVED_STATE_TYPE, AttributesDict - -__all__ = ['ContextMixin'] - - -class ContextMixin(persistence.Savable): - """ - Add a context to a Process. The contents of the context will be saved - in the instance state unlike standard instance variables. - """ - - CONTEXT: str = '_context' - - def __init__(self, *args: Any, **kwargs: Any): - super().__init__(*args, **kwargs) - self._context: Optional[AttributesDict] = AttributesDict() - - @property - def ctx(self) -> Optional[AttributesDict]: - return self._context - - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - try: - self._context = AttributesDict(**saved_state[self.CONTEXT]) - except KeyError: - pass diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index ccdeef26..d33afaa1 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -477,15 +477,11 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = _ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) - call_with_super_check(obj.load_instance_state, saved_state, load_context) + obj.load_instance_state(saved_state, load_context) return obj - @super_check def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext]) -> None: - self._ensure_persist_configured() - if self._auto_persist is not None: - for member in self._auto_persist: - setattr(self, member, self._get_value(saved_state, member, load_context)) + auto_load(self, saved_state, load_context) def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: out_state: SAVED_STATE_TYPE = auto_save(self, save_context) @@ -606,7 +602,8 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa return obj def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) + auto_load(self, saved_state, load_context) + if self._callbacks: # typing says asyncio.Future._callbacks needs to be called, but in the python 3.7 code it is a simple list for callback in self._callbacks: @@ -649,3 +646,9 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S out_state[member] = value return out_state + +def auto_load(obj: Savable, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None: + obj._ensure_persist_configured() + if obj._auto_persist is not None: + for member in obj._auto_persist: + setattr(obj, member, obj._get_value(saved_state, member, load_context)) diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 0f811cb6..1d7f2350 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -27,6 +27,7 @@ from plumpy import loaders from plumpy.process_comms import KillMessage, MessageType +from plumpy.persistence import _ensure_object_loader try: import tblib @@ -38,7 +39,16 @@ from . import exceptions, futures, persistence, utils from .base import state_machine as st from .lang import NULL -from .persistence import META__OBJECT_LOADER, META__TYPE__METHOD, META__TYPE__SAVABLE, LoadSaveContext, Savable, auto_persist, auto_save +from .persistence import ( + META__OBJECT_LOADER, + META__TYPE__METHOD, + META__TYPE__SAVABLE, + LoadSaveContext, + Savable, + auto_load, + auto_persist, + auto_save, +) from .utils import SAVED_STATE_TYPE __all__ = [ @@ -136,14 +146,28 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA return out_state - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) + + obj.state_machine = load_context.process try: - self.continue_fn = utils.load_function(saved_state[self.CONTINUE_FN]) + obj.continue_fn = utils.load_function(saved_state[obj.CONTINUE_FN]) except ValueError: process = load_context.process - self.continue_fn = getattr(process, saved_state[self.CONTINUE_FN]) + obj.continue_fn = getattr(process, saved_state[obj.CONTINUE_FN]) + return obj # endregion @@ -168,6 +192,7 @@ class ProcessState(Enum): # class Savable(Protocol): # def save(self, save_context: LoadSaveContext | None = None) -> SAVED_STATE_TYPE: ... + @final @auto_persist('args', 'kwargs') class Created(persistence.Savable): @@ -190,11 +215,27 @@ def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TY return out_state - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self.process = load_context.process + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance - self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + + auto_load(obj, saved_state, load_context) + + obj.process = load_context.process + + obj.run_fn = getattr(obj.process, saved_state[obj.RUN_FN]) + + return obj def execute(self) -> st.State: return st.create_state( @@ -245,13 +286,28 @@ def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TY return out_state - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self.process = load_context.process + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) + + obj.process = load_context.process - self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) - if self.COMMAND in saved_state: - self._command = persistence.Savable.load(saved_state[self.COMMAND], load_context) # type: ignore + obj.run_fn = getattr(obj.process, saved_state[obj.RUN_FN]) + if obj.COMMAND in saved_state: + # FIXME: typing + obj._command = persistence.load(saved_state[obj.COMMAND], load_context) # type: ignore + return obj def interrupt(self, reason: Any) -> None: pass @@ -368,16 +424,30 @@ def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TY return out_state - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self.process = load_context.process + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) + + obj.process = load_context.process - callback_name = saved_state.get(self.DONE_CALLBACK, None) + callback_name = saved_state.get(obj.DONE_CALLBACK, None) if callback_name is not None: - self.done_callback = getattr(self.process, callback_name) + obj.done_callback = getattr(obj.process, callback_name) else: - self.done_callback = None - self._waiting_future = futures.Future() + obj.done_callback = None + obj._waiting_future = futures.Future() + return obj def interrupt(self, reason: Exception) -> None: # This will cause the future in execute() to raise the exception @@ -459,17 +529,30 @@ def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TY return out_state - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) - self.exception = yaml.load(saved_state[self.EXC_VALUE], Loader=Loader) + obj.exception = yaml.load(saved_state[obj.EXC_VALUE], Loader=Loader) if _HAS_TBLIB: try: - self.traceback = tblib.Traceback.from_string(saved_state[self.TRACEBACK], strict=False) + obj.traceback = tblib.Traceback.from_string(saved_state[obj.TRACEBACK], strict=False) except KeyError: - self.traceback = None + obj.traceback = None else: - self.traceback = None + obj.traceback = None + return obj def get_exc_info( self, @@ -506,8 +589,21 @@ def __init__(self, result: Any, successful: bool) -> None: self.result = result self.successful = successful - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) + return obj def enter(self) -> None: ... @@ -537,8 +633,21 @@ def __init__(self, msg: Optional[MessageType]): """ self.msg = msg - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) + return obj def enter(self) -> None: ... diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 61c2ff46..f1a3f1f7 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -9,7 +9,6 @@ import copy import enum import functools -import inspect import logging import re import sys @@ -34,7 +33,7 @@ cast, ) -from plumpy import loaders +from plumpy.persistence import _ensure_object_loader try: from aiocontextvars import ContextVar @@ -277,9 +276,12 @@ def recreate_from( :return: An instance of the object with its state loaded from the save state. """ - process = cast(Process, super().recreate_from(saved_state, load_context)) - call_with_super_check(process.init) - return process + load_context = _ensure_object_loader(load_context, saved_state) + proc = cls.__new__(cls) + proc.load_instance_state(saved_state, load_context) + + call_with_super_check(proc.init) + return proc def __init__( self, @@ -660,7 +662,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi """ # First make sure the state machine constructor is called - super().__init__() + state_machine.StateMachine.__init__(self) self._setup_event_hooks() @@ -684,7 +686,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi self._logger = load_context.logger # Need to call this here as things downstream may rely on us having the runtime variable above - super().load_instance_state(saved_state, load_context) + persistence.auto_load(self, saved_state, load_context) # Inputs/outputs try: diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 9741f7ed..cdb3b00e 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -16,6 +16,7 @@ Mapping, MutableSequence, Optional, + Protocol, Sequence, Tuple, Type, @@ -26,11 +27,14 @@ import kiwipy from plumpy.base import state_machine +from plumpy.event_helper import EventHelper from plumpy.exceptions import InvalidStateError +from plumpy.process_listener import ProcessListener -from . import lang, mixins, persistence, process_states, processes -from .utils import PID_TYPE, SAVED_STATE_TYPE +from . import lang, persistence, process_states, processes +from .utils import PID_TYPE, SAVED_STATE_TYPE, AttributesDict from plumpy import loaders, utils +from plumpy.persistence import _ensure_object_loader __all__ = ['ToContext', 'WorkChain', 'WorkChainSpec', 'if_', 'return_', 'while_'] @@ -104,18 +108,15 @@ def enter(self) -> None: for awaitable in self._awaiting: awaitable.add_done_callback(self._awaitable_done) - self.in_state = True - def exit(self) -> None: if self.is_terminal: raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') - self.in_state = False for awaitable in self._awaiting: awaitable.remove_done_callback(self._awaitable_done) -class WorkChain(mixins.ContextMixin, processes.Process): +class WorkChain(processes.Process): """ A WorkChain is a series of instructions carried out with the ability to save state in between. @@ -123,7 +124,7 @@ class WorkChain(mixins.ContextMixin, processes.Process): _spec_class = WorkChainSpec _STEPPER_STATE = 'stepper_state' - _CONTEXT = 'CONTEXT' + CONTEXT = 'CONTEXT' @classmethod def get_state_classes(cls) -> Dict[process_states.ProcessState, Type[state_machine.State]]: @@ -140,9 +141,14 @@ def __init__( communicator: Optional[kiwipy.Communicator] = None, ) -> None: super().__init__(inputs=inputs, pid=pid, logger=logger, loop=loop, communicator=communicator) + self._context: Optional[AttributesDict] = AttributesDict() self._stepper: Optional[Stepper] = None self._awaitables: Dict[Union[asyncio.Future, processes.Process], str] = {} + @property + def ctx(self) -> Optional[AttributesDict]: + return self._context + @classmethod def spec(cls) -> WorkChainSpec: return cast(WorkChainSpec, super().spec()) @@ -215,7 +221,63 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA return out_state def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) + ######### + # FIXME: dup of Process.load_instance_state + state_machine.StateMachine.__init__(self) + + self._setup_event_hooks() + + # Runtime variables, set initial states + self._future = persistence.SavableFuture() + self._event_helper = EventHelper(ProcessListener) + self._logger = None + self._communicator = None + + if 'loop' in load_context: + self._loop = load_context.loop + else: + self._loop = asyncio.get_event_loop() + + self._state: state_machine.State = self.recreate_state(saved_state['_state']) + + if 'communicator' in load_context: + self._communicator = load_context.communicator + + if 'logger' in load_context: + self._logger = load_context.logger + + # Need to call this here as things downstream may rely on us having the runtime variable above + persistence.auto_load(self, saved_state, load_context) + + # Inputs/outputs + try: + decoded = self.decode_input_args(saved_state[processes.BundleKeys.INPUTS_RAW]) + self._raw_inputs = utils.AttributesFrozendict(decoded) + except KeyError: + self._raw_inputs = None + + try: + decoded = self.decode_input_args(saved_state[processes.BundleKeys.INPUTS_PARSED]) + self._parsed_inputs = utils.AttributesFrozendict(decoded) + except KeyError: + self._parsed_inputs = None + + try: + decoded = self.decode_input_args(saved_state[processes.BundleKeys.OUTPUTS]) + self._outputs = decoded + except KeyError: + self._outputs = {} + + # + ######### + + # context mixin + try: + self._context = AttributesDict(**saved_state[self.CONTEXT]) + except KeyError: + pass + + # end of context mixin # Recreate the stepper self._stepper = None @@ -258,15 +320,8 @@ def _do_step(self) -> Any: return return_value -class Stepper(persistence.Savable, metaclass=abc.ABCMeta): - def __init__(self, workchain: 'WorkChain') -> None: - self._workchain = workchain - - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self._workchain = load_context.workchain - - @abc.abstractmethod +# XXX: Stepper is also a Saver with `save` method. +class Stepper(Protocol): def step(self) -> Tuple[bool, Any]: """ Execute on step of the instructions. @@ -275,6 +330,7 @@ def step(self) -> Tuple[bool, Any]: 1. The return value from the executed step """ + ... class _Instruction(metaclass=abc.ABCMeta): @@ -304,9 +360,9 @@ def get_description(self) -> Any: """ -class _FunctionStepper(Stepper): +class _FunctionStepper(persistence.Savable): def __init__(self, workchain: 'WorkChain', fn: WC_COMMAND_TYPE): - super().__init__(workchain) + self._workchain = workchain self._fn = fn def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: @@ -315,9 +371,24 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA return out_state - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self._fn = getattr(self._workchain.__class__, saved_state['_fn']) + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + persistence.auto_load(obj, saved_state, load_context) + obj._workchain = load_context.workchain + obj._fn = getattr(obj._workchain.__class__, saved_state['_fn']) + + return obj def step(self) -> Tuple[bool, Any]: return True, self._fn(self._workchain) @@ -357,9 +428,9 @@ def get_description(self) -> str: @persistence.auto_persist('_pos') -class _BlockStepper(Stepper): +class _BlockStepper(persistence.Savable): def __init__(self, block: Sequence[_Instruction], workchain: 'WorkChain') -> None: - super().__init__(workchain) + self._workchain = workchain self._block = block self._pos: int = 0 self._child_stepper: Optional[Stepper] = self._block[0].create_stepper(self._workchain) @@ -391,13 +462,28 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA return out_state - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self._block = load_context.block_instruction + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + persistence.auto_load(obj, saved_state, load_context) + obj._workchain = load_context.workchain + obj._block = load_context.block_instruction stepper_state = saved_state.get(STEPPER_STATE, None) - self._child_stepper = None + obj._child_stepper = None if stepper_state is not None: - self._child_stepper = self._block[self._pos].recreate_stepper(stepper_state, self._workchain) + obj._child_stepper = obj._block[obj._pos].recreate_stepper(stepper_state, obj._workchain) + + return obj def __str__(self) -> str: return str(self._pos) + ':' + str(self._child_stepper) @@ -490,9 +576,9 @@ def __str__(self) -> str: @persistence.auto_persist('_pos') -class _IfStepper(Stepper): +class _IfStepper(persistence.Savable): def __init__(self, if_instruction: '_If', workchain: 'WorkChain') -> None: - super().__init__(workchain) + self._workchain = workchain self._if_instruction = if_instruction self._pos = 0 self._child_stepper: Optional[Stepper] = None @@ -531,13 +617,27 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA return out_state - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self._if_instruction = load_context.if_instruction + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + persistence.auto_load(obj, saved_state, load_context) + obj._workchain = load_context.workchain + obj._if_instruction = load_context.if_instruction stepper_state = saved_state.get(STEPPER_STATE, None) - self._child_stepper = None + obj._child_stepper = None if stepper_state is not None: - self._child_stepper = self._if_instruction[self._pos].body.recreate_stepper(stepper_state, self._workchain) + obj._child_stepper = obj._if_instruction[obj._pos].body.recreate_stepper(stepper_state, obj._workchain) + return obj def __str__(self) -> str: string = str(self._if_instruction[self._pos]) @@ -599,9 +699,9 @@ def get_description(self) -> Mapping[str, Any]: return description -class _WhileStepper(Stepper): +class _WhileStepper(persistence.Savable): def __init__(self, while_instruction: '_While', workchain: 'WorkChain') -> None: - super().__init__(workchain) + self._workchain = workchain self._while_instruction = while_instruction self._child_stepper: Optional[_BlockStepper] = None @@ -628,13 +728,27 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA return out_state - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self._while_instruction = load_context.while_instruction + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + persistence.auto_load(obj, saved_state, load_context) + obj._workchain = load_context.workchain + obj._while_instruction = load_context.while_instruction stepper_state = saved_state.get(STEPPER_STATE, None) - self._child_stepper = None + obj._child_stepper = None if stepper_state is not None: - self._child_stepper = self._while_instruction.body.recreate_stepper(stepper_state, self._workchain) + obj._child_stepper = obj._while_instruction.body.recreate_stepper(stepper_state, obj._workchain) + return obj def __str__(self) -> str: string = str(self._while_instruction) @@ -672,9 +786,9 @@ def __init__(self, exit_code: Optional[EXIT_CODE_TYPE]) -> None: self.exit_code = exit_code -class _ReturnStepper(Stepper): +class _ReturnStepper(persistence.Savable): def __init__(self, return_instruction: '_Return', workchain: 'WorkChain') -> None: - super().__init__(workchain) + self._workchain = workchain self._return_instruction = return_instruction def step(self) -> Tuple[bool, Any]: From 484ae879c3ac2e0a80e096621373267e80d1a681 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 9 Dec 2024 22:07:15 +0100 Subject: [PATCH 24/29] ProcessListener recreate_from --- src/plumpy/process_listener.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/src/plumpy/process_listener.py b/src/plumpy/process_listener.py index 8e1acf94..166a811a 100644 --- a/src/plumpy/process_listener.py +++ b/src/plumpy/process_listener.py @@ -3,7 +3,8 @@ from typing import TYPE_CHECKING, Any, Dict, Optional from . import persistence -from .utils import SAVED_STATE_TYPE, protected +from .utils import SAVED_STATE_TYPE +from plumpy.persistence import LoadSaveContext, _ensure_object_loader __all__ = ['ProcessListener'] @@ -22,12 +23,21 @@ def __init__(self) -> None: def init(self, **kwargs: Any) -> None: self._params = kwargs - @protected - def load_instance_state( - self, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext] - ) -> None: - super().load_instance_state(saved_state, load_context) - self.init(**saved_state['_params']) + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + obj.init(**saved_state['_params']) + return obj # endregion From c910d62838fd3ac494b1fa9f0c806eccaf5e8771 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 9 Dec 2024 22:24:38 +0100 Subject: [PATCH 25/29] Absorb all load_instance_state into recreate_from --- src/plumpy/persistence.py | 23 ++++---- src/plumpy/processes.py | 112 +++++++++++++++++--------------------- src/plumpy/workchains.py | 80 ++++++++++++++++----------- 3 files changed, 110 insertions(+), 105 deletions(-) diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index d33afaa1..13a21c61 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -477,12 +477,9 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = _ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) - obj.load_instance_state(saved_state, load_context) + auto_load(obj, saved_state, load_context) return obj - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext]) -> None: - auto_load(self, saved_state, load_context) - def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: out_state: SAVED_STATE_TYPE = auto_save(self, save_context) @@ -599,15 +596,16 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa obj = cls(loop=loop) obj.cancel() - return obj - - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None: - auto_load(self, saved_state, load_context) + # ## XXX: load_instance_state: test not cover + # auto_load(obj, saved_state, load_context) + # + # if obj._callbacks: + # # typing says asyncio.Future._callbacks needs to be called, but in the python 3.7 code it is a simple list + # for callback in obj._callbacks: + # obj.remove_done_callback(callback) # type: ignore[arg-type] + # ## UNTILHERE XXX: - if self._callbacks: - # typing says asyncio.Future._callbacks needs to be called, but in the python 3.7 code it is a simple list - for callback in self._callbacks: - self.remove_done_callback(callback) # type: ignore[arg-type] + return obj def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: @@ -647,6 +645,7 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S return out_state + def auto_load(obj: Savable, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None: obj._ensure_persist_configured() if obj._auto_persist is not None: diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index f1a3f1f7..96689024 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -84,7 +84,7 @@ class BundleKeys: """ String keys used by the process to save its state in the state bundle. - See :meth:`plumpy.processes.Process.save` and :meth:`plumpy.processes.Process.load_instance_state`. + See :meth:`plumpy.processes.Process.save` and :meth:`plumpy.processes.Process.recreate_from`. """ @@ -266,10 +266,8 @@ def recreate_from( cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext] = None, - ) -> 'Process': - """ - Recreate a process from a saved state, passing any positional and - keyword arguments on to load_instance_state + ) -> Process: + """Recreate a process from a saved state, passing any positional :param saved_state: The saved state to load from :param load_context: The load context to use @@ -278,7 +276,53 @@ def recreate_from( """ load_context = _ensure_object_loader(load_context, saved_state) proc = cls.__new__(cls) - proc.load_instance_state(saved_state, load_context) + + # XXX: load_instance_state + # First make sure the state machine constructor is called + state_machine.StateMachine.__init__(proc) + + proc._setup_event_hooks() + + # Runtime variables, set initial states + proc._future = persistence.SavableFuture() + proc._event_helper = EventHelper(ProcessListener) + proc._logger = None + proc._communicator = None + + if 'loop' in load_context: + proc._loop = load_context.loop + else: + proc._loop = asyncio.get_event_loop() + + proc._state: state_machine.State = proc.recreate_state(saved_state['_state']) + + if 'communicator' in load_context: + proc._communicator = load_context.communicator + + if 'logger' in load_context: + proc._logger = load_context.logger + + # Need to call this here as things downstream may rely on us having the runtime variable above + persistence.auto_load(proc, saved_state, load_context) + + # Inputs/outputs + try: + decoded = proc.decode_input_args(saved_state[BundleKeys.INPUTS_RAW]) + proc._raw_inputs = utils.AttributesFrozendict(decoded) + except KeyError: + proc._raw_inputs = None + + try: + decoded = proc.decode_input_args(saved_state[BundleKeys.INPUTS_PARSED]) + proc._parsed_inputs = utils.AttributesFrozendict(decoded) + except KeyError: + proc._parsed_inputs = None + + try: + decoded = proc.decode_input_args(saved_state[BundleKeys.OUTPUTS]) + proc._outputs = decoded + except KeyError: + proc._outputs = {} call_with_super_check(proc.init) return proc @@ -653,62 +697,6 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA return out_state - @protected - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - """Load the process from its saved instance state. - - :param saved_state: A bundle to load the state from - :param load_context: The load context - - """ - # First make sure the state machine constructor is called - state_machine.StateMachine.__init__(self) - - self._setup_event_hooks() - - # Runtime variables, set initial states - self._future = persistence.SavableFuture() - self._event_helper = EventHelper(ProcessListener) - self._logger = None - self._communicator = None - - if 'loop' in load_context: - self._loop = load_context.loop - else: - self._loop = asyncio.get_event_loop() - - self._state: state_machine.State = self.recreate_state(saved_state['_state']) - - if 'communicator' in load_context: - self._communicator = load_context.communicator - - if 'logger' in load_context: - self._logger = load_context.logger - - # Need to call this here as things downstream may rely on us having the runtime variable above - persistence.auto_load(self, saved_state, load_context) - - # Inputs/outputs - try: - decoded = self.decode_input_args(saved_state[BundleKeys.INPUTS_RAW]) - self._raw_inputs = utils.AttributesFrozendict(decoded) - except KeyError: - self._raw_inputs = None - - try: - decoded = self.decode_input_args(saved_state[BundleKeys.INPUTS_PARSED]) - self._parsed_inputs = utils.AttributesFrozendict(decoded) - except KeyError: - self._parsed_inputs = None - - try: - decoded = self.decode_input_args(saved_state[BundleKeys.OUTPUTS]) - self._outputs = decoded - except KeyError: - self._outputs = {} - - # endregion - def add_process_listener(self, listener: ProcessListener) -> None: """Add a process listener to the process. diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index cdb3b00e..cf7ad81f 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -27,6 +27,7 @@ import kiwipy from plumpy.base import state_machine +from plumpy.base.utils import call_with_super_check from plumpy.event_helper import EventHelper from plumpy.exceptions import InvalidStateError from plumpy.process_listener import ProcessListener @@ -220,70 +221,87 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA return out_state - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - ######### - # FIXME: dup of Process.load_instance_state - state_machine.StateMachine.__init__(self) + @classmethod + def recreate_from( + cls, + saved_state: SAVED_STATE_TYPE, + load_context: Optional[persistence.LoadSaveContext] = None, + ) -> WorkChain: + """Recreate a workchain from a saved state, passing any positional + + :param saved_state: The saved state to load from + :param load_context: The load context to use + :return: An instance of the object with its state loaded from the save state. + + """ + ### FIXME: dup from process.create_from + load_context = _ensure_object_loader(load_context, saved_state) + proc = cls.__new__(cls) + + # XXX: load_instance_state + # First make sure the state machine constructor is called + state_machine.StateMachine.__init__(proc) - self._setup_event_hooks() + proc._setup_event_hooks() # Runtime variables, set initial states - self._future = persistence.SavableFuture() - self._event_helper = EventHelper(ProcessListener) - self._logger = None - self._communicator = None + proc._future = persistence.SavableFuture() + proc._event_helper = EventHelper(ProcessListener) + proc._logger = None + proc._communicator = None if 'loop' in load_context: - self._loop = load_context.loop + proc._loop = load_context.loop else: - self._loop = asyncio.get_event_loop() + proc._loop = asyncio.get_event_loop() - self._state: state_machine.State = self.recreate_state(saved_state['_state']) + proc._state: state_machine.State = proc.recreate_state(saved_state['_state']) if 'communicator' in load_context: - self._communicator = load_context.communicator + proc._communicator = load_context.communicator if 'logger' in load_context: - self._logger = load_context.logger + proc._logger = load_context.logger # Need to call this here as things downstream may rely on us having the runtime variable above - persistence.auto_load(self, saved_state, load_context) + persistence.auto_load(proc, saved_state, load_context) # Inputs/outputs try: - decoded = self.decode_input_args(saved_state[processes.BundleKeys.INPUTS_RAW]) - self._raw_inputs = utils.AttributesFrozendict(decoded) + decoded = proc.decode_input_args(saved_state[processes.BundleKeys.INPUTS_RAW]) + proc._raw_inputs = utils.AttributesFrozendict(decoded) except KeyError: - self._raw_inputs = None + proc._raw_inputs = None try: - decoded = self.decode_input_args(saved_state[processes.BundleKeys.INPUTS_PARSED]) - self._parsed_inputs = utils.AttributesFrozendict(decoded) + decoded = proc.decode_input_args(saved_state[processes.BundleKeys.INPUTS_PARSED]) + proc._parsed_inputs = utils.AttributesFrozendict(decoded) except KeyError: - self._parsed_inputs = None + proc._parsed_inputs = None try: - decoded = self.decode_input_args(saved_state[processes.BundleKeys.OUTPUTS]) - self._outputs = decoded + decoded = proc.decode_input_args(saved_state[processes.BundleKeys.OUTPUTS]) + proc._outputs = decoded except KeyError: - self._outputs = {} - - # - ######### + proc._outputs = {} + ### UNTILHERE FIXME: dup from process.create_from # context mixin try: - self._context = AttributesDict(**saved_state[self.CONTEXT]) + proc._context = AttributesDict(**saved_state[proc.CONTEXT]) except KeyError: pass # end of context mixin # Recreate the stepper - self._stepper = None - stepper_state = saved_state.get(self._STEPPER_STATE, None) + proc._stepper = None + stepper_state = saved_state.get(proc._STEPPER_STATE, None) if stepper_state is not None: - self._stepper = self.spec().get_outline().recreate_stepper(stepper_state, self) + proc._stepper = proc.spec().get_outline().recreate_stepper(stepper_state, proc) + + call_with_super_check(proc.init) + return proc def to_context(self, **kwargs: Union[asyncio.Future, processes.Process]) -> None: """ From 61c7fb80e5f90a52583a5a54b218e72f708537e7 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 9 Dec 2024 23:01:53 +0100 Subject: [PATCH 26/29] Remove useless persist method of Savable class --- src/plumpy/persistence.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index 13a21c61..2367c759 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -460,10 +460,6 @@ def auto_persist(cls, *members: str) -> None: cls._auto_persist = set() cls._auto_persist.update(members) - @classmethod - def persist(cls) -> None: - pass - @classmethod def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': """ @@ -475,10 +471,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) - return obj + ... def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: out_state: SAVED_STATE_TYPE = auto_save(self, save_context) @@ -487,7 +480,6 @@ def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TY def _ensure_persist_configured(self) -> None: if not self._persist_configured: - self.persist() self._persist_configured = True # region Metadata getter/setters From 55bc734ed4e65041c0ad8a2846ce9f8f5b99a969 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 9 Dec 2024 23:02:09 +0100 Subject: [PATCH 27/29] Explicity recreate_from implementation --- src/plumpy/event_helper.py | 22 ++++++++++++++++-- tests/test_persistence.py | 46 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 2 deletions(-) diff --git a/src/plumpy/event_helper.py b/src/plumpy/event_helper.py index 47ad4956..e20dae3f 100644 --- a/src/plumpy/event_helper.py +++ b/src/plumpy/event_helper.py @@ -1,12 +1,14 @@ # -*- coding: utf-8 -*- import logging -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Optional + +from plumpy.utils import SAVED_STATE_TYPE from . import persistence +from plumpy.persistence import Savable, LoadSaveContext, _ensure_object_loader, auto_load if TYPE_CHECKING: from typing import Set, Type - from .process_listener import ProcessListener _LOGGER = logging.getLogger(__name__) @@ -30,6 +32,22 @@ def remove_listener(self, listener: 'ProcessListener') -> None: def remove_all_listeners(self) -> None: self._listeners.clear() + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> Savable: + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = _ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) + return obj + @property def listeners(self) -> 'Set[ProcessListener]': return self._listeners diff --git a/tests/test_persistence.py b/tests/test_persistence.py index 78724aa0..65ef3226 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -5,6 +5,7 @@ import yaml import plumpy +from plumpy.persistence import auto_load from . import utils @@ -12,6 +13,21 @@ class SaveEmpty(plumpy.Savable): pass + @classmethod + def recreate_from(cls, saved_state, load_context= None): + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) + return obj + @plumpy.auto_persist('test', 'test_method') class Save1(plumpy.Savable): @@ -22,12 +38,42 @@ def __init__(self): def m(): pass + @classmethod + def recreate_from(cls, saved_state, load_context= None): + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) + return obj + @plumpy.auto_persist('test') class Save(plumpy.Savable): def __init__(self): self.test = Save1() + @classmethod + def recreate_from(cls, saved_state, load_context= None): + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) + return obj + class TestSavable(unittest.TestCase): def test_empty_savable(self): From 9b9a5b77289ff0363cb2b1518e1e515b7a488adb Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Mon, 9 Dec 2024 23:14:28 +0100 Subject: [PATCH 28/29] WIP: forming Savable protocol - remove persist_config flag of savable --- src/plumpy/event_helper.py | 12 +- src/plumpy/persistence.py | 245 ++++++++++++++++----------------- src/plumpy/process_listener.py | 16 ++- src/plumpy/process_states.py | 77 +++++++---- src/plumpy/processes.py | 9 +- src/plumpy/workchains.py | 75 ++++------ tests/test_persistence.py | 32 +++-- tests/test_processes.py | 4 + tests/test_workchains.py | 2 + 9 files changed, 250 insertions(+), 222 deletions(-) diff --git a/src/plumpy/event_helper.py b/src/plumpy/event_helper.py index e20dae3f..abc2b24b 100644 --- a/src/plumpy/event_helper.py +++ b/src/plumpy/event_helper.py @@ -2,20 +2,21 @@ import logging from typing import TYPE_CHECKING, Any, Callable, Optional +from plumpy.persistence import LoadSaveContext, Savable, auto_load, auto_save, ensure_object_loader from plumpy.utils import SAVED_STATE_TYPE from . import persistence -from plumpy.persistence import Savable, LoadSaveContext, _ensure_object_loader, auto_load if TYPE_CHECKING: from typing import Set, Type + from .process_listener import ProcessListener _LOGGER = logging.getLogger(__name__) @persistence.auto_persist('_listeners', '_listener_type') -class EventHelper(persistence.Savable): +class EventHelper: def __init__(self, listener_type: 'Type[ProcessListener]'): assert listener_type is not None, 'Must provide valid listener type' @@ -43,11 +44,16 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) return obj + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + @property def listeners(self) -> 'Set[ProcessListener]': return self._listeners diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index 2367c759..44d812d1 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -9,12 +9,24 @@ import os import pickle from types import MethodType -from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Optional, Set, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Dict, + Generator, + Iterable, + List, + Optional, + Protocol, + cast, + runtime_checkable, +) import yaml from . import futures, loaders, utils -from .base.utils import call_with_super_check, super_check from .utils import PID_TYPE, SAVED_STATE_TYPE __all__ = [ @@ -100,10 +112,10 @@ def load(saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None = N :return: The loaded Savable instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) assert load_context.loader is not None # required for type checking try: - class_name = Savable._get_class_name(saved_state) + class_name = SaveUtil.get_class_name(saved_state) load_cls: Savable = load_context.loader.load_object(class_name) except KeyError: raise ValueError('Class name not found in saved state') @@ -392,22 +404,7 @@ def delete_process_checkpoints(self, pid: PID_TYPE) -> None: del self._checkpoints[pid] -SavableClsType = TypeVar('SavableClsType', bound='type[Savable]') - - -def auto_persist(*members: str) -> Callable[[SavableClsType], SavableClsType]: - def wrapped(savable: SavableClsType) -> SavableClsType: - if savable._auto_persist is None: - savable._auto_persist = set() - else: - savable._auto_persist = set(savable._auto_persist) - savable.auto_persist(*members) - return savable - - return wrapped - - -def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAVED_STATE_TYPE) -> 'LoadSaveContext': +def ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAVED_STATE_TYPE) -> 'LoadSaveContext': """ Given a LoadSaveContext this method will ensure that it has a valid class loader using the following priorities: @@ -429,7 +426,7 @@ def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAV # 2) Try getting from saved_state default_loader = loaders.get_object_loader() try: - loader_identifier = Savable.get_custom_meta(saved_state, META__OBJECT_LOADER) + loader_identifier = SaveUtil.get_custom_meta(saved_state, META__OBJECT_LOADER) except ValueError: # 3) Fall back to default loader = default_loader @@ -448,45 +445,10 @@ def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAV META__TYPE__SAVABLE: str = 'S' -class Savable: - CLASS_NAME: str = 'class_name' - - _auto_persist: Optional[Set[str]] = None - _persist_configured = False - - @classmethod - def auto_persist(cls, *members: str) -> None: - if cls._auto_persist is None: - cls._auto_persist = set() - cls._auto_persist.update(members) - - @classmethod - def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': - """ - Recreate a :class:`Savable` from a saved state using an optional load context. - - :param saved_state: The saved state - :param load_context: An optional load context - - :return: The recreated instance - - """ - ... - - def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: - out_state: SAVED_STATE_TYPE = auto_save(self, save_context) - - return out_state - - def _ensure_persist_configured(self) -> None: - if not self._persist_configured: - self._persist_configured = True - - # region Metadata getter/setters - +class SaveUtil: @staticmethod def set_custom_meta(out_state: SAVED_STATE_TYPE, name: str, value: Any) -> None: - user_dict = Savable._get_create_meta(out_state).setdefault(META__USER, {}) + user_dict = SaveUtil.get_create_meta(out_state).setdefault(META__USER, {}) user_dict[name] = value @staticmethod @@ -497,47 +459,127 @@ def get_custom_meta(saved_state: SAVED_STATE_TYPE, name: str) -> Any: raise ValueError(f"Unknown meta key '{name}'") @staticmethod - def _get_create_meta(out_state: SAVED_STATE_TYPE) -> Dict[str, Any]: + def get_create_meta(out_state: SAVED_STATE_TYPE) -> Dict[str, Any]: return out_state.setdefault(META, {}) @staticmethod - def _set_class_name(out_state: SAVED_STATE_TYPE, name: str) -> None: - Savable._get_create_meta(out_state)[META__CLASS_NAME] = name + def set_class_name(out_state: SAVED_STATE_TYPE, name: str) -> None: + SaveUtil.get_create_meta(out_state)[META__CLASS_NAME] = name @staticmethod - def _get_class_name(saved_state: SAVED_STATE_TYPE) -> str: - return Savable._get_create_meta(saved_state)[META__CLASS_NAME] + def get_class_name(saved_state: SAVED_STATE_TYPE) -> str: + return SaveUtil.get_create_meta(saved_state)[META__CLASS_NAME] @staticmethod - def _set_meta_type(out_state: SAVED_STATE_TYPE, name: str, type_spec: Any) -> None: - type_dict = Savable._get_create_meta(out_state).setdefault(META__TYPES, {}) + def set_meta_type(out_state: SAVED_STATE_TYPE, name: str, type_spec: Any) -> None: + type_dict = SaveUtil.get_create_meta(out_state).setdefault(META__TYPES, {}) type_dict[name] = type_spec @staticmethod - def _get_meta_type(saved_state: SAVED_STATE_TYPE, name: str) -> Any: + def get_meta_type(saved_state: SAVED_STATE_TYPE, name: str) -> Any: try: return saved_state[META][META__TYPES][name] except KeyError: pass - # endregion - def _get_value( - self, saved_state: SAVED_STATE_TYPE, name: str, load_context: Optional[LoadSaveContext] - ) -> Union[MethodType, 'Savable']: - value = saved_state[name] +@runtime_checkable +class Savable(Protocol): + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + ... + + def save(self, save_context: LoadSaveContext | None = None) -> SAVED_STATE_TYPE: ... + + +@runtime_checkable +class SavableWithAutoPersist(Savable, Protocol): + _auto_persist: ClassVar[set[str]] = set() + + +def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = {} + + if save_context is None: + save_context = LoadSaveContext() + + utils.type_check(save_context, LoadSaveContext) + + default_loader = loaders.get_object_loader() + # If the user has specified a class loader, then save it in the saved state + if save_context.loader is not None: + loader_class = default_loader.identify_object(save_context.loader.__class__) + SaveUtil.set_custom_meta(out_state, META__OBJECT_LOADER, loader_class) + loader = save_context.loader + else: + loader = default_loader + + SaveUtil.set_class_name(out_state, loader.identify_object(obj.__class__)) + + if isinstance(obj, SavableWithAutoPersist): + for member in obj._auto_persist: + value = getattr(obj, member) + if inspect.ismethod(value): + if value.__self__ is not obj: + raise TypeError('Cannot persist methods of other classes') + SaveUtil.set_meta_type(out_state, member, META__TYPE__METHOD) + value = value.__name__ + elif isinstance(value, Savable) and not isinstance(value, type): + # persist for a savable obj, call `save` method of obj. + SaveUtil.set_meta_type(out_state, member, META__TYPE__SAVABLE) + value = value.save() + else: + value = copy.deepcopy(value) + out_state[member] = value + + return out_state + + +def auto_load(obj: SavableWithAutoPersist, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None: + for member in obj._auto_persist: + setattr(obj, member, _get_value(obj, saved_state, member, load_context)) + - typ = Savable._get_meta_type(saved_state, name) - if typ == META__TYPE__METHOD: - value = getattr(self, value) - elif typ == META__TYPE__SAVABLE: - value = load(value, load_context) +def _get_value( + obj: Any, saved_state: SAVED_STATE_TYPE, name: str, load_context: LoadSaveContext | None +) -> MethodType | Savable: + value = saved_state[name] - return value + typ = SaveUtil.get_meta_type(saved_state, name) + if typ == META__TYPE__METHOD: + value = getattr(obj, value) + elif typ == META__TYPE__SAVABLE: + value = load(value, load_context) + + return value + + +def auto_persist(*members: str) -> Callable[..., Savable]: + def wrapped(savable_cls: type) -> Savable: + if not hasattr(savable_cls, '_auto_persist') or savable_cls._auto_persist is None: + savable_cls._auto_persist = set() # type: ignore[attr-defined] + else: + savable_cls._auto_persist = set(savable_cls._auto_persist) + savable_cls._auto_persist.update(members) # type: ignore[attr-defined] + # XXX: validate on `save` and `recreate_from` method?? + return cast(Savable, savable_cls) + return wrapped + + +# FIXME: move me to another module? savablefuture.py? @auto_persist('_state', '_result') -class SavableFuture(futures.Future, Savable): +class SavableFuture(futures.Future): """ A savable future. @@ -562,7 +604,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) try: loop = load_context.loop @@ -598,48 +640,3 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa # ## UNTILHERE XXX: return obj - - -def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: - out_state: SAVED_STATE_TYPE = {} - - if save_context is None: - save_context = LoadSaveContext() - - utils.type_check(save_context, LoadSaveContext) - - default_loader = loaders.get_object_loader() - # If the user has specified a class loader, then save it in the saved state - if save_context.loader is not None: - loader_class = default_loader.identify_object(save_context.loader.__class__) - Savable.set_custom_meta(out_state, META__OBJECT_LOADER, loader_class) - loader = save_context.loader - else: - loader = default_loader - - Savable._set_class_name(out_state, loader.identify_object(obj.__class__)) - - obj._ensure_persist_configured() - if obj._auto_persist is not None: - for member in obj._auto_persist: - value = getattr(obj, member) - if inspect.ismethod(value): - if value.__self__ is not obj: - raise TypeError('Cannot persist methods of other classes') - Savable._set_meta_type(out_state, member, META__TYPE__METHOD) - value = value.__name__ - elif isinstance(value, Savable): - Savable._set_meta_type(out_state, member, META__TYPE__SAVABLE) - value = value.save() - else: - value = copy.deepcopy(value) - out_state[member] = value - - return out_state - - -def auto_load(obj: Savable, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None: - obj._ensure_persist_configured() - if obj._auto_persist is not None: - for member in obj._auto_persist: - setattr(obj, member, obj._get_value(saved_state, member, load_context)) diff --git a/src/plumpy/process_listener.py b/src/plumpy/process_listener.py index 166a811a..2ec07751 100644 --- a/src/plumpy/process_listener.py +++ b/src/plumpy/process_listener.py @@ -2,18 +2,23 @@ import abc from typing import TYPE_CHECKING, Any, Dict, Optional +from plumpy.persistence import LoadSaveContext, auto_save, ensure_object_loader + from . import persistence from .utils import SAVED_STATE_TYPE -from plumpy.persistence import LoadSaveContext, _ensure_object_loader __all__ = ['ProcessListener'] if TYPE_CHECKING: + from plumpy.persistence import Savable + from .processes import Process +# FIXME: test any process listener is a savable + @persistence.auto_persist('_params') -class ProcessListener(persistence.Savable, metaclass=abc.ABCMeta): +class ProcessListener(metaclass=abc.ABCMeta): # region Persistence methods def __init__(self) -> None: @@ -34,11 +39,16 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) obj.init(**saved_state['_params']) return obj + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + # endregion def on_process_created(self, process: 'Process') -> None: diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 1d7f2350..337a3153 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -1,8 +1,6 @@ # -*- coding: utf-8 -*- from __future__ import annotations -import copy -import inspect import sys import traceback from enum import Enum @@ -13,21 +11,19 @@ Callable, ClassVar, Optional, - Protocol, Tuple, Type, Union, cast, final, - runtime_checkable, + override, ) import yaml from yaml.loader import Loader -from plumpy import loaders +from plumpy.persistence import ensure_object_loader from plumpy.process_comms import KillMessage, MessageType -from plumpy.persistence import _ensure_object_loader try: import tblib @@ -40,9 +36,6 @@ from .base import state_machine as st from .lang import NULL from .persistence import ( - META__OBJECT_LOADER, - META__TYPE__METHOD, - META__TYPE__SAVABLE, LoadSaveContext, Savable, auto_load, @@ -93,8 +86,26 @@ class PauseInterruption(Interruption): # region Commands -class Command(persistence.Savable): - pass +class Command: + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + obj = cls.__new__(cls) + auto_load(obj, saved_state, load_context) + return obj + + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state @auto_persist('msg') @@ -140,12 +151,14 @@ def __init__(self, continue_fn: Callable[..., Any], *args: Any, **kwargs: Any): self.args = args self.kwargs = kwargs + @override def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) out_state[self.CONTINUE_FN] = self.continue_fn.__name__ return out_state + @override @classmethod def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': """ @@ -157,7 +170,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) @@ -188,14 +201,9 @@ class ProcessState(Enum): KILLED = 'killed' -# @runtime_checkable -# class Savable(Protocol): -# def save(self, save_context: LoadSaveContext | None = None) -> SAVED_STATE_TYPE: ... - - @final @auto_persist('args', 'kwargs') -class Created(persistence.Savable): +class Created: LABEL: ClassVar = ProcessState.CREATED ALLOWED: ClassVar = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED} @@ -226,7 +234,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) @@ -249,7 +257,7 @@ def exit(self) -> None: ... @final @auto_persist('args', 'kwargs') -class Running(persistence.Savable): +class Running: LABEL: ClassVar = ProcessState.RUNNING ALLOWED: ClassVar = { ProcessState.RUNNING, @@ -297,7 +305,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) @@ -381,7 +389,7 @@ def exit(self) -> None: ... @auto_persist('msg', 'data') -class Waiting(persistence.Savable): +class Waiting: LABEL: ClassVar = ProcessState.WAITING ALLOWED: ClassVar = { ProcessState.RUNNING, @@ -435,7 +443,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) @@ -488,7 +496,8 @@ def exit(self) -> None: ... @final -class Excepted(persistence.Savable): +@auto_persist() +class Excepted: """ Excepted state, can optionally provide exception and traceback @@ -540,7 +549,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) @@ -573,7 +582,7 @@ def exit(self) -> None: ... @final @auto_persist('result', 'successful') -class Finished(persistence.Savable): +class Finished: """State for process is finished. :param result: The result of process @@ -600,11 +609,16 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) return obj + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + def enter(self) -> None: ... def exit(self) -> None: ... @@ -612,7 +626,7 @@ def exit(self) -> None: ... @final @auto_persist('msg') -class Killed(persistence.Savable): +class Killed: """ Represents a state where a process has been killed. @@ -644,11 +658,16 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) auto_load(obj, saved_state, load_context) return obj + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + def enter(self) -> None: ... def exit(self) -> None: ... diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 96689024..53723493 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -33,7 +33,7 @@ cast, ) -from plumpy.persistence import _ensure_object_loader +from plumpy.persistence import ensure_object_loader try: from aiocontextvars import ContextVar @@ -125,7 +125,7 @@ def func_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: '_pre_paused_status', '_event_helper', ) -class Process(StateMachine, persistence.Savable, metaclass=ProcessStateMachineMeta): +class Process(StateMachine, metaclass=ProcessStateMachineMeta): """ The Process class is the base for any unit of work in plumpy. @@ -274,7 +274,7 @@ def recreate_from( :return: An instance of the object with its state loaded from the save state. """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) proc = cls.__new__(cls) # XXX: load_instance_state @@ -681,8 +681,7 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA """ out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) - # FIXME: the combined ProcessState protocol should cover the case - if isinstance(self._state, process_states.Savable): + if isinstance(self._state, persistence.Savable): out_state['_state'] = self._state.save() # Inputs/outputs diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index cf7ad81f..66418861 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- from __future__ import annotations -import copy import abc import asyncio import collections @@ -9,6 +8,7 @@ import logging import re from typing import ( + TYPE_CHECKING, Any, Callable, Dict, @@ -26,16 +26,17 @@ import kiwipy +from plumpy import utils from plumpy.base import state_machine from plumpy.base.utils import call_with_super_check from plumpy.event_helper import EventHelper from plumpy.exceptions import InvalidStateError +from plumpy.persistence import LoadSaveContext, auto_persist, auto_save, ensure_object_loader, Savable from plumpy.process_listener import ProcessListener from . import lang, persistence, process_states, processes from .utils import PID_TYPE, SAVED_STATE_TYPE, AttributesDict -from plumpy import loaders, utils -from plumpy.persistence import _ensure_object_loader + __all__ = ['ToContext', 'WorkChain', 'WorkChainSpec', 'if_', 'return_', 'while_'] @@ -165,41 +166,9 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA :param out_state: A bundle to save the state to :param save_context: The save context """ - out_state: SAVED_STATE_TYPE = {} - - if save_context is None: - save_context = persistence.LoadSaveContext() + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) - utils.type_check(save_context, persistence.LoadSaveContext) - - default_loader = loaders.get_object_loader() - # If the user has specified a class loader, then save it in the saved state - if save_context.loader is not None: - loader_class = default_loader.identify_object(save_context.loader.__class__) - persistence.Savable.set_custom_meta(out_state, persistence.META__OBJECT_LOADER, loader_class) - loader = save_context.loader - else: - loader = default_loader - - persistence.Savable._set_class_name(out_state, loader.identify_object(self.__class__)) - - self._ensure_persist_configured() - if self._auto_persist is not None: - for member in self._auto_persist: - value = getattr(self, member) - if inspect.ismethod(value): - if value.__self__ is not self: - raise TypeError('Cannot persist methods of other classes') - persistence.Savable._set_meta_type(out_state, member, persistence.META__TYPE__METHOD) - value = value.__name__ - elif isinstance(value, persistence.Savable): - persistence.Savable._set_meta_type(out_state, member, persistence.META__TYPE__SAVABLE) - value = value.save() - else: - value = copy.deepcopy(value) - out_state[member] = value - - if isinstance(self._state, process_states.Savable): + if isinstance(self._state, persistence.Savable): out_state['_state'] = self._state.save() # Inputs/outputs @@ -213,7 +182,7 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA out_state[processes.BundleKeys.OUTPUTS] = self.encode_input_args(self.outputs) # Ask the stepper to save itself - if self._stepper is not None: + if self._stepper is not None and isinstance(self._stepper, Savable): out_state[self._STEPPER_STATE] = self._stepper.save() if self._context is not None: @@ -235,7 +204,7 @@ def recreate_from( """ ### FIXME: dup from process.create_from - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) proc = cls.__new__(cls) # XXX: load_instance_state @@ -378,7 +347,8 @@ def get_description(self) -> Any: """ -class _FunctionStepper(persistence.Savable): +@auto_persist() +class _FunctionStepper: def __init__(self, workchain: 'WorkChain', fn: WC_COMMAND_TYPE): self._workchain = workchain self._fn = fn @@ -390,7 +360,9 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA return out_state @classmethod - def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + def recreate_from( + cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext] = None + ) -> 'Savable': """ Recreate a :class:`Savable` from a saved state using an optional load context. @@ -400,7 +372,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) persistence.auto_load(obj, saved_state, load_context) obj._workchain = load_context.workchain @@ -446,7 +418,7 @@ def get_description(self) -> str: @persistence.auto_persist('_pos') -class _BlockStepper(persistence.Savable): +class _BlockStepper: def __init__(self, block: Sequence[_Instruction], workchain: 'WorkChain') -> None: self._workchain = workchain self._block = block @@ -491,7 +463,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) persistence.auto_load(obj, saved_state, load_context) obj._workchain = load_context.workchain @@ -594,7 +566,7 @@ def __str__(self) -> str: @persistence.auto_persist('_pos') -class _IfStepper(persistence.Savable): +class _IfStepper: def __init__(self, if_instruction: '_If', workchain: 'WorkChain') -> None: self._workchain = workchain self._if_instruction = if_instruction @@ -646,7 +618,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) persistence.auto_load(obj, saved_state, load_context) obj._workchain = load_context.workchain @@ -717,7 +689,7 @@ def get_description(self) -> Mapping[str, Any]: return description -class _WhileStepper(persistence.Savable): +class _WhileStepper: def __init__(self, while_instruction: '_While', workchain: 'WorkChain') -> None: self._workchain = workchain self._while_instruction = while_instruction @@ -747,7 +719,9 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA return out_state @classmethod - def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + def recreate_from( + cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext] = None + ) -> 'Savable': """ Recreate a :class:`Savable` from a saved state using an optional load context. @@ -757,7 +731,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) obj = cls.__new__(cls) persistence.auto_load(obj, saved_state, load_context) obj._workchain = load_context.workchain @@ -804,7 +778,8 @@ def __init__(self, exit_code: Optional[EXIT_CODE_TYPE]) -> None: self.exit_code = exit_code -class _ReturnStepper(persistence.Savable): +@persistence.auto_persist() +class _ReturnStepper: def __init__(self, return_instruction: '_Return', workchain: 'WorkChain') -> None: self._workchain = workchain self._return_instruction = return_instruction diff --git a/tests/test_persistence.py b/tests/test_persistence.py index 65ef3226..4ec4c1a5 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -5,16 +5,17 @@ import yaml import plumpy -from plumpy.persistence import auto_load +from plumpy.persistence import auto_load, auto_persist, auto_save +from plumpy.utils import SAVED_STATE_TYPE from . import utils -class SaveEmpty(plumpy.Savable): - pass +@auto_persist() +class SaveEmpty: @classmethod - def recreate_from(cls, saved_state, load_context= None): + def recreate_from(cls, saved_state, load_context=None): """ Recreate a :class:`Savable` from a saved state using an optional load context. @@ -28,9 +29,14 @@ def recreate_from(cls, saved_state, load_context= None): auto_load(obj, saved_state, load_context) return obj + def save(self, save_context=None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + @plumpy.auto_persist('test', 'test_method') -class Save1(plumpy.Savable): +class Save1: def __init__(self): self.test = 'sup yp' self.test_method = self.m @@ -39,7 +45,7 @@ def m(): pass @classmethod - def recreate_from(cls, saved_state, load_context= None): + def recreate_from(cls, saved_state, load_context=None): """ Recreate a :class:`Savable` from a saved state using an optional load context. @@ -53,14 +59,19 @@ def recreate_from(cls, saved_state, load_context= None): auto_load(obj, saved_state, load_context) return obj + def save(self, save_context=None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + @plumpy.auto_persist('test') -class Save(plumpy.Savable): +class Save: def __init__(self): self.test = Save1() @classmethod - def recreate_from(cls, saved_state, load_context= None): + def recreate_from(cls, saved_state, load_context=None): """ Recreate a :class:`Savable` from a saved state using an optional load context. @@ -74,6 +85,11 @@ def recreate_from(cls, saved_state, load_context= None): auto_load(obj, saved_state, load_context) return obj + def save(self, save_context=None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + class TestSavable(unittest.TestCase): def test_empty_savable(self): diff --git a/tests/test_processes.py b/tests/test_processes.py index 7fa33bb1..8c15cf9a 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -14,6 +14,10 @@ from plumpy.utils import AttributesFrozendict from tests import utils +# FIXME: after deabstract on savable into a protocol, test that all state are savable +# FIXME: also that any process is savable +# FIXME: any process listener is savable +# FIXME: any process control commands are savable class ForgetToCallParent(plumpy.Process): def __init__(self, forget_on): diff --git a/tests/test_workchains.py b/tests/test_workchains.py index 08c7317a..4e34d2b4 100644 --- a/tests/test_workchains.py +++ b/tests/test_workchains.py @@ -11,6 +11,8 @@ from . import utils +# FIXME: after deabstract on savable into a protocol, test that all stepper are savable +# FIXME: workchani itself is savable class Wf(WorkChain): # Keep track of which steps were completed by the workflow From 3e6a2dd6f1851bc186d4647353e482fa846c9fd9 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Wed, 11 Dec 2024 00:54:27 +0100 Subject: [PATCH 29/29] Make auto_load symmetry with auto_save and state/state_label distinguish --- src/plumpy/base/state_machine.py | 10 +++- src/plumpy/event_helper.py | 3 +- src/plumpy/persistence.py | 19 ++++++- src/plumpy/process_states.py | 42 ++++++---------- src/plumpy/processes.py | 79 +++++++++++++++-------------- src/plumpy/workchains.py | 24 ++++----- tests/base/test_statemachine.py | 10 ++-- tests/rmq/test_process_comms.py | 10 ++-- tests/test_persistence.py | 14 +++--- tests/test_processes.py | 85 ++++++++++++++++---------------- tests/utils.py | 2 +- 11 files changed, 153 insertions(+), 145 deletions(-) diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index fc926008..9e7ca122 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -266,7 +266,13 @@ def create_initial_state(self, *args: Any, **kwargs: Any) -> State: return self.get_state_class(self.initial_state_label())(self, *args, **kwargs) @property - def state(self) -> Any: + def state(self) -> State | None: + if self._state is None: + return None + return self._state + + @property + def state_label(self) -> Any: if self._state is None: return None return self._state.LABEL @@ -312,7 +318,7 @@ def transition_to(self, new_state: State | None, **kwargs: Any) -> None: if new_state is None: return None - initial_state_label = self._state.LABEL if self._state is not None else None + initial_state_label = self.state_label label = None try: self._transitioning = True diff --git a/src/plumpy/event_helper.py b/src/plumpy/event_helper.py index abc2b24b..9262f856 100644 --- a/src/plumpy/event_helper.py +++ b/src/plumpy/event_helper.py @@ -45,8 +45,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + obj = auto_load(cls, saved_state, load_context) return obj def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index 44d812d1..3b333edb 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -20,6 +20,7 @@ List, Optional, Protocol, + TypeVar, cast, runtime_checkable, ) @@ -535,6 +536,8 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S value = value.__name__ elif isinstance(value, Savable) and not isinstance(value, type): # persist for a savable obj, call `save` method of obj. + # the rhs branch is for when value is a Savable class, it is true runtime check + # of lhs condition. SaveUtil.set_meta_type(out_state, member, META__TYPE__SAVABLE) value = value.save() else: @@ -544,11 +547,25 @@ def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> S return out_state -def auto_load(obj: SavableWithAutoPersist, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None: +def load_auto_persist_params( + obj: SavableWithAutoPersist, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None +) -> None: for member in obj._auto_persist: setattr(obj, member, _get_value(obj, saved_state, member, load_context)) +T = TypeVar('T', bound=Savable) + + +def auto_load(cls: type[T], saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None) -> T: + obj = cls.__new__(cls) + + if isinstance(obj, SavableWithAutoPersist): + load_auto_persist_params(obj, saved_state, load_context) + + return obj + + def _get_value( obj: Any, saved_state: SAVED_STATE_TYPE, name: str, load_context: LoadSaveContext | None ) -> MethodType | Savable: diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 337a3153..1a176b9b 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -22,7 +22,6 @@ import yaml from yaml.loader import Loader -from plumpy.persistence import ensure_object_loader from plumpy.process_comms import KillMessage, MessageType try: @@ -41,6 +40,7 @@ auto_load, auto_persist, auto_save, + ensure_object_loader, ) from .utils import SAVED_STATE_TYPE @@ -98,8 +98,8 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) return obj def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: @@ -171,15 +171,15 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + obj = auto_load(cls, saved_state, load_context) - obj.state_machine = load_context.process try: obj.continue_fn = utils.load_function(saved_state[obj.CONTINUE_FN]) except ValueError: - process = load_context.process - obj.continue_fn = getattr(process, saved_state[obj.CONTINUE_FN]) + if load_context is not None: + obj.continue_fn = getattr(load_context.proc, saved_state[obj.CONTINUE_FN]) + else: + raise return obj @@ -235,12 +235,8 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - - auto_load(obj, saved_state, load_context) - + obj = auto_load(cls, saved_state, load_context) obj.process = load_context.process - obj.run_fn = getattr(obj.process, saved_state[obj.RUN_FN]) return obj @@ -306,15 +302,12 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) - + obj = auto_load(cls, saved_state, load_context) obj.process = load_context.process - obj.run_fn = getattr(obj.process, saved_state[obj.RUN_FN]) if obj.COMMAND in saved_state: - # FIXME: typing obj._command = persistence.load(saved_state[obj.COMMAND], load_context) # type: ignore + return obj def interrupt(self, reason: Any) -> None: @@ -444,9 +437,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) - + obj = auto_load(cls, saved_state, load_context) obj.process = load_context.process callback_name = saved_state.get(obj.DONE_CALLBACK, None) @@ -550,8 +541,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + obj = auto_load(cls, saved_state, load_context) obj.exception = yaml.load(saved_state[obj.EXC_VALUE], Loader=Loader) if _HAS_TBLIB: @@ -610,8 +600,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + obj = auto_load(cls, saved_state, load_context) return obj def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: @@ -659,8 +648,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + obj = auto_load(cls, saved_state, load_context) return obj def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 53723493..8b8107d4 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -20,6 +20,7 @@ Any, Awaitable, Callable, + ClassVar, Dict, Generator, Hashable, @@ -175,6 +176,7 @@ class Process(StateMachine, metaclass=ProcessStateMachineMeta): _cleanups: Optional[List[Callable[[], None]]] = None __called: bool = False + _auto_persist: ClassVar[set[str]] @classmethod def current(cls) -> Optional['Process']: @@ -294,7 +296,7 @@ def recreate_from( else: proc._loop = asyncio.get_event_loop() - proc._state: state_machine.State = proc.recreate_state(saved_state['_state']) + proc._state = proc.recreate_state(saved_state['_state']) if 'communicator' in load_context: proc._communicator = load_context.communicator @@ -303,7 +305,7 @@ def recreate_from( proc._logger = load_context.logger # Need to call this here as things downstream may rely on us having the runtime variable above - persistence.auto_load(proc, saved_state, load_context) + persistence.load_auto_persist_params(proc, saved_state, load_context) # Inputs/outputs try: @@ -527,7 +529,9 @@ def launch( def has_terminated(self) -> bool: """Return whether the process was terminated.""" - return self._state.is_terminal + if self.state is None: + raise exceptions.InvalidStateError('process is not in state None that is invalid') + return self.state.is_terminal def result(self) -> Any: """ @@ -537,12 +541,12 @@ def result(self) -> Any: If in any other state this will raise an InvalidStateError. :return: The result of the process """ - if isinstance(self._state, process_states.Finished): - return self._state.result - if isinstance(self._state, process_states.Killed): - raise exceptions.KilledError(self._state.msg) - if isinstance(self._state, process_states.Excepted): - raise (self._state.exception or Exception('process excepted')) + if isinstance(self.state, process_states.Finished): + return self.state.result + if isinstance(self.state, process_states.Killed): + raise exceptions.KilledError(self.state.msg) + if isinstance(self.state, process_states.Excepted): + raise (self.state.exception or Exception('process excepted')) raise exceptions.InvalidStateError @@ -552,7 +556,7 @@ def successful(self) -> bool: Will raise if the process is not in the FINISHED state """ try: - return self._state.successful # type: ignore + return self.state.successful # type: ignore except AttributeError as exception: raise exceptions.InvalidStateError('process is not in the finished state') from exception @@ -563,25 +567,25 @@ def is_successful(self) -> bool: :return: boolean, True if the process is in `Finished` state with `successful` attribute set to `True` """ try: - return self._state.successful # type: ignore + return self.state.successful # type: ignore except AttributeError: return False def killed(self) -> bool: """Return whether the process is killed.""" - return self.state == process_states.ProcessState.KILLED + return self.state_label == process_states.ProcessState.KILLED def killed_msg(self) -> Optional[MessageType]: """Return the killed message.""" - if isinstance(self._state, process_states.Killed): - return self._state.msg + if isinstance(self.state, process_states.Killed): + return self.state.msg raise exceptions.InvalidStateError('Has not been killed') def exception(self) -> Optional[BaseException]: """Return exception, if the process is terminated in excepted state.""" - if isinstance(self._state, process_states.Excepted): - return self._state.exception + if isinstance(self.state, process_states.Excepted): + return self.state.exception return None @@ -591,7 +595,7 @@ def is_excepted(self) -> bool: :return: boolean, True if the process is in ``EXCEPTED`` state. """ - return self.state == process_states.ProcessState.EXCEPTED + return self.state_label == process_states.ProcessState.EXCEPTED def done(self) -> bool: """Return True if the call was successfully killed or finished running. @@ -600,7 +604,7 @@ def done(self) -> bool: Use the `has_terminated` method instead """ warnings.warn('method is deprecated, use `has_terminated` instead', DeprecationWarning) - return self._state.is_terminal + return self.has_terminated() # endregion @@ -628,7 +632,7 @@ def callback_excepted( exception: Optional[BaseException], trace: Optional[TracebackType], ) -> None: - if self.state != process_states.ProcessState.EXCEPTED: + if self.state_label != process_states.ProcessState.EXCEPTED: self.fail(exception, trace) @contextlib.contextmanager @@ -681,8 +685,8 @@ def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SA """ out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) - if isinstance(self._state, persistence.Savable): - out_state['_state'] = self._state.save() + if isinstance(self.state, persistence.Savable): + out_state['_state'] = self.state.save() # Inputs/outputs if self.raw_inputs is not None: @@ -740,7 +744,7 @@ def on_entering(self, state: state_machine.State) -> None: def on_entered(self, from_state: Optional[state_machine.State]) -> None: # Map these onto direct functions that the subclass can implement - state_label = self._state.LABEL + state_label = self.state_label if state_label == process_states.ProcessState.RUNNING: call_with_super_check(self.on_running) elif state_label == process_states.ProcessState.WAITING: @@ -752,21 +756,21 @@ def on_entered(self, from_state: Optional[state_machine.State]) -> None: elif state_label == process_states.ProcessState.KILLED: call_with_super_check(self.on_killed) - if self._communicator and isinstance(self.state, enum.Enum): + if self._communicator and isinstance(self.state_label, enum.Enum): from_label = cast(enum.Enum, from_state.LABEL).value if from_state is not None else None - subject = f'state_changed.{from_label}.{self.state.value}' + subject = f'state_changed.{from_label}.{self.state_label.value}' self.logger.info('Process<%s>: Broadcasting state change: %s', self.pid, subject) try: self._communicator.broadcast_send(body=None, sender=self.pid, subject=subject) except (ConnectionClosed, ChannelInvalidStateError): message = 'Process<%s>: no connection available to broadcast state change from %s to %s' - self.logger.warning(message, self.pid, from_label, self.state.value) + self.logger.warning(message, self.pid, from_label, self.state_label.value) except kiwipy.TimeoutError: message = 'Process<%s>: sending broadcast of state change from %s to %s timed out' - self.logger.warning(message, self.pid, from_label, self.state.value) + self.logger.warning(message, self.pid, from_label, self.state_label.value) def on_exiting(self) -> None: - state = self.state + state = self.state_label if state == process_states.ProcessState.WAITING: call_with_super_check(self.on_exit_waiting) elif state == process_states.ProcessState.RUNNING: @@ -1069,7 +1073,6 @@ def transition_failed( if final_state == process_states.ProcessState.CREATED: raise exception.with_traceback(trace) - # state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED] new_state = create_state(self, process_states.ProcessState.EXCEPTED, exception=exception, traceback=trace) self.transition_to(new_state) @@ -1095,9 +1098,9 @@ def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.Cancellable return self._pausing if self._stepping: - if not isinstance(self._state, Interruptable): + if not isinstance(self.state, Interruptable): raise exceptions.InvalidStateError( - f'cannot interrupt {self._state.__class__}, method `interrupt` not implement' + f'cannot interrupt {self.state.__class__}, method `interrupt` not implement' ) # Ask the step function to pause by setting this flag and giving the @@ -1106,7 +1109,7 @@ def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.Cancellable self._set_interrupt_action_from_exception(interrupt_exception) self._pausing = self._interrupt_action # Try to interrupt the state - self._state.interrupt(interrupt_exception) + self.state.interrupt(interrupt_exception) return cast(futures.CancellableAction, self._interrupt_action) return self._do_pause(msg) @@ -1189,7 +1192,7 @@ def play(self) -> bool: @event(from_states=(process_states.Waiting)) def resume(self, *args: Any) -> None: """Start running the process again.""" - return self._state.resume(*args) # type: ignore + return self.state.resume(*args) # type: ignore @event(to_states=process_states.Excepted) def fail(self, exception: Optional[BaseException], traceback: Optional[TracebackType]) -> None: @@ -1207,7 +1210,7 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future] Kill the process :param msg: An optional kill message """ - if self.state == process_states.ProcessState.KILLED: + if self.state_label == process_states.ProcessState.KILLED: # Already killed return True @@ -1219,13 +1222,13 @@ def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future] # Already killing return self._killing - if self._stepping and isinstance(self._state, Interruptable): + if self._stepping and isinstance(self.state, Interruptable): # Ask the step function to pause by setting this flag and giving the # caller back a future interrupt_exception = process_states.KillInterruption(msg) self._set_interrupt_action_from_exception(interrupt_exception) self._killing = self._interrupt_action - self._state.interrupt(interrupt_exception) + self.state.interrupt(interrupt_exception) return cast(futures.CancellableAction, self._interrupt_action) new_state = create_state(self, process_states.ProcessState.KILLED, msg=msg) @@ -1294,14 +1297,14 @@ async def step(self) -> None: if self.paused and self._paused is not None: await self._paused - if not isinstance(self._state, Proceedable): - raise StateMachineError(f'cannot step from {self._state.__class__}, async method `execute` not implemented') + if not isinstance(self.state, Proceedable): + raise StateMachineError(f'cannot step from {self.state.__class__}, async method `execute` not implemented') try: self._stepping = True next_state = None try: - next_state = await self._run_task(self._state.execute) + next_state = await self._run_task(self.state.execute) except process_states.Interruption as exception: # If the interruption was caused by a call to a Process method then there should # be an interrupt action ready to be executed, so just check if the cookie matches diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 66418861..5caf1882 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -8,7 +8,6 @@ import logging import re from typing import ( - TYPE_CHECKING, Any, Callable, Dict, @@ -31,13 +30,12 @@ from plumpy.base.utils import call_with_super_check from plumpy.event_helper import EventHelper from plumpy.exceptions import InvalidStateError -from plumpy.persistence import LoadSaveContext, auto_persist, auto_save, ensure_object_loader, Savable +from plumpy.persistence import LoadSaveContext, Savable, auto_persist, auto_save, ensure_object_loader from plumpy.process_listener import ProcessListener from . import lang, persistence, process_states, processes from .utils import PID_TYPE, SAVED_STATE_TYPE, AttributesDict - __all__ = ['ToContext', 'WorkChain', 'WorkChainSpec', 'if_', 'return_', 'while_'] ToContext = dict @@ -224,7 +222,7 @@ def recreate_from( else: proc._loop = asyncio.get_event_loop() - proc._state: state_machine.State = proc.recreate_state(saved_state['_state']) + proc._state = proc.recreate_state(saved_state['_state']) if 'communicator' in load_context: proc._communicator = load_context.communicator @@ -233,7 +231,7 @@ def recreate_from( proc._logger = load_context.logger # Need to call this here as things downstream may rely on us having the runtime variable above - persistence.auto_load(proc, saved_state, load_context) + persistence.load_auto_persist_params(proc, saved_state, load_context) # Inputs/outputs try: @@ -373,8 +371,7 @@ def recreate_from( """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - persistence.auto_load(obj, saved_state, load_context) + obj = persistence.auto_load(cls, saved_state, load_context) obj._workchain = load_context.workchain obj._fn = getattr(obj._workchain.__class__, saved_state['_fn']) @@ -447,7 +444,7 @@ def finished(self) -> bool: def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) - if self._child_stepper is not None: + if self._child_stepper is not None and isinstance(self._child_stepper, Savable): out_state[STEPPER_STATE] = self._child_stepper.save() return out_state @@ -464,8 +461,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - persistence.auto_load(obj, saved_state, load_context) + obj = persistence.auto_load(cls, saved_state, load_context) obj._workchain = load_context.workchain obj._block = load_context.block_instruction stepper_state = saved_state.get(STEPPER_STATE, None) @@ -602,7 +598,7 @@ def finished(self) -> bool: def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) - if self._child_stepper is not None: + if self._child_stepper is not None and isinstance(self._child_stepper, Savable): out_state[STEPPER_STATE] = self._child_stepper.save() return out_state @@ -619,8 +615,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - persistence.auto_load(obj, saved_state, load_context) + obj = persistence.auto_load(cls, saved_state, load_context) obj._workchain = load_context.workchain obj._if_instruction = load_context.if_instruction stepper_state = saved_state.get(STEPPER_STATE, None) @@ -732,8 +727,7 @@ def recreate_from( """ load_context = ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - persistence.auto_load(obj, saved_state, load_context) + obj = persistence.auto_load(cls, saved_state, load_context) obj._workchain = load_context.workchain obj._while_instruction = load_context.while_instruction stepper_state = saved_state.get(STEPPER_STATE, None) diff --git a/tests/base/test_statemachine.py b/tests/base/test_statemachine.py index 6a61fe00..15a218ce 100644 --- a/tests/base/test_statemachine.py +++ b/tests/base/test_statemachine.py @@ -150,22 +150,22 @@ def stop(self): class TestStateMachine(unittest.TestCase): def test_basic(self): cd_player = CdPlayer() - self.assertEqual(cd_player.state, STOPPED) + self.assertEqual(cd_player.state_label, STOPPED) cd_player.play('Eminem - The Real Slim Shady') - self.assertEqual(cd_player.state, PLAYING) + self.assertEqual(cd_player.state_label, PLAYING) time.sleep(1.0) cd_player.pause() - self.assertEqual(cd_player.state, PAUSED) + self.assertEqual(cd_player.state_label, PAUSED) cd_player.play() - self.assertEqual(cd_player.state, PLAYING) + self.assertEqual(cd_player.state_label, PLAYING) self.assertEqual(cd_player.play(), False) cd_player.stop() - self.assertEqual(cd_player.state, STOPPED) + self.assertEqual(cd_player.state_label, STOPPED) def test_invalid_event(self): cd_player = CdPlayer() diff --git a/tests/rmq/test_process_comms.py b/tests/rmq/test_process_comms.py index c6826a24..307bfdb7 100644 --- a/tests/rmq/test_process_comms.py +++ b/tests/rmq/test_process_comms.py @@ -67,7 +67,7 @@ async def test_play(self, thread_communicator, async_controller): # Check that all is as we expect assert result - assert proc.state == plumpy.ProcessState.WAITING + assert proc.state_label == plumpy.ProcessState.WAITING # if not close the background process will raise exception # make sure proc reach the final state @@ -84,7 +84,7 @@ async def test_kill(self, thread_communicator, async_controller): # Check the outcome assert result - assert proc.state == plumpy.ProcessState.KILLED + assert proc.state_label == plumpy.ProcessState.KILLED @pytest.mark.asyncio async def test_status(self, thread_communicator, async_controller): @@ -172,7 +172,7 @@ async def test_play(self, thread_communicator, sync_controller): # Check that all is as we expect assert result - assert proc.state == plumpy.ProcessState.CREATED + assert proc.state_label == plumpy.ProcessState.CREATED @pytest.mark.asyncio async def test_kill(self, thread_communicator, sync_controller): @@ -186,7 +186,7 @@ async def test_kill(self, thread_communicator, sync_controller): # Check the outcome assert result # Occasionally fail - assert proc.state == plumpy.ProcessState.KILLED + assert proc.state_label == plumpy.ProcessState.KILLED @pytest.mark.asyncio async def test_kill_all(self, thread_communicator, sync_controller): @@ -199,7 +199,7 @@ async def test_kill_all(self, thread_communicator, sync_controller): sync_controller.kill_all(msg) await utils.wait_util(lambda: all([proc.killed() for proc in procs])) - assert all([proc.state == plumpy.ProcessState.KILLED for proc in procs]) + assert all([proc.state_label == plumpy.ProcessState.KILLED for proc in procs]) @pytest.mark.asyncio async def test_status(self, thread_communicator, sync_controller): diff --git a/tests/test_persistence.py b/tests/test_persistence.py index 4ec4c1a5..7f616433 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -5,7 +5,7 @@ import yaml import plumpy -from plumpy.persistence import auto_load, auto_persist, auto_save +from plumpy.persistence import auto_load, auto_persist, auto_save, ensure_object_loader from plumpy.utils import SAVED_STATE_TYPE from . import utils @@ -25,8 +25,8 @@ def recreate_from(cls, saved_state, load_context=None): :return: The recreated instance """ - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) return obj def save(self, save_context=None) -> SAVED_STATE_TYPE: @@ -55,8 +55,8 @@ def recreate_from(cls, saved_state, load_context=None): :return: The recreated instance """ - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) return obj def save(self, save_context=None) -> SAVED_STATE_TYPE: @@ -81,8 +81,8 @@ def recreate_from(cls, saved_state, load_context=None): :return: The recreated instance """ - obj = cls.__new__(cls) - auto_load(obj, saved_state, load_context) + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) return obj def save(self, save_context=None) -> SAVED_STATE_TYPE: diff --git a/tests/test_processes.py b/tests/test_processes.py index 8c15cf9a..a62bbd8d 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -19,6 +19,7 @@ # FIXME: any process listener is savable # FIXME: any process control commands are savable + class ForgetToCallParent(plumpy.Process): def __init__(self, forget_on): super().__init__() @@ -239,7 +240,7 @@ def test_execute(self): proc.execute() self.assertTrue(proc.has_terminated()) - self.assertEqual(proc.state, ProcessState.FINISHED) + self.assertEqual(proc.state_label, ProcessState.FINISHED) self.assertEqual(proc.outputs, {'default': 5}) def test_run_from_class(self): @@ -277,7 +278,7 @@ def test_exception(self): proc = utils.ExceptionProcess() with self.assertRaises(RuntimeError): proc.execute() - self.assertEqual(proc.state, ProcessState.EXCEPTED) + self.assertEqual(proc.state_label, ProcessState.EXCEPTED) def test_run_kill(self): proc = utils.KillProcess() @@ -330,7 +331,7 @@ def test_kill(self): proc.kill(msg) self.assertTrue(proc.killed()) self.assertEqual(proc.killed_msg(), msg) - self.assertEqual(proc.state, ProcessState.KILLED) + self.assertEqual(proc.state_label, ProcessState.KILLED) def test_wait_continue(self): proc = utils.WaitForSignalProcess() @@ -344,7 +345,7 @@ def test_wait_continue(self): # Check it's done self.assertTrue(proc.has_terminated()) - self.assertEqual(proc.state, ProcessState.FINISHED) + self.assertEqual(proc.state_label, ProcessState.FINISHED) def test_exc_info(self): proc = utils.ExceptionProcess() @@ -368,7 +369,7 @@ def test_wait_pause_play_resume(self): async def async_test(): await utils.run_until_waiting(proc) - self.assertEqual(proc.state, ProcessState.WAITING) + self.assertEqual(proc.state_label, ProcessState.WAITING) result = await proc.pause() self.assertTrue(result) @@ -384,7 +385,7 @@ async def async_test(): # Check it's done self.assertTrue(proc.has_terminated()) - self.assertEqual(proc.state, ProcessState.FINISHED) + self.assertEqual(proc.state_label, ProcessState.FINISHED) loop.create_task(proc.step_until_terminated()) loop.run_until_complete(async_test()) @@ -405,7 +406,7 @@ def test_pause_play_status_messaging(self): async def async_test(): await utils.run_until_waiting(proc) - self.assertEqual(proc.state, ProcessState.WAITING) + self.assertEqual(proc.state_label, ProcessState.WAITING) result = await proc.pause(PAUSE_STATUS) self.assertTrue(result) @@ -425,7 +426,7 @@ async def async_test(): loop.run_until_complete(async_test()) self.assertTrue(proc.has_terminated()) - self.assertEqual(proc.state, ProcessState.FINISHED) + self.assertEqual(proc.state_label, ProcessState.FINISHED) def test_kill_in_run(self): class KillProcess(Process): @@ -443,7 +444,7 @@ def run(self, **kwargs): proc.execute() self.assertTrue(proc.after_kill) - self.assertEqual(proc.state, ProcessState.KILLED) + self.assertEqual(proc.state_label, ProcessState.KILLED) def test_kill_when_paused_in_run(self): class PauseProcess(Process): @@ -455,7 +456,7 @@ def run(self, **kwargs): with self.assertRaises(plumpy.KilledError): proc.execute() - self.assertEqual(proc.state, ProcessState.KILLED) + self.assertEqual(proc.state_label, ProcessState.KILLED) def test_kill_when_paused(self): loop = asyncio.get_event_loop() @@ -479,7 +480,7 @@ async def async_test(): loop.create_task(proc.step_until_terminated()) loop.run_until_complete(async_test()) - self.assertEqual(proc.state, ProcessState.KILLED) + self.assertEqual(proc.state_label, ProcessState.KILLED) def test_run_multiple(self): # Create and play some processes @@ -555,7 +556,7 @@ def run(self): loop.run_forever() self.assertTrue(proc.paused) - self.assertEqual(plumpy.ProcessState.FINISHED, proc.state) + self.assertEqual(proc.state_label, plumpy.ProcessState.FINISHED) def test_pause_play_in_process(self): """Test that we can pause and play that by playing within the process""" @@ -573,7 +574,7 @@ def run(self): proc.execute() self.assertFalse(proc.paused) - self.assertEqual(plumpy.ProcessState.FINISHED, proc.state) + self.assertEqual(proc.state_label, plumpy.ProcessState.FINISHED) def test_process_stack(self): test_case = self @@ -784,7 +785,7 @@ def test_saving_each_step(self): proc = proc_class() saver = utils.ProcessSaver(proc) saver.capture() - self.assertEqual(proc.state, ProcessState.FINISHED) + self.assertEqual(proc.state_label, ProcessState.FINISHED) self.assertTrue(utils.check_process_against_snapshots(loop, proc_class, saver.snapshots)) def test_restart(self): @@ -799,7 +800,7 @@ async def async_test(): # Load a process from the saved state loaded_proc = saved_state.unbundle() - self.assertEqual(loaded_proc.state, ProcessState.WAITING) + self.assertEqual(loaded_proc.state_label, ProcessState.WAITING) # Now resume it loaded_proc.resume() @@ -822,7 +823,7 @@ async def async_test(): # Load a process from the saved state loaded_proc = saved_state.unbundle() - self.assertEqual(loaded_proc.state, ProcessState.WAITING) + self.assertEqual(loaded_proc.state_label, ProcessState.WAITING) # Now resume it twice in succession loaded_proc.resume() @@ -864,7 +865,7 @@ async def async_test(): def test_killed(self): proc = utils.DummyProcess() proc.kill() - self.assertEqual(proc.state, plumpy.ProcessState.KILLED) + self.assertEqual(proc.state_label, plumpy.ProcessState.KILLED) self._check_round_trip(proc) def _check_round_trip(self, proc1): @@ -987,40 +988,40 @@ def run(self): self.out(namespace_nested + '.two', 2) # Run the process in default mode which should not add any outputs and therefore fail - process = DummyDynamicProcess() - process.execute() + proc = DummyDynamicProcess() + proc.execute() - self.assertEqual(process.state, ProcessState.FINISHED) - self.assertFalse(process.is_successful) - self.assertDictEqual(process.outputs, {}) + self.assertEqual(proc.state_label, ProcessState.FINISHED) + self.assertFalse(proc.is_successful) + self.assertDictEqual(proc.outputs, {}) # Attaching only namespaced ports should fail, because the required port is not added - process = DummyDynamicProcess(inputs={'output_mode': OutputMode.DYNAMIC_PORT_NAMESPACE}) - process.execute() + proc = DummyDynamicProcess(inputs={'output_mode': OutputMode.DYNAMIC_PORT_NAMESPACE}) + proc.execute() - self.assertEqual(process.state, ProcessState.FINISHED) - self.assertFalse(process.is_successful) - self.assertEqual(process.outputs[namespace]['nested']['one'], 1) - self.assertEqual(process.outputs[namespace]['nested']['two'], 2) + self.assertEqual(proc.state_label, ProcessState.FINISHED) + self.assertFalse(proc.is_successful) + self.assertEqual(proc.outputs[namespace]['nested']['one'], 1) + self.assertEqual(proc.outputs[namespace]['nested']['two'], 2) # Attaching only the single required top-level port should be fine - process = DummyDynamicProcess(inputs={'output_mode': OutputMode.SINGLE_REQUIRED_PORT}) - process.execute() + proc = DummyDynamicProcess(inputs={'output_mode': OutputMode.SINGLE_REQUIRED_PORT}) + proc.execute() - self.assertEqual(process.state, ProcessState.FINISHED) - self.assertTrue(process.is_successful) - self.assertEqual(process.outputs['required_bool'], False) + self.assertEqual(proc.state_label, ProcessState.FINISHED) + self.assertTrue(proc.is_successful) + self.assertEqual(proc.outputs['required_bool'], False) # Attaching both the required and namespaced ports should result in a successful termination - process = DummyDynamicProcess(inputs={'output_mode': OutputMode.BOTH_SINGLE_AND_NAMESPACE}) - process.execute() - - self.assertIsNotNone(process.outputs) - self.assertEqual(process.state, ProcessState.FINISHED) - self.assertTrue(process.is_successful) - self.assertEqual(process.outputs['required_bool'], False) - self.assertEqual(process.outputs[namespace]['nested']['one'], 1) - self.assertEqual(process.outputs[namespace]['nested']['two'], 2) + proc = DummyDynamicProcess(inputs={'output_mode': OutputMode.BOTH_SINGLE_AND_NAMESPACE}) + proc.execute() + + self.assertIsNotNone(proc.outputs) + self.assertEqual(proc.state_label, ProcessState.FINISHED) + self.assertTrue(proc.is_successful) + self.assertEqual(proc.outputs['required_bool'], False) + self.assertEqual(proc.outputs[namespace]['nested']['one'], 1) + self.assertEqual(proc.outputs[namespace]['nested']['two'], 2) class TestProcessEvents(unittest.TestCase): diff --git a/tests/utils.py b/tests/utils.py index 88638e01..be8f2a5e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -468,7 +468,7 @@ def run_until_waiting(proc): listener = plumpy.ProcessListener() in_waiting = plumpy.Future() - if proc.state == ProcessState.WAITING: + if proc.state_label == ProcessState.WAITING: in_waiting.set_result(True) else: