diff --git a/CHANGELOG.md b/CHANGELOG.md index 16062edb5..e9dfc1650 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -59,6 +59,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Update `redis` dependency to `^5.1.0`. - **BREAKING**: Remove `torch` extra from `transformers` dependency. This must be installed separately. - `MarkdownifyWebScraperDriver.DEFAULT_EXCLUDE_TAGS` now includes media/blob-like HTML tags +- **BREAKING**: Split `BaseExtractionEngine.extract` into `extract_text` and `extract_artifacts` for consistency with `BaseSummaryEngine`. +- **BREAKING**: `BaseExtractionEngine` no longer catches exceptions and returns `ErrorArtifact`s. +- **BREAKING**: `JsonExtractionEngine.template_schema` is now required. +- **BREAKING**: `CsvExtractionEngine.column_names` is now required. +- `JsonExtractionEngine.extract_artifacts` now returns a `ListArtifact[JsonArtifact]`. +- `CsvExtractionEngine.extract_artifacts` now returns a `ListArtifact[CsvRowArtifact]`. ### Fixed - Anthropic native Tool calling diff --git a/docs/griptape-framework/engines/extraction-engines.md b/docs/griptape-framework/engines/extraction-engines.md index b971e63cc..c00352691 100644 --- a/docs/griptape-framework/engines/extraction-engines.md +++ b/docs/griptape-framework/engines/extraction-engines.md @@ -10,10 +10,7 @@ As of now, Griptape supports two types of Extraction Engines: the CSV Extraction ## CSV -The CSV Extraction Engine is designed specifically for extracting data from CSV-formatted content. - -!!! info - The CSV Extraction Engine requires the `column_names` parameter for specifying the columns to be extracted. +The CSV Extraction Engine extracts tabular content from unstructured text. ```python --8<-- "docs/griptape-framework/engines/src/extraction_engines_1.py" @@ -27,15 +24,27 @@ Charlie,40,Texas ## JSON -The JSON Extraction Engine is tailored for extracting data from JSON-formatted content. +The JSON Extraction Engine extracts JSON-formatted content from unstructured text. -!!! info - The JSON Extraction Engine requires the `template_schema` parameter for specifying the structure to be extracted. ```python --8<-- "docs/griptape-framework/engines/src/extraction_engines_2.py" ``` ``` -{'name': 'Alice', 'age': 28, 'location': 'New York'} -{'name': 'Bob', 'age': 35, 'location': 'California'} +{ + "model": "GPT-3.5", + "notes": [ + "Part of OpenAI's GPT series.", + "Used in ChatGPT and Microsoft Copilot." + ] +} +{ + "model": "GPT-4", + "notes": [ + "Part of OpenAI's GPT series.", + "Praised for increased accuracy and multimodal capabilities.", + "Architecture and number of parameters not revealed." + ] +} +...Output truncated for brevity... ``` diff --git a/docs/griptape-framework/engines/src/extraction_engines_1.py b/docs/griptape-framework/engines/src/extraction_engines_1.py index 17644ebf2..45ccfd3e0 100644 --- a/docs/griptape-framework/engines/src/extraction_engines_1.py +++ b/docs/griptape-framework/engines/src/extraction_engines_1.py @@ -1,4 +1,3 @@ -from griptape.artifacts import ListArtifact from griptape.drivers import OpenAiChatPromptDriver from griptape.engines import CsvExtractionEngine @@ -15,10 +14,7 @@ """ # Extract CSV rows using the engine -result = csv_engine.extract(sample_text) +result = csv_engine.extract_text(sample_text) -if isinstance(result, ListArtifact): - for row in result.value: - print(row.to_text()) -else: - print(result.to_text()) +for row in result: + print(row.to_text()) diff --git a/docs/griptape-framework/engines/src/extraction_engines_2.py b/docs/griptape-framework/engines/src/extraction_engines_2.py index a100754b3..35bbe53cd 100644 --- a/docs/griptape-framework/engines/src/extraction_engines_2.py +++ b/docs/griptape-framework/engines/src/extraction_engines_2.py @@ -1,28 +1,31 @@ -from schema import Schema +import json -from griptape.artifacts.list_artifact import ListArtifact +from schema import Literal, Schema + +from griptape.artifacts import ErrorArtifact, ListArtifact from griptape.drivers import OpenAiChatPromptDriver from griptape.engines import JsonExtractionEngine +from griptape.loaders import WebLoader # Define a schema for extraction -user_schema = Schema({"users": [{"name": str, "age": int, "location": str}]}).json_schema("UserSchema") - - json_engine = JsonExtractionEngine( - prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo"), template_schema=user_schema + prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"), + template_schema=Schema( + { + Literal("model", description="Name of an LLM model."): str, + Literal("notes", description="Any notes of substance about the model."): Schema([str]), + } + ).json_schema("ProductSchema"), ) -# Define some unstructured data -sample_json_text = """ -Alice (Age 28) lives in New York. -Bob (Age 35) lives in California. -""" +# Load data from the web +web_data = WebLoader().load("https://en.wikipedia.org/wiki/Large_language_model") + +if isinstance(web_data, ErrorArtifact): + raise Exception(web_data.value) # Extract data using the engine -result = json_engine.extract(sample_json_text) +result = json_engine.extract_artifacts(ListArtifact(web_data)) -if isinstance(result, ListArtifact): - for artifact in result.value: - print(artifact.value) -else: - print(result.to_text()) +for artifact in result: + print(json.dumps(artifact.value, indent=2)) diff --git a/griptape/artifacts/list_artifact.py b/griptape/artifacts/list_artifact.py index 0e6f81ca5..02dd295cd 100644 --- a/griptape/artifacts/list_artifact.py +++ b/griptape/artifacts/list_artifact.py @@ -18,18 +18,6 @@ class ListArtifact(BaseArtifact, Generic[T]): item_separator: str = field(default="\n\n", kw_only=True, metadata={"serializable": True}) validate_uniform_types: bool = field(default=False, kw_only=True, metadata={"serializable": True}) - def __getitem__(self, key: int) -> T: - return self.value[key] - - def __bool__(self) -> bool: - return len(self) > 0 - - def __add__(self, other: BaseArtifact) -> ListArtifact[T]: - return ListArtifact(self.value + other.value) - - def __iter__(self) -> Iterator[T]: - return iter(self.value) - @value.validator # pyright: ignore[reportAttributeAccessIssue] def validate_value(self, _: Attribute, value: list[T]) -> None: if self.validate_uniform_types and len(value) > 0: @@ -45,6 +33,18 @@ def child_type(self) -> Optional[type]: else: return None + def __getitem__(self, key: int) -> T: + return self.value[key] + + def __bool__(self) -> bool: + return len(self) > 0 + + def __add__(self, other: BaseArtifact) -> ListArtifact[T]: + return ListArtifact(self.value + other.value) + + def __iter__(self) -> Iterator[T]: + return iter(self.value) + def to_text(self) -> str: return self.item_separator.join([v.to_text() for v in self.value]) diff --git a/griptape/engines/extraction/base_extraction_engine.py b/griptape/engines/extraction/base_extraction_engine.py index d3a50585d..10ba5d142 100644 --- a/griptape/engines/extraction/base_extraction_engine.py +++ b/griptape/engines/extraction/base_extraction_engine.py @@ -5,11 +5,11 @@ from attrs import Attribute, Factory, define, field +from griptape.artifacts import ListArtifact, TextArtifact from griptape.chunkers import BaseChunker, TextChunker from griptape.configs import Defaults if TYPE_CHECKING: - from griptape.artifacts import ListArtifact from griptape.drivers import BasePromptDriver from griptape.rules import Ruleset @@ -47,10 +47,19 @@ def min_response_tokens(self) -> int: - self.prompt_driver.tokenizer.max_input_tokens * self.max_token_multiplier, ) + def extract_text( + self, + text: str, + *, + rulesets: Optional[list[Ruleset]] = None, + **kwargs, + ) -> ListArtifact: + return self.extract_artifacts(ListArtifact([TextArtifact(text)]), rulesets=rulesets, **kwargs) + @abstractmethod - def extract( + def extract_artifacts( self, - text: str | ListArtifact, + artifacts: ListArtifact[TextArtifact], *, rulesets: Optional[list[Ruleset]] = None, **kwargs, diff --git a/griptape/engines/extraction/csv_extraction_engine.py b/griptape/engines/extraction/csv_extraction_engine.py index e7be73c73..7fb2a164b 100644 --- a/griptape/engines/extraction/csv_extraction_engine.py +++ b/griptape/engines/extraction/csv_extraction_engine.py @@ -17,23 +17,23 @@ @define class CsvExtractionEngine(BaseExtractionEngine): - column_names: list[str] = field(default=Factory(list), kw_only=True) + column_names: list[str] = field(kw_only=True) system_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/system.j2")), kw_only=True) user_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/user.j2")), kw_only=True) formatter_fn: Callable[[dict], str] = field( default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True ) - def extract( + def extract_artifacts( self, - text: str | ListArtifact, + artifacts: ListArtifact[TextArtifact], *, rulesets: Optional[list[Ruleset]] = None, **kwargs, - ) -> ListArtifact: + ) -> ListArtifact[TextArtifact]: return ListArtifact( self._extract_rec( - cast(list[TextArtifact], text.value) if isinstance(text, ListArtifact) else [TextArtifact(text)], + cast(list[TextArtifact], artifacts.value), [], rulesets=rulesets, ), diff --git a/griptape/engines/extraction/json_extraction_engine.py b/griptape/engines/extraction/json_extraction_engine.py index a4cd3a438..c817efd5f 100644 --- a/griptape/engines/extraction/json_extraction_engine.py +++ b/griptape/engines/extraction/json_extraction_engine.py @@ -6,7 +6,7 @@ from attrs import Factory, define, field -from griptape.artifacts import ListArtifact, TextArtifact +from griptape.artifacts import JsonArtifact, ListArtifact, TextArtifact from griptape.common import PromptStack from griptape.common.prompt_stack.messages.message import Message from griptape.engines import BaseExtractionEngine @@ -20,43 +20,39 @@ class JsonExtractionEngine(BaseExtractionEngine): JSON_PATTERN = r"(?s)[^\[]*(\[.*\])" - template_schema: dict = field(default=Factory(dict), kw_only=True) + template_schema: dict = field(kw_only=True) system_template_generator: J2 = field( default=Factory(lambda: J2("engines/extraction/json/system.j2")), kw_only=True ) user_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/json/user.j2")), kw_only=True) - def extract( + def extract_artifacts( self, - text: str | ListArtifact, + artifacts: ListArtifact[TextArtifact], *, rulesets: Optional[list[Ruleset]] = None, **kwargs, - ) -> ListArtifact: + ) -> ListArtifact[JsonArtifact]: return ListArtifact( - self._extract_rec( - cast(list[TextArtifact], text.value) if isinstance(text, ListArtifact) else [TextArtifact(text)], - [], - rulesets=rulesets, - ), + self._extract_rec(cast(list[TextArtifact], artifacts.value), [], rulesets=rulesets), item_separator="\n", ) - def json_to_text_artifacts(self, json_input: str) -> list[TextArtifact]: + def json_to_text_artifacts(self, json_input: str) -> list[JsonArtifact]: json_matches = re.findall(self.JSON_PATTERN, json_input, re.DOTALL) if json_matches: - return [TextArtifact(json.dumps(e)) for e in json.loads(json_matches[-1])] + return [JsonArtifact(e) for e in json.loads(json_matches[-1])] else: return [] def _extract_rec( self, artifacts: list[TextArtifact], - extractions: list[TextArtifact], + extractions: list[JsonArtifact], *, rulesets: Optional[list[Ruleset]] = None, - ) -> list[TextArtifact]: + ) -> list[JsonArtifact]: artifacts_text = self.chunk_joiner.join([a.value for a in artifacts]) system_prompt = self.system_template_generator.render( json_template_schema=json.dumps(self.template_schema), diff --git a/griptape/tasks/extraction_task.py b/griptape/tasks/extraction_task.py index c74c3ac49..234840401 100644 --- a/griptape/tasks/extraction_task.py +++ b/griptape/tasks/extraction_task.py @@ -4,10 +4,11 @@ from attrs import define, field +from griptape.artifacts import ListArtifact from griptape.tasks import BaseTextInputTask if TYPE_CHECKING: - from griptape.artifacts import ErrorArtifact, ListArtifact + from griptape.artifacts import ErrorArtifact from griptape.engines import BaseExtractionEngine @@ -17,4 +18,6 @@ class ExtractionTask(BaseTextInputTask): args: dict = field(kw_only=True, factory=dict) def run(self) -> ListArtifact | ErrorArtifact: - return self.extraction_engine.extract(self.input.to_text(), rulesets=self.all_rulesets, **self.args) + return self.extraction_engine.extract_artifacts( + ListArtifact([self.input]), rulesets=self.all_rulesets, **self.args + ) diff --git a/griptape/tools/extraction/tool.py b/griptape/tools/extraction/tool.py index 3cb46a670..7e52dad69 100644 --- a/griptape/tools/extraction/tool.py +++ b/griptape/tools/extraction/tool.py @@ -57,4 +57,4 @@ def extract(self, params: dict) -> ListArtifact | InfoArtifact | ErrorArtifact: else: return ErrorArtifact("memory not found") - return self.extraction_engine.extract(artifacts) + return self.extraction_engine.extract_artifacts(artifacts) diff --git a/tests/unit/engines/extraction/test_csv_extraction_engine.py b/tests/unit/engines/extraction/test_csv_extraction_engine.py index 056df2d5a..36c58788d 100644 --- a/tests/unit/engines/extraction/test_csv_extraction_engine.py +++ b/tests/unit/engines/extraction/test_csv_extraction_engine.py @@ -8,8 +8,8 @@ class TestCsvExtractionEngine: def engine(self): return CsvExtractionEngine(column_names=["test1"]) - def test_extract(self, engine): - result = engine.extract("foo") + def test_extract_text(self, engine): + result = engine.extract_text("foo") assert len(result.value) == 1 assert result.value[0].value == "test1: mock output" diff --git a/tests/unit/engines/extraction/test_json_extraction_engine.py b/tests/unit/engines/extraction/test_json_extraction_engine.py index 9d6442579..b4aa58b75 100644 --- a/tests/unit/engines/extraction/test_json_extraction_engine.py +++ b/tests/unit/engines/extraction/test_json_extraction_engine.py @@ -1,3 +1,6 @@ +from os.path import dirname, join, normpath +from pathlib import Path + import pytest from schema import Schema @@ -15,24 +18,30 @@ def engine(self): template_schema=Schema({"foo": "bar"}).json_schema("TemplateSchema"), ) - def test_extract(self, engine): - result = engine.extract("foo") + def test_extract_text(self, engine): + result = engine.extract_text("foo") assert len(result.value) == 2 - assert result.value[0].value == '{"test_key_1": "test_value_1"}' - assert result.value[1].value == '{"test_key_2": "test_value_2"}' + assert result.value[0].value == {"test_key_1": "test_value_1"} + assert result.value[1].value == {"test_key_2": "test_value_2"} + + def test_chunked_extract_text(self, engine): + large_text = Path(normpath(join(dirname(__file__), "../../../resources", "test.txt"))).read_text() + + extracted = engine.extract_text(large_text * 50) + assert len(extracted) == 354 + assert extracted[0].value == {"test_key_1": "test_value_1"} def test_extract_error(self, engine): engine.template_schema = lambda: "non serializable" - with pytest.raises(TypeError): - engine.extract("foo") + engine.extract_text("foo") def test_json_to_text_artifacts(self, engine): assert [ a.value for a in engine.json_to_text_artifacts('[{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}]') - ] == ['{"test_key_1": "test_value_1"}', '{"test_key_2": "test_value_2"}'] + ] == [{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}] def test_json_to_text_artifacts_no_matches(self, engine): assert engine.json_to_text_artifacts("asdfasdfasdf") == [] diff --git a/tests/unit/tools/test_extraction_tool.py b/tests/unit/tools/test_extraction_tool.py index 3f783e8a4..7e66aa24d 100644 --- a/tests/unit/tools/test_extraction_tool.py +++ b/tests/unit/tools/test_extraction_tool.py @@ -40,15 +40,15 @@ def test_json_extract_artifacts(self, json_tool): ) assert len(result.value) == 2 - assert result.value[0].value == '{"test_key_1": "test_value_1"}' - assert result.value[1].value == '{"test_key_2": "test_value_2"}' + assert result.value[0].value == {"test_key_1": "test_value_1"} + assert result.value[1].value == {"test_key_2": "test_value_2"} def test_json_extract_content(self, json_tool): result = json_tool.extract({"values": {"data": "foo"}}) assert len(result.value) == 2 - assert result.value[0].value == '{"test_key_1": "test_value_1"}' - assert result.value[1].value == '{"test_key_2": "test_value_2"}' + assert result.value[0].value == {"test_key_1": "test_value_1"} + assert result.value[1].value == {"test_key_2": "test_value_2"} def test_csv_extract_artifacts(self, csv_tool): csv_tool.input_memory[0].store_artifact("foo", TextArtifact("foo,bar\nbaz,maz"))