From 1519e9fa9cd0d23ee7b64d80563a45b10599a0c6 Mon Sep 17 00:00:00 2001 From: zaristei Date: Tue, 10 Dec 2024 14:01:29 -0500 Subject: [PATCH 1/2] Fix Docstring Typos for LargeGraphIndexer (#9837) Fix some issues with the docstrings of LargeGraphIndexer. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Zachary Aristei --- torch_geometric/data/large_graph_indexer.py | 38 ++++++++++----------- torch_geometric/loader/__init__.py | 4 ++- torch_geometric/loader/rag_loader.py | 5 +-- 3 files changed, 25 insertions(+), 22 deletions(-) diff --git a/torch_geometric/data/large_graph_indexer.py b/torch_geometric/data/large_graph_indexer.py index 0644e2543303..d2cb30378908 100644 --- a/torch_geometric/data/large_graph_indexer.py +++ b/torch_geometric/data/large_graph_indexer.py @@ -7,7 +7,6 @@ Any, Callable, Dict, - Hashable, Iterable, Iterator, List, @@ -25,12 +24,13 @@ from torch_geometric.data import Data from torch_geometric.typing import WITH_PT24 -TripletLike = Tuple[Hashable, Hashable, Hashable] +# Could be any hashable type +TripletLike = Tuple[str, str, str] KnowledgeGraphLike = Iterable[TripletLike] -def ordered_set(values: Iterable[Hashable]) -> List[Hashable]: +def ordered_set(values: Iterable[str]) -> List[str]: return list(dict.fromkeys(values)) @@ -70,13 +70,13 @@ def __eq__(self, value: "MappedFeature") -> bool: class LargeGraphIndexer: - """For a dataset that consists of mulitiple subgraphs that are assumed to + """For a dataset that consists of multiple subgraphs that are assumed to be part of a much larger graph, collate the values into a large graph store to save resources. """ def __init__( self, - nodes: Iterable[Hashable], + nodes: Iterable[str], edges: KnowledgeGraphLike, node_attr: Optional[Dict[str, List[Any]]] = None, edge_attr: Optional[Dict[str, List[Any]]] = None, @@ -85,7 +85,7 @@ def __init__( by id. Not meant to be used directly. Args: - nodes (Iterable[Hashable]): Node ids in the graph. + nodes (Iterable[str]): Node ids in the graph. edges (KnowledgeGraphLike): Edge ids in the graph. node_attr (Optional[Dict[str, List[Any]]], optional): Mapping node attribute name and list of their values in order of unique node @@ -94,7 +94,7 @@ def __init__( attribute name and list of their values in order of unique edge ids. Defaults to None. """ - self._nodes: Dict[Hashable, int] = dict() + self._nodes: Dict[str, int] = dict() self._edges: Dict[TripletLike, int] = dict() self._mapped_node_features: Set[str] = set() @@ -201,7 +201,7 @@ def collate(cls, index. Args: - graphs (Iterable["LargeGraphIndexer"]): Indices to be + graphs (Iterable[LargeGraphIndexer]): Indices to be combined. Returns: @@ -212,8 +212,8 @@ def collate(cls, trips = chain.from_iterable([graph.to_triplets() for graph in graphs]) return cls.from_triplets(trips) - def get_unique_node_features( - self, feature_name: str = NODE_PID) -> List[Hashable]: + def get_unique_node_features(self, + feature_name: str = NODE_PID) -> List[str]: r"""Get all the unique values for a specific node attribute. Args: @@ -221,7 +221,7 @@ def get_unique_node_features( Defaults to NODE_PID. Returns: - List[Hashable]: List of unique values for the specified feature. + List[str]: List of unique values for the specified feature. """ try: if feature_name in self._mapped_node_features: @@ -272,7 +272,7 @@ def add_node_feature( def get_node_features( self, feature_name: str = NODE_PID, - pids: Optional[Iterable[Hashable]] = None, + pids: Optional[Iterable[str]] = None, ) -> List[Any]: r"""Get node feature values for a given set of unique node ids. Returned values are not necessarily unique. @@ -280,7 +280,7 @@ def get_node_features( Args: feature_name (str, optional): Name of feature to fetch. Defaults to NODE_PID. - pids (Optional[Iterable[Hashable]], optional): Node ids to fetch + pids (Optional[Iterable[str]], optional): Node ids to fetch for. Defaults to None, which fetches all nodes. Returns: @@ -302,7 +302,7 @@ def get_node_features( def get_node_features_iter( self, feature_name: str = NODE_PID, - pids: Optional[Iterable[Hashable]] = None, + pids: Optional[Iterable[str]] = None, index_only: bool = False, ) -> Iterator[Any]: """Iterator version of get_node_features. If index_only is True, @@ -337,8 +337,8 @@ def get_node_features_iter( else: yield self.node_attr[feature_name][idx] - def get_unique_edge_features( - self, feature_name: str = EDGE_PID) -> List[Hashable]: + def get_unique_edge_features(self, + feature_name: str = EDGE_PID) -> List[str]: r"""Get all the unique values for a specific edge attribute. Args: @@ -346,7 +346,7 @@ def get_unique_edge_features( Defaults to EDGE_PID. Returns: - List[Hashable]: List of unique values for the specified feature. + List[str]: List of unique values for the specified feature. """ try: if feature_name in self._mapped_edge_features: @@ -396,7 +396,7 @@ def add_edge_feature( def get_edge_features( self, feature_name: str = EDGE_PID, - pids: Optional[Iterable[Hashable]] = None, + pids: Optional[Iterable[str]] = None, ) -> List[Any]: r"""Get edge feature values for a given set of unique edge ids. Returned values are not necessarily unique. @@ -404,7 +404,7 @@ def get_edge_features( Args: feature_name (str, optional): Name of feature to fetch. Defaults to EDGE_PID. - pids (Optional[Iterable[Hashable]], optional): Edge ids to fetch + pids (Optional[Iterable[str]], optional): Edge ids to fetch for. Defaults to None, which fetches all edges. Returns: diff --git a/torch_geometric/loader/__init__.py b/torch_geometric/loader/__init__.py index 7e83c35befb6..75dbe9178681 100644 --- a/torch_geometric/loader/__init__.py +++ b/torch_geometric/loader/__init__.py @@ -22,7 +22,7 @@ from .prefetch import PrefetchLoader from .cache import CachedLoader from .mixin import AffinityMixin -from .rag_loader import RAGQueryLoader +from .rag_loader import RAGQueryLoader, RAGFeatureStore, RAGGraphStore __all__ = classes = [ 'DataLoader', @@ -52,6 +52,8 @@ 'CachedLoader', 'AffinityMixin', 'RAGQueryLoader', + 'RAGFeatureStore', + 'RAGGraphStore' ] RandomNodeSampler = deprecated( diff --git a/torch_geometric/loader/rag_loader.py b/torch_geometric/loader/rag_loader.py index 33d6cf0e868e..4ab457ef7072 100644 --- a/torch_geometric/loader/rag_loader.py +++ b/torch_geometric/loader/rag_loader.py @@ -7,7 +7,7 @@ class RAGFeatureStore(Protocol): - """Feature store for remote GNN RAG backend.""" + """Feature store template for remote GNN RAG backend.""" @abstractmethod def retrieve_seed_nodes(self, query: Any, **kwargs) -> InputNodes: """Makes a comparison between the query and all the nodes to get all @@ -33,7 +33,7 @@ def load_subgraph( class RAGGraphStore(Protocol): - """Graph store for remote GNN RAG backend.""" + """Graph store template for remote GNN RAG backend.""" @abstractmethod def sample_subgraph(self, seed_nodes: InputNodes, seed_edges: InputEdges, **kwargs) -> Union[SamplerOutput, HeteroSamplerOutput]: @@ -52,6 +52,7 @@ def register_feature_store(self, feature_store: FeatureStore): class RAGQueryLoader: + """Loader meant for making RAG queries from a remote backend.""" def __init__(self, data: Tuple[RAGFeatureStore, RAGGraphStore], local_filter: Optional[Callable[[Data, Any], Data]] = None, seed_nodes_kwargs: Optional[Dict[str, Any]] = None, From 2b1b32719d5d87c9f34ed467bf44778c567c732a Mon Sep 17 00:00:00 2001 From: Manan Shah Date: Wed, 11 Dec 2024 00:25:58 -0800 Subject: [PATCH 2/2] feat: store reverse mapping within `EdgeTypeStr` (#9844) To avoid issues when node types contain the `EDGE_TYPE_STR_SPLIT` delimiter. --------- Co-authored-by: rusty1s --- torch_geometric/typing.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/torch_geometric/typing.py b/torch_geometric/typing.py index e63b1849b65c..468f37abfaed 100644 --- a/torch_geometric/typing.py +++ b/torch_geometric/typing.py @@ -307,6 +307,8 @@ class EdgeTypeStr(str): r"""A helper class to construct serializable edge types by merging an edge type tuple into a single string. """ + edge_type: tuple[str, str, str] + def __new__(cls, *args: Any) -> 'EdgeTypeStr': if isinstance(args[0], (list, tuple)): # Unwrap `EdgeType((src, rel, dst))` and `EdgeTypeStr((src, dst))`: @@ -314,27 +316,34 @@ def __new__(cls, *args: Any) -> 'EdgeTypeStr': if len(args) == 1 and isinstance(args[0], str): arg = args[0] # An edge type string was passed. + edge_type = tuple(arg.split(EDGE_TYPE_STR_SPLIT)) + if len(edge_type) != 3: + raise ValueError(f"Cannot convert the edge type '{arg}' to a " + f"tuple since it holds invalid characters") elif len(args) == 2 and all(isinstance(arg, str) for arg in args): # A `(src, dst)` edge type was passed - add `DEFAULT_REL`: - arg = EDGE_TYPE_STR_SPLIT.join((args[0], DEFAULT_REL, args[1])) + edge_type = (args[0], DEFAULT_REL, args[1]) + arg = EDGE_TYPE_STR_SPLIT.join(edge_type) elif len(args) == 3 and all(isinstance(arg, str) for arg in args): # A `(src, rel, dst)` edge type was passed: + edge_type = tuple(args) arg = EDGE_TYPE_STR_SPLIT.join(args) else: raise ValueError(f"Encountered invalid edge type '{args}'") - return str.__new__(cls, arg) + out = str.__new__(cls, arg) + out.edge_type = edge_type # type: ignore + return out def to_tuple(self) -> EdgeType: r"""Returns the original edge type.""" - out = tuple(self.split(EDGE_TYPE_STR_SPLIT)) - if len(out) != 3: + if len(self.edge_type) != 3: raise ValueError(f"Cannot convert the edge type '{self}' to a " f"tuple since it holds invalid characters") - return out + return self.edge_type # There exist some short-cuts to query edge-types (given that the full triplet