Skip to content

Commit

Permalink
[Fix 358] Be able to isolate between workflow executions to user-defi… (
Browse files Browse the repository at this point in the history
flink-extended#372)

* [Fix 358] Be able to isolate between workflow executions to user-defined events
  • Loading branch information
jiangxin369 committed Sep 7, 2022
1 parent 1e365e6 commit 5d9b02a
Show file tree
Hide file tree
Showing 14 changed files with 279 additions and 44 deletions.
6 changes: 6 additions & 0 deletions ai_flow/cli/commands/db_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import os
import shutil

from ai_flow.common.env import get_aiflow_home
from ai_flow.metadata.base import Base
from ai_flow.common.util.db_util import db_migration
from ai_flow.common.configuration import config_constants
Expand All @@ -39,7 +40,12 @@ def reset(args):
if args.yes or input("This will drop existing tables if they exist. Proceed? (y/n)").upper() == "Y":
db_migration.reset_db(url=db_uri, metadata=Base.metadata)
if os.path.isdir(config_constants.LOCAL_REGISTRY_PATH):
print("Removing registry files of local task executor.")
shutil.rmtree(config_constants.LOCAL_REGISTRY_PATH)
ckp_file = os.path.join(get_aiflow_home(), '.checkpoint')
if os.path.exists(ckp_file):
print("Removing checkpoint file.")
os.remove(ckp_file)
else:
_logger.info('Cancel reset the database, db uri: {}'.format(db_uri))

Expand Down
22 changes: 10 additions & 12 deletions ai_flow/cli/commands/task_manager_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,20 @@
import os
import signal

from notification_service.client.embedded_notification_client import EmbeddedNotificationClient

from ai_flow.common.configuration import config_constants
from ai_flow.blob_manager.blob_manager_interface import BlobManagerFactory, BlobManagerConfig
from ai_flow.common.exception.exceptions import TaskFailedException, TaskForceStoppedException
from ai_flow.common.util import workflow_utils
from ai_flow.common.util.thread_utils import RepeatedTimer
from ai_flow.model.context import Context
from ai_flow.model.internal.contexts import TaskExecutionContext
from ai_flow.model.internal.events import TaskStatusEvent, TaskStatusChangedEvent
from ai_flow.model.operator import AIFlowOperator
from ai_flow.model.status import TaskStatus
from ai_flow.model.task_execution import TaskExecutionKey
from ai_flow.rpc.client.aiflow_client import get_notification_client
from ai_flow.rpc.client.heartbeat_client import HeartbeatClient
from ai_flow.common.configuration.helpers import AIFLOW_HOME
from ai_flow.model.internal.contexts import set_runtime_task_context

logger = logging.getLogger('aiflow.task')

Expand Down Expand Up @@ -74,10 +74,11 @@ def __init__(self,
self.workflow_name = workflow_name
self.task_execution_key = task_execution_key
self.workflow_snapshot_path = workflow_snapshot_path
self.notification_client = EmbeddedNotificationClient(
server_uri=notification_server_uri, namespace='task_status_change', sender='task_manager')
self.notification_client = get_notification_client(
notification_server_uri=notification_server_uri, namespace='task_status_change', sender='task_manager')
self.heartbeat_client = HeartbeatClient(heartbeat_server_uri)
self.heartbeat_thread = RepeatedTimer(heartbeat_interval, self._send_heartbeat)
self.context = TaskExecutionContext(task_execution_key)

def start(self):
self.heartbeat_thread.start()
Expand All @@ -95,18 +96,16 @@ def run_task(self):

def _execute(self):
task = self._get_task()
set_runtime_task_context(self.context)
try:
if isinstance(task, AIFlowOperator):
def signal_handler(signum, frame): # pylint: disable=unused-argument
logger.error("Received SIGTERM. Terminating subprocesses.")
context = Context()
task.stop(context)
task.stop(self.context)
raise TaskForceStoppedException("Task received SIGTERM signal")
signal.signal(signal.SIGTERM, signal_handler)

context = Context()
task.start(context)
task.await_termination(context)
task.start(self.context)
task.await_termination(self.context)
except TaskForceStoppedException:
raise
except (Exception, KeyboardInterrupt) as e:
Expand Down Expand Up @@ -160,4 +159,3 @@ def stop(self):
self.heartbeat_thread.cancel()
self.heartbeat_thread.join()
self.notification_client.close()

26 changes: 19 additions & 7 deletions ai_flow/model/internal/contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import contextlib
import json

from ai_flow.model.context import Context
from ai_flow.model.task_execution import TaskExecution
from ai_flow.model.task_execution import TaskExecutionKey
from ai_flow.model.workflow import Workflow
from ai_flow.model.workflow_execution import WorkflowExecution

Expand All @@ -42,9 +45,18 @@ def __init__(self,
class TaskExecutionContext(Context):
"""It contains a workflow, a workflow execution and a task execution. It is used to execute operators"""
def __init__(self,
workflow: Workflow,
workflow_execution: WorkflowExecution,
task_execution: TaskExecution):
self.workflow = workflow
self.workflow_execution = workflow_execution
self.task_execution = task_execution
task_execution_key: TaskExecutionKey):
self.task_execution_key = task_execution_key


_CURRENT_TASK_CONTEXT: TaskExecutionContext = None


def set_runtime_task_context(context: TaskExecutionContext):
global _CURRENT_TASK_CONTEXT
_CURRENT_TASK_CONTEXT = context


def get_runtime_task_context():
return _CURRENT_TASK_CONTEXT

18 changes: 18 additions & 0 deletions ai_flow/notification/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
106 changes: 106 additions & 0 deletions ai_flow/notification/notification_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import json
from datetime import datetime
from typing import List

from notification_service.client.embedded_notification_client import EmbeddedNotificationClient
from notification_service.client.notification_client import ListenerRegistrationId, \
ListenerProcessor
from notification_service.model.event import Event, EventKey

from ai_flow.common.exception.exceptions import AIFlowException
from ai_flow.model.internal.contexts import get_runtime_task_context
from ai_flow.model.internal.events import EventContextConstant


class AIFlowNotificationClient(object):

def __init__(self,
server_uri: str,
namespace: str = None,
sender: str = None,
client_id: int = None,
initial_seq_num: int = None):
self.client = EmbeddedNotificationClient(
server_uri=server_uri,
namespace=namespace,
sender=sender,
client_id=client_id,
initial_seq_num=initial_seq_num
)

def send_event_to_all_workflow_executions(self, event: Event) -> Event:
"""
Send event to all workflow executions.
:param event: the event to send.
:return: The sent event.
"""
return self.client.send_event(event)

def send_event(self, event: Event) -> Event:
"""
Send event to current workflow execution. This function can only be used
in AIFlow Operator runtime. It will retrieve the workflow execution info from runtime
context and set to context of the event.
:param event: the event to send.
:return: The sent event.
"""
context = get_runtime_task_context()
if not context:
raise AIFlowException("send_event can only be used in AIFlow Operator runtime.")
workflow_execution_id = context.task_execution_key.workflow_execution_id
if event.context is not None:
context_dict: dict = json.loads(event.context)
context_dict.update({
EventContextConstant.WORKFLOW_EXECUTION_ID: workflow_execution_id
})
else:
event.context = json.dumps({
EventContextConstant.WORKFLOW_EXECUTION_ID: workflow_execution_id
})
return self.client.send_event(event)

def register_listener(self,
listener_processor: ListenerProcessor,
event_keys: List[EventKey] = None,
offset: int = None) -> ListenerRegistrationId:
return self.client.register_listener(
listener_processor=listener_processor,
event_keys=event_keys,
offset=offset
)

def unregister_listener(self, registration_id: ListenerRegistrationId):
self.client.unregister_listener(registration_id)

def list_events(self, name: str = None, namespace: str = None, event_type: str = None, sender: str = None,
offset: int = None) -> List[Event]:
return self.client.list_events(
name=name,
namespace=namespace,
event_type=event_type,
sender=sender,
offset=offset
)

def time_to_offset(self, time: datetime) -> int:
return self.client.time_to_offset(time)
26 changes: 22 additions & 4 deletions ai_flow/rpc/client/aiflow_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,31 @@
from ai_flow.common.configuration import config_constants


def get_scheduler_client():
server_uri = config_constants.SERVER_ADDRESS
def get_scheduler_client(server_uri: str = None):
"""
Create a client to connect with AIFlow scheduler server.
:param server_uri: The uri of AIFlow server.
Use the default value in aiflow_client.yaml if not set.
"""
if server_uri is None:
server_uri = config_constants.SERVER_ADDRESS
return SchedulerClient(server_uri=server_uri)


def get_notification_client(namespace=DEFAULT_NAMESPACE, sender=None):
def get_notification_client(notification_server_uri: str = None,
namespace: str = DEFAULT_NAMESPACE,
sender: str = None):
"""
Create a notification client to connect with notification server.
:param notification_server_uri: The uri of notification server.
:param namespace: The event namespace.
:param sender: The event sender.
"""
if notification_server_uri is None:
notification_server_uri = config_constants.NOTIFICATION_SERVER_URI
return EmbeddedNotificationClient(
server_uri=config_constants.NOTIFICATION_SERVER_URI,
server_uri=notification_server_uri,
namespace=namespace,
sender=sender)
2 changes: 2 additions & 0 deletions ai_flow/rpc/service/scheduler_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,8 @@ def startWorkflowExecution(self, request, context):
workflow_executor = WorkflowExecutor(metadata_manager=metadata_manager)

workflow_meta = metadata_manager.get_workflow_by_name(namespace, workflow_name)
if workflow_meta is None:
raise AIFlowRpcServerException(f'Workflow {namespace}.{workflow_name} not exists')
latest_snapshot = metadata_manager.get_latest_snapshot(workflow_meta.id)
event = StartWorkflowExecutionEvent(workflow_meta.id, latest_snapshot.id)
workflow_execution_start_command = scheduling_event_processor.process(event=event)
Expand Down
10 changes: 4 additions & 6 deletions ai_flow/scheduler/runtime_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ai_flow.model.internal.contexts import WorkflowContext, WorkflowExecutionContext, TaskExecutionContext
from ai_flow.model.state import State, StateDescriptor
from ai_flow.model.status import TaskStatus
from ai_flow.model.task_execution import TaskExecution
from ai_flow.model.task_execution import TaskExecution, TaskExecutionKey
from ai_flow.model.workflow import Workflow
from ai_flow.model.workflow_execution import WorkflowExecution

Expand Down Expand Up @@ -63,14 +63,12 @@ def get_task_status(self, task_name) -> TaskStatus:

class TaskExecutionContextImpl(TaskExecutionContext):
def __init__(self,
workflow: Workflow,
workflow_execution: WorkflowExecution,
task_execution: TaskExecution,
task_execution_key: TaskExecutionKey,
metadata_manager: MetadataManager):
super().__init__(workflow, workflow_execution, task_execution)
super().__init__(task_execution_key)
self._metadata_manager = metadata_manager

def get_state(self, state_descriptor: StateDescriptor) -> State:
return self._metadata_manager.get_or_create_workflow_execution_state(
workflow_execution_id=self.workflow_execution.id,
workflow_execution_id=self.task_execution_key.workflow_execution_id,
descriptor=state_descriptor)
9 changes: 3 additions & 6 deletions ai_flow/task_executor/common/heartbeat_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,12 @@
import logging
import threading
import time
from concurrent import futures

import grpc
from notification_service.client.embedded_notification_client import EmbeddedNotificationClient
from concurrent import futures

from ai_flow.common.configuration.config_constants import NOTIFICATION_SERVER_URI
from ai_flow.common.util.db_util.session import create_session
from ai_flow.model.task_execution import TaskExecutionKey
from ai_flow.rpc.client.aiflow_client import get_notification_client

from ai_flow.rpc.protobuf.message_pb2 import Response, SUCCESS

Expand Down Expand Up @@ -55,8 +53,7 @@ def __init__(self):
self.heartbeat_check_thread = StoppableThread(target=self._check_heartbeat_timeout)

def start(self):
self.notification_client = EmbeddedNotificationClient(
server_uri=NOTIFICATION_SERVER_URI, namespace='task_status_change', sender='task_executor')
self.notification_client = get_notification_client(namespace='task_status_change', sender='task_executor')

self.grpc_server.start()
logger.info('Heartbeat Service started.')
Expand Down
15 changes: 11 additions & 4 deletions samples/quickstart/quickstart_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,30 @@
#
import time

from ai_flow.rpc.client.aiflow_client import get_notification_client
from notification_service.model.event import EventKey, Event

from ai_flow.model.action import TaskAction
from ai_flow.notification.notification_client import AIFlowNotificationClient
from ai_flow.operators.bash import BashOperator
from ai_flow.operators.python import PythonOperator
from ai_flow.model.status import TaskStatus

from ai_flow.model.workflow import Workflow

EVENT_KEY = EventKey(name='quickstart_key',
event_type='quickstart_type')
EVENT_KEY = EventKey(name='event_name',
event_type='user_defined_type',
namespace="my_namespace",
sender="task3"
)


def func():
time.sleep(5)
notification_client = get_notification_client()
notification_client = AIFlowNotificationClient(
server_uri="localhost:50052",
namespace="my_namespace",
sender="task3"
)
event = Event(event_key=EVENT_KEY, message='This is a custom message.')
notification_client.send_event(event)

Expand Down
Loading

0 comments on commit 5d9b02a

Please sign in to comment.