Skip to content

Commit

Permalink
Allow stop keyword in Regex
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Oct 12, 2023
1 parent f6e33dd commit fb465db
Showing 1 changed file with 35 additions and 7 deletions.
42 changes: 35 additions & 7 deletions outlines/text/generate/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def __init__(
model,
regex_string: str,
max_tokens: Optional[int] = None,
*,
sampler: Optional["Sampler"] = None,
stop: Union[str, List[str]] = [],
allow_empty_tokens: bool = True,
initial_state: Optional[int] = None,
final_states: Optional[Set[int]] = None,
Expand All @@ -62,6 +64,8 @@ def __init__(
`outlines.text.generate.sample.multinomial`. See
`outlines.text.generate.sample.Sampler` for the expected form of
such functions.
stop
Optional stopping string(s).
allow_empty_tokens
Allow sampling of tokens corresponding to empty strings.
states_to_token_maps
Expand All @@ -71,7 +75,7 @@ def __init__(
Pre-computed set of token ids for tokens that are empty strings.
"""
super().__init__(model, max_tokens, sampler)
super().__init__(model, max_tokens, sampler, stop)

if (
states_to_token_maps is None
Expand Down Expand Up @@ -248,7 +252,13 @@ def regex(
Allow sampling of tokens corresponding to empty strings.
"""
return Regex(model, regex_string, max_tokens, sampler, allow_empty_tokens)
return Regex(
model,
regex_string,
max_tokens,
sampler=sampler,
allow_empty_tokens=allow_empty_tokens,
)


def integer(
Expand Down Expand Up @@ -284,7 +294,13 @@ def integer(
Allow sampling of tokens corresponding to empty strings.
"""
return Regex(model, r"[-+]?\d+", max_tokens, sampler, allow_empty_tokens)
return Regex(
model,
r"[-+]?\d+",
max_tokens,
sampler=sampler,
allow_empty_tokens=allow_empty_tokens,
)


def float(
Expand Down Expand Up @@ -324,8 +340,8 @@ def float(
model,
r"([+-]?((0|[1-9]+)([.][0-9]*)?)|([.][0-9]+))",
max_tokens,
sampler,
allow_empty_tokens,
sampler=sampler,
allow_empty_tokens=allow_empty_tokens,
)


Expand Down Expand Up @@ -359,7 +375,13 @@ def choice(
Allow sampling of tokens corresponding to empty strings.
"""
regex_str = r"(" + r"|".join(choices) + r")"
return Regex(model, regex_str, max_tokens, sampler, allow_empty_tokens)
return Regex(
model,
regex_str,
max_tokens,
sampler=sampler,
allow_empty_tokens=allow_empty_tokens,
)


def json(
Expand Down Expand Up @@ -399,4 +421,10 @@ def json(

regex_str = build_regex_from_schema(schema)

return Regex(model, regex_str, max_tokens, sampler, allow_empty_tokens)
return Regex(
model,
regex_str,
max_tokens,
sampler=sampler,
allow_empty_tokens=allow_empty_tokens,
)

0 comments on commit fb465db

Please sign in to comment.