From 044e87091ec1f408f120f784d272cd3145dfd6a0 Mon Sep 17 00:00:00 2001 From: Vasily Vasinov Date: Thu, 8 Aug 2024 15:24:44 -0600 Subject: [PATCH 01/13] Add support for multiple to ResponseRagStage --- docs/examples/query-webpage-astra-db.md | 8 ++++-- docs/examples/talk-to-a-pdf.md | 8 ++++-- docs/examples/talk-to-a-webpage.md | 8 ++++-- .../griptape-framework/engines/rag-engines.md | 10 ++++--- .../structures/task-memory.md | 8 ++++-- docs/griptape-framework/structures/tasks.md | 8 ++++-- .../official-tools/rag-client.md | 8 ++++-- .../response/base_response_rag_module.py | 3 +- .../response/prompt_response_rag_module.py | 7 ++--- .../text_chunks_response_rag_module.py | 8 ++---- griptape/engines/rag/rag_context.py | 6 ++-- .../engines/rag/stages/query_rag_stage.py | 2 +- .../engines/rag/stages/response_rag_stage.py | 28 +++++-------------- .../engines/rag/stages/retrieval_rag_stage.py | 6 ++-- .../task/storage/text_artifact_storage.py | 10 +++---- griptape/structures/structure.py | 10 ++++--- griptape/tasks/rag_task.py | 10 +++---- griptape/tools/rag_client/tool.py | 11 ++++---- ...est_footnote_prompt_response_rag_module.py | 2 +- .../test_prompt_response_rag_module.py | 2 +- .../test_text_chunks_response_rag_module.py | 2 +- tests/unit/engines/rag/test_rag_engine.py | 8 ++++-- tests/unit/structures/test_agent.py | 2 +- tests/unit/tasks/test_rag_task.py | 4 ++- tests/unit/tools/test_rag_client.py | 2 +- tests/utils/defaults.py | 6 +++- 26 files changed, 98 insertions(+), 89 deletions(-) diff --git a/docs/examples/query-webpage-astra-db.md b/docs/examples/query-webpage-astra-db.md index 46989c2ba..53153a85f 100644 --- a/docs/examples/query-webpage-astra-db.md +++ b/docs/examples/query-webpage-astra-db.md @@ -56,9 +56,11 @@ engine = RagEngine( ] ), response_stage=ResponseRagStage( - response_module=PromptResponseRagModule( - prompt_driver=OpenAiChatPromptDriver(model="gpt-4o") - ) + response_modules=[ + PromptResponseRagModule( + prompt_driver=OpenAiChatPromptDriver(model="gpt-4o") + ) + ] ) ) diff --git a/docs/examples/talk-to-a-pdf.md b/docs/examples/talk-to-a-pdf.md index a3f47f0c6..f7a653289 100644 --- a/docs/examples/talk-to-a-pdf.md +++ b/docs/examples/talk-to-a-pdf.md @@ -30,9 +30,11 @@ engine = RagEngine( ] ), response_stage=ResponseRagStage( - response_module=PromptResponseRagModule( - prompt_driver=OpenAiChatPromptDriver(model="gpt-4o") - ) + response_modules=[ + PromptResponseRagModule( + prompt_driver=OpenAiChatPromptDriver(model="gpt-4o") + ) + ] ) ) vector_store_tool = RagClient( diff --git a/docs/examples/talk-to-a-webpage.md b/docs/examples/talk-to-a-webpage.md index d1d0da5fa..091a8c8c4 100644 --- a/docs/examples/talk-to-a-webpage.md +++ b/docs/examples/talk-to-a-webpage.md @@ -29,9 +29,11 @@ engine = RagEngine( ] ), response_stage=ResponseRagStage( - response_module=PromptResponseRagModule( - prompt_driver=OpenAiChatPromptDriver(model="gpt-4o") - ) + response_modules=[ + PromptResponseRagModule( + prompt_driver=OpenAiChatPromptDriver(model="gpt-4o") + ) + ] ) ) diff --git a/docs/griptape-framework/engines/rag-engines.md b/docs/griptape-framework/engines/rag-engines.md index c557c4087..f1ca8351e 100644 --- a/docs/griptape-framework/engines/rag-engines.md +++ b/docs/griptape-framework/engines/rag-engines.md @@ -86,9 +86,11 @@ rag_engine = RagEngine( ] ), response_stage=ResponseRagStage( - response_module=PromptResponseRagModule( - prompt_driver=prompt_driver - ) + response_modules=[ + PromptResponseRagModule( + prompt_driver=OpenAiChatPromptDriver(model="gpt-4o") + ) + ] ) ) @@ -104,6 +106,6 @@ rag_context = RagContext( ) print( - rag_engine.process(rag_context).output.to_text() + rag_engine.process(rag_context).outputs[0].to_text() ) ``` \ No newline at end of file diff --git a/docs/griptape-framework/structures/task-memory.md b/docs/griptape-framework/structures/task-memory.md index ea4a787f6..955fddb7b 100644 --- a/docs/griptape-framework/structures/task-memory.md +++ b/docs/griptape-framework/structures/task-memory.md @@ -243,9 +243,11 @@ agent = Agent( ] ), response_stage=ResponseRagStage( - response_module=PromptResponseRagModule( - prompt_driver=OpenAiChatPromptDriver(model="gpt-4o") - ) + response_modules=[ + PromptResponseRagModule( + prompt_driver=OpenAiChatPromptDriver(model="gpt-4o") + ) + ] ) ), retrieval_rag_module_name="VectorStoreRetrievalRagModule", diff --git a/docs/griptape-framework/structures/tasks.md b/docs/griptape-framework/structures/tasks.md index 6d479578b..5057056c2 100644 --- a/docs/griptape-framework/structures/tasks.md +++ b/docs/griptape-framework/structures/tasks.md @@ -421,9 +421,11 @@ agent.add_task( ] ), response_stage=ResponseRagStage( - response_module=PromptResponseRagModule( - prompt_driver=OpenAiChatPromptDriver(model="gpt-4o") - ) + response_modules=[ + PromptResponseRagModule( + prompt_driver=OpenAiChatPromptDriver(model="gpt-4o") + ) + ] ) ), ) diff --git a/docs/griptape-tools/official-tools/rag-client.md b/docs/griptape-tools/official-tools/rag-client.md index 8b1447768..080b7498b 100644 --- a/docs/griptape-tools/official-tools/rag-client.md +++ b/docs/griptape-tools/official-tools/rag-client.md @@ -37,9 +37,11 @@ rag_client = RagClient( ] ), response_stage=ResponseRagStage( - response_module=PromptResponseRagModule( - prompt_driver=OpenAiChatPromptDriver(model="gpt-4o") - ) + response_modules=[ + PromptResponseRagModule( + prompt_driver=OpenAiChatPromptDriver(model="gpt-4o") + ) + ] ) ) ) diff --git a/griptape/engines/rag/modules/response/base_response_rag_module.py b/griptape/engines/rag/modules/response/base_response_rag_module.py index 30ab82201..1bd3ddeb7 100644 --- a/griptape/engines/rag/modules/response/base_response_rag_module.py +++ b/griptape/engines/rag/modules/response/base_response_rag_module.py @@ -2,6 +2,7 @@ from attrs import define +from griptape.artifacts import BaseArtifact from griptape.engines.rag import RagContext from griptape.engines.rag.modules import BaseRagModule @@ -9,4 +10,4 @@ @define(kw_only=True) class BaseResponseRagModule(BaseRagModule, ABC): @abstractmethod - def run(self, context: RagContext) -> RagContext: ... + def run(self, context: RagContext) -> BaseArtifact: ... diff --git a/griptape/engines/rag/modules/response/prompt_response_rag_module.py b/griptape/engines/rag/modules/response/prompt_response_rag_module.py index 1cdb9a6f0..118ed8891 100644 --- a/griptape/engines/rag/modules/response/prompt_response_rag_module.py +++ b/griptape/engines/rag/modules/response/prompt_response_rag_module.py @@ -4,6 +4,7 @@ from attrs import Factory, define, field +from griptape.artifacts import BaseArtifact from griptape.artifacts.text_artifact import TextArtifact from griptape.engines.rag.modules import BaseResponseRagModule from griptape.utils import J2 @@ -21,7 +22,7 @@ class PromptResponseRagModule(BaseResponseRagModule): default=Factory(lambda self: self.default_system_template_generator, takes_self=True), ) - def run(self, context: RagContext) -> RagContext: + def run(self, context: RagContext) -> BaseArtifact: query = context.query tokenizer = self.prompt_driver.tokenizer included_chunks = [] @@ -45,12 +46,10 @@ def run(self, context: RagContext) -> RagContext: output = self.prompt_driver.run(self.generate_prompt_stack(system_prompt, query)).to_artifact() if isinstance(output, TextArtifact): - context.output = output + return output else: raise ValueError("Prompt driver did not return a TextArtifact") - return context - def default_system_template_generator(self, context: RagContext, artifacts: list[TextArtifact]) -> str: return J2("engines/rag/modules/response/prompt/system.j2").render( text_chunks=[c.to_text() for c in artifacts], diff --git a/griptape/engines/rag/modules/response/text_chunks_response_rag_module.py b/griptape/engines/rag/modules/response/text_chunks_response_rag_module.py index fd57b3905..370f69bff 100644 --- a/griptape/engines/rag/modules/response/text_chunks_response_rag_module.py +++ b/griptape/engines/rag/modules/response/text_chunks_response_rag_module.py @@ -1,13 +1,11 @@ from attrs import define -from griptape.artifacts import ListArtifact +from griptape.artifacts import ListArtifact, BaseArtifact from griptape.engines.rag import RagContext from griptape.engines.rag.modules import BaseResponseRagModule @define(kw_only=True) class TextChunksResponseRagModule(BaseResponseRagModule): - def run(self, context: RagContext) -> RagContext: - context.output = ListArtifact(context.text_chunks) - - return context + def run(self, context: RagContext) -> BaseArtifact: + return ListArtifact(context.text_chunks) diff --git a/griptape/engines/rag/rag_context.py b/griptape/engines/rag/rag_context.py index 3146070c2..1ddfbb1b0 100644 --- a/griptape/engines/rag/rag_context.py +++ b/griptape/engines/rag/rag_context.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from attrs import define, field @@ -22,7 +22,7 @@ class RagContext(SerializableMixin): before_query: An optional list of strings to add before the query in response modules. after_query: An optional list of strings to add after the query in response modules. text_chunks: A list of text chunks to pass around from the retrieval stage to the response stage. - output: Final output from the response stage. + outputs: List of outputs from the response stage. """ query: str = field(metadata={"serializable": True}) @@ -30,7 +30,7 @@ class RagContext(SerializableMixin): before_query: list[str] = field(factory=list, metadata={"serializable": True}) after_query: list[str] = field(factory=list, metadata={"serializable": True}) text_chunks: list[TextArtifact] = field(factory=list, metadata={"serializable": True}) - output: Optional[BaseArtifact] = field(default=None, metadata={"serializable": True}) + outputs: list[BaseArtifact] = field(factory=list, metadata={"serializable": True}) def get_references(self) -> list[Reference]: return utils.references_from_artifacts(self.text_chunks) diff --git a/griptape/engines/rag/stages/query_rag_stage.py b/griptape/engines/rag/stages/query_rag_stage.py index 97a6c2e2d..ebde3170c 100644 --- a/griptape/engines/rag/stages/query_rag_stage.py +++ b/griptape/engines/rag/stages/query_rag_stage.py @@ -23,7 +23,7 @@ def modules(self) -> Sequence[BaseRagModule]: return self.query_modules def run(self, context: RagContext) -> RagContext: - logging.info("QueryStage: running %s query generation modules sequentially", len(self.query_modules)) + logging.info("QueryRagStage: running %s query generation modules sequentially", len(self.query_modules)) [qm.run(context) for qm in self.query_modules] diff --git a/griptape/engines/rag/stages/response_rag_stage.py b/griptape/engines/rag/stages/response_rag_stage.py index b63b5bc21..4bc0b2be5 100644 --- a/griptape/engines/rag/stages/response_rag_stage.py +++ b/griptape/engines/rag/stages/response_rag_stage.py @@ -5,13 +5,12 @@ from attrs import define, field +from griptape import utils from griptape.engines.rag.stages import BaseRagStage if TYPE_CHECKING: from griptape.engines.rag import RagContext from griptape.engines.rag.modules import ( - BaseAfterResponseRagModule, - BaseBeforeResponseRagModule, BaseRagModule, BaseResponseRagModule, ) @@ -19,35 +18,22 @@ @define(kw_only=True) class ResponseRagStage(BaseRagStage): - before_response_modules: list[BaseBeforeResponseRagModule] = field(factory=list) - response_module: BaseResponseRagModule = field() - after_response_modules: list[BaseAfterResponseRagModule] = field(factory=list) + response_modules: list[BaseResponseRagModule] = field() @property def modules(self) -> list[BaseRagModule]: ms = [] - ms.extend(self.before_response_modules) - ms.extend(self.after_response_modules) - - if self.response_module is not None: - ms.append(self.response_module) + ms.extend(self.response_modules) return ms def run(self, context: RagContext) -> RagContext: - logging.info("GenerationStage: running %s before modules sequentially", len(self.before_response_modules)) - - for generator in self.before_response_modules: - context = generator.run(context) - - logging.info("GenerationStage: running generation module") - - context = self.response_module.run(context) + logging.info("ResponseRagStage: running %s retrieval modules in parallel", len(self.response_modules)) - logging.info("GenerationStage: running %s after modules sequentially", len(self.after_response_modules)) + with self.futures_executor_fn() as executor: + results = utils.execute_futures_list([executor.submit(r.run, context) for r in self.response_modules]) - for generator in self.after_response_modules: - context = generator.run(context) + context.outputs = results return context diff --git a/griptape/engines/rag/stages/retrieval_rag_stage.py b/griptape/engines/rag/stages/retrieval_rag_stage.py index 50b84abfc..fa618a7ff 100644 --- a/griptape/engines/rag/stages/retrieval_rag_stage.py +++ b/griptape/engines/rag/stages/retrieval_rag_stage.py @@ -33,7 +33,7 @@ def modules(self) -> list[BaseRagModule]: return ms def run(self, context: RagContext) -> RagContext: - logging.info("RetrievalStage: running %s retrieval modules in parallel", len(self.retrieval_modules)) + logging.info("RetrievalRagStage: running %s retrieval modules in parallel", len(self.retrieval_modules)) with self.futures_executor_fn() as executor: results = utils.execute_futures_list([executor.submit(r.run, context) for r in self.retrieval_modules]) @@ -47,7 +47,7 @@ def run(self, context: RagContext) -> RagContext: chunks_after_dedup = len(results) logging.info( - "RetrievalStage: deduplicated %s " "chunks (%s - %s)", + "RetrievalRagStage: deduplicated %s " "chunks (%s - %s)", chunks_before_dedup - chunks_after_dedup, chunks_before_dedup, chunks_after_dedup, @@ -56,7 +56,7 @@ def run(self, context: RagContext) -> RagContext: context.text_chunks = [a for a in results if isinstance(a, TextArtifact)] if self.rerank_module: - logging.info("RetrievalStage: running rerank module on %s chunks", chunks_after_dedup) + logging.info("RetrievalRagStage: running rerank module on %s chunks", chunks_after_dedup) context.text_chunks = [a for a in self.rerank_module.run(context) if isinstance(a, TextArtifact)] diff --git a/griptape/memory/task/storage/text_artifact_storage.py b/griptape/memory/task/storage/text_artifact_storage.py index 8e66c5aba..62a517bc9 100644 --- a/griptape/memory/task/storage/text_artifact_storage.py +++ b/griptape/memory/task/storage/text_artifact_storage.py @@ -52,7 +52,7 @@ def query(self, namespace: str, query: str, metadata: Any = None) -> BaseArtifac if self.retrieval_rag_module_name is None: raise ValueError("retrieval_rag_module_name is not set") - result = self.rag_engine.process( + outputs = self.rag_engine.process( RagContext( query=query, module_configs={ @@ -64,9 +64,9 @@ def query(self, namespace: str, query: str, metadata: Any = None) -> BaseArtifac }, }, ), - ).output + ).outputs - if result is None: - return InfoArtifact("Empty output") + if len(outputs) > 0: + return outputs[0] else: - return result + return InfoArtifact("Empty output") diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index d68457ebc..f318f4daf 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -180,11 +180,13 @@ def default_rag_engine(self) -> RagEngine: retrieval_modules=[VectorStoreRetrievalRagModule(vector_store_driver=self.config.vector_store_driver)], ), response_stage=ResponseRagStage( - before_response_modules=[ - RulesetsBeforeResponseRagModule(rulesets=self.rulesets), - MetadataBeforeResponseRagModule(), + # before_response_modules=[ + # RulesetsBeforeResponseRagModule(rulesets=self.rulesets), + # MetadataBeforeResponseRagModule(), + # ], + response_modules=[ + PromptResponseRagModule(prompt_driver=self.config.prompt_driver) ], - response_module=PromptResponseRagModule(prompt_driver=self.config.prompt_driver), ), ) diff --git a/griptape/tasks/rag_task.py b/griptape/tasks/rag_task.py index 3f88f34d1..2f44fdfa4 100644 --- a/griptape/tasks/rag_task.py +++ b/griptape/tasks/rag_task.py @@ -4,7 +4,7 @@ from attrs import define, field -from griptape.artifacts import BaseArtifact, ErrorArtifact +from griptape.artifacts import BaseArtifact, ErrorArtifact, ListArtifact from griptape.tasks import BaseTextInputTask if TYPE_CHECKING: @@ -29,9 +29,9 @@ def rag_engine(self, value: RagEngine) -> None: self._rag_engine = value def run(self) -> BaseArtifact: - result = self.rag_engine.process_query(self.input.to_text()).output + outputs = self.rag_engine.process_query(self.input.to_text()).outputs - if result is None: - return ErrorArtifact("empty output") + if len(outputs) > 0: + return ListArtifact(outputs) else: - return result + return ErrorArtifact("empty output") diff --git a/griptape/tools/rag_client/tool.py b/griptape/tools/rag_client/tool.py index bbdef8159..613e254af 100644 --- a/griptape/tools/rag_client/tool.py +++ b/griptape/tools/rag_client/tool.py @@ -5,7 +5,7 @@ from attrs import define, field from schema import Literal, Schema -from griptape.artifacts import BaseArtifact, ErrorArtifact +from griptape.artifacts import BaseArtifact, ErrorArtifact, ListArtifact from griptape.tools import BaseTool from griptape.utils.decorators import activity @@ -35,11 +35,12 @@ def search(self, params: dict) -> BaseArtifact: query = params["values"]["query"] try: - result = self.rag_engine.process_query(query) + outputs = self.rag_engine.process_query(query).outputs - if result.output is None: - return ErrorArtifact("query output is empty") + if len(outputs) > 0: + return ListArtifact(outputs) else: - return result.output + return ErrorArtifact("query output is empty") + except Exception as e: return ErrorArtifact(f"error querying: {e}") diff --git a/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py b/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py index 385cf0c04..4d0aad139 100644 --- a/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py +++ b/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py @@ -13,7 +13,7 @@ def module(self): return FootnotePromptResponseRagModule(prompt_driver=MockPromptDriver()) def test_run(self, module): - assert module.run(RagContext(query="test")).output.value == "mock output" + assert module.run(RagContext(query="test")).value == "mock output" def test_prompt(self, module): system_message = module.default_system_template_generator( diff --git a/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py b/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py index 2f8a912e2..fb78ddc3e 100644 --- a/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py +++ b/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py @@ -12,7 +12,7 @@ def module(self): return PromptResponseRagModule(prompt_driver=MockPromptDriver()) def test_run(self, module): - assert module.run(RagContext(query="test")).output.value == "mock output" + assert module.run(RagContext(query="test")).value == "mock output" def test_prompt(self, module): system_message = module.default_system_template_generator( diff --git a/tests/unit/engines/rag/modules/generation/test_text_chunks_response_rag_module.py b/tests/unit/engines/rag/modules/generation/test_text_chunks_response_rag_module.py index ae4410b2c..05b8042e6 100644 --- a/tests/unit/engines/rag/modules/generation/test_text_chunks_response_rag_module.py +++ b/tests/unit/engines/rag/modules/generation/test_text_chunks_response_rag_module.py @@ -13,4 +13,4 @@ def module(self): def test_run(self, module): text_chunks = [TextArtifact("foo"), TextArtifact("bar")] - assert module.run(RagContext(query="test", text_chunks=text_chunks)).output.value == text_chunks + assert module.run(RagContext(query="test", text_chunks=text_chunks)).value == text_chunks diff --git a/tests/unit/engines/rag/test_rag_engine.py b/tests/unit/engines/rag/test_rag_engine.py index c3d728bb3..964a52650 100644 --- a/tests/unit/engines/rag/test_rag_engine.py +++ b/tests/unit/engines/rag/test_rag_engine.py @@ -19,7 +19,9 @@ def engine(self): ) ] ), - response_stage=ResponseRagStage(response_module=PromptResponseRagModule(prompt_driver=MockPromptDriver())), + response_stage=ResponseRagStage( + response_modules=[PromptResponseRagModule(prompt_driver=MockPromptDriver())] + ), ) def test_module_name_uniqueness(self): @@ -45,7 +47,7 @@ def test_module_name_uniqueness(self): ) def test_process_query(self, engine): - assert engine.process_query("test").output.value == "mock output" + assert engine.process_query("test").outputs[0].value == "mock output" def test_process(self, engine): - assert engine.process(RagContext(query="test")).output.value == "mock output" + assert engine.process(RagContext(query="test")).outputs[0].value == "mock output" diff --git a/tests/unit/structures/test_agent.py b/tests/unit/structures/test_agent.py index a09ad0f9a..e3d9034c4 100644 --- a/tests/unit/structures/test_agent.py +++ b/tests/unit/structures/test_agent.py @@ -238,7 +238,7 @@ def test_task_memory_defaults(self): storage = list(agent.task_memory.artifact_storages.values())[0] assert isinstance(storage, TextArtifactStorage) - assert storage.rag_engine.response_stage.response_module.prompt_driver == prompt_driver + assert storage.rag_engine.response_stage.response_modules[0].prompt_driver == prompt_driver assert ( storage.rag_engine.retrieval_stage.retrieval_modules[0].vector_store_driver.embedding_driver == embedding_driver diff --git a/tests/unit/tasks/test_rag_task.py b/tests/unit/tasks/test_rag_task.py index b205d385a..590a4ce9f 100644 --- a/tests/unit/tasks/test_rag_task.py +++ b/tests/unit/tasks/test_rag_task.py @@ -15,7 +15,9 @@ def task(self): input="test", rag_engine=RagEngine( response_stage=ResponseRagStage( - response_module=PromptResponseRagModule(prompt_driver=MockPromptDriver()) + response_modules=[ + PromptResponseRagModule(prompt_driver=MockPromptDriver()) + ] ) ), ) diff --git a/tests/unit/tools/test_rag_client.py b/tests/unit/tools/test_rag_client.py index 9c6497b02..60a0df722 100644 --- a/tests/unit/tools/test_rag_client.py +++ b/tests/unit/tools/test_rag_client.py @@ -10,4 +10,4 @@ def test_search(self): vector_store_driver = LocalVectorStoreDriver(embedding_driver=MockEmbeddingDriver()) tool = RagClient(description="Test", rag_engine=rag_engine(MockPromptDriver(), vector_store_driver)) - assert tool.search({"values": {"query": "test"}}).value == "mock output" + assert tool.search({"values": {"query": "test"}}).value[0].value == "mock output" diff --git a/tests/utils/defaults.py b/tests/utils/defaults.py index bad7f0d79..c0fa5220f 100644 --- a/tests/utils/defaults.py +++ b/tests/utils/defaults.py @@ -34,5 +34,9 @@ def rag_engine(prompt_driver, vector_store_driver): retrieval_stage=RetrievalRagStage( retrieval_modules=[VectorStoreRetrievalRagModule(vector_store_driver=vector_store_driver)] ), - response_stage=ResponseRagStage(response_module=PromptResponseRagModule(prompt_driver=prompt_driver)), + response_stage=ResponseRagStage( + response_modules=[ + PromptResponseRagModule(prompt_driver=prompt_driver) + ] + ), ) From 8b977454f1f37f030d4c1bc3b8e6d2bfd9eb5254 Mon Sep 17 00:00:00 2001 From: Vasily Vasinov Date: Thu, 8 Aug 2024 15:26:59 -0600 Subject: [PATCH 02/13] Update CHANGELOG --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9e016228c..7073557b2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`. - **BREAKING**: Removed `EventPublisherMixin`. +- **BREAKING**: `RagContext.output` was changed to `RagContext.outputs` to support multiple outputs. All relevant RAG modules we adjusted accordingly. - `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible. ## [0.29.0] - 2024-07-30 From 6f59f87b3e59a9c2fc5d78a7a4fee2e38201fb10 Mon Sep 17 00:00:00 2001 From: Vasily Vasinov Date: Fri, 9 Aug 2024 13:07:47 -0600 Subject: [PATCH 03/13] Update rulsets and metadata usage in PromptResponseRagModule --- .../griptape-framework/engines/rag-engines.md | 2 -- griptape/engines/rag/modules/__init__.py | 4 --- .../metadata_before_response_rag_module.py | 25 ------------------- .../response/prompt_response_rag_module.py | 23 +++++++++++------ .../rulesets_before_response_rag_module.py | 22 ---------------- .../text_chunks_response_rag_module.py | 2 +- griptape/structures/structure.py | 8 +----- .../rag/modules/response/prompt/system.j2 | 8 ++++-- ...est_metadata_before_response_rag_module.py | 21 ---------------- .../test_prompt_response_rag_module.py | 9 +++++-- ...est_rulesets_before_response_rag_module.py | 10 -------- tests/unit/tasks/test_rag_task.py | 4 +-- tests/utils/defaults.py | 6 +---- 13 files changed, 32 insertions(+), 112 deletions(-) delete mode 100644 griptape/engines/rag/modules/response/metadata_before_response_rag_module.py delete mode 100644 griptape/engines/rag/modules/response/rulesets_before_response_rag_module.py delete mode 100644 tests/unit/engines/rag/modules/generation/test_metadata_before_response_rag_module.py delete mode 100644 tests/unit/engines/rag/modules/generation/test_rulesets_before_response_rag_module.py diff --git a/docs/griptape-framework/engines/rag-engines.md b/docs/griptape-framework/engines/rag-engines.md index f1ca8351e..d2990f23d 100644 --- a/docs/griptape-framework/engines/rag-engines.md +++ b/docs/griptape-framework/engines/rag-engines.md @@ -33,8 +33,6 @@ RAG modules are used to implement concrete actions in the RAG pipeline. `RagEngi - `TextChunksRerankRagModule` is for re-ranking retrieved results. #### Response Modules -- `MetadataBeforeResponseRagModule` is for appending metadata. -- `RulesetsBeforeResponseRagModule` is for appending rulesets. - `PromptResponseRagModule` is for generating responses based on retrieved text chunks. - `TextChunksResponseRagModule` is for responding with retrieved text chunks. - `FootnotePromptResponseRagModule` is for responding with automatic footnotes from text chunk references. diff --git a/griptape/engines/rag/modules/__init__.py b/griptape/engines/rag/modules/__init__.py index be66082f0..ace9f4a3b 100644 --- a/griptape/engines/rag/modules/__init__.py +++ b/griptape/engines/rag/modules/__init__.py @@ -10,8 +10,6 @@ from .response.base_after_response_rag_module import BaseAfterResponseRagModule from .response.base_response_rag_module import BaseResponseRagModule from .response.prompt_response_rag_module import PromptResponseRagModule -from .response.rulesets_before_response_rag_module import RulesetsBeforeResponseRagModule -from .response.metadata_before_response_rag_module import MetadataBeforeResponseRagModule from .response.text_chunks_response_rag_module import TextChunksResponseRagModule from .response.footnote_prompt_response_rag_module import FootnotePromptResponseRagModule @@ -28,8 +26,6 @@ "BaseAfterResponseRagModule", "BaseResponseRagModule", "PromptResponseRagModule", - "RulesetsBeforeResponseRagModule", - "MetadataBeforeResponseRagModule", "TextChunksResponseRagModule", "FootnotePromptResponseRagModule", ] diff --git a/griptape/engines/rag/modules/response/metadata_before_response_rag_module.py b/griptape/engines/rag/modules/response/metadata_before_response_rag_module.py deleted file mode 100644 index d2d546213..000000000 --- a/griptape/engines/rag/modules/response/metadata_before_response_rag_module.py +++ /dev/null @@ -1,25 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Optional - -from attrs import define, field - -from griptape.engines.rag.modules import BaseBeforeResponseRagModule -from griptape.utils import J2 - -if TYPE_CHECKING: - from griptape.engines.rag import RagContext - - -@define(kw_only=True) -class MetadataBeforeResponseRagModule(BaseBeforeResponseRagModule): - metadata: Optional[str] = field(default=None) - - def run(self, context: RagContext) -> RagContext: - context_metadata = self.get_context_param(context, "metadata") - metadata = self.metadata if context_metadata is None else context_metadata - - if metadata is not None: - context.before_query.append(J2("engines/rag/modules/response/metadata/system.j2").render(metadata=metadata)) - - return context diff --git a/griptape/engines/rag/modules/response/prompt_response_rag_module.py b/griptape/engines/rag/modules/response/prompt_response_rag_module.py index 118ed8891..02f10f5db 100644 --- a/griptape/engines/rag/modules/response/prompt_response_rag_module.py +++ b/griptape/engines/rag/modules/response/prompt_response_rag_module.py @@ -1,23 +1,26 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Callable, Optional, Any from attrs import Factory, define, field -from griptape.artifacts import BaseArtifact from griptape.artifacts.text_artifact import TextArtifact from griptape.engines.rag.modules import BaseResponseRagModule from griptape.utils import J2 if TYPE_CHECKING: + from griptape.artifacts import BaseArtifact from griptape.drivers import BasePromptDriver from griptape.engines.rag import RagContext + from griptape.rules import Ruleset @define(kw_only=True) class PromptResponseRagModule(BaseResponseRagModule): - answer_token_offset: int = field(default=400) prompt_driver: BasePromptDriver = field() + answer_token_offset: int = field(default=400) + rulesets: list[Ruleset] = field(factory=list) + metadata: Optional[str] = field(default=None) generate_system_template: Callable[[RagContext, list[TextArtifact]], str] = field( default=Factory(lambda self: self.default_system_template_generator, takes_self=True), ) @@ -51,8 +54,12 @@ def run(self, context: RagContext) -> BaseArtifact: raise ValueError("Prompt driver did not return a TextArtifact") def default_system_template_generator(self, context: RagContext, artifacts: list[TextArtifact]) -> str: - return J2("engines/rag/modules/response/prompt/system.j2").render( - text_chunks=[c.to_text() for c in artifacts], - before_system_prompt="\n\n".join(context.before_query), - after_system_prompt="\n\n".join(context.after_query), - ) + params: dict[str, Any] = {"text_chunks": [c.to_text() for c in artifacts]} + + if len(self.rulesets) > 0: + params["rulesets"] = J2("rulesets/rulesets.j2").render(rulesets=self.rulesets) + + if self.metadata is not None: + params["metadata"] = J2("engines/rag/modules/response/metadata/system.j2").render(metadata=self.metadata) + + return J2("engines/rag/modules/response/prompt/system.j2").render(**params) diff --git a/griptape/engines/rag/modules/response/rulesets_before_response_rag_module.py b/griptape/engines/rag/modules/response/rulesets_before_response_rag_module.py deleted file mode 100644 index 81b8410ce..000000000 --- a/griptape/engines/rag/modules/response/rulesets_before_response_rag_module.py +++ /dev/null @@ -1,22 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from attrs import define, field - -from griptape.engines.rag.modules import BaseBeforeResponseRagModule -from griptape.utils import J2 - -if TYPE_CHECKING: - from griptape.engines.rag import RagContext - from griptape.rules import Ruleset - - -@define -class RulesetsBeforeResponseRagModule(BaseBeforeResponseRagModule): - rulesets: list[Ruleset] = field(kw_only=True) - - def run(self, context: RagContext) -> RagContext: - context.before_query.append(J2("rulesets/rulesets.j2").render(rulesets=self.rulesets)) - - return context diff --git a/griptape/engines/rag/modules/response/text_chunks_response_rag_module.py b/griptape/engines/rag/modules/response/text_chunks_response_rag_module.py index 370f69bff..35da0592b 100644 --- a/griptape/engines/rag/modules/response/text_chunks_response_rag_module.py +++ b/griptape/engines/rag/modules/response/text_chunks_response_rag_module.py @@ -1,6 +1,6 @@ from attrs import define -from griptape.artifacts import ListArtifact, BaseArtifact +from griptape.artifacts import BaseArtifact, ListArtifact from griptape.engines.rag import RagContext from griptape.engines.rag.modules import BaseResponseRagModule diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index f318f4daf..8f095dfeb 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -22,9 +22,7 @@ from griptape.engines import CsvExtractionEngine, JsonExtractionEngine, PromptSummaryEngine from griptape.engines.rag import RagEngine from griptape.engines.rag.modules import ( - MetadataBeforeResponseRagModule, PromptResponseRagModule, - RulesetsBeforeResponseRagModule, VectorStoreRetrievalRagModule, ) from griptape.engines.rag.stages import ResponseRagStage, RetrievalRagStage @@ -180,12 +178,8 @@ def default_rag_engine(self) -> RagEngine: retrieval_modules=[VectorStoreRetrievalRagModule(vector_store_driver=self.config.vector_store_driver)], ), response_stage=ResponseRagStage( - # before_response_modules=[ - # RulesetsBeforeResponseRagModule(rulesets=self.rulesets), - # MetadataBeforeResponseRagModule(), - # ], response_modules=[ - PromptResponseRagModule(prompt_driver=self.config.prompt_driver) + PromptResponseRagModule(prompt_driver=self.config.prompt_driver, rulesets=self.rulesets) ], ), ) diff --git a/griptape/templates/engines/rag/modules/response/prompt/system.j2 b/griptape/templates/engines/rag/modules/response/prompt/system.j2 index 1fa9d8c12..38b0297d5 100644 --- a/griptape/templates/engines/rag/modules/response/prompt/system.j2 +++ b/griptape/templates/engines/rag/modules/response/prompt/system.j2 @@ -1,6 +1,10 @@ You are an expert Q&A system. Always answer the question using the provided context information, and not prior knowledge. Always be truthful. Don't make up facts. You can answer questions by searching through text chunks. -{% if before_system_prompt %} -{{ before_system_prompt }} +{% if rulesets %} +{{ rulesets }} + +{% endif %} +{% if metadata %} +{{ metadata }} {% endif %} Use the following list of text chunks to respond. If there are no text chunks available or text chunks don't have relevant information respond with "I could not find an answer." diff --git a/tests/unit/engines/rag/modules/generation/test_metadata_before_response_rag_module.py b/tests/unit/engines/rag/modules/generation/test_metadata_before_response_rag_module.py deleted file mode 100644 index 9519c8017..000000000 --- a/tests/unit/engines/rag/modules/generation/test_metadata_before_response_rag_module.py +++ /dev/null @@ -1,21 +0,0 @@ -from griptape.engines.rag import RagContext -from griptape.engines.rag.modules import MetadataBeforeResponseRagModule - - -class TestMetadataBeforeResponseRagModule: - def test_run(self): - module = MetadataBeforeResponseRagModule(name="foo") - - assert ( - "foo" in module.run(RagContext(module_configs={"foo": {"metadata": "foo"}}, query="test")).before_query[0] - ) - - def test_run_with_override(self): - module = MetadataBeforeResponseRagModule(name="foo", metadata="bar") - - assert ( - "bar" - in module.run( - RagContext(module_configs={"foo": {"query_params": {"metadata": "foo"}}}, query="test") - ).before_query[0] - ) diff --git a/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py b/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py index fb78ddc3e..7e8e86988 100644 --- a/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py +++ b/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py @@ -3,20 +3,25 @@ from griptape.artifacts import TextArtifact from griptape.engines.rag import RagContext from griptape.engines.rag.modules import PromptResponseRagModule +from griptape.rules import Ruleset, Rule from tests.mocks.mock_prompt_driver import MockPromptDriver class TestPromptResponseRagModule: @pytest.fixture() def module(self): - return PromptResponseRagModule(prompt_driver=MockPromptDriver()) + return PromptResponseRagModule( + prompt_driver=MockPromptDriver(), + rulesets=[Ruleset(name="test", rules=[Rule("*RULESET*")])], + metadata="*META*", + ) def test_run(self, module): assert module.run(RagContext(query="test")).value == "mock output" def test_prompt(self, module): system_message = module.default_system_template_generator( - RagContext(query="test", before_query=["*RULESET*", "*META*"], after_query=[]), + RagContext(query="test"), artifacts=[TextArtifact("*TEXT SEGMENT 1*"), TextArtifact("*TEXT SEGMENT 2*")], ) diff --git a/tests/unit/engines/rag/modules/generation/test_rulesets_before_response_rag_module.py b/tests/unit/engines/rag/modules/generation/test_rulesets_before_response_rag_module.py deleted file mode 100644 index bc85cf266..000000000 --- a/tests/unit/engines/rag/modules/generation/test_rulesets_before_response_rag_module.py +++ /dev/null @@ -1,10 +0,0 @@ -from griptape.engines.rag import RagContext -from griptape.engines.rag.modules import RulesetsBeforeResponseRagModule -from griptape.rules import Rule, Ruleset - - -class TestRulesetsBeforeResponseRagModule: - def test_run(self): - module = RulesetsBeforeResponseRagModule(rulesets=[Ruleset(name="test ruleset", rules=[Rule("test rule")])]) - - assert "test rule" in module.run(RagContext(query="test")).before_query[0] diff --git a/tests/unit/tasks/test_rag_task.py b/tests/unit/tasks/test_rag_task.py index 590a4ce9f..dc8603a2a 100644 --- a/tests/unit/tasks/test_rag_task.py +++ b/tests/unit/tasks/test_rag_task.py @@ -15,9 +15,7 @@ def task(self): input="test", rag_engine=RagEngine( response_stage=ResponseRagStage( - response_modules=[ - PromptResponseRagModule(prompt_driver=MockPromptDriver()) - ] + response_modules=[PromptResponseRagModule(prompt_driver=MockPromptDriver())] ) ), ) diff --git a/tests/utils/defaults.py b/tests/utils/defaults.py index c0fa5220f..154f63ac4 100644 --- a/tests/utils/defaults.py +++ b/tests/utils/defaults.py @@ -34,9 +34,5 @@ def rag_engine(prompt_driver, vector_store_driver): retrieval_stage=RetrievalRagStage( retrieval_modules=[VectorStoreRetrievalRagModule(vector_store_driver=vector_store_driver)] ), - response_stage=ResponseRagStage( - response_modules=[ - PromptResponseRagModule(prompt_driver=prompt_driver) - ] - ), + response_stage=ResponseRagStage(response_modules=[PromptResponseRagModule(prompt_driver=prompt_driver)]), ) From 481f95acb89e3396183b74290703c13dbbbf1957 Mon Sep 17 00:00:00 2001 From: Vasily Vasinov Date: Fri, 9 Aug 2024 13:10:47 -0600 Subject: [PATCH 04/13] Update docs --- docs/examples/src/query_webpage_astra_db_1.py | 4 +- docs/examples/src/talk_to_a_pdf_1.py | 4 +- docs/examples/src/talk_to_a_webpage_1.py | 4 +- .../engines/src/rag_engines_1.py | 67 ++++++++++++++----- .../structures/src/task_memory_6.py | 4 +- .../structures/src/tasks_9.py | 4 +- .../official-tools/src/rag_client_1.py | 4 +- 7 files changed, 68 insertions(+), 23 deletions(-) diff --git a/docs/examples/src/query_webpage_astra_db_1.py b/docs/examples/src/query_webpage_astra_db_1.py index 8b0adfb9a..966c0edff 100644 --- a/docs/examples/src/query_webpage_astra_db_1.py +++ b/docs/examples/src/query_webpage_astra_db_1.py @@ -40,7 +40,9 @@ ] ), response_stage=ResponseRagStage( - response_module=PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o")) + response_modules=[ + PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o")) + ] ), ) diff --git a/docs/examples/src/talk_to_a_pdf_1.py b/docs/examples/src/talk_to_a_pdf_1.py index ee309cba2..644af24e5 100644 --- a/docs/examples/src/talk_to_a_pdf_1.py +++ b/docs/examples/src/talk_to_a_pdf_1.py @@ -22,7 +22,9 @@ ] ), response_stage=ResponseRagStage( - response_module=PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o")) + response_modules=[ + PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o")) + ] ), ) vector_store_tool = RagClient( diff --git a/docs/examples/src/talk_to_a_webpage_1.py b/docs/examples/src/talk_to_a_webpage_1.py index 76638113f..7d6a0042e 100644 --- a/docs/examples/src/talk_to_a_webpage_1.py +++ b/docs/examples/src/talk_to_a_webpage_1.py @@ -22,7 +22,9 @@ ] ), response_stage=ResponseRagStage( - response_module=PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o")) + response_modules=[ + PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o")) + ] ), ) diff --git a/docs/griptape-framework/engines/src/rag_engines_1.py b/docs/griptape-framework/engines/src/rag_engines_1.py index 435a5d470..2d96c64e8 100644 --- a/docs/griptape-framework/engines/src/rag_engines_1.py +++ b/docs/griptape-framework/engines/src/rag_engines_1.py @@ -1,36 +1,69 @@ -from griptape.artifacts import ErrorArtifact -from griptape.drivers import LocalVectorStoreDriver, OpenAiChatPromptDriver, OpenAiEmbeddingDriver -from griptape.engines.rag import RagContext, RagEngine -from griptape.engines.rag.modules import PromptResponseRagModule, TranslateQueryRagModule, VectorStoreRetrievalRagModule -from griptape.engines.rag.stages import QueryRagStage, ResponseRagStage, RetrievalRagStage +from griptape.drivers import LocalVectorStoreDriver, OpenAiEmbeddingDriver, OpenAiChatPromptDriver +from griptape.engines.rag import RagEngine, RagContext +from griptape.engines.rag.modules import VectorStoreRetrievalRagModule, PromptResponseRagModule, TranslateQueryRagModule +from griptape.engines.rag.stages import RetrievalRagStage, ResponseRagStage, QueryRagStage from griptape.loaders import WebLoader +from griptape.rules import Ruleset, Rule prompt_driver = OpenAiChatPromptDriver(model="gpt-4o", temperature=0) -vector_store = LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver()) - -artifacts = WebLoader(max_tokens=500).load("https://www.griptape.ai") -if isinstance(artifacts, ErrorArtifact): - raise ValueError(artifacts.value) +vector_store = LocalVectorStoreDriver( + embedding_driver=OpenAiEmbeddingDriver() +) -vector_store.upsert_text_artifacts({"griptape": artifacts}) +vector_store.upsert_text_artifacts({ + "griptape": WebLoader(max_tokens=500).load("https://www.griptape.ai"), +}) rag_engine = RagEngine( - query_stage=QueryRagStage(query_modules=[TranslateQueryRagModule(prompt_driver=prompt_driver, language="English")]), + query_stage=QueryRagStage( + query_modules=[ + TranslateQueryRagModule( + prompt_driver=prompt_driver, + language="english" + ) + ] + ), retrieval_stage=RetrievalRagStage( max_chunks=5, retrieval_modules=[ VectorStoreRetrievalRagModule( - name="MyAwesomeRetriever", vector_store_driver=vector_store, query_params={"top_n": 20} + name="MyAwesomeRetriever", + vector_store_driver=vector_store, + query_params={ + "top_n": 20 + } ) - ], + ] ), - response_stage=ResponseRagStage(response_module=PromptResponseRagModule(prompt_driver=prompt_driver)), + response_stage=ResponseRagStage( + response_modules=[ + PromptResponseRagModule( + prompt_driver=prompt_driver, + rulesets=[ + Ruleset( + name="persona", + rules=[ + Rule("Talk like a pirate") + ] + ) + ] + ) + ] + ) ) rag_context = RagContext( query="¿Qué ofrecen los servicios en la nube de Griptape?", - module_configs={"MyAwesomeRetriever": {"query_params": {"namespace": "griptape"}}}, + module_configs={ + "MyAwesomeRetriever": { + "query_params": { + "namespace": "griptape" + } + } + } ) -print(rag_engine.process(rag_context).output.to_text()) +print( + rag_engine.process(rag_context).outputs[0].to_text() +) diff --git a/docs/griptape-framework/structures/src/task_memory_6.py b/docs/griptape-framework/structures/src/task_memory_6.py index 70c3bde55..3bff961be 100644 --- a/docs/griptape-framework/structures/src/task_memory_6.py +++ b/docs/griptape-framework/structures/src/task_memory_6.py @@ -34,7 +34,9 @@ ] ), response_stage=ResponseRagStage( - response_module=PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o")) + response_modules=[ + PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o")) + ] ), ), retrieval_rag_module_name="VectorStoreRetrievalRagModule", diff --git a/docs/griptape-framework/structures/src/tasks_9.py b/docs/griptape-framework/structures/src/tasks_9.py index 6fca66cc5..21b7c7b9a 100644 --- a/docs/griptape-framework/structures/src/tasks_9.py +++ b/docs/griptape-framework/structures/src/tasks_9.py @@ -29,7 +29,9 @@ ] ), response_stage=ResponseRagStage( - response_module=PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o")) + response_modules=[ + PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o")) + ] ), ), ) diff --git a/docs/griptape-tools/official-tools/src/rag_client_1.py b/docs/griptape-tools/official-tools/src/rag_client_1.py index 8751e80b1..d558c4946 100644 --- a/docs/griptape-tools/official-tools/src/rag_client_1.py +++ b/docs/griptape-tools/official-tools/src/rag_client_1.py @@ -27,7 +27,9 @@ ] ), response_stage=ResponseRagStage( - response_module=PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o")) + response_modules=[ + PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o")) + ] ), ), ) From 6b602a99adb205b83574638297e0149e9375080b Mon Sep 17 00:00:00 2001 From: Vasily Vasinov Date: Fri, 9 Aug 2024 13:19:13 -0600 Subject: [PATCH 05/13] Update doc examples --- docs/examples/src/query_webpage_astra_db_1.py | 4 +- docs/examples/src/talk_to_a_pdf_1.py | 4 +- docs/examples/src/talk_to_a_webpage_1.py | 4 +- .../engines/src/rag_engines_1.py | 53 +++++-------------- .../structures/src/task_memory_6.py | 4 +- .../structures/src/tasks_9.py | 4 +- .../official-tools/src/rag_client_1.py | 4 +- 7 files changed, 19 insertions(+), 58 deletions(-) diff --git a/docs/examples/src/query_webpage_astra_db_1.py b/docs/examples/src/query_webpage_astra_db_1.py index 966c0edff..b5d2b0a01 100644 --- a/docs/examples/src/query_webpage_astra_db_1.py +++ b/docs/examples/src/query_webpage_astra_db_1.py @@ -40,9 +40,7 @@ ] ), response_stage=ResponseRagStage( - response_modules=[ - PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o")) - ] + response_modules=[PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))] ), ) diff --git a/docs/examples/src/talk_to_a_pdf_1.py b/docs/examples/src/talk_to_a_pdf_1.py index 644af24e5..2ac184a22 100644 --- a/docs/examples/src/talk_to_a_pdf_1.py +++ b/docs/examples/src/talk_to_a_pdf_1.py @@ -22,9 +22,7 @@ ] ), response_stage=ResponseRagStage( - response_modules=[ - PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o")) - ] + response_modules=[PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))] ), ) vector_store_tool = RagClient( diff --git a/docs/examples/src/talk_to_a_webpage_1.py b/docs/examples/src/talk_to_a_webpage_1.py index 7d6a0042e..d24eb9427 100644 --- a/docs/examples/src/talk_to_a_webpage_1.py +++ b/docs/examples/src/talk_to_a_webpage_1.py @@ -22,9 +22,7 @@ ] ), response_stage=ResponseRagStage( - response_modules=[ - PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o")) - ] + response_modules=[PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))] ), ) diff --git a/docs/griptape-framework/engines/src/rag_engines_1.py b/docs/griptape-framework/engines/src/rag_engines_1.py index 2d96c64e8..ff4826123 100644 --- a/docs/griptape-framework/engines/src/rag_engines_1.py +++ b/docs/griptape-framework/engines/src/rag_engines_1.py @@ -7,63 +7,36 @@ prompt_driver = OpenAiChatPromptDriver(model="gpt-4o", temperature=0) -vector_store = LocalVectorStoreDriver( - embedding_driver=OpenAiEmbeddingDriver() -) +vector_store = LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver()) -vector_store.upsert_text_artifacts({ - "griptape": WebLoader(max_tokens=500).load("https://www.griptape.ai"), -}) +vector_store.upsert_text_artifacts( + { + "griptape": WebLoader(max_tokens=500).load("https://www.griptape.ai"), + } +) rag_engine = RagEngine( - query_stage=QueryRagStage( - query_modules=[ - TranslateQueryRagModule( - prompt_driver=prompt_driver, - language="english" - ) - ] - ), + query_stage=QueryRagStage(query_modules=[TranslateQueryRagModule(prompt_driver=prompt_driver, language="english")]), retrieval_stage=RetrievalRagStage( max_chunks=5, retrieval_modules=[ VectorStoreRetrievalRagModule( - name="MyAwesomeRetriever", - vector_store_driver=vector_store, - query_params={ - "top_n": 20 - } + name="MyAwesomeRetriever", vector_store_driver=vector_store, query_params={"top_n": 20} ) - ] + ], ), response_stage=ResponseRagStage( response_modules=[ PromptResponseRagModule( - prompt_driver=prompt_driver, - rulesets=[ - Ruleset( - name="persona", - rules=[ - Rule("Talk like a pirate") - ] - ) - ] + prompt_driver=prompt_driver, rulesets=[Ruleset(name="persona", rules=[Rule("Talk like a pirate")])] ) ] - ) + ), ) rag_context = RagContext( query="¿Qué ofrecen los servicios en la nube de Griptape?", - module_configs={ - "MyAwesomeRetriever": { - "query_params": { - "namespace": "griptape" - } - } - } + module_configs={"MyAwesomeRetriever": {"query_params": {"namespace": "griptape"}}}, ) -print( - rag_engine.process(rag_context).outputs[0].to_text() -) +print(rag_engine.process(rag_context).outputs[0].to_text()) diff --git a/docs/griptape-framework/structures/src/task_memory_6.py b/docs/griptape-framework/structures/src/task_memory_6.py index 3bff961be..7bbc5614a 100644 --- a/docs/griptape-framework/structures/src/task_memory_6.py +++ b/docs/griptape-framework/structures/src/task_memory_6.py @@ -34,9 +34,7 @@ ] ), response_stage=ResponseRagStage( - response_modules=[ - PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o")) - ] + response_modules=[PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))] ), ), retrieval_rag_module_name="VectorStoreRetrievalRagModule", diff --git a/docs/griptape-framework/structures/src/tasks_9.py b/docs/griptape-framework/structures/src/tasks_9.py index 21b7c7b9a..1033f1b2f 100644 --- a/docs/griptape-framework/structures/src/tasks_9.py +++ b/docs/griptape-framework/structures/src/tasks_9.py @@ -29,9 +29,7 @@ ] ), response_stage=ResponseRagStage( - response_modules=[ - PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o")) - ] + response_modules=[PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))] ), ), ) diff --git a/docs/griptape-tools/official-tools/src/rag_client_1.py b/docs/griptape-tools/official-tools/src/rag_client_1.py index d558c4946..01e71e253 100644 --- a/docs/griptape-tools/official-tools/src/rag_client_1.py +++ b/docs/griptape-tools/official-tools/src/rag_client_1.py @@ -27,9 +27,7 @@ ] ), response_stage=ResponseRagStage( - response_modules=[ - PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o")) - ] + response_modules=[PromptResponseRagModule(prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"))] ), ), ) From d2f9c3bbd436b82723cb3188183c8e380c159c34 Mon Sep 17 00:00:00 2001 From: Vasily Vasinov Date: Fri, 9 Aug 2024 13:19:27 -0600 Subject: [PATCH 06/13] Always generate unique RAG module names --- griptape/engines/rag/modules/base_rag_module.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/griptape/engines/rag/modules/base_rag_module.py b/griptape/engines/rag/modules/base_rag_module.py index 829a24565..01d6d1b1d 100644 --- a/griptape/engines/rag/modules/base_rag_module.py +++ b/griptape/engines/rag/modules/base_rag_module.py @@ -1,5 +1,6 @@ from __future__ import annotations +import uuid from abc import ABC from concurrent import futures from typing import TYPE_CHECKING, Any, Callable, Optional @@ -14,7 +15,9 @@ @define(kw_only=True) class BaseRagModule(ABC): - name: str = field(default=Factory(lambda self: self.__class__.__name__, takes_self=True), kw_only=True) + name: str = field( + default=Factory(lambda self: f"{self.__class__.__name__}-{uuid.uuid4().hex}", takes_self=True), kw_only=True + ) futures_executor_fn: Callable[[], futures.Executor] = field( default=Factory(lambda: lambda: futures.ThreadPoolExecutor()), ) From ddc398c77ba90e179086be654b83e54f00564d4c Mon Sep 17 00:00:00 2001 From: Vasily Vasinov Date: Fri, 9 Aug 2024 13:19:36 -0600 Subject: [PATCH 07/13] Update CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7073557b2..b4fc8d6d6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,11 +13,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Method `try_find_task` to `Structure`. - `TranslateQueryRagModule` `RagEngine` module for translating input queries. - Global event bus, `griptape.events.event_bus`, for publishing and subscribing to events. +- Unique name generation for all `RagEngine` modules. ### Changed - **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`. - **BREAKING**: Removed `EventPublisherMixin`. - **BREAKING**: `RagContext.output` was changed to `RagContext.outputs` to support multiple outputs. All relevant RAG modules we adjusted accordingly. + **BREAKING**: Removed before and after response modules from `ResponseRagStage`. + **BREAKING**: Moved ruleset and metadata ingestion from standalone modules to `PromptResponseRagModule`. - `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible. ## [0.29.0] - 2024-07-30 From 7654f8601f209ca29ebd172be261984cf6c31307 Mon Sep 17 00:00:00 2001 From: Vasily Vasinov Date: Fri, 9 Aug 2024 13:21:28 -0600 Subject: [PATCH 08/13] Changelog update --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b4fc8d6d6..dd0d4d8d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,7 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`. - **BREAKING**: Removed `EventPublisherMixin`. -- **BREAKING**: `RagContext.output` was changed to `RagContext.outputs` to support multiple outputs. All relevant RAG modules we adjusted accordingly. +- **BREAKING**: `RagContext.output` was changed to `RagContext.outputs` to support multiple outputs. All relevant RAG modules were adjusted accordingly. **BREAKING**: Removed before and after response modules from `ResponseRagStage`. **BREAKING**: Moved ruleset and metadata ingestion from standalone modules to `PromptResponseRagModule`. - `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible. From e45c2b5f92e72b19cb5a0c08e50cd7b2b91fe6ae Mon Sep 17 00:00:00 2001 From: Vasily Vasinov Date: Fri, 9 Aug 2024 13:41:32 -0600 Subject: [PATCH 09/13] Fix tests and formatting --- docs/griptape-framework/engines/src/rag_engines_1.py | 10 +++++----- .../rag/modules/response/prompt_response_rag_module.py | 2 +- .../generation/test_prompt_response_rag_module.py | 2 +- .../test_vector_store_retrieval_rag_module.py | 5 +++-- 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/docs/griptape-framework/engines/src/rag_engines_1.py b/docs/griptape-framework/engines/src/rag_engines_1.py index ff4826123..55c3b50f8 100644 --- a/docs/griptape-framework/engines/src/rag_engines_1.py +++ b/docs/griptape-framework/engines/src/rag_engines_1.py @@ -1,9 +1,9 @@ -from griptape.drivers import LocalVectorStoreDriver, OpenAiEmbeddingDriver, OpenAiChatPromptDriver -from griptape.engines.rag import RagEngine, RagContext -from griptape.engines.rag.modules import VectorStoreRetrievalRagModule, PromptResponseRagModule, TranslateQueryRagModule -from griptape.engines.rag.stages import RetrievalRagStage, ResponseRagStage, QueryRagStage +from griptape.drivers import LocalVectorStoreDriver, OpenAiChatPromptDriver, OpenAiEmbeddingDriver +from griptape.engines.rag import RagContext, RagEngine +from griptape.engines.rag.modules import PromptResponseRagModule, TranslateQueryRagModule, VectorStoreRetrievalRagModule +from griptape.engines.rag.stages import QueryRagStage, ResponseRagStage, RetrievalRagStage from griptape.loaders import WebLoader -from griptape.rules import Ruleset, Rule +from griptape.rules import Rule, Ruleset prompt_driver = OpenAiChatPromptDriver(model="gpt-4o", temperature=0) diff --git a/griptape/engines/rag/modules/response/prompt_response_rag_module.py b/griptape/engines/rag/modules/response/prompt_response_rag_module.py index 02f10f5db..0f53fec18 100644 --- a/griptape/engines/rag/modules/response/prompt_response_rag_module.py +++ b/griptape/engines/rag/modules/response/prompt_response_rag_module.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Optional, Any +from typing import TYPE_CHECKING, Any, Callable, Optional from attrs import Factory, define, field diff --git a/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py b/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py index 7e8e86988..cc8d35f0e 100644 --- a/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py +++ b/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py @@ -3,7 +3,7 @@ from griptape.artifacts import TextArtifact from griptape.engines.rag import RagContext from griptape.engines.rag.modules import PromptResponseRagModule -from griptape.rules import Ruleset, Rule +from griptape.rules import Rule, Ruleset from tests.mocks.mock_prompt_driver import MockPromptDriver diff --git a/tests/unit/engines/rag/modules/retrieval/test_vector_store_retrieval_rag_module.py b/tests/unit/engines/rag/modules/retrieval/test_vector_store_retrieval_rag_module.py index 9fecc3c0e..70b259220 100644 --- a/tests/unit/engines/rag/modules/retrieval/test_vector_store_retrieval_rag_module.py +++ b/tests/unit/engines/rag/modules/retrieval/test_vector_store_retrieval_rag_module.py @@ -40,6 +40,7 @@ def test_run_with_namespace(self): def test_run_with_namespace_overrides(self): vector_store_driver = LocalVectorStoreDriver(embedding_driver=MockEmbeddingDriver()) module = VectorStoreRetrievalRagModule( + name="TestModule", vector_store_driver=vector_store_driver, query_params={"namespace": "test"} ) @@ -48,13 +49,13 @@ def test_run_with_namespace_overrides(self): result1 = module.run( RagContext( - query="test", module_configs={"VectorStoreRetrievalRagModule": {"query_params": {"namespace": "empty"}}} + query="test", module_configs={"TestModule": {"query_params": {"namespace": "empty"}}} ) ) result2 = module.run( RagContext( - query="test", module_configs={"VectorStoreRetrievalRagModule": {"query_params": {"namespace": "test"}}} + query="test", module_configs={"TestModule": {"query_params": {"namespace": "test"}}} ) ) From 8fd947aecc0318c1bee66c877dae5d976a3f7899 Mon Sep 17 00:00:00 2001 From: Vasily Vasinov Date: Fri, 9 Aug 2024 13:51:09 -0600 Subject: [PATCH 10/13] Fix formatting --- .../test_vector_store_retrieval_rag_module.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/unit/engines/rag/modules/retrieval/test_vector_store_retrieval_rag_module.py b/tests/unit/engines/rag/modules/retrieval/test_vector_store_retrieval_rag_module.py index 70b259220..96a91280e 100644 --- a/tests/unit/engines/rag/modules/retrieval/test_vector_store_retrieval_rag_module.py +++ b/tests/unit/engines/rag/modules/retrieval/test_vector_store_retrieval_rag_module.py @@ -40,23 +40,18 @@ def test_run_with_namespace(self): def test_run_with_namespace_overrides(self): vector_store_driver = LocalVectorStoreDriver(embedding_driver=MockEmbeddingDriver()) module = VectorStoreRetrievalRagModule( - name="TestModule", - vector_store_driver=vector_store_driver, query_params={"namespace": "test"} + name="TestModule", vector_store_driver=vector_store_driver, query_params={"namespace": "test"} ) vector_store_driver.upsert_text_artifact(TextArtifact("foobar1"), namespace="test") vector_store_driver.upsert_text_artifact(TextArtifact("foobar2"), namespace="test") result1 = module.run( - RagContext( - query="test", module_configs={"TestModule": {"query_params": {"namespace": "empty"}}} - ) + RagContext(query="test", module_configs={"TestModule": {"query_params": {"namespace": "empty"}}}) ) result2 = module.run( - RagContext( - query="test", module_configs={"TestModule": {"query_params": {"namespace": "test"}}} - ) + RagContext(query="test", module_configs={"TestModule": {"query_params": {"namespace": "test"}}}) ) assert len(result1) == 0 From f6e520137650e8f926a42b9cb3c3b567659a1057 Mon Sep 17 00:00:00 2001 From: Vasily Vasinov Date: Mon, 12 Aug 2024 10:59:39 -0600 Subject: [PATCH 11/13] Fix docs --- docs/griptape-framework/engines/src/rag_engines_1.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/docs/griptape-framework/engines/src/rag_engines_1.py b/docs/griptape-framework/engines/src/rag_engines_1.py index 55c3b50f8..c257cd4df 100644 --- a/docs/griptape-framework/engines/src/rag_engines_1.py +++ b/docs/griptape-framework/engines/src/rag_engines_1.py @@ -1,3 +1,4 @@ +from griptape.artifacts import ErrorArtifact from griptape.drivers import LocalVectorStoreDriver, OpenAiChatPromptDriver, OpenAiEmbeddingDriver from griptape.engines.rag import RagContext, RagEngine from griptape.engines.rag.modules import PromptResponseRagModule, TranslateQueryRagModule, VectorStoreRetrievalRagModule @@ -8,10 +9,14 @@ prompt_driver = OpenAiChatPromptDriver(model="gpt-4o", temperature=0) vector_store = LocalVectorStoreDriver(embedding_driver=OpenAiEmbeddingDriver()) +artifacts = WebLoader(max_tokens=500).load("https://www.griptape.ai") + +if isinstance(artifacts, ErrorArtifact): + raise Exception(artifacts.value) vector_store.upsert_text_artifacts( { - "griptape": WebLoader(max_tokens=500).load("https://www.griptape.ai"), + "griptape": artifacts, } ) From 4db3d30819b5bf7b9aa7e9cef38910f1e9c6f567 Mon Sep 17 00:00:00 2001 From: Vasily Vasinov Date: Mon, 12 Aug 2024 11:00:15 -0600 Subject: [PATCH 12/13] Update CHANGELOG --- CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dd0d4d8d7..e7c2f70be 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,8 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Removed all uses of `EventPublisherMixin` in favor of `event_bus`. - **BREAKING**: Removed `EventPublisherMixin`. - **BREAKING**: `RagContext.output` was changed to `RagContext.outputs` to support multiple outputs. All relevant RAG modules were adjusted accordingly. - **BREAKING**: Removed before and after response modules from `ResponseRagStage`. - **BREAKING**: Moved ruleset and metadata ingestion from standalone modules to `PromptResponseRagModule`. +- **BREAKING**: Removed before and after response modules from `ResponseRagStage`. +- **BREAKING**: Moved ruleset and metadata ingestion from standalone modules to `PromptResponseRagModule`. - `BaseTask.add_parent/child` will now call `self.structure.add_task` if possible. ## [0.29.0] - 2024-07-30 From 8395cf80d81a7ec359dc675e48e649873cefd0f2 Mon Sep 17 00:00:00 2001 From: Vasily Vasinov Date: Mon, 12 Aug 2024 11:05:11 -0600 Subject: [PATCH 13/13] Integrate RuleMixin with PromptResponseRagModule --- .../rag/modules/response/prompt_response_rag_module.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/griptape/engines/rag/modules/response/prompt_response_rag_module.py b/griptape/engines/rag/modules/response/prompt_response_rag_module.py index 0f53fec18..99bdf5f5e 100644 --- a/griptape/engines/rag/modules/response/prompt_response_rag_module.py +++ b/griptape/engines/rag/modules/response/prompt_response_rag_module.py @@ -6,20 +6,19 @@ from griptape.artifacts.text_artifact import TextArtifact from griptape.engines.rag.modules import BaseResponseRagModule +from griptape.mixins import RuleMixin from griptape.utils import J2 if TYPE_CHECKING: from griptape.artifacts import BaseArtifact from griptape.drivers import BasePromptDriver from griptape.engines.rag import RagContext - from griptape.rules import Ruleset @define(kw_only=True) -class PromptResponseRagModule(BaseResponseRagModule): +class PromptResponseRagModule(BaseResponseRagModule, RuleMixin): prompt_driver: BasePromptDriver = field() answer_token_offset: int = field(default=400) - rulesets: list[Ruleset] = field(factory=list) metadata: Optional[str] = field(default=None) generate_system_template: Callable[[RagContext, list[TextArtifact]], str] = field( default=Factory(lambda self: self.default_system_template_generator, takes_self=True), @@ -56,8 +55,8 @@ def run(self, context: RagContext) -> BaseArtifact: def default_system_template_generator(self, context: RagContext, artifacts: list[TextArtifact]) -> str: params: dict[str, Any] = {"text_chunks": [c.to_text() for c in artifacts]} - if len(self.rulesets) > 0: - params["rulesets"] = J2("rulesets/rulesets.j2").render(rulesets=self.rulesets) + if len(self.all_rulesets) > 0: + params["rulesets"] = J2("rulesets/rulesets.j2").render(rulesets=self.all_rulesets) if self.metadata is not None: params["metadata"] = J2("engines/rag/modules/response/metadata/system.j2").render(metadata=self.metadata)