From 18d9d899dd0dfdaeda1482eef9c9aa1ba970dfeb Mon Sep 17 00:00:00 2001 From: Cody Fincher <204685+cofin@users.noreply.github.com> Date: Sun, 4 Aug 2024 15:20:32 -0500 Subject: [PATCH] feat: updated exception handler (#42) * feat: updated exception handler * fix: don't change the exc * fix: simplify response * feat: use response instead of back * fix: updated test --- litestar_vite/inertia/exception_handler.py | 37 ++++++++++----------- litestar_vite/inertia/request.py | 10 ++++++ litestar_vite/inertia/response.py | 29 ++++++++-------- pyproject.toml | 2 +- tests/test_inertia/test_inertia_request.py | 2 +- tests/test_inertia/test_inertia_response.py | 8 ++--- 6 files changed, 50 insertions(+), 38 deletions(-) diff --git a/litestar_vite/inertia/exception_handler.py b/litestar_vite/inertia/exception_handler.py index 1374c37..8d3af9d 100644 --- a/litestar_vite/inertia/exception_handler.py +++ b/litestar_vite/inertia/exception_handler.py @@ -18,6 +18,7 @@ ) from litestar.plugins.flash import flash from litestar.repository.exceptions import ( + ConflictError, # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue] NotFoundError, # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue] RepositoryError, # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue] ) @@ -47,13 +48,14 @@ class _HTTPConflictException(HTTPException): def exception_to_http_response(request: Request[UserT, AuthT, StateT], exc: Exception) -> Response[Any]: """Handler for all exceptions subclassed from HTTPException.""" inertia_enabled = getattr(request, "inertia_enabled", False) or getattr(request, "is_inertia", False) - if isinstance(exc, NotFoundError): - http_exc = NotFoundException - elif isinstance(exc, RepositoryError): - http_exc = _HTTPConflictException # type: ignore[assignment] - else: - http_exc = InternalServerException # type: ignore[assignment] + if not inertia_enabled: + if isinstance(exc, NotFoundError): + http_exc = NotFoundException + elif isinstance(exc, (RepositoryError, ConflictError)): + http_exc = _HTTPConflictException # type: ignore[assignment] + else: + http_exc = InternalServerException # type: ignore[assignment] if request.app.debug and http_exc not in (PermissionDeniedException, NotFoundError): return cast("Response[Any]", create_debug_response(request, exc)) return cast("Response[Any]", create_exception_response(request, http_exc(detail=str(exc.__cause__)))) @@ -88,16 +90,13 @@ def create_inertia_exception_response(request: Request[UserT, AuthT, StateT], ex return InertiaBack(request) if isinstance(exc, PermissionDeniedException): return InertiaBack(request) - if status_code == HTTP_401_UNAUTHORIZED or isinstance(exc, NotAuthorizedException): - if ( - inertia_plugin.config.redirect_unauthorized_to is not None - and str(request.url) != inertia_plugin.config.redirect_unauthorized_to - ): - return InertiaRedirect(request, redirect_to=inertia_plugin.config.redirect_unauthorized_to) - if str(request.url) != inertia_plugin.config.redirect_unauthorized_to: - return InertiaResponse[Any]( - media_type=preferred_type, - content=content, - status_code=status_code, - ) - return InertiaBack(request) + if (status_code == HTTP_401_UNAUTHORIZED or isinstance(exc, NotAuthorizedException)) and ( + inertia_plugin.config.redirect_unauthorized_to is not None + and str(request.url) != inertia_plugin.config.redirect_unauthorized_to + ): + return InertiaRedirect(request, redirect_to=inertia_plugin.config.redirect_unauthorized_to) + return InertiaResponse[Any]( + media_type=preferred_type, + content=content, + status_code=status_code, + ) diff --git a/litestar_vite/inertia/request.py b/litestar_vite/inertia/request.py index c6c6a71..6dc8af3 100644 --- a/litestar_vite/inertia/request.py +++ b/litestar_vite/inertia/request.py @@ -114,3 +114,13 @@ def is_inertia(self) -> bool: def inertia_enabled(self) -> bool: """True if the route handler contains an inertia enabled configuration.""" return bool(self.inertia.route_component is not None) + + @property + def is_partial_render(self) -> bool: + """True if the route handler contains an inertia enabled configuration.""" + return self.inertia.is_partial_render + + @property + def partial_keys(self) -> set[str]: + """True if the route handler contains an inertia enabled configuration.""" + return set(self.inertia.partial_keys) diff --git a/litestar_vite/inertia/response.py b/litestar_vite/inertia/response.py index 9929b49..02c78d8 100644 --- a/litestar_vite/inertia/response.py +++ b/litestar_vite/inertia/response.py @@ -33,7 +33,6 @@ from litestar.connection.base import AuthT, StateT, UserT from litestar.types import ResponseCookies, ResponseHeaders, TypeEncodersMap - from litestar_vite.inertia.request import InertiaRequest from litestar_vite.inertia.routes import Routes from .plugin import InertiaPlugin @@ -65,7 +64,10 @@ def error( connection.logger.warning(msg) -def get_shared_props(request: ASGIConnection[Any, Any, Any, Any]) -> Dict[str, Any]: # noqa: UP006 +def get_shared_props( + request: ASGIConnection[Any, Any, Any, Any], + partial_data: set[str] | None = None, +) -> Dict[str, Any]: # noqa: UP006 """Return shared session props for a request @@ -217,9 +219,11 @@ def to_asgi_response( removal_in="3.0.0", alternative="request.app", ) - inertia_enabled = getattr(request, "inertia_enabled", False) or getattr(request, "is_inertia", False) - is_inertia = getattr(request, "is_inertia", False) - + inertia_enabled = cast( + "bool", + getattr(request, "inertia_enabled", False) or getattr(request, "is_inertia", False), + ) + is_inertia = cast("bool", getattr(request, "is_inertia", False)) headers = {**headers, **self.headers} if headers is not None else self.headers cookies = self.cookies if cookies is None else itertools.chain(self.cookies, cookies) type_encoders = ( @@ -238,15 +242,18 @@ def to_asgi_response( media_type=media_type, status_code=self.status_code or status_code, ) + is_partial_render = cast("bool", getattr(request, "is_partial_render", False)) + partial_keys = cast("set[str]", getattr(request, "partial_keys", {})) vite_plugin = request.app.plugins.get(VitePlugin) template_engine = vite_plugin.template_config.to_engine() headers.update( {"Vary": "Accept", **get_headers(InertiaHeaderType(enabled=True))}, ) - shared_props = get_shared_props(request) + shared_props = get_shared_props(request, partial_data=partial_keys if is_partial_render else None) + shared_props["content"] = self.content page_props = PageProps[T]( component=request.inertia.route_component, # type: ignore[attr-defined] # pyright: ignore[reportUnknownArgumentType,reportUnknownMemberType,reportAttributeAccessIssue] - props={"content": self.content, **shared_props}, # pyright: ignore[reportArgumentType] + props=shared_props, # pyright: ignore[reportArgumentType] version=template_engine.asset_loader.version_id, url=request.url.path, ) @@ -337,7 +344,7 @@ def __init__( """Initialize external redirect, Set status code to 409 (required by Inertia), and pass redirect url. """ - referer = urlparse(request.headers.get("referer", str(request.base_url))) + referer = urlparse(request.headers.get("Referer", str(request.base_url))) redirect_to = urlunparse(urlparse(redirect_to)._replace(scheme=referer.scheme)) super().__init__( path=redirect_to, @@ -358,12 +365,8 @@ def __init__( """Initialize external redirect, Set status code to 409 (required by Inertia), and pass redirect url. """ - referer = request.headers.get("referer", str(request.base_url)) - inertia_enabled = getattr(request, "inertia_enabled", False) or getattr(request, "is_inertia", False) - if inertia_enabled: - referer = cast("InertiaRequest[Any, Any, Any]", request).inertia.referer or referer super().__init__( - path=request.headers.get("referer", str(request.base_url)), + path=request.headers.get("Referer", str(request.base_url)), status_code=HTTP_307_TEMPORARY_REDIRECT if request.method == "GET" else HTTP_303_SEE_OTHER, cookies=request.cookies, **kwargs, diff --git a/pyproject.toml b/pyproject.toml index 9ffb9b1..c4fc89b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ license = { text = "MIT" } name = "litestar-vite" readme = "README.md" requires-python = ">=3.8" -version = "0.2.8" +version = "0.2.9" [project.urls] Changelog = "https://cofin.github.io/litestar-vite/latest/changelog" diff --git a/tests/test_inertia/test_inertia_request.py b/tests/test_inertia/test_inertia_request.py index e8b1f84..979e001 100644 --- a/tests/test_inertia/test_inertia_request.py +++ b/tests/test_inertia/test_inertia_request.py @@ -73,7 +73,7 @@ async def handler(request: InertiaRequest[Any, Any, Any]) -> bool: response = client.get("/", headers={InertiaHeaders.ENABLED.value: "true"}) assert ( response.text - == '{"component":null,"url":"/","version":"1.0","props":{"content":true,"flash":{},"errors":{},"csrf_token":""}}' + == '{"component":null,"url":"/","version":"1.0","props":{"flash":{},"errors":{},"csrf_token":"","content":true}}' ) diff --git a/tests/test_inertia/test_inertia_response.py b/tests/test_inertia/test_inertia_response.py index c5b4a10..bb1df11 100644 --- a/tests/test_inertia/test_inertia_response.py +++ b/tests/test_inertia/test_inertia_response.py @@ -49,7 +49,7 @@ async def handler(request: Request[Any, Any, Any]) -> Dict[str, Any]: response = client.get("/", headers={InertiaHeaders.ENABLED.value: "true"}) assert ( response.content - == b'{"component":"Home","url":"/","version":"1.0","props":{"content":{"thing":"value"},"flash":{},"errors":{},"csrf_token":""}}' + == b'{"component":"Home","url":"/","version":"1.0","props":{"flash":{},"errors":{},"csrf_token":"","content":{"thing":"value"}}}' ) @@ -72,7 +72,7 @@ async def handler(request: Request[Any, Any, Any]) -> Dict[str, Any]: response = client.get("/", headers={InertiaHeaders.ENABLED.value: "true"}) assert ( response.content - == b'{"component":"Home","url":"/","version":"1.0","props":{"content":{"thing":"value"},"flash":{"info":["a flash message"]},"errors":{},"csrf_token":""}}' + == b'{"component":"Home","url":"/","version":"1.0","props":{"flash":{"info":["a flash message"]},"errors":{},"csrf_token":"","content":{"thing":"value"}}}' ) @@ -99,7 +99,7 @@ async def handler(request: Request[Any, Any, Any]) -> Dict[str, Any]: response = client.get("/", headers={InertiaHeaders.ENABLED.value: "true"}) assert ( response.content - == b'{"component":"Home","url":"/","version":"1.0","props":{"content":{"thing":"value"},"auth":{"user":"nobody"},"flash":{"info":["a flash message"]},"errors":{},"csrf_token":""}}' + == b'{"component":"Home","url":"/","version":"1.0","props":{"auth":{"user":"nobody"},"flash":{"info":["a flash message"]},"errors":{},"csrf_token":"","content":{"thing":"value"}}}' ) @@ -135,5 +135,5 @@ async def handler(request: Request[Any, Any, Any]) -> Dict[str, Any]: ) assert ( response.content - == b'{"component":"Home","url":"/","version":"1.0","props":{"content":{"thing":"value"},"flash":{},"errors":{},"csrf_token":""}}' + == b'{"component":"Home","url":"/","version":"1.0","props":{"flash":{},"errors":{},"csrf_token":"","content":{"thing":"value"}}}' )