Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extraction Engine Improvements #1097

Merged
merged 14 commits into from
Oct 4, 2024
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased
### Added
- `BaseConversationMemory.prompt_driver` for use with autopruning.
- Generic type support to `ListArtifact`.
- Iteration support to `ListArtifact`.

### Changed
dylanholmes marked this conversation as resolved.
Show resolved Hide resolved
- **BREAKING**: Split `BaseExtractionEngine.extract` into `extract` and `extract_artifacts` for consistency with `BaseSummaryEngine`.
dylanholmes marked this conversation as resolved.
Show resolved Hide resolved
- **BREAKING**: `BaseExtractionEngine` no longer catches exceptions and returns `ErrorArtifact`s.
dylanholmes marked this conversation as resolved.
Show resolved Hide resolved
- **BREAKING**: `JsonExtractionEngine.template_schema` is now required.
- **BREAKING**: `CsvExtractionEngine.column_names` is now required.
- `JsonExtractionEngine.extract_artifacts` now returns a `ListArtifact[JsonArtifact]`.
- `CsvExtractionEngine.extract_artifacts` now returns a `ListArtifact[CsvRowArtifact]`.

### Fixed
- Parsing streaming response with some OpenAi compatible services.
Expand Down
27 changes: 18 additions & 9 deletions docs/griptape-framework/engines/extraction-engines.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@ As of now, Griptape supports two types of Extraction Engines: the CSV Extraction

## CSV

The CSV Extraction Engine is designed specifically for extracting data from CSV-formatted content.

!!! info
The CSV Extraction Engine requires the `column_names` parameter for specifying the columns to be extracted.
The CSV Extraction Engine is designed for extracting CSV-formatted content from unstructured data.
dylanholmes marked this conversation as resolved.
Show resolved Hide resolved

