Skip to content

Commit

Permalink
check api compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
jakekaplan committed Sep 5, 2024
1 parent ce50e5e commit aa00477
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 0 deletions.
53 changes: 53 additions & 0 deletions src/prefect/client/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import pendulum
import pydantic
from asgi_lifespan import LifespanManager
from packaging import version
from starlette import status
from typing_extensions import ParamSpec

Expand Down Expand Up @@ -3329,6 +3330,32 @@ async def read_resource_related_automations(
async def delete_resource_owned_automations(self, resource_id: str):
await self._client.delete(f"/automations/owned-by/{resource_id}")

async def api_version(self) -> str:
res = await self._client.get("/admin/version")
return res.json()

def client_version(self) -> str:
return prefect.__version__

async def api_compatible(self):
# Cloud is always compatible as a server
if self.server_type == ServerType.CLOUD:
return

try:
api_version = await self.api_version()
except Exception as e:
raise RuntimeError(f"Failed to reach API at {self.api_url}") from e

api_version = version.parse(api_version)
client_version = version.parse(self.client_version())

if api_version.major != client_version.major:
raise RuntimeError(
f"Client version {client_version} is incompatible with api version {api_version}. "
f"Both client and api must be on the same major version."
)

async def __aenter__(self):
"""
Start the client.
Expand Down Expand Up @@ -3622,6 +3649,32 @@ def hello(self) -> httpx.Response:
"""
return self._client.get("/hello")

def api_version(self) -> str:
res = self._client.get("/admin/version")
return res.json()

def client_version(self) -> str:
return prefect.__version__

def api_compatible(self):
# Cloud is always compatible as a server
if self.server_type == ServerType.CLOUD:
return

try:
api_version = self.api_version()
except Exception as e:
raise RuntimeError(f"Failed to reach API at {self.api_url}") from e

api_version = version.parse(api_version)
client_version = version.parse(self.client_version())

if api_version.major != client_version.major:
raise RuntimeError(
f"Client version {client_version} is incompatible with api version {api_version}. "
f"Both client and api must be on the same major version."
)

def create_flow(self, flow: "FlowObject") -> UUID:
"""
Create a flow in the Prefect API.
Expand Down
2 changes: 2 additions & 0 deletions src/prefect/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def __enter__(self):
self._context_stack += 1
if self._context_stack == 1:
self.client.__enter__()
self.client.api_compatible()
return super().__enter__()
else:
return self
Expand Down Expand Up @@ -267,6 +268,7 @@ async def __aenter__(self):
self._context_stack += 1
if self._context_stack == 1:
await self.client.__aenter__()
await self.client.api_compatible()
return super().__enter__()
else:
return self
Expand Down
103 changes: 103 additions & 0 deletions tests/client/test_prefect_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2563,6 +2563,55 @@ async def test_disabled_setting_disabled(self, hosted_api_server):
assert not prefect_client._client.enable_csrf_support


class TestPrefectClientAPICompatibility:
async def test_api_compatible(self, prefect_client):
await prefect_client.api_compatible()

async def test_api_compatible_when_api_unreachable(
self, prefect_client, monkeypatch
):
async def something_went_wrong(*args, **kwargs):
raise httpx.ConnectError

monkeypatch.setattr(prefect_client, "api_version", something_went_wrong)
with pytest.raises(RuntimeError) as e:
await prefect_client.api_compatible()

assert "Failed to reach API" in str(e.value)

async def test_api_compatible_against_cloud(self, prefect_client, monkeypatch):
# manually set the server type to cloud
monkeypatch.setattr(prefect_client, "server_type", ServerType.CLOUD)

api_version_mock = AsyncMock()
monkeypatch.setattr(prefect_client, "api_version", api_version_mock)

await prefect_client.api_compatible()

api_version_mock.assert_not_called()

@pytest.mark.parametrize(
"client_version, api_version", [("3.0.0", "2.0.0"), ("2.0.0", "3.0.0")]
)
async def test_api_compatible_with_incompatible_versions(
self, prefect_client, monkeypatch, client_version, api_version
):
monkeypatch.setattr(
prefect_client, "api_version", AsyncMock(return_value=api_version)
)
monkeypatch.setattr(
prefect_client, "client_version", Mock(return_value=client_version)
)

with pytest.raises(RuntimeError) as e:
await prefect_client.api_compatible()

assert (
f"Client version {client_version} is incompatible with api version {api_version}"
in str(e.value)
)


class TestSyncClient:
def test_get_sync_client(self):
client = get_client(sync_client=True)
Expand All @@ -2574,3 +2623,57 @@ def test_fixture_is_sync(self, sync_prefect_client):
def test_hello(self, sync_prefect_client):
response = sync_prefect_client.hello()
assert response.json() == "👋"

def test_api_version(self, sync_prefect_client):
version = sync_prefect_client.api_version()
assert prefect.__version__
assert version == prefect.__version__


class TestSyncClientAPICompatible:
def test_api_compatible(self, sync_prefect_client):
sync_prefect_client.api_compatible()

def test_api_compatible_when_api_unreachable(
self, sync_prefect_client, monkeypatch
):
def something_went_wrong(*args, **kwargs):
raise httpx.ConnectError

monkeypatch.setattr(sync_prefect_client, "api_version", something_went_wrong)
with pytest.raises(RuntimeError) as e:
sync_prefect_client.api_compatible()

assert "Failed to reach API" in str(e.value)

def test_api_compatible_against_cloud(self, sync_prefect_client, monkeypatch):
# manually set the server type to cloud
monkeypatch.setattr(sync_prefect_client, "server_type", ServerType.CLOUD)

api_version_mock = Mock()
monkeypatch.setattr(sync_prefect_client, "api_version", api_version_mock)

sync_prefect_client.api_compatible()

api_version_mock.assert_not_called()

@pytest.mark.parametrize(
"client_version, api_version", [("3.0.0", "2.0.0"), ("2.0.0", "3.0.0")]
)
def test_api_compatible_with_incompatible_versions(
self, sync_prefect_client, monkeypatch, client_version, api_version
):
monkeypatch.setattr(
sync_prefect_client, "api_version", Mock(return_value=api_version)
)
monkeypatch.setattr(
sync_prefect_client, "client_version", Mock(return_value=client_version)
)

with pytest.raises(RuntimeError) as e:
sync_prefect_client.api_compatible()

assert (
f"Client version {client_version} is incompatible with api version {api_version}"
in str(e.value)
)

0 comments on commit aa00477

Please sign in to comment.