Skip to content

Commit

Permalink
[BugFix][Kernel] Fix Illegal memory access in causal_conv1d in H100 (v…
Browse files Browse the repository at this point in the history
…llm-project#9838)

Signed-off-by: mzusman <[email protected]>
  • Loading branch information
mzusman authored Oct 31, 2024
1 parent 55650c8 commit 9fb12f7
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 7 deletions.
34 changes: 32 additions & 2 deletions csrc/mamba/causal_conv1d/causal_conv1d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,31 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, seqlen - chunk * kChunkSize);
}
out += kChunkSize;

int final_state_position = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize);
// in case the final state is separated between the last "smem_exchange" and
// and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2),
// (which occurs when `final_state_position` is a non-positivie index)
// we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it
if (final_state_position < 0 && seqlen > kWidth){
input_t vals_load[kNElts] = {0};
if ((chunk == n_chunks - 2) && (tidx == kNThreads - 1)){
// chunk = n_chunks - 2, a segment of the final state sits in the last index
reinterpret_cast<vec_t *>(vals_load)[0] = smem_exchange[kNThreads - 1];
#pragma unroll
for (int w = 0; w < -final_state_position; ++w){
conv_states[w] = vals_load[kNElts + final_state_position + w];
}
}
if ((chunk == n_chunks - 1) && tidx == 0){
// chunk = n_chunks - 1, the second segment of the final state first positions
reinterpret_cast<vec_t *>(vals_load)[0] = smem_exchange[0];
for (int w = -final_state_position; w < kWidth - 1; ++w){
conv_states[w] = vals_load[w + final_state_position];
}
return;
}
}
}
// Final state is stored in the smem_exchange last token slot,
// in case seqlen < kWidth, we would need to take the final state from the
Expand Down Expand Up @@ -446,9 +471,14 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
}
else {
// in case the final state is in between the threads data
reinterpret_cast<vec_t *>(x_vals_load)[1] = smem_exchange[last_thread + 1];
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[last_thread];
const int offset = ((seqlen - (kWidth - 1)) % (kNElts));
if ((offset + kWidth - 2) >= kNElts && (last_thread + 1 < kNThreads)){
// In case last_thread == kNThreads - 1, accessing last_thread + 1 will result in a
// illegal access error on H100.
// Therefore, we access last_thread + 1, only if the final state data sits there
reinterpret_cast<vec_t *>(x_vals_load)[1] = smem_exchange[last_thread + 1];
}
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[last_thread];
#pragma unroll
for (int w = 0; w < kWidth - 1; ++w){
conv_states[w] = x_vals_load[offset + w ];
Expand Down
7 changes: 5 additions & 2 deletions tests/kernels/test_causal_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor,
@pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize(
'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 1025, 2048, 4096])
@pytest.mark.parametrize('dim', [64])
@pytest.mark.parametrize('batch', [1])
def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
Expand Down Expand Up @@ -420,7 +420,10 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,

unpadded_out = out[:, :out_ref_tensor.shape[-1]]
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol)
assert torch.allclose(final_states[state_indices],
final_states_ref[state_indices],
rtol=rtol,
atol=atol)

causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
padded_state_indices, has_initial_states,
Expand Down
6 changes: 3 additions & 3 deletions tests/kernels/test_mamba_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
if itype == torch.bfloat16:
rtol, atol = 7e-2, 7e-2
rtol, atol = 1e-1, 1e-1
if torch.version.hip:
atol *= 2
# set seed
Expand Down Expand Up @@ -610,8 +610,8 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
dt_bias=dt_bias,
dt_softplus=True)

print("Output diff max", (out - out_ref[0]).max())
print("Output diff mean", (out - out_ref[0]).mean())
print("Output diff max", (out[:batch_size] - out_ref).max())
print("Output diff mean", (out[:batch_size] - out_ref).mean())
print("Output state diff max", (state[state_indices, :] - state_ref).max())
print("Output state diff mean",
(state[state_indices, :] - state_ref).mean())
Expand Down

0 comments on commit 9fb12f7

Please sign in to comment.