Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vasinov committed Jul 8, 2024
1 parent 78c0a6c commit 0368ce4
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,12 @@ def test_run(self, module):

def test_prompt(self, module):
system_message = module.default_system_template_generator(
RagContext(
query="test",
text_chunks=[
TextArtifact("*TEXT SEGMENT 1*", reference=Reference(title="source 1")),
TextArtifact("*TEXT SEGMENT 2*", reference=Reference(title="source 2")),
TextArtifact("*TEXT SEGMENT 3*"),
],
before_query=["*RULESET*", "*META*"],
)
RagContext(query="test", before_query=["*RULESET*", "*META*"]),
artifacts=[
TextArtifact("*TEXT SEGMENT 1*", reference=Reference(title="source 1")),
TextArtifact("*TEXT SEGMENT 2*", reference=Reference(title="source 2")),
TextArtifact("*TEXT SEGMENT 3*"),
],
)

assert "*RULESET*" in system_message
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from griptape.artifacts import TextArtifact
from griptape.common import Reference
from griptape.engines.rag import RagContext
from griptape.engines.rag.modules import PromptResponseRagModule
from tests.mocks.mock_prompt_driver import MockPromptDriver
Expand All @@ -16,15 +17,26 @@ def test_run(self, module):

def test_prompt(self, module):
system_message = module.default_system_template_generator(
RagContext(
query="test",
text_chunks=[TextArtifact("*TEXT SEGMENT 1*"), TextArtifact("*TEXT SEGMENT 2*")],
before_query=["*RULESET*", "*META*"],
after_query=[],
)
RagContext(query="test", before_query=["*RULESET*", "*META*"], after_query=[]),
artifacts=[TextArtifact("*TEXT SEGMENT 1*"), TextArtifact("*TEXT SEGMENT 2*")],
)

assert "*RULESET*" in system_message
assert "*META*" in system_message
assert "*TEXT SEGMENT 1*" in system_message
assert "*TEXT SEGMENT 2*" in system_message

def test_references_from_artifacts(self, module):
reference1 = Reference(title="foo")
reference2 = Reference(title="bar")
artifacts = [
TextArtifact("foo", reference=reference1),
TextArtifact("foo", reference=reference1),
TextArtifact("foo"),
TextArtifact("foo", reference=reference2),
]
references = module.references_from_artifacts(artifacts)

assert len(references) == 2
assert references[0].id == reference1.id
assert references[1].id == reference2.id
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from unittest.mock import Mock
import pytest
from cohere import RerankResponseResultsItem, RerankResponseResultsItemDocument
from griptape.artifacts import TextArtifact
from griptape.drivers import CohereRerankDriver
from griptape.engines.rag import RagContext
from griptape.engines.rag.modules import TextChunksRerankRagModule
Expand All @@ -9,12 +10,19 @@ class TestTextChunksRerankRagModule:
@pytest.fixture
def mock_client(self, mocker):
mock_client = mocker.patch("cohere.Client").return_value
mock_client.rerank.return_value.results = [Mock(), Mock()]
mock_client.rerank.return_value.results = [
RerankResponseResultsItem(
index=1, relevance_score=1.0, document=RerankResponseResultsItemDocument(text="foo")
),
RerankResponseResultsItem(
index=2, relevance_score=0.5, document=RerankResponseResultsItemDocument(text="bar")
),
]

return mock_client

def test_run(self, mock_client):
module = TextChunksRerankRagModule(rerank_driver=CohereRerankDriver(api_key="api-key"))
result = module.run(RagContext(query="test"))
result = module.run(RagContext(query="test", text_chunks=[TextArtifact("foo"), TextArtifact("bar")]))

assert len(result) == 2
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_run(self):
module = TextLoaderRetrievalRagModule(
loader=WebLoader(max_tokens=MAX_TOKENS, embedding_driver=embedding_driver),
vector_store_driver=LocalVectorStoreDriver(embedding_driver=embedding_driver),
source="https://www.griptape.ai"
source="https://www.griptape.ai",
)

assert module.run(RagContext(query="foo"))[0].value == "foobar"
12 changes: 9 additions & 3 deletions tests/unit/mixins/test_seriliazable_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,22 @@ def test_from_json(self):

def test_str(self):
assert str(MockSerializable()) == json.dumps(
{"type": "MockSerializable", "foo": "bar", "bar": None, "baz": None}
{"type": "MockSerializable", "foo": "bar", "bar": None, "baz": None, "nested": None}
)

def test_to_json(self):
assert MockSerializable().to_json() == json.dumps(
{"type": "MockSerializable", "foo": "bar", "bar": None, "baz": None}
{"type": "MockSerializable", "foo": "bar", "bar": None, "baz": None, "nested": None}
)

def test_to_dict(self):
assert MockSerializable().to_dict() == {"type": "MockSerializable", "foo": "bar", "bar": None, "baz": None}
assert MockSerializable().to_dict() == {
"type": "MockSerializable",
"foo": "bar",
"bar": None,
"baz": None,
"nested": None,
}

def test_import_class_rec(self):
assert (
Expand Down

0 comments on commit 0368ce4

Please sign in to comment.