forked from flyteorg/flytekit-python-template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathintegration.py
119 lines (106 loc) · 5.51 KB
/
integration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import argparse
import sys, os, json
from flytekit.configuration import SerializationSettings, Config, PlatformConfig, AuthType, ImageConfig
from flytekit.core.base_task import PythonTask
from flytekit.core.workflow import WorkflowBase
from flytekit.remote import FlyteRemote, FlyteTask, FlyteWorkflow
from datetime import timedelta
from uuid import uuid4
from contextlib import contextmanager
from typing import Union, List
root_directory = os.path.abspath(os.path.dirname(__file__))
@contextmanager
def workflows_module_management(workflow_name: str):
"""
allows for the import of a workflow module from a path,
but imports from the templates root directory; preserving the correct path for imports
"""
module_name = "workflows"
path = os.path.join(root_directory, workflow_name, "{{cookiecutter.project_name}}")
sys.path.insert(0, path)
try:
yield __import__(module_name)
finally:
sys.path.remove(path)
for name in dir(sys.modules[module_name]):
if name.startswith('__'):
continue
if name in globals():
del globals()[name]
if module_name in sys.modules:
del sys.modules[module_name]
def register_all(context: FlyteRemote, templates: List[dict], image_hostname: str, image_suffix: str):
version = str(uuid4())
registered_workflows = []
for template in templates:
template_name = template["template_name"]
workflow_name = template["workflow_name"]
with workflows_module_management(template_name) as wf_module:
workflow = getattr(wf_module, workflow_name)
print(workflow.name)
image = f"{image_hostname}:{template_name}-{image_suffix}"
print(f"Registering workflow: {template_name} with image: {image}")
if isinstance(workflow, WorkflowBase):
reg_workflow = context.register_workflow(
entity=workflow,
serialization_settings=SerializationSettings(image_config=ImageConfig.from_images(image),
project="flytetester",
domain="development"),
version=version,
)
elif isinstance(workflow, PythonTask):
reg_workflow = context.register_task(
entity=workflow,
serialization_settings=SerializationSettings(image_config=ImageConfig.from_images(image),
project="flytetester",
domain="development"),
version=version,
)
else:
raise Exception("Unknown workflow type")
print(f"Registered workflow: {template_name}")
registered_workflows.append(reg_workflow)
return registered_workflows
def execute_all(remote_context: FlyteRemote, reg_workflows: List[Union[FlyteWorkflow, FlyteTask]]):
for reg_workflow in reg_workflows:
print(f"Executing workflow: {reg_workflow.id}")
execution = remote_context.execute(reg_workflow, inputs={}, project="flytetester", domain="development")
print(f"Execution url: {remote_context.generate_console_url(execution)}")
completed_execution = remote_context.wait(execution, timeout=timedelta(minutes=10))
if completed_execution.error is not None:
raise Exception(f"Execution failed with error: {completed_execution.error}")
else:
print(f"Execution succeeded: {completed_execution.outputs}")
if __name__ == "__main__":
"""
This program takes a remote cluster, registers all templates on it - and then returns a url to the workflow on the Flyte Cluster.
"""
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, required=True)
parser.add_argument("--auth_type", type=str, choices=["CLIENT_CREDENTIALS", "PKCE"], default="CLIENT_CREDENTIALS")
parser.add_argument("--insecure", type=bool, default=False)
parser.add_argument("--image_hostname", type=str, default="ghcr.io/flyteorg/flytekit-python-template")
parser.add_argument("--image_suffix", type=str, default="latest")
args, _ = parser.parse_known_args()
auth_type = getattr(AuthType, args.auth_type)
client_credential_parser = argparse.ArgumentParser(parents=[parser], add_help=False)
if auth_type == AuthType.CLIENT_CREDENTIALS:
client_credential_parser.add_argument("--client_id", type=str, required=True)
client_credential_parser.add_argument("--client_secret", type=str, required=True)
args = client_credential_parser.parse_args()
platform_args = {'endpoint': args.host, 'auth_mode': auth_type, 'insecure': args.insecure}
if auth_type == AuthType.CLIENT_CREDENTIALS:
platform_args['client_id'] = args.client_id
platform_args['client_credentials_secret'] = args.client_secret
remote = FlyteRemote(
config=Config(
platform=PlatformConfig(**platform_args),
)
)
with open('templates.json') as f:
templates_list = json.load(f)
print(templates_list)
remote_wfs = register_all(remote, templates_list, args.image_hostname, args.image_suffix)
print("All workflows Registered")
execute_all(remote, remote_wfs)
print("All executions completed")