Skip to content

Commit

Permalink
Change CsvExtractionEngine behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Nov 11, 2024
1 parent d5fc21e commit 495a7b3
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 28 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- `TrafilaturaWebScraperDriver.no_ssl` parameter to disable SSL verification. Defaults to `False`.
- `CsvExtractionEngine.format_header` parameter to format the header row.

### Changed

Expand All @@ -18,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **BREAKING**: Renamed `BaseTask.is_executing()` to `BaseTask.is_running()`.
- **BREAKING**: Renamed `Structure.is_executing()` to `Structure.is_running()`.
- **BREAKING**: Removed ability to pass bytes to `BaseFileLoader.fetch`.
- **BREAKING**: Updated `CsvExtractionEngine.format_row` to format rows as comma-separated values instead of newline-separated key-value pairs.
- Improved `CsvExtractionEngine` prompts.
- Removed `azure-core` and `azure-storage-blob` dependencies.
- `GriptapeCloudFileManagerDriver` no longer requires `drivers-file-manager-griptape-cloud` extra.
- `TrafilaturaWebScraperDriver` no longer sets `no_ssl` to `True` by default.
Expand Down
8 changes: 5 additions & 3 deletions docs/griptape-framework/engines/src/extraction_engines_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@

# Define some unstructured data
sample_text = """
Alice, 28, lives in New York.
Bob, 35 lives in California.
Charlie is 40 and lives in Texas.
Alice, 28, lives in New
York.
Bob, 35 lives in California
Charlie is 40
Collin is 28 and lives in San Francisco
"""

# Extract CSV rows using the engine
Expand Down
15 changes: 6 additions & 9 deletions griptape/engines/extraction/csv_extraction_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ class CsvExtractionEngine(BaseExtractionEngine):
column_names: list[str] = field(kw_only=True)
generate_system_template: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/system.j2")), kw_only=True)
generate_user_template: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/user.j2")), kw_only=True)
format_row: Callable[[dict], str] = field(
default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True
)
format_header: Callable[[list[str]], str] = field(default=lambda value: ",".join(value), kw_only=True)
format_row: Callable[[dict], str] = field(default=lambda value: ",".join(value.values()), kw_only=True)

def extract_artifacts(
self,
Expand All @@ -34,18 +33,18 @@ def extract_artifacts(
return ListArtifact(
self._extract_rec(
cast(list[TextArtifact], artifacts.value),
[],
[TextArtifact(self.format_header(self.column_names))],
rulesets=rulesets,
),
item_separator="\n",
)

def text_to_csv_rows(self, text: str, column_names: list[str]) -> list[TextArtifact]:
def text_to_csv_rows(self, text: str) -> list[TextArtifact]:
rows = []

with io.StringIO(text) as f:
for row in csv.reader(f):
rows.append(TextArtifact(self.format_row(dict(zip(column_names, [x.strip() for x in row])))))
for row in csv.DictReader(f):
rows.append(TextArtifact(self.format_row(row)))

return rows

Expand Down Expand Up @@ -79,7 +78,6 @@ def _extract_rec(
]
)
).value,
self.column_names,
),
)

Expand All @@ -100,7 +98,6 @@ def _extract_rec(
]
)
).value,
self.column_names,
),
)

Expand Down
2 changes: 1 addition & 1 deletion griptape/templates/engines/extraction/csv/system.j2
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Don't add the header row. Don't use markdown formatting for output. Fields containing line breaks (CRLF), double quotes, and commas should be enclosed in double-quotes.
Always add the header column. Don't use markdown formatting for output. Fields containing line breaks (CRLF), double quotes, and commas should be enclosed in double-quotes.
Column Names: """{{ column_names }}"""

{% if rulesets %}
Expand Down
18 changes: 11 additions & 7 deletions tests/unit/engines/extraction/test_csv_extraction_engine.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
import pytest

from griptape.engines import CsvExtractionEngine
from tests.mocks.mock_prompt_driver import MockPromptDriver


class TestCsvExtractionEngine:
@pytest.fixture()
def engine(self):
return CsvExtractionEngine(column_names=["test1"])
return CsvExtractionEngine(
column_names=["header"], prompt_driver=MockPromptDriver(mock_output="header\nmock output")
)

def test_extract_text(self, engine):
result = engine.extract_text("foo")
result = engine.extract_text("mock output")

assert len(result.value) == 1
assert result.value[0].value == "test1: mock output"
assert len(result.value) == 2
assert result.value[0].value == "header"
assert result.value[1].value == "mock output"

def test_text_to_csv_rows(self, engine):
result = engine.text_to_csv_rows("foo,bar\nbaz,maz", ["test1", "test2"])
result = engine.text_to_csv_rows("key,value\nfoo,bar\nbaz,maz")

assert len(result) == 2
assert result[0].value == "test1: foo\ntest2: bar"
assert result[1].value == "test1: baz\ntest2: maz"
assert result[0].value == "foo,bar"
assert result[1].value == "baz,maz"
12 changes: 9 additions & 3 deletions tests/unit/tasks/test_extraction_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,17 @@
from griptape.engines import CsvExtractionEngine
from griptape.structures import Agent
from griptape.tasks import ExtractionTask
from tests.mocks.mock_prompt_driver import MockPromptDriver


class TestExtractionTask:
@pytest.fixture()
def task(self):
return ExtractionTask(extraction_engine=CsvExtractionEngine(column_names=["test1"]))
return ExtractionTask(
extraction_engine=CsvExtractionEngine(
column_names=["test1"], prompt_driver=MockPromptDriver(mock_output="header\nmock output")
)
)

def test_run(self, task):
agent = Agent()
Expand All @@ -17,5 +22,6 @@ def test_run(self, task):

result = task.run()

assert len(result.value) == 1
assert result.value[0].value == "test1: mock output"
assert len(result.value) == 2
assert result.value[0].value == "test1"
assert result.value[1].value == "mock output"
12 changes: 7 additions & 5 deletions tests/unit/tools/test_extraction_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def csv_tool(self):
return ExtractionTool(
input_memory=[defaults.text_task_memory("TestMemory")],
extraction_engine=CsvExtractionEngine(
prompt_driver=MockPromptDriver(),
prompt_driver=MockPromptDriver(mock_output="header\nmock output"),
column_names=["test1"],
),
)
Expand Down Expand Up @@ -57,11 +57,13 @@ def test_csv_extract_artifacts(self, csv_tool):
{"values": {"data": {"memory_name": csv_tool.input_memory[0].name, "artifact_namespace": "foo"}}}
)

assert len(result.value) == 1
assert result.value[0].value == "test1: mock output"
assert len(result.value) == 2
assert result.value[0].value == "test1"
assert result.value[1].value == "mock output"

def test_csv_extract_content(self, csv_tool):
result = csv_tool.extract({"values": {"data": "foo"}})

assert len(result.value) == 1
assert result.value[0].value == "test1: mock output"
assert len(result.value) == 2
assert result.value[0].value == "test1"
assert result.value[1].value == "mock output"

0 comments on commit 495a7b3

Please sign in to comment.