From c746174b6fa41f255b3bee8434f0609c644fd098 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 12 Nov 2024 16:33:53 +0000 Subject: [PATCH] Change CsvExtractionEngine behavior (#1334) --- CHANGELOG.md | 3 +++ .../engines/extraction-engines.md | 11 ++++++----- .../engines/src/extraction_engines_1.py | 8 +++++--- .../extraction/csv_extraction_engine.py | 15 ++++++--------- .../templates/engines/extraction/csv/system.j2 | 2 +- .../extraction/test_csv_extraction_engine.py | 18 +++++++++++------- tests/unit/tasks/test_extraction_task.py | 12 +++++++++--- tests/unit/tools/test_extraction_tool.py | 12 +++++++----- 8 files changed, 48 insertions(+), 33 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2d5491586..f9ea4640c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - `TrafilaturaWebScraperDriver.no_ssl` parameter to disable SSL verification. Defaults to `False`. +- `CsvExtractionEngine.format_header` parameter to format the header row. ### Changed @@ -18,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Renamed `BaseTask.is_executing()` to `BaseTask.is_running()`. - **BREAKING**: Renamed `Structure.is_executing()` to `Structure.is_running()`. - **BREAKING**: Removed ability to pass bytes to `BaseFileLoader.fetch`. +- **BREAKING**: Updated `CsvExtractionEngine.format_row` to format rows as comma-separated values instead of newline-separated key-value pairs. +- Improved `CsvExtractionEngine` prompts. - Removed `azure-core` and `azure-storage-blob` dependencies. - `GriptapeCloudFileManagerDriver` no longer requires `drivers-file-manager-griptape-cloud` extra. - `TrafilaturaWebScraperDriver` no longer sets `no_ssl` to `True` by default. diff --git a/docs/griptape-framework/engines/extraction-engines.md b/docs/griptape-framework/engines/extraction-engines.md index 09b1d5ca1..5d9c80ba2 100644 --- a/docs/griptape-framework/engines/extraction-engines.md +++ b/docs/griptape-framework/engines/extraction-engines.md @@ -5,13 +5,13 @@ search: ## Overview -Extraction Engines in Griptape facilitate the extraction of data from text formats such as CSV and JSON. -These engines play a crucial role in the functionality of [Extraction Tasks](../../griptape-framework/structures/tasks.md). +Extraction Engines enable the extraction of structured data from unstructured text. +These Engines can be used with Structures via [Extraction Tasks](../../griptape-framework/structures/tasks.md). As of now, Griptape supports two types of Extraction Engines: the CSV Extraction Engine and the JSON Extraction Engine. ## CSV -The CSV Extraction Engine extracts tabular content from unstructured text. +The CSV Extraction Engine extracts comma separated values from unstructured text. ```python --8<-- "docs/griptape-framework/engines/src/extraction_engines_1.py" @@ -21,12 +21,13 @@ The CSV Extraction Engine extracts tabular content from unstructured text. name,age,location Alice,28,New York Bob,35,California -Charlie,40,Texas +Charlie,40, +Collin,28,San Francisco ``` ## JSON -The JSON Extraction Engine extracts JSON-formatted content from unstructured text. +The JSON Extraction Engine extracts JSON-formatted values from unstructured text. ```python --8<-- "docs/griptape-framework/engines/src/extraction_engines_2.py" diff --git a/docs/griptape-framework/engines/src/extraction_engines_1.py b/docs/griptape-framework/engines/src/extraction_engines_1.py index 45ccfd3e0..731e53b2b 100644 --- a/docs/griptape-framework/engines/src/extraction_engines_1.py +++ b/docs/griptape-framework/engines/src/extraction_engines_1.py @@ -8,9 +8,11 @@ # Define some unstructured data sample_text = """ -Alice, 28, lives in New York. -Bob, 35 lives in California. -Charlie is 40 and lives in Texas. +Alice, 28, lives in New +York. +Bob, 35 lives in California +Charlie is 40 +Collin is 28 and lives in San Francisco """ # Extract CSV rows using the engine diff --git a/griptape/engines/extraction/csv_extraction_engine.py b/griptape/engines/extraction/csv_extraction_engine.py index 7f4647d65..c27501fa7 100644 --- a/griptape/engines/extraction/csv_extraction_engine.py +++ b/griptape/engines/extraction/csv_extraction_engine.py @@ -20,9 +20,8 @@ class CsvExtractionEngine(BaseExtractionEngine): column_names: list[str] = field(kw_only=True) generate_system_template: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/system.j2")), kw_only=True) generate_user_template: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/user.j2")), kw_only=True) - format_row: Callable[[dict], str] = field( - default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True - ) + format_header: Callable[[list[str]], str] = field(default=lambda value: ",".join(value), kw_only=True) + format_row: Callable[[dict], str] = field(default=lambda value: ",".join(value.values()), kw_only=True) def extract_artifacts( self, @@ -34,18 +33,18 @@ def extract_artifacts( return ListArtifact( self._extract_rec( cast(list[TextArtifact], artifacts.value), - [], + [TextArtifact(self.format_header(self.column_names))], rulesets=rulesets, ), item_separator="\n", ) - def text_to_csv_rows(self, text: str, column_names: list[str]) -> list[TextArtifact]: + def text_to_csv_rows(self, text: str) -> list[TextArtifact]: rows = [] with io.StringIO(text) as f: - for row in csv.reader(f): - rows.append(TextArtifact(self.format_row(dict(zip(column_names, [x.strip() for x in row]))))) + for row in csv.DictReader(f): + rows.append(TextArtifact(self.format_row(row))) return rows @@ -79,7 +78,6 @@ def _extract_rec( ] ) ).value, - self.column_names, ), ) @@ -100,7 +98,6 @@ def _extract_rec( ] ) ).value, - self.column_names, ), ) diff --git a/griptape/templates/engines/extraction/csv/system.j2 b/griptape/templates/engines/extraction/csv/system.j2 index 7c5776257..9e10ff1f9 100644 --- a/griptape/templates/engines/extraction/csv/system.j2 +++ b/griptape/templates/engines/extraction/csv/system.j2 @@ -1,4 +1,4 @@ -Don't add the header row. Don't use markdown formatting for output. Fields containing line breaks (CRLF), double quotes, and commas should be enclosed in double-quotes. +Always add the header row. Don't use markdown formatting for output. Fields containing line breaks (CRLF), double quotes, and commas should be enclosed in double-quotes. Column Names: """{{ column_names }}""" {% if rulesets %} diff --git a/tests/unit/engines/extraction/test_csv_extraction_engine.py b/tests/unit/engines/extraction/test_csv_extraction_engine.py index 36c58788d..c3bff6d7d 100644 --- a/tests/unit/engines/extraction/test_csv_extraction_engine.py +++ b/tests/unit/engines/extraction/test_csv_extraction_engine.py @@ -1,22 +1,26 @@ import pytest from griptape.engines import CsvExtractionEngine +from tests.mocks.mock_prompt_driver import MockPromptDriver class TestCsvExtractionEngine: @pytest.fixture() def engine(self): - return CsvExtractionEngine(column_names=["test1"]) + return CsvExtractionEngine( + column_names=["header"], prompt_driver=MockPromptDriver(mock_output="header\nmock output") + ) def test_extract_text(self, engine): - result = engine.extract_text("foo") + result = engine.extract_text("mock output") - assert len(result.value) == 1 - assert result.value[0].value == "test1: mock output" + assert len(result.value) == 2 + assert result.value[0].value == "header" + assert result.value[1].value == "mock output" def test_text_to_csv_rows(self, engine): - result = engine.text_to_csv_rows("foo,bar\nbaz,maz", ["test1", "test2"]) + result = engine.text_to_csv_rows("key,value\nfoo,bar\nbaz,maz") assert len(result) == 2 - assert result[0].value == "test1: foo\ntest2: bar" - assert result[1].value == "test1: baz\ntest2: maz" + assert result[0].value == "foo,bar" + assert result[1].value == "baz,maz" diff --git a/tests/unit/tasks/test_extraction_task.py b/tests/unit/tasks/test_extraction_task.py index 06d444f9b..10c55b246 100644 --- a/tests/unit/tasks/test_extraction_task.py +++ b/tests/unit/tasks/test_extraction_task.py @@ -3,12 +3,17 @@ from griptape.engines import CsvExtractionEngine from griptape.structures import Agent from griptape.tasks import ExtractionTask +from tests.mocks.mock_prompt_driver import MockPromptDriver class TestExtractionTask: @pytest.fixture() def task(self): - return ExtractionTask(extraction_engine=CsvExtractionEngine(column_names=["test1"])) + return ExtractionTask( + extraction_engine=CsvExtractionEngine( + column_names=["test1"], prompt_driver=MockPromptDriver(mock_output="header\nmock output") + ) + ) def test_run(self, task): agent = Agent() @@ -17,5 +22,6 @@ def test_run(self, task): result = task.run() - assert len(result.value) == 1 - assert result.value[0].value == "test1: mock output" + assert len(result.value) == 2 + assert result.value[0].value == "test1" + assert result.value[1].value == "mock output" diff --git a/tests/unit/tools/test_extraction_tool.py b/tests/unit/tools/test_extraction_tool.py index 7e66aa24d..3c6bf6a6c 100644 --- a/tests/unit/tools/test_extraction_tool.py +++ b/tests/unit/tools/test_extraction_tool.py @@ -27,7 +27,7 @@ def csv_tool(self): return ExtractionTool( input_memory=[defaults.text_task_memory("TestMemory")], extraction_engine=CsvExtractionEngine( - prompt_driver=MockPromptDriver(), + prompt_driver=MockPromptDriver(mock_output="header\nmock output"), column_names=["test1"], ), ) @@ -57,11 +57,13 @@ def test_csv_extract_artifacts(self, csv_tool): {"values": {"data": {"memory_name": csv_tool.input_memory[0].name, "artifact_namespace": "foo"}}} ) - assert len(result.value) == 1 - assert result.value[0].value == "test1: mock output" + assert len(result.value) == 2 + assert result.value[0].value == "test1" + assert result.value[1].value == "mock output" def test_csv_extract_content(self, csv_tool): result = csv_tool.extract({"values": {"data": "foo"}}) - assert len(result.value) == 1 - assert result.value[0].value == "test1: mock output" + assert len(result.value) == 2 + assert result.value[0].value == "test1" + assert result.value[1].value == "mock output"