-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NATTEN failure: cutlass error: Error Internal at: 102 #163
Comments
Could you please share your use case? (Problem shape, operation, dtype, GPU) |
I am trying to use 2D unfused NA to replace 3D one by decomposing a 3D kernel (5,5,5) into 5 2D kernels with size of (5,5). NATTEN operations: na2d_qk, na2d_av, na3d_qk, na3d_av import time
import natten
import torch
from natten.functional import na2d, na2d_qk, na2d_av, na3d, na3d_qk, na3d_av
import os
from torch.nn import functional as F
from einops import rearrange
natten.use_fused_na(True)
natten.use_kv_parallelism_in_fused_na(True)
natten.set_memory_usage_preference("unrestricted")
method = '2D'
b = 1
num_heads, head_dim = 8, 32
t, h, w = 5, 256, 256
kernel_size = (5,5,5)
print(f"\nUsing {method} NA for 3D data [{t},{h},{w}] with kernel size of {kernel_size}\n")
for i in range(20):
print(f"----------\ntrial {i}")
query = torch.randn(b, num_heads, t, h, w, head_dim, requires_grad=True).to(device)
key, value = torch.randn_like(query).to(device), torch.randn_like(query).to(device)
torch.cuda.synchronize()
start = time.time()
if method =='2D':
idx = torch.arange(0,kernel_size[0]).reshape(1,-1).repeat(t,1)
query = rearrange(query, 'b head t h w c -> (b t) head h w c')
key = rearrange(key[:,:,idx], 'b head t unfold h w c -> unfold (b t) head h w c')
value = rearrange(value[:,:,idx], 'b head t unfold h w c -> unfold (b t) head h w c')
attn0 = na2d_qk(query, key[0], kernel_size=(kernel_size[1],kernel_size[2]), dilation=1)
attn1 = na2d_qk(query, key[1], kernel_size=(kernel_size[1],kernel_size[2]), dilation=1)
attn2 = na2d_qk(query, key[2], kernel_size=(kernel_size[1],kernel_size[2]), dilation=1)
attn3 = na2d_qk(query, key[3], kernel_size=(kernel_size[1],kernel_size[2]), dilation=1)
attn4 = na2d_qk(query, key[4], kernel_size=(kernel_size[1],kernel_size[2]), dilation=1)
attn = torch.cat((attn0,attn1,attn2,attn3,attn4), dim=-1)
attn = F.softmax(attn, dim=-1)
attn0,attn1,attn2,attn3,attn4 = torch.chunk(attn, 5, dim=-1)
output0 = na2d_av(attn0, value[0], kernel_size=(kernel_size[1],kernel_size[2]), dilation=1)
output1 = na2d_av(attn1, value[1], kernel_size=(kernel_size[1],kernel_size[2]), dilation=1)
output2 = na2d_av(attn2, value[2], kernel_size=(kernel_size[1],kernel_size[2]), dilation=1)
output3 = na2d_av(attn3, value[3], kernel_size=(kernel_size[1],kernel_size[2]), dilation=1)
output4 = na2d_av(attn4, value[4], kernel_size=(kernel_size[1],kernel_size[2]), dilation=1)
output = output0 + output1 + output2 + output3 + output4
elif enthod =='3D':
attn = na3d_qk(query, key, kernel_size=kernel_size, dilation=1)
attn = F.softmax(attn, dim=-1)
output = na3d_av(attn,value, kernel_size=kernel_size, dilation=1)
end = time.time()
print(f"forward: {end - start:3f}")
loss = output.sum()
torch.cuda.synchronize()
start = time.time()
loss.backward()
torch.cuda.synchronize()
end = time.time()
print(f"backward: {end - start:3f}") And it also report error: RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions. |
I think the issue is non contiguous views. The tensor slicing/manipulation you're doing could be resulting in tensors with invalid layouts that the underlying ops are going to try and visit. Can you try adding a Also just fyi: the FNA flags at the top won't affect this program since you're using the ops directly. |
You are right! By adding
By testing, it seems that |
Yes that's what I thought. The ops make q k and v contiguous by default, but not attn. Let me take a closer look and see if I can find the issue. Could you also try and comment out the backward pass and see if it runs without the To be clear there's nothing wrong with doing |
By the way, is decomposing 3D NA into multiple 2D NAs feasible? Do you think any potential problems? The decomposition is motivated by
|
Without |
No I think it makes sense to do 3D with 2D if possible, and because the software predication is simpler for 2D than 3D, it might end up being faster. But one thing I'll note is that if your 2D kernel size is a square, which seems to be the case, then 2D will target the GEMM based kernel, but 3D will hit naive. However, I think the extra . contiguous() calls and explicit data movement like that can easily become a bottleneck worse than the software predication in 3D. I am actually curious to see if with this implementation you can train faster than with just using FNA, even if FP32. |
@pwangcs is your issue resolved? |
Hi, Ali. During training, I get the following error:
With only forward pass during inference (without training), there is not such error.
What is the error?
The text was updated successfully, but these errors were encountered: