diff --git a/outlines/generate/api.py b/outlines/generate/api.py index 57b3f33da..0de579046 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -248,5 +248,11 @@ def json( regex_str = build_regex_from_object(schema) generator = regex(model, regex_str, max_tokens, sampler) generator.format_sequence = lambda x: pyjson.loads(x) + else: + raise ValueError( + f"Cannot parse schema {schema_object}. The schema must be either " + + "a Pydantic object, a function or a string that contains the JSON " + + "Schema specification" + ) return generator diff --git a/tests/generate/test_generator.py b/tests/generate/test_generator.py index b0aaf2b41..b1fdf2409 100644 --- a/tests/generate/test_generator.py +++ b/tests/generate/test_generator.py @@ -21,10 +21,10 @@ def test_sequence_generator_class(): class MockFSM: - def next_state(self, state, next_token_ids): + def next_state(self, state, next_token_ids, _): return 4 - def allowed_token_ids(self, _): + def allowed_token_ids(self, *_): return [4] def is_final_state(self, _, idx=0): diff --git a/tests/generate/test_integration_transfomers.py b/tests/generate/test_integration_transfomers.py index 94d668416..1bfbeba94 100644 --- a/tests/generate/test_integration_transfomers.py +++ b/tests/generate/test_integration_transfomers.py @@ -276,6 +276,30 @@ class Spam(BaseModel): assert len(result.spam) <= 10 +def test_transformers_json_schema(): + model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" + model = models.transformers(model_name, device="cpu") + prompt = "Output some JSON " + + schema = """{ + "title": "spam", + "type": "object", + "properties": { + "foo" : {"type": "integer"}, + "bar": {"type": "string", "maxLength": 4} + } + } + """ + + rng = torch.Generator() + rng.manual_seed(0) # make sure that `bar` is not an int + + result = generate.json(model, schema, max_tokens=500)(prompt, rng=rng) + assert isinstance(result, dict) + assert isinstance(result["foo"], int) + assert isinstance(result["bar"], str) + + def test_transformers_json_batch(): model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" model = models.transformers(model_name, device="cpu")