Skip to content

Commit

Permalink
Fix TorchScript support in case torch-sparse is not installed (#8738)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Jan 8, 2024
1 parent 5230418 commit 32e3872
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
18 changes: 9 additions & 9 deletions test/nn/conv/test_fa_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ def test_fa_conv():

if torch_geometric.typing.WITH_TORCH_SPARSE:
adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))
assert torch.allclose(conv(x, x_0, adj2.t()), out)
assert torch.allclose(conv(x, x_0, adj2.t()), out, atol=1e-6)
assert conv._cached_adj_t is not None
assert torch.allclose(conv(x, x_0, adj2.t()), out)
assert torch.allclose(conv(x, x_0, adj2.t()), out, atol=1e-6)

if is_full_test():

Expand All @@ -44,7 +44,7 @@ def forward(
return self.conv(x, x_0, edge_index)

jit = torch.jit.script(MyModule())
assert torch.allclose(jit(x, x_0, edge_index), out)
assert torch.allclose(jit(x, x_0, edge_index), out, atol=1e-6)

if torch_geometric.typing.WITH_TORCH_SPARSE:
assert torch.allclose(jit(x, x_0, adj2.t()), out)
Expand All @@ -56,13 +56,13 @@ def forward(
# Test without caching:
conv.cached = False
out = conv(x, x_0, edge_index)
assert torch.allclose(conv(x, x_0, adj1.t()), out)
assert torch.allclose(conv(x, x_0, adj1.t()), out, atol=1e-6)
if torch_geometric.typing.WITH_TORCH_SPARSE:
assert torch.allclose(conv(x, x_0, adj2.t()), out)
assert torch.allclose(conv(x, x_0, adj2.t()), out, atol=1e-6)

# Test `return_attention_weights`:
result = conv(x, x_0, edge_index, return_attention_weights=True)
assert torch.allclose(result[0], out)
assert torch.allclose(result[0], out, atol=1e-6)
assert result[1][0].size() == (2, 10)
assert result[1][1].size() == (10, )
assert conv._alpha is None
Expand All @@ -75,7 +75,7 @@ def forward(

if torch_geometric.typing.WITH_TORCH_SPARSE:
result = conv(x, x_0, adj2.t(), return_attention_weights=True)
assert torch.allclose(result[0], out)
assert torch.allclose(result[0], out, atol=1e-6)
assert result[1].sizes() == [4, 4] and result[1].nnz() == 10
assert conv._alpha is None

Expand All @@ -97,7 +97,7 @@ def forward(

jit = torch.jit.script(MyModule())
result = jit(x, x_0, edge_index)
assert torch.allclose(result[0], out)
assert torch.allclose(result[0], out, atol=1e-6)
assert result[1][0].size() == (2, 10)
assert result[1][1].size() == (10, )
assert conv._alpha is None
Expand All @@ -120,6 +120,6 @@ def forward(

jit = torch.jit.script(MyModule())
result = jit(x, x_0, adj2.t())
assert torch.allclose(result[0], out)
assert torch.allclose(result[0], out, atol=1e-6)
assert result[1].sizes() == [4, 4] and result[1].nnz() == 10
assert conv._alpha is None
10 changes: 10 additions & 0 deletions torch_geometric/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,12 @@ def __init__(
):
raise ImportError("'SparseStorage' requires 'torch-sparse'")

def value(self) -> Optional[Tensor]:
raise ImportError("'SparseStorage' requires 'torch-sparse'")

def rowcount(self) -> Tensor:
raise ImportError("'SparseStorage' requires 'torch-sparse'")

class SparseTensor: # type: ignore
def __init__(
self,
Expand All @@ -150,6 +156,10 @@ def from_edge_index(
) -> 'SparseTensor':
raise ImportError("'SparseTensor' requires 'torch-sparse'")

@property
def storage(self) -> SparseStorage:
raise ImportError("'SparseTensor' requires 'torch-sparse'")

@classmethod
def from_dense(self, mat: Tensor,
has_value: bool = True) -> 'SparseTensor':
Expand Down

0 comments on commit 32e3872

Please sign in to comment.