From aa00477b7eb5f4e8a8928864ddb7a91bea86e7b6 Mon Sep 17 00:00:00 2001 From: "jake@prefect.io" Date: Thu, 5 Sep 2024 18:16:38 -0400 Subject: [PATCH] check api compatibility --- src/prefect/client/orchestration.py | 53 ++++++++++++++ src/prefect/context.py | 2 + tests/client/test_prefect_client.py | 103 ++++++++++++++++++++++++++++ 3 files changed, 158 insertions(+) diff --git a/src/prefect/client/orchestration.py b/src/prefect/client/orchestration.py index e93881e752c4..31cf18b110ed 100644 --- a/src/prefect/client/orchestration.py +++ b/src/prefect/client/orchestration.py @@ -24,6 +24,7 @@ import pendulum import pydantic from asgi_lifespan import LifespanManager +from packaging import version from starlette import status from typing_extensions import ParamSpec @@ -3329,6 +3330,32 @@ async def read_resource_related_automations( async def delete_resource_owned_automations(self, resource_id: str): await self._client.delete(f"/automations/owned-by/{resource_id}") + async def api_version(self) -> str: + res = await self._client.get("/admin/version") + return res.json() + + def client_version(self) -> str: + return prefect.__version__ + + async def api_compatible(self): + # Cloud is always compatible as a server + if self.server_type == ServerType.CLOUD: + return + + try: + api_version = await self.api_version() + except Exception as e: + raise RuntimeError(f"Failed to reach API at {self.api_url}") from e + + api_version = version.parse(api_version) + client_version = version.parse(self.client_version()) + + if api_version.major != client_version.major: + raise RuntimeError( + f"Client version {client_version} is incompatible with api version {api_version}. " + f"Both client and api must be on the same major version." + ) + async def __aenter__(self): """ Start the client. @@ -3622,6 +3649,32 @@ def hello(self) -> httpx.Response: """ return self._client.get("/hello") + def api_version(self) -> str: + res = self._client.get("/admin/version") + return res.json() + + def client_version(self) -> str: + return prefect.__version__ + + def api_compatible(self): + # Cloud is always compatible as a server + if self.server_type == ServerType.CLOUD: + return + + try: + api_version = self.api_version() + except Exception as e: + raise RuntimeError(f"Failed to reach API at {self.api_url}") from e + + api_version = version.parse(api_version) + client_version = version.parse(self.client_version()) + + if api_version.major != client_version.major: + raise RuntimeError( + f"Client version {client_version} is incompatible with api version {api_version}. " + f"Both client and api must be on the same major version." + ) + def create_flow(self, flow: "FlowObject") -> UUID: """ Create a flow in the Prefect API. diff --git a/src/prefect/context.py b/src/prefect/context.py index 022d9776a578..cd52569713d7 100644 --- a/src/prefect/context.py +++ b/src/prefect/context.py @@ -210,6 +210,7 @@ def __enter__(self): self._context_stack += 1 if self._context_stack == 1: self.client.__enter__() + self.client.api_compatible() return super().__enter__() else: return self @@ -267,6 +268,7 @@ async def __aenter__(self): self._context_stack += 1 if self._context_stack == 1: await self.client.__aenter__() + await self.client.api_compatible() return super().__enter__() else: return self diff --git a/tests/client/test_prefect_client.py b/tests/client/test_prefect_client.py index a89835cd9fb8..b9ae3ab57a68 100644 --- a/tests/client/test_prefect_client.py +++ b/tests/client/test_prefect_client.py @@ -2563,6 +2563,55 @@ async def test_disabled_setting_disabled(self, hosted_api_server): assert not prefect_client._client.enable_csrf_support +class TestPrefectClientAPICompatibility: + async def test_api_compatible(self, prefect_client): + await prefect_client.api_compatible() + + async def test_api_compatible_when_api_unreachable( + self, prefect_client, monkeypatch + ): + async def something_went_wrong(*args, **kwargs): + raise httpx.ConnectError + + monkeypatch.setattr(prefect_client, "api_version", something_went_wrong) + with pytest.raises(RuntimeError) as e: + await prefect_client.api_compatible() + + assert "Failed to reach API" in str(e.value) + + async def test_api_compatible_against_cloud(self, prefect_client, monkeypatch): + # manually set the server type to cloud + monkeypatch.setattr(prefect_client, "server_type", ServerType.CLOUD) + + api_version_mock = AsyncMock() + monkeypatch.setattr(prefect_client, "api_version", api_version_mock) + + await prefect_client.api_compatible() + + api_version_mock.assert_not_called() + + @pytest.mark.parametrize( + "client_version, api_version", [("3.0.0", "2.0.0"), ("2.0.0", "3.0.0")] + ) + async def test_api_compatible_with_incompatible_versions( + self, prefect_client, monkeypatch, client_version, api_version + ): + monkeypatch.setattr( + prefect_client, "api_version", AsyncMock(return_value=api_version) + ) + monkeypatch.setattr( + prefect_client, "client_version", Mock(return_value=client_version) + ) + + with pytest.raises(RuntimeError) as e: + await prefect_client.api_compatible() + + assert ( + f"Client version {client_version} is incompatible with api version {api_version}" + in str(e.value) + ) + + class TestSyncClient: def test_get_sync_client(self): client = get_client(sync_client=True) @@ -2574,3 +2623,57 @@ def test_fixture_is_sync(self, sync_prefect_client): def test_hello(self, sync_prefect_client): response = sync_prefect_client.hello() assert response.json() == "👋" + + def test_api_version(self, sync_prefect_client): + version = sync_prefect_client.api_version() + assert prefect.__version__ + assert version == prefect.__version__ + + +class TestSyncClientAPICompatible: + def test_api_compatible(self, sync_prefect_client): + sync_prefect_client.api_compatible() + + def test_api_compatible_when_api_unreachable( + self, sync_prefect_client, monkeypatch + ): + def something_went_wrong(*args, **kwargs): + raise httpx.ConnectError + + monkeypatch.setattr(sync_prefect_client, "api_version", something_went_wrong) + with pytest.raises(RuntimeError) as e: + sync_prefect_client.api_compatible() + + assert "Failed to reach API" in str(e.value) + + def test_api_compatible_against_cloud(self, sync_prefect_client, monkeypatch): + # manually set the server type to cloud + monkeypatch.setattr(sync_prefect_client, "server_type", ServerType.CLOUD) + + api_version_mock = Mock() + monkeypatch.setattr(sync_prefect_client, "api_version", api_version_mock) + + sync_prefect_client.api_compatible() + + api_version_mock.assert_not_called() + + @pytest.mark.parametrize( + "client_version, api_version", [("3.0.0", "2.0.0"), ("2.0.0", "3.0.0")] + ) + def test_api_compatible_with_incompatible_versions( + self, sync_prefect_client, monkeypatch, client_version, api_version + ): + monkeypatch.setattr( + sync_prefect_client, "api_version", Mock(return_value=api_version) + ) + monkeypatch.setattr( + sync_prefect_client, "client_version", Mock(return_value=client_version) + ) + + with pytest.raises(RuntimeError) as e: + sync_prefect_client.api_compatible() + + assert ( + f"Client version {client_version} is incompatible with api version {api_version}" + in str(e.value) + )