diff --git a/fastapi/error_handlers.py b/fastapi/error_handlers.py index ca054a58..2e4202d6 100644 --- a/fastapi/error_handlers.py +++ b/fastapi/error_handlers.py @@ -1,14 +1,55 @@ # Copyright 2022 ACSONE SA/NV # License LGPL-3.0 or later (http://www.gnu.org/licenses/LGPL). - -from starlette.exceptions import WebSocketException +from starlette import status +from starlette.exceptions import HTTPException, WebSocketException from starlette.middleware.errors import ServerErrorMiddleware from starlette.middleware.exceptions import ExceptionMiddleware from starlette.responses import JSONResponse from starlette.websockets import WebSocket +from werkzeug.exceptions import HTTPException as WerkzeugHTTPException + +from odoo.exceptions import AccessDenied, AccessError, MissingError, UserError from fastapi import Request +from fastapi.encoders import jsonable_encoder +from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError +from fastapi.utils import is_body_allowed_for_status_code + + +def convert_exception_to_status_body(exc: Exception) -> tuple[int, dict]: + body = {} + status_code = status.HTTP_500_INTERNAL_SERVER_ERROR + details = "Internal Server Error" + + if isinstance(exc, WerkzeugHTTPException): + status_code = exc.code + details = exc.description + elif isinstance(exc, HTTPException): + status_code = exc.status_code + details = exc.detail + elif isinstance(exc, RequestValidationError): + status_code = status.HTTP_422_UNPROCESSABLE_ENTITY + details = jsonable_encoder(exc.errors()) + elif isinstance(exc, WebSocketRequestValidationError): + status_code = status.WS_1008_POLICY_VIOLATION + details = jsonable_encoder(exc.errors()) + elif isinstance(exc, AccessDenied | AccessError): + status_code = status.HTTP_403_FORBIDDEN + details = "AccessError" + elif isinstance(exc, MissingError): + status_code = status.HTTP_404_NOT_FOUND + details = "MissingError" + elif isinstance(exc, UserError): + status_code = status.HTTP_400_BAD_REQUEST + details = exc.args[0] + + if is_body_allowed_for_status_code(status_code): + # use the same format as in + # fastapi.exception_handlers.http_exception_handler + body = {"detail": details} + return status_code, body + # we need to monkey patch the ServerErrorMiddleware and ExceptionMiddleware classes # to ensure that all the exceptions that are handled by these specific diff --git a/fastapi/fastapi_dispatcher.py b/fastapi/fastapi_dispatcher.py index 8bd9aa41..1a8eb353 100644 --- a/fastapi/fastapi_dispatcher.py +++ b/fastapi/fastapi_dispatcher.py @@ -4,18 +4,10 @@ from contextlib import contextmanager from io import BytesIO -from starlette import status -from starlette.exceptions import HTTPException -from werkzeug.exceptions import HTTPException as WerkzeugHTTPException - -from odoo.exceptions import AccessDenied, AccessError, MissingError, UserError from odoo.http import Dispatcher, request -from fastapi.encoders import jsonable_encoder -from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError -from fastapi.utils import is_body_allowed_for_status_code - from .context import odoo_env_ctx +from .error_handlers import convert_exception_to_status_body class FastApiDispatcher(Dispatcher): @@ -47,34 +39,7 @@ def dispatch(self, endpoint, args): def handle_error(self, exc): headers = getattr(exc, "headers", None) - status_code = status.HTTP_500_INTERNAL_SERVER_ERROR - details = "Internal Server Error" - if isinstance(exc, WerkzeugHTTPException): - status_code = exc.code - details = exc.description - elif isinstance(exc, HTTPException): - status_code = exc.status_code - details = exc.detail - elif isinstance(exc, RequestValidationError): - status_code = status.HTTP_422_UNPROCESSABLE_ENTITY - details = jsonable_encoder(exc.errors()) - elif isinstance(exc, WebSocketRequestValidationError): - status_code = status.WS_1008_POLICY_VIOLATION - details = jsonable_encoder(exc.errors()) - elif isinstance(exc, AccessDenied | AccessError): - status_code = status.HTTP_403_FORBIDDEN - details = "AccessError" - elif isinstance(exc, MissingError): - status_code = status.HTTP_404_NOT_FOUND - details = "MissingError" - elif isinstance(exc, UserError): - status_code = status.HTTP_400_BAD_REQUEST - details = exc.args[0] - body = {} - if is_body_allowed_for_status_code(status_code): - # use the same format as in - # fastapi.exception_handlers.http_exception_handler - body = {"detail": details} + status_code, body = convert_exception_to_status_body(exc) return self.request.make_json_response( body, status=status_code, headers=headers ) diff --git a/fastapi/tests/common.py b/fastapi/tests/common.py index 65aba7dc..f9c05b67 100644 --- a/fastapi/tests/common.py +++ b/fastapi/tests/common.py @@ -1,10 +1,15 @@ # Copyright 2023 ACSONE SA/NV # License LGPL-3.0 or later (http://www.gnu.org/licenses/LGPL). +import logging from collections.abc import Callable from contextlib import contextmanager from functools import partial from typing import Any +from starlette import status +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + from odoo.api import Environment from odoo.tests import tagged from odoo.tests.common import TransactionCase @@ -20,6 +25,25 @@ authenticated_partner_impl, optionally_authenticated_partner_impl, ) +from ..error_handlers import convert_exception_to_status_body + +_logger = logging.getLogger(__name__) + + +def default_exception_handler(request: Request, exc: Exception) -> Response: + """ + Default exception handler that returns a response with the exception details. + """ + status_code, body = convert_exception_to_status_body(exc) + + if status_code == status.HTTP_500_INTERNAL_SERVER_ERROR: + # In testing we want to see the exception details of 500 errors + _logger.error("[%d] Error occurred: %s", exc_info=exc) + + return JSONResponse( + status_code=status_code, + content=body, + ) @tagged("post_install", "-at_install") @@ -125,6 +149,11 @@ def _create_test_client( if router: app.include_router(router) app.dependency_overrides = dependencies + + if not raise_server_exceptions: + # Handle exceptions as in FastAPIDispatcher + app.exception_handlers.setdefault(Exception, default_exception_handler) + ctx_token = odoo_env_ctx.set(env) testclient_kwargs = testclient_kwargs or {} try: diff --git a/fastapi/tests/test_fastapi_demo.py b/fastapi/tests/test_fastapi_demo.py index 1692e69f..43503b6d 100644 --- a/fastapi/tests/test_fastapi_demo.py +++ b/fastapi/tests/test_fastapi_demo.py @@ -5,11 +5,13 @@ from requests import Response +from odoo.exceptions import UserError + from fastapi import status from ..dependencies import fastapi_endpoint from ..routers import demo_router -from ..schemas import DemoEndpointAppInfo +from ..schemas import DemoEndpointAppInfo, DemoExceptionType from .common import FastAPITransactionCase @@ -61,3 +63,46 @@ def test_endpoint_info(self) -> None: response.json(), DemoEndpointAppInfo.model_validate(demo_app).model_dump(by_alias=True), ) + + def test_exception_raised(self) -> None: + with self.assertRaisesRegex(UserError, "User Error"): + with self._create_test_client() as test_client: + test_client.get( + "/demo/exception", + params={ + "exception_type": DemoExceptionType.user_error.value, + "error_message": "User Error", + }, + ) + with self.assertRaisesRegex(NotImplementedError, "Bare Exception"): + with self._create_test_client() as test_client: + test_client.get( + "/demo/exception", + params={ + "exception_type": DemoExceptionType.bare_exception.value, + "error_message": "Bare Exception", + }, + ) + + def test_exception_not_raised(self) -> None: + with self._create_test_client(raise_server_exceptions=False) as test_client: + response: Response = test_client.get( + "/demo/exception", + params={ + "exception_type": DemoExceptionType.user_error.value, + "error_message": "User Error", + }, + ) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), {"detail": "User Error"}) + + with self._create_test_client(raise_server_exceptions=False) as test_client: + response: Response = test_client.get( + "/demo/exception", + params={ + "exception_type": DemoExceptionType.bare_exception.value, + "error_message": "Bare Exception", + }, + ) + self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) + self.assertDictEqual(response.json(), {"detail": "Internal Server Error"})