Skip to content

Commit

Permalink
Merge branch 'main' into feature/better-llm-handling
Browse files Browse the repository at this point in the history
# Conflicts:
#	app/routes/messages.py
  • Loading branch information
Hialus committed Jul 12, 2023
2 parents 9613d1b + bf16ac8 commit a6609cb
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 32 deletions.
59 changes: 48 additions & 11 deletions app/core/custom_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,63 @@ class RequiresAuthenticationException(HTTPException):
def __init__(self):
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Requires authentication",
detail={
"type": "not_authenticated",
"errorMessage": "Requires authentication",
},
)


class PermissionDeniedException(HTTPException):
def __init__(self):
super().__init__(
status_code=status.HTTP_403_FORBIDDEN, detail="Permission denied"
status_code=status.HTTP_403_FORBIDDEN,
detail={
"type": "not_authorized",
"errorMessage": "Permission denied",
},
)


class BadDataException(HTTPException):
class InternalServerException(HTTPException):
def __init__(self, error_message: str):
super().__init__(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=[
{
"loc": [],
"msg": error_message,
"type": "value_error.bad_data",
}
],
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"type": "other",
"errorMessage": error_message,
},
)


class MissingParameterException(HTTPException):
def __init__(self, error_message: str):
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"type": "missing_parameter",
"errorMessage": error_message,
},
)


class InvalidTemplateException(HTTPException):
def __init__(self, error_message: str):
super().__init__(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"type": "invalid_template",
"errorMessage": error_message,
},
)


class InvalidModelException(HTTPException):
def __init__(self, error_message: str):
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"type": "invalid_model",
"errorMessage": error_message,
},
)
25 changes: 22 additions & 3 deletions app/routes/messages.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
from fastapi import APIRouter, Depends
from datetime import datetime, timezone

from parsimonious.exceptions import IncompleteParseError

from app.core.custom_exceptions import (
MissingParameterException,
InvalidTemplateException,
InternalServerException,
InvalidModelException,
)
from app.dependencies import PermissionsValidator
from app.models.dtos import SendMessageRequest, SendMessageResponse, LLMModel
from app.core.custom_exceptions import BadDataException
from app.dependencies import TokenPermissionsValidator
from app.models.dtos import SendMessageRequest, SendMessageResponse
Expand All @@ -13,16 +23,25 @@
"/api/v1/messages", dependencies=[Depends(TokenPermissionsValidator())]
)
def send_message(body: SendMessageRequest) -> SendMessageResponse:
try:
model = LLMModel(body.preferred_model)
except ValueError as e:
raise InvalidModelException(str(e))

guidance = GuidanceWrapper(
model=body.preferred_model,
model=model,
handlebars=body.template.content,
parameters=body.parameters,
)

try:
content = guidance.query()
except (KeyError, ValueError) as e:
raise BadDataException(str(e))
except KeyError as e:
raise MissingParameterException(str(e))
except (SyntaxError, IncompleteParseError) as e:
raise InvalidTemplateException(str(e))
except Exception as e:
raise InternalServerException(str(e))

# Turn content into an array if it's not already
if not isinstance(content, list):
Expand Down
36 changes: 18 additions & 18 deletions tests/routes/messages_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,12 @@ def test_send_message_raise_value_error(test_client, headers, mocker):
"parameters": {"query": "Some query"},
}
response = test_client.post("/api/v1/messages", headers=headers, json=body)
assert response.status_code == 422
assert response.status_code == 500
assert response.json() == {
"detail": [
{
"loc": [],
"msg": "value error message",
"type": "value_error.bad_data",
}
]
"detail": {
"type": "other",
"errorMessage": "value error message",
}
}


Expand All @@ -101,26 +98,29 @@ def test_send_message_raise_key_error(test_client, headers, mocker):
"parameters": {"query": "Some query"},
}
response = test_client.post("/api/v1/messages", headers=headers, json=body)
assert response.status_code == 422
assert response.status_code == 400
assert response.json() == {
"detail": [
{
"loc": [],
"msg": "'key error message'",
"type": "value_error.bad_data",
}
]
"detail": {
"type": "missing_parameter",
"errorMessage": "'key error message'",
}
}


def test_send_message_with_wrong_api_key(test_client):
headers = {"Authorization": "wrong api key"}
response = test_client.post("/api/v1/messages", headers=headers, json={})
assert response.status_code == 403
assert response.json()["detail"] == "Permission denied"
assert response.json()["detail"] == {
"type": "not_authorized",
"errorMessage": "Permission denied",
}


def test_send_message_without_authorization_header(test_client):
response = test_client.post("/api/v1/messages", json={})
assert response.status_code == 401
assert response.json()["detail"] == "Requires authentication"
assert response.json()["detail"] == {
"type": "not_authenticated",
"errorMessage": "Requires authentication",
}

0 comments on commit a6609cb

Please sign in to comment.