Skip to content

Commit

Permalink
Merge branch 'master' into feature/adding_ssma_aggregation
Browse files Browse the repository at this point in the history
  • Loading branch information
AlmogDavid authored Dec 24, 2024
2 parents 25aba74 + ab2b458 commit 67730a8
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 184 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Update Dockerfile to use latest from NVIDIA ([#9794](https://github.com/pyg-team/pytorch_geometric/pull/9794))
- Added various GRetriever Architecture Benchmarking examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))
- Added `profiler.nvtxit` with some examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))
- Added `loader.RagQueryLoader` with Remote Backend Example ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))
Expand Down
172 changes: 15 additions & 157 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,163 +1,21 @@
FROM ubuntu:18.04
FROM nvcr.io/nvidia/cuda-dl-base:24.09-cuda12.6-devel-ubuntu22.04

# metainformation
LABEL org.opencontainers.image.version = "2.3.1"
LABEL org.opencontainers.image.authors = "Matthias Fey"
LABEL org.opencontainers.image.source = "https://github.com/pyg-team/pytorch_geometric"
LABEL org.opencontainers.image.licenses = "MIT"
LABEL org.opencontainers.image.base.name="docker.io/library/ubuntu:18.04"
# Based on NGC PyG 24.09 image:
# https://docs.nvidia.com/deeplearning/frameworks/pyg-release-notes/rel-24-09.html#rel-24-09

