-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into feature/adding_ssma_aggregation
- Loading branch information
Showing
58 changed files
with
6,391 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,11 @@ | ||
# Examples for Co-training LLMs and GNNs | ||
|
||
| Example | Description | | ||
| ------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------- | | ||
| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information | | ||
| Example | Description | | ||
| -------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ||
| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information | | ||
| [`g_retriever_utils/`](./g_retriever_utils/) | Contains multiple scripts for benchmarking GRetriever's architecture and evaluating different retrieval methods. | | ||
| [`multihop_rag/`](./multihop_rag/) | Contains starter code and an example run for building a Multi-hop dataset using WikiHop5M and 2WikiMultiHopQA | | ||
| [`nvtx_examples/`](./nvtx_examples/) | Contains examples of how to wrap functions using the NVTX profiler for CUDA runtime analysis. | | ||
| [`molecule_gpt.py`](./molecule_gpt.py) | Example for MoleculeGPT: Instruction Following Large Language Models for Molecular Property Prediction | | ||
| [`glem.py`](./glem.py) | Example for [GLEM](https://arxiv.org/abs/2210.14709), a GNN+LLM co-training model via variational Expectation-Maximization (EM) framework on node classification tasks to achieve SOTA results | | ||
| [`git_mol.py`](./git_mol.py) | Example for GIT-Mol: A Multi-modal Large Language Model for Molecular Science with Graph, Image, and Text | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Examples for LLM and GNN co-training | ||
|
||
| Example | Description | | ||
| ---------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ||
| [`rag_feature_store.py`](./rag_feature_store.py) | A Proof of Concept Implementation of a RAG enabled FeatureStore that can serve as a starting point for implementing a custom RAG Remote Backend | | ||
| [`rag_graph_store.py`](./rag_graph_store.py) | A Proof of Concept Implementation of a RAG enabled GraphStore that can serve as a starting point for implementing a custom RAG Remote Backend | | ||
| [`rag_backend_utils.py`](./rag_backend_utils.py) | Utility functions used for loading a series of Knowledge Graph Triplets into the Remote Backend defined by a FeatureStore and GraphStore | | ||
| [`rag_generate.py`](./rag_generate.py) | Script for generating a unique set of subgraphs from the WebQSP dataset using a custom defined retrieval algorithm (defaults to the FeatureStore and GraphStore provided) | | ||
| [`benchmark_model_archs_rag.py`](./benchmark_model_archs_rag.py) | Script for running a GNN/LLM benchmark on GRetriever while grid searching relevent architecture parameters and datasets. | | ||
|
||
NOTE: Evaluating performance on GRetriever with smaller sample sizes may result in subpar performance. It is not unusual for the fine-tuned model/LLM to perform worse than an untrained LLM on very small sample sizes. |
105 changes: 105 additions & 0 deletions
105
examples/llm/g_retriever_utils/benchmark_model_archs_rag.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
"""Used to benchmark the performance of an untuned/fine tuned LLM against | ||
GRetriever with various architectures and layer depths. | ||
""" | ||
# %% | ||
import argparse | ||
import sys | ||
|
||
import torch | ||
|
||
from torch_geometric.datasets import WebQSPDataset | ||
from torch_geometric.nn.models import GAT, MLP, GRetriever | ||
|
||
sys.path.append('..') | ||
from minimal_demo import ( # noqa: E402 # isort:skip | ||
benchmark_models, get_loss, inference_step, | ||
) | ||
|
||
# %% | ||
parser = argparse.ArgumentParser( | ||
description="""Benchmarker for GRetriever\n""" + | ||
"""NOTE: Evaluating with smaller samples may result in poorer""" + | ||
""" performance for the trained models compared to """ + | ||
"""untrained models.""") | ||
parser.add_argument("--hidden_channels", type=int, default=1024) | ||
parser.add_argument("--learning_rate", type=float, default=1e-5) | ||
parser.add_argument("--epochs", type=int, default=2) | ||
parser.add_argument("--batch_size", type=int, default=8) | ||
parser.add_argument("--eval_batch_size", type=int, default=16) | ||
parser.add_argument("--tiny_llama", action='store_true') | ||
|
||
parser.add_argument("--dataset_path", type=str, required=False) | ||
# Default to WebQSP split | ||
parser.add_argument("--num_train", type=int, default=2826) | ||
parser.add_argument("--num_val", type=int, default=246) | ||
parser.add_argument("--num_test", type=int, default=1628) | ||
|
||
args = parser.parse_args() | ||
|
||
# %% | ||
hidden_channels = args.hidden_channels | ||
lr = args.learning_rate | ||
epochs = args.epochs | ||
batch_size = args.batch_size | ||
eval_batch_size = args.eval_batch_size | ||
|
||
# %% | ||
if not args.dataset_path: | ||
ds = WebQSPDataset('benchmark_archs', verbose=True, force_reload=True) | ||
else: | ||
# We just assume that the size of the dataset accomodates the | ||
# train/val/test split, because checking may be expensive. | ||
dataset = torch.load(args.dataset_path) | ||
|
||
class MockDataset: | ||
"""Utility class to patch the fields in WebQSPDataset used by | ||
GRetriever. | ||
""" | ||
def __init__(self) -> None: | ||
pass | ||
|
||
@property | ||
def split_idxs(self) -> dict: | ||
# Imitates the WebQSP split method | ||
return { | ||
"train": | ||
torch.arange(args.num_train), | ||
"val": | ||
torch.arange(args.num_val) + args.num_train, | ||
"test": | ||
torch.arange(args.num_test) + args.num_train + args.num_val, | ||
} | ||
|
||
def __getitem__(self, idx: int): | ||
return dataset[idx] | ||
|
||
ds = MockDataset() | ||
|
||
# %% | ||
model_names = [] | ||
model_classes = [] | ||
model_kwargs = [] | ||
model_type = ["GAT", "MLP"] | ||
models = {"GAT": GAT, "MLP": MLP} | ||
# Use to vary the depth of the GNN model | ||
num_layers = [4] | ||
# Use to vary the number of LLM tokens reserved for GNN output | ||
num_tokens = [1] | ||
for m_type in model_type: | ||
for n_layer in num_layers: | ||
for n_tokens in num_tokens: | ||
model_names.append(f"{m_type}_{n_layer}_{n_tokens}") | ||
model_classes.append(GRetriever) | ||
kwargs = dict(gnn_hidden_channels=hidden_channels, | ||
num_gnn_layers=n_layer, gnn_to_use=models[m_type], | ||
mlp_out_tokens=n_tokens) | ||
if args.tiny_llama: | ||
kwargs['llm_to_use'] = 'TinyLlama/TinyLlama-1.1B-Chat-v0.1' | ||
kwargs['mlp_out_dim'] = 2048 | ||
kwargs['num_llm_params'] = 1 | ||
model_kwargs.append(kwargs) | ||
|
||
# %% | ||
benchmark_models(model_classes, model_names, model_kwargs, ds, lr, epochs, | ||
batch_size, eval_batch_size, get_loss, inference_step, | ||
skip_LLMs=False, tiny_llama=args.tiny_llama, force=True) |
Oops, something went wrong.