Skip to content

Commit

Permalink
drafting
Browse files Browse the repository at this point in the history
  • Loading branch information
puririshi98 committed Jun 11, 2024
2 parents 7901b4a + 196301c commit fac0577
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 77 deletions.
8 changes: 4 additions & 4 deletions examples/llm_plus_gnn/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Examples for LLM and GNN co-training

| Example | Description |
| ------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training LLAMA2 with GAT for answering questions based on knowledge graph information |
| [`interactive_g_retriever_demo.py`](./interactive_g_retriever_demo.py) | Example of how to interact with a G Retriever model.|
| Example | Description |
| ---------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [`g_retriever.py`](./g_retriever.py) | Example for Retrieval-Augmented Generation (RAG) w/ GNN+LLM by co-training LLAMA2 with GAT for answering questions based on knowledge graph information |
| [`interactive_g_retriever_demo.py`](./interactive_g_retriever_demo.py) | Example of how to interact with a G Retriever model. |
9 changes: 6 additions & 3 deletions examples/llm_plus_gnn/g_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from torch_geometric.loader import DataLoader
from torch_geometric.nn.models import GRetriever
from torch_geometric.nn.nlp import LLM
from torch_geometric.nn.nlp.llm import max_new_tokens
from torch_geometric.nn.nlp.llm import max_new_tokens


def detect_hallucinate(pred, label):
try:
Expand Down Expand Up @@ -107,9 +108,11 @@ def get_loss(model, batch, model_save_name) -> torch.Tensor:
batch.ptr, batch.label, batch.edge_attr, batch.desc)


