Skip to content

Commit

Permalink
Revert some, update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Aug 21, 2024
1 parent e92e085 commit 317ec50
Show file tree
Hide file tree
Showing 12 changed files with 64 additions and 46 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
- **BREAKING**: Split `BaseExtractionEngine.extract` into `extract` and `extract_artifacts` for consistency with `BaseSummaryEngine`.
- **BREAKING**: `BaseExtractionEngine` no longer catches exceptions and returns `ErrorArtifact`s.
- `JsonExtractionEngine` to extract either a JSON object or array depending on the provided schema.
- **BREAKING**: `JsonExtractionEngine.template_schema` is now required.
- **BREAKING**: `CsvExtractionEngine.column_names` is now required.
- `JsonExtractionEngine.extract_artifacts` now returns a `ListArtifact[JsonArtifact]`.
- `CsvExtractionEngine.extract_artifacts` now returns a `ListArtifact[CsvRowArtifact]`.

Expand Down
27 changes: 18 additions & 9 deletions docs/griptape-framework/engines/extraction-engines.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@ As of now, Griptape supports two types of Extraction Engines: the CSV Extraction

## CSV

The CSV Extraction Engine is designed specifically for extracting data from CSV-formatted content.

!!! info
The CSV Extraction Engine requires the `column_names` parameter for specifying the columns to be extracted.
The CSV Extraction Engine is designed for extracting CSV-formatted content from unstructured data.

