diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 518e8b21..a6ff0bca 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -44,5 +44,8 @@ jobs: run: | poetry run flake8 . + - name: Start cache service + run: docker compose up -d --build cache + - name: Test with pytest run: poetry run pytest diff --git a/.gitignore b/.gitignore index 6d39fe93..75eeaafd 100644 --- a/.gitignore +++ b/.gitignore @@ -130,6 +130,7 @@ venv.bak/ # Config files application.yml +application-docker.yml # Spyder project settings .spyderproject diff --git a/README.md b/README.md index 6ab57a7f..45bc7772 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ ## With docker - (optional) Install **[Task](https://taskfile.dev)** +- Copy `application.example.yml` to `application-docker.yml` and change necessary values - Build docker: `docker compose build` or `task build` - Run server: `docker compose up -d` or `task up` - Stop server: `docker compose down` or `task down` diff --git a/Taskfile.yml b/Taskfile.yml index 9dabc13a..ee9dd81e 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -30,3 +30,8 @@ tasks: logs: cmds: - docker compose logs -f server + + console: + aliases: [ c ] + cmds: + - docker attach pyris-server # using to detach diff --git a/app/config.py b/app/config.py index 0a12c02e..82271dd7 100644 --- a/app/config.py +++ b/app/config.py @@ -16,19 +16,29 @@ class APIKeyConfig(BaseModel): llm_access: list[str] +class CacheSettings(BaseModel): + class CacheParams(BaseModel): + host: str + port: int + + hazelcast: CacheParams + + class Settings(BaseModel): class PyrisSettings(BaseModel): api_keys: list[APIKeyConfig] llms: dict[str, LLMModelConfig] + cache: CacheSettings pyris: PyrisSettings @classmethod def get_settings(cls): + postfix = "-docker" if "DOCKER" in os.environ else "" if "RUN_ENV" in os.environ and os.environ["RUN_ENV"] == "test": - file_path = "application.test.yml" + file_path = f"application{postfix}.test.yml" else: - file_path = "application.yml" + file_path = f"application{postfix}.yml" return Settings.parse_obj(parse_config(file_path)) diff --git a/app/main.py b/app/main.py index 276d6dee..3c71b8b1 100644 --- a/app/main.py +++ b/app/main.py @@ -2,9 +2,18 @@ from fastapi.responses import ORJSONResponse from app.routes.messages import router as messages_router +from app.routes.health import router as health_router +from app.services.hazelcast_client import hazelcast_client from app.routes.models import router as models_router app = FastAPI(default_response_class=ORJSONResponse) + +@app.on_event("shutdown") +async def shutdown(): + hazelcast_client.shutdown() + + app.include_router(messages_router) +app.include_router(health_router) app.include_router(models_router) diff --git a/app/models/dtos.py b/app/models/dtos.py index 47ffaa24..2fe7c28f 100644 --- a/app/models/dtos.py +++ b/app/models/dtos.py @@ -3,6 +3,12 @@ from datetime import datetime +class LLMStatus(str, Enum): + UP = "UP" + DOWN = "DOWN" + NOT_AVAILABLE = "NOT_AVAILABLE" + + class ContentType(str, Enum): TEXT = "text" @@ -33,6 +39,11 @@ class Message(BaseModel): message: Message +class ModelStatus(BaseModel): + model: str + status: LLMStatus + + class LLMModelResponse(BaseModel): id: str name: str diff --git a/app/routes/health.py b/app/routes/health.py new file mode 100644 index 00000000..8bdb26b6 --- /dev/null +++ b/app/routes/health.py @@ -0,0 +1,28 @@ +from fastapi import APIRouter, Depends + +from app.dependencies import TokenValidator +from app.models.dtos import LLMStatus, ModelStatus +from app.services.guidance_wrapper import GuidanceWrapper +from app.services.circuit_breaker import CircuitBreaker +from app.config import settings + +router = APIRouter() + + +@router.get("/api/v1/health", dependencies=[Depends(TokenValidator())]) +def checkhealth() -> list[ModelStatus]: + result = [] + + for key, model in settings.pyris.llms.items(): + circuit_status = CircuitBreaker.get_status( + checkhealth_func=GuidanceWrapper(model=model).is_up, + cache_key=key, + ) + status = ( + LLMStatus.UP + if circuit_status == CircuitBreaker.Status.CLOSED + else LLMStatus.DOWN + ) + result.append(ModelStatus(model=key, status=status)) + + return result diff --git a/app/routes/messages.py b/app/routes/messages.py index 40c584b8..aae6279b 100644 --- a/app/routes/messages.py +++ b/app/routes/messages.py @@ -11,6 +11,7 @@ ) from app.dependencies import TokenPermissionsValidator from app.models.dtos import SendMessageRequest, SendMessageResponse +from app.services.circuit_breaker import CircuitBreaker from app.services.guidance_wrapper import GuidanceWrapper from app.config import settings @@ -33,7 +34,11 @@ def send_message(body: SendMessageRequest) -> SendMessageResponse: ) try: - content = guidance.query() + content = CircuitBreaker.protected_call( + func=guidance.query, + cache_key=body.preferred_model, + accepted_exceptions=(KeyError, SyntaxError, IncompleteParseError), + ) except KeyError as e: raise MissingParameterException(str(e)) except (SyntaxError, IncompleteParseError) as e: diff --git a/app/services/cache.py b/app/services/cache.py new file mode 100644 index 00000000..4591b37b --- /dev/null +++ b/app/services/cache.py @@ -0,0 +1,154 @@ +from abc import ABC, abstractmethod +from datetime import datetime, timedelta +from typing import Optional, Any +from pydantic import BaseModel +from app.services.hazelcast_client import hazelcast_client + + +class CacheStoreInterface(ABC): + @abstractmethod + def get(self, name: str) -> Any: + """ + Get the value at key ``name`` + + Returns: + The value at key ``name``, or None if the key doesn't exist + """ + pass + + @abstractmethod + def set(self, name: str, value, ex: Optional[int] = None): + """ + Set the value at key ``name`` to ``value`` + + ``ex`` sets an expire flag on key ``name`` for ``ex`` seconds. + """ + pass + + @abstractmethod + def expire(self, name: str, ex: int): + """ + Set an expire flag on key ``name`` for ``ex`` seconds + """ + pass + + @abstractmethod + def incr(self, name: str) -> int: + """ + Increase the integer value of a key ``name`` by 1. + If the key does not exist, it is set to 0 + before performing the operation. + + Returns: + The value of key ``name`` after the increment + + Raises: + TypeError: if cache value is not an integer + """ + pass + + @abstractmethod + def flushdb(self): + """ + Delete all keys in the current database + """ + pass + + @abstractmethod + def delete(self, name: str): + """ + Delete the key specified by ``name`` + """ + pass + + +class InMemoryCacheStore(CacheStoreInterface): + class CacheValue(BaseModel): + value: Any + expired_at: Optional[datetime] + + def __init__(self): + self._cache: dict[str, Optional[InMemoryCacheStore.CacheValue]] = {} + + def get(self, name: str) -> Any: + current_time = datetime.now() + data = self._cache.get(name) + if data is None: + return None + if data.expired_at and current_time > data.expired_at: + del self._cache[name] + return None + return data.value + + def set(self, name: str, value, ex: Optional[int] = None): + self._cache[name] = InMemoryCacheStore.CacheValue( + value=value, expired_at=None + ) + if ex is not None: + self.expire(name, ex) + + def expire(self, name: str, ex: int): + current_time = datetime.now() + expired_at = current_time + timedelta(seconds=ex) + data = self._cache.get(name) + if data is None: + return + data.expired_at = expired_at + + def incr(self, name: str) -> int: + value = self.get(name) + if value is None: + self._cache[name] = InMemoryCacheStore.CacheValue( + value=1, expired_at=None + ) + return 1 + if isinstance(value, int): + value += 1 + self.set(name, value) + return value + raise TypeError("value is not an integer") + + def flushdb(self): + self._cache = {} + + def delete(self, name: str): + if name in self._cache: + del self._cache[name] + + +class HazelcastCacheStore(CacheStoreInterface): + def __init__(self): + self._cache = hazelcast_client.get_map("cache_store").blocking() + + def get(self, name: str) -> Any: + return self._cache.get(name) + + def set(self, name: str, value, ex: Optional[int] = None): + self._cache.put(name, value, ex) + + def expire(self, name: str, ex: int): + self._cache.set_ttl(name, ex) + + def incr(self, name: str) -> int: + flag = False + value = 0 + while not flag: + value = self._cache.get(name) + if value is None: + self._cache.set(name, 1) + return 1 + if not isinstance(value, int): + raise TypeError("value is not an integer") + + flag = self._cache.replace_if_same(name, value, value + 1) + + return value + 1 + + def flushdb(self): + self._cache.clear() + + def delete(self, name: str): + self._cache.remove(name) + + +cache_store = HazelcastCacheStore() diff --git a/app/services/circuit_breaker.py b/app/services/circuit_breaker.py new file mode 100644 index 00000000..2cd1fa1d --- /dev/null +++ b/app/services/circuit_breaker.py @@ -0,0 +1,92 @@ +import logging +from enum import Enum +from app.services.cache import cache_store + +log = logging.getLogger(__name__) + + +class CircuitBreaker: + """Circuit breaker pattern.""" + + MAX_FAILURES = 3 + CLOSED_TTL_SECONDS = 300 + OPEN_TTL_SECONDS = 30 + + class Status(str, Enum): + OPEN = "OPEN" + CLOSED = "CLOSED" + + @classmethod + def protected_call( + cls, func, cache_key: str, accepted_exceptions: tuple = () + ): + """Wrap function call to avoid too many failures in a row. + + Params: + func: function to be called + cache_key: key to be used in the cache store + accepted_exceptions: exceptions that are not considered failures + + Raises: + ValueError: if within the last OPEN_TTL_SECONDS seconds, + the function throws an exception for the MAX_FAILURES-th time. + """ + + num_failures_key = f"{cache_key}:num_failures" + status_key = f"{cache_key}:status" + + status = cache_store.get(status_key) + if status == cls.Status.OPEN: + raise ValueError("Too many failures! Please try again later.") + + try: + response = func() + cache_store.set( + status_key, cls.Status.CLOSED, ex=cls.CLOSED_TTL_SECONDS + ) + return response + except accepted_exceptions as e: + raise e + except Exception as e: + num_failures = cache_store.incr(num_failures_key) + cache_store.expire(num_failures_key, cls.OPEN_TTL_SECONDS) + if num_failures >= cls.MAX_FAILURES: + cache_store.set( + status_key, cls.Status.OPEN, ex=cls.OPEN_TTL_SECONDS + ) + + raise e + + @classmethod + def get_status(cls, checkhealth_func, cache_key: str): + """Get the status of the cache_key. + If the key is not in the cache, performs the function call. + Otherwise, returns the cached value. + + Params: + checkhealth_func: function to be called. Only returns a boolean. + cache_key: key to be used in the cache store + + Returns: Status of the cache_key. + Circuit is CLOSED if cache_key is up and running, and OPEN otherwise. + """ + status_key = f"{cache_key}:status" + status = cache_store.get(status_key) + + if status: + return cls.Status(status) + + is_up = False + try: + is_up = checkhealth_func() + except Exception as e: + log.error(e) + + if is_up: + cache_store.set( + status_key, cls.Status.CLOSED, ex=cls.CLOSED_TTL_SECONDS + ) + return cls.Status.CLOSED + + cache_store.set(status_key, cls.Status.OPEN, ex=cls.OPEN_TTL_SECONDS) + return cls.Status.OPEN diff --git a/app/services/guidance_wrapper.py b/app/services/guidance_wrapper.py index 43f94177..54f1130a 100644 --- a/app/services/guidance_wrapper.py +++ b/app/services/guidance_wrapper.py @@ -7,8 +7,11 @@ class GuidanceWrapper: """A wrapper service to all guidance package's methods.""" + def __new__(cls, *_, **__): + return super(GuidanceWrapper, cls).__new__(cls) + def __init__( - self, model: LLMModelConfig, handlebars: str, parameters=None + self, model: LLMModelConfig, handlebars="", parameters=None ) -> None: if parameters is None: parameters = {} @@ -38,6 +41,27 @@ def query(self) -> Content: return Content(type=ContentType.TEXT, textContent=result["response"]) + def is_up(self) -> bool: + """Check if the chosen LLM model is up. + + Returns: + True if the model is up, False otherwise. + """ + + guidance.llms.OpenAI.cache.clear() + handlebars = """ + {{#user~}}Say 1{{~/user}} + {{#assistant~}} + {{gen 'response' temperature=0.0 max_tokens=1}} + {{~/assistant}} + """ + content = ( + GuidanceWrapper(model=self.model, handlebars=handlebars) + .query() + .text_content + ) + return content == "1" + def _get_llm(self): llm_credentials = self.model.llm_credentials return guidance.llms.OpenAI(**llm_credentials) diff --git a/app/services/hazelcast_client.py b/app/services/hazelcast_client.py new file mode 100644 index 00000000..9c610ced --- /dev/null +++ b/app/services/hazelcast_client.py @@ -0,0 +1,9 @@ +import hazelcast +from app.config import settings + +hazelcast_client = hazelcast.HazelcastClient( + cluster_members=[ + f"{settings.pyris.cache.hazelcast.host}:\ + {settings.pyris.cache.hazelcast.port}" + ], +) diff --git a/application-docker.test.yml b/application-docker.test.yml new file mode 100644 index 00000000..55723fd5 --- /dev/null +++ b/application-docker.test.yml @@ -0,0 +1,22 @@ +pyris: + cache: + hazelcast: + host: cache + port: 5701 + api_keys: + - token: "secret" + comment: "DUMMY" + llm_access: + - DUMMY + + llms: + DUMMY: + name: "Dummy model" + description: "Dummy model for testing" + llm_credentials: + api_base: "" + api_type: "" + api_version: "" + deployment_id: "" + model: "" + token: "" diff --git a/application.example.yml b/application.example.yml index b41935a2..05dca1d5 100644 --- a/application.example.yml +++ b/application.example.yml @@ -1,4 +1,8 @@ pyris: + cache: + hazelcast: + host: localhost + port: 5701 api_keys: - token: "secret" comment: "DUMMY" diff --git a/application.test.yml b/application.test.yml index b41935a2..05dca1d5 100644 --- a/application.test.yml +++ b/application.test.yml @@ -1,4 +1,8 @@ pyris: + cache: + hazelcast: + host: localhost + port: 5701 api_keys: - token: "secret" comment: "DUMMY" diff --git a/docker-compose.yml b/docker-compose.yml index 2e5e790a..1653445f 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -3,15 +3,30 @@ version: "3.9" services: server: + container_name: pyris-server build: context: . dockerfile: Dockerfile.development + environment: + - DOCKER=1 + depends_on: + - cache ports: - "8000:8000" volumes: - .:/app networks: - pyris + stdin_open: true + tty: true + + cache: + container_name: pyris-cache + image: hazelcast/hazelcast:5.3.1 + networks: + - pyris + ports: + - '5701:5701' networks: pyris: diff --git a/docker/pyris-production.yml b/docker/pyris-production.yml index 1da5d2c1..c246fab1 100644 --- a/docker/pyris-production.yml +++ b/docker/pyris-production.yml @@ -20,6 +20,11 @@ services: networks: - pyris + cache: + image: hazelcast/hazelcast:5.3.1 + networks: + - pyris + nginx: extends: file: ./nginx.yml diff --git a/poetry.lock b/poetry.lock index 8be2fdc0..bbecd5c1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -672,6 +672,20 @@ files = [ {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, ] +[[package]] +name = "hazelcast-python-client" +version = "5.3.0" +description = "Hazelcast Python Client" +optional = false +python-versions = "*" +files = [ + {file = "hazelcast-python-client-5.3.0.tar.gz", hash = "sha256:9defd4adf4d25d66ca49383c78d8762960ae6ab96a3e1fecaa729b026542a5fc"}, + {file = "hazelcast_python_client-5.3.0-py3-none-any.whl", hash = "sha256:effba1748ba57b2dbb1a5aaa912e52ee039ed49f2d20dab7540fe7ac9c72688c"}, +] + +[package.extras] +stats = ["psutil"] + [[package]] name = "httpcore" version = "0.17.3" @@ -1688,4 +1702,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "6945b3c561481b43e8528c3ea3b82872c030560edb1e9aaeb6f9a133a665e5bb" +content-hash = "6fede1f3c79ddf48987ce1ff0b906bf90104a7d7a2fa3765c9932dc9df544084" diff --git a/pyproject.toml b/pyproject.toml index ba7f2c7b..638fefcf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ black = "^23.3.0" flake8 = "^6.0.0" pyaml-env = "^1.2.1" gunicorn = "^20.1.0" +hazelcast-python-client = "^5.3.0" parsimonious = "^0.10.0" diff --git a/tests/conftest.py b/tests/conftest.py index adf0e3bf..6de0ae74 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,19 @@ import pytest from fastapi.testclient import TestClient from app.main import app +from app.services.cache import cache_store + + +@pytest.fixture(autouse=True) +def clean_cache_store(): + cache_store.flushdb() + yield + cache_store.flushdb() + + +@pytest.fixture(scope="session") +def test_cache_store(): + return cache_store @pytest.fixture(scope="session") diff --git a/tests/routes/health_test.py b/tests/routes/health_test.py new file mode 100644 index 00000000..ce6fd725 --- /dev/null +++ b/tests/routes/health_test.py @@ -0,0 +1,117 @@ +import pytest +import app.config as config + + +@pytest.fixture(scope="function") +def model_configs(): + llm_model_config1 = config.LLMModelConfig( + name="test1", description="test", llm_credentials={} + ) + llm_model_config2 = config.LLMModelConfig( + name="test2", description="test", llm_credentials={} + ) + llm_model_config3 = config.LLMModelConfig( + name="test3", description="test", llm_credentials={} + ) + config.settings.pyris.llms = { + "GPT35_TURBO": llm_model_config1, + "GPT35_TURBO_16K_0613": llm_model_config2, + "GPT35_TURBO_0613": llm_model_config3, + } + + +@pytest.mark.usefixtures("model_configs") +def test_checkhealth(test_client, headers, mocker): + objA = mocker.Mock() + objB = mocker.Mock() + objC = mocker.Mock() + + def side_effect_func(*_, **kwargs): + if kwargs["model"] == config.settings.pyris.llms["GPT35_TURBO"]: + return objA + elif ( + kwargs["model"] + == config.settings.pyris.llms["GPT35_TURBO_16K_0613"] + ): + return objB + elif kwargs["model"] == config.settings.pyris.llms["GPT35_TURBO_0613"]: + return objC + + mocker.patch( + "app.services.guidance_wrapper.GuidanceWrapper.__new__", + side_effect=side_effect_func, + ) + mocker.patch.object(objA, "is_up", return_value=True) + mocker.patch.object(objB, "is_up", return_value=False) + mocker.patch.object(objC, "is_up", return_value=True) + + response = test_client.get("/api/v1/health", headers=headers) + assert response.status_code == 200 + assert response.json() == [ + {"model": "GPT35_TURBO", "status": "UP"}, + {"model": "GPT35_TURBO_16K_0613", "status": "DOWN"}, + {"model": "GPT35_TURBO_0613", "status": "UP"}, + ] + + # Assert the second call is cached + test_client.get("/api/v1/health", headers=headers) + objA.is_up.assert_called_once() + objB.is_up.assert_called_once() + objC.is_up.assert_called_once() + + +@pytest.mark.usefixtures("model_configs") +def test_checkhealth_save_to_cache( + test_client, test_cache_store, headers, mocker +): + objA = mocker.Mock() + objB = mocker.Mock() + objC = mocker.Mock() + + def side_effect_func(*_, **kwargs): + if kwargs["model"] == config.settings.pyris.llms["GPT35_TURBO"]: + return objA + elif ( + kwargs["model"] + == config.settings.pyris.llms["GPT35_TURBO_16K_0613"] + ): + return objB + elif kwargs["model"] == config.settings.pyris.llms["GPT35_TURBO_0613"]: + return objC + + mocker.patch( + "app.services.guidance_wrapper.GuidanceWrapper.__new__", + side_effect=side_effect_func, + ) + mocker.patch.object(objA, "is_up", return_value=True) + mocker.patch.object(objB, "is_up", return_value=False) + mocker.patch.object(objC, "is_up", return_value=True) + + assert test_cache_store.get("GPT35_TURBO:status") is None + assert test_cache_store.get("GPT35_TURBO_16K_0613:status") is None + assert test_cache_store.get("GPT35_TURBO_0613:status") is None + + test_client.get("/api/v1/health", headers=headers) + + assert test_cache_store.get("GPT35_TURBO:status") == "CLOSED" + assert test_cache_store.get("GPT35_TURBO_16K_0613:status") == "OPEN" + assert test_cache_store.get("GPT35_TURBO_0613:status") == "CLOSED" + + +def test_checkhealth_with_wrong_api_key(test_client): + headers = {"Authorization": "wrong api key"} + response = test_client.get("/api/v1/health", headers=headers) + assert response.status_code == 403 + assert response.json()["detail"] == { + "type": "not_authorized", + "errorMessage": "Permission denied", + } + + +def test_checkhealth_without_authorization_header(test_client): + response = test_client.get("/api/v1/health") + assert response.status_code == 401 + assert response.json()["detail"] == { + "type": "not_authenticated", + "errorMessage": "Requires authentication", + } diff --git a/tests/routes/messages_test.py b/tests/routes/messages_test.py index c194cb67..18b2e8eb 100644 --- a/tests/routes/messages_test.py +++ b/tests/routes/messages_test.py @@ -1,19 +1,24 @@ +import pytest from freezegun import freeze_time from app.models.dtos import Content, ContentType from app.services.guidance_wrapper import GuidanceWrapper import app.config as config -llm_model_config = config.LLMModelConfig( - name="test", description="test", llm_credentials={} -) -config.settings.pyris.llms = {"GPT35_TURBO": llm_model_config} -api_key_config = config.APIKeyConfig( - token="secret", comment="test", llm_access=["GPT35_TURBO"] -) -config.settings.pyris.api_keys = [api_key_config] + +@pytest.fixture(scope="function") +def model_configs(): + llm_model_config = config.LLMModelConfig( + name="test", description="test", llm_credentials={} + ) + config.settings.pyris.llms = {"GPT35_TURBO": llm_model_config} + api_key_config = config.APIKeyConfig( + token="secret", comment="test", llm_access=["GPT35_TURBO"] + ) + config.settings.pyris.api_keys = [api_key_config] @freeze_time("2023-06-16 03:21:34 +02:00") +@pytest.mark.usefixtures("model_configs") def test_send_message(test_client, headers, mocker): mocker.patch.object( GuidanceWrapper, @@ -77,6 +82,7 @@ def test_send_message_missing_params(test_client, headers): } +@pytest.mark.usefixtures("model_configs") def test_send_message_raise_value_error(test_client, headers, mocker): mocker.patch.object( GuidanceWrapper, "query", side_effect=ValueError("value error message") @@ -99,6 +105,7 @@ def test_send_message_raise_value_error(test_client, headers, mocker): } +@pytest.mark.usefixtures("model_configs") def test_send_message_raise_key_error(test_client, headers, mocker): mocker.patch.object( GuidanceWrapper, "query", side_effect=KeyError("key error message") @@ -138,3 +145,50 @@ def test_send_message_without_authorization_header(test_client): "type": "not_authenticated", "errorMessage": "Requires authentication", } + + +@pytest.mark.usefixtures("model_configs") +def test_send_message_fail_three_times( + test_client, mocker, headers, test_cache_store +): + mocker.patch.object( + GuidanceWrapper, "query", side_effect=Exception("some error") + ) + body = { + "template": { + "id": 123, + "content": "some template", + }, + "preferredModel": "GPT35_TURBO", + "parameters": {"query": "Some query"}, + } + + for _ in range(3): + try: + test_client.post("/api/v1/messages", headers=headers, json=body) + except Exception: + ... + + # Restrict access + response = test_client.post("/api/v1/messages", headers=headers, json=body) + assert test_cache_store.get("GPT35_TURBO:status") == "OPEN" + assert test_cache_store.get("GPT35_TURBO:num_failures") == 3 + assert response.status_code == 500 + assert response.json() == { + "detail": { + "errorMessage": "Too many failures! Please try again later.", + "type": "other", + } + } + + # Can access after TTL + test_cache_store.delete("GPT35_TURBO:status") + test_cache_store.delete("GPT35_TURBO:num_failures") + response = test_client.post("/api/v1/messages", headers=headers, json=body) + assert response.status_code == 500 + assert response.json() == { + "detail": { + "errorMessage": "some error", + "type": "other", + } + } diff --git a/tests/services/cache_test.py b/tests/services/cache_test.py new file mode 100644 index 00000000..099d6b5e --- /dev/null +++ b/tests/services/cache_test.py @@ -0,0 +1,219 @@ +import pytest +from datetime import datetime +from freezegun import freeze_time +from app.services.cache import HazelcastCacheStore, InMemoryCacheStore + + +@pytest.fixture(scope="function") +def in_memory_cache(): + return InMemoryCacheStore() + + +@pytest.fixture(scope="function") +def hazelcast_cache(): + return HazelcastCacheStore() + + +class TestInMemoryCacheStore: + def test_set_and_get(self, in_memory_cache): + in_memory_cache.set("test_key", "test_value") + assert in_memory_cache.get("test_key") == "test_value" + assert in_memory_cache._cache["test_key"].expired_at is None + + @freeze_time("2023-06-16 03:21:34 +02:00") + def test_set_and_get_an_existing_key(self, in_memory_cache): + in_memory_cache.set("test_key", "test_value", ex=30) + assert in_memory_cache.get("test_key") == "test_value" + assert in_memory_cache._cache["test_key"].expired_at == datetime( + year=2023, month=6, day=16, hour=1, minute=22, second=4 + ) + + in_memory_cache.set("test_key", "test_value2") + assert in_memory_cache.get("test_key") == "test_value2" + assert in_memory_cache._cache["test_key"].expired_at is None + + def test_set_and_get_before_expired_time(self, in_memory_cache): + initial_datetime = datetime( + year=1, month=1, day=1, hour=15, minute=2, second=3 + ) + other_datetime = datetime( + year=1, month=1, day=1, hour=15, minute=2, second=34 + ) + + with freeze_time(initial_datetime) as frozen_datetime: + in_memory_cache.set("test_key", "test_value", ex=40) + assert in_memory_cache.get("test_key") == "test_value" + assert in_memory_cache._cache["test_key"].expired_at == datetime( + year=1, month=1, day=1, hour=15, minute=2, second=43 + ) + + frozen_datetime.move_to(other_datetime) + + assert in_memory_cache.get("test_key") == "test_value" + assert in_memory_cache._cache["test_key"].expired_at == datetime( + year=1, month=1, day=1, hour=15, minute=2, second=43 + ) + + def test_set_and_get_after_expired_time(self, in_memory_cache): + initial_datetime = datetime( + year=1, month=1, day=1, hour=15, minute=2, second=3 + ) + other_datetime = datetime( + year=1, month=1, day=1, hour=15, minute=2, second=34 + ) + + with freeze_time(initial_datetime) as frozen_datetime: + in_memory_cache.set("test_key", "test_value", ex=30) + assert in_memory_cache.get("test_key") == "test_value" + assert in_memory_cache._cache["test_key"].expired_at == datetime( + year=1, month=1, day=1, hour=15, minute=2, second=33 + ) + + frozen_datetime.move_to(other_datetime) + + assert in_memory_cache.get("test_key") is None + assert "test_key" not in in_memory_cache._cache + + @freeze_time("2023-06-16 03:21:34 +02:00") + def test_expire_a_existing_key(self, in_memory_cache): + in_memory_cache.set("test_key", "test_value") + in_memory_cache.expire("test_key", 30) + assert in_memory_cache._cache["test_key"].expired_at == datetime( + year=2023, month=6, day=16, hour=1, minute=22, second=4 + ) + + @freeze_time("2023-06-16 03:21:34 +02:00") + def test_expire_a_non_existing_key(self, in_memory_cache): + in_memory_cache.expire("test_key", 30) + assert "test_key" not in in_memory_cache._cache + + def test_incr_a_non_existing_key(self, in_memory_cache): + response = in_memory_cache.incr("test_key") + assert response == 1 + assert in_memory_cache.get("test_key") == 1 + + def test_incr_a_existing_key(self, in_memory_cache): + in_memory_cache.set("test_key", 1) + response = in_memory_cache.incr("test_key") + assert response == 2 + assert in_memory_cache.get("test_key") == 2 + + def test_incr_a_existing_key_not_a_number(self, in_memory_cache): + in_memory_cache.set("test_key", "not a number") + + with pytest.raises(TypeError, match="value is not an integer"): + in_memory_cache.incr("test_key") + + def test_flushdb(self, in_memory_cache): + in_memory_cache.set("test_key", "test_value") + in_memory_cache.flushdb() + + assert in_memory_cache._cache == {} + + def test_delete_a_non_existing_key(self, in_memory_cache): + in_memory_cache.delete("test_key") + + assert in_memory_cache.get("test_key") is None + assert "test_key" not in in_memory_cache._cache + + def test_delete_a_existing_key(self, in_memory_cache): + in_memory_cache.set("test_key", "test_value", ex=1000) + in_memory_cache.delete("test_key") + + assert in_memory_cache.get("test_key") is None + assert "test_key" not in in_memory_cache._cache + + +class TestHazelcastCacheStore: + def test_set_and_get(self, hazelcast_cache): + hazelcast_cache.set("test_key", "test_value") + assert hazelcast_cache.get("test_key") == "test_value" + + @freeze_time("2023-06-16 03:21:34 +02:00") + def test_set_and_get_an_existing_key(self, hazelcast_cache): + hazelcast_cache.set("test_key", "test_value", ex=30) + assert hazelcast_cache.get("test_key") == "test_value" + + hazelcast_cache.set("test_key", "test_value2") + assert hazelcast_cache.get("test_key") == "test_value2" + + def test_set_and_get_before_expired_time(self, hazelcast_cache): + initial_datetime = datetime( + year=1, month=1, day=1, hour=15, minute=2, second=3 + ) + other_datetime = datetime( + year=1, month=1, day=1, hour=15, minute=2, second=34 + ) + + with freeze_time(initial_datetime) as frozen_datetime: + hazelcast_cache.set("test_key", "test_value", ex=40) + assert hazelcast_cache.get("test_key") == "test_value" + + frozen_datetime.move_to(other_datetime) + + assert hazelcast_cache.get("test_key") == "test_value" + + def test_set_and_get_after_expired_time(self, hazelcast_cache): + initial_datetime = datetime( + year=1, month=1, day=1, hour=15, minute=2, second=3 + ) + other_datetime = datetime( + year=1, month=1, day=1, hour=15, minute=2, second=34 + ) + + with freeze_time(initial_datetime) as frozen_datetime: + hazelcast_cache.set("test_key", "test_value", ex=30) + assert hazelcast_cache.get("test_key") == "test_value" + + frozen_datetime.move_to(other_datetime) + # NOTE: Hazelcast doesn't expire with freezegun + hazelcast_cache.delete("test_key") + + assert hazelcast_cache.get("test_key") is None + assert not hazelcast_cache._cache.contains_key("test_key") + + @freeze_time("2023-06-16 03:21:34 +02:00") + def test_expire_a_existing_key(self, hazelcast_cache): + hazelcast_cache.set("test_key", "test_value") + hazelcast_cache.expire("test_key", 30) + + @freeze_time("2023-06-16 03:21:34 +02:00") + def test_expire_a_non_existing_key(self, hazelcast_cache): + hazelcast_cache.expire("test_key", 30) + assert not hazelcast_cache._cache.contains_key("test_key") + + def test_incr_a_non_existing_key(self, hazelcast_cache): + response = hazelcast_cache.incr("test_key") + assert response == 1 + assert hazelcast_cache.get("test_key") == 1 + + def test_incr_a_existing_key(self, hazelcast_cache): + hazelcast_cache.set("test_key", 1) + response = hazelcast_cache.incr("test_key") + assert response == 2 + assert hazelcast_cache.get("test_key") == 2 + + def test_incr_a_existing_key_not_a_number(self, hazelcast_cache): + hazelcast_cache.set("test_key", "not a number") + + with pytest.raises(TypeError, match="value is not an integer"): + hazelcast_cache.incr("test_key") + + def test_flushdb(self, hazelcast_cache): + hazelcast_cache.set("test_key", "test_value") + hazelcast_cache.flushdb() + + assert hazelcast_cache._cache.is_empty() + + def test_delete_a_non_existing_key(self, hazelcast_cache): + hazelcast_cache.delete("test_key") + + assert hazelcast_cache.get("test_key") is None + assert not hazelcast_cache._cache.contains_key("test_key") + + def test_delete_a_existing_key(self, hazelcast_cache): + hazelcast_cache.set("test_key", "test_value", ex=1000) + hazelcast_cache.delete("test_key") + + assert hazelcast_cache.get("test_key") is None + assert not hazelcast_cache._cache.contains_key("test_key") diff --git a/tests/services/circuit_breaker_test.py b/tests/services/circuit_breaker_test.py new file mode 100644 index 00000000..fef7d1b7 --- /dev/null +++ b/tests/services/circuit_breaker_test.py @@ -0,0 +1,86 @@ +from datetime import datetime +import pytest +from freezegun import freeze_time +from app.services.circuit_breaker import CircuitBreaker + + +def test_protected_call_success(mocker): + mock_function = mocker.MagicMock(return_value=3) + + result = CircuitBreaker.protected_call( + func=mock_function, cache_key="test_key" + ) + + mock_function.assert_called_once() + assert result == 3 + + +def test_protected_call_fail_with_accepted_exceptions( + mocker, test_cache_store +): + mock_function = mocker.MagicMock(side_effect=KeyError("foo")) + + with pytest.raises(KeyError, match="foo"): + CircuitBreaker.protected_call( + func=mock_function, + cache_key="test_key", + accepted_exceptions=(KeyError,), + ) + + assert test_cache_store.get("test_key:num_failures") is None + + +@freeze_time("2023-06-16 03:21:34 +02:00") +def test_protected_call_too_many_failures_not_in_a_row( + mocker, test_cache_store +): + initial_datetime = datetime( + year=1, month=1, day=1, hour=15, minute=2, second=3 + ) + other_datetime = datetime( + year=1, month=1, day=1, hour=15, minute=2, second=34 + ) + + with freeze_time(initial_datetime) as frozen_datetime: + test_cache_store.set("test_key:num_failures", 2) + mock_function = mocker.MagicMock(side_effect=KeyError("foo")) + + with pytest.raises(KeyError, match="foo"): + CircuitBreaker.protected_call( + func=mock_function, + cache_key="test_key", + ) + + frozen_datetime.move_to(other_datetime) + + # NOTE: Hazelcast doesn't expire with freezegun + test_cache_store.delete("test_key:status") + test_cache_store.delete("test_key:num_failures") + + with pytest.raises(KeyError, match="foo"): + CircuitBreaker.protected_call( + func=mock_function, + cache_key="test_key", + ) + assert test_cache_store.get("test_key:num_failures") == 1 + + +def test_protected_call_too_many_failures_in_a_row(mocker, test_cache_store): + test_cache_store.set("test_key:num_failures", 2) + mock_function = mocker.MagicMock(side_effect=KeyError("foo")) + + with pytest.raises(KeyError, match="foo"): + CircuitBreaker.protected_call( + func=mock_function, + cache_key="test_key", + ) + + with pytest.raises( + ValueError, match="Too many failures! Please try again later." + ): + CircuitBreaker.protected_call( + func=mock_function, + cache_key="test_key", + ) + + assert test_cache_store.get("test_key:num_failures") == 3