diff --git a/outlines/generate/__init__.py b/outlines/generate/__init__.py index a8285393c..be98b89ca 100644 --- a/outlines/generate/__init__.py +++ b/outlines/generate/__init__.py @@ -1,4 +1,5 @@ -from .api import SequenceGenerator, cfg +from .api import SequenceGenerator +from .cfg import cfg from .choice import choice from .format import format from .json import json diff --git a/outlines/generate/api.py b/outlines/generate/api.py index b14aae4a7..ca5b58962 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -4,14 +4,12 @@ import torch -from outlines.fsm.fsm import CFGFSM from outlines.generate.generator import ( GenerationState, init_generator_state, sequence_generator, token_generator, ) -from outlines.generate.samplers import Sampler, multinomial class SequenceGenerator: @@ -339,17 +337,3 @@ def token_generator() -> Iterator[Union[List[str], str]]: yield next_tokens return token_generator() - - -def cfg( - model, - cfg_str: str, - max_tokens: Optional[int] = None, - sampler: Sampler = multinomial, -): - fsm = CFGFSM(cfg_str, model.tokenizer) - - device = model.device - generator = SequenceGenerator(fsm, model, sampler, device, max_tokens=max_tokens) - - return generator diff --git a/outlines/generate/cfg.py b/outlines/generate/cfg.py new file mode 100644 index 000000000..ddedcef31 --- /dev/null +++ b/outlines/generate/cfg.py @@ -0,0 +1,39 @@ +from functools import singledispatch +from typing import List, Optional, Union + +from outlines.fsm.fsm import CFGFSM +from outlines.generate.api import SequenceGenerator +from outlines.generate.samplers import Sampler, multinomial +from outlines.models import OpenAI + + +@singledispatch +def cfg( + model, + cfg_str: str, + max_tokens: Optional[int] = None, + stop_at: Optional[Union[str, List[str]]] = None, + sampler: Sampler = multinomial, +): + fsm = CFGFSM(cfg_str, model.tokenizer) + + device = model.device + generator = SequenceGenerator( + fsm, model, sampler, device, max_tokens=max_tokens, stop_at=stop_at + ) + + return generator + + +@cfg.register(OpenAI) +def cfg_openai( + model, + cfg_str: str, + max_tokens: Optional[int] = None, + stop_at: Optional[Union[str, List[str]]] = None, + sampler: Sampler = multinomial, +): + raise NotImplementedError( + "Cannot use grammar-structured generation with an OpenAI model" + + "due to the limitations of the OpenAI API." + )