Skip to content

Commit

Permalink
#0: Skip users in FlashDecode based on index
Browse files Browse the repository at this point in the history
  • Loading branch information
cglagovich committed Sep 19, 2024
1 parent bd3f53e commit e4dfcef
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,8 @@ def run_test_sdpa_decode_single_iter(
) # b, nh, 1, d
expect = expect.squeeze().unsqueeze(0)

out_pass, out_pcc = comp_pcc(expect, tt_back, min_pcc)
non_skip_indices = torch.tensor(start_indices) != -1
out_pass, out_pcc = comp_pcc(expect[:, non_skip_indices], tt_back[:, non_skip_indices], min_pcc)

logger.debug(f"python vs pytorch: {out_pcc}")
assert out_pass
Expand Down Expand Up @@ -460,6 +461,46 @@ def test_sdpa_decode(
)


@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs")
@pytest.mark.skip("Skipping due to potential nd pcc issue #9370")
@pytest.mark.parametrize(
"dtype, q_dtype",
[
[ttnn.bfloat16, ttnn.bfloat16],
],
ids=[
"all_bfp16",
],
)
@pytest.mark.parametrize(
"b, nh, nkv, s, d, grid_size, single_iter, cur_pos_tensor",
([32, 8, 1, 32768, 128, (8, 6), True, True],), # Llama2-70B
)
def test_sdpa_decode_ignore_users(
device, b, nh, nkv, s, d, dtype, grid_size, q_dtype, single_iter, cur_pos_tensor, use_program_cache
):
ttnn.device.DisablePersistentKernelCache()

# Set odd users to -1 to test skipping users
start_indices = [100 if bb % 2 == 0 else -1 for bb in range(b)]

run_test_sdpa_decode_single_iter(
device,
b,
nh,
nkv,
s,
d,
dtype,
grid_size,
q_dtype,
cur_pos_tensor,
sharded_in=False,
sharded_out=False,
start_indices=start_indices,
)


def run_test_sdpa_decode_paged_attention(
device,
b,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,11 @@ void MAIN {
cur_pos = index_addr_ptr[4+cur_batch];
cb_release_tile(cb_index_id);
}

if (cur_pos == (uint32_t) -1) {
// cur_pos of -1 indicates that the user should be skipped
return;
}
// Sequence length assignment
auto [PSt, k_num_chunks, k_chunk_start, k_chunk_end] = get_runtime_args(cur_pos, cur_batch, core_num, num_cores_per_batch, k_chunk_size);
if (k_chunk_start == k_chunk_end) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ void kernel_main() {
volatile tt_l1_ptr uint32_t* index_ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(index_cb_wr_ptr);
cur_pos = index_ptr[cur_batch];
}

if (cur_pos == (uint32_t) -1) {
// cur_pos of -1 indicates that the user should be skipped
return;
}
volatile tt_l1_ptr uint32_t* page_table_ptr;
if constexpr (is_paged_attention) {
constexpr uint32_t cb_id_page_table = tt::CB::dataflow1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,11 @@ void kernel_main() {
volatile tt_l1_ptr uint32_t* index_ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(index_cb_ptr);
cur_pos = index_ptr[cur_batch];
}

if (cur_pos == (uint32_t) -1) {
// cur_pos of -1 indicates that the user should be skipped
return;
}
// Sequence length assignment
auto [PSt, k_num_chunks, k_chunk_start, k_chunk_end] = get_runtime_args(cur_pos, cur_batch, core_num, num_cores_per_batch, k_chunk_size);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ void py_bind_sdpa_decode(py::module &module) {
"output: [1 x b x pnh x dh]"
"Accepts a `SDPAMultiCoreProgramConfig` which specifies the grid size and chunk tiles in the K/V/Mask sequence lengths (Q chunk tiles is not used). The op parallelizes over `b` and K/V/Mask's `s` dimension."
"If a position is given as (-1), compute for the corresponding index in the batch is skipped."
)doc";

using OperationType = decltype(ttnn::transformer::scaled_dot_product_attention_decode);
Expand Down

0 comments on commit e4dfcef

Please sign in to comment.