Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extraction Engine Improvements #1097

Merged
merged 14 commits into from
Oct 4, 2024
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **BREAKING**: Update `redis` dependency to `^5.1.0`.
- **BREAKING**: Remove `torch` extra from `transformers` dependency. This must be installed separately.
- `MarkdownifyWebScraperDriver.DEFAULT_EXCLUDE_TAGS` now includes media/blob-like HTML tags
- **BREAKING**: Split `BaseExtractionEngine.extract` into `extract_text` and `extract_artifacts` for consistency with `BaseSummaryEngine`.
- **BREAKING**: `BaseExtractionEngine` no longer catches exceptions and returns `ErrorArtifact`s.
dylanholmes marked this conversation as resolved.
Show resolved Hide resolved
- **BREAKING**: `JsonExtractionEngine.template_schema` is now required.
- **BREAKING**: `CsvExtractionEngine.column_names` is now required.
- `JsonExtractionEngine.extract_artifacts` now returns a `ListArtifact[JsonArtifact]`.
- `CsvExtractionEngine.extract_artifacts` now returns a `ListArtifact[CsvRowArtifact]`.

### Fixed
- Anthropic native Tool calling
Expand Down
27 changes: 18 additions & 9 deletions docs/griptape-framework/engines/extraction-engines.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@ As of now, Griptape supports two types of Extraction Engines: the CSV Extraction

## CSV

The CSV Extraction Engine is designed specifically for extracting data from CSV-formatted content.

!!! info
The CSV Extraction Engine requires the `column_names` parameter for specifying the columns to be extracted.
The CSV Extraction Engine extracts tabular content from unstructured text.

