Skip to content

Commit

Permalink
Merge branch 'master' into mz/griddata
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Jun 17, 2024
2 parents 7f485ff + c7f2cc6 commit 5ba8f0f
Show file tree
Hide file tree
Showing 14 changed files with 100 additions and 33 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ dependencies=[
[project.optional-dependencies]
graphgym=[
"protobuf<4.21",
"pytorch-lightning",
"pytorch-lightning<2.3.0",
"yacs",
]
modelhub=[
Expand Down Expand Up @@ -77,7 +77,7 @@ full = [
"h5py",
"matplotlib",
"networkx",
"numba",
"numba<0.60.0",
"opt_einsum",
"pandas",
"pgmpy",
Expand Down
4 changes: 4 additions & 0 deletions test/loader/test_zip_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ def test_zip_loader(filter_per_worker):
loader = ZipLoader(loaders, batch_size=10,
filter_per_worker=filter_per_worker)

batches = loader(torch.arange(5))
assert isinstance(batches, tuple)
assert len(batches) == 2

assert str(loader) == ('ZipLoader(loaders=[NeighborLoader(), '
'NeighborLoader()])')
assert len(loader) == 5
Expand Down
8 changes: 6 additions & 2 deletions test/nn/nlp/test_sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@
@onlyFullTest
@withPackage('transformers')
@pytest.mark.parametrize('batch_size', [None, 1])
def test_sentence_transformer(batch_size, device):
model = SentenceTransformer(model_name='prajjwal1/bert-tiny').to(device)
@pytest.mark.parametrize('pooling_strategy', ['mean', 'last', 'cls'])
def test_sentence_transformer(batch_size, pooling_strategy, device):
model = SentenceTransformer(
model_name='prajjwal1/bert-tiny',
pooling_strategy=pooling_strategy,
).to(device)
assert model.device == device
assert str(model) == 'SentenceTransformer(model_name=prajjwal1/bert-tiny)'

Expand Down
12 changes: 10 additions & 2 deletions test/nn/test_model_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,19 @@ def test_save_pretrained_with_push_to_hub(model, tmp_path):
# Push to hub with repo_id
model.save_pretrained(save_directory, push_to_hub=True, repo_id='CustomID',
config=CONFIG)
model.push_to_hub.assert_called_with(repo_id='CustomID', config=CONFIG)
model.push_to_hub.assert_called_with(
repo_id='CustomID',
model_card_kwargs={},
config=CONFIG,
)

# Push to hub with default repo_id (based on dir name)
model.save_pretrained(save_directory, push_to_hub=True, config=CONFIG)
model.push_to_hub.assert_called_with(repo_id=REPO_NAME, config=CONFIG)
model.push_to_hub.assert_called_with(
repo_id=REPO_NAME,
model_card_kwargs={},
config=CONFIG,
)


@withPackage('huggingface_hub')
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/data/lightning/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
try:
from pytorch_lightning import LightningDataModule as PLLightningDataModule
no_pytorch_lightning = False
except (ImportError, ModuleNotFoundError):
except ImportError:
PLLightningDataModule = object # type: ignore
no_pytorch_lightning = True

Expand Down
10 changes: 10 additions & 0 deletions torch_geometric/loader/zip_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@ def __init__(
self.loaders = loaders
self.filter_per_worker = filter_per_worker

def __call__(
self,
index: Union[Tensor, List[int]],
) -> Union[Tuple[Data, ...], Tuple[HeteroData, ...]]:
r"""Samples subgraphs from a batch of input IDs."""
out = self.collate_fn(index)
if not self.filter_per_worker:
out = self.filter_fn(out)
return out

def collate_fn(self, index: List[int]) -> Tuple[Any, ...]:
if not isinstance(index, Tensor):
index = torch.tensor(index, dtype=torch.long)
Expand Down
5 changes: 3 additions & 2 deletions torch_geometric/nn/conv/edge_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import torch
from torch import Tensor

import torch_geometric.typing
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import reset
from torch_geometric.typing import Adj, OptTensor, PairOptTensor, PairTensor

try:
if torch_geometric.typing.WITH_TORCH_CLUSTER:
from torch_cluster import knn
except ImportError:
else:
knn = None


Expand Down
5 changes: 3 additions & 2 deletions torch_geometric/nn/conv/gravnet_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
import torch
from torch import Tensor

import torch_geometric.typing
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import OptPairTensor # noqa
from torch_geometric.typing import OptTensor, PairOptTensor, PairTensor

try:
if torch_geometric.typing.WITH_TORCH_CLUSTER:
from torch_cluster import knn
except ImportError:
else:
knn = None


Expand Down
8 changes: 4 additions & 4 deletions torch_geometric/nn/conv/spline_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@
from torch import Tensor, nn
from torch.nn import Parameter

import torch_geometric.typing
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import uniform, zeros
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size
from torch_geometric.utils.repeat import repeat

try:
if torch_geometric.typing.WITH_TORCH_SPLINE_CONV:
from torch_spline_conv import spline_basis, spline_weighting
except (ImportError, OSError): # Fail gracefully on GLIBC errors
spline_basis = None
spline_weighting = None
else:
spline_basis = spline_weighting = None


class SplineConv(MessagePassing):
Expand Down
5 changes: 3 additions & 2 deletions torch_geometric/nn/conv/x_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
from torch.nn import Linear as L
from torch.nn import Sequential as S

import torch_geometric.typing
from torch_geometric.nn import Reshape
from torch_geometric.nn.inits import reset

try:
if torch_geometric.typing.WITH_TORCH_CLUSTER:
from torch_cluster import knn_graph
except ImportError:
else:
knn_graph = None


Expand Down
45 changes: 38 additions & 7 deletions torch_geometric/nn/nlp/sentence_transformer.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,45 @@
from typing import List, Optional
from enum import Enum
from typing import List, Optional, Union

import torch
import torch.nn.functional as F
from torch import Tensor


class PoolingStrategy(Enum):
MEAN = 'mean'
LAST = 'last'
CLS = 'cls'


class SentenceTransformer(torch.nn.Module):
def __init__(self, model_name: str) -> None:
def __init__(
self,
model_name: str,
pooling_strategy: Union[PoolingStrategy, str] = 'mean',
) -> None:
super().__init__()

self.model_name = model_name
self.pooling_strategy = PoolingStrategy(pooling_strategy)

from transformers import AutoModel, AutoTokenizer

self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)

def mean_pooling(self, emb: Tensor, attention_mask: Tensor) -> Tensor:
mask = attention_mask.unsqueeze(-1).expand(emb.size()).to(emb.dtype)
return (emb * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)

def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
out = self.model(input_ids=input_ids, attention_mask=attention_mask)

emb = out[0] # First element contains all token embeddings.
emb = self.mean_pooling(emb, attention_mask)
if self.pooling_strategy == PoolingStrategy.MEAN:
emb = mean_pooling(emb, attention_mask)
elif self.pooling_strategy == PoolingStrategy.LAST:
emb = last_pooling(emb, attention_mask)
else:
assert self.pooling_strategy == PoolingStrategy.CLS
emb = emb[:, 0, :]

emb = F.normalize(emb, p=2, dim=1)
return emb

Expand Down Expand Up @@ -61,3 +76,19 @@ def encode(

def __repr__(self) -> str:
return f'{self.__class__.__name__}(model_name={self.model_name})'


def mean_pooling(emb: Tensor, attention_mask: Tensor) -> Tensor:
mask = attention_mask.unsqueeze(-1).expand(emb.size()).to(emb.dtype)
return (emb * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)


def last_pooling(emb: Tensor, attention_mask: Tensor) -> Tensor:
# Check whether language model uses left padding,
# which is always used for decoder LLMs
left_padding = attention_mask[:, -1].sum() == attention_mask.size(0)
if left_padding:
return emb[:, -1]

seq_indices = attention_mask.sum(dim=1) - 1
return emb[torch.arange(emb.size(0), device=emb.device), seq_indices]
6 changes: 4 additions & 2 deletions torch_geometric/nn/pool/graclus.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

from torch import Tensor

try:
import torch_geometric.typing

if torch_geometric.typing.WITH_TORCH_CLUSTER:
from torch_cluster import graclus_cluster
except ImportError:
else:
graclus_cluster = None


Expand Down
5 changes: 3 additions & 2 deletions torch_geometric/nn/pool/voxel_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import torch
from torch import Tensor

import torch_geometric.typing
from torch_geometric.utils.repeat import repeat

try:
if torch_geometric.typing.WITH_TORCH_CLUSTER:
from torch_cluster import grid_cluster
except ImportError:
else:
grid_cluster = None


Expand Down
14 changes: 9 additions & 5 deletions torch_geometric/testing/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,16 @@ def has_package(package: str) -> bool:
req = Requirement(package)
if find_spec(req.name) is None:
return False
module = import_module(req.name)
if not hasattr(module, '__version__'):
return True

version = Version(module.__version__).base_version
return version in req.specifier
try:
module = import_module(req.name)
if not hasattr(module, '__version__'):
return True

version = Version(module.__version__).base_version
return version in req.specifier
except Exception:
return False


def withPackage(*args: str) -> Callable:
Expand Down

0 comments on commit 5ba8f0f

Please sign in to comment.