From 317ec502814dec87a2c461bd8003347cc4e272d0 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 21 Aug 2024 14:47:51 -0700 Subject: [PATCH] Revert some, update docs --- CHANGELOG.md | 3 +- .../engines/extraction-engines.md | 27 +++++++++++------ .../engines/src/extraction_engines_2.py | 29 ++++++++++++------- .../structures/src/tasks_6.py | 7 ++--- .../structures/src/tasks_7.py | 4 +-- griptape/artifacts/json_artifact.py | 4 +-- .../extraction/csv_extraction_engine.py | 2 +- .../extraction/json_extraction_engine.py | 6 ++-- .../engines/extraction/json/system.j2 | 2 -- .../templates/engines/extraction/json/user.j2 | 4 +-- .../extraction/test_json_extraction_engine.py | 12 ++++---- tests/unit/tools/test_extraction_tool.py | 10 ++++--- 12 files changed, 64 insertions(+), 46 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 09e82ef682..0b9ccbe591 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - **BREAKING**: Split `BaseExtractionEngine.extract` into `extract` and `extract_artifacts` for consistency with `BaseSummaryEngine`. - **BREAKING**: `BaseExtractionEngine` no longer catches exceptions and returns `ErrorArtifact`s. -- `JsonExtractionEngine` to extract either a JSON object or array depending on the provided schema. +- **BREAKING**: `JsonExtractionEngine.template_schema` is now required. +- **BREAKING**: `CsvExtractionEngine.column_names` is now required. - `JsonExtractionEngine.extract_artifacts` now returns a `ListArtifact[JsonArtifact]`. - `CsvExtractionEngine.extract_artifacts` now returns a `ListArtifact[CsvRowArtifact]`. diff --git a/docs/griptape-framework/engines/extraction-engines.md b/docs/griptape-framework/engines/extraction-engines.md index b971e63cc4..43333e4b02 100644 --- a/docs/griptape-framework/engines/extraction-engines.md +++ b/docs/griptape-framework/engines/extraction-engines.md @@ -10,10 +10,7 @@ As of now, Griptape supports two types of Extraction Engines: the CSV Extraction ## CSV -The CSV Extraction Engine is designed specifically for extracting data from CSV-formatted content. - -!!! info - The CSV Extraction Engine requires the `column_names` parameter for specifying the columns to be extracted. +The CSV Extraction Engine is designed for extracting CSV-formatted content from unstructured data. ```python --8<-- "docs/griptape-framework/engines/src/extraction_engines_1.py" @@ -27,15 +24,27 @@ Charlie,40,Texas ## JSON -The JSON Extraction Engine is tailored for extracting data from JSON-formatted content. +The JSON Extraction Engine is designed for extracting JSON-formatted content from unstructed data. -!!! info - The JSON Extraction Engine requires the `template_schema` parameter for specifying the structure to be extracted. ```python --8<-- "docs/griptape-framework/engines/src/extraction_engines_2.py" ``` ``` -{'name': 'Alice', 'age': 28, 'location': 'New York'} -{'name': 'Bob', 'age': 35, 'location': 'California'} +{ + "model": "GPT-3.5", + "notes": [ + "Part of OpenAI's GPT series.", + "Used in ChatGPT and Microsoft Copilot." + ] +} +{ + "model": "GPT-4", + "notes": [ + "Part of OpenAI's GPT series.", + "Praised for increased accuracy and multimodal capabilities.", + "Architecture and number of parameters not revealed." + ] +} +...Output truncated for brevity... ``` diff --git a/docs/griptape-framework/engines/src/extraction_engines_2.py b/docs/griptape-framework/engines/src/extraction_engines_2.py index 2fb8cd8b0b..35bbe53cdd 100644 --- a/docs/griptape-framework/engines/src/extraction_engines_2.py +++ b/docs/griptape-framework/engines/src/extraction_engines_2.py @@ -1,24 +1,31 @@ -from schema import Schema +import json +from schema import Literal, Schema + +from griptape.artifacts import ErrorArtifact, ListArtifact from griptape.drivers import OpenAiChatPromptDriver from griptape.engines import JsonExtractionEngine +from griptape.loaders import WebLoader # Define a schema for extraction -user_schema = Schema([{"name": str, "age": int, "location": str}]).json_schema("UserSchema") - json_engine = JsonExtractionEngine( - prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo"), template_schema=user_schema + prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"), + template_schema=Schema( + { + Literal("model", description="Name of an LLM model."): str, + Literal("notes", description="Any notes of substance about the model."): Schema([str]), + } + ).json_schema("ProductSchema"), ) -# Define some unstructured data -sample_json_text = """ -Alice (Age 28) lives in New York. -Bob (Age 35) lives in California. -""" +# Load data from the web +web_data = WebLoader().load("https://en.wikipedia.org/wiki/Large_language_model") +if isinstance(web_data, ErrorArtifact): + raise Exception(web_data.value) # Extract data using the engine -result = json_engine.extract_text(sample_json_text) +result = json_engine.extract_artifacts(ListArtifact(web_data)) for artifact in result: - print(artifact.value) + print(json.dumps(artifact.value, indent=2)) diff --git a/docs/griptape-framework/structures/src/tasks_6.py b/docs/griptape-framework/structures/src/tasks_6.py index a1b84e44da..ecd6f354f8 100644 --- a/docs/griptape-framework/structures/src/tasks_6.py +++ b/docs/griptape-framework/structures/src/tasks_6.py @@ -4,7 +4,9 @@ from griptape.tasks import ExtractionTask # Instantiate the CSV extraction engine -csv_extraction_engine = CsvExtractionEngine(prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo")) +csv_extraction_engine = CsvExtractionEngine( + prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo"), column_names=["Name", "Age", "Address"] +) # Define some unstructured data and columns csv_data = """ @@ -13,15 +15,12 @@ Charlie is 40 and lives in Texas. """ -columns = ["Name", "Age", "Address"] - # Create an agent and add the ExtractionTask to it agent = Agent() agent.add_task( ExtractionTask( extraction_engine=csv_extraction_engine, - args={"column_names": columns}, ) ) diff --git a/docs/griptape-framework/structures/src/tasks_7.py b/docs/griptape-framework/structures/src/tasks_7.py index 909d000845..1a3a32b29d 100644 --- a/docs/griptape-framework/structures/src/tasks_7.py +++ b/docs/griptape-framework/structures/src/tasks_7.py @@ -8,6 +8,7 @@ # Instantiate the json extraction engine json_extraction_engine = JsonExtractionEngine( prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo"), + template_schema=Schema({"users": [{"name": str, "age": int, "location": str}]}).json_schema("UserSchema"), ) # Define some unstructured data and a schema @@ -15,13 +16,12 @@ Alice (Age 28) lives in New York. Bob (Age 35) lives in California. """ -user_schema = Schema({"users": [{"name": str, "age": int, "location": str}]}).json_schema("UserSchema") + agent = Agent() agent.add_task( ExtractionTask( extraction_engine=json_extraction_engine, - args={"template_schema": user_schema}, ) ) diff --git a/griptape/artifacts/json_artifact.py b/griptape/artifacts/json_artifact.py index b292879a9c..26f69da7f4 100644 --- a/griptape/artifacts/json_artifact.py +++ b/griptape/artifacts/json_artifact.py @@ -7,11 +7,11 @@ from griptape.artifacts import BaseArtifact -Json = Union[dict[str, "Json"], list["Json"], str, int, float, bool, None] - @define class JsonArtifact(BaseArtifact): + Json = Union[dict[str, "Json"], list["Json"], str, int, float, bool, None] + value: Json = field(converter=lambda v: json.loads(json.dumps(v)), metadata={"serializable": True}) def to_text(self) -> str: diff --git a/griptape/engines/extraction/csv_extraction_engine.py b/griptape/engines/extraction/csv_extraction_engine.py index 40c9058ec5..70977f1f88 100644 --- a/griptape/engines/extraction/csv_extraction_engine.py +++ b/griptape/engines/extraction/csv_extraction_engine.py @@ -17,7 +17,7 @@ @define class CsvExtractionEngine(BaseExtractionEngine): - column_names: list[str] = field(default=Factory(list), kw_only=True) + column_names: list[str] = field(kw_only=True) system_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/system.j2")), kw_only=True) user_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/user.j2")), kw_only=True) diff --git a/griptape/engines/extraction/json_extraction_engine.py b/griptape/engines/extraction/json_extraction_engine.py index d3928c3ea7..c817efd5f7 100644 --- a/griptape/engines/extraction/json_extraction_engine.py +++ b/griptape/engines/extraction/json_extraction_engine.py @@ -18,9 +18,9 @@ @define class JsonExtractionEngine(BaseExtractionEngine): - JSON_PATTERN = r"(?s)(\{.*\}|\[.*\])" + JSON_PATTERN = r"(?s)[^\[]*(\[.*\])" - template_schema: dict = field(default=Factory(dict), kw_only=True) + template_schema: dict = field(kw_only=True) system_template_generator: J2 = field( default=Factory(lambda: J2("engines/extraction/json/system.j2")), kw_only=True ) @@ -42,7 +42,7 @@ def json_to_text_artifacts(self, json_input: str) -> list[JsonArtifact]: json_matches = re.findall(self.JSON_PATTERN, json_input, re.DOTALL) if json_matches: - return [JsonArtifact(json.loads(e)) for e in json_matches] + return [JsonArtifact(e) for e in json.loads(json_matches[-1])] else: return [] diff --git a/griptape/templates/engines/extraction/json/system.j2 b/griptape/templates/engines/extraction/json/system.j2 index b6bac028a7..987ff19a9e 100644 --- a/griptape/templates/engines/extraction/json/system.j2 +++ b/griptape/templates/engines/extraction/json/system.j2 @@ -1,5 +1,3 @@ -Your answer MUST be JSON that successfully validates against the Extraction Template JSON Schema. - Extraction Template JSON Schema: """{{ json_template_schema }}""" {% if rulesets %} diff --git a/griptape/templates/engines/extraction/json/user.j2 b/griptape/templates/engines/extraction/json/user.j2 index 9b0a611f0b..6d6311ab5f 100644 --- a/griptape/templates/engines/extraction/json/user.j2 +++ b/griptape/templates/engines/extraction/json/user.j2 @@ -1,4 +1,4 @@ -Extract information from the Text. +Extract information from the Text based on the Extraction Template JSON Schema into an array of JSON objects. Text: """{{ text }}""" -JSON: +JSON Array: diff --git a/tests/unit/engines/extraction/test_json_extraction_engine.py b/tests/unit/engines/extraction/test_json_extraction_engine.py index 938438af46..3cac363736 100644 --- a/tests/unit/engines/extraction/test_json_extraction_engine.py +++ b/tests/unit/engines/extraction/test_json_extraction_engine.py @@ -21,15 +21,16 @@ def engine(self): def test_extract_text(self, engine): result = engine.extract_text("foo") - assert len(result.value) == 1 - assert result.value[0].value == [{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}] + assert len(result.value) == 2 + assert result.value[0].value == {"test_key_1": "test_value_1"} + assert result.value[1].value == {"test_key_2": "test_value_2"} def test_chunked_extract_text(self, engine): large_text = Path(normpath(join(dirname(__file__), "../../../resources", "test.txt"))).read_text() extracted = engine.extract_text(large_text * 50) - assert len(extracted) == 177 - assert extracted[0].value == [{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}] + assert len(extracted) == 354 + assert extracted[0].value == {"test_key_1": "test_value_1"} def test_extract_error(self, engine): engine.template_schema = lambda: "non serializable" @@ -38,7 +39,8 @@ def test_extract_error(self, engine): def test_json_to_text_artifacts(self, engine): extracted = engine.json_to_text_artifacts('[{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}]') - assert extracted[0].value == [{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}] + assert extracted[0].value == {"test_key_1": "test_value_1"} + assert extracted[1].value == {"test_key_2": "test_value_2"} def test_json_to_text_artifacts_no_matches(self, engine): assert engine.json_to_text_artifacts("asdfasdfasdf") == [] diff --git a/tests/unit/tools/test_extraction_tool.py b/tests/unit/tools/test_extraction_tool.py index edf1663d21..1dbb9def26 100644 --- a/tests/unit/tools/test_extraction_tool.py +++ b/tests/unit/tools/test_extraction_tool.py @@ -39,14 +39,16 @@ def test_json_extract_artifacts(self, json_tool): {"values": {"data": {"memory_name": json_tool.input_memory[0].name, "artifact_namespace": "foo"}}} ) - assert len(result.value) == 1 - assert result.value[0].value == [{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}] + assert len(result.value) == 2 + assert result.value[0].value == {"test_key_1": "test_value_1"} + assert result.value[1].value == {"test_key_2": "test_value_2"} def test_json_extract_content(self, json_tool): result = json_tool.extract({"values": {"data": "foo"}}) - assert len(result.value) == 1 - assert result.value[0].value == [{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}] + assert len(result.value) == 2 + assert result.value[0].value == {"test_key_1": "test_value_1"} + assert result.value[1].value == {"test_key_2": "test_value_2"} def test_csv_extract_artifacts(self, csv_tool): csv_tool.input_memory[0].store_artifact("foo", TextArtifact("foo,bar\nbaz,maz"))