Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

De-abstract Savable by making it a protocol #298

Draft
wants to merge 29 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
1117eeb
amend from rebase
unkcpz Nov 27, 2024
b82791d
Add default MESSAGE_KEY to None value and FORCE_KILL_KEY
unkcpz Nov 29, 2024
d4c0489
Alias MessageType for message passing
unkcpz Nov 30, 2024
c5a195c
Simplify _create_state_instance so it only need to do real create
unkcpz Nov 30, 2024
8db6675
Furthur simplipy _create_state_instant only create state from class
unkcpz Nov 30, 2024
74d048d
Killed state all through passing msg
unkcpz Dec 1, 2024
667af7a
Amend
unkcpz Dec 1, 2024
4be6931
If transition_to None do noting
unkcpz Dec 1, 2024
88259d6
KillMessage build msg from parameters
unkcpz Dec 1, 2024
c3c9db4
Pause/Play/Status all using message builder
unkcpz Dec 1, 2024
d0e4e73
Remove duplicate codes
unkcpz Dec 2, 2024
e3c2ae8
Future type annotation
unkcpz Dec 2, 2024
e5c74ad
Fix doc
unkcpz Dec 2, 2024
18eb56e
Mapping states from state name
unkcpz Dec 3, 2024
b505628
Remove the middle layer of statemachine.State + Savable abstraction
unkcpz Dec 2, 2024
7f8a30e
Move is_terminal as class attribute required
unkcpz Dec 2, 2024
e207892
forming the enter/exit for State protocol
unkcpz Dec 2, 2024
080d036
Forming Interruptable and Proceedable protocol
unkcpz Dec 2, 2024
6bfb87d
Refactoring create_state as static function initialize state from label
unkcpz Dec 2, 2024
ef964ed
To lenthy for rethinking
unkcpz Dec 4, 2024
937ad01
Move static method load outside
unkcpz Dec 4, 2024
304f3ba
save_instance_state simplify to only has save interface
unkcpz Dec 9, 2024
ce6beae
WIP: load_instance_state deabstract simplify
unkcpz Dec 9, 2024
484ae87
ProcessListener recreate_from
unkcpz Dec 9, 2024
c910d62
Absorb all load_instance_state into recreate_from
unkcpz Dec 9, 2024
61c7fb8
Remove useless persist method of Savable class
unkcpz Dec 9, 2024
55bc734
Explicity recreate_from implementation
unkcpz Dec 9, 2024
9b9a5b7
WIP: forming Savable protocol
unkcpz Dec 9, 2024
3e6a2dd
Make auto_load symmetry with auto_save and state/state_label distinguish
unkcpz Dec 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/nitpick-exceptions
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ py:class kiwipy.communications.Communicator

# unavailable forward references
py:class plumpy.process_states.Command
py:class plumpy.process_states.State
py:class plumpy.state_machine.State
py:class plumpy.base.state_machine.State
py:class State
py:class Process
Expand Down
4 changes: 2 additions & 2 deletions docs/source/tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@
" def continue_fn(self):\n",
" print('continuing')\n",
" # message is stored in the process status\n",
" return plumpy.Kill('I was killed')\n",
" return plumpy.Kill(plumpy.KillMessage.build('I was killed'))\n",
"\n",
"\n",
"process = ContinueProcess()\n",
Expand Down Expand Up @@ -1118,7 +1118,7 @@
"\n",
"process = SimpleProcess(communicator=communicator)\n",
"\n",
"pprint(communicator.rpc_send(str(process.pid), plumpy.STATUS_MSG).result())"
"pprint(communicator.rpc_send(str(process.pid), plumpy.StatusMessage.build()).result())"
]
},
{
Expand Down
2 changes: 0 additions & 2 deletions src/plumpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from .exceptions import *
from .futures import *
from .loaders import *
from .mixins import *
from .persistence import *
from .ports import *
from .process_comms import *
Expand All @@ -25,7 +24,6 @@
+ processes.__all__
+ utils.__all__
+ futures.__all__
+ mixins.__all__
+ persistence.__all__
+ communications.__all__
+ process_comms.__all__
Expand Down
165 changes: 78 additions & 87 deletions src/plumpy/base/state_machine.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,30 @@
# -*- coding: utf-8 -*-
"""The state machine for processes"""

from __future__ import annotations

import enum
import functools
import inspect
import logging
import os
import sys
from types import TracebackType
from typing import Any, Callable, Dict, Hashable, Iterable, List, Optional, Sequence, Set, Type, Union, cast
from typing import (
Any,
Callable,
ClassVar,
Dict,
Hashable,
Iterable,
List,
Optional,
Protocol,
Sequence,
Type,
Union,
runtime_checkable,
)

from plumpy.futures import Future

Expand All @@ -18,7 +34,6 @@

_LOGGER = logging.getLogger(__name__)

LABEL_TYPE = Union[None, enum.Enum, str]
EVENT_CALLBACK_TYPE = Callable[['StateMachine', Hashable, Optional['State']], None]


Expand All @@ -31,7 +46,7 @@ class StateEntryFailed(Exception): # noqa: N818
Failed to enter a state, can provide the next state to go to via this exception
"""

def __init__(self, state: Hashable = None, *args: Any, **kwargs: Any) -> None:
def __init__(self, state: State, *args: Any, **kwargs: Any) -> None:
super().__init__('failed to enter state')
self.state = state
self.args = args
Expand Down Expand Up @@ -74,12 +89,12 @@ def event(
if from_states != '*':
if inspect.isclass(from_states):
from_states = (from_states,)
if not all(issubclass(state, State) for state in from_states): # type: ignore
if not all(isinstance(state, State) for state in from_states): # type: ignore
raise TypeError(f'from_states: {from_states}')
if to_states != '*':
if inspect.isclass(to_states):
to_states = (to_states,)
if not all(issubclass(state, State) for state in to_states): # type: ignore
if not all(isinstance(state, State) for state in to_states): # type: ignore
raise TypeError(f'to_states: {to_states}')

def wrapper(wrapped: Callable[..., Any]) -> Callable[..., Any]:
Expand Down Expand Up @@ -113,57 +128,40 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any:
return wrapper


class State:
LABEL: LABEL_TYPE = None
# A set containing the labels of states that can be entered
# from this one
ALLOWED: Set[LABEL_TYPE] = set()
@runtime_checkable
class State(Protocol):
LABEL: ClassVar[Any]
ALLOWED: ClassVar[set[Any]]
is_terminal: ClassVar[bool]

@classmethod
def is_terminal(cls) -> bool:
return not cls.ALLOWED
def __init__(self, *args: Any, **kwargs: Any): ...

def __init__(self, state_machine: 'StateMachine', *args: Any, **kwargs: Any):
"""
:param state_machine: The process this state belongs to
"""
self.state_machine = state_machine
self.in_state: bool = False
def enter(self) -> None: ...

def __str__(self) -> str:
return str(self.LABEL)
def exit(self) -> None: ...

@property
def label(self) -> LABEL_TYPE:
"""Convenience property to get the state label"""
return self.LABEL

@super_check
def enter(self) -> None:
"""Entering the state"""
@runtime_checkable
class Interruptable(Protocol):
def interrupt(self, reason: Exception) -> None: ...

def execute(self) -> Optional['State']:

@runtime_checkable
class Proceedable(Protocol):
def execute(self) -> State | None:
"""
Execute the state, performing the actions that this state is responsible for.
:returns: a state to transition to or None if finished.
"""
...

@super_check
def exit(self) -> None:
"""Exiting the state"""
if self.is_terminal():
raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}')

def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> 'State':
return self.state_machine.create_state(state_label, *args, **kwargs)

def do_enter(self) -> None:
call_with_super_check(self.enter)
self.in_state = True
def create_state(st: StateMachine, state_label: Hashable, *args: Any, **kwargs: Any) -> State:
if state_label not in st.get_states_map():
raise ValueError(f'{state_label} is not a valid state')

def do_exit(self) -> None:
call_with_super_check(self.exit)
self.in_state = False
state_cls = st.get_states_map()[state_label]
return state_cls(*args, **kwargs)


class StateEventHook(enum.Enum):
Expand All @@ -187,7 +185,7 @@ def __call__(cls, *args: Any, **kwargs: Any) -> 'StateMachine':
:param kwargs: Any keyword arguments to be passed to the constructor
:return: An instance of the state machine
"""
inst = super().__call__(*args, **kwargs)
inst: StateMachine = super().__call__(*args, **kwargs)
inst.transition_to(inst.create_initial_state())
call_with_super_check(inst.init)
return inst
Expand All @@ -214,13 +212,13 @@ def get_states(cls) -> Sequence[Type[State]]:
raise RuntimeError('States not defined')

@classmethod
def initial_state_label(cls) -> LABEL_TYPE:
def initial_state_label(cls) -> Any:
cls.__ensure_built()
assert cls.STATES is not None
return cls.STATES[0].LABEL

@classmethod
def get_state_class(cls, label: LABEL_TYPE) -> Type[State]:
def get_state_class(cls, label: Any) -> Type[State]:
cls.__ensure_built()
assert cls._STATES_MAP is not None
return cls._STATES_MAP[label]
Expand All @@ -240,7 +238,7 @@ def __ensure_built(cls) -> None:
# Build the states map
cls._STATES_MAP = {}
for state_cls in cls.STATES:
assert issubclass(state_cls, State)
assert isinstance(state_cls, State)
label = state_cls.LABEL
assert label not in cls._STATES_MAP, f"Duplicate label '{label}'"
cls._STATES_MAP[label] = state_cls
Expand All @@ -264,11 +262,17 @@ def init(self) -> None:
def __str__(self) -> str:
return f'<{self.__class__.__name__}> ({self.state})'

def create_initial_state(self) -> State:
return self.get_state_class(self.initial_state_label())(self)
def create_initial_state(self, *args: Any, **kwargs: Any) -> State:
return self.get_state_class(self.initial_state_label())(self, *args, **kwargs)

@property
def state(self) -> State | None:
if self._state is None:
return None
return self._state

@property
def state(self) -> Optional[LABEL_TYPE]:
def state_label(self) -> Any:
if self._state is None:
return None
return self._state.LABEL
Expand Down Expand Up @@ -300,16 +304,24 @@ def _fire_state_event(self, hook: Hashable, state: Optional[State]) -> None:
def on_terminated(self) -> None:
"""Called when a terminal state is entered"""

def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any) -> None:
def transition_to(self, new_state: State | None, **kwargs: Any) -> None:
"""Transite to the new state.

The new target state will be create lazily when the state is not yet instantiated,
which will happened for states not in the expect path such as pause and kill.
The arguments are passed to the state class to create state instance.
(process arg does not need to pass since it will always call with 'self' as process)
"""
print(f'try: {self._state} -> {new_state}')
assert not self._transitioning, 'Cannot call transition_to when already transitioning state'

