From b2ea6784ec290ea1a91efb5f20251776c269336c Mon Sep 17 00:00:00 2001 From: Subreptivus Date: Thu, 10 Jun 2021 11:24:31 +0300 Subject: [PATCH] feat(metadata): ability to get artifacts location for Argo-Workflows v3.0+ --- .../metadata_writer/src/metadata_writer.py | 32 +++++++++++++------ 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/backend/metadata_writer/src/metadata_writer.py b/backend/metadata_writer/src/metadata_writer.py index 62a7a502f99..aa5bc4a9ffa 100644 --- a/backend/metadata_writer/src/metadata_writer.py +++ b/backend/metadata_writer/src/metadata_writer.py @@ -84,7 +84,6 @@ def patch_pod_metadata( def output_name_to_argo(name: str) -> str: - import re # This sanitization code should be kept in sync with the code in the DSL compiler. # See https://github.com/kubeflow/pipelines/blob/39975e3cde7ba4dcea2bca835b92d0fe40b1ae3c/sdk/python/kfp/compiler/_k8s_helper.py#L33 return re.sub('-+', '-', re.sub('[^-_0-9A-Za-z]+', '-', name)).strip('-') @@ -98,15 +97,24 @@ def get_object_store_provider(endpoint: str) -> bool: else: return 'minio' -def argo_artifact_to_uri(artifact: dict) -> str: +def argo_artifact_to_uri(artifact: dict, skip_key=False) -> str: # s3 here means s3 compatible object storage. not AWS S3. if 's3' in artifact: s3_artifact = artifact['s3'] - return '{provider}://{bucket}/{key}'.format( - provider=get_object_store_provider(s3_artifact['endpoint']), - bucket=s3_artifact.get('bucket', ''), - key=s3_artifact.get('key', ''), - ) + if (s3_artifact.keys() >= {'endpoint', 'bucket'}): + if skip_key: + return '{provider}://{bucket}/'.format( + provider=get_object_store_provider(s3_artifact['endpoint']), + bucket=s3_artifact.get('bucket', ''), + ) + else: + return '{provider}://{bucket}/{key}'.format( + provider=get_object_store_provider(s3_artifact['endpoint']), + bucket=s3_artifact.get('bucket', ''), + key=s3_artifact.get('key', ''), + ) + else: + return s3_artifact.get('key', '') elif 'raw' in artifact: return None else: @@ -325,8 +333,14 @@ def is_kfp_v2_pod(pod) -> bool: output_artifacts = [] for name, art in argo_output_artifacts.items(): - artifact_uri = argo_artifact_to_uri(art) - if not artifact_uri: + artifact_uri_check = argo_artifact_to_uri(art) + if artifact_uri_check: + if re.search('(\W+)', artifact_uri_check).group(1) != '://': + artifact_uri_wo_key = argo_artifact_to_uri(argo_template.get('archiveLocation', {}), True) + artifact_uri = artifact_uri_wo_key + artifact_uri_check + else: + artifact_uri = artifact_uri_check + else: continue artifact_type_name = argo_output_name_to_type.get(name, 'NoType') # Cannot be None or ''