From 76cfc611dfc66d7d03cf466718bb440dd146e0f2 Mon Sep 17 00:00:00 2001 From: Ivan Herreros Date: Tue, 14 Nov 2023 12:30:52 +0100 Subject: [PATCH] Improve multiple-choice selection for the OpenAI API The current approach is greedy, in the sense that it generates a single token at each steps, asking the API to only generate valid next tokens. This mean having to pay for the prompt tokens for every token generated. This commit takes a more optimistic approach. It starts with allowing all tokens present in the sequences, and limiting the length of the generation to the number of tokens in the longest sequence. If the completion is not satisfactory it then takes one greedy step before switching back to the optimistic mode. On average this new approach consumes less tokens than the current one. --- outlines/models/openai.py | 116 +++++++++++++++++++++++++++++++------- 1 file changed, 97 insertions(+), 19 deletions(-) diff --git a/outlines/models/openai.py b/outlines/models/openai.py index e845b940e..4d5e35535 100644 --- a/outlines/models/openai.py +++ b/outlines/models/openai.py @@ -1,7 +1,9 @@ """Integration with OpenAI's API.""" import functools import os -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union +from collections import deque +from itertools import zip_longest +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple, Union import numpy as np @@ -85,6 +87,39 @@ async def generate_base( return results + def longest_common_prefix(tokens1: List[int], tokens2: List[int]) -> List[int]: + i = 0 + while i < len(tokens1) and i < len(tokens2) and tokens1[i] == tokens2[i]: + i += 1 + return tokens1[:i] + + def get_choices_with_longest_common_prefix( + response: List[int], is_in: List[List[int]] + ) -> Tuple[List[int], List[List[int]]]: + max_len_prefix = 0 + is_in_left = [] + prefix = [] + for i in range(len(is_in)): + len_prefix = len(longest_common_prefix(response, is_in[i])) + + if len_prefix > max_len_prefix: + max_len_prefix = len_prefix + is_in_left = [is_in[i][len_prefix:]] + prefix = is_in[i][:len_prefix] + + elif len_prefix == max_len_prefix: + is_in_left.append(is_in[i][len_prefix:]) + + return prefix, is_in_left + + def build_optimistic_mask(transposed: deque[Set]) -> Dict: + # build the biggest mask possible, adding tokens left to right + to_mask: Set[int] = set() + while len(transposed) > 0 and len(to_mask | transposed[0]) <= 300: + to_mask = to_mask | transposed.popleft() + + return {token: 100 for token in to_mask} + @functools.partial(outlines.vectorize, signature="(),(m),()->(s)") async def generate_choice( prompt: str, @@ -95,12 +130,11 @@ async def generate_choice( .. warning:: - This function will call the API once for every token generated. + Worst case, this function may call the API as many times as tokens are in the response. - We tokenize every choice, iterate over the token lists, create a mask - with the current tokens and generate one token. We progressively - eliminate the choices that don't start with the currently decoded - sequence. + With the optimistic approach, we activate all tokens that could form all answers. If the solution returned + does not match any of the answers, we the call the API again only with the tokens that can be accepted as + next-token. In average, this approach returns a solution consuming less calls to the API. """ try: @@ -111,20 +145,33 @@ async def generate_choice( ) tokenizer = tiktoken.encoding_for_model(model_name) - encoded: List[List[int]] = [tokenizer.encode(word) for word in is_in] decoded_samples = [] for _ in range(samples): + is_in_left = is_in.copy() decoded: List[str] = [] - for i in range(max([len(word) for word in encoded])): - mask = {} - for word, tokenized_word in zip(is_in, encoded): - if not word.startswith("".join(decoded)): - continue - try: - mask[tokenized_word[i]] = 100 - except IndexError: - pass + + greedy = False # we try to generate the full response at each iteration + + while len(is_in_left) > 0: + encoded: List[List[int]] = [ + tokenizer.encode(word) for word in is_in_left + ] + + max_tokens_left = max([len(tokens) for tokens in encoded]) + transposed: deque[Set] = deque( + [ + {item for item in subset if item is not None} + for subset in zip_longest(*encoded) + ] + ) + + if not greedy: + mask = build_optimistic_mask(transposed) + else: + mask = {} + for token in transposed.popleft(): # build greedy mask + mask[token] = 100 if len(mask) == 0: break @@ -132,15 +179,46 @@ async def generate_choice( response = await call_api( model_name, format_prompt(prompt), - 1, + max_tokens_left if not greedy else 1, temperature, [], mask, 1, ) - decoded.append(extract_choice(response["choices"][0])) - prompt = prompt + "".join(decoded) + current_resp = extract_choice(response["choices"][0]) + + if current_resp in is_in_left: + decoded.append(current_resp) + break + else: + # map response to tokens + tokenized_resp = tokenizer.encode(current_resp) + ( + tokenized_resp, + encoded, + ) = get_choices_with_longest_common_prefix( + tokenized_resp, encoded + ) + + if len(tokenized_resp) == 0: + greedy = True # next iteration will be "greedy" + continue + else: + decoded.append("".join(tokenizer.decode(tokenized_resp))) + + # map back to words + is_in_left = [ + "".join(tokenizer.decode(tokens)) for tokens in encoded + ] + + if len(is_in_left) == 1: # only one choice left + decoded.append(is_in_left[0]) + break + + greedy = False # after each success, stay with (or switch to) "optimistic" approach + + prompt = prompt + "".join(decoded) decoded_samples.append("".join(decoded))