Skip to content

Commit

Permalink
Development: Improve LLM handling in Pyris (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hialus authored Jul 18, 2023
1 parent b5a4043 commit a95f0aa
Show file tree
Hide file tree
Showing 11 changed files with 147 additions and 68 deletions.
19 changes: 16 additions & 3 deletions app/config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
import os
from pydantic import BaseModel

from pyaml_env import parse_config
from pydantic import BaseModel


class LLMModelConfig(BaseModel):
name: str
description: str
llm_credentials: dict


class APIKeyConfig(BaseModel):
token: str
comment: str
llm_access: list[str]


class Settings(BaseModel):
class PyrisSettings(BaseModel):
api_key: str
llm: dict
api_keys: list[APIKeyConfig]
llms: dict[str, LLMModelConfig]

pyris: PyrisSettings

Expand Down
31 changes: 26 additions & 5 deletions app/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from fastapi import Depends
from starlette.requests import Request as StarletteRequest
from app.config import settings
from app.config import settings, APIKeyConfig

from app.core.custom_exceptions import (
PermissionDeniedException,
RequiresAuthenticationException,
InvalidModelException,
)


Expand All @@ -17,7 +18,27 @@ def _get_api_key(request: StarletteRequest) -> str:
return authorization_header


class PermissionsValidator:
def __call__(self, api_key: str = Depends(_get_api_key)):
if api_key != settings.pyris.api_key:
raise PermissionDeniedException
class TokenValidator:
async def __call__(
self, api_key: str = Depends(_get_api_key)
) -> APIKeyConfig:
for key in settings.pyris.api_keys:
if key.token == api_key:
return key
raise PermissionDeniedException


class TokenPermissionsValidator:
async def __call__(
self, request: StarletteRequest, api_key: str = Depends(_get_api_key)
):
for key in settings.pyris.api_keys:
if key.token == api_key:
body = await request.json()
if body.get("preferredModel") in key.llm_access:
return
else:
raise InvalidModelException(
str(body.get("preferredModel"))
)
raise PermissionDeniedException
2 changes: 2 additions & 0 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from fastapi.responses import ORJSONResponse

from app.routes.messages import router as messages_router
from app.routes.models import router as models_router

app = FastAPI(default_response_class=ORJSONResponse)

app.include_router(messages_router)
app.include_router(models_router)
14 changes: 7 additions & 7 deletions app/models/dtos.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,6 @@
from datetime import datetime


class LLMModel(str, Enum):
GPT35_TURBO = "GPT35_TURBO"
GPT35_TURBO_16K_0613 = "GPT35_TURBO_16K_0613"
GPT35_TURBO_0613 = "GPT35_TURBO_0613"


class ContentType(str, Enum):
TEXT = "text"

Expand All @@ -35,5 +29,11 @@ class Message(BaseModel):
)
content: list[Content]

used_model: LLMModel = Field(..., alias="usedModel")
used_model: str = Field(..., alias="usedModel")
message: Message


class LLMModelResponse(BaseModel):
id: str
name: str
description: str
11 changes: 6 additions & 5 deletions app/routes/messages.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from fastapi import APIRouter, Depends
from datetime import datetime, timezone

from fastapi import APIRouter, Depends
from parsimonious.exceptions import IncompleteParseError

from app.core.custom_exceptions import (
Expand All @@ -9,19 +9,20 @@
InternalServerException,
InvalidModelException,
)
from app.dependencies import PermissionsValidator
from app.models.dtos import SendMessageRequest, SendMessageResponse, LLMModel
from app.dependencies import TokenPermissionsValidator
from app.models.dtos import SendMessageRequest, SendMessageResponse
from app.services.guidance_wrapper import GuidanceWrapper
from app.config import settings

router = APIRouter(tags=["messages"])


@router.post(
"/api/v1/messages", dependencies=[Depends(PermissionsValidator())]
"/api/v1/messages", dependencies=[Depends(TokenPermissionsValidator())]
)
def send_message(body: SendMessageRequest) -> SendMessageResponse:
try:
model = LLMModel(body.preferred_model)
model = settings.pyris.llms[body.preferred_model]
except ValueError as e:
raise InvalidModelException(str(e))

Expand Down
23 changes: 23 additions & 0 deletions app/routes/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from fastapi import APIRouter, Depends

from app.dependencies import TokenValidator
from app.config import settings, APIKeyConfig
from app.models.dtos import LLMModelResponse

router = APIRouter(tags=["models"])


@router.get("/api/v1/models")
def send_message(
api_key_config: APIKeyConfig = Depends(TokenValidator()),
) -> list[LLMModelResponse]:
llm_ids = api_key_config.llm_access

# Convert the result of this to a list of LLMModelResponse
return [
LLMModelResponse(
id=key, name=config.name, description=config.description
)
for key, config in settings.pyris.llms.items()
if key in llm_ids
]
8 changes: 4 additions & 4 deletions app/services/guidance_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import guidance

from app.config import settings
from app.models.dtos import Content, ContentType, LLMModel
from app.config import LLMModelConfig
from app.models.dtos import Content, ContentType


