diff --git a/src/integrations/prefect-kubernetes/README.md b/src/integrations/prefect-kubernetes/README.md index 4054a6acc48e..d62c14110dca 100644 --- a/src/integrations/prefect-kubernetes/README.md +++ b/src/integrations/prefect-kubernetes/README.md @@ -111,8 +111,8 @@ def kubernetes_orchestrator(): #### Patch an existing deployment ```python -import yaml -from kubernetes.client.models import V1Deployment +from kubernetes_asyncio.client.models import V1Deployment + from prefect import flow from prefect_kubernetes.credentials import KubernetesCredentials from prefect_kubernetes.deployments import patch_namespaced_deployment diff --git a/src/integrations/prefect-kubernetes/prefect_kubernetes/credentials.py b/src/integrations/prefect-kubernetes/prefect_kubernetes/credentials.py index c5149c6b2045..596aa06698cc 100644 --- a/src/integrations/prefect-kubernetes/prefect_kubernetes/credentials.py +++ b/src/integrations/prefect-kubernetes/prefect_kubernetes/credentials.py @@ -1,12 +1,12 @@ """Module for defining Kubernetes credential handling and client generation.""" -from contextlib import contextmanager +from contextlib import asynccontextmanager from pathlib import Path -from typing import Dict, Generator, Optional, Type, Union +from typing import AsyncGenerator, Dict, Optional, Type, Union import yaml -from kubernetes import config -from kubernetes.client import ( +from kubernetes_asyncio import config +from kubernetes_asyncio.client import ( ApiClient, AppsV1Api, BatchV1Api, @@ -14,7 +14,7 @@ CoreV1Api, CustomObjectsApi, ) -from kubernetes.config.config_exception import ConfigException +from kubernetes_asyncio.config.config_exception import ConfigException from pydantic import Field, field_validator from typing_extensions import Literal, Self @@ -103,24 +103,23 @@ def from_file( # Load the entire config file config_file_contents = path.read_text() config_dict = yaml.safe_load(config_file_contents) - return cls(config=config_dict, context_name=context_name) - def get_api_client(self) -> "ApiClient": + async def get_api_client(self) -> "ApiClient": """ Returns a Kubernetes API client for this cluster config. """ - return config.kube_config.new_client_from_config_dict( + return await config.kube_config.new_client_from_config_dict( config_dict=self.config, context=self.context_name ) - def configure_client(self) -> None: + async def configure_client(self) -> None: """ Activates this cluster configuration by loading the configuration into the Kubernetes Python client. After calling this, Kubernetes API clients can use this config's context. """ - config.kube_config.load_kube_config_from_dict( + await config.kube_config.load_kube_config_from_dict( config_dict=self.config, context=self.context_name ) @@ -147,12 +146,12 @@ class KubernetesCredentials(Block): cluster_config: Optional[KubernetesClusterConfig] = None - @contextmanager - def get_client( + @asynccontextmanager + async def get_client( self, client_type: Literal["apps", "batch", "core", "custom_objects"], configuration: Optional[Configuration] = None, - ) -> Generator[KubernetesClient, None, None]: + ) -> AsyncGenerator[KubernetesClient, None]: """Convenience method for retrieving a Kubernetes API client for deployment resources. Args: @@ -165,22 +164,36 @@ def get_client( ```python from prefect_kubernetes.credentials import KubernetesCredentials - with KubernetesCredentials.get_client("core") as core_v1_client: - for pod in core_v1_client.list_namespaced_pod(): + async with KubernetesCredentials.get_client("core") as core_v1_client: + pods = await core_v1_client.list_namespaced_pod() + for pod in pods.items: print(pod.metadata.name) ``` """ - client_config = configuration or Configuration() + client_configuration = configuration or Configuration() + if self.cluster_config: + config_dict = self.cluster_config.config + context = self.cluster_config.context_name + + # Use Configuration to load configuration from a dictionary - with ApiClient(configuration=client_config) as generic_client: + await config.load_kube_config_from_dict( + config_dict=config_dict, + context=context, + client_configuration=client_configuration, + ) + async with ApiClient(configuration=client_configuration) as api_client: try: - yield self.get_resource_specific_client(client_type) + yield await self.get_resource_specific_client( + client_type, api_client=api_client + ) finally: - generic_client.rest_client.pool_manager.clear() + await api_client.close() - def get_resource_specific_client( + async def get_resource_specific_client( self, client_type: str, + api_client: ApiClient, ) -> Union[AppsV1Api, BatchV1Api, CoreV1Api]: """ Utility function for configuring a generic Kubernetes client. @@ -209,15 +222,15 @@ def get_resource_specific_client( """ if self.cluster_config: - self.cluster_config.configure_client() + await self.cluster_config.configure_client() else: try: config.load_incluster_config() except ConfigException: - config.load_kube_config() + await config.load_kube_config() try: - return K8S_CLIENT_TYPES[client_type]() + return K8S_CLIENT_TYPES[client_type](api_client) except KeyError: raise ValueError( f"Invalid client type provided '{client_type}'." diff --git a/src/integrations/prefect-kubernetes/prefect_kubernetes/custom_objects.py b/src/integrations/prefect-kubernetes/prefect_kubernetes/custom_objects.py index 926179fdb19e..ce2deb2a6a9c 100644 --- a/src/integrations/prefect-kubernetes/prefect_kubernetes/custom_objects.py +++ b/src/integrations/prefect-kubernetes/prefect_kubernetes/custom_objects.py @@ -1,7 +1,6 @@ from typing import Any, Dict, Optional from prefect import task -from prefect.utilities.asyncutils import run_sync_in_worker_thread from prefect_kubernetes.credentials import KubernetesCredentials @@ -55,9 +54,10 @@ def kubernetes_orchestrator(): ) ``` """ - with kubernetes_credentials.get_client("custom_objects") as custom_objects_client: - return await run_sync_in_worker_thread( - custom_objects_client.create_namespaced_custom_object, + async with kubernetes_credentials.get_client( + "custom_objects" + ) as custom_objects_client: + return await custom_objects_client.create_namespaced_custom_object( group=group, version=version, plural=plural, @@ -113,9 +113,10 @@ def kubernetes_orchestrator(): ``` """ - with kubernetes_credentials.get_client("custom_objects") as custom_objects_client: - return await run_sync_in_worker_thread( - custom_objects_client.delete_namespaced_custom_object, + async with kubernetes_credentials.get_client( + "custom_objects" + ) as custom_objects_client: + return await custom_objects_client.delete_namespaced_custom_object( group=group, version=version, plural=plural, @@ -172,9 +173,10 @@ def kubernetes_orchestrator(): ) ``` """ - with kubernetes_credentials.get_client("custom_objects") as custom_objects_client: - return await run_sync_in_worker_thread( - custom_objects_client.get_namespaced_custom_object, + async with kubernetes_credentials.get_client( + "custom_objects" + ) as custom_objects_client: + return await custom_objects_client.get_namespaced_custom_object( group=group, version=version, plural=plural, @@ -230,9 +232,10 @@ def kubernetes_orchestrator(): ) ``` """ - with kubernetes_credentials.get_client("custom_objects") as custom_objects_client: - return await run_sync_in_worker_thread( - custom_objects_client.get_namespaced_custom_object_status, + async with kubernetes_credentials.get_client( + "custom_objects" + ) as custom_objects_client: + return await custom_objects_client.get_namespaced_custom_object_status( group=group, version=version, plural=plural, @@ -284,9 +287,10 @@ def kubernetes_orchestrator(): ) ``` """ - with kubernetes_credentials.get_client("custom_objects") as custom_objects_client: - return await run_sync_in_worker_thread( - custom_objects_client.list_namespaced_custom_object, + async with kubernetes_credentials.get_client( + "custom_objects" + ) as custom_objects_client: + return await custom_objects_client.list_namespaced_custom_object( group=group, version=version, plural=plural, @@ -354,9 +358,10 @@ def kubernetes_orchestrator(): ) ``` """ - with kubernetes_credentials.get_client("custom_objects") as custom_objects_client: - return await run_sync_in_worker_thread( - custom_objects_client.patch_namespaced_custom_object, + async with kubernetes_credentials.get_client( + "custom_objects" + ) as custom_objects_client: + return await custom_objects_client.patch_namespaced_custom_object( group=group, version=version, plural=plural, @@ -423,9 +428,10 @@ def kubernetes_orchestrator(): ) ``` """ - with kubernetes_credentials.get_client("custom_objects") as custom_objects_client: - return await run_sync_in_worker_thread( - custom_objects_client.replace_namespaced_custom_object, + async with kubernetes_credentials.get_client( + "custom_objects" + ) as custom_objects_client: + return await custom_objects_client.replace_namespaced_custom_object( group=group, version=version, plural=plural, diff --git a/src/integrations/prefect-kubernetes/prefect_kubernetes/deployments.py b/src/integrations/prefect-kubernetes/prefect_kubernetes/deployments.py index faeb3e6e6b97..c06665a48845 100644 --- a/src/integrations/prefect-kubernetes/prefect_kubernetes/deployments.py +++ b/src/integrations/prefect-kubernetes/prefect_kubernetes/deployments.py @@ -2,10 +2,13 @@ from typing import Any, Dict, Optional -from kubernetes.client.models import V1DeleteOptions, V1Deployment, V1DeploymentList +from kubernetes_asyncio.client.models import ( + V1DeleteOptions, + V1Deployment, + V1DeploymentList, +) from prefect import task -from prefect.utilities.asyncutils import run_sync_in_worker_thread from prefect_kubernetes.credentials import KubernetesCredentials @@ -34,7 +37,7 @@ async def create_namespaced_deployment( from prefect import flow from prefect_kubernetes.credentials import KubernetesCredentials from prefect_kubernetes.deployments import create_namespaced_deployment - from kubernetes.client.models import V1Deployment + from kubernetes_asyncio.client.models import V1Deployment @flow def kubernetes_orchestrator(): @@ -44,9 +47,8 @@ def kubernetes_orchestrator(): ) ``` """ - with kubernetes_credentials.get_client("apps") as apps_v1_client: - return await run_sync_in_worker_thread( - apps_v1_client.create_namespaced_deployment, + async with kubernetes_credentials.get_client("apps") as apps_v1_client: + return await apps_v1_client.create_namespaced_deployment( namespace=namespace, body=new_deployment, **kube_kwargs, @@ -80,7 +82,7 @@ async def delete_namespaced_deployment( from prefect import flow from prefect_kubernetes.credentials import KubernetesCredentials from prefect_kubernetes.deployments import delete_namespaced_deployment - from kubernetes.client.models import V1DeleteOptions + from kubernetes_asyncio.client.models import V1DeleteOptions @flow def kubernetes_orchestrator(): @@ -91,9 +93,8 @@ def kubernetes_orchestrator(): ) ``` """ - with kubernetes_credentials.get_client("apps") as apps_v1_client: - return await run_sync_in_worker_thread( - apps_v1_client.delete_namespaced_deployment, + async with kubernetes_credentials.get_client("apps") as apps_v1_client: + return await apps_v1_client.delete_namespaced_deployment( deployment_name, body=delete_options, namespace=namespace, @@ -132,9 +133,8 @@ def kubernetes_orchestrator(): ) ``` """ - with kubernetes_credentials.get_client("apps") as apps_v1_client: - return await run_sync_in_worker_thread( - apps_v1_client.list_namespaced_deployment, + async with kubernetes_credentials.get_client("apps") as apps_v1_client: + return await apps_v1_client.list_namespaced_deployment( namespace=namespace, **kube_kwargs, ) @@ -167,7 +167,7 @@ async def patch_namespaced_deployment( from prefect import flow from prefect_kubernetes.credentials import KubernetesCredentials from prefect_kubernetes.deployments import patch_namespaced_deployment - from kubernetes.client.models import V1Deployment + from kubernetes_asyncio.client.models import V1Deployment @flow def kubernetes_orchestrator(): @@ -178,9 +178,8 @@ def kubernetes_orchestrator(): ) ``` """ - with kubernetes_credentials.get_client("apps") as apps_v1_client: - return await run_sync_in_worker_thread( - apps_v1_client.patch_namespaced_deployment, + async with kubernetes_credentials.get_client("apps") as apps_v1_client: + return await apps_v1_client.patch_namespaced_deployment( name=deployment_name, namespace=namespace, body=deployment_updates, @@ -221,9 +220,8 @@ def kubernetes_orchestrator(): ) ``` """ - with kubernetes_credentials.get_client("apps") as apps_v1_client: - return await run_sync_in_worker_thread( - apps_v1_client.read_namespaced_deployment, + async with kubernetes_credentials.get_client("apps") as apps_v1_client: + return await apps_v1_client.read_namespaced_deployment( name=deployment_name, namespace=namespace, **kube_kwargs, @@ -257,7 +255,7 @@ async def replace_namespaced_deployment( from prefect import flow from prefect_kubernetes.credentials import KubernetesCredentials from prefect_kubernetes.deployments import replace_namespaced_deployment - from kubernetes.client.models import V1Deployment + from kubernetes_asyncio.client.models import V1Deployment @flow def kubernetes_orchestrator(): @@ -268,9 +266,8 @@ def kubernetes_orchestrator(): ) ``` """ - with kubernetes_credentials.get_client("apps") as apps_v1_client: - return await run_sync_in_worker_thread( - apps_v1_client.replace_namespaced_deployment, + async with kubernetes_credentials.get_client("apps") as apps_v1_client: + return await apps_v1_client.replace_namespaced_deployment( body=new_deployment, name=deployment_name, namespace=namespace, diff --git a/src/integrations/prefect-kubernetes/prefect_kubernetes/events.py b/src/integrations/prefect-kubernetes/prefect_kubernetes/events.py index 85101a2751e7..48ee0d01d50d 100644 --- a/src/integrations/prefect-kubernetes/prefect_kubernetes/events.py +++ b/src/integrations/prefect-kubernetes/prefect_kubernetes/events.py @@ -1,17 +1,11 @@ -import atexit -import threading -from typing import TYPE_CHECKING, Dict, List, Optional +import asyncio +from typing import Dict, List, Optional -from prefect.events import Event, RelatedResource, emit_event -from prefect.utilities.importtools import lazy_import +import kubernetes_asyncio +import kubernetes_asyncio.watch +from kubernetes_asyncio.client import ApiClient, V1Pod -if TYPE_CHECKING: - import kubernetes - import kubernetes.client - import kubernetes.watch - from kubernetes.client import ApiClient, V1Pod -else: - kubernetes = lazy_import("kubernetes") +from prefect.events import Event, RelatedResource, emit_event EVICTED_REASONS = { "OOMKilled", @@ -43,6 +37,7 @@ def __init__( self._job_name = job_name self._namespace = namespace self._timeout_seconds = timeout_seconds + self._task = None # All events emitted by this replicator have the pod itself as the # resource. The `worker_resource` is what the worker uses when it's @@ -51,29 +46,19 @@ def __init__( worker_resource["prefect.resource.role"] = "worker" worker_related_resource = RelatedResource(worker_resource) self._related_resources = related_resources + [worker_related_resource] - - self._watch = kubernetes.watch.Watch() - self._thread = threading.Thread(target=self._replicate_pod_events) - self._state = "READY" - atexit.register(self.stop) - - def __enter__(self): - """Start the replicator thread.""" - self._thread.start() + async def __aenter__(self): + """Start the Kubernetes event watcher when entering the context.""" + self._task = asyncio.create_task(self._replicate_pod_events()) self._state = "STARTED" + return self - def __exit__(self, *args, **kwargs): - """Stop the replicator thread.""" - self.stop() - - def stop(self): - """Stop watching for pod events and stop thread.""" - if self._thread.is_alive(): - self._watch.stop() - self._thread.join() - self._state = "STOPPED" + async def __aexit__(self, exc_type, exc_value, traceback): + """Stop the Kubernetes event watcher and ensure all tasks are completed before exiting the context.""" + self._state = "STOPPED" + if self._task: + self._task.cancel() def _pod_as_resource(self, pod: "V1Pod") -> Dict[str, str]: """Convert a pod to a resource dictionary""" @@ -83,14 +68,15 @@ def _pod_as_resource(self, pod: "V1Pod") -> Dict[str, str]: "kubernetes.namespace": pod.metadata.namespace, } - def _replicate_pod_events(self): + async def _replicate_pod_events(self): """Replicate Kubernetes pod events as Prefect Events.""" seen_phases = set() last_event = None - try: - core_client = kubernetes.client.CoreV1Api(api_client=self._client) - for event in self._watch.stream( + core_client = kubernetes_asyncio.client.CoreV1Api(api_client=self._client) + watch = kubernetes_asyncio.watch.Watch() + async with watch: + async for event in watch.stream( func=core_client.list_namespaced_pod, namespace=self._namespace, label_selector=f"job-name={self._job_name}", @@ -99,14 +85,14 @@ def _replicate_pod_events(self): phase = event["object"].status.phase if phase not in seen_phases: - last_event = self._emit_pod_event(event, last_event=last_event) + last_event = await self._emit_pod_event( + event, last_event=last_event + ) seen_phases.add(phase) if phase in FINAL_PHASES: - self._watch.stop() - finally: - self._client.rest_client.pool_manager.clear() + break - def _emit_pod_event( + async def _emit_pod_event( self, pod_event: Dict, last_event: Optional[Event] = None, diff --git a/src/integrations/prefect-kubernetes/prefect_kubernetes/exceptions.py b/src/integrations/prefect-kubernetes/prefect_kubernetes/exceptions.py index 2d01bd7fecfb..d221f08f6d53 100644 --- a/src/integrations/prefect-kubernetes/prefect_kubernetes/exceptions.py +++ b/src/integrations/prefect-kubernetes/prefect_kubernetes/exceptions.py @@ -1,6 +1,6 @@ """Module to define common exceptions within `prefect_kubernetes`.""" -from kubernetes.client.exceptions import ApiException, OpenApiException +from kubernetes_asyncio.client.exceptions import ApiException, OpenApiException class KubernetesJobDefinitionError(OpenApiException): diff --git a/src/integrations/prefect-kubernetes/prefect_kubernetes/jobs.py b/src/integrations/prefect-kubernetes/prefect_kubernetes/jobs.py index efc360d1df55..2b3978684823 100644 --- a/src/integrations/prefect-kubernetes/prefect_kubernetes/jobs.py +++ b/src/integrations/prefect-kubernetes/prefect_kubernetes/jobs.py @@ -5,13 +5,13 @@ from typing import Any, Callable, Dict, Optional, Type, Union import yaml -from kubernetes.client.models import V1DeleteOptions, V1Job, V1JobList, V1Status +from kubernetes_asyncio.client.models import V1DeleteOptions, V1Job, V1JobList, V1Status from pydantic import Field from typing_extensions import Self from prefect import task from prefect.blocks.abstract import JobBlock, JobRun -from prefect.utilities.asyncutils import run_sync_in_worker_thread, sync_compatible +from prefect.utilities.asyncutils import sync_compatible from prefect_kubernetes.credentials import KubernetesCredentials from prefect_kubernetes.exceptions import KubernetesJobTimeoutError from prefect_kubernetes.pods import list_namespaced_pod, read_namespaced_pod_log @@ -45,7 +45,7 @@ async def create_namespaced_job( from prefect import flow from prefect_kubernetes.credentials import KubernetesCredentials from prefect_kubernetes.jobs import create_namespaced_job - from kubernetes.client.models import V1Job + from kubernetes_asyncio.client.models import V1Job @flow def kubernetes_orchestrator(): @@ -55,9 +55,8 @@ def kubernetes_orchestrator(): ) ``` """ - with kubernetes_credentials.get_client("batch") as batch_v1_client: - return await run_sync_in_worker_thread( - batch_v1_client.create_namespaced_job, + async with kubernetes_credentials.get_client("batch") as batch_v1_client: + return await batch_v1_client.create_namespaced_job( namespace=namespace, body=new_job, **kube_kwargs, @@ -90,7 +89,7 @@ async def delete_namespaced_job( Example: Delete "my-job" in the default namespace: ```python - from kubernetes.client.models import V1DeleteOptions + from kubernetes_asyncio.client.models import V1DeleteOptions from prefect import flow from prefect_kubernetes.credentials import KubernetesCredentials from prefect_kubernetes.jobs import delete_namespaced_job @@ -105,9 +104,8 @@ def kubernetes_orchestrator(): ``` """ - with kubernetes_credentials.get_client("batch") as batch_v1_client: - return await run_sync_in_worker_thread( - batch_v1_client.delete_namespaced_job, + async with kubernetes_credentials.get_client("batch") as batch_v1_client: + return await batch_v1_client.delete_namespaced_job( name=job_name, body=delete_options, namespace=namespace, @@ -148,9 +146,8 @@ def kubernetes_orchestrator(): ) ``` """ - with kubernetes_credentials.get_client("batch") as batch_v1_client: - return await run_sync_in_worker_thread( - batch_v1_client.list_namespaced_job, + async with kubernetes_credentials.get_client("batch") as batch_v1_client: + return await batch_v1_client.list_namespaced_job( namespace=namespace, **kube_kwargs, ) @@ -188,7 +185,7 @@ async def patch_namespaced_job( from prefect_kubernetes.credentials import KubernetesCredentials from prefect_kubernetes.jobs import patch_namespaced_job - from kubernetes.client.models import V1Job + from kubernetes_asyncio.client.models import V1Job @flow def kubernetes_orchestrator(): @@ -200,9 +197,8 @@ def kubernetes_orchestrator(): ``` """ - with kubernetes_credentials.get_client("batch") as batch_v1_client: - return await run_sync_in_worker_thread( - batch_v1_client.patch_namespaced_job, + async with kubernetes_credentials.get_client("batch") as batch_v1_client: + return await batch_v1_client.patch_namespaced_job( name=job_name, namespace=namespace, body=job_updates, @@ -248,9 +244,8 @@ def kubernetes_orchestrator(): ) ``` """ - with kubernetes_credentials.get_client("batch") as batch_v1_client: - return await run_sync_in_worker_thread( - batch_v1_client.read_namespaced_job, + async with kubernetes_credentials.get_client("batch") as batch_v1_client: + return await batch_v1_client.read_namespaced_job( name=job_name, namespace=namespace, **kube_kwargs, @@ -295,9 +290,8 @@ def kubernetes_orchestrator(): ) ``` """ - with kubernetes_credentials.get_client("batch") as batch_v1_client: - return await run_sync_in_worker_thread( - batch_v1_client.replace_namespaced_job, + async with kubernetes_credentials.get_client("batch") as batch_v1_client: + return await batch_v1_client.replace_namespaced_job( name=job_name, body=new_job, namespace=namespace, @@ -340,9 +334,8 @@ def kubernetes_orchestrator(): ) ``` """ - with kubernetes_credentials.get_client("batch") as batch_v1_client: - return await run_sync_in_worker_thread( - batch_v1_client.read_namespaced_job_status, + async with kubernetes_credentials.get_client("batch") as batch_v1_client: + return await batch_v1_client.read_namespaced_job_status( name=job_name, namespace=namespace, **kube_kwargs, diff --git a/src/integrations/prefect-kubernetes/prefect_kubernetes/pods.py b/src/integrations/prefect-kubernetes/prefect_kubernetes/pods.py index 4877d57733d9..e4006a11075e 100644 --- a/src/integrations/prefect-kubernetes/prefect_kubernetes/pods.py +++ b/src/integrations/prefect-kubernetes/prefect_kubernetes/pods.py @@ -2,11 +2,10 @@ from typing import Any, Callable, Dict, Optional, Union -from kubernetes.client.models import V1DeleteOptions, V1Pod, V1PodList -from kubernetes.watch import Watch +from kubernetes_asyncio.client.models import V1DeleteOptions, V1Pod, V1PodList +from kubernetes_asyncio.watch import Watch from prefect import task -from prefect.utilities.asyncutils import run_sync_in_worker_thread from prefect_kubernetes.credentials import KubernetesCredentials @@ -35,7 +34,7 @@ async def create_namespaced_pod( from prefect import flow from prefect_kubernetes.credentials import KubernetesCredentials from prefect_kubernetes.pods import create_namespaced_pod - from kubernetes.client.models import V1Pod + from kubernetes_asyncio.client.models import V1Pod @flow def kubernetes_orchestrator(): @@ -45,9 +44,8 @@ def kubernetes_orchestrator(): ) ``` """ - with kubernetes_credentials.get_client("core") as core_v1_client: - return await run_sync_in_worker_thread( - core_v1_client.create_namespaced_pod, + async with kubernetes_credentials.get_client("core") as core_v1_client: + return await core_v1_client.create_namespaced_pod( namespace=namespace, body=new_pod, **kube_kwargs, @@ -81,7 +79,7 @@ async def delete_namespaced_pod( from prefect import flow from prefect_kubernetes.credentials import KubernetesCredentials from prefect_kubernetes.pods import delete_namespaced_pod - from kubernetes.client.models import V1DeleteOptions + from kubernetes_asyncio.client.models import V1DeleteOptions @flow def kubernetes_orchestrator(): @@ -92,9 +90,8 @@ def kubernetes_orchestrator(): ) ``` """ - with kubernetes_credentials.get_client("core") as core_v1_client: - return await run_sync_in_worker_thread( - core_v1_client.delete_namespaced_pod, + async with kubernetes_credentials.get_client("core") as core_v1_client: + return await core_v1_client.delete_namespaced_pod( pod_name, body=delete_options, namespace=namespace, @@ -133,9 +130,9 @@ def kubernetes_orchestrator(): ) ``` """ - with kubernetes_credentials.get_client("core") as core_v1_client: - return await run_sync_in_worker_thread( - core_v1_client.list_namespaced_pod, namespace=namespace, **kube_kwargs + async with kubernetes_credentials.get_client("core") as core_v1_client: + return await core_v1_client.list_namespaced_pod( + namespace=namespace, **kube_kwargs ) @@ -166,7 +163,7 @@ async def patch_namespaced_pod( from prefect import flow from prefect_kubernetes.credentials import KubernetesCredentials from prefect_kubernetes.pods import patch_namespaced_pod - from kubernetes.client.models import V1Pod + from kubernetes_asyncio.client.models import V1Pod @flow def kubernetes_orchestrator(): @@ -177,9 +174,8 @@ def kubernetes_orchestrator(): ) ``` """ - with kubernetes_credentials.get_client("core") as core_v1_client: - return await run_sync_in_worker_thread( - core_v1_client.patch_namespaced_pod, + async with kubernetes_credentials.get_client("core") as core_v1_client: + return await core_v1_client.patch_namespaced_pod( name=pod_name, namespace=namespace, body=pod_updates, @@ -220,9 +216,8 @@ def kubernetes_orchestrator(): ) ``` """ - with kubernetes_credentials.get_client("core") as core_v1_client: - return await run_sync_in_worker_thread( - core_v1_client.read_namespaced_pod, + async with kubernetes_credentials.get_client("core") as core_v1_client: + return await core_v1_client.read_namespaced_pod( name=pod_name, namespace=namespace, **kube_kwargs, @@ -277,20 +272,20 @@ def kubernetes_orchestrator(): ) ``` """ - with kubernetes_credentials.get_client("core") as core_v1_client: + async with kubernetes_credentials.get_client("core") as core_v1_client: if print_func is not None: # should no longer need to manually refresh on ApiException.status == 410 # as of https://github.com/kubernetes-client/python-base/pull/133 - for log_line in Watch().stream( - core_v1_client.read_namespaced_pod_log, - name=pod_name, - namespace=namespace, - container=container, - ): - print_func(log_line) - - return await run_sync_in_worker_thread( - core_v1_client.read_namespaced_pod_log, + async with Watch() as watch: + async for log_line in watch.stream( + core_v1_client.read_namespaced_pod_log, + name=pod_name, + namespace=namespace, + container=container, + ): + print_func(log_line) + + return await core_v1_client.read_namespaced_pod_log( name=pod_name, namespace=namespace, container=container, @@ -325,7 +320,7 @@ async def replace_namespaced_pod( from prefect import flow from prefect_kubernetes.credentials import KubernetesCredentials from prefect_kubernetes.pods import replace_namespaced_pod - from kubernetes.client.models import V1Pod + from kubernetes_asyncio.client.models import V1Pod @flow def kubernetes_orchestrator(): @@ -336,9 +331,8 @@ def kubernetes_orchestrator(): ) ``` """ - with kubernetes_credentials.get_client("core") as core_v1_client: - return await run_sync_in_worker_thread( - core_v1_client.replace_namespaced_pod, + async with kubernetes_credentials.get_client("core") as core_v1_client: + return await core_v1_client.replace_namespaced_pod( body=new_pod, name=pod_name, namespace=namespace, diff --git a/src/integrations/prefect-kubernetes/prefect_kubernetes/services.py b/src/integrations/prefect-kubernetes/prefect_kubernetes/services.py index a6d422aae5da..5ef7eaa9b593 100644 --- a/src/integrations/prefect-kubernetes/prefect_kubernetes/services.py +++ b/src/integrations/prefect-kubernetes/prefect_kubernetes/services.py @@ -2,10 +2,9 @@ from typing import Any, Dict, Optional -from kubernetes.client.models import V1DeleteOptions, V1Service, V1ServiceList +from kubernetes_asyncio.client.models import V1DeleteOptions, V1Service, V1ServiceList from prefect import task -from prefect.utilities.asyncutils import run_sync_in_worker_thread from prefect_kubernetes.credentials import KubernetesCredentials @@ -34,7 +33,7 @@ async def create_namespaced_service( from prefect import flow from prefect_kubernetes.credentials import KubernetesCredentials from prefect_kubernetes.services import create_namespaced_service - from kubernetes.client.models import V1Service + from kubernetes_asyncio.client.models import V1Service @flow def create_service_flow(): @@ -44,9 +43,8 @@ def create_service_flow(): ) ``` """ - with kubernetes_credentials.get_client("core") as core_v1_client: - return await run_sync_in_worker_thread( - core_v1_client.create_namespaced_service, + async with kubernetes_credentials.get_client("core") as core_v1_client: + return await core_v1_client.create_namespaced_service( body=new_service, namespace=namespace, **kube_kwargs, @@ -90,9 +88,8 @@ def kubernetes_orchestrator(): ) ``` """ - with kubernetes_credentials.get_client("core") as core_v1_client: - return await run_sync_in_worker_thread( - core_v1_client.delete_namespaced_service, + async with kubernetes_credentials.get_client("core") as core_v1_client: + return await core_v1_client.delete_namespaced_service( name=service_name, namespace=namespace, body=delete_options, @@ -131,9 +128,8 @@ def kubernetes_orchestrator(): ) ``` """ - with kubernetes_credentials.get_client("core") as core_v1_client: - return await run_sync_in_worker_thread( - core_v1_client.list_namespaced_service, + async with kubernetes_credentials.get_client("core") as core_v1_client: + return await core_v1_client.list_namespaced_service( namespace=namespace, **kube_kwargs, ) @@ -165,7 +161,7 @@ async def patch_namespaced_service( from prefect import flow from prefect_kubernetes.credentials import KubernetesCredentials from prefect_kubernetes.services import patch_namespaced_service - from kubernetes.client.models import V1Service + from kubernetes_asyncio.client.models import V1Service @flow def kubernetes_orchestrator(): @@ -177,9 +173,8 @@ def kubernetes_orchestrator(): ) ``` """ - with kubernetes_credentials.get_client("core") as core_v1_client: - return await run_sync_in_worker_thread( - core_v1_client.patch_namespaced_service, + async with kubernetes_credentials.get_client("core") as core_v1_client: + return await core_v1_client.patch_namespaced_service( name=service_name, body=service_updates, namespace=namespace, @@ -221,9 +216,8 @@ def kubernetes_orchestrator(): ) ``` """ - with kubernetes_credentials.get_client("core") as core_v1_client: - return await run_sync_in_worker_thread( - core_v1_client.read_namespaced_service, + async with kubernetes_credentials.get_client("core") as core_v1_client: + return await core_v1_client.read_namespaced_service( name=service_name, namespace=namespace, **kube_kwargs, @@ -256,7 +250,7 @@ async def replace_namespaced_service( from prefect import flow from prefect_kubernetes.credentials import KubernetesCredentials from prefect_kubernetes.services import replace_namespaced_service - from kubernetes.client.models import V1Service + from kubernetes_asyncio.client.models import V1Service @flow def kubernetes_orchestrator(): @@ -268,9 +262,8 @@ def kubernetes_orchestrator(): ) ``` """ - with kubernetes_credentials.get_client("core") as core_v1_client: - return await run_sync_in_worker_thread( - core_v1_client.replace_namespaced_service, + async with kubernetes_credentials.get_client("core") as core_v1_client: + return await core_v1_client.replace_namespaced_service( name=service_name, body=new_service, namespace=namespace, diff --git a/src/integrations/prefect-kubernetes/prefect_kubernetes/utilities.py b/src/integrations/prefect-kubernetes/prefect_kubernetes/utilities.py index 6d53b68f5e32..879b8cb3e030 100644 --- a/src/integrations/prefect-kubernetes/prefect_kubernetes/utilities.py +++ b/src/integrations/prefect-kubernetes/prefect_kubernetes/utilities.py @@ -4,7 +4,7 @@ import sys from typing import Optional, TypeVar -from kubernetes.client import ApiClient +from kubernetes_asyncio.client import ApiClient from slugify import slugify # Note: `dict(str, str)` is the Kubernetes API convention for diff --git a/src/integrations/prefect-kubernetes/prefect_kubernetes/worker.py b/src/integrations/prefect-kubernetes/prefect_kubernetes/worker.py index abb2973260ec..5afd34b9756b 100644 --- a/src/integrations/prefect-kubernetes/prefect_kubernetes/worker.py +++ b/src/integrations/prefect-kubernetes/prefect_kubernetes/worker.py @@ -105,22 +105,43 @@ import enum import json import logging -import math import os import shlex -import time -from contextlib import contextmanager +from contextlib import asynccontextmanager from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple, Union +from typing import ( + Any, + AsyncGenerator, + Dict, + Optional, + Tuple, +) +import aiohttp import anyio.abc -from kubernetes.client.exceptions import ApiException -from kubernetes.client.models import V1ObjectMeta, V1Secret +import kubernetes_asyncio +from kubernetes_asyncio import config +from kubernetes_asyncio.client import ( + ApiClient, + BatchV1Api, + Configuration, + CoreV1Api, + V1Job, + V1Pod, +) +from kubernetes_asyncio.client.exceptions import ApiException +from kubernetes_asyncio.client.models import ( + CoreV1Event, + CoreV1EventList, + V1ObjectMeta, + V1Secret, +) from pydantic import Field, model_validator from tenacity import retry, stop_after_attempt, wait_fixed, wait_random from typing_extensions import Literal, Self import prefect +from prefect.client.schemas import FlowRun from prefect.exceptions import ( InfrastructureError, InfrastructureNotAvailable, @@ -128,11 +149,10 @@ ) from prefect.server.schemas.core import Flow from prefect.server.schemas.responses import DeploymentResponse -from prefect.utilities.asyncutils import run_sync_in_worker_thread from prefect.utilities.dockerutils import get_prefect_image_name -from prefect.utilities.importtools import lazy_import from prefect.utilities.pydantic import JsonPatch from prefect.utilities.templating import find_placeholders +from prefect.utilities.timeout import timeout_async from prefect.workers.base import ( BaseJobConfiguration, BaseVariables, @@ -145,21 +165,8 @@ _slugify_label_key, _slugify_label_value, _slugify_name, - enable_socket_keep_alive, ) -if TYPE_CHECKING: - import kubernetes - import kubernetes.client - import kubernetes.client.exceptions - import kubernetes.config - import kubernetes.watch - from kubernetes.client import ApiClient, BatchV1Api, CoreV1Api, V1Job, V1Pod - - from prefect.client.schemas import FlowRun -else: - kubernetes = lazy_import("kubernetes") - MAX_ATTEMPTS = 3 RETRY_MIN_DELAY_SECONDS = 1 RETRY_MIN_DELAY_JITTER_SECONDS = 0 @@ -343,6 +350,7 @@ def prepare_for_flow_run( preparation. flow: The flow associated with the flow run used for preparation. """ + super().prepare_for_flow_run(flow_run, deployment, flow) # Update configuration env and job manifest env self._update_prefect_api_url_if_local_server() @@ -575,21 +583,17 @@ async def run( final state of the flow run """ logger = self.get_flow_run_logger(flow_run) - - with self._get_configured_kubernetes_client(configuration) as client: + async with self._get_configured_kubernetes_client(configuration) as client: logger.info("Creating Kubernetes job...") - job = await run_sync_in_worker_thread( - self._create_job, configuration, client - ) - pid = await run_sync_in_worker_thread( - self._get_infrastructure_pid, job, client - ) + + job = await self._create_job(configuration, client) + + pid = await self._get_infrastructure_pid(job, client) # Indicate that the job has started if task_status is not None: task_status.started(pid) # Monitor the job until completion - events_replicator = KubernetesEventsReplicator( client=client, job_name=job.metadata.name, @@ -600,11 +604,11 @@ async def run( ), timeout_seconds=configuration.pod_watch_timeout_seconds, ) - - with events_replicator: - status_code = await run_sync_in_worker_thread( - self._watch_job, logger, job.metadata.name, configuration, client + async with events_replicator: + status_code = await self._watch_job( + logger, job.metadata.name, configuration, client ) + return KubernetesWorkerResult(identifier=pid, status_code=status_code) async def kill_infrastructure( @@ -614,12 +618,10 @@ async def kill_infrastructure( grace_seconds: int = 30, ): """ - Stops a job for a cancelled flow run based on the provided infrastructure PID - and run configuration. - """ - await run_sync_in_worker_thread( - self._stop_job, infrastructure_pid, configuration, grace_seconds - ) + Stops a job for a cancelled flow run based on the provided infrastructure PID + and run configuration. + att""" + await self._stop_job(infrastructure_pid, configuration, grace_seconds) async def teardown(self, *exc_info): await super().teardown(*exc_info) @@ -628,33 +630,27 @@ async def teardown(self, *exc_info): async def _clean_up_created_secrets(self): """Deletes any secrets created during the worker's operation.""" - coros = [] for key, configuration in self._created_secrets.items(): - with self._get_configured_kubernetes_client(configuration) as client: - with self._get_core_client(client) as core_client: - coros.append( - run_sync_in_worker_thread( - core_client.delete_namespaced_secret, - name=key[0], - namespace=key[1], - ) - ) - - results = await asyncio.gather(*coros, return_exceptions=True) - for result in results: - if isinstance(result, Exception): - self._logger.warning( - "Failed to delete created secret with exception: %s", result + async with self._get_configured_kubernetes_client(configuration) as client: + v1 = CoreV1Api(client) + result = await v1.delete_namespaced_secret( + name=key[0], + namespace=key[1], ) - def _stop_job( + if isinstance(result, Exception): + self._logger.warning( + "Failed to delete created secret with exception: %s", result + ) + + async def _stop_job( self, infrastructure_pid: str, configuration: KubernetesWorkerJobConfiguration, grace_seconds: int = 30, ): """Removes the given Job from the Kubernetes cluster""" - with self._get_configured_kubernetes_client(configuration) as client: + async with self._get_configured_kubernetes_client(configuration) as client: job_cluster_uid, job_namespace, job_name = self._parse_infrastructure_pid( infrastructure_pid ) @@ -667,16 +663,16 @@ def _stop_job( "deployment configuration." ) - current_cluster_uid = self._get_cluster_uid(client) + current_cluster_uid = await self._get_cluster_uid(client) if job_cluster_uid != current_cluster_uid: raise InfrastructureNotAvailable( f"Unable to kill job {job_name!r}: The job is running on another " "cluster than the one specified by the infrastructure PID." ) - with self._get_batch_client(client) as batch_client: + async with self._get_batch_client(client) as batch_client: try: - batch_client.delete_namespaced_job( + await batch_client.delete_namespaced_job( name=job_name, namespace=job_namespace, grace_period_seconds=grace_seconds, @@ -686,7 +682,7 @@ def _stop_job( # See: https://kubernetes.io/docs/concepts/architecture/garbage-collection/#foreground-deletion # noqa propagation_policy="Foreground", ) - except kubernetes.client.exceptions.ApiException as exc: + except kubernetes_asyncio.client.exceptions.ApiException as exc: if exc.status == 404: raise InfrastructureNotFound( f"Unable to kill job {job_name!r}: The job was not found." @@ -694,42 +690,41 @@ def _stop_job( else: raise - @contextmanager - def _get_configured_kubernetes_client( + @asynccontextmanager + async def _get_configured_kubernetes_client( self, configuration: KubernetesWorkerJobConfiguration - ) -> Generator["ApiClient", None, None]: + ) -> AsyncGenerator["ApiClient", None]: """ Returns a configured Kubernetes client. """ - + client = None + + if configuration.cluster_config: + config_dict = configuration.cluster_config.config + context = configuration.cluster_config.context_name + + # Use Configuration to load configuration from a dictionary + client_configuration = Configuration() + await config.load_kube_config( + config_dict=config_dict, + context=context, + client_configuration=client_configuration, + ) + client = ApiClient(configuration=client_configuration) + else: + # Try to load in-cluster configuration + try: + await config.load_incluster_config() + client = ApiClient() + except config.ConfigException: + # If in-cluster config fails, load the local kubeconfig + client = await config.new_client_from_config() try: - if configuration.cluster_config: - client = kubernetes.config.new_client_from_config_dict( - config_dict=configuration.cluster_config.config, - context=configuration.cluster_config.context_name, - ) - else: - # If no hardcoded config specified, try to load Kubernetes configuration - # within a cluster. If that doesn't work, try to load the configuration - # from the local environment, allowing any further ConfigExceptions to - # bubble up. - try: - kubernetes.config.load_incluster_config() - config = kubernetes.client.Configuration.get_default_copy() - client = kubernetes.client.ApiClient(configuration=config) - except kubernetes.config.ConfigException: - client = kubernetes.config.new_client_from_config() - - if os.environ.get( - "PREFECT_KUBERNETES_WORKER_ADD_TCP_KEEPALIVE", "TRUE" - ).strip().lower() in ("true", "1"): - enable_socket_keep_alive(client) - yield client finally: - client.rest_client.pool_manager.clear() + await client.close() - def _replace_api_key_with_secret( + async def _replace_api_key_with_secret( self, configuration: KubernetesWorkerJobConfiguration, client: "ApiClient" ): """Replaces the PREFECT_API_KEY environment variable with a Kubernetes secret""" @@ -747,7 +742,7 @@ def _replace_api_key_with_secret( api_key = manifest_api_key_env.get("value") if api_key: secret_name = f"prefect-{_slugify_name(self.name)}-api-key" - secret = self._upsert_secret( + secret = await self._upsert_secret( name=secret_name, value=api_key, namespace=configuration.namespace, @@ -779,7 +774,7 @@ def _replace_api_key_with_secret( ), reraise=True, ) - def _create_job( + async def _create_job( self, configuration: KubernetesWorkerJobConfiguration, client: "ApiClient" ) -> "V1Job": """ @@ -788,15 +783,17 @@ def _create_job( if os.environ.get( "PREFECT_KUBERNETES_WORKER_STORE_PREFECT_API_IN_SECRET", "" ).strip().lower() in ("true", "1"): - self._replace_api_key_with_secret( + await self._replace_api_key_with_secret( configuration=configuration, client=client ) + try: - with self._get_batch_client(client) as batch_client: - job = batch_client.create_namespaced_job( - configuration.namespace, configuration.job_manifest - ) - except kubernetes.client.exceptions.ApiException as exc: + batch_client = BatchV1Api(client) + job = await batch_client.create_namespaced_job( + configuration.namespace, + configuration.job_manifest, + ) + except kubernetes_asyncio.client.exceptions.ApiException as exc: # Parse the reason and message from the response if feasible message = "" if exc.reason: @@ -810,56 +807,56 @@ def _create_job( return job - def _upsert_secret( + async def _upsert_secret( self, name: str, value: str, namespace: str, client: "ApiClient" ): encoded_value = base64.b64encode(value.encode("utf-8")).decode("utf-8") - with self._get_core_client(client) as core_client: - try: - # Get the current version of the Secret and update it with the - # new value - current_secret = core_client.read_namespaced_secret( - name=name, namespace=namespace - ) - current_secret.data = {"value": encoded_value} - secret = core_client.replace_namespaced_secret( - name=name, namespace=namespace, body=current_secret - ) - except ApiException as exc: - if exc.status != 404: - raise - # Create the secret if it doesn't already exist - metadata = V1ObjectMeta(name=name, namespace=namespace) - secret = V1Secret( - api_version="v1", - kind="Secret", - metadata=metadata, - data={"value": encoded_value}, - ) - secret = core_client.create_namespaced_secret( - namespace=namespace, body=secret - ) - return secret + core_client = CoreV1Api(client) + try: + # Get the current version of the Secret and update it with the + # new value + current_secret = await core_client.read_namespaced_secret( + name=name, namespace=namespace + ) + current_secret.data = {"value": encoded_value} + secret = await core_client.replace_namespaced_secret( + name=name, namespace=namespace, body=current_secret + ) + except ApiException as exc: + if exc.status != 404: + raise + # Create the secret if it doesn't already exist + metadata = V1ObjectMeta(name=name, namespace=namespace) + secret = V1Secret( + api_version="v1", + kind="Secret", + metadata=metadata, + data={"value": encoded_value}, + ) + secret = await core_client.create_namespaced_secret( + namespace=namespace, body=secret + ) + return secret - @contextmanager - def _get_batch_client( + @asynccontextmanager + async def _get_batch_client( self, client: "ApiClient" - ) -> Generator["BatchV1Api", None, None]: + ) -> AsyncGenerator["BatchV1Api", None]: """ Context manager for retrieving a Kubernetes batch client. """ try: - yield kubernetes.client.BatchV1Api(api_client=client) + yield BatchV1Api(api_client=client) finally: - client.rest_client.pool_manager.clear() + await client.close() - def _get_infrastructure_pid(self, job: "V1Job", client: "ApiClient") -> str: + async def _get_infrastructure_pid(self, job: "V1Job", client: "ApiClient") -> str: """ Generates a Kubernetes infrastructure PID. The PID is in the format: "::". """ - cluster_uid = self._get_cluster_uid(client) + cluster_uid = await self._get_cluster_uid(client) pid = f"{cluster_uid}:{job.metadata.namespace}:{job.metadata.name}" return pid @@ -874,19 +871,7 @@ def _parse_infrastructure_pid( cluster_uid, namespace, job_name = infrastructure_pid.split(":", 2) return cluster_uid, namespace, job_name - @contextmanager - def _get_core_client( - self, client: "ApiClient" - ) -> Generator["CoreV1Api", None, None]: - """ - Context manager for retrieving a Kubernetes core client. - """ - try: - yield kubernetes.client.CoreV1Api(api_client=client) - finally: - client.rest_client.pool_manager.clear() - - def _get_cluster_uid(self, client: "ApiClient") -> str: + async def _get_cluster_uid(self, client: "ApiClient") -> str: """ Gets a unique id for the current cluster being used. @@ -906,20 +891,52 @@ def _get_cluster_uid(self, client: "ApiClient") -> str: return env_cluster_uid # Read the UID from the cluster namespace - with self._get_core_client(client) as core_client: - namespace = core_client.read_namespace("kube-system") + v1 = CoreV1Api(client) + namespace = await v1.read_namespace("kube-system") cluster_uid = namespace.metadata.uid - return cluster_uid - def _job_events( + async def _stream_job_logs( + self, + logger: logging.Logger, + pod_name: str, + job_name: str, + configuration: KubernetesWorkerJobConfiguration, + client, + ): + timeout = aiohttp.ClientTimeout(total=None) + core_client = CoreV1Api(client) + + logs = await core_client.read_namespaced_pod_log( + pod_name, + configuration.namespace, + follow=True, + _preload_content=False, + container="prefect-job", + _request_timeout=timeout, + ) + try: + async for line in logs.content: + if not line: + break + print(line.decode().rstrip()) + except Exception: + logger.warning( + ( + "Error occurred while streaming logs - " + "Job will continue to run but logs will " + "no longer be streamed to stdout." + ), + exc_info=True, + ) + + async def _job_events( self, - watch: kubernetes.watch.Watch, - batch_client: kubernetes.client.BatchV1Api, + batch_client: BatchV1Api, job_name: str, namespace: str, watch_kwargs: dict, - ) -> Generator[Union[Any, dict, str], Any, None]: + ): """ Stream job events. @@ -928,25 +945,77 @@ def _job_events( See https://kubernetes.io/docs/reference/using-api/api-concepts/#efficient-detection-of-changes # noqa """ - while True: - try: - return watch.stream( - func=batch_client.list_namespaced_job, - namespace=namespace, - field_selector=f"metadata.name={job_name}", - **watch_kwargs, - ) - except ApiException as e: - if e.status == 410: - job_list = batch_client.list_namespaced_job( - namespace=namespace, field_selector=f"metadata.name={job_name}" - ) - resource_version = job_list.metadata.resource_version - watch_kwargs["resource_version"] = resource_version - else: - raise + watch = kubernetes_asyncio.watch.Watch() + resource_version = None + async with watch: + while True: + try: + async for event in watch.stream( + func=batch_client.list_namespaced_job, + namespace=namespace, + field_selector=f"metadata.name={job_name}", + **watch_kwargs, + ): + yield event + except ApiException as e: + if e.status == 410: + job_list = await batch_client.list_namespaced_job( + namespace=namespace, + field_selector=f"metadata.name={job_name}", + ) + + resource_version = job_list.metadata.resource_version + watch_kwargs["resource_version"] = resource_version + else: + raise - def _watch_job( + async def _monitor_job_events(self, batch_client, job_name, logger, configuration): + job = await batch_client.read_namespaced_job( + name=job_name, namespace=configuration.namespace + ) + completed = job.status.completion_time is not None + watch_kwargs = ( + {"timeout_seconds": configuration.job_watch_timeout_seconds} + if configuration.job_watch_timeout_seconds + else {} + ) + + while not completed: + async for event in self._job_events( + batch_client, + job_name, + configuration.namespace, + watch_kwargs, + ): + if event["type"] == "DELETED": + logger.error(f"Job {job_name!r}: Job has been deleted.") + completed = True + elif event["object"].status.completion_time: + if not event["object"].status.succeeded: + # Job failed, exit while loop and return pod exit code + logger.error(f"Job {job_name!r}: Job failed.") + completed = True + # Check if the job has reached its backoff limit + # and stop watching if it has + elif ( + event["object"].spec.backoff_limit is not None + and event["object"].status.failed is not None + and event["object"].status.failed + > event["object"].spec.backoff_limit + ): + logger.error(f"Job {job_name!r}: Job reached backoff limit.") + completed = True + # If the job has no backoff limit, check if it has failed + # and stop watching if it has + elif ( + not event["object"].spec.backoff_limit + and event["object"].status.failed + ): + completed = True + if completed: + break + + async def _watch_job( self, logger: logging.Logger, job_name: str, @@ -958,119 +1027,58 @@ def _watch_job( Return the final status code of the first container. """ + logger.debug(f"Job {job_name!r}: Monitoring job...") - job = self._get_job(logger, job_name, configuration, client) + job = await self._get_job(logger, job_name, configuration, client) if not job: return -1 - pod = self._get_job_pod(logger, job_name, configuration, client) + pod = await self._get_job_pod(logger, job_name, configuration, client) if not pod: return -1 - # Calculate the deadline before streaming output - deadline = ( - (time.monotonic() + configuration.job_watch_timeout_seconds) - if configuration.job_watch_timeout_seconds is not None - else None - ) - - if configuration.stream_output: - with self._get_core_client(client) as core_client: - logs = core_client.read_namespaced_pod_log( - pod.metadata.name, - configuration.namespace, - follow=True, - _preload_content=False, - container="prefect-job", + # Create a list of tasks to run concurrently + async with self._get_batch_client(client) as batch_client: + tasks = [ + self._monitor_job_events( + batch_client, + job_name, + logger, + configuration, ) - try: - for log in logs.stream(): - print(log.decode().rstrip()) - - # Check if we have passed the deadline and should stop streaming - # logs - remaining_time = ( - deadline - time.monotonic() if deadline else None + ] + try: + with timeout_async(seconds=configuration.job_watch_timeout_seconds): + if configuration.stream_output: + tasks.append( + self._stream_job_logs( + logger, + pod.metadata.name, + job_name, + configuration, + client, + ) ) - if deadline and remaining_time <= 0: - break - - except Exception: - logger.warning( - ( - "Error occurred while streaming logs - " - "Job will continue to run but logs will " - "no longer be streamed to stdout." - ), - exc_info=True, - ) - - with self._get_batch_client(client) as batch_client: - # Check if the job is completed before beginning a watch - job = batch_client.read_namespaced_job( - name=job_name, namespace=configuration.namespace - ) - completed = job.status.completion_time is not None - while not completed: - remaining_time = ( - math.ceil(deadline - time.monotonic()) if deadline else None + results = await asyncio.gather(*tasks, return_exceptions=True) + if any(isinstance(result, Exception) for result in results): + for result in results: + if isinstance(result, Exception): + logger.error( + f"Error during task execution: {result}", + exc_info=True, + ) + except TimeoutError: + logger.error( + f"Job {job_name!r}: Job did not complete within " + f"timeout of {configuration.job_watch_timeout_seconds}s." ) - if deadline and remaining_time <= 0: - logger.error( - f"Job {job_name!r}: Job did not complete within " - f"timeout of {configuration.job_watch_timeout_seconds}s." - ) - return -1 - - watch = kubernetes.watch.Watch() - - # The kubernetes library will disable retries if the timeout kwarg is - # present regardless of the value so we do not pass it unless given - # https://github.com/kubernetes-client/python/blob/84f5fea2a3e4b161917aa597bf5e5a1d95e24f5a/kubernetes/base/watch/watch.py#LL160 - watch_kwargs = {"timeout_seconds": remaining_time} if deadline else {} - - for event in self._job_events( - watch, - batch_client, - job_name, - configuration.namespace, - watch_kwargs, - ): - if event["type"] == "DELETED": - logger.error(f"Job {job_name!r}: Job has been deleted.") - completed = True - elif event["object"].status.completion_time: - if not event["object"].status.succeeded: - # Job failed, exit while loop and return pod exit code - logger.error(f"Job {job_name!r}: Job failed.") - completed = True - # Check if the job has reached its backoff limit - # and stop watching if it has - elif ( - event["object"].spec.backoff_limit is not None - and event["object"].status.failed is not None - and event["object"].status.failed - > event["object"].spec.backoff_limit - ): - logger.error(f"Job {job_name!r}: Job reached backoff limit.") - completed = True - # If the job has no backoff limit, check if it has failed - # and stop watching if it has - elif ( - not event["object"].spec.backoff_limit - and event["object"].status.failed - ): - completed = True - - if completed: - watch.stop() - break + return -1 - with self._get_core_client(client) as core_client: + core_client = CoreV1Api(client) # Get all pods for the job - pods = core_client.list_namespaced_pod( + pods = await core_client.list_namespaced_pod( namespace=configuration.namespace, label_selector=f"job-name={job_name}" ) # Get the status for only the most recently used pod @@ -1083,28 +1091,28 @@ def _watch_job( if most_recent_pod else None ) - if not first_container_status: - logger.error(f"Job {job_name!r}: No pods found for job.") - return -1 - # In some cases, such as spot instance evictions, the pod will be forcibly - # terminated and not report a status correctly. - elif ( - first_container_status.state is None - or first_container_status.state.terminated is None - or first_container_status.state.terminated.exit_code is None - ): - logger.error( - f"Could not determine exit code for {job_name!r}." - "Exit code will be reported as -1." - f"First container status info did not report an exit code." - f"First container info: {first_container_status}." - ) - return -1 + if not first_container_status: + logger.error(f"Job {job_name!r}: No pods found for job.") + return -1 + # In some cases, such as spot instance evictions, the pod will be forcibly + # terminated and not report a status correctly. + elif ( + first_container_status.state is None + or first_container_status.state.terminated is None + or first_container_status.state.terminated.exit_code is None + ): + logger.error( + f"Could not determine exit code for {job_name!r}." + "Exit code will be reported as -1." + f"First container status info did not report an exit code." + f"First container info: {first_container_status}." + ) + return -1 return first_container_status.state.terminated.exit_code - def _get_job( + async def _get_job( self, logger: logging.Logger, job_id: str, @@ -1112,17 +1120,18 @@ def _get_job( client: "ApiClient", ) -> Optional["V1Job"]: """Get a Kubernetes job by id.""" - with self._get_batch_client(client) as batch_client: - try: - job = batch_client.read_namespaced_job( - name=job_id, namespace=configuration.namespace - ) - except kubernetes.client.exceptions.ApiException: - logger.error(f"Job {job_id!r} was removed.", exc_info=True) - return None - return job - def _get_job_pod( + try: + batch_client = BatchV1Api(client) + job = await batch_client.read_namespaced_job( + name=job_id, namespace=configuration.namespace + ) + except kubernetes_asyncio.client.exceptions.ApiException: + logger.error(f"Job {job_id!r} was removed.", exc_info=True) + return None + return job + + async def _get_job_pod( self, logger: logging.Logger, job_name: str, @@ -1130,14 +1139,14 @@ def _get_job_pod( client: "ApiClient", ) -> Optional["V1Pod"]: """Get the first running pod for a job.""" - from kubernetes.client.models import V1Pod - watch = kubernetes.watch.Watch() - logger.debug(f"Job {job_name!r}: Starting watch for pod start...") + watch = kubernetes_asyncio.watch.Watch() + logger.info(f"Job {job_name!r}: Starting watch for pod start...") last_phase = None last_pod_name: Optional[str] = None - with self._get_core_client(client) as core_client: - for event in watch.stream( + core_client = CoreV1Api(client) + async with watch: + async for event in watch.stream( func=core_client.list_namespaced_pod, namespace=configuration.namespace, label_selector=f"job-name={job_name}", @@ -1145,26 +1154,26 @@ def _get_job_pod( ): pod: V1Pod = event["object"] last_pod_name = pod.metadata.name - + logger.info(f"Job {job_name!r}: Pod {last_pod_name!r} has started.") phase = pod.status.phase if phase != last_phase: logger.info(f"Job {job_name!r}: Pod has status {phase!r}.") if phase != "Pending": - watch.stop() return pod last_phase = phase + # If we've gotten here, we never found the Pod that was created for the flow run + # Job, so let's inspect the situation and log what we can find. It's possible + # that the Job ran into scheduling constraints it couldn't satisfy, like + # memory/CPU requests, or a volume that wasn't available, or a node with an + # available GPU. + logger.error(f"Job {job_name!r}: Pod never started.") + await self._log_recent_events( + logger, job_name, last_pod_name, configuration, client + ) - # If we've gotten here, we never found the Pod that was created for the flow run - # Job, so let's inspect the situation and log what we can find. It's possible - # that the Job ran into scheduling constraints it couldn't satisfy, like - # memory/CPU requests, or a volume that wasn't available, or a node with an - # available GPU. - logger.error(f"Job {job_name!r}: Pod never started.") - self._log_recent_events(logger, job_name, last_pod_name, configuration, client) - - def _log_recent_events( + async def _log_recent_events( self, logger: logging.Logger, job_name: str, @@ -1174,7 +1183,6 @@ def _log_recent_events( ) -> None: """Look for reasons why a Job may not have been able to schedule a Pod, or why a Pod may not have been able to start and log them to the provided logger.""" - from kubernetes.client.models import CoreV1Event, CoreV1EventList def best_event_time(event: CoreV1Event) -> datetime: """Choose the best timestamp from a Kubernetes event""" @@ -1200,25 +1208,27 @@ def log_event(event: CoreV1Event): event.message, ) - with self._get_core_client(client) as core_client: - events: CoreV1EventList = core_client.list_namespaced_event( - configuration.namespace - ) - event: CoreV1Event - for event in sorted(events.items, key=best_event_time): - if ( - event.involved_object.api_version == "batch/v1" - and event.involved_object.kind == "Job" - and event.involved_object.namespace == configuration.namespace - and event.involved_object.name == job_name - ): - log_event(event) - - if ( - pod_name - and event.involved_object.api_version == "v1" - and event.involved_object.kind == "Pod" - and event.involved_object.namespace == configuration.namespace - and event.involved_object.name == pod_name - ): - log_event(event) + core_client = CoreV1Api(client) + + events: CoreV1EventList = await core_client.list_namespaced_event( + configuration.namespace + ) + + event: CoreV1Event + for event in sorted(events.items, key=best_event_time): + if ( + event.involved_object.api_version == "batch/v1" + and event.involved_object.kind == "Job" + and event.involved_object.namespace == configuration.namespace + and event.involved_object.name == job_name + ): + log_event(event) + + if ( + pod_name + and event.involved_object.api_version == "v1" + and event.involved_object.kind == "Pod" + and event.involved_object.namespace == configuration.namespace + and event.involved_object.name == pod_name + ): + log_event(event) diff --git a/src/integrations/prefect-kubernetes/pyproject.toml b/src/integrations/prefect-kubernetes/pyproject.toml index 3e73bbdedc35..3ba1acae82de 100644 --- a/src/integrations/prefect-kubernetes/pyproject.toml +++ b/src/integrations/prefect-kubernetes/pyproject.toml @@ -3,6 +3,14 @@ requires = ["setuptools>=45", "wheel", "setuptools_scm>=6.2"] build-backend = "setuptools.build_meta" [project] +dependencies = [ + "prefect>=3.0.0rc1", + "kubernetes-asyncio>=29.0.0", + "tenacity>=8.2.3", + "exceptiongroup", + "pyopenssl>=24.1.0", + +] name = "prefect-kubernetes" description = "Prefect integrations for interacting with Kubernetes." readme = "README.md" @@ -11,41 +19,35 @@ license = { text = "Apache License 2.0" } keywords = ["prefect"] authors = [{ name = "Prefect Technologies, Inc.", email = "help@prefect.io" }] classifiers = [ - "Natural Language :: English", - "Intended Audience :: Developers", - "Intended Audience :: System Administrators", - "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Topic :: Software Development :: Libraries", -] -dependencies = [ - "prefect>=3.0.0rc1", - "kubernetes>=24.2.0", - "tenacity>=8.2.3", - "exceptiongroup", + "Natural Language :: English", + "Intended Audience :: Developers", + "Intended Audience :: System Administrators", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Software Development :: Libraries", ] dynamic = ["version"] [project.optional-dependencies] dev = [ - "aiohttp", - "coverage", - "interrogate", - "mkdocs-gen-files", - "mkdocs-material", - "mkdocs", - "mkdocstrings[python]", - "mypy", - "pillow", - "pre-commit", - "pytest-asyncio", - "pytest", - "pytest-env", - "pytest-xdist", + "aiohttp", + "coverage", + "interrogate", + "mkdocs-gen-files", + "mkdocs-material", + "mkdocs", + "mkdocstrings[python]", + "mypy", + "pillow", + "pre-commit", + "pytest-asyncio", + "pytest", + "pytest-env", + "pytest-xdist", ] [project.urls] @@ -77,6 +79,4 @@ show_missing = true [tool.pytest.ini_options] asyncio_mode = "auto" -env = [ - "PREFECT_TEST_MODE=1", -] +env = ["PREFECT_TEST_MODE=1"] diff --git a/src/integrations/prefect-kubernetes/tests/conftest.py b/src/integrations/prefect-kubernetes/tests/conftest.py index 477440141f88..55994625f583 100644 --- a/src/integrations/prefect-kubernetes/tests/conftest.py +++ b/src/integrations/prefect-kubernetes/tests/conftest.py @@ -1,18 +1,27 @@ -from contextlib import contextmanager from pathlib import Path -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock import pytest import yaml -from kubernetes.client import AppsV1Api, BatchV1Api, CoreV1Api, CustomObjectsApi, models -from kubernetes.client.exceptions import ApiException +from kubernetes_asyncio.client import ( + AppsV1Api, + BatchV1Api, + CoreV1Api, + CustomObjectsApi, + models, +) +from kubernetes_asyncio.client.exceptions import ApiException from prefect_kubernetes.credentials import KubernetesCredentials from prefect_kubernetes.jobs import KubernetesJob from prefect.settings import PREFECT_LOGGING_TO_API_ENABLED, temporary_settings from prefect.testing.utilities import prefect_test_harness -BASEDIR = Path("tests") +BASEDIR = ( + Path.cwd() / "src" / "integrations" / "prefect-kubernetes" / "tests" + if Path.cwd().name == "prefect" + else Path.cwd() / "tests" +) GOOD_CONFIG_FILE_PATH = BASEDIR / "kube_config.yaml" @@ -21,8 +30,14 @@ def prefect_db(): """ Sets up test harness for temporary DB during test runs. """ - with prefect_test_harness(): - yield + try: + with prefect_test_harness(): + yield + except OSError as e: + if "Directory not empty" in str(e): + pass + else: + raise e @pytest.fixture(scope="session", autouse=True) @@ -74,31 +89,21 @@ def kubernetes_credentials(kube_config_dict): @pytest.fixture def _mock_api_app_client(monkeypatch): - app_client = MagicMock(spec=AppsV1Api) - - @contextmanager - def get_client(self, _): - yield app_client - + app_client = AsyncMock(spec=AppsV1Api) monkeypatch.setattr( - "prefect_kubernetes.credentials.KubernetesCredentials.get_client", - get_client, + "prefect_kubernetes.credentials.KubernetesCredentials.get_resource_specific_client", + app_client, ) - return app_client @pytest.fixture -def _mock_api_batch_client(monkeypatch): - batch_client = MagicMock(spec=BatchV1Api) - - @contextmanager - def get_client(self, _): - yield batch_client +async def _mock_api_batch_client(monkeypatch): + batch_client = AsyncMock(spec=BatchV1Api) monkeypatch.setattr( - "prefect_kubernetes.credentials.KubernetesCredentials.get_client", - get_client, + "prefect_kubernetes.credentials.KubernetesCredentials.get_resource_specific_client", + batch_client, ) return batch_client @@ -106,31 +111,22 @@ def get_client(self, _): @pytest.fixture def _mock_api_core_client(monkeypatch): - core_client = MagicMock(spec=CoreV1Api) - - @contextmanager - def get_client(self, _): - yield core_client + core_client = AsyncMock(spec=CoreV1Api) monkeypatch.setattr( - "prefect_kubernetes.credentials.KubernetesCredentials.get_client", - get_client, + "prefect_kubernetes.credentials.KubernetesCredentials.get_resource_specific_client", + core_client, ) - return core_client @pytest.fixture def _mock_api_custom_objects_client(monkeypatch): - custom_objects_client = MagicMock(spec=CustomObjectsApi) - - @contextmanager - def get_client(self, _): - yield custom_objects_client + custom_objects_client = AsyncMock(spec=CustomObjectsApi) monkeypatch.setattr( - "prefect_kubernetes.credentials.KubernetesCredentials.get_client", - get_client, + "prefect_kubernetes.credentials.KubernetesCredentials.get_resource_specific_client", + custom_objects_client, ) return custom_objects_client @@ -138,18 +134,18 @@ def get_client(self, _): @pytest.fixture def mock_create_namespaced_job(monkeypatch): - mock_v1_job = MagicMock( + mock_v1_job = AsyncMock( return_value=models.V1Job(metadata=models.V1ObjectMeta(name="test")) ) monkeypatch.setattr( - "kubernetes.client.BatchV1Api.create_namespaced_job", mock_v1_job + "kubernetes_asyncio.client.api.BatchV1Api.create_namespaced_job", mock_v1_job ) return mock_v1_job @pytest.fixture def mock_read_namespaced_job_status(monkeypatch): - mock_v1_job_status = MagicMock( + mock_v1_job_status = AsyncMock( return_value=models.V1Job( metadata=models.V1ObjectMeta( name="test", labels={"controller-uid": "test"} @@ -170,7 +166,7 @@ def mock_read_namespaced_job_status(monkeypatch): ) ) monkeypatch.setattr( - "kubernetes.client.BatchV1Api.read_namespaced_job_status", + "kubernetes_asyncio.client.BatchV1Api.read_namespaced_job_status", mock_v1_job_status, ) return mock_v1_job_status @@ -178,11 +174,11 @@ def mock_read_namespaced_job_status(monkeypatch): @pytest.fixture def mock_delete_namespaced_job(monkeypatch): - mock_v1_job = MagicMock( + mock_v1_job = AsyncMock( return_value=models.V1Job(metadata=models.V1ObjectMeta(name="test")) ) monkeypatch.setattr( - "kubernetes.client.BatchV1Api.delete_namespaced_job", mock_v1_job + "kubernetes_asyncio.client.BatchV1Api.delete_namespaced_job", mock_v1_job ) return mock_v1_job @@ -190,16 +186,19 @@ def mock_delete_namespaced_job(monkeypatch): @pytest.fixture def mock_stream_timeout(monkeypatch): monkeypatch.setattr( - "kubernetes.watch.Watch.stream", + "kubernetes_asyncio.watch.Watch.stream", MagicMock(side_effect=ApiException(status=408)), ) @pytest.fixture def mock_pod_log(monkeypatch): + async def pod_log(*args, **kwargs): + yield "test log" + monkeypatch.setattr( - "kubernetes.watch.Watch.stream", - MagicMock(return_value=["test log"]), + "kubernetes_asyncio.watch.Watch.stream", + MagicMock(side_effect=pod_log), ) @@ -213,25 +212,27 @@ def mock_list_namespaced_pod(monkeypatch): ) ] ) - mock_pod_list = MagicMock(return_value=result) + mock_pod_list = AsyncMock(return_value=result) monkeypatch.setattr( - "kubernetes.client.CoreV1Api.list_namespaced_pod", mock_pod_list + "kubernetes_asyncio.client.api.CoreV1Api.list_namespaced_pod", mock_pod_list ) return mock_pod_list @pytest.fixture def read_pod_logs(monkeypatch): - pod_log = MagicMock(return_value="test log") + pod_log = AsyncMock(return_value="test log") - monkeypatch.setattr("kubernetes.client.CoreV1Api.read_namespaced_pod_log", pod_log) + monkeypatch.setattr( + "kubernetes_asyncio.client.api.CoreV1Api.read_namespaced_pod_log", pod_log + ) return pod_log @pytest.fixture def valid_kubernetes_job_block(kubernetes_credentials): - with open("tests/sample_k8s_resources/sample_job.yaml") as f: + with open(BASEDIR / "sample_k8s_resources" / "sample_job.yaml") as f: job_dict = yaml.safe_load(f) return KubernetesJob( diff --git a/src/integrations/prefect-kubernetes/tests/test_credentials.py b/src/integrations/prefect-kubernetes/tests/test_credentials.py index b8222d1c57ae..1f8d1864b7a3 100644 --- a/src/integrations/prefect-kubernetes/tests/test_credentials.py +++ b/src/integrations/prefect-kubernetes/tests/test_credentials.py @@ -1,137 +1,194 @@ -import base64 +import os +import tempfile from pathlib import Path from typing import Dict import pydantic import pytest import yaml -from kubernetes.client import ( +from kubernetes_asyncio.client import ( ApiClient, AppsV1Api, BatchV1Api, CoreV1Api, CustomObjectsApi, ) -from kubernetes.config.kube_config import list_kube_config_contexts +from kubernetes_asyncio.config.kube_config import list_kube_config_contexts +from OpenSSL import crypto from prefect_kubernetes.credentials import KubernetesClusterConfig -sample_base64_string = base64.b64encode(b"hello marvin from the other side") - -CONFIG_CONTENT = f""" - apiVersion: v1 - clusters: - - cluster: - certificate-authority-data: {sample_base64_string} - server: https://kubernetes.docker.internal:6443 - name: docker-desktop - contexts: - - context: - cluster: docker-desktop - user: docker-desktop - name: docker-desktop - current-context: docker-desktop - kind: Config - preferences: {{}} - users: - - name: docker-desktop - user: - client-certificate-data: {sample_base64_string} - client-key-data: {sample_base64_string} -""" - @pytest.fixture -def config_file(tmp_path) -> Path: - config_file = tmp_path / "kube_config" - - config_file.write_text(CONFIG_CONTENT) - - return config_file - +def create_temp_self_signed_cert(tmp_path): + """Create a self signed SSL certificate in temporary files for host + 'localhost' + + Returns a tuple containing the certificate file name and the key + file name. + + It is the caller's responsibility to delete the files after use + """ + # create a key pair + key = crypto.PKey() + key.generate_key(crypto.TYPE_RSA, 2048) + + # create a self-signed cert + cert = crypto.X509() + cert.get_subject().C = "US" + cert.get_subject().ST = "Chicago" + cert.get_subject().L = "Chicago" + cert.get_subject().O = "myapp" + cert.get_subject().OU = "myapp" + cert.get_subject().CN = "localhost" + cert.set_serial_number(1000) + cert.gmtime_adj_notBefore(0) + cert.gmtime_adj_notAfter(10 * 365 * 24 * 60 * 60) + cert.set_issuer(cert.get_subject()) + cert.set_pubkey(key) + cert.sign(key, "sha1") + + # Save certificate in temporary file + (cert_file_fd, cert_file_name) = tempfile.mkstemp( + suffix=".crt", prefix="cert", dir=tmp_path + ) + cert_file = os.fdopen(cert_file_fd, "wb") + cert_file.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert)) + cert_file.close() -@pytest.mark.parametrize( - "resource_type,client_type", - [ - ("apps", AppsV1Api), - ("batch", BatchV1Api), - ("core", CoreV1Api), - ("custom_objects", CustomObjectsApi), - ], -) -def test_client_return_type(kubernetes_credentials, resource_type, client_type): - with kubernetes_credentials.get_client(resource_type) as client: - assert isinstance(client, client_type) + # Save key in temporary file + (key_file_fd, key_file_name) = tempfile.mkstemp( + suffix=".key", prefix="cert", dir=tmp_path + ) + key_file = os.fdopen(key_file_fd, "wb") + key_file.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, key)) + key_file.close() -def test_client_bad_resource_type(kubernetes_credentials): - with pytest.raises( - ValueError, match="Invalid client type provided 'shoo-ba-daba-doo'" - ): - with kubernetes_credentials.get_client("shoo-ba-daba-doo"): - pass + # Return file names + yield (cert_file_name, key_file_name) + os.remove(cert_file_name) + os.remove(key_file_name) -async def test_instantiation_from_file(config_file): - cluster_config = KubernetesClusterConfig.from_file(path=config_file) - assert isinstance(cluster_config, KubernetesClusterConfig) - assert isinstance(cluster_config.config, Dict) - assert isinstance(cluster_config.context_name, str) - - assert cluster_config.config == yaml.safe_load(CONFIG_CONTENT) - assert cluster_config.context_name == "docker-desktop" +@pytest.fixture +def config_context(create_temp_self_signed_cert): + _cert_file, _cert_key = create_temp_self_signed_cert + CONFIG_CONTENT = f""" + apiVersion: v1 + clusters: + - cluster: + certificate-authority: {_cert_file} + server: https://kubernetes.docker.internal:6443 + name: docker-desktop + contexts: + - context: + cluster: docker-desktop + user: docker-desktop + name: docker-desktop + current-context: docker-desktop + kind: Config + preferences: {{}} + users: + - name: docker-desktop + user: + client-certificate: {_cert_file} + client-key: {_cert_key} + """ + return CONFIG_CONTENT -async def test_instantiation_from_dict(config_file): - cluster_config = KubernetesClusterConfig( - config=yaml.safe_load(CONFIG_CONTENT), context_name="docker-desktop" - ) +@pytest.fixture +def config_file(tmp_path, config_context) -> Path: + config_file = tmp_path / "kube_config" - assert isinstance(cluster_config, KubernetesClusterConfig) - assert isinstance(cluster_config.config, Dict) - assert isinstance(cluster_config.context_name, str) + config_file.write_text(config_context) - assert cluster_config.config == yaml.safe_load(CONFIG_CONTENT) - assert cluster_config.context_name == "docker-desktop" + return config_file -async def test_instantiation_from_str(): - cluster_config = KubernetesClusterConfig( - config=CONFIG_CONTENT, context_name="docker-desktop" +class TestCredentials: + @pytest.mark.parametrize( + "resource_type,client_type", + [ + ("apps", AppsV1Api), + ("batch", BatchV1Api), + ("core", CoreV1Api), + ("custom_objects", CustomObjectsApi), + ], ) + async def test_client_return_type( + self, kubernetes_credentials, resource_type, client_type + ): + async with kubernetes_credentials.get_client(resource_type) as client: + assert isinstance(client, client_type) - assert isinstance(cluster_config, KubernetesClusterConfig) - assert isinstance(cluster_config.config, Dict) - assert isinstance(cluster_config.context_name, str) + async def test_client_bad_resource_type(self, kubernetes_credentials): + with pytest.raises( + ValueError, match="Invalid client type provided 'shoo-ba-daba-doo'" + ): + async with kubernetes_credentials.get_client("shoo-ba-daba-doo"): + pass - assert cluster_config.config == yaml.safe_load(CONFIG_CONTENT) - assert cluster_config.context_name == "docker-desktop" +class TestKubernetesClusterConfig: + async def test_instantiation_from_file(self, config_file, config_context): + cluster_config = KubernetesClusterConfig.from_file(path=config_file) -async def test_instantiation_from_invalid_str(): - with pytest.raises( - pydantic.ValidationError, - match="Input should be a valid dictionary", - ): - KubernetesClusterConfig(config="foo", context_name="docker-desktop") + assert isinstance(cluster_config, KubernetesClusterConfig) + assert isinstance(cluster_config.config, Dict) + assert isinstance(cluster_config.context_name, str) + assert cluster_config.config == yaml.safe_load(config_context) + assert cluster_config.context_name == "docker-desktop" -async def test_instantiation_from_file_with_unknown_context_name(config_file): - with pytest.raises(ValueError): - KubernetesClusterConfig.from_file( - path=config_file, context_name="random_not_real" + async def test_instantiation_from_dict(self, config_file, config_context): + cluster_config = KubernetesClusterConfig( + config=yaml.safe_load(config_context), + context_name="docker-desktop", ) + assert isinstance(cluster_config, KubernetesClusterConfig) + assert isinstance(cluster_config.config, Dict) + assert isinstance(cluster_config.context_name, str) -async def test_get_api_client(config_file): - cluster_config = KubernetesClusterConfig.from_file(path=config_file) - api_client = cluster_config.get_api_client() - assert isinstance(api_client, ApiClient) + assert cluster_config.config == yaml.safe_load(config_context) + assert cluster_config.context_name == "docker-desktop" + async def test_instantiation_from_str(self, config_context): + cluster_config = KubernetesClusterConfig( + config=config_context, context_name="docker-desktop" + ) -async def test_configure_client(config_file): - cluster_config = KubernetesClusterConfig.from_file(path=config_file) - cluster_config.configure_client() - context_dict = list_kube_config_contexts(config_file=str(config_file)) - current_context = context_dict[1]["name"] - assert cluster_config.context_name == current_context + assert isinstance(cluster_config, KubernetesClusterConfig) + assert isinstance(cluster_config.config, Dict) + assert isinstance(cluster_config.context_name, str) + + assert cluster_config.config == yaml.safe_load(config_context) + assert cluster_config.context_name == "docker-desktop" + + async def test_instantiation_from_invalid_str(self): + with pytest.raises( + pydantic.ValidationError, + match="Input should be a valid dictionary", + ): + KubernetesClusterConfig(config="foo", context_name="docker-desktop") + + async def test_instantiation_from_file_with_unknown_context_name(self, config_file): + with pytest.raises(ValueError): + await KubernetesClusterConfig.from_file( + path=config_file, context_name="random_not_real" + ) + + async def test_get_api_client(self, config_file): + cluster_config = KubernetesClusterConfig.from_file(path=config_file) + + api_client = await cluster_config.get_api_client() + assert isinstance(api_client, ApiClient) + + async def test_configure_client(self, config_file): + cluster_config = KubernetesClusterConfig.from_file(path=config_file) + await cluster_config.configure_client() + context_dict = list_kube_config_contexts(config_file=str(config_file)) + current_context = context_dict[1]["name"] + assert cluster_config.context_name == current_context diff --git a/src/integrations/prefect-kubernetes/tests/test_custom_objects.py b/src/integrations/prefect-kubernetes/tests/test_custom_objects.py index 4fa39f0effa0..f2efa3142b1c 100644 --- a/src/integrations/prefect-kubernetes/tests/test_custom_objects.py +++ b/src/integrations/prefect-kubernetes/tests/test_custom_objects.py @@ -1,5 +1,5 @@ import pytest -from kubernetes.client.exceptions import ApiValueError +from kubernetes_asyncio.client.exceptions import ApiValueError from prefect_kubernetes.custom_objects import ( create_namespaced_custom_object, delete_namespaced_custom_object, @@ -59,34 +59,37 @@ async def test_create_namespaced_crd( ) assert ( - _mock_api_custom_objects_client.create_namespaced_custom_object.call_args[1][ - "a" - ] + _mock_api_custom_objects_client.return_value.create_namespaced_custom_object.call_args[ + 1 + ]["a"] == "test" ) assert ( - _mock_api_custom_objects_client.create_namespaced_custom_object.call_args[1][ - "group" - ] + _mock_api_custom_objects_client.return_value.create_namespaced_custom_object.call_args[ + 1 + ]["group"] == "my-group" ) assert ( - _mock_api_custom_objects_client.create_namespaced_custom_object.call_args[1][ - "version" - ] + _mock_api_custom_objects_client.return_value.create_namespaced_custom_object.call_args[ + 1 + ]["version"] == "v1" ) assert ( - _mock_api_custom_objects_client.create_namespaced_custom_object.call_args[1][ - "plural" - ] + _mock_api_custom_objects_client.return_value.create_namespaced_custom_object.call_args[ + 1 + ]["plural"] == "ops" ) # We can't have models for Custom Resources. - assert _mock_api_custom_objects_client.create_namespaced_custom_object.call_args[1][ - "body" - ]["metadata"] == {"name": "test"} + assert ( + _mock_api_custom_objects_client.return_value.create_namespaced_custom_object.call_args[ + 1 + ]["body"]["metadata"] + == {"name": "test"} + ) async def test_get_namespaced_custom_object( @@ -102,33 +105,35 @@ async def test_get_namespaced_custom_object( ) assert ( - _mock_api_custom_objects_client.get_namespaced_custom_object.call_args[1]["a"] + _mock_api_custom_objects_client.return_value.get_namespaced_custom_object.call_args[ + 1 + ]["a"] == "test" ) assert ( - _mock_api_custom_objects_client.get_namespaced_custom_object.call_args[1][ - "group" - ] + _mock_api_custom_objects_client.return_value.get_namespaced_custom_object.call_args[ + 1 + ]["group"] == "my-group" ) assert ( - _mock_api_custom_objects_client.get_namespaced_custom_object.call_args[1][ - "version" - ] + _mock_api_custom_objects_client.return_value.get_namespaced_custom_object.call_args[ + 1 + ]["version"] == "v1" ) assert ( - _mock_api_custom_objects_client.get_namespaced_custom_object.call_args[1][ - "plural" - ] + _mock_api_custom_objects_client.return_value.get_namespaced_custom_object.call_args[ + 1 + ]["plural"] == "ops" ) # We can't have models for Custom Resources. assert ( - _mock_api_custom_objects_client.get_namespaced_custom_object.call_args[1][ - "name" - ] + _mock_api_custom_objects_client.return_value.get_namespaced_custom_object.call_args[ + 1 + ]["name"] == "test-name" ) @@ -146,33 +151,33 @@ async def test_get_namespaced_custom_object_status( ) assert ( - _mock_api_custom_objects_client.get_namespaced_custom_object_status.call_args[ + _mock_api_custom_objects_client.return_value.get_namespaced_custom_object_status.call_args[ 1 ]["a"] == "test" ) assert ( - _mock_api_custom_objects_client.get_namespaced_custom_object_status.call_args[ + _mock_api_custom_objects_client.return_value.get_namespaced_custom_object_status.call_args[ 1 ]["group"] == "my-group" ) assert ( - _mock_api_custom_objects_client.get_namespaced_custom_object_status.call_args[ + _mock_api_custom_objects_client.return_value.get_namespaced_custom_object_status.call_args[ 1 ]["version"] == "v1" ) assert ( - _mock_api_custom_objects_client.get_namespaced_custom_object_status.call_args[ + _mock_api_custom_objects_client.return_value.get_namespaced_custom_object_status.call_args[ 1 ]["plural"] == "ops" ) # We can't have models for Custom Resources. assert ( - _mock_api_custom_objects_client.get_namespaced_custom_object_status.call_args[ + _mock_api_custom_objects_client.return_value.get_namespaced_custom_object_status.call_args[ 1 ]["name"] == "test-name" @@ -192,35 +197,35 @@ async def test_delete_namespaced_custom_object( ) assert ( - _mock_api_custom_objects_client.delete_namespaced_custom_object.call_args[1][ - "a" - ] + _mock_api_custom_objects_client.return_value.delete_namespaced_custom_object.call_args[ + 1 + ]["a"] == "test" ) assert ( - _mock_api_custom_objects_client.delete_namespaced_custom_object.call_args[1][ - "group" - ] + _mock_api_custom_objects_client.return_value.delete_namespaced_custom_object.call_args[ + 1 + ]["group"] == "my-group" ) assert ( - _mock_api_custom_objects_client.delete_namespaced_custom_object.call_args[1][ - "version" - ] + _mock_api_custom_objects_client.return_value.delete_namespaced_custom_object.call_args[ + 1 + ]["version"] == "v1" ) assert ( - _mock_api_custom_objects_client.delete_namespaced_custom_object.call_args[1][ - "plural" - ] + _mock_api_custom_objects_client.return_value.delete_namespaced_custom_object.call_args[ + 1 + ]["plural"] == "ops" ) # We can't have models for Custom Resources. assert ( - _mock_api_custom_objects_client.delete_namespaced_custom_object.call_args[1][ - "name" - ] + _mock_api_custom_objects_client.return_value.delete_namespaced_custom_object.call_args[ + 1 + ]["name"] == "test-name" ) @@ -237,26 +242,28 @@ async def test_list_namespaced_custom_object( ) assert ( - _mock_api_custom_objects_client.list_namespaced_custom_object.call_args[1]["a"] + _mock_api_custom_objects_client.return_value.list_namespaced_custom_object.call_args[ + 1 + ]["a"] == "test" ) assert ( - _mock_api_custom_objects_client.list_namespaced_custom_object.call_args[1][ - "group" - ] + _mock_api_custom_objects_client.return_value.list_namespaced_custom_object.call_args[ + 1 + ]["group"] == "my-group" ) assert ( - _mock_api_custom_objects_client.list_namespaced_custom_object.call_args[1][ - "version" - ] + _mock_api_custom_objects_client.return_value.list_namespaced_custom_object.call_args[ + 1 + ]["version"] == "v1" ) assert ( - _mock_api_custom_objects_client.list_namespaced_custom_object.call_args[1][ - "plural" - ] + _mock_api_custom_objects_client.return_value.list_namespaced_custom_object.call_args[ + 1 + ]["plural"] == "ops" ) @@ -281,37 +288,42 @@ async def test_patch_namespaced_custom_object( ) assert ( - _mock_api_custom_objects_client.patch_namespaced_custom_object.call_args[1]["a"] + _mock_api_custom_objects_client.return_value.patch_namespaced_custom_object.call_args[ + 1 + ]["a"] == "test" ) assert ( - _mock_api_custom_objects_client.patch_namespaced_custom_object.call_args[1][ - "group" - ] + _mock_api_custom_objects_client.return_value.patch_namespaced_custom_object.call_args[ + 1 + ]["group"] == "my-group" ) assert ( - _mock_api_custom_objects_client.patch_namespaced_custom_object.call_args[1][ - "version" - ] + _mock_api_custom_objects_client.return_value.patch_namespaced_custom_object.call_args[ + 1 + ]["version"] == "v1" ) assert ( - _mock_api_custom_objects_client.patch_namespaced_custom_object.call_args[1][ - "plural" - ] + _mock_api_custom_objects_client.return_value.patch_namespaced_custom_object.call_args[ + 1 + ]["plural"] == "ops" ) assert ( - _mock_api_custom_objects_client.patch_namespaced_custom_object.call_args[1][ - "name" - ] + _mock_api_custom_objects_client.return_value.patch_namespaced_custom_object.call_args[ + 1 + ]["name"] == "test-name" ) - assert _mock_api_custom_objects_client.patch_namespaced_custom_object.call_args[1][ - "body" - ]["metadata"] == {"name": "test"} + assert ( + _mock_api_custom_objects_client.return_value.patch_namespaced_custom_object.call_args[ + 1 + ]["body"]["metadata"] + == {"name": "test"} + ) async def test_replace_namespaced_custom_object( @@ -334,36 +346,39 @@ async def test_replace_namespaced_custom_object( ) assert ( - _mock_api_custom_objects_client.replace_namespaced_custom_object.call_args[1][ - "a" - ] + _mock_api_custom_objects_client.return_value.replace_namespaced_custom_object.call_args[ + 1 + ]["a"] == "test" ) assert ( - _mock_api_custom_objects_client.replace_namespaced_custom_object.call_args[1][ - "group" - ] + _mock_api_custom_objects_client.return_value.replace_namespaced_custom_object.call_args[ + 1 + ]["group"] == "my-group" ) assert ( - _mock_api_custom_objects_client.replace_namespaced_custom_object.call_args[1][ - "version" - ] + _mock_api_custom_objects_client.return_value.replace_namespaced_custom_object.call_args[ + 1 + ]["version"] == "v1" ) assert ( - _mock_api_custom_objects_client.replace_namespaced_custom_object.call_args[1][ - "plural" - ] + _mock_api_custom_objects_client.return_value.replace_namespaced_custom_object.call_args[ + 1 + ]["plural"] == "ops" ) assert ( - _mock_api_custom_objects_client.replace_namespaced_custom_object.call_args[1][ - "name" - ] + _mock_api_custom_objects_client.return_value.replace_namespaced_custom_object.call_args[ + 1 + ]["name"] == "test-name" ) - assert _mock_api_custom_objects_client.replace_namespaced_custom_object.call_args[ - 1 - ]["body"]["metadata"] == {"name": "test"} + assert ( + _mock_api_custom_objects_client.return_value.replace_namespaced_custom_object.call_args[ + 1 + ]["body"]["metadata"] + == {"name": "test"} + ) diff --git a/src/integrations/prefect-kubernetes/tests/test_deployments.py b/src/integrations/prefect-kubernetes/tests/test_deployments.py index 802216602341..956381a80ab4 100644 --- a/src/integrations/prefect-kubernetes/tests/test_deployments.py +++ b/src/integrations/prefect-kubernetes/tests/test_deployments.py @@ -1,5 +1,5 @@ import pytest -from kubernetes.client.models import V1DeleteOptions, V1Deployment +from kubernetes_asyncio.client.models import V1DeleteOptions, V1Deployment from prefect_kubernetes.deployments import ( create_namespaced_deployment, delete_namespaced_deployment, @@ -19,10 +19,13 @@ async def test_create_namespaced_deployment( kubernetes_credentials=kubernetes_credentials, ) - assert _mock_api_app_client.create_namespaced_deployment.call_args[1][ + assert _mock_api_app_client.return_value.create_namespaced_deployment.call_args[1][ "body" ].metadata == {"name": "test-deployment"} - assert _mock_api_app_client.create_namespaced_deployment.call_args[1]["a"] == "test" + assert ( + _mock_api_app_client.return_value.create_namespaced_deployment.call_args[1]["a"] + == "test" + ) async def test_delete_namespaced_deployment( @@ -35,12 +38,17 @@ async def test_delete_namespaced_deployment( a="test", ) assert ( - _mock_api_app_client.delete_namespaced_deployment.call_args[1]["namespace"] + _mock_api_app_client.return_value.delete_namespaced_deployment.call_args[1][ + "namespace" + ] == "default" ) - assert _mock_api_app_client.delete_namespaced_deployment.call_args[1]["a"] == "test" assert ( - _mock_api_app_client.delete_namespaced_deployment.call_args[1][ + _mock_api_app_client.return_value.delete_namespaced_deployment.call_args[1]["a"] + == "test" + ) + assert ( + _mock_api_app_client.return_value.delete_namespaced_deployment.call_args[1][ "body" ].grace_period_seconds == 42 @@ -63,10 +71,15 @@ async def test_list_namespaced_deployment(kubernetes_credentials, _mock_api_app_ kubernetes_credentials=kubernetes_credentials, ) assert ( - _mock_api_app_client.list_namespaced_deployment.call_args[1]["namespace"] + _mock_api_app_client.return_value.list_namespaced_deployment.call_args[1][ + "namespace" + ] == "ns" ) - assert _mock_api_app_client.list_namespaced_deployment.call_args[1]["a"] == "test" + assert ( + _mock_api_app_client.return_value.list_namespaced_deployment.call_args[1]["a"] + == "test" + ) async def test_patch_namespaced_deployment( @@ -78,14 +91,19 @@ async def test_patch_namespaced_deployment( deployment_name="test_deployment", a="test", ) - assert _mock_api_app_client.patch_namespaced_deployment.call_args[1][ + assert _mock_api_app_client.return_value.patch_namespaced_deployment.call_args[1][ "body" ].metadata == {"name": "test-deployment"} assert ( - _mock_api_app_client.patch_namespaced_deployment.call_args[1]["name"] + _mock_api_app_client.return_value.patch_namespaced_deployment.call_args[1][ + "name" + ] == "test_deployment" ) - assert _mock_api_app_client.patch_namespaced_deployment.call_args[1]["a"] == "test" + assert ( + _mock_api_app_client.return_value.patch_namespaced_deployment.call_args[1]["a"] + == "test" + ) async def test_read_namespaced_deployment(kubernetes_credentials, _mock_api_app_client): @@ -96,14 +114,21 @@ async def test_read_namespaced_deployment(kubernetes_credentials, _mock_api_app_ kubernetes_credentials=kubernetes_credentials, ) assert ( - _mock_api_app_client.read_namespaced_deployment.call_args[1]["name"] + _mock_api_app_client.return_value.read_namespaced_deployment.call_args[1][ + "name" + ] == "test_deployment" ) assert ( - _mock_api_app_client.read_namespaced_deployment.call_args[1]["namespace"] + _mock_api_app_client.return_value.read_namespaced_deployment.call_args[1][ + "namespace" + ] == "ns" ) - assert _mock_api_app_client.read_namespaced_deployment.call_args[1]["a"] == "test" + assert ( + _mock_api_app_client.return_value.read_namespaced_deployment.call_args[1]["a"] + == "test" + ) async def test_replace_namespaced_deployment( @@ -117,18 +142,25 @@ async def test_replace_namespaced_deployment( kubernetes_credentials=kubernetes_credentials, ) assert ( - _mock_api_app_client.replace_namespaced_deployment.call_args[1]["name"] + _mock_api_app_client.return_value.replace_namespaced_deployment.call_args[1][ + "name" + ] == "test_deployment" ) assert ( - _mock_api_app_client.replace_namespaced_deployment.call_args[1]["namespace"] + _mock_api_app_client.return_value.replace_namespaced_deployment.call_args[1][ + "namespace" + ] == "ns" ) - assert _mock_api_app_client.replace_namespaced_deployment.call_args[1][ + assert _mock_api_app_client.return_value.replace_namespaced_deployment.call_args[1][ "body" ].metadata == {"name": "test-deployment"} assert ( - _mock_api_app_client.replace_namespaced_deployment.call_args[1]["a"] == "test" + _mock_api_app_client.return_value.replace_namespaced_deployment.call_args[1][ + "a" + ] + == "test" ) diff --git a/src/integrations/prefect-kubernetes/tests/test_events_replicator.py b/src/integrations/prefect-kubernetes/tests/test_events_replicator.py index 2b10f0f3663e..d94308cbc627 100644 --- a/src/integrations/prefect-kubernetes/tests/test_events_replicator.py +++ b/src/integrations/prefect-kubernetes/tests/test_events_replicator.py @@ -1,21 +1,22 @@ +import asyncio import copy -import threading -import time -from unittest.mock import MagicMock, call, patch +from unittest import mock +from unittest.mock import AsyncMock, MagicMock, call, patch import pytest -from kubernetes.client import V1Pod +from kubernetes_asyncio.client import CoreV1Api, V1Pod from prefect_kubernetes.events import EVICTED_REASONS, KubernetesEventsReplicator from prefect.events import RelatedResource from prefect.utilities.importtools import lazy_import -kubernetes = lazy_import("kubernetes") +kubernetes_asyncio = lazy_import("kubernetes_asyncio") @pytest.fixture -def client(): - return MagicMock() +async def client(): + async with AsyncMock() as mock: + yield mock @pytest.fixture @@ -55,6 +56,22 @@ def failed_pod(pod): return failed_pod +@pytest.fixture +def mock_watch(monkeypatch): + mock = MagicMock(return_value=AsyncMock()) + monkeypatch.setattr("kubernetes_asyncio.watch.Watch", mock) + return mock + + +@pytest.fixture +def mock_core_client(monkeypatch): + mock = MagicMock(spec=CoreV1Api, return_value=AsyncMock()) + + monkeypatch.setattr("prefect_kubernetes.worker.CoreV1Api", mock) + monkeypatch.setattr("kubernetes_asyncio.client.CoreV1Api", mock) + return mock + + @pytest.fixture def evicted_pod(pod): container_status = MagicMock() @@ -70,57 +87,71 @@ def evicted_pod(pod): @pytest.fixture -def successful_pod_stream(pending_pod, running_pod, succeeded_pod): - return [ - { - "type": "ADDED", - "object": pending_pod, - }, - { - "type": "MODIFIED", - "object": running_pod, - }, - { - "type": "MODIFIED", - "object": succeeded_pod, - }, - ] +async def successful_pod_stream( + pending_pod, running_pod, succeeded_pod, mock_core_client +): + async def event_stream(**kwargs): + if kwargs["func"] == mock_core_client.return_value.list_namespaced_pod: + events = [ + {"type": "ADDED", "object": pending_pod}, + {"type": "MODIFIED", "object": running_pod}, + {"type": "MODIFIED", "object": succeeded_pod}, + ] + for event in events: + yield event + await asyncio.sleep(0.1) # simulate async behavior + + return event_stream @pytest.fixture -def failed_pod_stream(pending_pod, running_pod, failed_pod): - return [ - { - "type": "ADDED", - "object": pending_pod, - }, - { - "type": "MODIFIED", - "object": running_pod, - }, - { - "type": "MODIFIED", - "object": failed_pod, - }, - ] +def failed_pod_stream(pending_pod, running_pod, failed_pod, mock_core_client): + async def event_stream(**kwargs): + if kwargs["func"] == mock_core_client.return_value.list_namespaced_pod: + events = [ + { + "type": "ADDED", + "object": pending_pod, + }, + { + "type": "MODIFIED", + "object": running_pod, + }, + { + "type": "MODIFIED", + "object": failed_pod, + }, + ] + for event in events: + yield event + await asyncio.sleep(0.1) # simulate async behavior + + return event_stream @pytest.fixture -def evicted_pod_stream(pending_pod, running_pod, evicted_pod): - return [ - { - "type": "ADDED", - "object": pending_pod, - }, - { - "type": "MODIFIED", - "object": running_pod, - }, - { - "type": "MODIFIED", - "object": evicted_pod, - }, - ] +def evicted_pod_stream(pending_pod, running_pod, evicted_pod, mock_core_client): + async def event_stream(**kwargs): + if kwargs["func"] == mock_core_client.return_value.list_namespaced_pod: + events = [ + { + "type": "ADDED", + "object": pending_pod, + }, + { + "type": "MODIFIED", + "object": running_pod, + }, + { + "type": "MODIFIED", + "object": evicted_pod, + }, + ] + for event in events: + yield event + await asyncio.sleep(0.1) # simulate async behavior + + return event_stream @pytest.fixture @@ -152,25 +183,23 @@ def replicator(client, worker_resource, related_resources): ) -def test_lifecycle(replicator): - mock_watch = MagicMock(spec=kubernetes.watch.Watch) - mock_thread = MagicMock(spec=threading.Thread) +async def test_lifecycle(replicator, mock_watch, mock_core_client, pod): + async def mock_stream(**kwargs): + if kwargs["func"] == mock_core_client.return_value.list_namespaced_pod: + yield {"type": "ADDED", "object": pod} - with patch.object(replicator, "_watch", mock_watch): - with patch.object(replicator, "_thread", mock_thread): - with replicator: - assert replicator._state == "STARTED" - mock_thread.start.assert_called_once_with() - mock_thread.reset_mock() + mock_watch.return_value.stream = mock.Mock(side_effect=mock_stream) + async with replicator: + await asyncio.sleep(0.3) + assert replicator._state == "STARTED" assert replicator._state == "STOPPED" - mock_watch.stop.assert_called_once_with() - mock_thread.join.assert_called_once_with() -def test_replicate_successful_pod_events(replicator, successful_pod_stream): - mock_watch = MagicMock(spec=kubernetes.watch.Watch) - mock_watch.stream.return_value = successful_pod_stream +async def test_replicate_successful_pod_events( + replicator, mock_watch, successful_pod_stream +): + mock_watch.return_value.stream = mock.Mock(side_effect=successful_pod_stream) event_count = 0 @@ -180,9 +209,8 @@ def event(*args, **kwargs): return event_count with patch("prefect_kubernetes.events.emit_event", side_effect=event) as mock_emit: - with patch.object(replicator, "_watch", mock_watch): - with replicator: - time.sleep(0.3) + async with replicator: + await asyncio.sleep(0.5) # allow some time for the events to be processed mock_emit.assert_has_calls( [ @@ -257,12 +285,10 @@ def event(*args, **kwargs): ), ] ) - mock_watch.stop.assert_called_once_with() -def test_replicate_failed_pod_events(replicator, failed_pod_stream): - mock_watch = MagicMock(spec=kubernetes.watch.Watch) - mock_watch.stream.return_value = failed_pod_stream +async def test_replicate_failed_pod_events(replicator, mock_watch, failed_pod_stream): + mock_watch.return_value.stream = mock.Mock(side_effect=failed_pod_stream) event_count = 0 @@ -272,9 +298,8 @@ def event(*args, **kwargs): return event_count with patch("prefect_kubernetes.events.emit_event", side_effect=event) as mock_emit: - with patch.object(replicator, "_watch", mock_watch): - with replicator: - time.sleep(0.3) + async with replicator: + await asyncio.sleep(0.5) mock_emit.assert_has_calls( [ @@ -349,13 +374,11 @@ def event(*args, **kwargs): ), ] ) - mock_watch.stop.assert_called_once_with() - -def test_replicate_evicted_pod_events(replicator, evicted_pod_stream): - mock_watch = MagicMock(spec=kubernetes.watch.Watch) - mock_watch.stream.return_value = evicted_pod_stream +@pytest.mark.asyncio +async def test_replicate_evicted_pod_events(replicator, mock_watch, evicted_pod_stream): + mock_watch.return_value.stream = mock.Mock(side_effect=evicted_pod_stream) event_count = 0 def event(*args, **kwargs): @@ -364,9 +387,8 @@ def event(*args, **kwargs): return event_count with patch("prefect_kubernetes.events.emit_event", side_effect=event) as mock_emit: - with patch.object(replicator, "_watch", mock_watch): - with replicator: - time.sleep(0.3) + async with replicator: + await asyncio.sleep(0.5) mock_emit.assert_has_calls( [ @@ -442,4 +464,3 @@ def event(*args, **kwargs): ), ] ) - mock_watch.stop.assert_called_once_with() diff --git a/src/integrations/prefect-kubernetes/tests/test_jobs.py b/src/integrations/prefect-kubernetes/tests/test_jobs.py index 9b4a2f707409..b7b089b876a0 100644 --- a/src/integrations/prefect-kubernetes/tests/test_jobs.py +++ b/src/integrations/prefect-kubernetes/tests/test_jobs.py @@ -1,6 +1,8 @@ +from pathlib import Path + import pytest -from kubernetes.client.exceptions import ApiValueError -from kubernetes.client.models import V1Job +from kubernetes_asyncio.client.exceptions import ApiValueError +from kubernetes_asyncio.client.models import V1Job from prefect_kubernetes.jobs import ( KubernetesJob, create_namespaced_job, @@ -35,10 +37,13 @@ async def test_create_namespaced_job(kubernetes_credentials, _mock_api_batch_cli kubernetes_credentials=kubernetes_credentials, ) - assert _mock_api_batch_client.create_namespaced_job.call_args[1][ + assert _mock_api_batch_client.return_value.create_namespaced_job.call_args[1][ "body" ].metadata == {"name": "test-job"} - assert _mock_api_batch_client.create_namespaced_job.call_args[1]["a"] == "test" + assert ( + _mock_api_batch_client.return_value.create_namespaced_job.call_args[1]["a"] + == "test" + ) async def test_delete_namespaced_job(kubernetes_credentials, _mock_api_batch_client): @@ -48,9 +53,13 @@ async def test_delete_namespaced_job(kubernetes_credentials, _mock_api_batch_cli kubernetes_credentials=kubernetes_credentials, ) assert ( - _mock_api_batch_client.delete_namespaced_job.call_args[1]["name"] == "test-job" + _mock_api_batch_client.return_value.delete_namespaced_job.call_args[1]["name"] + == "test-job" + ) + assert ( + _mock_api_batch_client.return_value.delete_namespaced_job.call_args[1]["a"] + == "test" ) - assert _mock_api_batch_client.delete_namespaced_job.call_args[1]["a"] == "test" async def test_list_namespaced_job(kubernetes_credentials, _mock_api_batch_client): @@ -59,8 +68,16 @@ async def test_list_namespaced_job(kubernetes_credentials, _mock_api_batch_clien a="test", kubernetes_credentials=kubernetes_credentials, ) - assert _mock_api_batch_client.list_namespaced_job.call_args[1]["namespace"] == "ns" - assert _mock_api_batch_client.list_namespaced_job.call_args[1]["a"] == "test" + assert ( + _mock_api_batch_client.return_value.list_namespaced_job.call_args[1][ + "namespace" + ] + == "ns" + ) + assert ( + _mock_api_batch_client.return_value.list_namespaced_job.call_args[1]["a"] + == "test" + ) async def test_patch_namespaced_job(kubernetes_credentials, _mock_api_batch_client): @@ -70,13 +87,17 @@ async def test_patch_namespaced_job(kubernetes_credentials, _mock_api_batch_clie a="test", kubernetes_credentials=kubernetes_credentials, ) - assert _mock_api_batch_client.patch_namespaced_job.call_args[1][ + assert _mock_api_batch_client.return_value.patch_namespaced_job.call_args[1][ "body" ].metadata == {"name": "test-job"} assert ( - _mock_api_batch_client.patch_namespaced_job.call_args[1]["name"] == "test-job" + _mock_api_batch_client.return_value.patch_namespaced_job.call_args[1]["name"] + == "test-job" + ) + assert ( + _mock_api_batch_client.return_value.patch_namespaced_job.call_args[1]["a"] + == "test" ) - assert _mock_api_batch_client.patch_namespaced_job.call_args[1]["a"] == "test" async def test_read_namespaced_job(kubernetes_credentials, _mock_api_batch_client): @@ -86,9 +107,20 @@ async def test_read_namespaced_job(kubernetes_credentials, _mock_api_batch_clien a="test", kubernetes_credentials=kubernetes_credentials, ) - assert _mock_api_batch_client.read_namespaced_job.call_args[1]["name"] == "test-job" - assert _mock_api_batch_client.read_namespaced_job.call_args[1]["namespace"] == "ns" - assert _mock_api_batch_client.read_namespaced_job.call_args[1]["a"] == "test" + assert ( + _mock_api_batch_client.return_value.read_namespaced_job.call_args[1]["name"] + == "test-job" + ) + assert ( + _mock_api_batch_client.return_value.read_namespaced_job.call_args[1][ + "namespace" + ] + == "ns" + ) + assert ( + _mock_api_batch_client.return_value.read_namespaced_job.call_args[1]["a"] + == "test" + ) async def test_replace_namespaced_job(kubernetes_credentials, _mock_api_batch_client): @@ -100,15 +132,22 @@ async def test_replace_namespaced_job(kubernetes_credentials, _mock_api_batch_cl kubernetes_credentials=kubernetes_credentials, ) assert ( - _mock_api_batch_client.replace_namespaced_job.call_args[1]["name"] == "test-job" + _mock_api_batch_client.return_value.replace_namespaced_job.call_args[1]["name"] + == "test-job" ) assert ( - _mock_api_batch_client.replace_namespaced_job.call_args[1]["namespace"] == "ns" + _mock_api_batch_client.return_value.replace_namespaced_job.call_args[1][ + "namespace" + ] + == "ns" ) - assert _mock_api_batch_client.replace_namespaced_job.call_args[1][ + assert _mock_api_batch_client.return_value.replace_namespaced_job.call_args[1][ "body" ].metadata == {"name": "test-job"} - assert _mock_api_batch_client.replace_namespaced_job.call_args[1]["a"] == "test" + assert ( + _mock_api_batch_client.return_value.replace_namespaced_job.call_args[1]["a"] + == "test" + ) async def test_read_namespaced_job_status( @@ -121,20 +160,40 @@ async def test_read_namespaced_job_status( kubernetes_credentials=kubernetes_credentials, ) assert ( - _mock_api_batch_client.read_namespaced_job_status.call_args[1]["name"] + _mock_api_batch_client.return_value.read_namespaced_job_status.call_args[1][ + "name" + ] == "test-job" ) assert ( - _mock_api_batch_client.read_namespaced_job_status.call_args[1]["namespace"] + _mock_api_batch_client.return_value.read_namespaced_job_status.call_args[1][ + "namespace" + ] == "ns" ) - assert _mock_api_batch_client.read_namespaced_job_status.call_args[1]["a"] == "test" + assert ( + _mock_api_batch_client.return_value.read_namespaced_job_status.call_args[1]["a"] + == "test" + ) async def test_job_block_from_job_yaml(kubernetes_credentials): + DIR = ( + ( + Path.cwd() + / "src" + / "integrations" + / "prefect-kubernetes" + / "tests" + / "sample_k8s_resources" + / "sample_job.yaml" + ) + if Path.cwd().name == "prefect" + else Path.cwd() / "tests" / "sample_k8s_resources" / "sample_job.yaml" + ) job = KubernetesJob.from_yaml_file( credentials=kubernetes_credentials, - manifest_path="tests/sample_k8s_resources/sample_job.yaml", + manifest_path=DIR, ) assert isinstance(job, KubernetesJob) assert job.v1_job["metadata"]["name"] == "pi" diff --git a/src/integrations/prefect-kubernetes/tests/test_pods.py b/src/integrations/prefect-kubernetes/tests/test_pods.py index 1c8beae6fd85..cf5d7048dbab 100644 --- a/src/integrations/prefect-kubernetes/tests/test_pods.py +++ b/src/integrations/prefect-kubernetes/tests/test_pods.py @@ -1,6 +1,6 @@ import pytest -from kubernetes.client.exceptions import ApiException, ApiValueError -from kubernetes.client.models import V1DeleteOptions, V1Pod +from kubernetes_asyncio.client.exceptions import ApiException, ApiValueError +from kubernetes_asyncio.client.models import V1DeleteOptions, V1Pod from prefect_kubernetes.pods import ( create_namespaced_pod, delete_namespaced_pod, @@ -29,11 +29,13 @@ async def test_create_namespaced_pod(kubernetes_credentials, _mock_api_core_clie a="test", kubernetes_credentials=kubernetes_credentials, ) - - assert _mock_api_core_client.create_namespaced_pod.call_args[1][ + assert _mock_api_core_client.return_value.create_namespaced_pod.call_args[1][ "body" ].metadata == {"name": "test-pod"} - assert _mock_api_core_client.create_namespaced_pod.call_args[1]["a"] == "test" + assert ( + _mock_api_core_client.return_value.create_namespaced_pod.call_args[1]["a"] + == "test" + ) async def test_delete_namespaced_pod(kubernetes_credentials, _mock_api_core_client): @@ -44,12 +46,17 @@ async def test_delete_namespaced_pod(kubernetes_credentials, _mock_api_core_clie a="test", ) assert ( - _mock_api_core_client.delete_namespaced_pod.call_args[1]["namespace"] + _mock_api_core_client.return_value.delete_namespaced_pod.call_args[1][ + "namespace" + ] == "default" ) - assert _mock_api_core_client.delete_namespaced_pod.call_args[1]["a"] == "test" assert ( - _mock_api_core_client.delete_namespaced_pod.call_args[1][ + _mock_api_core_client.return_value.delete_namespaced_pod.call_args[1]["a"] + == "test" + ) + assert ( + _mock_api_core_client.return_value.delete_namespaced_pod.call_args[1][ "body" ].grace_period_seconds == 42 @@ -71,8 +78,14 @@ async def test_list_namespaced_pod(kubernetes_credentials, _mock_api_core_client a="test", kubernetes_credentials=kubernetes_credentials, ) - assert _mock_api_core_client.list_namespaced_pod.call_args[1]["namespace"] == "ns" - assert _mock_api_core_client.list_namespaced_pod.call_args[1]["a"] == "test" + assert ( + _mock_api_core_client.return_value.list_namespaced_pod.call_args[1]["namespace"] + == "ns" + ) + assert ( + _mock_api_core_client.return_value.list_namespaced_pod.call_args[1]["a"] + == "test" + ) async def test_patch_namespaced_pod(kubernetes_credentials, _mock_api_core_client): @@ -82,11 +95,17 @@ async def test_patch_namespaced_pod(kubernetes_credentials, _mock_api_core_clien pod_name="test_pod", a="test", ) - assert _mock_api_core_client.patch_namespaced_pod.call_args[1]["body"].metadata == { - "name": "test-pod" - } - assert _mock_api_core_client.patch_namespaced_pod.call_args[1]["name"] == "test_pod" - assert _mock_api_core_client.patch_namespaced_pod.call_args[1]["a"] == "test" + assert _mock_api_core_client.return_value.patch_namespaced_pod.call_args[1][ + "body" + ].metadata == {"name": "test-pod"} + assert ( + _mock_api_core_client.return_value.patch_namespaced_pod.call_args[1]["name"] + == "test_pod" + ) + assert ( + _mock_api_core_client.return_value.patch_namespaced_pod.call_args[1]["a"] + == "test" + ) async def test_read_namespaced_pod(kubernetes_credentials, _mock_api_core_client): @@ -96,9 +115,18 @@ async def test_read_namespaced_pod(kubernetes_credentials, _mock_api_core_client a="test", kubernetes_credentials=kubernetes_credentials, ) - assert _mock_api_core_client.read_namespaced_pod.call_args[1]["name"] == "test_pod" - assert _mock_api_core_client.read_namespaced_pod.call_args[1]["namespace"] == "ns" - assert _mock_api_core_client.read_namespaced_pod.call_args[1]["a"] == "test" + assert ( + _mock_api_core_client.return_value.read_namespaced_pod.call_args[1]["name"] + == "test_pod" + ) + assert ( + _mock_api_core_client.return_value.read_namespaced_pod.call_args[1]["namespace"] + == "ns" + ) + assert ( + _mock_api_core_client.return_value.read_namespaced_pod.call_args[1]["a"] + == "test" + ) async def test_read_namespaced_pod_logs(kubernetes_credentials, _mock_api_core_client): @@ -110,16 +138,25 @@ async def test_read_namespaced_pod_logs(kubernetes_credentials, _mock_api_core_c kubernetes_credentials=kubernetes_credentials, ) assert ( - _mock_api_core_client.read_namespaced_pod_log.call_args[1]["name"] == "test_pod" + _mock_api_core_client.return_value.read_namespaced_pod_log.call_args[1]["name"] + == "test_pod" ) assert ( - _mock_api_core_client.read_namespaced_pod_log.call_args[1]["namespace"] == "ns" + _mock_api_core_client.return_value.read_namespaced_pod_log.call_args[1][ + "namespace" + ] + == "ns" ) assert ( - _mock_api_core_client.read_namespaced_pod_log.call_args[1]["container"] + _mock_api_core_client.return_value.read_namespaced_pod_log.call_args[1][ + "container" + ] == "test_container" ) - assert _mock_api_core_client.read_namespaced_pod_log.call_args[1]["a"] == "test" + assert ( + _mock_api_core_client.return_value.read_namespaced_pod_log.call_args[1]["a"] + == "test" + ) async def test_replace_namespaced_pod(kubernetes_credentials, _mock_api_core_client): @@ -131,15 +168,22 @@ async def test_replace_namespaced_pod(kubernetes_credentials, _mock_api_core_cli kubernetes_credentials=kubernetes_credentials, ) assert ( - _mock_api_core_client.replace_namespaced_pod.call_args[1]["name"] == "test_pod" + _mock_api_core_client.return_value.replace_namespaced_pod.call_args[1]["name"] + == "test_pod" ) assert ( - _mock_api_core_client.replace_namespaced_pod.call_args[1]["namespace"] == "ns" + _mock_api_core_client.return_value.replace_namespaced_pod.call_args[1][ + "namespace" + ] + == "ns" ) - assert _mock_api_core_client.replace_namespaced_pod.call_args[1][ + assert _mock_api_core_client.return_value.replace_namespaced_pod.call_args[1][ "body" ].metadata == {"name": "test-pod"} - assert _mock_api_core_client.replace_namespaced_pod.call_args[1]["a"] == "test" + assert ( + _mock_api_core_client.return_value.replace_namespaced_pod.call_args[1]["a"] + == "test" + ) @pytest.mark.parametrize( @@ -172,13 +216,19 @@ async def test_read_pod_log_custom_print_func( assert capsys.readouterr().out == "test log\n" assert ( - _mock_api_core_client.read_namespaced_pod_log.call_args[1]["name"] == "test_pod" + _mock_api_core_client.return_value.read_namespaced_pod_log.call_args[1]["name"] + == "test_pod" ) assert ( - _mock_api_core_client.read_namespaced_pod_log.call_args[1]["namespace"] == "ns" + _mock_api_core_client.return_value.read_namespaced_pod_log.call_args[1][ + "namespace" + ] + == "ns" ) assert ( - _mock_api_core_client.read_namespaced_pod_log.call_args[1]["container"] + _mock_api_core_client.return_value.read_namespaced_pod_log.call_args[1][ + "container" + ] == "test_container" ) diff --git a/src/integrations/prefect-kubernetes/tests/test_services.py b/src/integrations/prefect-kubernetes/tests/test_services.py index 245d3acfda43..a12a3e734328 100644 --- a/src/integrations/prefect-kubernetes/tests/test_services.py +++ b/src/integrations/prefect-kubernetes/tests/test_services.py @@ -1,4 +1,4 @@ -from kubernetes.client.models import V1DeleteOptions, V1Service +from kubernetes_asyncio.client.models import V1DeleteOptions, V1Service from prefect_kubernetes.services import ( create_namespaced_service, delete_namespaced_service, @@ -16,12 +16,14 @@ async def test_create_namespaced_service(kubernetes_credentials, _mock_api_core_ namespace="default", ) - assert _mock_api_core_client.create_namespaced_service.call_count == 1 - assert _mock_api_core_client.create_namespaced_service.call_args[1][ + assert _mock_api_core_client.return_value.create_namespaced_service.call_count == 1 + assert _mock_api_core_client.return_value.create_namespaced_service.call_args[1][ "body" ].metadata == {"name": "test-service"} assert ( - _mock_api_core_client.create_namespaced_service.call_args[1]["namespace"] + _mock_api_core_client.return_value.create_namespaced_service.call_args[1][ + "namespace" + ] == "default" ) @@ -34,19 +36,23 @@ async def test_delete_namespaced_service(kubernetes_credentials, _mock_api_core_ namespace="default", ) - assert _mock_api_core_client.delete_namespaced_service.call_count == 1 + assert _mock_api_core_client.return_value.delete_namespaced_service.call_count == 1 assert ( - _mock_api_core_client.delete_namespaced_service.call_args[1]["name"] + _mock_api_core_client.return_value.delete_namespaced_service.call_args[1][ + "name" + ] == "test-service" ) assert ( - _mock_api_core_client.delete_namespaced_service.call_args[1][ + _mock_api_core_client.return_value.delete_namespaced_service.call_args[1][ "body" ].grace_period_seconds == 42 ) assert ( - _mock_api_core_client.delete_namespaced_service.call_args[1]["namespace"] + _mock_api_core_client.return_value.delete_namespaced_service.call_args[1][ + "namespace" + ] == "default" ) @@ -57,9 +63,11 @@ async def test_list_namespaced_service(kubernetes_credentials, _mock_api_core_cl namespace="default", ) - assert _mock_api_core_client.list_namespaced_service.call_count == 1 + assert _mock_api_core_client.return_value.list_namespaced_service.call_count == 1 assert ( - _mock_api_core_client.list_namespaced_service.call_args[1]["namespace"] + _mock_api_core_client.return_value.list_namespaced_service.call_args[1][ + "namespace" + ] == "default" ) @@ -72,16 +80,18 @@ async def test_patch_namespaced_service(kubernetes_credentials, _mock_api_core_c namespace="default", ) - assert _mock_api_core_client.patch_namespaced_service.call_count == 1 + assert _mock_api_core_client.return_value.patch_namespaced_service.call_count == 1 assert ( - _mock_api_core_client.patch_namespaced_service.call_args[1]["name"] + _mock_api_core_client.return_value.patch_namespaced_service.call_args[1]["name"] == "test-service-old" ) - assert _mock_api_core_client.patch_namespaced_service.call_args[1][ + assert _mock_api_core_client.return_value.patch_namespaced_service.call_args[1][ "body" ].metadata == {"name": "test-service"} assert ( - _mock_api_core_client.patch_namespaced_service.call_args[1]["namespace"] + _mock_api_core_client.return_value.patch_namespaced_service.call_args[1][ + "namespace" + ] == "default" ) @@ -93,13 +103,15 @@ async def test_read_namespaced_service(kubernetes_credentials, _mock_api_core_cl namespace="default", ) - assert _mock_api_core_client.read_namespaced_service.call_count == 1 + assert _mock_api_core_client.return_value.read_namespaced_service.call_count == 1 assert ( - _mock_api_core_client.read_namespaced_service.call_args[1]["name"] + _mock_api_core_client.return_value.read_namespaced_service.call_args[1]["name"] == "test-service" ) assert ( - _mock_api_core_client.read_namespaced_service.call_args[1]["namespace"] + _mock_api_core_client.return_value.read_namespaced_service.call_args[1][ + "namespace" + ] == "default" ) @@ -114,15 +126,19 @@ async def test_replace_namespaced_service( namespace="default", ) - assert _mock_api_core_client.replace_namespaced_service.call_count == 1 + assert _mock_api_core_client.return_value.replace_namespaced_service.call_count == 1 assert ( - _mock_api_core_client.replace_namespaced_service.call_args[1]["name"] + _mock_api_core_client.return_value.replace_namespaced_service.call_args[1][ + "name" + ] == "test-service" ) - assert _mock_api_core_client.replace_namespaced_service.call_args[1][ + assert _mock_api_core_client.return_value.replace_namespaced_service.call_args[1][ "body" ].metadata == {"labels": {"foo": "bar"}} assert ( - _mock_api_core_client.replace_namespaced_service.call_args[1]["namespace"] + _mock_api_core_client.return_value.replace_namespaced_service.call_args[1][ + "namespace" + ] == "default" ) diff --git a/src/integrations/prefect-kubernetes/tests/test_utilities.py b/src/integrations/prefect-kubernetes/tests/test_utilities.py index 9fd05e3bdd29..bccecc14c59f 100644 --- a/src/integrations/prefect-kubernetes/tests/test_utilities.py +++ b/src/integrations/prefect-kubernetes/tests/test_utilities.py @@ -1,7 +1,7 @@ -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock import pytest -from kubernetes.config import ConfigException +from kubernetes_asyncio.config import ConfigException from prefect_kubernetes.utilities import ( enable_socket_keep_alive, ) @@ -18,14 +18,14 @@ def mock_cluster_config(monkeypatch): [], {"context": {"cluster": FAKE_CLUSTER}}, ) - monkeypatch.setattr("kubernetes.config", mock) - monkeypatch.setattr("kubernetes.config.ConfigException", ConfigException) + monkeypatch.setattr("kubernetes_asyncio.config", mock) + monkeypatch.setattr("kubernetes_asyncio.config.ConfigException", ConfigException) return mock @pytest.fixture def mock_api_client(mock_cluster_config): - return MagicMock() + return AsyncMock() def test_keep_alive_updates_socket_options(mock_api_client): diff --git a/src/integrations/prefect-kubernetes/tests/test_worker.py b/src/integrations/prefect-kubernetes/tests/test_worker.py index f01372bc3a6d..8f12fc81ae42 100644 --- a/src/integrations/prefect-kubernetes/tests/test_worker.py +++ b/src/integrations/prefect-kubernetes/tests/test_worker.py @@ -2,19 +2,20 @@ import json import re import uuid -from contextlib import contextmanager +from contextlib import asynccontextmanager from time import monotonic, sleep from unittest import mock -from unittest.mock import MagicMock, Mock +from unittest.mock import AsyncMock, MagicMock import anyio import anyio.abc -import kubernetes +import kubernetes_asyncio import pendulum import pytest from exceptiongroup import ExceptionGroup, catch -from kubernetes.client.exceptions import ApiException -from kubernetes.client.models import ( +from kubernetes_asyncio.client import ApiClient, BatchV1Api, CoreV1Api, V1Pod +from kubernetes_asyncio.client.exceptions import ApiException +from kubernetes_asyncio.client.models import ( CoreV1Event, CoreV1EventList, V1ListMeta, @@ -22,7 +23,7 @@ V1ObjectReference, V1Secret, ) -from kubernetes.config import ConfigException +from kubernetes_asyncio.config import ConfigException from prefect_kubernetes import KubernetesWorker from prefect_kubernetes.utilities import _slugify_label_value, _slugify_name from prefect_kubernetes.worker import KubernetesWorkerJobConfiguration @@ -50,34 +51,38 @@ @pytest.fixture def mock_watch(monkeypatch): - pytest.importorskip("kubernetes") + mock = MagicMock(return_value=AsyncMock()) + monkeypatch.setattr("kubernetes_asyncio.watch.Watch", mock) + return mock - mock = MagicMock() - monkeypatch.setattr("kubernetes.watch.Watch", MagicMock(return_value=mock)) - return mock +async def mock_stream(*args, **kwargs): + async for event in mock_pods_stream_that_returns_completed_pod(*args, **kwargs): + yield event @pytest.fixture def mock_cluster_config(monkeypatch): - mock = MagicMock() + mock = AsyncMock() # We cannot mock this or the `except` clause will complain - mock.config.ConfigException = ConfigException - mock.list_kube_config_contexts.return_value = ( + mock.return_value.ConfigException.return_value = ConfigException + mock.return_value.list_kube_config_contexts.return_value = ( [], {"context": {"cluster": FAKE_CLUSTER}}, ) - monkeypatch.setattr("kubernetes.config", mock) - monkeypatch.setattr("kubernetes.config.ConfigException", ConfigException) + monkeypatch.setattr("prefect_kubernetes.worker.config", mock) + monkeypatch.setattr( + "prefect_kubernetes.worker.config.ConfigException", ConfigException + ) return mock @pytest.fixture -def mock_anyio_sleep_monotonic(monkeypatch): +def mock_anyio_sleep_monotonic(monkeypatch, event_loop): def mock_monotonic(): return mock_sleep.current_time - def mock_sleep(duration): + async def mock_sleep(duration): mock_sleep.current_time += duration mock_sleep.current_time = monotonic() @@ -87,58 +92,96 @@ def mock_sleep(duration): @pytest.fixture def mock_job(): - mock = MagicMock(spec=kubernetes.client.V1Job) + mock = AsyncMock(spec=kubernetes_asyncio.client.V1Job) + mock.metadata.name = "mock-job" mock.metadata.namespace = "mock-namespace" return mock @pytest.fixture -def mock_core_client(monkeypatch, mock_cluster_config): - mock = MagicMock(spec=kubernetes.client.CoreV1Api) - mock.read_namespace.return_value.metadata.uid = MOCK_CLUSTER_UID +def mock_pod(): + pod = MagicMock(spec=V1Pod) + pod.status.phase = "Running" + pod.metadata.name = "mock-pod" + pod.metadata.namespace = "mock-namespace" + pod.metadata.uid = "1234" + return pod - @contextmanager - def get_core_client(*args, **kwargs): - yield mock + +@pytest.fixture +def mock_core_client(monkeypatch, mock_cluster_config): + mock = MagicMock(spec=CoreV1Api, return_value=AsyncMock()) + mock.return_value.read_namespace.return_value.metadata.uid = MOCK_CLUSTER_UID + mock.return_value.list_namespaced_pod.return_value.items.sort = MagicMock() + mock.return_value.read_namespaced_pod_log.return_value.content.readline = AsyncMock( + return_value=None + ) monkeypatch.setattr( - "prefect_kubernetes.worker.KubernetesWorker._get_core_client", - get_core_client, + "prefect_kubernetes.worker.KubernetesWorker._get_configured_kubernetes_client", + MagicMock(spec=ApiClient), ) + monkeypatch.setattr("prefect_kubernetes.worker.CoreV1Api", mock) + monkeypatch.setattr("kubernetes_asyncio.client.CoreV1Api", mock) return mock @pytest.fixture -def mock_batch_client(monkeypatch, mock_cluster_config, mock_job): - pytest.importorskip("kubernetes") +def mock_core_client_lean(monkeypatch): + mock = MagicMock(spec=CoreV1Api, return_value=AsyncMock()) + monkeypatch.setattr("prefect_kubernetes.worker.CoreV1Api", mock) + monkeypatch.setattr("kubernetes_asyncio.client.CoreV1Api", mock) + mock.return_value.list_namespaced_pod.return_value.items.sort = MagicMock() + return mock - mock = MagicMock(spec=kubernetes.client.BatchV1Api) - mock.read_namespaced_job.return_value = mock_job - mock.create_namespaced_job.return_value = mock_job - @contextmanager - def get_batch_client(*args, **kwargs): - yield mock +@pytest.fixture +def mock_batch_client(monkeypatch, mock_job): + mock = MagicMock(spec=BatchV1Api, return_value=AsyncMock()) + + @asynccontextmanager + async def get_batch_client(*args, **kwargs): + yield mock() monkeypatch.setattr( "prefect_kubernetes.worker.KubernetesWorker._get_batch_client", get_batch_client, ) + + mock.return_value.create_namespaced_job.return_value = mock_job + monkeypatch.setattr("prefect_kubernetes.worker.BatchV1Api", mock) return mock -def _mock_pods_stream_that_returns_running_pod(*args, **kwargs): - job_pod = MagicMock(spec=kubernetes.client.V1Pod) - job_pod.status.phase = "Running" +@pytest.fixture +async def mock_pods_stream_that_returns_running_pod( + mock_core_client, mock_pod, mock_job +): + async def mock_stream(*args, **kwargs): + if kwargs["func"] == mock_core_client.return_value.list_namespaced_pod: + yield {"object": mock_pod, "type": "MODIFIED"} + if kwargs["func"] == mock_core_client.return_value.list_namespaced_job: + mock_job.status.completion_time = pendulum.now("utc").timestamp() + yield {"object": mock_job, "type": "MODIFIED"} + + return mock_stream + - job = MagicMock(spec=kubernetes.client.V1Job) - job.status.completion_time = pendulum.now("utc").timestamp() +@pytest.fixture +async def mock_pods_stream_that_returns_completed_pod( + mock_core_client, mock_pod, mock_job +): + async def mock_stream(*args, **kwargs): + if kwargs["func"] == mock_core_client.return_value.list_namespaced_pod: + yield {"object": mock_pod, "type": "MODIFIED"} + if kwargs["func"] == mock_core_client.return_value.list_namespaced_job: + mock_job.status.completion_time = True + mock_job.status.failed = 0 + mock_job.spec.backoff_limit = 6 + yield {"object": mock_job, "type": "MODIFIED"} - return [ - {"object": job_pod, "type": "MODIFIED"}, - {"object": job, "type": "MODIFIED"}, - ] + return mock_stream @pytest.fixture @@ -1236,19 +1279,21 @@ async def test_creates_job_by_building_a_manifest( mock_batch_client, mock_core_client, mock_watch, + mock_pods_stream_that_returns_completed_pod, ): - mock_watch.stream = _mock_pods_stream_that_returns_running_pod default_configuration.prepare_for_flow_run(flow_run) expected_manifest = default_configuration.job_manifest - + mock_watch.return_value.stream = mock.Mock( + side_effect=mock_pods_stream_that_returns_completed_pod + ) async with KubernetesWorker(work_pool_name="test") as k8s_worker: await k8s_worker.run(flow_run=flow_run, configuration=default_configuration) - mock_core_client.list_namespaced_pod.assert_called_with( + mock_core_client.return_value.list_namespaced_pod.assert_called_with( namespace=default_configuration.namespace, label_selector="job-name=mock-job", ) - mock_batch_client.create_namespaced_job.assert_called_with( + mock_batch_client.return_value.create_namespaced_job.assert_called_with( "default", expected_manifest, ) @@ -1260,8 +1305,11 @@ async def test_task_status_receives_job_pid( mock_batch_client, mock_core_client, mock_watch, - monkeypatch, + mock_pods_stream_that_returns_completed_pod, ): + mock_watch.return_value.stream = mock.Mock( + side_effect=mock_pods_stream_that_returns_completed_pod + ) async with KubernetesWorker(work_pool_name="test") as k8s_worker: fake_status = MagicMock(spec=anyio.abc.TaskStatus) await k8s_worker.run( @@ -1279,8 +1327,12 @@ async def test_cluster_uid_uses_env_var_if_set( mock_batch_client, mock_core_client, mock_watch, + mock_pods_stream_that_returns_completed_pod, monkeypatch, ): + mock_watch.return_value.stream = mock.Mock( + side_effect=mock_pods_stream_that_returns_completed_pod + ) async with KubernetesWorker(work_pool_name="test") as k8s_worker: monkeypatch.setenv("PREFECT_KUBERNETES_CLUSTER_UID", "test-uid") fake_status = MagicMock(spec=anyio.abc.TaskStatus) @@ -1302,8 +1354,11 @@ async def test_task_group_start_returns_job_pid( mock_batch_client, mock_core_client, mock_watch, - monkeypatch, + mock_pods_stream_that_returns_completed_pod, ): + mock_watch.return_value.stream = mock.Mock( + side_effect=mock_pods_stream_that_returns_completed_pod + ) expected_value = f"{MOCK_CLUSTER_UID}:mock-namespace:mock-job" async with anyio.create_task_group() as tg: async with KubernetesWorker(work_pool_name="test") as k8s_worker: @@ -1317,11 +1372,15 @@ async def test_missing_job_returns_bad_status_code( mock_batch_client, mock_core_client, mock_watch, + mock_pods_stream_that_returns_completed_pod, caplog, ): - mock_batch_client.read_namespaced_job.side_effect = ApiException( + mock_batch_client.return_value.read_namespaced_job.side_effect = ApiException( status=404, reason="Job not found" ) + mock_watch.return_value.stream = mock.Mock( + side_effect=mock_pods_stream_that_returns_completed_pod + ) async with KubernetesWorker(work_pool_name="test") as k8s_worker: result = await k8s_worker.run( @@ -1353,19 +1412,22 @@ async def test_job_name_creates_valid_name( flow_run, mock_core_client, mock_watch, + mock_pods_stream_that_returns_completed_pod, mock_batch_client, job_name, clean_name, ): - mock_watch.stream = _mock_pods_stream_that_returns_running_pod default_configuration.name = job_name default_configuration.prepare_for_flow_run(flow_run) + mock_watch.return_value.stream = mock.Mock( + side_effect=mock_pods_stream_that_returns_completed_pod + ) async with KubernetesWorker(work_pool_name="test") as k8s_worker: await k8s_worker.run(flow_run=flow_run, configuration=default_configuration) - mock_batch_client.create_namespaced_job.assert_called_once() - call_name = mock_batch_client.create_namespaced_job.call_args[0][1][ - "metadata" - ]["generateName"] + mock_batch_client.return_value.create_namespaced_job.assert_called_once() + call_name = mock_batch_client.return_value.create_namespaced_job.call_args[ + 0 + ][1]["metadata"]["generateName"] assert call_name == clean_name async def test_uses_image_variable( @@ -1373,19 +1435,21 @@ async def test_uses_image_variable( flow_run, mock_core_client, mock_watch, + mock_pods_stream_that_returns_completed_pod, mock_batch_client, ): - mock_watch.stream = _mock_pods_stream_that_returns_running_pod - + mock_watch.return_value.stream = mock.Mock( + side_effect=mock_pods_stream_that_returns_completed_pod + ) configuration = await KubernetesWorkerJobConfiguration.from_template_and_values( KubernetesWorker.get_default_base_job_template(), {"image": "foo"} ) async with KubernetesWorker(work_pool_name="test") as k8s_worker: await k8s_worker.run(flow_run, configuration) - mock_batch_client.create_namespaced_job.assert_called_once() - image = mock_batch_client.create_namespaced_job.call_args[0][1]["spec"][ - "template" - ]["spec"]["containers"][0]["image"] + mock_batch_client.return_value.create_namespaced_job.assert_called_once() + image = mock_batch_client.return_value.create_namespaced_job.call_args[0][ + 1 + ]["spec"]["template"]["spec"]["containers"][0]["image"] assert image == "foo" async def test_can_store_api_key_in_secret( @@ -1393,11 +1457,16 @@ async def test_can_store_api_key_in_secret( flow_run, mock_core_client, mock_watch, + mock_pods_stream_that_returns_completed_pod, mock_batch_client, enable_store_api_key_in_secret, ): - mock_watch.stream = _mock_pods_stream_that_returns_running_pod - mock_core_client.read_namespaced_secret.side_effect = ApiException(status=404) + mock_watch.return_value.stream = mock.Mock( + side_effect=mock_pods_stream_that_returns_completed_pod + ) + mock_core_client.return_value.read_namespaced_secret.side_effect = ApiException( + status=404 + ) configuration = await KubernetesWorkerJobConfiguration.from_template_and_values( KubernetesWorker.get_default_base_job_template(), {"image": "foo"} @@ -1406,10 +1475,10 @@ async def test_can_store_api_key_in_secret( async with KubernetesWorker(work_pool_name="test") as k8s_worker: configuration.prepare_for_flow_run(flow_run=flow_run) await k8s_worker.run(flow_run, configuration) - mock_batch_client.create_namespaced_job.assert_called_once() - env = mock_batch_client.create_namespaced_job.call_args[0][1]["spec"][ - "template" - ]["spec"]["containers"][0]["env"] + mock_batch_client.return_value.create_namespaced_job.assert_called_once() + env = mock_batch_client.return_value.create_namespaced_job.call_args[0][ + 1 + ]["spec"]["template"]["spec"]["containers"][0]["env"] assert { "name": "PREFECT_API_KEY", "valueFrom": { @@ -1419,7 +1488,7 @@ async def test_can_store_api_key_in_secret( } }, } in env - mock_core_client.create_namespaced_secret.assert_called_with( + mock_core_client.return_value.create_namespaced_secret.assert_called_with( namespace=configuration.namespace, body=V1Secret( api_version="v1", @@ -1437,7 +1506,7 @@ async def test_can_store_api_key_in_secret( ) # Make sure secret gets deleted - assert mock_core_client.delete_namespaced_secret( + assert mock_core_client.return_value.delete_namespaced_secret( name=f"prefect-{_slugify_name(k8s_worker.name)}-api-key", namespace=configuration.namespace, ) @@ -1447,36 +1516,39 @@ async def test_store_api_key_in_existing_secret( flow_run, mock_core_client, mock_watch, + mock_pods_stream_that_returns_running_pod, mock_batch_client, enable_store_api_key_in_secret, ): - mock_watch.stream = _mock_pods_stream_that_returns_running_pod + mock_watch.return_value.stream = mock_pods_stream_that_returns_running_pod configuration = await KubernetesWorkerJobConfiguration.from_template_and_values( KubernetesWorker.get_default_base_job_template(), {"image": "foo"} ) with temporary_settings(updates={PREFECT_API_KEY: "fake"}): async with KubernetesWorker(work_pool_name="test") as k8s_worker: - mock_core_client.read_namespaced_secret.return_value = V1Secret( - api_version="v1", - kind="Secret", - metadata=V1ObjectMeta( - name=f"prefect-{_slugify_name(k8s_worker.name)}-api-key", - namespace=configuration.namespace, - ), - data={ - "value": base64.b64encode("fake".encode("utf-8")).decode( - "utf-8" - ) - }, + mock_core_client.return_value.read_namespaced_secret.return_value = ( + V1Secret( + api_version="v1", + kind="Secret", + metadata=V1ObjectMeta( + name=f"prefect-{_slugify_name(k8s_worker.name)}-api-key", + namespace=configuration.namespace, + ), + data={ + "value": base64.b64encode("fake".encode("utf-8")).decode( + "utf-8" + ) + }, + ) ) configuration.prepare_for_flow_run(flow_run=flow_run) await k8s_worker.run(flow_run, configuration) - mock_batch_client.create_namespaced_job.assert_called_once() - env = mock_batch_client.create_namespaced_job.call_args[0][1]["spec"][ - "template" - ]["spec"]["containers"][0]["env"] + mock_batch_client.return_value.create_namespaced_job.assert_called_once() + env = mock_batch_client.return_value.create_namespaced_job.call_args[0][ + 1 + ]["spec"]["template"]["spec"]["containers"][0]["env"] assert { "name": "PREFECT_API_KEY", "valueFrom": { @@ -1486,7 +1558,7 @@ async def test_store_api_key_in_existing_secret( } }, } in env - mock_core_client.replace_namespaced_secret.assert_called_with( + mock_core_client.return_value.replace_namespaced_secret.assert_called_with( name=f"prefect-{_slugify_name(k8s_worker.name)}-api-key", namespace=configuration.namespace, body=V1Secret( @@ -1527,7 +1599,7 @@ async def test_create_job_failure( response.status = 403 response.reason = "Forbidden" - mock_batch_client.create_namespaced_job.side_effect = ApiException( + mock_batch_client.return_value.create_namespaced_job.side_effect = ApiException( http_resp=response ) @@ -1570,7 +1642,7 @@ async def test_create_job_retries( response.status = 403 response.reason = "Forbidden" - mock_batch_client.create_namespaced_job.side_effect = ApiException( + mock_batch_client.return_value.create_namespaced_job.side_effect = ApiException( http_resp=response ) @@ -1589,7 +1661,10 @@ async def test_create_job_retries( ): await k8s_worker.run(flow_run, configuration) - assert mock_batch_client.create_namespaced_job.call_count == MAX_ATTEMPTS + assert ( + mock_batch_client.return_value.create_namespaced_job.call_count + == MAX_ATTEMPTS + ) async def test_create_job_failure_no_reason( self, @@ -1614,7 +1689,7 @@ async def test_create_job_failure_no_reason( response.status = 403 response.reason = None - mock_batch_client.create_namespaced_job.side_effect = ApiException( + mock_batch_client.return_value.create_namespaced_job.side_effect = ApiException( http_resp=response ) @@ -1655,7 +1730,7 @@ async def test_create_job_failure_no_message( response.status = 403 response.reason = "Test" - mock_batch_client.create_namespaced_job.side_effect = ApiException( + mock_batch_client.return_value.create_namespaced_job.side_effect = ApiException( http_resp=response ) @@ -1681,7 +1756,7 @@ async def test_create_job_failure_no_response_body( response.status = 403 response.reason = "Test" - mock_batch_client.create_namespaced_job.side_effect = ApiException( + mock_batch_client.return_value.create_namespaced_job.side_effect = ApiException( http_resp=response ) @@ -1701,9 +1776,10 @@ async def test_allows_image_setting_from_manifest( flow_run, mock_core_client, mock_watch, + mock_pods_stream_that_returns_running_pod, mock_batch_client, ): - mock_watch.stream = _mock_pods_stream_that_returns_running_pod + mock_watch.return_value.stream = mock_pods_stream_that_returns_running_pod default_configuration.job_manifest["spec"]["template"]["spec"]["containers"][0][ "image" @@ -1712,10 +1788,10 @@ async def test_allows_image_setting_from_manifest( async with KubernetesWorker(work_pool_name="test") as k8s_worker: await k8s_worker.run(flow_run, default_configuration) - mock_batch_client.create_namespaced_job.assert_called_once() - image = mock_batch_client.create_namespaced_job.call_args[0][1]["spec"][ - "template" - ]["spec"]["containers"][0]["image"] + mock_batch_client.return_value.create_namespaced_job.assert_called_once() + image = mock_batch_client.return_value.create_namespaced_job.call_args[0][ + 1 + ]["spec"]["template"]["spec"]["containers"][0]["image"] assert image == "test" async def test_uses_labels_setting( @@ -1723,10 +1799,10 @@ async def test_uses_labels_setting( flow_run, mock_core_client, mock_watch, + mock_pods_stream_that_returns_running_pod, mock_batch_client, ): - mock_watch.stream = _mock_pods_stream_that_returns_running_pod - + mock_watch.return_value.stream = mock_pods_stream_that_returns_running_pod configuration = await KubernetesWorkerJobConfiguration.from_template_and_values( KubernetesWorker.get_default_base_job_template(), {"labels": {"foo": "foo", "bar": "bar"}}, @@ -1734,10 +1810,10 @@ async def test_uses_labels_setting( async with KubernetesWorker(work_pool_name="test") as k8s_worker: await k8s_worker.run(flow_run, configuration) - mock_batch_client.create_namespaced_job.assert_called_once() - labels = mock_batch_client.create_namespaced_job.call_args[0][1][ - "metadata" - ]["labels"] + mock_batch_client.return_value.create_namespaced_job.assert_called_once() + labels = mock_batch_client.return_value.create_namespaced_job.call_args[0][ + 1 + ]["metadata"]["labels"] assert labels["foo"] == "foo" assert labels["bar"] == "bar" @@ -1746,8 +1822,10 @@ async def test_sets_environment_variables( flow_run, mock_core_client, mock_watch, + mock_pods_stream_that_returns_running_pod, mock_batch_client, ): + mock_watch.return_value.stream = mock_pods_stream_that_returns_running_pod configuration = await KubernetesWorkerJobConfiguration.from_template_and_values( KubernetesWorker.get_default_base_job_template(), {"env": {"foo": "FOO", "bar": "BAR"}}, @@ -1756,9 +1834,11 @@ async def test_sets_environment_variables( async with KubernetesWorker(work_pool_name="test") as k8s_worker: await k8s_worker.run(flow_run, configuration) - mock_batch_client.create_namespaced_job.assert_called_once() + mock_batch_client.return_value.create_namespaced_job.assert_called_once() - manifest = mock_batch_client.create_namespaced_job.call_args[0][1] + manifest = mock_batch_client.return_value.create_namespaced_job.call_args[ + 0 + ][1] pod = manifest["spec"]["template"]["spec"] env = pod["containers"][0]["env"] assert env == [ @@ -1776,8 +1856,10 @@ async def test_allows_unsetting_environment_variables( flow_run, mock_core_client, mock_watch, + mock_pods_stream_that_returns_running_pod, mock_batch_client, ): + mock_watch.return_value.stream = mock_pods_stream_that_returns_running_pod configuration = await KubernetesWorkerJobConfiguration.from_template_and_values( KubernetesWorker.get_default_base_job_template(), {"env": {"PREFECT_TEST_MODE": None}}, @@ -1785,9 +1867,11 @@ async def test_allows_unsetting_environment_variables( configuration.prepare_for_flow_run(flow_run) async with KubernetesWorker(work_pool_name="test") as k8s_worker: await k8s_worker.run(flow_run, configuration) - mock_batch_client.create_namespaced_job.assert_called_once() + mock_batch_client.return_value.create_namespaced_job.assert_called_once() - manifest = mock_batch_client.create_namespaced_job.call_args[0][1] + manifest = mock_batch_client.return_value.create_namespaced_job.call_args[ + 0 + ][1] pod = manifest["spec"]["template"]["spec"] env = pod["containers"][0]["env"] env_names = {variable["name"] for variable in env} @@ -1838,23 +1922,26 @@ async def test_sanitizes_user_label_keys( flow_run, mock_core_client, mock_watch, + mock_pods_stream_that_returns_running_pod, mock_batch_client, given, expected, ): - mock_watch.stream = _mock_pods_stream_that_returns_running_pod + mock_watch.return_value.stream = mock_pods_stream_that_returns_running_pod configuration = await KubernetesWorkerJobConfiguration.from_template_and_values( KubernetesWorker.get_default_base_job_template(), - {"labels": {given: "foo"}}, + { + "labels": {given: "foo"}, + }, ) configuration.prepare_for_flow_run(flow_run) async with KubernetesWorker(work_pool_name="test") as k8s_worker: await k8s_worker.run(flow_run, configuration) - mock_batch_client.create_namespaced_job.assert_called_once() - labels = mock_batch_client.create_namespaced_job.call_args[0][1][ - "metadata" - ]["labels"] + mock_batch_client.return_value.create_namespaced_job.assert_called_once() + labels = mock_batch_client.return_value.create_namespaced_job.call_args[0][ + 1 + ]["metadata"]["labels"] assert labels[expected] == "foo" @pytest.mark.parametrize( @@ -1886,11 +1973,12 @@ async def test_sanitizes_user_label_values( flow_run, mock_core_client, mock_watch, + mock_pods_stream_that_returns_running_pod, mock_batch_client, given, expected, ): - mock_watch.stream = _mock_pods_stream_that_returns_running_pod + mock_watch.return_value.stream = mock_pods_stream_that_returns_running_pod configuration = await KubernetesWorkerJobConfiguration.from_template_and_values( KubernetesWorker.get_default_base_job_template(), @@ -1900,10 +1988,10 @@ async def test_sanitizes_user_label_values( async with KubernetesWorker(work_pool_name="test") as k8s_worker: await k8s_worker.run(flow_run, configuration) - mock_batch_client.create_namespaced_job.assert_called_once() - labels = mock_batch_client.create_namespaced_job.call_args[0][1][ - "metadata" - ]["labels"] + mock_batch_client.return_value.create_namespaced_job.assert_called_once() + labels = mock_batch_client.return_value.create_namespaced_job.call_args[0][ + 1 + ]["metadata"]["labels"] assert labels["foo"] == expected async def test_uses_namespace_setting( @@ -1911,9 +1999,10 @@ async def test_uses_namespace_setting( flow_run, mock_core_client, mock_watch, + mock_pods_stream_that_returns_running_pod, mock_batch_client, ): - mock_watch.stream = _mock_pods_stream_that_returns_running_pod + mock_watch.return_value.stream = mock_pods_stream_that_returns_running_pod configuration = await KubernetesWorkerJobConfiguration.from_template_and_values( KubernetesWorker.get_default_base_job_template(), {"namespace": "foo"}, @@ -1921,10 +2010,10 @@ async def test_uses_namespace_setting( async with KubernetesWorker(work_pool_name="test") as k8s_worker: await k8s_worker.run(flow_run, configuration) - mock_batch_client.create_namespaced_job.assert_called_once() - namespace = mock_batch_client.create_namespaced_job.call_args[0][1][ - "metadata" - ]["namespace"] + mock_batch_client.return_value.create_namespaced_job.assert_called_once() + namespace = mock_batch_client.return_value.create_namespaced_job.call_args[ + 0 + ][1]["metadata"]["namespace"] assert namespace == "foo" async def test_allows_namespace_setting_from_manifest( @@ -1933,19 +2022,20 @@ async def test_allows_namespace_setting_from_manifest( default_configuration, mock_core_client, mock_watch, + mock_pods_stream_that_returns_running_pod, mock_batch_client, ): - mock_watch.stream = _mock_pods_stream_that_returns_running_pod + mock_watch.return_value.stream = mock_pods_stream_that_returns_running_pod default_configuration.job_manifest["metadata"]["namespace"] = "test" default_configuration.prepare_for_flow_run(flow_run) async with KubernetesWorker(work_pool_name="test") as k8s_worker: await k8s_worker.run(flow_run, default_configuration) - mock_batch_client.create_namespaced_job.assert_called_once() - namespace = mock_batch_client.create_namespaced_job.call_args[0][1][ - "metadata" - ]["namespace"] + mock_batch_client.return_value.create_namespaced_job.assert_called_once() + namespace = mock_batch_client.return_value.create_namespaced_job.call_args[ + 0 + ][1]["metadata"]["namespace"] assert namespace == "test" async def test_uses_service_account_name_setting( @@ -1953,9 +2043,10 @@ async def test_uses_service_account_name_setting( flow_run, mock_core_client, mock_watch, + mock_pods_stream_that_returns_running_pod, mock_batch_client, ): - mock_watch.stream = _mock_pods_stream_that_returns_running_pod + mock_watch.return_value.stream = mock_pods_stream_that_returns_running_pod configuration = await KubernetesWorkerJobConfiguration.from_template_and_values( KubernetesWorker.get_default_base_job_template(), {"service_account_name": "foo"}, @@ -1963,10 +2054,12 @@ async def test_uses_service_account_name_setting( async with KubernetesWorker(work_pool_name="test") as k8s_worker: await k8s_worker.run(flow_run, configuration) - mock_batch_client.create_namespaced_job.assert_called_once() - service_account_name = mock_batch_client.create_namespaced_job.call_args[0][ - 1 - ]["spec"]["template"]["spec"]["serviceAccountName"] + mock_batch_client.return_value.create_namespaced_job.assert_called_once() + service_account_name = ( + mock_batch_client.return_value.create_namespaced_job.call_args[0][1][ + "spec" + ]["template"]["spec"]["serviceAccountName"] + ) assert service_account_name == "foo" async def test_uses_finished_job_ttl_setting( @@ -1974,9 +2067,10 @@ async def test_uses_finished_job_ttl_setting( flow_run, mock_core_client, mock_watch, + mock_pods_stream_that_returns_running_pod, mock_batch_client, ): - mock_watch.stream = _mock_pods_stream_that_returns_running_pod + mock_watch.return_value.stream = mock_pods_stream_that_returns_running_pod configuration = await KubernetesWorkerJobConfiguration.from_template_and_values( KubernetesWorker.get_default_base_job_template(), {"finished_job_ttl": 123}, @@ -1984,10 +2078,12 @@ async def test_uses_finished_job_ttl_setting( async with KubernetesWorker(work_pool_name="test") as k8s_worker: await k8s_worker.run(flow_run, configuration) - mock_batch_client.create_namespaced_job.assert_called_once() - finished_job_ttl = mock_batch_client.create_namespaced_job.call_args[0][1][ - "spec" - ]["ttlSecondsAfterFinished"] + mock_batch_client.return_value.create_namespaced_job.assert_called_once() + finished_job_ttl = ( + mock_batch_client.return_value.create_namespaced_job.call_args[0][1][ + "spec" + ]["ttlSecondsAfterFinished"] + ) assert finished_job_ttl == 123 async def test_uses_specified_image_pull_policy( @@ -1995,31 +2091,43 @@ async def test_uses_specified_image_pull_policy( flow_run, mock_core_client, mock_watch, + mock_pods_stream_that_returns_running_pod, mock_batch_client, ): - mock_watch.stream = _mock_pods_stream_that_returns_running_pod + mock_watch.return_value.stream = mock_pods_stream_that_returns_running_pod configuration = await KubernetesWorkerJobConfiguration.from_template_and_values( KubernetesWorker.get_default_base_job_template(), {"image_pull_policy": "IfNotPresent"}, ) async with KubernetesWorker(work_pool_name="test") as k8s_worker: await k8s_worker.run(flow_run, configuration) - mock_batch_client.create_namespaced_job.assert_called_once() - call_image_pull_policy = mock_batch_client.create_namespaced_job.call_args[ - 0 - ][1]["spec"]["template"]["spec"]["containers"][0].get("imagePullPolicy") + mock_batch_client.return_value.create_namespaced_job.assert_called_once() + call_image_pull_policy = ( + mock_batch_client.return_value.create_namespaced_job.call_args[0][1][ + "spec" + ]["template"]["spec"]["containers"][0].get("imagePullPolicy") + ) assert call_image_pull_policy == "IfNotPresent" async def test_defaults_to_incluster_config( self, flow_run, default_configuration, - mock_core_client, + mock_core_client_lean, mock_watch, mock_cluster_config, mock_batch_client, + mock_job, + mock_pod, ): - mock_watch.stream = _mock_pods_stream_that_returns_running_pod + async def mock_stream(*args, **kwargs): + if kwargs["func"] == mock_core_client_lean.return_value.list_namespaced_pod: + yield {"object": mock_pod, "type": "MODIFIED"} + if kwargs["func"] == mock_core_client_lean.return_value.list_namespaced_job: + mock_job.status.completion_time = pendulum.now("utc").timestamp() + yield {"object": mock_job, "type": "MODIFIED"} + + mock_watch.return_value.stream = mock_stream async with KubernetesWorker(work_pool_name="test") as k8s_worker: await k8s_worker.run(flow_run, default_configuration) @@ -2031,609 +2139,520 @@ async def test_uses_cluster_config_if_not_in_cluster( self, flow_run, default_configuration, - mock_core_client, mock_watch, mock_cluster_config, mock_batch_client, + mock_core_client_lean, + mock_job, + mock_pod, ): - mock_watch.stream = _mock_pods_stream_that_returns_running_pod + async def mock_stream(*args, **kwargs): + if kwargs["func"] == mock_core_client_lean.return_value.list_namespaced_pod: + yield {"object": mock_pod, "type": "MODIFIED"} + if kwargs["func"] == mock_core_client_lean.return_value.list_namespaced_job: + mock_job.status.completion_time = pendulum.now("utc").timestamp() + yield {"object": mock_job, "type": "MODIFIED"} + + mock_watch.return_value.stream = mock_stream mock_cluster_config.load_incluster_config.side_effect = ConfigException() async with KubernetesWorker(work_pool_name="test") as k8s_worker: await k8s_worker.run(flow_run, default_configuration) - mock_cluster_config.new_client_from_config.assert_called_once() - @pytest.mark.parametrize("job_timeout", [24, 100]) - async def test_allows_configurable_timeouts_for_pod_and_job_watches( - self, - mock_core_client, - mock_watch, - mock_batch_client, - job_timeout, - default_configuration: KubernetesWorkerJobConfiguration, - flow_run, - ): - mock_watch.stream = Mock(side_effect=_mock_pods_stream_that_returns_running_pod) + class TestPodWatch: + @pytest.mark.parametrize("job_timeout", [24, 100]) + async def test_allows_configurable_timeouts_for_pod_and_job_watches( + self, + mock_core_client, + mock_watch, + mock_batch_client, + job_timeout, + default_configuration: KubernetesWorkerJobConfiguration, + flow_run, + mock_pod, + mock_job, + ): + async def mock_stream(*args, **kwargs): + mock_job.status.completion_time = pendulum.now("utc").timestamp() + stream = [ + {"object": mock_job, "type": "MODIFIED"}, + {"object": mock_pod, "type": "MODIFIED"}, + ] + for item in stream: + yield item + + mock_watch.return_value.stream = mock.Mock(side_effect=mock_stream) + + # The job should not be completed to start + mock_batch_client.return_value.read_namespaced_job.return_value.status.completion_time = None + + k8s_job_args = dict( + command=["echo", "hello"], + pod_watch_timeout_seconds=42, + ) + expected_job_call_kwargs = dict( + func=mock_batch_client.return_value.list_namespaced_job, + namespace=mock.ANY, + field_selector=mock.ANY, + ) - # The job should not be completed to start - mock_batch_client.read_namespaced_job.return_value.status.completion_time = None + if job_timeout is not None: + k8s_job_args["job_watch_timeout_seconds"] = job_timeout + expected_job_call_kwargs["timeout_seconds"] = pytest.approx( + job_timeout, abs=1 + ) - k8s_job_args = dict( - command=["echo", "hello"], - pod_watch_timeout_seconds=42, - ) - expected_job_call_kwargs = dict( - func=mock_batch_client.list_namespaced_job, - namespace=mock.ANY, - field_selector=mock.ANY, - ) + default_configuration.job_watch_timeout_seconds = job_timeout + default_configuration.pod_watch_timeout_seconds = 42 - if job_timeout is not None: - k8s_job_args["job_watch_timeout_seconds"] = job_timeout - expected_job_call_kwargs["timeout_seconds"] = pytest.approx( - job_timeout, abs=1 + async with KubernetesWorker(work_pool_name="test") as k8s_worker: + await k8s_worker.run(flow_run, default_configuration) + + mock_watch.return_value.stream.assert_has_calls( + [ + mock.call( + func=mock_core_client.return_value.list_namespaced_pod, + namespace=mock.ANY, + label_selector=mock.ANY, + timeout_seconds=42, + ), + mock.call(**expected_job_call_kwargs), + ] ) - default_configuration.job_watch_timeout_seconds = job_timeout - default_configuration.pod_watch_timeout_seconds = 42 - - async with KubernetesWorker(work_pool_name="test") as k8s_worker: - await k8s_worker.run(flow_run, default_configuration) + @pytest.mark.parametrize("job_timeout", [None]) + async def test_excludes_timeout_from_job_watches_when_null( + self, + flow_run, + default_configuration, + mock_core_client, + mock_watch, + mock_pods_stream_that_returns_running_pod, + mock_batch_client, + job_timeout, + mock_pod, + mock_job, + ): + async def mock_stream(*args, **kwargs): + mock_job.status.completion_time = pendulum.now("utc").timestamp() + stream = [ + {"object": mock_job, "type": "MODIFIED"}, + {"object": mock_pod, "type": "MODIFIED"}, + ] + for item in stream: + yield item - mock_watch.stream.assert_has_calls( - [ - mock.call( - func=mock_core_client.list_namespaced_pod, - namespace=mock.ANY, - label_selector=mock.ANY, - timeout_seconds=42, - ), - mock.call(**expected_job_call_kwargs), - ] - ) + mock_watch.return_value.stream = mock.Mock(side_effect=mock_stream) + # The job should not be completed to start + mock_batch_client.return_value.read_namespaced_job.return_value.status.completion_time = None - @pytest.mark.parametrize("job_timeout", [None]) - async def test_excludes_timeout_from_job_watches_when_null( - self, - flow_run, - default_configuration, - mock_core_client, - mock_watch, - mock_batch_client, - job_timeout, - ): - mock_watch.stream = mock.Mock( - side_effect=_mock_pods_stream_that_returns_running_pod - ) - # The job should not be completed to start - mock_batch_client.read_namespaced_job.return_value.status.completion_time = None + default_configuration.job_watch_timeout_seconds = job_timeout - default_configuration.job_watch_timeout_seconds = job_timeout + async with KubernetesWorker(work_pool_name="test") as k8s_worker: + await k8s_worker.run(flow_run, default_configuration) + + mock_watch.return_value.stream.assert_has_calls( + [ + mock.call( + func=mock_core_client.return_value.list_namespaced_pod, + namespace=mock.ANY, + label_selector=mock.ANY, + timeout_seconds=mock.ANY, + ), + mock.call( + func=mock_batch_client.return_value.list_namespaced_job, + namespace=mock.ANY, + field_selector=mock.ANY, + # Note: timeout_seconds is excluded here + ), + ] + ) - async with KubernetesWorker(work_pool_name="test") as k8s_worker: - await k8s_worker.run(flow_run, default_configuration) + async def test_watches_the_right_namespace( + self, + flow_run, + default_configuration, + mock_core_client, + mock_watch, + mock_batch_client, + mock_pod, + mock_job, + ): + async def mock_stream(*args, **kwargs): + mock_job.status.completion_time = pendulum.now("utc").timestamp() + stream = [ + {"object": mock_job, "type": "MODIFIED"}, + {"object": mock_pod, "type": "MODIFIED"}, + ] + for item in stream: + yield item + + mock_watch.return_value.stream = mock.Mock(side_effect=mock_stream) + # The job should not be completed to start + mock_batch_client.return_value.read_namespaced_job.return_value.status.completion_time = None + default_configuration.namespace = "my-awesome-flows" + default_configuration.prepare_for_flow_run(flow_run) - mock_watch.stream.assert_has_calls( - [ - mock.call( - func=mock_core_client.list_namespaced_pod, - namespace=mock.ANY, - label_selector=mock.ANY, - timeout_seconds=mock.ANY, - ), - mock.call( - func=mock_batch_client.list_namespaced_job, - namespace=mock.ANY, - field_selector=mock.ANY, - # Note: timeout_seconds is excluded here - ), - ] - ) + async with KubernetesWorker(work_pool_name="test") as k8s_worker: + await k8s_worker.run(flow_run, default_configuration) + + mock_watch.return_value.stream.assert_has_calls( + [ + mock.call( + func=mock_core_client.return_value.list_namespaced_pod, + namespace="my-awesome-flows", + label_selector=mock.ANY, + timeout_seconds=60, + ), + mock.call( + func=mock_batch_client.return_value.list_namespaced_job, + namespace="my-awesome-flows", + field_selector=mock.ANY, + ), + ] + ) - async def test_watches_the_right_namespace( - self, - flow_run, - default_configuration, - mock_core_client, - mock_watch, - mock_batch_client, - ): - mock_watch.stream = mock.Mock( - side_effect=_mock_pods_stream_that_returns_running_pod - ) - # The job should not be completed to start - mock_batch_client.read_namespaced_job.return_value.status.completion_time = None - default_configuration.namespace = "my-awesome-flows" - default_configuration.prepare_for_flow_run(flow_run) + async def test_streaming_pod_logs_timeout_warns( + self, + flow_run, + default_configuration: KubernetesWorkerJobConfiguration, + mock_core_client, + mock_watch, + mock_batch_client, + caplog, + mock_pod, + mock_job, + ): + async def mock_stream(*args, **kwargs): + mock_job.status.completion_time = pendulum.now("utc").timestamp() + stream = [ + {"object": mock_job, "type": "MODIFIED"}, + {"object": mock_pod, "type": "MODIFIED"}, + ] + for item in stream: + yield item + + mock_watch.return_value.stream = mock.Mock(side_effect=mock_stream) + # The job should not be completed to start + mock_batch_client.return_value.read_namespaced_job.return_value.status.completion_time = None + + async def mock_log_stream(*args, **kwargs): + yield RuntimeError("something went wrong") + + mock_core_client.return_value.read_namespaced_pod_log.return_value.content = mock_log_stream + async with KubernetesWorker(work_pool_name="test") as k8s_worker: + with caplog.at_level("WARNING"): + result = await k8s_worker.run(flow_run, default_configuration) - async with KubernetesWorker(work_pool_name="test") as k8s_worker: - await k8s_worker.run(flow_run, default_configuration) + assert result.status_code == 1 + assert "Error occurred while streaming logs - " in caplog.text - mock_watch.stream.assert_has_calls( - [ - mock.call( - func=mock_core_client.list_namespaced_pod, - namespace="my-awesome-flows", - label_selector=mock.ANY, - timeout_seconds=60, - ), - mock.call( - func=mock_batch_client.list_namespaced_job, - namespace="my-awesome-flows", - field_selector=mock.ANY, - ), - ] - ) + async def test_watch_timeout( + self, + mock_core_client, + mock_watch, + mock_batch_client, + flow_run, + default_configuration, + mock_pod, + ): + # The job should not be completed to start + mock_batch_client.return_value.read_namespaced_job.return_value.status.completion_time = None - async def test_streaming_pod_logs_timeout_warns( - self, - flow_run, - default_configuration: KubernetesWorkerJobConfiguration, - mock_core_client, - mock_watch, - mock_batch_client, - caplog, - ): - mock_watch.stream = _mock_pods_stream_that_returns_running_pod - # The job should not be completed to start - mock_batch_client.read_namespaced_job.return_value.status.completion_time = None + async def mock_stream(*args, **kwargs): + if kwargs["func"] == mock_core_client.return_value.list_namespaced_pod: + yield {"object": mock_pod, "type": "ADDED"} - mock_logs = MagicMock() - mock_logs.stream = MagicMock(side_effect=RuntimeError("something went wrong")) + if kwargs["func"] == mock_batch_client.return_value.list_namespaced_job: + job = MagicMock(spec=kubernetes_asyncio.client.V1Job) + job.status.completion_time = None + yield {"object": job, "type": "ADDED"} + sleep(0.5) + yield {"object": job, "type": "ADDED"} - mock_core_client.read_namespaced_pod_log = MagicMock(return_value=mock_logs) + default_configuration.pod_watch_timeout_seconds = 42 + default_configuration.job_watch_timeout_seconds = 0 + mock_watch.return_value.stream = mock.Mock(side_effect=mock_stream) - async with KubernetesWorker(work_pool_name="test") as k8s_worker: - with caplog.at_level("WARNING"): + async with KubernetesWorker(work_pool_name="test") as k8s_worker: result = await k8s_worker.run(flow_run, default_configuration) + assert result.status_code == -1 - assert result.status_code == 1 - assert "Error occurred while streaming logs - " in caplog.text - - async def test_watch_timeout( - self, - mock_core_client, - mock_watch, - mock_batch_client, - flow_run, - default_configuration, - ): - # The job should not be completed to start - mock_batch_client.read_namespaced_job.return_value.status.completion_time = None + async def test_watch_deadline_is_computed_before_log_streams( + self, + flow_run, + default_configuration, + mock_core_client, + mock_watch, + mock_batch_client, + mock_pod, + ): + # The job should not be completed to start + mock_batch_client.return_value.read_namespaced_job.return_value.status.completion_time = None - def mock_stream(*args, **kwargs): - if kwargs["func"] == mock_core_client.list_namespaced_pod: - job_pod = MagicMock(spec=kubernetes.client.V1Pod) - job_pod.status.phase = "Running" - yield {"object": job_pod} + async def mock_stream(*args, **kwargs): + if kwargs["func"] == mock_core_client.return_value.list_namespaced_pod: + yield {"object": mock_pod, "type": "MODIFIED"} - if kwargs["func"] == mock_batch_client.list_namespaced_job: - job = MagicMock(spec=kubernetes.client.V1Job) - job.status.completion_time = None - yield {"object": job} - sleep(0.5) - yield {"object": job} + if kwargs["func"] == mock_batch_client.return_value.list_namespaced_job: + job = MagicMock(spec=kubernetes_asyncio.client.V1Job) - mock_watch.stream.side_effect = mock_stream + # Yield the completed job + job.status.completion_time = True + job.status.failed = 0 + job.spec.backoff_limit = 6 + yield {"object": job, "type": "ADDED"} - default_configuration.pod_watch_timeout_seconds = 42 - default_configuration.job_watch_timeout_seconds = 0 + async def mock_log_stream(*args, **kwargs): + await anyio.sleep(50) + yield MagicMock() - async with KubernetesWorker(work_pool_name="test") as k8s_worker: - result = await k8s_worker.run(flow_run, default_configuration) - assert result.status_code == -1 + mock_core_client.return_value.read_namespaced_pod_log.return_value.stream = mock_log_stream + mock_watch.return_value.stream = mock.Mock(side_effect=mock_stream) - async def test_watch_deadline_is_computed_before_log_streams( - self, - flow_run, - default_configuration, - mock_core_client, - mock_watch, - mock_batch_client, - mock_anyio_sleep_monotonic, - ): - # The job should not be completed to start - mock_batch_client.read_namespaced_job.return_value.status.completion_time = None - - def mock_stream(*args, **kwargs): - if kwargs["func"] == mock_core_client.list_namespaced_pod: - job_pod = MagicMock(spec=kubernetes.client.V1Pod) - job_pod.status.phase = "Running" - yield {"object": job_pod} - - if kwargs["func"] == mock_batch_client.list_namespaced_job: - job = MagicMock(spec=kubernetes.client.V1Job) - - # Yield the completed job - job.status.completion_time = True - job.status.failed = 0 - job.spec.backoff_limit = 6 - yield {"object": job, "type": "ADDED"} + default_configuration.job_watch_timeout_seconds = 100 + async with KubernetesWorker(work_pool_name="test") as k8s_worker: + result = await k8s_worker.run(flow_run, default_configuration) - def mock_log_stream(*args, **kwargs): - anyio.sleep(500) - return MagicMock() + assert result.status_code == 1 - mock_core_client.read_namespaced_pod_log.side_effect = mock_log_stream - mock_watch.stream.side_effect = mock_stream + mock_watch.return_value.stream.assert_has_calls( + [ + mock.call( + func=mock_core_client.return_value.list_namespaced_pod, + namespace=mock.ANY, + label_selector=mock.ANY, + timeout_seconds=mock.ANY, + ), + # Starts with the full timeout minus the amount we slept streaming logs + mock.call( + func=mock_batch_client.return_value.list_namespaced_job, + field_selector=mock.ANY, + namespace=mock.ANY, + timeout_seconds=pytest.approx(50, 1), + ), + ] + ) - default_configuration.job_watch_timeout_seconds = 1000 - async with KubernetesWorker(work_pool_name="test") as k8s_worker: - result = await k8s_worker.run(flow_run, default_configuration) + async def test_watch_timeout_is_restarted_until_job_is_complete( + self, + flow_run, + default_configuration, + mock_core_client, + mock_watch, + mock_batch_client, + mock_pod, + ): + # The job should not be completed to start + mock_batch_client.return_value.read_namespaced_job.return_value.status.completion_time = None - assert result.status_code == 1 + # TODO investigate why it needs type + async def mock_stream(*args, **kwargs): + if kwargs["func"] == mock_core_client.return_value.list_namespaced_pod: + yield {"object": mock_pod, "type": "MODIFIED"} - mock_watch.stream.assert_has_calls( - [ - mock.call( - func=mock_core_client.list_namespaced_pod, - namespace=mock.ANY, - label_selector=mock.ANY, - timeout_seconds=mock.ANY, - ), - # Starts with the full timeout minus the amount we slept streaming logs - mock.call( - func=mock_batch_client.list_namespaced_job, - field_selector=mock.ANY, - namespace=mock.ANY, - timeout_seconds=pytest.approx(500, 1), - ), - ] - ) + if kwargs["func"] == mock_batch_client.return_value.list_namespaced_job: + job = MagicMock(spec=kubernetes_asyncio.client.V1Job) - async def test_timeout_is_checked_during_log_streams( - self, - flow_run, - default_configuration, - mock_core_client, - mock_watch, - mock_batch_client, - capsys, - ): - # The job should not be completed to start - mock_batch_client.read_namespaced_job.return_value.status.completion_time = None - - def mock_stream(*args, **kwargs): - if kwargs["func"] == mock_core_client.list_namespaced_pod: - job_pod = MagicMock(spec=kubernetes.client.V1Pod) - job_pod.status.phase = "Running" - yield {"object": job_pod, "type": "ADDED"} - - if kwargs["func"] == mock_batch_client.list_namespaced_job: - job = MagicMock(spec=kubernetes.client.V1Job) - - # Yield the job then return exiting the stream - # After restarting the watch a few times, we'll report completion - job.status.completion_time = ( - None if mock_watch.stream.call_count < 3 else True - ) - yield {"object": job} + # Sleep a little + await anyio.sleep(10) - def mock_log_stream(*args, **kwargs): - for i in range(10): - sleep(0.25) - yield f"test {i}".encode() + # Yield the job then return exiting the stream + job.status.completion_time = None + job.status.failed = 0 + job.spec.backoff_limit = 6 + yield {"object": job, "type": "ADDED"} - mock_core_client.read_namespaced_pod_log.return_value.stream = mock_log_stream - mock_watch.stream.side_effect = mock_stream + # mock_watch.return_value.stream = mock_stream + mock_watch.return_value.stream = mock.Mock(side_effect=mock_stream) + default_configuration.job_watch_timeout_seconds = 1 + async with KubernetesWorker(work_pool_name="test") as k8s_worker: + result = await k8s_worker.run(flow_run, default_configuration) - default_configuration.job_watch_timeout_seconds = 1 + assert result.status_code == -1 - async with KubernetesWorker(work_pool_name="test") as k8s_worker: - result = await k8s_worker.run(flow_run, default_configuration) - - # The job should timeout - assert result.status_code == -1 - - mock_watch.stream.assert_has_calls( - [ - mock.call( - func=mock_core_client.list_namespaced_pod, - namespace=mock.ANY, - label_selector=mock.ANY, - timeout_seconds=mock.ANY, - ), - # No watch call is made because the deadline is exceeded beforehand + async def test_watch_stops_after_backoff_limit_reached( + self, + flow_run, + default_configuration, + mock_core_client, + mock_watch, + mock_batch_client, + mock_pod, + ): + # The job should not be completed to start + mock_batch_client.return_value.read_namespaced_job.return_value.status.completion_time = None + job_pod = MagicMock(spec=kubernetes_asyncio.client.V1Pod) + job_pod.status.phase = "Running" + mock_container_status = MagicMock( + spec=kubernetes_asyncio.client.V1ContainerStatus + ) + mock_container_status.state.terminated.exit_code = 137 + job_pod.status.container_statuses = [mock_container_status] + mock_core_client.return_value.list_namespaced_pod.return_value.items = [ + job_pod ] - ) - - # Check for logs - stdout, _ = capsys.readouterr() - - # Before the deadline, logs should be displayed - for i in range(4): - assert f"test {i}" in stdout - for i in range(4, 10): - assert f"test {i}" not in stdout - - async def test_timeout_during_log_stream_does_not_fail_completed_job( - self, - mock_core_client, - mock_watch, - mock_batch_client, - capsys, - flow_run, - default_configuration, - ): - # Pretend the job is completed immediately - mock_batch_client.read_namespaced_job.return_value.status.completion_time = True - def mock_stream(*args, **kwargs): - if kwargs["func"] == mock_core_client.list_namespaced_pod: - job_pod = MagicMock(spec=kubernetes.client.V1Pod) - job_pod.status.phase = "Running" - yield {"object": job_pod} + # TODO investigate why it needs type + async def mock_stream(*args, **kwargs): + if kwargs["func"] == mock_core_client.return_value.list_namespaced_pod: + yield {"object": mock_pod, "type": "ADDED"} - def mock_log_stream(*args, **kwargs): - for i in range(10): - sleep(0.25) - yield f"test {i}".encode() + if kwargs["func"] == mock_batch_client.return_value.list_namespaced_job: + job = MagicMock(spec=kubernetes_asyncio.client.V1Job) - mock_core_client.read_namespaced_pod_log.return_value.stream = mock_log_stream - mock_watch.stream.side_effect = mock_stream + # Yield the job then return exiting the stream + job.status.completion_time = None + job.spec.backoff_limit = 6 + for i in range(0, 8): + job.status.failed = i + yield {"object": job, "type": "ADDED"} - default_configuration.job_watch_timeout_seconds = 1 - async with KubernetesWorker(work_pool_name="test") as k8s_worker: - result = await k8s_worker.run(flow_run, default_configuration) - - # The job should not timeout - assert result.status_code == 1 - - mock_watch.stream.assert_has_calls( - [ - mock.call( - func=mock_core_client.list_namespaced_pod, - namespace=mock.ANY, - label_selector=mock.ANY, - timeout_seconds=mock.ANY, - ), - # No watch call is made because the job is completed already - ] - ) + mock_watch.return_value.stream = mock.Mock(side_effect=mock_stream) - # Check for logs - stdout, _ = capsys.readouterr() + async with KubernetesWorker(work_pool_name="test") as k8s_worker: + result = await k8s_worker.run(flow_run, default_configuration) - # Before the deadline, logs should be displayed - for i in range(4): - assert f"test {i}" in stdout - for i in range(4, 10): - assert f"test {i}" not in stdout + assert result.status_code == 137 - @pytest.mark.flaky # Rarely, the sleep times we check for do not fit within the tolerances - async def test_watch_timeout_is_restarted_until_job_is_complete( - self, - flow_run, - default_configuration, - mock_core_client, - mock_watch, - mock_batch_client, - mock_anyio_sleep_monotonic, - ): - # The job should not be completed to start - mock_batch_client.read_namespaced_job.return_value.status.completion_time = None + async def test_watch_handles_no_pod( + self, + flow_run, + default_configuration, + mock_core_client, + mock_watch, + mock_batch_client, + mock_pod, + ): + # The job should not be completed to start + mock_batch_client.return_value.read_namespaced_job.return_value.status.completion_time = None + mock_core_client.return_value.list_namespaced_pod.return_value.items = [] - def mock_stream(*args, **kwargs): - if kwargs["func"] == mock_core_client.list_namespaced_pod: - job_pod = MagicMock(spec=kubernetes.client.V1Pod) - job_pod.status.phase = "Running" - yield {"object": job_pod} + # TODO investigate why it needs type + async def mock_stream(*args, **kwargs): + if kwargs["func"] == mock_core_client.return_value.list_namespaced_pod: + yield {"object": mock_pod, "type": "ADDED"} - if kwargs["func"] == mock_batch_client.list_namespaced_job: - job = MagicMock(spec=kubernetes.client.V1Job) + if kwargs["func"] == mock_batch_client.return_value.list_namespaced_job: + job = MagicMock(spec=kubernetes_asyncio.client.V1Job) - # Sleep a little - anyio.sleep(10) + # Yield the job then return exiting the stream + job.status.completion_time = None + job.spec.backoff_limit = 6 + for i in range(0, 8): + job.status.failed = i + yield {"object": job, "type": "ADDED"} - # Yield the job then return exiting the stream - job.status.completion_time = None - job.status.failed = 0 - job.spec.backoff_limit = 6 - yield {"object": job, "type": "ADDED"} + mock_watch.return_value.stream = mock.Mock(side_effect=mock_stream) - mock_watch.stream.side_effect = mock_stream - default_configuration.job_watch_timeout_seconds = 40 - async with KubernetesWorker(work_pool_name="test") as k8s_worker: - result = await k8s_worker.run(flow_run, default_configuration) + async with KubernetesWorker(work_pool_name="test") as k8s_worker: + result = await k8s_worker.run(flow_run, default_configuration) - assert result.status_code == -1 + assert result.status_code == -1 - mock_watch.stream.assert_has_calls( - [ - mock.call( - func=mock_core_client.list_namespaced_pod, - namespace=mock.ANY, - label_selector=mock.ANY, - timeout_seconds=mock.ANY, - ), - # Starts with the full timeout - mock.call( - func=mock_batch_client.list_namespaced_job, - field_selector=mock.ANY, - namespace=mock.ANY, - timeout_seconds=pytest.approx(40, abs=1), - ), - mock.call( - func=mock_batch_client.list_namespaced_job, - field_selector=mock.ANY, - namespace=mock.ANY, - timeout_seconds=pytest.approx(30, abs=1), - ), - # Then, elapsed time removed on each call - mock.call( - func=mock_batch_client.list_namespaced_job, - field_selector=mock.ANY, - namespace=mock.ANY, - timeout_seconds=pytest.approx(20, abs=1), - ), - mock.call( - func=mock_batch_client.list_namespaced_job, - field_selector=mock.ANY, - namespace=mock.ANY, - timeout_seconds=pytest.approx(10, abs=1), - ), + async def test_watch_handles_pod_without_exit_code( + self, + flow_run, + default_configuration, + mock_core_client, + mock_watch, + mock_batch_client, + mock_pod, + ): + """ + This test case mimics the behavior of a pod that has been forcefully terminated + (i.e. AWS spot instance termination or another node failure). + """ + mock_batch_client.return_value.read_namespaced_job.return_value.status.completion_time = None + job_pod = MagicMock(spec=kubernetes_asyncio.client.V1Pod) + job_pod.status.phase = "Running" + mock_container_status = MagicMock( + spec=kubernetes_asyncio.client.V1ContainerStatus + ) + # The container may exist but because it has been forcefully terminated + # it will not have an exit code. + mock_container_status.state.terminated = None + job_pod.status.container_statuses = [mock_container_status] + mock_core_client.return_value.list_namespaced_pod.return_value.items = [ + job_pod ] - ) - async def test_watch_stops_after_backoff_limit_reached( - self, - flow_run, - default_configuration, - mock_core_client, - mock_watch, - mock_batch_client, - ): - # The job should not be completed to start - mock_batch_client.read_namespaced_job.return_value.status.completion_time = None - job_pod = MagicMock(spec=kubernetes.client.V1Pod) - job_pod.status.phase = "Running" - mock_container_status = MagicMock(spec=kubernetes.client.V1ContainerStatus) - mock_container_status.state.terminated.exit_code = 137 - job_pod.status.container_statuses = [mock_container_status] - mock_core_client.list_namespaced_pod.return_value.items = [job_pod] - - def mock_stream(*args, **kwargs): - if kwargs["func"] == mock_core_client.list_namespaced_pod: - yield {"object": job_pod} - - if kwargs["func"] == mock_batch_client.list_namespaced_job: - job = MagicMock(spec=kubernetes.client.V1Job) - - # Yield the job then return exiting the stream - job.status.completion_time = None - job.spec.backoff_limit = 6 - for i in range(0, 8): - job.status.failed = i - yield {"object": job, "type": "ADDED"} + # TODO investigate why it needs type + async def mock_stream(*args, **kwargs): + if kwargs["func"] == mock_core_client.return_value.list_namespaced_pod: + yield {"object": mock_pod, "type": "ADDED"} - mock_watch.stream.side_effect = mock_stream + if kwargs["func"] == mock_batch_client.return_value.list_namespaced_job: + job = MagicMock(spec=kubernetes_asyncio.client.V1Job) - async with KubernetesWorker(work_pool_name="test") as k8s_worker: - result = await k8s_worker.run(flow_run, default_configuration) + # Yield the job then return exiting the stream + job.status.completion_time = None + job.spec.backoff_limit = 6 + for i in range(0, 8): + job.status.failed = i + yield {"object": job, "type": "ADDED"} - assert result.status_code == 137 + mock_watch.return_value.stream = mock.Mock(side_effect=mock_stream) - async def test_watch_handles_no_pod( - self, - flow_run, - default_configuration, - mock_core_client, - mock_watch, - mock_batch_client, - ): - # The job should not be completed to start - mock_batch_client.read_namespaced_job.return_value.status.completion_time = None - mock_core_client.list_namespaced_pod.return_value.items = [] - - def mock_stream(*args, **kwargs): - if kwargs["func"] == mock_core_client.list_namespaced_pod: - job_pod = MagicMock(spec=kubernetes.client.V1Pod) - job_pod.status.phase = "Running" - yield {"object": job_pod} - - if kwargs["func"] == mock_batch_client.list_namespaced_job: - job = MagicMock(spec=kubernetes.client.V1Job) - - # Yield the job then return exiting the stream - job.status.completion_time = None - job.spec.backoff_limit = 6 - for i in range(0, 8): - job.status.failed = i - yield {"object": job, "type": "ADDED"} - - mock_watch.stream.side_effect = mock_stream - - async with KubernetesWorker(work_pool_name="test") as k8s_worker: - result = await k8s_worker.run(flow_run, default_configuration) - - assert result.status_code == -1 - - async def test_watch_handles_pod_without_exit_code( - self, - flow_run, - default_configuration, - mock_core_client, - mock_watch, - mock_batch_client, - ): - """ - This test case mimics the behavior of a pod that has been forcefully terminated - (i.e. AWS spot instance termination or another node failure). - """ - mock_batch_client.read_namespaced_job.return_value.status.completion_time = None - job_pod = MagicMock(spec=kubernetes.client.V1Pod) - job_pod.status.phase = "Running" - mock_container_status = MagicMock(spec=kubernetes.client.V1ContainerStatus) - # The container may exist but because it has been forcefully terminated - # it will not have an exit code. - mock_container_status.state.terminated = None - job_pod.status.container_statuses = [mock_container_status] - mock_core_client.list_namespaced_pod.return_value.items = [job_pod] - - def mock_stream(*args, **kwargs): - if kwargs["func"] == mock_core_client.list_namespaced_pod: - job_pod = MagicMock(spec=kubernetes.client.V1Pod) - job_pod.status.phase = "Running" - yield {"object": job_pod} - - if kwargs["func"] == mock_batch_client.list_namespaced_job: - job = MagicMock(spec=kubernetes.client.V1Job) - - # Yield the job then return exiting the stream - job.status.completion_time = None - job.spec.backoff_limit = 6 - for i in range(0, 8): - job.status.failed = i - yield {"object": job, "type": "ADDED"} - - mock_watch.stream.side_effect = mock_stream - - async with KubernetesWorker(work_pool_name="test") as k8s_worker: - result = await k8s_worker.run(flow_run, default_configuration) - - assert result.status_code == -1 + async with KubernetesWorker(work_pool_name="test") as k8s_worker: + result = await k8s_worker.run(flow_run, default_configuration) - async def test_watch_handles_410( - self, - default_configuration: KubernetesWorkerJobConfiguration, - flow_run, - mock_batch_client, - mock_core_client, - mock_watch, - ): - mock_watch.stream.side_effect = [ - _mock_pods_stream_that_returns_running_pod(), - _mock_pods_stream_that_returns_running_pod(), - ApiException(status=410), - _mock_pods_stream_that_returns_running_pod(), - ] + assert result.status_code == -1 - job_list = MagicMock(spec=kubernetes.client.V1JobList) - job_list.metadata.resource_version = "1" + async def test_watch_handles_410( + self, + default_configuration: KubernetesWorkerJobConfiguration, + flow_run, + mock_batch_client, + mock_core_client, + mock_watch, + mock_job, + mock_pod, + ): + async def mock_stream(*args, **kwargs): + mock_job.status.completion_time = pendulum.now("utc").timestamp() + items = [ + {"object": mock_pod, "type": "MODIFIED"}, + {"object": mock_job, "type": "MODIFIED"}, + ] + for item in items: + yield item + + stream_return = [ + mock_stream(), + mock_stream(), + ApiException(status=410), + mock_stream(), + ] + mock_watch.return_value.stream = mock.Mock(side_effect=stream_return) + job_list = MagicMock(spec=kubernetes_asyncio.client.V1JobList) + job_list.metadata.resource_version = "1" - mock_batch_client.list_namespaced_job.side_effect = [job_list] + mock_batch_client.return_value.list_namespaced_job.side_effect = [job_list] - # The job should not be completed to start - mock_batch_client.read_namespaced_job.return_value.status.completion_time = None + # The job should not be completed to start + mock_batch_client.return_value.read_namespaced_job.return_value.status.completion_time = None - async with KubernetesWorker(work_pool_name="test") as k8s_worker: - await k8s_worker.run(flow_run=flow_run, configuration=default_configuration) + async with KubernetesWorker(work_pool_name="test") as k8s_worker: + await k8s_worker.run( + flow_run=flow_run, configuration=default_configuration + ) - mock_watch.stream.assert_has_calls( - [ - mock.call( - func=mock_batch_client.list_namespaced_job, - namespace=mock.ANY, - field_selector="metadata.name=mock-job", - ), - mock.call( - func=mock_batch_client.list_namespaced_job, - namespace=mock.ANY, - field_selector="metadata.name=mock-job", - resource_version="1", - ), - ] - ) + mock_watch.return_value.stream.assert_has_calls( + [ + mock.call( + func=mock_batch_client.return_value.list_namespaced_job, + namespace=mock.ANY, + field_selector="metadata.name=mock-job", + ), + mock.call( + func=mock_batch_client.return_value.list_namespaced_job, + namespace=mock.ANY, + field_selector="metadata.name=mock-job", + resource_version="1", + ), + ] + ) class TestKillInfrastructure: async def test_kill_infrastructure_calls_delete_namespaced_job( @@ -2642,7 +2661,6 @@ async def test_kill_infrastructure_calls_delete_namespaced_job( mock_batch_client, mock_core_client, mock_watch, - monkeypatch, ): async with KubernetesWorker(work_pool_name="test") as k8s_worker: await k8s_worker.kill_infrastructure( @@ -2652,7 +2670,7 @@ async def test_kill_infrastructure_calls_delete_namespaced_job( ) assert len(mock_batch_client.mock_calls) == 1 - mock_batch_client.delete_namespaced_job.assert_called_once_with( + mock_batch_client.return_value.delete_namespaced_job.assert_called_once_with( name="mock-k8s-v1-job", namespace="default", grace_period_seconds=0, @@ -2665,7 +2683,6 @@ async def test_kill_infrastructure_uses_correct_grace_seconds( mock_batch_client, mock_core_client, mock_watch, - monkeypatch, ): GRACE_SECONDS = 42 async with KubernetesWorker(work_pool_name="test") as k8s_worker: @@ -2676,7 +2693,7 @@ async def test_kill_infrastructure_uses_correct_grace_seconds( ) assert len(mock_batch_client.mock_calls) == 1 - mock_batch_client.delete_namespaced_job.assert_called_once_with( + mock_batch_client.return_value.delete_namespaced_job.assert_called_once_with( name="mock-k8s-v1-job", namespace="default", grace_period_seconds=GRACE_SECONDS, @@ -2689,7 +2706,6 @@ async def test_kill_infrastructure_raises_infra_not_available_on_mismatched_clus mock_batch_client, mock_core_client, mock_watch, - monkeypatch, ): BAD_NAMESPACE = "dog" @@ -2715,7 +2731,6 @@ async def test_kill_infrastructure_raises_infra_not_available_on_mismatched_clus mock_batch_client, mock_core_client, mock_watch, - monkeypatch, ): BAD_CLUSTER = "4321" @@ -2738,9 +2753,8 @@ async def test_kill_infrastructure_raises_infrastructure_not_found_on_404( mock_batch_client, mock_core_client, mock_watch, - monkeypatch, ): - mock_batch_client.delete_namespaced_job.side_effect = [ + mock_batch_client.return_value.delete_namespaced_job.side_effect = [ ApiException(status=404) ] @@ -2766,9 +2780,8 @@ async def test_kill_infrastructure_passes_other_k8s_api_errors_through( mock_batch_client, mock_core_client, mock_watch, - monkeypatch, ): - mock_batch_client.delete_namespaced_job.side_effect = [ + mock_batch_client.return_value.delete_namespaced_job.side_effect = [ ApiException(status=400) ] @@ -2785,89 +2798,91 @@ def handle_api_error(exc: ExceptionGroup): ) @pytest.fixture - def mock_events(self, mock_core_client): - mock_core_client.list_namespaced_event.return_value = CoreV1EventList( - metadata=V1ListMeta(resource_version="1"), - items=[ - CoreV1Event( - metadata=V1ObjectMeta(), - involved_object=V1ObjectReference( - api_version="batch/v1", - kind="Job", - namespace="default", - name="mock-job", + async def mock_events(self, mock_core_client): + mock_core_client.return_value.list_namespaced_event.return_value = ( + CoreV1EventList( + metadata=V1ListMeta(resource_version="1"), + items=[ + CoreV1Event( + metadata=V1ObjectMeta(), + involved_object=V1ObjectReference( + api_version="batch/v1", + kind="Job", + namespace="default", + name="mock-job", + ), + reason="StuffBlewUp", + count=2, + last_timestamp=pendulum.parse("2022-01-02T03:04:05Z"), + message="Whew, that was baaaaad", ), - reason="StuffBlewUp", - count=2, - last_timestamp=pendulum.parse("2022-01-02T03:04:05Z"), - message="Whew, that was baaaaad", - ), - CoreV1Event( - metadata=V1ObjectMeta(), - involved_object=V1ObjectReference( - api_version="batch/v1", - kind="Job", - namespace="default", - name="this-aint-me", # not my flow run ID + CoreV1Event( + metadata=V1ObjectMeta(), + involved_object=V1ObjectReference( + api_version="batch/v1", + kind="Job", + namespace="default", + name="this-aint-me", # not my flow run ID + ), + reason="NahChief", + count=2, + last_timestamp=pendulum.parse("2022-01-02T03:04:05Z"), + message="You do not want to know about this one", ), - reason="NahChief", - count=2, - last_timestamp=pendulum.parse("2022-01-02T03:04:05Z"), - message="You do not want to know about this one", - ), - CoreV1Event( - metadata=V1ObjectMeta(), - involved_object=V1ObjectReference( - api_version="v1", - kind="Pod", - namespace="default", - name="my-pod", + CoreV1Event( + metadata=V1ObjectMeta(), + involved_object=V1ObjectReference( + api_version="v1", + kind="Pod", + namespace="default", + name="my-pod", + ), + reason="ImageWhatImage", + count=1, + event_time=pendulum.parse("2022-01-02T03:04:05Z"), + message="I don't see no image", ), - reason="ImageWhatImage", - count=1, - event_time=pendulum.parse("2022-01-02T03:04:05Z"), - message="I don't see no image", - ), - CoreV1Event( - metadata=V1ObjectMeta(), - involved_object=V1ObjectReference( - api_version="v1", - kind="Pod", - namespace="default", - name="my-pod", + CoreV1Event( + metadata=V1ObjectMeta(), + involved_object=V1ObjectReference( + api_version="v1", + kind="Pod", + namespace="default", + name="my-pod", + ), + reason="GoodLuck", + count=1, + last_timestamp=pendulum.parse("2022-01-02T03:04:05Z"), + message="You ain't getting no more RAM", ), - reason="GoodLuck", - count=1, - last_timestamp=pendulum.parse("2022-01-02T03:04:05Z"), - message="You ain't getting no more RAM", - ), - CoreV1Event( - metadata=V1ObjectMeta(), - involved_object=V1ObjectReference( - api_version="v1", - kind="Pod", - namespace="default", - name="somebody-else", # not my pod + CoreV1Event( + metadata=V1ObjectMeta(), + involved_object=V1ObjectReference( + api_version="v1", + kind="Pod", + namespace="default", + name="somebody-else", # not my pod + ), + reason="NotMeDude", + count=1, + last_timestamp=pendulum.parse("2022-01-02T03:04:05Z"), + message="You ain't getting no more RAM", ), - reason="NotMeDude", - count=1, - last_timestamp=pendulum.parse("2022-01-02T03:04:05Z"), - message="You ain't getting no more RAM", - ), - CoreV1Event( - metadata=V1ObjectMeta(), - involved_object=V1ObjectReference( - api_version="batch/v1", - kind="Job", - namespace="default", - name="mock-job", + CoreV1Event( + metadata=V1ObjectMeta(), + involved_object=V1ObjectReference( + api_version="batch/v1", + kind="Job", + namespace="default", + name="mock-job", + ), + reason="StuffBlewUp", + count=2, + last_timestamp=pendulum.parse("2022-01-02T03:04:05Z"), + message="I mean really really bad", ), - reason="StuffBlewUp", - count=2, - last_timestamp=pendulum.parse("2022-01-02T03:04:05Z"), - message="I mean really really bad", - ), - ], + ], + ) ) async def test_explains_what_might_have_gone_wrong_in_scheduling_the_pod( @@ -2882,6 +2897,13 @@ async def test_explains_what_might_have_gone_wrong_in_scheduling_the_pod( ): """Regression test for #87, where workers were giving only very vague information about the reason a pod was never scheduled.""" + + async def mock_stream(*args, **kwargs): + if kwargs["func"] == mock_batch_client.return_value.list_namespaced_job: + job = MagicMock(spec=kubernetes_asyncio.client.V1Job) + yield {"object": job, "type": "ADDED"} + + mock_watch.return_value.stream = mock.Mock(side_effect=mock_stream) async with KubernetesWorker(work_pool_name="test") as k8s_worker: await k8s_worker.run( flow_run=flow_run, @@ -2889,7 +2911,7 @@ async def test_explains_what_might_have_gone_wrong_in_scheduling_the_pod( task_status=MagicMock(spec=anyio.abc.TaskStatus), ) - mock_core_client.list_namespaced_event.assert_called_once_with( + mock_core_client.return_value.list_namespaced_event.assert_called_once_with( default_configuration.namespace ) @@ -2919,7 +2941,7 @@ async def test_explains_what_might_have_gone_wrong_in_starting_the_pod( logger = k8s_worker.get_flow_run_logger(flow_run) mock_client = mock.Mock() - k8s_worker._log_recent_events( + await k8s_worker._log_recent_events( logger, "mock-job", "my-pod", default_configuration, mock_client )