diff --git a/flytekit/models/security.py b/flytekit/models/security.py index e210c910b7..be1c586d54 100644 --- a/flytekit/models/security.py +++ b/flytekit/models/security.py @@ -17,6 +17,9 @@ class Secret(_common.FlyteIdlEntity): key is optional and can be an individual secret identifier within the secret For k8s this is required version is the version of the secret. This is an optional field mount_requirement provides a hint to the system as to how the secret should be injected + env_var is optional. Custom environment name to set the value of the secret. + If mount_requirement is ENV_VAR, then the value is the secret itself. + If mount_requirement is FILE, then the value is the path to the secret file. """ class MountType(Enum): @@ -39,6 +42,7 @@ class MountType(Enum): key: Optional[str] = None group_version: Optional[str] = None mount_requirement: MountType = MountType.ANY + env_var: Optional[str] = None def __post_init__(self): from flytekit.configuration.plugin import get_plugin @@ -56,6 +60,7 @@ def to_flyte_idl(self) -> _sec.Secret: group_version=self.group_version, key=self.key, mount_requirement=self.mount_requirement.value, + env_var=self.env_var, ) @classmethod @@ -65,6 +70,7 @@ def from_flyte_idl(cls, pb2_object: _sec.Secret) -> "Secret": group_version=pb2_object.group_version if pb2_object.group_version else None, key=pb2_object.key if pb2_object.key else None, mount_requirement=Secret.MountType(pb2_object.mount_requirement), + env_var=pb2_object.env_var if pb2_object.env_var else None, ) diff --git a/pyproject.toml b/pyproject.toml index 63a2dce6a9..765d66e5e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ "diskcache>=5.2.1", "docker>=4.0.0", "docstring-parser>=0.9.0", - "flyteidl>=1.14.1", + "flyteidl>=1.14.2", "fsspec>=2023.3.0", "gcsfs>=2023.3.0", "googleapis-common-protos>=1.57", diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index 82c18b3c50..ba66345deb 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -1,4 +1,5 @@ import botocore.session +import shutil from contextlib import ExitStack, contextmanager import datetime import hashlib @@ -899,3 +900,49 @@ def retry_operation(operation): remote.wait(execution=execution, timeout=datetime.timedelta(minutes=5)) assert execution.outputs["o0"] == {"title": "my report", "data": [1.0, 2.0, 3.0, 4.0, 5.0]} + + +@pytest.fixture +def kubectl_secret(): + secret = "abc-xyz" + # Create secret + kubectl = shutil.which("kubectl") + if kubectl is None: + pytest.skip("kubectl not found") + + subprocess.run([ + kubectl, + "create", + "secret", + "-n", + "flytesnacks-development", + "generic", + "my-group", + f"--from-literal=token={secret}", + ], capture_output=True, text=True) + yield secret + + # Remove secret + subprocess.run([ + kubectl, + "delete", + "secrets", + "-n", + "flytesnacks-development", + "my-group", + ], capture_output=True, text=True) + + +# To enable this test, kubectl must be available. +@pytest.mark.skip(reason="Waiting for flyte release that includes https://github.com/flyteorg/flyte/pull/6176") +@pytest.mark.parametrize("task", ["get_secret_env_var", "get_secret_file"]) +def test_check_secret(kubectl_secret, task): + execution_id = run("get_secret.py", task) + + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) + execution = remote.fetch_execution(name=execution_id) + execution = remote.wait(execution=execution) + assert execution.closure.phase == WorkflowExecutionPhase.SUCCEEDED, ( + f"Execution failed with phase: {execution.closure.phase}" + ) + assert execution.outputs['o0'] == kubectl_secret diff --git a/tests/flytekit/integration/remote/workflows/basic/get_secret.py b/tests/flytekit/integration/remote/workflows/basic/get_secret.py new file mode 100644 index 0000000000..a7d7ecb488 --- /dev/null +++ b/tests/flytekit/integration/remote/workflows/basic/get_secret.py @@ -0,0 +1,26 @@ +from flytekit import task, Secret, workflow +from os import getenv + +secret_env_var = Secret( + group="my-group", + key="token", + env_var="MY_SECRET", + mount_requirement=Secret.MountType.ENV_VAR, +) +secret_env_file = Secret( + group="my-group", + key="token", + env_var="MY_SECRET_FILE", + mount_requirement=Secret.MountType.FILE, +) + + +@task(secret_requests=[secret_env_var]) +def get_secret_env_var() -> str: + return getenv("MY_SECRET", "") + + +@task(secret_requests=[secret_env_file]) +def get_secret_file() -> str: + with open(getenv("MY_SECRET_FILE")) as f: + return f.read()