Skip to content

Commit

Permalink
#12909: Skip users in FlashDecode based on index
Browse files Browse the repository at this point in the history
Index of -1 causes FlashDecode to skip computation. Based on cur_pos, skip tile reads for K and V chunks outside of valid range.

---------

Signed-off-by: Salar Hosseini <[email protected]>
Co-authored-by: Colman Glagovich <[email protected]>
Co-authored-by: Salar Hosseini <[email protected]>
  • Loading branch information
3 people authored Sep 20, 2024
1 parent 613505d commit a3afddb
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 22 deletions.
5 changes: 4 additions & 1 deletion models/demos/t3000/llama2_70b/tt/llama_model_optimized.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,10 @@ def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None, mode="decode"):
else:
cache_idxs = start_pos.to(dtype=torch.int64)

rot_mat = get_rotation_mat(self.rot_emb, cache_idxs, seq_len, batch=batch)
rot_cache_idxs = torch.maximum(
cache_idxs, torch.tensor(0, dtype=torch.int64)
) # Ensure position indices are non-negative
rot_mat = get_rotation_mat(self.rot_emb, rot_cache_idxs, seq_len, batch=batch)
assert rot_mat.size() == (1, batch, self.head_dim, self.head_dim)

rot_mats = ttnn.as_tensor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,8 @@ def run_test_sdpa_decode_single_iter(
) # b, nh, 1, d
expect = expect.squeeze().unsqueeze(0)

out_pass, out_pcc = comp_pcc(expect, tt_back, min_pcc)
non_skip_indices = torch.tensor(start_indices) != -1
out_pass, out_pcc = comp_pcc(expect[:, non_skip_indices], tt_back[:, non_skip_indices], min_pcc)

logger.debug(f"python vs pytorch: {out_pcc}")
assert out_pass
Expand Down Expand Up @@ -460,6 +461,46 @@ def test_sdpa_decode(
)


@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs")
@pytest.mark.skip("Skipping due to potential nd pcc issue #9370")
@pytest.mark.parametrize(
"dtype, q_dtype",
[
[ttnn.bfloat16, ttnn.bfloat16],
],
ids=[
"all_bfp16",
],
)
@pytest.mark.parametrize(
"b, nh, nkv, s, d, grid_size, single_iter, cur_pos_tensor",
([32, 8, 1, 32768, 128, (8, 6), True, True],), # Llama2-70B
)
def test_sdpa_decode_ignore_users(
device, b, nh, nkv, s, d, dtype, grid_size, q_dtype, single_iter, cur_pos_tensor, use_program_cache
):
ttnn.device.DisablePersistentKernelCache()

# Set odd users to -1 to test skipping users
start_indices = [100 if bb % 2 == 0 else -1 for bb in range(b)]

run_test_sdpa_decode_single_iter(
device,
b,
nh,
nkv,
s,
d,
dtype,
grid_size,
q_dtype,
cur_pos_tensor,
sharded_in=False,
sharded_out=False,
start_indices=start_indices,
)


def run_test_sdpa_decode_paged_attention(
device,
b,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -494,8 +494,8 @@ void MAIN {

// Get cur_pos
uint32_t cur_pos = 0;
// using 4294967295 (end of uint32 range) as a flag to indicate that cur_pos is not provided as a list
if (cur_pos_arg!=4294967295){
// using UINT32_MAX as a flag to indicate that cur_pos is not provided as a list
if (cur_pos_arg != UINT32_MAX){
cur_pos = cur_pos_arg;
}
else {
Expand All @@ -506,6 +506,11 @@ void MAIN {
cur_pos = index_addr_ptr[4+cur_batch];
cb_release_tile(cb_index_id);
}

if (cur_pos == UINT32_MAX) {
// cur_pos of -1 indicates that the user should be skipped
return;
}
// Sequence length assignment
auto [PSt, k_num_chunks, k_chunk_start, k_chunk_end] = get_runtime_args(cur_pos, cur_batch, core_num, num_cores_per_batch, k_chunk_size);
if (k_chunk_start == k_chunk_end) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ void kernel_main() {
}
// Get cur_pos
uint32_t cur_pos = 0;
// using 4294967295 (end of uint32 range) as a flag to indicate that cur_pos is not provided as a list
if (cur_pos_arg!=4294967295){
// using UINT32_MAX as a flag to indicate that cur_pos is not provided as a list
if (cur_pos_arg != UINT32_MAX){
cur_pos = cur_pos_arg;
}
else {
Expand All @@ -87,6 +87,13 @@ void kernel_main() {
volatile tt_l1_ptr uint32_t* index_ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(index_cb_wr_ptr);
cur_pos = index_ptr[cur_batch];
}

if (cur_pos == UINT32_MAX) {
// cur_pos of -1 indicates that the user should be skipped
return;
}
const uint32_t valid_seq_len_tiles = (cur_pos + 1 + 32 - 1) / 32;

volatile tt_l1_ptr uint32_t* page_table_ptr;
if constexpr (is_paged_attention) {
constexpr uint32_t cb_id_page_table = tt::CB::dataflow1;
Expand Down Expand Up @@ -259,14 +266,15 @@ void kernel_main() {
for (uint32_t col = 0; col < DHt; ++col) {
uint32_t k_tile_id = k_start_tile_id + col;
for (uint32_t row = 0; row < Sk_chunk_t; ++row) {
noc_async_read_tile(k_tile_id, k_reader, k_write_ptr);
if (row <= valid_seq_len_tiles) {
noc_async_read_tile(k_tile_id, k_reader, k_write_ptr);
if (++barrier_count == barrier_threshold) {
noc_async_read_barrier();
barrier_count = 0;
}
}
k_tile_id += DHt;
k_write_ptr += k_tile_bytes;

if (++barrier_count == barrier_threshold) {
noc_async_read_barrier();
barrier_count = 0;
}
}
}
noc_async_read_barrier();
Expand All @@ -278,14 +286,17 @@ void kernel_main() {
uint32_t v_write_ptr = get_write_ptr(cb_v_in);
barrier_count = 0;
uint32_t v_tile_id = v_start_tile_id;
for (uint32_t tile = 0; tile < k_chunk_tiles; ++tile) {
noc_async_read_tile(v_tile_id, v_reader, v_write_ptr);
v_tile_id++;
v_write_ptr += v_tile_bytes;

if (++barrier_count == barrier_threshold) {
noc_async_read_barrier();
barrier_count = 0;
for (uint32_t row = 0; row < Sk_chunk_t; ++row) {
for (uint32_t col = 0; col < DHt; ++col) {
if (row <= valid_seq_len_tiles) {
noc_async_read_tile(v_tile_id, v_reader, v_write_ptr);
if (++barrier_count == barrier_threshold) {
noc_async_read_barrier();
barrier_count = 0;
}
}
v_tile_id++;
v_write_ptr += v_tile_bytes;
}
}
noc_async_read_barrier();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,8 @@ void kernel_main() {
}
// Get cur_pos
uint32_t cur_pos = 0;
// using 4294967295 (end of uint32 range) as a flag to indicate that cur_pos is not provided as a list
if (cur_pos_arg!=4294967295){
// using UINT32_MAX as a flag to indicate that cur_pos is not provided as a list
if (cur_pos_arg != UINT32_MAX){
cur_pos = cur_pos_arg;
}
else {
Expand All @@ -288,6 +288,11 @@ void kernel_main() {
volatile tt_l1_ptr uint32_t* index_ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(index_cb_ptr);
cur_pos = index_ptr[cur_batch];
}

if (cur_pos == UINT32_MAX) {
// cur_pos of -1 indicates that the user should be skipped
return;
}
// Sequence length assignment
auto [PSt, k_num_chunks, k_chunk_start, k_chunk_end] = get_runtime_args(cur_pos, cur_batch, core_num, num_cores_per_batch, k_chunk_size);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ void py_bind_sdpa_decode(py::module &module) {
"output: [1 x b x pnh x dh]"
"Accepts a `SDPAMultiCoreProgramConfig` which specifies the grid size and chunk tiles in the K/V/Mask sequence lengths (Q chunk tiles is not used). The op parallelizes over `b` and K/V/Mask's `s` dimension."
"If a position is given as (-1), compute for the corresponding index in the batch is skipped."
)doc";

using OperationType = decltype(ttnn::transformer::scaled_dot_product_attention_decode);
Expand Down

0 comments on commit a3afddb

Please sign in to comment.