Skip to content

Commit

Permalink
Cleanup (#9242)
Browse files Browse the repository at this point in the history
cleaning up to address review for
#9167

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
puririshi98 and pre-commit-ci[bot] authored Apr 27, 2024
1 parent 8b327d0 commit 2463a43
Show file tree
Hide file tree
Showing 8 changed files with 670 additions and 589 deletions.
30 changes: 25 additions & 5 deletions examples/llm_plus_gnn/g_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from torch_geometric import seed_everything
from torch_geometric.datasets import WebQSPDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn.models.gnn_llm import GNN_LLM, LLM
from torch_geometric.nn.models import GRetriever
from torch_geometric.nn.text import LLM


def detect_hallucinate(pred, label):
Expand Down Expand Up @@ -99,6 +100,23 @@ def load_params_dict(model, save_path):
return model


def get_loss(model, batch, model_save_name):
if model_save_name == "llm":
return model(batch.question, batch.label, batch.desc)
else:
return model(batch.question, batch.x, batch.edge_index, batch.batch,
batch.ptr, batch.label, batch.edge_attr, batch.desc)


def inference_step(model, batch, model_save_name):
if model_save_name == "llm":
return model.inference(batch.question, batch.desc)
else:
return model.inference(batch.question, batch.x, batch.edge_index,
batch.batch, batch.ptr, batch.edge_attr,
batch.desc)


def train(since, num_epochs, hidden_channels, num_gnn_layers, batch_size,
eval_batch_size, lr, model=None, dataset=None, checkpointing=False):
def adjust_learning_rate(param_group, LR, epoch):
Expand Down Expand Up @@ -170,7 +188,7 @@ def adjust_learning_rate(param_group, LR, epoch):
loader = tqdm(train_loader, desc=epoch_str)
for step, batch in enumerate(loader):
optimizer.zero_grad()
loss = model(batch)
loss = get_loss(model, batch, model_save_name)
loss.backward()

clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1)
Expand All @@ -192,7 +210,7 @@ def adjust_learning_rate(param_group, LR, epoch):
model.eval()
with torch.no_grad():
for step, batch in enumerate(val_loader):
loss = model(batch)
loss = get_loss(model, batch, model_save_name)
val_loss += loss.item()
val_loss = val_loss / len(val_loader)
print(epoch_str + f", Val Loss: {val_loss}")
Expand All @@ -215,7 +233,8 @@ def adjust_learning_rate(param_group, LR, epoch):
print("Final Evaluation...")
for step, batch in enumerate(test_loader):
with torch.no_grad():
output = model.inference(batch)
output = inference_step(model, batch, model_save_name)
output["label"] = batch.label
eval_output.append(output)
progress_bar_test.update(1)

Expand Down Expand Up @@ -263,7 +282,8 @@ def minimal_demo(gnn_llm_eval_outs, dataset, lr, epochs, batch_size,
question = batch.question[0]
correct_answer = batch.label[0]
# GNN+LLM only using 32 tokens to answer, give untrained LLM more
pure_llm_out = pure_llm.inference(batch, max_out_tokens=256)
pure_llm_out = pure_llm.inference(batch.question, batch.desc,
max_out_tokens=256)
gnn_llm_pred = gnn_llm_preds[i]
pure_llm_pred = pure_llm_out['pred'][0]
gnn_llm_hallucinates = detect_hallucinate(gnn_llm_pred,
Expand Down
119 changes: 9 additions & 110 deletions torch_geometric/datasets/web_qsp_dataset.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,25 @@
from typing import Dict, List, Tuple, Union, no_type_check
from typing import Dict, List, Tuple, no_type_check

import numpy as np

try:
import pandas as pd
from pandas import DataFrame as df
WITH_PANDAS = True
except ImportError as e: # noqa
df = None
WITH_PANDAS = False
import torch
import torch.nn.functional as F

try:
from pcst_fast import pcst_fast
WITH_PCST = True
except ImportError as e: # noqa
WITH_PCST = False
from torch.utils.data import DataLoader
from tqdm import tqdm

try:
from transformers import AutoModel, AutoTokenizer
WITH_TRANSFORMERS = True
except ImportError as e: # noqa
WITH_TRANSFORMERS = False
from torch_geometric.nn.text import SentenceTransformer, text2embedding

try:
import datasets
WITH_DATASETS = True
Expand Down Expand Up @@ -140,94 +136,6 @@ def retrieval_via_pcst(graph: Data, q_emb: torch.Tensor, textual_nodes: df,
return data, desc


class Dataset(torch.utils.data.Dataset):
def __init__(self, input_ids: torch.Tensor,
attention_mask: torch.Tensor) -> None:
super().__init__()
self.data = {
"input_ids": input_ids,
"att_mask": attention_mask,
}

def __len__(self) -> int:
return self.data["input_ids"].size(0)

def __getitem__(
self, index: Union[int, torch.Tensor]) -> Dict[str, torch.Tensor]:
if isinstance(index, torch.Tensor):
index = index.item()
batch_data = dict()
for key in self.data.keys():
if self.data[key] is not None:
batch_data[key] = self.data[key][index]
return batch_data


class SentenceTransformer(torch.nn.Module):
def __init__(self, pretrained_repo: str) -> None:
super().__init__()
print(f"inherit model weights from {pretrained_repo}")
self.bert_model = AutoModel.from_pretrained(pretrained_repo)

def mean_pooling(self, token_embeddings: torch.Tensor,
attention_mask: torch.Tensor) -> torch.Tensor:
data_type = token_embeddings.dtype
input_mask_expanded = attention_mask.unsqueeze(-1).expand(
token_embeddings.size()).to(data_type)
return torch.sum(token_embeddings * input_mask_expanded,
1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

def forward(self, input_ids: torch.Tensor,
att_mask: torch.Tensor) -> torch.Tensor:
bert_out = self.bert_model(input_ids=input_ids,
attention_mask=att_mask)

# First element of model_output contains all token embeddings
token_embeddings = bert_out[0]
sentence_embeddings = self.mean_pooling(token_embeddings, att_mask)
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
return sentence_embeddings


def sbert_text2embedding(model: SentenceTransformer,
tokenizer: torch.nn.Module, device: torch.device,
text: List[str]) -> torch.Tensor:
try:
encoding = tokenizer(text, padding=True, truncation=True,
return_tensors="pt")
dataset = Dataset(input_ids=encoding.input_ids,
attention_mask=encoding.attention_mask)

# DataLoader
dataloader = DataLoader(dataset, batch_size=256, shuffle=False)

# Placeholder for storing the embeddings
all_embeddings_list = []

# Iterate through batches
with torch.no_grad():

for batch in dataloader:
# Move batch to the appropriate device
batch = {key: value.to(device) for key, value in batch.items()}

# Forward pass
embeddings = model(input_ids=batch["input_ids"],
att_mask=batch["att_mask"])

# Append the embeddings to the list
all_embeddings_list.append(embeddings)

# Concatenate the embeddings from all batches
all_embeddings = torch.cat(all_embeddings_list, dim=0).cpu()
except: # noqa
print(
"SBERT text embedding failed, returning torch.zeros((0, 1024))...")
return torch.zeros((0, 1024))

return all_embeddings


class WebQSPDataset(InMemoryDataset):
r"""The WebQuestionsSP dataset was released as part of
“The Value of Semantic Parse Labeling for Knowledge
Expand All @@ -254,9 +162,6 @@ def __init__(
if not WITH_PCST:
missing_str_list.append('pcst_fast')
missing_imports = True
if not WITH_TRANSFORMERS:
missing_str_list.append('transformers')
missing_imports = True
if not WITH_DATASETS:
missing_str_list.append('datasets')
missing_imports = True
Expand All @@ -267,7 +172,6 @@ def __init__(
missing_str = ' '.join(missing_str_list)
error_out = f"`pip install {missing_str}` to use this dataset."
raise ImportError(error_out)
self.prompt = "Please answer the given question."
self.graph = None
self.graph_type = "Knowledge Graph"
self.model_name = "sbert"
Expand Down Expand Up @@ -299,19 +203,15 @@ def download(self) -> None:
}

def process(self) -> None:
import pandas
pretrained_repo = "sentence-transformers/all-roberta-large-v1"
self.model = SentenceTransformer(pretrained_repo)
self.model.to(self.device)
self.model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_repo)
self.text2embedding = sbert_text2embedding
self.questions = [i["question"] for i in self.raw_dataset]
list_of_graphs = []
# encode questions
print("Encoding questions...")
q_embs = self.text2embedding(self.model, self.tokenizer, self.device,
self.questions)
q_embs = text2embedding(self.model, self.device, self.questions)
print("Encoding graphs...")
for index in tqdm(range(len(self.raw_dataset))):
data_i = self.raw_dataset[index]
Expand All @@ -338,12 +238,11 @@ def process(self) -> None:
columns=["src", "edge_attr", "dst"])
# encode nodes
nodes.node_attr.fillna("", inplace=True)
x = self.text2embedding(self.model, self.tokenizer, self.device,
nodes.node_attr.tolist())
x = text2embedding(self.model, self.device,
nodes.node_attr.tolist())
# encode edges
edge_attr = self.text2embedding(self.model, self.tokenizer,
self.device,
edges.edge_attr.tolist())
edge_attr = text2embedding(self.model, self.device,
edges.edge_attr.tolist())
edge_index = torch.LongTensor(
[edges.src.tolist(), edges.dst.tolist()])
question = f"Question: {data_i['question']}\nAnswer: "
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/nn/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from .pmlp import PMLP
from .neural_fingerprint import NeuralFingerprint
from .visnet import ViSNet
from .gnn_llm import GNN_LLM
from .g_retriever import GRetriever
# Deprecated:
from torch_geometric.explain.algorithm.captum import (to_captum_input,
captum_output_to_dicts)
Expand Down Expand Up @@ -74,5 +74,5 @@
'PMLP',
'NeuralFingerprint',
'ViSNet',
'GNN_LLM',
'GRetriever',
]
Loading

0 comments on commit 2463a43

Please sign in to comment.