From 747f0eac84b3b9e7bba25df68b428f2466b71ebe Mon Sep 17 00:00:00 2001 From: Rishi Puri Date: Wed, 11 Dec 2024 17:48:49 -0800 Subject: [PATCH] Improve system prompt for TXT2KG (#9848) improve https://github.com/pyg-team/pytorch_geometric/pull/9846 --------- Co-authored-by: riship Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../g_retriever_utils/rag_backend_utils.py | 224 ------------------ .../g_retriever_utils/rag_feature_store.py | 189 --------------- .../llm/g_retriever_utils/rag_graph_store.py | 107 --------- examples/llm/hotpot_qa.py | 15 +- torch_geometric/nn/nlp/__init__.py | 1 + torch_geometric/nn/nlp/txt2kg.py | 2 +- torch_geometric/utils/rag/backend_utils.py | 8 + 7 files changed, 18 insertions(+), 528 deletions(-) delete mode 100644 examples/llm/g_retriever_utils/rag_backend_utils.py delete mode 100644 examples/llm/g_retriever_utils/rag_feature_store.py delete mode 100644 examples/llm/g_retriever_utils/rag_graph_store.py diff --git a/examples/llm/g_retriever_utils/rag_backend_utils.py b/examples/llm/g_retriever_utils/rag_backend_utils.py deleted file mode 100644 index 0f1c0e1b87ec..000000000000 --- a/examples/llm/g_retriever_utils/rag_backend_utils.py +++ /dev/null @@ -1,224 +0,0 @@ -from dataclasses import dataclass -from enum import Enum, auto -from typing import ( - Any, - Callable, - Dict, - Iterable, - Optional, - Protocol, - Tuple, - Type, - runtime_checkable, -) - -import torch -from torch import Tensor -from torch.nn import Module - -from torch_geometric.data import ( - FeatureStore, - GraphStore, - LargeGraphIndexer, - TripletLike, -) -from torch_geometric.data.large_graph_indexer import EDGE_RELATION -from torch_geometric.distributed import ( - LocalFeatureStore, - LocalGraphStore, - Partitioner, -) -from torch_geometric.typing import EdgeType, NodeType - -RemoteGraphBackend = Tuple[FeatureStore, GraphStore] - -# TODO: Make everything compatible with Hetero graphs aswell - - -# Adapted from LocalGraphStore -@runtime_checkable -class ConvertableGraphStore(Protocol): - @classmethod - def from_data( - cls, - edge_id: Tensor, - edge_index: Tensor, - num_nodes: int, - is_sorted: bool = False, - ) -> GraphStore: - ... - - @classmethod - def from_hetero_data( - cls, - edge_id_dict: Dict[EdgeType, Tensor], - edge_index_dict: Dict[EdgeType, Tensor], - num_nodes_dict: Dict[NodeType, int], - is_sorted: bool = False, - ) -> GraphStore: - ... - - @classmethod - def from_partition(cls, root: str, pid: int) -> GraphStore: - ... - - -# Adapted from LocalFeatureStore -@runtime_checkable -class ConvertableFeatureStore(Protocol): - @classmethod - def from_data( - cls, - node_id: Tensor, - x: Optional[Tensor] = None, - y: Optional[Tensor] = None, - edge_id: Optional[Tensor] = None, - edge_attr: Optional[Tensor] = None, - ) -> FeatureStore: - ... - - @classmethod - def from_hetero_data( - cls, - node_id_dict: Dict[NodeType, Tensor], - x_dict: Optional[Dict[NodeType, Tensor]] = None, - y_dict: Optional[Dict[NodeType, Tensor]] = None, - edge_id_dict: Optional[Dict[EdgeType, Tensor]] = None, - edge_attr_dict: Optional[Dict[EdgeType, Tensor]] = None, - ) -> FeatureStore: - ... - - @classmethod - def from_partition(cls, root: str, pid: int) -> FeatureStore: - ... - - -class RemoteDataType(Enum): - DATA = auto() - PARTITION = auto() - - -@dataclass -class RemoteGraphBackendLoader: - """Utility class to load triplets into a RAG Backend.""" - path: str - datatype: RemoteDataType - graph_store_type: Type[ConvertableGraphStore] - feature_store_type: Type[ConvertableFeatureStore] - - def load(self, pid: Optional[int] = None) -> RemoteGraphBackend: - if self.datatype == RemoteDataType.DATA: - data_obj = torch.load(self.path) - graph_store = self.graph_store_type.from_data( - edge_id=data_obj['edge_id'], edge_index=data_obj.edge_index, - num_nodes=data_obj.num_nodes) - feature_store = self.feature_store_type.from_data( - node_id=data_obj['node_id'], x=data_obj.x, - edge_id=data_obj['edge_id'], edge_attr=data_obj.edge_attr) - elif self.datatype == RemoteDataType.PARTITION: - if pid is None: - assert pid is not None, \ - "Partition ID must be defined for loading from a " \ - + "partitioned store." - graph_store = self.graph_store_type.from_partition(self.path, pid) - feature_store = self.feature_store_type.from_partition( - self.path, pid) - else: - raise NotImplementedError - return (feature_store, graph_store) - - -# TODO: make profilable -def create_remote_backend_from_triplets( - triplets: Iterable[TripletLike], node_embedding_model: Module, - edge_embedding_model: Module | None = None, - graph_db: Type[ConvertableGraphStore] = LocalGraphStore, - feature_db: Type[ConvertableFeatureStore] = LocalFeatureStore, - node_method_to_call: str = "forward", - edge_method_to_call: str | None = None, - pre_transform: Callable[[TripletLike], TripletLike] | None = None, - path: str = '', n_parts: int = 1, - node_method_kwargs: Optional[Dict[str, Any]] = None, - edge_method_kwargs: Optional[Dict[str, Any]] = None -) -> RemoteGraphBackendLoader: - """Utility function that can be used to create a RAG Backend from triplets. - - Args: - triplets (Iterable[TripletLike]): Triplets to load into the RAG - Backend. - node_embedding_model (Module): Model to embed nodes into a feature - space. - edge_embedding_model (Module | None, optional): Model to embed edges - into a feature space. Defaults to the node model. - graph_db (Type[ConvertableGraphStore], optional): GraphStore class to - use. Defaults to LocalGraphStore. - feature_db (Type[ConvertableFeatureStore], optional): FeatureStore - class to use. Defaults to LocalFeatureStore. - node_method_to_call (str, optional): method to call for embeddings on - the node model. Defaults to "forward". - edge_method_to_call (str | None, optional): method to call for - embeddings on the edge model. Defaults to the node method. - pre_transform (Callable[[TripletLike], TripletLike] | None, optional): - optional preprocessing function for triplets. Defaults to None. - path (str, optional): path to save resulting stores. Defaults to ''. - n_parts (int, optional): Number of partitons to store in. - Defaults to 1. - node_method_kwargs (Optional[Dict[str, Any]], optional): args to pass - into node encoding method. Defaults to None. - edge_method_kwargs (Optional[Dict[str, Any]], optional): args to pass - into edge encoding method. Defaults to None. - - Returns: - RemoteGraphBackendLoader: Loader to load RAG backend from disk or - memory. - """ - # Will return attribute errors for missing attributes - if not issubclass(graph_db, ConvertableGraphStore): - getattr(graph_db, "from_data") - getattr(graph_db, "from_hetero_data") - getattr(graph_db, "from_partition") - elif not issubclass(feature_db, ConvertableFeatureStore): - getattr(feature_db, "from_data") - getattr(feature_db, "from_hetero_data") - getattr(feature_db, "from_partition") - - # Resolve callable methods - node_method_kwargs = node_method_kwargs \ - if node_method_kwargs is not None else dict() - - edge_embedding_model = edge_embedding_model \ - if edge_embedding_model is not None else node_embedding_model - edge_method_to_call = edge_method_to_call \ - if edge_method_to_call is not None else node_method_to_call - edge_method_kwargs = edge_method_kwargs \ - if edge_method_kwargs is not None else node_method_kwargs - - # These will return AttributeErrors if they don't exist - node_model = getattr(node_embedding_model, node_method_to_call) - edge_model = getattr(edge_embedding_model, edge_method_to_call) - - indexer = LargeGraphIndexer.from_triplets(triplets, - pre_transform=pre_transform) - - node_feats = node_model(indexer.get_node_features(), **node_method_kwargs) - indexer.add_node_feature('x', node_feats) - - edge_feats = edge_model( - indexer.get_unique_edge_features(feature_name=EDGE_RELATION), - **edge_method_kwargs) - indexer.add_edge_feature(new_feature_name="edge_attr", - new_feature_vals=edge_feats, - map_from_feature=EDGE_RELATION) - - data = indexer.to_data(node_feature_name='x', - edge_feature_name='edge_attr') - - if n_parts == 1: - torch.save(data, path) - return RemoteGraphBackendLoader(path, RemoteDataType.DATA, graph_db, - feature_db) - else: - partitioner = Partitioner(data=data, num_parts=n_parts, root=path) - partitioner.generate_partition() - return RemoteGraphBackendLoader(path, RemoteDataType.PARTITION, - graph_db, feature_db) diff --git a/examples/llm/g_retriever_utils/rag_feature_store.py b/examples/llm/g_retriever_utils/rag_feature_store.py deleted file mode 100644 index e01e9e59bb88..000000000000 --- a/examples/llm/g_retriever_utils/rag_feature_store.py +++ /dev/null @@ -1,189 +0,0 @@ -import gc -from collections.abc import Iterable, Iterator -from typing import Any, Dict, Optional, Type, Union - -import torch -from torch import Tensor -from torch.nn import Module -from torchmetrics.functional import pairwise_cosine_similarity - -from torch_geometric.data import Data, HeteroData -from torch_geometric.distributed import LocalFeatureStore -from torch_geometric.nn.nlp import SentenceTransformer -from torch_geometric.nn.pool import ApproxMIPSKNNIndex -from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput -from torch_geometric.typing import InputEdges, InputNodes - - -# NOTE: Only compatible with Homogeneous graphs for now -class KNNRAGFeatureStore(LocalFeatureStore): - def __init__(self, enc_model: Type[Module], - model_kwargs: Optional[Dict[str, - Any]] = None, *args, **kwargs): - self.device = torch.device( - "cuda" if torch.cuda.is_available() else "cpu") - self.enc_model = enc_model(*args, **kwargs).to(self.device) - self.enc_model.eval() - self.model_kwargs = \ - model_kwargs if model_kwargs is not None else dict() - super().__init__() - - @property - def x(self) -> Tensor: - return self.get_tensor(group_name=None, attr_name='x') - - @property - def edge_attr(self) -> Tensor: - return self.get_tensor(group_name=(None, None), attr_name='edge_attr') - - def retrieve_seed_nodes(self, query: Any, k_nodes: int = 5) -> InputNodes: - result = next(self._retrieve_seed_nodes_batch([query], k_nodes)) - gc.collect() - torch.cuda.empty_cache() - return result - - def _retrieve_seed_nodes_batch(self, query: Iterable[Any], - k_nodes: int) -> Iterator[InputNodes]: - if isinstance(self.meta, dict) and self.meta.get("is_hetero", False): - raise NotImplementedError - - query_enc = self.enc_model.encode(query, - **self.model_kwargs).to(self.device) - prizes = pairwise_cosine_similarity(query_enc, self.x.to(self.device)) - topk = min(k_nodes, len(self.x)) - for q in prizes: - _, indices = torch.topk(q, topk, largest=True) - yield indices - - def retrieve_seed_edges(self, query: Any, k_edges: int = 3) -> InputEdges: - result = next(self._retrieve_seed_edges_batch([query], k_edges)) - gc.collect() - torch.cuda.empty_cache() - return result - - def _retrieve_seed_edges_batch(self, query: Iterable[Any], - k_edges: int) -> Iterator[InputEdges]: - if isinstance(self.meta, dict) and self.meta.get("is_hetero", False): - raise NotImplementedError - - query_enc = self.enc_model.encode(query, - **self.model_kwargs).to(self.device) - - prizes = pairwise_cosine_similarity(query_enc, - self.edge_attr.to(self.device)) - topk = min(k_edges, len(self.edge_attr)) - for q in prizes: - _, indices = torch.topk(q, topk, largest=True) - yield indices - - def load_subgraph( - self, sample: Union[SamplerOutput, HeteroSamplerOutput] - ) -> Union[Data, HeteroData]: - - if isinstance(sample, HeteroSamplerOutput): - raise NotImplementedError - - # NOTE: torch_geometric.loader.utils.filter_custom_store can be used - # here if it supported edge features - node_id = sample.node - edge_id = sample.edge - edge_index = torch.stack((sample.row, sample.col), dim=0) - x = self.x[node_id] - edge_attr = self.edge_attr[edge_id] - - return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, - node_idx=node_id, edge_idx=edge_id) - - -# TODO: Refactor because composition >> inheritance - - -def _add_features_to_knn_index(knn_index: ApproxMIPSKNNIndex, emb: Tensor, - device: torch.device, batch_size: int = 2**20): - """Add new features to the existing KNN index in batches. - - Args: - knn_index (ApproxMIPSKNNIndex): Index to add features to. - emb (Tensor): Embeddings to add. - device (torch.device): Device to store in - batch_size (int, optional): Batch size to iterate by. - Defaults to 2**20, which equates to 4GB if working with - 1024 dim floats. - """ - for i in range(0, emb.size(0), batch_size): - if emb.size(0) - i >= batch_size: - emb_batch = emb[i:i + batch_size].to(device) - else: - emb_batch = emb[i:].to(device) - knn_index.add(emb_batch) - - -class ApproxKNNRAGFeatureStore(KNNRAGFeatureStore): - def __init__(self, enc_model: Type[Module], - model_kwargs: Optional[Dict[str, - Any]] = None, *args, **kwargs): - # TODO: Add kwargs for approx KNN to parameters here. - super().__init__(enc_model, model_kwargs, *args, **kwargs) - self.node_knn_index = None - self.edge_knn_index = None - - def _retrieve_seed_nodes_batch(self, query: Iterable[Any], - k_nodes: int) -> Iterator[InputNodes]: - if isinstance(self.meta, dict) and self.meta.get("is_hetero", False): - raise NotImplementedError - - enc_model = self.enc_model.to(self.device) - query_enc = enc_model.encode(query, - **self.model_kwargs).to(self.device) - del enc_model - gc.collect() - torch.cuda.empty_cache() - - if self.node_knn_index is None: - self.node_knn_index = ApproxMIPSKNNIndex(num_cells=100, - num_cells_to_visit=100, - bits_per_vector=4) - # Need to add in batches to avoid OOM - _add_features_to_knn_index(self.node_knn_index, self.x, - self.device) - - output = self.node_knn_index.search(query_enc, k=k_nodes) - yield from output.index - - def _retrieve_seed_edges_batch(self, query: Iterable[Any], - k_edges: int) -> Iterator[InputEdges]: - if isinstance(self.meta, dict) and self.meta.get("is_hetero", False): - raise NotImplementedError - - enc_model = self.enc_model.to(self.device) - query_enc = enc_model.encode(query, - **self.model_kwargs).to(self.device) - del enc_model - gc.collect() - torch.cuda.empty_cache() - - if self.edge_knn_index is None: - self.edge_knn_index = ApproxMIPSKNNIndex(num_cells=100, - num_cells_to_visit=100, - bits_per_vector=4) - # Need to add in batches to avoid OOM - _add_features_to_knn_index(self.edge_knn_index, self.edge_attr, - self.device) - - output = self.edge_knn_index.search(query_enc, k=k_edges) - yield from output.index - - -# TODO: These two classes should be refactored -class SentenceTransformerFeatureStore(KNNRAGFeatureStore): - def __init__(self, *args, **kwargs): - kwargs['model_name'] = kwargs.get( - 'model_name', 'sentence-transformers/all-roberta-large-v1') - super().__init__(SentenceTransformer, *args, **kwargs) - - -class SentenceTransformerApproxFeatureStore(ApproxKNNRAGFeatureStore): - def __init__(self, *args, **kwargs): - kwargs['model_name'] = kwargs.get( - 'model_name', 'sentence-transformers/all-roberta-large-v1') - super().__init__(SentenceTransformer, *args, **kwargs) diff --git a/examples/llm/g_retriever_utils/rag_graph_store.py b/examples/llm/g_retriever_utils/rag_graph_store.py deleted file mode 100644 index 48473f287233..000000000000 --- a/examples/llm/g_retriever_utils/rag_graph_store.py +++ /dev/null @@ -1,107 +0,0 @@ -from typing import Optional, Union - -import torch -from torch import Tensor - -from torch_geometric.data import FeatureStore -from torch_geometric.distributed import LocalGraphStore -from torch_geometric.sampler import ( - HeteroSamplerOutput, - NeighborSampler, - NodeSamplerInput, - SamplerOutput, -) -from torch_geometric.sampler.neighbor_sampler import NumNeighborsType -from torch_geometric.typing import EdgeTensorType, InputEdges, InputNodes - - -class NeighborSamplingRAGGraphStore(LocalGraphStore): - def __init__(self, feature_store: Optional[FeatureStore] = None, - num_neighbors: NumNeighborsType = [1], **kwargs): - self.feature_store = feature_store - self._num_neighbors = num_neighbors - self.sample_kwargs = kwargs - self._sampler_is_initialized = False - super().__init__() - - def _init_sampler(self): - if self.feature_store is None: - raise AttributeError("Feature store not registered yet.") - self.sampler = NeighborSampler(data=(self.feature_store, self), - num_neighbors=self._num_neighbors, - **self.sample_kwargs) - self._sampler_is_initialized = True - - def register_feature_store(self, feature_store: FeatureStore): - self.feature_store = feature_store - self._sampler_is_initialized = False - - def put_edge_id(self, edge_id: Tensor, *args, **kwargs) -> bool: - ret = super().put_edge_id(edge_id.contiguous(), *args, **kwargs) - self._sampler_is_initialized = False - return ret - - @property - def edge_index(self): - return self.get_edge_index(*self.edge_idx_args, **self.edge_idx_kwargs) - - def put_edge_index(self, edge_index: EdgeTensorType, *args, - **kwargs) -> bool: - ret = super().put_edge_index(edge_index, *args, **kwargs) - # HACK - self.edge_idx_args = args - self.edge_idx_kwargs = kwargs - self._sampler_is_initialized = False - return ret - - @property - def num_neighbors(self): - return self._num_neighbors - - @num_neighbors.setter - def num_neighbors(self, num_neighbors: NumNeighborsType): - self._num_neighbors = num_neighbors - if hasattr(self, 'sampler'): - self.sampler.num_neighbors = num_neighbors - - def sample_subgraph( - self, seed_nodes: InputNodes, seed_edges: InputEdges, - num_neighbors: Optional[NumNeighborsType] = None - ) -> Union[SamplerOutput, HeteroSamplerOutput]: - """Sample the graph starting from the given nodes and edges using the - in-built NeighborSampler. - - Args: - seed_nodes (InputNodes): Seed nodes to start sampling from. - seed_edges (InputEdges): Seed edges to start sampling from. - num_neighbors (Optional[NumNeighborsType], optional): Parameters - to determine how many hops and number of neighbors per hop. - Defaults to None. - - Returns: - Union[SamplerOutput, HeteroSamplerOutput]: NeighborSamplerOutput - for the input. - """ - if not self._sampler_is_initialized: - self._init_sampler() - if num_neighbors is not None: - self.num_neighbors = num_neighbors - - # FIXME: Right now, only input nodes/edges as tensors are be supported - if not isinstance(seed_nodes, Tensor): - raise NotImplementedError - if not isinstance(seed_edges, Tensor): - raise NotImplementedError - device = seed_nodes.device - - # TODO: Call sample_from_edges for seed_edges - # Turning them into nodes for now. - seed_edges = self.edge_index.to(device).T[seed_edges.to( - device)].reshape(-1) - seed_nodes = torch.cat((seed_nodes, seed_edges), dim=0) - - seed_nodes = seed_nodes.unique().contiguous() - node_sample_input = NodeSamplerInput(input_id=None, node=seed_nodes) - out = self.sampler.sample_from_nodes(node_sample_input) - - return out diff --git a/examples/llm/hotpot_qa.py b/examples/llm/hotpot_qa.py index c52557aee02c..5bab32f9b50e 100644 --- a/examples/llm/hotpot_qa.py +++ b/examples/llm/hotpot_qa.py @@ -7,12 +7,12 @@ from tqdm import tqdm from torch_geometric import seed_everything -from torch_geometric.datasets.web_qsp_dataset import preprocess_triplet from torch_geometric.loader import RAGQueryLoader from torch_geometric.nn.nlp import TXT2KG, SentenceTransformer from torch_geometric.utils.rag.backend_utils import ( create_remote_backend_from_triplets, make_pcst_filter, + preprocess_triplet, ) from torch_geometric.utils.rag.feature_store import ( SentenceTransformerFeatureStore, @@ -28,6 +28,7 @@ parser.add_argument('--local_lm', action="store_true") parser.add_argument('--percent_data', type=float, default=1.0) parser.add_argument('--chunk_size', type=int, default=512) + parser.add_argument('--verbose', action="store_true") args = parser.parse_args() assert args.percent_data <= 100 and args.percent_data > 0 if args.local_lm: @@ -110,12 +111,12 @@ q = QA_pair[0] retrieved_subgraph = query_loader.query(q) retrieved_triples = retrieved_subgraph.triples - ########## - # for debug - # print("Q=", q) - # print("A=", QA_pair[1]) - # print("retrieved_triples =", retrieved_triples) - #### + + if args.verbose: + print("Q=", q) + print("A=", QA_pair[1]) + print("retrieved_triples =", retrieved_triples) + num_relevant_out_of_retrieved = float( sum([ int(bool(retrieved_triple in golden_triples)) diff --git a/torch_geometric/nn/nlp/__init__.py b/torch_geometric/nn/nlp/__init__.py index 62163ddbfd1c..cbda10d6b82b 100644 --- a/torch_geometric/nn/nlp/__init__.py +++ b/torch_geometric/nn/nlp/__init__.py @@ -1,6 +1,7 @@ from .sentence_transformer import SentenceTransformer from .vision_transformer import VisionTransformer from .llm import LLM +from .txt2kg import TXT2KG __all__ = classes = [ 'SentenceTransformer', diff --git a/torch_geometric/nn/nlp/txt2kg.py b/torch_geometric/nn/nlp/txt2kg.py index 48ee8babc733..aef406436765 100644 --- a/torch_geometric/nn/nlp/txt2kg.py +++ b/torch_geometric/nn/nlp/txt2kg.py @@ -9,7 +9,7 @@ CLIENT = None GLOBAL_NIM_KEY = "" -SYSTEM_PROMPT = "Please convert the above text into a list of knowledge triples with the form ('entity', 'relation', 'entity'). Seperate each with a new line. Do not output anything else.”" +SYSTEM_PROMPT = "Please convert the above text into a list of knowledge triples with the form ('entity', 'relation', 'entity'). Seperate each with a new line. Do not output anything else. Try to focus on key triples that form a connected graph.”" class TXT2KG(): diff --git a/torch_geometric/utils/rag/backend_utils.py b/torch_geometric/utils/rag/backend_utils.py index 58b13bb7c978..bfbec9281659 100644 --- a/torch_geometric/utils/rag/backend_utils.py +++ b/torch_geometric/utils/rag/backend_utils.py @@ -43,6 +43,14 @@ # TODO: Make everything compatible with Hetero graphs aswell +# (TODO) once Zacks webqsp PR is merged +# https://github.com/pyg-team/pytorch_geometric/pull/9806 +# update WebQSP in this branch to use preprocess_triplet from here +def preprocess_triplet(triplet: TripletLike) -> TripletLike: + h, r, t = triplet + return str(h).lower(), str(r), str(t).lower() + + # Adapted from LocalGraphStore @runtime_checkable class ConvertableGraphStore(Protocol):