Skip to content

Commit

Permalink
check api compatibility when entering a client context (#15252)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakekaplan authored Sep 6, 2024
1 parent c0b2e47 commit c9fb300
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 0 deletions.
53 changes: 53 additions & 0 deletions src/prefect/client/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import pendulum
import pydantic
from asgi_lifespan import LifespanManager
from packaging import version
from starlette import status
from typing_extensions import ParamSpec

Expand Down Expand Up @@ -3329,6 +3330,32 @@ async def read_resource_related_automations(
async def delete_resource_owned_automations(self, resource_id: str):
await self._client.delete(f"/automations/owned-by/{resource_id}")

async def api_version(self) -> str:
res = await self._client.get("/admin/version")
return res.json()

def client_version(self) -> str:
return prefect.__version__

async def raise_for_api_version_mismatch(self):
# Cloud is always compatible as a server
if self.server_type == ServerType.CLOUD:
return

try:
api_version = await self.api_version()
except Exception as e:
raise RuntimeError(f"Failed to reach API at {self.api_url}") from e

api_version = version.parse(api_version)
client_version = version.parse(self.client_version())

if api_version.major != client_version.major:
raise RuntimeError(
f"Found incompatible versions: client: {client_version}, server: {api_version}. "
f"Major versions must match."
)

async def __aenter__(self):
"""
Start the client.
Expand Down Expand Up @@ -3622,6 +3649,32 @@ def hello(self) -> httpx.Response:
"""
return self._client.get("/hello")

def api_version(self) -> str:
res = self._client.get("/admin/version")
return res.json()

def client_version(self) -> str:
return prefect.__version__

def raise_for_api_version_mismatch(self):
# Cloud is always compatible as a server
if self.server_type == ServerType.CLOUD:
return

try:
api_version = self.api_version()
except Exception as e:
raise RuntimeError(f"Failed to reach API at {self.api_url}") from e

api_version = version.parse(api_version)
client_version = version.parse(self.client_version())

if api_version.major != client_version.major:
raise RuntimeError(
f"Found incompatible versions: client: {client_version}, server: {api_version}. "
f"Major versions must match."
)

def create_flow(self, flow: "FlowObject") -> UUID:
"""
Create a flow in the Prefect API.
Expand Down
2 changes: 2 additions & 0 deletions src/prefect/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def __enter__(self):
self._context_stack += 1
if self._context_stack == 1:
self.client.__enter__()
self.client.raise_for_api_version_mismatch()
return super().__enter__()
else:
return self
Expand Down Expand Up @@ -267,6 +268,7 @@ async def __aenter__(self):
self._context_stack += 1
if self._context_stack == 1:
await self.client.__aenter__()
await self.client.raise_for_api_version_mismatch()
return super().__enter__()
else:
return self
Expand Down
107 changes: 107 additions & 0 deletions tests/client/test_prefect_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2563,6 +2563,57 @@ async def test_disabled_setting_disabled(self, hosted_api_server):
assert not prefect_client._client.enable_csrf_support


class TestPrefectClientRaiseForAPIVersionMismatch:
async def test_raise_for_api_version_mismatch(self, prefect_client):
await prefect_client.raise_for_api_version_mismatch()

async def test_raise_for_api_version_mismatch_when_api_unreachable(
self, prefect_client, monkeypatch
):
async def something_went_wrong(*args, **kwargs):
raise httpx.ConnectError

monkeypatch.setattr(prefect_client, "api_version", something_went_wrong)
with pytest.raises(RuntimeError) as e:
await prefect_client.raise_for_api_version_mismatch()

assert "Failed to reach API" in str(e.value)

async def test_raise_for_api_version_mismatch_against_cloud(
self, prefect_client, monkeypatch
):
# manually set the server type to cloud
monkeypatch.setattr(prefect_client, "server_type", ServerType.CLOUD)

api_version_mock = AsyncMock()
monkeypatch.setattr(prefect_client, "api_version", api_version_mock)

await prefect_client.raise_for_api_version_mismatch()

api_version_mock.assert_not_called()

@pytest.mark.parametrize(
"client_version, api_version", [("3.0.0", "2.0.0"), ("2.0.0", "3.0.0")]
)
async def test_raise_for_api_version_mismatch_with_incompatible_versions(
self, prefect_client, monkeypatch, client_version, api_version
):
monkeypatch.setattr(
prefect_client, "api_version", AsyncMock(return_value=api_version)
)
monkeypatch.setattr(
prefect_client, "client_version", Mock(return_value=client_version)
)

with pytest.raises(RuntimeError) as e:
await prefect_client.raise_for_api_version_mismatch()

assert (
f"Found incompatible versions: client: {client_version}, server: {api_version}. "
in str(e.value)
)


class TestSyncClient:
def test_get_sync_client(self):
client = get_client(sync_client=True)
Expand All @@ -2574,3 +2625,59 @@ def test_fixture_is_sync(self, sync_prefect_client):
def test_hello(self, sync_prefect_client):
response = sync_prefect_client.hello()
assert response.json() == "👋"

def test_api_version(self, sync_prefect_client):
version = sync_prefect_client.api_version()
assert prefect.__version__
assert version == prefect.__version__


class TestSyncClientRaiseForAPIVersionMismatch:
def test_raise_for_api_version_mismatch(self, sync_prefect_client):
sync_prefect_client.raise_for_api_version_mismatch()

def test_raise_for_api_version_mismatch_when_api_unreachable(
self, sync_prefect_client, monkeypatch
):
def something_went_wrong(*args, **kwargs):
raise httpx.ConnectError

monkeypatch.setattr(sync_prefect_client, "api_version", something_went_wrong)
with pytest.raises(RuntimeError) as e:
sync_prefect_client.raise_for_api_version_mismatch()

assert "Failed to reach API" in str(e.value)

def test_raise_for_api_version_mismatch_against_cloud(
self, sync_prefect_client, monkeypatch
):
# manually set the server type to cloud
monkeypatch.setattr(sync_prefect_client, "server_type", ServerType.CLOUD)

api_version_mock = Mock()
monkeypatch.setattr(sync_prefect_client, "api_version", api_version_mock)

sync_prefect_client.raise_for_api_version_mismatch()

api_version_mock.assert_not_called()

@pytest.mark.parametrize(
"client_version, api_version", [("3.0.0", "2.0.0"), ("2.0.0", "3.0.0")]
)
def test_raise_for_api_version_mismatch_with_incompatible_versions(
self, sync_prefect_client, monkeypatch, client_version, api_version
):
monkeypatch.setattr(
sync_prefect_client, "api_version", Mock(return_value=api_version)
)
monkeypatch.setattr(
sync_prefect_client, "client_version", Mock(return_value=client_version)
)

with pytest.raises(RuntimeError) as e:
sync_prefect_client.raise_for_api_version_mismatch()

assert (
f"Found incompatible versions: client: {client_version}, server: {api_version}. "
in str(e.value)
)

0 comments on commit c9fb300

Please sign in to comment.