From 8ec4065cd8c58beb3fa7e749fc5dcb7b16b5c074 Mon Sep 17 00:00:00 2001 From: Nick Sullivan Date: Fri, 4 Aug 2023 19:10:54 +0200 Subject: [PATCH] Refactor config reading and expand test coverage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit refactors the way we read the configuration in `lm.py` to ensure that we always return a dictionary, even when the configuration file is not found. This change simplifies the code and makes it more robust. In `output.py`, we've updated the comment to better reflect the actual code change, which is setting the padding to 0. In `test_lm.py`, we've expanded the test coverage to include more models. This will help us catch potential issues with these models earlier. These changes should improve the overall quality and reliability of our code. 🚀👍🏽 --- aicodebot/lm.py | 4 ++-- aicodebot/output.py | 2 +- tests/test_lm.py | 11 +++++++---- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/aicodebot/lm.py b/aicodebot/lm.py index 550816e..ad5dab7 100644 --- a/aicodebot/lm.py +++ b/aicodebot/lm.py @@ -158,7 +158,7 @@ def get_api_key(self, key_name): if api_key: return api_key else: - config = read_config() + config = read_config() or {} key_name_lower = key_name.lower() # Try both upper and lower case from the config file if key_name_lower in config: @@ -218,7 +218,7 @@ def get_token_size(self, text): def read_model_config(self): # Figure out which model to use, based on the config file or environment variables - config = read_config() + config = read_config() or {} self.provider = os.getenv( "AICODEBOT_MODEL_PROVIDER", config.get("language_model_provider", self.DEFAULT_PROVIDER) ) diff --git a/aicodebot/output.py b/aicodebot/output.py index e18e6f4..43da03e 100644 --- a/aicodebot/output.py +++ b/aicodebot/output.py @@ -38,7 +38,7 @@ class OurCodeBlock(CodeBlock): def __rich_console__(self, console, options): code = str(self.text) - # set dedent=True to remove leading spaces and turn off padding + # Set the padding to 0 syntax = Syntax(code, self.lexer_name, theme=self.theme, word_wrap=True, padding=0) yield syntax diff --git a/tests/test_lm.py b/tests/test_lm.py index 5042601..bdf69e3 100644 --- a/tests/test_lm.py +++ b/tests/test_lm.py @@ -14,7 +14,13 @@ def test_token_size(monkeypatch): @pytest.mark.parametrize( - "provider,model_name", [(LanguageModelManager.OPENAI, "gpt-4"), (LanguageModelManager.OPENROUTER, "gpt-4")] + "provider,model_name", + [ + (LanguageModelManager.OPENAI, "gpt-4"), + (LanguageModelManager.OPENAI, "gpt-3.5-turbo"), + (LanguageModelManager.OPENROUTER, "gpt-4"), + (LanguageModelManager.OPENROUTER, "gpt-4-32k"), + ], ) def test_chain_factory(provider, model_name, monkeypatch): monkeypatch.setenv("AICODEBOT_MODEL_PROVIDER", provider) @@ -31,6 +37,3 @@ def test_chain_factory(provider, model_name, monkeypatch): if hasattr(chain.llm, "model_name"): # OpenAI compatible assert chain.llm.model_name == model_name - elif hasattr(chain.llm, "repo_id"): - # Hugging Face Hub - assert chain.llm.repo_id == model_name