diff --git a/CHANGELOG.md b/CHANGELOG.md index a92cf9ee2d73..69ec38aaa4ca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added various GRetriever Architecture Benchmarking examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) +- Added `profiler.nvtxit` with some examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) +- Added `loader.RagQueryLoader` with Remote Backend Example ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) +- Added `data.LargeGraphIndexer` ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) +- Added `GIT-Mol` ([#9730](https://github.com/pyg-team/pytorch_geometric/pull/9730)) +- Added comment in `g_retriever.py` pointing to `Neo4j` Graph DB integration demo ([#9748](https://github.com/pyg-team/pytorch_geometric/pull/9797)) +- Added `MoleculeGPT` example ([#9710](https://github.com/pyg-team/pytorch_geometric/pull/9710)) +- Added `nn.models.GLEM` ([#9662](https://github.com/pyg-team/pytorch_geometric/pull/9662)) +- Added `TAGDataset` ([#9662](https://github.com/pyg-team/pytorch_geometric/pull/9662)) - Added support for fast `Delaunay()` triangulation via the `torch_delaunay` package ([#9748](https://github.com/pyg-team/pytorch_geometric/pull/9748)) - Added PyTorch 2.5 support ([#9779](https://github.com/pyg-team/pytorch_geometric/pull/9779), [#9779](https://github.com/pyg-team/pytorch_geometric/pull/9780)) - Support 3D tetrahedral mesh elements of shape `[4, num_faces]` in the `FaceToEdge` transformation ([#9776](https://github.com/pyg-team/pytorch_geometric/pull/9776)) @@ -18,6 +27,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed - Dropped Python 3.8 support ([#9696](https://github.com/pyg-team/pytorch_geometric/pull/9606)) +- Added a check that confirms that custom edge types of `NumNeighbors` actually exist in the graph ([#9807](https://github.com/pyg-team/pytorch_geometric/pull/9807)) ### Deprecated diff --git a/examples/llm/README.md b/examples/llm/README.md index f1f01428d991..4503e28ce6ee 100644 --- a/examples/llm/README.md +++ b/examples/llm/README.md @@ -1,5 +1,11 @@ # Examples for Co-training LLMs and GNNs -| Example | Description | -| ------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------- | -| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information | +| Example | Description | +| -------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training `LLAMA2` with `GAT` for answering questions based on knowledge graph information | +| [`g_retriever_utils/`](./g_retriever_utils/) | Contains multiple scripts for benchmarking GRetriever's architecture and evaluating different retrieval methods. | +| [`multihop_rag/`](./multihop_rag/) | Contains starter code and an example run for building a Multi-hop dataset using WikiHop5M and 2WikiMultiHopQA | +| [`nvtx_examples/`](./nvtx_examples/) | Contains examples of how to wrap functions using the NVTX profiler for CUDA runtime analysis. | +| [`molecule_gpt.py`](./molecule_gpt.py) | Example for MoleculeGPT: Instruction Following Large Language Models for Molecular Property Prediction | +| [`glem.py`](./glem.py) | Example for [GLEM](https://arxiv.org/abs/2210.14709), a GNN+LLM co-training model via variational Expectation-Maximization (EM) framework on node classification tasks to achieve SOTA results | +| [`git_mol.py`](./git_mol.py) | Example for GIT-Mol: A Multi-modal Large Language Model for Molecular Science with Graph, Image, and Text | diff --git a/examples/llm/g_retriever.py b/examples/llm/g_retriever.py index 1735d17f5249..a48901f1ff0e 100644 --- a/examples/llm/g_retriever.py +++ b/examples/llm/g_retriever.py @@ -6,8 +6,12 @@ Requirements: `pip install datasets transformers pcst_fast sentencepiece accelerate` + +Example repo for integration with Neo4j Graph DB: +https://github.com/neo4j-product-examples/neo4j-gnn-llm-example """ import argparse +import gc import math import os.path as osp import re @@ -142,6 +146,9 @@ def adjust_learning_rate(param_group, LR, epoch): test_loader = DataLoader(test_dataset, batch_size=eval_batch_size, drop_last=False, pin_memory=True, shuffle=False) + # To clean up after Data Preproc + gc.collect() + torch.cuda.empty_cache() gnn = GAT( in_channels=1024, hidden_channels=hidden_channels, diff --git a/examples/llm/g_retriever_utils/README.md b/examples/llm/g_retriever_utils/README.md new file mode 100644 index 000000000000..e072e6746b7c --- /dev/null +++ b/examples/llm/g_retriever_utils/README.md @@ -0,0 +1,11 @@ +# Examples for LLM and GNN co-training + +| Example | Description | +| ---------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| [`rag_feature_store.py`](./rag_feature_store.py) | A Proof of Concept Implementation of a RAG enabled FeatureStore that can serve as a starting point for implementing a custom RAG Remote Backend | +| [`rag_graph_store.py`](./rag_graph_store.py) | A Proof of Concept Implementation of a RAG enabled GraphStore that can serve as a starting point for implementing a custom RAG Remote Backend | +| [`rag_backend_utils.py`](./rag_backend_utils.py) | Utility functions used for loading a series of Knowledge Graph Triplets into the Remote Backend defined by a FeatureStore and GraphStore | +| [`rag_generate.py`](./rag_generate.py) | Script for generating a unique set of subgraphs from the WebQSP dataset using a custom defined retrieval algorithm (defaults to the FeatureStore and GraphStore provided) | +| [`benchmark_model_archs_rag.py`](./benchmark_model_archs_rag.py) | Script for running a GNN/LLM benchmark on GRetriever while grid searching relevent architecture parameters and datasets. | + +NOTE: Evaluating performance on GRetriever with smaller sample sizes may result in subpar performance. It is not unusual for the fine-tuned model/LLM to perform worse than an untrained LLM on very small sample sizes. diff --git a/examples/llm/g_retriever_utils/benchmark_model_archs_rag.py b/examples/llm/g_retriever_utils/benchmark_model_archs_rag.py new file mode 100644 index 000000000000..6522aafca68b --- /dev/null +++ b/examples/llm/g_retriever_utils/benchmark_model_archs_rag.py @@ -0,0 +1,105 @@ +"""Used to benchmark the performance of an untuned/fine tuned LLM against +GRetriever with various architectures and layer depths. +""" +# %% +import argparse +import sys + +import torch + +from torch_geometric.datasets import WebQSPDataset +from torch_geometric.nn.models import GAT, MLP, GRetriever + +sys.path.append('..') +from minimal_demo import ( # noqa: E402 # isort:skip + benchmark_models, get_loss, inference_step, +) + +# %% +parser = argparse.ArgumentParser( + description="""Benchmarker for GRetriever\n""" + + """NOTE: Evaluating with smaller samples may result in poorer""" + + """ performance for the trained models compared to """ + + """untrained models.""") +parser.add_argument("--hidden_channels", type=int, default=1024) +parser.add_argument("--learning_rate", type=float, default=1e-5) +parser.add_argument("--epochs", type=int, default=2) +parser.add_argument("--batch_size", type=int, default=8) +parser.add_argument("--eval_batch_size", type=int, default=16) +parser.add_argument("--tiny_llama", action='store_true') + +parser.add_argument("--dataset_path", type=str, required=False) +# Default to WebQSP split +parser.add_argument("--num_train", type=int, default=2826) +parser.add_argument("--num_val", type=int, default=246) +parser.add_argument("--num_test", type=int, default=1628) + +args = parser.parse_args() + +# %% +hidden_channels = args.hidden_channels +lr = args.learning_rate +epochs = args.epochs +batch_size = args.batch_size +eval_batch_size = args.eval_batch_size + +# %% +if not args.dataset_path: + ds = WebQSPDataset('benchmark_archs', verbose=True, force_reload=True) +else: + # We just assume that the size of the dataset accomodates the + # train/val/test split, because checking may be expensive. + dataset = torch.load(args.dataset_path) + + class MockDataset: + """Utility class to patch the fields in WebQSPDataset used by + GRetriever. + """ + def __init__(self) -> None: + pass + + @property + def split_idxs(self) -> dict: + # Imitates the WebQSP split method + return { + "train": + torch.arange(args.num_train), + "val": + torch.arange(args.num_val) + args.num_train, + "test": + torch.arange(args.num_test) + args.num_train + args.num_val, + } + + def __getitem__(self, idx: int): + return dataset[idx] + + ds = MockDataset() + +# %% +model_names = [] +model_classes = [] +model_kwargs = [] +model_type = ["GAT", "MLP"] +models = {"GAT": GAT, "MLP": MLP} +# Use to vary the depth of the GNN model +num_layers = [4] +# Use to vary the number of LLM tokens reserved for GNN output +num_tokens = [1] +for m_type in model_type: + for n_layer in num_layers: + for n_tokens in num_tokens: + model_names.append(f"{m_type}_{n_layer}_{n_tokens}") + model_classes.append(GRetriever) + kwargs = dict(gnn_hidden_channels=hidden_channels, + num_gnn_layers=n_layer, gnn_to_use=models[m_type], + mlp_out_tokens=n_tokens) + if args.tiny_llama: + kwargs['llm_to_use'] = 'TinyLlama/TinyLlama-1.1B-Chat-v0.1' + kwargs['mlp_out_dim'] = 2048 + kwargs['num_llm_params'] = 1 + model_kwargs.append(kwargs) + +# %% +benchmark_models(model_classes, model_names, model_kwargs, ds, lr, epochs, + batch_size, eval_batch_size, get_loss, inference_step, + skip_LLMs=False, tiny_llama=args.tiny_llama, force=True) diff --git a/examples/llm/g_retriever_utils/minimal_demo.py b/examples/llm/g_retriever_utils/minimal_demo.py new file mode 100644 index 000000000000..bdd78c3180cb --- /dev/null +++ b/examples/llm/g_retriever_utils/minimal_demo.py @@ -0,0 +1,638 @@ +"""This example implements the G-Retriever model +(https://arxiv.org/abs/2402.07630) using PyG. + +G-Retriever significantly reduces hallucinations by 54% compared to the +stand-alone LLM baseline. + +Requirements: +`pip install datasets transformers pcst_fast sentencepiece accelerate` +""" +import argparse +import gc +import math +import multiprocessing as mp +import re +import sys +import time +from os import path +from typing import Any, Callable, Dict, List, Type + +import pandas as pd +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn.utils import clip_grad_norm_ +from tqdm import tqdm + +from torch_geometric import seed_everything +from torch_geometric.data import Dataset +from torch_geometric.datasets import WebQSPDataset +from torch_geometric.loader import DataLoader +from torch_geometric.nn.models import GAT, GRetriever +from torch_geometric.nn.nlp import LLM + +# NOTE: This used to be merged in the G-Retriever example. +# FIXME: Getting the demos working like before is a WIP +sys.path.append('..') +from g_retriever import ( # noqa: E402 # isort:skip + compute_metrics, load_params_dict, save_params_dict, +) + + +def _detect_hallucinate(inp): + pred, label = inp + try: + split_pred = pred.split('[/s]')[0].strip().split('|') + correct_hit = len(re.findall(split_pred[0], label)) > 0 + correct_hit = correct_hit or any( + [label_i in pred.lower() for label_i in label.split('|')]) + hallucination = not correct_hit + return hallucination + except: # noqa + return "skip" + + +def detect_hallucinate(pred_batch, label_batch): + r"""An approximation for the unsolved task of detecting hallucinations. + We define a hallucination as an output that contains no instances of + acceptable label. + """ + with mp.Pool(len(pred_batch)) as p: + res = p.map(_detect_hallucinate, zip(pred_batch, label_batch)) + return res + + +def compute_n_parameters(model: torch.nn.Module) -> int: + return sum([p.numel() for p in model.parameters() if p.requires_grad]) + + +def get_loss(model, batch, model_save_name) -> Tensor: + if model_save_name == 'llm': + return model(batch.question, batch.label, batch.desc) + else: + return model(batch.question, batch.x, batch.edge_index, batch.batch, + batch.label, batch.edge_attr, batch.desc) + + +def inference_step(model, batch, model_save_name): + if model_save_name == 'llm': + return model.inference(batch.question, batch.desc) + else: + return model.inference(batch.question, batch.x, batch.edge_index, + batch.batch, batch.edge_attr, batch.desc) + + +# TODO: Merge with G-Retriever example and make sure changes still work +def train( + num_epochs, + hidden_channels, + num_gnn_layers, + batch_size, + eval_batch_size, + lr, + checkpointing=False, + tiny_llama=False, + model=None, + dataset=None, + model_save_name=None, +): + def adjust_learning_rate(param_group, LR, epoch): + # Decay the learning rate with half-cycle cosine after warmup + min_lr = 5e-6 + warmup_epochs = 1 + if epoch < warmup_epochs: + lr = LR + else: + lr = min_lr + (LR - min_lr) * 0.5 * ( + 1.0 + math.cos(math.pi * (epoch - warmup_epochs) / + (num_epochs - warmup_epochs))) + param_group['lr'] = lr + return lr + + start_time = time.time() + seed_everything(42) + if dataset is None: + dataset = WebQSPDataset() + gc.collect() + elif not isinstance(dataset, Dataset) and callable(dataset): + dataset = dataset() + gc.collect() + idx_split = dataset.split_idxs + + # Step 1: Build Node Classification Dataset + train_dataset = [dataset[i] for i in idx_split['train']] + val_dataset = [dataset[i] for i in idx_split['val']] + test_dataset = [dataset[i] for i in idx_split['test']] + + train_loader = DataLoader(train_dataset, batch_size=batch_size, + drop_last=True, pin_memory=True, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=eval_batch_size, + drop_last=False, pin_memory=True, shuffle=False) + test_loader = DataLoader(test_dataset, batch_size=eval_batch_size, + drop_last=False, pin_memory=True, shuffle=False) + + if model is None: + gc.collect() + gnn = GAT( + in_channels=1024, + hidden_channels=hidden_channels, + out_channels=1024, + num_layers=num_gnn_layers, + heads=4, + ) + if tiny_llama: + llm = LLM( + model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1', + num_params=1, + ) + model = GRetriever(llm=llm, gnn=gnn, mlp_out_channels=2048) + else: + llm = LLM(model_name='meta-llama/Llama-2-7b-chat-hf', num_params=7) + model = GRetriever(llm=llm, gnn=gnn) + + if model_save_name is None: + model_save_name = 'gnn_llm' if num_gnn_layers is not None else 'llm' + + model_save_name = 'gnn_llm' if num_gnn_layers != 0 else 'llm' + if model_save_name == 'llm': + model = llm + + params = [p for _, p in model.named_parameters() if p.requires_grad] + optimizer = torch.optim.AdamW([ + { + 'params': params, + 'lr': lr, + 'weight_decay': 0.05 + }, + ], betas=(0.9, 0.95)) + grad_steps = 2 + + best_epoch = 0 + best_val_loss = float('inf') + for epoch in range(num_epochs): + model.train() + epoch_loss = 0 + if epoch == 0: + print(f"Total Preparation Time: {time.time() - start_time:2f}s") + start_time = time.time() + print("Training beginning...") + epoch_str = f'Epoch: {epoch + 1}|{num_epochs}' + loader = tqdm(train_loader, desc=epoch_str) + for step, batch in enumerate(loader): + optimizer.zero_grad() + loss = get_loss(model, batch, model_save_name) + loss.backward() + + clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1) + + if (step + 1) % grad_steps == 0: + adjust_learning_rate(optimizer.param_groups[0], lr, + step / len(train_loader) + epoch) + + optimizer.step() + epoch_loss = epoch_loss + float(loss) + + if (step + 1) % grad_steps == 0: + lr = optimizer.param_groups[0]['lr'] + train_loss = epoch_loss / len(train_loader) + print(epoch_str + f', Train Loss: {train_loss:4f}') + + val_loss = 0 + eval_output = [] + model.eval() + with torch.no_grad(): + for step, batch in enumerate(val_loader): + loss = get_loss(model, batch, model_save_name) + val_loss += loss.item() + val_loss = val_loss / len(val_loader) + print(epoch_str + f", Val Loss: {val_loss:4f}") + if checkpointing and val_loss < best_val_loss: + print("Checkpointing best model...") + best_val_loss = val_loss + best_epoch = epoch + save_params_dict(model, f'{model_save_name}_best_val_loss_ckpt.pt') + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + + if checkpointing and best_epoch != num_epochs - 1: + print("Loading best checkpoint...") + model = load_params_dict( + model, + f'{model_save_name}_best_val_loss_ckpt.pt', + ) + + model.eval() + eval_output = [] + print("Final evaluation...") + progress_bar_test = tqdm(range(len(test_loader))) + for step, batch in enumerate(test_loader): + with torch.no_grad(): + pred = inference_step(model, batch, model_save_name) + eval_data = { + 'pred': pred, + 'question': batch.question, + 'desc': batch.desc, + 'label': batch.label + } + eval_output.append(eval_data) + progress_bar_test.update(1) + + # Step 6 Post-processing & compute metrics + compute_metrics(eval_output) + print(f"Total Training Time: {time.time() - start_time:2f}s") + save_params_dict(model, f'{model_save_name}.pt') + torch.save(eval_output, f'{model_save_name}_eval_outs.pt') + print("Done!") + return prep_time, dataset, eval_output # noqa: F821 + + +def _eval_hallucinations_on_loader(outs, loader, eval_batch_size): + model_save_list = [] + model_preds = [] + for out in outs: + model_preds += out['pred'] + for i, batch in enumerate(loader): + correct_answer = batch.label + + model_pred = model_preds[i * eval_batch_size:(i + 1) * eval_batch_size] + model_hallucinates = detect_hallucinate(model_pred, correct_answer) + model_save_list += [tup for tup in zip(model_pred, model_hallucinates)] + return model_save_list + + +def benchmark_models(models: List[Type[nn.Module]], model_names: List[str], + model_kwargs: List[Dict[str, Any]], dataset: Dataset, + lr: float, epochs: int, batch_size: int, + eval_batch_size: int, loss_fn: Callable, + inference_fn: Callable, skip_LLMs: bool = True, + tiny_llama: bool = False, checkpointing: bool = True, + force: bool = False, root_dir='.'): + """Utility function for creating a model benchmark for GRetriever that + grid searches over hyperparameters. Produces a DataFrame containing + metrics for each model. + + Args: + models (List[Type[nn.Module]]): Models to be benchmarked. + model_names (List[str]): Name of save files for model checkpoints + model_kwargs (List[Dict[str, Any]]): Parameters to use for each + particular model. + dataset (Dataset): Input dataset to train on. + lr (float): Learning rate + epochs (int): Number of epochs + batch_size (int): Batch size for training + eval_batch_size (int): Batch size for eval. Also determines + hallucination detection concurrancy. + loss_fn (Callable): Loss function + inference_fn (Callable): Inference function + skip_LLMs (bool, optional): Whether to skip LLM-only runs. + Defaults to True. + tiny_llama (bool, optional): Whether to use tiny llama as LLM. + Defaults to False. + checkpointing (bool, optional): Whether to checkpoint models. + Defaults to True. + force (bool, optional): Whether to rerun already existing results. + Defaults to False. + root_dir (str, optional): Dir to save results and checkpoints in. + Defaults to '.'. + """ + model_log: Dict[str, Dict[str, Any]] = dict() + idx_split = dataset.split_idxs + test_dataset = [dataset[i] for i in idx_split['test']] + loader = DataLoader(test_dataset, batch_size=eval_batch_size, + drop_last=False, pin_memory=True, shuffle=False) + + if not skip_LLMs: + if tiny_llama: + pure_llm = LLM( + model_name="TinyLlama/TinyLlama-1.1B-Chat-v0.1", + num_params=1, + ) + else: + pure_llm = LLM(model_name="meta-llama/Llama-2-7b-chat-hf", + num_params=7) + + if force or not path.exists(root_dir + "/pure_llm_model_log.pt"): + model_log["pure_llm"] = dict() + + pure_preds = [] + for batch in tqdm(loader): + pure_llm_preds = pure_llm.inference(batch.question, batch.desc, + max_tokens=256) + pure_preds += pure_llm_preds + pure_preds = [{"pred": pred} for pred in pure_preds] + + model_log["pure_llm"]["preds"] = pure_preds + model_log["pure_llm"]["hallucinates_list"] = \ + _eval_hallucinations_on_loader(pure_preds, loader, + eval_batch_size) + model_log["pure_llm"]["n_params"] = compute_n_parameters(pure_llm) + torch.save(model_log["pure_llm"], + root_dir + "/pure_llm_model_log.pt") + else: + model_log["pure_llm"] = \ + torch.load(root_dir+"/pure_llm_model_log.pt") + + # LORA + if force or not path.exists(root_dir + "/tuned_llm_model_log.pt"): + model_log["tuned_llm"] = dict() + since = time.time() + gc.collect() + prep_time, _, lora_eval_outs = train(since, epochs, None, None, + batch_size, eval_batch_size, + lr, loss_fn, inference_fn, + model=pure_llm, + dataset=dataset) + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + gc.collect() + e2e_time = round(time.time() - since, 2) + model_log["tuned_llm"]["prep_time"] = prep_time + model_log["tuned_llm"]["e2e_time"] = e2e_time + model_log["tuned_llm"]["eval_output"] = lora_eval_outs + print("E2E time (e2e_time) =", e2e_time, "seconds") + print("E2E tme minus Prep Time =", e2e_time - prep_time, "seconds") + + model_log["tuned_llm"]["hallucinates_list"] = \ + _eval_hallucinations_on_loader(lora_eval_outs, loader, + eval_batch_size) + model_log["tuned_llm"]["n_params"] = compute_n_parameters(pure_llm) + torch.save(model_log["tuned_llm"], + root_dir + "/tuned_llm_model_log.pt") + else: + model_log["tuned_llm"] = \ + torch.load(root_dir+"/tuned_llm_model_log.pt") + + del pure_llm + gc.collect() + + # All other models + for name, Model, kwargs in zip(model_names, models, model_kwargs): + model_log[name] = dict() + train_model = True + if path.exists(root_dir + f"/{name}.pt") and not force: + print(f"Model {name} appears to already exist.") + print("Would you like to retrain?") + train_model = str(input("(y/n):")).lower() == "y" + + if train_model: + since = time.time() + gc.collect() + model = Model(**kwargs) + prep_time, _, model_eval_outs = train( + since=since, num_epochs=epochs, hidden_channels=None, + num_gnn_layers=None, batch_size=batch_size, + eval_batch_size=eval_batch_size, lr=lr, loss_fn=loss_fn, + inference_fn=inference_fn, checkpointing=checkpointing, + tiny_llama=tiny_llama, dataset=dataset, + model_save_name=root_dir + '/' + name, model=model) + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + gc.collect() + e2e_time = round(time.time() - since, 2) + model_log[name]["prep_time"] = prep_time + model_log[name]["e2e_time"] = e2e_time + model_log[name]["eval_output"] = model_eval_outs + print("E2E time (e2e_time) =", e2e_time, "seconds") + print("E2E tme minus Prep Time =", e2e_time - prep_time, "seconds") + model_log[name]["n_params"] = compute_n_parameters(model) + del model + gc.collect() + else: + model_eval_outs = torch.load(root_dir + f"/{name}_eval_outs.pt") + + # Calculate Hallucinations + skip_hallucination_detection = False + + if path.exists(root_dir + f"/{name}_model_log.pt") and not force: + print(f"Saved outputs for {name} have been found.") + print("Would you like to redo?") + user_input = str(input("(y/n):")).lower() + skip_hallucination_detection = user_input != "y" + + if not skip_hallucination_detection: + model_save_list = _eval_hallucinations_on_loader( + model_eval_outs, loader, eval_batch_size) + + model_log[name]["hallucinates_list"] = model_save_list + torch.save(model_log[name], root_dir + f"/{name}_model_log.pt") + else: + model_log[name]["hallucinates_list"] = \ + torch.load( + root_dir+f"/{name}_model_log.pt" + )["hallucinates_list"] + + hal_dict = { + k: [tup[1] for tup in v["hallucinates_list"]] + for (k, v) in model_log.items() + } + hallucinates_df = pd.DataFrame(hal_dict).astype(str) + hallucinates_df = hallucinates_df.apply(pd.Series.value_counts).transpose() + hallucinates_df['e2e_time'] = pd.Series( + {k: v.get('e2e_time') + for (k, v) in model_log.items()}) + hallucinates_df['n_params'] = pd.Series( + {k: v.get('n_params') + for (k, v) in model_log.items()}) + print(hallucinates_df) + hallucinates_df.to_csv(root_dir + "/hallucinates_df.csv", index=False) + + +def minimal_demo(gnn_llm_eval_outs, dataset, lr, epochs, batch_size, + eval_batch_size, loss_fn, inference_fn, + skip_pretrained_LLM=False, tiny_llama=False): + if not skip_pretrained_LLM: + print("First comparing against a pretrained LLM...") + # Step 1: Define a single batch size test loader + idx_split = dataset.split_idxs + test_dataset = [dataset[i] for i in idx_split['test']] + # batch size 1 loader for simplicity + loader = DataLoader(test_dataset, batch_size=eval_batch_size, + drop_last=False, pin_memory=True, shuffle=False) + if tiny_llama: + pure_llm = LLM( + model_name="TinyLlama/TinyLlama-1.1B-Chat-v0.1", + num_params=1, + ) + else: + pure_llm = LLM(model_name="meta-llama/Llama-2-7b-chat-hf", + num_params=7) + if path.exists("demo_save_dict.pt"): + print("Saved outputs for the first step of the demo found.") + print("Would you like to redo?") + user_input = str(input("(y/n):")).lower() + skip_step_one = user_input == "n" + else: + skip_step_one = False + + if not skip_step_one: + gnn_llm_hallucin_sum = 0 + pure_llm_hallucin_sum = 0 + gnn_save_list = [] + untuned_llm_save_list = [] + gnn_llm_preds = [] + for out in gnn_llm_eval_outs: + gnn_llm_preds += out['pred'] + if skip_pretrained_LLM: + print("Checking GNN+LLM for hallucinations...") + else: + print( + "Checking pretrained LLM vs trained GNN+LLM for hallucinations..." # noqa + ) + for i, batch in enumerate(tqdm(loader)): + question = batch.question + correct_answer = batch.label + + gnn_llm_pred = gnn_llm_preds[i * eval_batch_size:(i + 1) * + eval_batch_size] + gnn_llm_hallucinates = detect_hallucinate(gnn_llm_pred, + correct_answer) + gnn_save_list += [ + tup for tup in zip(gnn_llm_pred, gnn_llm_hallucinates) + ] + + if not skip_pretrained_LLM: + # GNN+LLM only using 32 tokens to answer. + # Allow more output tokens for untrained LLM + pure_llm_pred = pure_llm.inference(batch.question, batch.desc, + max_tokens=256) + pure_llm_hallucinates = detect_hallucinate( + pure_llm_pred, correct_answer) + else: + pure_llm_pred = [''] * len(gnn_llm_hallucinates) + pure_llm_hallucinates = [False] * len(gnn_llm_hallucinates) + untuned_llm_save_list += [ + tup for tup in zip(pure_llm_pred, pure_llm_hallucinates) + ] + + for gnn_llm_hal, pure_llm_hal in zip(gnn_llm_hallucinates, + pure_llm_hallucinates): + if gnn_llm_hal == "skip" or pure_llm_hal == "skip": # noqa + # skipping when hallucination is hard to eval + continue + gnn_llm_hallucin_sum += int(gnn_llm_hal) + pure_llm_hallucin_sum += int(pure_llm_hal) + if not skip_pretrained_LLM: + print("Total Pure LLM Hallucinations:", pure_llm_hallucin_sum) + print("Total GNN+LLM Hallucinations:", gnn_llm_hallucin_sum) + percent = 100.0 * round( + 1 - (gnn_llm_hallucin_sum / pure_llm_hallucin_sum), 2) + print(f"GNN reduces pretrained LLM hallucinations by: ~{percent}%") + print("Note: hallucinations detected by regex hence the ~") + print("Now we see how the LLM compares when finetuned...") + print("Saving outputs of GNN+LLM and pretrained LLM...") + save_dict = { + "gnn_save_list": gnn_save_list, + "untuned_llm_save_list": untuned_llm_save_list, + "gnn_llm_hallucin_sum": gnn_llm_hallucin_sum, + "pure_llm_hallucin_sum": pure_llm_hallucin_sum + } + torch.save(save_dict, "demo_save_dict.pt") + print("Done!") + else: + save_dict = torch.load("demo_save_dict.pt") + gnn_save_list = save_dict["gnn_save_list"] + untuned_llm_save_list = save_dict["untuned_llm_save_list"] + gnn_llm_hallucin_sum = save_dict["gnn_llm_hallucin_sum"] + pure_llm_hallucin_sum = save_dict["pure_llm_hallucin_sum"] + + trained_llm_hallucin_sum = 0 + untuned_llm_hallucin_sum = pure_llm_hallucin_sum + final_prnt_str = "" + if path.exists("llm.pt") and path.exists("llm_eval_outs.pt"): + print("Existing finetuned LLM found.") + print("Would you like to retrain?") + user_input = str(input("(y/n):")).lower() + retrain = user_input == "y" + else: + retrain = True + if retrain: + print("Finetuning LLM...") + since = time.time() + _, _, pure_llm_eval_outputs = train(since, epochs, None, None, + batch_size, eval_batch_size, lr, + loss_fn, inference_fn, + model=pure_llm, dataset=dataset) + e2e_time = round(time.time() - since, 2) + print("E2E time (e2e_time) =", e2e_time, "seconds") + else: + pure_llm_eval_outputs = torch.load("llm_eval_outs.pt") + pure_llm_preds = [] + for out in pure_llm_eval_outputs: + pure_llm_preds += out['pred'] + print("Final comparison between all models...") + for i, batch in enumerate(tqdm(loader)): + question = batch.question + correct_answer = batch.label + gnn_llm_pred, gnn_llm_hallucinates = list( + zip(*gnn_save_list[i * eval_batch_size:(i + 1) * eval_batch_size])) + untuned_llm_pred, untuned_llm_hallucinates = list( + zip(*untuned_llm_save_list[i * eval_batch_size:(i + 1) * + eval_batch_size])) + pure_llm_pred = pure_llm_preds[i * eval_batch_size:(i + 1) * + eval_batch_size] + pure_llm_hallucinates = detect_hallucinate(pure_llm_pred, + correct_answer) + for j in range(len(gnn_llm_pred)): + if skip_pretrained_LLM: + # we did not check the untrained LLM, so do not decide to demo + # based on this. + # HACK + untuned_llm_hallucinates = {j: True} + if gnn_llm_hallucinates[j] == "skip" or untuned_llm_hallucinates[ + j] == "skip" or pure_llm_hallucinates[j] == "skip": + continue + trained_llm_hallucin_sum += int(pure_llm_hallucinates[j]) + if untuned_llm_hallucinates[j] and pure_llm_hallucinates[ + j] and not gnn_llm_hallucinates[j]: # noqa + final_prnt_str += "Prompt: '" + question[j] + "'\n" + final_prnt_str += "Label: '" + correct_answer[j] + "'\n" + if not skip_pretrained_LLM: + final_prnt_str += "Untuned LLM Output: '" \ + + untuned_llm_pred[j] + "'\n" # noqa + final_prnt_str += "Tuned LLM Output: '" + pure_llm_pred[ + j] + "'\n" + final_prnt_str += "GNN+LLM Output: '" + gnn_llm_pred[j] + "'\n" + final_prnt_str += "\n" + "#" * 20 + "\n\n" + if not skip_pretrained_LLM: + print("Total untuned LLM Hallucinations:", untuned_llm_hallucin_sum) + print("Total tuned LLM Hallucinations:", trained_llm_hallucin_sum) + print("Total GNN+LLM Hallucinations:", gnn_llm_hallucin_sum) + if not skip_pretrained_LLM: + percent = 100.0 * round( + 1 - (gnn_llm_hallucin_sum / untuned_llm_hallucin_sum), 2) + print(f"GNN reduces untuned LLM hallucinations by: ~{percent}%") + tuned_percent = 100.0 * round( + 1 - (gnn_llm_hallucin_sum / trained_llm_hallucin_sum), 2) + print(f"GNN reduces tuned LLM hallucinations by: ~{tuned_percent}%") + print("Note: hallucinations detected by regex hence the ~") + print("Potential instances where GNN solves the hallucinations of LLM:") + print(final_prnt_str) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--gnn_hidden_channels', type=int, default=1024) + parser.add_argument('--num_gnn_layers', type=int, default=4) + parser.add_argument('--lr', type=float, default=1e-5) + parser.add_argument('--epochs', type=int, default=2) + parser.add_argument('--batch_size', type=int, default=8) + parser.add_argument('--eval_batch_size', type=int, default=16) + parser.add_argument('--checkpointing', action='store_true') + parser.add_argument('--tiny_llama', action='store_true') + parser.add_argument( + "--skip_pretrained_llm_eval", action="store_true", + help="This flag will skip the evaluation of the pretrained LLM.") + args = parser.parse_args() + + start_time = time.time() + train( + args.epochs, + args.gnn_hidden_channels, + args.num_gnn_layers, + args.batch_size, + args.eval_batch_size, + args.lr, + checkpointing=args.checkpointing, + tiny_llama=args.tiny_llama, + ) + print(f"Total Time: {time.time() - start_time:2f}s") diff --git a/examples/llm/g_retriever_utils/rag_backend_utils.py b/examples/llm/g_retriever_utils/rag_backend_utils.py new file mode 100644 index 000000000000..0f1c0e1b87ec --- /dev/null +++ b/examples/llm/g_retriever_utils/rag_backend_utils.py @@ -0,0 +1,224 @@ +from dataclasses import dataclass +from enum import Enum, auto +from typing import ( + Any, + Callable, + Dict, + Iterable, + Optional, + Protocol, + Tuple, + Type, + runtime_checkable, +) + +import torch +from torch import Tensor +from torch.nn import Module + +from torch_geometric.data import ( + FeatureStore, + GraphStore, + LargeGraphIndexer, + TripletLike, +) +from torch_geometric.data.large_graph_indexer import EDGE_RELATION +from torch_geometric.distributed import ( + LocalFeatureStore, + LocalGraphStore, + Partitioner, +) +from torch_geometric.typing import EdgeType, NodeType + +RemoteGraphBackend = Tuple[FeatureStore, GraphStore] + +# TODO: Make everything compatible with Hetero graphs aswell + + +# Adapted from LocalGraphStore +@runtime_checkable +class ConvertableGraphStore(Protocol): + @classmethod + def from_data( + cls, + edge_id: Tensor, + edge_index: Tensor, + num_nodes: int, + is_sorted: bool = False, + ) -> GraphStore: + ... + + @classmethod + def from_hetero_data( + cls, + edge_id_dict: Dict[EdgeType, Tensor], + edge_index_dict: Dict[EdgeType, Tensor], + num_nodes_dict: Dict[NodeType, int], + is_sorted: bool = False, + ) -> GraphStore: + ... + + @classmethod + def from_partition(cls, root: str, pid: int) -> GraphStore: + ... + + +# Adapted from LocalFeatureStore +@runtime_checkable +class ConvertableFeatureStore(Protocol): + @classmethod + def from_data( + cls, + node_id: Tensor, + x: Optional[Tensor] = None, + y: Optional[Tensor] = None, + edge_id: Optional[Tensor] = None, + edge_attr: Optional[Tensor] = None, + ) -> FeatureStore: + ... + + @classmethod + def from_hetero_data( + cls, + node_id_dict: Dict[NodeType, Tensor], + x_dict: Optional[Dict[NodeType, Tensor]] = None, + y_dict: Optional[Dict[NodeType, Tensor]] = None, + edge_id_dict: Optional[Dict[EdgeType, Tensor]] = None, + edge_attr_dict: Optional[Dict[EdgeType, Tensor]] = None, + ) -> FeatureStore: + ... + + @classmethod + def from_partition(cls, root: str, pid: int) -> FeatureStore: + ... + + +class RemoteDataType(Enum): + DATA = auto() + PARTITION = auto() + + +@dataclass +class RemoteGraphBackendLoader: + """Utility class to load triplets into a RAG Backend.""" + path: str + datatype: RemoteDataType + graph_store_type: Type[ConvertableGraphStore] + feature_store_type: Type[ConvertableFeatureStore] + + def load(self, pid: Optional[int] = None) -> RemoteGraphBackend: + if self.datatype == RemoteDataType.DATA: + data_obj = torch.load(self.path) + graph_store = self.graph_store_type.from_data( + edge_id=data_obj['edge_id'], edge_index=data_obj.edge_index, + num_nodes=data_obj.num_nodes) + feature_store = self.feature_store_type.from_data( + node_id=data_obj['node_id'], x=data_obj.x, + edge_id=data_obj['edge_id'], edge_attr=data_obj.edge_attr) + elif self.datatype == RemoteDataType.PARTITION: + if pid is None: + assert pid is not None, \ + "Partition ID must be defined for loading from a " \ + + "partitioned store." + graph_store = self.graph_store_type.from_partition(self.path, pid) + feature_store = self.feature_store_type.from_partition( + self.path, pid) + else: + raise NotImplementedError + return (feature_store, graph_store) + + +# TODO: make profilable +def create_remote_backend_from_triplets( + triplets: Iterable[TripletLike], node_embedding_model: Module, + edge_embedding_model: Module | None = None, + graph_db: Type[ConvertableGraphStore] = LocalGraphStore, + feature_db: Type[ConvertableFeatureStore] = LocalFeatureStore, + node_method_to_call: str = "forward", + edge_method_to_call: str | None = None, + pre_transform: Callable[[TripletLike], TripletLike] | None = None, + path: str = '', n_parts: int = 1, + node_method_kwargs: Optional[Dict[str, Any]] = None, + edge_method_kwargs: Optional[Dict[str, Any]] = None +) -> RemoteGraphBackendLoader: + """Utility function that can be used to create a RAG Backend from triplets. + + Args: + triplets (Iterable[TripletLike]): Triplets to load into the RAG + Backend. + node_embedding_model (Module): Model to embed nodes into a feature + space. + edge_embedding_model (Module | None, optional): Model to embed edges + into a feature space. Defaults to the node model. + graph_db (Type[ConvertableGraphStore], optional): GraphStore class to + use. Defaults to LocalGraphStore. + feature_db (Type[ConvertableFeatureStore], optional): FeatureStore + class to use. Defaults to LocalFeatureStore. + node_method_to_call (str, optional): method to call for embeddings on + the node model. Defaults to "forward". + edge_method_to_call (str | None, optional): method to call for + embeddings on the edge model. Defaults to the node method. + pre_transform (Callable[[TripletLike], TripletLike] | None, optional): + optional preprocessing function for triplets. Defaults to None. + path (str, optional): path to save resulting stores. Defaults to ''. + n_parts (int, optional): Number of partitons to store in. + Defaults to 1. + node_method_kwargs (Optional[Dict[str, Any]], optional): args to pass + into node encoding method. Defaults to None. + edge_method_kwargs (Optional[Dict[str, Any]], optional): args to pass + into edge encoding method. Defaults to None. + + Returns: + RemoteGraphBackendLoader: Loader to load RAG backend from disk or + memory. + """ + # Will return attribute errors for missing attributes + if not issubclass(graph_db, ConvertableGraphStore): + getattr(graph_db, "from_data") + getattr(graph_db, "from_hetero_data") + getattr(graph_db, "from_partition") + elif not issubclass(feature_db, ConvertableFeatureStore): + getattr(feature_db, "from_data") + getattr(feature_db, "from_hetero_data") + getattr(feature_db, "from_partition") + + # Resolve callable methods + node_method_kwargs = node_method_kwargs \ + if node_method_kwargs is not None else dict() + + edge_embedding_model = edge_embedding_model \ + if edge_embedding_model is not None else node_embedding_model + edge_method_to_call = edge_method_to_call \ + if edge_method_to_call is not None else node_method_to_call + edge_method_kwargs = edge_method_kwargs \ + if edge_method_kwargs is not None else node_method_kwargs + + # These will return AttributeErrors if they don't exist + node_model = getattr(node_embedding_model, node_method_to_call) + edge_model = getattr(edge_embedding_model, edge_method_to_call) + + indexer = LargeGraphIndexer.from_triplets(triplets, + pre_transform=pre_transform) + + node_feats = node_model(indexer.get_node_features(), **node_method_kwargs) + indexer.add_node_feature('x', node_feats) + + edge_feats = edge_model( + indexer.get_unique_edge_features(feature_name=EDGE_RELATION), + **edge_method_kwargs) + indexer.add_edge_feature(new_feature_name="edge_attr", + new_feature_vals=edge_feats, + map_from_feature=EDGE_RELATION) + + data = indexer.to_data(node_feature_name='x', + edge_feature_name='edge_attr') + + if n_parts == 1: + torch.save(data, path) + return RemoteGraphBackendLoader(path, RemoteDataType.DATA, graph_db, + feature_db) + else: + partitioner = Partitioner(data=data, num_parts=n_parts, root=path) + partitioner.generate_partition() + return RemoteGraphBackendLoader(path, RemoteDataType.PARTITION, + graph_db, feature_db) diff --git a/examples/llm/g_retriever_utils/rag_feature_store.py b/examples/llm/g_retriever_utils/rag_feature_store.py new file mode 100644 index 000000000000..e01e9e59bb88 --- /dev/null +++ b/examples/llm/g_retriever_utils/rag_feature_store.py @@ -0,0 +1,189 @@ +import gc +from collections.abc import Iterable, Iterator +from typing import Any, Dict, Optional, Type, Union + +import torch +from torch import Tensor +from torch.nn import Module +from torchmetrics.functional import pairwise_cosine_similarity + +from torch_geometric.data import Data, HeteroData +from torch_geometric.distributed import LocalFeatureStore +from torch_geometric.nn.nlp import SentenceTransformer +from torch_geometric.nn.pool import ApproxMIPSKNNIndex +from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput +from torch_geometric.typing import InputEdges, InputNodes + + +# NOTE: Only compatible with Homogeneous graphs for now +class KNNRAGFeatureStore(LocalFeatureStore): + def __init__(self, enc_model: Type[Module], + model_kwargs: Optional[Dict[str, + Any]] = None, *args, **kwargs): + self.device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu") + self.enc_model = enc_model(*args, **kwargs).to(self.device) + self.enc_model.eval() + self.model_kwargs = \ + model_kwargs if model_kwargs is not None else dict() + super().__init__() + + @property + def x(self) -> Tensor: + return self.get_tensor(group_name=None, attr_name='x') + + @property + def edge_attr(self) -> Tensor: + return self.get_tensor(group_name=(None, None), attr_name='edge_attr') + + def retrieve_seed_nodes(self, query: Any, k_nodes: int = 5) -> InputNodes: + result = next(self._retrieve_seed_nodes_batch([query], k_nodes)) + gc.collect() + torch.cuda.empty_cache() + return result + + def _retrieve_seed_nodes_batch(self, query: Iterable[Any], + k_nodes: int) -> Iterator[InputNodes]: + if isinstance(self.meta, dict) and self.meta.get("is_hetero", False): + raise NotImplementedError + + query_enc = self.enc_model.encode(query, + **self.model_kwargs).to(self.device) + prizes = pairwise_cosine_similarity(query_enc, self.x.to(self.device)) + topk = min(k_nodes, len(self.x)) + for q in prizes: + _, indices = torch.topk(q, topk, largest=True) + yield indices + + def retrieve_seed_edges(self, query: Any, k_edges: int = 3) -> InputEdges: + result = next(self._retrieve_seed_edges_batch([query], k_edges)) + gc.collect() + torch.cuda.empty_cache() + return result + + def _retrieve_seed_edges_batch(self, query: Iterable[Any], + k_edges: int) -> Iterator[InputEdges]: + if isinstance(self.meta, dict) and self.meta.get("is_hetero", False): + raise NotImplementedError + + query_enc = self.enc_model.encode(query, + **self.model_kwargs).to(self.device) + + prizes = pairwise_cosine_similarity(query_enc, + self.edge_attr.to(self.device)) + topk = min(k_edges, len(self.edge_attr)) + for q in prizes: + _, indices = torch.topk(q, topk, largest=True) + yield indices + + def load_subgraph( + self, sample: Union[SamplerOutput, HeteroSamplerOutput] + ) -> Union[Data, HeteroData]: + + if isinstance(sample, HeteroSamplerOutput): + raise NotImplementedError + + # NOTE: torch_geometric.loader.utils.filter_custom_store can be used + # here if it supported edge features + node_id = sample.node + edge_id = sample.edge + edge_index = torch.stack((sample.row, sample.col), dim=0) + x = self.x[node_id] + edge_attr = self.edge_attr[edge_id] + + return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, + node_idx=node_id, edge_idx=edge_id) + + +# TODO: Refactor because composition >> inheritance + + +def _add_features_to_knn_index(knn_index: ApproxMIPSKNNIndex, emb: Tensor, + device: torch.device, batch_size: int = 2**20): + """Add new features to the existing KNN index in batches. + + Args: + knn_index (ApproxMIPSKNNIndex): Index to add features to. + emb (Tensor): Embeddings to add. + device (torch.device): Device to store in + batch_size (int, optional): Batch size to iterate by. + Defaults to 2**20, which equates to 4GB if working with + 1024 dim floats. + """ + for i in range(0, emb.size(0), batch_size): + if emb.size(0) - i >= batch_size: + emb_batch = emb[i:i + batch_size].to(device) + else: + emb_batch = emb[i:].to(device) + knn_index.add(emb_batch) + + +class ApproxKNNRAGFeatureStore(KNNRAGFeatureStore): + def __init__(self, enc_model: Type[Module], + model_kwargs: Optional[Dict[str, + Any]] = None, *args, **kwargs): + # TODO: Add kwargs for approx KNN to parameters here. + super().__init__(enc_model, model_kwargs, *args, **kwargs) + self.node_knn_index = None + self.edge_knn_index = None + + def _retrieve_seed_nodes_batch(self, query: Iterable[Any], + k_nodes: int) -> Iterator[InputNodes]: + if isinstance(self.meta, dict) and self.meta.get("is_hetero", False): + raise NotImplementedError + + enc_model = self.enc_model.to(self.device) + query_enc = enc_model.encode(query, + **self.model_kwargs).to(self.device) + del enc_model + gc.collect() + torch.cuda.empty_cache() + + if self.node_knn_index is None: + self.node_knn_index = ApproxMIPSKNNIndex(num_cells=100, + num_cells_to_visit=100, + bits_per_vector=4) + # Need to add in batches to avoid OOM + _add_features_to_knn_index(self.node_knn_index, self.x, + self.device) + + output = self.node_knn_index.search(query_enc, k=k_nodes) + yield from output.index + + def _retrieve_seed_edges_batch(self, query: Iterable[Any], + k_edges: int) -> Iterator[InputEdges]: + if isinstance(self.meta, dict) and self.meta.get("is_hetero", False): + raise NotImplementedError + + enc_model = self.enc_model.to(self.device) + query_enc = enc_model.encode(query, + **self.model_kwargs).to(self.device) + del enc_model + gc.collect() + torch.cuda.empty_cache() + + if self.edge_knn_index is None: + self.edge_knn_index = ApproxMIPSKNNIndex(num_cells=100, + num_cells_to_visit=100, + bits_per_vector=4) + # Need to add in batches to avoid OOM + _add_features_to_knn_index(self.edge_knn_index, self.edge_attr, + self.device) + + output = self.edge_knn_index.search(query_enc, k=k_edges) + yield from output.index + + +# TODO: These two classes should be refactored +class SentenceTransformerFeatureStore(KNNRAGFeatureStore): + def __init__(self, *args, **kwargs): + kwargs['model_name'] = kwargs.get( + 'model_name', 'sentence-transformers/all-roberta-large-v1') + super().__init__(SentenceTransformer, *args, **kwargs) + + +class SentenceTransformerApproxFeatureStore(ApproxKNNRAGFeatureStore): + def __init__(self, *args, **kwargs): + kwargs['model_name'] = kwargs.get( + 'model_name', 'sentence-transformers/all-roberta-large-v1') + super().__init__(SentenceTransformer, *args, **kwargs) diff --git a/examples/llm/g_retriever_utils/rag_generate.py b/examples/llm/g_retriever_utils/rag_generate.py new file mode 100644 index 000000000000..896fbd7598b1 --- /dev/null +++ b/examples/llm/g_retriever_utils/rag_generate.py @@ -0,0 +1,139 @@ +# %% +import argparse +from itertools import chain +from typing import Tuple + +import pandas as pd +import torch +import tqdm +from rag_backend_utils import create_remote_backend_from_triplets +from rag_feature_store import SentenceTransformerFeatureStore +from rag_graph_store import NeighborSamplingRAGGraphStore + +from torch_geometric.data import Data +from torch_geometric.datasets import WebQSPDataset +from torch_geometric.datasets.web_qsp_dataset import ( + preprocess_triplet, + retrieval_via_pcst, +) +from torch_geometric.loader import RAGQueryLoader +from torch_geometric.nn.nlp import SentenceTransformer + +# %% +parser = argparse.ArgumentParser( + description="""Generate new WebQSP subgraphs\n""" + + """NOTE: Evaluating with smaller samples may result in""" + + """ poorer performance for the trained models compared""" + + """ to untrained models.""") +# TODO: Add more arguments for configuring rag params +parser.add_argument("--use_pcst", action="store_true") +parser.add_argument("--num_samples", type=int, default=4700) +parser.add_argument("--out_file", default="subg_results.pt") +args = parser.parse_args() + +# %% +ds = WebQSPDataset("dataset", limit=args.num_samples, verbose=True, + force_reload=True) + +# %% +triplets = chain.from_iterable(d['graph'] for d in ds.raw_dataset) + +# %% +questions = ds.raw_dataset['question'] + +# %% +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model = SentenceTransformer( + model_name='sentence-transformers/all-roberta-large-v1').to(device) + +# %% +fs, gs = create_remote_backend_from_triplets( + triplets=triplets, node_embedding_model=model, + node_method_to_call="encode", path="backend", + pre_transform=preprocess_triplet, node_method_kwargs={ + "batch_size": 256 + }, graph_db=NeighborSamplingRAGGraphStore, + feature_db=SentenceTransformerFeatureStore).load() + +# %% + + +def apply_retrieval_via_pcst(graph: Data, query: str, topk: int = 3, + topk_e: int = 3, + cost_e: float = 0.5) -> Tuple[Data, str]: + q_emb = model.encode(query) + textual_nodes = ds.textual_nodes.iloc[graph["node_idx"]].reset_index() + textual_edges = ds.textual_edges.iloc[graph["edge_idx"]].reset_index() + out_graph, desc = retrieval_via_pcst(graph, q_emb, textual_nodes, + textual_edges, topk, topk_e, cost_e) + out_graph["desc"] = desc + return out_graph + + +def apply_retrieval_with_text(graph: Data, query: str) -> Tuple[Data, str]: + textual_nodes = ds.textual_nodes.iloc[graph["node_idx"]].reset_index() + textual_edges = ds.textual_edges.iloc[graph["edge_idx"]].reset_index() + desc = ( + textual_nodes.to_csv(index=False) + "\n" + + textual_edges.to_csv(index=False, columns=["src", "edge_attr", "dst"])) + graph["desc"] = desc + return graph + + +transform = apply_retrieval_via_pcst \ + if args.use_pcst else apply_retrieval_with_text + +query_loader = RAGQueryLoader(data=(fs, gs), seed_nodes_kwargs={"k_nodes": 5}, + seed_edges_kwargs={"k_edges": 5}, + sampler_kwargs={"num_neighbors": [50] * 2}, + local_filter=transform) + + +# %% +# Accuracy Metrics to be added to Profiler +def _eidx_helper(subg: Data, ground_truth: Data): + subg_eidx, gt_eidx = subg.edge_idx, ground_truth.edge_idx + if isinstance(subg_eidx, torch.Tensor): + subg_eidx = subg_eidx.tolist() + if isinstance(gt_eidx, torch.Tensor): + gt_eidx = gt_eidx.tolist() + subg_e = set(subg_eidx) + gt_e = set(gt_eidx) + return subg_e, gt_e + + +def check_retrieval_accuracy(subg: Data, ground_truth: Data, num_edges: int): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + total_e = set(range(num_edges)) + tp = len(subg_e & gt_e) + tn = len(total_e - (subg_e | gt_e)) + return (tp + tn) / num_edges + + +def check_retrieval_precision(subg: Data, ground_truth: Data): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + return len(subg_e & gt_e) / len(subg_e) + + +def check_retrieval_recall(subg: Data, ground_truth: Data): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + return len(subg_e & gt_e) / len(gt_e) + + +# %% +retrieval_stats = {"precision": [], "recall": [], "accuracy": []} +subgs = [] +node_len = [] +edge_len = [] +for subg in tqdm.tqdm(query_loader.query(q) for q in questions): + subgs.append(subg) + node_len.append(subg['x'].shape[0]) + edge_len.append(subg['edge_attr'].shape[0]) + +for i, subg in enumerate(subgs): + subg['question'] = questions[i] + subg['label'] = ds[i]['label'] + +pd.DataFrame.from_dict(retrieval_stats).to_csv( + args.out_file.split('.')[0] + '_metadata.csv') +torch.save(subgs, args.out_file) diff --git a/examples/llm/g_retriever_utils/rag_graph_store.py b/examples/llm/g_retriever_utils/rag_graph_store.py new file mode 100644 index 000000000000..48473f287233 --- /dev/null +++ b/examples/llm/g_retriever_utils/rag_graph_store.py @@ -0,0 +1,107 @@ +from typing import Optional, Union + +import torch +from torch import Tensor + +from torch_geometric.data import FeatureStore +from torch_geometric.distributed import LocalGraphStore +from torch_geometric.sampler import ( + HeteroSamplerOutput, + NeighborSampler, + NodeSamplerInput, + SamplerOutput, +) +from torch_geometric.sampler.neighbor_sampler import NumNeighborsType +from torch_geometric.typing import EdgeTensorType, InputEdges, InputNodes + + +class NeighborSamplingRAGGraphStore(LocalGraphStore): + def __init__(self, feature_store: Optional[FeatureStore] = None, + num_neighbors: NumNeighborsType = [1], **kwargs): + self.feature_store = feature_store + self._num_neighbors = num_neighbors + self.sample_kwargs = kwargs + self._sampler_is_initialized = False + super().__init__() + + def _init_sampler(self): + if self.feature_store is None: + raise AttributeError("Feature store not registered yet.") + self.sampler = NeighborSampler(data=(self.feature_store, self), + num_neighbors=self._num_neighbors, + **self.sample_kwargs) + self._sampler_is_initialized = True + + def register_feature_store(self, feature_store: FeatureStore): + self.feature_store = feature_store + self._sampler_is_initialized = False + + def put_edge_id(self, edge_id: Tensor, *args, **kwargs) -> bool: + ret = super().put_edge_id(edge_id.contiguous(), *args, **kwargs) + self._sampler_is_initialized = False + return ret + + @property + def edge_index(self): + return self.get_edge_index(*self.edge_idx_args, **self.edge_idx_kwargs) + + def put_edge_index(self, edge_index: EdgeTensorType, *args, + **kwargs) -> bool: + ret = super().put_edge_index(edge_index, *args, **kwargs) + # HACK + self.edge_idx_args = args + self.edge_idx_kwargs = kwargs + self._sampler_is_initialized = False + return ret + + @property + def num_neighbors(self): + return self._num_neighbors + + @num_neighbors.setter + def num_neighbors(self, num_neighbors: NumNeighborsType): + self._num_neighbors = num_neighbors + if hasattr(self, 'sampler'): + self.sampler.num_neighbors = num_neighbors + + def sample_subgraph( + self, seed_nodes: InputNodes, seed_edges: InputEdges, + num_neighbors: Optional[NumNeighborsType] = None + ) -> Union[SamplerOutput, HeteroSamplerOutput]: + """Sample the graph starting from the given nodes and edges using the + in-built NeighborSampler. + + Args: + seed_nodes (InputNodes): Seed nodes to start sampling from. + seed_edges (InputEdges): Seed edges to start sampling from. + num_neighbors (Optional[NumNeighborsType], optional): Parameters + to determine how many hops and number of neighbors per hop. + Defaults to None. + + Returns: + Union[SamplerOutput, HeteroSamplerOutput]: NeighborSamplerOutput + for the input. + """ + if not self._sampler_is_initialized: + self._init_sampler() + if num_neighbors is not None: + self.num_neighbors = num_neighbors + + # FIXME: Right now, only input nodes/edges as tensors are be supported + if not isinstance(seed_nodes, Tensor): + raise NotImplementedError + if not isinstance(seed_edges, Tensor): + raise NotImplementedError + device = seed_nodes.device + + # TODO: Call sample_from_edges for seed_edges + # Turning them into nodes for now. + seed_edges = self.edge_index.to(device).T[seed_edges.to( + device)].reshape(-1) + seed_nodes = torch.cat((seed_nodes, seed_edges), dim=0) + + seed_nodes = seed_nodes.unique().contiguous() + node_sample_input = NodeSamplerInput(input_id=None, node=seed_nodes) + out = self.sampler.sample_from_nodes(node_sample_input) + + return out diff --git a/examples/llm/git_mol.py b/examples/llm/git_mol.py new file mode 100644 index 000000000000..d05104db050c --- /dev/null +++ b/examples/llm/git_mol.py @@ -0,0 +1,133 @@ +"""This example implements the GIT-Mol model +(https://arxiv.org/abs/2308.06911) using PyG. +""" +import argparse +import os.path as osp + +import torch +from accelerate import Accelerator +from torch.optim.lr_scheduler import StepLR +from tqdm import tqdm + +from torch_geometric import seed_everything +from torch_geometric.datasets import GitMolDataset +from torch_geometric.loader import DataLoader +from torch_geometric.nn.models import GITMol + + +@torch.no_grad() +def eval(model, data_loader): + model.eval() + loss = 0 + + for batch in data_loader: + batch_loss = model(batch.x, batch.edge_index, batch.batch, + batch.edge_attr, batch.smiles, batch.image, + batch.caption) + loss += batch_loss.item() / len(data_loader) + return loss + + +def train( + num_epochs: int, + lr: float, + weight_decay: float, + batch_size: int, + checkpointing: bool, +): + # Load dataset ================================================ + path = osp.dirname(osp.realpath(__file__)) + path = osp.join(path, '..', '..', 'data', 'GITMol') + train_dataset = GitMolDataset(path, split=0) + val_dataset = GitMolDataset(path, split=1) + test_dataset = GitMolDataset(path, split=2) + + seed_everything(42) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, + drop_last=True, pin_memory=True, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, + drop_last=False, pin_memory=True, shuffle=False) + test_loader = DataLoader(test_dataset, batch_size=batch_size, + drop_last=False, pin_memory=True, shuffle=False) + + # Create model =============================================== + accelerator = Accelerator() + device = accelerator.device + model = GITMol().to(device) + optimizer = torch.optim.AdamW( + [p for p in model.parameters() if p.requires_grad], lr=lr, + weight_decay=weight_decay) + scheduler = StepLR(optimizer, step_size=1, gamma=0.1) + model, optimizer, train_loader, scheduler = accelerator.prepare( + model, optimizer, train_loader, scheduler) + val_loader = accelerator.prepare_data_loader(val_loader, + device_placement=True) + test_loader = accelerator.prepare_data_loader(test_loader, + device_placement=True) + + # Train and eval ============================================ + best_epoch = 0 + best_val_loss = float('inf') + for epoch in range(num_epochs): + # Train + model.train() + epoch_loss = 0 + if epoch == 0: + print("Training beginning...") + epoch_str = f'Epoch: {epoch + 1}|{num_epochs}' + + for batch in tqdm(train_loader, desc=epoch_str): + optimizer.zero_grad() + loss = model(batch.x, batch.edge_index, batch.batch, + batch.edge_attr, batch.smiles, batch.image, + batch.caption) + accelerator.backward(loss) + + optimizer.step() + epoch_loss += loss.item() + + train_loss = epoch_loss / len(train_loader) + + # Eval + val_loss = eval(model, val_loader) + print( + f'{epoch_str}, Train loss: {train_loss:4f}, Val loss: {val_loss:4f}' # noqa: E501 + ) + + if checkpointing and val_loss < best_val_loss: + best_val_loss = val_loss + best_epoch = epoch + torch.save( + { + 'model_state_dict': + accelerator.unwrap_model(model).state_dict(), + 'best_loss': + best_val_loss + }, + f'gitmol_pretrain_epoch{best_epoch}_val_loss{best_val_loss:4f}_ckpt.pt' # noqa: E501 + ) + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + + # Test + test_loss = eval(model, test_loader) + print(f'Test loss: {test_loss:4f}') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--epochs', type=int, default=3) + parser.add_argument('--lr', type=float, default=1e-5) + parser.add_argument('--batch_size', type=int, default=4) + parser.add_argument("--weight_decay", type=float, default=0.01) + parser.add_argument('--checkpointing', type=bool, default=True) + args = parser.parse_args() + + train( + args.epochs, + args.lr, + args.weight_decay, + args.batch_size, + args.checkpointing, + ) diff --git a/examples/llm/glem.py b/examples/llm/glem.py new file mode 100644 index 000000000000..ec76cef4c010 --- /dev/null +++ b/examples/llm/glem.py @@ -0,0 +1,443 @@ +"""This example run GLEM model using PyG. +Original Paper: https://arxiv.org/abs/2210.14709 +“Learning on Large-scale Text-attributed Graphs via Variational Inference“. +Requirements on top of basic PyG: +`pip install ogb transformers peft tqdm`. +GLEM is a data augmentation co-training strategy for LM and GNN, our +implementation extended original implementation from LM to LLM and opt for LoRA +from peft. + +``note:: + use addtional trick, please add your external prediction by assigning + `ext_pred_path` and combine it into pretraining phase and node features +""" + +import argparse +import os +import os.path as osp +import time + +import torch +from ogb.nodeproppred import Evaluator, PygNodePropPredDataset + +from torch_geometric import seed_everything +from torch_geometric.data import download_google_url +from torch_geometric.datasets import TAGDataset +from torch_geometric.loader import DataLoader, NeighborLoader +from torch_geometric.nn.models import GAT, GCN, GLEM, GraphSAGE + + +def get_n_params(model): + pp = 0 + for p in list(model.parameters()): + nn = 1 + for s in list(p.size()): + nn = nn * s + pp += nn + return pp + + +def main(args): + gpu = args.gpu + dataset_name = args.dataset + root = osp.join('data', 'ogb') + hf_model = args.hf_model + pl_ratio = args.pl_ratio + gnn_lr = args.gnn_lr + lm_lr = args.lm_lr + em_order = args.em_order + gnn_epochs = args.gnn_epochs + lm_epochs = args.lm_epochs + patience = args.patience + verbose = args.verbose + out_dir = args.out_dir + lm_batch_size = args.lm_batch_size + gnn_batch_size = args.gnn_batch_size + lm_use_lora = args.lm_use_lora + token_on_disk = args.token_on_disk + num_em_iters = args.num_em_iters + start_time = time.time() + train_without_ext_pred = args.train_without_ext_pred + ext_pred = None + pretrain_augmented = False + ext_pseudo_labels = None + device = torch.device( + f'cuda:{gpu}' if torch.cuda.is_available() else 'cpu') + print(f'Running on: {torch.cuda.get_device_name({gpu})}') + torch.cuda.empty_cache() + + if not train_without_ext_pred: + ext_pred_path = download_google_url( + id='15sO2m7BeW7C1Upmdw3Cx1JS__6nxTAzY', + folder='data/ogb/ogbn_products/ext_preds', + filename='giant_sagn_scr.pt', log=True) + ext_pred = torch.load(ext_pred_path, map_location=device) + ext_pseudo_labels = ext_pred.argmax(dim=-1) + pretrain_augmented = True + + seed_everything(42) + + dataset = PygNodePropPredDataset(f'ogbn-{dataset_name}', root=root) + split_idx = dataset.get_idx_split() + data = dataset.data + + tag_dataset = TAGDataset(root, dataset, hf_model, + token_on_disk=token_on_disk) + text_dataset = tag_dataset.to_text_dataset() + print(tag_dataset.num_classes, tag_dataset.raw_file_names) + + num_classes = tag_dataset.num_classes + num_features = data.num_features + # =========================== LM Data split =============================== + split_idx = tag_dataset.get_idx_split() + + # GLEM train with augmented data, mark original train data as gold data, + gold_idx = split_idx['train'] + split_idx['valid'] + test_idx = split_idx['test'] + + # randome sample pseudo labels nodes, generate their index + num_pseudo_labels = int(gold_idx.numel() * pl_ratio) + idx_to_select = torch.randperm(test_idx.numel())[:num_pseudo_labels] + pseudo_labels_idx = test_idx[idx_to_select] + train_idx = torch.cat( + (gold_idx, pseudo_labels_idx)) # augmented train_indx + + print(f'train_idx: {train_idx.size(0)}, ' + f'gold_idx: {gold_idx.size(0)}, ' + f'pseudo labels ratio: {pl_ratio}, ' + f'{train_idx.size(0)/gold_idx.size(0) - 1.0}') + gold_dataset = torch.utils.data.Subset(dataset=text_dataset, + indices=gold_idx) + train_dataset = torch.utils.data.Subset(dataset=text_dataset, + indices=train_idx) + # ========================== LM Data Loader =============================== + + print('Building language model dataloader...', end='-->') + + # if set train_without_ext_pred == True, use this for pretrain + text_pretrain_loader = DataLoader(gold_dataset, batch_size=lm_batch_size, + drop_last=False, pin_memory=True, + shuffle=True) + # training with augmented data, + text_train_loader = DataLoader(train_dataset, batch_size=lm_batch_size, + drop_last=False, pin_memory=True, + shuffle=True) + text_test_loader = DataLoader(text_dataset, batch_size=lm_batch_size * 4, + drop_last=False, pin_memory=True, + shuffle=False) + print('done') + + # =========================== GNN Data Loader ============================= + initial_memory = torch.cuda.memory_allocated() + data = data.to(device) + if ext_pred is not None: + data.x = torch.cat((data.x, ext_pred), dim=1) + num_features += ext_pred.size(1) + current_memory_1 = torch.cuda.max_memory_allocated() + # 1 GB = 1073741824 Byte + gpu_usage = float(current_memory_1 - initial_memory) / 1073741824 + # Print the maximum memory usage after running the model + print(f'GPU memory usage -- data to gpu: {gpu_usage:.2f} GB') + + print('build GNN dataloader(GraphSAGE NeighborLoader)', end='-->') + + # train on gold data w/o pseudo labels + graph_pretrain_loader = NeighborLoader( + data, + input_nodes=gold_idx, + num_neighbors=[15, 10, 5], + batch_size=gnn_batch_size, + shuffle=True, + num_workers=12, + persistent_workers=True, + ) + + # graph data loader w/ pseudo labels in M-step + graph_train_loader = NeighborLoader( + data, + input_nodes=train_idx, + num_neighbors=[15, 10, 5], + batch_size=gnn_batch_size, + shuffle=True, + num_workers=12, + persistent_workers=True, + ) + + # for gnn inference + subgraph_loader = NeighborLoader( + data, + input_nodes=None, + num_neighbors=[-1], + batch_size=gnn_batch_size * 4, + num_workers=12, + persistent_workers=True, + ) + # =========================== internal function =========================== + + evaluator = Evaluator(name=f'ogbn-{dataset_name}') + + def evaluate(out, split): + y_true = data.y.cpu() + y_pred = out.argmax(dim=-1, keepdim=True) + train_acc, val_acc, test_acc = None, None, None + if 'train' in split: + train_acc = evaluator.eval({ + 'y_true': y_true[split_idx['train']], + 'y_pred': y_pred[split_idx['train']], + })['acc'] + if 'valid' in split: + val_acc = evaluator.eval({ + 'y_true': y_true[split_idx['valid']], + 'y_pred': y_pred[split_idx['valid']], + })['acc'] + if 'test' in split: + test_acc = evaluator.eval({ + 'y_true': y_true[split_idx['test']], + 'y_pred': y_pred[split_idx['test']], + })['acc'] + + return train_acc, val_acc, test_acc + + # =========================== Build GNN Model ============================= + gnn = None + if args.gnn_model == 'SAGE': + gnn = GraphSAGE( + in_channels=num_features, + hidden_channels=args.gnn_hidden_channels, + num_layers=args.gnn_num_layers, + out_channels=dataset.num_classes, + ) + elif args.gnn_model == 'GAT': + gnn = GAT(in_channels=num_features, + hidden_channels=args.gnn_hidden_channels, + num_layers=args.gnn_num_layers, + out_channels=dataset.num_classes, heads=args.gat_heads) + else: + gnn = GCN( + in_channels=num_features, + hidden_channels=args.gnn_hidden_channels, + num_layers=args.gnn_num_layers, + out_channels=dataset.num_classes, + ) + + print("# GNN Params:", get_n_params(gnn)) + # =========================== Build LM Model ============================== + + model = GLEM(lm_to_use=hf_model, gnn_to_use=gnn, out_channels=num_classes, + lm_use_lora=lm_use_lora, device=device) + lm = model.lm + print("# LM Params:", get_n_params(lm)) + gnn_opt = torch.optim.Adam(gnn.parameters(), lr=gnn_lr) + lm_opt = torch.optim.Adam(lm.parameters(), lr=lm_lr) + + def load_model(em_phase): + print(f'Move {em_phase} model from cpu memory') + if em_phase == 'lm': + model.lm = model.lm.to(device, non_blocking=True) + optimizer = torch.optim.Adam(model.lm.parameters(), lr=lm_lr) + if em_phase == 'gnn': + model.gnn = model.gnn.to(device, non_blocking=True) + optimizer = torch.optim.Adam(model.gnn.parameters(), lr=gnn_lr) + return optimizer + + # ================================= Run GLEM ============================== + preds_filename = 'lm_pretrain' + preds_dir = f'{out_dir}preds/{dataset_name}/' + gnn_test_acc = 0.0 + lm_test_acc = 0.0 + # =============================== GLEM pretraining ======================== + pretrain_phase = 'lm' + if em_order == 'lm': + pretrain_phase = 'gnn' + pretrain_start_time = time.time() + # pretraining + pretrain_loader = graph_pretrain_loader + test_loader = subgraph_loader + pretrain_num_epochs = gnn_epochs + pretrain_opt = gnn_opt + if pretrain_phase == 'gnn': + model.gnn = model.gnn.to(device) + print('pretraining gnn to generate pseudo labels') + if not train_without_ext_pred: + pretrain_loader = graph_train_loader + preds_filename = 'gnn_pretrain' + elif pretrain_phase == 'lm': + model.lm = model.lm.to(device) + print('pretraining lm to generate pseudo labels') + pretrain_num_epochs = lm_epochs + pretrain_loader = text_pretrain_loader + test_loader = text_test_loader + pretrain_opt = lm_opt + if not train_without_ext_pred: + pretrain_loader = text_train_loader + preds_filename = 'lm_pretrain' + + early_stopping = 0 + best_val_acc = 0.0 + for epoch in range(1, pretrain_num_epochs + 1): + acc, loss = model.train(pretrain_phase, pretrain_loader, pretrain_opt, + ext_pseudo_labels, epoch, pretrain_augmented, + verbose) + if epoch >= 5 or epoch == pretrain_num_epochs: + pretrain_preds = model.inference(pretrain_phase, test_loader, + verbose=verbose) + train_acc, val_acc, _ = evaluate(pretrain_preds, + ['train', 'valid']) + + print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}') + + if val_acc <= best_val_acc: + early_stopping += 1 + if early_stopping > patience: + print(f'Pretrain Early stopped by Epoch: {epoch}') + break + else: + best_val_acc = val_acc + preds = model.inference(pretrain_phase, test_loader, verbose=verbose) + train_acc, val_acc, test_acc = evaluate(preds, ['train', 'valid', 'test']) + if pretrain_phase == 'gnn': + gnn_test_acc = max(gnn_test_acc, test_acc) + model.gnn = model.gnn.to('cpu', non_blocking=True) + else: + lm_test_acc = max(lm_test_acc, test_acc) + model.lm = model.lm.to('cpu', non_blocking=True) + torch.cuda.empty_cache() + + pretrain_phase_time = time.time() - pretrain_start_time + print(f'Pretrain {pretrain_phase} time: {pretrain_phase_time:.2f}s') + os.makedirs(osp.dirname(preds_dir), exist_ok=True) + torch.save(preds, osp.join(preds_dir, f'{preds_filename}.pt')) + print( + f'Saved predictions to {osp.join(preds_dir, f"{preds_filename}.pt")}') + train_acc, val_acc, test_acc = evaluate(preds, ['train', 'valid', 'test']) + print(f'Pretraining acc: {train_acc:.4f}, Val: {val_acc:.4f}, ' + f'Test: {test_acc:.4f}') + + # EM iterations + + em_phase = em_order + """ + We run E-step(LM training) and M-Step(GNN training) alternatively in each + em iterations, so the total number of iterations is num_em_iter * 2 and + we switch the em_phase at end of each iteration in following loop + """ + gnn_val_acc = lm_val_acc = 0.0 + for em_it in range(1, num_em_iters * 2 + 1): + pseudo_labels = preds.argmax(dim=-1) + best_val_acc = 0.0 + print(f'EM iteration: {em_it}, EM phase: {em_phase}') + optimizer = load_model(em_phase) + num_epochs = lm_epochs + train_loader = text_train_loader + test_loader = text_test_loader + early_stopping = 0 + if em_phase == 'gnn': + train_loader = graph_train_loader + num_epochs = gnn_epochs + test_loader = subgraph_loader + for epoch in range(1, num_epochs + 1): + acc, loss = model.train(em_phase, train_loader, optimizer, + pseudo_labels, epoch, True, verbose) + if epoch >= 5 or epoch == num_epochs: + cur_preds = model.inference(em_phase, test_loader, + verbose=verbose) + train_acc, val_acc, _ = evaluate(cur_preds, ['train', 'valid']) + + print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f},') + + if val_acc <= best_val_acc: + early_stopping += 1 + if early_stopping > patience: + print(f'''Early stopped by Epoch: {epoch}, \ + Best acc: {best_val_acc}''') + break + else: + best_val_acc = val_acc + + preds = model.inference(em_phase, test_loader, verbose=verbose) + if em_phase == 'gnn': + gnn_val_acc = max(gnn_val_acc, best_val_acc) + model.gnn = model.gnn.to('cpu', non_blocking=True) + em_phase = 'lm' + else: + lm_val_acc = max(lm_val_acc, best_val_acc) + model.lm = model.lm.to('cpu', non_blocking=True) + em_phase = 'gnn' + torch.cuda.empty_cache() + print(f'Best GNN validation acc: {gnn_val_acc},' + f'LM validation acc: {lm_val_acc}') + print('============================') + if gnn_val_acc > lm_val_acc: + em_phase = 'gnn' + model.gnn = model.gnn.to(device, non_blocking=True) + else: + em_phase = 'lm' + model.lm = model.lm.to(device, non_blocking=True) + test_preds = model.inference(em_phase, test_loader, verbose=verbose) + train_acc, val_acc, test_acc = evaluate(test_preds, + ['train', 'valid', 'test']) + final_test_acc = max(gnn_test_acc, max(lm_test_acc, test_acc)) + print(f'Best test acc: {final_test_acc}, model: {em_phase}') + end_time = time.time() + running_time = (end_time - start_time) / 3600 + print(f'Total running time: {running_time:.2f} hours') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='GLEM Example:') + parser.add_argument('--gpu', type=int, default=0) + parser.add_argument('--num_runs', type=int, default=10, + help='number of runs') + parser.add_argument('--num_em_iters', type=int, default=1, + help='number of iterations') + parser.add_argument("--dataset", type=str, default='products', + help='arxiv or products') + parser.add_argument("--pl_ratio", type=float, default=0.5, + help="pseudo labels ratio") + parser.add_argument('--hf_model', type=str, default='prajjwal1/bert-tiny', + help='huggingface model repo id') + parser.add_argument( + '--gnn_model', type=str, default='SAGE', + help='gnn model for node classification,' + 'options: SAGE, GAT, GCN') + parser.add_argument('--gnn_hidden_channels', type=int, default=256) + parser.add_argument('--gnn_num_layers', type=int, default=3) + parser.add_argument('--gat_heads', type=int, default=4, + help='Number of multi-head-attentions for GAT ') + parser.add_argument('--lm_batch_size', type=int, default=256) + parser.add_argument('--gnn_batch_size', type=int, default=1024) + parser.add_argument( + '--external_pred_path', type=str, default=None, + help="Other model's output logits during the " + "pretraining phase or simply concatenate it with" + "node features as augmented data for gnn") + parser.add_argument('--alpha', type=float, default=0.5, + help='pseudo label weight in E-step') + parser.add_argument('--beta', type=float, default=0.5, + help='pseudo label weight in M-step') + parser.add_argument('--lm_epochs', type=int, default=10) + parser.add_argument('--gnn_epochs', type=int, default=50) + parser.add_argument('--gnn_lr', type=float, default=0.002) + parser.add_argument('--lm_lr', type=float, default=0.001) + parser.add_argument('--patience', type=int, default=3, + help='Patience for early stopping') + parser.add_argument('--verbose', action='store_true', + help='show progress bar during training or not') + parser.add_argument('--em_order', type=str, default='lm', + help='decide train LM first or GNN first') + parser.add_argument('--lm_use_lora', action='store_true', + help='use Lora to fine-tune model or not') + parser.add_argument( + '--token_on_disk', action='store_true', + help='save token on disk and load token from disk' + 'for reducing duplicated tokenizing') + parser.add_argument('--out_dir', type=str, default='output/', + help='output directory') + parser.add_argument( + '--train_without_ext_pred', action='store_true', + help='train glem without using additional pseudo labels ' + 'for augmenting data only available for ogbn-products') + args = parser.parse_args() + print(args) + main(args) diff --git a/examples/llm/molecule_gpt.py b/examples/llm/molecule_gpt.py new file mode 100644 index 000000000000..8f6c6024014d --- /dev/null +++ b/examples/llm/molecule_gpt.py @@ -0,0 +1,193 @@ +"""This example implements the MoleculeGPT model +(https://ai4d3.github.io/papers/34.pdf) using PyG. +""" +import argparse +import math +import os.path as osp +import time + +import torch +from torch.nn.utils import clip_grad_norm_ +from tqdm import tqdm + +from torch_geometric import seed_everything +from torch_geometric.datasets import MoleculeGPTDataset +from torch_geometric.loader import DataLoader +from torch_geometric.nn import GINEConv +from torch_geometric.nn.models import MoleculeGPT +from torch_geometric.nn.nlp import LLM, SentenceTransformer + + +def save_params_dict(model, save_path): + state_dict = model.state_dict() + param_grad_dict = { + k: v.requires_grad + for (k, v) in model.named_parameters() + } + for k in list(state_dict.keys()): + if k in param_grad_dict.keys() and not param_grad_dict[k]: + del state_dict[k] # Delete parameters that do not require gradient + torch.save(state_dict, save_path) + + +@torch.no_grad() +def eval(model, data_loader): + model.eval() + loss = 0 + + for batch in data_loader: + batch_loss = model(batch.x, batch.edge_index, batch.batch, + batch.edge_attr, batch.smiles, batch.instruction, + batch.y) + loss += batch_loss.item() / len(data_loader) + return loss + + +def train( + num_epochs: int, + lr: float, + batch_size: int, + checkpointing: bool, +): + def adjust_learning_rate(param_group, LR, epoch): + # Decay the learning rate with half-cycle cosine after warmup + min_lr = 5e-6 + warmup_epochs = 1 + if epoch < warmup_epochs: + lr = LR + else: + lr = min_lr + (LR - min_lr) * 0.5 * ( + 1.0 + math.cos(math.pi * (epoch - warmup_epochs) / + (num_epochs - warmup_epochs))) + param_group['lr'] = lr + return lr + + start_time = time.time() + # Load dataset ================================================ + path = osp.dirname(osp.realpath(__file__)) + path = osp.join(path, '..', '..', 'data', 'MoleculeGPT') + dataset = MoleculeGPTDataset(path) + train_size, val_size = int(0.8 * len(dataset)), int(0.1 * len(dataset)) + train_dataset = dataset[:train_size] + val_dataset = dataset[train_size:train_size + val_size] + test_dataset = dataset[train_size + val_size:] + + seed_everything(42) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, + drop_last=True, pin_memory=True, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, + drop_last=False, pin_memory=True, shuffle=False) + test_loader = DataLoader(test_dataset, batch_size=batch_size, + drop_last=False, pin_memory=True, shuffle=False) + + # Create model =============================================== + llm = LLM( + # model_name='lmsys/vicuna-7b-v1.5', + model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1', + num_params=1, + dtype=torch.bfloat16, + ) + + graph_encoder = GINEConv( + nn=torch.nn.Sequential( + torch.nn.Linear(6, 768), + torch.nn.ReLU(), + torch.nn.Linear(768, 768), + ), + train_eps=True, + edge_dim=4, + ) + + smiles_encoder = SentenceTransformer( + model_name='DeepChem/ChemBERTa-77M-MTR', + pooling_strategy='last_hidden_state', + ) + + model = MoleculeGPT( + llm=llm, + graph_encoder=graph_encoder, + smiles_encoder=smiles_encoder, + ) + + # Train and eval ============================================ + params = [p for _, p in model.named_parameters() if p.requires_grad] + optimizer = torch.optim.AdamW([ + { + 'params': params, + 'lr': lr, + 'weight_decay': 0.05, + }, + ], betas=(0.9, 0.95)) + grad_steps = 2 + + best_epoch = 0 + best_val_loss = float('inf') + for epoch in range(num_epochs): + # Train + model.train() + epoch_loss = 0 + if epoch == 0: + print(f"Total Preparation Time: {time.time() - start_time:2f}s") + start_time = time.time() + print("Training beginning...") + epoch_str = f'Epoch: {epoch + 1}|{num_epochs}' + loader = tqdm(train_loader, desc=epoch_str) + + for step, batch in enumerate(loader): + optimizer.zero_grad() + loss = model(batch.x, batch.edge_index, batch.batch, + batch.edge_attr, batch.smiles, batch.instruction, + batch.y) + loss.backward() + clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1) + + if (step + 1) % grad_steps == 0: + adjust_learning_rate(optimizer.param_groups[0], lr, + step / len(train_loader) + epoch) + + optimizer.step() + epoch_loss += loss.item() + + if (step + 1) % grad_steps == 0: + lr = optimizer.param_groups[0]['lr'] + train_loss = epoch_loss / len(train_loader) + + # Eval + val_loss = eval(model, val_loader) + print( + f'{epoch_str}, Train loss: {train_loss:4f}, Val loss: {val_loss:4f}' # noqa: E501 + ) + + if checkpointing and val_loss < best_val_loss: + best_val_loss = val_loss + best_epoch = epoch + save_params_dict( + model, + f'moleculegpt_epoch{best_epoch}_val_loss{best_val_loss:4f}_ckpt.pt' # noqa: E501 + ) + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + + print(f"Total Training Time: {time.time() - start_time:2f}s") + # Test + test_loss = eval(model, test_loader) + print(f'Test loss: {test_loss:4f}') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--epochs', type=int, default=3) + parser.add_argument('--lr', type=float, default=1e-5) + parser.add_argument('--batch_size', type=int, default=2) + parser.add_argument('--checkpointing', type=bool, default=True) + args = parser.parse_args() + + start_time = time.time() + train( + args.epochs, + args.lr, + args.batch_size, + args.checkpointing, + ) + print(f'Total Time: {time.time() - start_time:2f}s') diff --git a/examples/llm/multihop_rag/README.md b/examples/llm/multihop_rag/README.md new file mode 100644 index 000000000000..ff43b16a2c05 --- /dev/null +++ b/examples/llm/multihop_rag/README.md @@ -0,0 +1,9 @@ +# Examples for LLM and GNN co-training + +| Example | Description | +| -------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ | +| [`multihop_download.sh`](./multihop_download.sh) | Downloads all the components of the multihop dataset. | +| [`multihop_preprocess.py`](./multihop_preprocess.py) | Preprocesses the dataset to pair questions/answers with components in the knowledge graph. Contains documentation to describe the process. | +| [`rag_generate_multihop.py`](./rag_generate_multihop.py) | Utilizes the sample remote backend in [`g_retriever_utils`](../g_retriever_utils/) to generate subgraphs for the multihop dataset. | + +NOTE: Performance of GRetriever on this dataset has not been evaluated. diff --git a/examples/llm/multihop_rag/multihop_download.sh b/examples/llm/multihop_rag/multihop_download.sh new file mode 100644 index 000000000000..3c1970d39440 --- /dev/null +++ b/examples/llm/multihop_rag/multihop_download.sh @@ -0,0 +1,12 @@ +#!/bin/sh + +# Wikidata5m + +wget -O "wikidata5m_alias.tar.gz" "https://www.dropbox.com/s/lnbhc8yuhit4wm5/wikidata5m_alias.tar.gz" +tar -xvf "wikidata5m_alias.tar.gz" +wget -O "wikidata5m_all_triplet.txt.gz" "https://www.dropbox.com/s/563omb11cxaqr83/wikidata5m_all_triplet.txt.gz" +gzip -d "wikidata5m_all_triplet.txt.gz" -f + +# 2Multihopqa +wget -O "data_ids_april7.zip" "https://www.dropbox.com/s/ms2m13252h6xubs/data_ids_april7.zip" +unzip -o "data_ids_april7.zip" diff --git a/examples/llm/multihop_rag/multihop_preprocess.py b/examples/llm/multihop_rag/multihop_preprocess.py new file mode 100644 index 000000000000..46052bdf1b15 --- /dev/null +++ b/examples/llm/multihop_rag/multihop_preprocess.py @@ -0,0 +1,276 @@ +"""Example workflow for downloading and assembling a multihop QA dataset.""" + +import argparse +import json +from subprocess import call + +import pandas as pd +import torch +import tqdm + +from torch_geometric.data import LargeGraphIndexer + +# %% [markdown] +# # Encoding A Large Knowledge Graph Part 2 + +# %% [markdown] +# In this notebook, we will continue where we left off by building a new +# multi-hop QA dataset based on Wikidata. + +# %% [markdown] +# ## Example 2: Building a new Dataset from Questions and an already-existing +# Knowledge Graph + +# %% [markdown] +# ### Motivation + +# %% [markdown] +# One potential application of knowledge graph structural encodings is +# capturing the relationships between different entities that are multiple +# hops apart. This can be challenging for an LLM to recognize from prepended +# graph information. Here's a motivating example (credit to @Rishi Puri): + +# %% [markdown] +# In this example, the question can only be answered by reasoning about the +# relationships between the entities in the knowledge graph. + +# %% [markdown] +# ### Building a Multi-Hop QA Dataset + +# %% [markdown] +# To start, we need to download the raw data of a knowledge graph. +# In this case, we use WikiData5M +# ([Wang et al] +# (https://paperswithcode.com/paper/kepler-a-unified-model-for-knowledge)). +# Here we download the raw triplets and their entity codes. Information about +# this dataset can be found +# [here](https://deepgraphlearning.github.io/project/wikidata5m). + +# %% [markdown] +# The following download contains the ID to plaintext mapping for all the +# entities and relations in the knowledge graph: + +rv = call("./multihop_download.sh") + +# %% [markdown] +# To start, we are going to preprocess the knowledge graph to substitute each +# of the entity/relation codes with their plaintext aliases. This makes it +# easier to use a pre-trained textual encoding model to create triplet +# embeddings, as such a model likely won't understand how to properly embed +# the entity codes. + +# %% + +# %% +parser = argparse.ArgumentParser(description="Preprocess wikidata5m") +parser.add_argument("--n_triplets", type=int, default=-1) +args = parser.parse_args() + +# %% +# Substitute entity codes with their aliases +# Picking the first alias for each entity (rather arbitrarily) +alias_map = {} +rel_alias_map = {} +for line in open('wikidata5m_entity.txt'): + parts = line.strip().split('\t') + entity_id = parts[0] + aliases = parts[1:] + alias_map[entity_id] = aliases[0] +for line in open('wikidata5m_relation.txt'): + parts = line.strip().split('\t') + relation_id = parts[0] + relation_name = parts[1] + rel_alias_map[relation_id] = relation_name + +# %% +full_graph = [] +missing_total = 0 +total = 0 +limit = None if args.n_triplets == -1 else args.n_triplets +i = 0 + +for line in tqdm.tqdm(open('wikidata5m_all_triplet.txt')): + if limit is not None and i >= limit: + break + src, rel, dst = line.strip().split('\t') + if src not in alias_map: + missing_total += 1 + if dst not in alias_map: + missing_total += 1 + if rel not in rel_alias_map: + missing_total += 1 + total += 3 + full_graph.append([ + alias_map.get(src, src), + rel_alias_map.get(rel, rel), + alias_map.get(dst, dst) + ]) + i += 1 +print(f"Missing aliases: {missing_total}/{total}") + +# %% [markdown] +# Now `full_graph` represents the knowledge graph triplets in +# understandable plaintext. + +# %% [markdown] +# Next, we need a set of multi-hop questions that the Knowledge Graph will +# provide us with context for. We utilize a subset of +# [HotPotQA](https://hotpotqa.github.io/) +# ([Yang et. al.](https://arxiv.org/pdf/1809.09600)) called +# [2WikiMultiHopQA](https://github.com/Alab-NII/2wikimultihop) +# ([Ho et. al.](https://aclanthology.org/2020.coling-main.580.pdf)), +# which includes a subgraph of entities that serve as the ground truth +# justification for answering each multi-hop question: + +# %% +with open('train.json') as f: + train_data = json.load(f) +train_df = pd.DataFrame(train_data) +train_df['split_type'] = 'train' + +with open('dev.json') as f: + dev_data = json.load(f) +dev_df = pd.DataFrame(dev_data) +dev_df['split_type'] = 'dev' + +with open('test.json') as f: + test_data = json.load(f) +test_df = pd.DataFrame(test_data) +test_df['split_type'] = 'test' + +df = pd.concat([train_df, dev_df, test_df]) + +# %% [markdown] +# Now we need to extract the subgraphs + +# %% +df['graph_size'] = df['evidences_id'].apply(lambda row: len(row)) + +# %% [markdown] +# (Optional) We take only questions where the evidence graph is greater than +# 0. (Note: this gets rid of the test set): + +# %% +# df = df[df['graph_size'] > 0] + +# %% +refined_df = df[[ + '_id', 'question', 'answer', 'split_type', 'evidences_id', 'type', + 'graph_size' +]] + +# %% [markdown] +# Checkpoint: + +# %% +refined_df.to_csv('wikimultihopqa_refined.csv', index=False) + +# %% [markdown] +# Now we need to check that all the entities mentioned in the question/answer +# set are also present in the Wikidata graph: + +# %% +relation_map = {} +with open('wikidata5m_relation.txt') as f: + for line in tqdm.tqdm(f): + parts = line.strip().split('\t') + for i in range(1, len(parts)): + if parts[i] not in relation_map: + relation_map[parts[i]] = [] + relation_map[parts[i]].append(parts[0]) + +# %% +entity_set = set() +with open('wikidata5m_entity.txt') as f: + for line in tqdm.tqdm(f): + entity_set.add(line.strip().split('\t')[0]) + +# %% +missing_entities = set() +missing_entity_idx = set() +for i, row in enumerate(refined_df.itertuples()): + for trip in row.evidences_id: + entities = trip[0], trip[2] + for entity in entities: + if entity not in entity_set: + # print( + # f'The following entity was not found in the KG: {entity}' + # ) + missing_entities.add(entity) + missing_entity_idx.add(i) + +# %% [markdown] +# Right now, we drop the missing entity entries. Additional preprocessing can +# be done here to resolve the entity/relation collisions, but that is out of +# the scope for this notebook. + +# %% +# missing relations are ok, but missing entities cannot be mapped to +# plaintext, so they should be dropped. +refined_df.reset_index(inplace=True, drop=True) + +# %% +cleaned_df = refined_df.drop(missing_entity_idx) + +# %% [markdown] +# Now we save the resulting graph and questions/answers dataset: + +# %% +cleaned_df.to_csv('wikimultihopqa_cleaned.csv', index=False) + +# %% + +# %% +torch.save(full_graph, 'wikimultihopqa_full_graph.pt') + +# %% [markdown] +# ### Question: How do we extract a contextual subgraph for a given query? + +# %% [markdown] +# The chosen retrieval algorithm is a critical component in the pipeline for +# affecting RAG performance. In the next section (1), we will demonstrate a +# naive method of retrieval for a large knowledge graph, and how to apply it +# to this dataset along with WebQSP. + +# %% [markdown] +# ### Preparing a Textualized Graph for LLM + +# %% [markdown] +# For now however, we need to prepare the graph data to be used as a plaintext +# prefix to the LLM. In order to do this, we want to prompt the LLM to use the +# unique nodes, and unique edge triplets of a given subgraph. In order to do +# this, we prepare a unique indexed node df and edge df for the knowledge +# graph now. This process occurs trivially with the LargeGraphIndexer: + +# %% + +# %% +indexer = LargeGraphIndexer.from_triplets(full_graph) + +# %% +# Node DF +textual_nodes = pd.DataFrame.from_dict( + {"node_attr": indexer.get_node_features()}) +textual_nodes["node_id"] = textual_nodes.index +textual_nodes = textual_nodes[["node_id", "node_attr"]] + +# %% [markdown] +# Notice how LargeGraphIndexer ensures that there are no duplicate indices: + +# %% +# Edge DF +textual_edges = pd.DataFrame(indexer.get_edge_features(), + columns=["src", "edge_attr", "dst"]) +textual_edges["src"] = [indexer._nodes[h] for h in textual_edges["src"]] +textual_edges["dst"] = [indexer._nodes[h] for h in textual_edges["dst"]] + +# %% [markdown] +# Note: The edge table refers to each node by its index in the node table. +# We will see how this gets utilized later when indexing a subgraph. + +# %% [markdown] +# Now we can save the result + +# %% +textual_nodes.to_csv('wikimultihopqa_textual_nodes.csv', index=False) +textual_edges.to_csv('wikimultihopqa_textual_edges.csv', index=False) diff --git a/examples/llm/multihop_rag/rag_generate_multihop.py b/examples/llm/multihop_rag/rag_generate_multihop.py new file mode 100644 index 000000000000..de93a9e75dd1 --- /dev/null +++ b/examples/llm/multihop_rag/rag_generate_multihop.py @@ -0,0 +1,88 @@ +# %% +import argparse +import sys +from typing import Tuple + +import pandas as pd +import torch +import tqdm + +from torch_geometric.data import Data +from torch_geometric.datasets.web_qsp_dataset import ( + preprocess_triplet, + retrieval_via_pcst, +) +from torch_geometric.loader import RAGQueryLoader +from torch_geometric.nn.nlp import SentenceTransformer + +sys.path.append('..') + +from g_retriever_utils.rag_backend_utils import \ + create_remote_backend_from_triplets # noqa: E402 +from g_retriever_utils.rag_feature_store import \ + SentenceTransformerApproxFeatureStore # noqa: E402 +from g_retriever_utils.rag_graph_store import \ + NeighborSamplingRAGGraphStore # noqa: E402 + +# %% +parser = argparse.ArgumentParser( + description="Generate new multihop dataset for rag") +# TODO: Add more arguments for configuring rag params +parser.add_argument("--num_samples", type=int) +args = parser.parse_args() + +# %% +triplets = torch.load('wikimultihopqa_full_graph.pt') + +# %% +df = pd.read_csv('wikimultihopqa_cleaned.csv') +questions = df['question'][:args.num_samples] +labels = df['answer'][:args.num_samples] + +# %% +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model = SentenceTransformer( + model_name='sentence-transformers/all-roberta-large-v1').to(device) + +# %% +fs, gs = create_remote_backend_from_triplets( + triplets=triplets, node_embedding_model=model, + node_method_to_call="encode", path="backend", + pre_transform=preprocess_triplet, node_method_kwargs={ + "batch_size": 256 + }, graph_db=NeighborSamplingRAGGraphStore, + feature_db=SentenceTransformerApproxFeatureStore).load() + +# %% + +all_textual_nodes = pd.read_csv('wikimultihopqa_textual_nodes.csv') +all_textual_edges = pd.read_csv('wikimultihopqa_textual_edges.csv') + + +def apply_retrieval_via_pcst(graph: Data, query: str, topk: int = 3, + topk_e: int = 3, + cost_e: float = 0.5) -> Tuple[Data, str]: + q_emb = model.encode(query) + textual_nodes = all_textual_nodes.iloc[graph["node_idx"]].reset_index() + textual_edges = all_textual_edges.iloc[graph["edge_idx"]].reset_index() + out_graph, desc = retrieval_via_pcst(graph, q_emb, textual_nodes, + textual_edges, topk, topk_e, cost_e) + out_graph["desc"] = desc + return out_graph + + +# %% +query_loader = RAGQueryLoader(data=(fs, gs), seed_nodes_kwargs={"k_nodes": 10}, + seed_edges_kwargs={"k_edges": 10}, + sampler_kwargs={"num_neighbors": [40] * 3}, + local_filter=apply_retrieval_via_pcst) + +# %% +subgs = [] +for q, l in tqdm.tqdm(zip(questions, labels)): + subg = query_loader.query(q) + subg['question'] = q + subg['label'] = l + subgs.append(subg) + +torch.save(subgs, 'subg_results.pt') diff --git a/examples/llm/nvtx_examples/README.md b/examples/llm/nvtx_examples/README.md new file mode 100644 index 000000000000..aa4f070d9824 --- /dev/null +++ b/examples/llm/nvtx_examples/README.md @@ -0,0 +1,7 @@ +# Examples for LLM and GNN co-training + +| Example | Description | +| -------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------- | +| [`nvtx_run.sh`](./nvtx_run.sh) | Runs nsys profiler on a given Python file that contains NVTX calls. | +| [`nvtx_rag_backend_example.py`](./nvtx_rag_backend_example.py) | Example script for nsys profiling a RAG Backend such as that used in [`rag_generate.py`](../g_retriever_utils/rag_generate.py). | +| [`nvtx_webqsp_example.py`](./nvtx_webqsp_example.py) | Example script for nsys profiling the WebQSP dataset. | diff --git a/examples/llm/nvtx_examples/nvtx_rag_backend_example.py b/examples/llm/nvtx_examples/nvtx_rag_backend_example.py new file mode 100644 index 000000000000..b30e34b8c7b1 --- /dev/null +++ b/examples/llm/nvtx_examples/nvtx_rag_backend_example.py @@ -0,0 +1,144 @@ +# %% +import argparse +import sys +from itertools import chain +from typing import Tuple + +import torch + +from torch_geometric.data import Data, get_features_for_triplets_groups +from torch_geometric.datasets import WebQSPDataset +from torch_geometric.datasets.web_qsp_dataset import ( + preprocess_triplet, + retrieval_via_pcst, +) +from torch_geometric.loader import rag_loader +from torch_geometric.nn.nlp import SentenceTransformer +from torch_geometric.profile.nvtx import nvtxit + +sys.path.append('..') +from g_retriever_utils.rag_backend_utils import \ + create_remote_backend_from_triplets # noqa: E402 +from g_retriever_utils.rag_feature_store import \ + SentenceTransformerFeatureStore # noqa: E402 +from g_retriever_utils.rag_graph_store import \ + NeighborSamplingRAGGraphStore # noqa: E402 + +# %% +# Patch FeatureStore and GraphStore + +SentenceTransformerFeatureStore.retrieve_seed_nodes = nvtxit()( + SentenceTransformerFeatureStore.retrieve_seed_nodes) +SentenceTransformerFeatureStore.retrieve_seed_edges = nvtxit()( + SentenceTransformerFeatureStore.retrieve_seed_edges) +SentenceTransformerFeatureStore.load_subgraph = nvtxit()( + SentenceTransformerFeatureStore.load_subgraph) +NeighborSamplingRAGGraphStore.sample_subgraph = nvtxit()( + NeighborSamplingRAGGraphStore.sample_subgraph) +rag_loader.RAGQueryLoader.query = nvtxit()(rag_loader.RAGQueryLoader.query) + +# %% +ds = WebQSPDataset("small_ds_1", force_reload=True, limit=10) + +# %% +triplets = list(chain.from_iterable(d['graph'] for d in ds.raw_dataset)) + +# %% +questions = ds.raw_dataset['question'] + +# %% +ground_truth_graphs = get_features_for_triplets_groups( + ds.indexer, (d['graph'] for d in ds.raw_dataset), + pre_transform=preprocess_triplet) +num_edges = len(ds.indexer._edges) + +# %% +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model = SentenceTransformer('sentence-transformers/all-roberta-large-v1').to( + device) + +# %% +fs, gs = create_remote_backend_from_triplets( + triplets=triplets, node_embedding_model=model, + node_method_to_call="encode", path="backend", + pre_transform=preprocess_triplet, node_method_kwargs={ + "batch_size": 256 + }, graph_db=NeighborSamplingRAGGraphStore, + feature_db=SentenceTransformerFeatureStore).load() + +# %% + + +@nvtxit() +def apply_retrieval_via_pcst(graph: Data, query: str, topk: int = 3, + topk_e: int = 3, + cost_e: float = 0.5) -> Tuple[Data, str]: + q_emb = model.encode(query) + textual_nodes = ds.textual_nodes.iloc[graph["node_idx"]].reset_index() + textual_edges = ds.textual_edges.iloc[graph["edge_idx"]].reset_index() + out_graph, desc = retrieval_via_pcst(graph, q_emb, textual_nodes, + textual_edges, topk, topk_e, cost_e) + out_graph["desc"] = desc + return graph + + +# %% +query_loader = rag_loader.RAGQueryLoader( + data=(fs, gs), seed_nodes_kwargs={"k_nodes": + 10}, seed_edges_kwargs={"k_edges": 10}, + sampler_kwargs={"num_neighbors": + [40] * 10}, local_filter=apply_retrieval_via_pcst) + + +# %% +# Accuracy Metrics to be added to Profiler +def _eidx_helper(subg: Data, ground_truth: Data): + subg_eidx, gt_eidx = subg.edge_idx, ground_truth.edge_idx + if isinstance(subg_eidx, torch.Tensor): + subg_eidx = subg_eidx.tolist() + if isinstance(gt_eidx, torch.Tensor): + gt_eidx = gt_eidx.tolist() + subg_e = set(subg_eidx) + gt_e = set(gt_eidx) + return subg_e, gt_e + + +def check_retrieval_accuracy(subg: Data, ground_truth: Data, num_edges: int): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + total_e = set(range(num_edges)) + tp = len(subg_e & gt_e) + tn = len(total_e - (subg_e | gt_e)) + return (tp + tn) / num_edges + + +def check_retrieval_precision(subg: Data, ground_truth: Data): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + return len(subg_e & gt_e) / len(subg_e) + + +def check_retrieval_recall(subg: Data, ground_truth: Data): + subg_e, gt_e = _eidx_helper(subg, ground_truth) + return len(subg_e & gt_e) / len(gt_e) + + +# %% + + +@nvtxit() +def _run_eval(): + for subg, gt in zip((query_loader.query(q) for q in questions), + ground_truth_graphs): + print(check_retrieval_accuracy(subg, gt, num_edges), + check_retrieval_precision(subg, gt), + check_retrieval_recall(subg, gt)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--capture-torch-kernels", "-k", action="store_true") + args = parser.parse_args() + if args.capture_torch_kernels: + with torch.autograd.profiler.emit_nvtx(): + _run_eval() + else: + _run_eval() diff --git a/examples/llm/nvtx_examples/nvtx_run.sh b/examples/llm/nvtx_examples/nvtx_run.sh new file mode 100644 index 000000000000..4c6fce7c8224 --- /dev/null +++ b/examples/llm/nvtx_examples/nvtx_run.sh @@ -0,0 +1,27 @@ +#!/bin/sh + +# Check if the user provided a Python file +if [ -z "$1" ]; then + echo "Usage: $0 " + exit 1 +fi + +# Check if the provided file exists +if [[ ! -f "$1" ]]; then + echo "Error: File '$1' does not exist." + exit 1 +fi + +# Check if the provided file is a Python file +if [[ ! "$1" == *.py ]]; then + echo "Error: '$1' is not a Python file." + exit 1 +fi + +# Get the base name of the Python file +python_file=$(basename "$1") + +# Run nsys profile on the Python file +nsys profile -c cudaProfilerApi --capture-range-end repeat -t cuda,nvtx,osrt,cudnn,cublas --cuda-memory-usage true --cudabacktrace all --force-overwrite true --output=profile_${python_file%.py} python "$1" + +echo "Profile data saved as profile_${python_file%.py}.nsys-rep" diff --git a/examples/llm/nvtx_examples/nvtx_webqsp_example.py b/examples/llm/nvtx_examples/nvtx_webqsp_example.py new file mode 100644 index 000000000000..5a9aad27f1c0 --- /dev/null +++ b/examples/llm/nvtx_examples/nvtx_webqsp_example.py @@ -0,0 +1,22 @@ +import argparse + +import torch + +from torch_geometric.datasets import web_qsp_dataset +from torch_geometric.profile import nvtxit + +# Apply Patches +web_qsp_dataset.retrieval_via_pcst = nvtxit()( + web_qsp_dataset.retrieval_via_pcst) +web_qsp_dataset.WebQSPDataset.process = nvtxit()( + web_qsp_dataset.WebQSPDataset.process) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--capture-torch-kernels", "-k", action="store_true") + args = parser.parse_args() + if args.capture_torch_kernels: + with torch.autograd.profiler.emit_nvtx(): + ds = web_qsp_dataset.WebQSPDataset('baseline', split='val') + else: + ds = web_qsp_dataset.WebQSPDataset('baseline', split='val') diff --git a/examples/multi_gpu/papers100m_gcn_cugraph.py b/examples/multi_gpu/papers100m_gcn_cugraph.py index 5413492a5bc5..799b6317c374 100644 --- a/examples/multi_gpu/papers100m_gcn_cugraph.py +++ b/examples/multi_gpu/papers100m_gcn_cugraph.py @@ -86,8 +86,8 @@ def run(rank, data, world_size, cugraph_id, model, epochs, batch_size, fan_out, )] = ixr feature_store = TensorDictFeatureStore() - feature_store['node', 'x'] = data.x - feature_store['node', 'y'] = data.y + feature_store['node', 'x', None] = data.x + feature_store['node', 'y', None] = data.y dist.barrier() diff --git a/examples/multi_gpu/papers100m_gcn_cugraph_multinode.py b/examples/multi_gpu/papers100m_gcn_cugraph_multinode.py index 4ea78eb64ad6..eb074defeafe 100644 --- a/examples/multi_gpu/papers100m_gcn_cugraph_multinode.py +++ b/examples/multi_gpu/papers100m_gcn_cugraph_multinode.py @@ -142,9 +142,9 @@ def load_partitioned_data(rank, edge_path, feature_path, label_path, meta_path, split_idx[split] = fs.torch_load(path) path = osp.join(feature_path, f'rank={rank}_x.pt') - feature_store['node', 'x'] = fs.torch_load(path) + feature_store['node', 'x', None] = fs.torch_load(path) path = osp.join(feature_path, f'rank={rank}_y.pt') - feature_store['node', 'y'] = fs.torch_load(path) + feature_store['node', 'y', None] = fs.torch_load(path) eix = fs.torch_load(osp.join(edge_path, f'rank={rank}.pt')) graph_store[dict( diff --git a/examples/ogbn_papers_100m_cugraph.py b/examples/ogbn_papers_100m_cugraph.py index 7c1da866056a..8ae35cd776a4 100644 --- a/examples/ogbn_papers_100m_cugraph.py +++ b/examples/ogbn_papers_100m_cugraph.py @@ -63,8 +63,8 @@ )] = data.edge_index feature_store = cugraph_pyg.data.TensorDictFeatureStore() -feature_store['node', 'x'] = data.x -feature_store['node', 'y'] = data.y +feature_store['node', 'x', None] = data.x +feature_store['node', 'y', None] = data.y data = (feature_store, graph_store) diff --git a/test/data/test_large_graph_indexer.py b/test/data/test_large_graph_indexer.py new file mode 100644 index 000000000000..b98fe7d7ddbf --- /dev/null +++ b/test/data/test_large_graph_indexer.py @@ -0,0 +1,177 @@ +import random +import string +from typing import List + +import pytest +import torch + +from torch_geometric.data import ( + Data, + LargeGraphIndexer, + TripletLike, + get_features_for_triplets, +) +from torch_geometric.data.large_graph_indexer import ( + EDGE_PID, + EDGE_RELATION, + NODE_PID, +) +from torch_geometric.typing import WITH_PT20 + +# create possible nodes and edges for graph +strkeys = string.ascii_letters + string.digits +NODE_POOL = list({"".join(random.sample(strkeys, 10)) for i in range(1000)}) +EDGE_POOL = list({"".join(random.sample(strkeys, 10)) for i in range(50)}) + + +def featurize(s: str) -> int: + return int.from_bytes(s.encode(), 'little') + + +def sample_triplets(amount: int = 1) -> List[TripletLike]: + trips = [] + for i in range(amount): + h, t = random.sample(NODE_POOL, k=2) + r = random.sample(EDGE_POOL, k=1)[0] + trips.append(tuple([h, r, t])) + return trips + + +def preprocess_triplet(triplet: TripletLike) -> TripletLike: + h, r, t = triplet + return h.lower(), r, t.lower() + + +def test_basic_collate(): + graphs = [sample_triplets(1000) for i in range(2)] + + indexer_0 = LargeGraphIndexer.from_triplets( + graphs[0], pre_transform=preprocess_triplet) + indexer_1 = LargeGraphIndexer.from_triplets( + graphs[1], pre_transform=preprocess_triplet) + + big_indexer = LargeGraphIndexer.collate([indexer_0, indexer_1]) + + assert len(indexer_0._nodes) + len( + indexer_1._nodes) - len(indexer_0._nodes.keys() + & indexer_1._nodes.keys()) == len( + big_indexer._nodes) + assert len(indexer_0._edges) + len( + indexer_1._edges) - len(indexer_0._edges.keys() + & indexer_1._edges.keys()) == len( + big_indexer._edges) + + assert len(set(big_indexer._nodes.values())) == len(big_indexer._nodes) + assert len(set(big_indexer._edges.values())) == len(big_indexer._edges) + + for node in (indexer_0._nodes.keys() | indexer_1._nodes.keys()): + assert big_indexer.node_attr[NODE_PID][ + big_indexer._nodes[node]] == node + + +def test_large_graph_index(): + graphs = [sample_triplets(1000) for i in range(100)] + + # Preprocessing of trips lowercases nodes but not edges + node_feature_vecs = {s.lower(): featurize(s.lower()) for s in NODE_POOL} + edge_feature_vecs = {s: featurize(s) for s in EDGE_POOL} + + def encode_graph_from_trips(triplets: List[TripletLike]) -> Data: + seen_nodes = dict() + edge_attrs = list() + edge_idx = [] + for trip in triplets: + trip = preprocess_triplet(trip) + h, r, t = trip + seen_nodes[h] = len( + seen_nodes) if h not in seen_nodes else seen_nodes[h] + seen_nodes[t] = len( + seen_nodes) if t not in seen_nodes else seen_nodes[t] + edge_attrs.append(edge_feature_vecs[r]) + edge_idx.append((seen_nodes[h], seen_nodes[t])) + + x = torch.Tensor([node_feature_vecs[n] for n in seen_nodes.keys()]) + edge_idx = torch.LongTensor(edge_idx).T + edge_attrs = torch.Tensor(edge_attrs) + return Data(x=x, edge_index=edge_idx, edge_attr=edge_attrs) + + naive_graph_ds = [ + encode_graph_from_trips(triplets=trips) for trips in graphs + ] + + indexer = LargeGraphIndexer.collate([ + LargeGraphIndexer.from_triplets(g, pre_transform=preprocess_triplet) + for g in graphs + ]) + indexer_nodes = indexer.get_unique_node_features() + indexer_node_vals = torch.Tensor( + [node_feature_vecs[n] for n in indexer_nodes]) + indexer_edges = indexer.get_unique_edge_features( + feature_name=EDGE_RELATION) + indexer_edge_vals = torch.Tensor( + [edge_feature_vecs[e] for e in indexer_edges]) + indexer.add_node_feature('x', indexer_node_vals) + indexer.add_edge_feature('edge_attr', indexer_edge_vals, + map_from_feature=EDGE_RELATION) + large_graph_ds = [ + get_features_for_triplets(indexer=indexer, triplets=g, + node_feature_name='x', + edge_feature_name='edge_attr', + pre_transform=preprocess_triplet) + for g in graphs + ] + + for ds in large_graph_ds: + assert NODE_PID in ds + assert EDGE_PID in ds + assert "node_idx" in ds + assert "edge_idx" in ds + + def results_are_close_enough(ground_truth: Data, new_method: Data, + thresh=.99): + def _sorted_tensors_are_close(tensor1, tensor2): + return torch.all( + torch.isclose(tensor1.sort()[0], + tensor2.sort()[0]) > thresh) + + def _graphs_are_same(tensor1, tensor2): + if not WITH_PT20: + pytest.skip( + "This test requires a PyG version with NetworkX as a " + + "dependency.") + import networkx as nx + return nx.weisfeiler_lehman_graph_hash(nx.Graph( + tensor1.T)) == nx.weisfeiler_lehman_graph_hash( + nx.Graph(tensor2.T)) + return True + return _sorted_tensors_are_close( + ground_truth.x, new_method.x) \ + and _sorted_tensors_are_close( + ground_truth.edge_attr, new_method.edge_attr) \ + and _graphs_are_same( + ground_truth.edge_index, new_method.edge_index) + + for dsets in zip(naive_graph_ds, large_graph_ds): + assert results_are_close_enough(*dsets) + + +def test_save_load(tmp_path): + graph = sample_triplets(1000) + + node_feature_vecs = {s: featurize(s) for s in NODE_POOL} + edge_feature_vecs = {s: featurize(s) for s in EDGE_POOL} + + indexer = LargeGraphIndexer.from_triplets(graph) + indexer_nodes = indexer.get_unique_node_features() + indexer_node_vals = torch.Tensor( + [node_feature_vecs[n] for n in indexer_nodes]) + indexer_edges = indexer.get_unique_edge_features( + feature_name=EDGE_RELATION) + indexer_edge_vals = torch.Tensor( + [edge_feature_vecs[e] for e in indexer_edges]) + indexer.add_node_feature('x', indexer_node_vals) + indexer.add_edge_feature('edge_attr', indexer_edge_vals, + map_from_feature=EDGE_RELATION) + + indexer.save(str(tmp_path)) + assert indexer == LargeGraphIndexer.from_disk(str(tmp_path)) diff --git a/test/datasets/test_git_mol_dataset.py b/test/datasets/test_git_mol_dataset.py new file mode 100644 index 000000000000..f4e652b6ae43 --- /dev/null +++ b/test/datasets/test_git_mol_dataset.py @@ -0,0 +1,22 @@ +from typing import Tuple + +import pytest + +from torch_geometric.datasets import GitMolDataset +from torch_geometric.testing import onlyFullTest, withPackage + + +@onlyFullTest +@withPackage('torchvision', 'rdkit', 'PIL') +@pytest.mark.parametrize('split', [ + (0, 3610), + (1, 451), + (2, 451), +]) +def test_git_mol_dataset(split: Tuple[int, int]) -> None: + dataset = GitMolDataset(root='./data/GITMol', split=split[0]) + + assert len(dataset) == split[1] + assert dataset[0].image.size() == (1, 3, 224, 224) + assert dataset[0].num_node_features == 9 + assert dataset[0].num_edge_features == 3 diff --git a/test/datasets/test_molecule_gpt_dataset.py b/test/datasets/test_molecule_gpt_dataset.py new file mode 100644 index 000000000000..7c00c5efc1b6 --- /dev/null +++ b/test/datasets/test_molecule_gpt_dataset.py @@ -0,0 +1,10 @@ +from torch_geometric.datasets import MoleculeGPTDataset +from torch_geometric.testing import withPackage + + +@withPackage('transformers', 'sentencepiece', 'accelerate', 'rdkit') +def test_molecule_gpt_dataset(): + dataset = MoleculeGPTDataset(root='./data/MoleculeGPT') + assert str(dataset) == f'MoleculeGPTDataset({len(dataset)})' + assert dataset.num_edge_features == 4 + assert dataset.num_node_features == 6 diff --git a/test/nn/attention/test_qformer.py b/test/nn/attention/test_qformer.py new file mode 100644 index 000000000000..0de023708fd8 --- /dev/null +++ b/test/nn/attention/test_qformer.py @@ -0,0 +1,13 @@ +import torch + +from torch_geometric.nn.attention import QFormer + + +def test_qformer(): + x = torch.randn(1, 4, 16) + attn = QFormer(input_dim=16, hidden_dim=16, output_dim=32, num_heads=4, + num_layers=2) + out = attn(x) + + assert out.shape == (1, 4, 32) + assert str(attn) == ('QFormer(num_heads=4, num_layers=2)') diff --git a/test/nn/models/test_g_retriever.py b/test/nn/models/test_g_retriever.py index 899e70730cc9..24a74d1b6f6e 100644 --- a/test/nn/models/test_g_retriever.py +++ b/test/nn/models/test_g_retriever.py @@ -51,3 +51,52 @@ def test_g_retriever() -> None: # Test inference: pred = model.inference(question, x, edge_index, batch, edge_attr) assert len(pred) == 1 + + +@onlyFullTest +@withPackage('transformers', 'sentencepiece', 'accelerate') +def test_g_retriever_many_tokens() -> None: + llm = LLM( + model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1', + num_params=1, + dtype=torch.float16, + ) + + gnn = GAT( + in_channels=1024, + out_channels=1024, + hidden_channels=1024, + num_layers=2, + heads=4, + norm='batch_norm', + ) + + model = GRetriever( + llm=llm, + gnn=gnn, + mlp_out_channels=2048, + mlp_out_tokens=2, + ) + assert str(model) == ('GRetriever(\n' + ' llm=LLM(TinyLlama/TinyLlama-1.1B-Chat-v0.1),\n' + ' gnn=GAT(1024, 1024, num_layers=2),\n' + ')') + + x = torch.randn(10, 1024) + edge_index = torch.tensor([ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 0], + ]) + edge_attr = torch.randn(edge_index.size(1), 1024) + batch = torch.zeros(x.size(0), dtype=torch.long) + + question = ["Is PyG the best open-source GNN library?"] + label = ["yes!"] + + # Test train: + loss = model(question, x, edge_index, batch, label, edge_attr) + assert loss >= 0 + + # Test inference: + pred = model.inference(question, x, edge_index, batch, edge_attr) + assert len(pred) == 1 diff --git a/test/nn/models/test_git_mol.py b/test/nn/models/test_git_mol.py new file mode 100644 index 000000000000..ee557bfaa9fc --- /dev/null +++ b/test/nn/models/test_git_mol.py @@ -0,0 +1,24 @@ +import torch + +from torch_geometric.nn.models import GITMol +from torch_geometric.testing import withPackage + + +@withPackage('transformers', 'sentencepiece', 'accelerate') +def test_git_mol(): + model = GITMol() + + x = torch.ones(10, 16, dtype=torch.long) + edge_index = torch.tensor([ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [1, 2, 3, 4, 0, 6, 7, 8, 9, 5], + ]) + edge_attr = torch.zeros(edge_index.size(1), 16, dtype=torch.long) + batch = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + smiles = ['CC(C)([C@H]1CC2=C(O1)C=CC3=C2OC(=O)C=C3)O'] + captions = ['The molecule is the (R)-(-)-enantiomer of columbianetin.'] + images = torch.randn(1, 3, 224, 224) + + # Test train: + loss = model(x, edge_index, batch, edge_attr, smiles, images, captions) + assert loss >= 0 diff --git a/test/nn/models/test_molecule_gpt.py b/test/nn/models/test_molecule_gpt.py new file mode 100644 index 000000000000..c9f0a53403ee --- /dev/null +++ b/test/nn/models/test_molecule_gpt.py @@ -0,0 +1,60 @@ +import torch +from torch.nn import Linear as Lin +from torch.nn import ReLU +from torch.nn import Sequential as Seq + +from torch_geometric.nn import GINEConv, MoleculeGPT +from torch_geometric.nn.nlp import LLM, SentenceTransformer +from torch_geometric.testing import onlyFullTest, withPackage + + +@onlyFullTest +@withPackage('transformers', 'sentencepiece', 'accelerate') +def test_molecule_gpt() -> None: + llm = LLM( + # model_name='lmsys/vicuna-7b-v1.5', + model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1', + num_params=1, + dtype=torch.bfloat16, + ) + + graph_encoder = GINEConv(nn=Seq(Lin(16, 16), ReLU(), Lin(16, 16)), + train_eps=True, edge_dim=16) + + smiles_encoder = SentenceTransformer( + model_name='DeepChem/ChemBERTa-77M-MTR', + pooling_strategy='last_hidden_state', + ) + + model = MoleculeGPT( + llm=llm, + graph_encoder=graph_encoder, + smiles_encoder=smiles_encoder, + ) + + assert str(model) == ( + 'MoleculeGPT(\n' + ' llm=LLM(TinyLlama/TinyLlama-1.1B-Chat-v0.1),\n' + ' graph=GINEConv,\n' + ' smiles=SentenceTransformer(model_name=DeepChem/ChemBERTa-77M-MTR),\n' # noqa: E501 + ')') + + x = torch.randn(10, 16) + edge_index = torch.tensor([ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 0], + ]) + edge_attr = torch.randn(edge_index.size(1), 16) + batch = torch.zeros(x.size(0), dtype=torch.long) + smiles = ['CCCCCCCCCC'] + instructions = ['What is ∼ functional related to?'] + label = ['I do not know!'] + + # Test train: + loss = model(x, edge_index, batch, edge_attr, smiles, instructions, label) + assert loss >= 0 + + # Test inference: + pred = model.inference(x, edge_index, batch, edge_attr, smiles, + instructions) + assert len(pred) == 1 diff --git a/test/nn/nlp/test_vision_transformer.py b/test/nn/nlp/test_vision_transformer.py new file mode 100644 index 000000000000..7500ebc7fd0e --- /dev/null +++ b/test/nn/nlp/test_vision_transformer.py @@ -0,0 +1,26 @@ +import torch + +from torch_geometric.nn.nlp import VisionTransformer +from torch_geometric.testing import onlyFullTest, withCUDA, withPackage + + +@withCUDA +@onlyFullTest +@withPackage('transformers') +def test_vision_transformer(device): + model = VisionTransformer( + model_name='microsoft/swin-base-patch4-window7-224', ).to(device) + assert model.device == device + assert str( + model + ) == 'VisionTransformer(model_name=microsoft/swin-base-patch4-window7-224)' + + images = torch.randn(2, 3, 224, 224).to(device) + + out = model(images) + assert out.device == device + assert out.size() == (2, 49, 1024) + + out = model(images, output_device='cpu') + assert out.is_cpu + assert out.size() == (2, 49, 1024) diff --git a/test/profile/test_nvtx.py b/test/profile/test_nvtx.py new file mode 100644 index 000000000000..56e28a9c2e59 --- /dev/null +++ b/test/profile/test_nvtx.py @@ -0,0 +1,136 @@ +from unittest.mock import call, patch + +from torch_geometric.profile import nvtxit + + +def _setup_mock(torch_cuda_mock): + torch_cuda_mock.is_available.return_value = True + torch_cuda_mock.cudart.return_value.cudaProfilerStart.return_value = None + torch_cuda_mock.cudart.return_value.cudaProfilerStop.return_value = None + return torch_cuda_mock + + +@patch('torch_geometric.profile.nvtx.torch.cuda') +def test_nvtxit_base(torch_cuda_mock): + torch_cuda_mock = _setup_mock(torch_cuda_mock) + + # dummy func calls a calls b + + @nvtxit() + def call_b(): + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + return 42 + + @nvtxit() + def call_a(): + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + return call_b() + + def dummy_func(): + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + return call_a() + + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + dummy_func() + + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 1 # noqa: E501 + assert torch_cuda_mock.nvtx.range_push.call_args_list == [ + call('call_a_0'), call('call_b_0') + ] + + +@patch('torch_geometric.profile.nvtx.torch.cuda') +def test_nvtxit_rename(torch_cuda_mock): + torch_cuda_mock = _setup_mock(torch_cuda_mock) + + # dummy func calls a calls b + + @nvtxit() + def call_b(): + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + return 42 + + @nvtxit('a_nvtx') + def call_a(): + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + return call_b() + + def dummy_func(): + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + return call_a() + + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + dummy_func() + + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 1 # noqa: E501 + assert torch_cuda_mock.nvtx.range_push.call_args_list == [ + call('a_nvtx_0'), call('call_b_0') + ] + + +@patch('torch_geometric.profile.nvtx.torch.cuda') +def test_nvtxit_iters(torch_cuda_mock): + torch_cuda_mock = _setup_mock(torch_cuda_mock) + + # dummy func calls a calls b + + @nvtxit(n_iters=1) + def call_b(): + return 42 + + @nvtxit() + def call_a(): + return call_b() + + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + + call_b() + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 1 # noqa: E501 + call_a() + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 2 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 2 # noqa: E501 + + assert torch_cuda_mock.nvtx.range_push.call_args_list == [ + call('call_b_0'), call('call_a_0') + ] + + +@patch('torch_geometric.profile.nvtx.torch.cuda') +def test_nvtxit_warmups(torch_cuda_mock): + torch_cuda_mock = _setup_mock(torch_cuda_mock) + + # dummy func calls a calls b + + @nvtxit(n_warmups=1) + def call_b(): + return 42 + + @nvtxit() + def call_a(): + return call_b() + + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + + call_b() + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0 # noqa: E501 + call_a() + assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1 # noqa: E501 + assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 1 # noqa: E501 + + assert torch_cuda_mock.nvtx.range_push.call_args_list == [ + call('call_a_0'), call('call_b_1') + ] diff --git a/test/sampler/test_sampler_base.py b/test/sampler/test_sampler_base.py index dc8142176bf6..41a7da25534f 100644 --- a/test/sampler/test_sampler_base.py +++ b/test/sampler/test_sampler_base.py @@ -49,6 +49,9 @@ def test_heterogeneous_num_neighbors_dict_and_default(): num_neighbors = NumNeighbors({('A', 'B'): [25, 10]}, default=[-1, -1]) + with pytest.raises(ValueError, match="Not all edge types"): + num_neighbors.get_values([('A', 'C'), ('B', 'A')]) + values = num_neighbors.get_values([('A', 'B'), ('B', 'A')]) assert values == {('A', 'B'): [25, 10], ('B', 'A'): [-1, -1]} diff --git a/torch_geometric/data/__init__.py b/torch_geometric/data/__init__.py index 821ef9c5c063..fee215b1a357 100644 --- a/torch_geometric/data/__init__.py +++ b/torch_geometric/data/__init__.py @@ -16,6 +16,7 @@ from .makedirs import makedirs from .download import download_url, download_google_url from .extract import extract_tar, extract_zip, extract_bz2, extract_gz +from .large_graph_indexer import LargeGraphIndexer, TripletLike, get_features_for_triplets, get_features_for_triplets_groups from torch_geometric.lazy_loader import LazyLoader @@ -27,6 +28,8 @@ 'Dataset', 'InMemoryDataset', 'OnDiskDataset', + 'LargeGraphIndexer', + 'TripletLike', ] remote_backend_classes = [ @@ -50,6 +53,8 @@ 'extract_zip', 'extract_bz2', 'extract_gz', + 'get_features_for_triplets', + "get_features_for_triplets_groups", ] __all__ = data_classes + remote_backend_classes + helper_functions diff --git a/torch_geometric/data/dataset.py b/torch_geometric/data/dataset.py index dd5239eec7c2..9df4359f1450 100644 --- a/torch_geometric/data/dataset.py +++ b/torch_geometric/data/dataset.py @@ -383,7 +383,7 @@ def to_datapipe(self) -> Any: r"""Converts the dataset into a :class:`torch.utils.data.DataPipe`. The returned instance can then be used with :pyg:`PyG's` built-in - :class:`DataPipes` for baching graphs as follows: + :class:`DataPipes` for batching graphs as follows: .. code-block:: python diff --git a/torch_geometric/data/large_graph_indexer.py b/torch_geometric/data/large_graph_indexer.py new file mode 100644 index 000000000000..0644e2543303 --- /dev/null +++ b/torch_geometric/data/large_graph_indexer.py @@ -0,0 +1,677 @@ +import os +import pickle as pkl +import shutil +from dataclasses import dataclass +from itertools import chain +from typing import ( + Any, + Callable, + Dict, + Hashable, + Iterable, + Iterator, + List, + Optional, + Sequence, + Set, + Tuple, + Union, +) + +import torch +from torch import Tensor +from tqdm import tqdm + +from torch_geometric.data import Data +from torch_geometric.typing import WITH_PT24 + +TripletLike = Tuple[Hashable, Hashable, Hashable] + +KnowledgeGraphLike = Iterable[TripletLike] + + +def ordered_set(values: Iterable[Hashable]) -> List[Hashable]: + return list(dict.fromkeys(values)) + + +# TODO: Refactor Node and Edge funcs and attrs to be accessible via an Enum? + +NODE_PID = "pid" + +NODE_KEYS = {NODE_PID} + +EDGE_PID = "e_pid" +EDGE_HEAD = "h" +EDGE_RELATION = "r" +EDGE_TAIL = "t" +EDGE_INDEX = "edge_idx" + +EDGE_KEYS = {EDGE_PID, EDGE_HEAD, EDGE_RELATION, EDGE_TAIL, EDGE_INDEX} + +FeatureValueType = Union[Sequence[Any], Tensor] + + +@dataclass +class MappedFeature: + name: str + values: FeatureValueType + + def __eq__(self, value: "MappedFeature") -> bool: + eq = self.name == value.name + if isinstance(self.values, torch.Tensor): + eq &= torch.equal(self.values, value.values) + else: + eq &= self.values == value.values + return eq + + +if WITH_PT24: + torch.serialization.add_safe_globals([MappedFeature]) + + +class LargeGraphIndexer: + """For a dataset that consists of mulitiple subgraphs that are assumed to + be part of a much larger graph, collate the values into a large graph store + to save resources. + """ + def __init__( + self, + nodes: Iterable[Hashable], + edges: KnowledgeGraphLike, + node_attr: Optional[Dict[str, List[Any]]] = None, + edge_attr: Optional[Dict[str, List[Any]]] = None, + ) -> None: + r"""Constructs a new index that uniquely catalogs each node and edge + by id. Not meant to be used directly. + + Args: + nodes (Iterable[Hashable]): Node ids in the graph. + edges (KnowledgeGraphLike): Edge ids in the graph. + node_attr (Optional[Dict[str, List[Any]]], optional): Mapping node + attribute name and list of their values in order of unique node + ids. Defaults to None. + edge_attr (Optional[Dict[str, List[Any]]], optional): Mapping edge + attribute name and list of their values in order of unique edge + ids. Defaults to None. + """ + self._nodes: Dict[Hashable, int] = dict() + self._edges: Dict[TripletLike, int] = dict() + + self._mapped_node_features: Set[str] = set() + self._mapped_edge_features: Set[str] = set() + + if len(nodes) != len(set(nodes)): + raise AttributeError("Nodes need to be unique") + if len(edges) != len(set(edges)): + raise AttributeError("Edges need to be unique") + + if node_attr is not None: + # TODO: Validity checks btw nodes and node_attr + self.node_attr = node_attr + if NODE_KEYS & set(self.node_attr.keys()) != NODE_KEYS: + raise AttributeError( + "Invalid node_attr object. Missing " + + f"{NODE_KEYS - set(self.node_attr.keys())}") + elif self.node_attr[NODE_PID] != nodes: + raise AttributeError( + "Nodes provided do not match those in node_attr") + else: + self.node_attr = dict() + self.node_attr[NODE_PID] = nodes + + for i, node in enumerate(self.node_attr[NODE_PID]): + self._nodes[node] = i + + if edge_attr is not None: + # TODO: Validity checks btw edges and edge_attr + self.edge_attr = edge_attr + + if EDGE_KEYS & set(self.edge_attr.keys()) != EDGE_KEYS: + raise AttributeError( + "Invalid edge_attr object. Missing " + + f"{EDGE_KEYS - set(self.edge_attr.keys())}") + elif self.node_attr[EDGE_PID] != edges: + raise AttributeError( + "Edges provided do not match those in edge_attr") + + else: + self.edge_attr = dict() + for default_key in EDGE_KEYS: + self.edge_attr[default_key] = list() + self.edge_attr[EDGE_PID] = edges + + for i, tup in enumerate(edges): + h, r, t = tup + self.edge_attr[EDGE_HEAD].append(h) + self.edge_attr[EDGE_RELATION].append(r) + self.edge_attr[EDGE_TAIL].append(t) + self.edge_attr[EDGE_INDEX].append( + (self._nodes[h], self._nodes[t])) + + for i, tup in enumerate(edges): + self._edges[tup] = i + + @classmethod + def from_triplets( + cls, + triplets: KnowledgeGraphLike, + pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None, + ) -> "LargeGraphIndexer": + r"""Generate a new index from a series of triplets that represent edge + relations between nodes. + Formatted like (source_node, edge, dest_node). + + Args: + triplets (KnowledgeGraphLike): Series of triplets representing + knowledge graph relations. + pre_transform (Optional[Callable[[TripletLike], TripletLike]]): + Optional preprocessing function to apply to triplets. + Defaults to None. + + Returns: + LargeGraphIndexer: Index of unique nodes and edges. + """ + # NOTE: Right now assumes that all trips can be loaded into memory + nodes = set() + edges = set() + + if pre_transform is not None: + + def apply_transform( + trips: KnowledgeGraphLike) -> Iterator[TripletLike]: + for trip in trips: + yield pre_transform(trip) + + triplets = apply_transform(triplets) + + for h, r, t in triplets: + + for node in (h, t): + nodes.add(node) + + edge_idx = (h, r, t) + edges.add(edge_idx) + + return cls(list(nodes), list(edges)) + + @classmethod + def collate(cls, + graphs: Iterable["LargeGraphIndexer"]) -> "LargeGraphIndexer": + r"""Combines a series of large graph indexes into a single large graph + index. + + Args: + graphs (Iterable["LargeGraphIndexer"]): Indices to be + combined. + + Returns: + LargeGraphIndexer: Singular unique index for all nodes and edges + in input indices. + """ + # FIXME Needs to merge node attrs and edge attrs? + trips = chain.from_iterable([graph.to_triplets() for graph in graphs]) + return cls.from_triplets(trips) + + def get_unique_node_features( + self, feature_name: str = NODE_PID) -> List[Hashable]: + r"""Get all the unique values for a specific node attribute. + + Args: + feature_name (str, optional): Name of feature to get. + Defaults to NODE_PID. + + Returns: + List[Hashable]: List of unique values for the specified feature. + """ + try: + if feature_name in self._mapped_node_features: + raise IndexError( + "Only non-mapped features can be retrieved uniquely.") + return ordered_set(self.get_node_features(feature_name)) + + except KeyError: + raise AttributeError( + f"Nodes do not have a feature called {feature_name}") + + def add_node_feature( + self, + new_feature_name: str, + new_feature_vals: FeatureValueType, + map_from_feature: str = NODE_PID, + ) -> None: + r"""Adds a new feature that corresponds to each unique node in + the graph. + + Args: + new_feature_name (str): Name to call the new feature. + new_feature_vals (FeatureValueType): Values to map for that + new feature. + map_from_feature (str, optional): Key of feature to map from. + Size must match the number of feature values. + Defaults to NODE_PID. + """ + if new_feature_name in self.node_attr: + raise AttributeError("Features cannot be overridden once created") + if map_from_feature in self._mapped_node_features: + raise AttributeError( + f"{map_from_feature} is already a feature mapping.") + + feature_keys = self.get_unique_node_features(map_from_feature) + if len(feature_keys) != len(new_feature_vals): + raise AttributeError( + "Expected encodings for {len(feature_keys)} unique features," + + f" but got {len(new_feature_vals)} encodings.") + + if map_from_feature == NODE_PID: + self.node_attr[new_feature_name] = new_feature_vals + else: + self.node_attr[new_feature_name] = MappedFeature( + name=map_from_feature, values=new_feature_vals) + self._mapped_node_features.add(new_feature_name) + + def get_node_features( + self, + feature_name: str = NODE_PID, + pids: Optional[Iterable[Hashable]] = None, + ) -> List[Any]: + r"""Get node feature values for a given set of unique node ids. + Returned values are not necessarily unique. + + Args: + feature_name (str, optional): Name of feature to fetch. Defaults + to NODE_PID. + pids (Optional[Iterable[Hashable]], optional): Node ids to fetch + for. Defaults to None, which fetches all nodes. + + Returns: + List[Any]: Node features corresponding to the specified ids. + """ + if feature_name in self._mapped_node_features: + values = self.node_attr[feature_name].values + else: + values = self.node_attr[feature_name] + + # TODO: torch_geometric.utils.select + if isinstance(values, torch.Tensor): + idxs = list( + self.get_node_features_iter(feature_name, pids, + index_only=True)) + return values[idxs] + return list(self.get_node_features_iter(feature_name, pids)) + + def get_node_features_iter( + self, + feature_name: str = NODE_PID, + pids: Optional[Iterable[Hashable]] = None, + index_only: bool = False, + ) -> Iterator[Any]: + """Iterator version of get_node_features. If index_only is True, + yields indices instead of values. + """ + if pids is None: + pids = self.node_attr[NODE_PID] + + if feature_name in self._mapped_node_features: + feature_map_info = self.node_attr[feature_name] + from_feature_name, to_feature_vals = ( + feature_map_info.name, + feature_map_info.values, + ) + from_feature_vals = self.get_unique_node_features( + from_feature_name) + feature_mapping = {k: i for i, k in enumerate(from_feature_vals)} + + for pid in pids: + idx = self._nodes[pid] + from_feature_val = self.node_attr[from_feature_name][idx] + to_feature_idx = feature_mapping[from_feature_val] + if index_only: + yield to_feature_idx + else: + yield to_feature_vals[to_feature_idx] + else: + for pid in pids: + idx = self._nodes[pid] + if index_only: + yield idx + else: + yield self.node_attr[feature_name][idx] + + def get_unique_edge_features( + self, feature_name: str = EDGE_PID) -> List[Hashable]: + r"""Get all the unique values for a specific edge attribute. + + Args: + feature_name (str, optional): Name of feature to get. + Defaults to EDGE_PID. + + Returns: + List[Hashable]: List of unique values for the specified feature. + """ + try: + if feature_name in self._mapped_edge_features: + raise IndexError( + "Only non-mapped features can be retrieved uniquely.") + return ordered_set(self.get_edge_features(feature_name)) + except KeyError: + raise AttributeError( + f"Edges do not have a feature called {feature_name}") + + def add_edge_feature( + self, + new_feature_name: str, + new_feature_vals: FeatureValueType, + map_from_feature: str = EDGE_PID, + ) -> None: + r"""Adds a new feature that corresponds to each unique edge in + the graph. + + Args: + new_feature_name (str): Name to call the new feature. + new_feature_vals (FeatureValueType): Values to map for that new + feature. + map_from_feature (str, optional): Key of feature to map from. + Size must match the number of feature values. + Defaults to EDGE_PID. + """ + if new_feature_name in self.edge_attr: + raise AttributeError("Features cannot be overridden once created") + if map_from_feature in self._mapped_edge_features: + raise AttributeError( + f"{map_from_feature} is already a feature mapping.") + + feature_keys = self.get_unique_edge_features(map_from_feature) + if len(feature_keys) != len(new_feature_vals): + raise AttributeError( + f"Expected encodings for {len(feature_keys)} unique features, " + + f"but got {len(new_feature_vals)} encodings.") + + if map_from_feature == EDGE_PID: + self.edge_attr[new_feature_name] = new_feature_vals + else: + self.edge_attr[new_feature_name] = MappedFeature( + name=map_from_feature, values=new_feature_vals) + self._mapped_edge_features.add(new_feature_name) + + def get_edge_features( + self, + feature_name: str = EDGE_PID, + pids: Optional[Iterable[Hashable]] = None, + ) -> List[Any]: + r"""Get edge feature values for a given set of unique edge ids. + Returned values are not necessarily unique. + + Args: + feature_name (str, optional): Name of feature to fetch. + Defaults to EDGE_PID. + pids (Optional[Iterable[Hashable]], optional): Edge ids to fetch + for. Defaults to None, which fetches all edges. + + Returns: + List[Any]: Node features corresponding to the specified ids. + """ + if feature_name in self._mapped_edge_features: + values = self.edge_attr[feature_name].values + else: + values = self.edge_attr[feature_name] + + # TODO: torch_geometric.utils.select + if isinstance(values, torch.Tensor): + idxs = list( + self.get_edge_features_iter(feature_name, pids, + index_only=True)) + return values[idxs] + return list(self.get_edge_features_iter(feature_name, pids)) + + def get_edge_features_iter( + self, + feature_name: str = EDGE_PID, + pids: Optional[KnowledgeGraphLike] = None, + index_only: bool = False, + ) -> Iterator[Any]: + """Iterator version of get_edge_features. If index_only is True, + yields indices instead of values. + """ + if pids is None: + pids = self.edge_attr[EDGE_PID] + + if feature_name in self._mapped_edge_features: + feature_map_info = self.edge_attr[feature_name] + from_feature_name, to_feature_vals = ( + feature_map_info.name, + feature_map_info.values, + ) + from_feature_vals = self.get_unique_edge_features( + from_feature_name) + feature_mapping = {k: i for i, k in enumerate(from_feature_vals)} + + for pid in pids: + idx = self._edges[pid] + from_feature_val = self.edge_attr[from_feature_name][idx] + to_feature_idx = feature_mapping[from_feature_val] + if index_only: + yield to_feature_idx + else: + yield to_feature_vals[to_feature_idx] + else: + for pid in pids: + idx = self._edges[pid] + if index_only: + yield idx + else: + yield self.edge_attr[feature_name][idx] + + def to_triplets(self) -> Iterator[TripletLike]: + return iter(self.edge_attr[EDGE_PID]) + + def save(self, path: str) -> None: + if os.path.exists(path): + shutil.rmtree(path) + os.makedirs(path, exist_ok=True) + with open(path + "/edges", "wb") as f: + pkl.dump(self._edges, f) + with open(path + "/nodes", "wb") as f: + pkl.dump(self._nodes, f) + + with open(path + "/mapped_edges", "wb") as f: + pkl.dump(self._mapped_edge_features, f) + with open(path + "/mapped_nodes", "wb") as f: + pkl.dump(self._mapped_node_features, f) + + node_attr_path = path + "/node_attr" + os.makedirs(node_attr_path, exist_ok=True) + for attr_name, vals in self.node_attr.items(): + torch.save(vals, node_attr_path + f"/{attr_name}.pt") + + edge_attr_path = path + "/edge_attr" + os.makedirs(edge_attr_path, exist_ok=True) + for attr_name, vals in self.edge_attr.items(): + torch.save(vals, edge_attr_path + f"/{attr_name}.pt") + + @classmethod + def from_disk(cls, path: str) -> "LargeGraphIndexer": + indexer = cls(list(), list()) + with open(path + "/edges", "rb") as f: + indexer._edges = pkl.load(f) + with open(path + "/nodes", "rb") as f: + indexer._nodes = pkl.load(f) + + with open(path + "/mapped_edges", "rb") as f: + indexer._mapped_edge_features = pkl.load(f) + with open(path + "/mapped_nodes", "rb") as f: + indexer._mapped_node_features = pkl.load(f) + + node_attr_path = path + "/node_attr" + for fname in os.listdir(node_attr_path): + full_fname = f"{node_attr_path}/{fname}" + key = fname.split(".")[0] + indexer.node_attr[key] = torch.load(full_fname) + + edge_attr_path = path + "/edge_attr" + for fname in os.listdir(edge_attr_path): + full_fname = f"{edge_attr_path}/{fname}" + key = fname.split(".")[0] + indexer.edge_attr[key] = torch.load(full_fname) + + return indexer + + def to_data(self, node_feature_name: str, + edge_feature_name: Optional[str] = None) -> Data: + """Return a Data object containing all the specified node and + edge features and the graph. + + Args: + node_feature_name (str): Feature to use for nodes + edge_feature_name (Optional[str], optional): Feature to use for + edges. Defaults to None. + + Returns: + Data: Data object containing the specified node and + edge features and the graph. + """ + x = torch.Tensor(self.get_node_features(node_feature_name)) + node_id = torch.LongTensor(range(len(x))) + + edge_index = torch.t( + torch.LongTensor(self.get_edge_features(EDGE_INDEX))) + + edge_attr = (self.get_edge_features(edge_feature_name) + if edge_feature_name is not None else None) + edge_id = torch.LongTensor(range(len(edge_attr))) + + return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, + edge_id=edge_id, node_id=node_id) + + def __eq__(self, value: "LargeGraphIndexer") -> bool: + eq = True + eq &= self._nodes == value._nodes + eq &= self._edges == value._edges + eq &= self.node_attr.keys() == value.node_attr.keys() + eq &= self.edge_attr.keys() == value.edge_attr.keys() + eq &= self._mapped_node_features == value._mapped_node_features + eq &= self._mapped_edge_features == value._mapped_edge_features + + for k in self.node_attr: + eq &= isinstance(self.node_attr[k], type(value.node_attr[k])) + if isinstance(self.node_attr[k], torch.Tensor): + eq &= torch.equal(self.node_attr[k], value.node_attr[k]) + else: + eq &= self.node_attr[k] == value.node_attr[k] + for k in self.edge_attr: + eq &= isinstance(self.edge_attr[k], type(value.edge_attr[k])) + if isinstance(self.edge_attr[k], torch.Tensor): + eq &= torch.equal(self.edge_attr[k], value.edge_attr[k]) + else: + eq &= self.edge_attr[k] == value.edge_attr[k] + return eq + + +def get_features_for_triplets_groups( + indexer: LargeGraphIndexer, + triplet_groups: Iterable[KnowledgeGraphLike], + node_feature_name: str = "x", + edge_feature_name: str = "edge_attr", + pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None, + verbose: bool = False, +) -> Iterator[Data]: + """Given an indexer and a series of triplet groups (like a dataset), + retrieve the specified node and edge features for each triplet from the + index. + + Args: + indexer (LargeGraphIndexer): Indexer containing desired features + triplet_groups (Iterable[KnowledgeGraphLike]): List of lists of + triplets to fetch features for + node_feature_name (str, optional): Node feature to fetch. + Defaults to "x". + edge_feature_name (str, optional): edge feature to fetch. + Defaults to "edge_attr". + pre_transform (Optional[Callable[[TripletLike], TripletLike]]): + Optional preprocessing to perform on triplets. + Defaults to None. + verbose (bool, optional): Whether to print progress. Defaults to False. + + Yields: + Iterator[Data]: For each triplet group, yield a data object containing + the unique graph and features from the index. + """ + if pre_transform is not None: + + def apply_transform(trips): + for trip in trips: + yield pre_transform(tuple(trip)) + + # TODO: Make this safe for large amounts of triplets? + triplet_groups = (list(apply_transform(triplets)) + for triplets in triplet_groups) + + node_keys = [] + edge_keys = [] + edge_index = [] + + for triplets in tqdm(triplet_groups, disable=not verbose): + small_graph_indexer = LargeGraphIndexer.from_triplets( + triplets, pre_transform=pre_transform) + + node_keys.append(small_graph_indexer.get_node_features()) + edge_keys.append(small_graph_indexer.get_edge_features(pids=triplets)) + edge_index.append( + small_graph_indexer.get_edge_features(EDGE_INDEX, triplets)) + + node_feats = indexer.get_node_features(feature_name=node_feature_name, + pids=chain.from_iterable(node_keys)) + edge_feats = indexer.get_edge_features(feature_name=edge_feature_name, + pids=chain.from_iterable(edge_keys)) + + last_node_idx, last_edge_idx = 0, 0 + for (nkeys, ekeys, eidx) in zip(node_keys, edge_keys, edge_index): + nlen, elen = len(nkeys), len(ekeys) + x = torch.Tensor(node_feats[last_node_idx:last_node_idx + nlen]) + last_node_idx += len(nkeys) + + edge_attr = torch.Tensor(edge_feats[last_edge_idx:last_edge_idx + + elen]) + last_edge_idx += len(ekeys) + + edge_idx = torch.LongTensor(eidx).T + + data_obj = Data(x=x, edge_attr=edge_attr, edge_index=edge_idx) + data_obj[NODE_PID] = node_keys + data_obj[EDGE_PID] = edge_keys + data_obj["node_idx"] = [indexer._nodes[k] for k in nkeys] + data_obj["edge_idx"] = [indexer._edges[e] for e in ekeys] + + yield data_obj + + +def get_features_for_triplets( + indexer: LargeGraphIndexer, + triplets: KnowledgeGraphLike, + node_feature_name: str = "x", + edge_feature_name: str = "edge_attr", + pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None, + verbose: bool = False, +) -> Data: + """For a given set of triplets retrieve a Data object containing the + unique graph and features from the index. + + Args: + indexer (LargeGraphIndexer): Indexer containing desired features + triplets (KnowledgeGraphLike): Triplets to fetch features for + node_feature_name (str, optional): Feature to use for node features. + Defaults to "x". + edge_feature_name (str, optional): Feature to use for edge features. + Defaults to "edge_attr". + pre_transform (Optional[Callable[[TripletLike], TripletLike]]): + Optional preprocessing function for triplets. Defaults to None. + verbose (bool, optional): Whether to print progress. Defaults to False. + + Returns: + Data: Data object containing the unique graph and features from the + index for the given triplets. + """ + gen = get_features_for_triplets_groups(indexer, [triplets], + node_feature_name, + edge_feature_name, pre_transform, + verbose) + return next(gen) diff --git a/torch_geometric/datasets/__init__.py b/torch_geometric/datasets/__init__.py index 96d51032d818..12895ad1dbac 100644 --- a/torch_geometric/datasets/__init__.py +++ b/torch_geometric/datasets/__init__.py @@ -77,6 +77,9 @@ from .brca_tgca import BrcaTcga from .neurograph import NeuroGraphDataset from .web_qsp_dataset import WebQSPDataset +from .git_mol_dataset import GitMolDataset +from .molecule_gpt_dataset import MoleculeGPTDataset +from .tag_dataset import TAGDataset from .dbp15k import DBP15K from .aminer import AMiner @@ -190,6 +193,9 @@ 'BrcaTcga', 'NeuroGraphDataset', 'WebQSPDataset', + 'GitMolDataset', + 'MoleculeGPTDataset', + 'TAGDataset', ] hetero_datasets = [ diff --git a/torch_geometric/datasets/git_mol_dataset.py b/torch_geometric/datasets/git_mol_dataset.py new file mode 100644 index 000000000000..4b7cfa78117c --- /dev/null +++ b/torch_geometric/datasets/git_mol_dataset.py @@ -0,0 +1,263 @@ +import sys +from typing import Any, Callable, Dict, List, Optional + +import numpy as np +import torch +from tqdm import tqdm + +from torch_geometric.data import ( + Data, + InMemoryDataset, + download_google_url, + extract_zip, +) +from torch_geometric.io import fs + + +def safe_index(lst: List[Any], e: int) -> int: + return lst.index(e) if e in lst else len(lst) - 1 + + +class GitMolDataset(InMemoryDataset): + r"""The dataset from the `"GIT-Mol: A Multi-modal Large Language Model + for Molecular Science with Graph, Image, and Text" + `_ paper. + + Args: + root (str): Root directory where the dataset should be saved. + transform (callable, optional): A function/transform that takes in an + :obj:`torch_geometric.data.Data` object and returns a transformed + version. The data object will be transformed before every access. + (default: :obj:`None`) + pre_transform (callable, optional): A function/transform that takes in + an :obj:`torch_geometric.data.Data` object and returns a + transformed version. The data object will be transformed before + being saved to disk. (default: :obj:`None`) + pre_filter (callable, optional): A function that takes in an + :obj:`torch_geometric.data.Data` object and returns a boolean + value, indicating whether the data object should be included in the + final dataset. (default: :obj:`None`) + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) + split (int, optional): Datasets split, train/valid/test=0/1/2. + (default: :obj:`0`) + """ + + raw_url_id = '1loBXabD6ncAFY-vanRsVtRUSFkEtBweg' + + def __init__( + self, + root: str, + transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, + pre_filter: Optional[Callable] = None, + force_reload: bool = False, + split: int = 0, + ): + from torchvision import transforms + + self.split = split + + if self.split == 0: + self.img_transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.RandomRotation(15), + transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + else: + self.img_transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + + super().__init__(root, transform, pre_transform, pre_filter, + force_reload=force_reload) + + self.load(self.processed_paths[0]) + + @property + def raw_file_names(self) -> List[str]: + return ['train_3500.pkl', 'valid_450.pkl', 'test_450.pkl'] + + @property + def processed_file_names(self) -> str: + return ['train.pt', 'valid.pt', 'test.pt'][self.split] + + def download(self) -> None: + file_path = download_google_url( + self.raw_url_id, + self.raw_dir, + 'gitmol.zip', + ) + extract_zip(file_path, self.raw_dir) + + def process(self) -> None: + import pandas as pd + from PIL import Image + + try: + from rdkit import Chem, RDLogger + RDLogger.DisableLog('rdApp.*') # type: ignore + WITH_RDKIT = True + + except ImportError: + WITH_RDKIT = False + + if not WITH_RDKIT: + print(("Using a pre-processed version of the dataset. Please " + "install 'rdkit' to alternatively process the raw data."), + file=sys.stderr) + + data_list = fs.torch_load(self.raw_paths[0]) + data_list = [Data(**data_dict) for data_dict in data_list] + + if self.pre_filter is not None: + data_list = [d for d in data_list if self.pre_filter(d)] + + if self.pre_transform is not None: + data_list = [self.pre_transform(d) for d in data_list] + + self.save(data_list, self.processed_paths[0]) + return + + allowable_features: Dict[str, List[Any]] = { + 'possible_atomic_num_list': + list(range(1, 119)) + ['misc'], + 'possible_formal_charge_list': + [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 'misc'], + 'possible_chirality_list': [ + Chem.rdchem.ChiralType.CHI_UNSPECIFIED, + Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, + Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, + Chem.rdchem.ChiralType.CHI_OTHER + ], + 'possible_hybridization_list': [ + Chem.rdchem.HybridizationType.SP, + Chem.rdchem.HybridizationType.SP2, + Chem.rdchem.HybridizationType.SP3, + Chem.rdchem.HybridizationType.SP3D, + Chem.rdchem.HybridizationType.SP3D2, + Chem.rdchem.HybridizationType.UNSPECIFIED, 'misc' + ], + 'possible_numH_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], + 'possible_implicit_valence_list': [0, 1, 2, 3, 4, 5, 6], + 'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'], + 'possible_number_radical_e_list': [0, 1, 2, 3, 4, 'misc'], + 'possible_is_aromatic_list': [False, True], + 'possible_is_in_ring_list': [False, True], + 'possible_bond_type_list': [ + Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, + Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC, + Chem.rdchem.BondType.ZERO + ], + 'possible_bond_dirs': [ # only for double bond stereo information + Chem.rdchem.BondDir.NONE, Chem.rdchem.BondDir.ENDUPRIGHT, + Chem.rdchem.BondDir.ENDDOWNRIGHT + ], + 'possible_bond_stereo_list': [ + Chem.rdchem.BondStereo.STEREONONE, + Chem.rdchem.BondStereo.STEREOZ, + Chem.rdchem.BondStereo.STEREOE, + Chem.rdchem.BondStereo.STEREOCIS, + Chem.rdchem.BondStereo.STEREOTRANS, + Chem.rdchem.BondStereo.STEREOANY, + ], + 'possible_is_conjugated_list': [False, True] + } + + data = pd.read_pickle( + f'{self.raw_dir}/igcdata_toy/{self.raw_file_names[self.split]}') + + data_list = [] + for _, r in tqdm(data.iterrows(), total=data.shape[0]): + smiles = r['isosmiles'] + mol = Chem.MolFromSmiles(smiles.strip('\n')) + if mol is not None: + # text + summary = r['summary'] + # image + cid = r['cid'] + img_file = f'{self.raw_dir}/igcdata_toy/imgs/CID_{cid}.png' + img = Image.open(img_file).convert('RGB') + img = self.img_transform(img).unsqueeze(0) + # graph + atom_features_list = [] + for atom in mol.GetAtoms(): # type: ignore + atom_feature = [ + safe_index( + allowable_features['possible_atomic_num_list'], + atom.GetAtomicNum()), + allowable_features['possible_chirality_list'].index( + atom.GetChiralTag()), + safe_index(allowable_features['possible_degree_list'], + atom.GetTotalDegree()), + safe_index( + allowable_features['possible_formal_charge_list'], + atom.GetFormalCharge()), + safe_index(allowable_features['possible_numH_list'], + atom.GetTotalNumHs()), + safe_index( + allowable_features[ + 'possible_number_radical_e_list'], + atom.GetNumRadicalElectrons()), + safe_index( + allowable_features['possible_hybridization_list'], + atom.GetHybridization()), + allowable_features['possible_is_aromatic_list'].index( + atom.GetIsAromatic()), + allowable_features['possible_is_in_ring_list'].index( + atom.IsInRing()), + ] + atom_features_list.append(atom_feature) + x = torch.tensor(np.array(atom_features_list), + dtype=torch.long) + + edges_list = [] + edge_features_list = [] + for bond in mol.GetBonds(): # type: ignore + i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() + edge_feature = [ + safe_index( + allowable_features['possible_bond_type_list'], + bond.GetBondType()), + allowable_features['possible_bond_stereo_list'].index( + bond.GetStereo()), + allowable_features['possible_is_conjugated_list']. + index(bond.GetIsConjugated()), + ] + edges_list.append((i, j)) + edge_features_list.append(edge_feature) + edges_list.append((j, i)) + edge_features_list.append(edge_feature) + + edge_index = torch.tensor( + np.array(edges_list).T, + dtype=torch.long, + ) + edge_attr = torch.tensor( + np.array(edge_features_list), + dtype=torch.long, + ) + + data = Data( + x=x, + edge_index=edge_index, + smiles=smiles, + edge_attr=edge_attr, + image=img, + caption=summary, + ) + + if self.pre_filter is not None and not self.pre_filter(data): + continue + if self.pre_transform is not None: + data = self.pre_transform(data) + + data_list.append(data) + + self.save(data_list, self.processed_paths[0]) diff --git a/torch_geometric/datasets/molecule_gpt_dataset.py b/torch_geometric/datasets/molecule_gpt_dataset.py new file mode 100644 index 000000000000..b1da09f38570 --- /dev/null +++ b/torch_geometric/datasets/molecule_gpt_dataset.py @@ -0,0 +1,480 @@ +import gzip +import json +import multiprocessing +import os +import sys +from collections import defaultdict +from multiprocessing import Pool +from typing import Callable, List, Optional, Tuple + +import numpy as np +import requests +import torch +from tqdm import tqdm + +from torch_geometric.data import Data, InMemoryDataset, download_url +from torch_geometric.io import fs +from torch_geometric.nn.nlp import LLM +from torch_geometric.utils import one_hot + + +def clean_up_description(description: str) -> str: + description = description + " " + + # extra adj Pure + if description.startswith("Pure "): + description = description.replace("Pure ", "") + # fix typo + if description.startswith("Mercurycombines"): + description = description.replace("Mercurycombines", + "Mercury combines") + + # a special case + description = description.replace( + "17-Hydroxy-6-methylpregna-3,6-diene-3,20-dione. ", + "17-Hydroxy-6-methylpregna-3,6-diene-3,20-dione is ") + + # a special case + description = description.replace("5-Thymidylic acid. ", + "5-Thymidylic acid. is ") + + # a special case + description = description.replace( + "5'-S-(3-Amino-3-carboxypropyl)-5'-thioadenosine. ", + "5'-S-(3-Amino-3-carboxypropyl)-5'-thioadenosine. is ") + + # a special case + description = description.replace( + ("Guanosine 5'-(trihydrogen diphosphate), monoanhydride" + " with phosphorothioic acid. "), + ("Guanosine 5'-(trihydrogen diphosphate), monoanhydride" + " with phosphorothioic acid is ")) + + # a special case + description = description.replace("5'-Uridylic acid. ", + "5'-Uridylic acid is ") + + # a special case + description = description.replace("5'-Adenylic acid, ", + "5'-Adenylic acid is ") + + # a special case + description = description.replace( + "Uridine 5'-(tetrahydrogen triphosphate). ", + "Uridine 5'-(tetrahydrogen triphosphate). is ") + + # a special case + description = description.replace("Inosine 5'-Monophosphate. ", + "Inosine 5'-Monophosphate. is ") + + # a special case + description = description.replace("Pivaloyloxymethyl butyrate (AN-9), ", + "Pivaloyloxymethyl butyrate (AN-9) is ") + + # a special case + description = description.replace( + "4-Amino-5-cyano-7-(D-ribofuranosyl)-7H- pyrrolo(2,3-d)pyrimidine. ", + "4-Amino-5-cyano-7-(D-ribofuranosyl)-7H- pyrrolo(2,3-d)pyrimidine is ") + + # a special case + description = description.replace( + "Cardamonin (also known as Dihydroxymethoxychalcone), ", + "Cardamonin (also known as Dihydroxymethoxychalcone) is ") + + # a special case + description = description.replace("Lithium has been used to treat ", + "Lithium is ") + + # a special case + description = description.replace("4,4'-Methylenebis ", + "4,4'-Methylenebis is ") + + # a special case + description = description.replace( + "2,3,7,8-Tetrachlorodibenzo-p-dioxin", + "2,3,7,8-Tetrachlorodibenzo-p-dioxin is ") + + # a special case + description = description.replace("Exposure to 2,4,5-trichlorophenol ", + "2,4,5-Trichlorophenol exposure ") + + index = 0 + L = len(description) + if description.startswith('C.I. '): + start_index = len('C.I. ') + elif description.startswith('Nectriapyrone. D '): + start_index = len('Nectriapyrone. D ') + elif description.startswith( + 'Salmonella enterica sv. Minnesota LPS core oligosaccharide'): + start_index = len( + 'Salmonella enterica sv. Minnesota LPS core oligosaccharide') + else: + start_index = 0 + for index in range(start_index, L - 1): + if index < L - 2: + if description[index] == '.' and description[ + index + 1] == ' ' and 'A' <= description[index + 2] <= 'Z': + break + elif index == L - 2: + break + + first_sentence = description[:index + 1] + return first_sentence + + +def extract_name(name_raw: str, description: str) -> Tuple[str, str, str]: + first_sentence = clean_up_description(description) + + splitter = ' -- -- ' + if ' are ' in first_sentence or ' were ' in first_sentence: + replaced_words = 'These molecules' + else: + replaced_words = 'This molecule' + + first_sentence = first_sentence.replace(' is ', splitter) + first_sentence = first_sentence.replace(' are ', splitter) + first_sentence = first_sentence.replace(' was ', splitter) + first_sentence = first_sentence.replace(' were ', splitter) + first_sentence = first_sentence.replace(' appears ', splitter) + first_sentence = first_sentence.replace(' occurs ', splitter) + first_sentence = first_sentence.replace(' stands for ', splitter) + first_sentence = first_sentence.replace(' belongs to ', splitter) + first_sentence = first_sentence.replace(' exists ', + splitter) # only for CID=11443 + first_sentence = first_sentence.replace(' has been used in trials ', + splitter) + first_sentence = first_sentence.replace(' has been investigated ', + splitter) + first_sentence = first_sentence.replace(' has many uses ', splitter) + + if splitter in first_sentence: + extracted_name = first_sentence.split(splitter, 1)[0] + elif first_sentence.startswith(name_raw): + extracted_name = name_raw + elif name_raw in first_sentence: + extracted_name = name_raw + extracted_name = None + print("=====", name_raw) + print("first sentence: ", first_sentence) + else: + extracted_name = None + + if extracted_name is not None: + extracted_description = description.replace(extracted_name, + replaced_words) + else: + extracted_description = description + + return extracted_name, extracted_description, first_sentence + + +class MoleculeGPTDataset(InMemoryDataset): + r"""The dataset from the `"MoleculeGPT: Instruction Following Large + Language Models for Molecular Property Prediction" + `_ paper. + + Args: + root (str): Root directory where the dataset should be saved. + transform (callable, optional): A function/transform that takes in an + :obj:`torch_geometric.data.Data` object and returns a transformed + version. The data object will be transformed before every access. + (default: :obj:`None`) + pre_transform (callable, optional): A function/transform that takes in + an :obj:`torch_geometric.data.Data` object and returns a + transformed version. The data object will be transformed before + being saved to disk. (default: :obj:`None`) + pre_filter (callable, optional): A function that takes in an + :obj:`torch_geometric.data.Data` object and returns a boolean + value, indicating whether the data object should be included in the + final dataset. (default: :obj:`None`) + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) + total_page_num (int, optional): The number of pages from PubChem. + (default: :obj:`10`) + total_block_num (int, optional): The blocks of SDF files from PubChem. + (default: :obj:`1`) + """ + description_url = ( + 'https://pubchem.ncbi.nlm.nih.gov/rest/pug_view/annotations/' + 'heading/json?heading_type=Compound&heading=Record+Description&page={}' + ) + compound_url = ('https://ftp.ncbi.nlm.nih.gov/pubchem/Compound/' + 'CURRENT-Full/SDF') + + def __init__( + self, + root: str, + transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, + pre_filter: Optional[Callable] = None, + force_reload: bool = False, + total_page_num: int = 10, + total_block_num: int = 1, + ): + self.total_page_num = total_page_num + self.total_block_num = total_block_num + + super().__init__(root, transform, pre_transform, pre_filter, + force_reload=force_reload) + self.load(self.processed_paths[0]) + + @property + def raw_file_names(self) -> List[str]: + return ['pubchem.csv'] + + @property + def processed_file_names(self) -> List[str]: + return ['data.pt'] + + def download(self) -> None: + # Step 01. Extract description + step1_folder = f"{self.raw_dir}/step_01_PubChemSTM_description" + if not os.path.exists(step1_folder): + os.makedirs(step1_folder) + valid_CID_set = set() + CID2name_raw, CID2name_extracted = defaultdict(list), defaultdict( + list) + CID2text_raw, CID2text_extracted = defaultdict(list), defaultdict( + list) + + for page_index in tqdm(range(self.total_page_num)): + page_num = page_index + 1 + f_out = open( + f"{step1_folder}/Compound_description_{page_num}.txt", "w") + + description_data = requests.get( + self.description_url.format(page_num)).json() + + description_data = description_data["Annotations"] + assert description_data["Page"] == page_num + + record_list = description_data["Annotation"] + + for record in record_list: + try: + CID = record["LinkedRecords"]["CID"][0] + if "Name" in record: + name_raw = record["Name"] + CID2name_raw[CID].append(name_raw) + else: + name_raw = None + + data_list = record["Data"] + for data in data_list: + description = data["Value"]["StringWithMarkup"][0][ + "String"].strip() + + extracted_name, extracted_description, _ = extract_name( # noqa: E501 + name_raw, description) + if extracted_name is not None: + CID2name_extracted[CID].append(extracted_name) + + CID2text_raw[CID].append(description) + CID2text_extracted[CID].append( + extracted_description) + + valid_CID_set.add(CID) + f_out.write(f"{CID}\n") + f_out.write(f"{extracted_description}\n\n") + except Exception: + continue + + valid_CID_list = sorted(list(valid_CID_set)) + print(f"Total CID (with raw name) {len(CID2name_raw)}") + print(f"Total CID (with extracted name) {len(CID2name_extracted)}") + print(f"Total CID {len(valid_CID_list)}") + + with open(f"{self.raw_dir}/CID2name_raw.json", "w") as f: + json.dump(CID2name_raw, f) + + with open(f"{self.raw_dir}/CID2name.json", "w") as f: + json.dump(CID2name_extracted, f) + + with open(f"{self.raw_dir}/CID2text_raw.json", "w") as f: + json.dump(CID2text_raw, f) + + with open(f"{self.raw_dir}/CID2text.json", "w") as f: + json.dump(CID2text_extracted, f) + + # Step 02. Download SDF Files + step2_folder = f"{self.raw_dir}/step_02_PubChemSTM_SDF" + if not os.path.exists(step2_folder): + for block_id in tqdm(range(self.total_block_num)): + block_size = 500000 + l_id = block_id * block_size + 1 + r_id = (block_id + 1) * block_size + + compound_file_name = f"Compound_{l_id:09d}_{r_id:09d}.sdf.gz" + download_url(f"{self.compound_url}/{compound_file_name}", + step2_folder) + + def process(self, use_mp: bool = False) -> None: + try: + from rdkit import Chem + from rdkit.Chem.rdchem import BondType as BT + WITH_RDKIT = True + + except ImportError: + WITH_RDKIT = False + + if not WITH_RDKIT: + print(("Using a pre-processed version of the dataset. Please " + "install 'rdkit' to alternatively process the raw data."), + file=sys.stderr) + + data_list = fs.torch_load(self.raw_paths[0]) + data_list = [Data(**data_dict) for data_dict in data_list] + + if self.pre_filter is not None: + data_list = [d for d in data_list if self.pre_filter(d)] + + if self.pre_transform is not None: + data_list = [self.pre_transform(d) for d in data_list] + + self.save(data_list, self.processed_paths[0]) + return + + # Step 03. Filter out SDF + step2_folder = f"{self.raw_dir}/step_02_PubChemSTM_SDF" + step3_folder = f"{self.raw_dir}/step_03_PubChemSTM_filtered" + if not os.path.exists(step3_folder): + os.makedirs(step3_folder) + with open(f"{self.raw_dir}/CID2text.json") as f: + CID2text = json.load(f) + target_CID_list = set(CID2text.keys()) + + block_size = 500000 + + def extract_one_SDF_file(block_id: int) -> None: + valid_mol_count = 0 + + writer = Chem.SDWriter( + f'{step3_folder}/filtered_{block_id}.sdf') + l_id = block_id * block_size + 1 + r_id = (block_id + 1) * block_size + + compound_file_name = f"Compound_{l_id:09d}_{r_id:09d}.sdf.gz" + gzip_loader = gzip.open(f"{step2_folder}/{compound_file_name}") + suppl = Chem.ForwardSDMolSupplier(gzip_loader) + + for mol in tqdm(suppl): + if mol is None: + continue + cid = mol.GetProp("PUBCHEM_COMPOUND_CID") + + if cid not in target_CID_list: + continue + + writer.write(mol) + valid_mol_count += 1 + + print(f"block id: {block_id}\nfound {valid_mol_count}\n\n") + sys.stdout.flush() + return + + if use_mp: + num_process = multiprocessing.cpu_count() + print(f"{num_process} CPUs") + num_process = 8 + p = Pool(num_process) + + block_id_list = np.arange(self.total_block_num) + with p: + p.map(extract_one_SDF_file, block_id_list) + else: + for block_id in range(self.total_block_num): + extract_one_SDF_file(block_id) + + # Step 04. Merge SDF + with open(f"{self.raw_dir}/CID2text.json") as f: + CID2text = json.load(f) + target_CID_list = set(CID2text.keys()) + print(f'The length of target_CID_list: {len(target_CID_list)}') + + writer = Chem.SDWriter(f'{self.raw_dir}/molecules.sdf') + + found_CID_set = set() + for block_id in range(self.total_block_num + 1): + compound_file_path = f"{step3_folder}/filtered_{block_id}.sdf" + try: + suppl = Chem.SDMolSupplier(compound_file_path) + + for mol in tqdm(suppl): + writer.write(mol) + cid = mol.GetProp("PUBCHEM_COMPOUND_CID") + found_CID_set.add(cid) + except Exception: + print(f"block id: {block_id} with 0 valid SDF file") + continue + + print(f"In total: {len(found_CID_set)} molecules") + + # Step 05. Convert to PyG data format + types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'Unknow': 5} + bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3} + + data_list = [] + # Real data + CID2text_file = f'{self.raw_dir}/CID2text.json' + + with open(CID2text_file) as f: + CID2text_data = json.load(f) + + suppl = Chem.SDMolSupplier(f'{self.raw_dir}/molecules.sdf') + + llm = LLM( + # model_name='lmsys/vicuna-7b-v1.5', + model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1', + num_params=1, + dtype=torch.bfloat16, + ) + prompt = ("Propose a question regarding the molecule '∼' " + "whose answer is: {}:") + for mol in tqdm(suppl): + if mol.HasProp('PUBCHEM_COMPOUND_CID'): + CID = mol.GetProp("PUBCHEM_COMPOUND_CID") + CAN_SMILES = mol.GetProp("PUBCHEM_OPENEYE_CAN_SMILES") + + m: Chem.Mol = Chem.MolFromSmiles(CAN_SMILES) + if m is None: + continue + RDKit_CAN_SMILES = Chem.MolToSmiles(m) + + ground_truth = CID2text_data[CID][0] + + instruction = llm.inference([prompt.format(ground_truth)])[0] + + x: torch.Tensor = torch.tensor([ + types[atom.GetSymbol()] if atom.GetSymbol() in types else 5 + for atom in m.GetAtoms() # type: ignore + ]) + x = one_hot(x, num_classes=len(types), dtype=torch.float) + + rows, cols, edge_types = [], [], [] + for bond in m.GetBonds(): # type: ignore + i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() + edge_types += [bonds[bond.GetBondType()]] * 2 + rows += [i, j] + cols += [j, i] + + edge_index = torch.tensor([rows, cols], dtype=torch.long) + edge_type = torch.tensor(edge_types, dtype=torch.long) + edge_attr = one_hot(edge_type, num_classes=len(bonds)) + + data = Data( + x=x, + edge_index=edge_index, + edge_attr=edge_attr, + smiles=RDKit_CAN_SMILES, + instruction=instruction, + y=ground_truth, + ) + + if self.pre_filter is not None and not self.pre_filter(data): + continue + if self.pre_transform is not None: + data = self.pre_transform(data) + + data_list.append(data) + + self.save(data_list, self.processed_paths[0]) diff --git a/torch_geometric/datasets/tag_dataset.py b/torch_geometric/datasets/tag_dataset.py new file mode 100644 index 000000000000..f25992ced989 --- /dev/null +++ b/torch_geometric/datasets/tag_dataset.py @@ -0,0 +1,350 @@ +import os +import os.path as osp +from collections.abc import Sequence +from typing import Dict, List, Optional, Union + +import numpy as np +import torch +from torch import Tensor +from tqdm import tqdm + +from torch_geometric.data import InMemoryDataset, download_google_url +from torch_geometric.data.data import BaseData + +try: + from pandas import DataFrame, read_csv + WITH_PANDAS = True +except ImportError: + WITH_PANDAS = False + +IndexType = Union[slice, Tensor, np.ndarray, Sequence] + + +class TAGDataset(InMemoryDataset): + r"""The Text Attributed Graph datasets from the + `"Learning on Large-scale Text-attributed Graphs via Variational Inference + " `_ paper. + This dataset is aiming on transform `ogbn products`, `ogbn arxiv` + into Text Attributed Graph that each node in graph is associate with a + raw text, that dataset can be adapt to DataLoader (for LM training) and + NeighborLoader(for GNN training). In addition, this class can be use as a + wrapper class by convert a InMemoryDataset with Tokenizer and text into + Text Attributed Graph. + + Args: + root (str): Root directory where the dataset should be saved. + dataset (InMemoryDataset): The name of the dataset + (:obj:`"ogbn-products"`, :obj:`"ogbn-arxiv"`). + tokenizer_name (str): The tokenizer name for language model, + Be sure to use same tokenizer name as your `model id` of model repo + on huggingface.co. + text (List[str]): list of raw text associate with node, the order of + list should be align with node list + split_idx (Optional[Dict[str, torch.Tensor]]): Optional dictionary, + for saving split index, it is required that if your dataset doesn't + have get_split_idx function + tokenize_batch_size (int): batch size of tokenizing text, the + tokenizing process will run on cpu, default: 256 + token_on_disk (bool): save token as .pt file on disk or not, + default: False + text_on_disk (bool): save given text(list of str) as dataframe on disk + or not, default: False + force_reload (bool): default: False + .. note:: + See `example/llm_plus_gnn/glem.py` for example usage + """ + raw_text_id = { + 'ogbn-arxiv': '1g3OOVhRyiyKv13LY6gbp8GLITocOUr_3', + 'ogbn-products': '1I-S176-W4Bm1iPDjQv3hYwQBtxE0v8mt' + } + + def __init__(self, root: str, dataset: InMemoryDataset, + tokenizer_name: str, text: Optional[List[str]] = None, + split_idx: Optional[Dict[str, Tensor]] = None, + tokenize_batch_size: int = 256, token_on_disk: bool = False, + text_on_disk: bool = False, + force_reload: bool = False) -> None: + # list the vars you want to pass in before run download & process + self.name = dataset.name + self.text = text + self.tokenizer_name = tokenizer_name + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + if self.tokenizer.pad_token_id is None: + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.dir_name = '_'.join(dataset.name.split('-')) + self.root = osp.join(root, self.dir_name) + missing_str_list = [] + if not WITH_PANDAS: + missing_str_list.append('pandas') + if len(missing_str_list) > 0: + missing_str = ' '.join(missing_str_list) + error_out = f"`pip install {missing_str}` to use this dataset." + raise ImportError(error_out) + if hasattr(dataset, 'get_idx_split'): + self.split_idx = dataset.get_idx_split() + elif split_idx is not None: + self.split_idx = split_idx + else: + raise ValueError("TAGDataset need split idx for generating " + "is_gold mask, please pass splited index " + "in format of dictionaty with 'train', 'valid' " + "'test' index tensor to 'split_idx'") + if text is not None and text_on_disk: + self.save_node_text(text) + self.text_on_disk = text_on_disk + # init will call download and process + super().__init__(self.root, transform=None, pre_transform=None, + pre_filter=None, force_reload=force_reload) + # after processing and download + # Dataset has to have BaseData as _data + assert dataset._data is not None + self._data = dataset._data # reassign reference + assert self._data is not None + assert dataset._data.y is not None + assert isinstance(self._data, BaseData) + assert self._data.num_nodes is not None + assert isinstance(dataset._data.num_nodes, int) + assert isinstance(self._data.num_nodes, int) + self._n_id = torch.arange(self._data.num_nodes) + is_good_tensor = self.load_gold_mask() + self._is_gold = is_good_tensor.squeeze() + self._data['is_gold'] = is_good_tensor + if self.text is not None and len(self.text) != self._data.num_nodes: + raise ValueError("The number of text sequence in 'text' should be " + "equal to number of nodes!") + self.token_on_disk = token_on_disk + self.tokenize_batch_size = tokenize_batch_size + self._token = self.tokenize_graph(self.tokenize_batch_size) + self.__num_classes__ = dataset.num_classes + + @property + def num_classes(self) -> int: + return self.__num_classes__ + + @property + def raw_file_names(self) -> List[str]: + file_names = [] + for root, _, files in os.walk(osp.join(self.root, 'raw')): + for file in files: + file_names.append(file) + return file_names + + @property + def processed_file_names(self) -> List[str]: + return [ + 'geometric_data_processed.pt', 'pre_filter.pt', + 'pre_transformed.pt' + ] + + @property + def token(self) -> Dict[str, Tensor]: + if self._token is None: # lazy load + self._token = self.tokenize_graph() + return self._token + + # load is_gold after init + @property + def is_gold(self) -> Tensor: + if self._is_gold is None: + print('lazy load is_gold!!') + self._is_gold = self.load_gold_mask() + return self._is_gold + + def get_n_id(self, node_idx: IndexType) -> Tensor: + if self._n_id is None: + assert self._data is not None + assert self._data.num_nodes is not None + assert isinstance(self._data.num_nodes, int) + self._n_id = torch.arange(self._data.num_nodes) + return self._n_id[node_idx] + + def load_gold_mask(self) -> Tensor: + r"""Use original train split as gold split, generating is_gold mask + for picking ground truth labels and pseudo labels. + """ + train_split_idx = self.get_idx_split()['train'] + assert self._data is not None + assert self._data.num_nodes is not None + assert isinstance(self._data.num_nodes, int) + is_good_tensor = torch.zeros(self._data.num_nodes, + dtype=torch.bool).view(-1, 1) + is_good_tensor[train_split_idx] = True + return is_good_tensor + + def get_gold(self, node_idx: IndexType) -> Tensor: + r"""Get gold mask for given node_idx. + + Args: + node_idx (torch.tensor): a tensor contain node idx + """ + if self._is_gold is None: + self._is_gold = self.is_gold + return self._is_gold[node_idx] + + def get_idx_split(self) -> Dict[str, Tensor]: + return self.split_idx + + def download(self) -> None: + print('downloading raw text') + raw_text_path = download_google_url(id=self.raw_text_id[self.name], + folder=f'{self.root}/raw', + filename='node-text.csv.gz', + log=True) + text_df = read_csv(raw_text_path) + self.text = list(text_df['text']) + + def process(self) -> None: + if osp.exists(osp.join(self.root, 'raw', 'node-text.csv.gz')): + text_df = read_csv(osp.join(self.root, 'raw', 'node-text.csv.gz')) + self.text = list(text_df['text']) + elif self.name in self.raw_text_id: + self.download() + else: + print('The dataset is not ogbn-products nor ogbn-arxiv,' + 'please pass in your raw text string list to `text`') + if self.text is None: + raise ValueError("The TAGDataset only have ogbn-products and " + "ogbn-arxiv raw text in default " + "The raw text of each node is not specified" + "Please pass in 'text' when convert your dataset " + "to Text Attribute Graph Dataset") + + def save_node_text(self, text: List[str]) -> None: + node_text_path = osp.join(self.root, 'raw', 'node-text.csv.gz') + if osp.exists(node_text_path): + print(f'The raw text is existed at {node_text_path}') + else: + print(f'Saving raw text file at {node_text_path}') + os.makedirs(f'{self.root}/raw', exist_ok=True) + text_df = DataFrame(text, columns=['text']) + text_df.to_csv(osp.join(node_text_path), compression='gzip', + index=False) + + def tokenize_graph(self, batch_size: int = 256) -> Dict[str, Tensor]: + r"""Tokenizing the text associate with each node, running in cpu. + + Args: + batch_size (Optional[int]): batch size of list of text for + generating emebdding + Returns: + Dict[str, torch.Tensor]: tokenized graph + """ + data_len = 0 + if self.text is not None: + data_len = len(self.text) + else: + raise ValueError("The TAGDataset need text for tokenization") + token_keys = ['input_ids', 'token_type_ids', 'attention_mask'] + path = os.path.join(self.processed_dir, 'token', self.tokenizer_name) + # Check if the .pt files already exist + token_files_exist = any( + os.path.exists(os.path.join(path, f'{k}.pt')) for k in token_keys) + + if token_files_exist and self.token_on_disk: + print('Found tokenized file, loading may take several minutes...') + all_encoded_token = { + k: torch.load(os.path.join(path, f'{k}.pt'), weights_only=True) + for k in token_keys + if os.path.exists(os.path.join(path, f'{k}.pt')) + } + return all_encoded_token + + all_encoded_token = {k: [] for k in token_keys} + pbar = tqdm(total=data_len) + + pbar.set_description('Tokenizing Text Attributed Graph') + for i in range(0, data_len, batch_size): + end_index = min(data_len, i + batch_size) + token = self.tokenizer(self.text[i:min(i + batch_size, data_len)], + padding='max_length', truncation=True, + max_length=512, return_tensors="pt") + for k in token.keys(): + all_encoded_token[k].append(token[k]) + pbar.update(end_index - i) + pbar.close() + + all_encoded_token = { + k: torch.cat(v) + for k, v in all_encoded_token.items() if len(v) > 0 + } + if self.token_on_disk: + os.makedirs(path, exist_ok=True) + print('Saving tokens on Disk') + for k, tensor in all_encoded_token.items(): + torch.save(tensor, os.path.join(path, f'{k}.pt')) + print('Token saved:', os.path.join(path, f'{k}.pt')) + os.environ["TOKENIZERS_PARALLELISM"] = 'true' # supressing warning + return all_encoded_token + + def __repr__(self) -> str: + return f'{self.__class__.__name__}()' + + class TextDataset(torch.utils.data.Dataset): + r"""This nested dataset provides textual data for each node in + the graph. Factory method to create TextDataset from TAGDataset. + + Args: + tag_dataset (TAGDataset): the parent dataset + """ + def __init__(self, tag_dataset: 'TAGDataset') -> None: + self.tag_dataset = tag_dataset + self.token = tag_dataset.token + assert tag_dataset._data is not None + self._data = tag_dataset._data + + assert tag_dataset._data.y is not None + self.labels = tag_dataset._data.y + + def get_token(self, node_idx: IndexType) -> Dict[str, Tensor]: + r"""This function will be called in __getitem__(). + + Args: + node_idx (IndexType): selected node idx in each batch + Returns: + items (Dict[str, Tensor]): input for LM + """ + items = {k: v[node_idx] for k, v in self.token.items()} + return items + + # for LM training + def __getitem__( + self, node_id: IndexType + ) -> Dict[str, Union[Tensor, Dict[str, Tensor]]]: + r"""This function will override the function in + torch.utils.data.Dataset, and will be called when you + iterate batch in the dataloader, make sure all following + key value pairs are present in the return dict. + + Args: + node_id (List[int]): list of node idx for selecting tokens, + labels etc. when iterating data loader for LM + Returns: + items (dict): input k,v pairs for Language model training and + inference + """ + item: Dict[str, Union[Tensor, Dict[str, Tensor]]] = {} + item['input'] = self.get_token(node_id) + item['labels'] = self.labels[node_id] + item['is_gold'] = self.tag_dataset.get_gold(node_id) + item['n_id'] = self.tag_dataset.get_n_id(node_id) + return item + + def __len__(self) -> int: + assert self._data.num_nodes is not None + return self._data.num_nodes + + def get(self, idx: int) -> BaseData: + return self._data + + def __repr__(self) -> str: + return f'{self.__class__.__name__}()' + + def to_text_dataset(self) -> TextDataset: + r"""Factory Build text dataset from Text Attributed Graph Dataset + each data point is node's associated text token. + """ + return TAGDataset.TextDataset(self) diff --git a/torch_geometric/loader/__init__.py b/torch_geometric/loader/__init__.py index 266f498a113b..7e83c35befb6 100644 --- a/torch_geometric/loader/__init__.py +++ b/torch_geometric/loader/__init__.py @@ -22,6 +22,7 @@ from .prefetch import PrefetchLoader from .cache import CachedLoader from .mixin import AffinityMixin +from .rag_loader import RAGQueryLoader __all__ = classes = [ 'DataLoader', @@ -50,6 +51,7 @@ 'PrefetchLoader', 'CachedLoader', 'AffinityMixin', + 'RAGQueryLoader', ] RandomNodeSampler = deprecated( diff --git a/torch_geometric/loader/neighbor_loader.py b/torch_geometric/loader/neighbor_loader.py index 341f2f5a23b6..5814724f0c48 100644 --- a/torch_geometric/loader/neighbor_loader.py +++ b/torch_geometric/loader/neighbor_loader.py @@ -14,7 +14,7 @@ class NeighborLoader(NodeLoader): This loader allows for mini-batch training of GNNs on large-scale graphs where full-batch training is not feasible. - More specifically, :obj:`num_neighbors` denotes how much neighbors are + More specifically, :obj:`num_neighbors` denotes how many neighbors are sampled for each node in each iteration. :class:`~torch_geometric.loader.NeighborLoader` takes in this list of :obj:`num_neighbors` and iteratively samples :obj:`num_neighbors[i]` for diff --git a/torch_geometric/loader/rag_loader.py b/torch_geometric/loader/rag_loader.py new file mode 100644 index 000000000000..33d6cf0e868e --- /dev/null +++ b/torch_geometric/loader/rag_loader.py @@ -0,0 +1,106 @@ +from abc import abstractmethod +from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union + +from torch_geometric.data import Data, FeatureStore, HeteroData +from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput +from torch_geometric.typing import InputEdges, InputNodes + + +class RAGFeatureStore(Protocol): + """Feature store for remote GNN RAG backend.""" + @abstractmethod + def retrieve_seed_nodes(self, query: Any, **kwargs) -> InputNodes: + """Makes a comparison between the query and all the nodes to get all + the closest nodes. Return the indices of the nodes that are to be seeds + for the RAG Sampler. + """ + ... + + @abstractmethod + def retrieve_seed_edges(self, query: Any, **kwargs) -> InputEdges: + """Makes a comparison between the query and all the edges to get all + the closest nodes. Returns the edge indices that are to be the seeds + for the RAG Sampler. + """ + ... + + @abstractmethod + def load_subgraph( + self, sample: Union[SamplerOutput, HeteroSamplerOutput] + ) -> Union[Data, HeteroData]: + """Combines sampled subgraph output with features in a Data object.""" + ... + + +class RAGGraphStore(Protocol): + """Graph store for remote GNN RAG backend.""" + @abstractmethod + def sample_subgraph(self, seed_nodes: InputNodes, seed_edges: InputEdges, + **kwargs) -> Union[SamplerOutput, HeteroSamplerOutput]: + """Sample a subgraph using the seeded nodes and edges.""" + ... + + @abstractmethod + def register_feature_store(self, feature_store: FeatureStore): + """Register a feature store to be used with the sampler. Samplers need + info from the feature store in order to work properly on HeteroGraphs. + """ + ... + + +# TODO: Make compatible with Heterographs + + +class RAGQueryLoader: + def __init__(self, data: Tuple[RAGFeatureStore, RAGGraphStore], + local_filter: Optional[Callable[[Data, Any], Data]] = None, + seed_nodes_kwargs: Optional[Dict[str, Any]] = None, + seed_edges_kwargs: Optional[Dict[str, Any]] = None, + sampler_kwargs: Optional[Dict[str, Any]] = None, + loader_kwargs: Optional[Dict[str, Any]] = None): + """Loader meant for making queries from a remote backend. + + Args: + data (Tuple[RAGFeatureStore, RAGGraphStore]): Remote FeatureStore + and GraphStore to load from. Assumed to conform to the + protocols listed above. + local_filter (Optional[Callable[[Data, Any], Data]], optional): + Optional local transform to apply to data after retrieval. + Defaults to None. + seed_nodes_kwargs (Optional[Dict[str, Any]], optional): Paramaters + to pass into process for fetching seed nodes. Defaults to None. + seed_edges_kwargs (Optional[Dict[str, Any]], optional): Parameters + to pass into process for fetching seed edges. Defaults to None. + sampler_kwargs (Optional[Dict[str, Any]], optional): Parameters to + pass into process for sampling graph. Defaults to None. + loader_kwargs (Optional[Dict[str, Any]], optional): Parameters to + pass into process for loading graph features. Defaults to None. + """ + fstore, gstore = data + self.feature_store = fstore + self.graph_store = gstore + self.graph_store.register_feature_store(self.feature_store) + self.local_filter = local_filter + self.seed_nodes_kwargs = seed_nodes_kwargs or {} + self.seed_edges_kwargs = seed_edges_kwargs or {} + self.sampler_kwargs = sampler_kwargs or {} + self.loader_kwargs = loader_kwargs or {} + + def query(self, query: Any) -> Data: + """Retrieve a subgraph associated with the query with all its feature + attributes. + """ + seed_nodes = self.feature_store.retrieve_seed_nodes( + query, **self.seed_nodes_kwargs) + seed_edges = self.feature_store.retrieve_seed_edges( + query, **self.seed_edges_kwargs) + + subgraph_sample = self.graph_store.sample_subgraph( + seed_nodes, seed_edges, **self.sampler_kwargs) + + data = self.feature_store.load_subgraph(sample=subgraph_sample, + **self.loader_kwargs) + + if self.local_filter: + data = self.local_filter(data, query) + return data diff --git a/torch_geometric/nn/attention/__init__.py b/torch_geometric/nn/attention/__init__.py index 947d5850173b..6b4064cd34b9 100644 --- a/torch_geometric/nn/attention/__init__.py +++ b/torch_geometric/nn/attention/__init__.py @@ -1,3 +1,7 @@ from .performer import PerformerAttention +from .qformer import QFormer -__all__ = ['PerformerAttention'] +__all__ = [ + 'PerformerAttention', + 'QFormer', +] diff --git a/torch_geometric/nn/attention/qformer.py b/torch_geometric/nn/attention/qformer.py new file mode 100644 index 000000000000..3a8f512d3f83 --- /dev/null +++ b/torch_geometric/nn/attention/qformer.py @@ -0,0 +1,71 @@ +from typing import Callable + +import torch + + +class QFormer(torch.nn.Module): + r"""The Querying Transformer (Q-Former) from + `"BLIP-2: Bootstrapping Language-Image Pre-training + with Frozen Image Encoders and Large Language Models" + `_ paper. + + Args: + input_dim (int): The number of features in the input. + hidden_dim (int): The dimension of the fnn in the encoder layer. + output_dim (int): The final output dimension. + num_heads (int): The number of multi-attention-heads. + num_layers (int): The number of sub-encoder-layers in the encoder. + dropout (int): The dropout value in each encoder layer. + + + .. note:: + This is a simplified version of the original Q-Former implementation. + """ + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_heads: int, + num_layers: int, + dropout: float = 0.0, + activation: Callable = torch.nn.ReLU(), + ) -> None: + + super().__init__() + self.num_layers = num_layers + self.num_heads = num_heads + + self.layer_norm = torch.nn.LayerNorm(input_dim) + self.encoder_layer = torch.nn.TransformerEncoderLayer( + d_model=input_dim, + nhead=num_heads, + dim_feedforward=hidden_dim, + dropout=dropout, + activation=activation, + batch_first=True, + ) + self.encoder = torch.nn.TransformerEncoder( + self.encoder_layer, + num_layers=num_layers, + ) + self.project = torch.nn.Linear(input_dim, output_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r"""Forward pass. + + Args: + x (torch.Tensor): Input sequence to the encoder layer. + :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with + batch-size :math:`B`, sequence length :math:`N`, + and feature dimension :math:`F`. + """ + x = self.layer_norm(x) + x = self.encoder(x) + out = self.project(x) + return out + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}(' + f'num_heads={self.num_heads}, ' + f'num_layers={self.num_layers})') diff --git a/torch_geometric/nn/models/__init__.py b/torch_geometric/nn/models/__init__.py index 7cfadf0143b2..9ade58cebc05 100644 --- a/torch_geometric/nn/models/__init__.py +++ b/torch_geometric/nn/models/__init__.py @@ -29,7 +29,9 @@ from .neural_fingerprint import NeuralFingerprint from .visnet import ViSNet from .g_retriever import GRetriever - +from .git_mol import GITMol +from .molecule_gpt import MoleculeGPT +from .glem import GLEM # Deprecated: from torch_geometric.explain.algorithm.captum import (to_captum_input, captum_output_to_dicts) @@ -77,4 +79,7 @@ 'NeuralFingerprint', 'ViSNet', 'GRetriever', + 'GITMol', + 'MoleculeGPT', + 'GLEM', ] diff --git a/torch_geometric/nn/models/g_retriever.py b/torch_geometric/nn/models/g_retriever.py index 6f8fbcc644dc..f7529ae721b7 100644 --- a/torch_geometric/nn/models/g_retriever.py +++ b/torch_geometric/nn/models/g_retriever.py @@ -21,6 +21,8 @@ class GRetriever(torch.nn.Module): (default: :obj:`False`) mlp_out_channels (int, optional): The size of each graph embedding after projection. (default: :obj:`4096`) + mlp_out_tokens (int, optional): Number of LLM prefix tokens to + reserve for GNN output. (default: :obj:`1`) .. warning:: This module has been tested with the following HuggingFace models @@ -43,6 +45,7 @@ def __init__( gnn: torch.nn.Module, use_lora: bool = False, mlp_out_channels: int = 4096, + mlp_out_tokens: int = 1, ) -> None: super().__init__() @@ -77,7 +80,9 @@ def __init__( self.projector = torch.nn.Sequential( torch.nn.Linear(mlp_hidden_channels, mlp_hidden_channels), torch.nn.Sigmoid(), - torch.nn.Linear(mlp_hidden_channels, mlp_out_channels), + torch.nn.Linear(mlp_hidden_channels, + mlp_out_channels * mlp_out_tokens), + torch.nn.Unflatten(-1, (mlp_out_tokens, mlp_out_channels)), ).to(self.llm.device) def encode( @@ -126,6 +131,9 @@ def forward( x = self.projector(x) xs = x.split(1, dim=0) + # Handle case where theres more than one embedding for each sample + xs = [x.squeeze(0) for x in xs] + # Handle questions without node features: batch_unique = batch.unique() batch_size = len(question) @@ -182,6 +190,9 @@ def inference( x = self.projector(x) xs = x.split(1, dim=0) + # Handle case where theres more than one embedding for each sample + xs = [x.squeeze(0) for x in xs] + # Handle questions without node features: batch_unique = batch.unique() batch_size = len(question) diff --git a/torch_geometric/nn/models/git_mol.py b/torch_geometric/nn/models/git_mol.py new file mode 100644 index 000000000000..c06b44671931 --- /dev/null +++ b/torch_geometric/nn/models/git_mol.py @@ -0,0 +1,336 @@ +from typing import List, Optional + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.nn import BatchNorm1d, LayerNorm, Linear, ReLU, Sequential + +from torch_geometric.nn import GINEConv +from torch_geometric.nn.nlp import SentenceTransformer, VisionTransformer +from torch_geometric.utils import add_self_loops, to_dense_batch + + +class GraphEncoder(torch.nn.Module): + def __init__( + self, + num_layers: int, + in_channels: int, + dropout: float = 0., + num_atom_type: int = 120, + num_chirality_tag: int = 3, + num_bond_type: int = 6, + num_bond_direction: int = 3, + ) -> None: + super().__init__() + + self.num_layers = num_layers + self.dropout = dropout + + self.x_embed1 = torch.nn.Embedding(num_atom_type, in_channels) + self.x_embed2 = torch.nn.Embedding(num_chirality_tag, in_channels) + self.edge_embed1 = torch.nn.Embedding(num_bond_type, in_channels) + self.edge_embed2 = torch.nn.Embedding(num_bond_direction, in_channels) + + self.gnns = torch.nn.ModuleList() + self.batch_norms = torch.nn.ModuleList() + for _ in range(num_layers): + self.gnns.append( + GINEConv( + nn=Sequential( + Linear(in_channels, in_channels * 2), + ReLU(), + Linear(in_channels * 2, in_channels), + ), + train_eps=True, + edge_dim=in_channels, + )) + self.batch_norms.append(BatchNorm1d(in_channels)) + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.xavier_uniform_(self.x_embed1.weight.data) + torch.nn.init.xavier_uniform_(self.x_embed2.weight.data) + torch.nn.init.xavier_uniform_(self.edge_embed1.weight.data) + torch.nn.init.xavier_uniform_(self.edge_embed2.weight.data) + + def forward( + self, + x: Tensor, + edge_index: Tensor, + batch: Tensor, + edge_attr: Tensor, + ) -> Tensor: + x = self.x_embed1(x[:, 0].long()) + self.x_embed2(x[:, 1].long()) + edge_index, edge_attr = add_self_loops( + edge_index, + edge_attr, + fill_value=0, + num_nodes=x.size(0), + ) + edge_attr = self.edge_embed1(edge_attr[:, 0]) + self.edge_embed2( + edge_attr[:, 1]) + for i, (gnn, bn) in enumerate(zip(self.gnns, self.batch_norms)): + x = gnn(x, edge_index, edge_attr) + x = bn(x) + if i < self.num_layers - 1: + x = F.relu(x) + x = F.dropout(x, self.dropout, training=self.training) + + x, mask = to_dense_batch(x, batch) + return x, mask + + +class GITFormer(torch.nn.Module): + def __init__( + self, + num_query_token: int, + vision_graph_width: int, + cross_attention_freq: int = 2, + ): + super().__init__() + from transformers import AutoConfig, AutoModel + + config = AutoConfig.from_pretrained("allenai/scibert_scivocab_uncased") + config.encoder_width = vision_graph_width + # insert cross-attention layer every other block + config.add_cross_attention = True + config.is_decoder = True + config.cross_attention_freq = cross_attention_freq + config.query_length = num_query_token + self.Qformer = AutoModel.from_pretrained( + "allenai/scibert_scivocab_uncased", config=config) + self.query_tokens = torch.nn.Parameter( + torch.zeros(1, num_query_token, config.hidden_size)) + self.query_tokens.data.normal_(mean=0.0, std=config.initializer_range) + + +class GITMol(torch.nn.Module): + r"""The GITMol model from the `"GIT-Mol: A Multi-modal Large Language + Model for Molecular Science with Graph, Image, and Text" + `_ paper. + + .. note:: + For an example of using :class:`GITMol`, see + `examples/llm/git_mol.py `_. + """ + def __init__(self) -> None: + super().__init__() + # graph + self.graph_encoder = GraphEncoder(num_layers=2, in_channels=16) + self.graph_proj = Linear(16, 768) + self.ln_graph = LayerNorm(768) + # text + self.text_encoder = SentenceTransformer( + model_name='allenai/scibert_scivocab_uncased', + pooling_strategy='last_hidden_state', + ) + self.text_proj = Linear(768, 768) + self.ln_text = LayerNorm(768) + # vision + self.vision_encoder = VisionTransformer( + model_name='microsoft/swin-base-patch4-window7-224', ) + self.vision_proj = Linear(1024, 768) + self.ln_vision = LayerNorm(768) + # cross-attention + self.gitformer = GITFormer(384, 768) + + self.xtm_head = torch.nn.ModuleDict({ + 'image': + Linear(self.gitformer.Qformer.config.hidden_size, 2), + 'graph': + Linear(self.gitformer.Qformer.config.hidden_size, 2), + 'cs_text': + Linear(self.gitformer.Qformer.config.hidden_size, 2), + }) + + self.xtc_proj = torch.nn.ModuleDict({ + 'image': + Linear(self.gitformer.Qformer.config.hidden_size, 768), + 'graph': + Linear(self.gitformer.Qformer.config.hidden_size, 768), + 'cs_text': + Linear(self.gitformer.Qformer.config.hidden_size, 768), + }) + self.temp = torch.nn.Parameter(0.07 * torch.ones([])) + self.model_freeze() + + def model_freeze(self) -> None: + for param in self.graph_encoder.parameters(): + param.requires_grad = False + + for param in self.vision_encoder.parameters(): + param.requires_grad = False + + def forward( + self, + x: Tensor, + edge_index: Tensor, + batch: Tensor, + edge_attr: Optional[Tensor], + smiles: List[str], + images: Tensor, + captions: List[str], + ) -> Tensor: + batch_size = len(smiles) + + x_vision = self.vision_encoder(images) + x_vision = self.vision_proj(x_vision) + x_vision = self.ln_vision(x_vision) # [bs, patch_len, d] + vision_atts = torch.ones(x_vision.size()[:-1], + dtype=torch.long).to(x_vision.device) + vision_targets = torch.arange(batch_size).to(x_vision.device) + + x_graph, graph_atts = self.graph_encoder(x, edge_index, batch, + edge_attr) + x_graph = self.graph_proj(x_graph) + x_graph = self.ln_graph(x_graph) # [bs, node_len, d] + graph_targets = torch.arange(batch_size).to(x_graph.device) + + x_smiles = self.text_encoder.encode(smiles) # [bs, seq_len, d] + smiles_atts = torch.ones(x_smiles.size()[:-1], + dtype=torch.long).to(x_smiles.device) + smiles_targets = torch.arange(batch_size).to(x_smiles.device) + + caption_input_ids, caption_attention_masks = self.text_encoder.get_input_ids( # noqa: E501 + captions) + + text_output = self.gitformer.Qformer( + caption_input_ids, + attention_mask=caption_attention_masks, + return_dict=True, + ) + text_feat = F.normalize( + self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1) + + loss = 0 + for x_embed, x_atts, x_targets, modal in zip( + [x_graph, x_smiles, x_vision], + [graph_atts, smiles_atts, vision_atts], + [graph_targets, smiles_targets, vision_targets], + ['graph', 'cs_text', 'image'], + ): + loss += self._calc_xtc_loss(x_embed, x_atts, x_targets, text_feat, + modal) + loss += self._calc_xtm_loss(x_embed, caption_input_ids, + caption_attention_masks, modal) + + return loss / 6 + + def _calc_xtm_loss( + self, + x_embeds: Tensor, + input_ids: Tensor, + attention_mask: Tensor, + modal: str, + ) -> Tensor: + # Initializing lists to hold the original and negative samples + x_embeds_list = [] + text_input_ids_list = [] + text_attention_mask_list = [] + + batch_size = x_embeds.size(0) + for i in range(batch_size): + # Original samples + x_embeds_list.append(x_embeds[i]) + text_input_ids_list.append(input_ids[i, :]) + text_attention_mask_list.append(attention_mask[i, :]) + + if batch_size > 1: + # Negative samples (neg_text_input_ids corresponds to x_embeds) + neg_text_input_ids = input_ids[i - 1 if i == batch_size - + 1 else i + 1, :] + neg_text_attention_mask = attention_mask[i - + 1 if i == batch_size - + 1 else i + 1, :] + text_input_ids_list.append(neg_text_input_ids) + text_attention_mask_list.append(neg_text_attention_mask) + x_embeds_list.append(x_embeds[i, :]) + + # Negative samples (text_input_ids corresponds to neg_x_embeds) + neg_x_embeds = x_embeds[i - 1 if i == batch_size - 1 else i + + 1, :] + x_embeds_list.append(neg_x_embeds) + text_input_ids_list.append(input_ids[i, :]) + text_attention_mask_list.append(attention_mask[i, :]) + + # Stack all samples into two large tensors + x_embeds_all = torch.stack(x_embeds_list, dim=1) \ + .reshape(-1, x_embeds.size(1), x_embeds.size(2)) + text_input_ids_all = torch.stack(text_input_ids_list, dim=1) \ + .reshape(-1, input_ids.size(1)) + # Create image attention masks for the concatenated tensor + image_attns_all = torch.ones(x_embeds_all.size()[:-1], + dtype=torch.long).to(x_embeds_all.device) + query_tokens_xtm = self.gitformer.query_tokens.expand( + text_input_ids_all.shape[0], -1, -1) + query_attns_xtm = torch.ones(query_tokens_xtm.size()[:-1], + dtype=torch.long).to(x_embeds_all.device) + + output_xtm = self.gitformer.Qformer( + inputs_embeds=query_tokens_xtm, + attention_mask=query_attns_xtm, + encoder_hidden_states=x_embeds_all, + encoder_attention_mask=image_attns_all, + return_dict=True, + ).last_hidden_state + + xtm_embeddings = output_xtm[:, :query_tokens_xtm.size(1), :] + + xtm_logit = self.xtm_head[modal](xtm_embeddings).mean(dim=1) + # Create labels: 1 for the original samples, 0 for the negative samples + if batch_size > 1: + labels = torch.cat( + [torch.ones(batch_size), + torch.zeros(batch_size * 2)], dim=0) + else: + labels = torch.ones(batch_size) + labels = labels.long().to(xtm_logit.device) + + # Calculate cross entropy loss + return F.cross_entropy(xtm_logit, labels) + + def _calc_xtc_loss( + self, + x_embeds: Tensor, + x_atts: Tensor, + x_targets: Tensor, + text_feat: Tensor, + modal: str, + ) -> Tensor: + query_tokens = self.gitformer.query_tokens.expand( + x_embeds.shape[0], -1, -1) + + query_output = self.gitformer.Qformer( + inputs_embeds=query_tokens, + encoder_hidden_states=x_embeds, + encoder_attention_mask=x_atts, + return_dict=True, + ).last_hidden_state + + x_feats = F.normalize(self.xtc_proj[modal](query_output), dim=-1) + + sim_q2t = torch.matmul( + x_feats.unsqueeze(1), + text_feat.unsqueeze(-1), + ).squeeze(-1) + + # modal-text similarity: aggregate across all query tokens + sim_x2t, _ = sim_q2t.max(-1) + sim_x2t = sim_x2t / self.temp + + # text-query similarity + sim_t2q = torch.matmul( + text_feat.unsqueeze(1).unsqueeze(1), + x_feats.permute(0, 2, 1), + ).squeeze(-2) + + # text-modal similarity: aggregate across all query tokens + sim_t2x, _ = sim_t2q.max(-1) + sim_t2x = sim_t2x / self.temp + + loss_itc = ( + F.cross_entropy(sim_x2t, x_targets, label_smoothing=0.1) + + F.cross_entropy(sim_t2x, x_targets, label_smoothing=0.1)) / 2 + + return loss_itc diff --git a/torch_geometric/nn/models/glem.py b/torch_geometric/nn/models/glem.py new file mode 100644 index 000000000000..afc8b09d77c7 --- /dev/null +++ b/torch_geometric/nn/models/glem.py @@ -0,0 +1,384 @@ +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from tqdm import tqdm + +from torch_geometric.loader import DataLoader, NeighborLoader +from torch_geometric.nn.models import GraphSAGE, basic_gnn + + +class GLEM(torch.nn.Module): + r"""This GNN+LM co-training model is based on GLEM from the `"Learning on + Large-scale Text-attributed Graphs via Variational Inference" + `_ paper. + + Args: + lm_to_use (str): A TextEncoder from huggingface model repo + with a classifier(default: TinyBERT) + gnn_to_use (torch_geometric.nn.models): (default: GraphSAGE) + out_channels (int): output channels for LM and GNN, should be same + num_gnn_heads Optional[int]: Number of heads for attention, if needed + num_gnn_layers (int): number of gnn layers + gnn_loss: loss function for gnn, (default: CrossEntropyLoss) + lm_loss: loss function for Language Model, (default: CrossEntropyLoss) + alpha (float): pseudo label weight of E-step, LM optimization, + (default: 0.5) + beta (float): pseudo label weight of M-step, GNN optimization, + (default: 0.5) + lm_dtype (torch.dtype): the data type once you load LM into memory, + (default: torch.bfloat16) + lm_use_lora (bool): choose if LM use Lora peft for fine tune, + (default: True) + lora_target_modules: The names of the target modules to apply the lora + adapter to, e.g. ['q_proj', 'v_proj'] for LLM , (default: None) + + .. note:: + See `examples/llm_plus_gnn/glem.py` for example usage. + """ + def __init__( + self, + lm_to_use: str = 'prajjwal1/bert-tiny', + gnn_to_use: basic_gnn = GraphSAGE, + out_channels: int = 47, + gnn_loss=nn.CrossEntropyLoss(reduction='mean'), + lm_loss=nn.CrossEntropyLoss(reduction='mean'), + alpha: float = 0.5, + beta: float = 0.5, + lm_dtype: torch.dtype = torch.bfloat16, + lm_use_lora: bool = True, + lora_target_modules: Optional[Union[List[str], str]] = None, + device: Union[str, torch.device] = torch.device('cpu'), + ): + super().__init__() + self.device = device + self.lm_loss = lm_loss + self.gnn = gnn_to_use + self.gnn_loss = gnn_loss + self.alpha = alpha + self.beta = beta + self.gnn_loss = gnn_loss + self.lm = lm_to_use + from transformers import AutoModelForSequenceClassification + self.lm = AutoModelForSequenceClassification.from_pretrained( + lm_to_use, num_labels=out_channels, torch_dtype=lm_dtype, + offload_folder="offload", trust_remote_code=True) + if lm_use_lora: + from peft import ( + LoraConfig, + TaskType, + get_peft_model, + prepare_model_for_kbit_training, + ) + print("Training LM with LORA!") + self.lm = prepare_model_for_kbit_training(self.lm) + config = LoraConfig(task_type=TaskType.SEQ_CLS, r=16, + lora_alpha=16, lora_dropout=0.05, bias="none", + target_modules=lora_target_modules) + self.lm = get_peft_model(self.lm, config) + self.lm.print_trainable_parameters() + self.lm.config.pad_token_id = self.lm.config.eos_token_id + self.lm_device = self.lm.device + + if self.lm.num_labels != self.gnn.out_channels: + raise ValueError('''The output channel of language model \ + and gnn should be the same''') + + def pre_train_gnn(self, train_loader: NeighborLoader, + optimizer: torch.optim.Optimizer, num_epochs: int, + patience: int, ext_pseudo_labels: torch.Tensor = None, + is_augmented: bool = False, verbose: bool = True): + # Pretrain GNN, optional steps if you do not have pseudo labels. + best_acc = 0 + early_stopping = 0 + # training only based on gold data + for epoch in range(0, num_epochs): + acc, loss = self.train_gnn(train_loader, optimizer, epoch, + ext_pseudo_labels, is_augmented, + verbose) + if acc < best_acc: + early_stopping += 1 + if early_stopping > patience: + print(f'Early stopped by Epoch: {epoch}, ' + f'Best acc: {best_acc}') + break + best_acc = max(best_acc, acc) + + def pre_train_lm(self, train_loader: DataLoader, + optimizer: torch.optim.Optimizer, num_epochs: int, + patience: int, ext_pseudo_labels: torch.Tensor = None, + is_augmented: bool = False, verbose: bool = True): + # Pretrain language model + best_acc = 0 + early_stopping = 0 + for epoch in range(1, num_epochs + 1): + acc, loss = self.train_lm(train_loader, optimizer, epoch, + ext_pseudo_labels, is_augmented, verbose) + if acc < best_acc: + early_stopping += 1 + if early_stopping > patience: + print(f'Early stopped by Epoch: {epoch}, ' + f'Best acc: {best_acc}') + break + best_acc = max(best_acc, acc) + + def train(self, em_phase: str, train_loader: Union[DataLoader, + NeighborLoader], + optimizer: torch.optim.Optimizer, pseudo_labels: torch.Tensor, + epoch: int, is_augmented: bool = False, verbose: bool = False): + r"""GLEM training step, EM steps. + + Args: + em_phase(str): 'gnn' or 'lm' choose which phase you are training on + train_loader(Union[DataLoader, NeighborLoader]): use DataLoader for + lm training, include tokenized data, labels is_gold mask. + use NeighborLoader for gnn training, include x, edge_index. + optimizer (torch.optim.Optimizer): optimizer for training + pseudo_labels(torch.Tensor): the predicted labels used as pseudo + labels + epoch (int): current epoch + is_augmented (bool): will use pseudo_labels or not + verbose (bool): print training progress bar or not + + Returns: + acc (float): training accuracy + loss (float): loss value + """ + pseudo_labels = pseudo_labels.to(self.device) + if em_phase == 'gnn': + acc, loss = self.train_gnn(train_loader, optimizer, epoch, + pseudo_labels, is_augmented, verbose) + if em_phase == 'lm': + acc, loss = self.train_lm(train_loader, optimizer, epoch, + pseudo_labels, is_augmented, verbose) + return acc, loss + + def train_lm(self, train_loader: DataLoader, + optimizer: torch.optim.Optimizer, epoch: int, + pseudo_labels: torch.Tensor = None, + is_augmented: bool = False, verbose: bool = True): + r"""Language model Training in every epoch. + + Args: + train_loader (loader.dataloader.DataLoader): text token dataloader + optimizer (torch.optim.Optimizer): model optimizer + epoch (int): current train epoch + pseudo_labels (torch.Tensor): 1-D tensor, predictions from gnn + is_augmented (bool): train with pseudo labels or not + verbose (bool): print training progress bar or not + + Returns: + approx_acc (torch.tensor): training accuracy + loss (torch.float): loss value + + """ + all_out = [] + total_loss = total_correct = 0 + num_nodes = train_loader.dataset.indices.size(0) + self.lm.train() + if verbose: + pbar = tqdm(total=num_nodes) + pbar.set_description(f'Epoch {epoch:02d}') + for batch in train_loader: + inputs = {k: v.to(self.device) for k, v in batch['input'].items()} + out = self.lm(**inputs).logits + labels = batch['labels'].to(self.device).squeeze() + # training with pseudo labels or not + if is_augmented: + pl_batch = pseudo_labels[batch['n_id']].to(self.device) + else: + pl_batch = None + loss = self.loss(out, labels, self.lm_loss, + batch['is_gold'].to(self.device), pl_batch, + self.alpha, is_augmented) + loss.backward() + optimizer.step() + optimizer.zero_grad() + all_out.append(out) + total_correct += int(out.argmax(dim=-1).eq(labels).sum()) + total_loss += float(loss) + if verbose: + pbar.update(batch['n_id'].size(0)) + + all_out = torch.cat(all_out, dim=0) + approx_acc = total_correct / num_nodes + loss = total_loss / len(train_loader) + if verbose: + pbar.close() + print(f'Epoch {epoch:02d} Loss: {loss:.4f} ' + f'Approx. Train: {approx_acc:.4f}') + return approx_acc, loss + + def train_gnn(self, train_loader: NeighborLoader, + optimizer: torch.optim.Optimizer, epoch: int, + pseudo_labels: torch.Tensor = None, + is_augmented: bool = False, verbose: bool = True): + r"""GNN training step in every epoch. + + Args: + train_loader (loader.NeighborLoader): gnn Neighbor node loader + optimizer (torch.optim.Optimizer): model optimizer + epoch (int): current train epoch + pseudo_labels(torch.tensor): 1-D tensor, predictions from lm + is_augmented(bool): use pseudo labeled node or not + verbose (bool): print training progress or not + + Returns: + approx_acc (torch.tensor): training accuracy + loss (torch.float): loss value + """ + self.gnn.train() + num_nodes = train_loader.input_nodes.size(0) + if verbose: + pbar = tqdm(total=num_nodes) + pbar.set_description(f'Epoch {epoch:02d}') + total_loss = total_correct = 0 + all_out = [] + for batch in train_loader: + batch = batch.to(self.device) + out = self.gnn(batch.x, batch.edge_index)[:batch.batch_size] + all_out.append(out) + labels = batch.y[:batch.batch_size].squeeze() + is_gold_batch = batch.is_gold[:batch.batch_size].squeeze() + # training with pseudo labels or not + if is_augmented and pseudo_labels is not None: + pl_batch = pseudo_labels[batch.n_id[:batch.batch_size]] + else: + pl_batch = None + loss = self.loss(out, labels, self.gnn_loss, is_gold_batch, + pl_batch, self.beta, is_augmented) + loss.backward() + optimizer.step() + optimizer.zero_grad() + total_loss += float(loss) + total_correct += int(out.argmax(dim=-1).eq(labels).sum()) + if verbose: + pbar.update(batch.batch_size) + + all_out = torch.cat(all_out, dim=0) + loss = total_loss / len(train_loader) + approx_acc = total_correct / num_nodes + if verbose: + pbar.close() + print(f'Epoch: {epoch:02d} Loss: {loss:.4f} ' + f'Approx. Train: {approx_acc:.4f}') + return approx_acc, loss + + @torch.no_grad() + def inference(self, em_phase: str, data_loader: Union[NeighborLoader, + DataLoader], + verbose: bool = False): + r"""GLEM inference step. + + Args: + em_phase(str): 'gnn' or 'lm' + data_loader(dataloader or Neighborloader): + dataloader: for lm training, include tokenized data + nodeloader: for gnn training, include x, edge_index + verbose(bool): print inference progress or not + + Returns: + out (torch.Tensor): n * m tensor, m is number of classes, + n is number of nodes + """ + out = None + if em_phase == 'gnn': + self.gnn.eval() + out = self.inference_gnn(data_loader, verbose) + elif em_phase == 'lm': + self.lm.eval() + out = self.inference_lm(data_loader, verbose) + return out + + @torch.no_grad() + def inference_lm(self, data_loader: DataLoader, verbose: bool = True): + r"""LM inference step. + + Args: + data_loader (Dataloader): include token, labels, and gold mask + verbose (bool): print progress bar or not + + Returns: + preds (tensor): prediction from GNN, convert to pseudo labels + by preds.argmax(dim=-1).unsqueeze(1) + """ + if verbose: + pbar = tqdm(total=data_loader.dataset._data.num_nodes) + pbar.set_description('LM inference stage') + self.lm.eval() + preds = [] + for batch in data_loader: + inputs = {k: v.to(self.device) for k, v in batch['input'].items()} + logits = self.lm(**inputs).logits + preds.append(logits) + if verbose: + pbar.update(batch['n_id'].size(0)) + if verbose: + pbar.close() + preds = torch.cat(preds) + return preds + + @torch.no_grad() + def inference_gnn(self, data_loader: NeighborLoader, verbose: bool = True): + r"""GNN inference step. + + Args: + data_loader(NeighborLoader): include x, edge_index, + verbose (bool): print progress bar or not + + Returns: + preds (tensor): prediction from GNN, + convert to pseudo labels by preds.argmax(dim=-1).unsqueeze(1) + """ + if verbose: + pbar = tqdm(total=data_loader.data.num_nodes) + pbar.set_description('GNN inference stage') + preds = [] + self.gnn.eval() + for batch in data_loader: + batch = batch.to(self.device) + out = self.gnn(batch.x, batch.edge_index)[:batch.batch_size] + preds.append(out) + if verbose: + pbar.update(batch.batch_size) + if verbose: + pbar.close() + preds = torch.cat(preds, dim=0) + return preds + + def loss(self, logits: torch.Tensor, labels: torch.Tensor, + loss_func: torch.nn.functional, is_gold: torch.Tensor, + pseudo_labels: torch.Tensor = None, pl_weight: float = 0.5, + is_augmented: bool = True): + r"""Core function of variational EM inference, this function is aming + on combining loss value on gold(original train) and loss value on + pseudo labels. + + Reference: + # noqa + + Args: + logits(torch.tensor): predict results from LM or GNN + labels(torch.tensor): combined node labels from ground truth and + pseudo labels(if provided) + loss_func(torch.nn.modules.loss): loss function for classification + is_gold(tensor): a tensor with bool value that mask ground truth + label and during training, thus ~is_gold mask pseudo labels + pseudo_labels(torch.tensor): predictions from other model + pl_weight: the pseudo labels used in E-step and M-step optimization + alpha in E-step, beta in M-step respectively + is_augmented: use EM or just train GNN and LM with gold data + + """ + def deal_nan(x): + return 0 if torch.isnan(x) else x + + if is_augmented and (sum(~is_gold) > 0): + mle_loss = deal_nan(loss_func(logits[is_gold], labels[is_gold])) + # all other labels beside from ground truth(gold labels) + pseudo_label_loss = deal_nan( + loss_func(logits[~is_gold], pseudo_labels[~is_gold])) + loss = pl_weight * pseudo_label_loss + (1 - pl_weight) * mle_loss + else: + loss = loss_func(logits, labels) + return loss diff --git a/torch_geometric/nn/models/molecule_gpt.py b/torch_geometric/nn/models/molecule_gpt.py new file mode 100644 index 000000000000..a0ac73ad9abb --- /dev/null +++ b/torch_geometric/nn/models/molecule_gpt.py @@ -0,0 +1,222 @@ +from typing import List, Optional + +import torch +from torch import Tensor + +from torch_geometric.nn.attention import QFormer +from torch_geometric.nn.nlp.llm import BOS, LLM, MAX_NEW_TOKENS +from torch_geometric.utils import to_dense_batch + + +def pad_or_truncate(embeddings: Tensor, max_seq_len: int, + padding_value: int = 0) -> Tensor: + batch_size, current_seq_len, d = embeddings.size() + + if current_seq_len > max_seq_len: + return embeddings[:, :max_seq_len, :] + elif current_seq_len < max_seq_len: + pad_tensor = torch.full((batch_size, max_seq_len - current_seq_len, d), + padding_value, dtype=embeddings.dtype, + device=embeddings.device) + return torch.cat([embeddings, pad_tensor], dim=1) + else: + return embeddings + + +class MoleculeGPT(torch.nn.Module): + r"""The MoleculeGPT model from the `"MoleculeGPT: Instruction + Following Large Language Models for Molecular Property Prediction" + `_ paper. + + Args: + llm (LLM): The LLM to use. + graph_encoder (torch.nn.Module): Encode 2D molecule graph. + smiles_encoder (torch.nn.Module): Encode 1D SMILES. + mlp_out_channels (int, optional): The size of each embedding + after qformer encoding. (default: :obj:`32`) + max_tokens (int, optional): Max output tokens of 1D/2D encoder. + (default: :obj:`20`) + + .. warning:: + This module has been tested with the following HuggingFace models + + * :obj:`llm_to_use="lmsys/vicuna-7b-v1.5"` + + and may not work with other models. See other models at `HuggingFace + Models `_ and let us know if you + encounter any issues. + + .. note:: + For an example of using :class:`MoleculeGPT`, see + `examples/llm/molecule_gpt.py `_. + """ + def __init__( + self, + llm: LLM, + graph_encoder: torch.nn.Module, + smiles_encoder: torch.nn.Module, + mlp_out_channels: int = 32, + max_tokens: Optional[int] = 20, + ) -> None: + super().__init__() + self.llm = llm + self.graph_encoder = graph_encoder.to(self.llm.device) + self.smiles_encoder = smiles_encoder.to(self.llm.device) + + self.graph_qformer = QFormer( + input_dim=self.graph_encoder.nn[-1].out_features, + hidden_dim=mlp_out_channels, + output_dim=mlp_out_channels, + num_heads=4, + num_layers=2, + ).to(self.llm.device) + + self.smiles_qformer = QFormer( + input_dim=self.smiles_encoder.model.pooler.dense.out_features, + hidden_dim=mlp_out_channels, + output_dim=mlp_out_channels, + num_heads=4, + num_layers=2, + ).to(self.llm.device) + + self.max_tokens = max_tokens + + self.word_embedding = self.llm.word_embedding + self.llm_generator = self.llm.llm + + # LLMs + in_dim = 2 * mlp_out_channels * max_tokens + out_dim = self.llm.llm.model.embed_tokens.embedding_dim + self.projector = torch.nn.Sequential( + torch.nn.Linear(in_dim, in_dim), + torch.nn.Sigmoid(), + torch.nn.Linear(in_dim, out_dim), + ).to(self.llm.device) + + def encode( + self, + x: Tensor, + edge_index: Tensor, + batch: Tensor, + edge_attr: Optional[Tensor], + smiles: List[str], + ) -> Tensor: + batch_size = len(smiles) + # 2D Graph Branch: [bs, node_len, d] + x = x.to(self.llm.device) + edge_index = edge_index.to(self.llm.device) + if edge_attr is not None: + edge_attr = edge_attr.to(self.llm.device) + batch = batch.to(self.llm.device) + + x_graph = self.graph_encoder(x, edge_index, edge_attr=edge_attr) + x_graph = to_dense_batch(x_graph, batch)[0] + out_graph = self.graph_qformer(x_graph) + out_graph = pad_or_truncate(out_graph, max_seq_len=self.max_tokens, + padding_value=0) + out_graph = out_graph.view(batch_size, -1) + + # 1D SMILES Branch: [bs, seq_len, d] + x_smiles = self.smiles_encoder.encode(smiles, + output_device=self.llm.device) + out_smiles = self.smiles_qformer(x_smiles) + out_smiles = pad_or_truncate(out_smiles, max_seq_len=self.max_tokens, + padding_value=0) + out_smiles = out_smiles.view(batch_size, -1) + + # Merge into LLMs + x_cat = torch.cat([out_graph, out_smiles], dim=1) + return x_cat + + def forward( + self, + x: Tensor, + edge_index: Tensor, + batch: Tensor, + edge_attr: Optional[Tensor], + smiles: List[str], + instructions: List[str], + label: List[str], + additional_text_context: Optional[List[str]] = None, + ): + x = self.encode(x, edge_index, batch, edge_attr, smiles) + x = self.projector(x) + xs = x.split(1, dim=0) + + batch_unique = batch.unique() + batch_size = len(instructions) + if len(batch_unique) < batch_size: + xs = [ + xs[i] if i in batch_unique else None for i in range(batch_size) + ] + + ( + inputs_embeds, + attention_mask, + label_input_ids, + ) = self.llm._get_embeds(instructions, additional_text_context, xs, + label) + + with self.llm.autocast_context: + outputs = self.llm_generator( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + return_dict=True, + labels=label_input_ids, + ) + + return outputs.loss + + @torch.no_grad() + def inference( + self, + x: Tensor, + edge_index: Tensor, + batch: Tensor, + edge_attr: Optional[Tensor], + smiles: List[str], + instructions: List[str], + additional_text_context: Optional[List[str]] = None, + max_out_tokens: Optional[int] = MAX_NEW_TOKENS, + ): + x = self.encode(x, edge_index, batch, edge_attr, smiles) + x = self.projector(x) + xs = x.split(1, dim=0) + + # Handle questions without node features: + batch_unique = batch.unique() + batch_size = len(instructions) + if len(batch_unique) < batch_size: + xs = [ + xs[i] if i in batch_unique else None for i in range(batch_size) + ] + + inputs_embeds, attention_mask, _ = self.llm._get_embeds( + instructions, additional_text_context, xs) + + bos_token = self.llm.tokenizer( + BOS, + add_special_tokens=False, + ).input_ids[0] + + with self.llm.autocast_context: + outputs = self.llm_generator.generate( + inputs_embeds=inputs_embeds, + max_new_tokens=max_out_tokens, + attention_mask=attention_mask, + bos_token_id=bos_token, + use_cache=True # Important to set! + ) + + return self.llm.tokenizer.batch_decode( + outputs, + skip_special_tokens=True, + ) + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}(\n' + f' llm={self.llm},\n' + f' graph={self.graph_encoder.__class__.__name__},\n' + f' smiles={self.smiles_encoder},\n' + f')') diff --git a/torch_geometric/nn/nlp/__init__.py b/torch_geometric/nn/nlp/__init__.py index c101a359e3f5..434619352460 100644 --- a/torch_geometric/nn/nlp/__init__.py +++ b/torch_geometric/nn/nlp/__init__.py @@ -1,7 +1,9 @@ from .sentence_transformer import SentenceTransformer +from .vision_transformer import VisionTransformer from .llm import LLM __all__ = classes = [ 'SentenceTransformer', + 'VisionTransformer', 'LLM', ] diff --git a/torch_geometric/nn/nlp/llm.py b/torch_geometric/nn/nlp/llm.py index b58059f8e098..d18aa42382f7 100644 --- a/torch_geometric/nn/nlp/llm.py +++ b/torch_geometric/nn/nlp/llm.py @@ -56,7 +56,7 @@ class LLM(torch.nn.Module): allocate the correct number of GPUs needed, given the available GPU memory of your GPUs. dtype (torch.dtype, optional): The data type to use for the LLM. - (default :obj: `torch.bloat16`) + (default :obj: `torch.bfloat16`) """ def __init__( self, diff --git a/torch_geometric/nn/nlp/sentence_transformer.py b/torch_geometric/nn/nlp/sentence_transformer.py index c66677e8fa24..6d904b8e0fbf 100644 --- a/torch_geometric/nn/nlp/sentence_transformer.py +++ b/torch_geometric/nn/nlp/sentence_transformer.py @@ -10,6 +10,7 @@ class PoolingStrategy(Enum): MEAN = 'mean' LAST = 'last' CLS = 'cls' + LAST_HIDDEN_STATE = 'last_hidden_state' class SentenceTransformer(torch.nn.Module): @@ -38,6 +39,8 @@ def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor: emb = mean_pooling(emb, attention_mask) elif self.pooling_strategy == PoolingStrategy.LAST: emb = last_pooling(emb, attention_mask) + elif self.pooling_strategy == PoolingStrategy.LAST_HIDDEN_STATE: + emb = out.last_hidden_state else: assert self.pooling_strategy == PoolingStrategy.CLS emb = emb[:, 0, :] @@ -45,6 +48,36 @@ def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor: emb = F.normalize(emb, p=2, dim=1) return emb + def get_input_ids( + self, + text: List[str], + batch_size: Optional[int] = None, + output_device: Optional[Union[torch.device, str]] = None, + ) -> Tensor: + is_empty = len(text) == 0 + text = ['dummy'] if is_empty else text + + batch_size = len(text) if batch_size is None else batch_size + + input_ids: List[Tensor] = [] + attention_masks: List[Tensor] = [] + for start in range(0, len(text), batch_size): + token = self.tokenizer( + text[start:start + batch_size], + padding=True, + truncation=True, + return_tensors='pt', + ) + input_ids.append(token.input_ids.to(self.device)) + attention_masks.append(token.attention_mask.to(self.device)) + + def _out(x: List[Tensor]) -> Tensor: + out = torch.cat(x, dim=0) if len(x) > 1 else x[0] + out = out[:0] if is_empty else out + return out.to(output_device) + + return _out(input_ids), _out(attention_masks) + @property def device(self) -> torch.device: return next(iter(self.model.parameters())).device diff --git a/torch_geometric/nn/nlp/vision_transformer.py b/torch_geometric/nn/nlp/vision_transformer.py new file mode 100644 index 000000000000..517a524f4d84 --- /dev/null +++ b/torch_geometric/nn/nlp/vision_transformer.py @@ -0,0 +1,33 @@ +from typing import Optional, Union + +import torch +from torch import Tensor + + +class VisionTransformer(torch.nn.Module): + def __init__( + self, + model_name: str, + ) -> None: + super().__init__() + self.model_name = model_name + + from transformers import SwinConfig, SwinModel + + self.config = SwinConfig.from_pretrained(model_name) + self.model = SwinModel(self.config) + + @torch.no_grad() + def forward( + self, + images: Tensor, + output_device: Optional[Union[torch.device, str]] = None, + ) -> Tensor: + return self.model(images).last_hidden_state.to(output_device) + + @property + def device(self) -> torch.device: + return next(iter(self.model.parameters())).device + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(model_name={self.model_name})' diff --git a/torch_geometric/profile/__init__.py b/torch_geometric/profile/__init__.py index 833ee657d0e7..22d3039f4c83 100644 --- a/torch_geometric/profile/__init__.py +++ b/torch_geometric/profile/__init__.py @@ -20,6 +20,7 @@ get_gpu_memory_from_nvidia_smi, get_model_size, ) +from .nvtx import nvtxit __all__ = [ 'profileit', @@ -38,6 +39,7 @@ 'get_gpu_memory_from_nvidia_smi', 'get_gpu_memory_from_ipex', 'benchmark', + 'nvtxit', ] classes = __all__ diff --git a/torch_geometric/profile/nvtx.py b/torch_geometric/profile/nvtx.py new file mode 100644 index 000000000000..8dbce375ae5a --- /dev/null +++ b/torch_geometric/profile/nvtx.py @@ -0,0 +1,66 @@ +from functools import wraps +from typing import Optional + +import torch + +CUDA_PROFILE_STARTED = False + + +def begin_cuda_profile(): + global CUDA_PROFILE_STARTED + prev_state = CUDA_PROFILE_STARTED + if prev_state is False: + CUDA_PROFILE_STARTED = True + torch.cuda.cudart().cudaProfilerStart() + return prev_state + + +def end_cuda_profile(prev_state: bool): + global CUDA_PROFILE_STARTED + CUDA_PROFILE_STARTED = prev_state + if prev_state is False: + torch.cuda.cudart().cudaProfilerStop() + + +def nvtxit(name: Optional[str] = None, n_warmups: int = 0, + n_iters: Optional[int] = None): + """Enables NVTX profiling for a function. + + Args: + name (Optional[str], optional): Name to give the reference frame for + the function being wrapped. Defaults to the name of the + function in code. + n_warmups (int, optional): Number of iters to call that function + before starting. Defaults to 0. + n_iters (Optional[int], optional): Number of iters of that function to + record. Defaults to all of them. + """ + def nvtx(func): + + nonlocal name + iters_so_far = 0 + if name is None: + name = func.__name__ + + @wraps(func) + def wrapper(*args, **kwargs): + nonlocal iters_so_far + if not torch.cuda.is_available(): + return func(*args, **kwargs) + elif iters_so_far < n_warmups: + iters_so_far += 1 + return func(*args, **kwargs) + elif n_iters is None or iters_so_far < n_iters + n_warmups: + prev_state = begin_cuda_profile() + torch.cuda.nvtx.range_push(f"{name}_{iters_so_far}") + result = func(*args, **kwargs) + torch.cuda.nvtx.range_pop() + end_cuda_profile(prev_state) + iters_so_far += 1 + return result + else: + return func(*args, **kwargs) + + return wrapper + + return nvtx diff --git a/torch_geometric/sampler/base.py b/torch_geometric/sampler/base.py index d67ddd5af79b..1bd2e4346e1d 100644 --- a/torch_geometric/sampler/base.py +++ b/torch_geometric/sampler/base.py @@ -425,6 +425,14 @@ def _get_values( else: assert False + # Confirm that `values` only hold valid edge types: + if isinstance(self.values, dict): + edge_types_str = {EdgeTypeStr(key) for key in edge_types} + invalid_edge_types = set(self.values.keys()) - edge_types_str + if len(invalid_edge_types) > 0: + raise ValueError("Not all edge types specified in " + "'num_neighbors' exist in the graph") + out = {} for edge_type in edge_types: edge_type_str = EdgeTypeStr(edge_type)