Skip to content

Commit

Permalink
Standardize callables (#1275)
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter authored Oct 29, 2024
1 parent 90311a2 commit 3020a59
Show file tree
Hide file tree
Showing 72 changed files with 216 additions and 190 deletions.
21 changes: 21 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,27 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **BREAKING**: Renamed `BaseTask.can_execute` to `BaseTool.can_run`.
- **BREAKING**: Renamed `BaseTool.run` to `BaseTool.try_run`.
- **BREAKING**: Renamed `BaseTool.execute` to `BaseTool.run`.
- **BREAKING**: Renamed callables throughout the framework for consistency:
- Renamed `LocalStructureRunDriver.structure_factory_fn` to `LocalStructureRunDriver.create_structure`.
- Renamed `SnowflakeSqlDriver.connection_func` to `SnowflakeSqlDriver.get_connection`.
- Renamed `CsvLoader.formatter_fn` to `CsvLoader.format_row`.
- Renamed `SqlLoader.formatter_fn` to `SqlLoader.format_row`.
- Renamed `CsvExtractionEngine.system_template_generator` to `CsvExtractionEngine.generate_system_template`.
- Renamed `CsvExtractionEngine.user_template_generator` to `CsvExtractionEngine.generate_user_template`.
- Renamed `JsonExtractionEngine.system_template_generator` to `JsonExtractionEngine.generate_system_template`.
- Renamed `JsonExtractionEngine.user_template_generator` to `JsonExtractionEngine.generate_user_template`.
- Renamed `PromptResponseRagModule.generate_system_template` to `PromptResponseRagModule.generate_system_template`.
- Renamed `PromptTask.generate_system_template` to `PromptTask.generate_system_template`.
- Renamed `ToolkitTask.generate_assistant_subtask_template` to `ToolkitTask.generate_assistant_subtask_template`.
- Renamed `JsonSchemaRule.template_generator` to `JsonSchemaRule.generate_template`.
- Renamed `ToolkitTask.generate_user_subtask_template` to `ToolkitTask.generate_user_subtask_template`.
- Renamed `TextLoaderRetrievalRagModule.process_query_output_fn` to `TextLoaderRetrievalRagModule.process_query_output`.
- Renamed `FuturesExecutorMixin.futures_executor_fn` to `FuturesExecutorMixin.create_futures_executor`.
- Renamed `VectorStoreTool.process_query_output_fn` to `VectorStoreTool.process_query_output`.
- Renamed `CodeExecutionTask.run_fn` to `CodeExecutionTask.on_run`.
- Renamed `Chat.input_fn` to `Chat.handle_input`.
- Renamed `Chat.output_fn` to `Chat.handle_output`.
- Renamed `EventListener.handler` to `EventListener.on_event`.
- Updated `EventListener.handler` return type to `Optional[BaseEvent | dict]`.
- `BaseTask.parent_outputs` type has changed from `dict[str, str | None]` to `dict[str, BaseArtifact]`.
- `Workflow.context["parent_outputs"]` type has changed from `dict[str, str | None]` to `dict[str, BaseArtifact]`.
Expand Down
5 changes: 5 additions & 0 deletions MIGRATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ Defaults.drivers_config = AnthropicDriversConfig(
)
```

### Renamed Callables

Many callables have been renamed for consistency. Update your code to use the new names using the [CHANGELOG.md](https://github.com/griptape-ai/griptape/pull/1275/files#diff-06572a96a58dc510037d5efa622f9bec8519bc1beab13c9f251e97e657a9d4ed) as the source of truth.


### Removed `CompletionChunkEvent`

`CompletionChunkEvent` has been removed. There is now `BaseChunkEvent` with children `TextChunkEvent` and `ActionChunkEvent`. `BaseChunkEvent` can replace `completion_chunk_event.token` by doing `str(base_chunk_event)`.
Expand Down
4 changes: 2 additions & 2 deletions docs/examples/src/multi_agent_workflow_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def build_writer(role: str, goal: str, backstory: str) -> Agent:
),
id="research",
driver=LocalStructureRunDriver(
structure_factory_fn=build_researcher,
create_structure=build_researcher,
),
),
)
Expand All @@ -150,7 +150,7 @@ def build_writer(role: str, goal: str, backstory: str) -> Agent:
{{ parent_outputs["research"] }}""",
),
driver=LocalStructureRunDriver(
structure_factory_fn=lambda writer=writer: build_writer(
create_structure=lambda writer=writer: build_writer(
role=writer["role"],
goal=writer["goal"],
backstory=writer["backstory"],
Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/drivers/src/sql_drivers_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ def get_snowflake_connection() -> SnowflakeConnection:
)


driver = SnowflakeSqlDriver(connection_func=get_snowflake_connection)
driver = SnowflakeSqlDriver(get_connection=get_snowflake_connection)

driver.execute_query("select * from people;")
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ def build_joke_rewriter() -> Agent:
tasks=[
StructureRunTask(
driver=LocalStructureRunDriver(
structure_factory_fn=build_joke_teller,
create_structure=build_joke_teller,
),
),
StructureRunTask(
("Rewrite this joke: {{ parent_output }}",),
driver=LocalStructureRunDriver(
structure_factory_fn=build_joke_rewriter,
create_structure=build_joke_rewriter,
),
),
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
StructureRunTask(
("Think of a question related to Retrieval Augmented Generation.",),
driver=LocalStructureRunDriver(
structure_factory_fn=lambda: Agent(
create_structure=lambda: Agent(
rules=[
Rule(
value="You are an expert in Retrieval Augmented Generation.",
Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/engines/summary-engines.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Summary engines are used to summarize text and collections of [TextArtifact](../

## Prompt

Used to summarize texts with LLMs. You can set a custom [prompt_driver](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.prompt_driver), [system_template_generator](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.system_template_generator), [user_template_generator](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.user_template_generator), and [chunker](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.chunker).
Used to summarize texts with LLMs. You can set a custom [prompt_driver](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.prompt_driver), [generate_system_template](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.generate_system_template), [generate_user_template](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.generate_user_template), and [chunker](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.chunker).

Use the [summarize_artifacts](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.summarize_artifacts) method to summarize a list of artifacts or [summarize_text](../../reference/griptape/engines/summary/base_summary_engine.md#griptape.engines.summary.base_summary_engine.BaseSummaryEngine.summarize_text) to summarize an arbitrary string.

Expand Down
12 changes: 6 additions & 6 deletions docs/griptape-framework/misc/events.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,21 +159,21 @@ Assistant:
...
```

## `EventListenerDriver.handler` Return Value Behavior
## `EventListenerDriver.on_event` Return Value Behavior

The value that gets returned from the [`EventListener.handler`](../../reference/griptape/events/event_listener.md#griptape.events.event_listener.EventListener.handler) will determine what gets sent to the `event_listener_driver`.
The value that gets returned from the [`EventListener.on_event`](../../reference/griptape/events/event_listener.md#griptape.events.event_listener.EventListener.on_event) will determine what gets sent to the `event_listener_driver`.

### `EventListener.handler` is None
### `EventListener.on_event` is None

By default, the `EventListener.handler` function is `None`. Any events that the `EventListener` is listening for will get sent to the `event_listener_driver` as-is.
By default, the `EventListener.on_event` function is `None`. Any events that the `EventListener` is listening for will get sent to the `event_listener_driver` as-is.

### Return `BaseEvent` or `dict`

You can return a `BaseEvent` or `dict` object from `EventListener.handler`, and it will get sent to the `event_listener_driver`.
You can return a `BaseEvent` or `dict` object from `EventListener.on_event`, and it will get sent to the `event_listener_driver`.

### Return `None`

You can return `None` in the handler function to prevent the event from getting sent to the `event_listener_driver`.
You can return `None` in the on_event function to prevent the event from getting sent to the `event_listener_driver`.

```python
--8<-- "docs/griptape-framework/misc/src/events_no_publish.py"
Expand Down
4 changes: 2 additions & 2 deletions docs/griptape-framework/misc/src/events_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
from griptape.structures import Agent


def handler(event: BaseEvent) -> None:
def on_event(event: BaseEvent) -> None:
print(event.__class__)


EventBus.add_event_listeners(
[
EventListener(
handler,
on_event,
event_types=[
StartTaskEvent,
FinishTaskEvent,
Expand Down
4 changes: 2 additions & 2 deletions docs/griptape-framework/misc/src/events_6.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
from griptape.structures import Agent


def handler(event: BaseEvent) -> None:
def on_event(event: BaseEvent) -> None:
if isinstance(event, StartPromptEvent):
print("Prompt Stack Messages:")
for message in event.prompt_stack.messages:
print(f"{message.role}: {message.to_text()}")


EventBus.add_event_listeners([EventListener(handler=handler, event_types=[StartPromptEvent])])
EventBus.add_event_listeners([EventListener(on_event=on_event, event_types=[StartPromptEvent])])

agent = Agent()

Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/structures/src/tasks_10.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def character_counter(task: CodeExecutionTask) -> BaseArtifact:

pipeline.add_tasks(
# take the first argument from the pipeline `run` method
CodeExecutionTask(run_fn=character_counter),
CodeExecutionTask(on_run=character_counter),
# # take the output from the previous task and insert it into the prompt
PromptTask("{{args[0]}} using {{ parent_output }} characters"),
)
Expand Down
4 changes: 2 additions & 2 deletions docs/griptape-framework/structures/src/tasks_16.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def build_writer() -> Agent:
"""Perform a detailed examination of the newest developments in AI as of 2024.
Pinpoint major trends, breakthroughs, and their implications for various industries.""",
),
driver=LocalStructureRunDriver(structure_factory_fn=build_researcher),
driver=LocalStructureRunDriver(create_structure=build_researcher),
),
StructureRunTask(
(
Expand All @@ -122,7 +122,7 @@ def build_writer() -> Agent:
Keep the tone appealing and use simple language to make it less technical.""",
"{{parent_output}}",
),
driver=LocalStructureRunDriver(structure_factory_fn=build_writer),
driver=LocalStructureRunDriver(create_structure=build_writer),
),
],
)
Expand Down
2 changes: 1 addition & 1 deletion griptape/artifacts/base_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class BaseArtifact(SerializableMixin, ABC):
name: The name of the Artifact. Defaults to the id.
value: The value of the Artifact.
encoding: The encoding to use when encoding/decoding the value.
encoding_error_handler: The error handler to use when encoding/decoding the value.
encoding_error_handler: The error on_event to use when encoding/decoding the value.
"""

id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={"serializable": True})
Expand Down
12 changes: 6 additions & 6 deletions griptape/drivers/sql/snowflake_sql_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,24 @@

@define
class SnowflakeSqlDriver(BaseSqlDriver):
connection_func: Callable[[], SnowflakeConnection] = field(kw_only=True)
get_connection: Callable[[], SnowflakeConnection] = field(kw_only=True)
_engine: Engine = field(default=None, kw_only=True, alias="engine", metadata={"serializable": False})

@connection_func.validator # pyright: ignore[reportFunctionMemberAccess]
def validate_connection_func(self, _: Attribute, connection_func: Callable[[], SnowflakeConnection]) -> None:
snowflake_connection = connection_func()
@get_connection.validator # pyright: ignore[reportFunctionMemberAccess]
def validate_get_connection(self, _: Attribute, get_connection: Callable[[], SnowflakeConnection]) -> None:
snowflake_connection = get_connection()
snowflake = import_optional_dependency("snowflake")

if not isinstance(snowflake_connection, snowflake.connector.SnowflakeConnection):
raise ValueError("The connection_func must return a SnowflakeConnection")
raise ValueError("The get_connection function must return a SnowflakeConnection")
if not snowflake_connection.schema or not snowflake_connection.database:
raise ValueError("Provide a schema and database for the Snowflake connection")

@lazy_property()
def engine(self) -> Engine:
return import_optional_dependency("sqlalchemy").create_engine(
"snowflake://not@used/db",
creator=self.connection_func,
creator=self.get_connection,
)

def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]:
Expand Down
8 changes: 4 additions & 4 deletions griptape/drivers/structure_run/local_structure_run_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@

@define
class LocalStructureRunDriver(BaseStructureRunDriver):
structure_factory_fn: Callable[[], Structure] = field(kw_only=True)
create_structure: Callable[[], Structure] = field(kw_only=True)

def try_run(self, *args: BaseArtifact) -> BaseArtifact:
old_env = os.environ.copy()
try:
os.environ.update(self.env)
structure_factory_fn = self.structure_factory_fn().run(*[arg.value for arg in args])
structure = self.create_structure().run(*[arg.value for arg in args])
finally:
os.environ.clear()
os.environ.update(old_env)

if structure_factory_fn.output_task.output is not None:
return structure_factory_fn.output_task.output
if structure.output_task.output is not None:
return structure.output_task.output
else:
return InfoArtifact("No output found in response")
4 changes: 2 additions & 2 deletions griptape/drivers/vector/local_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
class LocalVectorStoreDriver(BaseVectorStoreDriver):
entries: dict[str, BaseVectorStoreDriver.Entry] = field(factory=dict)
persist_file: Optional[str] = field(default=None)
relatedness_fn: Callable = field(default=lambda x, y: dot(x, y) / (norm(x) * norm(y)))
calculate_relatedness: Callable = field(default=lambda x, y: dot(x, y) / (norm(x) * norm(y)))
thread_lock: threading.Lock = field(default=Factory(lambda: threading.Lock()))

def __attrs_post_init__(self) -> None:
Expand Down Expand Up @@ -95,7 +95,7 @@ def query(
entries = self.entries

entries_and_relatednesses = [
(entry, self.relatedness_fn(query_embedding, entry.vector)) for entry in list(entries.values())
(entry, self.calculate_relatedness(query_embedding, entry.vector)) for entry in list(entries.values())
]

entries_and_relatednesses.sort(key=operator.itemgetter(1), reverse=True)
Expand Down
14 changes: 7 additions & 7 deletions griptape/engines/extraction/csv_extraction_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
@define
class CsvExtractionEngine(BaseExtractionEngine):
column_names: list[str] = field(kw_only=True)
system_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/system.j2")), kw_only=True)
user_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/user.j2")), kw_only=True)
formatter_fn: Callable[[dict], str] = field(
generate_system_template: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/system.j2")), kw_only=True)
generate_user_template: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/user.j2")), kw_only=True)
format_row: Callable[[dict], str] = field(
default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True
)

Expand All @@ -45,7 +45,7 @@ def text_to_csv_rows(self, text: str, column_names: list[str]) -> list[TextArtif

with io.StringIO(text) as f:
for row in csv.reader(f):
rows.append(TextArtifact(self.formatter_fn(dict(zip(column_names, [x.strip() for x in row])))))
rows.append(TextArtifact(self.format_row(dict(zip(column_names, [x.strip() for x in row])))))

return rows

Expand All @@ -57,11 +57,11 @@ def _extract_rec(
rulesets: Optional[list[Ruleset]] = None,
) -> list[TextArtifact]:
artifacts_text = self.chunk_joiner.join([a.value for a in artifacts])
system_prompt = self.system_template_generator.render(
system_prompt = self.generate_system_template.render(
column_names=self.column_names,
rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
)
user_prompt = self.user_template_generator.render(
user_prompt = self.generate_user_template.render(
text=artifacts_text,
)

Expand All @@ -86,7 +86,7 @@ def _extract_rec(
return rows
else:
chunks = self.chunker.chunk(artifacts_text)
partial_text = self.user_template_generator.render(
partial_text = self.generate_user_template.render(
text=chunks[0].value,
)

Expand Down
12 changes: 5 additions & 7 deletions griptape/engines/extraction/json_extraction_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@ class JsonExtractionEngine(BaseExtractionEngine):
JSON_PATTERN = r"(?s)[^\[]*(\[.*\])"

template_schema: dict = field(kw_only=True)
system_template_generator: J2 = field(
default=Factory(lambda: J2("engines/extraction/json/system.j2")), kw_only=True
)
user_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/json/user.j2")), kw_only=True)
generate_system_template: J2 = field(default=Factory(lambda: J2("engines/extraction/json/system.j2")), kw_only=True)
generate_user_template: J2 = field(default=Factory(lambda: J2("engines/extraction/json/user.j2")), kw_only=True)

def extract_artifacts(
self,
Expand Down Expand Up @@ -54,11 +52,11 @@ def _extract_rec(
rulesets: Optional[list[Ruleset]] = None,
) -> list[JsonArtifact]:
artifacts_text = self.chunk_joiner.join([a.value for a in artifacts])
system_prompt = self.system_template_generator.render(
system_prompt = self.generate_system_template.render(
json_template_schema=json.dumps(self.template_schema),
rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
)
user_prompt = self.user_template_generator.render(
user_prompt = self.generate_user_template.render(
text=artifacts_text,
)

Expand All @@ -82,7 +80,7 @@ def _extract_rec(
return extractions
else:
chunks = self.chunker.chunk(artifacts_text)
partial_text = self.user_template_generator.render(
partial_text = self.generate_user_template.render(
text=chunks[0].value,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class TranslateQueryRagModule(BaseQueryRagModule):
prompt_driver: BasePromptDriver = field()
language: str = field()
generate_user_template: Callable[[str, str], str] = field(
default=Factory(lambda self: self.default_user_template_generator, takes_self=True),
default=Factory(lambda self: self.default_generate_user_template, takes_self=True),
)

def run(self, context: RagContext) -> RagContext:
Expand All @@ -28,5 +28,5 @@ def run(self, context: RagContext) -> RagContext:

return context

def default_user_template_generator(self, query: str, language: str) -> str:
def default_generate_user_template(self, query: str, language: str) -> str:
return J2("engines/rag/modules/query/translate/user.j2").render(query=query, language=language)
Loading

0 comments on commit 3020a59

Please sign in to comment.