diff --git a/CHANGELOG.md b/CHANGELOG.md index cf29ead4dc5f..d3887b05be3c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added the `WebQSPDataset` dataset with example training G-Retriever (GNN+LLM) ([#8984](https://github.com/pyg-team/pytorch_geometric/pull/8984)) - Added support for `EdgeIndex` in `message_and_aggregate` ([#9131](https://github.com/pyg-team/pytorch_geometric/pull/9131)) - Added `CornellTemporalHyperGraphDataset` ([#9090](https://github.com/pyg-team/pytorch_geometric/pull/9090)) - Added support for cuGraph data loading and `GAT` in single node Papers100m examples ([#8173](https://github.com/pyg-team/pytorch_geometric/pull/8173)) diff --git a/examples/README.md b/examples/README.md index 336b3d816a82..a31bf851dacf 100644 --- a/examples/README.md +++ b/examples/README.md @@ -14,8 +14,10 @@ For examples on [Open Graph Benchmark](https://ogb.stanford.edu/) datasets, see - [`ogbn_papers_100m.py`](./ogbn_papers_100m.py) is an example for training a GNN on the large-scale `ogbn-papers100m` dataset, containing approximately ~1.6B edges. - [`ogbn_papers_100m_cugraph.py`](./ogbn_papers_100m_cugraph.py) shows how to accelerate the `ogbn-papers100m` workflow using [CuGraph](https://github.com/rapidsai/cugraph). -For examples on using `torch.compile`, see the examples under [`examples/compile`](./compile). +For examples on co-training LLM with GNN, see examples and README under [`examples/llm_plus_gnn`](./llm_plus_gnn). -For examples on scaling PyG up via multi-GPUs, see the examples under [`examples/multi_gpu`](./multi_gpu). +For examples on using `torch.compile`, see examples and README under [`examples/compile`](./compile). -For examples on working with heterogeneous data, see the examples under [`examples/hetero`](./hetero). +For examples on scaling PyG up via multi-GPUs, see examples and README under [`examples/multi_gpu`](./multi_gpu). + +For examples on working with heterogeneous data, see examples and README under [`examples/hetero`](./hetero). diff --git a/examples/llm_plus_gnn/README.md b/examples/llm_plus_gnn/README.md new file mode 100644 index 000000000000..5c7900479540 --- /dev/null +++ b/examples/llm_plus_gnn/README.md @@ -0,0 +1,5 @@ +# Examples for LLM and GNN co-training + +| Example | Description | +| ------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------- | +| [`g_retriever.py`](./g_retriever.py) | Example for Retriever Augmented Generation (RAG) w/ GNN+LLM by co-training LLAMA2 with GAT for answering questions based on knowledge graph information | diff --git a/examples/llm_plus_gnn/g_retriever.py b/examples/llm_plus_gnn/g_retriever.py new file mode 100644 index 000000000000..b647dee8d1f0 --- /dev/null +++ b/examples/llm_plus_gnn/g_retriever.py @@ -0,0 +1,583 @@ +"""This example implements G-retriever using PyG. +Original Paper: https://arxiv.org/abs/2402.07630 +“G-Retriever significantly reduces hallucinations +by 54% compared to the [LLM] baseline“. + +requirements on top of basic PyG: +pip install peft datasets transformers pcst_fast sentencepiece tqdm pandas +""" +import argparse +import gc +import math +import re +import time +from os import path + +import pandas as pd +import torch +import torch.nn as nn +from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training +from torch.nn.utils import clip_grad_norm_ +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + +import torch_geometric +from torch_geometric import seed_everything +from torch_geometric.data import Batch, DataLoader +from torch_geometric.datasets import WebQSPDataset +from torch_geometric.utils import scatter + +BOS = '[INST]' +EOS_USER = '[/INST]' +EOS = '[/s]' +IGNORE_INDEX = -100 +llama2_str_name = "meta-llama/Llama-2-7b-chat-hf" +max_txt_len = 512 +max_new_tokens = 32 +pad_token_id = 0 +padding_side = 'left' + + +def detect_hallucinate(pred, label): + try: + pred = pred.split('[/s]')[0].strip().split('|') + correct_hit = len(re.findall(pred[0], label)) > 0 + hallucination = not correct_hit + return hallucination + except: # noqa + return "skip" + + +def compute_accuracy(eval_output): + df = pd.concat([pd.DataFrame(d) for d in eval_output]) + all_hit = [] + all_precision = [] + all_recall = [] + all_f1 = [] + + for pred, label in zip(df.pred.tolist(), df.label.tolist()): + try: + pred = pred.split('[/s]')[0].strip().split('|') + hit = re.findall(pred[0], label) + all_hit.append(len(hit) > 0) + + label = label.split('|') + matches = set(pred).intersection(set(label)) + precision = len(matches) / len(set(label)) + recall = len(matches) / len(set(pred)) + if recall + precision == 0: + f1 = 0 + else: + f1 = 2 * precision * recall / (precision + recall) + + all_precision.append(precision) + all_recall.append(recall) + all_f1.append(f1) + + except Exception as e: + print(f'Label: {label}') + print(f'Pred: {pred}') + print(f'Exception: {e}') + print('------------------') + hit = sum(all_hit) / len(all_hit) + precision = sum(all_precision) / len(all_precision) + recall = sum(all_recall) / len(all_recall) + f1 = sum(all_f1) / len(all_f1) + + print(f'Hit: {hit:.4f}') + print(f'Precision: {precision:.4f}') + print(f'Recall: {recall:.4f}') + print(f'F1: {f1:.4f}') + + return hit + + +def get_llm_kwargs(): + assert torch.cuda.is_available(), "GPU needed!" + avail_gpus = torch.cuda.device_count() + kwargs = { + "revision": "main", + } + max_mem_dict = {} + avail_mem_dict = {} + mem_total = 0 + gpus_2_use_4_llm = 0 + for i in range(avail_gpus): + available_mem = int(torch.cuda.mem_get_info(0)[0] // 1024**3) + mem_total += available_mem + avail_mem_dict[i] = available_mem + gpus_2_use_4_llm += 1 + # We want to use the minimum number of GPUs that LLM can fit on + # this is to minimize the need for interGPU communications + # >= 75 GB VRAM in total is recommended + if mem_total >= 75: + break + + for i in range(gpus_2_use_4_llm): + max_mem_dict[i] = str(avail_mem_dict[i]) + "GiB" + kwargs["max_memory"] = max_mem_dict + kwargs["device_map"] = "auto" + return kwargs + + +class LLAMA2(nn.Module): + # Pure LLAMA2 LLM module for demo + def __init__(self): + super().__init__() + print('Loading LLAMA') + kwargs = get_llm_kwargs() + print("Setting up LLAMA w/ kwargs =", kwargs) + self.tokenizer = AutoTokenizer.from_pretrained(llama2_str_name, + use_fast=False) + self.tokenizer.pad_token_id = pad_token_id + self.tokenizer.padding_side = padding_side + self.llm = AutoModelForCausalLM.from_pretrained( + llama2_str_name, torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, **kwargs) + self.llm_device = self.llm.device + self.word_embedding = self.llm.model.get_input_embeddings() + + def encode_inputs(self, samples: Batch): + batch_size = len(samples['question']) + questions = self.tokenizer(samples["question"], + add_special_tokens=False) + descriptions = self.tokenizer(samples["desc"], + add_special_tokens=False) + + # encode special tokens + eos_user_tokens = self.tokenizer(EOS_USER, add_special_tokens=False) + bos_embeds = self.word_embedding( + self.tokenizer(BOS, add_special_tokens=False, + return_tensors='pt').input_ids[0].to( + self.llm_device)) + pad_embeds = self.word_embedding( + torch.tensor(self.tokenizer.pad_token_id).to( + self.llm_device)).unsqueeze(0) + return batch_size, questions, descriptions, eos_user_tokens, bos_embeds, pad_embeds + + def inference(self, samples: Batch): + batch_size, questions, descriptions, eos_user_tokens, bos_embeds, pad_embeds = self.encode_inputs( + samples) + batch_inputs_embeds = [] + batch_attention_mask = [] + for i in range(batch_size): + # Add bos & eos token + input_ids = descriptions.input_ids[ + i][:max_txt_len] + questions.input_ids[ + i] + eos_user_tokens.input_ids + inputs_embeds = self.word_embedding( + torch.tensor(input_ids).to(self.llm_device)) + inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=0) + batch_inputs_embeds.append(inputs_embeds) + batch_attention_mask.append([1] * inputs_embeds.shape[0]) + + # pad inputs_embeds + max_length = max([x.shape[0] for x in batch_inputs_embeds]) + for i in range(batch_size): + pad_length = max_length - batch_inputs_embeds[i].shape[0] + batch_inputs_embeds[i] = torch.cat( + [pad_embeds.repeat(pad_length, 1), batch_inputs_embeds[i]]) + batch_attention_mask[i] = [0 + ] * pad_length + batch_attention_mask[i] + + inputs_embeds = torch.stack(batch_inputs_embeds, + dim=0).to(self.llm_device) + attention_mask = torch.tensor(batch_attention_mask).to(self.llm_device) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + outputs = self.llm.generate( + inputs_embeds=inputs_embeds, + max_new_tokens=max_new_tokens, + attention_mask=attention_mask, + # do_sample=True, + use_cache=True # IMPORTANT! + ) + pred = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) + + return { + 'pred': pred, + 'label': samples['label'], + 'question': samples['question'], + 'desc': samples['desc'], + } + + +class GAT_LLAMA(nn.Module): + def __init__(self, hidden_channels: int, num_gnn_layers: int): + super().__init__() + + self.llama2 = LLAMA2() + + print("Training LLAMA with LORA!") + self.llm = self.llama2.llm + self.llm_device = self.llama2.llm_device + self.llm = prepare_model_for_kbit_training(self.llm) + self.tokenizer = self.llama2.tokenizer + lora_r: int = 8 + lora_alpha: int = 16 + lora_dropout: float = 0.05 + lora_target_modules = [ + "q_proj", + "v_proj", + ] + config = LoraConfig( + r=lora_r, + lora_alpha=lora_alpha, + target_modules=lora_target_modules, + lora_dropout=lora_dropout, + bias="none", + task_type="CAUSAL_LM", + ) + self.llm = get_peft_model(self.llm, config) + + print('Finish loading LLAMA!') + + self.graph_encoder = torch_geometric.nn.models.GAT( + in_channels=1024, + out_channels=1024, + hidden_channels=hidden_channels, + num_layers=num_gnn_layers, + heads=4, + ).to(self.llm_device) + self.projector = nn.Sequential( + nn.Linear(1024, 2048), + nn.Sigmoid(), + nn.Linear(2048, 4096), + ).to(self.llm_device) + + self.word_embedding = self.llama2.word_embedding + + def encode_graphs(self, samples: Batch): + x = samples.x.to(self.llm_device) + edge_index = samples.edge_index.long().to(self.llm_device) + edge_attr = samples.edge_attr.to(self.llm_device) + n_embeds = self.graph_encoder(x, edge_index.long(), edge_attr) + batch = samples.batch.to(self.llm_device) + # mean pooling + g_embeds = scatter(n_embeds, batch, dim=0, reduce='mean') + return g_embeds + + def forward(self, samples: Batch): + batch_size, questions, descriptions, eos_user_tokens, bos_embeds, pad_embeds = self.llama2.encode_inputs( + samples) + # encode labels + labels = self.tokenizer(samples.label, add_special_tokens=False) + # encode training specific special token + eos_tokens = self.tokenizer(EOS, add_special_tokens=False) + + # encode graphs + graph_embeds = self.encode_graphs(samples) + graph_embeds = self.projector(graph_embeds) + batch_inputs_embeds = [] + batch_attention_mask = [] + batch_label_input_ids = [] + num_nodes_per_graph = samples.ptr[1:] - samples.ptr[:-1] + for i in range(batch_size): + # Add bos & eos token + label_input_ids = labels.input_ids[ + i][:max_new_tokens] + eos_tokens.input_ids + input_ids = descriptions.input_ids[ + i][:max_txt_len] + questions.input_ids[ + i] + eos_user_tokens.input_ids + label_input_ids + inputs_embeds = self.word_embedding( + torch.tensor(input_ids).to(self.llm_device)) + to_cat = [bos_embeds] + if num_nodes_per_graph[i] != 0: + to_cat.append(graph_embeds[i].unsqueeze(0)) + to_cat.append(inputs_embeds) + inputs_embeds = torch.cat(to_cat, dim=0) + batch_inputs_embeds.append(inputs_embeds) + batch_attention_mask.append([1] * inputs_embeds.shape[0]) + label_input_ids = [IGNORE_INDEX + ] * (inputs_embeds.shape[0] - + len(label_input_ids)) + label_input_ids + batch_label_input_ids.append(label_input_ids) + + # pad inputs_embeds + max_length = max([x.shape[0] for x in batch_inputs_embeds]) + for i in range(batch_size): + pad_length = max_length - batch_inputs_embeds[i].shape[0] + batch_inputs_embeds[i] = torch.cat( + [pad_embeds.repeat(pad_length, 1), batch_inputs_embeds[i]]) + batch_attention_mask[i] = [0 + ] * pad_length + batch_attention_mask[i] + batch_label_input_ids[ + i] = [IGNORE_INDEX] * pad_length + batch_label_input_ids[i] + + inputs_embeds = torch.stack(batch_inputs_embeds, + dim=0).to(self.llm_device) + attention_mask = torch.tensor(batch_attention_mask).to(self.llm_device) + label_input_ids = torch.tensor(batch_label_input_ids).to( + self.llm_device) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + outputs = self.llm( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + return_dict=True, + labels=label_input_ids, + ) + return outputs.loss + + def inference(self, samples: Batch): + batch_size, questions, descriptions, eos_user_tokens, bos_embeds, pad_embeds = self.llama2.encode_inputs( + samples) + # encode graphs + graph_embeds = self.encode_graphs(samples) + graph_embeds = self.projector(graph_embeds) + + batch_inputs_embeds = [] + batch_attention_mask = [] + for i in range(batch_size): + # Add bos & eos token + input_ids = descriptions.input_ids[ + i][:max_txt_len] + questions.input_ids[ + i] + eos_user_tokens.input_ids + inputs_embeds = self.word_embedding( + torch.tensor(input_ids).to(self.llm_device)) + inputs_embeds = torch.cat( + [bos_embeds, graph_embeds[i].unsqueeze(0), inputs_embeds], + dim=0) + batch_inputs_embeds.append(inputs_embeds) + batch_attention_mask.append([1] * inputs_embeds.shape[0]) + + # pad inputs_embeds + max_length = max([x.shape[0] for x in batch_inputs_embeds]) + for i in range(batch_size): + pad_length = max_length - batch_inputs_embeds[i].shape[0] + batch_inputs_embeds[i] = torch.cat( + [pad_embeds.repeat(pad_length, 1), batch_inputs_embeds[i]]) + batch_attention_mask[i] = [0 + ] * pad_length + batch_attention_mask[i] + + inputs_embeds = torch.stack(batch_inputs_embeds, + dim=0).to(self.llm_device) + attention_mask = torch.tensor(batch_attention_mask).to(self.llm_device) + + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + outputs = self.llm.generate( + inputs_embeds=inputs_embeds, + max_new_tokens=max_new_tokens, + attention_mask=attention_mask, + # do_sample=True, + use_cache=True # IMPORTANT! + ) + pred = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) + + return { + 'pred': pred, + 'label': samples['label'], + 'question': samples['question'], + 'desc': samples['desc'], + } + + def print_trainable_params(self): + trainable_params = 0 + all_param = 0 + for _, param in self.named_parameters(): + num_params = param.numel() + + all_param += num_params + if param.requires_grad: + trainable_params += num_params + + return trainable_params, all_param + + +def main(since: float, num_epochs: int, hidden_channels: int, + num_gnn_layers: int, batch_size: int, lr: float): + def adjust_learning_rate(param_group, LR, epoch): + # Decay the learning rate with half-cycle cosine after warmup + min_lr = 5e-6 + warmup_epochs = 1 + if epoch < warmup_epochs: + lr = LR + else: + lr = min_lr + (LR - min_lr) * 0.5 * ( + 1.0 + math.cos(math.pi * (epoch - warmup_epochs) / + (num_epochs - warmup_epochs))) + param_group["lr"] = lr + return lr + + seed_everything(42) + + dataset = WebQSPDataset() + idx_split = dataset.split_idxs + + # Step 1: Build Node Classification Dataset + train_dataset = [dataset[i] for i in idx_split['train']] + val_dataset = [dataset[i] for i in idx_split['val']] + test_dataset = [dataset[i] for i in idx_split['test']] + + train_loader = DataLoader(train_dataset, batch_size=batch_size, + drop_last=True, pin_memory=True, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=batch_size, + drop_last=False, pin_memory=True, shuffle=False) + test_loader = DataLoader(test_dataset, batch_size=batch_size, + drop_last=False, pin_memory=True, shuffle=False) + + # Step 2: Build Model + model = GAT_LLAMA(hidden_channels, num_gnn_layers) + + # Step 3 Set Optimizer + params = [p for _, p in model.named_parameters() if p.requires_grad] + optimizer = torch.optim.AdamW([ + { + 'params': params, + 'lr': lr, + 'weight_decay': .05 + }, + ], betas=(0.9, 0.95)) + grad_steps = 2 + trainable_params, all_param = model.print_trainable_params() + print(f"trainable params: {trainable_params} || \ + all params: {all_param} || \ + trainable%: {100 * trainable_params / all_param}") + + # Step 4 Training + for epoch in range(num_epochs): + model.train() + epoch_loss = 0. + if epoch == 0: + prep_time = round(time.time() - since, 2) + print("Total Prep Time (prep_time) =", prep_time) + print("Training beginning...") + epoch_str = f"Epoch: {epoch + 1}|{num_epochs}" + loader = tqdm(train_loader, desc=epoch_str) + for step, batch in enumerate(loader): + optimizer.zero_grad() + loss = model(batch) + loss.backward() + + clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1) + + if (step + 1) % grad_steps == 0: + adjust_learning_rate(optimizer.param_groups[0], lr, + step / len(train_loader) + epoch) + + optimizer.step() + epoch_loss = epoch_loss + loss.item() + + if (step + 1) % grad_steps == 0: + lr = optimizer.param_groups[0]["lr"] + train_loss = epoch_loss / len(train_loader) + print(epoch_str + f",Train Loss (Epoch Mean): {train_loss}") + + val_loss = 0. + eval_output = [] + model.eval() + with torch.no_grad(): + for step, batch in enumerate(val_loader): + loss = model(batch) + val_loss += loss.item() + val_loss = val_loss / len(val_loader) + print(epoch_str + f", Val Loss: {val_loss}") + + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + + # Step 5 Evaluating + print("Final Evaluation...") + model.eval() + eval_output = [] + progress_bar_test = tqdm(range(len(test_loader))) + for step, batch in enumerate(test_loader): + with torch.no_grad(): + output = model.inference(batch) + eval_output.append(output) + + progress_bar_test.update(1) + + # Step 6 Post-processing & compute metrics + acc = compute_accuracy(eval_output) + print(f'Test Acc {acc}') + # save model + torch.save(model, "gat_llama.pt") + return prep_time, dataset, model + + +def minimal_demo(model, dataset): + # Step 1: Define a single batch size test loader + idx_split = dataset.split_idxs + test_dataset = [dataset[i] for i in idx_split['test']] + # batch size 1 loader for simplicity + loader = DataLoader(test_dataset, batch_size=1, drop_last=False, + pin_memory=True, shuffle=False) + # define the pure pretrained LLM + pure_llm = LLAMA2() + + # Step loop through the loader and run both models + gnn_llm_hallucin_sum = 0 + pure_llm_hallucin_sum = 0 + final_prnt_str = "" + print("Checking LLM vs GNN+LLM for hallucinations...") + for batch in tqdm(loader): + question = batch.question[0] + correct_answer = batch.label[0] + gnn_llm_out = model.inference(batch) + pure_llm_out = pure_llm.inference(batch) + gnn_llm_pred = gnn_llm_out['pred'][0] + pure_llm_pred = pure_llm_out['pred'][0] + gnn_llm_hallucinates = detect_hallucinate(gnn_llm_pred, correct_answer) + pure_llm_hallucinates = detect_hallucinate(pure_llm_pred, + correct_answer) + if gnn_llm_hallucinates == "skip" or pure_llm_hallucinates == "skip": + # skipping since hard to evaluate if the answer's are hallucinations + continue + gnn_llm_hallucin_sum += bool(gnn_llm_hallucinates) + pure_llm_hallucin_sum += bool(pure_llm_hallucinates) + # showcase LLM hallucinations solved by GNN + if pure_llm_hallucinates and not gnn_llm_hallucinates: + final_prnt_str += "Question: " + question + "\n" + final_prnt_str += "Correct Answer: " + correct_answer + "\n" + final_prnt_str += "Pure LLM Prediction: " + pure_llm_pred + "\n" + final_prnt_str += "GNN+LLM Prediction:" + gnn_llm_pred + "\n" + final_prnt_str += "#" * 20 + "\n" + print("Total GNN+LLM Hallucinations:", gnn_llm_hallucin_sum) + print("Total Pure LLM Hallucinations:", pure_llm_hallucin_sum) + percent = 100.0 * round(1 - + (gnn_llm_hallucin_sum / pure_llm_hallucin_sum), 2) + print(f"GNN reduces hallucinations by: ~{percent}%") + print("Note: hallucinations detected by regex hence the ~") + print("Instances where GNN solves the hallucinations of Pure LLMs:") + print(final_prnt_str) + + +if __name__ == "__main__": + # check if saved model + if path.exists("gat_llama.pt"): + print("Existing trained model found.") + # ask if want to retrain or skip to demo + print("Would you like to retrain?") + user_input = str(input("(y/n):")).lower() + retrain = user_input == "y" + else: + retrain = True + if retrain: + parser = argparse.ArgumentParser() + parser.add_argument('--hidden_channels', type=int, default=1024) + parser.add_argument('--num_gnn_layers', type=int, default=4) + parser.add_argument('--lr', type=float, default=1e-5) + parser.add_argument('--epochs', type=int, default=10) + parser.add_argument('--batch_size', type=int, default=4) + args = parser.parse_args() + since = time.time() + prep_time, dataset, model = main(since, args.epochs, + args.hidden_channels, + args.num_gnn_layers, args.batch_size, + args.lr) + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + gc.collect() + e2e_time = round(time.time() - since, 2) + print("E2E time (e2e_time) =", e2e_time) + print("E2E time minus Prep Time =", e2e_time - prep_time) + else: + model = torch.load("gat_llama.pt") + dataset = WebQSPDataset() + print( + "Would you like a minimal demo showcasing how GNN+LLM can solve LLM hallucinations?" + ) + user_input = str(input("(y/n):")).lower() + if user_input == "y": + minimal_demo(model, dataset) diff --git a/torch_geometric/datasets/__init__.py b/torch_geometric/datasets/__init__.py index 9face4868c2f..de8d133c7971 100644 --- a/torch_geometric/datasets/__init__.py +++ b/torch_geometric/datasets/__init__.py @@ -76,6 +76,7 @@ from .wikidata import Wikidata5M from .myket import MyketDataset from .brca_tgca import BrcaTcga +from .web_qsp_dataset import WebQSPDataset from .dbp15k import DBP15K from .aminer import AMiner @@ -109,6 +110,7 @@ import torch_geometric.datasets.utils homo_datasets = [ + 'WebQSPDataset', 'KarateClub', 'TUDataset', 'GNNBenchmarkDataset', diff --git a/torch_geometric/datasets/web_qsp_dataset.py b/torch_geometric/datasets/web_qsp_dataset.py new file mode 100644 index 000000000000..2284321c2284 --- /dev/null +++ b/torch_geometric/datasets/web_qsp_dataset.py @@ -0,0 +1,358 @@ +from typing import Dict, List, Tuple, no_type_check + +import numpy as np + +try: + import pandas as pd + from pandas import DataFrame as df + WITH_PANDAS = True +except ImportError as e: # noqa + df = None + WITH_PANDAS = False +import torch +import torch.nn.functional as F + +try: + from pcst_fast import pcst_fast + WITH_PCST = True +except ImportError as e: # noqa + WITH_PCST = False +from torch.utils.data import DataLoader +from tqdm import tqdm + +try: + from transformers import AutoModel, AutoTokenizer + WITH_TRANSFORMERS = True +except ImportError as e: # noqa + WITH_TRANSFORMERS = False +try: + import datasets + WITH_DATASETS = True +except ImportError as e: # noqa + WITH_DATASETS = False + +from torch_geometric.data import Data, InMemoryDataset + + +@no_type_check +def retrieval_via_pcst(graph: Data, q_emb: torch.Tensor, textual_nodes: df, + textual_edges: df, topk: int = 3, topk_e: int = 3, + cost_e: float = 0.5) -> Tuple[Data, str]: + # from original G-Retriever work + # https://arxiv.org/abs/2402.07630 + c = 0.01 + if len(textual_nodes) == 0 or len(textual_edges) == 0: + desc = textual_nodes.to_csv(index=False) + "\n" + textual_edges.to_csv( + index=False, columns=["src", "edge_attr", "dst"]) + graph = Data(x=graph.x, edge_index=graph.edge_index, + edge_attr=graph.edge_attr, num_nodes=graph.num_nodes) + return graph, desc + + root = -1 # unrooted + num_clusters = 1 + pruning = "gw" + verbosity_level = 0 + if topk > 0: + n_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, graph.x) + topk = min(topk, graph.num_nodes) + _, topk_n_indices = torch.topk(n_prizes, topk, largest=True) + + n_prizes = torch.zeros_like(n_prizes) + n_prizes[topk_n_indices] = torch.arange(topk, 0, -1).float() + else: + n_prizes = torch.zeros(graph.num_nodes) + + if topk_e > 0: + e_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, graph.edge_attr) + topk_e = min(topk_e, e_prizes.unique().size(0)) + + topk_e_values, _ = torch.topk(e_prizes.unique(), topk_e, largest=True) + e_prizes[e_prizes < topk_e_values[-1]] = 0.0 + last_topk_e_value = topk_e + for k in range(topk_e): + indices = e_prizes == topk_e_values[k] + value = min((topk_e - k) / sum(indices), last_topk_e_value - c) + e_prizes[indices] = value + last_topk_e_value = value + # cost_e = max(min(cost_e, e_prizes.max().item()-c), 0) + else: + e_prizes = torch.zeros(graph.num_edges) + + costs = [] + edges = [] + virtual_n_prizes = [] + virtual_edges = [] + virtual_costs = [] + mapping_n = {} + mapping_e = {} + for i, (src, dst) in enumerate(graph.edge_index.T.numpy()): + prize_e = e_prizes[i] + if prize_e <= cost_e: + mapping_e[len(edges)] = i + edges.append((src, dst)) + costs.append(cost_e - prize_e) + else: + virtual_node_id = graph.num_nodes + len(virtual_n_prizes) + mapping_n[virtual_node_id] = i + virtual_edges.append((src, virtual_node_id)) + virtual_edges.append((virtual_node_id, dst)) + virtual_costs.append(0) + virtual_costs.append(0) + virtual_n_prizes.append(prize_e - cost_e) + + prizes = np.concatenate([n_prizes, np.array(virtual_n_prizes)]) + num_edges = len(edges) + if len(virtual_costs) > 0: + costs = np.array(costs + virtual_costs) + edges = np.array(edges + virtual_edges) + + vertices, edges = pcst_fast(edges, prizes, costs, root, num_clusters, + pruning, verbosity_level) + + selected_nodes = vertices[vertices < graph.num_nodes] + selected_edges = [mapping_e[e] for e in edges if e < num_edges] + virtual_vertices = vertices[vertices >= graph.num_nodes] + if len(virtual_vertices) > 0: + virtual_vertices = vertices[vertices >= graph.num_nodes] + virtual_edges = [mapping_n[i] for i in virtual_vertices] + selected_edges = np.array(selected_edges + virtual_edges) + + edge_index = graph.edge_index[:, selected_edges] + selected_nodes = np.unique( + np.concatenate( + [selected_nodes, edge_index[0].numpy(), edge_index[1].numpy()])) + + n = textual_nodes.iloc[selected_nodes] + e = textual_edges.iloc[selected_edges] + desc = n.to_csv(index=False) + "\n" + e.to_csv( + index=False, columns=["src", "edge_attr", "dst"]) + + mapping = {n: i for i, n in enumerate(selected_nodes.tolist())} + + x = graph.x[selected_nodes] + edge_attr = graph.edge_attr[selected_edges] + src = [mapping[i] for i in edge_index[0].tolist()] + dst = [mapping[i] for i in edge_index[1].tolist()] + edge_index = torch.LongTensor([src, dst]) + data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, + num_nodes=len(selected_nodes)) + + return data, desc + + +class Dataset(torch.utils.data.Dataset): + def __init__(self, input_ids: torch.Tensor, + attention_mask: torch.Tensor) -> None: + super().__init__() + self.data = { + "input_ids": input_ids, + "att_mask": attention_mask, + } + + def __len__(self) -> int: + return self.data["input_ids"].size(0) + + def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: + if isinstance(index, torch.Tensor): + index = index.item() + batch_data = dict() + for key in self.data.keys(): + if self.data[key] is not None: + batch_data[key] = self.data[key][index] + return batch_data + + +class Sentence_Transformer(torch.nn.Module): + def __init__(self, pretrained_repo: str) -> None: + super(Sentence_Transformer, self).__init__() + print(f"inherit model weights from {pretrained_repo}") + self.bert_model = AutoModel.from_pretrained(pretrained_repo) + + def mean_pooling(self, token_embeddings: torch.Tensor, + attention_mask: torch.Tensor) -> torch.Tensor: + data_type = token_embeddings.dtype + input_mask_expanded = attention_mask.unsqueeze(-1).expand( + token_embeddings.size()).to(data_type) + return torch.sum(token_embeddings * input_mask_expanded, + 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) + + def forward(self, input_ids: torch.Tensor, + att_mask: torch.Tensor) -> torch.Tensor: + bert_out = self.bert_model(input_ids=input_ids, + attention_mask=att_mask) + + # First element of model_output contains all token embeddings + token_embeddings = bert_out[0] + sentence_embeddings = self.mean_pooling(token_embeddings, att_mask) + sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) + return sentence_embeddings + + +def sbert_text2embedding(model: Sentence_Transformer, + tokenizer: torch.nn.Module, device: torch.device, + text: List[str]) -> torch.Tensor: + try: + encoding = tokenizer(text, padding=True, truncation=True, + return_tensors="pt") + dataset = Dataset(input_ids=encoding.input_ids, + attention_mask=encoding.attention_mask) + + # DataLoader + dataloader = DataLoader(dataset, batch_size=256, shuffle=False) + + # Placeholder for storing the embeddings + all_embeddings_list = [] + + # Iterate through batches + with torch.no_grad(): + + for batch in dataloader: + # Move batch to the appropriate device + batch = {key: value.to(device) for key, value in batch.items()} + + # Forward pass + embeddings = model(input_ids=batch["input_ids"], + att_mask=batch["att_mask"]) + + # Append the embeddings to the list + all_embeddings_list.append(embeddings) + + # Concatenate the embeddings from all batches + all_embeddings = torch.cat(all_embeddings_list, dim=0).cpu() + except: # noqa + print( + "SBERT text embedding failed, returning torch.zeros((0, 1024))...") + return torch.zeros((0, 1024)) + + return all_embeddings + + +class WebQSPDataset(InMemoryDataset): + r"""The WebQuestionsSP dataset was released as part of + “The Value of Semantic Parse Labeling for Knowledge + Base Question Answering” + [Yih, Richardson, Meek, Chang & Suh, 2016]. + It contains semantic parses, vs. answers, for a set of questions + that originally comes from WebQuestions [Berant et al., 2013]." + Processing based on "G-Retriever: Retrieval-Augmented Generation + for Textual Graph Understanding and Question Answering". + Requires datasets and transformers from HuggingFace. + + Args: + root (str): Root directory where the dataset should be saved. + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) + """ + def __init__( + self, + root: str = "", + force_reload: bool = False, + ) -> None: + missing_imports = False + missing_str_list = [] + if not WITH_PCST: + missing_str_list.append('pcst_fast') + missing_imports = True + if not WITH_TRANSFORMERS: + missing_str_list.append('transformers') + missing_imports = True + if not WITH_DATASETS: + missing_str_list.append('datasets') + missing_imports = True + if not WITH_PANDAS: + missing_str_list.append('pandas') + missing_imports = True + if missing_imports: + missing_str = ' '.join(missing_str_list) + error_out = f"`pip install {missing_str}` to use this dataset." + raise ImportError(error_out) + self.prompt = "Please answer the given question." + self.graph = None + self.graph_type = "Knowledge Graph" + self.model_name = "sbert" + self.device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu") + super().__init__(root, None, None, force_reload=force_reload) + self.load(self.processed_paths[0]) + + @property + def raw_file_names(self) -> List[str]: + return [] + + @property + def processed_file_names(self) -> List[str]: + return ["list_of_graphs.pt", "pre_filter.pt", "pre_transform.pt"] + + def download(self) -> None: + dataset = datasets.load_dataset("rmanluo/RoG-webqsp") + self.raw_dataset = datasets.concatenate_datasets( + [dataset["train"], dataset["validation"], dataset["test"]]) + self.split_idxs = { + "train": + torch.arange(len(dataset["train"])), + "val": + torch.arange(len(dataset["validation"])) + len(dataset["train"]), + "test": + torch.arange(len(dataset["test"])) + len(dataset["train"]) + + len(dataset["validation"]) + } + + def process(self) -> None: + pretrained_repo = "sentence-transformers/all-roberta-large-v1" + self.model = Sentence_Transformer(pretrained_repo) + self.model.to(self.device) + self.model.eval() + self.tokenizer = AutoTokenizer.from_pretrained(pretrained_repo) + self.text2embedding = sbert_text2embedding + self.questions = [i["question"] for i in self.raw_dataset] + list_of_graphs = [] + # encode questions + print("Encoding questions...") + q_embs = self.text2embedding(self.model, self.tokenizer, self.device, + self.questions) + print("Encoding graphs...") + for index in tqdm(range(len(self.raw_dataset))): + data_i = self.raw_dataset[index] + raw_nodes: Dict[str, int] = {} + raw_edges = [] + for tri in data_i["graph"]: + h, r, t = tri + h = h.lower() + t = t.lower() + if h not in raw_nodes: + raw_nodes[h] = len(raw_nodes) + if t not in raw_nodes: + raw_nodes[t] = len(raw_nodes) + raw_edges.append({ + "src": raw_nodes[h], + "edge_attr": r, + "dst": raw_nodes[t] + }) + nodes = pd.DataFrame([{ + "node_id": v, + "node_attr": k + } for k, v in raw_nodes.items()], columns=["node_id", "node_attr"]) + edges = pd.DataFrame(raw_edges, + columns=["src", "edge_attr", "dst"]) + # encode nodes + nodes.node_attr.fillna("", inplace=True) + x = self.text2embedding(self.model, self.tokenizer, self.device, + nodes.node_attr.tolist()) + # encode edges + edge_attr = self.text2embedding(self.model, self.tokenizer, + self.device, + edges.edge_attr.tolist()) + edge_index = torch.LongTensor( + [edges.src.tolist(), edges.dst.tolist()]) + question = f"Question: {data_i['question']}\nAnswer: " + label = ("|").join(data_i["answer"]).lower() + raw_graph = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, + num_nodes=len(nodes)).to("cpu") + psct_subgraph, desc = retrieval_via_pcst(raw_graph, q_embs[index], + nodes, edges, topk=3, + topk_e=5, cost_e=0.5) + psct_subgraph["question"] = question + psct_subgraph["label"] = label + psct_subgraph["desc"] = desc + list_of_graphs.append(psct_subgraph.to("cpu")) + self.save(list_of_graphs, self.processed_paths[0])