diff --git a/CHANGELOG.md b/CHANGELOG.md index 69ec38aaa4ca..b4ff83664a29 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Update Dockerfile to use latest from NVIDIA ([#9794](https://github.com/pyg-team/pytorch_geometric/pull/9794)) - Added various GRetriever Architecture Benchmarking examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) - Added `profiler.nvtxit` with some examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) - Added `loader.RagQueryLoader` with Remote Backend Example ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) diff --git a/docker/Dockerfile b/docker/Dockerfile index d4f37f061d68..d7a879ba1157 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,163 +1,21 @@ -FROM ubuntu:18.04 +FROM nvcr.io/nvidia/cuda-dl-base:24.09-cuda12.6-devel-ubuntu22.04 -# metainformation -LABEL org.opencontainers.image.version = "2.3.1" -LABEL org.opencontainers.image.authors = "Matthias Fey" -LABEL org.opencontainers.image.source = "https://github.com/pyg-team/pytorch_geometric" -LABEL org.opencontainers.image.licenses = "MIT" -LABEL org.opencontainers.image.base.name="docker.io/library/ubuntu:18.04" +# Based on NGC PyG 24.09 image: +# https://docs.nvidia.com/deeplearning/frameworks/pyg-release-notes/rel-24-09.html#rel-24-09 -RUN apt-get update && apt-get install -y apt-transport-https ca-certificates && \ - rm -rf /var/lib/apt/lists/* +# install pip +RUN apt-get update && apt-get install -y python3-pip -RUN apt-get update && apt-get install -y --no-install-recommends apt-utils gnupg2 curl && \ - curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub | apt-key add - && \ - echo "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/cuda.list && \ - echo "deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/nvidia-ml.list &&\ - apt-get purge --autoremove -y curl && \ -rm -rf /var/lib/apt/lists/* +# install PyTorch - latest stable version +RUN pip install torch torchvision torchaudio -ENV CUDA_VERSION 10.1.243 -ENV NCCL_VERSION 2.4.8 -ENV CUDA_PKG_VERSION 10-1=$CUDA_VERSION-1 -ENV CUDNN_VERSION 7.6.5.32 +# install graphviz - latest stable version +RUN apt-get install -y graphviz graphviz-dev +RUN pip install pygraphviz -RUN apt-get update && apt-get install -y --no-install-recommends \ - cuda-cudart-$CUDA_PKG_VERSION \ - cuda-compat-10-1 && \ - ln -s cuda-10.1 /usr/local/cuda && \ - rm -rf /var/lib/apt/lists/* +# install python packages with NGC PyG 24.09 image versions +RUN pip install torch_geometric==2.6.0 +RUN pip install triton==3.0.0 numba==0.59.0 requests==2.32.3 opencv-python==4.7.0.72 scipy==1.14.0 jupyterlab==4.2.5 -RUN apt-get update && apt-get install -y --allow-unauthenticated --no-install-recommends \ - cuda-libraries-$CUDA_PKG_VERSION \ - cuda-nvtx-$CUDA_PKG_VERSION \ - libcublas10=10.2.1.243-1 \ - libnccl2=$NCCL_VERSION-1+cuda10.1 && \ - apt-mark hold libnccl2 && \ - rm -rf /var/lib/apt/lists/* - -RUN apt-get update && apt-get install -y --allow-unauthenticated --no-install-recommends \ - cuda-libraries-dev-$CUDA_PKG_VERSION \ - cuda-nvml-dev-$CUDA_PKG_VERSION \ - cuda-minimal-build-$CUDA_PKG_VERSION \ - cuda-command-line-tools-$CUDA_PKG_VERSION \ - libnccl-dev=$NCCL_VERSION-1+cuda10.1 \ - libcublas-dev=10.2.1.243-1 \ - && \ - rm -rf /var/lib/apt/lists/* - -RUN apt-get update && apt-get install -y --no-install-recommends \ - libcudnn7=$CUDNN_VERSION-1+cuda10.1 \ - libcudnn7-dev=$CUDNN_VERSION-1+cuda10.1 \ - && \ - apt-mark hold libcudnn7 && \ - rm -rf /var/lib/apt/lists/* - - -ENV LIBRARY_PATH /usr/local/cuda/lib64/stubs - -# NVIDIA docker 1.0. -LABEL com.nvidia.volumes.needed="nvidia_driver" -LABEL com.nvidia.cuda.version="${CUDA_VERSION}" - -RUN echo "/usr/local/nvidia/lib" >> /etc/ld.so.conf.d/nvidia.conf && \ - echo "/usr/local/nvidia/lib64" >> /etc/ld.so.conf.d/nvidia.conf - -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:${PATH} -ENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64 - -# NVIDIA container runtime. -ENV NVIDIA_VISIBLE_DEVICES all -ENV NVIDIA_DRIVER_CAPABILITIES compute,utility -ENV NVIDIA_REQUIRE_CUDA "cuda>=10.0 brand=tesla,driver>=384,driver<385 brand=tesla,driver>=410,driver<411" - -# PyTorch (Geometric) installation -RUN rm /etc/apt/sources.list.d/cuda.list && \ - rm /etc/apt/sources.list.d/nvidia-ml.list - -RUN apt-get update && apt-get install -y \ - curl \ - ca-certificates \ - vim \ - sudo \ - git \ - bzip2 \ - libx11-6 \ - && rm -rf /var/lib/apt/lists/* - -# Create a working directory. -RUN mkdir /app -WORKDIR /app - -# Create a non-root user and switch to it. -RUN adduser --disabled-password --gecos '' --shell /bin/bash user \ - && chown -R user:user /app -RUN echo "user ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/90-user -USER user - -# All users can use /home/user as their home directory. -ENV HOME=/home/user -RUN chmod 777 /home/user - -# Install Miniconda. -RUN curl -so ~/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ - && chmod +x ~/miniconda.sh \ - && ~/miniconda.sh -b -p ~/miniconda \ - && rm ~/miniconda.sh -ENV PATH=/home/user/miniconda/bin:$PATH -ENV CONDA_AUTO_UPDATE_CONDA=false - -# Create a Python 3.6 environment. -RUN /home/user/miniconda/bin/conda install conda-build \ - && /home/user/miniconda/bin/conda create -y --name py36 python=3.6.5 \ - && /home/user/miniconda/bin/conda clean -ya -ENV CONDA_DEFAULT_ENV=py36 -ENV CONDA_PREFIX=/home/user/miniconda/envs/$CONDA_DEFAULT_ENV -ENV PATH=$CONDA_PREFIX/bin:$PATH - -# CUDA 10.0-specific steps. -RUN conda install -y -c pytorch \ - cudatoolkit=10.1 \ - "pytorch=1.4.0=py3.6_cuda10.1.243_cudnn7.6.3_0" \ - torchvision=0.5.0=py36_cu101 \ - && conda clean -ya - -# Install HDF5 Python bindings. -RUN conda install -y h5py=2.8.0 \ - && conda clean -ya -RUN pip install h5py-cache==1.0 - -# Install TorchNet, a high-level framework for PyTorch. -RUN pip install torchnet==0.0.4 - -# Install Requests, a Python library for making HTTP requests. -RUN conda install -y requests=2.19.1 \ - && conda clean -ya - -# Install Graphviz. -RUN conda install -y graphviz=2.40.1 python-graphviz=0.8.4 \ - && conda clean -ya - -# Install OpenCV3 Python bindings. -RUN sudo apt-get update && sudo apt-get install -y --no-install-recommends \ - libgtk2.0-0 \ - libcanberra-gtk-module \ - && sudo rm -rf /var/lib/apt/lists/* -RUN conda install -y -c menpo opencv3=3.1.0 \ - && conda clean -ya - -# Install PyG. -RUN CPATH=/usr/local/cuda/include:$CPATH \ - && LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH \ - && DYLD_LIBRARY_PATH=/usr/local/cuda/lib:$DYLD_LIBRARY_PATH - -RUN pip install scipy - -RUN pip install --no-index torch_scatter -f https://data.pyg.org/whl/torch-1.4.0+cu101.html \ - && pip install --no-index torch_sparse -f https://data.pyg.org/whl/torch-1.4.0+cu101.html \ - && pip install --no-index torch_cluster -f https://data.pyg.org/whl/torch-1.4.0+cu101.html \ - && pip install --no-index torch_spline_conv -f https://data.pyg.org/whl/torch-1.4.0+cu101.html \ - && pip install torch-geometric - -# Set the default command to python3. -CMD ["python3"] +# install cugraph +RUN pip install cugraph-cu12 cugraph-pyg-cu12 --extra-index-url=https://pypi.nvidia.com diff --git a/torch_geometric/data/large_graph_indexer.py b/torch_geometric/data/large_graph_indexer.py index 0644e2543303..d2cb30378908 100644 --- a/torch_geometric/data/large_graph_indexer.py +++ b/torch_geometric/data/large_graph_indexer.py @@ -7,7 +7,6 @@ Any, Callable, Dict, - Hashable, Iterable, Iterator, List, @@ -25,12 +24,13 @@ from torch_geometric.data import Data from torch_geometric.typing import WITH_PT24 -TripletLike = Tuple[Hashable, Hashable, Hashable] +# Could be any hashable type +TripletLike = Tuple[str, str, str] KnowledgeGraphLike = Iterable[TripletLike] -def ordered_set(values: Iterable[Hashable]) -> List[Hashable]: +def ordered_set(values: Iterable[str]) -> List[str]: return list(dict.fromkeys(values)) @@ -70,13 +70,13 @@ def __eq__(self, value: "MappedFeature") -> bool: class LargeGraphIndexer: - """For a dataset that consists of mulitiple subgraphs that are assumed to + """For a dataset that consists of multiple subgraphs that are assumed to be part of a much larger graph, collate the values into a large graph store to save resources. """ def __init__( self, - nodes: Iterable[Hashable], + nodes: Iterable[str], edges: KnowledgeGraphLike, node_attr: Optional[Dict[str, List[Any]]] = None, edge_attr: Optional[Dict[str, List[Any]]] = None, @@ -85,7 +85,7 @@ def __init__( by id. Not meant to be used directly. Args: - nodes (Iterable[Hashable]): Node ids in the graph. + nodes (Iterable[str]): Node ids in the graph. edges (KnowledgeGraphLike): Edge ids in the graph. node_attr (Optional[Dict[str, List[Any]]], optional): Mapping node attribute name and list of their values in order of unique node @@ -94,7 +94,7 @@ def __init__( attribute name and list of their values in order of unique edge ids. Defaults to None. """ - self._nodes: Dict[Hashable, int] = dict() + self._nodes: Dict[str, int] = dict() self._edges: Dict[TripletLike, int] = dict() self._mapped_node_features: Set[str] = set() @@ -201,7 +201,7 @@ def collate(cls, index. Args: - graphs (Iterable["LargeGraphIndexer"]): Indices to be + graphs (Iterable[LargeGraphIndexer]): Indices to be combined. Returns: @@ -212,8 +212,8 @@ def collate(cls, trips = chain.from_iterable([graph.to_triplets() for graph in graphs]) return cls.from_triplets(trips) - def get_unique_node_features( - self, feature_name: str = NODE_PID) -> List[Hashable]: + def get_unique_node_features(self, + feature_name: str = NODE_PID) -> List[str]: r"""Get all the unique values for a specific node attribute. Args: @@ -221,7 +221,7 @@ def get_unique_node_features( Defaults to NODE_PID. Returns: - List[Hashable]: List of unique values for the specified feature. + List[str]: List of unique values for the specified feature. """ try: if feature_name in self._mapped_node_features: @@ -272,7 +272,7 @@ def add_node_feature( def get_node_features( self, feature_name: str = NODE_PID, - pids: Optional[Iterable[Hashable]] = None, + pids: Optional[Iterable[str]] = None, ) -> List[Any]: r"""Get node feature values for a given set of unique node ids. Returned values are not necessarily unique. @@ -280,7 +280,7 @@ def get_node_features( Args: feature_name (str, optional): Name of feature to fetch. Defaults to NODE_PID. - pids (Optional[Iterable[Hashable]], optional): Node ids to fetch + pids (Optional[Iterable[str]], optional): Node ids to fetch for. Defaults to None, which fetches all nodes. Returns: @@ -302,7 +302,7 @@ def get_node_features( def get_node_features_iter( self, feature_name: str = NODE_PID, - pids: Optional[Iterable[Hashable]] = None, + pids: Optional[Iterable[str]] = None, index_only: bool = False, ) -> Iterator[Any]: """Iterator version of get_node_features. If index_only is True, @@ -337,8 +337,8 @@ def get_node_features_iter( else: yield self.node_attr[feature_name][idx] - def get_unique_edge_features( - self, feature_name: str = EDGE_PID) -> List[Hashable]: + def get_unique_edge_features(self, + feature_name: str = EDGE_PID) -> List[str]: r"""Get all the unique values for a specific edge attribute. Args: @@ -346,7 +346,7 @@ def get_unique_edge_features( Defaults to EDGE_PID. Returns: - List[Hashable]: List of unique values for the specified feature. + List[str]: List of unique values for the specified feature. """ try: if feature_name in self._mapped_edge_features: @@ -396,7 +396,7 @@ def add_edge_feature( def get_edge_features( self, feature_name: str = EDGE_PID, - pids: Optional[Iterable[Hashable]] = None, + pids: Optional[Iterable[str]] = None, ) -> List[Any]: r"""Get edge feature values for a given set of unique edge ids. Returned values are not necessarily unique. @@ -404,7 +404,7 @@ def get_edge_features( Args: feature_name (str, optional): Name of feature to fetch. Defaults to EDGE_PID. - pids (Optional[Iterable[Hashable]], optional): Edge ids to fetch + pids (Optional[Iterable[str]], optional): Edge ids to fetch for. Defaults to None, which fetches all edges. Returns: diff --git a/torch_geometric/loader/__init__.py b/torch_geometric/loader/__init__.py index 7e83c35befb6..75dbe9178681 100644 --- a/torch_geometric/loader/__init__.py +++ b/torch_geometric/loader/__init__.py @@ -22,7 +22,7 @@ from .prefetch import PrefetchLoader from .cache import CachedLoader from .mixin import AffinityMixin -from .rag_loader import RAGQueryLoader +from .rag_loader import RAGQueryLoader, RAGFeatureStore, RAGGraphStore __all__ = classes = [ 'DataLoader', @@ -52,6 +52,8 @@ 'CachedLoader', 'AffinityMixin', 'RAGQueryLoader', + 'RAGFeatureStore', + 'RAGGraphStore' ] RandomNodeSampler = deprecated( diff --git a/torch_geometric/loader/rag_loader.py b/torch_geometric/loader/rag_loader.py index 33d6cf0e868e..4ab457ef7072 100644 --- a/torch_geometric/loader/rag_loader.py +++ b/torch_geometric/loader/rag_loader.py @@ -7,7 +7,7 @@ class RAGFeatureStore(Protocol): - """Feature store for remote GNN RAG backend.""" + """Feature store template for remote GNN RAG backend.""" @abstractmethod def retrieve_seed_nodes(self, query: Any, **kwargs) -> InputNodes: """Makes a comparison between the query and all the nodes to get all @@ -33,7 +33,7 @@ def load_subgraph( class RAGGraphStore(Protocol): - """Graph store for remote GNN RAG backend.""" + """Graph store template for remote GNN RAG backend.""" @abstractmethod def sample_subgraph(self, seed_nodes: InputNodes, seed_edges: InputEdges, **kwargs) -> Union[SamplerOutput, HeteroSamplerOutput]: @@ -52,6 +52,7 @@ def register_feature_store(self, feature_store: FeatureStore): class RAGQueryLoader: + """Loader meant for making RAG queries from a remote backend.""" def __init__(self, data: Tuple[RAGFeatureStore, RAGGraphStore], local_filter: Optional[Callable[[Data, Any], Data]] = None, seed_nodes_kwargs: Optional[Dict[str, Any]] = None, diff --git a/torch_geometric/typing.py b/torch_geometric/typing.py index e63b1849b65c..513f041847b1 100644 --- a/torch_geometric/typing.py +++ b/torch_geometric/typing.py @@ -307,6 +307,8 @@ class EdgeTypeStr(str): r"""A helper class to construct serializable edge types by merging an edge type tuple into a single string. """ + edge_type: tuple[str, str, str] + def __new__(cls, *args: Any) -> 'EdgeTypeStr': if isinstance(args[0], (list, tuple)): # Unwrap `EdgeType((src, rel, dst))` and `EdgeTypeStr((src, dst))`: @@ -314,27 +316,37 @@ def __new__(cls, *args: Any) -> 'EdgeTypeStr': if len(args) == 1 and isinstance(args[0], str): arg = args[0] # An edge type string was passed. + edge_type = tuple(arg.split(EDGE_TYPE_STR_SPLIT)) + if len(edge_type) != 3: + raise ValueError(f"Cannot convert the edge type '{arg}' to a " + f"tuple since it holds invalid characters") elif len(args) == 2 and all(isinstance(arg, str) for arg in args): # A `(src, dst)` edge type was passed - add `DEFAULT_REL`: - arg = EDGE_TYPE_STR_SPLIT.join((args[0], DEFAULT_REL, args[1])) + edge_type = (args[0], DEFAULT_REL, args[1]) + arg = EDGE_TYPE_STR_SPLIT.join(edge_type) elif len(args) == 3 and all(isinstance(arg, str) for arg in args): # A `(src, rel, dst)` edge type was passed: + edge_type = tuple(args) arg = EDGE_TYPE_STR_SPLIT.join(args) else: raise ValueError(f"Encountered invalid edge type '{args}'") - return str.__new__(cls, arg) + out = str.__new__(cls, arg) + out.edge_type = edge_type # type: ignore + return out def to_tuple(self) -> EdgeType: r"""Returns the original edge type.""" - out = tuple(self.split(EDGE_TYPE_STR_SPLIT)) - if len(out) != 3: + if len(self.edge_type) != 3: raise ValueError(f"Cannot convert the edge type '{self}' to a " f"tuple since it holds invalid characters") - return out + return self.edge_type + + def __reduce__(self) -> tuple[Any, Any]: + return (self.__class__, (self.edge_type, )) # There exist some short-cuts to query edge-types (given that the full triplet