From 5aae5bdb4cae58423d8a19959e569a17135c00f5 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Mon, 21 Oct 2024 14:42:37 -0700 Subject: [PATCH] Standardize callables --- CHANGELOG.md | 20 +++++++++++ docs/examples/src/multi_agent_workflow_1.py | 4 +-- .../drivers/src/sql_drivers_3.py | 2 +- .../drivers/src/structure_run_drivers_1.py | 4 +-- .../drivers/src/structure_run_drivers_2.py | 2 +- .../engines/summary-engines.md | 2 +- .../structures/src/tasks_10.py | 2 +- .../structures/src/tasks_16.py | 4 +-- griptape/drivers/sql/snowflake_sql_driver.py | 14 ++++---- .../local_structure_run_driver.py | 8 ++--- .../vector/local_vector_store_driver.py | 4 +-- .../extraction/csv_extraction_engine.py | 14 ++++---- .../extraction/json_extraction_engine.py | 12 +++---- .../query/translate_query_rag_module.py | 8 ++--- .../footnote_prompt_response_rag_module.py | 2 +- .../response/prompt_response_rag_module.py | 12 +++---- .../text_loader_retrieval_rag_module.py | 4 +-- .../vector_store_retrieval_rag_module.py | 4 +-- .../engines/summary/prompt_summary_engine.py | 10 +++--- griptape/loaders/csv_loader.py | 4 +-- griptape/loaders/sql_loader.py | 4 +-- .../structure/summary_conversation_memory.py | 8 ++--- griptape/mixins/futures_executor_mixin.py | 4 +-- griptape/rules/json_schema_rule.py | 6 ++-- griptape/tasks/code_execution_task.py | 4 +-- griptape/tasks/prompt_task.py | 8 ++--- griptape/tasks/tool_task.py | 2 +- griptape/tasks/toolkit_task.py | 20 +++++------ griptape/tools/vector_store/tool.py | 6 ++-- griptape/utils/chat.py | 34 ++++++++++--------- tests/integration/rules/test_rule.py | 2 +- .../tasks/test_csv_extraction_task.py | 2 +- .../tasks/test_json_extraction_task.py | 2 +- tests/integration/tasks/test_prompt_task.py | 4 ++- tests/integration/tasks/test_rag_task.py | 2 +- .../tasks/test_text_summary_task.py | 2 +- tests/integration/tasks/test_tool_task.py | 4 ++- tests/integration/tasks/test_toolkit_task.py | 2 +- .../integration/tools/test_calculator_tool.py | 2 +- .../tools/test_file_manager_tool.py | 2 +- .../tools/test_google_docs_tool.py | 2 +- .../tools/test_google_drive_tool.py | 2 +- tests/mocks/mock_event_listener_driver.py | 12 +++---- .../test_base_event_listener_driver.py | 8 ++--- ...est_hugging_face_pipeline_prompt_driver.py | 18 +++++----- .../drivers/sql/test_snowflake_sql_driver.py | 14 ++++---- .../test_local_structure_run_driver.py | 4 +-- ...est_footnote_prompt_response_rag_module.py | 2 +- .../test_prompt_response_rag_module.py | 2 +- .../mixins/test_futures_executor_mixin.py | 2 +- tests/unit/structures/test_pipeline.py | 4 +-- tests/unit/structures/test_workflow.py | 4 +-- tests/unit/tasks/test_code_execution_task.py | 6 ++-- tests/unit/tasks/test_structure_run_task.py | 4 +-- tests/unit/tasks/test_tool_task.py | 2 +- tests/unit/tasks/test_toolkit_task.py | 2 +- tests/unit/tools/test_structure_run_tool.py | 4 +-- .../test_variation_image_generation_tool.py | 22 ++++++------ tests/unit/tools/test_vector_store_tool.py | 4 +-- tests/unit/utils/test_chat.py | 16 ++++----- tests/utils/structure_tester.py | 2 +- 61 files changed, 209 insertions(+), 183 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d38286dcc7..a519cbac86 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,26 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Moved `griptape.common.observable.observable` to `griptape.common.decorators.observable`. - **BREAKING**: `AnthropicDriversConfig` no longer bundles `VoyageAiEmbeddingDriver`. - Updated `EventListener.handler` return type to `Optional[BaseEvent | dict]`. +- **BREAKING**: Renamed callables throughout the framework for consistency: + - Renamed `LocalStructureRunDriver.structure_factory_fn` to `LocalStructureRunDriver.structure_factory`. + - Renamed `SnowflakeSqlDriver.connection_func` to `SnowflakeSqlDriver.connection_provider`. + - Renamed `CsvLoader.formatter_fn` to `CsvLoader.row_formatter`. + - Renamed `SqlLoader.formatter_fn` to `SqlLoader.row_formatter`. + - Renamed `CsvExtractionEngine.system_template_generator` to `CsvExtractionEngine.system_template_provider`. + - Renamed `CsvExtractionEngine.user_template_generator` to `CsvExtractionEngine.user_template_provider`. + - Renamed `JsonExtractionEngine.system_template_generator` to `JsonExtractionEngine.system_template_provider`. + - Renamed `JsonExtractionEngine.user_template_generator` to `JsonExtractionEngine.user_template_provider`. + - Renamed `PromptResponseRagModule.generate_system_template` to `PromptResponseRagModule.system_template_provider`. + - Renamed `PromptTask.generate_system_template` to `PromptTask.system_template_provider`. + - Renamed `ToolkitTask.generate_assistant_subtask_template` to `ToolkitTask.assistant_subtask_template_provider`. + - Renamed `JsonSchemaRule.template_generator` to `JsonSchemaRule.template_provider`. + - Renamed `ToolkitTask.generate_user_subtask_template` to `ToolkitTask.user_subtask_template_provider`. + - Renamed `TextLoaderRetrievalRagModule.process_query_output_fn` to `TextLoaderRetrievalRagModule.query_output_processor`. + - Renamed `FuturesExecutorMixin.futures_executor_fn` to `FuturesExecutorMixin.futures_executor_factory`. + - Renamed `VectorStoreTool.process_query_output_fn` to `VectorStoreTool.query_output_processor`. + - Renamed `CodeExecutionTask.run_fn` to `CodeExecutionTask.runner`. + - Renamed `Chat.input_fn` to `Chat.input_handler`. + - Renamed `Chat.output_fn` to `Chat.output_handler`. - `BaseTask.parent_outputs` type has changed from `dict[str, str | None]` to `dict[str, BaseArtifact]`. - `Workflow.context["parent_outputs"]` type has changed from `dict[str, str | None]` to `dict[str, BaseArtifact]`. - `Pipeline.context["parent_output"]` has changed type from `str | None` to `BaseArtifact | None`. diff --git a/docs/examples/src/multi_agent_workflow_1.py b/docs/examples/src/multi_agent_workflow_1.py index ad9436a55a..e313995f9f 100644 --- a/docs/examples/src/multi_agent_workflow_1.py +++ b/docs/examples/src/multi_agent_workflow_1.py @@ -133,7 +133,7 @@ def build_writer(role: str, goal: str, backstory: str) -> Agent: ), id="research", driver=LocalStructureRunDriver( - structure_factory_fn=build_researcher, + structure_factory=build_researcher, ), ), ) @@ -150,7 +150,7 @@ def build_writer(role: str, goal: str, backstory: str) -> Agent: {{ parent_outputs["research"] }}""", ), driver=LocalStructureRunDriver( - structure_factory_fn=lambda writer=writer: build_writer( + structure_factory=lambda writer=writer: build_writer( role=writer["role"], goal=writer["goal"], backstory=writer["backstory"], diff --git a/docs/griptape-framework/drivers/src/sql_drivers_3.py b/docs/griptape-framework/drivers/src/sql_drivers_3.py index 29ee4a818b..d33ec8cae4 100644 --- a/docs/griptape-framework/drivers/src/sql_drivers_3.py +++ b/docs/griptape-framework/drivers/src/sql_drivers_3.py @@ -17,6 +17,6 @@ def get_snowflake_connection() -> SnowflakeConnection: ) -driver = SnowflakeSqlDriver(connection_func=get_snowflake_connection) +driver = SnowflakeSqlDriver(connection_provider=get_snowflake_connection) driver.execute_query("select * from people;") diff --git a/docs/griptape-framework/drivers/src/structure_run_drivers_1.py b/docs/griptape-framework/drivers/src/structure_run_drivers_1.py index a29bfbedfe..3d251700cb 100644 --- a/docs/griptape-framework/drivers/src/structure_run_drivers_1.py +++ b/docs/griptape-framework/drivers/src/structure_run_drivers_1.py @@ -28,13 +28,13 @@ def build_joke_rewriter() -> Agent: tasks=[ StructureRunTask( driver=LocalStructureRunDriver( - structure_factory_fn=build_joke_teller, + structure_factory=build_joke_teller, ), ), StructureRunTask( ("Rewrite this joke: {{ parent_output }}",), driver=LocalStructureRunDriver( - structure_factory_fn=build_joke_rewriter, + structure_factory=build_joke_rewriter, ), ), ] diff --git a/docs/griptape-framework/drivers/src/structure_run_drivers_2.py b/docs/griptape-framework/drivers/src/structure_run_drivers_2.py index 6103a65071..97556c126a 100644 --- a/docs/griptape-framework/drivers/src/structure_run_drivers_2.py +++ b/docs/griptape-framework/drivers/src/structure_run_drivers_2.py @@ -15,7 +15,7 @@ StructureRunTask( ("Think of a question related to Retrieval Augmented Generation.",), driver=LocalStructureRunDriver( - structure_factory_fn=lambda: Agent( + structure_factory=lambda: Agent( rules=[ Rule( value="You are an expert in Retrieval Augmented Generation.", diff --git a/docs/griptape-framework/engines/summary-engines.md b/docs/griptape-framework/engines/summary-engines.md index cfd99f3a87..f2f36d0b21 100644 --- a/docs/griptape-framework/engines/summary-engines.md +++ b/docs/griptape-framework/engines/summary-engines.md @@ -9,7 +9,7 @@ Summary engines are used to summarize text and collections of [TextArtifact](../ ## Prompt -Used to summarize texts with LLMs. You can set a custom [prompt_driver](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.prompt_driver), [system_template_generator](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.system_template_generator), [user_template_generator](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.user_template_generator), and [chunker](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.chunker). +Used to summarize texts with LLMs. You can set a custom [prompt_driver](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.prompt_driver), [system_template_provider](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.system_template_provider), [user_template_provider](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.user_template_provider), and [chunker](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.chunker). Use the [summarize_artifacts](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.summarize_artifacts) method to summarize a list of artifacts or [summarize_text](../../reference/griptape/engines/summary/base_summary_engine.md#griptape.engines.summary.base_summary_engine.BaseSummaryEngine.summarize_text) to summarize an arbitrary string. diff --git a/docs/griptape-framework/structures/src/tasks_10.py b/docs/griptape-framework/structures/src/tasks_10.py index c94fa79192..b337832690 100644 --- a/docs/griptape-framework/structures/src/tasks_10.py +++ b/docs/griptape-framework/structures/src/tasks_10.py @@ -14,7 +14,7 @@ def character_counter(task: CodeExecutionTask) -> BaseArtifact: pipeline.add_tasks( # take the first argument from the pipeline `run` method - CodeExecutionTask(run_fn=character_counter), + CodeExecutionTask(run_handler=character_counter), # # take the output from the previous task and insert it into the prompt PromptTask("{{args[0]}} using {{ parent_output }} characters"), ) diff --git a/docs/griptape-framework/structures/src/tasks_16.py b/docs/griptape-framework/structures/src/tasks_16.py index 7496d2d9c1..64e40a915d 100644 --- a/docs/griptape-framework/structures/src/tasks_16.py +++ b/docs/griptape-framework/structures/src/tasks_16.py @@ -112,7 +112,7 @@ def build_writer() -> Agent: """Perform a detailed examination of the newest developments in AI as of 2024. Pinpoint major trends, breakthroughs, and their implications for various industries.""", ), - driver=LocalStructureRunDriver(structure_factory_fn=build_researcher), + driver=LocalStructureRunDriver(structure_factory=build_researcher), ), StructureRunTask( ( @@ -122,7 +122,7 @@ def build_writer() -> Agent: Keep the tone appealing and use simple language to make it less technical.""", "{{parent_output}}", ), - driver=LocalStructureRunDriver(structure_factory_fn=build_writer), + driver=LocalStructureRunDriver(structure_factory=build_writer), ), ], ) diff --git a/griptape/drivers/sql/snowflake_sql_driver.py b/griptape/drivers/sql/snowflake_sql_driver.py index d1b4310b5a..b7fc156b07 100644 --- a/griptape/drivers/sql/snowflake_sql_driver.py +++ b/griptape/drivers/sql/snowflake_sql_driver.py @@ -15,16 +15,18 @@ @define class SnowflakeSqlDriver(BaseSqlDriver): - connection_func: Callable[[], SnowflakeConnection] = field(kw_only=True) + connection_provider: Callable[[], SnowflakeConnection] = field(kw_only=True) _engine: Engine = field(default=None, kw_only=True, alias="engine", metadata={"serializable": False}) - @connection_func.validator # pyright: ignore[reportFunctionMemberAccess] - def validate_connection_func(self, _: Attribute, connection_func: Callable[[], SnowflakeConnection]) -> None: - snowflake_connection = connection_func() + @connection_provider.validator # pyright: ignore[reportFunctionMemberAccess] + def validate_connection_provider( + self, _: Attribute, connection_provider: Callable[[], SnowflakeConnection] + ) -> None: + snowflake_connection = connection_provider() snowflake = import_optional_dependency("snowflake") if not isinstance(snowflake_connection, snowflake.connector.SnowflakeConnection): - raise ValueError("The connection_func must return a SnowflakeConnection") + raise ValueError("The connection_provider must return a SnowflakeConnection") if not snowflake_connection.schema or not snowflake_connection.database: raise ValueError("Provide a schema and database for the Snowflake connection") @@ -32,7 +34,7 @@ def validate_connection_func(self, _: Attribute, connection_func: Callable[[], S def engine(self) -> Engine: return import_optional_dependency("sqlalchemy").create_engine( "snowflake://not@used/db", - creator=self.connection_func, + creator=self.connection_provider, ) def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]: diff --git a/griptape/drivers/structure_run/local_structure_run_driver.py b/griptape/drivers/structure_run/local_structure_run_driver.py index c0049b29aa..2bfb257edf 100644 --- a/griptape/drivers/structure_run/local_structure_run_driver.py +++ b/griptape/drivers/structure_run/local_structure_run_driver.py @@ -14,18 +14,18 @@ @define class LocalStructureRunDriver(BaseStructureRunDriver): - structure_factory_fn: Callable[[], Structure] = field(kw_only=True) + structure_factory: Callable[[], Structure] = field(kw_only=True) def try_run(self, *args: BaseArtifact) -> BaseArtifact: old_env = os.environ.copy() try: os.environ.update(self.env) - structure_factory_fn = self.structure_factory_fn().run(*[arg.value for arg in args]) + structure = self.structure_factory().run(*[arg.value for arg in args]) finally: os.environ.clear() os.environ.update(old_env) - if structure_factory_fn.output_task.output is not None: - return structure_factory_fn.output_task.output + if structure.output_task.output is not None: + return structure.output_task.output else: return InfoArtifact("No output found in response") diff --git a/griptape/drivers/vector/local_vector_store_driver.py b/griptape/drivers/vector/local_vector_store_driver.py index 36203d5401..ab5ba68918 100644 --- a/griptape/drivers/vector/local_vector_store_driver.py +++ b/griptape/drivers/vector/local_vector_store_driver.py @@ -19,7 +19,7 @@ class LocalVectorStoreDriver(BaseVectorStoreDriver): entries: dict[str, BaseVectorStoreDriver.Entry] = field(factory=dict) persist_file: Optional[str] = field(default=None) - relatedness_fn: Callable = field(default=lambda x, y: dot(x, y) / (norm(x) * norm(y))) + relatedness_calculator: Callable = field(default=lambda x, y: dot(x, y) / (norm(x) * norm(y))) thread_lock: threading.Lock = field(default=Factory(lambda: threading.Lock())) def __attrs_post_init__(self) -> None: @@ -95,7 +95,7 @@ def query( entries = self.entries entries_and_relatednesses = [ - (entry, self.relatedness_fn(query_embedding, entry.vector)) for entry in list(entries.values()) + (entry, self.relatedness_calculator(query_embedding, entry.vector)) for entry in list(entries.values()) ] entries_and_relatednesses.sort(key=operator.itemgetter(1), reverse=True) diff --git a/griptape/engines/extraction/csv_extraction_engine.py b/griptape/engines/extraction/csv_extraction_engine.py index 7fb2a164bf..f250f3e7c9 100644 --- a/griptape/engines/extraction/csv_extraction_engine.py +++ b/griptape/engines/extraction/csv_extraction_engine.py @@ -18,9 +18,9 @@ @define class CsvExtractionEngine(BaseExtractionEngine): column_names: list[str] = field(kw_only=True) - system_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/system.j2")), kw_only=True) - user_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/user.j2")), kw_only=True) - formatter_fn: Callable[[dict], str] = field( + system_template_provider: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/system.j2")), kw_only=True) + user_template_provider: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/user.j2")), kw_only=True) + row_formatter: Callable[[dict], str] = field( default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True ) @@ -45,7 +45,7 @@ def text_to_csv_rows(self, text: str, column_names: list[str]) -> list[TextArtif with io.StringIO(text) as f: for row in csv.reader(f): - rows.append(TextArtifact(self.formatter_fn(dict(zip(column_names, [x.strip() for x in row]))))) + rows.append(TextArtifact(self.row_formatter(dict(zip(column_names, [x.strip() for x in row]))))) return rows @@ -57,11 +57,11 @@ def _extract_rec( rulesets: Optional[list[Ruleset]] = None, ) -> list[TextArtifact]: artifacts_text = self.chunk_joiner.join([a.value for a in artifacts]) - system_prompt = self.system_template_generator.render( + system_prompt = self.system_template_provider.render( column_names=self.column_names, rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets), ) - user_prompt = self.user_template_generator.render( + user_prompt = self.user_template_provider.render( text=artifacts_text, ) @@ -86,7 +86,7 @@ def _extract_rec( return rows else: chunks = self.chunker.chunk(artifacts_text) - partial_text = self.user_template_generator.render( + partial_text = self.user_template_provider.render( text=chunks[0].value, ) diff --git a/griptape/engines/extraction/json_extraction_engine.py b/griptape/engines/extraction/json_extraction_engine.py index c817efd5f7..376d213be1 100644 --- a/griptape/engines/extraction/json_extraction_engine.py +++ b/griptape/engines/extraction/json_extraction_engine.py @@ -21,10 +21,8 @@ class JsonExtractionEngine(BaseExtractionEngine): JSON_PATTERN = r"(?s)[^\[]*(\[.*\])" template_schema: dict = field(kw_only=True) - system_template_generator: J2 = field( - default=Factory(lambda: J2("engines/extraction/json/system.j2")), kw_only=True - ) - user_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/json/user.j2")), kw_only=True) + system_template_provider: J2 = field(default=Factory(lambda: J2("engines/extraction/json/system.j2")), kw_only=True) + user_template_provider: J2 = field(default=Factory(lambda: J2("engines/extraction/json/user.j2")), kw_only=True) def extract_artifacts( self, @@ -54,11 +52,11 @@ def _extract_rec( rulesets: Optional[list[Ruleset]] = None, ) -> list[JsonArtifact]: artifacts_text = self.chunk_joiner.join([a.value for a in artifacts]) - system_prompt = self.system_template_generator.render( + system_prompt = self.system_template_provider.render( json_template_schema=json.dumps(self.template_schema), rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets), ) - user_prompt = self.user_template_generator.render( + user_prompt = self.user_template_provider.render( text=artifacts_text, ) @@ -82,7 +80,7 @@ def _extract_rec( return extractions else: chunks = self.chunker.chunk(artifacts_text) - partial_text = self.user_template_generator.render( + partial_text = self.user_template_provider.render( text=chunks[0].value, ) diff --git a/griptape/engines/rag/modules/query/translate_query_rag_module.py b/griptape/engines/rag/modules/query/translate_query_rag_module.py index f1f9ca0ece..5526d366d0 100644 --- a/griptape/engines/rag/modules/query/translate_query_rag_module.py +++ b/griptape/engines/rag/modules/query/translate_query_rag_module.py @@ -16,17 +16,17 @@ class TranslateQueryRagModule(BaseQueryRagModule): prompt_driver: BasePromptDriver = field() language: str = field() - generate_user_template: Callable[[str, str], str] = field( - default=Factory(lambda self: self.default_user_template_generator, takes_self=True), + user_template_provider: Callable[[str, str], str] = field( + default=Factory(lambda self: self.default_user_template_provider, takes_self=True), ) def run(self, context: RagContext) -> RagContext: - user_prompt = self.generate_user_template(context.query, self.language) + user_prompt = self.user_template_provider(context.query, self.language) output = self.prompt_driver.run(self.generate_prompt_stack(None, user_prompt)).to_artifact() context.query = output.to_text() return context - def default_user_template_generator(self, query: str, language: str) -> str: + def default_user_template_provider(self, query: str, language: str) -> str: return J2("engines/rag/modules/query/translate/user.j2").render(query=query, language=language) diff --git a/griptape/engines/rag/modules/response/footnote_prompt_response_rag_module.py b/griptape/engines/rag/modules/response/footnote_prompt_response_rag_module.py index 3687d1942b..d21e9d3434 100644 --- a/griptape/engines/rag/modules/response/footnote_prompt_response_rag_module.py +++ b/griptape/engines/rag/modules/response/footnote_prompt_response_rag_module.py @@ -15,7 +15,7 @@ @define(kw_only=True) class FootnotePromptResponseRagModule(PromptResponseRagModule): - def default_system_template_generator(self, context: RagContext, artifacts: list[TextArtifact]) -> str: + def default_system_template_provider(self, context: RagContext, artifacts: list[TextArtifact]) -> str: return J2("engines/rag/modules/response/footnote_prompt/system.j2").render( text_chunk_artifacts=artifacts, references=utils.references_from_artifacts(artifacts), diff --git a/griptape/engines/rag/modules/response/prompt_response_rag_module.py b/griptape/engines/rag/modules/response/prompt_response_rag_module.py index b62a0eba3c..a7413b515c 100644 --- a/griptape/engines/rag/modules/response/prompt_response_rag_module.py +++ b/griptape/engines/rag/modules/response/prompt_response_rag_module.py @@ -21,20 +21,20 @@ class PromptResponseRagModule(BaseResponseRagModule, RuleMixin): prompt_driver: BasePromptDriver = field(default=Factory(lambda: Defaults.drivers_config.prompt_driver)) answer_token_offset: int = field(default=400) metadata: Optional[str] = field(default=None) - generate_system_template: Callable[[RagContext, list[TextArtifact]], str] = field( - default=Factory(lambda self: self.default_system_template_generator, takes_self=True), + system_template_provider: Callable[[RagContext, list[TextArtifact]], str] = field( + default=Factory(lambda self: self.default_system_template_provider, takes_self=True), ) def run(self, context: RagContext) -> BaseArtifact: query = context.query tokenizer = self.prompt_driver.tokenizer included_chunks = [] - system_prompt = self.generate_system_template(context, included_chunks) + system_prompt = self.system_template_provider(context, included_chunks) for artifact in context.text_chunks: included_chunks.append(artifact) - system_prompt = self.generate_system_template(context, included_chunks) + system_prompt = self.system_template_provider(context, included_chunks) message_token_count = self.prompt_driver.tokenizer.count_tokens( self.prompt_driver.prompt_stack_to_string(self.generate_prompt_stack(system_prompt, query)), ) @@ -42,7 +42,7 @@ def run(self, context: RagContext) -> BaseArtifact: if message_token_count + self.answer_token_offset >= tokenizer.max_input_tokens: included_chunks.pop() - system_prompt = self.generate_system_template(context, included_chunks) + system_prompt = self.system_template_provider(context, included_chunks) break @@ -53,7 +53,7 @@ def run(self, context: RagContext) -> BaseArtifact: else: raise ValueError("Prompt driver did not return a TextArtifact") - def default_system_template_generator(self, context: RagContext, artifacts: list[TextArtifact]) -> str: + def default_system_template_provider(self, context: RagContext, artifacts: list[TextArtifact]) -> str: params: dict[str, Any] = {"text_chunks": [c.to_text() for c in artifacts]} if len(self.rulesets) > 0: diff --git a/griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py b/griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py index 0348a20946..7264c6be53 100644 --- a/griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py +++ b/griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py @@ -25,7 +25,7 @@ class TextLoaderRetrievalRagModule(BaseRetrievalRagModule): vector_store_driver: BaseVectorStoreDriver = field() source: Any = field() query_params: dict[str, Any] = field(factory=dict) - process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field( + query_output_processor: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field( default=Factory(lambda: lambda es: [e.to_artifact() for e in es]), ) @@ -43,4 +43,4 @@ def run(self, context: RagContext) -> Sequence[TextArtifact]: self.vector_store_driver.upsert_text_artifacts({namespace: chunks}) - return self.process_query_output_fn(self.vector_store_driver.query(context.query, **query_params)) + return self.query_output_processor(self.vector_store_driver.query(context.query, **query_params)) diff --git a/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py b/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py index ddff2549c3..1027548f38 100644 --- a/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py +++ b/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py @@ -22,11 +22,11 @@ class VectorStoreRetrievalRagModule(BaseRetrievalRagModule): default=Factory(lambda: Defaults.drivers_config.vector_store_driver) ) query_params: dict[str, Any] = field(factory=dict) - process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field( + query_output_processor: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field( default=Factory(lambda: lambda es: [e.to_artifact() for e in es]), ) def run(self, context: RagContext) -> Sequence[TextArtifact]: query_params = utils.dict_merge(self.query_params, self.get_context_param(context, "query_params")) - return self.process_query_output_fn(self.vector_store_driver.query(context.query, **query_params)) + return self.query_output_processor(self.vector_store_driver.query(context.query, **query_params)) diff --git a/griptape/engines/summary/prompt_summary_engine.py b/griptape/engines/summary/prompt_summary_engine.py index 3cc3dd4704..9724506799 100644 --- a/griptape/engines/summary/prompt_summary_engine.py +++ b/griptape/engines/summary/prompt_summary_engine.py @@ -20,8 +20,8 @@ class PromptSummaryEngine(BaseSummaryEngine): chunk_joiner: str = field(default="\n\n", kw_only=True) max_token_multiplier: float = field(default=0.5, kw_only=True) - system_template_generator: J2 = field(default=Factory(lambda: J2("engines/summary/system.j2")), kw_only=True) - user_template_generator: J2 = field(default=Factory(lambda: J2("engines/summary/user.j2")), kw_only=True) + system_template_provider: J2 = field(default=Factory(lambda: J2("engines/summary/system.j2")), kw_only=True) + user_template_provider: J2 = field(default=Factory(lambda: J2("engines/summary/user.j2")), kw_only=True) prompt_driver: BasePromptDriver = field( default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True ) @@ -67,12 +67,12 @@ def summarize_artifacts_rec( artifacts_text = self.chunk_joiner.join([a.to_text() for a in artifacts]) - system_prompt = self.system_template_generator.render( + system_prompt = self.system_template_provider.render( summary=summary, rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets), ) - user_prompt = self.user_template_generator.render(text=artifacts_text) + user_prompt = self.user_template_provider.render(text=artifacts_text) if ( self.prompt_driver.tokenizer.count_input_tokens_left(user_prompt + system_prompt) @@ -94,7 +94,7 @@ def summarize_artifacts_rec( else: chunks = self.chunker.chunk(artifacts_text) - partial_text = self.user_template_generator.render(text=chunks[0].value) + partial_text = self.user_template_provider.render(text=chunks[0].value) return self.summarize_artifacts_rec( chunks[1:], diff --git a/griptape/loaders/csv_loader.py b/griptape/loaders/csv_loader.py index 4487d7aec5..ed683bc659 100644 --- a/griptape/loaders/csv_loader.py +++ b/griptape/loaders/csv_loader.py @@ -14,7 +14,7 @@ class CsvLoader(BaseFileLoader[ListArtifact[TextArtifact]]): delimiter: str = field(default=",", kw_only=True) encoding: str = field(default="utf-8", kw_only=True) - formatter_fn: Callable[[dict], str] = field( + row_formatter: Callable[[dict], str] = field( default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True ) @@ -22,5 +22,5 @@ def parse(self, data: bytes) -> ListArtifact[TextArtifact]: reader = csv.DictReader(StringIO(data.decode(self.encoding)), delimiter=self.delimiter) return ListArtifact( - [TextArtifact(self.formatter_fn(row), meta={"row_num": row_num}) for row_num, row in enumerate(reader)] + [TextArtifact(self.row_formatter(row), meta={"row_num": row_num}) for row_num, row in enumerate(reader)] ) diff --git a/griptape/loaders/sql_loader.py b/griptape/loaders/sql_loader.py index 0c6e8bdf94..36ae8a4635 100644 --- a/griptape/loaders/sql_loader.py +++ b/griptape/loaders/sql_loader.py @@ -12,7 +12,7 @@ @define class SqlLoader(BaseLoader[str, list[BaseSqlDriver.RowResult], ListArtifact[TextArtifact]]): sql_driver: BaseSqlDriver = field(kw_only=True) - formatter_fn: Callable[[dict], str] = field( + row_formatter: Callable[[dict], str] = field( default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True ) @@ -20,4 +20,4 @@ def fetch(self, source: str) -> list[BaseSqlDriver.RowResult]: return self.sql_driver.execute_query(source) or [] def parse(self, data: list[BaseSqlDriver.RowResult]) -> ListArtifact[TextArtifact]: - return ListArtifact([TextArtifact(self.formatter_fn(row.cells)) for row in data]) + return ListArtifact([TextArtifact(self.row_formatter(row.cells)) for row in data]) diff --git a/griptape/memory/structure/summary_conversation_memory.py b/griptape/memory/structure/summary_conversation_memory.py index 055057d344..440a39e8a8 100644 --- a/griptape/memory/structure/summary_conversation_memory.py +++ b/griptape/memory/structure/summary_conversation_memory.py @@ -23,8 +23,8 @@ class SummaryConversationMemory(ConversationMemory): ) summary: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) summary_index: int = field(default=0, kw_only=True, metadata={"serializable": True}) - summary_template_generator: J2 = field(default=Factory(lambda: J2("memory/conversation/summary.j2")), kw_only=True) - summarize_conversation_template_generator: J2 = field( + summary_template_provider: J2 = field(default=Factory(lambda: J2("memory/conversation/summary.j2")), kw_only=True) + summarize_conversation_template_provider: J2 = field( default=Factory(lambda: J2("memory/conversation/summarize_conversation.j2")), kw_only=True, ) @@ -32,7 +32,7 @@ class SummaryConversationMemory(ConversationMemory): def to_prompt_stack(self, last_n: Optional[int] = None) -> PromptStack: stack = PromptStack() if self.summary: - stack.add_user_message(self.summary_template_generator.render(summary=self.summary)) + stack.add_user_message(self.summary_template_provider.render(summary=self.summary)) for r in self.unsummarized_runs(last_n): stack.add_user_message(r.input) @@ -66,7 +66,7 @@ def try_add_run(self, run: Run) -> None: def summarize_runs(self, previous_summary: str | None, runs: list[Run]) -> str | None: try: if len(runs) > 0: - summary = self.summarize_conversation_template_generator.render(summary=previous_summary, runs=runs) + summary = self.summarize_conversation_template_provider.render(summary=previous_summary, runs=runs) return self.prompt_driver.run( prompt_stack=PromptStack(messages=[Message(summary, role=Message.USER_ROLE)]), ).to_text() diff --git a/griptape/mixins/futures_executor_mixin.py b/griptape/mixins/futures_executor_mixin.py index 84a3b6f25e..f97ade883b 100644 --- a/griptape/mixins/futures_executor_mixin.py +++ b/griptape/mixins/futures_executor_mixin.py @@ -9,12 +9,12 @@ @define(slots=False, kw_only=True) class FuturesExecutorMixin(ABC): - futures_executor_fn: Callable[[], futures.Executor] = field( + futures_executor_factory: Callable[[], futures.Executor] = field( default=Factory(lambda: lambda: futures.ThreadPoolExecutor()), ) futures_executor: Optional[futures.Executor] = field( - default=Factory(lambda self: self.futures_executor_fn(), takes_self=True) + default=Factory(lambda self: self.futures_executor_factory(), takes_self=True) ) def __del__(self) -> None: diff --git a/griptape/rules/json_schema_rule.py b/griptape/rules/json_schema_rule.py index 1bd4184646..f311cc022e 100644 --- a/griptape/rules/json_schema_rule.py +++ b/griptape/rules/json_schema_rule.py @@ -2,7 +2,7 @@ import json -from attrs import define, field +from attrs import Factory, define, field from griptape.rules import BaseRule from griptape.utils import J2 @@ -11,7 +11,7 @@ @define(frozen=True) class JsonSchemaRule(BaseRule): value: dict = field() - template_generator: J2 = field(default=J2("rules/json_schema.j2")) + template_provider: J2 = field(default=Factory(lambda: J2("rules/json_schema.j2"))) def to_text(self) -> str: - return self.template_generator.render(json_schema=json.dumps(self.value)) + return self.template_provider.render(json_schema=json.dumps(self.value)) diff --git a/griptape/tasks/code_execution_task.py b/griptape/tasks/code_execution_task.py index d627382fd0..855ad848f4 100644 --- a/griptape/tasks/code_execution_task.py +++ b/griptape/tasks/code_execution_task.py @@ -12,7 +12,7 @@ @define class CodeExecutionTask(BaseTextInputTask): - run_fn: Callable[[CodeExecutionTask], BaseArtifact] = field(kw_only=True) + run_handler: Callable[[CodeExecutionTask], BaseArtifact] = field(kw_only=True) def run(self) -> BaseArtifact: - return self.run_fn(self) + return self.run_handler(self) diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index ed3ffa4524..dd199bc56b 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -24,8 +24,8 @@ class PromptTask(RuleMixin, BaseTask): prompt_driver: BasePromptDriver = field( default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True ) - generate_system_template: Callable[[PromptTask], str] = field( - default=Factory(lambda self: self.default_system_template_generator, takes_self=True), + system_template_provider: Callable[[PromptTask], str] = field( + default=Factory(lambda self: self.default_system_template_provider, takes_self=True), kw_only=True, ) _input: Union[str, list, tuple, BaseArtifact, Callable[[BaseTask], BaseArtifact]] = field( @@ -64,7 +64,7 @@ def prompt_stack(self) -> PromptStack: stack = PromptStack() memory = self.structure.conversation_memory if self.structure is not None else None - system_template = self.generate_system_template(self) + system_template = self.system_template_provider(self) if system_template: stack.add_system_message(system_template) @@ -79,7 +79,7 @@ def prompt_stack(self) -> PromptStack: return stack - def default_system_template_generator(self, _: PromptTask) -> str: + def default_system_template_provider(self, _: PromptTask) -> str: return J2("tasks/prompt_task/system.j2").render( rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.rulesets), ) diff --git a/griptape/tasks/tool_task.py b/griptape/tasks/tool_task.py index a9a36ddb6a..a5c791fb45 100644 --- a/griptape/tasks/tool_task.py +++ b/griptape/tasks/tool_task.py @@ -49,7 +49,7 @@ def preprocess(self, structure: Structure) -> ToolTask: return self - def default_system_template_generator(self, _: PromptTask) -> str: + def default_system_template_provider(self, _: PromptTask) -> str: return J2("tasks/tool_task/system.j2").render( rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.rulesets), action_schema=utils.minify_json(json.dumps(self.tool.schema())), diff --git a/griptape/tasks/toolkit_task.py b/griptape/tasks/toolkit_task.py index e2b0e23a5e..fba2f8c554 100644 --- a/griptape/tasks/toolkit_task.py +++ b/griptape/tasks/toolkit_task.py @@ -31,12 +31,12 @@ class ToolkitTask(PromptTask, ActionsSubtaskOriginMixin): max_subtasks: int = field(default=DEFAULT_MAX_STEPS, kw_only=True) task_memory: Optional[TaskMemory] = field(default=None, kw_only=True) subtasks: list[ActionsSubtask] = field(factory=list) - generate_assistant_subtask_template: Callable[[ActionsSubtask], str] = field( - default=Factory(lambda self: self.default_assistant_subtask_template_generator, takes_self=True), + assistant_subtask_template_provider: Callable[[ActionsSubtask], str] = field( + default=Factory(lambda self: self.default_assistant_subtask_template_provider, takes_self=True), kw_only=True, ) - generate_user_subtask_template: Callable[[ActionsSubtask], str] = field( - default=Factory(lambda self: self.default_user_subtask_template_generator, takes_self=True), + user_subtask_template_provider: Callable[[ActionsSubtask], str] = field( + default=Factory(lambda self: self.default_user_subtask_template_provider, takes_self=True), kw_only=True, ) response_stop_sequence: str = field(default=RESPONSE_STOP_SEQUENCE, kw_only=True) @@ -70,7 +70,7 @@ def prompt_stack(self) -> PromptStack: stack = PromptStack(tools=self.tools) memory = self.structure.conversation_memory if self.structure is not None else None - stack.add_system_message(self.generate_system_template(self)) + stack.add_system_message(self.system_template_provider(self)) stack.add_user_message(self.input) @@ -110,8 +110,8 @@ def prompt_stack(self) -> PromptStack: ), ) else: - stack.add_assistant_message(self.generate_assistant_subtask_template(s)) - stack.add_user_message(self.generate_user_subtask_template(s)) + stack.add_assistant_message(self.assistant_subtask_template_provider(s)) + stack.add_user_message(self.user_subtask_template_provider(s)) if memory is not None: # inserting at index 1 to place memory right after system prompt @@ -127,7 +127,7 @@ def preprocess(self, structure: Structure) -> ToolkitTask: return self - def default_system_template_generator(self, _: PromptTask) -> str: + def default_system_template_provider(self, _: PromptTask) -> str: schema = self.actions_schema().json_schema("Actions Schema") schema["minItems"] = 1 # The `schema` library doesn't support `minItems` so we must add it manually. @@ -140,13 +140,13 @@ def default_system_template_generator(self, _: PromptTask) -> str: stop_sequence=self.response_stop_sequence, ) - def default_assistant_subtask_template_generator(self, subtask: ActionsSubtask) -> str: + def default_assistant_subtask_template_provider(self, subtask: ActionsSubtask) -> str: return J2("tasks/toolkit_task/assistant_subtask.j2").render( stop_sequence=self.response_stop_sequence, subtask=subtask, ) - def default_user_subtask_template_generator(self, subtask: ActionsSubtask) -> str: + def default_user_subtask_template_provider(self, subtask: ActionsSubtask) -> str: return J2("tasks/toolkit_task/user_subtask.j2").render( stop_sequence=self.response_stop_sequence, subtask=subtask, diff --git a/griptape/tools/vector_store/tool.py b/griptape/tools/vector_store/tool.py index 71902b1c7b..d9c98ca4d7 100644 --- a/griptape/tools/vector_store/tool.py +++ b/griptape/tools/vector_store/tool.py @@ -21,7 +21,7 @@ class VectorStoreTool(BaseTool): description: LLM-friendly vector DB description. vector_store_driver: `BaseVectorStoreDriver`. query_params: Optional dictionary of vector store driver query parameters. - process_query_output_fn: Optional lambda for processing vector store driver query output `Entry`s. + query_output_processor: Optional lambda for processing vector store driver query output `Entry`s. """ DEFAULT_TOP_N = 5 @@ -29,7 +29,7 @@ class VectorStoreTool(BaseTool): description: str = field() vector_store_driver: BaseVectorStoreDriver = field() query_params: dict[str, Any] = field(factory=dict) - process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], BaseArtifact] = field( + query_output_processor: Callable[[list[BaseVectorStoreDriver.Entry]], BaseArtifact] = field( default=Factory(lambda: lambda es: ListArtifact([e.to_artifact() for e in es])), ) @@ -50,6 +50,6 @@ def search(self, params: dict) -> BaseArtifact: query = params["values"]["query"] try: - return self.process_query_output_fn(self.vector_store_driver.query(query, **self.query_params)) + return self.query_output_processor(self.vector_store_driver.query(query, **self.query_params)) except Exception as e: return ErrorArtifact(f"error querying vector store: {e}") diff --git a/griptape/utils/chat.py b/griptape/utils/chat.py index 802fd809dd..295a3dba01 100644 --- a/griptape/utils/chat.py +++ b/griptape/utils/chat.py @@ -25,8 +25,8 @@ class Chat: intro_text: Text to display when the chat starts. prompt_prefix: Prefix for the user's input. response_prefix: Prefix for the assistant's response. - input_fn: Function to get the user's input. - output_fn: Function to output text. Takes a `text` argument for the text to output. + input_handler: Function to get the user's input. + output_handler: Function to output text. Takes a `text` argument for the text to output. Also takes a `stream` argument which will be set to True when streaming Prompt Tasks are present. """ @@ -40,19 +40,19 @@ class ChatPrompt(Prompt): intro_text: Optional[str] = field(default=None, kw_only=True) prompt_prefix: str = field(default="User: ", kw_only=True) response_prefix: str = field(default="Assistant: ", kw_only=True) - input_fn: Callable[[str], str] = field( - default=Factory(lambda self: self.default_input_fn, takes_self=True), kw_only=True + input_handler: Callable[[str], str] = field( + default=Factory(lambda self: self.default_input_handler, takes_self=True), kw_only=True ) - output_fn: Callable[..., None] = field( - default=Factory(lambda self: self.default_output_fn, takes_self=True), + output_handler: Callable[..., None] = field( + default=Factory(lambda self: self.default_output_handler, takes_self=True), kw_only=True, ) logger_level: int = field(default=logging.ERROR, kw_only=True) - def default_input_fn(self, prompt_prefix: str) -> str: + def default_input_handler(self, prompt_prefix: str) -> str: return Chat.ChatPrompt.ask(prompt_prefix) - def default_output_fn(self, text: str, *, stream: bool = False) -> None: + def default_output_handler(self, text: str, *, stream: bool = False) -> None: if stream: rprint(text, end="", flush=True) else: @@ -66,26 +66,28 @@ def start(self) -> None: logging.getLogger(Defaults.logging_config.logger_name).setLevel(self.logger_level) if self.intro_text: - self.output_fn(self.intro_text) + self.output_handler(self.intro_text) has_streaming_tasks = self._has_streaming_tasks() while True: - question = self.input_fn(self.prompt_prefix) + question = self.input_handler(self.prompt_prefix) if question.lower() in self.exit_keywords: - self.output_fn(self.exiting_text) + self.output_handler(self.exiting_text) break if has_streaming_tasks: - self.output_fn(self.processing_text) + self.output_handler(self.processing_text) stream = Stream(self.structure).run(question) first_chunk = next(stream) - self.output_fn(self.response_prefix + first_chunk.value, stream=True) + self.output_handler(self.response_prefix + first_chunk.value, stream=True) for chunk in stream: - self.output_fn(chunk.value, stream=True) + self.output_handler(chunk.value, stream=True) else: - self.output_fn(self.processing_text) - self.output_fn(f"{self.response_prefix}{self.structure.run(question).output_task.output.to_text()}") + self.output_handler(self.processing_text) + self.output_handler( + f"{self.response_prefix}{self.structure.run(question).output_task.output.to_text()}" + ) # Restore the original logger level logging.getLogger(Defaults.logging_config.logger_name).setLevel(old_logger_level) diff --git a/tests/integration/rules/test_rule.py b/tests/integration/rules/test_rule.py index a62263c576..91c4276533 100644 --- a/tests/integration/rules/test_rule.py +++ b/tests/integration/rules/test_rule.py @@ -5,7 +5,7 @@ class TestRule: @pytest.fixture( - autouse=True, params=StructureTester.RULE_CAPABLE_PROMPT_DRIVERS, ids=StructureTester.prompt_driver_id_fn + autouse=True, params=StructureTester.RULE_CAPABLE_PROMPT_DRIVERS, ids=StructureTester.generate_prompt_driver_id ) def structure_tester(self, request): from griptape.rules import Rule diff --git a/tests/integration/tasks/test_csv_extraction_task.py b/tests/integration/tasks/test_csv_extraction_task.py index db58b96158..3e2186faeb 100644 --- a/tests/integration/tasks/test_csv_extraction_task.py +++ b/tests/integration/tasks/test_csv_extraction_task.py @@ -7,7 +7,7 @@ class TestCsvExtractionTask: @pytest.fixture( autouse=True, params=StructureTester.CSV_EXTRACTION_TASK_CAPABLE_PROMPT_DRIVERS, - ids=StructureTester.prompt_driver_id_fn, + ids=StructureTester.generate_prompt_driver_id, ) def structure_tester(self, request): from griptape.engines import CsvExtractionEngine diff --git a/tests/integration/tasks/test_json_extraction_task.py b/tests/integration/tasks/test_json_extraction_task.py index 115f805dab..e13fa7aa56 100644 --- a/tests/integration/tasks/test_json_extraction_task.py +++ b/tests/integration/tasks/test_json_extraction_task.py @@ -7,7 +7,7 @@ class TestJsonExtractionTask: @pytest.fixture( autouse=True, params=StructureTester.JSON_EXTRACTION_TASK_CAPABLE_PROMPT_DRIVERS, - ids=StructureTester.prompt_driver_id_fn, + ids=StructureTester.generate_prompt_driver_id, ) def structure_tester(self, request): from schema import Schema diff --git a/tests/integration/tasks/test_prompt_task.py b/tests/integration/tasks/test_prompt_task.py index 1d223b4ca8..95106a9a08 100644 --- a/tests/integration/tasks/test_prompt_task.py +++ b/tests/integration/tasks/test_prompt_task.py @@ -5,7 +5,9 @@ class TestPromptTask: @pytest.fixture( - autouse=True, params=StructureTester.PROMPT_TASK_CAPABLE_PROMPT_DRIVERS, ids=StructureTester.prompt_driver_id_fn + autouse=True, + params=StructureTester.PROMPT_TASK_CAPABLE_PROMPT_DRIVERS, + ids=StructureTester.generate_prompt_driver_id, ) def structure_tester(self, request): from griptape.structures import Agent diff --git a/tests/integration/tasks/test_rag_task.py b/tests/integration/tasks/test_rag_task.py index ce3a9140de..255e608f3f 100644 --- a/tests/integration/tasks/test_rag_task.py +++ b/tests/integration/tasks/test_rag_task.py @@ -9,7 +9,7 @@ class TestRagTask: @pytest.fixture( autouse=True, params=StructureTester.TEXT_SUMMARY_TASK_CAPABLE_PROMPT_DRIVERS, - ids=StructureTester.prompt_driver_id_fn, + ids=StructureTester.generate_prompt_driver_id, ) def structure_tester(self, request): from griptape.artifacts import TextArtifact diff --git a/tests/integration/tasks/test_text_summary_task.py b/tests/integration/tasks/test_text_summary_task.py index ff6597ba00..811ec39f68 100644 --- a/tests/integration/tasks/test_text_summary_task.py +++ b/tests/integration/tasks/test_text_summary_task.py @@ -7,7 +7,7 @@ class TestTextSummaryTask: @pytest.fixture( autouse=True, params=StructureTester.TEXT_SUMMARY_TASK_CAPABLE_PROMPT_DRIVERS, - ids=StructureTester.prompt_driver_id_fn, + ids=StructureTester.generate_prompt_driver_id, ) def structure_tester(self, request): from griptape.engines.summary.prompt_summary_engine import PromptSummaryEngine diff --git a/tests/integration/tasks/test_tool_task.py b/tests/integration/tasks/test_tool_task.py index 426dde995a..712e23d26e 100644 --- a/tests/integration/tasks/test_tool_task.py +++ b/tests/integration/tasks/test_tool_task.py @@ -5,7 +5,9 @@ class TestToolTask: @pytest.fixture( - autouse=True, params=StructureTester.TOOL_TASK_CAPABLE_PROMPT_DRIVERS, ids=StructureTester.prompt_driver_id_fn + autouse=True, + params=StructureTester.TOOL_TASK_CAPABLE_PROMPT_DRIVERS, + ids=StructureTester.generate_prompt_driver_id, ) def structure_tester(self, request): from griptape.structures import Agent diff --git a/tests/integration/tasks/test_toolkit_task.py b/tests/integration/tasks/test_toolkit_task.py index 50b4f2a977..7593c5391b 100644 --- a/tests/integration/tasks/test_toolkit_task.py +++ b/tests/integration/tasks/test_toolkit_task.py @@ -7,7 +7,7 @@ class TestToolkitTask: @pytest.fixture( autouse=True, params=StructureTester.TOOLKIT_TASK_CAPABLE_PROMPT_DRIVERS, - ids=StructureTester.prompt_driver_id_fn, + ids=StructureTester.generate_prompt_driver_id, ) def structure_tester(self, request): import os diff --git a/tests/integration/tools/test_calculator_tool.py b/tests/integration/tools/test_calculator_tool.py index c209a9a2cb..634b848033 100644 --- a/tests/integration/tools/test_calculator_tool.py +++ b/tests/integration/tools/test_calculator_tool.py @@ -7,7 +7,7 @@ class TestCalculator: @pytest.fixture( autouse=True, params=StructureTester.TOOLKIT_TASK_CAPABLE_PROMPT_DRIVERS, - ids=StructureTester.prompt_driver_id_fn, + ids=StructureTester.generate_prompt_driver_id, ) def structure_tester(self, request): from griptape.structures import Agent diff --git a/tests/integration/tools/test_file_manager_tool.py b/tests/integration/tools/test_file_manager_tool.py index 4b52991755..ce6b331c22 100644 --- a/tests/integration/tools/test_file_manager_tool.py +++ b/tests/integration/tools/test_file_manager_tool.py @@ -7,7 +7,7 @@ class TestFileManager: @pytest.fixture( autouse=True, params=StructureTester.TOOLKIT_TASK_CAPABLE_PROMPT_DRIVERS, - ids=StructureTester.prompt_driver_id_fn, + ids=StructureTester.generate_prompt_driver_id, ) def structure_tester(self, request): from griptape.structures import Agent diff --git a/tests/integration/tools/test_google_docs_tool.py b/tests/integration/tools/test_google_docs_tool.py index 7c8828dd30..e977e55237 100644 --- a/tests/integration/tools/test_google_docs_tool.py +++ b/tests/integration/tools/test_google_docs_tool.py @@ -9,7 +9,7 @@ class TestGoogleDocsTool: @pytest.fixture( autouse=True, params=StructureTester.TOOLKIT_TASK_CAPABLE_PROMPT_DRIVERS, - ids=StructureTester.prompt_driver_id_fn, + ids=StructureTester.generate_prompt_driver_id, ) def structure_tester(self, request): from griptape.structures import Agent diff --git a/tests/integration/tools/test_google_drive_tool.py b/tests/integration/tools/test_google_drive_tool.py index 7fd8b90470..fdd9fde892 100644 --- a/tests/integration/tools/test_google_drive_tool.py +++ b/tests/integration/tools/test_google_drive_tool.py @@ -9,7 +9,7 @@ class TestGoogleDriveTool: @pytest.fixture( autouse=True, params=StructureTester.TOOLKIT_TASK_CAPABLE_PROMPT_DRIVERS, - ids=StructureTester.prompt_driver_id_fn, + ids=StructureTester.generate_prompt_driver_id, ) def structure_tester(self, request): from griptape.structures import Agent diff --git a/tests/mocks/mock_event_listener_driver.py b/tests/mocks/mock_event_listener_driver.py index e56d35e90d..a7805819ee 100644 --- a/tests/mocks/mock_event_listener_driver.py +++ b/tests/mocks/mock_event_listener_driver.py @@ -9,13 +9,13 @@ @define class MockEventListenerDriver(BaseEventListenerDriver): - try_publish_event_payload_fn: Optional[Callable[[dict], None]] = field(default=None, kw_only=True) - try_publish_event_payload_batch_fn: Optional[Callable[[list[dict]], None]] = field(default=None, kw_only=True) + publish_event_payload_callback: Optional[Callable[[dict], None]] = field(default=None, kw_only=True) + publish_event_payload_batch_callback: Optional[Callable[[list[dict]], None]] = field(default=None, kw_only=True) def try_publish_event_payload(self, event_payload: dict) -> None: - if self.try_publish_event_payload_fn is not None: - self.try_publish_event_payload_fn(event_payload) + if self.publish_event_payload_callback is not None: + self.publish_event_payload_callback(event_payload) def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: - if self.try_publish_event_payload_batch_fn is not None: - self.try_publish_event_payload_batch_fn(event_payload_batch) + if self.publish_event_payload_batch_callback is not None: + self.publish_event_payload_batch_callback(event_payload_batch) diff --git a/tests/unit/drivers/event_listener/test_base_event_listener_driver.py b/tests/unit/drivers/event_listener/test_base_event_listener_driver.py index 365fb21996..9bed3ee610 100644 --- a/tests/unit/drivers/event_listener/test_base_event_listener_driver.py +++ b/tests/unit/drivers/event_listener/test_base_event_listener_driver.py @@ -59,7 +59,7 @@ def test__safe_publish_event_payload(self): mock_fn = MagicMock() driver = MockEventListenerDriver( batched=False, - try_publish_event_payload_fn=mock_fn, + publish_event_payload_callback=mock_fn, ) mock_event_payload = MockEvent().to_dict() @@ -71,7 +71,7 @@ def test__safe_publish_event_payload_batch(self): mock_fn = MagicMock() driver = MockEventListenerDriver( batched=True, - try_publish_event_payload_batch_fn=mock_fn, + publish_event_payload_batch_callback=mock_fn, ) mock_event_payloads = [MockEvent().to_dict() for _ in range(0, 3)] @@ -83,7 +83,7 @@ def test__safe_publish_event_payload_error(self): mock_fn = MagicMock() driver = MockEventListenerDriver( batched=False, - try_publish_event_payload_fn=mock_fn, + publish_event_payload_callback=mock_fn, max_attempts=2, max_retry_delay=0.1, min_retry_delay=0.1, @@ -100,7 +100,7 @@ def test__safe_publish_event_payload_batch_error(self): mock_fn = MagicMock() driver = MockEventListenerDriver( batched=True, - try_publish_event_payload_batch_fn=mock_fn, + publish_event_payload_batch_callback=mock_fn, max_attempts=2, max_retry_delay=0.1, min_retry_delay=0.1, diff --git a/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py index 0ece6c9767..4c24e90a13 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py @@ -10,11 +10,11 @@ def mock_pipeline(self, mocker): return mocker.patch("transformers.pipeline") @pytest.fixture(autouse=True) - def mock_generator(self, mock_pipeline): - mock_generator = mock_pipeline.return_value - mock_generator.task = "text-generation" - mock_generator.return_value = [{"generated_text": [{"content": "model-output"}]}] - return mock_generator + def mock_provider(self, mock_pipeline): + mock_provider = mock_pipeline.return_value + mock_provider.task = "text-generation" + mock_provider.return_value = [{"generated_text": [{"content": "model-output"}]}] + return mock_provider @pytest.fixture(autouse=True) def mock_autotokenizer(self, mocker): @@ -59,10 +59,10 @@ def test_try_stream(self, prompt_stack): assert e.value.args[0] == "streaming is not supported" @pytest.mark.parametrize("choices", [[], [1, 2]]) - def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_generator, prompt_stack): + def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_provider, prompt_stack): # Given driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42) - mock_generator.return_value = choices + mock_provider.return_value = choices # When with pytest.raises(Exception) as e: @@ -71,10 +71,10 @@ def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_gener # Then assert e.value.args[0] == "completion with more than one choice is not supported yet" - def test_try_run_throws_when_non_list(self, mock_generator, prompt_stack): + def test_try_run_throws_when_non_list(self, mock_provider, prompt_stack): # Given driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42) - mock_generator.return_value = {} + mock_provider.return_value = {} # When with pytest.raises(Exception) as e: diff --git a/tests/unit/drivers/sql/test_snowflake_sql_driver.py b/tests/unit/drivers/sql/test_snowflake_sql_driver.py index a758bb3a23..c918bb13fb 100644 --- a/tests/unit/drivers/sql/test_snowflake_sql_driver.py +++ b/tests/unit/drivers/sql/test_snowflake_sql_driver.py @@ -63,32 +63,34 @@ def driver(self, mock_snowflake_engine, mock_snowflake_connection): def get_connection(): return mock_snowflake_connection - return SnowflakeSqlDriver(connection_func=get_connection, engine=mock_snowflake_engine) + return SnowflakeSqlDriver(connection_provider=get_connection, engine=mock_snowflake_engine) - def test_connection_function_wrong_return_type(self): + def test_connection_providertion_wrong_return_type(self): def get_connection() -> Any: return object with pytest.raises(ValueError): - SnowflakeSqlDriver(connection_func=get_connection) + SnowflakeSqlDriver(connection_provider=get_connection) def test_connection_validation_no_schema(self, mock_snowflake_connection_no_schema): def get_connection(): return mock_snowflake_connection_no_schema with pytest.raises(ValueError): - SnowflakeSqlDriver(connection_func=get_connection) + SnowflakeSqlDriver(connection_provider=get_connection) def test_connection_validation_no_database(self, mock_snowflake_connection_no_database): def get_connection(): return mock_snowflake_connection_no_database with pytest.raises(ValueError): - SnowflakeSqlDriver(connection_func=get_connection) + SnowflakeSqlDriver(connection_provider=get_connection) def test_engine_url_validation_wrong_engine(self, mock_snowflake_connection): with pytest.raises(ValueError): - SnowflakeSqlDriver(connection_func=mock_snowflake_connection, engine=create_engine("sqlite:///:memory:")) + SnowflakeSqlDriver( + connection_provider=mock_snowflake_connection, engine=create_engine("sqlite:///:memory:") + ) def test_execute_query(self, driver): assert driver.execute_query("query") == [ diff --git a/tests/unit/drivers/structure_run/test_local_structure_run_driver.py b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py index 2dd68e24e3..c4edbfeb3b 100644 --- a/tests/unit/drivers/structure_run/test_local_structure_run_driver.py +++ b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py @@ -9,7 +9,7 @@ class TestLocalStructureRunDriver: def test_run(self): pipeline = Pipeline() - driver = LocalStructureRunDriver(structure_factory_fn=lambda: Agent()) + driver = LocalStructureRunDriver(structure_factory=lambda: Agent()) task = StructureRunTask(driver=driver) @@ -22,7 +22,7 @@ def test_run_with_env(self, mock_config): mock_config.drivers_config.prompt_driver = MockPromptDriver(mock_output=lambda _: os.environ["KEY"]) agent = Agent() - driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent, env={"KEY": "value"}) + driver = LocalStructureRunDriver(structure_factory=lambda: agent, env={"KEY": "value"}) task = StructureRunTask(driver=driver) pipeline.add_task(task) diff --git a/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py b/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py index 430f67ef92..61611904ba 100644 --- a/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py +++ b/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py @@ -15,7 +15,7 @@ def test_run(self, module): assert module.run(RagContext(query="test")).value == "mock output" def test_prompt(self, module): - system_message = module.default_system_template_generator( + system_message = module.default_system_template_provider( RagContext(query="test", before_query=["*RULESET*", "*META*"]), artifacts=[ TextArtifact("*TEXT SEGMENT 1*", reference=Reference(title="source 1")), diff --git a/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py b/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py index cc8d35f0e1..78b2aeace7 100644 --- a/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py +++ b/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py @@ -20,7 +20,7 @@ def test_run(self, module): assert module.run(RagContext(query="test")).value == "mock output" def test_prompt(self, module): - system_message = module.default_system_template_generator( + system_message = module.default_system_template_provider( RagContext(query="test"), artifacts=[TextArtifact("*TEXT SEGMENT 1*"), TextArtifact("*TEXT SEGMENT 2*")], ) diff --git a/tests/unit/mixins/test_futures_executor_mixin.py b/tests/unit/mixins/test_futures_executor_mixin.py index 3be336687b..d31c7bd63d 100644 --- a/tests/unit/mixins/test_futures_executor_mixin.py +++ b/tests/unit/mixins/test_futures_executor_mixin.py @@ -7,4 +7,4 @@ class TestFuturesExecutorMixin: def test_futures_executor(self): executor = futures.ThreadPoolExecutor() - assert MockFuturesExecutor(futures_executor_fn=lambda: executor).futures_executor == executor + assert MockFuturesExecutor(futures_executor_factory=lambda: executor).futures_executor == executor diff --git a/tests/unit/structures/test_pipeline.py b/tests/unit/structures/test_pipeline.py index f86c6330a1..7856976a38 100644 --- a/tests/unit/structures/test_pipeline.py +++ b/tests/unit/structures/test_pipeline.py @@ -18,14 +18,14 @@ def fn(task): time.sleep(2) return TextArtifact("done") - return CodeExecutionTask(run_fn=fn) + return CodeExecutionTask(run_handler=fn) @pytest.fixture() def error_artifact_task(self): def fn(task): return ErrorArtifact("error") - return CodeExecutionTask(run_fn=fn) + return CodeExecutionTask(run_handler=fn) def test_init(self): pipeline = Pipeline(rulesets=[Ruleset("TestRuleset", [Rule("test")])]) diff --git a/tests/unit/structures/test_workflow.py b/tests/unit/structures/test_workflow.py index d3fb17906e..2f15593d53 100644 --- a/tests/unit/structures/test_workflow.py +++ b/tests/unit/structures/test_workflow.py @@ -17,14 +17,14 @@ def fn(task): time.sleep(2) return TextArtifact("done") - return CodeExecutionTask(run_fn=fn) + return CodeExecutionTask(run_handler=fn) @pytest.fixture() def error_artifact_task(self): def fn(task): return ErrorArtifact("error") - return CodeExecutionTask(run_fn=fn) + return CodeExecutionTask(run_handler=fn) def test_init(self): workflow = Workflow(rulesets=[Ruleset("TestRuleset", [Rule("test")])]) diff --git a/tests/unit/tasks/test_code_execution_task.py b/tests/unit/tasks/test_code_execution_task.py index f0eb37ede0..c385fb9967 100644 --- a/tests/unit/tasks/test_code_execution_task.py +++ b/tests/unit/tasks/test_code_execution_task.py @@ -21,7 +21,7 @@ def deliberate_exception(task: CodeExecutionTask) -> BaseArtifact: class TestCodeExecutionTask: def test_hello_world_fn(self): - task = CodeExecutionTask(run_fn=hello_world) + task = CodeExecutionTask(run_handler=hello_world) assert task.run().value == "Hello World!" @@ -29,13 +29,13 @@ def test_hello_world_fn(self): # Overriding the input because we are implementing the task not the Pipeline def test_noop_fn(self): pipeline = Pipeline() - task = CodeExecutionTask("No Op", run_fn=non_outputting) + task = CodeExecutionTask("No Op", run_handler=non_outputting) pipeline.add_task(task) temp = task.run() assert temp.value == "No Op" def test_error_fn(self): - task = CodeExecutionTask(run_fn=deliberate_exception) + task = CodeExecutionTask(run_handler=deliberate_exception) with pytest.raises(ValueError): task.run() diff --git a/tests/unit/tasks/test_structure_run_task.py b/tests/unit/tasks/test_structure_run_task.py index 2973c4a050..50c8e1a788 100644 --- a/tests/unit/tasks/test_structure_run_task.py +++ b/tests/unit/tasks/test_structure_run_task.py @@ -10,7 +10,7 @@ def test_run_single_input(self, mock_config): agent = Agent() mock_config.drivers_config.prompt_driver = MockPromptDriver(mock_output="pipeline mock output") pipeline = Pipeline() - driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent) + driver = LocalStructureRunDriver(structure_factory=lambda: agent) task = StructureRunTask(driver=driver) @@ -23,7 +23,7 @@ def test_run_multiple_inputs(self, mock_config): agent = Agent() mock_config.drivers_config.prompt_driver = MockPromptDriver(mock_output="pipeline mock output") pipeline = Pipeline() - driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent) + driver = LocalStructureRunDriver(structure_factory=lambda: agent) task = StructureRunTask(input=["foo", "bar", "baz"], driver=driver) diff --git a/tests/unit/tasks/test_tool_task.py b/tests/unit/tasks/test_tool_task.py index cb2a6b3415..88e869e332 100644 --- a/tests/unit/tasks/test_tool_task.py +++ b/tests/unit/tasks/test_tool_task.py @@ -226,7 +226,7 @@ def test_meta_memory(self): agent.add_task(task) - system_template = task.generate_system_template(task) + system_template = task.system_template_provider(task) assert "You have access to additional contextual information" in system_template diff --git a/tests/unit/tasks/test_toolkit_task.py b/tests/unit/tasks/test_toolkit_task.py index 24c715423c..d73e34a21f 100644 --- a/tests/unit/tasks/test_toolkit_task.py +++ b/tests/unit/tasks/test_toolkit_task.py @@ -365,7 +365,7 @@ def test_meta_memory(self): agent.add_task(task) - system_template = task.generate_system_template(PromptTask()) + system_template = task.system_template_provider(PromptTask()) assert "You have access to additional contextual information" in system_template diff --git a/tests/unit/tools/test_structure_run_tool.py b/tests/unit/tools/test_structure_run_tool.py index f62cdeea73..e5a1f418fd 100644 --- a/tests/unit/tools/test_structure_run_tool.py +++ b/tests/unit/tools/test_structure_run_tool.py @@ -10,9 +10,7 @@ class TestStructureRunTool: def client(self): agent = Agent() - return StructureRunTool( - description="foo bar", driver=LocalStructureRunDriver(structure_factory_fn=lambda: agent) - ) + return StructureRunTool(description="foo bar", driver=LocalStructureRunDriver(structure_factory=lambda: agent)) def test_run_structure(self, client): assert client.run_structure({"values": {"args": "foo bar"}}).value == "mock output" diff --git a/tests/unit/tools/test_variation_image_generation_tool.py b/tests/unit/tools/test_variation_image_generation_tool.py index 5fd3513c1c..b7c55be40e 100644 --- a/tests/unit/tools/test_variation_image_generation_tool.py +++ b/tests/unit/tools/test_variation_image_generation_tool.py @@ -26,7 +26,7 @@ def image_loader(self) -> Mock: return loader @pytest.fixture() - def image_generator(self, image_generation_engine, image_loader) -> VariationImageGenerationTool: + def image_provider(self, image_generation_engine, image_loader) -> VariationImageGenerationTool: return VariationImageGenerationTool(engine=image_generation_engine, image_loader=image_loader) def test_validate_output_configs(self, image_generation_engine, image_loader) -> None: @@ -35,12 +35,12 @@ def test_validate_output_configs(self, image_generation_engine, image_loader) -> engine=image_generation_engine, output_dir="test", output_file="test", image_loader=image_loader ) - def test_image_variation(self, image_generator, path_from_resource_path) -> None: - image_generator.engine.run.return_value = Mock( + def test_image_variation(self, image_provider, path_from_resource_path) -> None: + image_provider.engine.run.return_value = Mock( value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt" ) - image_artifact = image_generator.image_variation_from_file( + image_artifact = image_provider.image_variation_from_file( params={ "values": { "prompt": "test prompt", @@ -54,15 +54,15 @@ def test_image_variation(self, image_generator, path_from_resource_path) -> None def test_image_variation_with_outfile(self, image_generation_engine, image_loader, path_from_resource_path) -> None: outfile = f"{tempfile.gettempdir()}/{str(uuid.uuid4())}.png" - image_generator = VariationImageGenerationTool( + image_provider = VariationImageGenerationTool( engine=image_generation_engine, output_file=outfile, image_loader=image_loader ) - image_generator.engine.run.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess] + image_provider.engine.run.return_value = ImageArtifact( # pyright: ignore[reportFunctionMemberAccess] value=b"image data", format="png", width=512, height=512 ) - image_artifact = image_generator.image_variation_from_file( + image_artifact = image_provider.image_variation_from_file( params={ "values": { "prompt": "test prompt", @@ -76,16 +76,16 @@ def test_image_variation_with_outfile(self, image_generation_engine, image_loade assert os.path.exists(outfile) def test_image_variation_from_memory(self, image_generation_engine, image_artifact): - image_generator = VariationImageGenerationTool(engine=image_generation_engine) + image_provider = VariationImageGenerationTool(engine=image_generation_engine) memory = Mock() memory.load_artifacts = Mock(return_value=[image_artifact]) - image_generator.find_input_memory = Mock(return_value=memory) + image_provider.find_input_memory = Mock(return_value=memory) - image_generator.engine.run.return_value = Mock( # pyright: ignore[reportFunctionMemberAccess] + image_provider.engine.run.return_value = Mock( # pyright: ignore[reportFunctionMemberAccess] value=b"image data", format="png", width=512, height=512, model="test model", prompt="test prompt" ) - image_artifact = image_generator.image_variation_from_memory( + image_artifact = image_provider.image_variation_from_memory( params={ "values": { "prompt": "test prompt", diff --git a/tests/unit/tools/test_vector_store_tool.py b/tests/unit/tools/test_vector_store_tool.py index 30596f09f8..1230c52941 100644 --- a/tests/unit/tools/test_vector_store_tool.py +++ b/tests/unit/tools/test_vector_store_tool.py @@ -23,12 +23,12 @@ def test_search_with_namespace(self): assert len(tool1.search({"values": {"query": "test"}})) == 2 assert len(tool2.search({"values": {"query": "test"}})) == 0 - def test_custom_process_query_output_fn(self): + def test_custom_query_output_processor(self): driver = LocalVectorStoreDriver(embedding_driver=MockEmbeddingDriver()) tool1 = VectorStoreTool( description="Test", vector_store_driver=driver, - process_query_output_fn=lambda es: ListArtifact([e.vector for e in es]), + query_output_processor=lambda es: ListArtifact([e.vector for e in es]), query_params={"include_vectors": True}, ) diff --git a/tests/unit/utils/test_chat.py b/tests/unit/utils/test_chat.py index 5a7b4e069d..c24b3eb9ca 100644 --- a/tests/unit/utils/test_chat.py +++ b/tests/unit/utils/test_chat.py @@ -21,8 +21,8 @@ def test_init(self): intro_text="hello...", prompt_prefix="Question: ", response_prefix="Answer: ", - input_fn=input, - output_fn=logging.info, + input_handler=input, + output_handler=logging.info, logger_level=logging.INFO, ) assert chat.structure == agent @@ -31,8 +31,8 @@ def test_init(self): assert chat.intro_text == "hello..." assert chat.prompt_prefix == "Question: " assert chat.response_prefix == "Answer: " - assert callable(chat.input_fn) - assert callable(chat.output_fn) + assert callable(chat.input_handler) + assert callable(chat.output_handler) assert chat.logger_level == logging.INFO @patch("builtins.input", side_effect=["exit"]) @@ -57,16 +57,16 @@ def test_chat_prompt(self): @pytest.mark.parametrize("stream", [True, False]) @patch("builtins.input", side_effect=["foo", "exit"]) def test_start(self, mock_input, stream): - mock_output_fn = Mock() + mock_output_handler = Mock() agent = Agent(conversation_memory=ConversationMemory(), stream=stream) - chat = Chat(agent, intro_text="foo", output_fn=mock_output_fn) + chat = Chat(agent, intro_text="foo", output_handler=mock_output_handler) chat.start() mock_input.assert_has_calls([call(), call()]) if stream: - mock_output_fn.assert_has_calls( + mock_output_handler.assert_has_calls( [ call("foo"), call("Thinking..."), @@ -76,7 +76,7 @@ def test_start(self, mock_input, stream): ] ) else: - mock_output_fn.assert_has_calls( + mock_output_handler.assert_has_calls( [ call("foo"), call("Thinking..."), diff --git a/tests/utils/structure_tester.py b/tests/utils/structure_tester.py index c943525b64..a34871013c 100644 --- a/tests/utils/structure_tester.py +++ b/tests/utils/structure_tester.py @@ -224,7 +224,7 @@ class TesterPromptDriverOption: structure: Structure = field() @classmethod - def prompt_driver_id_fn(cls, prompt_driver) -> str: + def generate_prompt_driver_id(cls, prompt_driver) -> str: return f"{prompt_driver.__class__.__name__}-{prompt_driver.model}" def verify_structure_output(self, structure) -> dict: