Skip to content

Commit

Permalink
Merge pull request #5 from shubhamugare/multi_batch
Browse files Browse the repository at this point in the history
Multi batch generation
  • Loading branch information
shubhamugare authored Oct 30, 2023
2 parents eb679fb + 83382e3 commit 6d7f691
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 65 deletions.
14 changes: 7 additions & 7 deletions llm_cfg/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,16 @@ def generate_batch_completion_grammar(self, prompt, batch_size) -> list[str]:
logits_processor=self.logit_processors
)

last_token_id = len(generated_ids[0])

last_token_id = [len(generated_ids[i]) for i in range(batch_size)]
if self.logit_processors is not None:
python_decoder = self.logit_processors[0]
last_token_id = python_decoder.last_valid_stage

batch_completions = self.tokenizer.batch_decode(
[ids[input_ids_cutoff:last_token_id] for ids in generated_ids],
skip_special_tokens=True,
)
last_token_id = [python_decoder.last_valid_state[i] for i in range(batch_size)]

batch_completions = [
self.tokenizer.decode(ids[input_ids_cutoff:last_token_id[i]], skip_special_tokens=True) for i, ids in enumerate(generated_ids)
]

if self.logit_processors is not None:
python_decoder = self.logit_processors[0]
print(f"Time taken for generation: {time.time() - start_time:.2f}s")
Expand Down
115 changes: 57 additions & 58 deletions llm_cfg/python_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,97 +6,96 @@


class PythonDecoder(LogitsProcessor):
"""
This class is used to filter the logits of the model to only allow syntactically valid tokens for Python.
"""
def __init__(self, tokenizer: PreTrainedTokenizer, **kwargs):
time_start = time.time()
self.tokenizer = tokenizer
self.inc_parser = IncrementalParser()

self.batch_size = None # We update this in the first call to __call__
self.inc_parsers = None

# For backtracking to syntactically valid completions
self.partial_codes_trace = []
self.last_valid_stage = 0

# For profiling
self.token_cnt = 0
self.accept_tokens_sizes = []
self.non_matching_token_cnt = 0
self.partial_codes = []
self.last_valid_stage = 0
self.terminals_nfa = common.load_nfa(tokenizer=self.tokenizer, inc_parser=self.inc_parser, use_cache=True)

# Iterate through the vocabulary and create a map of (tokenizer token -> grammar terminal)
# Note: It may happen that many tokens do not fall in any category
self.terminal_to_mask = {}
self.uncategorezed_mask = torch.zeros(tokenizer.vocab_size, dtype=torch.bool)
# Load NFA
self.terminals_nfa = common.load_nfa(tokenizer=self.tokenizer, use_cache=True)

self.start_time = time.time()
self.prev_time = self.start_time
self.debug = True

print(f"Time taken for preprocessing: {time.time() - time_start:.2f}s")

for i in range(tokenizer.vocab_size):
token = tokenizer.decode(torch.tensor([i]), skip_special_tokens=True)
token_types = []

token_type = self.inc_parser.get_matching_terminal(token)
prefix_token_types = self.inc_parser.get_prefix_terminals_match(token)
def _print_current_status(self, partial_code, cur_ac_terminals, next_ac_terminals, incomplete_terminal):
print('partial code:\n', repr(partial_code))
print('inc:', repr(incomplete_terminal), '\n', 'cur:', cur_ac_terminals, '\n', 'next:', next_ac_terminals)

token_types.append(token_type)
token_types += prefix_token_types

for token_type in token_types:
if not token_type in self.terminal_to_mask:
self.terminal_to_mask[token_type] = torch.zeros(tokenizer.vocab_size, dtype=torch.bool)
self.terminal_to_mask[token_type][i] = True

print(f"Time taken for preprocessing: {time.time() - time_start:.2f}s")

def _reset(self):
self.token_cnt = 0
self.inc_parsers = [IncrementalParser() for _ in range(self.batch_size)]
self.last_valid_state = [0 for _ in range(self.batch_size)]
self.accept_tokens_sizes = []
self.partial_codes = []
self.last_valid_stage = 0
self.inc_parser = IncrementalParser()
self.partial_codes_trace = []
self.token_cnt = 0


def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
partial_code = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
partial_codes = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True)

