Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add static data and model for hybrid link gnn #27

Open
wants to merge 13 commits into
base: master
Choose a base branch
from

Conversation

XinweiHe
Copy link
Collaborator

@XinweiHe XinweiHe commented Sep 15, 2024

Model performance

yelp2018:

Best test metrics: {'link_prediction_map': np.float64(0.010300840885387663), 'link_prediction_precision': np.float64(0.0182834406972338), 'link_prediction_recall': np.float64(0.04186301508148293)}

yelp2018 + lhs embedding:

Best test metrics: {'link_prediction_map': np.float64(0.011119191812776428), 'link_prediction_precision': np.float64(0.01897656940760389), 'link_prediction_recall': np.float64(0.04355107044768818)}

after extensive number of neighbors tuning [16, 16, 16, 16]:

Best test metrics: {'link_prediction_recall': np.float64(0.049)}
After removing known labels from train set, Best test metrics: {'link_prediction_recall': np.float64(0.053)}

baseline is 0.0649

Best ndcg@20 metrics: 0.038154342177117975
After removing known labels from train set, Best ndcg@20 metrics: {'link_prediction_recall': np.float64(0.0413)}

baseline is 0.0530

amazon-book:

Best test metrics: {'link_prediction_map': np.float64(0.006906666366673945), 'link_prediction_precision': np.float64(0.010561560819924392), 'link_prediction_recall': np.float64(0.026417809635051446)}

amazon-book + lhs embedding:

Best test metrics: {'link_prediction_map': np.float64(0.009266786705680494), 'link_prediction_precision': np.float64(0.01382910009688634), 'link_prediction_recall': np.float64(0.0349443496127792)}

after extensive number of neighbors tuning [16, 16, 8, 8]:

Best test metrics: {'link_prediction_recall': np.float64(0.0431)}
After removing known labels from train set, Best test metrics: {'link_prediction_recall': np.float64(0.04513)}

baseline is 0.0419

Best ndcg@20 metrics: 0.033726223620099235
After removing known labels from train set, Best ndcg@20 metrics: 0.03765

baseline is 0.0320

gowalla:

Best test metrics: {'link_prediction_map': np.float64(0.03726399904265111), 'link_prediction_precision': np.float64(0.03518152588920893), 'link_prediction_recall': np.float64(0.13429455195462536)}

gowalla + lhs embedding:

Best test metrics: {'link_prediction_map': np.float64(0.04336352677029089), 'link_prediction_precision': np.float64(0.04146962288163977), 'link_prediction_recall': np.float64(0.1368813331221897)}

after extensive number of neighbors tuning [16, 16, 16, 16]:

Best test metrics: {'link_prediction_recall': np.float64(0.1718)}
After removing known labels from train set, Best test metrics: {'link_prediction_recall': np.float64(0.1720)}

baseline is 0.1830

Best ndcg@20 metrics: 0.12145722834345984
After removing known labels from train set, Best ndcg@20 metrics: {'link_prediction_recall': np.float64(0.1275)}

baseline is 0.1554

Copy link
Contributor

@yiweny yiweny left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rusty1s Shall we assign pseudo timestamps or shall we just remove the time encoder in the model for the static tasks?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather not check in those large files. Can you upload them kumo-public-datasets?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we ever plan to make hybridgnn repo public? If so, kumo-public-datasets makes less sense?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we upload it to pyg? @rusty1s

examples/static_example.py Outdated Show resolved Hide resolved
examples/static_example.py Outdated Show resolved Hide resolved
examples/static_example.py Outdated Show resolved Hide resolved
model = HybridGNN(
data=data,
col_stats_dict=col_stats_dict,
rhs_emb_mode=RHSEmbeddingMode.FUSION,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. Can you let user specify the rhs_embedding_mode?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep it like this since rhs_emb_mode is not customized in relbench_example.py as well.

f"node in any mini-batch. Try to increase the number "
f"of layers/hops and re-try. If you run into memory "
f"issues with deeper nets, decrease the batch size.")
return loss_accum / count_accum if count_accum > 0 else float("nan")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we raise an error here instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, this is following relbench_example.py. In static_example.py I think we can focus on making sure different models work e2e. When we need some benchmarking results we could create another static_link_prediction_benchmark.py

examples/static_example.py Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's host somewhere else.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate a bit? I see that https://github.com/kuandeng/LightGCN/tree/master/Data saves data in the repo, which seems to be the most straightforward and convenient approach.

Comment on lines +153 to +161
def get_static_stype_proposal(
table_dict: Dict[str, pd.DataFrame]) -> Dict[str, Dict[str, stype]]:
r"""Infer style for table columns."""
inferred_col_to_stype_dict = {}
for table_name, df in table_dict.items():
df = df.sample(min(1_000, len(df)))
inferred_col_to_stype = infer_df_stype(df)
inferred_col_to_stype_dict[table_name] = inferred_col_to_stype
return inferred_col_to_stype_dict
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we don't have any features?

# https://github.com/kumo-ai/hybridgnn/blob/master/examples/
# relbench_example.py#L86. This is to make sure each table contains=
# at least one feature.
df = pd.DataFrame({"__const__": np.ones(len(table)), **fkey_dict})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would assume we want lookup embeddings for both LHS and RHS.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation here follow https://github.com/kumo-ai/hybridgnn/blob/master/examples/relbench_example.py#L86 and based on the implementation of make_pkey_fkey_graph in relbench, the primary key and foreign keys are completely removed in src table, dst table and transaction table, I assumed looking embedding for LHS and RHS must be supported in some other way in https://github.com/kumo-ai/hybridgnn/blob/master/examples/relbench_example.py#L86 and correspondingly LHS and RHS are supported in here as well. Let me know if this is not the case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually we should already have rhs embedding with hybridgnn fusion model. We can potentially try adding lhs embedding later.

Comment on lines 118 to 120
pseudo_times = pd.date_range(start=pd.Timestamp('1970-01-01'),
periods=len(train_df), freq='s')
train_df[PSEUDO_TIME] = pseudo_times
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this will work. We need to make sure that the link we are predicting is not part of the subgraph, while all other links in the training set are. We don't achieve this here:

  • The link we are predicting is part of the subgraph
  • A fraction of other links are not part of the subgraph.

It will be inherently challenging to move these static datasets to the rel-bench setting, and we probably shouldn't go this route. We need a pipeline here that

  • samples a positive link for a LHS node
  • creates LHS subgraph and removes the positive link from it
  • trains against this positive link and negative RHS nodes which are not part of ground-truth. Then, the ID-GNN logic is triggered if there exists a path to the positive link within the subgraph, and the ShallowRHS logic is triggered otherwise.

hybridgnn/nn/models/hybridgnn.py Outdated Show resolved Hide resolved
hybridgnn/nn/models/idgnn.py Outdated Show resolved Hide resolved
hybridgnn/nn/models/shallowrhsgnn.py Outdated Show resolved Hide resolved
# Shuffle train data
test_df = test_df.sample(frac=1, random_state=args.seed).reset_index(drop=True)
# Add pseudo time column
test_df[PSEUDO_TIME] = TRAIN_SET_TIMESTAMP + pd.Timedelta(days=1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we still need it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise, there are too many places that need to be changed in model source code.

examples/static_example.py Outdated Show resolved Hide resolved
examples/static_example.py Show resolved Hide resolved
examples/static_example.py Outdated Show resolved Hide resolved
Comment on lines 402 to 405
edge_label_index_hash = edge_label_index[
0, :] * NUM_DST_NODES + edge_label_index[1, :]
edge_index_hash = edge_index[0, :] * NUM_DST_NODES + edge_index[
1, :]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't look correct. edge_index refers to the sampled subgraph, while edge_label_index is global. Don't we need to map edge_label to global indices?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. We probably can do something like pyg-team/pytorch_geometric#6923 (reply in thread).

examples/static_example.py Show resolved Hide resolved

# Mask to filter out edges in edge_index_hash that are in
# edge_label_index_hash
mask = ~torch.isin(edge_index_hash, edge_label_index_hash)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that we only want to filter out edge_label_index edges per example/subgraph, and not globally, i.e., an entity in the batch should still be able to see other training edges.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate? I thought the purpose of mini-batching here is to sample both supervision edges and message passing edges? Are you suggesting we only sample supervision edges but not message passing edges? i.e. we use the all existing edges as message passing edge in each minibatch?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean that during mini-batch sampling, you get a subgraph per entity. You need to make sure here to only filter out supervision edges per example in the corresponding subgraph, and not in all subgraphs that the mini-batch holds.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants