Skip to content

Commit

Permalink
Test EdgeIndex.matmul backward (#8481)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Nov 29, 2023
1 parent 1d11179 commit b6d16f1
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 10 deletions.
32 changes: 31 additions & 1 deletion test/data/test_edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch import Tensor

import torch_geometric
from torch_geometric.data.edge_index import EdgeIndex, to_dense
from torch_geometric.data.edge_index import EdgeIndex, matmul, to_dense
from torch_geometric.testing import (
disableExtensions,
onlyCUDA,
Expand Down Expand Up @@ -323,6 +323,36 @@ def test_matmul():
assert torch.allclose(out.to_dense(), adj2.to_dense() @ adj2.to_dense())


def test_matmul_input_value():
adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row')

x = torch.randn(3, 1)
value = torch.randn(4)

out = matmul(adj, x, input_value=value)
assert torch.allclose(out, to_dense(adj, value=value) @ x)


def test_matmul_grad():
adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row')

x1 = torch.randn(3, 1, requires_grad=True)
value = torch.randn(4, requires_grad=True)

out = matmul(adj, x1, input_value=value)
grad_out = torch.randn_like(out)
out.backward(grad_out)

x2 = x1.detach().requires_grad_()
dense_adj = to_dense(adj, value=value).detach().requires_grad_()
out = dense_adj @ x2
out.backward(grad_out)

assert torch.allclose(x1.grad, x2.grad)
if torch_geometric.typing.WITH_PT21: # TODO Investigate.
assert torch.allclose(value.grad, dense_adj.grad[adj[0], adj[1]])


def test_save_and_load(tmp_path):
adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row')
adj.fill_cache()
Expand Down
21 changes: 12 additions & 9 deletions torch_geometric/data/edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ def to_sparse_csr(tensor: EdgeIndex, value: Optional[Tensor] = None) -> Tensor:
return torch.sparse_csr_tensor(
crow_indices=tensor.get_rowptr(),
col_indices=tensor[1],
values=_get_value(tensor.size(1), device=tensor.device),
values=value,
size=tensor.get_sparse_size(),
device=tensor.device,
)
Expand Down Expand Up @@ -695,8 +695,8 @@ def to_sparse(tensor: EdgeIndex, value: Optional[Tensor] = None) -> Tensor:
def matmul(
input: EdgeIndex,
other: Union[Tensor, EdgeIndex],
input_weight: Optional[Tensor] = None,
other_weight: Optional[Tensor] = None,
input_value: Optional[Tensor] = None,
other_value: Optional[Tensor] = None,
reduce: Literal['sum'] = 'sum',
) -> Union[Tensor, Tuple[EdgeIndex, Tensor]]:

Expand All @@ -705,15 +705,18 @@ def matmul(
# TODO Utilize available `CSC` representation for faster backward passes.

if input._sort_order == SortOrder.COL:
input = to_sparse_csc(input, other_weight)
input = to_sparse_csc(input, input_value)
else:
input = to_sparse_csr(input, other_weight)
input = to_sparse_csr(input, input_value)

if isinstance(other, EdgeIndex):
if other._sort_order == SortOrder.COL:
other = to_sparse_csc(other, other_weight)
other = to_sparse_csc(other, other_value)
else:
other = to_sparse_csr(other, other_weight)
other = to_sparse_csr(other, other_value)

elif other_value is not None:
raise ValueError("'other_value' not supported for sparse-dense "
"matrix multiplication")

return Tensor.__torch_function__( #
Tensor.matmul, (Tensor, ), (input, other))
return torch.matmul(input, other)

0 comments on commit b6d16f1

Please sign in to comment.