diff --git a/app/core/custom_exceptions.py b/app/core/custom_exceptions.py index e4cb41e4..a7c5ff23 100644 --- a/app/core/custom_exceptions.py +++ b/app/core/custom_exceptions.py @@ -5,26 +5,63 @@ class RequiresAuthenticationException(HTTPException): def __init__(self): super().__init__( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Requires authentication", + detail={ + "type": "not_authenticated", + "errorMessage": "Requires authentication", + }, ) class PermissionDeniedException(HTTPException): def __init__(self): super().__init__( - status_code=status.HTTP_403_FORBIDDEN, detail="Permission denied" + status_code=status.HTTP_403_FORBIDDEN, + detail={ + "type": "not_authorized", + "errorMessage": "Permission denied", + }, ) -class BadDataException(HTTPException): +class InternalServerException(HTTPException): def __init__(self, error_message: str): super().__init__( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail=[ - { - "loc": [], - "msg": error_message, - "type": "value_error.bad_data", - } - ], + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={ + "type": "other", + "errorMessage": error_message, + }, + ) + + +class MissingParameterException(HTTPException): + def __init__(self, error_message: str): + super().__init__( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "type": "missing_parameter", + "errorMessage": error_message, + }, + ) + + +class InvalidTemplateException(HTTPException): + def __init__(self, error_message: str): + super().__init__( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "type": "invalid_template", + "errorMessage": error_message, + }, + ) + + +class InvalidModelException(HTTPException): + def __init__(self, error_message: str): + super().__init__( + status_code=status.HTTP_404_NOT_FOUND, + detail={ + "type": "invalid_model", + "errorMessage": error_message, + }, ) diff --git a/app/routes/messages.py b/app/routes/messages.py index a545c4a7..0eb5766d 100644 --- a/app/routes/messages.py +++ b/app/routes/messages.py @@ -1,6 +1,16 @@ from fastapi import APIRouter, Depends from datetime import datetime, timezone +from parsimonious.exceptions import IncompleteParseError + +from app.core.custom_exceptions import ( + MissingParameterException, + InvalidTemplateException, + InternalServerException, + InvalidModelException, +) +from app.dependencies import PermissionsValidator +from app.models.dtos import SendMessageRequest, SendMessageResponse, LLMModel from app.core.custom_exceptions import BadDataException from app.dependencies import TokenPermissionsValidator from app.models.dtos import SendMessageRequest, SendMessageResponse @@ -13,16 +23,25 @@ "/api/v1/messages", dependencies=[Depends(TokenPermissionsValidator())] ) def send_message(body: SendMessageRequest) -> SendMessageResponse: + try: + model = LLMModel(body.preferred_model) + except ValueError as e: + raise InvalidModelException(str(e)) + guidance = GuidanceWrapper( - model=body.preferred_model, + model=model, handlebars=body.template.content, parameters=body.parameters, ) try: content = guidance.query() - except (KeyError, ValueError) as e: - raise BadDataException(str(e)) + except KeyError as e: + raise MissingParameterException(str(e)) + except (SyntaxError, IncompleteParseError) as e: + raise InvalidTemplateException(str(e)) + except Exception as e: + raise InternalServerException(str(e)) # Turn content into an array if it's not already if not isinstance(content, list): diff --git a/tests/routes/messages_test.py b/tests/routes/messages_test.py index a77d43c3..0c382573 100644 --- a/tests/routes/messages_test.py +++ b/tests/routes/messages_test.py @@ -76,15 +76,12 @@ def test_send_message_raise_value_error(test_client, headers, mocker): "parameters": {"query": "Some query"}, } response = test_client.post("/api/v1/messages", headers=headers, json=body) - assert response.status_code == 422 + assert response.status_code == 500 assert response.json() == { - "detail": [ - { - "loc": [], - "msg": "value error message", - "type": "value_error.bad_data", - } - ] + "detail": { + "type": "other", + "errorMessage": "value error message", + } } @@ -101,15 +98,12 @@ def test_send_message_raise_key_error(test_client, headers, mocker): "parameters": {"query": "Some query"}, } response = test_client.post("/api/v1/messages", headers=headers, json=body) - assert response.status_code == 422 + assert response.status_code == 400 assert response.json() == { - "detail": [ - { - "loc": [], - "msg": "'key error message'", - "type": "value_error.bad_data", - } - ] + "detail": { + "type": "missing_parameter", + "errorMessage": "'key error message'", + } } @@ -117,10 +111,16 @@ def test_send_message_with_wrong_api_key(test_client): headers = {"Authorization": "wrong api key"} response = test_client.post("/api/v1/messages", headers=headers, json={}) assert response.status_code == 403 - assert response.json()["detail"] == "Permission denied" + assert response.json()["detail"] == { + "type": "not_authorized", + "errorMessage": "Permission denied", + } def test_send_message_without_authorization_header(test_client): response = test_client.post("/api/v1/messages", json={}) assert response.status_code == 401 - assert response.json()["detail"] == "Requires authentication" + assert response.json()["detail"] == { + "type": "not_authenticated", + "errorMessage": "Requires authentication", + }