Skip to content

Commit

Permalink
Merge branch 'master' of personal:michailmelonas/pytorch_geometric in…
Browse files Browse the repository at this point in the history
…to add-token-gt
  • Loading branch information
michailramp committed Dec 12, 2024
2 parents 5f6c41d + 2b1b327 commit 58fa8ed
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 27 deletions.
38 changes: 19 additions & 19 deletions torch_geometric/data/large_graph_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
Any,
Callable,
Dict,
Hashable,
Iterable,
Iterator,
List,
Expand All @@ -25,12 +24,13 @@
from torch_geometric.data import Data
from torch_geometric.typing import WITH_PT24

TripletLike = Tuple[Hashable, Hashable, Hashable]
# Could be any hashable type
TripletLike = Tuple[str, str, str]

KnowledgeGraphLike = Iterable[TripletLike]


def ordered_set(values: Iterable[Hashable]) -> List[Hashable]:
def ordered_set(values: Iterable[str]) -> List[str]:
return list(dict.fromkeys(values))


Expand Down Expand Up @@ -70,13 +70,13 @@ def __eq__(self, value: "MappedFeature") -> bool:


class LargeGraphIndexer:
"""For a dataset that consists of mulitiple subgraphs that are assumed to
"""For a dataset that consists of multiple subgraphs that are assumed to
be part of a much larger graph, collate the values into a large graph store
to save resources.
"""
def __init__(
self,
nodes: Iterable[Hashable],
nodes: Iterable[str],
edges: KnowledgeGraphLike,
node_attr: Optional[Dict[str, List[Any]]] = None,
edge_attr: Optional[Dict[str, List[Any]]] = None,
Expand All @@ -85,7 +85,7 @@ def __init__(
by id. Not meant to be used directly.
Args:
nodes (Iterable[Hashable]): Node ids in the graph.
nodes (Iterable[str]): Node ids in the graph.
edges (KnowledgeGraphLike): Edge ids in the graph.
node_attr (Optional[Dict[str, List[Any]]], optional): Mapping node
attribute name and list of their values in order of unique node
Expand All @@ -94,7 +94,7 @@ def __init__(
attribute name and list of their values in order of unique edge
ids. Defaults to None.
"""
self._nodes: Dict[Hashable, int] = dict()
self._nodes: Dict[str, int] = dict()
self._edges: Dict[TripletLike, int] = dict()

