-
Notifications
You must be signed in to change notification settings - Fork 186
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ResponseRagStage
and PromptResponseRagModule
updates
#1056
Changes from 14 commits
044e870
8b97745
6f59f87
ebc6949
481f95a
6b602a9
d2f9c3b
ddc398c
7654f86
e45c2b5
8fd947a
d5a507f
f6e5201
4db3d30
8395cf8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, Callable | ||
from typing import TYPE_CHECKING, Any, Callable, Optional | ||
|
||
from attrs import Factory, define, field | ||
|
||
|
@@ -9,19 +9,23 @@ | |
from griptape.utils import J2 | ||
|
||
if TYPE_CHECKING: | ||
from griptape.artifacts import BaseArtifact | ||
from griptape.drivers import BasePromptDriver | ||
from griptape.engines.rag import RagContext | ||
from griptape.rules import Ruleset | ||
|
||
|
||
@define(kw_only=True) | ||
class PromptResponseRagModule(BaseResponseRagModule): | ||
answer_token_offset: int = field(default=400) | ||
prompt_driver: BasePromptDriver = field() | ||
answer_token_offset: int = field(default=400) | ||
rulesets: list[Ruleset] = field(factory=list) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should implement |
||
metadata: Optional[str] = field(default=None) | ||
generate_system_template: Callable[[RagContext, list[TextArtifact]], str] = field( | ||
default=Factory(lambda self: self.default_system_template_generator, takes_self=True), | ||
) | ||
|
||
def run(self, context: RagContext) -> RagContext: | ||
def run(self, context: RagContext) -> BaseArtifact: | ||
query = context.query | ||
tokenizer = self.prompt_driver.tokenizer | ||
included_chunks = [] | ||
|
@@ -45,15 +49,17 @@ def run(self, context: RagContext) -> RagContext: | |
output = self.prompt_driver.run(self.generate_prompt_stack(system_prompt, query)).to_artifact() | ||
|
||
if isinstance(output, TextArtifact): | ||
context.output = output | ||
return output | ||
else: | ||
raise ValueError("Prompt driver did not return a TextArtifact") | ||
|
||
return context | ||
|
||
def default_system_template_generator(self, context: RagContext, artifacts: list[TextArtifact]) -> str: | ||
return J2("engines/rag/modules/response/prompt/system.j2").render( | ||
text_chunks=[c.to_text() for c in artifacts], | ||
before_system_prompt="\n\n".join(context.before_query), | ||
after_system_prompt="\n\n".join(context.after_query), | ||
) | ||
params: dict[str, Any] = {"text_chunks": [c.to_text() for c in artifacts]} | ||
|
||
if len(self.rulesets) > 0: | ||
params["rulesets"] = J2("rulesets/rulesets.j2").render(rulesets=self.rulesets) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After implementing |
||
|
||
if self.metadata is not None: | ||
params["metadata"] = J2("engines/rag/modules/response/metadata/system.j2").render(metadata=self.metadata) | ||
|
||
return J2("engines/rag/modules/response/prompt/system.j2").render(**params) |
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,11 @@ | ||
from attrs import define | ||
|
||
from griptape.artifacts import ListArtifact | ||
from griptape.artifacts import BaseArtifact, ListArtifact | ||
from griptape.engines.rag import RagContext | ||
from griptape.engines.rag.modules import BaseResponseRagModule | ||
|
||
|
||
@define(kw_only=True) | ||
class TextChunksResponseRagModule(BaseResponseRagModule): | ||
def run(self, context: RagContext) -> RagContext: | ||
context.output = ListArtifact(context.text_chunks) | ||
|
||
return context | ||
def run(self, context: RagContext) -> BaseArtifact: | ||
return ListArtifact(context.text_chunks) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,49 +5,35 @@ | |
|
||
from attrs import define, field | ||
|
||
from griptape import utils | ||
from griptape.engines.rag.stages import BaseRagStage | ||
|
||
if TYPE_CHECKING: | ||
from griptape.engines.rag import RagContext | ||
from griptape.engines.rag.modules import ( | ||
BaseAfterResponseRagModule, | ||
BaseBeforeResponseRagModule, | ||
BaseRagModule, | ||
BaseResponseRagModule, | ||
) | ||
|
||
|
||
@define(kw_only=True) | ||
class ResponseRagStage(BaseRagStage): | ||
before_response_modules: list[BaseBeforeResponseRagModule] = field(factory=list) | ||
response_module: BaseResponseRagModule = field() | ||
after_response_modules: list[BaseAfterResponseRagModule] = field(factory=list) | ||
response_modules: list[BaseResponseRagModule] = field() | ||
|
||
@property | ||
def modules(self) -> list[BaseRagModule]: | ||
ms = [] | ||
|
||
ms.extend(self.before_response_modules) | ||
ms.extend(self.after_response_modules) | ||
|
||
if self.response_module is not None: | ||
ms.append(self.response_module) | ||
ms.extend(self.response_modules) | ||
|
||
return ms | ||
Comment on lines
24
to
29
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we still need this property? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We do because we test for module name uniqueness by using this property. |
||
|
||
def run(self, context: RagContext) -> RagContext: | ||
logging.info("GenerationStage: running %s before modules sequentially", len(self.before_response_modules)) | ||
|
||
for generator in self.before_response_modules: | ||
context = generator.run(context) | ||
|
||
logging.info("GenerationStage: running generation module") | ||
|
||
context = self.response_module.run(context) | ||
logging.info("ResponseRagStage: running %s retrieval modules in parallel", len(self.response_modules)) | ||
|
||
logging.info("GenerationStage: running %s after modules sequentially", len(self.after_response_modules)) | ||
with self.futures_executor_fn() as executor: | ||
results = utils.execute_futures_list([executor.submit(r.run, context) for r in self.response_modules]) | ||
|
||
for generator in self.after_response_modules: | ||
context = generator.run(context) | ||
context.outputs = results | ||
|
||
return context |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does name uniqueness get us?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Much easier to add multiple modules of the same type without having to explicitly define names. May be that's something we add to tools as well?