diff --git a/flytekit/core/node.py b/flytekit/core/node.py index e31e9e5f56f..dba183fed72 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -11,6 +11,7 @@ from flytekit.loggers import logger from flytekit.models import literals as _literal_models from flytekit.models.core import workflow as _workflow_model +from flytekit.models.security import Secret, SecurityContext from flytekit.models.task import Resources as _resources_model @@ -66,6 +67,7 @@ def __init__( self._resources: typing.Optional[_resources_model] = None self._extended_resources: typing.Optional[tasks_pb2.ExtendedResources] = None self._container_image: typing.Optional[str] = None + self._security_context: typing.Optional[SecurityContext] = None def runs_before(self, other: Node): """ @@ -196,6 +198,14 @@ def with_overrides(self, *args, **kwargs): assert_not_promise(v, "container_image") self._container_image = v + if "secret_requests" in kwargs: + v = kwargs["secret_requests"] + assert_not_promise(v, "secret_requests") + for secret in v: + if not isinstance(secret, Secret): + raise ValueError("secret_requests should be a list of flytekit.Secret objects") + self._security_context = SecurityContext(secrets=v) + if "accelerator" in kwargs: v = kwargs["accelerator"] assert_not_promise(v, "accelerator") diff --git a/flytekit/models/core/workflow.py b/flytekit/models/core/workflow.py index 44fe7e1f444..38c8f0c5c1f 100644 --- a/flytekit/models/core/workflow.py +++ b/flytekit/models/core/workflow.py @@ -11,6 +11,7 @@ from flytekit.models.core import identifier as _identifier from flytekit.models.literals import Binding as _Binding from flytekit.models.literals import RetryStrategy as _RetryStrategy +from flytekit.models.security import SecurityContext from flytekit.models.task import Resources @@ -599,10 +600,12 @@ def __init__( resources: typing.Optional[Resources], extended_resources: typing.Optional[tasks_pb2.ExtendedResources], container_image: typing.Optional[str] = None, + security_context: typing.Optional[SecurityContext] = None, ): self._resources = resources self._extended_resources = extended_resources self._container_image = container_image + self._security_context = security_context @property def resources(self) -> Resources: @@ -616,21 +619,43 @@ def extended_resources(self) -> tasks_pb2.ExtendedResources: def container_image(self) -> typing.Optional[str]: return self._container_image + @property + def security_context(self) -> typing.Optional[SecurityContext]: + return self._security_context + def to_flyte_idl(self): return _core_workflow.TaskNodeOverrides( resources=self.resources.to_flyte_idl() if self.resources is not None else None, extended_resources=self.extended_resources, container_image=self.container_image, + security_context=self.security_context.to_flyte_idl() if self.security_context else None, ) @classmethod def from_flyte_idl(cls, pb2_object): resources = Resources.from_flyte_idl(pb2_object.resources) extended_resources = pb2_object.extended_resources if pb2_object.HasField("extended_resources") else None - container_image = pb2_object.container_image if len(pb2_object.container_image) > 0 else None + container_image = ( + pb2_object.container_image if len(pb2_object.container_image) > 0 else None + ) #! What if container_image is None? + security_context = ( + SecurityContext.from_flyte_idl(pb2_object.security_context) + if pb2_object.HasField("security_context") + else None + ) if bool(resources.requests) or bool(resources.limits): - return cls(resources=resources, extended_resources=extended_resources, container_image=container_image) - return cls(resources=None, extended_resources=extended_resources, container_image=container_image) + return cls( + resources=resources, + extended_resources=extended_resources, + container_image=container_image, + security_context=security_context, + ) + return cls( + resources=None, + extended_resources=extended_resources, + container_image=container_image, + security_context=security_context, + ) class TaskNode(_common.FlyteIdlEntity): diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index b49639d23aa..70dd30bb9ab 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -482,6 +482,7 @@ def get_serializable_node( resources=entity._resources, extended_resources=entity._extended_resources, container_image=entity._container_image, + security_context=entity._security_context, ), ), ) @@ -563,6 +564,7 @@ def get_serializable_node( resources=entity._resources, extended_resources=entity._extended_resources, container_image=entity._container_image, + security_context=entity._security_context, ), ), ) @@ -616,6 +618,7 @@ def get_serializable_array_node( resources=node._resources, extended_resources=node._extended_resources, container_image=node._container_image, + security_context=node._security_context, ), ) node = workflow_model.Node(