-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add static data and model for hybrid link gnn #27
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rusty1s Shall we assign pseudo timestamps or shall we just remove the time encoder in the model for the static tasks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather not check in those large files. Can you upload them kumo-public-datasets
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we ever plan to make hybridgnn repo public? If so, kumo-public-datasets
makes less sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we upload it to pyg? @rusty1s
examples/static_example.py
Outdated
model = HybridGNN( | ||
data=data, | ||
col_stats_dict=col_stats_dict, | ||
rhs_emb_mode=RHSEmbeddingMode.FUSION, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here. Can you let user specify the rhs_embedding_mode?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's keep it like this since rhs_emb_mode
is not customized in relbench_example.py
as well.
f"node in any mini-batch. Try to increase the number " | ||
f"of layers/hops and re-try. If you run into memory " | ||
f"issues with deeper nets, decrease the batch size.") | ||
return loss_accum / count_accum if count_accum > 0 else float("nan") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we raise an error here instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly, this is following relbench_example.py
. In static_example.py
I think we can focus on making sure different models work e2e. When we need some benchmarking results we could create another static_link_prediction_benchmark.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's host somewhere else.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you elaborate a bit? I see that https://github.com/kuandeng/LightGCN/tree/master/Data saves data in the repo, which seems to be the most straightforward and convenient approach.
def get_static_stype_proposal( | ||
table_dict: Dict[str, pd.DataFrame]) -> Dict[str, Dict[str, stype]]: | ||
r"""Infer style for table columns.""" | ||
inferred_col_to_stype_dict = {} | ||
for table_name, df in table_dict.items(): | ||
df = df.sample(min(1_000, len(df))) | ||
inferred_col_to_stype = infer_df_stype(df) | ||
inferred_col_to_stype_dict[table_name] = inferred_col_to_stype | ||
return inferred_col_to_stype_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought we don't have any features?
examples/static_example.py
Outdated
# https://github.com/kumo-ai/hybridgnn/blob/master/examples/ | ||
# relbench_example.py#L86. This is to make sure each table contains= | ||
# at least one feature. | ||
df = pd.DataFrame({"__const__": np.ones(len(table)), **fkey_dict}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would assume we want lookup embeddings for both LHS and RHS.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation here follow https://github.com/kumo-ai/hybridgnn/blob/master/examples/relbench_example.py#L86 and based on the implementation of make_pkey_fkey_graph
in relbench, the primary key and foreign keys are completely removed in src table, dst table and transaction table, I assumed looking embedding for LHS and RHS must be supported in some other way in https://github.com/kumo-ai/hybridgnn/blob/master/examples/relbench_example.py#L86 and correspondingly LHS and RHS are supported in here as well. Let me know if this is not the case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually we should already have rhs embedding with hybridgnn fusion model. We can potentially try adding lhs embedding later.
examples/static_example.py
Outdated
pseudo_times = pd.date_range(start=pd.Timestamp('1970-01-01'), | ||
periods=len(train_df), freq='s') | ||
train_df[PSEUDO_TIME] = pseudo_times |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this will work. We need to make sure that the link we are predicting is not part of the subgraph, while all other links in the training set are. We don't achieve this here:
- The link we are predicting is part of the subgraph
- A fraction of other links are not part of the subgraph.
It will be inherently challenging to move these static datasets to the rel-bench setting, and we probably shouldn't go this route. We need a pipeline here that
- samples a positive link for a LHS node
- creates LHS subgraph and removes the positive link from it
- trains against this positive link and negative RHS nodes which are not part of ground-truth. Then, the ID-GNN logic is triggered if there exists a path to the positive link within the subgraph, and the ShallowRHS logic is triggered otherwise.
# Shuffle train data | ||
test_df = test_df.sample(frac=1, random_state=args.seed).reset_index(drop=True) | ||
# Add pseudo time column | ||
test_df[PSEUDO_TIME] = TRAIN_SET_TIMESTAMP + pd.Timedelta(days=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we still need it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise, there are too many places that need to be changed in model source code.
examples/static_example.py
Outdated
edge_label_index_hash = edge_label_index[ | ||
0, :] * NUM_DST_NODES + edge_label_index[1, :] | ||
edge_index_hash = edge_index[0, :] * NUM_DST_NODES + edge_index[ | ||
1, :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this doesn't look correct. edge_index
refers to the sampled subgraph, while edge_label_index
is global. Don't we need to map edge_label
to global indices?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. We probably can do something like pyg-team/pytorch_geometric#6923 (reply in thread).
examples/static_example.py
Outdated
|
||
# Mask to filter out edges in edge_index_hash that are in | ||
# edge_label_index_hash | ||
mask = ~torch.isin(edge_index_hash, edge_label_index_hash) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that we only want to filter out edge_label_index
edges per example/subgraph, and not globally, i.e., an entity in the batch should still be able to see other training edges.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate? I thought the purpose of mini-batching here is to sample both supervision edges and message passing edges? Are you suggesting we only sample supervision edges but not message passing edges? i.e. we use the all existing edges as message passing edge in each minibatch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean that during mini-batch sampling, you get a subgraph per entity. You need to make sure here to only filter out supervision edges per example in the corresponding subgraph, and not in all subgraphs that the mini-batch holds.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
examples/static_example.py
script to run hybrid link gnn for these three static dataset.Model performance
yelp2018:
yelp2018 + lhs embedding:
after extensive number of neighbors tuning
[16, 16, 16, 16]
:baseline is
0.0649
baseline is
0.0530
amazon-book:
amazon-book + lhs embedding:
after extensive number of neighbors tuning
[16, 16, 8, 8]
:baseline is
0.0419
baseline is
0.0320
gowalla:
gowalla + lhs embedding:
after extensive number of neighbors tuning
[16, 16, 16, 16]
:baseline is
0.1830
baseline is
0.1554