diff --git a/pyproject.toml b/pyproject.toml index 94d0b2c25a47..ae337f7587f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ dependencies=[ [project.optional-dependencies] graphgym=[ "protobuf<4.21", - "pytorch-lightning", + "pytorch-lightning<2.3.0", "yacs", ] modelhub=[ @@ -77,7 +77,7 @@ full = [ "h5py", "matplotlib", "networkx", - "numba", + "numba<0.60.0", "opt_einsum", "pandas", "pgmpy", diff --git a/test/loader/test_zip_loader.py b/test/loader/test_zip_loader.py index 4bb4a4d79e2b..724cb25dab75 100644 --- a/test/loader/test_zip_loader.py +++ b/test/loader/test_zip_loader.py @@ -21,6 +21,10 @@ def test_zip_loader(filter_per_worker): loader = ZipLoader(loaders, batch_size=10, filter_per_worker=filter_per_worker) + batches = loader(torch.arange(5)) + assert isinstance(batches, tuple) + assert len(batches) == 2 + assert str(loader) == ('ZipLoader(loaders=[NeighborLoader(), ' 'NeighborLoader()])') assert len(loader) == 5 diff --git a/test/nn/nlp/test_sentence_transformer.py b/test/nn/nlp/test_sentence_transformer.py index d778d5abc1ae..956df7e8fa71 100644 --- a/test/nn/nlp/test_sentence_transformer.py +++ b/test/nn/nlp/test_sentence_transformer.py @@ -8,8 +8,12 @@ @onlyFullTest @withPackage('transformers') @pytest.mark.parametrize('batch_size', [None, 1]) -def test_sentence_transformer(batch_size, device): - model = SentenceTransformer(model_name='prajjwal1/bert-tiny').to(device) +@pytest.mark.parametrize('pooling_strategy', ['mean', 'last', 'cls']) +def test_sentence_transformer(batch_size, pooling_strategy, device): + model = SentenceTransformer( + model_name='prajjwal1/bert-tiny', + pooling_strategy=pooling_strategy, + ).to(device) assert model.device == device assert str(model) == 'SentenceTransformer(model_name=prajjwal1/bert-tiny)' diff --git a/test/nn/test_model_hub.py b/test/nn/test_model_hub.py index bd7c486bbdb5..fafe400f6435 100644 --- a/test/nn/test_model_hub.py +++ b/test/nn/test_model_hub.py @@ -68,11 +68,19 @@ def test_save_pretrained_with_push_to_hub(model, tmp_path): # Push to hub with repo_id model.save_pretrained(save_directory, push_to_hub=True, repo_id='CustomID', config=CONFIG) - model.push_to_hub.assert_called_with(repo_id='CustomID', config=CONFIG) + model.push_to_hub.assert_called_with( + repo_id='CustomID', + model_card_kwargs={}, + config=CONFIG, + ) # Push to hub with default repo_id (based on dir name) model.save_pretrained(save_directory, push_to_hub=True, config=CONFIG) - model.push_to_hub.assert_called_with(repo_id=REPO_NAME, config=CONFIG) + model.push_to_hub.assert_called_with( + repo_id=REPO_NAME, + model_card_kwargs={}, + config=CONFIG, + ) @withPackage('huggingface_hub') diff --git a/torch_geometric/data/lightning/datamodule.py b/torch_geometric/data/lightning/datamodule.py index 0e9e3cafed2e..4889ec92a71f 100644 --- a/torch_geometric/data/lightning/datamodule.py +++ b/torch_geometric/data/lightning/datamodule.py @@ -13,7 +13,7 @@ try: from pytorch_lightning import LightningDataModule as PLLightningDataModule no_pytorch_lightning = False -except (ImportError, ModuleNotFoundError): +except ImportError: PLLightningDataModule = object # type: ignore no_pytorch_lightning = True diff --git a/torch_geometric/loader/zip_loader.py b/torch_geometric/loader/zip_loader.py index 60836579634c..2dd4beb5edfc 100644 --- a/torch_geometric/loader/zip_loader.py +++ b/torch_geometric/loader/zip_loader.py @@ -59,6 +59,16 @@ def __init__( self.loaders = loaders self.filter_per_worker = filter_per_worker + def __call__( + self, + index: Union[Tensor, List[int]], + ) -> Union[Tuple[Data, ...], Tuple[HeteroData, ...]]: + r"""Samples subgraphs from a batch of input IDs.""" + out = self.collate_fn(index) + if not self.filter_per_worker: + out = self.filter_fn(out) + return out + def collate_fn(self, index: List[int]) -> Tuple[Any, ...]: if not isinstance(index, Tensor): index = torch.tensor(index, dtype=torch.long) diff --git a/torch_geometric/nn/conv/edge_conv.py b/torch_geometric/nn/conv/edge_conv.py index b72324d5e00b..5381749a9563 100644 --- a/torch_geometric/nn/conv/edge_conv.py +++ b/torch_geometric/nn/conv/edge_conv.py @@ -3,13 +3,14 @@ import torch from torch import Tensor +import torch_geometric.typing from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.inits import reset from torch_geometric.typing import Adj, OptTensor, PairOptTensor, PairTensor -try: +if torch_geometric.typing.WITH_TORCH_CLUSTER: from torch_cluster import knn -except ImportError: +else: knn = None diff --git a/torch_geometric/nn/conv/gravnet_conv.py b/torch_geometric/nn/conv/gravnet_conv.py index e7afac36f99f..7d7c2298ba8f 100644 --- a/torch_geometric/nn/conv/gravnet_conv.py +++ b/torch_geometric/nn/conv/gravnet_conv.py @@ -4,14 +4,15 @@ import torch from torch import Tensor +import torch_geometric.typing from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.typing import OptPairTensor # noqa from torch_geometric.typing import OptTensor, PairOptTensor, PairTensor -try: +if torch_geometric.typing.WITH_TORCH_CLUSTER: from torch_cluster import knn -except ImportError: +else: knn = None diff --git a/torch_geometric/nn/conv/spline_conv.py b/torch_geometric/nn/conv/spline_conv.py index e363647b1292..dd4fec2e756c 100644 --- a/torch_geometric/nn/conv/spline_conv.py +++ b/torch_geometric/nn/conv/spline_conv.py @@ -5,17 +5,17 @@ from torch import Tensor, nn from torch.nn import Parameter +import torch_geometric.typing from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.inits import uniform, zeros from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size from torch_geometric.utils.repeat import repeat -try: +if torch_geometric.typing.WITH_TORCH_SPLINE_CONV: from torch_spline_conv import spline_basis, spline_weighting -except (ImportError, OSError): # Fail gracefully on GLIBC errors - spline_basis = None - spline_weighting = None +else: + spline_basis = spline_weighting = None class SplineConv(MessagePassing): diff --git a/torch_geometric/nn/conv/x_conv.py b/torch_geometric/nn/conv/x_conv.py index 8a533946f0ba..a5d2af4c28ef 100644 --- a/torch_geometric/nn/conv/x_conv.py +++ b/torch_geometric/nn/conv/x_conv.py @@ -9,12 +9,13 @@ from torch.nn import Linear as L from torch.nn import Sequential as S +import torch_geometric.typing from torch_geometric.nn import Reshape from torch_geometric.nn.inits import reset -try: +if torch_geometric.typing.WITH_TORCH_CLUSTER: from torch_cluster import knn_graph -except ImportError: +else: knn_graph = None diff --git a/torch_geometric/nn/nlp/sentence_transformer.py b/torch_geometric/nn/nlp/sentence_transformer.py index f82b77c2e6f1..cfb26176fcc6 100644 --- a/torch_geometric/nn/nlp/sentence_transformer.py +++ b/torch_geometric/nn/nlp/sentence_transformer.py @@ -1,30 +1,45 @@ -from typing import List, Optional +from enum import Enum +from typing import List, Optional, Union import torch import torch.nn.functional as F from torch import Tensor +class PoolingStrategy(Enum): + MEAN = 'mean' + LAST = 'last' + CLS = 'cls' + + class SentenceTransformer(torch.nn.Module): - def __init__(self, model_name: str) -> None: + def __init__( + self, + model_name: str, + pooling_strategy: Union[PoolingStrategy, str] = 'mean', + ) -> None: super().__init__() self.model_name = model_name + self.pooling_strategy = PoolingStrategy(pooling_strategy) from transformers import AutoModel, AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModel.from_pretrained(model_name) - def mean_pooling(self, emb: Tensor, attention_mask: Tensor) -> Tensor: - mask = attention_mask.unsqueeze(-1).expand(emb.size()).to(emb.dtype) - return (emb * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9) - def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor: out = self.model(input_ids=input_ids, attention_mask=attention_mask) emb = out[0] # First element contains all token embeddings. - emb = self.mean_pooling(emb, attention_mask) + if self.pooling_strategy == PoolingStrategy.MEAN: + emb = mean_pooling(emb, attention_mask) + elif self.pooling_strategy == PoolingStrategy.LAST: + emb = last_pooling(emb, attention_mask) + else: + assert self.pooling_strategy == PoolingStrategy.CLS + emb = emb[:, 0, :] + emb = F.normalize(emb, p=2, dim=1) return emb @@ -61,3 +76,19 @@ def encode( def __repr__(self) -> str: return f'{self.__class__.__name__}(model_name={self.model_name})' + + +def mean_pooling(emb: Tensor, attention_mask: Tensor) -> Tensor: + mask = attention_mask.unsqueeze(-1).expand(emb.size()).to(emb.dtype) + return (emb * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9) + + +def last_pooling(emb: Tensor, attention_mask: Tensor) -> Tensor: + # Check whether language model uses left padding, + # which is always used for decoder LLMs + left_padding = attention_mask[:, -1].sum() == attention_mask.size(0) + if left_padding: + return emb[:, -1] + + seq_indices = attention_mask.sum(dim=1) - 1 + return emb[torch.arange(emb.size(0), device=emb.device), seq_indices] diff --git a/torch_geometric/nn/pool/graclus.py b/torch_geometric/nn/pool/graclus.py index 22b7d5a2af97..99c31f64a94f 100644 --- a/torch_geometric/nn/pool/graclus.py +++ b/torch_geometric/nn/pool/graclus.py @@ -2,9 +2,11 @@ from torch import Tensor -try: +import torch_geometric.typing + +if torch_geometric.typing.WITH_TORCH_CLUSTER: from torch_cluster import graclus_cluster -except ImportError: +else: graclus_cluster = None diff --git a/torch_geometric/nn/pool/voxel_grid.py b/torch_geometric/nn/pool/voxel_grid.py index 3deaec418c97..d258bdf44c72 100644 --- a/torch_geometric/nn/pool/voxel_grid.py +++ b/torch_geometric/nn/pool/voxel_grid.py @@ -3,11 +3,12 @@ import torch from torch import Tensor +import torch_geometric.typing from torch_geometric.utils.repeat import repeat -try: +if torch_geometric.typing.WITH_TORCH_CLUSTER: from torch_cluster import grid_cluster -except ImportError: +else: grid_cluster = None diff --git a/torch_geometric/testing/decorators.py b/torch_geometric/testing/decorators.py index 40a21ef9f3ff..e0b16e138d2d 100644 --- a/torch_geometric/testing/decorators.py +++ b/torch_geometric/testing/decorators.py @@ -177,12 +177,16 @@ def has_package(package: str) -> bool: req = Requirement(package) if find_spec(req.name) is None: return False - module = import_module(req.name) - if not hasattr(module, '__version__'): - return True - version = Version(module.__version__).base_version - return version in req.specifier + try: + module = import_module(req.name) + if not hasattr(module, '__version__'): + return True + + version = Version(module.__version__).base_version + return version in req.specifier + except Exception: + return False def withPackage(*args: str) -> Callable: