Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added enum support #254

Merged
merged 3 commits into from
Dec 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ workloads/*
*pytest_cache*
*ruff_cache*
motion-old*
venv/

# dependencies
website/node_modules
Expand Down
4 changes: 4 additions & 0 deletions docetl/operations/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ def convert_val(value: Any, model: str = "gpt-4o-mini") -> Dict[str, Any]:
if "gemini" not in model:
result["additionalProperties"] = False
return result
elif value.startswith("enum[") and value.endswith("]"):
enum_values = value[5:-1].strip().split(",")
enum_values = [v.strip() for v in enum_values]
return {"type": "string", "enum": enum_values}
else:
raise ValueError(f"Unsupported value type: {value}")

Expand Down
3 changes: 2 additions & 1 deletion docs/concepts/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,15 @@ prompt: |

## Output Schema

The `output` attribute defines the structure of the LLM's response. It supports various data types:
The `output` attribute defines the structure of the LLM's response. It supports various data types (see [schemas](../concepts/schemas.md) for more details):

- `string` (or `str`, `text`, `varchar`): For text data
- `integer` (or `int`): For whole numbers
- `number` (or `float`, `decimal`): For decimal numbers
- `boolean` (or `bool`): For true/false values
- `list`: For arrays or sequences of items
- objects: Using notation `{field: type}`
- `enum`: For a set of possible values

Example:

Expand Down
19 changes: 19 additions & 0 deletions docs/concepts/schemas.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Schemas are defined in the `output` section of an operator. They support various
| `integer` | `int` | For whole numbers |
| `number` | `float`, `decimal` | For decimal numbers |
| `boolean` | `bool` | For true/false values |
| `enum` | - | For a set of possible values |
| `list` | - | For arrays or sequences of items (must specify element type) |
| Objects | - | Using notation `{field: type}` |

Expand Down Expand Up @@ -72,6 +73,24 @@ Objects are defined using curly braces and must have typed fields:

Make sure that you put the type in quotation marks, if it references an object type (i.e., has curly braces)! Otherwise the yaml won't compile!

## Enum Types

You can also specify enum types, which will be validated against a set of possible values. Suppose we have an operation to extract sentiments from a document, and we want to ensure that the sentiment is one of the three possible values. Our schema would look like this:

```yaml
output:
schema:
sentiment: "enum[positive, negative, neutral]"
```

You can also specify a list of enum types (say, if we wanted to extract _multiple_ sentiments from a document):

```yaml
output:
schema:
possible_sentiments: "list[enum[positive, negative, neutral]]"
```

## Structured Outputs and Tool API

DocETL uses structured outputs or tool API to enforce schema typing. This ensures that the LLM outputs adhere to the specified schema, making the results more consistent and easier to process in subsequent operations.
Expand Down
38 changes: 38 additions & 0 deletions tests/basic/test_basic_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,44 @@ def test_map_operation_with_gleaning(simple_map_config, map_sample_data, api_wra
any(vs in result["sentiment"] for vs in valid_sentiments) for result in results
)

def test_map_with_enum_output(simple_map_config, map_sample_data, api_wrapper):
map_config_with_enum_output = {
**simple_map_config,
"output": {"schema": {"sentiment": "enum[positive, negative, neutral]"}},
"bypass_cache": True,
}

operation = MapOperation(api_wrapper, map_config_with_enum_output, "gpt-4o-mini", 4)
results, cost = operation.execute(map_sample_data)

assert len(results) == len(map_sample_data)
assert all("sentiment" in result for result in results)
assert all(result["sentiment"] in ["positive", "negative", "neutral"] for result in results)

# # Try gemini model
# map_config_with_enum_output["model"] = "gemini/gemini-1.5-flash"
# operation = MapOperation(api_wrapper, map_config_with_enum_output, "gemini/gemini-1.5-flash", 4)
# results, cost = operation.execute(map_sample_data)

# assert len(results) == len(map_sample_data)
# assert all("sentiment" in result for result in results)
# assert all(result["sentiment"] in ["positive", "negative", "neutral"] for result in results)
# assert cost > 0

# Try list of enum types
map_config_with_enum_output["output"] = {"schema": {"possible_sentiments": "list[enum[positive, negative, neutral]]"}}
operation = MapOperation(api_wrapper, map_config_with_enum_output, "gpt-4o-mini", 4)
results, cost = operation.execute(map_sample_data)
assert cost > 0

assert len(results) == len(map_sample_data)
assert all("possible_sentiments" in result for result in results)
for result in results:
for ps in result["possible_sentiments"]:
assert ps in ["positive", "negative", "neutral"]



def test_map_operation_with_batch_processing(simple_map_config, map_sample_data, api_wrapper):
# Add batch processing configuration
map_config_with_batch = {
Expand Down
Loading