-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
092ac9c
commit 35969b2
Showing
1 changed file
with
164 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
import asyncio | ||
import socket | ||
from typing import Iterator, List | ||
from uuid import uuid4 | ||
|
||
import temporalio.api.common.v1 | ||
import temporalio.api.enums.v1 | ||
import temporalio.api.history.v1 | ||
import temporalio.api.workflowservice.v1 | ||
from temporalio import workflow | ||
from temporalio.client import Client, WorkflowHandle | ||
from temporalio.worker import Worker | ||
|
||
RunId = str | ||
|
||
WORKFLOW_ID = uuid4().hex | ||
TASK_QUEUE = __file__ | ||
|
||
|
||
@workflow.defn | ||
class WorkflowWithUpdateHandler: | ||
def __init__(self) -> None: | ||
self.update_args = [] | ||
self.signal_args = [] | ||
|
||
@workflow.update | ||
async def my_update(self, arg: int): | ||
self.update_args.append(arg) | ||
|
||
@workflow.signal | ||
async def my_signal(self, arg: int): | ||
self.signal_args.append(arg) | ||
|
||
@workflow.run | ||
async def run(self) -> dict: | ||
await asyncio.sleep(0xFFFF) | ||
await workflow.wait_condition( | ||
lambda: False and bool(self.signal_args and self.update_args) | ||
) | ||
return {"update_args": self.update_args, "signal_args": self.signal_args} | ||
|
||
|
||
async def app(client: Client): | ||
handle = await client.start_workflow( | ||
WorkflowWithUpdateHandler.run, id=WORKFLOW_ID, task_queue=TASK_QUEUE | ||
) | ||
|
||
print( | ||
f"started workflow http://{server()}/namespaces/default/workflows/{WORKFLOW_ID}" | ||
) | ||
|
||
if input("send signals?") in ["y", ""]: | ||
for i in range(2): | ||
await handle.signal(WorkflowWithUpdateHandler.my_signal, arg=i) | ||
print(f"sent signal") | ||
|
||
if input("execute updates?") in ["y", ""]: | ||
for i in range(2): | ||
await handle.execute_update(WorkflowWithUpdateHandler.my_update, arg=i) | ||
print(f"executed update") | ||
|
||
if input("reset?") in ["y", ""]: | ||
history = [e async for e in handle.fetch_history_events()] | ||
reset_to = next_event( | ||
history, | ||
temporalio.api.enums.v1.EventType.EVENT_TYPE_WORKFLOW_TASK_COMPLETED, | ||
) | ||
|
||
print(f"about to reset to event {reset_to.event_id}") | ||
run_id = get_first_execution_run_id(history) | ||
new_run_id = await reset_workflow(run_id, reset_to, client) | ||
print( | ||
f"did reset: http://localhost:8080/namespaces/default/workflows/{WORKFLOW_ID}/{new_run_id}" | ||
) | ||
|
||
new_handle = client.get_workflow_handle(WORKFLOW_ID, run_id=new_run_id) | ||
|
||
history = [e async for e in new_handle.fetch_history_events()] | ||
|
||
print("new history") | ||
for e in history: | ||
print(f"{e.event_id} {e.event_type}") | ||
print(e) | ||
|
||
await asyncio.sleep(0xFFFF) | ||
wf_result = handle.result() | ||
print(f"wf result: {wf_result}") | ||
|
||
|
||
async def reset_workflow( | ||
run_id: str, | ||
event: temporalio.api.history.v1.HistoryEvent, | ||
client: Client, | ||
) -> RunId: | ||
resp = await client.workflow_service.reset_workflow_execution( | ||
temporalio.api.workflowservice.v1.ResetWorkflowExecutionRequest( | ||
namespace="default", | ||
workflow_execution=temporalio.api.common.v1.WorkflowExecution( | ||
workflow_id=WORKFLOW_ID, | ||
run_id=run_id, | ||
), | ||
reason="Reset to test update reapply", | ||
request_id="1", | ||
reset_reapply_type=temporalio.api.enums.v1.ResetReapplyType.RESET_REAPPLY_TYPE_UNSPECIFIED, # TODO | ||
workflow_task_finish_event_id=event.event_id, | ||
) | ||
) | ||
assert resp.run_id | ||
return resp.run_id | ||
|
||
|
||
def next_event( | ||
history: List[temporalio.api.history.v1.HistoryEvent], | ||
event_type: temporalio.api.enums.v1.EventType.ValueType, | ||
) -> temporalio.api.history.v1.HistoryEvent: | ||
return next(e for e in history if e.event_type == event_type) | ||
|
||
|
||
def get_first_execution_run_id( | ||
history: List[temporalio.api.history.v1.HistoryEvent], | ||
) -> str: | ||
# TODO: correct way to obtain run_id | ||
wf_started_event = next_event( | ||
history, temporalio.api.enums.v1.EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED | ||
) | ||
run_id = ( | ||
wf_started_event.workflow_execution_started_event_attributes.first_execution_run_id | ||
) | ||
assert run_id | ||
return run_id | ||
|
||
|
||
async def main(): | ||
client = await Client.connect("localhost:7233") | ||
async with Worker( | ||
client, task_queue=TASK_QUEUE, workflows=[WorkflowWithUpdateHandler] | ||
): | ||
await app(client) | ||
|
||
|
||
def only(it: Iterator): | ||
t = next(it) | ||
assert next(it, it) == it | ||
return t | ||
|
||
|
||
def is_listening(addr: str) -> bool: | ||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | ||
h, p = addr.split(":") | ||
try: | ||
s.connect((h, int(p))) | ||
return True | ||
except socket.error: | ||
return False | ||
finally: | ||
s.close() | ||
|
||
|
||
def server() -> str: | ||
return only(filter(is_listening, ["localhost:8080", "localhost:8233"])) | ||
|
||
|
||
if __name__ == "__main__": | ||
asyncio.run(main()) |