```python
--8<-- "docs/griptape-framework/engines/src/extraction_engines_1.py"
Expand All @@ -27,15 +24,27 @@ Charlie,40,Texas

## JSON

The JSON Extraction Engine is tailored for extracting data from JSON-formatted content.
The JSON Extraction Engine is designed for extracting JSON-formatted content from unstructed data.

!!! info
The JSON Extraction Engine requires the `template_schema` parameter for specifying the structure to be extracted.

```python
--8<-- "docs/griptape-framework/engines/src/extraction_engines_2.py"
```
```
{'name': 'Alice', 'age': 28, 'location': 'New York'}
{'name': 'Bob', 'age': 35, 'location': 'California'}
{
"model": "GPT-3.5",
"notes": [
"Part of OpenAI's GPT series.",
"Used in ChatGPT and Microsoft Copilot."
]
}
{
"model": "GPT-4",
"notes": [
"Part of OpenAI's GPT series.",
"Praised for increased accuracy and multimodal capabilities.",
"Architecture and number of parameters not revealed."
]
}
...Output truncated for brevity...
```
10 changes: 3 additions & 7 deletions docs/griptape-framework/engines/src/extraction_engines_1.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from griptape.artifacts import ListArtifact
from griptape.drivers import OpenAiChatPromptDriver
from griptape.engines import CsvExtractionEngine

Expand All @@ -15,10 +14,7 @@
"""

# Extract CSV rows using the engine
result = csv_engine.extract(sample_text)
result = csv_engine.extract_text(sample_text)

if isinstance(result, ListArtifact):
for row in result.value:
print(row.to_text())
else:
print(result.to_text())
for row in result:
print(row.to_text())
37 changes: 20 additions & 17 deletions docs/griptape-framework/engines/src/extraction_engines_2.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,31 @@
from schema import Schema
import json

from griptape.artifacts.list_artifact import ListArtifact
from schema import Literal, Schema

from griptape.artifacts import ErrorArtifact, ListArtifact
from griptape.drivers import OpenAiChatPromptDriver
from griptape.engines import JsonExtractionEngine
from griptape.loaders import WebLoader

# Define a schema for extraction
user_schema = Schema({"users": [{"name": str, "age": int, "location": str}]}).json_schema("UserSchema")


json_engine = JsonExtractionEngine(
prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo"), template_schema=user_schema
prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"),
template_schema=Schema(
{
Literal("model", description="Name of an LLM model."): str,
Literal("notes", description="Any notes of substance about the model."): Schema([str]),
}
).json_schema("ProductSchema"),
)

# Define some unstructured data
sample_json_text = """
Alice (Age 28) lives in New York.
Bob (Age 35) lives in California.
"""
# Load data from the web
web_data = WebLoader().load("https://en.wikipedia.org/wiki/Large_language_model")

if isinstance(web_data, ErrorArtifact):
raise Exception(web_data.value)

# Extract data using the engine
result = json_engine.extract(sample_json_text)
result = json_engine.extract_artifacts(ListArtifact(web_data))

if isinstance(result, ListArtifact):
for artifact in result.value:
print(artifact.value)
else:
print(result.to_text())
for artifact in result:
print(json.dumps(artifact.value, indent=2))
4 changes: 2 additions & 2 deletions griptape/artifacts/json_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@

from griptape.artifacts import BaseArtifact

Json = Union[dict[str, "Json"], list["Json"], str, int, float, bool, None]


@define
class JsonArtifact(BaseArtifact):
Json = Union[dict[str, "Json"], list["Json"], str, int, float, bool, None]

Copy link
Member Author

@collindutter collindutter Aug 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allows us to access via JsonArtifact.Json which may be useful in the future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, could you provide a hypothetical example?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the near future we may offer some sort of structured output functionality that would allow users to pass in a JSON schema they'd like the LLM to output their answer in.

The type could be something like: output_format: JsonArtifact.Json

value: Json = field(converter=lambda v: json.loads(json.dumps(v)), metadata={"serializable": True})

def to_text(self) -> str:
Expand Down
23 changes: 14 additions & 9 deletions griptape/artifacts/list_artifact.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Generic, Optional, TypeVar

from attrs import Attribute, define, field

from griptape.artifacts import BaseArtifact

if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Iterator, Sequence

T = TypeVar("T", bound=BaseArtifact, covariant=True)


@define
class ListArtifact(BaseArtifact):
value: Sequence[BaseArtifact] = field(factory=list, metadata={"serializable": True})
class ListArtifact(BaseArtifact, Generic[T]):
value: Sequence[T] = field(factory=list, metadata={"serializable": True})
item_separator: str = field(default="\n\n", kw_only=True, metadata={"serializable": True})
validate_uniform_types: bool = field(default=False, kw_only=True, metadata={"serializable": True})

@value.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_value(self, _: Attribute, value: list[BaseArtifact]) -> None:
def validate_value(self, _: Attribute, value: list[T]) -> None:
if self.validate_uniform_types and len(value) > 0:
first_type = type(value[0])

Expand All @@ -31,18 +33,21 @@ def child_type(self) -> Optional[type]:
else:
return None

def __getitem__(self, key: int) -> BaseArtifact:
def __getitem__(self, key: int) -> T:
return self.value[key]

def __bool__(self) -> bool:
return len(self) > 0

def __add__(self, other: BaseArtifact) -> ListArtifact[T]:
return ListArtifact(self.value + other.value)

def __iter__(self) -> Iterator[T]:
return iter(self.value)

def to_text(self) -> str:
return self.item_separator.join([v.to_text() for v in self.value])

def __add__(self, other: BaseArtifact) -> BaseArtifact:
return ListArtifact(self.value + other.value)

def is_type(self, target_type: type) -> bool:
if self.value:
return isinstance(self.value[0], target_type)
Expand Down
17 changes: 13 additions & 4 deletions griptape/engines/extraction/base_extraction_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

from attrs import Attribute, Factory, define, field

from griptape.artifacts import ListArtifact, TextArtifact
from griptape.chunkers import BaseChunker, TextChunker
from griptape.configs import Defaults

if TYPE_CHECKING:
from griptape.artifacts import ErrorArtifact, ListArtifact
from griptape.drivers import BasePromptDriver
from griptape.rules import Ruleset

Expand Down Expand Up @@ -47,11 +47,20 @@ def min_response_tokens(self) -> int:
- self.prompt_driver.tokenizer.max_input_tokens * self.max_token_multiplier,
)

def extract_text(
self,
text: str,
*,
rulesets: Optional[list[Ruleset]] = None,
**kwargs,
) -> ListArtifact:
return self.extract_artifacts(ListArtifact([TextArtifact(text)]), rulesets=rulesets, **kwargs)

@abstractmethod
def extract(
def extract_artifacts(
self,
text: str | ListArtifact,
artifacts: ListArtifact[TextArtifact],
*,
rulesets: Optional[list[Ruleset]] = None,
**kwargs,
) -> ListArtifact | ErrorArtifact: ...
) -> ListArtifact: ...
28 changes: 13 additions & 15 deletions griptape/engines/extraction/csv_extraction_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from attrs import Factory, define, field

from griptape.artifacts import CsvRowArtifact, ErrorArtifact, ListArtifact, TextArtifact
from griptape.artifacts import CsvRowArtifact, ListArtifact, TextArtifact
from griptape.common import Message, PromptStack
from griptape.engines import BaseExtractionEngine
from griptape.utils import J2
Expand All @@ -17,27 +17,25 @@

@define
class CsvExtractionEngine(BaseExtractionEngine):
column_names: list[str] = field(default=Factory(list), kw_only=True)
column_names: list[str] = field(kw_only=True)
system_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/system.j2")), kw_only=True)
user_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/user.j2")), kw_only=True)

def extract(
def extract_artifacts(
self,
text: str | ListArtifact,
artifacts: ListArtifact[TextArtifact],
*,
rulesets: Optional[list[Ruleset]] = None,
**kwargs,
) -> ListArtifact | ErrorArtifact:
try:
return ListArtifact(
self._extract_rec(
cast(list[TextArtifact], text.value) if isinstance(text, ListArtifact) else [TextArtifact(text)],
[],
),
item_separator="\n",
)
except Exception as e:
return ErrorArtifact(f"error extracting CSV rows: {e}")
) -> ListArtifact[CsvRowArtifact]:
return ListArtifact(
self._extract_rec(
cast(list[TextArtifact], artifacts.value),
[],
rulesets=rulesets,
),
item_separator="\n",
)

def text_to_csv_rows(self, text: str, column_names: list[str]) -> list[CsvRowArtifact]:
rows = []
Expand Down
33 changes: 13 additions & 20 deletions griptape/engines/extraction/json_extraction_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from attrs import Factory, define, field

from griptape.artifacts import ErrorArtifact, ListArtifact, TextArtifact
from griptape.artifacts import JsonArtifact, ListArtifact, TextArtifact
from griptape.common import PromptStack
from griptape.common.prompt_stack.messages.message import Message
from griptape.engines import BaseExtractionEngine
Expand All @@ -20,46 +20,39 @@
class JsonExtractionEngine(BaseExtractionEngine):
JSON_PATTERN = r"(?s)[^\[]*(\[.*\])"

template_schema: dict = field(default=Factory(dict), kw_only=True)
template_schema: dict = field(kw_only=True)
system_template_generator: J2 = field(
default=Factory(lambda: J2("engines/extraction/json/system.j2")), kw_only=True
)
user_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/json/user.j2")), kw_only=True)

def extract(
def extract_artifacts(
self,
text: str | ListArtifact,
artifacts: ListArtifact[TextArtifact],
*,
rulesets: Optional[list[Ruleset]] = None,
**kwargs,
) -> ListArtifact | ErrorArtifact:
try:
return ListArtifact(
self._extract_rec(
cast(list[TextArtifact], text.value) if isinstance(text, ListArtifact) else [TextArtifact(text)],
[],
rulesets=rulesets,
),
item_separator="\n",
)
except Exception as e:
return ErrorArtifact(f"error extracting JSON: {e}")
) -> ListArtifact[JsonArtifact]:
return ListArtifact(
self._extract_rec(cast(list[TextArtifact], artifacts.value), [], rulesets=rulesets),
item_separator="\n",
)

def json_to_text_artifacts(self, json_input: str) -> list[TextArtifact]:
def json_to_text_artifacts(self, json_input: str) -> list[JsonArtifact]:
json_matches = re.findall(self.JSON_PATTERN, json_input, re.DOTALL)

if json_matches:
return [TextArtifact(json.dumps(e)) for e in json.loads(json_matches[-1])]
return [JsonArtifact(e) for e in json.loads(json_matches[-1])]
else:
return []

def _extract_rec(
self,
artifacts: list[TextArtifact],
extractions: list[TextArtifact],
extractions: list[JsonArtifact],
*,
rulesets: Optional[list[Ruleset]] = None,
) -> list[TextArtifact]:
) -> list[JsonArtifact]:
artifacts_text = self.chunk_joiner.join([a.value for a in artifacts])
system_prompt = self.system_template_generator.render(
json_template_schema=json.dumps(self.template_schema),
Expand Down
6 changes: 5 additions & 1 deletion griptape/schemas/base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from abc import ABC
from collections.abc import Sequence
from typing import Any, Literal, Union, _SpecialForm, get_args, get_origin
from typing import Any, Literal, TypeVar, Union, _SpecialForm, get_args, get_origin

import attrs
from marshmallow import INCLUDE, Schema, fields
Expand Down Expand Up @@ -56,6 +56,10 @@ def _get_field_for_type(cls, field_type: type) -> fields.Field | fields.Nested:

field_class, args, optional = cls._get_field_type_info(field_type)

# Resolve TypeVars to their bound type
if isinstance(field_class, TypeVar):
field_class = field_class.__bound__

if attrs.has(field_class):
if ABC in field_class.__bases__:
return fields.Nested(PolymorphicSchema(inner_class=field_class), allow_none=optional)
Expand Down
7 changes: 5 additions & 2 deletions griptape/tasks/extraction_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

from attrs import define, field

from griptape.artifacts import ListArtifact
from griptape.tasks import BaseTextInputTask

if TYPE_CHECKING:
from griptape.artifacts import ErrorArtifact, ListArtifact
from griptape.artifacts import ErrorArtifact
from griptape.engines import BaseExtractionEngine


Expand All @@ -17,4 +18,6 @@ class ExtractionTask(BaseTextInputTask):
args: dict = field(kw_only=True, factory=dict)

def run(self) -> ListArtifact | ErrorArtifact:
return self.extraction_engine.extract(self.input.to_text(), rulesets=self.all_rulesets, **self.args)
return self.extraction_engine.extract_artifacts(
ListArtifact([self.input]), rulesets=self.all_rulesets, **self.args
)
Loading