From a95f0aa74c2be38088c11e306c92015d503cb7d8 Mon Sep 17 00:00:00 2001 From: Timor Morrien Date: Tue, 18 Jul 2023 14:56:22 +0200 Subject: [PATCH] `Development`: Improve LLM handling in Pyris (#12) --- app/config.py | 19 +++++++++++--- app/dependencies.py | 31 ++++++++++++++++++---- app/main.py | 2 ++ app/models/dtos.py | 14 +++++----- app/routes/messages.py | 11 ++++---- app/routes/models.py | 23 +++++++++++++++++ app/services/guidance_wrapper.py | 8 +++--- application.example.yml | 34 ++++++++++++------------- application.test.yml | 34 ++++++++++++------------- tests/routes/messages_test.py | 26 ++++++++++++++----- tests/services/guidance_wrapper_test.py | 13 +++++++--- 11 files changed, 147 insertions(+), 68 deletions(-) create mode 100644 app/routes/models.py diff --git a/app/config.py b/app/config.py index b8df14a6..0a12c02e 100644 --- a/app/config.py +++ b/app/config.py @@ -1,12 +1,25 @@ import os -from pydantic import BaseModel + from pyaml_env import parse_config +from pydantic import BaseModel + + +class LLMModelConfig(BaseModel): + name: str + description: str + llm_credentials: dict + + +class APIKeyConfig(BaseModel): + token: str + comment: str + llm_access: list[str] class Settings(BaseModel): class PyrisSettings(BaseModel): - api_key: str - llm: dict + api_keys: list[APIKeyConfig] + llms: dict[str, LLMModelConfig] pyris: PyrisSettings diff --git a/app/dependencies.py b/app/dependencies.py index 4436fda5..91be6c78 100644 --- a/app/dependencies.py +++ b/app/dependencies.py @@ -1,10 +1,11 @@ from fastapi import Depends from starlette.requests import Request as StarletteRequest -from app.config import settings +from app.config import settings, APIKeyConfig from app.core.custom_exceptions import ( PermissionDeniedException, RequiresAuthenticationException, + InvalidModelException, ) @@ -17,7 +18,27 @@ def _get_api_key(request: StarletteRequest) -> str: return authorization_header -class PermissionsValidator: - def __call__(self, api_key: str = Depends(_get_api_key)): - if api_key != settings.pyris.api_key: - raise PermissionDeniedException +class TokenValidator: + async def __call__( + self, api_key: str = Depends(_get_api_key) + ) -> APIKeyConfig: + for key in settings.pyris.api_keys: + if key.token == api_key: + return key + raise PermissionDeniedException + + +class TokenPermissionsValidator: + async def __call__( + self, request: StarletteRequest, api_key: str = Depends(_get_api_key) + ): + for key in settings.pyris.api_keys: + if key.token == api_key: + body = await request.json() + if body.get("preferredModel") in key.llm_access: + return + else: + raise InvalidModelException( + str(body.get("preferredModel")) + ) + raise PermissionDeniedException diff --git a/app/main.py b/app/main.py index ee3665ab..276d6dee 100644 --- a/app/main.py +++ b/app/main.py @@ -2,7 +2,9 @@ from fastapi.responses import ORJSONResponse from app.routes.messages import router as messages_router +from app.routes.models import router as models_router app = FastAPI(default_response_class=ORJSONResponse) app.include_router(messages_router) +app.include_router(models_router) diff --git a/app/models/dtos.py b/app/models/dtos.py index 4b58908a..47ffaa24 100644 --- a/app/models/dtos.py +++ b/app/models/dtos.py @@ -3,12 +3,6 @@ from datetime import datetime -class LLMModel(str, Enum): - GPT35_TURBO = "GPT35_TURBO" - GPT35_TURBO_16K_0613 = "GPT35_TURBO_16K_0613" - GPT35_TURBO_0613 = "GPT35_TURBO_0613" - - class ContentType(str, Enum): TEXT = "text" @@ -35,5 +29,11 @@ class Message(BaseModel): ) content: list[Content] - used_model: LLMModel = Field(..., alias="usedModel") + used_model: str = Field(..., alias="usedModel") message: Message + + +class LLMModelResponse(BaseModel): + id: str + name: str + description: str diff --git a/app/routes/messages.py b/app/routes/messages.py index 04382d5c..40c584b8 100644 --- a/app/routes/messages.py +++ b/app/routes/messages.py @@ -1,6 +1,6 @@ -from fastapi import APIRouter, Depends from datetime import datetime, timezone +from fastapi import APIRouter, Depends from parsimonious.exceptions import IncompleteParseError from app.core.custom_exceptions import ( @@ -9,19 +9,20 @@ InternalServerException, InvalidModelException, ) -from app.dependencies import PermissionsValidator -from app.models.dtos import SendMessageRequest, SendMessageResponse, LLMModel +from app.dependencies import TokenPermissionsValidator +from app.models.dtos import SendMessageRequest, SendMessageResponse from app.services.guidance_wrapper import GuidanceWrapper +from app.config import settings router = APIRouter(tags=["messages"]) @router.post( - "/api/v1/messages", dependencies=[Depends(PermissionsValidator())] + "/api/v1/messages", dependencies=[Depends(TokenPermissionsValidator())] ) def send_message(body: SendMessageRequest) -> SendMessageResponse: try: - model = LLMModel(body.preferred_model) + model = settings.pyris.llms[body.preferred_model] except ValueError as e: raise InvalidModelException(str(e)) diff --git a/app/routes/models.py b/app/routes/models.py new file mode 100644 index 00000000..782e913c --- /dev/null +++ b/app/routes/models.py @@ -0,0 +1,23 @@ +from fastapi import APIRouter, Depends + +from app.dependencies import TokenValidator +from app.config import settings, APIKeyConfig +from app.models.dtos import LLMModelResponse + +router = APIRouter(tags=["models"]) + + +@router.get("/api/v1/models") +def send_message( + api_key_config: APIKeyConfig = Depends(TokenValidator()), +) -> list[LLMModelResponse]: + llm_ids = api_key_config.llm_access + + # Convert the result of this to a list of LLMModelResponse + return [ + LLMModelResponse( + id=key, name=config.name, description=config.description + ) + for key, config in settings.pyris.llms.items() + if key in llm_ids + ] diff --git a/app/services/guidance_wrapper.py b/app/services/guidance_wrapper.py index 9aeca440..43f94177 100644 --- a/app/services/guidance_wrapper.py +++ b/app/services/guidance_wrapper.py @@ -1,14 +1,14 @@ import guidance -from app.config import settings -from app.models.dtos import Content, ContentType, LLMModel +from app.config import LLMModelConfig +from app.models.dtos import Content, ContentType class GuidanceWrapper: """A wrapper service to all guidance package's methods.""" def __init__( - self, model: LLMModel, handlebars: str, parameters=None + self, model: LLMModelConfig, handlebars: str, parameters=None ) -> None: if parameters is None: parameters = {} @@ -39,5 +39,5 @@ def query(self) -> Content: return Content(type=ContentType.TEXT, textContent=result["response"]) def _get_llm(self): - llm_credentials = settings.pyris.llm[self.model.value] + llm_credentials = self.model.llm_credentials return guidance.llms.OpenAI(**llm_credentials) diff --git a/application.example.yml b/application.example.yml index 7b850048..b41935a2 100644 --- a/application.example.yml +++ b/application.example.yml @@ -1,18 +1,18 @@ pyris: - api_key: secret - llm: - GPT35_TURBO: - model: - token: - api_base: - api_type: - api_version: - deployment_id: - GPT35_TURBO_16K_0613: - model: - token: - chat_mode: - GPT35_TURBO_0613: - model: - token: - chat_mode: + api_keys: + - token: "secret" + comment: "DUMMY" + llm_access: + - DUMMY + + llms: + DUMMY: + name: "Dummy model" + description: "Dummy model for testing" + llm_credentials: + api_base: "" + api_type: "" + api_version: "" + deployment_id: "" + model: "" + token: "" diff --git a/application.test.yml b/application.test.yml index 7b850048..b41935a2 100644 --- a/application.test.yml +++ b/application.test.yml @@ -1,18 +1,18 @@ pyris: - api_key: secret - llm: - GPT35_TURBO: - model: - token: - api_base: - api_type: - api_version: - deployment_id: - GPT35_TURBO_16K_0613: - model: - token: - chat_mode: - GPT35_TURBO_0613: - model: - token: - chat_mode: + api_keys: + - token: "secret" + comment: "DUMMY" + llm_access: + - DUMMY + + llms: + DUMMY: + name: "Dummy model" + description: "Dummy model for testing" + llm_credentials: + api_base: "" + api_type: "" + api_version: "" + deployment_id: "" + model: "" + token: "" diff --git a/tests/routes/messages_test.py b/tests/routes/messages_test.py index 0c382573..c194cb67 100644 --- a/tests/routes/messages_test.py +++ b/tests/routes/messages_test.py @@ -1,6 +1,16 @@ from freezegun import freeze_time from app.models.dtos import Content, ContentType from app.services.guidance_wrapper import GuidanceWrapper +import app.config as config + +llm_model_config = config.LLMModelConfig( + name="test", description="test", llm_credentials={} +) +config.settings.pyris.llms = {"GPT35_TURBO": llm_model_config} +api_key_config = config.APIKeyConfig( + token="secret", comment="test", llm_access=["GPT35_TURBO"] +) +config.settings.pyris.api_keys = [api_key_config] @freeze_time("2023-06-16 03:21:34 +02:00") @@ -39,8 +49,17 @@ def test_send_message(test_client, headers, mocker): } -def test_send_message_missing_params(test_client, headers): +def test_send_message_missing_model(test_client, headers): response = test_client.post("/api/v1/messages", headers=headers, json={}) + assert response.status_code == 404 + + +def test_send_message_missing_params(test_client, headers): + response = test_client.post( + "/api/v1/messages", + headers=headers, + json={"preferredModel": "GPT35_TURBO"}, + ) assert response.status_code == 422 assert response.json() == { "detail": [ @@ -49,11 +68,6 @@ def test_send_message_missing_params(test_client, headers): "msg": "field required", "type": "value_error.missing", }, - { - "loc": ["body", "preferredModel"], - "msg": "field required", - "type": "value_error.missing", - }, { "loc": ["body", "parameters"], "msg": "field required", diff --git a/tests/services/guidance_wrapper_test.py b/tests/services/guidance_wrapper_test.py index e1e84df7..b20bb666 100644 --- a/tests/services/guidance_wrapper_test.py +++ b/tests/services/guidance_wrapper_test.py @@ -1,8 +1,13 @@ import pytest import guidance -from app.models.dtos import Content, ContentType, LLMModel +from app.models.dtos import Content, ContentType from app.services.guidance_wrapper import GuidanceWrapper +from app.config import LLMModelConfig + +llm_model_config = LLMModelConfig( + name="test", description="test", llm_credentials={} +) def test_query_success(mocker): @@ -17,7 +22,7 @@ def test_query_success(mocker): {{gen 'response' temperature=0.0 max_tokens=500}}{{~/assistant}}""" guidance_wrapper = GuidanceWrapper( - model=LLMModel.GPT35_TURBO, + model=llm_model_config, handlebars=handlebars, parameters={"query": "Some query"}, ) @@ -41,7 +46,7 @@ def test_query_missing_required_params(mocker): {{gen 'response' temperature=0.0 max_tokens=500}}{{~/assistant}}""" guidance_wrapper = GuidanceWrapper( - model=LLMModel.GPT35_TURBO, + model=llm_model_config, handlebars=handlebars, parameters={"somethingelse": "Something"}, ) @@ -63,7 +68,7 @@ def test_query_handlebars_not_generate_response(mocker): handlebars = "Not a valid handlebars" guidance_wrapper = GuidanceWrapper( - model=LLMModel.GPT35_TURBO, + model=llm_model_config, handlebars=handlebars, parameters={"query": "Something"}, )