-
Notifications
You must be signed in to change notification settings - Fork 147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
spspmm lead to error: PyTorch CUDA error: an illegal memory access was encountered. #314
Comments
What version of |
Thank you for your reply. |
You mean without upgrading CUDA? You should be able to install from wheels via |
I mean if I want to install torch-sparse 0.6.16, I need the following dependency ring: torch1.13->cuda11.6. |
The CUDA version needs to match with the one installed by PyTorch, not necessarily your system CUDA. |
I tried to upgrade torch-sparse 0.6.16, however, I got a new error when running the previous code, is there any solution? RuntimeError (note: full exception trace is shown but execution is paused at: _run_module_as_main) |
What does |
torch_sparse.version return |
Hi, do you have any idea about this problem? |
Mh, can you show me the content of |
Mh, looks like this is an issue with PyTorch then, not with import torch
A = torch.randn(5, 5).to_torch_coo_tensor().cuda()
torch.sparse.mm(A, A) also fails for you? |
Yes, running the above code will report the following error. |
Needs to be |
Running the above code is successful :( |
Then I am at a loss :( What happens if you run adj_l @ adj_r in your code above? |
Suppose there are five different data adj_l, adj_r. After running the first four without any problem, running the fifth one will report the error mentioned at the beginning of my question. After I upgraded the torch and torch_sparse, I ran it again and got the second error: |
Do you have a reproducible example? Happy to look into it. |
I have uploaded adl_l and adj_r to google cloud. You can download these two data and run: |
Thanks. Will look into it. |
Hi, Sorry to bother you. |
I can reproduce this :( I assume that your matrices are too large for adj_l = adj_l.to_torch_sparse_csr_tensor()
adj_r = adj_r.to_torch_sparse_csr_tensor()
out = adj_l @ adj_r also fails, while something like adj_l = adj_l[:10000]
adj_r = adj_r[:, :10000]
out = adj_l @ adj_r works. I suggest to create a similar issue in https://github.com/pytorch/pytorch. |
This dataset is actually data from ogbn-mag's PF and FP relationship, I noticed that your work also appears on the mag early list, maybe I'll study your previous work to see how to use torch_sparse to support the mag dataset. |
We never used sparse-sparse matrix multiplication in our benchmarks, so this issue never appeared to us. |
This issue had no activity for 6 months. It will be closed in 2 weeks unless there is some new activity. Is this issue already resolved? |
Hi, Sorry to bother you. |
We are using the PyTorch routine now for SpSpMM, so this is either no longer an issue or needs to be routed to the PyTorch team directly. |
Hi, I'm having the same problem with #174.
I have two large adjacency matrices, the details are as follows
adj_l
SparseTensor(row=tensor([ 0, 0, 0, ..., 736388, 736388, 736388], device='cuda:2'),
col=tensor([ 145, 2215, 3205, ..., 21458, 22283, 31934], device='cuda:2'),
val=tensor([0.0909, 0.0909, 0.0909, ..., 0.1000, 0.1000, 0.1000], device='cuda:2'),
size=(736389, 59965), nnz=7505078, density=0.02%)
adj_r
SparseTensor(row=tensor([ 0, 0, 0, ..., 59962, 59963, 59964], device='cuda:2'),
col=tensor([222683, 370067, 430465, ..., 38176, 514545, 334613], device='cuda:2'),
val=tensor([0.1429, 0.1429, 0.1429, ..., 0.5000, 1.0000, 1.0000], device='cuda:2'),
size=(59965, 736389), nnz=7505078, density=0.02%)
Convert them to sparse format and use the following code,
rowA, colA, _ = adj_l.coo()
rowB, colB, _ = adj_r.coo()
indexA = torch. stack((rowA,colA))
indexB = torch. stack((rowB,colB))
valueA = adj_l.storage._value
valueB = adj_r.storage._value
indexC, valueC = spspmm(indexA, valueA, indexB, valueB, adj_l.size(0), adj_l.size(1), adj_r.size(1), coalesced=True)
Then an error will be reported. CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Even with CUDA_LAUNCH_BLOCKING=1. There is no more information, I believe this is caused by too much memory for the two sparse matrices. Is there any way to run it on gpu?
The text was updated successfully, but these errors were encountered: