From 1bca1a52654396c2d01499477de4981bb7a1075e Mon Sep 17 00:00:00 2001 From: Mecoli1219 Date: Tue, 7 May 2024 21:31:44 +0800 Subject: [PATCH] #5085 Override task secret_requests using with_overrides Signed-off-by: Mecoli1219 --- flytekit/core/node.py | 10 ++++++++++ flytekit/models/core/workflow.py | 29 +++++++++++++++++++++++++++-- flytekit/tools/translator.py | 3 +++ 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/flytekit/core/node.py b/flytekit/core/node.py index e31e9e5f56..6855cad319 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -11,6 +11,7 @@ from flytekit.loggers import logger from flytekit.models import literals as _literal_models from flytekit.models.core import workflow as _workflow_model +from flytekit.models.security import Secret, SecurityContext from flytekit.models.task import Resources as _resources_model @@ -66,6 +67,7 @@ def __init__( self._resources: typing.Optional[_resources_model] = None self._extended_resources: typing.Optional[tasks_pb2.ExtendedResources] = None self._container_image: typing.Optional[str] = None + self._override_security_context: typing.Optional[SecurityContext] = None def runs_before(self, other: Node): """ @@ -196,6 +198,14 @@ def with_overrides(self, *args, **kwargs): assert_not_promise(v, "container_image") self._container_image = v + if "secret_requests" in kwargs: + v = kwargs["secret_requests"] + assert_not_promise(v, "secret_requests") + for secret in v: + if not isinstance(secret, Secret): + raise ValueError("secret_requests should be a list of flytekit.Secret objects") + self._override_security_context = SecurityContext(secrets=v) + if "accelerator" in kwargs: v = kwargs["accelerator"] assert_not_promise(v, "accelerator") diff --git a/flytekit/models/core/workflow.py b/flytekit/models/core/workflow.py index 44fe7e1f44..3483c21e8a 100644 --- a/flytekit/models/core/workflow.py +++ b/flytekit/models/core/workflow.py @@ -11,6 +11,7 @@ from flytekit.models.core import identifier as _identifier from flytekit.models.literals import Binding as _Binding from flytekit.models.literals import RetryStrategy as _RetryStrategy +from flytekit.models.security import SecurityContext from flytekit.models.task import Resources @@ -599,10 +600,12 @@ def __init__( resources: typing.Optional[Resources], extended_resources: typing.Optional[tasks_pb2.ExtendedResources], container_image: typing.Optional[str] = None, + override_security_context: typing.Optional[SecurityContext] = None, ): self._resources = resources self._extended_resources = extended_resources self._container_image = container_image + self._override_security_context = override_security_context @property def resources(self) -> Resources: @@ -616,11 +619,18 @@ def extended_resources(self) -> tasks_pb2.ExtendedResources: def container_image(self) -> typing.Optional[str]: return self._container_image + @property + def override_security_context(self) -> typing.Optional[SecurityContext]: + return self._override_security_context + def to_flyte_idl(self): return _core_workflow.TaskNodeOverrides( resources=self.resources.to_flyte_idl() if self.resources is not None else None, extended_resources=self.extended_resources, container_image=self.container_image, + override_security_context=self.override_security_context.to_flyte_idl() + if self.override_security_context + else None, ) @classmethod @@ -628,9 +638,24 @@ def from_flyte_idl(cls, pb2_object): resources = Resources.from_flyte_idl(pb2_object.resources) extended_resources = pb2_object.extended_resources if pb2_object.HasField("extended_resources") else None container_image = pb2_object.container_image if len(pb2_object.container_image) > 0 else None + security_context = ( + SecurityContext.from_flyte_idl(pb2_object.override_security_context) + if pb2_object.HasField("override_security_context") + else None + ) if bool(resources.requests) or bool(resources.limits): - return cls(resources=resources, extended_resources=extended_resources, container_image=container_image) - return cls(resources=None, extended_resources=extended_resources, container_image=container_image) + return cls( + resources=resources, + extended_resources=extended_resources, + container_image=container_image, + override_security_context=security_context, + ) + return cls( + resources=None, + extended_resources=extended_resources, + container_image=container_image, + override_security_context=security_context, + ) class TaskNode(_common.FlyteIdlEntity): diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index b49639d23a..b656b0d259 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -482,6 +482,7 @@ def get_serializable_node( resources=entity._resources, extended_resources=entity._extended_resources, container_image=entity._container_image, + override_security_context=entity._override_security_context, ), ), ) @@ -563,6 +564,7 @@ def get_serializable_node( resources=entity._resources, extended_resources=entity._extended_resources, container_image=entity._container_image, + override_security_context=entity._override_security_context, ), ), ) @@ -616,6 +618,7 @@ def get_serializable_array_node( resources=node._resources, extended_resources=node._extended_resources, container_image=node._container_image, + override_security_context=node._override_security_context, ), ) node = workflow_model.Node(