Skip to content

Commit

Permalink
Fixes #89
Browse files Browse the repository at this point in the history
We were caching model list endpoint errors as "no models". We now save that it was an error, and will retry next time.

Wasn't important for OpenAI/Openrouter type servers that are almost always up, but important for LiteLLM and other local servers that are started/stopped often.
  • Loading branch information
scosman committed Dec 23, 2024
1 parent 662b8cf commit d932bcf
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 43 deletions.
40 changes: 26 additions & 14 deletions app/desktop/studio_server/provider_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,12 +567,16 @@ class OpenAICompatibleProviderCache:
providers: List[AvailableModels]
last_updated: datetime | None = None
openai_compat_config_when_cached: Any | None = None
had_error: bool = False

# Cache for 60 minutes, or if the config changes
def is_stale(self) -> bool:
if self.last_updated is None:
return True

if self.had_error:
return True

if datetime.now() - self.last_updated > timedelta(minutes=60):
return True

Expand All @@ -594,23 +598,26 @@ def openai_compatible_providers() -> List[AvailableModels]:
or _openai_compatible_providers_cache.is_stale()
):
# Load values and cache them
provider_config = Config.shared().openai_compatible_providers
providers = openai_compatible_providers_uncached(provider_config)
_openai_compatible_providers_cache = OpenAICompatibleProviderCache(
providers=providers,
last_updated=datetime.now(),
openai_compat_config_when_cached=provider_config,
)
cache = openai_compatible_providers_load_cache()
_openai_compatible_providers_cache = cache

if _openai_compatible_providers_cache is None:
return []

return _openai_compatible_providers_cache.providers


def openai_compatible_providers_uncached(providers: List[Any]) -> List[AvailableModels]:
if not providers or len(providers) == 0:
return []
def openai_compatible_providers_load_cache() -> OpenAICompatibleProviderCache | None:
provider_config = Config.shared().openai_compatible_providers
if not provider_config or len(provider_config) == 0:
return None

# Errors that can be retried, like network issues, are tracked in cache.
# We retry populating the cache on each call
has_error = False

openai_compatible_models: List[AvailableModels] = []
for provider in providers:
for provider in provider_config:
models: List[ModelDetails] = []
base_url = provider.get("base_url")
if not base_url or not base_url.startswith("http"):
Expand Down Expand Up @@ -650,9 +657,14 @@ def openai_compatible_providers_uncached(providers: List[Any]) -> List[Available
)
except Exception as e:
print(f"Error connecting to OpenAI compatible provider {name}: {e}")
has_error = True
continue

if len(openai_compatible_models) == 0:
return []
cache = OpenAICompatibleProviderCache(
providers=openai_compatible_models,
last_updated=datetime.now(),
openai_compat_config_when_cached=provider_config,
had_error=has_error,
)

return openai_compatible_models
return cache
82 changes: 53 additions & 29 deletions app/desktop/studio_server/test_provider_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from datetime import datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, Mock, patch
from unittest.mock import MagicMock, Mock, patch

import pytest
from fastapi import FastAPI, HTTPException
Expand Down Expand Up @@ -30,7 +30,7 @@
custom_models,
model_from_ollama_tag,
openai_compatible_providers,
openai_compatible_providers_uncached,
openai_compatible_providers_load_cache,
)


Expand Down Expand Up @@ -1191,31 +1191,35 @@ def test_openai_compatible_providers():
with (
patch("app.desktop.studio_server.provider_api.Config.shared") as mock_config,
patch(
"app.desktop.studio_server.provider_api.openai_compatible_providers_uncached"
"app.desktop.studio_server.provider_api.openai_compatible_providers_load_cache"
) as mock_uncached,
):
mock_config.return_value.openai_compatible_providers = mock_provider_config
mock_uncached.return_value = [
AvailableModels(
provider_id=ModelProviderName.openai_compatible,
provider_name="test_provider",
models=[
ModelDetails(
id="test_provider::model1",
name="model1",
supports_structured_output=False,
supports_data_gen=False,
untested_model=True,
)
],
)
]
mock_uncached.return_value = OpenAICompatibleProviderCache(
providers=[
AvailableModels(
provider_id=ModelProviderName.openai_compatible,
provider_name="test_provider",
models=[
ModelDetails(
id="test_provider::model1",
name="model1",
supports_structured_output=False,
supports_data_gen=False,
untested_model=True,
)
],
),
],
last_updated=datetime.now(),
openai_compat_config_when_cached=mock_provider_config,
)

# First call should create cache
result1 = openai_compatible_providers()
assert len(result1) == 1
assert result1[0].provider_name == "test_provider"
mock_uncached.assert_called_once_with(mock_provider_config)
mock_uncached.assert_called_once()

# Second call should use cache
mock_uncached.reset_mock()
Expand Down Expand Up @@ -1253,8 +1257,12 @@ def test_openai_compatible_providers_uncached():
mock_client = MagicMock()
mock_client.models.list = mock_models_list

with patch("openai.OpenAI", return_value=mock_client):
result = openai_compatible_providers_uncached(mock_providers)
with (
patch("openai.OpenAI", return_value=mock_client),
patch("app.desktop.studio_server.provider_api.Config.shared") as mock_config,
):
mock_config.return_value.openai_compatible_providers = mock_providers
result = openai_compatible_providers_load_cache().providers

assert len(result) == 1
assert result[0].provider_name == "test_provider"
Expand All @@ -1266,8 +1274,12 @@ def test_openai_compatible_providers_uncached():


def test_openai_compatible_providers_uncached_empty_providers():
assert openai_compatible_providers_uncached([]) == []
assert openai_compatible_providers_uncached(None) == []
with (
patch("app.desktop.studio_server.provider_api.Config.shared") as mock_config,
):
mock_config.return_value.openai_compatible_providers = []
cached = openai_compatible_providers_load_cache()
assert cached is None


def test_openai_compatible_providers_uncached_invalid_provider():
Expand All @@ -1282,9 +1294,13 @@ def test_openai_compatible_providers_uncached_invalid_provider():
{"name": "test", "api_key": "key"}, # No base_url
]

with patch("openai.OpenAI") as mock_openai:
result = openai_compatible_providers_uncached(invalid_providers)
assert result == []
with (
patch("openai.OpenAI") as mock_openai,
patch("app.desktop.studio_server.provider_api.Config.shared") as mock_config,
):
mock_config.return_value.openai_compatible_providers = invalid_providers
result = openai_compatible_providers_load_cache()
assert result.providers == []
mock_openai.assert_not_called()


Expand All @@ -1297,7 +1313,15 @@ def test_openai_compatible_providers_uncached_api_error():
}
]

with patch("openai.OpenAI") as mock_openai:
with (
patch("openai.OpenAI") as mock_openai,
patch("app.desktop.studio_server.provider_api.Config.shared") as mock_config,
):
mock_config.return_value.openai_compatible_providers = mock_providers
mock_openai.return_value.models.list.side_effect = Exception("API Error")
result = openai_compatible_providers_uncached(mock_providers)
assert result == []
result = openai_compatible_providers_load_cache()
assert result.providers == []

# Confirm the cache knows about the error and reports stale
assert result.had_error
assert result.is_stale()

0 comments on commit d932bcf

Please sign in to comment.