Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Warn if websocket connection can't be made #15261

Merged
merged 25 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/prefect/client/subscriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,13 @@ async def _ensure_connected(self):
AssertionError,
websockets.exceptions.ConnectionClosedError,
) as e:
if isinstance(e, AssertionError) or e.code == WS_1008_POLICY_VIOLATION:
if isinstance(e, AssertionError) or e.rcvd.code == WS_1008_POLICY_VIOLATION:
if isinstance(e, AssertionError):
reason = e.args[0]
elif isinstance(e, websockets.exceptions.ConnectionClosedError):
reason = e.reason
reason = e.rcvd.reason

if isinstance(e, AssertionError) or e.code == WS_1008_POLICY_VIOLATION:
if isinstance(e, AssertionError) or e.rcvd.code == WS_1008_POLICY_VIOLATION:
raise Exception(
"Unable to authenticate to the subscription. Please "
"ensure the provided `PREFECT_API_KEY` you are using is "
Expand Down
52 changes: 35 additions & 17 deletions src/prefect/events/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
PREFECT_API_KEY,
PREFECT_API_URL,
PREFECT_CLOUD_API_URL,
PREFECT_DEBUG_MODE,
PREFECT_SERVER_ALLOW_EPHEMERAL_MODE,
)

Expand Down Expand Up @@ -67,6 +68,18 @@
logger = get_logger(__name__)


def http_to_ws(url: str):
return url.replace("https://", "wss://").replace("http://", "ws://").rstrip("/")


def events_in_socket_from_api_url(url: str):
return http_to_ws(url) + "/events/in"


def events_out_socket_from_api_url(url: str):
return http_to_ws(url) + "/events/out"


def get_events_client(
reconnection_attempts: int = 10,
checkpoint_every: int = 700,
Expand Down Expand Up @@ -251,12 +264,7 @@ def __init__(
"api_url must be provided or set in the Prefect configuration"
)

self._events_socket_url = (
api_url.replace("https://", "wss://")
.replace("http://", "ws://")
.rstrip("/")
+ "/events/in"
)
self._events_socket_url = events_in_socket_from_api_url(api_url)
self._connect = connect(self._events_socket_url)
self._websocket = None
self._reconnection_attempts = reconnection_attempts
Expand Down Expand Up @@ -285,11 +293,26 @@ async def _reconnect(self) -> None:
self._websocket = None
await self._connect.__aexit__(None, None, None)

self._websocket = await self._connect.__aenter__()

# make sure we have actually connected
pong = await self._websocket.ping()
await pong
try:
self._websocket = await self._connect.__aenter__()
# make sure we have actually connected
pong = await self._websocket.ping()
await pong
except Exception as e:
# The client is frequently run in a background thread
# so we log an additional warning to ensure
# surfacing the error to the user.
logger.warning(
"Unable to connect to %r. "
"Please check your network settings to ensure websocket connections "
"to the API are allowed. Otherwise event data (including task run data) may be lost. "
"Reason: %s. "
"Set PREFECT_DEBUG_MODE=1 to see the full error.",
self._events_socket_url,
str(e),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a blocker - it depends entirely on how you want this displayed - but str(e) hides the exception type whereas repr(e) preserves it; totally up to you which you think is better for this situation, e.g.,

str(ValueError("foo")) # 'foo'
repr(ValueError("foo")) # "ValueError('foo')"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My fault, I had enabled auto-merge here. I think repr is probably more correct, let me open another pr to update

exc_info=PREFECT_DEBUG_MODE,
)
raise

events_to_resend = self._unconfirmed_events
# Clear the unconfirmed events here, because they are going back through emit
Expand Down Expand Up @@ -412,7 +435,6 @@ def __init__(
reconnection_attempts=reconnection_attempts,
checkpoint_every=checkpoint_every,
)

self._connect = connect(
self._events_socket_url,
extra_headers={"Authorization": f"bearer {api_key}"},
Expand Down Expand Up @@ -468,11 +490,7 @@ def __init__(
self._filter = filter or EventFilter() # type: ignore[call-arg]
self._seen_events = TTLCache(maxsize=SEEN_EVENTS_SIZE, ttl=SEEN_EVENTS_TTL)

socket_url = (
api_url.replace("https://", "wss://")
.replace("http://", "ws://")
.rstrip("/")
) + "/events/out"
socket_url = events_out_socket_from_api_url(api_url)

logger.debug("Connecting to %s", socket_url)

Expand Down
27 changes: 27 additions & 0 deletions tests/events/client/test_events_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import Type
from unittest import mock

Expand Down Expand Up @@ -342,3 +343,29 @@ async def test_recovers_from_long_lasting_error_reconnecting(
# event 2 never made it because we cause that error during reconnection
# event 3 never made it because we told the server to refuse future connects
]


async def test_events_client_warn_if_connect_fails(
monkeypatch: pytest.MonkeyPatch,
caplog: pytest.LogCaptureFixture,
):
class MockConnect:
async def __aenter__(self):
raise Exception("Connection failed")

async def __aexit__(self, exc_type, exc_val, exc_tb):
pass

def mock_connect(*args, **kwargs):
return MockConnect()

monkeypatch.setattr("prefect.events.clients.connect", mock_connect)

with caplog.at_level(logging.WARNING):
with pytest.raises(Exception, match="Connection failed"):
async with PrefectEventsClient("ws://localhost"):
pass

assert any(
"Unable to connect to 'ws" in record.message for record in caplog.records
)