diff --git a/outlines/text/generate/regex.py b/outlines/text/generate/regex.py index 7eedb2abc..af87f9f51 100644 --- a/outlines/text/generate/regex.py +++ b/outlines/text/generate/regex.py @@ -40,7 +40,9 @@ def __init__( model, regex_string: str, max_tokens: Optional[int] = None, + *, sampler: Optional["Sampler"] = None, + stop: Union[str, List[str]] = [], allow_empty_tokens: bool = True, initial_state: Optional[int] = None, final_states: Optional[Set[int]] = None, @@ -62,6 +64,8 @@ def __init__( `outlines.text.generate.sample.multinomial`. See `outlines.text.generate.sample.Sampler` for the expected form of such functions. + stop + Optional stopping string(s). allow_empty_tokens Allow sampling of tokens corresponding to empty strings. states_to_token_maps @@ -71,7 +75,7 @@ def __init__( Pre-computed set of token ids for tokens that are empty strings. """ - super().__init__(model, max_tokens, sampler) + super().__init__(model, max_tokens, sampler, stop) if ( states_to_token_maps is None @@ -248,7 +252,13 @@ def regex( Allow sampling of tokens corresponding to empty strings. """ - return Regex(model, regex_string, max_tokens, sampler, allow_empty_tokens) + return Regex( + model, + regex_string, + max_tokens, + sampler=sampler, + allow_empty_tokens=allow_empty_tokens, + ) def integer( @@ -284,7 +294,13 @@ def integer( Allow sampling of tokens corresponding to empty strings. """ - return Regex(model, r"[-+]?\d+", max_tokens, sampler, allow_empty_tokens) + return Regex( + model, + r"[-+]?\d+", + max_tokens, + sampler=sampler, + allow_empty_tokens=allow_empty_tokens, + ) def float( @@ -324,8 +340,8 @@ def float( model, r"([+-]?((0|[1-9]+)([.][0-9]*)?)|([.][0-9]+))", max_tokens, - sampler, - allow_empty_tokens, + sampler=sampler, + allow_empty_tokens=allow_empty_tokens, ) @@ -359,7 +375,13 @@ def choice( Allow sampling of tokens corresponding to empty strings. """ regex_str = r"(" + r"|".join(choices) + r")" - return Regex(model, regex_str, max_tokens, sampler, allow_empty_tokens) + return Regex( + model, + regex_str, + max_tokens, + sampler=sampler, + allow_empty_tokens=allow_empty_tokens, + ) def json( @@ -399,4 +421,10 @@ def json( regex_str = build_regex_from_schema(schema) - return Regex(model, regex_str, max_tokens, sampler, allow_empty_tokens) + return Regex( + model, + regex_str, + max_tokens, + sampler=sampler, + allow_empty_tokens=allow_empty_tokens, + )