RUN apt-get update && apt-get install -y apt-transport-https ca-certificates && \
rm -rf /var/lib/apt/lists/*
# install pip
RUN apt-get update && apt-get install -y python3-pip

RUN apt-get update && apt-get install -y --no-install-recommends apt-utils gnupg2 curl && \
curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub | apt-key add - && \
echo "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/cuda.list && \
echo "deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64 /" > /etc/apt/sources.list.d/nvidia-ml.list &&\
apt-get purge --autoremove -y curl && \
rm -rf /var/lib/apt/lists/*
# install PyTorch - latest stable version
RUN pip install torch torchvision torchaudio

ENV CUDA_VERSION 10.1.243
ENV NCCL_VERSION 2.4.8
ENV CUDA_PKG_VERSION 10-1=$CUDA_VERSION-1
ENV CUDNN_VERSION 7.6.5.32
# install graphviz - latest stable version
RUN apt-get install -y graphviz graphviz-dev
RUN pip install pygraphviz

RUN apt-get update && apt-get install -y --no-install-recommends \
cuda-cudart-$CUDA_PKG_VERSION \
cuda-compat-10-1 && \
ln -s cuda-10.1 /usr/local/cuda && \
rm -rf /var/lib/apt/lists/*
# install python packages with NGC PyG 24.09 image versions
RUN pip install torch_geometric==2.6.0
RUN pip install triton==3.0.0 numba==0.59.0 requests==2.32.3 opencv-python==4.7.0.72 scipy==1.14.0 jupyterlab==4.2.5

RUN apt-get update && apt-get install -y --allow-unauthenticated --no-install-recommends \
cuda-libraries-$CUDA_PKG_VERSION \
cuda-nvtx-$CUDA_PKG_VERSION \
libcublas10=10.2.1.243-1 \
libnccl2=$NCCL_VERSION-1+cuda10.1 && \
apt-mark hold libnccl2 && \
rm -rf /var/lib/apt/lists/*

RUN apt-get update && apt-get install -y --allow-unauthenticated --no-install-recommends \
cuda-libraries-dev-$CUDA_PKG_VERSION \
cuda-nvml-dev-$CUDA_PKG_VERSION \
cuda-minimal-build-$CUDA_PKG_VERSION \
cuda-command-line-tools-$CUDA_PKG_VERSION \
libnccl-dev=$NCCL_VERSION-1+cuda10.1 \
libcublas-dev=10.2.1.243-1 \
&& \
rm -rf /var/lib/apt/lists/*

RUN apt-get update && apt-get install -y --no-install-recommends \
libcudnn7=$CUDNN_VERSION-1+cuda10.1 \
libcudnn7-dev=$CUDNN_VERSION-1+cuda10.1 \
&& \
apt-mark hold libcudnn7 && \
rm -rf /var/lib/apt/lists/*


ENV LIBRARY_PATH /usr/local/cuda/lib64/stubs

# NVIDIA docker 1.0.
LABEL com.nvidia.volumes.needed="nvidia_driver"
LABEL com.nvidia.cuda.version="${CUDA_VERSION}"

RUN echo "/usr/local/nvidia/lib" >> /etc/ld.so.conf.d/nvidia.conf && \
echo "/usr/local/nvidia/lib64" >> /etc/ld.so.conf.d/nvidia.conf

ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:${PATH}
ENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64

# NVIDIA container runtime.
ENV NVIDIA_VISIBLE_DEVICES all
ENV NVIDIA_DRIVER_CAPABILITIES compute,utility
ENV NVIDIA_REQUIRE_CUDA "cuda>=10.0 brand=tesla,driver>=384,driver<385 brand=tesla,driver>=410,driver<411"

# PyTorch (Geometric) installation
RUN rm /etc/apt/sources.list.d/cuda.list && \
rm /etc/apt/sources.list.d/nvidia-ml.list

RUN apt-get update && apt-get install -y \
curl \
ca-certificates \
vim \
sudo \
git \
bzip2 \
libx11-6 \
&& rm -rf /var/lib/apt/lists/*

# Create a working directory.
RUN mkdir /app
WORKDIR /app

# Create a non-root user and switch to it.
RUN adduser --disabled-password --gecos '' --shell /bin/bash user \
&& chown -R user:user /app
RUN echo "user ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/90-user
USER user

# All users can use /home/user as their home directory.
ENV HOME=/home/user
RUN chmod 777 /home/user

# Install Miniconda.
RUN curl -so ~/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& chmod +x ~/miniconda.sh \
&& ~/miniconda.sh -b -p ~/miniconda \
&& rm ~/miniconda.sh
ENV PATH=/home/user/miniconda/bin:$PATH
ENV CONDA_AUTO_UPDATE_CONDA=false

# Create a Python 3.6 environment.
RUN /home/user/miniconda/bin/conda install conda-build \
&& /home/user/miniconda/bin/conda create -y --name py36 python=3.6.5 \
&& /home/user/miniconda/bin/conda clean -ya
ENV CONDA_DEFAULT_ENV=py36
ENV CONDA_PREFIX=/home/user/miniconda/envs/$CONDA_DEFAULT_ENV
ENV PATH=$CONDA_PREFIX/bin:$PATH

# CUDA 10.0-specific steps.
RUN conda install -y -c pytorch \
cudatoolkit=10.1 \
"pytorch=1.4.0=py3.6_cuda10.1.243_cudnn7.6.3_0" \
torchvision=0.5.0=py36_cu101 \
&& conda clean -ya

# Install HDF5 Python bindings.
RUN conda install -y h5py=2.8.0 \
&& conda clean -ya
RUN pip install h5py-cache==1.0

# Install TorchNet, a high-level framework for PyTorch.
RUN pip install torchnet==0.0.4

# Install Requests, a Python library for making HTTP requests.
RUN conda install -y requests=2.19.1 \
&& conda clean -ya

# Install Graphviz.
RUN conda install -y graphviz=2.40.1 python-graphviz=0.8.4 \
&& conda clean -ya

# Install OpenCV3 Python bindings.
RUN sudo apt-get update && sudo apt-get install -y --no-install-recommends \
libgtk2.0-0 \
libcanberra-gtk-module \
&& sudo rm -rf /var/lib/apt/lists/*
RUN conda install -y -c menpo opencv3=3.1.0 \
&& conda clean -ya

# Install PyG.
RUN CPATH=/usr/local/cuda/include:$CPATH \
&& LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH \
&& DYLD_LIBRARY_PATH=/usr/local/cuda/lib:$DYLD_LIBRARY_PATH

RUN pip install scipy

RUN pip install --no-index torch_scatter -f https://data.pyg.org/whl/torch-1.4.0+cu101.html \
&& pip install --no-index torch_sparse -f https://data.pyg.org/whl/torch-1.4.0+cu101.html \
&& pip install --no-index torch_cluster -f https://data.pyg.org/whl/torch-1.4.0+cu101.html \
&& pip install --no-index torch_spline_conv -f https://data.pyg.org/whl/torch-1.4.0+cu101.html \
&& pip install torch-geometric

# Set the default command to python3.
CMD ["python3"]
# install cugraph
RUN pip install cugraph-cu12 cugraph-pyg-cu12 --extra-index-url=https://pypi.nvidia.com
38 changes: 19 additions & 19 deletions torch_geometric/data/large_graph_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
Any,
Callable,
Dict,
Hashable,
Iterable,
Iterator,
List,
Expand All @@ -25,12 +24,13 @@
from torch_geometric.data import Data
from torch_geometric.typing import WITH_PT24

TripletLike = Tuple[Hashable, Hashable, Hashable]
# Could be any hashable type
TripletLike = Tuple[str, str, str]

KnowledgeGraphLike = Iterable[TripletLike]


def ordered_set(values: Iterable[Hashable]) -> List[Hashable]:
def ordered_set(values: Iterable[str]) -> List[str]:
return list(dict.fromkeys(values))


Expand Down Expand Up @@ -70,13 +70,13 @@ def __eq__(self, value: "MappedFeature") -> bool:


class LargeGraphIndexer:
"""For a dataset that consists of mulitiple subgraphs that are assumed to
"""For a dataset that consists of multiple subgraphs that are assumed to
be part of a much larger graph, collate the values into a large graph store
to save resources.
"""
def __init__(
self,
nodes: Iterable[Hashable],
nodes: Iterable[str],
edges: KnowledgeGraphLike,
node_attr: Optional[Dict[str, List[Any]]] = None,
edge_attr: Optional[Dict[str, List[Any]]] = None,
Expand All @@ -85,7 +85,7 @@ def __init__(
by id. Not meant to be used directly.
Args:
nodes (Iterable[Hashable]): Node ids in the graph.
nodes (Iterable[str]): Node ids in the graph.
edges (KnowledgeGraphLike): Edge ids in the graph.
node_attr (Optional[Dict[str, List[Any]]], optional): Mapping node
attribute name and list of their values in order of unique node
Expand All @@ -94,7 +94,7 @@ def __init__(
attribute name and list of their values in order of unique edge
ids. Defaults to None.
"""
self._nodes: Dict[Hashable, int] = dict()
self._nodes: Dict[str, int] = dict()
self._edges: Dict[TripletLike, int] = dict()

self._mapped_node_features: Set[str] = set()
Expand Down Expand Up @@ -201,7 +201,7 @@ def collate(cls,
index.
Args:
graphs (Iterable[&quot;LargeGraphIndexer&quot;]): Indices to be
graphs (Iterable[LargeGraphIndexer]): Indices to be
combined.
Returns:
Expand All @@ -212,16 +212,16 @@ def collate(cls,
trips = chain.from_iterable([graph.to_triplets() for graph in graphs])
return cls.from_triplets(trips)

def get_unique_node_features(
self, feature_name: str = NODE_PID) -> List[Hashable]:
def get_unique_node_features(self,
feature_name: str = NODE_PID) -> List[str]:
r"""Get all the unique values for a specific node attribute.
Args:
feature_name (str, optional): Name of feature to get.
Defaults to NODE_PID.
Returns:
List[Hashable]: List of unique values for the specified feature.
List[str]: List of unique values for the specified feature.
"""
try:
if feature_name in self._mapped_node_features:
Expand Down Expand Up @@ -272,15 +272,15 @@ def add_node_feature(
def get_node_features(
self,
feature_name: str = NODE_PID,
pids: Optional[Iterable[Hashable]] = None,
pids: Optional[Iterable[str]] = None,
) -> List[Any]:
r"""Get node feature values for a given set of unique node ids.
Returned values are not necessarily unique.
Args:
feature_name (str, optional): Name of feature to fetch. Defaults
to NODE_PID.
pids (Optional[Iterable[Hashable]], optional): Node ids to fetch
pids (Optional[Iterable[str]], optional): Node ids to fetch
for. Defaults to None, which fetches all nodes.
Returns:
Expand All @@ -302,7 +302,7 @@ def get_node_features(
def get_node_features_iter(
self,
feature_name: str = NODE_PID,
pids: Optional[Iterable[Hashable]] = None,
pids: Optional[Iterable[str]] = None,
index_only: bool = False,
) -> Iterator[Any]:
"""Iterator version of get_node_features. If index_only is True,
Expand Down Expand Up @@ -337,16 +337,16 @@ def get_node_features_iter(
else:
yield self.node_attr[feature_name][idx]

def get_unique_edge_features(
self, feature_name: str = EDGE_PID) -> List[Hashable]:
def get_unique_edge_features(self,
feature_name: str = EDGE_PID) -> List[str]:
r"""Get all the unique values for a specific edge attribute.
Args:
feature_name (str, optional): Name of feature to get.
Defaults to EDGE_PID.
Returns:
List[Hashable]: List of unique values for the specified feature.
List[str]: List of unique values for the specified feature.
"""
try:
if feature_name in self._mapped_edge_features:
Expand Down Expand Up @@ -396,15 +396,15 @@ def add_edge_feature(
def get_edge_features(
self,
feature_name: str = EDGE_PID,
pids: Optional[Iterable[Hashable]] = None,
pids: Optional[Iterable[str]] = None,
) -> List[Any]:
r"""Get edge feature values for a given set of unique edge ids.
Returned values are not necessarily unique.
Args:
feature_name (str, optional): Name of feature to fetch.
Defaults to EDGE_PID.
pids (Optional[Iterable[Hashable]], optional): Edge ids to fetch
pids (Optional[Iterable[str]], optional): Edge ids to fetch
for. Defaults to None, which fetches all edges.
Returns:
Expand Down
4 changes: 3 additions & 1 deletion torch_geometric/loader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .prefetch import PrefetchLoader
from .cache import CachedLoader
from .mixin import AffinityMixin
from .rag_loader import RAGQueryLoader
from .rag_loader import RAGQueryLoader, RAGFeatureStore, RAGGraphStore

__all__ = classes = [
'DataLoader',
Expand Down Expand Up @@ -52,6 +52,8 @@
'CachedLoader',
'AffinityMixin',
'RAGQueryLoader',
'RAGFeatureStore',
'RAGGraphStore'
]

RandomNodeSampler = deprecated(
Expand Down
5 changes: 3 additions & 2 deletions torch_geometric/loader/rag_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class RAGFeatureStore(Protocol):
"""Feature store for remote GNN RAG backend."""
"""Feature store template for remote GNN RAG backend."""
@abstractmethod
def retrieve_seed_nodes(self, query: Any, **kwargs) -> InputNodes:
"""Makes a comparison between the query and all the nodes to get all
Expand All @@ -33,7 +33,7 @@ def load_subgraph(


class RAGGraphStore(Protocol):
"""Graph store for remote GNN RAG backend."""
"""Graph store template for remote GNN RAG backend."""
@abstractmethod
def sample_subgraph(self, seed_nodes: InputNodes, seed_edges: InputEdges,
**kwargs) -> Union[SamplerOutput, HeteroSamplerOutput]:
Expand All @@ -52,6 +52,7 @@ def register_feature_store(self, feature_store: FeatureStore):


class RAGQueryLoader:
"""Loader meant for making RAG queries from a remote backend."""
def __init__(self, data: Tuple[RAGFeatureStore, RAGGraphStore],
local_filter: Optional[Callable[[Data, Any], Data]] = None,
seed_nodes_kwargs: Optional[Dict[str, Any]] = None,
Expand Down
Loading

0 comments on commit 67730a8

Please sign in to comment.