Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem when passing a SparseTensor to PyG GCNconv #382

Open
yiming421 opened this issue Sep 8, 2024 · 0 comments
Open

Problem when passing a SparseTensor to PyG GCNconv #382

yiming421 opened this issue Sep 8, 2024 · 0 comments

Comments

@yiming421
Copy link

The problem occured when I try to pass a SparseTensor to PyG GCNconv. I'm working with python 3.10, cuda 12.1, torch 2.2.0, PyG 2.5.2 and torch_sparse 0.6.18 installed by conda on a ubuntu server, then things didn't work well. No matter how I change the way to create the SparseTensor object, the problem just persists. I'm wondering whether the problem comes from some version compatibility issues or there's something wrong in my environment setting(very simple because I just installed torch pyg and torch_sparse). Does anyone meet similar problem or get some idea on why this issue takes place?
I think you can reproduce the issue by running following code:
def test():
ei = torch.tensor([[2, 3, 4], [1, 2, 3]]).cuda(0)
sp = SparseTensor.from_edge_index(ei, sparse_sizes=(5, 5))
model = GCNConv(2, 2).cuda(0)
x = torch.tensor([[1, 1], [1, 1], [1, 1], [1, 1], [1, 1]]).float().cuda(0)
print(x, sp)
model(x, sp)
print('success')
test()
Here is the error message:
Traceback (most recent call last):
File "/.../debug.py", line 15, in
test()
File "/.../debug.py", line 13, in test
model(x, sp)
File "/.../miniconda3/envs/wl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/.../miniconda3/envs/wl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/.../miniconda3/envs/wl/lib/python3.10/site-packages/torch_geometric/nn/conv/gcn_conv.py", line 252, in forward
edge_index = gcn_norm( # yapf: disable
File "/.../miniconda3/envs/wl/lib/python3.10/site-packages/torch_geometric/nn/conv/gcn_conv.py", line 64, in gcn_norm
adj_t = torch_sparse.fill_diag(adj_t, fill_value)
File "/.../miniconda3/envs/wl/lib/python3.10/site-packages/torch_sparse/diag.py", line 92, in fill_diag
return set_diag(src, value.new_full(sizes, fill_value), k)
File "/.../miniconda3/envs/wl/lib/python3.10/site-packages/torch_sparse/diag.py", line 49, in set_diag
new_row[mask] = row
RuntimeError: shape mismatch: value tensor of shape [3] cannot be broadcast to indexing result of shape [0]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant