Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standardize callables #1275

Merged
merged 1 commit into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,27 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **BREAKING**: Renamed `BaseTask.can_execute` to `BaseTool.can_run`.
- **BREAKING**: Renamed `BaseTool.run` to `BaseTool.try_run`.
- **BREAKING**: Renamed `BaseTool.execute` to `BaseTool.run`.
- **BREAKING**: Renamed callables throughout the framework for consistency:
- Renamed `LocalStructureRunDriver.structure_factory_fn` to `LocalStructureRunDriver.create_structure`.
- Renamed `SnowflakeSqlDriver.connection_func` to `SnowflakeSqlDriver.get_connection`.
- Renamed `CsvLoader.formatter_fn` to `CsvLoader.format_row`.
- Renamed `SqlLoader.formatter_fn` to `SqlLoader.format_row`.
- Renamed `CsvExtractionEngine.system_template_generator` to `CsvExtractionEngine.generate_system_template`.
- Renamed `CsvExtractionEngine.user_template_generator` to `CsvExtractionEngine.generate_user_template`.
- Renamed `JsonExtractionEngine.system_template_generator` to `JsonExtractionEngine.generate_system_template`.
- Renamed `JsonExtractionEngine.user_template_generator` to `JsonExtractionEngine.generate_user_template`.
- Renamed `PromptResponseRagModule.generate_system_template` to `PromptResponseRagModule.generate_system_template`.
- Renamed `PromptTask.generate_system_template` to `PromptTask.generate_system_template`.
- Renamed `ToolkitTask.generate_assistant_subtask_template` to `ToolkitTask.generate_assistant_subtask_template`.
- Renamed `JsonSchemaRule.template_generator` to `JsonSchemaRule.generate_template`.
- Renamed `ToolkitTask.generate_user_subtask_template` to `ToolkitTask.generate_user_subtask_template`.
- Renamed `TextLoaderRetrievalRagModule.process_query_output_fn` to `TextLoaderRetrievalRagModule.process_query_output`.
- Renamed `FuturesExecutorMixin.futures_executor_fn` to `FuturesExecutorMixin.create_futures_executor`.
- Renamed `VectorStoreTool.process_query_output_fn` to `VectorStoreTool.process_query_output`.
- Renamed `CodeExecutionTask.run_fn` to `CodeExecutionTask.on_run`.
- Renamed `Chat.input_fn` to `Chat.handle_input`.
- Renamed `Chat.output_fn` to `Chat.handle_output`.
- Renamed `EventListener.handler` to `EventListener.on_event`.
- Updated `EventListener.handler` return type to `Optional[BaseEvent | dict]`.
- `BaseTask.parent_outputs` type has changed from `dict[str, str | None]` to `dict[str, BaseArtifact]`.
- `Workflow.context["parent_outputs"]` type has changed from `dict[str, str | None]` to `dict[str, BaseArtifact]`.
Expand Down
5 changes: 5 additions & 0 deletions MIGRATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ Defaults.drivers_config = AnthropicDriversConfig(
)
```

### Renamed Callables

Many callables have been renamed for consistency. Update your code to use the new names using the [CHANGELOG.md](https://github.com/griptape-ai/griptape/pull/1275/files#diff-06572a96a58dc510037d5efa622f9bec8519bc1beab13c9f251e97e657a9d4ed) as the source of truth.


### Removed `CompletionChunkEvent`

`CompletionChunkEvent` has been removed. There is now `BaseChunkEvent` with children `TextChunkEvent` and `ActionChunkEvent`. `BaseChunkEvent` can replace `completion_chunk_event.token` by doing `str(base_chunk_event)`.
Expand Down
4 changes: 2 additions & 2 deletions docs/examples/src/multi_agent_workflow_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def build_writer(role: str, goal: str, backstory: str) -> Agent:
),
id="research",
driver=LocalStructureRunDriver(
structure_factory_fn=build_researcher,
create_structure=build_researcher,
),
),
)
Expand All @@ -150,7 +150,7 @@ def build_writer(role: str, goal: str, backstory: str) -> Agent:
{{ parent_outputs["research"] }}""",
),
driver=LocalStructureRunDriver(
structure_factory_fn=lambda writer=writer: build_writer(
create_structure=lambda writer=writer: build_writer(
role=writer["role"],
goal=writer["goal"],
backstory=writer["backstory"],
Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/drivers/src/sql_drivers_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ def get_snowflake_connection() -> SnowflakeConnection:
)


driver = SnowflakeSqlDriver(connection_func=get_snowflake_connection)
driver = SnowflakeSqlDriver(get_connection=get_snowflake_connection)

driver.execute_query("select * from people;")
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ def build_joke_rewriter() -> Agent:
tasks=[
StructureRunTask(
driver=LocalStructureRunDriver(
structure_factory_fn=build_joke_teller,
create_structure=build_joke_teller,
),
),
StructureRunTask(
("Rewrite this joke: {{ parent_output }}",),
driver=LocalStructureRunDriver(
structure_factory_fn=build_joke_rewriter,
create_structure=build_joke_rewriter,
),
),
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
StructureRunTask(
("Think of a question related to Retrieval Augmented Generation.",),
driver=LocalStructureRunDriver(
structure_factory_fn=lambda: Agent(
create_structure=lambda: Agent(
rules=[
Rule(
value="You are an expert in Retrieval Augmented Generation.",
Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/engines/summary-engines.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Summary engines are used to summarize text and collections of [TextArtifact](../

## Prompt

Used to summarize texts with LLMs. You can set a custom [prompt_driver](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.prompt_driver), [system_template_generator](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.system_template_generator), [user_template_generator](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.user_template_generator), and [chunker](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.chunker).
Used to summarize texts with LLMs. You can set a custom [prompt_driver](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.prompt_driver), [generate_system_template](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.generate_system_template), [generate_user_template](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.generate_user_template), and [chunker](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.chunker).

Use the [summarize_artifacts](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.summarize_artifacts) method to summarize a list of artifacts or [summarize_text](../../reference/griptape/engines/summary/base_summary_engine.md#griptape.engines.summary.base_summary_engine.BaseSummaryEngine.summarize_text) to summarize an arbitrary string.

Expand Down
12 changes: 6 additions & 6 deletions docs/griptape-framework/misc/events.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,21 +159,21 @@ Assistant:
...
```

## `EventListenerDriver.handler` Return Value Behavior
## `EventListenerDriver.on_event` Return Value Behavior

The value that gets returned from the [`EventListener.handler`](../../reference/griptape/events/event_listener.md#griptape.events.event_listener.EventListener.handler) will determine what gets sent to the `event_listener_driver`.
The value that gets returned from the [`EventListener.on_event`](../../reference/griptape/events/event_listener.md#griptape.events.event_listener.EventListener.on_event) will determine what gets sent to the `event_listener_driver`.

### `EventListener.handler` is None
### `EventListener.on_event` is None

By default, the `EventListener.handler` function is `None`. Any events that the `EventListener` is listening for will get sent to the `event_listener_driver` as-is.
By default, the `EventListener.on_event` function is `None`. Any events that the `EventListener` is listening for will get sent to the `event_listener_driver` as-is.

### Return `BaseEvent` or `dict`

You can return a `BaseEvent` or `dict` object from `EventListener.handler`, and it will get sent to the `event_listener_driver`.
You can return a `BaseEvent` or `dict` object from `EventListener.on_event`, and it will get sent to the `event_listener_driver`.

### Return `None`

You can return `None` in the handler function to prevent the event from getting sent to the `event_listener_driver`.
You can return `None` in the on_event function to prevent the event from getting sent to the `event_listener_driver`.

```python
--8<-- "docs/griptape-framework/misc/src/events_no_publish.py"
Expand Down
4 changes: 2 additions & 2 deletions docs/griptape-framework/misc/src/events_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
from griptape.structures import Agent


def handler(event: BaseEvent) -> None:
def on_event(event: BaseEvent) -> None:
print(event.__class__)


EventBus.add_event_listeners(
[
EventListener(
handler,
on_event,
event_types=[
StartTaskEvent,
FinishTaskEvent,
Expand Down
4 changes: 2 additions & 2 deletions docs/griptape-framework/misc/src/events_6.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
from griptape.structures import Agent


def handler(event: BaseEvent) -> None:
def on_event(event: BaseEvent) -> None:
if isinstance(event, StartPromptEvent):
print("Prompt Stack Messages:")
for message in event.prompt_stack.messages:
print(f"{message.role}: {message.to_text()}")


EventBus.add_event_listeners([EventListener(handler=handler, event_types=[StartPromptEvent])])
EventBus.add_event_listeners([EventListener(on_event=on_event, event_types=[StartPromptEvent])])

agent = Agent()

Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/structures/src/tasks_10.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def character_counter(task: CodeExecutionTask) -> BaseArtifact:

pipeline.add_tasks(
# take the first argument from the pipeline `run` method
CodeExecutionTask(run_fn=character_counter),
CodeExecutionTask(on_run=character_counter),
# # take the output from the previous task and insert it into the prompt
PromptTask("{{args[0]}} using {{ parent_output }} characters"),
)
Expand Down
4 changes: 2 additions & 2 deletions docs/griptape-framework/structures/src/tasks_16.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def build_writer() -> Agent:
"""Perform a detailed examination of the newest developments in AI as of 2024.
Pinpoint major trends, breakthroughs, and their implications for various industries.""",
),
driver=LocalStructureRunDriver(structure_factory_fn=build_researcher),
driver=LocalStructureRunDriver(create_structure=build_researcher),
),
StructureRunTask(
(
Expand All @@ -122,7 +122,7 @@ def build_writer() -> Agent:
Keep the tone appealing and use simple language to make it less technical.""",
"{{parent_output}}",
),
driver=LocalStructureRunDriver(structure_factory_fn=build_writer),
driver=LocalStructureRunDriver(create_structure=build_writer),
),
],
)
Expand Down
2 changes: 1 addition & 1 deletion griptape/artifacts/base_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class BaseArtifact(SerializableMixin, ABC):
name: The name of the Artifact. Defaults to the id.
value: The value of the Artifact.
encoding: The encoding to use when encoding/decoding the value.
encoding_error_handler: The error handler to use when encoding/decoding the value.
encoding_error_handler: The error on_event to use when encoding/decoding the value.
"""

id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={"serializable": True})
Expand Down
12 changes: 6 additions & 6 deletions griptape/drivers/sql/snowflake_sql_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,24 @@

@define
class SnowflakeSqlDriver(BaseSqlDriver):
connection_func: Callable[[], SnowflakeConnection] = field(kw_only=True)
get_connection: Callable[[], SnowflakeConnection] = field(kw_only=True)
_engine: Engine = field(default=None, kw_only=True, alias="engine", metadata={"serializable": False})

@connection_func.validator # pyright: ignore[reportFunctionMemberAccess]
def validate_connection_func(self, _: Attribute, connection_func: Callable[[], SnowflakeConnection]) -> None:
snowflake_connection = connection_func()
@get_connection.validator # pyright: ignore[reportFunctionMemberAccess]
def validate_get_connection(self, _: Attribute, get_connection: Callable[[], SnowflakeConnection]) -> None:
snowflake_connection = get_connection()
snowflake = import_optional_dependency("snowflake")

if not isinstance(snowflake_connection, snowflake.connector.SnowflakeConnection):
raise ValueError("The connection_func must return a SnowflakeConnection")
raise ValueError("The get_connection function must return a SnowflakeConnection")
if not snowflake_connection.schema or not snowflake_connection.database:
raise ValueError("Provide a schema and database for the Snowflake connection")

@lazy_property()
def engine(self) -> Engine:
return import_optional_dependency("sqlalchemy").create_engine(
"snowflake://not@used/db",
creator=self.connection_func,
creator=self.get_connection,
)

def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]:
Expand Down
8 changes: 4 additions & 4 deletions griptape/drivers/structure_run/local_structure_run_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@

@define
class LocalStructureRunDriver(BaseStructureRunDriver):
structure_factory_fn: Callable[[], Structure] = field(kw_only=True)
create_structure: Callable[[], Structure] = field(kw_only=True)

def try_run(self, *args: BaseArtifact) -> BaseArtifact:
old_env = os.environ.copy()
try:
os.environ.update(self.env)
structure_factory_fn = self.structure_factory_fn().run(*[arg.value for arg in args])
structure = self.create_structure().run(*[arg.value for arg in args])
finally:
os.environ.clear()
os.environ.update(old_env)

if structure_factory_fn.output_task.output is not None:
return structure_factory_fn.output_task.output
if structure.output_task.output is not None:
return structure.output_task.output
else:
return InfoArtifact("No output found in response")
4 changes: 2 additions & 2 deletions griptape/drivers/vector/local_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
class LocalVectorStoreDriver(BaseVectorStoreDriver):
entries: dict[str, BaseVectorStoreDriver.Entry] = field(factory=dict)
persist_file: Optional[str] = field(default=None)
relatedness_fn: Callable = field(default=lambda x, y: dot(x, y) / (norm(x) * norm(y)))
calculate_relatedness: Callable = field(default=lambda x, y: dot(x, y) / (norm(x) * norm(y)))
thread_lock: threading.Lock = field(default=Factory(lambda: threading.Lock()))

def __attrs_post_init__(self) -> None:
Expand Down Expand Up @@ -95,7 +95,7 @@ def query(
entries = self.entries

entries_and_relatednesses = [
(entry, self.relatedness_fn(query_embedding, entry.vector)) for entry in list(entries.values())
(entry, self.calculate_relatedness(query_embedding, entry.vector)) for entry in list(entries.values())
]

entries_and_relatednesses.sort(key=operator.itemgetter(1), reverse=True)
Expand Down
14 changes: 7 additions & 7 deletions griptape/engines/extraction/csv_extraction_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
@define
class CsvExtractionEngine(BaseExtractionEngine):
column_names: list[str] = field(kw_only=True)
system_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/system.j2")), kw_only=True)
user_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/user.j2")), kw_only=True)
formatter_fn: Callable[[dict], str] = field(
generate_system_template: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/system.j2")), kw_only=True)
generate_user_template: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/user.j2")), kw_only=True)
format_row: Callable[[dict], str] = field(
default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True
)

Expand All @@ -45,7 +45,7 @@

with io.StringIO(text) as f:
for row in csv.reader(f):
rows.append(TextArtifact(self.formatter_fn(dict(zip(column_names, [x.strip() for x in row])))))
rows.append(TextArtifact(self.format_row(dict(zip(column_names, [x.strip() for x in row])))))

return rows

Expand All @@ -57,11 +57,11 @@
rulesets: Optional[list[Ruleset]] = None,
) -> list[TextArtifact]:
artifacts_text = self.chunk_joiner.join([a.value for a in artifacts])
system_prompt = self.system_template_generator.render(
system_prompt = self.generate_system_template.render(
column_names=self.column_names,
rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
)
user_prompt = self.user_template_generator.render(
user_prompt = self.generate_user_template.render(
text=artifacts_text,
)

Expand All @@ -86,7 +86,7 @@
return rows
else:
chunks = self.chunker.chunk(artifacts_text)
partial_text = self.user_template_generator.render(
partial_text = self.generate_user_template.render(

Check warning on line 89 in griptape/engines/extraction/csv_extraction_engine.py

View check run for this annotation

Codecov / codecov/patch

griptape/engines/extraction/csv_extraction_engine.py#L89

Added line #L89 was not covered by tests
text=chunks[0].value,
)

Expand Down
12 changes: 5 additions & 7 deletions griptape/engines/extraction/json_extraction_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@ class JsonExtractionEngine(BaseExtractionEngine):
JSON_PATTERN = r"(?s)[^\[]*(\[.*\])"

template_schema: dict = field(kw_only=True)
system_template_generator: J2 = field(
default=Factory(lambda: J2("engines/extraction/json/system.j2")), kw_only=True
)
user_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/json/user.j2")), kw_only=True)
generate_system_template: J2 = field(default=Factory(lambda: J2("engines/extraction/json/system.j2")), kw_only=True)
generate_user_template: J2 = field(default=Factory(lambda: J2("engines/extraction/json/user.j2")), kw_only=True)

def extract_artifacts(
self,
Expand Down Expand Up @@ -54,11 +52,11 @@ def _extract_rec(
rulesets: Optional[list[Ruleset]] = None,
) -> list[JsonArtifact]:
artifacts_text = self.chunk_joiner.join([a.value for a in artifacts])
system_prompt = self.system_template_generator.render(
system_prompt = self.generate_system_template.render(
json_template_schema=json.dumps(self.template_schema),
rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
)
user_prompt = self.user_template_generator.render(
user_prompt = self.generate_user_template.render(
text=artifacts_text,
)

Expand All @@ -82,7 +80,7 @@ def _extract_rec(
return extractions
else:
chunks = self.chunker.chunk(artifacts_text)
partial_text = self.user_template_generator.render(
partial_text = self.generate_user_template.render(
text=chunks[0].value,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class TranslateQueryRagModule(BaseQueryRagModule):
prompt_driver: BasePromptDriver = field()
language: str = field()
generate_user_template: Callable[[str, str], str] = field(
default=Factory(lambda self: self.default_user_template_generator, takes_self=True),
default=Factory(lambda self: self.default_generate_user_template, takes_self=True),
)

def run(self, context: RagContext) -> RagContext:
Expand All @@ -28,5 +28,5 @@ def run(self, context: RagContext) -> RagContext:

return context

def default_user_template_generator(self, query: str, language: str) -> str:
def default_generate_user_template(self, query: str, language: str) -> str:
return J2("engines/rag/modules/query/translate/user.j2").render(query=query, language=language)
Loading
Loading