Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add static data and model for hybrid link gnn #27

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
508 changes: 508 additions & 0 deletions examples/static_example.py

Large diffs are not rendered by default.

11 changes: 7 additions & 4 deletions hybridgnn/nn/models/hybridgnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
norm: str = 'layer_norm',
torch_frame_model_cls: Type[torch.nn.Module] = ResNet,
torch_frame_model_kwargs: Optional[Dict[str, Any]] = None,
is_static: Optional[bool] = False,
) -> None:
super().__init__(data, col_stats_dict, rhs_emb_mode, dst_entity_table,
num_nodes, embedding_dim)
Expand Down Expand Up @@ -76,6 +77,7 @@ def __init__(
self.lin_offset_idgnn = torch.nn.Linear(embedding_dim, 1)
self.lin_offset_embgnn = torch.nn.Linear(embedding_dim, 1)
self.channels = channels
self.is_static = is_static

self.reset_parameters()

Expand Down Expand Up @@ -103,11 +105,12 @@ def forward(
# Add ID-awareness to the root node
x_dict[entity_table][:seed_time.size(0
)] += self.id_awareness_emb.weight
rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict,
batch.batch_dict)
if self.is_static is not True:
XinweiHe marked this conversation as resolved.
Show resolved Hide resolved
rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict,
batch.batch_dict)

for node_type, rel_time in rel_time_dict.items():
x_dict[node_type] = x_dict[node_type] + rel_time
for node_type, rel_time in rel_time_dict.items():
x_dict[node_type] = x_dict[node_type] + rel_time

x_dict = self.gnn(
x_dict,
Expand Down
11 changes: 7 additions & 4 deletions hybridgnn/nn/models/idgnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
norm: str = 'layer_norm',
torch_frame_model_cls: Type[torch.nn.Module] = ResNet,
torch_frame_model_kwargs: Optional[Dict[str, Any]] = None,
is_static=True,
) -> None:
super().__init__()

Expand Down Expand Up @@ -65,6 +66,7 @@ def __init__(
)

self.id_awareness_emb = torch.nn.Embedding(1, channels)
self.is_static = is_static
self.reset_parameters()

def reset_parameters(self) -> None:
Expand All @@ -85,11 +87,12 @@ def forward(
# Add ID-awareness to the root node
x_dict[entity_table][:seed_time.size(0
)] += self.id_awareness_emb.weight
rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict,
batch.batch_dict)
if self.is_static is not True:
XinweiHe marked this conversation as resolved.
Show resolved Hide resolved
rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict,
batch.batch_dict)

for node_type, rel_time in rel_time_dict.items():
x_dict[node_type] = x_dict[node_type] + rel_time
for node_type, rel_time in rel_time_dict.items():
x_dict[node_type] = x_dict[node_type] + rel_time

x_dict = self.gnn(
x_dict,
Expand Down
13 changes: 9 additions & 4 deletions hybridgnn/nn/models/shallowrhsgnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
norm: str = 'layer_norm',
torch_frame_model_cls: Type[torch.nn.Module] = ResNet,
torch_frame_model_kwargs: Optional[Dict[str, Any]] = None,
is_static: Optional[bool] = False,
) -> None:
super().__init__(data, col_stats_dict, rhs_emb_mode, dst_entity_table,
num_nodes, embedding_dim)
Expand Down Expand Up @@ -71,6 +72,8 @@ def __init__(
)
self.lhs_projector = torch.nn.Linear(channels, embedding_dim)
self.id_awareness_emb = torch.nn.Embedding(1, channels)
self.is_static = is_static

self.reset_parameters()

def reset_parameters(self) -> None:
Expand All @@ -94,11 +97,13 @@ def forward(
# Add ID-awareness to the root node
x_dict[entity_table][:seed_time.size(0
)] += self.id_awareness_emb.weight
rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict,
batch.batch_dict)

for node_type, rel_time in rel_time_dict.items():
x_dict[node_type] = x_dict[node_type] + rel_time
if self.is_static is not True:
XinweiHe marked this conversation as resolved.
Show resolved Hide resolved
rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict,
batch.batch_dict)

for node_type, rel_time in rel_time_dict.items():
x_dict[node_type] = x_dict[node_type] + rel_time

x_dict = self.gnn(
x_dict,
Expand Down
1 change: 1 addition & 0 deletions static_data/amazon-book/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Look for the full dataset? Please visit the [websit](http://jmcauley.ucsd.edu/data/amazon).
Loading
Loading