Skip to content

Commit

Permalink
Merge branch 'master' into feature/adding_ssma_aggregation
Browse files Browse the repository at this point in the history
  • Loading branch information
AlmogDavid authored Dec 4, 2024
2 parents 0b594f6 + 4670584 commit 25aba74
Show file tree
Hide file tree
Showing 58 changed files with 6,391 additions and 15 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added various GRetriever Architecture Benchmarking examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))
- Added `profiler.nvtxit` with some examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))
- Added `loader.RagQueryLoader` with Remote Backend Example ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))
- Added `data.LargeGraphIndexer` ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))
- Added `GIT-Mol` ([#9730](https://github.com/pyg-team/pytorch_geometric/pull/9730))
- Added comment in `g_retriever.py` pointing to `Neo4j` Graph DB integration demo ([#9748](https://github.com/pyg-team/pytorch_geometric/pull/9797))
- Added `MoleculeGPT` example ([#9710](https://github.com/pyg-team/pytorch_geometric/pull/9710))
- Added `nn.models.GLEM` ([#9662](https://github.com/pyg-team/pytorch_geometric/pull/9662))
- Added `TAGDataset` ([#9662](https://github.com/pyg-team/pytorch_geometric/pull/9662))
- Added support for fast `Delaunay()` triangulation via the `torch_delaunay` package ([#9748](https://github.com/pyg-team/pytorch_geometric/pull/9748))
- Added PyTorch 2.5 support ([#9779](https://github.com/pyg-team/pytorch_geometric/pull/9779), [#9779](https://github.com/pyg-team/pytorch_geometric/pull/9780))
- Support 3D tetrahedral mesh elements of shape `[4, num_faces]` in the `FaceToEdge` transformation ([#9776](https://github.com/pyg-team/pytorch_geometric/pull/9776))
Expand All @@ -18,6 +27,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Changed

- Dropped Python 3.8 support ([#9696](https://github.com/pyg-team/pytorch_geometric/pull/9606))
- Added a check that confirms that custom edge types of `NumNeighbors` actually exist in the graph ([#9807](https://github.com/pyg-team/pytorch_geometric/pull/9807))

### Deprecated

Expand Down
12 changes: 9 additions & 3 deletions examples/llm/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Examples for Co-training LLMs and GNNs

| Example | Description |
| ------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information |
| Example | Description |
| -------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information |
| [`g_retriever_utils/`](./g_retriever_utils/) | Contains multiple scripts for benchmarking GRetriever's architecture and evaluating different retrieval methods. |
| [`multihop_rag/`](./multihop_rag/) | Contains starter code and an example run for building a Multi-hop dataset using WikiHop5M and 2WikiMultiHopQA |
| [`nvtx_examples/`](./nvtx_examples/) | Contains examples of how to wrap functions using the NVTX profiler for CUDA runtime analysis. |
| [`molecule_gpt.py`](./molecule_gpt.py) | Example for MoleculeGPT: Instruction Following Large Language Models for Molecular Property Prediction |
| [`glem.py`](./glem.py) | Example for [GLEM](https://arxiv.org/abs/2210.14709), a GNN+LLM co-training model via variational Expectation-Maximization (EM) framework on node classification tasks to achieve SOTA results |
| [`git_mol.py`](./git_mol.py) | Example for GIT-Mol: A Multi-modal Large Language Model for Molecular Science with Graph, Image, and Text |
7 changes: 7 additions & 0 deletions examples/llm/g_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@
Requirements:
`pip install datasets transformers pcst_fast sentencepiece accelerate`
Example repo for integration with Neo4j Graph DB:
https://github.com/neo4j-product-examples/neo4j-gnn-llm-example
"""
import argparse
import gc
import math
import os.path as osp
import re
Expand Down Expand Up @@ -142,6 +146,9 @@ def adjust_learning_rate(param_group, LR, epoch):
test_loader = DataLoader(test_dataset, batch_size=eval_batch_size,
drop_last=False, pin_memory=True, shuffle=False)

# To clean up after Data Preproc
gc.collect()
torch.cuda.empty_cache()
gnn = GAT(
in_channels=1024,
hidden_channels=hidden_channels,
Expand Down
11 changes: 11 additions & 0 deletions examples/llm/g_retriever_utils/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Examples for LLM and GNN co-training

| Example | Description |
| ---------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [`rag_feature_store.py`](./rag_feature_store.py) | A Proof of Concept Implementation of a RAG enabled FeatureStore that can serve as a starting point for implementing a custom RAG Remote Backend |
| [`rag_graph_store.py`](./rag_graph_store.py) | A Proof of Concept Implementation of a RAG enabled GraphStore that can serve as a starting point for implementing a custom RAG Remote Backend |
| [`rag_backend_utils.py`](./rag_backend_utils.py) | Utility functions used for loading a series of Knowledge Graph Triplets into the Remote Backend defined by a FeatureStore and GraphStore |
| [`rag_generate.py`](./rag_generate.py) | Script for generating a unique set of subgraphs from the WebQSP dataset using a custom defined retrieval algorithm (defaults to the FeatureStore and GraphStore provided) |
| [`benchmark_model_archs_rag.py`](./benchmark_model_archs_rag.py) | Script for running a GNN/LLM benchmark on GRetriever while grid searching relevent architecture parameters and datasets. |

NOTE: Evaluating performance on GRetriever with smaller sample sizes may result in subpar performance. It is not unusual for the fine-tuned model/LLM to perform worse than an untrained LLM on very small sample sizes.
105 changes: 105 additions & 0 deletions examples/llm/g_retriever_utils/benchmark_model_archs_rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""Used to benchmark the performance of an untuned/fine tuned LLM against
GRetriever with various architectures and layer depths.
"""
# %%
import argparse
import sys

import torch

from torch_geometric.datasets import WebQSPDataset
from torch_geometric.nn.models import GAT, MLP, GRetriever

sys.path.append('..')
from minimal_demo import ( # noqa: E402 # isort:skip
benchmark_models, get_loss, inference_step,
)

# %%
parser = argparse.ArgumentParser(
description="""Benchmarker for GRetriever\n""" +
"""NOTE: Evaluating with smaller samples may result in poorer""" +
""" performance for the trained models compared to """ +
"""untrained models.""")
parser.add_argument("--hidden_channels", type=int, default=1024)
parser.add_argument("--learning_rate", type=float, default=1e-5)
parser.add_argument("--epochs", type=int, default=2)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--eval_batch_size", type=int, default=16)
parser.add_argument("--tiny_llama", action='store_true')

parser.add_argument("--dataset_path", type=str, required=False)
# Default to WebQSP split
parser.add_argument("--num_train", type=int, default=2826)
parser.add_argument("--num_val", type=int, default=246)
parser.add_argument("--num_test", type=int, default=1628)

args = parser.parse_args()

# %%
hidden_channels = args.hidden_channels
lr = args.learning_rate
epochs = args.epochs
batch_size = args.batch_size
eval_batch_size = args.eval_batch_size

# %%
if not args.dataset_path:
ds = WebQSPDataset('benchmark_archs', verbose=True, force_reload=True)
else:
# We just assume that the size of the dataset accomodates the
# train/val/test split, because checking may be expensive.
dataset = torch.load(args.dataset_path)

class MockDataset:
"""Utility class to patch the fields in WebQSPDataset used by
GRetriever.
"""
def __init__(self) -> None:
pass

@property
def split_idxs(self) -> dict:
# Imitates the WebQSP split method
return {
"train":
torch.arange(args.num_train),
"val":
torch.arange(args.num_val) + args.num_train,
"test":
torch.arange(args.num_test) + args.num_train + args.num_val,
}

def __getitem__(self, idx: int):
return dataset[idx]

ds = MockDataset()

# %%
model_names = []
model_classes = []
model_kwargs = []
model_type = ["GAT", "MLP"]
models = {"GAT": GAT, "MLP": MLP}
# Use to vary the depth of the GNN model
num_layers = [4]
# Use to vary the number of LLM tokens reserved for GNN output
num_tokens = [1]
for m_type in model_type:
for n_layer in num_layers:
for n_tokens in num_tokens:
model_names.append(f"{m_type}_{n_layer}_{n_tokens}")
model_classes.append(GRetriever)
kwargs = dict(gnn_hidden_channels=hidden_channels,
num_gnn_layers=n_layer, gnn_to_use=models[m_type],
mlp_out_tokens=n_tokens)
if args.tiny_llama:
kwargs['llm_to_use'] = 'TinyLlama/TinyLlama-1.1B-Chat-v0.1'
kwargs['mlp_out_dim'] = 2048
kwargs['num_llm_params'] = 1
model_kwargs.append(kwargs)

# %%
benchmark_models(model_classes, model_names, model_kwargs, ds, lr, epochs,
batch_size, eval_batch_size, get_loss, inference_step,
skip_LLMs=False, tiny_llama=args.tiny_llama, force=True)
Loading

0 comments on commit 25aba74

Please sign in to comment.