From f4f0b6921e51c6b394a687ec350363dda3e6b1d8 Mon Sep 17 00:00:00 2001 From: zzstoatzz Date: Sun, 11 Aug 2024 00:55:23 -0500 Subject: [PATCH] add tiktoken cache --- src/marvin/settings.py | 24 ++++++++++++++++++++-- tests/utilities/test_tiktoken.py | 35 ++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 2 deletions(-) create mode 100644 tests/utilities/test_tiktoken.py diff --git a/src/marvin/settings.py b/src/marvin/settings.py index 1cf135294..7328a8f28 100644 --- a/src/marvin/settings.py +++ b/src/marvin/settings.py @@ -19,8 +19,7 @@ class MarvinSettings(BaseSettings): def __setattr__(self, name: str, value: Any) -> None: # wrap bare strings in SecretStr if the field is annotated with SecretStr - field = self.model_fields.get(name) - if field: + if field := self.model_fields.get(name): annotation = field.annotation base_types = ( getattr(annotation, "__args__", None) @@ -32,6 +31,18 @@ def __setattr__(self, name: str, value: Any) -> None: super().__setattr__(name, value) +class TiktokenSettings(MarvinSettings): + model_config = SettingsConfigDict(env_prefix="marvin_tiktoken_", extra="ignore") + + cache_dir: Optional[str] = Field( + default=None, description="Directory to store cached tiktoken encoding files." + ) + verify_ssl: bool = Field( + default=True, + description="Whether to verify SSL certificates for tiktoken requests.", + ) + + class ChatCompletionSettings(MarvinSettings): model_config = SettingsConfigDict( env_prefix="marvin_chat_completions_", extra="ignore" @@ -40,10 +51,19 @@ class ChatCompletionSettings(MarvinSettings): temperature: float = Field(description="The default temperature to use.", default=1) + tiktoken: TiktokenSettings = Field(default_factory=TiktokenSettings) + @property def encoder(self): import tiktoken + if self.tiktoken.cache_dir: + os.environ["TIKTOKEN_CACHE_DIR"] = self.tiktoken.cache_dir + if not self.tiktoken.verify_ssl: + import ssl + + ssl._create_default_https_context = ssl._create_unverified_context + try: encoding = tiktoken.encoding_for_model(self.model) except KeyError: diff --git a/tests/utilities/test_tiktoken.py b/tests/utilities/test_tiktoken.py new file mode 100644 index 000000000..ebae38c2b --- /dev/null +++ b/tests/utilities/test_tiktoken.py @@ -0,0 +1,35 @@ +import os +from unittest.mock import MagicMock, patch + +from marvin.settings import ChatCompletionSettings, settings, temporary_settings + + +def test_tiktoken_cache_dir_setting(tmp_path): + with temporary_settings( + openai__chat__completions__tiktoken__cache_dir=str(tmp_path) + ): + _ = settings.openai.chat.completions.encoder + assert os.environ.get("TIKTOKEN_CACHE_DIR") == str(tmp_path) + + # Check that the environment is cleaned up after the test + assert "TIKTOKEN_CACHE_DIR" not in os.environ + + +def test_tiktoken_default_behavior(): + # Test with default settings (no cache dir, SSL verification enabled) + with patch("tiktoken.encoding_for_model") as mock_encoding: + mock_encoder = MagicMock() + mock_encoding.return_value = mock_encoder + + chat_settings = ChatCompletionSettings() + _ = chat_settings.encoder + + # Check that TIKTOKEN_CACHE_DIR is not set + assert "TIKTOKEN_CACHE_DIR" not in os.environ + + # Check that SSL verification is not modified + import ssl + + assert ssl._create_default_https_context != ssl._create_unverified_context + + mock_encoding.assert_called_once_with(chat_settings.model)