if len(self.partial_codes) > 0 and self.partial_codes[-1] not in partial_code:
if self.batch_size == None or (len(self.partial_codes_trace) > 0 and self.partial_codes_trace[0] not in partial_codes[0]):
self.batch_size = len(partial_codes)
self._reset()

self.token_cnt += 1
self.token_cnt += self.batch_size
greedy_grammar_token = None
try:
compilation_start_time = time.time()
self.partial_codes.append(partial_code)

# returns the names of the Terminals that are currently accepted.
cur_ac_terminals, next_ac_terminals, incomplete_terminal = self.inc_parser.get_acceptable_next_terminals(partial_code)
for i, partial_code in enumerate(partial_codes):
try:
compilation_start_time = time.time()
self.partial_codes_trace.append(partial_code)

greedy_token = self.tokenizer.decode(scores.argmax(dim=-1), skip_special_tokens=True) # For debugging - remove later
# returns the names of the Terminals that are currently accepted.
cur_ac_terminals, next_ac_terminals, incomplete_terminal = self.inc_parsers[i].get_acceptable_next_terminals(partial_code)

if 'EOF' in next_ac_terminals:
self.last_valid_stage = len(input_ids[0])
greedy_token = self.tokenizer.decode(scores.argmax(dim=-1), skip_special_tokens=True) # For debugging - remove later

self.accept_tokens_sizes.append(len(cur_ac_terminals)) # For profiling
if 'EOF' in next_ac_terminals:
self.last_valid_state[i] = len(input_ids[i])

accept_mask = self.terminals_nfa.get_overapprox_tokens_mask(incomplete_terminal, next_ac_terminals)
self.accept_tokens_sizes.append(len(cur_ac_terminals)) # For profiling

print('partial code:')
print(repr(partial_code))
print('inc:', repr(incomplete_terminal))
print(next_ac_terminals)
scores = scores.masked_fill(~accept_mask.to(scores.device), -float("inf"))
print(i, 'Time taken for compilation:', time.time() - compilation_start_time)
accept_mask = self.terminals_nfa.get_overapprox_tokens_mask(incomplete_terminal, next_ac_terminals)

if self.debug and self.token_cnt%50==0:
print('Time taken for compilation:', time.time() - compilation_start_time)
print(i, 'Time taken for overapproximation:', time.time() - compilation_start_time)
if self.debug and self.token_cnt%50==0:
self._print_current_status(partial_code, cur_ac_terminals, next_ac_terminals, incomplete_terminal)

if torch.sum(accept_mask) != 0: # If there are acceptable tokens for the current partial code
scores[i] = scores[i].masked_fill(~accept_mask.to(scores.device), -float("inf"))
else: # Otherwise, report the error and mask no tokens
print('No acceptable tokens for the current partial code!')
self._print_current_status(partial_code, cur_ac_terminals, next_ac_terminals, incomplete_terminal)

greedy_grammar_token = self.tokenizer.decode(scores.argmax(dim=-1), skip_special_tokens=True)
print(i, 'Time taken for masking:', time.time() - compilation_start_time)

if greedy_token != greedy_grammar_token:
print('Greedy token:', repr(greedy_token), scores.argmax(dim=-1))
print('Greedy grammar-based token:', repr(greedy_grammar_token), scores.argmax(dim=-1))
print('Current acceptable terminals:', cur_ac_terminals)
print('Next acceptable terminals:', next_ac_terminals)
print('Partial code:', partial_code)
self.non_matching_token_cnt += 1

except Exception as e:
print("-"*80, '\n', 'Code lenght:', len(partial_code), '\n', partial_code, '\n', repr(partial_code), '\n', 'Error:', e)
greedy_grammar_token = self.tokenizer.decode(scores.argmax(dim=-1), skip_special_tokens=True)

if greedy_token != greedy_grammar_token:
print('Greedy token:', repr(greedy_token), scores.argmax(dim=-1))
print('Greedy grammar-based token:', repr(greedy_grammar_token), scores.argmax(dim=-1))
self._print_current_status(partial_code, cur_ac_terminals, next_ac_terminals, incomplete_terminal)
self.non_matching_token_cnt += 1
except Exception as e:
print("-"*80, '\n', 'Code lenght:', len(partial_code), '\n', partial_code, '\n', repr(partial_code), '\n', 'Error:', e)

return scores

0 comments on commit 6d7f691

Please sign in to comment.