Skip to content

Commit

Permalink
fix: multi-step + flashinfer with cuda graphs (#1036)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlpinDale authored Dec 27, 2024
1 parent 055c890 commit c951a54
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion aphrodite/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,9 +600,18 @@ def build(self, seq_lens: List[int], query_lens: List[int],
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = self.runner.graph_block_tables[:batch_size]
max_blocks = input_block_tables.shape[1]
for i, block_table in enumerate(self.block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
num_blocks = len(block_table)
if num_blocks <= max_blocks:
input_block_tables[i, :num_blocks] = block_table
else:
# It may be possible to have more blocks allocated due
# to lookahead slots of multi-step, however, they are
# not used anyway, so can be safely ignored.
input_block_tables[
i, :max_blocks] = block_table[:max_blocks]
block_tables = torch.from_numpy(input_block_tables).to(
device, non_blocking=True)

Expand Down

0 comments on commit c951a54

Please sign in to comment.