diff --git a/CHANGELOG.md b/CHANGELOG.md index e5de54e3be..627ba90c6c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Parameter `meta: dict` on `BaseEvent`. +- `TableArtifact` for storing CSV data. ### Changed - **BREAKING**: Drivers, Loaders, and Engines now raise exceptions rather than returning `ErrorArtifacts`. @@ -39,6 +40,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: `BaseConversationMemoryDriver.load` now returns `tuple[list[Run], dict]`. This represents the runs and metadata. - **BREAKING**: `BaseConversationMemoryDriver.store` now takes `runs: list[Run]` and `metadata: dict` as input. - **BREAKING**: Parameter `file_path` on `LocalConversationMemoryDriver` renamed to `persist_file` and is now type `Optional[str]`. +- **BREAKING**: `CsvLoader` now returns a `TableArtifact` instead of a `list[CsvRowArtifact]`. - `Defaults.drivers_config.conversation_memory_driver` now defaults to `LocalConversationMemoryDriver` instead of `None`. - `CsvRowArtifact.to_text()` now includes the header. diff --git a/griptape/artifacts/__init__.py b/griptape/artifacts/__init__.py index a647d05782..2099a78de2 100644 --- a/griptape/artifacts/__init__.py +++ b/griptape/artifacts/__init__.py @@ -7,6 +7,7 @@ from .audio_artifact import AudioArtifact from .json_artifact import JsonArtifact from .action_artifact import ActionArtifact + from .generic_artifact import GenericArtifact from .error_artifact import ErrorArtifact @@ -27,4 +28,5 @@ "AudioArtifact", "ActionArtifact", "GenericArtifact", + "TableArtifact", ] diff --git a/griptape/artifacts/blob_artifact.py b/griptape/artifacts/blob_artifact.py index 49c7a4a802..26e657005e 100644 --- a/griptape/artifacts/blob_artifact.py +++ b/griptape/artifacts/blob_artifact.py @@ -2,7 +2,7 @@ from typing import Any -from attrs import define, field +from attrs import Converter, define, field from griptape.artifacts import BaseSystemArtifact @@ -27,6 +27,7 @@ class BlobArtifact(BaseSystemArtifact): value: bytes = field(converter=value_to_bytes, metadata={"serializable": True}) encoding: str = field(default="utf-8", kw_only=True) encoding_error_handler: str = field(default="strict", kw_only=True) + media_type: str = field(default="application/octet-stream", kw_only=True) def to_bytes(self) -> bytes: return self.value diff --git a/griptape/artifacts/json_artifact.py b/griptape/artifacts/json_artifact.py index b3d45e4d93..fd514ce9c6 100644 --- a/griptape/artifacts/json_artifact.py +++ b/griptape/artifacts/json_artifact.py @@ -3,11 +3,9 @@ import json from typing import Any, Union -from attrs import define, field +from attrs import Converter, define, field -from griptape.artifacts import BaseArtifact - -Json = Union[dict[str, "Json"], list["Json"], str, int, float, bool, None] +from griptape.artifacts.text_artifact import TextArtifact def value_to_json(value: Any) -> Json: diff --git a/griptape/artifacts/list_artifact.py b/griptape/artifacts/list_artifact.py index 00df837898..6b22793a90 100644 --- a/griptape/artifacts/list_artifact.py +++ b/griptape/artifacts/list_artifact.py @@ -9,6 +9,8 @@ if TYPE_CHECKING: from collections.abc import Sequence + from griptape.artifacts import BaseArtifact + @define class ListArtifact(BaseSystemArtifact): diff --git a/griptape/artifacts/table_artifact.py b/griptape/artifacts/table_artifact.py new file mode 100644 index 0000000000..d3ca003fa6 --- /dev/null +++ b/griptape/artifacts/table_artifact.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +import csv +import io +from typing import TYPE_CHECKING, Optional + +from attrs import define, field + +from griptape.artifacts.text_artifact import TextArtifact + +if TYPE_CHECKING: + from collections.abc import Sequence + + +@define +class TableArtifact(TextArtifact): + value: list[dict] = field(factory=list, metadata={"serializable": True}) + delimiter: str = field(default=",", kw_only=True, metadata={"serializable": True}) + fieldnames: Optional[Sequence[str]] = field(factory=list, metadata={"serializable": True}) + quoting: int = field(default=csv.QUOTE_MINIMAL, kw_only=True, metadata={"serializable": True}) + line_terminator: str = field(default="\n", kw_only=True, metadata={"serializable": True}) + + def __bool__(self) -> bool: + return len(self.value) > 0 + + def to_text(self) -> str: + with io.StringIO() as csvfile: + fieldnames = (self.value[0].keys() if self.value else []) if self.fieldnames is None else self.fieldnames + + writer = csv.DictWriter( + csvfile, + fieldnames=fieldnames, + quoting=self.quoting, + delimiter=self.delimiter, + lineterminator=self.line_terminator, + ) + + writer.writeheader() + writer.writerows(self.value) + + return csvfile.getvalue().strip() diff --git a/griptape/loaders/csv_loader.py b/griptape/loaders/csv_loader.py index 20ac237c56..ab00a3216b 100644 --- a/griptape/loaders/csv_loader.py +++ b/griptape/loaders/csv_loader.py @@ -19,33 +19,28 @@ class CsvLoader(BaseLoader): delimiter: str = field(default=",", kw_only=True) encoding: str = field(default="utf-8", kw_only=True) - def load(self, source: bytes | str, *args, **kwargs) -> list[TextArtifact]: - artifacts = [] - + def load(self, source: bytes | str, *args, **kwargs) -> TableArtifact: if isinstance(source, bytes): source = source.decode(encoding=self.encoding) elif isinstance(source, (bytearray, memoryview)): raise ValueError(f"Unsupported source type: {type(source)}") reader = csv.DictReader(StringIO(source), delimiter=self.delimiter) - chunks = [TextArtifact(row, meta={"row": row_num}) for row_num, row in enumerate(reader)] - if self.embedding_driver: - for chunk in chunks: - chunk.generate_embedding(self.embedding_driver) + artifact = TableArtifact(list(reader), delimiter=self.delimiter, fieldnames=reader.fieldnames) - for chunk in chunks: - artifacts.append(chunk) + if self.embedding_driver: + artifact.generate_embedding(self.embedding_driver) - return artifacts + return artifact def load_collection( self, sources: list[bytes | str], *args, **kwargs, - ) -> dict[str, list[TextArtifact]]: + ) -> dict[str, TableArtifact]: return cast( - dict[str, list[TextArtifact]], + dict[str, TableArtifact], super().load_collection(sources, *args, **kwargs), ) diff --git a/griptape/loaders/dataframe_loader.py b/griptape/loaders/dataframe_loader.py index 5ecb35ecd5..5fbbd51d16 100644 --- a/griptape/loaders/dataframe_loader.py +++ b/griptape/loaders/dataframe_loader.py @@ -4,7 +4,7 @@ from attrs import define, field -from griptape.artifacts import TextArtifact +from griptape.artifacts import TableArtifact from griptape.loaders import BaseLoader from griptape.utils import import_optional_dependency from griptape.utils.hash import str_to_hash @@ -19,22 +19,16 @@ class DataFrameLoader(BaseLoader): embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) - def load(self, source: DataFrame, *args, **kwargs) -> list[TextArtifact]: - artifacts = [] - - chunks = [TextArtifact(row) for row in source.to_dict(orient="records")] + def load(self, source: DataFrame, *args, **kwargs) -> TableArtifact: + artifact = TableArtifact(list(source.to_dict(orient="records"))) if self.embedding_driver: - for chunk in chunks: - chunk.generate_embedding(self.embedding_driver) - - for chunk in chunks: - artifacts.append(chunk) + artifact.generate_embedding(self.embedding_driver) - return artifacts + return artifact - def load_collection(self, sources: list[DataFrame], *args, **kwargs) -> dict[str, list[TextArtifact]]: - return cast(dict[str, list[TextArtifact]], super().load_collection(sources, *args, **kwargs)) + def load_collection(self, sources: list[DataFrame], *args, **kwargs) -> dict[str, TableArtifact]: + return cast(dict[str, TableArtifact], super().load_collection(sources, *args, **kwargs)) def to_key(self, source: DataFrame, *args, **kwargs) -> str: hash_pandas_object = import_optional_dependency("pandas.core.util.hashing").hash_pandas_object diff --git a/griptape/loaders/sql_loader.py b/griptape/loaders/sql_loader.py index 14320911eb..d723a5a9fb 100644 --- a/griptape/loaders/sql_loader.py +++ b/griptape/loaders/sql_loader.py @@ -4,7 +4,7 @@ from attrs import define, field -from griptape.artifacts.text_artifact import TextArtifact +from griptape.artifacts import TableArtifact from griptape.loaders import BaseLoader if TYPE_CHECKING: @@ -16,20 +16,14 @@ class SqlLoader(BaseLoader): sql_driver: BaseSqlDriver = field(kw_only=True) embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True) - def load(self, source: str, *args, **kwargs) -> list[TextArtifact]: + def load(self, source: str, *args, **kwargs) -> TableArtifact: rows = self.sql_driver.execute_query(source) - artifacts = [] - - chunks = [TextArtifact(row.cells, meta={"row": row_num}) for row_num, row in enumerate(rows)] if rows else [] + artifact = TableArtifact([row.cells for row in rows] if rows else []) if self.embedding_driver: - for chunk in chunks: - chunk.generate_embedding(self.embedding_driver) - - for chunk in chunks: - artifacts.append(chunk) + artifact.generate_embedding(self.embedding_driver) - return artifacts + return artifact - def load_collection(self, sources: list[str], *args, **kwargs) -> dict[str, list[TextArtifact]]: - return cast(dict[str, list[TextArtifact]], super().load_collection(sources, *args, **kwargs)) + def load_collection(self, sources: list[str], *args, **kwargs) -> dict[str, TableArtifact]: + return cast(dict[str, TableArtifact], super().load_collection(sources, *args, **kwargs)) diff --git a/griptape/templates/memory/tool.j2 b/griptape/templates/memory/tool.j2 index d58b26f5d7..10214b91d0 100644 --- a/griptape/templates/memory/tool.j2 +++ b/griptape/templates/memory/tool.j2 @@ -1 +1 @@ -Output of "{{ tool_name }}.{{ activity_name }}" was stored in memory with memory_name "{{ memory_name }}" and artifact_namespace "{{ artifact_namespace }}" \ No newline at end of file +Output of "{{ tool_name }}.{{ activity_name }}" was stored in memory with memory_name "{{ memory_name }}" and artifact_namespace "{{ artifact_namespace }}" diff --git a/griptape/tools/sql/tool.py b/griptape/tools/sql/tool.py index a84bb87bed..aca41b4ac6 100644 --- a/griptape/tools/sql/tool.py +++ b/griptape/tools/sql/tool.py @@ -5,11 +5,12 @@ from attrs import define, field from schema import Schema -from griptape.artifacts import ErrorArtifact, InfoArtifact, ListArtifact +from griptape.artifacts import ErrorArtifact, InfoArtifact from griptape.tools import BaseTool from griptape.utils.decorators import activity if TYPE_CHECKING: + from griptape.artifacts import TableArtifact from griptape.loaders import SqlLoader @@ -43,7 +44,7 @@ def table_schema(self) -> Optional[str]: "schema": Schema({"sql_query": str}), }, ) - def execute_query(self, params: dict) -> ListArtifact | InfoArtifact | ErrorArtifact: + def execute_query(self, params: dict) -> TableArtifact | InfoArtifact | ErrorArtifact: try: query = params["values"]["sql_query"] rows = self.sql_loader.load(query) @@ -51,6 +52,6 @@ def execute_query(self, params: dict) -> ListArtifact | InfoArtifact | ErrorArti return ErrorArtifact(f"error executing query: {e}") if len(rows) > 0: - return ListArtifact(rows) + return rows else: return InfoArtifact("No results found") diff --git a/tests/unit/loaders/test_csv_loader.py b/tests/unit/loaders/test_csv_loader.py index d08b939197..cfe251ff32 100644 --- a/tests/unit/loaders/test_csv_loader.py +++ b/tests/unit/loaders/test_csv_loader.py @@ -11,11 +11,11 @@ def loader(self, request): if encoding is None: return CsvLoader(embedding_driver=MockEmbeddingDriver()) else: - return CsvLoader(embedding_driver=MockEmbeddingDriver(), encoding=encoding) + return CsvLoader(encoding=encoding, embedding_driver=MockEmbeddingDriver()) @pytest.fixture() def loader_with_pipe_delimiter(self): - return CsvLoader(embedding_driver=MockEmbeddingDriver(), delimiter="|") + return CsvLoader(delimiter="|", embedding_driver=MockEmbeddingDriver()) @pytest.fixture(params=["bytes_from_resource_path", "str_from_resource_path"]) def create_source(self, request): @@ -24,7 +24,7 @@ def create_source(self, request): def test_load(self, loader, create_source): source = create_source("test-1.csv") - artifacts = loader.load(source) + artifact = loader.load(source) assert len(artifacts) == 10 first_artifact = artifacts[0] @@ -34,7 +34,7 @@ def test_load(self, loader, create_source): def test_load_delimiter(self, loader_with_pipe_delimiter, create_source): source = create_source("test-pipe.csv") - artifacts = loader_with_pipe_delimiter.load(source) + artifact = loader_with_pipe_delimiter.load(source) assert len(artifacts) == 10 first_artifact = artifacts[0] diff --git a/tests/unit/loaders/test_sql_loader.py b/tests/unit/loaders/test_sql_loader.py index 2ff6c7fafd..16ca2af5ae 100644 --- a/tests/unit/loaders/test_sql_loader.py +++ b/tests/unit/loaders/test_sql_loader.py @@ -35,14 +35,14 @@ def loader(self): return sql_loader def test_load(self, loader): - artifacts = loader.load("SELECT * FROM test_table;") + artifact = loader.load("SELECT * FROM test_table;") assert len(artifacts) == 3 assert artifacts[0].value == "id: 1\nname: Alice\nage: 25\ncity: New York" assert artifacts[1].value == "id: 2\nname: Bob\nage: 30\ncity: Los Angeles" assert artifacts[2].value == "id: 3\nname: Charlie\nage: 22\ncity: Chicago" - assert artifacts[0].embedding == [0, 1] + assert artifact.embedding == [0, 1] def test_load_collection(self, loader): sources = ["SELECT * FROM test_table LIMIT 1;", "SELECT * FROM test_table LIMIT 2;"]