Skip to content

Commit

Permalink
Enable secret_requests overriding
Browse files Browse the repository at this point in the history
Signed-off-by: Mecoli <[email protected]>
  • Loading branch information
Mecoli authored and Mecoli committed Apr 16, 2024
1 parent 87bbbb4 commit ab9c8cb
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 3 deletions.
10 changes: 10 additions & 0 deletions flytekit/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from flytekit.loggers import logger
from flytekit.models import literals as _literal_models
from flytekit.models.core import workflow as _workflow_model
from flytekit.models.security import Secret, SecurityContext
from flytekit.models.task import Resources as _resources_model


Expand Down Expand Up @@ -66,6 +67,7 @@ def __init__(
self._resources: typing.Optional[_resources_model] = None
self._extended_resources: typing.Optional[tasks_pb2.ExtendedResources] = None
self._container_image: typing.Optional[str] = None
self._security_context: typing.Optional[SecurityContext] = None

def runs_before(self, other: Node):
"""
Expand Down Expand Up @@ -196,6 +198,14 @@ def with_overrides(self, *args, **kwargs):
assert_not_promise(v, "container_image")
self._container_image = v

if "secret_requests" in kwargs:
v = kwargs["secret_requests"]
assert_not_promise(v, "secret_requests")
for secret in v:
if not isinstance(secret, Secret):
raise ValueError("secret_requests should be a list of flytekit.Secret objects")
self._security_context = SecurityContext(secrets=v)

if "accelerator" in kwargs:
v = kwargs["accelerator"]
assert_not_promise(v, "accelerator")
Expand Down
31 changes: 28 additions & 3 deletions flytekit/models/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from flytekit.models.core import identifier as _identifier
from flytekit.models.literals import Binding as _Binding
from flytekit.models.literals import RetryStrategy as _RetryStrategy
from flytekit.models.security import SecurityContext
from flytekit.models.task import Resources


Expand Down Expand Up @@ -599,10 +600,12 @@ def __init__(
resources: typing.Optional[Resources],
extended_resources: typing.Optional[tasks_pb2.ExtendedResources],
container_image: typing.Optional[str] = None,
security_context: typing.Optional[SecurityContext] = None,
):
self._resources = resources
self._extended_resources = extended_resources
self._container_image = container_image
self._security_context = security_context

@property
def resources(self) -> Resources:
Expand All @@ -616,21 +619,43 @@ def extended_resources(self) -> tasks_pb2.ExtendedResources:
def container_image(self) -> typing.Optional[str]:
return self._container_image

@property
def security_context(self) -> typing.Optional[SecurityContext]:
return self._security_context

def to_flyte_idl(self):
return _core_workflow.TaskNodeOverrides(
resources=self.resources.to_flyte_idl() if self.resources is not None else None,
extended_resources=self.extended_resources,
container_image=self.container_image,
security_context=self.security_context.to_flyte_idl() if self.security_context else None,
)

@classmethod
def from_flyte_idl(cls, pb2_object):
resources = Resources.from_flyte_idl(pb2_object.resources)
extended_resources = pb2_object.extended_resources if pb2_object.HasField("extended_resources") else None
container_image = pb2_object.container_image if len(pb2_object.container_image) > 0 else None
container_image = (
pb2_object.container_image if len(pb2_object.container_image) > 0 else None
) #! What if container_image is None?
security_context = (
SecurityContext.from_flyte_idl(pb2_object.security_context)
if pb2_object.HasField("security_context")
else None
)
if bool(resources.requests) or bool(resources.limits):
return cls(resources=resources, extended_resources=extended_resources, container_image=container_image)
return cls(resources=None, extended_resources=extended_resources, container_image=container_image)
return cls(
resources=resources,
extended_resources=extended_resources,
container_image=container_image,
security_context=security_context,
)
return cls(
resources=None,
extended_resources=extended_resources,
container_image=container_image,
security_context=security_context,
)


class TaskNode(_common.FlyteIdlEntity):
Expand Down
3 changes: 3 additions & 0 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ def get_serializable_node(
resources=entity._resources,
extended_resources=entity._extended_resources,
container_image=entity._container_image,
security_context=entity._security_context,
),
),
)
Expand Down Expand Up @@ -563,6 +564,7 @@ def get_serializable_node(
resources=entity._resources,
extended_resources=entity._extended_resources,
container_image=entity._container_image,
security_context=entity._security_context,
),
),
)
Expand Down Expand Up @@ -616,6 +618,7 @@ def get_serializable_array_node(
resources=node._resources,
extended_resources=node._extended_resources,
container_image=node._container_image,
security_context=node._security_context,
),
)
node = workflow_model.Node(
Expand Down

0 comments on commit ab9c8cb

Please sign in to comment.