Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GNN-RAG with PyG #9852

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
9 changes: 5 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `EdgeIndex.sparse_resize_` functionality ([#8983](https://github.com/pyg-team/pytorch_geometric/pull/8983))
- Added approximate `faiss`-based KNN-search ([#8952](https://github.com/pyg-team/pytorch_geometric/pull/8952))
- Added documentation on environment setup on XPU device ([#9407](https://github.com/pyg-team/pytorch_geometric/pull/9407))
- Added GNN-RAG PyG support

### Changed

Expand Down Expand Up @@ -270,7 +271,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added possibility to run training benchmarks on XPU device ([#7925](https://github.com/pyg-team/pytorch_geometric/pull/7925))
- Added `utils.ppr` for personalized PageRank computation ([#7917](https://github.com/pyg-team/pytorch_geometric/pull/7917))
- Added support for XPU device in `PrefetchLoader` ([#7918](https://github.com/pyg-team/pytorch_geometric/pull/7918))
- Added support for floating-point slicing in `Dataset`, *e.g.*, `dataset[:0.9]` ([#7915](https://github.com/pyg-team/pytorch_geometric/pull/7915))
- Added support for floating-point slicing in `Dataset`, _e.g._, `dataset[:0.9]` ([#7915](https://github.com/pyg-team/pytorch_geometric/pull/7915))
- Added nightly GPU tests ([#7895](https://github.com/pyg-team/pytorch_geometric/pull/7895))
- Added the `HalfHop` graph upsampling augmentation ([#7827](https://github.com/pyg-team/pytorch_geometric/pull/7827))
- Added the `Wikidata5M` dataset ([#7864](https://github.com/pyg-team/pytorch_geometric/pull/7864))
Expand All @@ -288,7 +289,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added back support for PyTorch >= 1.11.0 ([#7656](https://github.com/pyg-team/pytorch_geometric/pull/7656))
- Added `Data.sort()` and `HeteroData.sort()` functionalities ([#7649](https://github.com/pyg-team/pytorch_geometric/pull/7649))
- Added `torch.nested_tensor` support in `Data` and `Batch` ([#7643](https://github.com/pyg-team/pytorch_geometric/pull/7643), [#7647](https://github.com/pyg-team/pytorch_geometric/pull/7647))
- Added `interval` argument to `Cartesian`, `LocalCartesian` and `Distance` transformations ([#7533](https://github.com/pyg-team/pytorch_geometric/pull/7533), [#7614](https://github.com/pyg-team/pytorch_geometric/pull/7614), [#7700](https://github.com/pyg-team/pytorch_geometric/pull/7700))
- Added `interval` argument to `Cartesian`, `LocalCartesian` and `Distance` transformations ([#7533](https://github.com/pyg-team/pytorch_geometric/pull/7533), [#7614](https://github.com/pyg-team/pytorch_geometric/pull/7614), [#7700](https://github.com/pyg-team/pytorch_geometric/pull/7700))
- Added a `LightGCN` example on the `AmazonBook` dataset ([7603](https://github.com/pyg-team/pytorch_geometric/pull/7603))
- Added a tutorial on hierarchical neighborhood sampling ([#7594](https://github.com/pyg-team/pytorch_geometric/pull/7594))
- Enabled different attention modes in `HypergraphConv` via the `attention_mode` argument ([#7601](https://github.com/pyg-team/pytorch_geometric/pull/7601))
Expand Down Expand Up @@ -348,7 +349,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Fixed `HeteroConv` for layers that have a non-default argument order, *e.g.*, `GCN2Conv` ([#8166](https://github.com/pyg-team/pytorch_geometric/pull/8166))
- Fixed `HeteroConv` for layers that have a non-default argument order, _e.g._, `GCN2Conv` ([#8166](https://github.com/pyg-team/pytorch_geometric/pull/8166))
- Handle reserved keywords as keys in `ModuleDict` and `ParameterDict` ([#8163](https://github.com/pyg-team/pytorch_geometric/pull/8163))
- Updated the examples and tutorials to account for `torch.compile(dynamic=True)` in PyTorch 2.1.0 ([#8145](https://github.com/pyg-team/pytorch_geometric/pull/8145))
- Enabled dense eigenvalue computation in `AddLaplacianEigenvectorPE` for small-scale graphs ([#8143](https://github.com/pyg-team/pytorch_geometric/pull/8143))
Expand All @@ -357,7 +358,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed the `trim_to_layer` function to filter out non-reachable node and edge types when operating on heterogeneous graphs ([#7942](https://github.com/pyg-team/pytorch_geometric/pull/7942))
- Accelerated and simplified `top_k` computation in `TopKPooling` ([#7737](https://github.com/pyg-team/pytorch_geometric/pull/7737))
- Updated `GIN` implementation in kernel benchmarks to have sequential batchnorms ([#7955](https://github.com/pyg-team/pytorch_geometric/pull/7955))
- Fixed bugs in benchmarks caused by a lack of the device conditions for CPU and unexpected `cache` argument in heterogeneous models ([#7956](https://github.com/pyg-team/pytorch_geometric/pull/7956)
- Fixed bugs in benchmarks caused by a lack of the device conditions for CPU and unexpected `cache` argument in heterogeneous models ([#7956](https://github.com/pyg-team/pytorch_geometric/pull/7956)
- Fixed a bug in which `batch.e_id` was not correctly computed on unsorted graph inputs ([#7953](https://github.com/pyg-team/pytorch_geometric/pull/7953))
- Fixed `from_networkx` conversion from `nx.stochastic_block_model` graphs ([#7941](https://github.com/pyg-team/pytorch_geometric/pull/7941))
- Fixed the usage of `bias_initializer` in `HeteroLinear` ([#7923](https://github.com/pyg-team/pytorch_geometric/pull/7923))
Expand Down
95 changes: 95 additions & 0 deletions examples/gnn-rag/gnn-rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import argparse
import os
import time

from torch_geometric.nn.models import Trainer_KBQA
from torch_geometric.utils import create_logger

parser = argparse.ArgumentParser()
add_parse_args(parser)

args = parser.parse_args()
args.use_cuda = torch.cuda.is_available()

np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.experiment_name == None:
timestamp = str(int(time.time()))
args.experiment_name = "{}-{}-{}".format(
args.dataset,
args.model_name,
timestamp,
)


def run_query():
question = input("Please ask the model a question.")
query = "Please create a knowledge query for the following question, which leads with one or more relation queries from the question to the answer. " + question

response = openai.ChatCompletion.create(
model="gpt-4",
messages=[
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": query
},
],
temperature=0.7,
max_tokens=200,
)

response = response["choices"][0]["message"]["content"].strip()

answers = ""
#answers = Iterate knowledge graph

query = "You want to know: " + question + ". Give a simple answer to the question based on the information provided: " + answers + ". Do not provide any explanation. Only your factual response to the prompt."

final_response = openai.ChatCompletion.create(
model="gpt-4",
messages=[
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": query
},
],
temperature=0.7,
max_tokens=200,
)

final_response = final_response["choices"][0]["message"]["content"].strip()

return final_response


def main():
if not os.path.exists(args.checkpoint_dir):
os.mkdir(args.checkpoint_dir)
logger = create_logger(args)
trainer = Trainer_KBQA(args=vars(args), model_name='ReaRev', logger=logger)
if not args.is_eval:
trainer.train(0, args.num_epoch - 1)
else:
assert args.load_experiment is not None
if args.load_experiment is not None:
ckpt_path = os.path.join(args.checkpoint_dir, args.load_experiment)
print(f"Loading pre trained model from {ckpt_path}")
else:
ckpt_path = None
trainer.evaluate_single(ckpt_path)


if __name__ == '__main__':
main()

import openai

openai.api_key = "OMITTED"
154 changes: 154 additions & 0 deletions torch_geometric/data/dataset_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import json
from collections import deque
from typing import List, Optional, Tuple, Union

import networkx as nx
import torch

from torch_geometric.data import Data
from torch_geometric.utils import coalesce, to_undirected

# Load entity names
with open('entities_names.json') as f:
entities_names = json.load(f)
names_entities = {v: k for k, v in entities_names.items()}


def build_pyg_graph(graph_data: List[Tuple[str, str, str]],
entities: Optional[List[str]] = None,
encrypt: bool = False) -> Data:
"""Construct a PyG Data object from a list of (head, relation, tail) triplets.
"""
edges = []
for h, r, t in graph_data:
if encrypt:
if entities is not None and h in names_entities and \
names_entities[h] in entities:
h = names_entities[h]
if entities is not None and t in names_entities and \
names_entities[t] in entities:
t = names_entities[t]
edges.append((h, r.strip(), t))

# Build node list and mapping
node_set = {node for edge in edges for node in (edge[0], edge[2])}
node_list = sorted(node_set)
node_to_idx = {node: i for i, node in enumerate(node_list)}

edge_index_list = []
rel_list = []

for h, r, t in edges:
edge_index_list.append([node_to_idx[h], node_to_idx[t]])
rel_list.append(r)

edge_index = torch.tensor(edge_index_list,
dtype=torch.long).t().contiguous()

# Coalesce step: remove duplicates, sort edges
edge_index = to_undirected(edge_index)
edge_weights = torch.ones(edge_index.size(1), dtype=torch.float32)
edge_index, edge_weights = coalesce(edge_index, edge_weights,
len(node_list), len(node_list))

# Remap relations to edges
rel_map = {}
for h, r, t in edges:
u, v = node_to_idx[h], node_to_idx[t]
rel_map[(u, v)] = r
rel_map[(v, u)] = r

final_rels = [
rel_map[(edge_index[0, i].item(), edge_index[1, i].item())]
for i in range(edge_index.size(1))
]

data = Data(num_nodes=len(node_list), edge_index=edge_index,
edge_weights=edge_weights)
data.node_list = node_list # Reference to original names
data.node_to_idx = node_to_idx
data.relations = final_rels # List of relations, parallel to edge_index

return data


def pyg_data_to_networkx(data: Data) -> nx.Graph:
"""Converts a PyG Data object to a NetworkX Graph.
"""
G = nx.Graph()
G.add_nodes_from(data.node_list)

for i in range(data.edge_index.size(1)):
u = data.node_list[data.edge_index[0, i].item()]
v = data.node_list[data.edge_index[1, i].item()]
rel = data.relations[i]
G.add_edge(u, v, relation=rel)

return G


def bfs_with_rule(data: Data, start_node: Union[str,
int], target_rule: List[str],
max_p: int = 10) -> List[List[Tuple[str, str, str]]]:
"""Perform BFS to find paths matching a sequence of relations (target_rule).
"""
G = pyg_data_to_networkx(data)

if isinstance(start_node, str) and start_node not in data.node_to_idx:
return []
if isinstance(start_node, int):
start_node = data.node_list[start_node]

result_paths = []
queue: deque[Tuple[str, List[Tuple[str, str,
str]]]] = deque([(start_node, [])])

while queue:
current_node, current_path = queue.popleft()

if len(current_path) == len(target_rule):
result_paths.append(current_path)
continue

if current_node in G:
expected_rel = target_rule[len(current_path)]
for neighbor in G.neighbors(current_node):
rel = G[current_node][neighbor]['relation']
if rel == expected_rel:
new_path = current_path + [(current_node, rel, neighbor)]
queue.append((neighbor, new_path))

return result_paths


def get_truth_paths(q_entities: List[str], a_entities: List[str],
data: Data) -> List[List[Tuple[str, str, str]]]:
"""Retrieves all shortest paths between question entities and answer entities.
"""
G = pyg_data_to_networkx(data)
paths = []

for q in q_entities:
if q not in G:
continue
for a in a_entities:
if a not in G:
continue
try:
for path in nx.all_shortest_paths(G, q, a):
path_with_rels = [
(path[i], G[path[i]][path[i + 1]]['relation'],
path[i + 1]) for i in range(len(path) - 1)
]
paths.append(path_with_rels)
except nx.NetworkXNoPath:
pass # If no path exists, continue

return paths


def verbalize_paths(paths: List[List[Tuple[str, str, str]]]) -> str:
"""Converts paths into a readable format.
"""
return "\n".join(" → ".join(f"{edge[0]} → {edge[1]} → {edge[2]}"
for edge in path) for path in paths)
Loading
Loading