self._mapped_node_features: Set[str] = set()
Expand Down Expand Up @@ -201,7 +201,7 @@ def collate(cls,
index.
Args:
graphs (Iterable["LargeGraphIndexer"]): Indices to be
graphs (Iterable[LargeGraphIndexer]): Indices to be
combined.
Returns:
Expand All @@ -212,16 +212,16 @@ def collate(cls,
trips = chain.from_iterable([graph.to_triplets() for graph in graphs])
return cls.from_triplets(trips)

def get_unique_node_features(
self, feature_name: str = NODE_PID) -> List[Hashable]:
def get_unique_node_features(self,
feature_name: str = NODE_PID) -> List[str]:
r"""Get all the unique values for a specific node attribute.
Args:
feature_name (str, optional): Name of feature to get.
Defaults to NODE_PID.
Returns:
List[Hashable]: List of unique values for the specified feature.
List[str]: List of unique values for the specified feature.
"""
try:
if feature_name in self._mapped_node_features:
Expand Down Expand Up @@ -272,15 +272,15 @@ def add_node_feature(
def get_node_features(
self,
feature_name: str = NODE_PID,
pids: Optional[Iterable[Hashable]] = None,
pids: Optional[Iterable[str]] = None,
) -> List[Any]:
r"""Get node feature values for a given set of unique node ids.
Returned values are not necessarily unique.
Args:
feature_name (str, optional): Name of feature to fetch. Defaults
to NODE_PID.
pids (Optional[Iterable[Hashable]], optional): Node ids to fetch
pids (Optional[Iterable[str]], optional): Node ids to fetch
for. Defaults to None, which fetches all nodes.
Returns:
Expand All @@ -302,7 +302,7 @@ def get_node_features(
def get_node_features_iter(
self,
feature_name: str = NODE_PID,
pids: Optional[Iterable[Hashable]] = None,
pids: Optional[Iterable[str]] = None,
index_only: bool = False,
) -> Iterator[Any]:
"""Iterator version of get_node_features. If index_only is True,
Expand Down Expand Up @@ -337,16 +337,16 @@ def get_node_features_iter(
else:
yield self.node_attr[feature_name][idx]

def get_unique_edge_features(
self, feature_name: str = EDGE_PID) -> List[Hashable]:
def get_unique_edge_features(self,
feature_name: str = EDGE_PID) -> List[str]:
r"""Get all the unique values for a specific edge attribute.
Args:
feature_name (str, optional): Name of feature to get.
Defaults to EDGE_PID.
Returns:
List[Hashable]: List of unique values for the specified feature.
List[str]: List of unique values for the specified feature.
"""
try:
if feature_name in self._mapped_edge_features:
Expand Down Expand Up @@ -396,15 +396,15 @@ def add_edge_feature(
def get_edge_features(
self,
feature_name: str = EDGE_PID,
pids: Optional[Iterable[Hashable]] = None,
pids: Optional[Iterable[str]] = None,
) -> List[Any]:
r"""Get edge feature values for a given set of unique edge ids.
Returned values are not necessarily unique.
Args:
feature_name (str, optional): Name of feature to fetch.
Defaults to EDGE_PID.
pids (Optional[Iterable[Hashable]], optional): Edge ids to fetch
pids (Optional[Iterable[str]], optional): Edge ids to fetch
for. Defaults to None, which fetches all edges.
Returns:
Expand Down
4 changes: 3 additions & 1 deletion torch_geometric/loader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .prefetch import PrefetchLoader
from .cache import CachedLoader
from .mixin import AffinityMixin
from .rag_loader import RAGQueryLoader
from .rag_loader import RAGQueryLoader, RAGFeatureStore, RAGGraphStore

__all__ = classes = [
'DataLoader',
Expand Down Expand Up @@ -52,6 +52,8 @@
'CachedLoader',
'AffinityMixin',
'RAGQueryLoader',
'RAGFeatureStore',
'RAGGraphStore'
]

RandomNodeSampler = deprecated(
Expand Down
5 changes: 3 additions & 2 deletions torch_geometric/loader/rag_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class RAGFeatureStore(Protocol):
"""Feature store for remote GNN RAG backend."""
"""Feature store template for remote GNN RAG backend."""
@abstractmethod
def retrieve_seed_nodes(self, query: Any, **kwargs) -> InputNodes:
"""Makes a comparison between the query and all the nodes to get all
Expand All @@ -33,7 +33,7 @@ def load_subgraph(


class RAGGraphStore(Protocol):
"""Graph store for remote GNN RAG backend."""
"""Graph store template for remote GNN RAG backend."""
@abstractmethod
def sample_subgraph(self, seed_nodes: InputNodes, seed_edges: InputEdges,
**kwargs) -> Union[SamplerOutput, HeteroSamplerOutput]:
Expand All @@ -52,6 +52,7 @@ def register_feature_store(self, feature_store: FeatureStore):


class RAGQueryLoader:
"""Loader meant for making RAG queries from a remote backend."""
def __init__(self, data: Tuple[RAGFeatureStore, RAGGraphStore],
local_filter: Optional[Callable[[Data, Any], Data]] = None,
seed_nodes_kwargs: Optional[Dict[str, Any]] = None,
Expand Down
19 changes: 14 additions & 5 deletions torch_geometric/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,34 +307,43 @@ class EdgeTypeStr(str):
r"""A helper class to construct serializable edge types by merging an edge
type tuple into a single string.
"""
edge_type: tuple[str, str, str]

def __new__(cls, *args: Any) -> 'EdgeTypeStr':
if isinstance(args[0], (list, tuple)):
# Unwrap `EdgeType((src, rel, dst))` and `EdgeTypeStr((src, dst))`:
args = tuple(args[0])

if len(args) == 1 and isinstance(args[0], str):
arg = args[0] # An edge type string was passed.
edge_type = tuple(arg.split(EDGE_TYPE_STR_SPLIT))
if len(edge_type) != 3:
raise ValueError(f"Cannot convert the edge type '{arg}' to a "
f"tuple since it holds invalid characters")

elif len(args) == 2 and all(isinstance(arg, str) for arg in args):
# A `(src, dst)` edge type was passed - add `DEFAULT_REL`:
arg = EDGE_TYPE_STR_SPLIT.join((args[0], DEFAULT_REL, args[1]))
edge_type = (args[0], DEFAULT_REL, args[1])
arg = EDGE_TYPE_STR_SPLIT.join(edge_type)

elif len(args) == 3 and all(isinstance(arg, str) for arg in args):
# A `(src, rel, dst)` edge type was passed:
edge_type = tuple(args)
arg = EDGE_TYPE_STR_SPLIT.join(args)

else:
raise ValueError(f"Encountered invalid edge type '{args}'")

return str.__new__(cls, arg)
out = str.__new__(cls, arg)
out.edge_type = edge_type # type: ignore
return out

def to_tuple(self) -> EdgeType:
r"""Returns the original edge type."""
out = tuple(self.split(EDGE_TYPE_STR_SPLIT))
if len(out) != 3:
if len(self.edge_type) != 3:
raise ValueError(f"Cannot convert the edge type '{self}' to a "
f"tuple since it holds invalid characters")
return out
return self.edge_type


# There exist some short-cuts to query edge-types (given that the full triplet
Expand Down

0 comments on commit 58fa8ed

Please sign in to comment.