def inference_step(model, batch, model_save_name, max_out_tokens=max_new_tokens):
def inference_step(model, batch, model_save_name,
max_out_tokens=max_new_tokens):
if model_save_name == "llm":
return model.inference(batch.question, batch.desc, max_out_tokens=max_out_tokens)
return model.inference(batch.question, batch.desc,
max_out_tokens=max_out_tokens)
else:
return model.inference(batch.question, batch.x, batch.edge_index,
batch.batch, batch.ptr, batch.edge_attr,
Expand Down
152 changes: 82 additions & 70 deletions examples/llm_plus_gnn/interactive_g_retriever_demo.py
Original file line number Diff line number Diff line change
@@ -1,85 +1,97 @@
import gc
from typing import List, Tuple

import torch
from torch_geometric.data import Data
from g_retriever import inference_step, load_params_dict
from typing import List, Tuple
from torch_geometric.nn.nlp import SentenceTransformer, LLM
from torch_geometric.nn.models import GRetriever
import gc

from torch_geometric import seed_everything
from torch_geometric.data import Data
from torch_geometric.nn.models import GRetriever
from torch_geometric.nn.nlp import LLM, SentenceTransformer


def make_data_obj(text_encoder: SentenceTransformer, question: str,
nodes: List[Tuple[str, str]],
edges: List[Tuple[str, str, str]]) -> Data:
data = Data()
num_nodes = 0
# list of size 1 to simulate batchsize=1
# true inference setting would batch user queries
data.question = [question]
data.num_nodes = len(nodes)
data.n_id = torch.arange(data.num_nodes).to(torch.int64)

def make_data_obj(text_encoder: SentenceTransformer, question: str, nodes: List[Tuple[str, str]] , edges:List[Tuple[str, str, str]]) -> Data:
data = Data()
num_nodes = 0
# list of size 1 to simulate batchsize=1
# true inference setting would batch user queries
data.question = [question]
data.num_nodes = len(nodes)
data.n_id = torch.arange(data.num_nodes).to(torch.int64)
# Model expects batches sampled from Dataloader
# hardcoding values for single item batch
data.batch = torch.zeros(data.num_nodes).to(torch.int64)
data.ptr = torch.tensor([0, data.num_nodes]).to(torch.int64)

# Model expects batches sampled from Dataloader
# hardcoding values for single item batch
data.batch = torch.zeros(data.num_nodes).to(torch.int64)
data.ptr = torch.tensor([0, data.num_nodes]).to(torch.int64)
graph_text_description = "node_id,node_attr" + "\n"
# collect node attributes
to_encode = []
for node_id, node_attr in nodes:
to_encode.append(node_attr)
graph_text_description += str(node_id) + "," + str(node_attr) + "\n"

graph_text_description = "node_id,node_attr" + "\n"
# collect node attributes
to_encode = []
for node_id, node_attr in nodes:
to_encode.append(node_attr)
graph_text_description += str(node_id) + "," + str(node_attr) + "\n"
# collect edge info
data.num_edges = len(edges)
graph_text_description += "src,edge_attr,dst" + "\n"
src_ids, dst_ids, e_attrs = [], [], []
for src_id, e_attr, dst_id in edges:
src_ids.append(int(src_id))
dst_ids.append(int(dst_id))
e_attrs.append(e_attr)
graph_text_description += str(src_id) + "," + str(e_attr) + "," + str(
dst_id) + "\n"
to_encode += e_attrs

# collect edge info
data.num_edges = len(edges)
graph_text_description += "src,edge_attr,dst" + "\n"
src_ids, dst_ids, e_attrs = [], [], []
for src_id, e_attr, dst_id in edges:
src_ids.append(int(src_id))
dst_ids.append(int(dst_id))
e_attrs.append(e_attr)
graph_text_description += str(src_id) + "," + str(e_attr) + "," + str(dst_id) +"\n"
to_encode += e_attrs
# encode text
encoded_text = text_encoder.encode(to_encode)

# encode text
encoded_text = text_encoder.encode(to_encode)
# store processed data
data.x = encoded_text[:data.num_nodes]
data.edge_attr = encoded_text[data.num_nodes:data.num_nodes +
data.num_edges]
data.edge_index = torch.tensor([src_ids, dst_ids]).to(torch.int64)
data.desc = [graph_text_description[:-1]] # remove last newline

# store processed data
data.x = encoded_text[:data.num_nodes]
data.edge_attr = encoded_text[data.num_nodes:data.num_nodes+data.num_edges]
data.edge_index = torch.tensor([src_ids, dst_ids]).to(torch.int64)
data.desc = [graph_text_description[:-1]] # remove last newline
return data

return data

def user_input_data():
q_input = input("Please enter your Question:\n")
question = f"Question: {q_input}\nAnswer: "
print("\nPlease enter the node attributes with format 'n_id,textual_node_attribute'.") # noqa
print("Please ensure to order n_ids from 0, 1, 2, ..., num_nodes-1.")
print("Use [[stop]] to stop inputting.")
nodes = []
most_recent_node = ""
while True:
most_recent_node = input()
if most_recent_node == "[[stop]]":
break
else:
nodes.append(tuple(most_recent_node.split(',')))
print("\nPlease enter the edge attributes with format 'src_id,textual_edge_attribute,dst_id'") # noqa
print("Use [[stop]] to stop inputting.")
edges = []
most_recent_edge = ""
while True:
most_recent_edge = input()
if most_recent_edge == "[[stop]]":
break
else:
edges.append(tuple(most_recent_edge.split(',')))
print("Creating data object...")
text_encoder = SentenceTransformer()
data_obj = make_data_obj(text_encoder, question, nodes, edges)
print("Done!")
print("data =", data_obj)
return data_obj
q_input = input("Please enter your Question:\n")
question = f"Question: {q_input}\nAnswer: "
print(
"\nPlease enter the node attributes with format 'n_id,textual_node_attribute'."
) # noqa
print("Please ensure to order n_ids from 0, 1, 2, ..., num_nodes-1.")
print("Use [[stop]] to stop inputting.")
nodes = []
most_recent_node = ""
while True:
most_recent_node = input()
if most_recent_node == "[[stop]]":
break
else:
nodes.append(tuple(most_recent_node.split(',')))
print(
"\nPlease enter the edge attributes with format 'src_id,textual_edge_attribute,dst_id'"
) # noqa
print("Use [[stop]] to stop inputting.")
edges = []
most_recent_edge = ""
while True:
most_recent_edge = input()
if most_recent_edge == "[[stop]]":
break
else:
edges.append(tuple(most_recent_edge.split(',')))
print("Creating data object...")
text_encoder = SentenceTransformer()
data_obj = make_data_obj(text_encoder, question, nodes, edges)
print("Done!")
print("data =", data_obj)
return data_obj


if __name__ == "__main__":
Expand Down

0 comments on commit fac0577

Please sign in to comment.