```python
--8<-- "docs/griptape-framework/engines/src/extraction_engines_1.py"
Expand All @@ -27,15 +24,27 @@ Charlie,40,Texas

## JSON

The JSON Extraction Engine is tailored for extracting data from JSON-formatted content.
The JSON Extraction Engine is designed for extracting JSON-formatted content from unstructed data.

!!! info
The JSON Extraction Engine requires the `template_schema` parameter for specifying the structure to be extracted.

```python
--8<-- "docs/griptape-framework/engines/src/extraction_engines_2.py"
```
```
{'name': 'Alice', 'age': 28, 'location': 'New York'}
{'name': 'Bob', 'age': 35, 'location': 'California'}
{
"model": "GPT-3.5",
"notes": [
"Part of OpenAI's GPT series.",
"Used in ChatGPT and Microsoft Copilot."
]
}
{
"model": "GPT-4",
"notes": [
"Part of OpenAI's GPT series.",
"Praised for increased accuracy and multimodal capabilities.",
"Architecture and number of parameters not revealed."
]
}
...Output truncated for brevity...
```
29 changes: 18 additions & 11 deletions docs/griptape-framework/engines/src/extraction_engines_2.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,31 @@
from schema import Schema
import json

from schema import Literal, Schema

from griptape.artifacts import ErrorArtifact, ListArtifact
from griptape.drivers import OpenAiChatPromptDriver
from griptape.engines import JsonExtractionEngine
from griptape.loaders import WebLoader

# Define a schema for extraction
user_schema = Schema([{"name": str, "age": int, "location": str}]).json_schema("UserSchema")

json_engine = JsonExtractionEngine(
prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo"), template_schema=user_schema
prompt_driver=OpenAiChatPromptDriver(model="gpt-4o"),
template_schema=Schema(
{
Literal("model", description="Name of an LLM model."): str,
Literal("notes", description="Any notes of substance about the model."): Schema([str]),
}
).json_schema("ProductSchema"),
)

# Define some unstructured data
sample_json_text = """
Alice (Age 28) lives in New York.
Bob (Age 35) lives in California.
"""
# Load data from the web
web_data = WebLoader().load("https://en.wikipedia.org/wiki/Large_language_model")

if isinstance(web_data, ErrorArtifact):
raise Exception(web_data.value)

# Extract data using the engine
result = json_engine.extract_text(sample_json_text)
result = json_engine.extract_artifacts(ListArtifact(web_data))

for artifact in result:
print(artifact.value)
print(json.dumps(artifact.value, indent=2))
7 changes: 3 additions & 4 deletions docs/griptape-framework/structures/src/tasks_6.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from griptape.tasks import ExtractionTask

# Instantiate the CSV extraction engine
csv_extraction_engine = CsvExtractionEngine(prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo"))
csv_extraction_engine = CsvExtractionEngine(
prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo"), column_names=["Name", "Age", "Address"]
)

# Define some unstructured data and columns
csv_data = """
Expand All @@ -13,15 +15,12 @@
Charlie is 40 and lives in Texas.
"""

columns = ["Name", "Age", "Address"]


# Create an agent and add the ExtractionTask to it
agent = Agent()
agent.add_task(
ExtractionTask(
extraction_engine=csv_extraction_engine,
args={"column_names": columns},
)
)

Expand Down
4 changes: 2 additions & 2 deletions docs/griptape-framework/structures/src/tasks_7.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,20 @@
# Instantiate the json extraction engine
json_extraction_engine = JsonExtractionEngine(
prompt_driver=OpenAiChatPromptDriver(model="gpt-3.5-turbo"),
template_schema=Schema({"users": [{"name": str, "age": int, "location": str}]}).json_schema("UserSchema"),
)

# Define some unstructured data and a schema
json_data = """
Alice (Age 28) lives in New York.
Bob (Age 35) lives in California.
"""
user_schema = Schema({"users": [{"name": str, "age": int, "location": str}]}).json_schema("UserSchema")


agent = Agent()
agent.add_task(
ExtractionTask(
extraction_engine=json_extraction_engine,
args={"template_schema": user_schema},
)
)

Expand Down
4 changes: 2 additions & 2 deletions griptape/artifacts/json_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@

from griptape.artifacts import BaseArtifact

Json = Union[dict[str, "Json"], list["Json"], str, int, float, bool, None]


@define
class JsonArtifact(BaseArtifact):
Json = Union[dict[str, "Json"], list["Json"], str, int, float, bool, None]

value: Json = field(converter=lambda v: json.loads(json.dumps(v)), metadata={"serializable": True})

def to_text(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion griptape/engines/extraction/csv_extraction_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

@define
class CsvExtractionEngine(BaseExtractionEngine):
column_names: list[str] = field(default=Factory(list), kw_only=True)
column_names: list[str] = field(kw_only=True)
system_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/system.j2")), kw_only=True)
user_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/user.j2")), kw_only=True)

Expand Down
6 changes: 3 additions & 3 deletions griptape/engines/extraction/json_extraction_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@

@define
class JsonExtractionEngine(BaseExtractionEngine):
JSON_PATTERN = r"(?s)(\{.*\}|\[.*\])"
JSON_PATTERN = r"(?s)[^\[]*(\[.*\])"

template_schema: dict = field(default=Factory(dict), kw_only=True)
template_schema: dict = field(kw_only=True)
system_template_generator: J2 = field(
default=Factory(lambda: J2("engines/extraction/json/system.j2")), kw_only=True
)
Expand All @@ -42,7 +42,7 @@ def json_to_text_artifacts(self, json_input: str) -> list[JsonArtifact]:
json_matches = re.findall(self.JSON_PATTERN, json_input, re.DOTALL)

if json_matches:
return [JsonArtifact(json.loads(e)) for e in json_matches]
return [JsonArtifact(e) for e in json.loads(json_matches[-1])]
else:
return []

Expand Down
2 changes: 0 additions & 2 deletions griptape/templates/engines/extraction/json/system.j2
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
Your answer MUST be JSON that successfully validates against the Extraction Template JSON Schema.

Extraction Template JSON Schema: """{{ json_template_schema }}"""

{% if rulesets %}
Expand Down
4 changes: 2 additions & 2 deletions griptape/templates/engines/extraction/json/user.j2
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Extract information from the Text.
Extract information from the Text based on the Extraction Template JSON Schema into an array of JSON objects.
Text: """{{ text }}"""

JSON:
JSON Array:
12 changes: 7 additions & 5 deletions tests/unit/engines/extraction/test_json_extraction_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,16 @@ def engine(self):
def test_extract_text(self, engine):
result = engine.extract_text("foo")

assert len(result.value) == 1
assert result.value[0].value == [{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}]
assert len(result.value) == 2
assert result.value[0].value == {"test_key_1": "test_value_1"}
assert result.value[1].value == {"test_key_2": "test_value_2"}

def test_chunked_extract_text(self, engine):
large_text = Path(normpath(join(dirname(__file__), "../../../resources", "test.txt"))).read_text()

extracted = engine.extract_text(large_text * 50)
assert len(extracted) == 177
assert extracted[0].value == [{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}]
assert len(extracted) == 354
assert extracted[0].value == {"test_key_1": "test_value_1"}

def test_extract_error(self, engine):
engine.template_schema = lambda: "non serializable"
Expand All @@ -38,7 +39,8 @@ def test_extract_error(self, engine):

def test_json_to_text_artifacts(self, engine):
extracted = engine.json_to_text_artifacts('[{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}]')
assert extracted[0].value == [{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}]
assert extracted[0].value == {"test_key_1": "test_value_1"}
assert extracted[1].value == {"test_key_2": "test_value_2"}

def test_json_to_text_artifacts_no_matches(self, engine):
assert engine.json_to_text_artifacts("asdfasdfasdf") == []
10 changes: 6 additions & 4 deletions tests/unit/tools/test_extraction_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,16 @@ def test_json_extract_artifacts(self, json_tool):
{"values": {"data": {"memory_name": json_tool.input_memory[0].name, "artifact_namespace": "foo"}}}
)

assert len(result.value) == 1
assert result.value[0].value == [{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}]
assert len(result.value) == 2
assert result.value[0].value == {"test_key_1": "test_value_1"}
assert result.value[1].value == {"test_key_2": "test_value_2"}

def test_json_extract_content(self, json_tool):
result = json_tool.extract({"values": {"data": "foo"}})

assert len(result.value) == 1
assert result.value[0].value == [{"test_key_1": "test_value_1"}, {"test_key_2": "test_value_2"}]
assert len(result.value) == 2
assert result.value[0].value == {"test_key_1": "test_value_1"}
assert result.value[1].value == {"test_key_2": "test_value_2"}

def test_csv_extract_artifacts(self, csv_tool):
csv_tool.input_memory[0].store_artifact("foo", TextArtifact("foo,bar\nbaz,maz"))
Expand Down

0 comments on commit 317ec50

Please sign in to comment.