Skip to content

Commit

Permalink
Refactor Artifacts
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Sep 4, 2024
1 parent 2904e50 commit 08d6818
Show file tree
Hide file tree
Showing 13 changed files with 83 additions and 53 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- Parameter `meta: dict` on `BaseEvent`.
- `TableArtifact` for storing CSV data.

### Changed
- **BREAKING**: Drivers, Loaders, and Engines now raise exceptions rather than returning `ErrorArtifacts`.
Expand All @@ -39,6 +40,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **BREAKING**: `BaseConversationMemoryDriver.load` now returns `tuple[list[Run], dict]`. This represents the runs and metadata.
- **BREAKING**: `BaseConversationMemoryDriver.store` now takes `runs: list[Run]` and `metadata: dict` as input.
- **BREAKING**: Parameter `file_path` on `LocalConversationMemoryDriver` renamed to `persist_file` and is now type `Optional[str]`.
- **BREAKING**: `CsvLoader` now returns a `TableArtifact` instead of a `list[CsvRowArtifact]`.
- `Defaults.drivers_config.conversation_memory_driver` now defaults to `LocalConversationMemoryDriver` instead of `None`.
- `CsvRowArtifact.to_text()` now includes the header.

Expand Down
2 changes: 2 additions & 0 deletions griptape/artifacts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .audio_artifact import AudioArtifact
from .json_artifact import JsonArtifact
from .action_artifact import ActionArtifact

from .generic_artifact import GenericArtifact

from .error_artifact import ErrorArtifact
Expand All @@ -27,4 +28,5 @@
"AudioArtifact",
"ActionArtifact",
"GenericArtifact",
"TableArtifact",
]
3 changes: 2 additions & 1 deletion griptape/artifacts/blob_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Any

from attrs import define, field
from attrs import Converter, define, field

from griptape.artifacts import BaseSystemArtifact

Expand All @@ -27,6 +27,7 @@ class BlobArtifact(BaseSystemArtifact):
value: bytes = field(converter=value_to_bytes, metadata={"serializable": True})
encoding: str = field(default="utf-8", kw_only=True)
encoding_error_handler: str = field(default="strict", kw_only=True)
media_type: str = field(default="application/octet-stream", kw_only=True)

def to_bytes(self) -> bytes:
return self.value
Expand Down
6 changes: 2 additions & 4 deletions griptape/artifacts/json_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
import json
from typing import Any, Union

from attrs import define, field
from attrs import Converter, define, field

from griptape.artifacts import BaseArtifact

Json = Union[dict[str, "Json"], list["Json"], str, int, float, bool, None]
from griptape.artifacts.text_artifact import TextArtifact


def value_to_json(value: Any) -> Json:
Expand Down
2 changes: 2 additions & 0 deletions griptape/artifacts/list_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
if TYPE_CHECKING:
from collections.abc import Sequence

from griptape.artifacts import BaseArtifact


@define
class ListArtifact(BaseSystemArtifact):
Expand Down
41 changes: 41 additions & 0 deletions griptape/artifacts/table_artifact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from __future__ import annotations

import csv
import io
from typing import TYPE_CHECKING, Optional

from attrs import define, field

from griptape.artifacts.text_artifact import TextArtifact

if TYPE_CHECKING:
from collections.abc import Sequence


@define
class TableArtifact(TextArtifact):
value: list[dict] = field(factory=list, metadata={"serializable": True})
delimiter: str = field(default=",", kw_only=True, metadata={"serializable": True})
fieldnames: Optional[Sequence[str]] = field(factory=list, metadata={"serializable": True})
quoting: int = field(default=csv.QUOTE_MINIMAL, kw_only=True, metadata={"serializable": True})
line_terminator: str = field(default="\n", kw_only=True, metadata={"serializable": True})

def __bool__(self) -> bool:
return len(self.value) > 0

def to_text(self) -> str:
with io.StringIO() as csvfile:
fieldnames = (self.value[0].keys() if self.value else []) if self.fieldnames is None else self.fieldnames

writer = csv.DictWriter(
csvfile,
fieldnames=fieldnames,
quoting=self.quoting,
delimiter=self.delimiter,
lineterminator=self.line_terminator,
)

writer.writeheader()
writer.writerows(self.value)

return csvfile.getvalue().strip()
19 changes: 7 additions & 12 deletions griptape/loaders/csv_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,33 +19,28 @@ class CsvLoader(BaseLoader):
delimiter: str = field(default=",", kw_only=True)
encoding: str = field(default="utf-8", kw_only=True)

def load(self, source: bytes | str, *args, **kwargs) -> list[TextArtifact]:
artifacts = []

def load(self, source: bytes | str, *args, **kwargs) -> TableArtifact:
if isinstance(source, bytes):
source = source.decode(encoding=self.encoding)
elif isinstance(source, (bytearray, memoryview)):
raise ValueError(f"Unsupported source type: {type(source)}")

reader = csv.DictReader(StringIO(source), delimiter=self.delimiter)
chunks = [TextArtifact(row, meta={"row": row_num}) for row_num, row in enumerate(reader)]

if self.embedding_driver:
for chunk in chunks:
chunk.generate_embedding(self.embedding_driver)
artifact = TableArtifact(list(reader), delimiter=self.delimiter, fieldnames=reader.fieldnames)

for chunk in chunks:
artifacts.append(chunk)
if self.embedding_driver:
artifact.generate_embedding(self.embedding_driver)

return artifacts
return artifact

def load_collection(
self,
sources: list[bytes | str],
*args,
**kwargs,
) -> dict[str, list[TextArtifact]]:
) -> dict[str, TableArtifact]:
return cast(
dict[str, list[TextArtifact]],
dict[str, TableArtifact],
super().load_collection(sources, *args, **kwargs),
)
20 changes: 7 additions & 13 deletions griptape/loaders/dataframe_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from attrs import define, field

from griptape.artifacts import TextArtifact
from griptape.artifacts import TableArtifact
from griptape.loaders import BaseLoader
from griptape.utils import import_optional_dependency
from griptape.utils.hash import str_to_hash
Expand All @@ -19,22 +19,16 @@
class DataFrameLoader(BaseLoader):
embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True)

def load(self, source: DataFrame, *args, **kwargs) -> list[TextArtifact]:
artifacts = []

chunks = [TextArtifact(row) for row in source.to_dict(orient="records")]
def load(self, source: DataFrame, *args, **kwargs) -> TableArtifact:
artifact = TableArtifact(list(source.to_dict(orient="records")))

if self.embedding_driver:
for chunk in chunks:
chunk.generate_embedding(self.embedding_driver)

for chunk in chunks:
artifacts.append(chunk)
artifact.generate_embedding(self.embedding_driver)

return artifacts
return artifact

def load_collection(self, sources: list[DataFrame], *args, **kwargs) -> dict[str, list[TextArtifact]]:
return cast(dict[str, list[TextArtifact]], super().load_collection(sources, *args, **kwargs))
def load_collection(self, sources: list[DataFrame], *args, **kwargs) -> dict[str, TableArtifact]:
return cast(dict[str, TableArtifact], super().load_collection(sources, *args, **kwargs))

def to_key(self, source: DataFrame, *args, **kwargs) -> str:
hash_pandas_object = import_optional_dependency("pandas.core.util.hashing").hash_pandas_object
Expand Down
20 changes: 7 additions & 13 deletions griptape/loaders/sql_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from attrs import define, field

from griptape.artifacts.text_artifact import TextArtifact
from griptape.artifacts import TableArtifact
from griptape.loaders import BaseLoader

if TYPE_CHECKING:
Expand All @@ -16,20 +16,14 @@ class SqlLoader(BaseLoader):
sql_driver: BaseSqlDriver = field(kw_only=True)
embedding_driver: Optional[BaseEmbeddingDriver] = field(default=None, kw_only=True)

def load(self, source: str, *args, **kwargs) -> list[TextArtifact]:
def load(self, source: str, *args, **kwargs) -> TableArtifact:
rows = self.sql_driver.execute_query(source)
artifacts = []

chunks = [TextArtifact(row.cells, meta={"row": row_num}) for row_num, row in enumerate(rows)] if rows else []
artifact = TableArtifact([row.cells for row in rows] if rows else [])

if self.embedding_driver:
for chunk in chunks:
chunk.generate_embedding(self.embedding_driver)

for chunk in chunks:
artifacts.append(chunk)
artifact.generate_embedding(self.embedding_driver)

return artifacts
return artifact

def load_collection(self, sources: list[str], *args, **kwargs) -> dict[str, list[TextArtifact]]:
return cast(dict[str, list[TextArtifact]], super().load_collection(sources, *args, **kwargs))
def load_collection(self, sources: list[str], *args, **kwargs) -> dict[str, TableArtifact]:
return cast(dict[str, TableArtifact], super().load_collection(sources, *args, **kwargs))
2 changes: 1 addition & 1 deletion griptape/templates/memory/tool.j2
Original file line number Diff line number Diff line change
@@ -1 +1 @@
Output of "{{ tool_name }}.{{ activity_name }}" was stored in memory with memory_name "{{ memory_name }}" and artifact_namespace "{{ artifact_namespace }}"
Output of "{{ tool_name }}.{{ activity_name }}" was stored in memory with memory_name "{{ memory_name }}" and artifact_namespace "{{ artifact_namespace }}"
7 changes: 4 additions & 3 deletions griptape/tools/sql/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from attrs import define, field
from schema import Schema

from griptape.artifacts import ErrorArtifact, InfoArtifact, ListArtifact
from griptape.artifacts import ErrorArtifact, InfoArtifact
from griptape.tools import BaseTool
from griptape.utils.decorators import activity

if TYPE_CHECKING:
from griptape.artifacts import TableArtifact
from griptape.loaders import SqlLoader


Expand Down Expand Up @@ -43,14 +44,14 @@ def table_schema(self) -> Optional[str]:
"schema": Schema({"sql_query": str}),
},
)
def execute_query(self, params: dict) -> ListArtifact | InfoArtifact | ErrorArtifact:
def execute_query(self, params: dict) -> TableArtifact | InfoArtifact | ErrorArtifact:
try:
query = params["values"]["sql_query"]
rows = self.sql_loader.load(query)
except Exception as e:
return ErrorArtifact(f"error executing query: {e}")

if len(rows) > 0:
return ListArtifact(rows)
return rows
else:
return InfoArtifact("No results found")
8 changes: 4 additions & 4 deletions tests/unit/loaders/test_csv_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ def loader(self, request):
if encoding is None:
return CsvLoader(embedding_driver=MockEmbeddingDriver())
else:
return CsvLoader(embedding_driver=MockEmbeddingDriver(), encoding=encoding)
return CsvLoader(encoding=encoding, embedding_driver=MockEmbeddingDriver())

@pytest.fixture()
def loader_with_pipe_delimiter(self):
return CsvLoader(embedding_driver=MockEmbeddingDriver(), delimiter="|")
return CsvLoader(delimiter="|", embedding_driver=MockEmbeddingDriver())

@pytest.fixture(params=["bytes_from_resource_path", "str_from_resource_path"])
def create_source(self, request):
Expand All @@ -24,7 +24,7 @@ def create_source(self, request):
def test_load(self, loader, create_source):
source = create_source("test-1.csv")

artifacts = loader.load(source)
artifact = loader.load(source)

assert len(artifacts) == 10
first_artifact = artifacts[0]
Expand All @@ -34,7 +34,7 @@ def test_load(self, loader, create_source):
def test_load_delimiter(self, loader_with_pipe_delimiter, create_source):
source = create_source("test-pipe.csv")

artifacts = loader_with_pipe_delimiter.load(source)
artifact = loader_with_pipe_delimiter.load(source)

assert len(artifacts) == 10
first_artifact = artifacts[0]
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/loaders/test_sql_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ def loader(self):
return sql_loader

def test_load(self, loader):
artifacts = loader.load("SELECT * FROM test_table;")
artifact = loader.load("SELECT * FROM test_table;")

assert len(artifacts) == 3
assert artifacts[0].value == "id: 1\nname: Alice\nage: 25\ncity: New York"
assert artifacts[1].value == "id: 2\nname: Bob\nage: 30\ncity: Los Angeles"
assert artifacts[2].value == "id: 3\nname: Charlie\nage: 22\ncity: Chicago"

assert artifacts[0].embedding == [0, 1]
assert artifact.embedding == [0, 1]

def test_load_collection(self, loader):
sources = ["SELECT * FROM test_table LIMIT 1;", "SELECT * FROM test_table LIMIT 2;"]
Expand Down

0 comments on commit 08d6818

Please sign in to comment.