From 0409e151d98c8230308d15246f9eda2984eae6fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Mon, 20 Nov 2023 08:08:47 +0100 Subject: [PATCH] Refactor the OpenAI integration --- examples/math_generate_code.py | 2 +- examples/pick_odd_one_out.py | 2 +- examples/react.py | 8 +- examples/self_consistency.py | 2 +- outlines/models/__init__.py | 2 +- outlines/models/openai.py | 575 ++++++++++++++++----------- outlines/text/generate/sequence.py | 4 +- tests/models/test_openai.py | 27 ++ tests/text/generate/test_sequence.py | 4 +- 9 files changed, 379 insertions(+), 247 deletions(-) create mode 100644 tests/models/test_openai.py diff --git a/examples/math_generate_code.py b/examples/math_generate_code.py index 507a76ec0..df1418188 100644 --- a/examples/math_generate_code.py +++ b/examples/math_generate_code.py @@ -35,6 +35,6 @@ def execute_code(code): prompt = answer_with_code_prompt(question, examples) -answer = models.openai("text-davinci-003")(prompt) +answer = models.openai("gpt-4")(prompt) result = execute_code(answer) print(f"It takes Carla {result:.0f} minutes to download the file.") diff --git a/examples/pick_odd_one_out.py b/examples/pick_odd_one_out.py index 286125037..676c7e56e 100644 --- a/examples/pick_odd_one_out.py +++ b/examples/pick_odd_one_out.py @@ -29,7 +29,7 @@ def build_ooo_prompt(options): """ -model = models.openai("text-davinci-003") +model = models.openai("gpt-3.5-turbo") options = ["sea", "mountains", "plains", "sock"] prompt = build_ooo_prompt(options) diff --git a/examples/react.py b/examples/react.py index c3964cfa2..2a4a52627 100644 --- a/examples/react.py +++ b/examples/react.py @@ -45,17 +45,19 @@ def search_wikipedia(query: str): prompt = build_reAct_prompt("Where is Apple Computers headquarted? ") -complete = models.openai("gpt-3.5-turbo", temperature=1.0) +complete = models.openai("gpt-3.5-turbo") for i in range(1, 10): - mode = complete(prompt, is_in=["Tho", "Act"], max_tokens=128) + mode = complete.generate_choice(prompt, choices=["Tho", "Act"], max_tokens=128) prompt = add_mode(i, mode, "", prompt) if mode == "Tho": thought = complete(prompt, stop_at="\n", max_tokens=128) prompt += f"{thought}" elif mode == "Act": - action = complete(prompt, is_in=["Search", "Finish"], max_tokens=128) + action = complete.generate_choice( + prompt, choices=["Search", "Finish"], max_tokens=128 + ) prompt += f"{action} '" subject = complete( diff --git a/examples/self_consistency.py b/examples/self_consistency.py index 6aded6e67..396c1a45f 100644 --- a/examples/self_consistency.py +++ b/examples/self_consistency.py @@ -55,7 +55,7 @@ def few_shots(question, examples): """ -model = models.openai("text-davinci-003") +model = models.openai("gpt-3.5-turbo") prompt = few_shots(question, examples) answers = model(prompt, samples=100) diff --git a/outlines/models/__init__.py b/outlines/models/__init__.py index e0bc748f8..d0b344a5c 100644 --- a/outlines/models/__init__.py +++ b/outlines/models/__init__.py @@ -5,5 +5,5 @@ codebase. """ -from .openai import OpenAIAPI, openai +from .openai import OpenAI, openai from .transformers import Transformers, transformers diff --git a/outlines/models/openai.py b/outlines/models/openai.py index 4d5e35535..8c6ea0647 100644 --- a/outlines/models/openai.py +++ b/outlines/models/openai.py @@ -1,29 +1,102 @@ """Integration with OpenAI's API.""" import functools import os -from collections import deque +import textwrap +from dataclasses import asdict, dataclass, field, replace from itertools import zip_longest from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple, Union import numpy as np import outlines -from outlines.caching import cache -__all__ = ["OpenAIAPI", "openai"] +__all__ = ["OpenAI", "openai"] if TYPE_CHECKING: from openai import AsyncOpenAI -class OpenAIAPI: +@dataclass(frozen=True) +class OpenAIConfig: + """Represents the parameters of the OpenAI API. + + The information was last fetched on 2023/11/20. We document below the + properties that are specific to the OpenAI API. Not all these properties are + supported by Outlines. + + Properties + ---------- + model_name + The name of the model. Available models can be found on OpenAI's website. + frequence_penalty + Number between 2.0 and -2.0. Positive values penalize new tokens based on + their existing frequency in the text, + logit_bias + Modifies the likelihood of specified tokens to appear in the completion. + Number between -100 (forbid) and +100 (only allows). + n + The number of completions to return for each prompt. + presence_penalty + Similar to frequency penalty. + response_format + Specifies the format the model must output. `{"type": "json_object"}` + enables JSON mode. + seed + Two completions with the same `seed` value should return the same + completion. This is however not guaranteed. + stop + Up to 4 words where the API will stop the completion. + temperature + Number between 0 and 2. Higher values make the output more random, while + lower values make it more deterministic. + top_p + Number between 0 and 1. Parameter for nucleus sampling. + user + A unique identifier for the end-user. + + """ + + model: str + frequency_penalty: float = 0 + logit_bias: Dict[int, int] = field(default_factory=dict) + max_tokens: Optional[int] = None + n: int = 1 + presence_penalty: float = 0 + response_format: Optional[Dict[str, str]] = None + seed: Optional[int] = None + stop: Optional[Union[str, List[str]]] = None + temperature: Optional[float] = None + top_p: int = 1 + user: str = field(default_factory=str) + + +class OpenAI: + """An object that represents the OpenAI API.""" + def __init__( self, model_name: str, - api_key: Optional[str] = os.getenv("OPENAI_API_KEY"), - temperature: float = 1.0, + api_key: Optional[str] = None, max_retries: int = 6, + config: Optional[OpenAIConfig] = None, ): + """Create an `OpenAI` instance. + + Parameters + ---------- + model_name + Model to use, as defined in OpenAI's documentation + api_key + Secret key to use with the OpenAI API. One can also set the + `OPENAI_API_KEY` environment variable, or the value of + `openai.api_key`. + max_retries + The maximum number of retries when calls to the API fail. + config + An instance of `OpenAIConfig`. Can be useful to specify some + parameters that cannot be set by calling this class' methods. + + """ try: import openai except ImportError: @@ -31,222 +104,296 @@ def __init__( "The `openai` library needs to be installed in order to use Outlines' OpenAI integration." ) - try: - client = openai.AsyncOpenAI(api_key=api_key, max_retries=max_retries) - except openai.OpenAIError as e: - raise e + if api_key is None: + if os.getenv("OPENAI_API_KEY") is not None: + api_key = os.getenv("OPENAI_API_KEY") + elif openai.api_key is not None: + api_key = openai.api_key + else: + raise ValueError( + "You must specify an API key to use the OpenAI API integration." + ) - @error_handler - @cache - async def cached_call_completion_api(*args, **kwargs): - response = await call_completion_api(client, *args, **kwargs) - return response - - @error_handler - @cache - async def cached_call_chat_completion_api(*args, **kwargs): - response = await call_chat_completion_api(client, *args, **kwargs) - return response - - if "text-" in model_name: - call_api = cached_call_completion_api - format_prompt = lambda x: x - extract_choice = lambda x: x["text"] - elif "gpt-" in model_name: - call_api = cached_call_chat_completion_api - format_prompt = lambda x: [{"role": "user", "content": x}] - extract_choice = lambda x: x["message"]["content"] + if config is not None: + self.config = replace(config, model=model_name) # type: ignore else: - raise NameError( - f"The model {model_name} requested is not available. Only the completion and chat completion models are available for OpenAI." - ) + self.config = OpenAIConfig(model=model_name) + + self.client = openai.AsyncOpenAI(api_key=api_key, max_retries=max_retries) - @functools.partial(outlines.vectorize, signature="(),(),(m),()->(s)") - async def generate_base( - prompt: str, - max_tokens: int, - stop_at: List[Optional[str]], - samples: int, - ) -> str: - responses = await call_api( - model_name, - format_prompt(prompt), - int(max_tokens), - temperature, - stop_at, - {}, - samples, + def __call__( + self, + prompt: Union[str, List[str]], + max_tokens: Optional[int] = None, + *, + temperature: float = 1.0, + samples: int = 1, + stop_at: Optional[Union[List[str], str]] = None, + ) -> np.ndarray: + """Call the OpenAI API to generate text. + + Parameters + ---------- + prompt + A string or list of strings that will be used to prompt the model + max_tokens + The maximum number of tokens to generate + temperature + The value of the temperature used to sample tokens + samples + The number of completions to generate for each prompt + stop_at + Up to 4 words where the API will stop the completion. + + """ + config = replace(self.config, max_tokens=max_tokens, n=samples, stop=stop_at) # type: ignore + + if "text-" in self.config.model: + raise NotImplementedError( + textwrap.dedent( + "Most models that support the legacy completion endpoints will be " + "deprecated on January 2024. Use Chat models instead.\n" + "The list of chat models is available at https://platform.openai.com/docs/guides/text-generation." + ) ) + if "gpt-" in self.config.model: + return generate_chat(prompt, self.client, config) + + """ + def generate_choice_greedy(transposed, max_tokens_left): + mask = {token: 100 for token in transposed.popleft()} + config = replace(config, logit_bias=mask, max_tokens=max_tokens_left) + response = generate_chat(prompt, config) + prefix, _ = find_common_prefix(response, choices_left) + return prefix + + + def generate_choice_optimistic(transposed, max_tokens_left): + mask = build_optimistic_mask(transposed) + config = replace(config, logit_bias=mask, max_tokens=max_tokens_left) + response = generate_chat(prompt, config) + return response + + while len(choices_left) > 0: + if greedy == True: + prefix = generate_choice_greedy() + choices_left = find_choices_left(prefix, choices_left) + if len(choices_left) == 1: + return choices_left[0] + else: + decoded.append(prefix) + greedy = False + else: + remainder = generate_choice_optimistic() + if remainder in choices_left: # Not exactly true + return remainder + else: + prefix, _ = find_common_prefix(remainder, choices_left) + decoded.append(prefix) + greedy = True + """ + + def generate_choice( + self, prompt: str, choices: List[str], max_tokens: Optional[int] = None + ) -> str: + """Call the OpenAI API to generate one of several choices. + + Parameters + ---------- + prompt + A string or list of strings that will be used to prompt the model + choices + The list of strings between which we ask the model to choose + max_tokens + The maximum number of tokens to generate + + """ + try: + import tiktoken + except ImportError: + raise ImportError( + "The `tiktoken` library needs to be installed in order to choose `outlines.models.openai` with `is_in`" + ) + + config = replace(self.config, max_tokens=max_tokens) + + tokenizer = tiktoken.encoding_for_model(self.config.model) - if samples == 1: - results = np.array([extract_choice(responses["choices"][0])]) + greedy = False + decoded: List[str] = [] + encoded_choices_left: List[List[int]] = [ + tokenizer.encode(word) for word in choices + ] + + while len(encoded_choices_left) > 0: + max_tokens_left = max([len(tokens) for tokens in encoded_choices_left]) + transposed_choices_left: List[Set] = [ + {item for item in subset if item is not None} + for subset in zip_longest(*encoded_choices_left) + ] + + if not greedy: + mask = build_optimistic_mask(transposed_choices_left) else: - results = np.array( - [extract_choice(responses["choices"][i]) for i in range(samples)] - ) + mask = {} + for token in transposed_choices_left[0]: # build greedy mask + mask[token] = 100 + + if len(mask) == 0: + break - return results - - def longest_common_prefix(tokens1: List[int], tokens2: List[int]) -> List[int]: - i = 0 - while i < len(tokens1) and i < len(tokens2) and tokens1[i] == tokens2[i]: - i += 1 - return tokens1[:i] - - def get_choices_with_longest_common_prefix( - response: List[int], is_in: List[List[int]] - ) -> Tuple[List[int], List[List[int]]]: - max_len_prefix = 0 - is_in_left = [] - prefix = [] - for i in range(len(is_in)): - len_prefix = len(longest_common_prefix(response, is_in[i])) - - if len_prefix > max_len_prefix: - max_len_prefix = len_prefix - is_in_left = [is_in[i][len_prefix:]] - prefix = is_in[i][:len_prefix] - - elif len_prefix == max_len_prefix: - is_in_left.append(is_in[i][len_prefix:]) - - return prefix, is_in_left - - def build_optimistic_mask(transposed: deque[Set]) -> Dict: - # build the biggest mask possible, adding tokens left to right - to_mask: Set[int] = set() - while len(transposed) > 0 and len(to_mask | transposed[0]) <= 300: - to_mask = to_mask | transposed.popleft() - - return {token: 100 for token in to_mask} - - @functools.partial(outlines.vectorize, signature="(),(m),()->(s)") - async def generate_choice( - prompt: str, - is_in: List[str], - samples: int, - ) -> Union[List[str], str]: - """Generate a sequence that must be one of many options. - - .. warning:: - - Worst case, this function may call the API as many times as tokens are in the response. - - With the optimistic approach, we activate all tokens that could form all answers. If the solution returned - does not match any of the answers, we the call the API again only with the tokens that can be accepted as - next-token. In average, this approach returns a solution consuming less calls to the API. - - """ - try: - import tiktoken - except ImportError: - raise ImportError( - "The `tiktoken` library needs to be installed in order to choose `outlines.models.openai` with `is_in`" + config = replace(config, logit_bias=mask, max_tokens=max_tokens_left) + response = generate_chat(prompt, self.client, config) + encoded_response = tokenizer.encode(response) + + if encoded_response in encoded_choices_left: + decoded.append(response) + break + else: + ( + encoded_response, + encoded_choices_left, + ) = find_response_choices_intersection( + encoded_response, encoded_choices_left ) - tokenizer = tiktoken.encoding_for_model(model_name) + if len(encoded_response) == 0: + greedy = True # next iteration will be "greedy" + continue + else: + decoded.append("".join(tokenizer.decode(encoded_response))) + + if len(encoded_choices_left) == 1: # only one choice left + choice_left = tokenizer.decode(encoded_choices_left[0]) + decoded.append(choice_left) + break - decoded_samples = [] - for _ in range(samples): - is_in_left = is_in.copy() - decoded: List[str] = [] + greedy = False # after each success, stay with (or switch to) "optimistic" approach - greedy = False # we try to generate the full response at each iteration + prompt = prompt + "".join(decoded) - while len(is_in_left) > 0: - encoded: List[List[int]] = [ - tokenizer.encode(word) for word in is_in_left - ] + choice = "".join(decoded) - max_tokens_left = max([len(tokens) for tokens in encoded]) - transposed: deque[Set] = deque( - [ - {item for item in subset if item is not None} - for subset in zip_longest(*encoded) - ] - ) + return choice - if not greedy: - mask = build_optimistic_mask(transposed) - else: - mask = {} - for token in transposed.popleft(): # build greedy mask - mask[token] = 100 + def generate_json(self): + """Call the OpenAI API to generate a JSON object.""" + raise NotImplementedError - if len(mask) == 0: - break + def __str__(self): + return self.__class__.__name__ + " API" - response = await call_api( - model_name, - format_prompt(prompt), - max_tokens_left if not greedy else 1, - temperature, - [], - mask, - 1, - ) + def __repr__(self): + return str(self.config) - current_resp = extract_choice(response["choices"][0]) - if current_resp in is_in_left: - decoded.append(current_resp) - break - else: - # map response to tokens - tokenized_resp = tokenizer.encode(current_resp) - ( - tokenized_resp, - encoded, - ) = get_choices_with_longest_common_prefix( - tokenized_resp, encoded - ) +@functools.partial(outlines.vectorize, signature="(),(),()->(s)") +async def generate_chat( + prompt: str, client: "AsyncOpenAI", config: OpenAIConfig +) -> np.ndarray: + responses = await client.chat.completions.create( + messages=[{"role": "user", "content": prompt}], **asdict(config) # type: ignore + ) - if len(tokenized_resp) == 0: - greedy = True # next iteration will be "greedy" - continue - else: - decoded.append("".join(tokenizer.decode(tokenized_resp))) + if config.n == 1: + results = np.array([responses.choices[0].message.content]) + else: + results = np.array( + [responses.choices[i].message.content for i in range(config.n)] + ) - # map back to words - is_in_left = [ - "".join(tokenizer.decode(tokens)) for tokens in encoded - ] + return results - if len(is_in_left) == 1: # only one choice left - decoded.append(is_in_left[0]) - break - greedy = False # after each success, stay with (or switch to) "optimistic" approach +openai = OpenAI - prompt = prompt + "".join(decoded) - decoded_samples.append("".join(decoded)) +def find_longest_common_prefix(tokens1: List[int], tokens2: List[int]) -> List[int]: + i = 0 + while i < len(tokens1) and i < len(tokens2) and tokens1[i] == tokens2[i]: + i += 1 + return tokens1[:i] - return np.array(decoded_samples) - self.generate_base = generate_base - self.generate_choice = generate_choice +def find_response_choices_intersection_new(response, choices): + """Find the longest intersection between the response and the different + choices. - def __call__( - self, - prompt: str, - max_tokens: int = 500, - *, - samples=1, - stop_at: Union[List[Optional[str]], str] = [], - is_in: Optional[List[str]] = None, - ): - if is_in is not None and stop_at: - raise TypeError("You cannot set `is_in` and `stop_at` at the same time.") - elif is_in is not None: - return self.generate_choice(prompt, is_in, samples) + Say the response is of the form `[1, 2, 3, 4, 5]` and we have the choices + `[[1, 2], [1, 2, 3], [6, 7, 8]` then the function will return `[1, 2]` as the + intersection, and `[1, 2, 3]` as the choice that is left. + + TODO: Implement a test for this + + Parameters + ---------- + response + The model's response + choices + The remaining possible choices + + Returns + ------- + A tuple that contains the longest intersection between the response and the + different choices, and the choices which start with this intersection. + + """ + max_choice_length = min([len(choice) for choice in choices]) + choices_left = choices + + for i, token in enumerate(response): + if i == max_choice_length: + return response[:i], choices_left + + remaining = [choice for choice in choices_left if choice[i] == token] + if len(remaining) == 0: + return response[:i], [choice[i:] for choice in choices_left] else: - if isinstance(stop_at, str): - stop_at = [stop_at] - return self.generate_base(prompt, max_tokens, stop_at, samples) + choices_left = remaining + + return response, [choice[i + 1 :] for choice in choices_left if len(choice) > i + 1] -openai = OpenAIAPI +def find_response_choices_intersection( + response: List[int], choices: List[List[int]] +) -> Tuple[List[int], List[List[int]]]: + max_len_prefix = 0 + choices_left = [] + prefix = [] + for i in range(len(choices)): + len_prefix = len(find_longest_common_prefix(response, choices[i])) + + if len_prefix > max_len_prefix: + max_len_prefix = len_prefix + choices_left = [choices[i][len_prefix:]] + prefix = choices[i][:len_prefix] + + elif len_prefix == max_len_prefix: + choices_left.append(choices[i][len_prefix:]) + + return prefix, choices_left + + +def build_optimistic_mask(transposed: List[Set]) -> Dict: + """We build the largest mask possible. + + Tokens are added from left to right, so if the encoded choices are e.g. + `[[1,2], [3,4]]`, `1` and `3` will be added before `2` and `4`. + + Parameters + ---------- + transposed + A list of lists that contain the nth token of each choice. + + """ + tokens_to_mask: Set[int] = set() + for tokens in transposed: + if len(tokens_to_mask) + len(tokens) <= 300: + tokens_to_mask = tokens_to_mask | tokens + else: + break + + return {token: 100 for token in tokens_to_mask} def error_handler(api_call_fn: Callable) -> Callable: @@ -274,47 +421,3 @@ def call(*args, **kwargs): raise e return call - - -async def call_completion_api( - client: "AsyncOpenAI", - model: str, - prompt: str, - max_tokens: int, - temperature: float, - stop_sequences: List[str], - logit_bias: Dict[str, int], - num_samples: int, -) -> dict: - response = await client.completions.create( - model=model, - prompt=prompt, - temperature=temperature, - max_tokens=max_tokens, - stop=list(stop_sequences) if len(stop_sequences) > 0 else None, - logit_bias=logit_bias, - n=int(num_samples), - ) - return response.model_dump() - - -async def call_chat_completion_api( - client: "AsyncOpenAI", - model: str, - messages: List[Dict[str, str]], - max_tokens: int, - temperature: float, - stop_sequences: List[str], - logit_bias: Dict[str, int], - num_samples: int, -) -> dict: - response = await client.chat.completions.create( - model=model, - messages=messages, - max_tokens=max_tokens, - temperature=temperature, - stop=list(stop_sequences) if len(stop_sequences) > 0 else None, - logit_bias=logit_bias, - n=int(num_samples), - ) - return response.model_dump() diff --git a/outlines/text/generate/sequence.py b/outlines/text/generate/sequence.py index 8550c2e8a..538579580 100644 --- a/outlines/text/generate/sequence.py +++ b/outlines/text/generate/sequence.py @@ -3,7 +3,7 @@ import torch -from outlines.models import OpenAIAPI +from outlines.models import OpenAI if TYPE_CHECKING: from outlines.models.transformers import KVCacheType, Transformers @@ -35,7 +35,7 @@ def __init__( such functions. """ - if isinstance(model, OpenAIAPI): + if isinstance(model, OpenAI): raise TypeError("Cannot use guided generation with the OpenAI API.") self.model = model diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py new file mode 100644 index 000000000..8b29f89d6 --- /dev/null +++ b/tests/models/test_openai.py @@ -0,0 +1,27 @@ +import pytest + +from outlines.models.openai import find_response_choices_intersection_new + + +@pytest.mark.parametrize( + "response,choice,expected_intersection,expected_choices_left", + ( + ([1, 2, 3, 4], [[5, 6]], [], [[5, 6]]), + ([1, 2, 3, 4], [[5, 6], [7, 8]], [], [[5, 6], [7, 8]]), + ([1, 2, 3, 4], [[1, 2], [7, 8]], [1, 2], [[1, 2]]), + ([1, 2], [[1, 2, 3, 4], [1, 2]], [1, 2], [[3, 4]]), + ([1, 2, 3], [[1, 2, 3, 4], [1, 2]], [1, 2], [[4]]), + ), +) +def test_find_reponse_choices_new( + response, choice, expected_intersection, expected_choices_left +): + intersection, choices_left = find_response_choices_intersection_new( + response, choice + ) + assert intersection == expected_intersection + assert choices_left == expected_choices_left + + +def test_build_optimistic_mask(): + raise NotImplementedError diff --git a/tests/text/generate/test_sequence.py b/tests/text/generate/test_sequence.py index f4fd52c03..e5ede8c53 100644 --- a/tests/text/generate/test_sequence.py +++ b/tests/text/generate/test_sequence.py @@ -5,13 +5,13 @@ import pytest import torch -from outlines.models import OpenAIAPI +from outlines.models import OpenAI from outlines.models.tokenizer import Tokenizer from outlines.text.generate.sequence import Sequence def test_openai_error(): - class Mock(OpenAIAPI): + class Mock(OpenAI): def __init__(self): pass