diff --git a/relbench/modeling/nn.py b/relbench/modeling/nn.py index 40d26987..0d4c9f2f 100644 --- a/relbench/modeling/nn.py +++ b/relbench/modeling/nn.py @@ -5,7 +5,7 @@ from torch import Tensor from torch_frame.data.stats import StatType from torch_frame.nn.models import ResNet -from torch_geometric.nn import HeteroConv, LayerNorm, PositionalEncoding, SAGEConv +from torch_geometric.nn import HeteroConv, LayerNorm, PositionalEncoding, SAGEConv, GATConv from torch_geometric.typing import EdgeType, NodeType @@ -171,3 +171,60 @@ def forward( x_dict = {key: x.relu() for key, x in x_dict.items()} return x_dict + +class HeteroGAT(torch.nn.Module): + """ + Implementation of heterogeneous GAT. + """ + + def __init__( + self, + node_types: List[NodeType], + edge_types: List[EdgeType], + channels: int, + aggr: str = "mean", + num_layers: int = 2, + ): + super().__init__() + + self.convs = torch.nn.ModuleList() + for _ in range(num_layers): + conv = HeteroConv( + { + edge_type: GATConv( + (channels, channels), channels, heads=1, add_self_loops=False + ) + for edge_type in edge_types + }, + aggr="sum", + ) + self.convs.append(conv) + + self.norms = torch.nn.ModuleList() + for _ in range(num_layers): + norm_dict = torch.nn.ModuleDict() + for node_type in node_types: + norm_dict[node_type] = LayerNorm(channels, mode="node") + self.norms.append(norm_dict) + + def reset_parameters(self): + for conv in self.convs: + conv.reset_parameters() + for norm_dict in self.norms: + for norm in norm_dict.values(): + norm.reset_parameters() + + def forward( + self, + x_dict: Dict[NodeType, Tensor], + edge_index_dict: Dict[NodeType, Tensor], + num_sampled_nodes_dict: Optional[Dict[NodeType, List[int]]] = None, + num_sampled_edges_dict: Optional[Dict[EdgeType, List[int]]] = None, + ) -> Dict[NodeType, Tensor]: + for _, (conv, norm_dict) in enumerate(zip(self.convs, self.norms)): + x_dict = conv(x_dict, edge_index_dict) + x_dict = {key: norm_dict[key](x) for key, x in x_dict.items()} + x_dict = {key: x.relu() for key, x in x_dict.items()} + + return x_dict +