class GuidanceWrapper:
"""A wrapper service to all guidance package's methods."""

def __init__(
self, model: LLMModel, handlebars: str, parameters=None
self, model: LLMModelConfig, handlebars: str, parameters=None
) -> None:
if parameters is None:
parameters = {}
Expand Down Expand Up @@ -39,5 +39,5 @@ def query(self) -> Content:
return Content(type=ContentType.TEXT, textContent=result["response"])

def _get_llm(self):
llm_credentials = settings.pyris.llm[self.model.value]
llm_credentials = self.model.llm_credentials
return guidance.llms.OpenAI(**llm_credentials)
34 changes: 17 additions & 17 deletions application.example.yml
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
pyris:
api_key: secret
llm:
GPT35_TURBO:
model:
token:
api_base:
api_type:
api_version:
deployment_id:
GPT35_TURBO_16K_0613:
model:
token:
chat_mode:
GPT35_TURBO_0613:
model:
token:
chat_mode:
api_keys:
- token: "secret"
comment: "DUMMY"
llm_access:
- DUMMY

llms:
DUMMY:
name: "Dummy model"
description: "Dummy model for testing"
llm_credentials:
api_base: ""
api_type: ""
api_version: ""
deployment_id: ""
model: ""
token: ""
34 changes: 17 additions & 17 deletions application.test.yml
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
pyris:
api_key: secret
llm:
GPT35_TURBO:
model:
token:
api_base:
api_type:
api_version:
deployment_id:
GPT35_TURBO_16K_0613:
model:
token:
chat_mode:
GPT35_TURBO_0613:
model:
token:
chat_mode:
api_keys:
- token: "secret"
comment: "DUMMY"
llm_access:
- DUMMY

llms:
DUMMY:
name: "Dummy model"
description: "Dummy model for testing"
llm_credentials:
api_base: ""
api_type: ""
api_version: ""
deployment_id: ""
model: ""
token: ""
26 changes: 20 additions & 6 deletions tests/routes/messages_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
from freezegun import freeze_time
from app.models.dtos import Content, ContentType
from app.services.guidance_wrapper import GuidanceWrapper
import app.config as config

llm_model_config = config.LLMModelConfig(
name="test", description="test", llm_credentials={}
)
config.settings.pyris.llms = {"GPT35_TURBO": llm_model_config}
api_key_config = config.APIKeyConfig(
token="secret", comment="test", llm_access=["GPT35_TURBO"]
)
config.settings.pyris.api_keys = [api_key_config]


@freeze_time("2023-06-16 03:21:34 +02:00")
Expand Down Expand Up @@ -39,8 +49,17 @@ def test_send_message(test_client, headers, mocker):
}


def test_send_message_missing_params(test_client, headers):
def test_send_message_missing_model(test_client, headers):
response = test_client.post("/api/v1/messages", headers=headers, json={})
assert response.status_code == 404


def test_send_message_missing_params(test_client, headers):
response = test_client.post(
"/api/v1/messages",
headers=headers,
json={"preferredModel": "GPT35_TURBO"},
)
assert response.status_code == 422
assert response.json() == {
"detail": [
Expand All @@ -49,11 +68,6 @@ def test_send_message_missing_params(test_client, headers):
"msg": "field required",
"type": "value_error.missing",
},
{
"loc": ["body", "preferredModel"],
"msg": "field required",
"type": "value_error.missing",
},
{
"loc": ["body", "parameters"],
"msg": "field required",
Expand Down
13 changes: 9 additions & 4 deletions tests/services/guidance_wrapper_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import pytest
import guidance

from app.models.dtos import Content, ContentType, LLMModel
from app.models.dtos import Content, ContentType
from app.services.guidance_wrapper import GuidanceWrapper
from app.config import LLMModelConfig

llm_model_config = LLMModelConfig(
name="test", description="test", llm_credentials={}
)


def test_query_success(mocker):
Expand All @@ -17,7 +22,7 @@ def test_query_success(mocker):
{{gen 'response' temperature=0.0 max_tokens=500}}{{~/assistant}}"""

guidance_wrapper = GuidanceWrapper(
model=LLMModel.GPT35_TURBO,
model=llm_model_config,
handlebars=handlebars,
parameters={"query": "Some query"},
)
Expand All @@ -41,7 +46,7 @@ def test_query_missing_required_params(mocker):
{{gen 'response' temperature=0.0 max_tokens=500}}{{~/assistant}}"""

guidance_wrapper = GuidanceWrapper(
model=LLMModel.GPT35_TURBO,
model=llm_model_config,
handlebars=handlebars,
parameters={"somethingelse": "Something"},
)
Expand All @@ -63,7 +68,7 @@ def test_query_handlebars_not_generate_response(mocker):

handlebars = "Not a valid handlebars"
guidance_wrapper = GuidanceWrapper(
model=LLMModel.GPT35_TURBO,
model=llm_model_config,
handlebars=handlebars,
parameters={"query": "Something"},
)
Expand Down

0 comments on commit a95f0aa

Please sign in to comment.