Skip to content

Commit

Permalink
Dispatch cfg based on the model passed by the user
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Jan 26, 2024
1 parent f96ae5c commit 6c93534
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 17 deletions.
3 changes: 2 additions & 1 deletion outlines/generate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .api import SequenceGenerator, cfg
from .api import SequenceGenerator
from .cfg import cfg
from .choice import choice
from .format import format
from .json import json
Expand Down
16 changes: 0 additions & 16 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@

import torch

from outlines.fsm.fsm import CFGFSM
from outlines.generate.generator import (
GenerationState,
init_generator_state,
sequence_generator,
token_generator,
)
from outlines.generate.samplers import Sampler, multinomial


class SequenceGenerator:
Expand Down Expand Up @@ -339,17 +337,3 @@ def token_generator() -> Iterator[Union[List[str], str]]:
yield next_tokens

return token_generator()


def cfg(
model,
cfg_str: str,
max_tokens: Optional[int] = None,
sampler: Sampler = multinomial,
):
fsm = CFGFSM(cfg_str, model.tokenizer)

device = model.device
generator = SequenceGenerator(fsm, model, sampler, device, max_tokens=max_tokens)

return generator
39 changes: 39 additions & 0 deletions outlines/generate/cfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from functools import singledispatch
from typing import List, Optional, Union

from outlines.fsm.fsm import CFGFSM
from outlines.generate.api import SequenceGenerator
from outlines.generate.samplers import Sampler, multinomial
from outlines.models import OpenAI


@singledispatch
def cfg(
model,
cfg_str: str,
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
sampler: Sampler = multinomial,
):
fsm = CFGFSM(cfg_str, model.tokenizer)

device = model.device
generator = SequenceGenerator(
fsm, model, sampler, device, max_tokens=max_tokens, stop_at=stop_at
)

return generator


@cfg.register(OpenAI)
def cfg_openai(
model,
cfg_str: str,
max_tokens: Optional[int] = None,
stop_at: Optional[Union[str, List[str]]] = None,
sampler: Sampler = multinomial,
):
raise NotImplementedError(
"Cannot use grammar-structured generation with an OpenAI model"
+ "due to the limitations of the OpenAI API."
)

0 comments on commit 6c93534

Please sign in to comment.