Skip to content

Commit

Permalink
Standardize callables
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Oct 21, 2024
1 parent e6ffc90 commit 1f9357f
Show file tree
Hide file tree
Showing 60 changed files with 189 additions and 183 deletions.
4 changes: 2 additions & 2 deletions docs/examples/src/multi_agent_workflow_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def build_writer(role: str, goal: str, backstory: str) -> Agent:
),
id="research",
driver=LocalStructureRunDriver(
structure_factory_fn=build_researcher,
structure_factory=build_researcher,
),
),
)
Expand All @@ -150,7 +150,7 @@ def build_writer(role: str, goal: str, backstory: str) -> Agent:
{{ parent_outputs["research"] }}""",
),
driver=LocalStructureRunDriver(
structure_factory_fn=lambda writer=writer: build_writer(
structure_factory=lambda writer=writer: build_writer(
role=writer["role"],
goal=writer["goal"],
backstory=writer["backstory"],
Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/drivers/src/sql_drivers_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ def get_snowflake_connection() -> SnowflakeConnection:
)


driver = SnowflakeSqlDriver(connection_func=get_snowflake_connection)
driver = SnowflakeSqlDriver(connection_provider=get_snowflake_connection)

driver.execute_query("select * from people;")
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ def build_joke_rewriter() -> Agent:
tasks=[
StructureRunTask(
driver=LocalStructureRunDriver(
structure_factory_fn=build_joke_teller,
structure_factory=build_joke_teller,
),
),
StructureRunTask(
("Rewrite this joke: {{ parent_output }}",),
driver=LocalStructureRunDriver(
structure_factory_fn=build_joke_rewriter,
structure_factory=build_joke_rewriter,
),
),
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
StructureRunTask(
("Think of a question related to Retrieval Augmented Generation.",),
driver=LocalStructureRunDriver(
structure_factory_fn=lambda: Agent(
structure_factory=lambda: Agent(
rules=[
Rule(
value="You are an expert in Retrieval Augmented Generation.",
Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/engines/summary-engines.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Summary engines are used to summarize text and collections of [TextArtifact](../

## Prompt

Used to summarize texts with LLMs. You can set a custom [prompt_driver](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.prompt_driver), [system_template_generator](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.system_template_generator), [user_template_generator](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.user_template_generator), and [chunker](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.chunker).
Used to summarize texts with LLMs. You can set a custom [prompt_driver](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.prompt_driver), [system_template_provider](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.system_template_provider), [user_template_provider](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.user_template_provider), and [chunker](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.chunker).

Use the [summarize_artifacts](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.summarize_artifacts) method to summarize a list of artifacts or [summarize_text](../../reference/griptape/engines/summary/base_summary_engine.md#griptape.engines.summary.base_summary_engine.BaseSummaryEngine.summarize_text) to summarize an arbitrary string.

Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/structures/src/tasks_10.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def character_counter(task: CodeExecutionTask) -> BaseArtifact:

pipeline.add_tasks(
# take the first argument from the pipeline `run` method
CodeExecutionTask(run_fn=character_counter),
CodeExecutionTask(run_handler=character_counter),
# # take the output from the previous task and insert it into the prompt
PromptTask("{{args[0]}} using {{ parent_output }} characters"),
)
Expand Down
4 changes: 2 additions & 2 deletions docs/griptape-framework/structures/src/tasks_16.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def build_writer() -> Agent:
"""Perform a detailed examination of the newest developments in AI as of 2024.
Pinpoint major trends, breakthroughs, and their implications for various industries.""",
),
driver=LocalStructureRunDriver(structure_factory_fn=build_researcher),
driver=LocalStructureRunDriver(structure_factory=build_researcher),
),
StructureRunTask(
(
Expand All @@ -122,7 +122,7 @@ def build_writer() -> Agent:
Keep the tone appealing and use simple language to make it less technical.""",
"{{parent_output}}",
),
driver=LocalStructureRunDriver(structure_factory_fn=build_writer),
driver=LocalStructureRunDriver(structure_factory=build_writer),
),
],
)
Expand Down
14 changes: 8 additions & 6 deletions griptape/drivers/sql/snowflake_sql_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,26 @@

@define
class SnowflakeSqlDriver(BaseSqlDriver):
connection_func: Callable[[], SnowflakeConnection] = field(kw_only=True)
connection_provider: Callable[[], SnowflakeConnection] = field(kw_only=True)
_engine: Engine = field(default=None, kw_only=True, alias="engine", metadata={"serializable": False})

@connection_func.validator # pyright: ignore[reportFunctionMemberAccess]
def validate_connection_func(self, _: Attribute, connection_func: Callable[[], SnowflakeConnection]) -> None:
snowflake_connection = connection_func()
@connection_provider.validator # pyright: ignore[reportFunctionMemberAccess]
def validate_connection_provider(
self, _: Attribute, connection_provider: Callable[[], SnowflakeConnection]
) -> None:
snowflake_connection = connection_provider()
snowflake = import_optional_dependency("snowflake")

if not isinstance(snowflake_connection, snowflake.connector.SnowflakeConnection):
raise ValueError("The connection_func must return a SnowflakeConnection")
raise ValueError("The connection_provider must return a SnowflakeConnection")
if not snowflake_connection.schema or not snowflake_connection.database:
raise ValueError("Provide a schema and database for the Snowflake connection")

@lazy_property()
def engine(self) -> Engine:
return import_optional_dependency("sqlalchemy").create_engine(
"snowflake://not@used/db",
creator=self.connection_func,
creator=self.connection_provider,
)

def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]:
Expand Down
8 changes: 4 additions & 4 deletions griptape/drivers/structure_run/local_structure_run_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@

@define
class LocalStructureRunDriver(BaseStructureRunDriver):
structure_factory_fn: Callable[[], Structure] = field(kw_only=True)
structure_factory: Callable[[], Structure] = field(kw_only=True)

def try_run(self, *args: BaseArtifact) -> BaseArtifact:
old_env = os.environ.copy()
try:
os.environ.update(self.env)
structure_factory_fn = self.structure_factory_fn().run(*[arg.value for arg in args])
structure = self.structure_factory().run(*[arg.value for arg in args])
finally:
os.environ.clear()
os.environ.update(old_env)

if structure_factory_fn.output_task.output is not None:
return structure_factory_fn.output_task.output
if structure.output_task.output is not None:
return structure.output_task.output
else:
return InfoArtifact("No output found in response")
4 changes: 2 additions & 2 deletions griptape/drivers/vector/local_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
class LocalVectorStoreDriver(BaseVectorStoreDriver):
entries: dict[str, BaseVectorStoreDriver.Entry] = field(factory=dict)
persist_file: Optional[str] = field(default=None)
relatedness_fn: Callable = field(default=lambda x, y: dot(x, y) / (norm(x) * norm(y)))
relatedness_calculator: Callable = field(default=lambda x, y: dot(x, y) / (norm(x) * norm(y)))
thread_lock: threading.Lock = field(default=Factory(lambda: threading.Lock()))

def __attrs_post_init__(self) -> None:
Expand Down Expand Up @@ -95,7 +95,7 @@ def query(
entries = self.entries

entries_and_relatednesses = [
(entry, self.relatedness_fn(query_embedding, entry.vector)) for entry in list(entries.values())
(entry, self.relatedness_calculator(query_embedding, entry.vector)) for entry in list(entries.values())
]

entries_and_relatednesses.sort(key=operator.itemgetter(1), reverse=True)
Expand Down
14 changes: 7 additions & 7 deletions griptape/engines/extraction/csv_extraction_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
@define
class CsvExtractionEngine(BaseExtractionEngine):
column_names: list[str] = field(kw_only=True)
system_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/system.j2")), kw_only=True)
user_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/user.j2")), kw_only=True)
formatter_fn: Callable[[dict], str] = field(
system_template_provider: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/system.j2")), kw_only=True)
user_template_provider: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/user.j2")), kw_only=True)
row_formatter: Callable[[dict], str] = field(
default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True
)

Expand All @@ -45,7 +45,7 @@ def text_to_csv_rows(self, text: str, column_names: list[str]) -> list[TextArtif

with io.StringIO(text) as f:
for row in csv.reader(f):
rows.append(TextArtifact(self.formatter_fn(dict(zip(column_names, [x.strip() for x in row])))))
rows.append(TextArtifact(self.row_formatter(dict(zip(column_names, [x.strip() for x in row])))))

return rows

Expand All @@ -57,11 +57,11 @@ def _extract_rec(
rulesets: Optional[list[Ruleset]] = None,
) -> list[TextArtifact]:
artifacts_text = self.chunk_joiner.join([a.value for a in artifacts])
system_prompt = self.system_template_generator.render(
system_prompt = self.system_template_provider.render(
column_names=self.column_names,
rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
)
user_prompt = self.user_template_generator.render(
user_prompt = self.user_template_provider.render(
text=artifacts_text,
)

Expand All @@ -86,7 +86,7 @@ def _extract_rec(
return rows
else:
chunks = self.chunker.chunk(artifacts_text)
partial_text = self.user_template_generator.render(
partial_text = self.user_template_provider.render(
text=chunks[0].value,
)

Expand Down
12 changes: 5 additions & 7 deletions griptape/engines/extraction/json_extraction_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@ class JsonExtractionEngine(BaseExtractionEngine):
JSON_PATTERN = r"(?s)[^\[]*(\[.*\])"

template_schema: dict = field(kw_only=True)
system_template_generator: J2 = field(
default=Factory(lambda: J2("engines/extraction/json/system.j2")), kw_only=True
)
user_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/json/user.j2")), kw_only=True)
system_template_provider: J2 = field(default=Factory(lambda: J2("engines/extraction/json/system.j2")), kw_only=True)
user_template_provider: J2 = field(default=Factory(lambda: J2("engines/extraction/json/user.j2")), kw_only=True)

def extract_artifacts(
self,
Expand Down Expand Up @@ -54,11 +52,11 @@ def _extract_rec(
rulesets: Optional[list[Ruleset]] = None,
) -> list[JsonArtifact]:
artifacts_text = self.chunk_joiner.join([a.value for a in artifacts])
system_prompt = self.system_template_generator.render(
system_prompt = self.system_template_provider.render(
json_template_schema=json.dumps(self.template_schema),
rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
)
user_prompt = self.user_template_generator.render(
user_prompt = self.user_template_provider.render(
text=artifacts_text,
)

Expand All @@ -82,7 +80,7 @@ def _extract_rec(
return extractions
else:
chunks = self.chunker.chunk(artifacts_text)
partial_text = self.user_template_generator.render(
partial_text = self.user_template_provider.render(
text=chunks[0].value,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@
class TranslateQueryRagModule(BaseQueryRagModule):
prompt_driver: BasePromptDriver = field()
language: str = field()
generate_user_template: Callable[[str, str], str] = field(
default=Factory(lambda self: self.default_user_template_generator, takes_self=True),
user_template_provider: Callable[[str, str], str] = field(
default=Factory(lambda self: self.default_user_template_provider, takes_self=True),
)

def run(self, context: RagContext) -> RagContext:
user_prompt = self.generate_user_template(context.query, self.language)
user_prompt = self.user_template_provider(context.query, self.language)
output = self.prompt_driver.run(self.generate_prompt_stack(None, user_prompt)).to_artifact()

context.query = output.to_text()

return context

def default_user_template_generator(self, query: str, language: str) -> str:
def default_user_template_provider(self, query: str, language: str) -> str:
return J2("engines/rag/modules/query/translate/user.j2").render(query=query, language=language)
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

@define(kw_only=True)
class FootnotePromptResponseRagModule(PromptResponseRagModule):
def default_system_template_generator(self, context: RagContext, artifacts: list[TextArtifact]) -> str:
def default_system_template_provider(self, context: RagContext, artifacts: list[TextArtifact]) -> str:
return J2("engines/rag/modules/response/footnote_prompt/system.j2").render(
text_chunk_artifacts=artifacts,
references=utils.references_from_artifacts(artifacts),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,28 @@ class PromptResponseRagModule(BaseResponseRagModule, RuleMixin):
prompt_driver: BasePromptDriver = field(default=Factory(lambda: Defaults.drivers_config.prompt_driver))
answer_token_offset: int = field(default=400)
metadata: Optional[str] = field(default=None)
generate_system_template: Callable[[RagContext, list[TextArtifact]], str] = field(
default=Factory(lambda self: self.default_system_template_generator, takes_self=True),
system_template_provider: Callable[[RagContext, list[TextArtifact]], str] = field(
default=Factory(lambda self: self.default_system_template_provider, takes_self=True),
)

def run(self, context: RagContext) -> BaseArtifact:
query = context.query
tokenizer = self.prompt_driver.tokenizer
included_chunks = []
system_prompt = self.generate_system_template(context, included_chunks)
system_prompt = self.system_template_provider(context, included_chunks)

for artifact in context.text_chunks:
included_chunks.append(artifact)

system_prompt = self.generate_system_template(context, included_chunks)
system_prompt = self.system_template_provider(context, included_chunks)
message_token_count = self.prompt_driver.tokenizer.count_tokens(
self.prompt_driver.prompt_stack_to_string(self.generate_prompt_stack(system_prompt, query)),
)

if message_token_count + self.answer_token_offset >= tokenizer.max_input_tokens:
included_chunks.pop()

system_prompt = self.generate_system_template(context, included_chunks)
system_prompt = self.system_template_provider(context, included_chunks)

break

Expand All @@ -53,7 +53,7 @@ def run(self, context: RagContext) -> BaseArtifact:
else:
raise ValueError("Prompt driver did not return a TextArtifact")

def default_system_template_generator(self, context: RagContext, artifacts: list[TextArtifact]) -> str:
def default_system_template_provider(self, context: RagContext, artifacts: list[TextArtifact]) -> str:
params: dict[str, Any] = {"text_chunks": [c.to_text() for c in artifacts]}

if len(self.rulesets) > 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class TextLoaderRetrievalRagModule(BaseRetrievalRagModule):
vector_store_driver: BaseVectorStoreDriver = field()
source: Any = field()
query_params: dict[str, Any] = field(factory=dict)
process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field(
query_output_processor: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field(
default=Factory(lambda: lambda es: [e.to_artifact() for e in es]),
)

Expand All @@ -43,4 +43,4 @@ def run(self, context: RagContext) -> Sequence[TextArtifact]:

self.vector_store_driver.upsert_text_artifacts({namespace: chunks})

return self.process_query_output_fn(self.vector_store_driver.query(context.query, **query_params))
return self.query_output_processor(self.vector_store_driver.query(context.query, **query_params))
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ class VectorStoreRetrievalRagModule(BaseRetrievalRagModule):
default=Factory(lambda: Defaults.drivers_config.vector_store_driver)
)
query_params: dict[str, Any] = field(factory=dict)
process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field(
query_output_processor: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field(
default=Factory(lambda: lambda es: [e.to_artifact() for e in es]),
)

def run(self, context: RagContext) -> Sequence[TextArtifact]:
query_params = utils.dict_merge(self.query_params, self.get_context_param(context, "query_params"))

return self.process_query_output_fn(self.vector_store_driver.query(context.query, **query_params))
return self.query_output_processor(self.vector_store_driver.query(context.query, **query_params))
Loading

0 comments on commit 1f9357f

Please sign in to comment.