diff --git a/docs/source/nitpick-exceptions b/docs/source/nitpick-exceptions index 2f354987..e1d6d969 100644 --- a/docs/source/nitpick-exceptions +++ b/docs/source/nitpick-exceptions @@ -18,7 +18,7 @@ py:class kiwipy.communications.Communicator # unavailable forward references py:class plumpy.process_states.Command -py:class plumpy.process_states.State +py:class plumpy.state_machine.State py:class plumpy.base.state_machine.State py:class State py:class Process diff --git a/docs/source/tutorial.ipynb b/docs/source/tutorial.ipynb index c1fdb3b2..af1ed795 100644 --- a/docs/source/tutorial.ipynb +++ b/docs/source/tutorial.ipynb @@ -281,7 +281,7 @@ " def continue_fn(self):\n", " print('continuing')\n", " # message is stored in the process status\n", - " return plumpy.Kill('I was killed')\n", + " return plumpy.Kill(plumpy.KillMessage.build('I was killed'))\n", "\n", "\n", "process = ContinueProcess()\n", @@ -1118,7 +1118,7 @@ "\n", "process = SimpleProcess(communicator=communicator)\n", "\n", - "pprint(communicator.rpc_send(str(process.pid), plumpy.STATUS_MSG).result())" + "pprint(communicator.rpc_send(str(process.pid), plumpy.StatusMessage.build()).result())" ] }, { diff --git a/src/plumpy/__init__.py b/src/plumpy/__init__.py index 6f94b5bf..5aa23401 100644 --- a/src/plumpy/__init__.py +++ b/src/plumpy/__init__.py @@ -9,7 +9,6 @@ from .exceptions import * from .futures import * from .loaders import * -from .mixins import * from .persistence import * from .ports import * from .process_comms import * @@ -25,7 +24,6 @@ + processes.__all__ + utils.__all__ + futures.__all__ - + mixins.__all__ + persistence.__all__ + communications.__all__ + process_comms.__all__ diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index d99d0705..9e7ca122 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- """The state machine for processes""" +from __future__ import annotations + import enum import functools import inspect @@ -8,7 +10,21 @@ import os import sys from types import TracebackType -from typing import Any, Callable, Dict, Hashable, Iterable, List, Optional, Sequence, Set, Type, Union, cast +from typing import ( + Any, + Callable, + ClassVar, + Dict, + Hashable, + Iterable, + List, + Optional, + Protocol, + Sequence, + Type, + Union, + runtime_checkable, +) from plumpy.futures import Future @@ -18,7 +34,6 @@ _LOGGER = logging.getLogger(__name__) -LABEL_TYPE = Union[None, enum.Enum, str] EVENT_CALLBACK_TYPE = Callable[['StateMachine', Hashable, Optional['State']], None] @@ -31,7 +46,7 @@ class StateEntryFailed(Exception): # noqa: N818 Failed to enter a state, can provide the next state to go to via this exception """ - def __init__(self, state: Hashable = None, *args: Any, **kwargs: Any) -> None: + def __init__(self, state: State, *args: Any, **kwargs: Any) -> None: super().__init__('failed to enter state') self.state = state self.args = args @@ -74,12 +89,12 @@ def event( if from_states != '*': if inspect.isclass(from_states): from_states = (from_states,) - if not all(issubclass(state, State) for state in from_states): # type: ignore + if not all(isinstance(state, State) for state in from_states): # type: ignore raise TypeError(f'from_states: {from_states}') if to_states != '*': if inspect.isclass(to_states): to_states = (to_states,) - if not all(issubclass(state, State) for state in to_states): # type: ignore + if not all(isinstance(state, State) for state in to_states): # type: ignore raise TypeError(f'to_states: {to_states}') def wrapper(wrapped: Callable[..., Any]) -> Callable[..., Any]: @@ -113,57 +128,40 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any: return wrapper -class State: - LABEL: LABEL_TYPE = None - # A set containing the labels of states that can be entered - # from this one - ALLOWED: Set[LABEL_TYPE] = set() +@runtime_checkable +class State(Protocol): + LABEL: ClassVar[Any] + ALLOWED: ClassVar[set[Any]] + is_terminal: ClassVar[bool] - @classmethod - def is_terminal(cls) -> bool: - return not cls.ALLOWED + def __init__(self, *args: Any, **kwargs: Any): ... - def __init__(self, state_machine: 'StateMachine', *args: Any, **kwargs: Any): - """ - :param state_machine: The process this state belongs to - """ - self.state_machine = state_machine - self.in_state: bool = False + def enter(self) -> None: ... - def __str__(self) -> str: - return str(self.LABEL) + def exit(self) -> None: ... - @property - def label(self) -> LABEL_TYPE: - """Convenience property to get the state label""" - return self.LABEL - @super_check - def enter(self) -> None: - """Entering the state""" +@runtime_checkable +class Interruptable(Protocol): + def interrupt(self, reason: Exception) -> None: ... - def execute(self) -> Optional['State']: + +@runtime_checkable +class Proceedable(Protocol): + def execute(self) -> State | None: """ Execute the state, performing the actions that this state is responsible for. :returns: a state to transition to or None if finished. """ + ... - @super_check - def exit(self) -> None: - """Exiting the state""" - if self.is_terminal(): - raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') - - def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'State': - return self.state_machine.create_state(state_label, *args, **kwargs) - def do_enter(self) -> None: - call_with_super_check(self.enter) - self.in_state = True +def create_state(st: StateMachine, state_label: Hashable, *args: Any, **kwargs: Any) -> State: + if state_label not in st.get_states_map(): + raise ValueError(f'{state_label} is not a valid state') - def do_exit(self) -> None: - call_with_super_check(self.exit) - self.in_state = False + state_cls = st.get_states_map()[state_label] + return state_cls(*args, **kwargs) class StateEventHook(enum.Enum): @@ -187,7 +185,7 @@ def __call__(cls, *args: Any, **kwargs: Any) -> 'StateMachine': :param kwargs: Any keyword arguments to be passed to the constructor :return: An instance of the state machine """ - inst = super().__call__(*args, **kwargs) + inst: StateMachine = super().__call__(*args, **kwargs) inst.transition_to(inst.create_initial_state()) call_with_super_check(inst.init) return inst @@ -214,13 +212,13 @@ def get_states(cls) -> Sequence[Type[State]]: raise RuntimeError('States not defined') @classmethod - def initial_state_label(cls) -> LABEL_TYPE: + def initial_state_label(cls) -> Any: cls.__ensure_built() assert cls.STATES is not None return cls.STATES[0].LABEL @classmethod - def get_state_class(cls, label: LABEL_TYPE) -> Type[State]: + def get_state_class(cls, label: Any) -> Type[State]: cls.__ensure_built() assert cls._STATES_MAP is not None return cls._STATES_MAP[label] @@ -240,7 +238,7 @@ def __ensure_built(cls) -> None: # Build the states map cls._STATES_MAP = {} for state_cls in cls.STATES: - assert issubclass(state_cls, State) + assert isinstance(state_cls, State) label = state_cls.LABEL assert label not in cls._STATES_MAP, f"Duplicate label '{label}'" cls._STATES_MAP[label] = state_cls @@ -264,11 +262,17 @@ def init(self) -> None: def __str__(self) -> str: return f'<{self.__class__.__name__}> ({self.state})' - def create_initial_state(self) -> State: - return self.get_state_class(self.initial_state_label())(self) + def create_initial_state(self, *args: Any, **kwargs: Any) -> State: + return self.get_state_class(self.initial_state_label())(self, *args, **kwargs) + + @property + def state(self) -> State | None: + if self._state is None: + return None + return self._state @property - def state(self) -> Optional[LABEL_TYPE]: + def state_label(self) -> Any: if self._state is None: return None return self._state.LABEL @@ -300,16 +304,24 @@ def _fire_state_event(self, hook: Hashable, state: Optional[State]) -> None: def on_terminated(self) -> None: """Called when a terminal state is entered""" - def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any) -> None: + def transition_to(self, new_state: State | None, **kwargs: Any) -> None: + """Transite to the new state. + + The new target state will be create lazily when the state is not yet instantiated, + which will happened for states not in the expect path such as pause and kill. + The arguments are passed to the state class to create state instance. + (process arg does not need to pass since it will always call with 'self' as process) + """ + print(f'try: {self._state} -> {new_state}') assert not self._transitioning, 'Cannot call transition_to when already transitioning state' - initial_state_label = self._state.LABEL if self._state is not None else None + if new_state is None: + return None + + initial_state_label = self.state_label label = None try: self._transitioning = True - - # Make sure we have a state instance - new_state = self._create_state_instance(new_state, *args, **kwargs) label = new_state.LABEL # If the previous transition failed, do not try to exit it but go straight to next state @@ -319,13 +331,12 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A try: self._enter_next_state(new_state) except StateEntryFailed as exception: - # Make sure we have a state instance - new_state = self._create_state_instance(exception.state, *exception.args, **exception.kwargs) + new_state = exception.state label = new_state.LABEL self._exit_current_state(new_state) self._enter_next_state(new_state) - if self._state is not None and self._state.is_terminal(): + if self._state is not None and self._state.is_terminal: call_with_super_check(self.on_terminated) except Exception: self._transitioning = False @@ -338,7 +349,11 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A self._transitioning = False def transition_failed( - self, initial_state: Hashable, final_state: Hashable, exception: Exception, trace: TracebackType + self, + initial_state: Hashable, + final_state: Hashable, + exception: Exception, + trace: TracebackType, ) -> None: """Called when a state transitions fails. @@ -354,49 +369,25 @@ def get_debug(self) -> bool: def set_debug(self, enabled: bool) -> None: self._debug: bool = enabled - def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> State: - try: - return self.get_states_map()[state_label](self, *args, **kwargs) - except KeyError: - raise ValueError(f'{state_label} is not a valid state') - def _exit_current_state(self, next_state: State) -> None: """Exit the given state""" # If we're just being constructed we may not have a state yet to exit, # in which case check the new state is the initial state if self._state is None: - if next_state.label != self.initial_state_label(): + if next_state.LABEL != self.initial_state_label(): raise RuntimeError(f"Cannot enter state '{next_state}' as the initial state") return # Nothing to exit if next_state.LABEL not in self._state.ALLOWED: - raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.label}') + raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.LABEL}') self._fire_state_event(StateEventHook.EXITING_STATE, next_state) - self._state.do_exit() + self._state.exit() def _enter_next_state(self, next_state: State) -> None: last_state = self._state self._fire_state_event(StateEventHook.ENTERING_STATE, next_state) # Enter the new state - next_state.do_enter() + next_state.enter() self._state = next_state self._fire_state_event(StateEventHook.ENTERED_STATE, last_state) - - def _create_state_instance(self, state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any) -> State: - if isinstance(state, State): - # It's already a state instance - return state - - # OK, have to create it - state_cls = self._ensure_state_class(state) - return state_cls(self, *args, **kwargs) - - def _ensure_state_class(self, state: Union[Hashable, Type[State]]) -> Type[State]: - if inspect.isclass(state) and issubclass(state, State): - return state - - try: - return self.get_states_map()[cast(Hashable, state)] - except KeyError: - raise ValueError(f'{state} is not a valid state') diff --git a/src/plumpy/event_helper.py b/src/plumpy/event_helper.py index 47ad4956..9262f856 100644 --- a/src/plumpy/event_helper.py +++ b/src/plumpy/event_helper.py @@ -1,6 +1,9 @@ # -*- coding: utf-8 -*- import logging -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Optional + +from plumpy.persistence import LoadSaveContext, Savable, auto_load, auto_save, ensure_object_loader +from plumpy.utils import SAVED_STATE_TYPE from . import persistence @@ -13,7 +16,7 @@ @persistence.auto_persist('_listeners', '_listener_type') -class EventHelper(persistence.Savable): +class EventHelper: def __init__(self, listener_type: 'Type[ProcessListener]'): assert listener_type is not None, 'Must provide valid listener type' @@ -30,6 +33,26 @@ def remove_listener(self, listener: 'ProcessListener') -> None: def remove_all_listeners(self) -> None: self._listeners.clear() + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> Savable: + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) + return obj + + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + @property def listeners(self) -> 'Set[ProcessListener]': return self._listeners diff --git a/src/plumpy/mixins.py b/src/plumpy/mixins.py deleted file mode 100644 index 10142eb7..00000000 --- a/src/plumpy/mixins.py +++ /dev/null @@ -1,44 +0,0 @@ -# -*- coding: utf-8 -*- -from typing import Any, Optional - -from . import persistence -from .utils import SAVED_STATE_TYPE, AttributesDict - -__all__ = ['ContextMixin'] - - -class ContextMixin(persistence.Savable): - """ - Add a context to a Process. The contents of the context will be saved - in the instance state unlike standard instance variables. - """ - - CONTEXT: str = '_context' - - def __init__(self, *args: Any, **kwargs: Any): - super().__init__(*args, **kwargs) - self._context: Optional[AttributesDict] = AttributesDict() - - @property - def ctx(self) -> Optional[AttributesDict]: - return self._context - - def save_instance_state( - self, out_state: SAVED_STATE_TYPE, save_context: Optional[persistence.LoadSaveContext] - ) -> None: - """Add the instance state to ``out_state``. - .. important:: - - The instance state will contain a pointer to the ``ctx``, - and so should be deep copied or serialised before persisting. - """ - super().save_instance_state(out_state, save_context) - if self._context is not None: - out_state[self.CONTEXT] = self._context.__dict__ - - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - try: - self._context = AttributesDict(**saved_state[self.CONTEXT]) - except KeyError: - pass diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index ba755bc5..3b333edb 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -9,12 +9,25 @@ import os import pickle from types import MethodType -from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Optional, Set, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Dict, + Generator, + Iterable, + List, + Optional, + Protocol, + TypeVar, + cast, + runtime_checkable, +) import yaml from . import futures, loaders, utils -from .base.utils import call_with_super_check, super_check from .utils import PID_TYPE, SAVED_STATE_TYPE __all__ = [ @@ -35,8 +48,33 @@ from .processes import Process +class LoadSaveContext: + def __init__(self, loader: Optional[loaders.ObjectLoader] = None, **kwargs: Any) -> None: + self._values = dict(**kwargs) + self.loader = loader + + def __getattr__(self, item: str) -> Any: + try: + return self._values[item] + except KeyError: + raise AttributeError(f"item '{item}' not found") + + def __iter__(self) -> Iterable[Any]: + return self._value.__iter__() + + def __contains__(self, item: Any) -> bool: + return self._values.__contains__(item) + + def copyextend(self, **kwargs: Any) -> 'LoadSaveContext': + """Add additional information to the context by making a copy with the new values""" + extended = self._values.copy() + extended.update(kwargs) + loader = extended.pop('loader', self.loader) + return LoadSaveContext(loader=loader, **extended) + + class Bundle(dict): - def __init__(self, savable: 'Savable', save_context: Optional['LoadSaveContext'] = None, dereference: bool = False): + def __init__(self, savable: 'Savable', save_context: LoadSaveContext | None = None, dereference: bool = False): """ Create a bundle from a savable. Optionally keep information about the class loader that can be used to load the classes in the bundle. @@ -52,7 +90,7 @@ class loader that can be used to load the classes in the bundle. else: self.update(savable.save(save_context)) - def unbundle(self, load_context: Optional['LoadSaveContext'] = None) -> 'Savable': + def unbundle(self, load_context: LoadSaveContext | None = None) -> 'Savable': """ This method loads the class of the object and calls its recreate_from method passing the positional and keyword arguments. @@ -61,7 +99,29 @@ def unbundle(self, load_context: Optional['LoadSaveContext'] = None) -> 'Savable :return: An instance of the Savable """ - return Savable.load(self, load_context) + return load(self, load_context) + + +def load(saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None = None) -> 'Savable': + """ + Load a `Savable` from a saved instance state. The load context is a way of passing + runtime data to the object being loaded. + + :param saved_state: The saved state + :param load_context: Additional runtime state that can be passed into when loading. + The type and content (if any) is completely user defined + :return: The loaded Savable instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + assert load_context.loader is not None # required for type checking + try: + class_name = SaveUtil.get_class_name(saved_state) + load_cls: Savable = load_context.loader.load_object(class_name) + except KeyError: + raise ValueError('Class name not found in saved state') + else: + return load_cls.recreate_from(saved_state, load_context) _BUNDLE_TAG = '!plumpy:Bundle' @@ -345,22 +405,7 @@ def delete_process_checkpoints(self, pid: PID_TYPE) -> None: del self._checkpoints[pid] -SavableClsType = TypeVar('SavableClsType', bound='type[Savable]') - - -def auto_persist(*members: str) -> Callable[[SavableClsType], SavableClsType]: - def wrapped(savable: SavableClsType) -> SavableClsType: - if savable._auto_persist is None: - savable._auto_persist = set() - else: - savable._auto_persist = set(savable._auto_persist) - savable.auto_persist(*members) - return savable - - return wrapped - - -def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAVED_STATE_TYPE) -> 'LoadSaveContext': +def ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAVED_STATE_TYPE) -> 'LoadSaveContext': """ Given a LoadSaveContext this method will ensure that it has a valid class loader using the following priorities: @@ -382,7 +427,7 @@ def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAV # 2) Try getting from saved_state default_loader = loaders.get_object_loader() try: - loader_identifier = Savable.get_custom_meta(saved_state, META__OBJECT_LOADER) + loader_identifier = SaveUtil.get_custom_meta(saved_state, META__OBJECT_LOADER) except ValueError: # 3) Fall back to default loader = default_loader @@ -392,31 +437,6 @@ def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAV return context.copyextend(loader=loader) -class LoadSaveContext: - def __init__(self, loader: Optional[loaders.ObjectLoader] = None, **kwargs: Any) -> None: - self._values = dict(**kwargs) - self.loader = loader - - def __getattr__(self, item: str) -> Any: - try: - return self._values[item] - except KeyError: - raise AttributeError(f"item '{item}' not found") - - def __iter__(self) -> Iterable[Any]: - return self._value.__iter__() - - def __contains__(self, item: Any) -> bool: - return self._values.__contains__(item) - - def copyextend(self, **kwargs: Any) -> 'LoadSaveContext': - """Add additional information to the context by making a copy with the new values""" - extended = self._values.copy() - extended.update(kwargs) - loader = extended.pop('loader', self.loader) - return LoadSaveContext(loader=loader, **extended) - - META: str = '!!meta' META__CLASS_NAME: str = 'class_name' META__OBJECT_LOADER: str = 'object_loader' @@ -426,46 +446,48 @@ def copyextend(self, **kwargs: Any) -> 'LoadSaveContext': META__TYPE__SAVABLE: str = 'S' -class Savable: - CLASS_NAME: str = 'class_name' +class SaveUtil: + @staticmethod + def set_custom_meta(out_state: SAVED_STATE_TYPE, name: str, value: Any) -> None: + user_dict = SaveUtil.get_create_meta(out_state).setdefault(META__USER, {}) + user_dict[name] = value - _auto_persist: Optional[Set[str]] = None - _persist_configured = False + @staticmethod + def get_custom_meta(saved_state: SAVED_STATE_TYPE, name: str) -> Any: + try: + return saved_state[META][name] + except KeyError: + raise ValueError(f"Unknown meta key '{name}'") @staticmethod - def load(saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': - """ - Load a `Savable` from a saved instance state. The load context is a way of passing - runtime data to the object being loaded. + def get_create_meta(out_state: SAVED_STATE_TYPE) -> Dict[str, Any]: + return out_state.setdefault(META, {}) - :param saved_state: The saved state - :param load_context: Additional runtime state that can be passed into when loading. - The type and content (if any) is completely user defined - :return: The loaded Savable instance + @staticmethod + def set_class_name(out_state: SAVED_STATE_TYPE, name: str) -> None: + SaveUtil.get_create_meta(out_state)[META__CLASS_NAME] = name - """ - load_context = _ensure_object_loader(load_context, saved_state) - assert load_context.loader is not None # required for type checking + @staticmethod + def get_class_name(saved_state: SAVED_STATE_TYPE) -> str: + return SaveUtil.get_create_meta(saved_state)[META__CLASS_NAME] + + @staticmethod + def set_meta_type(out_state: SAVED_STATE_TYPE, name: str, type_spec: Any) -> None: + type_dict = SaveUtil.get_create_meta(out_state).setdefault(META__TYPES, {}) + type_dict[name] = type_spec + + @staticmethod + def get_meta_type(saved_state: SAVED_STATE_TYPE, name: str) -> Any: try: - class_name = Savable._get_class_name(saved_state) - load_cls = load_context.loader.load_object(class_name) + return saved_state[META][META__TYPES][name] except KeyError: - raise ValueError('Class name not found in saved state') - else: - return load_cls.recreate_from(saved_state, load_context) - - @classmethod - def auto_persist(cls, *members: str) -> None: - if cls._auto_persist is None: - cls._auto_persist = set() - cls._auto_persist.update(members) + pass - @classmethod - def persist(cls) -> None: - pass +@runtime_checkable +class Savable(Protocol): @classmethod - def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None = None) -> 'Savable': """ Recreate a :class:`Savable` from a saved state using an optional load context. @@ -475,137 +497,119 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) - obj = cls.__new__(cls) - call_with_super_check(obj.load_instance_state, saved_state, load_context) - return obj + ... - @super_check - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext]) -> None: - self._ensure_persist_configured() - if self._auto_persist is not None: - self.load_members(self._auto_persist, saved_state, load_context) + def save(self, save_context: LoadSaveContext | None = None) -> SAVED_STATE_TYPE: ... - @super_check - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: Optional[LoadSaveContext]) -> None: - self._ensure_persist_configured() - if self._auto_persist is not None: - self.save_members(self._auto_persist, out_state) - def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: - out_state: SAVED_STATE_TYPE = {} +@runtime_checkable +class SavableWithAutoPersist(Savable, Protocol): + _auto_persist: ClassVar[set[str]] = set() - if save_context is None: - save_context = LoadSaveContext() - utils.type_check(save_context, LoadSaveContext) +def auto_save(obj: Savable, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = {} - default_loader = loaders.get_object_loader() - # If the user has specified a class loader, then save it in the saved state - if save_context.loader is not None: - loader_class = default_loader.identify_object(save_context.loader.__class__) - Savable.set_custom_meta(out_state, META__OBJECT_LOADER, loader_class) - loader = save_context.loader - else: - loader = default_loader + if save_context is None: + save_context = LoadSaveContext() - Savable._set_class_name(out_state, loader.identify_object(self.__class__)) - call_with_super_check(self.save_instance_state, out_state, save_context) - return out_state + utils.type_check(save_context, LoadSaveContext) - def save_members(self, members: Iterable[str], out_state: SAVED_STATE_TYPE) -> None: - for member in members: - value = getattr(self, member) + default_loader = loaders.get_object_loader() + # If the user has specified a class loader, then save it in the saved state + if save_context.loader is not None: + loader_class = default_loader.identify_object(save_context.loader.__class__) + SaveUtil.set_custom_meta(out_state, META__OBJECT_LOADER, loader_class) + loader = save_context.loader + else: + loader = default_loader + + SaveUtil.set_class_name(out_state, loader.identify_object(obj.__class__)) + + if isinstance(obj, SavableWithAutoPersist): + for member in obj._auto_persist: + value = getattr(obj, member) if inspect.ismethod(value): - if value.__self__ is not self: + if value.__self__ is not obj: raise TypeError('Cannot persist methods of other classes') - Savable._set_meta_type(out_state, member, META__TYPE__METHOD) + SaveUtil.set_meta_type(out_state, member, META__TYPE__METHOD) value = value.__name__ - elif isinstance(value, Savable): - Savable._set_meta_type(out_state, member, META__TYPE__SAVABLE) + elif isinstance(value, Savable) and not isinstance(value, type): + # persist for a savable obj, call `save` method of obj. + # the rhs branch is for when value is a Savable class, it is true runtime check + # of lhs condition. + SaveUtil.set_meta_type(out_state, member, META__TYPE__SAVABLE) value = value.save() else: value = copy.deepcopy(value) out_state[member] = value - def load_members( - self, members: Iterable[str], saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None - ) -> None: - for member in members: - setattr(self, member, self._get_value(saved_state, member, load_context)) + return out_state - def _ensure_persist_configured(self) -> None: - if not self._persist_configured: - self.persist() - self._persist_configured = True - # region Metadata getter/setters +def load_auto_persist_params( + obj: SavableWithAutoPersist, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None +) -> None: + for member in obj._auto_persist: + setattr(obj, member, _get_value(obj, saved_state, member, load_context)) - @staticmethod - def set_custom_meta(out_state: SAVED_STATE_TYPE, name: str, value: Any) -> None: - user_dict = Savable._get_create_meta(out_state).setdefault(META__USER, {}) - user_dict[name] = value - @staticmethod - def get_custom_meta(saved_state: SAVED_STATE_TYPE, name: str) -> Any: - try: - return saved_state[META][name] - except KeyError: - raise ValueError(f"Unknown meta key '{name}'") +T = TypeVar('T', bound=Savable) - @staticmethod - def _get_create_meta(out_state: SAVED_STATE_TYPE) -> Dict[str, Any]: - return out_state.setdefault(META, {}) - @staticmethod - def _set_class_name(out_state: SAVED_STATE_TYPE, name: str) -> None: - Savable._get_create_meta(out_state)[META__CLASS_NAME] = name +def auto_load(cls: type[T], saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext | None) -> T: + obj = cls.__new__(cls) - @staticmethod - def _get_class_name(saved_state: SAVED_STATE_TYPE) -> str: - return Savable._get_create_meta(saved_state)[META__CLASS_NAME] + if isinstance(obj, SavableWithAutoPersist): + load_auto_persist_params(obj, saved_state, load_context) - @staticmethod - def _set_meta_type(out_state: SAVED_STATE_TYPE, name: str, type_spec: Any) -> None: - type_dict = Savable._get_create_meta(out_state).setdefault(META__TYPES, {}) - type_dict[name] = type_spec + return obj - @staticmethod - def _get_meta_type(saved_state: SAVED_STATE_TYPE, name: str) -> Any: - try: - return saved_state[META][META__TYPES][name] - except KeyError: - pass - # endregion +def _get_value( + obj: Any, saved_state: SAVED_STATE_TYPE, name: str, load_context: LoadSaveContext | None +) -> MethodType | Savable: + value = saved_state[name] + + typ = SaveUtil.get_meta_type(saved_state, name) + if typ == META__TYPE__METHOD: + value = getattr(obj, value) + elif typ == META__TYPE__SAVABLE: + value = load(value, load_context) + + return value - def _get_value( - self, saved_state: SAVED_STATE_TYPE, name: str, load_context: Optional[LoadSaveContext] - ) -> Union[MethodType, 'Savable']: - value = saved_state[name] - typ = Savable._get_meta_type(saved_state, name) - if typ == META__TYPE__METHOD: - value = getattr(self, value) - elif typ == META__TYPE__SAVABLE: - value = Savable.load(value, load_context) +def auto_persist(*members: str) -> Callable[..., Savable]: + def wrapped(savable_cls: type) -> Savable: + if not hasattr(savable_cls, '_auto_persist') or savable_cls._auto_persist is None: + savable_cls._auto_persist = set() # type: ignore[attr-defined] + else: + savable_cls._auto_persist = set(savable_cls._auto_persist) + + savable_cls._auto_persist.update(members) # type: ignore[attr-defined] + # XXX: validate on `save` and `recreate_from` method?? + return cast(Savable, savable_cls) - return value + return wrapped +# FIXME: move me to another module? savablefuture.py? @auto_persist('_state', '_result') -class SavableFuture(futures.Future, Savable): +class SavableFuture(futures.Future): """ A savable future. .. note: This does not save any assigned done callbacks. """ - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) if self.done() and self.exception() is not None: out_state['exception'] = self.exception() + return out_state + @classmethod def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': """ @@ -617,7 +621,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa :return: The recreated instance """ - load_context = _ensure_object_loader(load_context, saved_state) + load_context = ensure_object_loader(load_context, saved_state) try: loop = load_context.loop @@ -643,11 +647,13 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[Loa obj = cls(loop=loop) obj.cancel() - return obj + # ## XXX: load_instance_state: test not cover + # auto_load(obj, saved_state, load_context) + # + # if obj._callbacks: + # # typing says asyncio.Future._callbacks needs to be called, but in the python 3.7 code it is a simple list + # for callback in obj._callbacks: + # obj.remove_done_callback(callback) # type: ignore[arg-type] + # ## UNTILHERE XXX: - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - if self._callbacks: - # typing says asyncio.Future._callbacks needs to be called, but in the python 3.7 code it is a simple list - for callback in self._callbacks: - self.remove_done_callback(callback) # type: ignore[arg-type] + return obj diff --git a/src/plumpy/process_comms.py b/src/plumpy/process_comms.py index 293c680b..cd6e7238 100644 --- a/src/plumpy/process_comms.py +++ b/src/plumpy/process_comms.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- """Module for process level communication functions and classes""" +from __future__ import annotations + import asyncio -import copy import logging from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union, cast @@ -12,13 +13,13 @@ from .utils import PID_TYPE __all__ = [ - 'KILL_MSG', - 'PAUSE_MSG', - 'PLAY_MSG', - 'STATUS_MSG', + 'KillMessage', + 'PauseMessage', + 'PlayMessage', 'ProcessLauncher', 'RemoteProcessController', 'RemoteProcessThreadController', + 'StatusMessage', 'create_continue_body', 'create_launch_body', ] @@ -31,6 +32,7 @@ INTENT_KEY = 'intent' MESSAGE_KEY = 'message' +FORCE_KILL_KEY = 'force_kill' class Intent: @@ -42,10 +44,45 @@ class Intent: STATUS: str = 'status' -PAUSE_MSG = {INTENT_KEY: Intent.PAUSE} -PLAY_MSG = {INTENT_KEY: Intent.PLAY} -KILL_MSG = {INTENT_KEY: Intent.KILL} -STATUS_MSG = {INTENT_KEY: Intent.STATUS} +MessageType = Dict[str, Any] + + +class PlayMessage: + @classmethod + def build(cls, message: str | None = None) -> MessageType: + return { + INTENT_KEY: Intent.PLAY, + MESSAGE_KEY: message, + } + + +class PauseMessage: + @classmethod + def build(cls, message: str | None = None) -> MessageType: + return { + INTENT_KEY: Intent.PAUSE, + MESSAGE_KEY: message, + } + + +class KillMessage: + @classmethod + def build(cls, message: str | None = None, force: bool = False) -> MessageType: + return { + INTENT_KEY: Intent.KILL, + MESSAGE_KEY: message, + FORCE_KILL_KEY: force, + } + + +class StatusMessage: + @classmethod + def build(cls, message: str | None = None) -> MessageType: + return { + INTENT_KEY: Intent.STATUS, + MESSAGE_KEY: message, + } + TASK_KEY = 'task' TASK_ARGS = 'args' @@ -162,7 +199,7 @@ async def get_status(self, pid: 'PID_TYPE') -> 'ProcessStatus': :param pid: the process id :return: the status response from the process """ - future = self._communicator.rpc_send(pid, STATUS_MSG) + future = self._communicator.rpc_send(pid, StatusMessage.build()) result = await asyncio.wrap_future(future) return result @@ -174,11 +211,9 @@ async def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'Pr :param msg: optional pause message :return: True if paused, False otherwise """ - message = copy.copy(PAUSE_MSG) - if msg is not None: - message[MESSAGE_KEY] = msg + msg = PauseMessage.build(message=msg) - pause_future = self._communicator.rpc_send(pid, message) + pause_future = self._communicator.rpc_send(pid, msg) # rpc_send return a thread future from communicator future = await asyncio.wrap_future(pause_future) # future is just returned from rpc call which return a kiwipy future @@ -192,12 +227,12 @@ async def play_process(self, pid: 'PID_TYPE') -> 'ProcessResult': :param pid: the pid of the process to play :return: True if played, False otherwise """ - play_future = self._communicator.rpc_send(pid, PLAY_MSG) + play_future = self._communicator.rpc_send(pid, PlayMessage.build()) future = await asyncio.wrap_future(play_future) result = await asyncio.wrap_future(future) return result - async def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'ProcessResult': + async def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> 'ProcessResult': """ Kill the process @@ -205,12 +240,11 @@ async def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'Pro :param msg: optional kill message :return: True if killed, False otherwise """ - message = copy.copy(KILL_MSG) - if msg is not None: - message[MESSAGE_KEY] = msg + if msg is None: + msg = KillMessage.build() # Wait for the communication to go through - kill_future = self._communicator.rpc_send(pid, message) + kill_future = self._communicator.rpc_send(pid, msg) future = await asyncio.wrap_future(kill_future) # Now wait for the kill to be enacted result = await asyncio.wrap_future(future) @@ -331,7 +365,7 @@ def get_status(self, pid: 'PID_TYPE') -> kiwipy.Future: :param pid: the process id :return: the status response from the process """ - return self._communicator.rpc_send(pid, STATUS_MSG) + return self._communicator.rpc_send(pid, StatusMessage.build()) def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Future: """ @@ -342,11 +376,9 @@ def pause_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Fu :return: a response future from the process to be paused """ - message = copy.copy(PAUSE_MSG) - if msg is not None: - message[MESSAGE_KEY] = msg + msg = PauseMessage.build(message=msg) - return self._communicator.rpc_send(pid, message) + return self._communicator.rpc_send(pid, msg) def pause_all(self, msg: Any) -> None: """ @@ -364,7 +396,7 @@ def play_process(self, pid: 'PID_TYPE') -> kiwipy.Future: :return: a response future from the process to be played """ - return self._communicator.rpc_send(pid, PLAY_MSG) + return self._communicator.rpc_send(pid, PlayMessage.build()) def play_all(self) -> None: """ @@ -372,7 +404,7 @@ def play_all(self) -> None: """ self._communicator.broadcast_send(None, subject=Intent.PLAY) - def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Future: + def kill_process(self, pid: 'PID_TYPE', msg: Optional[MessageType] = None) -> kiwipy.Future: """ Kill the process @@ -381,18 +413,20 @@ def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Fut :return: a response future from the process to be killed """ - message = copy.copy(KILL_MSG) - if msg is not None: - message[MESSAGE_KEY] = msg + if msg is None: + msg = KillMessage.build() - return self._communicator.rpc_send(pid, message) + return self._communicator.rpc_send(pid, msg) - def kill_all(self, msg: Optional[Any]) -> None: + def kill_all(self, msg: Optional[MessageType]) -> None: """ Kill all processes that are subscribed to the same communicator :param msg: an optional pause message """ + if msg is None: + msg = KillMessage.build() + self._communicator.broadcast_send(msg, subject=Intent.KILL) def continue_process( diff --git a/src/plumpy/process_listener.py b/src/plumpy/process_listener.py index 8e1acf94..2ec07751 100644 --- a/src/plumpy/process_listener.py +++ b/src/plumpy/process_listener.py @@ -2,17 +2,23 @@ import abc from typing import TYPE_CHECKING, Any, Dict, Optional +from plumpy.persistence import LoadSaveContext, auto_save, ensure_object_loader + from . import persistence -from .utils import SAVED_STATE_TYPE, protected +from .utils import SAVED_STATE_TYPE __all__ = ['ProcessListener'] if TYPE_CHECKING: + from plumpy.persistence import Savable + from .processes import Process +# FIXME: test any process listener is a savable + @persistence.auto_persist('_params') -class ProcessListener(persistence.Savable, metaclass=abc.ABCMeta): +class ProcessListener(metaclass=abc.ABCMeta): # region Persistence methods def __init__(self) -> None: @@ -22,12 +28,26 @@ def __init__(self) -> None: def init(self, **kwargs: Any) -> None: self._params = kwargs - @protected - def load_instance_state( - self, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext] - ) -> None: - super().load_instance_state(saved_state, load_context) - self.init(**saved_state['_params']) + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + obj = cls.__new__(cls) + obj.init(**saved_state['_params']) + return obj + + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state # endregion diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 7ae6e9bd..1a176b9b 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -1,13 +1,29 @@ # -*- coding: utf-8 -*- +from __future__ import annotations + import sys import traceback from enum import Enum from types import TracebackType -from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Optional, + Tuple, + Type, + Union, + cast, + final, + override, +) import yaml from yaml.loader import Loader +from plumpy.process_comms import KillMessage, MessageType + try: import tblib @@ -16,9 +32,16 @@ _HAS_TBLIB = False from . import exceptions, futures, persistence, utils -from .base import state_machine +from .base import state_machine as st from .lang import NULL -from .persistence import auto_persist +from .persistence import ( + LoadSaveContext, + Savable, + auto_load, + auto_persist, + auto_save, + ensure_object_loader, +) from .utils import SAVED_STATE_TYPE __all__ = [ @@ -48,7 +71,12 @@ class Interruption(Exception): # noqa: N818 class KillInterruption(Interruption): - pass + def __init__(self, msg: MessageType | None): + super().__init__() + if msg is None: + msg = KillMessage.build() + + self.msg: MessageType = msg class PauseInterruption(Interruption): @@ -58,13 +86,31 @@ class PauseInterruption(Interruption): # region Commands -class Command(persistence.Savable): - pass +class Command: + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) + return obj + + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state @auto_persist('msg') class Kill(Command): - def __init__(self, msg: Optional[Any] = None): + def __init__(self, msg: Optional[MessageType] = None): super().__init__() self.msg = msg @@ -76,7 +122,10 @@ class Pause(Command): @auto_persist('msg', 'data') class Wait(Command): def __init__( - self, continue_fn: Optional[Callable[..., Any]] = None, msg: Optional[Any] = None, data: Optional[Any] = None + self, + continue_fn: Optional[Callable[..., Any]] = None, + msg: Optional[Any] = None, + data: Optional[Any] = None, ): super().__init__() self.continue_fn = continue_fn @@ -102,17 +151,36 @@ def __init__(self, continue_fn: Callable[..., Any], *args: Any, **kwargs: Any): self.args = args self.kwargs = kwargs - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + @override + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) out_state[self.CONTINUE_FN] = self.continue_fn.__name__ - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) + return out_state + + @override + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) + try: - self.continue_fn = utils.load_function(saved_state[self.CONTINUE_FN]) + obj.continue_fn = utils.load_function(saved_state[obj.CONTINUE_FN]) except ValueError: - process = load_context.process - self.continue_fn = getattr(process, saved_state[self.CONTINUE_FN]) + if load_context is not None: + obj.continue_fn = getattr(load_context.proc, saved_state[obj.CONTINUE_FN]) + else: + raise + return obj # endregion @@ -125,61 +193,69 @@ class ProcessState(Enum): The possible states that a :class:`~plumpy.processes.Process` can be in. """ - CREATED: str = 'created' - RUNNING: str = 'running' - WAITING: str = 'waiting' - FINISHED: str = 'finished' - EXCEPTED: str = 'excepted' - KILLED: str = 'killed' - - -@auto_persist('in_state') -class State(state_machine.State, persistence.Savable): - @property - def process(self) -> state_machine.StateMachine: - """ - :return: The process - """ - return self.state_machine - - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self.state_machine = load_context.process - - def interrupt(self, reason: Any) -> None: - pass + CREATED = 'created' + RUNNING = 'running' + WAITING = 'waiting' + FINISHED = 'finished' + EXCEPTED = 'excepted' + KILLED = 'killed' +@final @auto_persist('args', 'kwargs') -class Created(State): - LABEL = ProcessState.CREATED - ALLOWED = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED} +class Created: + LABEL: ClassVar = ProcessState.CREATED + ALLOWED: ClassVar = {ProcessState.RUNNING, ProcessState.KILLED, ProcessState.EXCEPTED} RUN_FN = 'run_fn' + is_terminal: ClassVar[bool] = False def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: - super().__init__(process) assert run_fn is not None + self.process = process self.run_fn = run_fn self.args = args self.kwargs = kwargs - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) out_state[self.RUN_FN] = self.run_fn.__name__ - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) + return out_state + + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance - def execute(self) -> state_machine.State: - return self.create_state(ProcessState.RUNNING, self.run_fn, *self.args, **self.kwargs) + """ + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) + obj.process = load_context.process + obj.run_fn = getattr(obj.process, saved_state[obj.RUN_FN]) + + return obj + + def execute(self) -> st.State: + return st.create_state( + self.process, ProcessState.RUNNING, process=self.process, run_fn=self.run_fn, *self.args, **self.kwargs + ) + + def enter(self) -> None: ... + def exit(self) -> None: ... + +@final @auto_persist('args', 'kwargs') -class Running(State): - LABEL = ProcessState.RUNNING - ALLOWED = { +class Running: + LABEL: ClassVar = ProcessState.RUNNING + ALLOWED: ClassVar = { ProcessState.RUNNING, ProcessState.WAITING, ProcessState.FINISHED, @@ -195,30 +271,49 @@ class Running(State): _running: bool = False _run_handle = None + is_terminal: ClassVar[bool] = False + def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: - super().__init__(process) assert run_fn is not None + self.process = process self.run_fn = run_fn self.args = args self.kwargs = kwargs self._run_handle = None - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + out_state[self.RUN_FN] = self.run_fn.__name__ if self._command is not None: out_state[self.COMMAND] = self._command.save() - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) - if self.COMMAND in saved_state: - self._command = persistence.Savable.load(saved_state[self.COMMAND], load_context) # type: ignore + return out_state + + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) + obj.process = load_context.process + obj.run_fn = getattr(obj.process, saved_state[obj.RUN_FN]) + if obj.COMMAND in saved_state: + obj._command = persistence.load(saved_state[obj.COMMAND], load_context) # type: ignore + + return obj def interrupt(self, reason: Any) -> None: pass - async def execute(self) -> State: # type: ignore + def execute(self) -> st.State: if self._command is not None: command = self._command else: @@ -232,8 +327,10 @@ async def execute(self) -> State: # type: ignore # Let this bubble up to the caller raise except Exception: - excepted = self.create_state(ProcessState.EXCEPTED, *sys.exc_info()[1:]) - return cast(State, excepted) + _, exception, traceback = sys.exc_info() + # excepted = state_cls(exception=exception, traceback=traceback) + excepted = Excepted(exception=exception, traceback=traceback) + return excepted else: if not isinstance(result, Command): if isinstance(result, exceptions.UnsuccessfulResult): @@ -242,32 +339,52 @@ async def execute(self) -> State: # type: ignore # Got passed a basic return type result = Stop(result, True) - command = result + command = cast(Stop, result) next_state = self._action_command(command) return next_state - def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> State: + def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> st.State: if isinstance(command, Kill): - state = self.create_state(ProcessState.KILLED, command.msg) + state = st.create_state(self.process, ProcessState.KILLED, msg=command.msg) # elif isinstance(command, Pause): # self.pause() elif isinstance(command, Stop): - state = self.create_state(ProcessState.FINISHED, command.result, command.successful) + state = st.create_state( + self.process, ProcessState.FINISHED, result=command.result, successful=command.successful + ) elif isinstance(command, Wait): - state = self.create_state(ProcessState.WAITING, command.continue_fn, command.msg, command.data) + state = st.create_state( + self.process, + ProcessState.WAITING, + process=self.process, + done_callback=command.continue_fn, + msg=command.msg, + data=command.data, + ) elif isinstance(command, Continue): - state = self.create_state(ProcessState.RUNNING, command.continue_fn, *command.args) + state = st.create_state( + self.process, + ProcessState.RUNNING, + process=self.process, + run_fn=command.continue_fn, + *command.args, + **command.kwargs, + ) else: raise ValueError('Unrecognised command') - return cast(State, state) # casting from base.State to process.State + return state + + def enter(self) -> None: ... + + def exit(self) -> None: ... @auto_persist('msg', 'data') -class Waiting(State): - LABEL = ProcessState.WAITING - ALLOWED = { +class Waiting: + LABEL: ClassVar = ProcessState.WAITING + ALLOWED: ClassVar = { ProcessState.RUNNING, ProcessState.WAITING, ProcessState.KILLED, @@ -279,6 +396,8 @@ class Waiting(State): _interruption = None + is_terminal: ClassVar[bool] = False + def __str__(self) -> str: state_info = super().__str__() if self.msg is not None: @@ -292,31 +411,48 @@ def __init__( msg: Optional[str] = None, data: Optional[Any] = None, ) -> None: - super().__init__(process) + self.process = process self.done_callback = done_callback self.msg = msg self.data = data self._waiting_future: futures.Future = futures.Future() - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + if self.done_callback is not None: out_state[self.DONE_CALLBACK] = self.done_callback.__name__ - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - callback_name = saved_state.get(self.DONE_CALLBACK, None) + return out_state + + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) + obj.process = load_context.process + + callback_name = saved_state.get(obj.DONE_CALLBACK, None) if callback_name is not None: - self.done_callback = getattr(self.process, callback_name) + obj.done_callback = getattr(obj.process, callback_name) else: - self.done_callback = None - self._waiting_future = futures.Future() + obj.done_callback = None + obj._waiting_future = futures.Future() + return obj - def interrupt(self, reason: Any) -> None: + def interrupt(self, reason: Exception) -> None: # This will cause the future in execute() to raise the exception self._waiting_future.set_exception(reason) - async def execute(self) -> State: # type: ignore + async def execute(self) -> st.State: try: result = await self._waiting_future except Interruption: @@ -327,11 +463,15 @@ async def execute(self) -> State: # type: ignore raise if result == NULL: - next_state = self.create_state(ProcessState.RUNNING, self.done_callback) + next_state = st.create_state( + self.process, ProcessState.RUNNING, process=self.process, run_fn=self.done_callback + ) else: - next_state = self.create_state(ProcessState.RUNNING, self.done_callback, result) + next_state = st.create_state( + self.process, ProcessState.RUNNING, process=self.process, done_callback=self.done_callback, *result + ) - return cast(State, next_state) # casting from base.State to process.State + return next_state def resume(self, value: Any = NULL) -> None: assert self._waiting_future is not None, 'Not yet waiting' @@ -341,75 +481,184 @@ def resume(self, value: Any = NULL) -> None: self._waiting_future.set_result(value) + def enter(self) -> None: ... -class Excepted(State): - LABEL = ProcessState.EXCEPTED + def exit(self) -> None: ... + + +@final +@auto_persist() +class Excepted: + """ + Excepted state, can optionally provide exception and traceback + + :param exception: The exception instance + :param traceback: An optional exception traceback + """ + + LABEL: ClassVar = ProcessState.EXCEPTED + ALLOWED: ClassVar[set[str]] = set() EXC_VALUE = 'ex_value' TRACEBACK = 'traceback' + is_terminal: ClassVar = True + def __init__( - self, process: 'Process', exception: Optional[BaseException], trace_back: Optional[TracebackType] = None + self, + exception: Optional[BaseException], + traceback: Optional[TracebackType] = None, ): """ - :param process: The associated process :param exception: The exception instance - :param trace_back: An optional exception traceback + :param traceback: An optional exception traceback """ - super().__init__(process) self.exception = exception - self.traceback = trace_back + self.traceback = traceback def __str__(self) -> str: exception = traceback.format_exception_only(type(self.exception) if self.exception else None, self.exception)[0] return super().__str__() + f'({exception})' - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + out_state[self.EXC_VALUE] = yaml.dump(self.exception) if self.traceback is not None: out_state[self.TRACEBACK] = ''.join(traceback.format_tb(self.traceback)) - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self.exception = yaml.load(saved_state[self.EXC_VALUE], Loader=Loader) + return out_state + + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) + + obj.exception = yaml.load(saved_state[obj.EXC_VALUE], Loader=Loader) if _HAS_TBLIB: try: - self.traceback = tblib.Traceback.from_string(saved_state[self.TRACEBACK], strict=False) + obj.traceback = tblib.Traceback.from_string(saved_state[obj.TRACEBACK], strict=False) except KeyError: - self.traceback = None + obj.traceback = None else: - self.traceback = None + obj.traceback = None + return obj - def get_exc_info(self) -> Tuple[Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType]]: + def get_exc_info( + self, + ) -> Tuple[Optional[Type[BaseException]], Optional[BaseException], Optional[TracebackType]]: """ Recreate the exc_info tuple and return it """ - return type(self.exception) if self.exception else None, self.exception, self.traceback + return ( + type(self.exception) if self.exception else None, + self.exception, + self.traceback, + ) + + def enter(self) -> None: ... + def exit(self) -> None: ... + +@final @auto_persist('result', 'successful') -class Finished(State): - LABEL = ProcessState.FINISHED +class Finished: + """State for process is finished. + + :param result: The result of process + :param successful: Boolean for the exit code is ``0`` the process is successful. + """ + + LABEL: ClassVar = ProcessState.FINISHED + ALLOWED: ClassVar[set[str]] = set() + + is_terminal: ClassVar[bool] = True - def __init__(self, process: 'Process', result: Any, successful: bool) -> None: - super().__init__(process) + def __init__(self, result: Any, successful: bool) -> None: self.result = result self.successful = successful + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) + return obj + + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + return out_state + + def enter(self) -> None: ... + + def exit(self) -> None: ... + + +@final @auto_persist('msg') -class Killed(State): - LABEL = ProcessState.KILLED +class Killed: + """ + Represents a state where a process has been killed. + + This state is used to indicate that a process has been terminated and can optionally + include a message providing details about the termination. + + :param msg: An optional message explaining the reason for the process termination. + """ + + LABEL: ClassVar = ProcessState.KILLED + ALLOWED: ClassVar[set[str]] = set() - def __init__(self, process: 'Process', msg: Optional[str]): + is_terminal: ClassVar[bool] = True + + def __init__(self, msg: Optional[MessageType]): """ - :param process: The associated process :param msg: Optional kill message - """ - super().__init__(process) self.msg = msg + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) + return obj + + def save(self, save_context: Optional[LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + + def enter(self) -> None: ... + + def exit(self) -> None: ... + # endregion diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index ba7967d3..8b8107d4 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- """The main Process module""" +from __future__ import annotations + import abc import asyncio import contextlib @@ -18,6 +20,7 @@ Any, Awaitable, Callable, + ClassVar, Dict, Generator, Hashable, @@ -26,10 +29,13 @@ Sequence, Tuple, Type, + TypeVar, Union, cast, ) +from plumpy.persistence import ensure_object_loader + try: from aiocontextvars import ContextVar except ModuleNotFoundError: @@ -39,15 +45,36 @@ import yaml from aio_pika.exceptions import ChannelInvalidStateError, ConnectionClosed -from . import events, exceptions, futures, persistence, ports, process_comms, process_states, utils +from . import ( + events, + exceptions, + futures, + persistence, + ports, + process_comms, + process_states, + utils, +) from .base import state_machine -from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event +from .base.state_machine import ( + Interruptable, + Proceedable, + StateEntryFailed, + StateMachine, + StateMachineError, + TransitionFailed, + create_state, + event, +) from .base.utils import call_with_super_check, super_check from .event_helper import EventHelper +from .process_comms import MESSAGE_KEY, KillMessage, MessageType from .process_listener import ProcessListener from .process_spec import ProcessSpec from .utils import PID_TYPE, SAVED_STATE_TYPE, protected +T = TypeVar('T') + __all__ = ['BundleKeys', 'Process', 'ProcessSpec', 'TransitionFailed'] _LOGGER = logging.getLogger(__name__) @@ -58,7 +85,7 @@ class BundleKeys: """ String keys used by the process to save its state in the state bundle. - See :meth:`plumpy.processes.Process.save_instance_state` and :meth:`plumpy.processes.Process.load_instance_state`. + See :meth:`plumpy.processes.Process.save` and :meth:`plumpy.processes.Process.recreate_from`. """ @@ -91,9 +118,15 @@ def func_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: @persistence.auto_persist( - '_pid', '_creation_time', '_future', '_paused', '_status', '_pre_paused_status', '_event_helper' + '_pid', + '_creation_time', + '_future', + '_paused', + '_status', + '_pre_paused_status', + '_event_helper', ) -class Process(StateMachine, persistence.Savable, metaclass=ProcessStateMachineMeta): +class Process(StateMachine, metaclass=ProcessStateMachineMeta): """ The Process class is the base for any unit of work in plumpy. @@ -143,6 +176,7 @@ class Process(StateMachine, persistence.Savable, metaclass=ProcessStateMachineMe _cleanups: Optional[List[Callable[[], None]]] = None __called: bool = False + _auto_persist: ClassVar[set[str]] @classmethod def current(cls) -> Optional['Process']: @@ -158,7 +192,7 @@ def current(cls) -> Optional['Process']: return None @classmethod - def get_states(cls) -> Sequence[Type[process_states.State]]: + def get_states(cls) -> Sequence[Type[state_machine.State]]: """Return all allowed states of the process.""" state_classes = cls.get_state_classes() return ( @@ -167,7 +201,7 @@ def get_states(cls) -> Sequence[Type[process_states.State]]: ) @classmethod - def get_state_classes(cls) -> Dict[Hashable, Type[process_states.State]]: + def get_state_classes(cls) -> dict[process_states.ProcessState, Type[state_machine.State]]: # A mapping of the State constants to the corresponding state class return { process_states.ProcessState.CREATED: process_states.Created, @@ -231,20 +265,69 @@ def get_description(cls) -> Dict[str, Any]: @classmethod def recreate_from( - cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext] = None - ) -> 'Process': - """ - Recreate a process from a saved state, passing any positional and - keyword arguments on to load_instance_state + cls, + saved_state: SAVED_STATE_TYPE, + load_context: Optional[persistence.LoadSaveContext] = None, + ) -> Process: + """Recreate a process from a saved state, passing any positional :param saved_state: The saved state to load from :param load_context: The load context to use :return: An instance of the object with its state loaded from the save state. """ - process = cast(Process, super().recreate_from(saved_state, load_context)) - call_with_super_check(process.init) - return process + load_context = ensure_object_loader(load_context, saved_state) + proc = cls.__new__(cls) + + # XXX: load_instance_state + # First make sure the state machine constructor is called + state_machine.StateMachine.__init__(proc) + + proc._setup_event_hooks() + + # Runtime variables, set initial states + proc._future = persistence.SavableFuture() + proc._event_helper = EventHelper(ProcessListener) + proc._logger = None + proc._communicator = None + + if 'loop' in load_context: + proc._loop = load_context.loop + else: + proc._loop = asyncio.get_event_loop() + + proc._state = proc.recreate_state(saved_state['_state']) + + if 'communicator' in load_context: + proc._communicator = load_context.communicator + + if 'logger' in load_context: + proc._logger = load_context.logger + + # Need to call this here as things downstream may rely on us having the runtime variable above + persistence.load_auto_persist_params(proc, saved_state, load_context) + + # Inputs/outputs + try: + decoded = proc.decode_input_args(saved_state[BundleKeys.INPUTS_RAW]) + proc._raw_inputs = utils.AttributesFrozendict(decoded) + except KeyError: + proc._raw_inputs = None + + try: + decoded = proc.decode_input_args(saved_state[BundleKeys.INPUTS_PARSED]) + proc._parsed_inputs = utils.AttributesFrozendict(decoded) + except KeyError: + proc._parsed_inputs = None + + try: + decoded = proc.decode_input_args(saved_state[BundleKeys.OUTPUTS]) + proc._outputs = decoded + except KeyError: + proc._outputs = {} + + call_with_super_check(proc.init) + return proc def __init__( self, @@ -314,14 +397,21 @@ def init(self) -> None: identifier = self._communicator.add_broadcast_subscriber(subscriber, identifier=str(self.pid)) self.add_cleanup(functools.partial(self._communicator.remove_broadcast_subscriber, identifier)) except kiwipy.TimeoutError: - self.logger.exception('Process<%s>: failed to register as a broadcast subscriber', self.pid) + self.logger.exception( + 'Process<%s>: failed to register as a broadcast subscriber', + self.pid, + ) if not self._future.done(): def try_killing(future: futures.Future) -> None: if future.cancelled(): - if not self.kill('Killed by future being cancelled'): - self.logger.warning('Process<%s>: Failed to kill process on future cancel', self.pid) + msg = KillMessage.build(message='Killed by future being cancelled') + if not self.kill(msg): + self.logger.warning( + 'Process<%s>: Failed to kill process on future cancel', + self.pid, + ) self._future.add_done_callback(try_killing) @@ -329,10 +419,10 @@ def _setup_event_hooks(self) -> None: """Set the event hooks to process, when it is created or loaded(recreated).""" event_hooks = { state_machine.StateEventHook.ENTERING_STATE: lambda _s, _h, state: self.on_entering( - cast(process_states.State, state) + cast(state_machine.State, state) ), state_machine.StateEventHook.ENTERED_STATE: lambda _s, _h, from_state: self.on_entered( - cast(Optional[process_states.State], from_state) + cast(Optional[state_machine.State], from_state) ), state_machine.StateEventHook.EXITING_STATE: lambda _s, _h, _state: self.on_exiting(), } @@ -425,7 +515,13 @@ def launch( The process is started asynchronously, without blocking other task in the event loop. """ - process = process_class(inputs=inputs, pid=pid, logger=logger, loop=self.loop, communicator=self._communicator) + process = process_class( + inputs=inputs, + pid=pid, + logger=logger, + loop=self.loop, + communicator=self._communicator, + ) self.loop.create_task(process.step_until_terminated()) return process @@ -433,7 +529,9 @@ def launch( def has_terminated(self) -> bool: """Return whether the process was terminated.""" - return self._state.is_terminal() + if self.state is None: + raise exceptions.InvalidStateError('process is not in state None that is invalid') + return self.state.is_terminal def result(self) -> Any: """ @@ -443,12 +541,12 @@ def result(self) -> Any: If in any other state this will raise an InvalidStateError. :return: The result of the process """ - if isinstance(self._state, process_states.Finished): - return self._state.result - if isinstance(self._state, process_states.Killed): - raise exceptions.KilledError(self._state.msg) - if isinstance(self._state, process_states.Excepted): - raise (self._state.exception or Exception('process excepted')) + if isinstance(self.state, process_states.Finished): + return self.state.result + if isinstance(self.state, process_states.Killed): + raise exceptions.KilledError(self.state.msg) + if isinstance(self.state, process_states.Excepted): + raise (self.state.exception or Exception('process excepted')) raise exceptions.InvalidStateError @@ -458,7 +556,7 @@ def successful(self) -> bool: Will raise if the process is not in the FINISHED state """ try: - return self._state.successful # type: ignore + return self.state.successful # type: ignore except AttributeError as exception: raise exceptions.InvalidStateError('process is not in the finished state') from exception @@ -469,25 +567,25 @@ def is_successful(self) -> bool: :return: boolean, True if the process is in `Finished` state with `successful` attribute set to `True` """ try: - return self._state.successful # type: ignore + return self.state.successful # type: ignore except AttributeError: return False def killed(self) -> bool: """Return whether the process is killed.""" - return self.state == process_states.ProcessState.KILLED + return self.state_label == process_states.ProcessState.KILLED - def killed_msg(self) -> Optional[str]: + def killed_msg(self) -> Optional[MessageType]: """Return the killed message.""" - if isinstance(self._state, process_states.Killed): - return self._state.msg + if isinstance(self.state, process_states.Killed): + return self.state.msg raise exceptions.InvalidStateError('Has not been killed') def exception(self) -> Optional[BaseException]: """Return exception, if the process is terminated in excepted state.""" - if isinstance(self._state, process_states.Excepted): - return self._state.exception + if isinstance(self.state, process_states.Excepted): + return self.state.exception return None @@ -497,7 +595,7 @@ def is_excepted(self) -> bool: :return: boolean, True if the process is in ``EXCEPTED`` state. """ - return self.state == process_states.ProcessState.EXCEPTED + return self.state_label == process_states.ProcessState.EXCEPTED def done(self) -> bool: """Return True if the call was successfully killed or finished running. @@ -506,7 +604,7 @@ def done(self) -> bool: Use the `has_terminated` method instead """ warnings.warn('method is deprecated, use `has_terminated` instead', DeprecationWarning) - return self._state.is_terminal() + return self.has_terminated() # endregion @@ -529,9 +627,12 @@ def call_soon(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> return handle def callback_excepted( - self, _callback: Callable[..., Any], exception: Optional[BaseException], trace: Optional[TracebackType] + self, + _callback: Callable[..., Any], + exception: Optional[BaseException], + trace: Optional[TracebackType], ) -> None: - if self.state != process_states.ProcessState.EXCEPTED: + if self.state_label != process_states.ProcessState.EXCEPTED: self.fail(exception, trace) @contextlib.contextmanager @@ -555,7 +656,7 @@ def _process_scope(self) -> Generator[None, None, None]: stack_copy.pop() PROCESS_STACK.set(stack_copy) - async def _run_task(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: + async def _run_task(self, callback: Callable[..., T], *args: Any, **kwargs: Any) -> T: """ This method should be used to run all Process related functions and coroutines. If there is an exception the process will enter the EXCEPTED state. @@ -575,18 +676,17 @@ async def _run_task(self, callback: Callable[..., Any], *args: Any, **kwargs: An # region Persistence - def save_instance_state( - self, out_state: SAVED_STATE_TYPE, save_context: Optional[persistence.LoadSaveContext] - ) -> None: + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: """ Ask the process to save its current instance state. :param out_state: A bundle to save the state to :param save_context: The save context """ - super().save_instance_state(out_state, save_context) + out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) - out_state['_state'] = self._state.save() + if isinstance(self.state, persistence.Savable): + out_state['_state'] = self.state.save() # Inputs/outputs if self.raw_inputs is not None: @@ -598,61 +698,7 @@ def save_instance_state( if self.outputs: out_state[BundleKeys.OUTPUTS] = self.encode_input_args(self.outputs) - @protected - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - """Load the process from its saved instance state. - - :param saved_state: A bundle to load the state from - :param load_context: The load context - - """ - # First make sure the state machine constructor is called - super().__init__() - - self._setup_event_hooks() - - # Runtime variables, set initial states - self._future = persistence.SavableFuture() - self._event_helper = EventHelper(ProcessListener) - self._logger = None - self._communicator = None - - if 'loop' in load_context: - self._loop = load_context.loop - else: - self._loop = asyncio.get_event_loop() - - self._state: process_states.State = self.recreate_state(saved_state['_state']) - - if 'communicator' in load_context: - self._communicator = load_context.communicator - - if 'logger' in load_context: - self._logger = load_context.logger - - # Need to call this here as things downstream may rely on us having the runtime variable above - super().load_instance_state(saved_state, load_context) - - # Inputs/outputs - try: - decoded = self.decode_input_args(saved_state[BundleKeys.INPUTS_RAW]) - self._raw_inputs = utils.AttributesFrozendict(decoded) - except KeyError: - self._raw_inputs = None - - try: - decoded = self.decode_input_args(saved_state[BundleKeys.INPUTS_PARSED]) - self._parsed_inputs = utils.AttributesFrozendict(decoded) - except KeyError: - self._parsed_inputs = None - - try: - decoded = self.decode_input_args(saved_state[BundleKeys.OUTPUTS]) - self._outputs = decoded - except KeyError: - self._outputs = {} - - # endregion + return out_state def add_process_listener(self, listener: ProcessListener) -> None: """Add a process listener to the process. @@ -680,7 +726,7 @@ def log_with_pid(self, level: int, msg: str) -> None: # region Events - def on_entering(self, state: process_states.State) -> None: + def on_entering(self, state: state_machine.State) -> None: # Map these onto direct functions that the subclass can implement state_label = state.LABEL if state_label == process_states.ProcessState.CREATED: @@ -696,9 +742,9 @@ def on_entering(self, state: process_states.State) -> None: elif state_label == process_states.ProcessState.EXCEPTED: call_with_super_check(self.on_except, state.get_exc_info()) # type: ignore - def on_entered(self, from_state: Optional[process_states.State]) -> None: + def on_entered(self, from_state: Optional[state_machine.State]) -> None: # Map these onto direct functions that the subclass can implement - state_label = self._state.LABEL + state_label = self.state_label if state_label == process_states.ProcessState.RUNNING: call_with_super_check(self.on_running) elif state_label == process_states.ProcessState.WAITING: @@ -710,21 +756,21 @@ def on_entered(self, from_state: Optional[process_states.State]) -> None: elif state_label == process_states.ProcessState.KILLED: call_with_super_check(self.on_killed) - if self._communicator and isinstance(self.state, enum.Enum): + if self._communicator and isinstance(self.state_label, enum.Enum): from_label = cast(enum.Enum, from_state.LABEL).value if from_state is not None else None - subject = f'state_changed.{from_label}.{self.state.value}' + subject = f'state_changed.{from_label}.{self.state_label.value}' self.logger.info('Process<%s>: Broadcasting state change: %s', self.pid, subject) try: self._communicator.broadcast_send(body=None, sender=self.pid, subject=subject) except (ConnectionClosed, ChannelInvalidStateError): message = 'Process<%s>: no connection available to broadcast state change from %s to %s' - self.logger.warning(message, self.pid, from_label, self.state.value) + self.logger.warning(message, self.pid, from_label, self.state_label.value) except kiwipy.TimeoutError: message = 'Process<%s>: sending broadcast of state change from %s to %s timed out' - self.logger.warning(message, self.pid, from_label, self.state.value) + self.logger.warning(message, self.pid, from_label, self.state_label.value) def on_exiting(self) -> None: - state = self.state + state = self.state_label if state == process_states.ProcessState.WAITING: call_with_super_check(self.on_exit_waiting) elif state == process_states.ProcessState.RUNNING: @@ -828,7 +874,9 @@ def on_finish(self, result: Any, successful: bool) -> None: if successful: validation_error = self.spec().outputs.validate(self.outputs) if validation_error: - raise StateEntryFailed(process_states.ProcessState.FINISHED, result, False) + state_cls = self.get_states_map()[process_states.ProcessState.FINISHED] + finished_state = state_cls(result=result, successful=False) + raise StateEntryFailed(finished_state) self.future().set_result(self.outputs) @@ -857,10 +905,15 @@ def on_excepted(self) -> None: self._fire_event(ProcessListener.on_process_excepted, str(self.future().exception())) @super_check - def on_kill(self, msg: Optional[str]) -> None: + def on_kill(self, msg: Optional[MessageType]) -> None: """Entering the KILLED state.""" - self.set_status(msg) - self.future().set_exception(exceptions.KilledError(msg)) + if msg is None: + msg_txt = '' + else: + msg_txt = msg[MESSAGE_KEY] or '' + + self.set_status(msg_txt) + self.future().set_exception(exceptions.KilledError(msg_txt)) @super_check def on_killed(self) -> None: @@ -906,7 +959,12 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An :param msg: the message :return: the outcome of processing the message, the return value will be sent back as a response to the sender """ - self.logger.debug("Process<%s>: received RPC message with communicator '%s': %r", self.pid, _comm, msg) + self.logger.debug( + "Process<%s>: received RPC message with communicator '%s': %r", + self.pid, + _comm, + msg, + ) intent = msg[process_comms.INTENT_KEY] @@ -915,7 +973,7 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An if intent == process_comms.Intent.PAUSE: return self._schedule_rpc(self.pause, msg=msg.get(process_comms.MESSAGE_KEY, None)) if intent == process_comms.Intent.KILL: - return self._schedule_rpc(self.kill, msg=msg.get(process_comms.MESSAGE_KEY, None)) + return self._schedule_rpc(self.kill, msg=msg) if intent == process_comms.Intent.STATUS: status_info: Dict[str, Any] = {} self.get_status_info(status_info) @@ -935,7 +993,11 @@ def broadcast_receive( """ self.logger.debug( - "Process<%s>: received broadcast message '%s' with communicator '%s': %r", self.pid, subject, _comm, body + "Process<%s>: received broadcast message '%s' with communicator '%s': %r", + self.pid, + subject, + _comm, + body, ) # If we get a message we recognise then action it, otherwise ignore @@ -1001,13 +1063,18 @@ def close(self) -> None: # region State related methods def transition_failed( - self, initial_state: Hashable, final_state: Hashable, exception: Exception, trace: TracebackType + self, + initial_state: Hashable, + final_state: Hashable, + exception: Exception, + trace: TracebackType, ) -> None: # If we are creating, then reraise instead of failing. if final_state == process_states.ProcessState.CREATED: raise exception.with_traceback(trace) - self.transition_to(process_states.ProcessState.EXCEPTED, exception, trace) + new_state = create_state(self, process_states.ProcessState.EXCEPTED, exception=exception, traceback=trace) + self.transition_to(new_state) def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.CancellableAction]: """Pause the process. @@ -1031,18 +1098,27 @@ def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.Cancellable return self._pausing if self._stepping: + if not isinstance(self.state, Interruptable): + raise exceptions.InvalidStateError( + f'cannot interrupt {self.state.__class__}, method `interrupt` not implement' + ) + # Ask the step function to pause by setting this flag and giving the # caller back a future interrupt_exception = process_states.PauseInterruption(msg) self._set_interrupt_action_from_exception(interrupt_exception) self._pausing = self._interrupt_action # Try to interrupt the state - self._state.interrupt(interrupt_exception) + self.state.interrupt(interrupt_exception) return cast(futures.CancellableAction, self._interrupt_action) return self._do_pause(msg) - def _do_pause(self, state_msg: Optional[str], next_state: Optional[process_states.State] = None) -> bool: + @staticmethod + def _interrupt(state: Interruptable, reason: Exception) -> None: + state.interrupt(reason) + + def _do_pause(self, state_msg: Optional[str], next_state: Optional[state_machine.State] = None) -> bool: """Carry out the pause procedure, optionally transitioning to the next state first""" try: if next_state is not None: @@ -1068,11 +1144,13 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu if isinstance(exception, process_states.KillInterruption): - def do_kill(_next_state: process_states.State) -> Any: + def do_kill(_next_state: state_machine.State) -> Any: try: - # Ignore the next state - self.transition_to(process_states.ProcessState.KILLED, str(exception)) + new_state = create_state(self, process_states.ProcessState.KILLED, msg=exception.msg) + self.transition_to(new_state) return True + # FIXME: if try block except, will hit deadlock in event loop + # need to know how to debug it, and where to set a timeout. finally: self._killing = None @@ -1114,23 +1192,25 @@ def play(self) -> bool: @event(from_states=(process_states.Waiting)) def resume(self, *args: Any) -> None: """Start running the process again.""" - return self._state.resume(*args) # type: ignore + return self.state.resume(*args) # type: ignore @event(to_states=process_states.Excepted) - def fail(self, exception: Optional[BaseException], trace_back: Optional[TracebackType]) -> None: + def fail(self, exception: Optional[BaseException], traceback: Optional[TracebackType]) -> None: """ Fail the process in response to an exception :param exception: The exception that caused the failure - :param trace_back: Optional exception traceback + :param traceback: Optional exception traceback """ - self.transition_to(process_states.ProcessState.EXCEPTED, exception, trace_back) + # state_class = self.get_states_map()[process_states.ProcessState.EXCEPTED] + new_state = create_state(self, process_states.ProcessState.EXCEPTED, exception=exception, traceback=traceback) + self.transition_to(new_state) - def kill(self, msg: Union[str, None] = None) -> Union[bool, asyncio.Future]: + def kill(self, msg: Optional[MessageType] = None) -> Union[bool, asyncio.Future]: """ Kill the process :param msg: An optional kill message """ - if self.state == process_states.ProcessState.KILLED: + if self.state_label == process_states.ProcessState.KILLED: # Already killed return True @@ -1142,16 +1222,17 @@ def kill(self, msg: Union[str, None] = None) -> Union[bool, asyncio.Future]: # Already killing return self._killing - if self._stepping: + if self._stepping and isinstance(self.state, Interruptable): # Ask the step function to pause by setting this flag and giving the # caller back a future interrupt_exception = process_states.KillInterruption(msg) self._set_interrupt_action_from_exception(interrupt_exception) self._killing = self._interrupt_action - self._state.interrupt(interrupt_exception) + self.state.interrupt(interrupt_exception) return cast(futures.CancellableAction, self._interrupt_action) - self.transition_to(process_states.ProcessState.KILLED, msg) + new_state = create_state(self, process_states.ProcessState.KILLED, msg=msg) + self.transition_to(new_state) return True @property @@ -1161,16 +1242,16 @@ def is_killing(self) -> bool: # endregion - def create_initial_state(self) -> process_states.State: + def create_initial_state(self) -> state_machine.State: """This method is here to override its superclass. Automatically enter the CREATED state when the process is created. :return: A Created state """ - return cast(process_states.State, self.get_state_class(process_states.ProcessState.CREATED)(self, self.run)) + return self.get_state_class(process_states.ProcessState.CREATED)(self, self.run) - def recreate_state(self, saved_state: persistence.Bundle) -> process_states.State: + def recreate_state(self, saved_state: persistence.Bundle) -> state_machine.State: """ Create a state object from a saved state @@ -1178,7 +1259,7 @@ def recreate_state(self, saved_state: persistence.Bundle) -> process_states.Stat :return: An instance of the object with its state loaded from the save state. """ load_context = persistence.LoadSaveContext(process=self) - return cast(process_states.State, persistence.Savable.load(saved_state, load_context)) + return cast(state_machine.State, persistence.load(saved_state, load_context)) # endregion @@ -1216,11 +1297,14 @@ async def step(self) -> None: if self.paused and self._paused is not None: await self._paused + if not isinstance(self.state, Proceedable): + raise StateMachineError(f'cannot step from {self.state.__class__}, async method `execute` not implemented') + try: self._stepping = True next_state = None try: - next_state = await self._run_task(self._state.execute) + next_state = await self._run_task(self.state.execute) except process_states.Interruption as exception: # If the interruption was caused by a call to a Process method then there should # be an interrupt action ready to be executed, so just check if the cookie matches @@ -1236,7 +1320,9 @@ async def step(self) -> None: raise except Exception: # Overwrite the next state to go to excepted directly - next_state = self.create_state(process_states.ProcessState.EXCEPTED, *sys.exc_info()[1:]) + next_state = create_state( + self, process_states.ProcessState.EXCEPTED, exception=sys.exc_info()[1], traceback=sys.exc_info()[2] + ) self._set_interrupt_action(None) if self._interrupt_action: diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 748a44d7..5caf1882 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -11,11 +11,11 @@ Any, Callable, Dict, - Hashable, List, Mapping, MutableSequence, Optional, + Protocol, Sequence, Tuple, Type, @@ -25,8 +25,16 @@ import kiwipy -from . import lang, mixins, persistence, process_states, processes -from .utils import PID_TYPE, SAVED_STATE_TYPE +from plumpy import utils +from plumpy.base import state_machine +from plumpy.base.utils import call_with_super_check +from plumpy.event_helper import EventHelper +from plumpy.exceptions import InvalidStateError +from plumpy.persistence import LoadSaveContext, Savable, auto_persist, auto_save, ensure_object_loader +from plumpy.process_listener import ProcessListener + +from . import lang, persistence, process_states, processes +from .utils import PID_TYPE, SAVED_STATE_TYPE, AttributesDict __all__ = ['ToContext', 'WorkChain', 'WorkChainSpec', 'if_', 'return_', 'while_'] @@ -68,6 +76,7 @@ def get_outline(self) -> Union['_Instruction', '_FunctionCall']: return self._outline +# FIXME: better use composition here @persistence.auto_persist('_awaiting') class Waiting(process_states.Waiting): """Overwrite the waiting state""" @@ -77,24 +86,14 @@ def __init__( process: 'WorkChain', done_callback: Optional[Callable[..., Any]], msg: Optional[str] = None, - awaiting: Optional[Dict[Union[asyncio.Future, processes.Process], str]] = None, + data: Optional[Dict[Union[asyncio.Future, processes.Process], str]] = None, ) -> None: - super().__init__(process, done_callback, msg, awaiting) + super().__init__(process, done_callback, msg, data) self._awaiting: Dict[asyncio.Future, str] = {} - for awaitable, key in (awaiting or {}).items(): + for awaitable, key in (data or {}).items(): resolved_awaitable = awaitable.future() if isinstance(awaitable, processes.Process) else awaitable self._awaiting[resolved_awaitable] = key - def enter(self) -> None: - super().enter() - for awaitable in self._awaiting: - awaitable.add_done_callback(self._awaitable_done) - - def exit(self) -> None: - super().exit() - for awaitable in self._awaiting: - awaitable.remove_done_callback(self._awaitable_done) - def _awaitable_done(self, awaitable: asyncio.Future) -> None: key = self._awaiting.pop(awaitable) try: @@ -105,8 +104,19 @@ def _awaitable_done(self, awaitable: asyncio.Future) -> None: if not self._awaiting: self._waiting_future.set_result(lang.NULL) + def enter(self) -> None: + for awaitable in self._awaiting: + awaitable.add_done_callback(self._awaitable_done) + + def exit(self) -> None: + if self.is_terminal: + raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + for awaitable in self._awaiting: + awaitable.remove_done_callback(self._awaitable_done) + -class WorkChain(mixins.ContextMixin, processes.Process): +class WorkChain(processes.Process): """ A WorkChain is a series of instructions carried out with the ability to save state in between. @@ -114,10 +124,10 @@ class WorkChain(mixins.ContextMixin, processes.Process): _spec_class = WorkChainSpec _STEPPER_STATE = 'stepper_state' - _CONTEXT = 'CONTEXT' + CONTEXT = 'CONTEXT' @classmethod - def get_state_classes(cls) -> Dict[Hashable, Type[process_states.State]]: + def get_state_classes(cls) -> Dict[process_states.ProcessState, Type[state_machine.State]]: states_map = super().get_state_classes() states_map[process_states.ProcessState.WAITING] = Waiting return states_map @@ -131,9 +141,14 @@ def __init__( communicator: Optional[kiwipy.Communicator] = None, ) -> None: super().__init__(inputs=inputs, pid=pid, logger=logger, loop=loop, communicator=communicator) + self._context: Optional[AttributesDict] = AttributesDict() self._stepper: Optional[Stepper] = None self._awaitables: Dict[Union[asyncio.Future, processes.Process], str] = {} + @property + def ctx(self) -> Optional[AttributesDict]: + return self._context + @classmethod def spec(cls) -> WorkChainSpec: return cast(WorkChainSpec, super().spec()) @@ -142,23 +157,118 @@ def on_create(self) -> None: super().on_create() self._stepper = self.spec().get_outline().create_stepper(self) - def save_instance_state( - self, out_state: SAVED_STATE_TYPE, save_context: Optional[persistence.LoadSaveContext] - ) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: + """ + Ask the process to save its current instance state. + + :param out_state: A bundle to save the state to + :param save_context: The save context + """ + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + if isinstance(self._state, persistence.Savable): + out_state['_state'] = self._state.save() + + # Inputs/outputs + if self.raw_inputs is not None: + out_state[processes.BundleKeys.INPUTS_RAW] = self.encode_input_args(self.raw_inputs) + + if self.inputs is not None: + out_state[processes.BundleKeys.INPUTS_PARSED] = self.encode_input_args(self.inputs) + + if self.outputs: + out_state[processes.BundleKeys.OUTPUTS] = self.encode_input_args(self.outputs) # Ask the stepper to save itself - if self._stepper is not None: + if self._stepper is not None and isinstance(self._stepper, Savable): out_state[self._STEPPER_STATE] = self._stepper.save() - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) + if self._context is not None: + out_state[self.CONTEXT] = self._context.__dict__ + + return out_state + + @classmethod + def recreate_from( + cls, + saved_state: SAVED_STATE_TYPE, + load_context: Optional[persistence.LoadSaveContext] = None, + ) -> WorkChain: + """Recreate a workchain from a saved state, passing any positional + + :param saved_state: The saved state to load from + :param load_context: The load context to use + :return: An instance of the object with its state loaded from the save state. + + """ + ### FIXME: dup from process.create_from + load_context = ensure_object_loader(load_context, saved_state) + proc = cls.__new__(cls) + + # XXX: load_instance_state + # First make sure the state machine constructor is called + state_machine.StateMachine.__init__(proc) + + proc._setup_event_hooks() + + # Runtime variables, set initial states + proc._future = persistence.SavableFuture() + proc._event_helper = EventHelper(ProcessListener) + proc._logger = None + proc._communicator = None + + if 'loop' in load_context: + proc._loop = load_context.loop + else: + proc._loop = asyncio.get_event_loop() + + proc._state = proc.recreate_state(saved_state['_state']) + + if 'communicator' in load_context: + proc._communicator = load_context.communicator + + if 'logger' in load_context: + proc._logger = load_context.logger + + # Need to call this here as things downstream may rely on us having the runtime variable above + persistence.load_auto_persist_params(proc, saved_state, load_context) + + # Inputs/outputs + try: + decoded = proc.decode_input_args(saved_state[processes.BundleKeys.INPUTS_RAW]) + proc._raw_inputs = utils.AttributesFrozendict(decoded) + except KeyError: + proc._raw_inputs = None + + try: + decoded = proc.decode_input_args(saved_state[processes.BundleKeys.INPUTS_PARSED]) + proc._parsed_inputs = utils.AttributesFrozendict(decoded) + except KeyError: + proc._parsed_inputs = None + + try: + decoded = proc.decode_input_args(saved_state[processes.BundleKeys.OUTPUTS]) + proc._outputs = decoded + except KeyError: + proc._outputs = {} + ### UNTILHERE FIXME: dup from process.create_from + + # context mixin + try: + proc._context = AttributesDict(**saved_state[proc.CONTEXT]) + except KeyError: + pass + + # end of context mixin # Recreate the stepper - self._stepper = None - stepper_state = saved_state.get(self._STEPPER_STATE, None) + proc._stepper = None + stepper_state = saved_state.get(proc._STEPPER_STATE, None) if stepper_state is not None: - self._stepper = self.spec().get_outline().recreate_stepper(stepper_state, self) + proc._stepper = proc.spec().get_outline().recreate_stepper(stepper_state, proc) + + call_with_super_check(proc.init) + return proc def to_context(self, **kwargs: Union[asyncio.Future, processes.Process]) -> None: """ @@ -195,15 +305,8 @@ def _do_step(self) -> Any: return return_value -class Stepper(persistence.Savable, metaclass=abc.ABCMeta): - def __init__(self, workchain: 'WorkChain') -> None: - self._workchain = workchain - - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self._workchain = load_context.workchain - - @abc.abstractmethod +# XXX: Stepper is also a Saver with `save` method. +class Stepper(Protocol): def step(self) -> Tuple[bool, Any]: """ Execute on step of the instructions. @@ -212,6 +315,7 @@ def step(self) -> Tuple[bool, Any]: 1. The return value from the executed step """ + ... class _Instruction(metaclass=abc.ABCMeta): @@ -241,18 +345,37 @@ def get_description(self) -> Any: """ -class _FunctionStepper(Stepper): +@auto_persist() +class _FunctionStepper: def __init__(self, workchain: 'WorkChain', fn: WC_COMMAND_TYPE): - super().__init__(workchain) + self._workchain = workchain self._fn = fn - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) out_state['_fn'] = self._fn.__name__ - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self._fn = getattr(self._workchain.__class__, saved_state['_fn']) + return out_state + + @classmethod + def recreate_from( + cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext] = None + ) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + obj = persistence.auto_load(cls, saved_state, load_context) + obj._workchain = load_context.workchain + obj._fn = getattr(obj._workchain.__class__, saved_state['_fn']) + + return obj def step(self) -> Tuple[bool, Any]: return True, self._fn(self._workchain) @@ -292,9 +415,9 @@ def get_description(self) -> str: @persistence.auto_persist('_pos') -class _BlockStepper(Stepper): +class _BlockStepper: def __init__(self, block: Sequence[_Instruction], workchain: 'WorkChain') -> None: - super().__init__(workchain) + self._workchain = workchain self._block = block self._pos: int = 0 self._child_stepper: Optional[Stepper] = self._block[0].create_stepper(self._workchain) @@ -319,18 +442,34 @@ def next_instruction(self) -> None: def finished(self) -> bool: return self._pos == len(self._block) - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) - if self._child_stepper is not None: + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) + if self._child_stepper is not None and isinstance(self._child_stepper, Savable): out_state[STEPPER_STATE] = self._child_stepper.save() - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self._block = load_context.block_instruction + return out_state + + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + obj = persistence.auto_load(cls, saved_state, load_context) + obj._workchain = load_context.workchain + obj._block = load_context.block_instruction stepper_state = saved_state.get(STEPPER_STATE, None) - self._child_stepper = None + obj._child_stepper = None if stepper_state is not None: - self._child_stepper = self._block[self._pos].recreate_stepper(stepper_state, self._workchain) + obj._child_stepper = obj._block[obj._pos].recreate_stepper(stepper_state, obj._workchain) + + return obj def __str__(self) -> str: return str(self._pos) + ':' + str(self._child_stepper) @@ -423,9 +562,9 @@ def __str__(self) -> str: @persistence.auto_persist('_pos') -class _IfStepper(Stepper): +class _IfStepper: def __init__(self, if_instruction: '_If', workchain: 'WorkChain') -> None: - super().__init__(workchain) + self._workchain = workchain self._if_instruction = if_instruction self._pos = 0 self._child_stepper: Optional[Stepper] = None @@ -457,18 +596,33 @@ def step(self) -> Tuple[bool, Any]: def finished(self) -> bool: return self._pos == len(self._if_instruction) - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) - if self._child_stepper is not None: + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) + if self._child_stepper is not None and isinstance(self._child_stepper, Savable): out_state[STEPPER_STATE] = self._child_stepper.save() - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self._if_instruction = load_context.if_instruction + return out_state + + @classmethod + def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + obj = persistence.auto_load(cls, saved_state, load_context) + obj._workchain = load_context.workchain + obj._if_instruction = load_context.if_instruction stepper_state = saved_state.get(STEPPER_STATE, None) - self._child_stepper = None + obj._child_stepper = None if stepper_state is not None: - self._child_stepper = self._if_instruction[self._pos].body.recreate_stepper(stepper_state, self._workchain) + obj._child_stepper = obj._if_instruction[obj._pos].body.recreate_stepper(stepper_state, obj._workchain) + return obj def __str__(self) -> str: string = str(self._if_instruction[self._pos]) @@ -530,9 +684,9 @@ def get_description(self) -> Mapping[str, Any]: return description -class _WhileStepper(Stepper): +class _WhileStepper: def __init__(self, while_instruction: '_While', workchain: 'WorkChain') -> None: - super().__init__(workchain) + self._workchain = workchain self._while_instruction = while_instruction self._child_stepper: Optional[_BlockStepper] = None @@ -551,18 +705,36 @@ def step(self) -> Tuple[bool, Any]: return False, result - def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persistence.LoadSaveContext) -> None: - super().save_instance_state(out_state, save_context) + def save(self, save_context: Optional[persistence.LoadSaveContext] = None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = persistence.auto_save(self, save_context) + if self._child_stepper is not None: out_state[STEPPER_STATE] = self._child_stepper.save() - def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: - super().load_instance_state(saved_state, load_context) - self._while_instruction = load_context.while_instruction + return out_state + + @classmethod + def recreate_from( + cls, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext] = None + ) -> 'Savable': + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + obj = persistence.auto_load(cls, saved_state, load_context) + obj._workchain = load_context.workchain + obj._while_instruction = load_context.while_instruction stepper_state = saved_state.get(STEPPER_STATE, None) - self._child_stepper = None + obj._child_stepper = None if stepper_state is not None: - self._child_stepper = self._while_instruction.body.recreate_stepper(stepper_state, self._workchain) + obj._child_stepper = obj._while_instruction.body.recreate_stepper(stepper_state, obj._workchain) + return obj def __str__(self) -> str: string = str(self._while_instruction) @@ -600,9 +772,10 @@ def __init__(self, exit_code: Optional[EXIT_CODE_TYPE]) -> None: self.exit_code = exit_code -class _ReturnStepper(Stepper): +@persistence.auto_persist() +class _ReturnStepper: def __init__(self, return_instruction: '_Return', workchain: 'WorkChain') -> None: - super().__init__(workchain) + self._workchain = workchain self._return_instruction = return_instruction def step(self) -> Tuple[bool, Any]: diff --git a/tests/base/test_statemachine.py b/tests/base/test_statemachine.py index 5b4b73d8..15a218ce 100644 --- a/tests/base/test_statemachine.py +++ b/tests/base/test_statemachine.py @@ -1,8 +1,10 @@ # -*- coding: utf-8 -*- import time +from typing import final import unittest from plumpy.base import state_machine +from plumpy.exceptions import InvalidStateError # Events PLAY = 'Play' @@ -15,31 +17,25 @@ STOPPED = 'Stopped' -class Playing(state_machine.State): +class Playing: LABEL = PLAYING ALLOWED = {PAUSED, STOPPED} TRANSITIONS = {STOP: STOPPED} + is_terminal = False + def __init__(self, player, track): assert track is not None, 'Must provide a track name' - super().__init__(player) self.track = track self._last_time = None self._played = 0.0 + self.in_state = False def __str__(self): if self.in_state: self._update_time() return f'> {self.track} ({self._played}s)' - def enter(self): - super().enter() - self._last_time = time.time() - - def exit(self): - super().exit() - self._update_time() - def play(self, track=None): # pylint: disable=no-self-use, unused-argument return False @@ -48,15 +44,28 @@ def _update_time(self): self._played += current_time - self._last_time self._last_time = current_time + def enter(self) -> None: + self._last_time = time.time() + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self._update_time() + self.in_state = False + -class Paused(state_machine.State): +class Paused: LABEL = PAUSED ALLOWED = {PLAYING, STOPPED} TRANSITIONS = {STOP: STOPPED} + is_terminal = False + def __init__(self, player, playing_state): assert isinstance(playing_state, Playing), 'Must provide the playing state to pause' - super().__init__(player) + self._player = player self.playing_state = playing_state def __str__(self): @@ -64,23 +73,46 @@ def __str__(self): def play(self, track=None): if track is not None: - self.state_machine.transition_to(Playing, track=track) + self._player.transition_to(Playing(player=self.state_machine, track=track)) else: - self.state_machine.transition_to(self.playing_state) + self._player.transition_to(self.playing_state) + + def enter(self) -> None: + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + self.in_state = False -class Stopped(state_machine.State): + +class Stopped: LABEL = STOPPED ALLOWED = { PLAYING, } TRANSITIONS = {PLAY: PLAYING} + is_terminal = False + + def __init__(self, player): + self._player = player + def __str__(self): return '[]' def play(self, track): - self.state_machine.transition_to(Playing, track=track) + self._player.transition_to(Playing(self._player, track=track)) + + def enter(self) -> None: + self.in_state = True + + def exit(self) -> None: + if self.is_terminal: + raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') + + self.in_state = False class CdPlayer(state_machine.StateMachine): @@ -107,33 +139,33 @@ def play(self, track=None): @state_machine.event(from_states=Playing, to_states=Paused) def pause(self): - self.transition_to(Paused, playing_state=self._state) + self.transition_to(Paused(self, playing_state=self._state)) return True @state_machine.event(from_states=(Playing, Paused), to_states=Stopped) def stop(self): - self.transition_to(Stopped) + self.transition_to(Stopped(self)) class TestStateMachine(unittest.TestCase): def test_basic(self): cd_player = CdPlayer() - self.assertEqual(cd_player.state, STOPPED) + self.assertEqual(cd_player.state_label, STOPPED) cd_player.play('Eminem - The Real Slim Shady') - self.assertEqual(cd_player.state, PLAYING) + self.assertEqual(cd_player.state_label, PLAYING) time.sleep(1.0) cd_player.pause() - self.assertEqual(cd_player.state, PAUSED) + self.assertEqual(cd_player.state_label, PAUSED) cd_player.play() - self.assertEqual(cd_player.state, PLAYING) + self.assertEqual(cd_player.state_label, PLAYING) self.assertEqual(cd_player.play(), False) cd_player.stop() - self.assertEqual(cd_player.state, STOPPED) + self.assertEqual(cd_player.state_label, STOPPED) def test_invalid_event(self): cd_player = CdPlayer() diff --git a/tests/persistence/test_inmemory.py b/tests/persistence/test_inmemory.py index b0db46e7..9e3141de 100644 --- a/tests/persistence/test_inmemory.py +++ b/tests/persistence/test_inmemory.py @@ -1,11 +1,9 @@ # -*- coding: utf-8 -*- import unittest -from ..utils import ProcessWithCheckpoint - import plumpy -import plumpy +from ..utils import ProcessWithCheckpoint class TestInMemoryPersister(unittest.TestCase): diff --git a/tests/persistence/test_pickle.py b/tests/persistence/test_pickle.py index dd68b4fd..da4ede51 100644 --- a/tests/persistence/test_pickle.py +++ b/tests/persistence/test_pickle.py @@ -5,10 +5,10 @@ if getattr(tempfile, 'TemporaryDirectory', None) is None: from backports import tempfile -from ..utils import ProcessWithCheckpoint - import plumpy +from ..utils import ProcessWithCheckpoint + class TestPicklePersister(unittest.TestCase): def test_save_load_roundtrip(self): diff --git a/tests/rmq/test_process_comms.py b/tests/rmq/test_process_comms.py index 7223b888..307bfdb7 100644 --- a/tests/rmq/test_process_comms.py +++ b/tests/rmq/test_process_comms.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- import asyncio -import copy import kiwipy import pytest @@ -68,7 +67,7 @@ async def test_play(self, thread_communicator, async_controller): # Check that all is as we expect assert result - assert proc.state == plumpy.ProcessState.WAITING + assert proc.state_label == plumpy.ProcessState.WAITING # if not close the background process will raise exception # make sure proc reach the final state @@ -85,7 +84,7 @@ async def test_kill(self, thread_communicator, async_controller): # Check the outcome assert result - assert proc.state == plumpy.ProcessState.KILLED + assert proc.state_label == plumpy.ProcessState.KILLED @pytest.mark.asyncio async def test_status(self, thread_communicator, async_controller): @@ -173,7 +172,7 @@ async def test_play(self, thread_communicator, sync_controller): # Check that all is as we expect assert result - assert proc.state == plumpy.ProcessState.CREATED + assert proc.state_label == plumpy.ProcessState.CREATED @pytest.mark.asyncio async def test_kill(self, thread_communicator, sync_controller): @@ -187,7 +186,7 @@ async def test_kill(self, thread_communicator, sync_controller): # Check the outcome assert result # Occasionally fail - assert proc.state == plumpy.ProcessState.KILLED + assert proc.state_label == plumpy.ProcessState.KILLED @pytest.mark.asyncio async def test_kill_all(self, thread_communicator, sync_controller): @@ -196,12 +195,11 @@ async def test_kill_all(self, thread_communicator, sync_controller): for _ in range(10): procs.append(utils.WaitForSignalProcess(communicator=thread_communicator)) - msg = copy.copy(process_comms.KILL_MSG) - msg[process_comms.MESSAGE_KEY] = 'bang bang, I shot you down' + msg = process_comms.KillMessage.build(message='bang bang, I shot you down') sync_controller.kill_all(msg) await utils.wait_util(lambda: all([proc.killed() for proc in procs])) - assert all([proc.state == plumpy.ProcessState.KILLED for proc in procs]) + assert all([proc.state_label == plumpy.ProcessState.KILLED for proc in procs]) @pytest.mark.asyncio async def test_status(self, thread_communicator, sync_controller): diff --git a/tests/test_expose.py b/tests/test_expose.py index 0f6f8087..c5e6014c 100644 --- a/tests/test_expose.py +++ b/tests/test_expose.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- import unittest -from .utils import NewLoopProcess - from plumpy.ports import PortNamespace from plumpy.process_spec import ProcessSpec from plumpy.processes import Process +from .utils import NewLoopProcess + def validator_function(input, port): pass diff --git a/tests/test_persistence.py b/tests/test_persistence.py index 78724aa0..7f616433 100644 --- a/tests/test_persistence.py +++ b/tests/test_persistence.py @@ -5,16 +5,38 @@ import yaml import plumpy +from plumpy.persistence import auto_load, auto_persist, auto_save, ensure_object_loader +from plumpy.utils import SAVED_STATE_TYPE from . import utils -class SaveEmpty(plumpy.Savable): - pass +@auto_persist() +class SaveEmpty: + + @classmethod + def recreate_from(cls, saved_state, load_context=None): + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) + return obj + + def save(self, save_context=None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state @plumpy.auto_persist('test', 'test_method') -class Save1(plumpy.Savable): +class Save1: def __init__(self): self.test = 'sup yp' self.test_method = self.m @@ -22,12 +44,52 @@ def __init__(self): def m(): pass + @classmethod + def recreate_from(cls, saved_state, load_context=None): + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) + return obj + + def save(self, save_context=None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + @plumpy.auto_persist('test') -class Save(plumpy.Savable): +class Save: def __init__(self): self.test = Save1() + @classmethod + def recreate_from(cls, saved_state, load_context=None): + """ + Recreate a :class:`Savable` from a saved state using an optional load context. + + :param saved_state: The saved state + :param load_context: An optional load context + + :return: The recreated instance + + """ + load_context = ensure_object_loader(load_context, saved_state) + obj = auto_load(cls, saved_state, load_context) + return obj + + def save(self, save_context=None) -> SAVED_STATE_TYPE: + out_state: SAVED_STATE_TYPE = auto_save(self, save_context) + + return out_state + class TestSavable(unittest.TestCase): def test_empty_savable(self): diff --git a/tests/test_process_comms.py b/tests/test_process_comms.py index c59737ac..44947230 100644 --- a/tests/test_process_comms.py +++ b/tests/test_process_comms.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- import pytest -from tests import utils import plumpy from plumpy import process_comms +from tests import utils class Process(plumpy.Process): diff --git a/tests/test_processes.py b/tests/test_processes.py index faea9eae..a62bbd8d 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -2,18 +2,22 @@ """Process tests""" import asyncio -import copy import enum import unittest import kiwipy import pytest -from tests import utils import plumpy from plumpy import BundleKeys, Process, ProcessState -from plumpy.process_comms import KILL_MSG, MESSAGE_KEY +from plumpy.process_comms import KillMessage from plumpy.utils import AttributesFrozendict +from tests import utils + +# FIXME: after deabstract on savable into a protocol, test that all state are savable +# FIXME: also that any process is savable +# FIXME: any process listener is savable +# FIXME: any process control commands are savable class ForgetToCallParent(plumpy.Process): @@ -236,7 +240,7 @@ def test_execute(self): proc.execute() self.assertTrue(proc.has_terminated()) - self.assertEqual(proc.state, ProcessState.FINISHED) + self.assertEqual(proc.state_label, ProcessState.FINISHED) self.assertEqual(proc.outputs, {'default': 5}) def test_run_from_class(self): @@ -274,7 +278,7 @@ def test_exception(self): proc = utils.ExceptionProcess() with self.assertRaises(RuntimeError): proc.execute() - self.assertEqual(proc.state, ProcessState.EXCEPTED) + self.assertEqual(proc.state_label, ProcessState.EXCEPTED) def test_run_kill(self): proc = utils.KillProcess() @@ -323,12 +327,11 @@ def run(self, **kwargs): def test_kill(self): proc: Process = utils.DummyProcess() - msg = copy.copy(KILL_MSG) - msg[MESSAGE_KEY] = 'Farewell!' + msg = KillMessage.build(message='Farewell!') proc.kill(msg) self.assertTrue(proc.killed()) self.assertEqual(proc.killed_msg(), msg) - self.assertEqual(proc.state, ProcessState.KILLED) + self.assertEqual(proc.state_label, ProcessState.KILLED) def test_wait_continue(self): proc = utils.WaitForSignalProcess() @@ -342,7 +345,7 @@ def test_wait_continue(self): # Check it's done self.assertTrue(proc.has_terminated()) - self.assertEqual(proc.state, ProcessState.FINISHED) + self.assertEqual(proc.state_label, ProcessState.FINISHED) def test_exc_info(self): proc = utils.ExceptionProcess() @@ -366,7 +369,7 @@ def test_wait_pause_play_resume(self): async def async_test(): await utils.run_until_waiting(proc) - self.assertEqual(proc.state, ProcessState.WAITING) + self.assertEqual(proc.state_label, ProcessState.WAITING) result = await proc.pause() self.assertTrue(result) @@ -382,7 +385,7 @@ async def async_test(): # Check it's done self.assertTrue(proc.has_terminated()) - self.assertEqual(proc.state, ProcessState.FINISHED) + self.assertEqual(proc.state_label, ProcessState.FINISHED) loop.create_task(proc.step_until_terminated()) loop.run_until_complete(async_test()) @@ -403,7 +406,7 @@ def test_pause_play_status_messaging(self): async def async_test(): await utils.run_until_waiting(proc) - self.assertEqual(proc.state, ProcessState.WAITING) + self.assertEqual(proc.state_label, ProcessState.WAITING) result = await proc.pause(PAUSE_STATUS) self.assertTrue(result) @@ -423,15 +426,14 @@ async def async_test(): loop.run_until_complete(async_test()) self.assertTrue(proc.has_terminated()) - self.assertEqual(proc.state, ProcessState.FINISHED) + self.assertEqual(proc.state_label, ProcessState.FINISHED) def test_kill_in_run(self): class KillProcess(Process): after_kill = False def run(self, **kwargs): - msg = copy.copy(KILL_MSG) - msg[MESSAGE_KEY] = 'killed' + msg = KillMessage.build(message='killed') self.kill(msg) # The following line should be executed because kill will not # interrupt execution of a method call in the RUNNING state @@ -442,7 +444,7 @@ def run(self, **kwargs): proc.execute() self.assertTrue(proc.after_kill) - self.assertEqual(proc.state, ProcessState.KILLED) + self.assertEqual(proc.state_label, ProcessState.KILLED) def test_kill_when_paused_in_run(self): class PauseProcess(Process): @@ -454,7 +456,7 @@ def run(self, **kwargs): with self.assertRaises(plumpy.KilledError): proc.execute() - self.assertEqual(proc.state, ProcessState.KILLED) + self.assertEqual(proc.state_label, ProcessState.KILLED) def test_kill_when_paused(self): loop = asyncio.get_event_loop() @@ -478,7 +480,7 @@ async def async_test(): loop.create_task(proc.step_until_terminated()) loop.run_until_complete(async_test()) - self.assertEqual(proc.state, ProcessState.KILLED) + self.assertEqual(proc.state_label, ProcessState.KILLED) def test_run_multiple(self): # Create and play some processes @@ -554,7 +556,7 @@ def run(self): loop.run_forever() self.assertTrue(proc.paused) - self.assertEqual(plumpy.ProcessState.FINISHED, proc.state) + self.assertEqual(proc.state_label, plumpy.ProcessState.FINISHED) def test_pause_play_in_process(self): """Test that we can pause and play that by playing within the process""" @@ -572,7 +574,7 @@ def run(self): proc.execute() self.assertFalse(proc.paused) - self.assertEqual(plumpy.ProcessState.FINISHED, proc.state) + self.assertEqual(proc.state_label, plumpy.ProcessState.FINISHED) def test_process_stack(self): test_case = self @@ -656,7 +658,7 @@ def test_exception_during_on_entered(self): class RaisingProcess(Process): def on_entered(self, from_state): - if from_state is not None and from_state.label == ProcessState.RUNNING: + if from_state is not None and from_state.LABEL == ProcessState.RUNNING: raise RuntimeError('exception during on_entered') super().on_entered(from_state) @@ -703,7 +705,7 @@ def step2(self): class TestProcessSaving(unittest.TestCase): maxDiff = None - def test_running_save_instance_state(self): + def test_running_save(self): loop = asyncio.get_event_loop() nsync_comeback = SavePauseProc() @@ -783,7 +785,7 @@ def test_saving_each_step(self): proc = proc_class() saver = utils.ProcessSaver(proc) saver.capture() - self.assertEqual(proc.state, ProcessState.FINISHED) + self.assertEqual(proc.state_label, ProcessState.FINISHED) self.assertTrue(utils.check_process_against_snapshots(loop, proc_class, saver.snapshots)) def test_restart(self): @@ -798,7 +800,7 @@ async def async_test(): # Load a process from the saved state loaded_proc = saved_state.unbundle() - self.assertEqual(loaded_proc.state, ProcessState.WAITING) + self.assertEqual(loaded_proc.state_label, ProcessState.WAITING) # Now resume it loaded_proc.resume() @@ -821,7 +823,7 @@ async def async_test(): # Load a process from the saved state loaded_proc = saved_state.unbundle() - self.assertEqual(loaded_proc.state, ProcessState.WAITING) + self.assertEqual(loaded_proc.state_label, ProcessState.WAITING) # Now resume it twice in succession loaded_proc.resume() @@ -863,7 +865,7 @@ async def async_test(): def test_killed(self): proc = utils.DummyProcess() proc.kill() - self.assertEqual(proc.state, plumpy.ProcessState.KILLED) + self.assertEqual(proc.state_label, plumpy.ProcessState.KILLED) self._check_round_trip(proc) def _check_round_trip(self, proc1): @@ -986,40 +988,40 @@ def run(self): self.out(namespace_nested + '.two', 2) # Run the process in default mode which should not add any outputs and therefore fail - process = DummyDynamicProcess() - process.execute() + proc = DummyDynamicProcess() + proc.execute() - self.assertEqual(process.state, ProcessState.FINISHED) - self.assertFalse(process.is_successful) - self.assertDictEqual(process.outputs, {}) + self.assertEqual(proc.state_label, ProcessState.FINISHED) + self.assertFalse(proc.is_successful) + self.assertDictEqual(proc.outputs, {}) # Attaching only namespaced ports should fail, because the required port is not added - process = DummyDynamicProcess(inputs={'output_mode': OutputMode.DYNAMIC_PORT_NAMESPACE}) - process.execute() + proc = DummyDynamicProcess(inputs={'output_mode': OutputMode.DYNAMIC_PORT_NAMESPACE}) + proc.execute() - self.assertEqual(process.state, ProcessState.FINISHED) - self.assertFalse(process.is_successful) - self.assertEqual(process.outputs[namespace]['nested']['one'], 1) - self.assertEqual(process.outputs[namespace]['nested']['two'], 2) + self.assertEqual(proc.state_label, ProcessState.FINISHED) + self.assertFalse(proc.is_successful) + self.assertEqual(proc.outputs[namespace]['nested']['one'], 1) + self.assertEqual(proc.outputs[namespace]['nested']['two'], 2) # Attaching only the single required top-level port should be fine - process = DummyDynamicProcess(inputs={'output_mode': OutputMode.SINGLE_REQUIRED_PORT}) - process.execute() + proc = DummyDynamicProcess(inputs={'output_mode': OutputMode.SINGLE_REQUIRED_PORT}) + proc.execute() - self.assertEqual(process.state, ProcessState.FINISHED) - self.assertTrue(process.is_successful) - self.assertEqual(process.outputs['required_bool'], False) + self.assertEqual(proc.state_label, ProcessState.FINISHED) + self.assertTrue(proc.is_successful) + self.assertEqual(proc.outputs['required_bool'], False) # Attaching both the required and namespaced ports should result in a successful termination - process = DummyDynamicProcess(inputs={'output_mode': OutputMode.BOTH_SINGLE_AND_NAMESPACE}) - process.execute() - - self.assertIsNotNone(process.outputs) - self.assertEqual(process.state, ProcessState.FINISHED) - self.assertTrue(process.is_successful) - self.assertEqual(process.outputs['required_bool'], False) - self.assertEqual(process.outputs[namespace]['nested']['one'], 1) - self.assertEqual(process.outputs[namespace]['nested']['two'], 2) + proc = DummyDynamicProcess(inputs={'output_mode': OutputMode.BOTH_SINGLE_AND_NAMESPACE}) + proc.execute() + + self.assertIsNotNone(proc.outputs) + self.assertEqual(proc.state_label, ProcessState.FINISHED) + self.assertTrue(proc.is_successful) + self.assertEqual(proc.outputs['required_bool'], False) + self.assertEqual(proc.outputs[namespace]['nested']['one'], 1) + self.assertEqual(proc.outputs[namespace]['nested']['two'], 2) class TestProcessEvents(unittest.TestCase): diff --git a/tests/test_workchains.py b/tests/test_workchains.py index 08c7317a..4e34d2b4 100644 --- a/tests/test_workchains.py +++ b/tests/test_workchains.py @@ -11,6 +11,8 @@ from . import utils +# FIXME: after deabstract on savable into a protocol, test that all stepper are savable +# FIXME: workchani itself is savable class Wf(WorkChain): # Keep track of which steps were completed by the workflow diff --git a/tests/utils.py b/tests/utils.py index f2a58dfc..be8f2a5e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,13 +3,12 @@ import asyncio import collections -import copy import unittest from collections.abc import Mapping import plumpy from plumpy import persistence, process_states, processes, utils -from plumpy.process_comms import KILL_MSG, MESSAGE_KEY +from plumpy.process_comms import KillMessage Snapshot = collections.namedtuple('Snapshot', ['state', 'bundle', 'outputs']) @@ -86,8 +85,7 @@ def last_step(self): class KillProcess(processes.Process): @utils.override def run(self): - msg = copy.copy(KILL_MSG) - msg[MESSAGE_KEY] = 'killed' + msg = KillMessage.build(message='killed') return process_states.Kill(msg=msg) @@ -470,7 +468,7 @@ def run_until_waiting(proc): listener = plumpy.ProcessListener() in_waiting = plumpy.Future() - if proc.state == ProcessState.WAITING: + if proc.state_label == ProcessState.WAITING: in_waiting.set_result(True) else: