Skip to content

Commit

Permalink
[IMP] fastapi: Factorize error handling and use it in tests with rais…
Browse files Browse the repository at this point in the history
…e_server_exceptions=False
  • Loading branch information
paradoxxxzero authored and lmignon committed Oct 16, 2024
1 parent 7c0cbe6 commit 3cda449
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 40 deletions.
45 changes: 43 additions & 2 deletions fastapi/error_handlers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,55 @@
# Copyright 2022 ACSONE SA/NV
# License LGPL-3.0 or later (http://www.gnu.org/licenses/LGPL).


from starlette.exceptions import WebSocketException
from starlette import status
from starlette.exceptions import HTTPException, WebSocketException
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.middleware.exceptions import ExceptionMiddleware
from starlette.responses import JSONResponse
from starlette.websockets import WebSocket
from werkzeug.exceptions import HTTPException as WerkzeugHTTPException

from odoo.exceptions import AccessDenied, AccessError, MissingError, UserError

from fastapi import Request
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
from fastapi.utils import is_body_allowed_for_status_code


def convert_exception_to_status_body(exc: Exception) -> tuple[int, dict]:
body = {}
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
details = "Internal Server Error"

if isinstance(exc, WerkzeugHTTPException):
status_code = exc.code
details = exc.description

Check warning on line 27 in fastapi/error_handlers.py

View check run for this annotation

Codecov / codecov/patch

fastapi/error_handlers.py#L26-L27

Added lines #L26 - L27 were not covered by tests
elif isinstance(exc, HTTPException):
status_code = exc.status_code
details = exc.detail
elif isinstance(exc, RequestValidationError):
status_code = status.HTTP_422_UNPROCESSABLE_ENTITY
details = jsonable_encoder(exc.errors())
elif isinstance(exc, WebSocketRequestValidationError):
status_code = status.WS_1008_POLICY_VIOLATION
details = jsonable_encoder(exc.errors())

Check warning on line 36 in fastapi/error_handlers.py

View check run for this annotation

Codecov / codecov/patch

fastapi/error_handlers.py#L35-L36

Added lines #L35 - L36 were not covered by tests
elif isinstance(exc, AccessDenied | AccessError):
status_code = status.HTTP_403_FORBIDDEN
details = "AccessError"
elif isinstance(exc, MissingError):
status_code = status.HTTP_404_NOT_FOUND
details = "MissingError"
elif isinstance(exc, UserError):
status_code = status.HTTP_400_BAD_REQUEST
details = exc.args[0]

if is_body_allowed_for_status_code(status_code):
# use the same format as in
# fastapi.exception_handlers.http_exception_handler
body = {"detail": details}
return status_code, body


# we need to monkey patch the ServerErrorMiddleware and ExceptionMiddleware classes
# to ensure that all the exceptions that are handled by these specific
Expand Down
39 changes: 2 additions & 37 deletions fastapi/fastapi_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,10 @@
from contextlib import contextmanager
from io import BytesIO

from starlette import status
from starlette.exceptions import HTTPException
from werkzeug.exceptions import HTTPException as WerkzeugHTTPException

from odoo.exceptions import AccessDenied, AccessError, MissingError, UserError
from odoo.http import Dispatcher, request

from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError
from fastapi.utils import is_body_allowed_for_status_code

from .context import odoo_env_ctx
from .error_handlers import convert_exception_to_status_body


class FastApiDispatcher(Dispatcher):
Expand Down Expand Up @@ -47,34 +39,7 @@ def dispatch(self, endpoint, args):

def handle_error(self, exc):
headers = getattr(exc, "headers", None)
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
details = "Internal Server Error"
if isinstance(exc, WerkzeugHTTPException):
status_code = exc.code
details = exc.description
elif isinstance(exc, HTTPException):
status_code = exc.status_code
details = exc.detail
elif isinstance(exc, RequestValidationError):
status_code = status.HTTP_422_UNPROCESSABLE_ENTITY
details = jsonable_encoder(exc.errors())
elif isinstance(exc, WebSocketRequestValidationError):
status_code = status.WS_1008_POLICY_VIOLATION
details = jsonable_encoder(exc.errors())
elif isinstance(exc, AccessDenied | AccessError):
status_code = status.HTTP_403_FORBIDDEN
details = "AccessError"
elif isinstance(exc, MissingError):
status_code = status.HTTP_404_NOT_FOUND
details = "MissingError"
elif isinstance(exc, UserError):
status_code = status.HTTP_400_BAD_REQUEST
details = exc.args[0]
body = {}
if is_body_allowed_for_status_code(status_code):
# use the same format as in
# fastapi.exception_handlers.http_exception_handler
body = {"detail": details}
status_code, body = convert_exception_to_status_body(exc)
return self.request.make_json_response(
body, status=status_code, headers=headers
)
Expand Down
29 changes: 29 additions & 0 deletions fastapi/tests/common.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
# Copyright 2023 ACSONE SA/NV
# License LGPL-3.0 or later (http://www.gnu.org/licenses/LGPL).
import logging
from collections.abc import Callable
from contextlib import contextmanager
from functools import partial
from typing import Any

from starlette import status
from starlette.requests import Request
from starlette.responses import JSONResponse, Response

from odoo.api import Environment
from odoo.tests import tagged
from odoo.tests.common import TransactionCase
Expand All @@ -20,6 +25,25 @@
authenticated_partner_impl,
optionally_authenticated_partner_impl,
)
from ..error_handlers import convert_exception_to_status_body

_logger = logging.getLogger(__name__)


def default_exception_handler(request: Request, exc: Exception) -> Response:
"""
Default exception handler that returns a response with the exception details.
"""
status_code, body = convert_exception_to_status_body(exc)

if status_code == status.HTTP_500_INTERNAL_SERVER_ERROR:
# In testing we want to see the exception details of 500 errors
_logger.error("[%d] Error occurred: %s", exc_info=exc)

return JSONResponse(
status_code=status_code,
content=body,
)


@tagged("post_install", "-at_install")
Expand Down Expand Up @@ -125,6 +149,11 @@ def _create_test_client(
if router:
app.include_router(router)
app.dependency_overrides = dependencies

if not raise_server_exceptions:
# Handle exceptions as in FastAPIDispatcher
app.exception_handlers.setdefault(Exception, default_exception_handler)

ctx_token = odoo_env_ctx.set(env)
testclient_kwargs = testclient_kwargs or {}
try:
Expand Down
47 changes: 46 additions & 1 deletion fastapi/tests/test_fastapi_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

from requests import Response

from odoo.exceptions import UserError

from fastapi import status

from ..dependencies import fastapi_endpoint
from ..routers import demo_router
from ..schemas import DemoEndpointAppInfo
from ..schemas import DemoEndpointAppInfo, DemoExceptionType
from .common import FastAPITransactionCase


Expand Down Expand Up @@ -61,3 +63,46 @@ def test_endpoint_info(self) -> None:
response.json(),
DemoEndpointAppInfo.model_validate(demo_app).model_dump(by_alias=True),
)

def test_exception_raised(self) -> None:
with self.assertRaisesRegex(UserError, "User Error"):
with self._create_test_client() as test_client:
test_client.get(
"/demo/exception",
params={
"exception_type": DemoExceptionType.user_error.value,
"error_message": "User Error",
},
)
with self.assertRaisesRegex(NotImplementedError, "Bare Exception"):
with self._create_test_client() as test_client:
test_client.get(
"/demo/exception",
params={
"exception_type": DemoExceptionType.bare_exception.value,
"error_message": "Bare Exception",
},
)

def test_exception_not_raised(self) -> None:
with self._create_test_client(raise_server_exceptions=False) as test_client:
response: Response = test_client.get(
"/demo/exception",
params={
"exception_type": DemoExceptionType.user_error.value,
"error_message": "User Error",
},
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertDictEqual(response.json(), {"detail": "User Error"})

with self._create_test_client(raise_server_exceptions=False) as test_client:
response: Response = test_client.get(
"/demo/exception",
params={
"exception_type": DemoExceptionType.bare_exception.value,
"error_message": "Bare Exception",
},
)
self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR)
self.assertDictEqual(response.json(), {"detail": "Internal Server Error"})

0 comments on commit 3cda449

Please sign in to comment.