diff --git a/.gitignore b/.gitignore index c4834bfc..197961be 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ workloads/* *pytest_cache* *ruff_cache* motion-old* +venv/ # dependencies website/node_modules diff --git a/docetl/operations/utils/validation.py b/docetl/operations/utils/validation.py index f193d93d..5bbc3be6 100644 --- a/docetl/operations/utils/validation.py +++ b/docetl/operations/utils/validation.py @@ -103,6 +103,10 @@ def convert_val(value: Any, model: str = "gpt-4o-mini") -> Dict[str, Any]: if "gemini" not in model: result["additionalProperties"] = False return result + elif value.startswith("enum[") and value.endswith("]"): + enum_values = value[5:-1].strip().split(",") + enum_values = [v.strip() for v in enum_values] + return {"type": "string", "enum": enum_values} else: raise ValueError(f"Unsupported value type: {value}") diff --git a/docs/concepts/operators.md b/docs/concepts/operators.md index 43ef2200..f5ee3e23 100644 --- a/docs/concepts/operators.md +++ b/docs/concepts/operators.md @@ -88,7 +88,7 @@ prompt: | ## Output Schema -The `output` attribute defines the structure of the LLM's response. It supports various data types: +The `output` attribute defines the structure of the LLM's response. It supports various data types (see [schemas](../concepts/schemas.md) for more details): - `string` (or `str`, `text`, `varchar`): For text data - `integer` (or `int`): For whole numbers @@ -96,6 +96,7 @@ The `output` attribute defines the structure of the LLM's response. It supports - `boolean` (or `bool`): For true/false values - `list`: For arrays or sequences of items - objects: Using notation `{field: type}` +- `enum`: For a set of possible values Example: diff --git a/docs/concepts/schemas.md b/docs/concepts/schemas.md index 8f6b103f..6f359ac4 100644 --- a/docs/concepts/schemas.md +++ b/docs/concepts/schemas.md @@ -22,6 +22,7 @@ Schemas are defined in the `output` section of an operator. They support various | `integer` | `int` | For whole numbers | | `number` | `float`, `decimal` | For decimal numbers | | `boolean` | `bool` | For true/false values | +| `enum` | - | For a set of possible values | | `list` | - | For arrays or sequences of items (must specify element type) | | Objects | - | Using notation `{field: type}` | @@ -72,6 +73,24 @@ Objects are defined using curly braces and must have typed fields: Make sure that you put the type in quotation marks, if it references an object type (i.e., has curly braces)! Otherwise the yaml won't compile! +## Enum Types + +You can also specify enum types, which will be validated against a set of possible values. Suppose we have an operation to extract sentiments from a document, and we want to ensure that the sentiment is one of the three possible values. Our schema would look like this: + +```yaml +output: + schema: + sentiment: "enum[positive, negative, neutral]" +``` + +You can also specify a list of enum types (say, if we wanted to extract _multiple_ sentiments from a document): + +```yaml +output: + schema: + possible_sentiments: "list[enum[positive, negative, neutral]]" +``` + ## Structured Outputs and Tool API DocETL uses structured outputs or tool API to enforce schema typing. This ensures that the LLM outputs adhere to the specified schema, making the results more consistent and easier to process in subsequent operations. diff --git a/tests/basic/test_basic_map.py b/tests/basic/test_basic_map.py index a1919f6f..cec7c06f 100644 --- a/tests/basic/test_basic_map.py +++ b/tests/basic/test_basic_map.py @@ -218,6 +218,44 @@ def test_map_operation_with_gleaning(simple_map_config, map_sample_data, api_wra any(vs in result["sentiment"] for vs in valid_sentiments) for result in results ) +def test_map_with_enum_output(simple_map_config, map_sample_data, api_wrapper): + map_config_with_enum_output = { + **simple_map_config, + "output": {"schema": {"sentiment": "enum[positive, negative, neutral]"}}, + "bypass_cache": True, + } + + operation = MapOperation(api_wrapper, map_config_with_enum_output, "gpt-4o-mini", 4) + results, cost = operation.execute(map_sample_data) + + assert len(results) == len(map_sample_data) + assert all("sentiment" in result for result in results) + assert all(result["sentiment"] in ["positive", "negative", "neutral"] for result in results) + + # # Try gemini model + # map_config_with_enum_output["model"] = "gemini/gemini-1.5-flash" + # operation = MapOperation(api_wrapper, map_config_with_enum_output, "gemini/gemini-1.5-flash", 4) + # results, cost = operation.execute(map_sample_data) + + # assert len(results) == len(map_sample_data) + # assert all("sentiment" in result for result in results) + # assert all(result["sentiment"] in ["positive", "negative", "neutral"] for result in results) + # assert cost > 0 + + # Try list of enum types + map_config_with_enum_output["output"] = {"schema": {"possible_sentiments": "list[enum[positive, negative, neutral]]"}} + operation = MapOperation(api_wrapper, map_config_with_enum_output, "gpt-4o-mini", 4) + results, cost = operation.execute(map_sample_data) + assert cost > 0 + + assert len(results) == len(map_sample_data) + assert all("possible_sentiments" in result for result in results) + for result in results: + for ps in result["possible_sentiments"]: + assert ps in ["positive", "negative", "neutral"] + + + def test_map_operation_with_batch_processing(simple_map_config, map_sample_data, api_wrapper): # Add batch processing configuration map_config_with_batch = {