Skip to content

Commit

Permalink
reset/reapply_updates
Browse files Browse the repository at this point in the history
  • Loading branch information
dandavison committed Jan 28, 2024
1 parent 092ac9c commit 35969b2
Showing 1 changed file with 164 additions and 0 deletions.
164 changes: 164 additions & 0 deletions reset/reapply_updates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import asyncio
import socket
from typing import Iterator, List
from uuid import uuid4

import temporalio.api.common.v1
import temporalio.api.enums.v1
import temporalio.api.history.v1
import temporalio.api.workflowservice.v1
from temporalio import workflow
from temporalio.client import Client, WorkflowHandle
from temporalio.worker import Worker

RunId = str

WORKFLOW_ID = uuid4().hex
TASK_QUEUE = __file__


@workflow.defn
class WorkflowWithUpdateHandler:
def __init__(self) -> None:
self.update_args = []
self.signal_args = []

@workflow.update
async def my_update(self, arg: int):
self.update_args.append(arg)

@workflow.signal
async def my_signal(self, arg: int):
self.signal_args.append(arg)

@workflow.run
async def run(self) -> dict:
await asyncio.sleep(0xFFFF)
await workflow.wait_condition(
lambda: False and bool(self.signal_args and self.update_args)
)
return {"update_args": self.update_args, "signal_args": self.signal_args}


async def app(client: Client):
handle = await client.start_workflow(
WorkflowWithUpdateHandler.run, id=WORKFLOW_ID, task_queue=TASK_QUEUE
)

print(
f"started workflow http://{server()}/namespaces/default/workflows/{WORKFLOW_ID}"
)

if input("send signals?") in ["y", ""]:
for i in range(2):
await handle.signal(WorkflowWithUpdateHandler.my_signal, arg=i)
print(f"sent signal")

if input("execute updates?") in ["y", ""]:
for i in range(2):
await handle.execute_update(WorkflowWithUpdateHandler.my_update, arg=i)
print(f"executed update")

if input("reset?") in ["y", ""]:
history = [e async for e in handle.fetch_history_events()]
reset_to = next_event(
history,
temporalio.api.enums.v1.EventType.EVENT_TYPE_WORKFLOW_TASK_COMPLETED,
)

print(f"about to reset to event {reset_to.event_id}")
run_id = get_first_execution_run_id(history)
new_run_id = await reset_workflow(run_id, reset_to, client)
print(
f"did reset: http://localhost:8080/namespaces/default/workflows/{WORKFLOW_ID}/{new_run_id}"
)

new_handle = client.get_workflow_handle(WORKFLOW_ID, run_id=new_run_id)

history = [e async for e in new_handle.fetch_history_events()]

print("new history")
for e in history:
print(f"{e.event_id} {e.event_type}")
print(e)

await asyncio.sleep(0xFFFF)
wf_result = handle.result()
print(f"wf result: {wf_result}")


async def reset_workflow(
run_id: str,
event: temporalio.api.history.v1.HistoryEvent,
client: Client,
) -> RunId:
resp = await client.workflow_service.reset_workflow_execution(
temporalio.api.workflowservice.v1.ResetWorkflowExecutionRequest(
namespace="default",
workflow_execution=temporalio.api.common.v1.WorkflowExecution(
workflow_id=WORKFLOW_ID,
run_id=run_id,
),
reason="Reset to test update reapply",
request_id="1",
reset_reapply_type=temporalio.api.enums.v1.ResetReapplyType.RESET_REAPPLY_TYPE_UNSPECIFIED, # TODO
workflow_task_finish_event_id=event.event_id,
)
)
assert resp.run_id
return resp.run_id


def next_event(
history: List[temporalio.api.history.v1.HistoryEvent],
event_type: temporalio.api.enums.v1.EventType.ValueType,
) -> temporalio.api.history.v1.HistoryEvent:
return next(e for e in history if e.event_type == event_type)


def get_first_execution_run_id(
history: List[temporalio.api.history.v1.HistoryEvent],
) -> str:
# TODO: correct way to obtain run_id
wf_started_event = next_event(
history, temporalio.api.enums.v1.EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED
)
run_id = (
wf_started_event.workflow_execution_started_event_attributes.first_execution_run_id
)
assert run_id
return run_id


async def main():
client = await Client.connect("localhost:7233")
async with Worker(
client, task_queue=TASK_QUEUE, workflows=[WorkflowWithUpdateHandler]
):
await app(client)


def only(it: Iterator):
t = next(it)
assert next(it, it) == it
return t


def is_listening(addr: str) -> bool:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
h, p = addr.split(":")
try:
s.connect((h, int(p)))
return True
except socket.error:
return False
finally:
s.close()


def server() -> str:
return only(filter(is_listening, ["localhost:8080", "localhost:8233"]))


if __name__ == "__main__":
asyncio.run(main())

0 comments on commit 35969b2

Please sign in to comment.