Skip to content

Commit

Permalink
feat: updated exception handler (#42)
Browse files Browse the repository at this point in the history
* feat: updated exception handler

* fix: don't change the exc

* fix: simplify response

* feat: use response instead of back

* fix: updated test
  • Loading branch information
cofin authored Aug 4, 2024
1 parent 689c195 commit 18d9d89
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 38 deletions.
37 changes: 18 additions & 19 deletions litestar_vite/inertia/exception_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from litestar.plugins.flash import flash
from litestar.repository.exceptions import (
ConflictError, # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue]
NotFoundError, # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue]
RepositoryError, # pyright: ignore[reportUnknownVariableType,reportAttributeAccessIssue]
)
Expand Down Expand Up @@ -47,13 +48,14 @@ class _HTTPConflictException(HTTPException):
def exception_to_http_response(request: Request[UserT, AuthT, StateT], exc: Exception) -> Response[Any]:
"""Handler for all exceptions subclassed from HTTPException."""
inertia_enabled = getattr(request, "inertia_enabled", False) or getattr(request, "is_inertia", False)
if isinstance(exc, NotFoundError):
http_exc = NotFoundException
elif isinstance(exc, RepositoryError):
http_exc = _HTTPConflictException # type: ignore[assignment]
else:
http_exc = InternalServerException # type: ignore[assignment]

if not inertia_enabled:
if isinstance(exc, NotFoundError):
http_exc = NotFoundException
elif isinstance(exc, (RepositoryError, ConflictError)):
http_exc = _HTTPConflictException # type: ignore[assignment]
else:
http_exc = InternalServerException # type: ignore[assignment]
if request.app.debug and http_exc not in (PermissionDeniedException, NotFoundError):
return cast("Response[Any]", create_debug_response(request, exc))
return cast("Response[Any]", create_exception_response(request, http_exc(detail=str(exc.__cause__))))
Expand Down Expand Up @@ -88,16 +90,13 @@ def create_inertia_exception_response(request: Request[UserT, AuthT, StateT], ex
return InertiaBack(request)
if isinstance(exc, PermissionDeniedException):
return InertiaBack(request)
if status_code == HTTP_401_UNAUTHORIZED or isinstance(exc, NotAuthorizedException):
if (
inertia_plugin.config.redirect_unauthorized_to is not None
and str(request.url) != inertia_plugin.config.redirect_unauthorized_to
):
return InertiaRedirect(request, redirect_to=inertia_plugin.config.redirect_unauthorized_to)
if str(request.url) != inertia_plugin.config.redirect_unauthorized_to:
return InertiaResponse[Any](
media_type=preferred_type,
content=content,
status_code=status_code,
)
return InertiaBack(request)
if (status_code == HTTP_401_UNAUTHORIZED or isinstance(exc, NotAuthorizedException)) and (
inertia_plugin.config.redirect_unauthorized_to is not None
and str(request.url) != inertia_plugin.config.redirect_unauthorized_to
):
return InertiaRedirect(request, redirect_to=inertia_plugin.config.redirect_unauthorized_to)
return InertiaResponse[Any](
media_type=preferred_type,
content=content,
status_code=status_code,
)
10 changes: 10 additions & 0 deletions litestar_vite/inertia/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,13 @@ def is_inertia(self) -> bool:
def inertia_enabled(self) -> bool:
"""True if the route handler contains an inertia enabled configuration."""
return bool(self.inertia.route_component is not None)

@property
def is_partial_render(self) -> bool:
"""True if the route handler contains an inertia enabled configuration."""
return self.inertia.is_partial_render

@property
def partial_keys(self) -> set[str]:
"""True if the route handler contains an inertia enabled configuration."""
return set(self.inertia.partial_keys)
29 changes: 16 additions & 13 deletions litestar_vite/inertia/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from litestar.connection.base import AuthT, StateT, UserT
from litestar.types import ResponseCookies, ResponseHeaders, TypeEncodersMap

from litestar_vite.inertia.request import InertiaRequest
from litestar_vite.inertia.routes import Routes

from .plugin import InertiaPlugin
Expand Down Expand Up @@ -65,7 +64,10 @@ def error(
connection.logger.warning(msg)


def get_shared_props(request: ASGIConnection[Any, Any, Any, Any]) -> Dict[str, Any]: # noqa: UP006
def get_shared_props(
request: ASGIConnection[Any, Any, Any, Any],
partial_data: set[str] | None = None,
) -> Dict[str, Any]: # noqa: UP006
"""Return shared session props for a request
Expand Down Expand Up @@ -217,9 +219,11 @@ def to_asgi_response(
removal_in="3.0.0",
alternative="request.app",
)
inertia_enabled = getattr(request, "inertia_enabled", False) or getattr(request, "is_inertia", False)
is_inertia = getattr(request, "is_inertia", False)

inertia_enabled = cast(
"bool",
getattr(request, "inertia_enabled", False) or getattr(request, "is_inertia", False),
)
is_inertia = cast("bool", getattr(request, "is_inertia", False))
headers = {**headers, **self.headers} if headers is not None else self.headers
cookies = self.cookies if cookies is None else itertools.chain(self.cookies, cookies)
type_encoders = (
Expand All @@ -238,15 +242,18 @@ def to_asgi_response(
media_type=media_type,
status_code=self.status_code or status_code,
)
is_partial_render = cast("bool", getattr(request, "is_partial_render", False))
partial_keys = cast("set[str]", getattr(request, "partial_keys", {}))
vite_plugin = request.app.plugins.get(VitePlugin)
template_engine = vite_plugin.template_config.to_engine()
headers.update(
{"Vary": "Accept", **get_headers(InertiaHeaderType(enabled=True))},
)
shared_props = get_shared_props(request)
shared_props = get_shared_props(request, partial_data=partial_keys if is_partial_render else None)
shared_props["content"] = self.content
page_props = PageProps[T](
component=request.inertia.route_component, # type: ignore[attr-defined] # pyright: ignore[reportUnknownArgumentType,reportUnknownMemberType,reportAttributeAccessIssue]
props={"content": self.content, **shared_props}, # pyright: ignore[reportArgumentType]
props=shared_props, # pyright: ignore[reportArgumentType]
version=template_engine.asset_loader.version_id,
url=request.url.path,
)
Expand Down Expand Up @@ -337,7 +344,7 @@ def __init__(
"""Initialize external redirect, Set status code to 409 (required by Inertia),
and pass redirect url.
"""
referer = urlparse(request.headers.get("referer", str(request.base_url)))
referer = urlparse(request.headers.get("Referer", str(request.base_url)))
redirect_to = urlunparse(urlparse(redirect_to)._replace(scheme=referer.scheme))
super().__init__(
path=redirect_to,
Expand All @@ -358,12 +365,8 @@ def __init__(
"""Initialize external redirect, Set status code to 409 (required by Inertia),
and pass redirect url.
"""
referer = request.headers.get("referer", str(request.base_url))
inertia_enabled = getattr(request, "inertia_enabled", False) or getattr(request, "is_inertia", False)
if inertia_enabled:
referer = cast("InertiaRequest[Any, Any, Any]", request).inertia.referer or referer
super().__init__(
path=request.headers.get("referer", str(request.base_url)),
path=request.headers.get("Referer", str(request.base_url)),
status_code=HTTP_307_TEMPORARY_REDIRECT if request.method == "GET" else HTTP_303_SEE_OTHER,
cookies=request.cookies,
**kwargs,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ license = { text = "MIT" }
name = "litestar-vite"
readme = "README.md"
requires-python = ">=3.8"
version = "0.2.8"
version = "0.2.9"

[project.urls]
Changelog = "https://cofin.github.io/litestar-vite/latest/changelog"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_inertia/test_inertia_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ async def handler(request: InertiaRequest[Any, Any, Any]) -> bool:
response = client.get("/", headers={InertiaHeaders.ENABLED.value: "true"})
assert (
response.text
== '{"component":null,"url":"/","version":"1.0","props":{"content":true,"flash":{},"errors":{},"csrf_token":""}}'
== '{"component":null,"url":"/","version":"1.0","props":{"flash":{},"errors":{},"csrf_token":"","content":true}}'
)


Expand Down
8 changes: 4 additions & 4 deletions tests/test_inertia/test_inertia_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ async def handler(request: Request[Any, Any, Any]) -> Dict[str, Any]:
response = client.get("/", headers={InertiaHeaders.ENABLED.value: "true"})
assert (
response.content
== b'{"component":"Home","url":"/","version":"1.0","props":{"content":{"thing":"value"},"flash":{},"errors":{},"csrf_token":""}}'
== b'{"component":"Home","url":"/","version":"1.0","props":{"flash":{},"errors":{},"csrf_token":"","content":{"thing":"value"}}}'
)


Expand All @@ -72,7 +72,7 @@ async def handler(request: Request[Any, Any, Any]) -> Dict[str, Any]:
response = client.get("/", headers={InertiaHeaders.ENABLED.value: "true"})
assert (
response.content
== b'{"component":"Home","url":"/","version":"1.0","props":{"content":{"thing":"value"},"flash":{"info":["a flash message"]},"errors":{},"csrf_token":""}}'
== b'{"component":"Home","url":"/","version":"1.0","props":{"flash":{"info":["a flash message"]},"errors":{},"csrf_token":"","content":{"thing":"value"}}}'
)


Expand All @@ -99,7 +99,7 @@ async def handler(request: Request[Any, Any, Any]) -> Dict[str, Any]:
response = client.get("/", headers={InertiaHeaders.ENABLED.value: "true"})
assert (
response.content
== b'{"component":"Home","url":"/","version":"1.0","props":{"content":{"thing":"value"},"auth":{"user":"nobody"},"flash":{"info":["a flash message"]},"errors":{},"csrf_token":""}}'
== b'{"component":"Home","url":"/","version":"1.0","props":{"auth":{"user":"nobody"},"flash":{"info":["a flash message"]},"errors":{},"csrf_token":"","content":{"thing":"value"}}}'
)


Expand Down Expand Up @@ -135,5 +135,5 @@ async def handler(request: Request[Any, Any, Any]) -> Dict[str, Any]:
)
assert (
response.content
== b'{"component":"Home","url":"/","version":"1.0","props":{"content":{"thing":"value"},"flash":{},"errors":{},"csrf_token":""}}'
== b'{"component":"Home","url":"/","version":"1.0","props":{"flash":{},"errors":{},"csrf_token":"","content":{"thing":"value"}}}'
)

0 comments on commit 18d9d89

Please sign in to comment.