diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0c9c4195..d859323e 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -42,8 +42,8 @@ jobs: if (context.ref.startsWith("refs/tags/")) { return context.ref.slice(10); } - if (context.ref === "refs/heads/develop") { - return "develop"; + if (context.ref === "refs/heads/main") { + return "latest"; } } return "FALSE"; diff --git a/Dockerfile b/Dockerfile index 1653ab81..23ca4b37 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,4 +7,4 @@ RUN poetry install COPY app /app/app EXPOSE 8000 -CMD ["poetry", "run", "gunicorn", "app.main:app", "-w", "32", "-k", "uvicorn.workers.UvicornWorker", "-b", "0.0.0.0:8000"] +CMD ["poetry", "run", "gunicorn", "app.main:app", "-w", "8", "-k", "uvicorn.workers.UvicornWorker", "-b", "0.0.0.0:8000"] diff --git a/app/config.py b/app/config.py index bbcb9b8a..82271dd7 100644 --- a/app/config.py +++ b/app/config.py @@ -1,6 +1,19 @@ import os -from pydantic import BaseModel + from pyaml_env import parse_config +from pydantic import BaseModel + + +class LLMModelConfig(BaseModel): + name: str + description: str + llm_credentials: dict + + +class APIKeyConfig(BaseModel): + token: str + comment: str + llm_access: list[str] class CacheSettings(BaseModel): @@ -13,9 +26,9 @@ class CacheParams(BaseModel): class Settings(BaseModel): class PyrisSettings(BaseModel): - api_key: str + api_keys: list[APIKeyConfig] + llms: dict[str, LLMModelConfig] cache: CacheSettings - llm: dict pyris: PyrisSettings diff --git a/app/dependencies.py b/app/dependencies.py index 4436fda5..91be6c78 100644 --- a/app/dependencies.py +++ b/app/dependencies.py @@ -1,10 +1,11 @@ from fastapi import Depends from starlette.requests import Request as StarletteRequest -from app.config import settings +from app.config import settings, APIKeyConfig from app.core.custom_exceptions import ( PermissionDeniedException, RequiresAuthenticationException, + InvalidModelException, ) @@ -17,7 +18,27 @@ def _get_api_key(request: StarletteRequest) -> str: return authorization_header -class PermissionsValidator: - def __call__(self, api_key: str = Depends(_get_api_key)): - if api_key != settings.pyris.api_key: - raise PermissionDeniedException +class TokenValidator: + async def __call__( + self, api_key: str = Depends(_get_api_key) + ) -> APIKeyConfig: + for key in settings.pyris.api_keys: + if key.token == api_key: + return key + raise PermissionDeniedException + + +class TokenPermissionsValidator: + async def __call__( + self, request: StarletteRequest, api_key: str = Depends(_get_api_key) + ): + for key in settings.pyris.api_keys: + if key.token == api_key: + body = await request.json() + if body.get("preferredModel") in key.llm_access: + return + else: + raise InvalidModelException( + str(body.get("preferredModel")) + ) + raise PermissionDeniedException diff --git a/app/main.py b/app/main.py index e6b84fd1..3c71b8b1 100644 --- a/app/main.py +++ b/app/main.py @@ -4,6 +4,7 @@ from app.routes.messages import router as messages_router from app.routes.health import router as health_router from app.services.hazelcast_client import hazelcast_client +from app.routes.models import router as models_router app = FastAPI(default_response_class=ORJSONResponse) @@ -15,3 +16,4 @@ async def shutdown(): app.include_router(messages_router) app.include_router(health_router) +app.include_router(models_router) diff --git a/app/models/dtos.py b/app/models/dtos.py index 899ec00d..2fe7c28f 100644 --- a/app/models/dtos.py +++ b/app/models/dtos.py @@ -3,12 +3,6 @@ from datetime import datetime -class LLMModel(str, Enum): - GPT35_TURBO = "GPT35_TURBO" - GPT35_TURBO_16K_0613 = "GPT35_TURBO_16K_0613" - GPT35_TURBO_0613 = "GPT35_TURBO_0613" - - class LLMStatus(str, Enum): UP = "UP" DOWN = "DOWN" @@ -41,10 +35,16 @@ class Message(BaseModel): ) content: list[Content] - used_model: LLMModel = Field(..., alias="usedModel") + used_model: str = Field(..., alias="usedModel") message: Message class ModelStatus(BaseModel): - model: LLMModel + model: str status: LLMStatus + + +class LLMModelResponse(BaseModel): + id: str + name: str + description: str diff --git a/app/routes/health.py b/app/routes/health.py index 48594b56..8bdb26b6 100644 --- a/app/routes/health.py +++ b/app/routes/health.py @@ -1,27 +1,28 @@ from fastapi import APIRouter, Depends -from app.dependencies import PermissionsValidator -from app.models.dtos import LLMModel, LLMStatus, ModelStatus +from app.dependencies import TokenValidator +from app.models.dtos import LLMStatus, ModelStatus from app.services.guidance_wrapper import GuidanceWrapper from app.services.circuit_breaker import CircuitBreaker +from app.config import settings router = APIRouter() -@router.get("/api/v1/health", dependencies=[Depends(PermissionsValidator())]) +@router.get("/api/v1/health", dependencies=[Depends(TokenValidator())]) def checkhealth() -> list[ModelStatus]: result = [] - for model in LLMModel: + for key, model in settings.pyris.llms.items(): circuit_status = CircuitBreaker.get_status( checkhealth_func=GuidanceWrapper(model=model).is_up, - cache_key=model, + cache_key=key, ) status = ( LLMStatus.UP if circuit_status == CircuitBreaker.Status.CLOSED else LLMStatus.DOWN ) - result.append(ModelStatus(model=model, status=status)) + result.append(ModelStatus(model=key, status=status)) return result diff --git a/app/routes/messages.py b/app/routes/messages.py index 48d9224d..aae6279b 100644 --- a/app/routes/messages.py +++ b/app/routes/messages.py @@ -1,6 +1,6 @@ -from fastapi import APIRouter, Depends from datetime import datetime, timezone +from fastapi import APIRouter, Depends from parsimonious.exceptions import IncompleteParseError from app.core.custom_exceptions import ( @@ -9,20 +9,21 @@ InternalServerException, InvalidModelException, ) -from app.dependencies import PermissionsValidator -from app.models.dtos import SendMessageRequest, SendMessageResponse, LLMModel +from app.dependencies import TokenPermissionsValidator +from app.models.dtos import SendMessageRequest, SendMessageResponse from app.services.circuit_breaker import CircuitBreaker from app.services.guidance_wrapper import GuidanceWrapper +from app.config import settings router = APIRouter(tags=["messages"]) @router.post( - "/api/v1/messages", dependencies=[Depends(PermissionsValidator())] + "/api/v1/messages", dependencies=[Depends(TokenPermissionsValidator())] ) def send_message(body: SendMessageRequest) -> SendMessageResponse: try: - model = LLMModel(body.preferred_model) + model = settings.pyris.llms[body.preferred_model] except ValueError as e: raise InvalidModelException(str(e)) @@ -35,7 +36,7 @@ def send_message(body: SendMessageRequest) -> SendMessageResponse: try: content = CircuitBreaker.protected_call( func=guidance.query, - cache_key=model, + cache_key=body.preferred_model, accepted_exceptions=(KeyError, SyntaxError, IncompleteParseError), ) except KeyError as e: diff --git a/app/routes/models.py b/app/routes/models.py new file mode 100644 index 00000000..782e913c --- /dev/null +++ b/app/routes/models.py @@ -0,0 +1,23 @@ +from fastapi import APIRouter, Depends + +from app.dependencies import TokenValidator +from app.config import settings, APIKeyConfig +from app.models.dtos import LLMModelResponse + +router = APIRouter(tags=["models"]) + + +@router.get("/api/v1/models") +def send_message( + api_key_config: APIKeyConfig = Depends(TokenValidator()), +) -> list[LLMModelResponse]: + llm_ids = api_key_config.llm_access + + # Convert the result of this to a list of LLMModelResponse + return [ + LLMModelResponse( + id=key, name=config.name, description=config.description + ) + for key, config in settings.pyris.llms.items() + if key in llm_ids + ] diff --git a/app/services/guidance_wrapper.py b/app/services/guidance_wrapper.py index 0858f111..54f1130a 100644 --- a/app/services/guidance_wrapper.py +++ b/app/services/guidance_wrapper.py @@ -1,7 +1,7 @@ import guidance -from app.config import settings -from app.models.dtos import Content, ContentType, LLMModel +from app.config import LLMModelConfig +from app.models.dtos import Content, ContentType class GuidanceWrapper: @@ -11,7 +11,7 @@ def __new__(cls, *_, **__): return super(GuidanceWrapper, cls).__new__(cls) def __init__( - self, model: LLMModel, handlebars="", parameters=None + self, model: LLMModelConfig, handlebars="", parameters=None ) -> None: if parameters is None: parameters = {} @@ -63,5 +63,5 @@ def is_up(self) -> bool: return content == "1" def _get_llm(self): - llm_credentials = settings.pyris.llm[self.model.value] + llm_credentials = self.model.llm_credentials return guidance.llms.OpenAI(**llm_credentials) diff --git a/application-docker.test.yml b/application-docker.test.yml index ff2bba88..9b4a5276 100644 --- a/application-docker.test.yml +++ b/application-docker.test.yml @@ -4,19 +4,20 @@ pyris: hazelcast: host: cache port: 5701 - llm: - GPT35_TURBO: - model: - token: - api_base: - api_type: - api_version: - deployment_id: - GPT35_TURBO_16K_0613: - model: - token: - chat_mode: - GPT35_TURBO_0613: - model: - token: - chat_mode: + api_keys: + - token: "secret" + comment: "DUMMY" + llm_access: + - DUMMY + + llms: + DUMMY: + name: "Dummy model" + description: "Dummy model for testing" + llm_credentials: + api_base: "" + api_type: "" + api_version: "" + deployment_id: "" + model: "" + token: "" diff --git a/application.example.yml b/application.example.yml index 2f607570..05dca1d5 100644 --- a/application.example.yml +++ b/application.example.yml @@ -1,22 +1,22 @@ pyris: - api_key: secret cache: hazelcast: host: localhost port: 5701 - llm: - GPT35_TURBO: - model: - token: - api_base: - api_type: - api_version: - deployment_id: - GPT35_TURBO_16K_0613: - model: - token: - chat_mode: - GPT35_TURBO_0613: - model: - token: - chat_mode: + api_keys: + - token: "secret" + comment: "DUMMY" + llm_access: + - DUMMY + + llms: + DUMMY: + name: "Dummy model" + description: "Dummy model for testing" + llm_credentials: + api_base: "" + api_type: "" + api_version: "" + deployment_id: "" + model: "" + token: "" diff --git a/application.test.yml b/application.test.yml index 2f607570..7948fd02 100644 --- a/application.test.yml +++ b/application.test.yml @@ -4,19 +4,20 @@ pyris: hazelcast: host: localhost port: 5701 - llm: - GPT35_TURBO: - model: - token: - api_base: - api_type: - api_version: - deployment_id: - GPT35_TURBO_16K_0613: - model: - token: - chat_mode: - GPT35_TURBO_0613: - model: - token: - chat_mode: + api_keys: + - token: "secret" + comment: "DUMMY" + llm_access: + - DUMMY + + llms: + DUMMY: + name: "Dummy model" + description: "Dummy model for testing" + llm_credentials: + api_base: "" + api_type: "" + api_version: "" + deployment_id: "" + model: "" + token: "" diff --git a/tests/routes/health_test.py b/tests/routes/health_test.py index 8e019c5e..ce6fd725 100644 --- a/tests/routes/health_test.py +++ b/tests/routes/health_test.py @@ -1,17 +1,40 @@ -from app.models.dtos import LLMModel +import pytest +import app.config as config +@pytest.fixture(scope="function") +def model_configs(): + llm_model_config1 = config.LLMModelConfig( + name="test1", description="test", llm_credentials={} + ) + llm_model_config2 = config.LLMModelConfig( + name="test2", description="test", llm_credentials={} + ) + llm_model_config3 = config.LLMModelConfig( + name="test3", description="test", llm_credentials={} + ) + config.settings.pyris.llms = { + "GPT35_TURBO": llm_model_config1, + "GPT35_TURBO_16K_0613": llm_model_config2, + "GPT35_TURBO_0613": llm_model_config3, + } + + +@pytest.mark.usefixtures("model_configs") def test_checkhealth(test_client, headers, mocker): objA = mocker.Mock() objB = mocker.Mock() objC = mocker.Mock() def side_effect_func(*_, **kwargs): - if kwargs["model"] == LLMModel.GPT35_TURBO: + if kwargs["model"] == config.settings.pyris.llms["GPT35_TURBO"]: return objA - elif kwargs["model"] == LLMModel.GPT35_TURBO_16K_0613: + elif ( + kwargs["model"] + == config.settings.pyris.llms["GPT35_TURBO_16K_0613"] + ): return objB - elif kwargs["model"] == LLMModel.GPT35_TURBO_0613: + elif kwargs["model"] == config.settings.pyris.llms["GPT35_TURBO_0613"]: return objC mocker.patch( @@ -37,6 +60,7 @@ def side_effect_func(*_, **kwargs): objC.is_up.assert_called_once() +@pytest.mark.usefixtures("model_configs") def test_checkhealth_save_to_cache( test_client, test_cache_store, headers, mocker ): @@ -45,11 +69,14 @@ def test_checkhealth_save_to_cache( objC = mocker.Mock() def side_effect_func(*_, **kwargs): - if kwargs["model"] == LLMModel.GPT35_TURBO: + if kwargs["model"] == config.settings.pyris.llms["GPT35_TURBO"]: return objA - elif kwargs["model"] == LLMModel.GPT35_TURBO_16K_0613: + elif ( + kwargs["model"] + == config.settings.pyris.llms["GPT35_TURBO_16K_0613"] + ): return objB - elif kwargs["model"] == LLMModel.GPT35_TURBO_0613: + elif kwargs["model"] == config.settings.pyris.llms["GPT35_TURBO_0613"]: return objC mocker.patch( @@ -60,17 +87,15 @@ def side_effect_func(*_, **kwargs): mocker.patch.object(objB, "is_up", return_value=False) mocker.patch.object(objC, "is_up", return_value=True) - assert test_cache_store.get("LLMModel.GPT35_TURBO:status") is None - assert test_cache_store.get("LLMModel.GPT35_TURBO_16K_0613:status") is None - assert test_cache_store.get("LLMModel.GPT35_TURBO_0613:status") is None + assert test_cache_store.get("GPT35_TURBO:status") is None + assert test_cache_store.get("GPT35_TURBO_16K_0613:status") is None + assert test_cache_store.get("GPT35_TURBO_0613:status") is None test_client.get("/api/v1/health", headers=headers) - assert test_cache_store.get("LLMModel.GPT35_TURBO:status") == "CLOSED" - assert ( - test_cache_store.get("LLMModel.GPT35_TURBO_16K_0613:status") == "OPEN" - ) - assert test_cache_store.get("LLMModel.GPT35_TURBO_0613:status") == "CLOSED" + assert test_cache_store.get("GPT35_TURBO:status") == "CLOSED" + assert test_cache_store.get("GPT35_TURBO_16K_0613:status") == "OPEN" + assert test_cache_store.get("GPT35_TURBO_0613:status") == "CLOSED" def test_checkhealth_with_wrong_api_key(test_client): diff --git a/tests/routes/messages_test.py b/tests/routes/messages_test.py index c030c616..f15567c9 100644 --- a/tests/routes/messages_test.py +++ b/tests/routes/messages_test.py @@ -1,9 +1,24 @@ +import pytest from freezegun import freeze_time from app.models.dtos import Content, ContentType from app.services.guidance_wrapper import GuidanceWrapper +import app.config as config + + +@pytest.fixture(scope="function") +def model_configs(): + llm_model_config = config.LLMModelConfig( + name="test", description="test", llm_credentials={} + ) + config.settings.pyris.llms = {"GPT35_TURBO": llm_model_config} + api_key_config = config.APIKeyConfig( + token="secret", comment="test", llm_access=["GPT35_TURBO"] + ) + config.settings.pyris.api_keys = [api_key_config] @freeze_time("2023-06-16 03:21:34 +02:00") +@pytest.mark.usefixtures("model_configs") def test_send_message(test_client, headers, mocker): mocker.patch.object( GuidanceWrapper, @@ -39,8 +54,17 @@ def test_send_message(test_client, headers, mocker): } -def test_send_message_missing_params(test_client, headers): +def test_send_message_missing_model(test_client, headers): response = test_client.post("/api/v1/messages", headers=headers, json={}) + assert response.status_code == 404 + + +def test_send_message_missing_params(test_client, headers): + response = test_client.post( + "/api/v1/messages", + headers=headers, + json={"preferredModel": "GPT35_TURBO"}, + ) assert response.status_code == 422 assert response.json() == { "detail": [ @@ -49,11 +73,6 @@ def test_send_message_missing_params(test_client, headers): "msg": "field required", "type": "value_error.missing", }, - { - "loc": ["body", "preferredModel"], - "msg": "field required", - "type": "value_error.missing", - }, { "loc": ["body", "parameters"], "msg": "field required", @@ -63,6 +82,7 @@ def test_send_message_missing_params(test_client, headers): } +@pytest.mark.usefixtures("model_configs") def test_send_message_raise_value_error(test_client, headers, mocker): mocker.patch.object( GuidanceWrapper, "query", side_effect=ValueError("value error message") @@ -85,6 +105,7 @@ def test_send_message_raise_value_error(test_client, headers, mocker): } +@pytest.mark.usefixtures("model_configs") def test_send_message_raise_key_error(test_client, headers, mocker): mocker.patch.object( GuidanceWrapper, "query", side_effect=KeyError("key error message") @@ -126,6 +147,7 @@ def test_send_message_without_authorization_header(test_client): } +@pytest.mark.usefixtures("model_configs") def test_send_message_fail_three_times( test_client, mocker, headers, test_cache_store ): @@ -147,8 +169,8 @@ def test_send_message_fail_three_times( # Restrict access response = test_client.post("/api/v1/messages", headers=headers, json=body) - assert test_cache_store.get("LLMModel.GPT35_TURBO:status") == "OPEN" - assert test_cache_store.get("LLMModel.GPT35_TURBO:num_failures") == 3 + assert test_cache_store.get("GPT35_TURBO:status") == "OPEN" + assert test_cache_store.get("GPT35_TURBO:num_failures") == 3 assert response.status_code == 500 assert response.json() == { "detail": { @@ -158,7 +180,7 @@ def test_send_message_fail_three_times( } # Can access after TTL - test_cache_store.delete("LLMModel.GPT35_TURBO:status") - test_cache_store.delete("LLMModel.GPT35_TURBO:num_failures") + test_cache_store.delete("GPT35_TURBO:status") + test_cache_store.delete("GPT35_TURBO:num_failures") response = test_client.post("/api/v1/messages", headers=headers, json=body) assert response.status_code == 500 diff --git a/tests/services/guidance_wrapper_test.py b/tests/services/guidance_wrapper_test.py index e1e84df7..b20bb666 100644 --- a/tests/services/guidance_wrapper_test.py +++ b/tests/services/guidance_wrapper_test.py @@ -1,8 +1,13 @@ import pytest import guidance -from app.models.dtos import Content, ContentType, LLMModel +from app.models.dtos import Content, ContentType from app.services.guidance_wrapper import GuidanceWrapper +from app.config import LLMModelConfig + +llm_model_config = LLMModelConfig( + name="test", description="test", llm_credentials={} +) def test_query_success(mocker): @@ -17,7 +22,7 @@ def test_query_success(mocker): {{gen 'response' temperature=0.0 max_tokens=500}}{{~/assistant}}""" guidance_wrapper = GuidanceWrapper( - model=LLMModel.GPT35_TURBO, + model=llm_model_config, handlebars=handlebars, parameters={"query": "Some query"}, ) @@ -41,7 +46,7 @@ def test_query_missing_required_params(mocker): {{gen 'response' temperature=0.0 max_tokens=500}}{{~/assistant}}""" guidance_wrapper = GuidanceWrapper( - model=LLMModel.GPT35_TURBO, + model=llm_model_config, handlebars=handlebars, parameters={"somethingelse": "Something"}, ) @@ -63,7 +68,7 @@ def test_query_handlebars_not_generate_response(mocker): handlebars = "Not a valid handlebars" guidance_wrapper = GuidanceWrapper( - model=LLMModel.GPT35_TURBO, + model=llm_model_config, handlebars=handlebars, parameters={"query": "Something"}, )