Skip to content

Commit

Permalink
Add a singular GATConv layer
Browse files Browse the repository at this point in the history
  • Loading branch information
cpondoc committed Dec 12, 2024
1 parent 6bcb12a commit 2c1cd88
Showing 1 changed file with 58 additions and 1 deletion.
59 changes: 58 additions & 1 deletion relbench/modeling/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch import Tensor
from torch_frame.data.stats import StatType
from torch_frame.nn.models import ResNet
from torch_geometric.nn import HeteroConv, LayerNorm, PositionalEncoding, SAGEConv
from torch_geometric.nn import HeteroConv, LayerNorm, PositionalEncoding, SAGEConv, GATConv
from torch_geometric.typing import EdgeType, NodeType


Expand Down Expand Up @@ -171,3 +171,60 @@ def forward(
x_dict = {key: x.relu() for key, x in x_dict.items()}

return x_dict

class HeteroGAT(torch.nn.Module):
"""
Implementation of heterogeneous GAT.
"""

def __init__(
self,
node_types: List[NodeType],
edge_types: List[EdgeType],
channels: int,
aggr: str = "mean",
num_layers: int = 2,
):
super().__init__()

self.convs = torch.nn.ModuleList()
for _ in range(num_layers):
conv = HeteroConv(
{
edge_type: GATConv(
(channels, channels), channels, heads=1, add_self_loops=False
)
for edge_type in edge_types
},
aggr="sum",
)
self.convs.append(conv)

self.norms = torch.nn.ModuleList()
for _ in range(num_layers):
norm_dict = torch.nn.ModuleDict()
for node_type in node_types:
norm_dict[node_type] = LayerNorm(channels, mode="node")
self.norms.append(norm_dict)

def reset_parameters(self):
for conv in self.convs:
conv.reset_parameters()
for norm_dict in self.norms:
for norm in norm_dict.values():
norm.reset_parameters()

def forward(
self,
x_dict: Dict[NodeType, Tensor],
edge_index_dict: Dict[NodeType, Tensor],
num_sampled_nodes_dict: Optional[Dict[NodeType, List[int]]] = None,
num_sampled_edges_dict: Optional[Dict[EdgeType, List[int]]] = None,
) -> Dict[NodeType, Tensor]:
for _, (conv, norm_dict) in enumerate(zip(self.convs, self.norms)):
x_dict = conv(x_dict, edge_index_dict)
x_dict = {key: norm_dict[key](x) for key, x in x_dict.items()}
x_dict = {key: x.relu() for key, x in x_dict.items()}

return x_dict

0 comments on commit 2c1cd88

Please sign in to comment.