Skip to content

Commit

Permalink
Change CsvExtractionEngine behavior (#1334)
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter authored Nov 12, 2024
1 parent d0a1891 commit c746174
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 33 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- `TrafilaturaWebScraperDriver.no_ssl` parameter to disable SSL verification. Defaults to `False`.
- `CsvExtractionEngine.format_header` parameter to format the header row.

### Changed

Expand All @@ -18,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **BREAKING**: Renamed `BaseTask.is_executing()` to `BaseTask.is_running()`.
- **BREAKING**: Renamed `Structure.is_executing()` to `Structure.is_running()`.
- **BREAKING**: Removed ability to pass bytes to `BaseFileLoader.fetch`.
- **BREAKING**: Updated `CsvExtractionEngine.format_row` to format rows as comma-separated values instead of newline-separated key-value pairs.
- Improved `CsvExtractionEngine` prompts.
- Removed `azure-core` and `azure-storage-blob` dependencies.
- `GriptapeCloudFileManagerDriver` no longer requires `drivers-file-manager-griptape-cloud` extra.
- `TrafilaturaWebScraperDriver` no longer sets `no_ssl` to `True` by default.
Expand Down
11 changes: 6 additions & 5 deletions docs/griptape-framework/engines/extraction-engines.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ search:

## Overview

Extraction Engines in Griptape facilitate the extraction of data from text formats such as CSV and JSON.
These engines play a crucial role in the functionality of [Extraction Tasks](../../griptape-framework/structures/tasks.md).
Extraction Engines enable the extraction of structured data from unstructured text.
These Engines can be used with Structures via [Extraction Tasks](../../griptape-framework/structures/tasks.md).
As of now, Griptape supports two types of Extraction Engines: the CSV Extraction Engine and the JSON Extraction Engine.

## CSV

The CSV Extraction Engine extracts tabular content from unstructured text.
The CSV Extraction Engine extracts comma separated values from unstructured text.

```python
--8<-- "docs/griptape-framework/engines/src/extraction_engines_1.py"
Expand All @@ -21,12 +21,13 @@ The CSV Extraction Engine extracts tabular content from unstructured text.
name,age,location
Alice,28,New York
Bob,35,California
Charlie,40,Texas
Charlie,40,
Collin,28,San Francisco
```

## JSON

The JSON Extraction Engine extracts JSON-formatted content from unstructured text.
The JSON Extraction Engine extracts JSON-formatted values from unstructured text.

```python
--8<-- "docs/griptape-framework/engines/src/extraction_engines_2.py"
Expand Down
8 changes: 5 additions & 3 deletions docs/griptape-framework/engines/src/extraction_engines_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@

# Define some unstructured data
sample_text = """
Alice, 28, lives in New York.
Bob, 35 lives in California.
Charlie is 40 and lives in Texas.
Alice, 28, lives in New
York.
Bob, 35 lives in California
Charlie is 40
Collin is 28 and lives in San Francisco
"""

# Extract CSV rows using the engine
Expand Down
15 changes: 6 additions & 9 deletions griptape/engines/extraction/csv_extraction_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ class CsvExtractionEngine(BaseExtractionEngine):
column_names: list[str] = field(kw_only=True)
generate_system_template: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/system.j2")), kw_only=True)
generate_user_template: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/user.j2")), kw_only=True)
format_row: Callable[[dict], str] = field(
default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True
)
format_header: Callable[[list[str]], str] = field(default=lambda value: ",".join(value), kw_only=True)
format_row: Callable[[dict], str] = field(default=lambda value: ",".join(value.values()), kw_only=True)

def extract_artifacts(
self,
Expand All @@ -34,18 +33,18 @@ def extract_artifacts(
return ListArtifact(
self._extract_rec(
cast(list[TextArtifact], artifacts.value),
[],
[TextArtifact(self.format_header(self.column_names))],
rulesets=rulesets,
),
item_separator="\n",
)

def text_to_csv_rows(self, text: str, column_names: list[str]) -> list[TextArtifact]:
def text_to_csv_rows(self, text: str) -> list[TextArtifact]:
rows = []

with io.StringIO(text) as f:
for row in csv.reader(f):
rows.append(TextArtifact(self.format_row(dict(zip(column_names, [x.strip() for x in row])))))
for row in csv.DictReader(f):
rows.append(TextArtifact(self.format_row(row)))

return rows

Expand Down Expand Up @@ -79,7 +78,6 @@ def _extract_rec(
]
)
).value,
self.column_names,
),
)

Expand All @@ -100,7 +98,6 @@ def _extract_rec(
]
)
).value,
self.column_names,
),
)

Expand Down
2 changes: 1 addition & 1 deletion griptape/templates/engines/extraction/csv/system.j2
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Don't add the header row. Don't use markdown formatting for output. Fields containing line breaks (CRLF), double quotes, and commas should be enclosed in double-quotes.
Always add the header row. Don't use markdown formatting for output. Fields containing line breaks (CRLF), double quotes, and commas should be enclosed in double-quotes.
Column Names: """{{ column_names }}"""

{% if rulesets %}
Expand Down
18 changes: 11 additions & 7 deletions tests/unit/engines/extraction/test_csv_extraction_engine.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
import pytest

from griptape.engines import CsvExtractionEngine
from tests.mocks.mock_prompt_driver import MockPromptDriver


class TestCsvExtractionEngine:
@pytest.fixture()
def engine(self):
return CsvExtractionEngine(column_names=["test1"])
return CsvExtractionEngine(
column_names=["header"], prompt_driver=MockPromptDriver(mock_output="header\nmock output")
)

def test_extract_text(self, engine):
result = engine.extract_text("foo")
result = engine.extract_text("mock output")

assert len(result.value) == 1
assert result.value[0].value == "test1: mock output"
assert len(result.value) == 2
assert result.value[0].value == "header"
assert result.value[1].value == "mock output"

def test_text_to_csv_rows(self, engine):
result = engine.text_to_csv_rows("foo,bar\nbaz,maz", ["test1", "test2"])
result = engine.text_to_csv_rows("key,value\nfoo,bar\nbaz,maz")

assert len(result) == 2
assert result[0].value == "test1: foo\ntest2: bar"
assert result[1].value == "test1: baz\ntest2: maz"
assert result[0].value == "foo,bar"
assert result[1].value == "baz,maz"
12 changes: 9 additions & 3 deletions tests/unit/tasks/test_extraction_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,17 @@
from griptape.engines import CsvExtractionEngine
from griptape.structures import Agent
from griptape.tasks import ExtractionTask
from tests.mocks.mock_prompt_driver import MockPromptDriver


class TestExtractionTask:
@pytest.fixture()
def task(self):
return ExtractionTask(extraction_engine=CsvExtractionEngine(column_names=["test1"]))
return ExtractionTask(
extraction_engine=CsvExtractionEngine(
column_names=["test1"], prompt_driver=MockPromptDriver(mock_output="header\nmock output")
)
)

def test_run(self, task):
agent = Agent()
Expand All @@ -17,5 +22,6 @@ def test_run(self, task):

result = task.run()

assert len(result.value) == 1
assert result.value[0].value == "test1: mock output"
assert len(result.value) == 2
assert result.value[0].value == "test1"
assert result.value[1].value == "mock output"
12 changes: 7 additions & 5 deletions tests/unit/tools/test_extraction_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def csv_tool(self):
return ExtractionTool(
input_memory=[defaults.text_task_memory("TestMemory")],
extraction_engine=CsvExtractionEngine(
prompt_driver=MockPromptDriver(),
prompt_driver=MockPromptDriver(mock_output="header\nmock output"),
column_names=["test1"],
),
)
Expand Down Expand Up @@ -57,11 +57,13 @@ def test_csv_extract_artifacts(self, csv_tool):
{"values": {"data": {"memory_name": csv_tool.input_memory[0].name, "artifact_namespace": "foo"}}}
)

assert len(result.value) == 1
assert result.value[0].value == "test1: mock output"
assert len(result.value) == 2
assert result.value[0].value == "test1"
assert result.value[1].value == "mock output"

def test_csv_extract_content(self, csv_tool):
result = csv_tool.extract({"values": {"data": "foo"}})

assert len(result.value) == 1
assert result.value[0].value == "test1: mock output"
assert len(result.value) == 2
assert result.value[0].value == "test1"
assert result.value[1].value == "mock output"

0 comments on commit c746174

Please sign in to comment.