diff --git a/src/integrations/prefect-kubernetes/prefect_kubernetes/utilities.py b/src/integrations/prefect-kubernetes/prefect_kubernetes/utilities.py index 879b8cb3e030..ba2c6b1fdd5e 100644 --- a/src/integrations/prefect-kubernetes/prefect_kubernetes/utilities.py +++ b/src/integrations/prefect-kubernetes/prefect_kubernetes/utilities.py @@ -4,7 +4,9 @@ import sys from typing import Optional, TypeVar -from kubernetes_asyncio.client import ApiClient +from aiohttp import ClientResponse +from aiohttp.client_reqrep import ClientRequest +from aiohttp.connector import Connection from slugify import slugify # Note: `dict(str, str)` is the Kubernetes API convention for @@ -14,34 +16,34 @@ V1KubernetesModel = TypeVar("V1KubernetesModel") -def enable_socket_keep_alive(client: ApiClient) -> None: +class KeepAliveClientRequest(ClientRequest): """ - Setting the keep-alive flags on the kubernetes client object. - Unfortunately neither the kubernetes library nor the urllib3 library which - kubernetes is using internally offer the functionality to enable keep-alive - messages. Thus the flags are added to be used on the underlying sockets. + aiohttp only directly implements socket keepalive for incoming connections + in its RequestHandler. For client connections, we need to set the keepalive + ourselves. + Refer to https://github.com/aio-libs/aiohttp/issues/3904#issuecomment-759205696 """ - socket_options = [(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)] + async def send(self, conn: Connection) -> ClientResponse: + sock = conn.protocol.transport.get_extra_info("socket") + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - if hasattr(socket, "TCP_KEEPINTVL"): - socket_options.append((socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 30)) + if hasattr(socket, "TCP_KEEPIDLE"): + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 30) - if hasattr(socket, "TCP_KEEPCNT"): - socket_options.append((socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 6)) + if hasattr(socket, "TCP_KEEPINTVL"): + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 6) - if hasattr(socket, "TCP_KEEPIDLE"): - socket_options.append((socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 6)) + if hasattr(socket, "TCP_KEEPCNT"): + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 6) - if sys.platform == "darwin": - # TCP_KEEP_ALIVE not available on socket module in macOS, but defined in - # https://github.com/apple/darwin-xnu/blob/2ff845c2e033bd0ff64b5b6aa6063a1f8f65aa32/bsd/netinet/tcp.h#L215 - TCP_KEEP_ALIVE = 0x10 - socket_options.append((socket.IPPROTO_TCP, TCP_KEEP_ALIVE, 30)) + if sys.platform == "darwin": + # TCP_KEEP_ALIVE not available on socket module in macOS, but defined in + # https://github.com/apple/darwin-xnu/blob/2ff845c2e033bd0ff64b5b6aa6063a1f8f65aa32/bsd/netinet/tcp.h#L215 + TCP_KEEP_ALIVE = 0x10 + sock.setsockopt(socket.IPPROTO_TCP, TCP_KEEP_ALIVE, 30) - client.rest_client.pool_manager.connection_pool_kw[ - "socket_options" - ] = socket_options + return await super().send(conn) def _slugify_name(name: str, max_length: int = 45) -> Optional[str]: diff --git a/src/integrations/prefect-kubernetes/prefect_kubernetes/worker.py b/src/integrations/prefect-kubernetes/prefect_kubernetes/worker.py index 67e5bbc3e965..a9c9a532a5e2 100644 --- a/src/integrations/prefect-kubernetes/prefect_kubernetes/worker.py +++ b/src/integrations/prefect-kubernetes/prefect_kubernetes/worker.py @@ -167,6 +167,7 @@ from prefect_kubernetes.credentials import KubernetesClusterConfig from prefect_kubernetes.events import KubernetesEventsReplicator from prefect_kubernetes.utilities import ( + KeepAliveClientRequest, _slugify_label_key, _slugify_label_value, _slugify_name, @@ -739,6 +740,12 @@ async def _get_configured_kubernetes_client( except config.ConfigException: # If in-cluster config fails, load the local kubeconfig client = await config.new_client_from_config() + + if os.environ.get( + "PREFECT_KUBERNETES_WORKER_ADD_TCP_KEEPALIVE", "TRUE" + ).strip().lower() in ("true", "1"): + client.rest_client.pool_manager._request_class = KeepAliveClientRequest + try: yield client finally: diff --git a/src/integrations/prefect-kubernetes/tests/test_utilities.py b/src/integrations/prefect-kubernetes/tests/test_utilities.py index bccecc14c59f..62980fc0dc20 100644 --- a/src/integrations/prefect-kubernetes/tests/test_utilities.py +++ b/src/integrations/prefect-kubernetes/tests/test_utilities.py @@ -2,9 +2,6 @@ import pytest from kubernetes_asyncio.config import ConfigException -from prefect_kubernetes.utilities import ( - enable_socket_keep_alive, -) FAKE_CLUSTER = "fake-cluster" @@ -26,14 +23,3 @@ def mock_cluster_config(monkeypatch): @pytest.fixture def mock_api_client(mock_cluster_config): return AsyncMock() - - -def test_keep_alive_updates_socket_options(mock_api_client): - enable_socket_keep_alive(mock_api_client) - - assert ( - mock_api_client.rest_client.pool_manager.connection_pool_kw[ - "socket_options" - ]._mock_set_call - is not None - ) diff --git a/src/integrations/prefect-kubernetes/tests/test_worker.py b/src/integrations/prefect-kubernetes/tests/test_worker.py index 47b0cc7a4acc..2f497bb206cb 100644 --- a/src/integrations/prefect-kubernetes/tests/test_worker.py +++ b/src/integrations/prefect-kubernetes/tests/test_worker.py @@ -25,7 +25,11 @@ ) from kubernetes_asyncio.config import ConfigException from prefect_kubernetes import KubernetesWorker -from prefect_kubernetes.utilities import _slugify_label_value, _slugify_name +from prefect_kubernetes.utilities import ( + KeepAliveClientRequest, + _slugify_label_value, + _slugify_name, +) from prefect_kubernetes.worker import KubernetesWorkerJobConfiguration from pydantic import VERSION as PYDANTIC_VERSION @@ -1512,7 +1516,7 @@ async def test_can_store_api_key_in_secret( ) # Make sure secret gets deleted - assert mock_core_client.return_value.delete_namespaced_secret( + assert await mock_core_client.return_value.delete_namespaced_secret( name=f"prefect-{_slugify_name(k8s_worker.name)}-api-key", namespace=configuration.namespace, ) @@ -2115,6 +2119,24 @@ async def test_uses_specified_image_pull_policy( ) assert call_image_pull_policy == "IfNotPresent" + @pytest.mark.usefixtures("mock_core_client_lean", "mock_cluster_config") + async def test_keepalive_enabled( + self, + ): + configuration = await KubernetesWorkerJobConfiguration.from_template_and_values( + KubernetesWorker.get_default_base_job_template(), + {"image": "foo"}, + ) + + async with KubernetesWorker(work_pool_name="test") as k8s_worker: + async with k8s_worker._get_configured_kubernetes_client( + configuration + ) as client: + assert ( + client.rest_client.pool_manager._request_class + is KeepAliveClientRequest + ) + async def test_defaults_to_incluster_config( self, flow_run,