Skip to content

Commit

Permalink
Test GNNExplainer with TransformerConv (#9451)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Jun 24, 2024
1 parent 7bde377 commit 9b7874b
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions test/explain/algorithm/test_gnn_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,17 @@
ModelReturnType,
ModelTaskLevel,
)
from torch_geometric.nn import AttentiveFP, ChebConv, GCNConv, global_add_pool
from torch_geometric.nn import (
AttentiveFP,
ChebConv,
GCNConv,
TransformerConv,
global_add_pool,
)


class GCN(torch.nn.Module):
def __init__(self, model_config: ModelConfig):
class GNN(torch.nn.Module):
def __init__(self, Conv, model_config: ModelConfig):
super().__init__()
self.model_config = model_config

Expand All @@ -23,8 +29,8 @@ def __init__(self, model_config: ModelConfig):
else:
out_channels = 1

self.conv1 = GCNConv(3, 16)
self.conv2 = GCNConv(16, out_channels)
self.conv1 = Conv(3, 16)
self.conv2 = Conv(16, out_channels)

# Add unused parameter:
self.param = torch.nn.Parameter(torch.empty(1))
Expand Down Expand Up @@ -71,6 +77,7 @@ def forward(self, x, edge_index, batch=None, edge_label_index=None):
edge_label_index = torch.tensor([[0, 1, 2], [3, 4, 5]])


@pytest.mark.parametrize('Conv', [GCNConv, TransformerConv])
@pytest.mark.parametrize('edge_mask_type', edge_mask_types)
@pytest.mark.parametrize('node_mask_type', node_mask_types)
@pytest.mark.parametrize('explanation_type', explanation_types)
Expand All @@ -81,6 +88,7 @@ def forward(self, x, edge_index, batch=None, edge_label_index=None):
])
@pytest.mark.parametrize('index', indices)
def test_gnn_explainer_binary_classification(
Conv,
edge_mask_type,
node_mask_type,
explanation_type,
Expand All @@ -95,7 +103,7 @@ def test_gnn_explainer_binary_classification(
return_type=return_type,
)

model = GCN(model_config)
model = GNN(Conv, model_config)

target = None
if explanation_type == ExplanationType.phenomenon:
Expand Down Expand Up @@ -130,6 +138,7 @@ def test_gnn_explainer_binary_classification(
check_explanation(explanation, node_mask_type, edge_mask_type)


@pytest.mark.parametrize('Conv', [GCNConv])
@pytest.mark.parametrize('edge_mask_type', edge_mask_types)
@pytest.mark.parametrize('node_mask_type', node_mask_types)
@pytest.mark.parametrize('explanation_type', explanation_types)
Expand All @@ -141,6 +150,7 @@ def test_gnn_explainer_binary_classification(
])
@pytest.mark.parametrize('index', indices)
def test_gnn_explainer_multiclass_classification(
Conv,
edge_mask_type,
node_mask_type,
explanation_type,
Expand All @@ -155,7 +165,7 @@ def test_gnn_explainer_multiclass_classification(
return_type=return_type,
)

model = GCN(model_config)
model = GNN(Conv, model_config)

target = None
if explanation_type == ExplanationType.phenomenon:
Expand Down Expand Up @@ -186,12 +196,14 @@ def test_gnn_explainer_multiclass_classification(
check_explanation(explanation, node_mask_type, edge_mask_type)


@pytest.mark.parametrize('Conv', [GCNConv])
@pytest.mark.parametrize('edge_mask_type', edge_mask_types)
@pytest.mark.parametrize('node_mask_type', node_mask_types)
@pytest.mark.parametrize('explanation_type', explanation_types)
@pytest.mark.parametrize('task_level', task_levels)
@pytest.mark.parametrize('index', indices)
def test_gnn_explainer_regression(
Conv,
edge_mask_type,
node_mask_type,
explanation_type,
Expand All @@ -204,7 +216,7 @@ def test_gnn_explainer_regression(
task_level=task_level,
)

model = GCN(model_config)
model = GNN(Conv, model_config)

target = None
if explanation_type == ExplanationType.phenomenon:
Expand Down

0 comments on commit 9b7874b

Please sign in to comment.