Skip to content

Commit

Permalink
Extraction Engine Improvements (#1097)
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter authored Oct 4, 2024
1 parent 9cd4199 commit 79f976b
Show file tree
Hide file tree
Showing 13 changed files with 114 additions and 83 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **BREAKING**: Update `redis` dependency to `^5.1.0`.
- **BREAKING**: Remove `torch` extra from `transformers` dependency. This must be installed separately.
- `MarkdownifyWebScraperDriver.DEFAULT_EXCLUDE_TAGS` now includes media/blob-like HTML tags
- **BREAKING**: Split `BaseExtractionEngine.extract` into `extract_text` and `extract_artifacts` for consistency with `BaseSummaryEngine`.
- **BREAKING**: `BaseExtractionEngine` no longer catches exceptions and returns `ErrorArtifact`s.
- **BREAKING**: `JsonExtractionEngine.template_schema` is now required.
- **BREAKING**: `CsvExtractionEngine.column_names` is now required.
- `JsonExtractionEngine.extract_artifacts` now returns a `ListArtifact[JsonArtifact]`.
- `CsvExtractionEngine.extract_artifacts` now returns a `ListArtifact[CsvRowArtifact]`.

### Fixed
- Anthropic native Tool calling
Expand Down
27 changes: 18 additions & 9 deletions docs/griptape-framework/engines/extraction-engines.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@ As of now, Griptape supports two types of Extraction Engines: the CSV Extraction

## CSV

The CSV Extraction Engine is designed specifically for extracting data from CSV-formatted content.

!!! info
The CSV Extraction Engine requires the `column_names` parameter for specifying the columns to be extracted.
The CSV Extraction Engine extracts tabular content from unstructured text.

```python
--8<-- "docs/griptape-framework/engines/src/extraction_engines_1.py"
Expand All @@ -27,15 +24,27 @@ Charlie,40,Texas

## JSON

The JSON Extraction Engine is tailored for extracting data from JSON-formatted content.
The JSON Extraction Engine extracts JSON-formatted content from unstructured text.

!!! info
The JSON Extraction Engine requires the `template_schema` parameter for specifying the structure to be extracted.

```python
--8<-- "docs/griptape-framework/engines/src/extraction_engines_2.py"
```
```
{'name': 'Alice', 'age': 28, 'location': 'New York'}
{'name': 'Bob', 'age': 35, 'location': 'California'}
{
"model": "GPT-3.5",
"notes": [
"Part of OpenAI's GPT series.",
"Used in ChatGPT and Microsoft Copilot."
]
}
{
"model": "GPT-4",
"notes": [
"Part of OpenAI's GPT series.",
"Praised for increased accuracy and multimodal capabilities.",
"Architecture and number of parameters not revealed."
]
}
...Output truncated for brevity...
```
10 changes: 3 additions & 7 deletions docs/griptape-framework/engines/src/extraction_engines_1.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from griptape.artifacts import ListArtifact
from griptape.drivers import OpenAiChatPromptDriver
from griptape.engines import CsvExtractionEngine

Expand All @@ -15,10 +14,7 @@
"""

# Extract CSV rows using the engine
result = csv_engine.extract(sample_text)
result = csv_engine.extract_text(sample_text)

if isinstance(result, ListArtifact):
for row in result.value:
print(row.to_text())
else:
print(result.to_text())
for row in result:
print(row.to_text())
37 changes: 20 additions & 17 deletions docs/griptape-framework/engines/src/extraction_engines_2.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,31 @@
from schema import Schema
import json

from griptape.artifacts.list_artifact import ListArtifact
from schema import Literal, Schema

from griptape.artifacts import ErrorArtifact, ListArtifact
from griptape.drivers import OpenAiChatPromptDriver
from griptape.engines import JsonExtractionEngine
from griptape.loaders import WebLoader

# Define a schema for extraction
user_schema = Schema({"users": [{"name": str, "age": int, "location": str}]}).json_schema("UserSchema")


json_engine = JsonExtractionEngine(
prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo"), template_schema=user_schema
prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"),
template_schema=Schema(
{
Literal("model", description="Name of an LLM model."): str,
Literal("notes", description="Any notes of substance about the model."): Schema([str]),
}
).json_schema("ProductSchema"),
)

# Define some unstructured data
sample_json_text = """
Alice (Age 28) lives in New York.
Bob (Age 35) lives in California.
"""
# Load data from the web
web_data = WebLoader().load("https://en.wikipedia.org/wiki/Large_language_model")

if isinstance(web_data, ErrorArtifact):
raise Exception(web_data.value)

# Extract data using the engine
result = json_engine.extract(sample_json_text)
result = json_engine.extract_artifacts(ListArtifact(web_data))

if isinstance(result, ListArtifact):
for artifact in result.value:
print(artifact.value)
else:
print(result.to_text())
for artifact in result:
print(json.dumps(artifact.value, indent=2))
24 changes: 12 additions & 12 deletions griptape/artifacts/list_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,6 @@ class ListArtifact(BaseArtifact, Generic[T]):
item_separator: str = field(default="\n\n", kw_only=True, metadata={"serializable": True})
validate_uniform_types: bool = field(default=False, kw_only=True, metadata={"serializable": True})

def __getitem__(self, key: int) -> T:
return self.value[key]

def __bool__(self) -> bool:
return len(self) > 0

def __add__(self, other: BaseArtifact) -> ListArtifact[T]:
return ListArtifact(self.value + other.value)

def __iter__(self) -> Iterator[T]:
return iter(self.value)

@value.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_value(self, _: Attribute, value: list[T]) -> None:
if self.validate_uniform_types and len(value) > 0:
Expand All @@ -45,6 +33,18 @@ def child_type(self) -> Optional[type]:
else:
return None

def __getitem__(self, key: int) -> T:
return self.value[key]

def __bool__(self) -> bool:
return len(self) > 0

def __add__(self, other: BaseArtifact) -> ListArtifact[T]:
return ListArtifact(self.value + other.value)

def __iter__(self) -> Iterator[T]:
return iter(self.value)

def to_text(self) -> str:
return self.item_separator.join([v.to_text() for v in self.value])

Expand Down
15 changes: 12 additions & 3 deletions griptape/engines/extraction/base_extraction_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

from attrs import Attribute, Factory, define, field

from griptape.artifacts import ListArtifact, TextArtifact
from griptape.chunkers import BaseChunker, TextChunker
from griptape.configs import Defaults

if TYPE_CHECKING:
from griptape.artifacts import ListArtifact
from griptape.drivers import BasePromptDriver
from griptape.rules import Ruleset

Expand Down Expand Up @@ -47,10 +47,19 @@ def min_response_tokens(self) -> int:
- self.prompt_driver.tokenizer.max_input_tokens * self.max_token_multiplier,
)

def extract_text(
self,
text: str,
*,
rulesets: Optional[list[Ruleset]] = None,
**kwargs,
) -> ListArtifact:
return self.extract_artifacts(ListArtifact([TextArtifact(text)]), rulesets=rulesets, **kwargs)

@abstractmethod
def extract(
def extract_artifacts(
self,
text: str | ListArtifact,
artifacts: ListArtifact[TextArtifact],
*,
rulesets: Optional[list[Ruleset]] = None,
**kwargs,
Expand Down
10 changes: 5 additions & 5 deletions griptape/engines/extraction/csv_extraction_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,23 @@

@define
class CsvExtractionEngine(BaseExtractionEngine):
column_names: list[str] = field(default=Factory(list), kw_only=True)
column_names: list[str] = field(kw_only=True)
system_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/system.j2")), kw_only=True)
user_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/user.j2")), kw_only=True)
formatter_fn: Callable[[dict], str] = field(
default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True
)

def extract(
def extract_artifacts(
self,
text: str | ListArtifact,
artifacts: ListArtifact[TextArtifact],
*,
rulesets: Optional[list[Ruleset]] = None,
**kwargs,
) -> ListArtifact:
) -> ListArtifact[TextArtifact]:
return ListArtifact(
self._extract_rec(
cast(list[TextArtifact], text.value) if isinstance(text, ListArtifact) else [TextArtifact(text)],
cast(list[TextArtifact], artifacts.value),
[],
rulesets=rulesets,
),
Expand Down
24 changes: 10 additions & 14 deletions griptape/engines/extraction/json_extraction_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from attrs import Factory, define, field

from griptape.artifacts import ListArtifact, TextArtifact
from griptape.artifacts import JsonArtifact, ListArtifact, TextArtifact
from griptape.common import PromptStack
from griptape.common.prompt_stack.messages.message import Message
from griptape.engines import BaseExtractionEngine
Expand All @@ -20,43 +20,39 @@
class JsonExtractionEngine(BaseExtractionEngine):
JSON_PATTERN = r"(?s)[^\[]*(\[.*\])"

template_schema: dict = field(default=Factory(dict), kw_only=True)
template_schema: dict = field(kw_only=True)
system_template_generator: J2 = field(
default=Factory(lambda: J2("engines/extraction/json/system.j2")), kw_only=True
)
user_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/json/user.j2")), kw_only=True)

def extract(
def extract_artifacts(
self,
text: str | ListArtifact,
artifacts: ListArtifact[TextArtifact],
*,
rulesets: Optional[list[Ruleset]] = None,
**kwargs,
) -> ListArtifact:
) -> ListArtifact[JsonArtifact]:
return ListArtifact(
self._extract_rec(
cast(list[TextArtifact], text.value) if isinstance(text, ListArtifact) else [TextArtifact(text)],
[],
rulesets=rulesets,
),
self._extract_rec(cast(list[TextArtifact], artifacts.value), [], rulesets=rulesets),
item_separator="\n",
)

def json_to_text_artifacts(self, json_input: str) -> list[TextArtifact]:
def json_to_text_artifacts(self, json_input: str) -> list[JsonArtifact]:
json_matches = re.findall(self.JSON_PATTERN, json_input, re.DOTALL)

if json_matches:
return [TextArtifact(json.dumps(e)) for e in json.loads(json_matches[-1])]
return [JsonArtifact(e) for e in json.loads(json_matches[-1])]
else:
return []

def _extract_rec(
self,
artifacts: list[TextArtifact],
extractions: list[TextArtifact],
extractions: list[JsonArtifact],
*,
rulesets: Optional[list[Ruleset]] = None,
) -> list[TextArtifact]:
) -> list[JsonArtifact]:
artifacts_text = self.chunk_joiner.join([a.value for a in artifacts])
system_prompt = self.system_template_generator.render(
json_template_schema=json.dumps(self.template_schema),
Expand Down
7 changes: 5 additions & 2 deletions griptape/tasks/extraction_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

from attrs import define, field

from griptape.artifacts import ListArtifact
from griptape.tasks import BaseTextInputTask

if TYPE_CHECKING:
from griptape.artifacts import ErrorArtifact, ListArtifact
from griptape.artifacts import ErrorArtifact
from griptape.engines import BaseExtractionEngine


Expand All @@ -17,4 +18,6 @@ class ExtractionTask(BaseTextInputTask):
args: dict = field(kw_only=True, factory=dict)

def run(self) -> ListArtifact | ErrorArtifact:
return self.extraction_engine.extract(self.input.to_text(), rulesets=self.all_rulesets, **self.args)
return self.extraction_engine.extract_artifacts(
ListArtifact([self.input]), rulesets=self.all_rulesets, **self.args
)
2 changes: 1 addition & 1 deletion griptape/tools/extraction/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,4 @@ def extract(self, params: dict) -> ListArtifact | InfoArtifact | ErrorArtifact:
else:
return ErrorArtifact("memory not found")

return self.extraction_engine.extract(artifacts)
return self.extraction_engine.extract_artifacts(artifacts)
4 changes: 2 additions & 2 deletions tests/unit/engines/extraction/test_csv_extraction_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ class TestCsvExtractionEngine:
def engine(self):
return CsvExtractionEngine(column_names=["test1"])

def test_extract(self, engine):
result = engine.extract("foo")
def test_extract_text(self, engine):
result = engine.extract_text("foo")

assert len(result.value) == 1
assert result.value[0].value == "test1: mock output"
Expand Down
23 changes: 16 additions & 7 deletions tests/unit/engines/extraction/test_json_extraction_engine.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from os.path import dirname, join, normpath
from pathlib import Path

import pytest
from schema import Schema

Expand All @@ -15,24 +18,30 @@ def engine(self):
template_schema=Schema({"foo": "bar"}).json_schema("TemplateSchema"),
)

def test_extract(self, engine):
result = engine.extract("foo")
def test_extract_text(self, engine):
result = engine.extract_text("foo")

assert len(result.value) == 2
assert result.value[0].value == '{"test_key_1": "test_value_1"}'
assert result.value[1].value == '{"test_key_2": "test_value_2"}'
assert result.value[0].value == {"test_key_1": "test_value_1"}
assert result.value[1].value == {"test_key_2": "test_value_2"}

def test_chunked_extract_text(self, engine):
large_text = Path(normpath(join(dirname(__file__), "../../../resources", "test.txt"))).read_text()

extracted = engine.extract_text(large_text * 50)
assert len(extracted) == 354
assert extracted[0].value == {"test_key_1": "test_value_1"}

def test_extract_error(self, engine):
engine.template_schema = lambda: "non serializable"

with pytest.raises(TypeError):
engine.extract("foo")
engine.extract_text("foo")

def test_json_to_text_artifacts(self, engine):
assert [
a.value
for a in engine.json_to_text_artifacts('[{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}]')
] == ['{"test_key_1": "test_value_1"}', '{"test_key_2": "test_value_2"}']
] == [{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}]

def test_json_to_text_artifacts_no_matches(self, engine):
assert engine.json_to_text_artifacts("asdfasdfasdf") == []
8 changes: 4 additions & 4 deletions tests/unit/tools/test_extraction_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ def test_json_extract_artifacts(self, json_tool):
)

assert len(result.value) == 2
assert result.value[0].value == '{"test_key_1": "test_value_1"}'
assert result.value[1].value == '{"test_key_2": "test_value_2"}'
assert result.value[0].value == {"test_key_1": "test_value_1"}
assert result.value[1].value == {"test_key_2": "test_value_2"}

def test_json_extract_content(self, json_tool):
result = json_tool.extract({"values": {"data": "foo"}})

assert len(result.value) == 2
assert result.value[0].value == '{"test_key_1": "test_value_1"}'
assert result.value[1].value == '{"test_key_2": "test_value_2"}'
assert result.value[0].value == {"test_key_1": "test_value_1"}
assert result.value[1].value == {"test_key_2": "test_value_2"}

def test_csv_extract_artifacts(self, csv_tool):
csv_tool.input_memory[0].store_artifact("foo", TextArtifact("foo,bar\nbaz,maz"))
Expand Down

0 comments on commit 79f976b

Please sign in to comment.