diff --git a/outlines/models/openai.py b/outlines/models/openai.py index 9e1a201a1..61909b638 100644 --- a/outlines/models/openai.py +++ b/outlines/models/openai.py @@ -158,7 +158,7 @@ def __call__( stop_at: Optional[Union[List[str], str]] = None, *, temperature: float = 1.0, - samples: int = 1, + samples: Optional[int] = None, ) -> np.ndarray: """Call the OpenAI API to generate text. @@ -176,6 +176,9 @@ def __call__( Up to 4 words where the API will stop the completion. """ + if samples is None: + samples = self.config.n + config = replace(self.config, max_tokens=max_tokens, n=samples, stop=stop_at) # type: ignore if isinstance(stop_at, list) and len(stop_at) > 4: diff --git a/outlines/models/openai_compatible.py b/outlines/models/openai_compatible.py index 81ce744ea..40cefa298 100644 --- a/outlines/models/openai_compatible.py +++ b/outlines/models/openai_compatible.py @@ -106,7 +106,7 @@ def __call__( stop_at: Optional[Union[List[str], str]] = None, *, temperature: float = 1.0, - samples: int = 1, + samples: Optional[int] = None, ) -> np.ndarray: """Call the OpenAI compatible API to generate text. @@ -124,6 +124,9 @@ def __call__( Up to 4 words where the API will stop the completion. """ + if samples is None: + samples = self.config.n + config = replace(self.config, max_tokens=max_tokens, n=samples, stop=stop_at, temperature=temperature) # type: ignore # We assume it's using the chat completion API style as that's the most commonly supported