Skip to content

Commit

Permalink
Do not overwrite n in the config when samples not specified
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Jan 26, 2024
1 parent ead42b3 commit 80cf3c6
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
5 changes: 4 additions & 1 deletion outlines/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def __call__(
stop_at: Optional[Union[List[str], str]] = None,
*,
temperature: float = 1.0,
samples: int = 1,
samples: Optional[int] = None,
) -> np.ndarray:
"""Call the OpenAI API to generate text.
Expand All @@ -176,6 +176,9 @@ def __call__(
Up to 4 words where the API will stop the completion.
"""
if samples is None:
samples = self.config.n

config = replace(self.config, max_tokens=max_tokens, n=samples, stop=stop_at) # type: ignore

if isinstance(stop_at, list) and len(stop_at) > 4:
Expand Down
5 changes: 4 additions & 1 deletion outlines/models/openai_compatible.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def __call__(
stop_at: Optional[Union[List[str], str]] = None,
*,
temperature: float = 1.0,
samples: int = 1,
samples: Optional[int] = None,
) -> np.ndarray:
"""Call the OpenAI compatible API to generate text.
Expand All @@ -124,6 +124,9 @@ def __call__(
Up to 4 words where the API will stop the completion.
"""
if samples is None:
samples = self.config.n

config = replace(self.config, max_tokens=max_tokens, n=samples, stop=stop_at, temperature=temperature) # type: ignore

# We assume it's using the chat completion API style as that's the most commonly supported
Expand Down

0 comments on commit 80cf3c6

Please sign in to comment.