Skip to content

Commit

Permalink
Development: Add truncate function (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hialus authored Jul 22, 2023
1 parent 164de0f commit 94d78e5
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 0 deletions.
8 changes: 8 additions & 0 deletions app/services/guidance_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
def truncate(history: list[any], max_length: int):
if max_length == 0:
return []

if max_length > 0:
return history[:max_length]

return history[max_length:]
2 changes: 2 additions & 0 deletions app/services/guidance_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from app.config import LLMModelConfig
from app.models.dtos import Content, ContentType
from app.services.guidance_functions import truncate


class GuidanceWrapper:
Expand Down Expand Up @@ -34,6 +35,7 @@ def query(self) -> Content:
template = guidance(self.handlebars)
result = template(
llm=self._get_llm(),
truncate=truncate,
**self.parameters,
)

Expand Down
21 changes: 21 additions & 0 deletions tests/services/guidance_functions_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest
from app.services.guidance_functions import truncate


@pytest.mark.parametrize(
"history,max_length,expected",
[
([], -2, []),
([], 0, []),
([], 2, []),
([1, 2, 3], 0, []),
# Get the last n elements
([1, 2, 3], -4, [1, 2, 3]),
([1, 2, 3], -2, [2, 3]),
# Get the first n elements
([1, 2, 3], 2, [1, 2]),
([1, 2, 3], 4, [1, 2, 3]),
],
)
def test_truncate(history, max_length, expected):
assert truncate(history, max_length) == expected
26 changes: 26 additions & 0 deletions tests/services/guidance_wrapper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,32 @@ def test_query_success(mocker):
assert result.text_content == "the output"


def test_query_using_truncate_function(mocker):
mocker.patch.object(
GuidanceWrapper,
"_get_llm",
return_value=guidance.llms.Mock("the output"),
)

handlebars = """{{#user~}}I want a response to the following query:
{{query}}{{~/user}}{{#assistant~}}
{{gen 'answer' temperature=0.0 max_tokens=500}}{{~/assistant}}
{{set 'response' (truncate answer 3)}}
"""

guidance_wrapper = GuidanceWrapper(
model=llm_model_config,
handlebars=handlebars,
parameters={"query": "Some query"},
)

result = guidance_wrapper.query()

assert isinstance(result, Content)
assert result.type == ContentType.TEXT
assert result.text_content == "the"


def test_query_missing_required_params(mocker):
mocker.patch.object(
GuidanceWrapper,
Expand Down

0 comments on commit 94d78e5

Please sign in to comment.