From b2abb60cc4b0bf767b19cc6c4fc8cd5c6ee92920 Mon Sep 17 00:00:00 2001 From: Timor Morrien Date: Wed, 19 Jul 2023 13:13:01 +0200 Subject: [PATCH 1/3] Add truncate function --- app/services/guidance_functions.py | 3 +++ app/services/guidance_wrapper.py | 2 ++ 2 files changed, 5 insertions(+) create mode 100644 app/services/guidance_functions.py diff --git a/app/services/guidance_functions.py b/app/services/guidance_functions.py new file mode 100644 index 00000000..18e6f9dc --- /dev/null +++ b/app/services/guidance_functions.py @@ -0,0 +1,3 @@ +def truncate(history: list[any], max_length: int): + return history[-max_length:] + diff --git a/app/services/guidance_wrapper.py b/app/services/guidance_wrapper.py index 43f94177..6c7243bb 100644 --- a/app/services/guidance_wrapper.py +++ b/app/services/guidance_wrapper.py @@ -2,6 +2,7 @@ from app.config import LLMModelConfig from app.models.dtos import Content, ContentType +from app.services.guidance_functions import truncate class GuidanceWrapper: @@ -30,6 +31,7 @@ def query(self) -> Content: template = guidance(self.handlebars) result = template( llm=self._get_llm(), + truncate=truncate, **self.parameters, ) From d0a585970eb7a1bd6f2dcbc0a6822fdbe032f79b Mon Sep 17 00:00:00 2001 From: Khoa Nguyen Date: Wed, 19 Jul 2023 13:38:20 +0200 Subject: [PATCH 2/3] Add tests --- app/services/guidance_functions.py | 7 ++++++- tests/services/guidance_functions_test.py | 21 +++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) create mode 100644 tests/services/guidance_functions_test.py diff --git a/app/services/guidance_functions.py b/app/services/guidance_functions.py index 18e6f9dc..a703775d 100644 --- a/app/services/guidance_functions.py +++ b/app/services/guidance_functions.py @@ -1,3 +1,8 @@ def truncate(history: list[any], max_length: int): - return history[-max_length:] + if max_length == 0: + return [] + if max_length > 0: + return history[:max_length] + + return history[max_length:] diff --git a/tests/services/guidance_functions_test.py b/tests/services/guidance_functions_test.py new file mode 100644 index 00000000..182c3c4b --- /dev/null +++ b/tests/services/guidance_functions_test.py @@ -0,0 +1,21 @@ +import pytest +from app.services.guidance_functions import truncate + + +@pytest.mark.parametrize( + "history,max_length,expected", + [ + ([], -2, []), + ([], 0, []), + ([], 2, []), + ([1, 2, 3], 0, []), + # Get the last n elements + ([1, 2, 3], -4, [1, 2, 3]), + ([1, 2, 3], -2, [2, 3]), + # Get the first n elements + ([1, 2, 3], 2, [1, 2]), + ([1, 2, 3], 4, [1, 2, 3]), + ], +) +def test_truncate(history, max_length, expected): + assert truncate(history, max_length) == expected From 36029ebcc091bc6c18750564baafb20f19d25ddd Mon Sep 17 00:00:00 2001 From: Khoa Nguyen Date: Wed, 19 Jul 2023 13:58:40 +0200 Subject: [PATCH 3/3] Add test --- tests/services/guidance_wrapper_test.py | 26 +++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/services/guidance_wrapper_test.py b/tests/services/guidance_wrapper_test.py index f43a558d..9f1312db 100644 --- a/tests/services/guidance_wrapper_test.py +++ b/tests/services/guidance_wrapper_test.py @@ -34,6 +34,32 @@ def test_query_success(mocker): assert result.text_content == "the output" +def test_query_using_truncate_function(mocker): + mocker.patch.object( + GuidanceWrapper, + "_get_llm", + return_value=guidance.llms.Mock("the output"), + ) + + handlebars = """{{#user~}}I want a response to the following query: + {{query}}{{~/user}}{{#assistant~}} + {{gen 'answer' temperature=0.0 max_tokens=500}}{{~/assistant}} + {{set 'response' (truncate answer 3)}} + """ + + guidance_wrapper = GuidanceWrapper( + model=llm_model_config, + handlebars=handlebars, + parameters={"query": "Some query"}, + ) + + result = guidance_wrapper.query() + + assert isinstance(result, Content) + assert result.type == ContentType.TEXT + assert result.text_content == "the" + + @pytest.mark.skip( reason="This tests library behavior changed by Guidance version bump" )