Skip to content

Commit

Permalink
added enum support (#254)
Browse files Browse the repository at this point in the history
* added enum support

* tests: add test for enum type output

* docs: update docs to support enum type schemas

---------

Co-authored-by: Shreya Shankar <[email protected]>
  • Loading branch information
rrawatt and shreyashankar authored Dec 26, 2024
1 parent a377c20 commit 0e077aa
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 1 deletion.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ workloads/*
*pytest_cache*
*ruff_cache*
motion-old*
venv/

# dependencies
website/node_modules
Expand Down
4 changes: 4 additions & 0 deletions docetl/operations/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ def convert_val(value: Any, model: str = "gpt-4o-mini") -> Dict[str, Any]:
if "gemini" not in model:
result["additionalProperties"] = False
return result
elif value.startswith("enum[") and value.endswith("]"):
enum_values = value[5:-1].strip().split(",")
enum_values = [v.strip() for v in enum_values]
return {"type": "string", "enum": enum_values}
else:
raise ValueError(f"Unsupported value type: {value}")

Expand Down
3 changes: 2 additions & 1 deletion docs/concepts/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,15 @@ prompt: |
## Output Schema
The `output` attribute defines the structure of the LLM's response. It supports various data types:
The `output` attribute defines the structure of the LLM's response. It supports various data types (see [schemas](../concepts/schemas.md) for more details):

- `string` (or `str`, `text`, `varchar`): For text data
- `integer` (or `int`): For whole numbers
- `number` (or `float`, `decimal`): For decimal numbers
- `boolean` (or `bool`): For true/false values
- `list`: For arrays or sequences of items
- objects: Using notation `{field: type}`
- `enum`: For a set of possible values

Example:

Expand Down
19 changes: 19 additions & 0 deletions docs/concepts/schemas.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Schemas are defined in the `output` section of an operator. They support various
| `integer` | `int` | For whole numbers |
| `number` | `float`, `decimal` | For decimal numbers |
| `boolean` | `bool` | For true/false values |
| `enum` | - | For a set of possible values |
| `list` | - | For arrays or sequences of items (must specify element type) |
| Objects | - | Using notation `{field: type}` |

Expand Down Expand Up @@ -72,6 +73,24 @@ Objects are defined using curly braces and must have typed fields:

Make sure that you put the type in quotation marks, if it references an object type (i.e., has curly braces)! Otherwise the yaml won't compile!

## Enum Types

You can also specify enum types, which will be validated against a set of possible values. Suppose we have an operation to extract sentiments from a document, and we want to ensure that the sentiment is one of the three possible values. Our schema would look like this:

```yaml
output:
schema:
sentiment: "enum[positive, negative, neutral]"
```

You can also specify a list of enum types (say, if we wanted to extract _multiple_ sentiments from a document):

```yaml
output:
schema:
possible_sentiments: "list[enum[positive, negative, neutral]]"
```

## Structured Outputs and Tool API

DocETL uses structured outputs or tool API to enforce schema typing. This ensures that the LLM outputs adhere to the specified schema, making the results more consistent and easier to process in subsequent operations.
Expand Down
38 changes: 38 additions & 0 deletions tests/basic/test_basic_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,44 @@ def test_map_operation_with_gleaning(simple_map_config, map_sample_data, api_wra
any(vs in result["sentiment"] for vs in valid_sentiments) for result in results
)

def test_map_with_enum_output(simple_map_config, map_sample_data, api_wrapper):
map_config_with_enum_output = {
**simple_map_config,
"output": {"schema": {"sentiment": "enum[positive, negative, neutral]"}},
"bypass_cache": True,
}

operation = MapOperation(api_wrapper, map_config_with_enum_output, "gpt-4o-mini", 4)
results, cost = operation.execute(map_sample_data)

assert len(results) == len(map_sample_data)
assert all("sentiment" in result for result in results)
assert all(result["sentiment"] in ["positive", "negative", "neutral"] for result in results)

# # Try gemini model
# map_config_with_enum_output["model"] = "gemini/gemini-1.5-flash"
# operation = MapOperation(api_wrapper, map_config_with_enum_output, "gemini/gemini-1.5-flash", 4)
# results, cost = operation.execute(map_sample_data)

# assert len(results) == len(map_sample_data)
# assert all("sentiment" in result for result in results)
# assert all(result["sentiment"] in ["positive", "negative", "neutral"] for result in results)
# assert cost > 0

# Try list of enum types
map_config_with_enum_output["output"] = {"schema": {"possible_sentiments": "list[enum[positive, negative, neutral]]"}}
operation = MapOperation(api_wrapper, map_config_with_enum_output, "gpt-4o-mini", 4)
results, cost = operation.execute(map_sample_data)
assert cost > 0

assert len(results) == len(map_sample_data)
assert all("possible_sentiments" in result for result in results)
for result in results:
for ps in result["possible_sentiments"]:
assert ps in ["positive", "negative", "neutral"]



def test_map_operation_with_batch_processing(simple_map_config, map_sample_data, api_wrapper):
# Add batch processing configuration
map_config_with_batch = {
Expand Down

0 comments on commit 0e077aa

Please sign in to comment.