```python
--8<-- "docs/griptape-framework/engines/src/extraction_engines_1.py"
Expand All @@ -27,15 +24,27 @@ Charlie,40,Texas

## JSON

The JSON Extraction Engine is tailored for extracting data from JSON-formatted content.
The JSON Extraction Engine extracts JSON-formatted content from unstructured text.

!!! info
The JSON Extraction Engine requires the `template_schema` parameter for specifying the structure to be extracted.

```python
--8<-- "docs/griptape-framework/engines/src/extraction_engines_2.py"
```
```
{'name': 'Alice', 'age': 28, 'location': 'New York'}
{'name': 'Bob', 'age': 35, 'location': 'California'}
{
"model": "GPT-3.5",
"notes": [
"Part of OpenAI's GPT series.",
"Used in ChatGPT and Microsoft Copilot."
]
}
{
"model": "GPT-4",
"notes": [
"Part of OpenAI's GPT series.",
"Praised for increased accuracy and multimodal capabilities.",
"Architecture and number of parameters not revealed."
]
}
...Output truncated for brevity...
```
10 changes: 3 additions & 7 deletions docs/griptape-framework/engines/src/extraction_engines_1.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from griptape.artifacts import ListArtifact
from griptape.drivers import OpenAiChatPromptDriver
from griptape.engines import CsvExtractionEngine

Expand All @@ -15,10 +14,7 @@
"""

# Extract CSV rows using the engine
result = csv_engine.extract(sample_text)
result = csv_engine.extract_text(sample_text)

if isinstance(result, ListArtifact):
for row in result.value:
print(row.to_text())
else:
print(result.to_text())
for row in result:
print(row.to_text())
37 changes: 20 additions & 17 deletions docs/griptape-framework/engines/src/extraction_engines_2.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,31 @@
from schema import Schema
import json

from griptape.artifacts.list_artifact import ListArtifact
from schema import Literal, Schema

from griptape.artifacts import ErrorArtifact, ListArtifact
from griptape.drivers import OpenAiChatPromptDriver
from griptape.engines import JsonExtractionEngine
from griptape.loaders import WebLoader

# Define a schema for extraction
user_schema = Schema({"users": [{"name": str, "age": int, "location": str}]}).json_schema("UserSchema")


json_engine = JsonExtractionEngine(
prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo"), template_schema=user_schema
prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"),
template_schema=Schema(
{
Literal("model", description="Name of an LLM model."): str,
Literal("notes", description="Any notes of substance about the model."): Schema([str]),
}
).json_schema("ProductSchema"),
)

# Define some unstructured data
sample_json_text = """
Alice (Age 28) lives in New York.
Bob (Age 35) lives in California.
"""
# Load data from the web
web_data = WebLoader().load("https://en.wikipedia.org/wiki/Large_language_model")

if isinstance(web_data, ErrorArtifact):
raise Exception(web_data.value)

# Extract data using the engine
result = json_engine.extract(sample_json_text)
result = json_engine.extract_artifacts(ListArtifact(web_data))

if isinstance(result, ListArtifact):
for artifact in result.value:
print(artifact.value)
else:
print(result.to_text())
for artifact in result:
print(json.dumps(artifact.value, indent=2))
24 changes: 12 additions & 12 deletions griptape/artifacts/list_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,6 @@ class ListArtifact(BaseArtifact, Generic[T]):
item_separator: str = field(default="\n\n", kw_only=True, metadata={"serializable": True})
validate_uniform_types: bool = field(default=False, kw_only=True, metadata={"serializable": True})

def __getitem__(self, key: int) -> T:
return self.value[key]

def __bool__(self) -> bool:
return len(self) > 0

def __add__(self, other: BaseArtifact) -> ListArtifact[T]:
return ListArtifact(self.value + other.value)

def __iter__(self) -> Iterator[T]:
return iter(self.value)

@value.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_value(self, _: Attribute, value: list[T]) -> None:
if self.validate_uniform_types and len(value) > 0:
Expand All @@ -45,6 +33,18 @@ def child_type(self) -> Optional[type]:
else:
return None

def __getitem__(self, key: int) -> T:
return self.value[key]

def __bool__(self) -> bool:
return len(self) > 0

def __add__(self, other: BaseArtifact) -> ListArtifact[T]:
return ListArtifact(self.value + other.value)

def __iter__(self) -> Iterator[T]:
return iter(self.value)

def to_text(self) -> str:
return self.item_separator.join([v.to_text() for v in self.value])

Expand Down
15 changes: 12 additions & 3 deletions griptape/engines/extraction/base_extraction_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

from attrs import Attribute, Factory, define, field

from griptape.artifacts import ListArtifact, TextArtifact
from griptape.chunkers import BaseChunker, TextChunker
from griptape.configs import Defaults

if TYPE_CHECKING:
from griptape.artifacts import ListArtifact
from griptape.drivers import BasePromptDriver
from griptape.rules import Ruleset

Expand Down Expand Up @@ -47,10 +47,19 @@ def min_response_tokens(self) -> int:
- self.prompt_driver.tokenizer.max_input_tokens * self.max_token_multiplier,
)

def extract_text(
self,
text: str,
*,
rulesets: Optional[list[Ruleset]] = None,
**kwargs,
) -> ListArtifact:
return self.extract_artifacts(ListArtifact([TextArtifact(text)]), rulesets=rulesets, **kwargs)

@abstractmethod
def extract(
def extract_artifacts(
self,
text: str | ListArtifact,
artifacts: ListArtifact[TextArtifact],
*,
rulesets: Optional[list[Ruleset]] = None,
**kwargs,
Expand Down
10 changes: 5 additions & 5 deletions griptape/engines/extraction/csv_extraction_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,23 @@

@define
class CsvExtractionEngine(BaseExtractionEngine):
column_names: list[str] = field(default=Factory(list), kw_only=True)
column_names: list[str] = field(kw_only=True)
system_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/system.j2")), kw_only=True)
user_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/user.j2")), kw_only=True)
formatter_fn: Callable[[dict], str] = field(
default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True
)

def extract(
def extract_artifacts(
self,
text: str | ListArtifact,
artifacts: ListArtifact[TextArtifact],
*,
rulesets: Optional[list[Ruleset]] = None,
**kwargs,
) -> ListArtifact:
) -> ListArtifact[TextArtifact]:
return ListArtifact(
self._extract_rec(
cast(list[TextArtifact], text.value) if isinstance(text, ListArtifact) else [TextArtifact(text)],
cast(list[TextArtifact], artifacts.value),
[],
rulesets=rulesets,
),
Expand Down
24 changes: 10 additions & 14 deletions griptape/engines/extraction/json_extraction_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from attrs import Factory, define, field

from griptape.artifacts import ListArtifact, TextArtifact
from griptape.artifacts import JsonArtifact, ListArtifact, TextArtifact
from griptape.common import PromptStack
from griptape.common.prompt_stack.messages.message import Message
from griptape.engines import BaseExtractionEngine
Expand All @@ -20,43 +20,39 @@
class JsonExtractionEngine(BaseExtractionEngine):
JSON_PATTERN = r"(?s)[^\[]*(\[.*\])"

template_schema: dict = field(default=Factory(dict), kw_only=True)
template_schema: dict = field(kw_only=True)
system_template_generator: J2 = field(
default=Factory(lambda: J2("engines/extraction/json/system.j2")), kw_only=True
)
user_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/json/user.j2")), kw_only=True)

def extract(
def extract_artifacts(
self,
text: str | ListArtifact,
artifacts: ListArtifact[TextArtifact],
*,
rulesets: Optional[list[Ruleset]] = None,
**kwargs,
) -> ListArtifact:
) -> ListArtifact[JsonArtifact]:
return ListArtifact(
self._extract_rec(
cast(list[TextArtifact], text.value) if isinstance(text, ListArtifact) else [TextArtifact(text)],
[],
rulesets=rulesets,
),
self._extract_rec(cast(list[TextArtifact], artifacts.value), [], rulesets=rulesets),
item_separator="\n",
)

def json_to_text_artifacts(self, json_input: str) -> list[TextArtifact]:
def json_to_text_artifacts(self, json_input: str) -> list[JsonArtifact]:
json_matches = re.findall(self.JSON_PATTERN, json_input, re.DOTALL)

if json_matches:
return [TextArtifact(json.dumps(e)) for e in json.loads(json_matches[-1])]
return [JsonArtifact(e) for e in json.loads(json_matches[-1])]
else:
return []

def _extract_rec(
self,
artifacts: list[TextArtifact],
extractions: list[TextArtifact],
extractions: list[JsonArtifact],
*,
rulesets: Optional[list[Ruleset]] = None,
) -> list[TextArtifact]:
) -> list[JsonArtifact]:
artifacts_text = self.chunk_joiner.join([a.value for a in artifacts])
system_prompt = self.system_template_generator.render(
json_template_schema=json.dumps(self.template_schema),
Expand Down
7 changes: 5 additions & 2 deletions griptape/tasks/extraction_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

from attrs import define, field

from griptape.artifacts import ListArtifact
from griptape.tasks import BaseTextInputTask

if TYPE_CHECKING:
from griptape.artifacts import ErrorArtifact, ListArtifact
from griptape.artifacts import ErrorArtifact
from griptape.engines import BaseExtractionEngine


Expand All @@ -17,4 +18,6 @@ class ExtractionTask(BaseTextInputTask):
args: dict = field(kw_only=True, factory=dict)

def run(self) -> ListArtifact | ErrorArtifact:
return self.extraction_engine.extract(self.input.to_text(), rulesets=self.all_rulesets, **self.args)
return self.extraction_engine.extract_artifacts(
ListArtifact([self.input]), rulesets=self.all_rulesets, **self.args
)
2 changes: 1 addition & 1 deletion griptape/tools/extraction/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,4 @@ def extract(self, params: dict) -> ListArtifact | InfoArtifact | ErrorArtifact:
else:
return ErrorArtifact("memory not found")

return self.extraction_engine.extract(artifacts)
return self.extraction_engine.extract_artifacts(artifacts)
4 changes: 2 additions & 2 deletions tests/unit/engines/extraction/test_csv_extraction_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ class TestCsvExtractionEngine:
def engine(self):
return CsvExtractionEngine(column_names=["test1"])

def test_extract(self, engine):
result = engine.extract("foo")
def test_extract_text(self, engine):
result = engine.extract_text("foo")

assert len(result.value) == 1
assert result.value[0].value == "test1: mock output"
Expand Down
23 changes: 16 additions & 7 deletions tests/unit/engines/extraction/test_json_extraction_engine.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from os.path import dirname, join, normpath
from pathlib import Path

import pytest
from schema import Schema

Expand All @@ -15,24 +18,30 @@ def engine(self):
template_schema=Schema({"foo": "bar"}).json_schema("TemplateSchema"),
)

def test_extract(self, engine):
result = engine.extract("foo")
def test_extract_text(self, engine):
result = engine.extract_text("foo")

assert len(result.value) == 2
assert result.value[0].value == '{"test_key_1": "test_value_1"}'
assert result.value[1].value == '{"test_key_2": "test_value_2"}'
assert result.value[0].value == {"test_key_1": "test_value_1"}
assert result.value[1].value == {"test_key_2": "test_value_2"}

def test_chunked_extract_text(self, engine):
large_text = Path(normpath(join(dirname(__file__), "../../../resources", "test.txt"))).read_text()

extracted = engine.extract_text(large_text * 50)
assert len(extracted) == 354
assert extracted[0].value == {"test_key_1": "test_value_1"}

def test_extract_error(self, engine):
engine.template_schema = lambda: "non serializable"

with pytest.raises(TypeError):
engine.extract("foo")
engine.extract_text("foo")

def test_json_to_text_artifacts(self, engine):
assert [
a.value
for a in engine.json_to_text_artifacts('[{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}]')
] == ['{"test_key_1": "test_value_1"}', '{"test_key_2": "test_value_2"}']
] == [{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}]

def test_json_to_text_artifacts_no_matches(self, engine):
assert engine.json_to_text_artifacts("asdfasdfasdf") == []
8 changes: 4 additions & 4 deletions tests/unit/tools/test_extraction_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ def test_json_extract_artifacts(self, json_tool):
)

assert len(result.value) == 2
assert result.value[0].value == '{"test_key_1": "test_value_1"}'
assert result.value[1].value == '{"test_key_2": "test_value_2"}'
assert result.value[0].value == {"test_key_1": "test_value_1"}
assert result.value[1].value == {"test_key_2": "test_value_2"}

def test_json_extract_content(self, json_tool):
result = json_tool.extract({"values": {"data": "foo"}})

assert len(result.value) == 2
assert result.value[0].value == '{"test_key_1": "test_value_1"}'
assert result.value[1].value == '{"test_key_2": "test_value_2"}'
assert result.value[0].value == {"test_key_1": "test_value_1"}
assert result.value[1].value == {"test_key_2": "test_value_2"}

def test_csv_extract_artifacts(self, csv_tool):
csv_tool.input_memory[0].store_artifact("foo", TextArtifact("foo,bar\nbaz,maz"))
Expand Down
Loading