initial_state_label = self._state.LABEL if self._state is not None else None
if new_state is None:
return None

initial_state_label = self.state_label
label = None
try:
self._transitioning = True

# Make sure we have a state instance
new_state = self._create_state_instance(new_state, *args, **kwargs)
label = new_state.LABEL

# If the previous transition failed, do not try to exit it but go straight to next state
Expand All @@ -319,13 +331,12 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A
try:
self._enter_next_state(new_state)
except StateEntryFailed as exception:
# Make sure we have a state instance
new_state = self._create_state_instance(exception.state, *exception.args, **exception.kwargs)
new_state = exception.state
label = new_state.LABEL
self._exit_current_state(new_state)
self._enter_next_state(new_state)

if self._state is not None and self._state.is_terminal():
if self._state is not None and self._state.is_terminal:
call_with_super_check(self.on_terminated)
except Exception:
self._transitioning = False
Expand All @@ -338,7 +349,11 @@ def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: A
self._transitioning = False

def transition_failed(
self, initial_state: Hashable, final_state: Hashable, exception: Exception, trace: TracebackType
self,
initial_state: Hashable,
final_state: Hashable,
exception: Exception,
trace: TracebackType,
) -> None:
"""Called when a state transitions fails.

Expand All @@ -354,49 +369,25 @@ def get_debug(self) -> bool:
def set_debug(self, enabled: bool) -> None:
self._debug: bool = enabled

def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> State:
try:
return self.get_states_map()[state_label](self, *args, **kwargs)
except KeyError:
raise ValueError(f'{state_label} is not a valid state')

def _exit_current_state(self, next_state: State) -> None:
"""Exit the given state"""

# If we're just being constructed we may not have a state yet to exit,
# in which case check the new state is the initial state
if self._state is None:
if next_state.label != self.initial_state_label():
if next_state.LABEL != self.initial_state_label():
raise RuntimeError(f"Cannot enter state '{next_state}' as the initial state")
return # Nothing to exit

if next_state.LABEL not in self._state.ALLOWED:
raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.label}')
raise RuntimeError(f'Cannot transition from {self._state.LABEL} to {next_state.LABEL}')
self._fire_state_event(StateEventHook.EXITING_STATE, next_state)
self._state.do_exit()
self._state.exit()

def _enter_next_state(self, next_state: State) -> None:
last_state = self._state
self._fire_state_event(StateEventHook.ENTERING_STATE, next_state)
# Enter the new state
next_state.do_enter()
next_state.enter()
self._state = next_state
self._fire_state_event(StateEventHook.ENTERED_STATE, last_state)

def _create_state_instance(self, state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any) -> State:
if isinstance(state, State):
# It's already a state instance
return state

# OK, have to create it
state_cls = self._ensure_state_class(state)
return state_cls(self, *args, **kwargs)

def _ensure_state_class(self, state: Union[Hashable, Type[State]]) -> Type[State]:
if inspect.isclass(state) and issubclass(state, State):
return state

try:
return self.get_states_map()[cast(Hashable, state)]
except KeyError:
raise ValueError(f'{state} is not a valid state')
Loading