-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Splits #9167 into multiple PRs. --------- Co-authored-by: puririshi98 <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Rishi Puri <[email protected]> Co-authored-by: Akihiro Nitta <[email protected]>
- Loading branch information
1 parent
9745df0
commit 9dc2302
Showing
6 changed files
with
102 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Examples for Co-training LLMs and GNNs | ||
|
||
| Example | Description | | ||
| ------- | ----------- | | ||
| | | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import pytest | ||
|
||
from torch_geometric.nn.nlp import SentenceTransformer | ||
from torch_geometric.testing import onlyFullTest, withCUDA | ||
|
||
|
||
@withCUDA | ||
@onlyFullTest | ||
@pytest.mark.parametrize('batch_size', [None, 1]) | ||
def test_sentence_transformer(batch_size, device): | ||
model = SentenceTransformer(model_name='prajjwal1/bert-tiny').to(device) | ||
assert model.device == device | ||
assert str(model) == 'SentenceTransformer(model_name=prajjwal1/bert-tiny)' | ||
|
||
text = [ | ||
"this is a basic english text", | ||
"PyG is the best open-source GNN library :)", | ||
] | ||
|
||
out = model.encode(text, batch_size=batch_size) | ||
assert out.is_cpu | ||
assert out.size() == (2, 128) | ||
|
||
out = model.encode(text, batch_size=batch_size, output_device=device) | ||
assert out.device == device | ||
assert out.size() == (2, 128) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .sentence_transformer import SentenceTransformer | ||
|
||
__all__ = classes = [ | ||
'SentenceTransformer', | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from typing import List, Optional | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch import Tensor | ||
|
||
|
||
class SentenceTransformer(torch.nn.Module): | ||
def __init__(self, model_name: str) -> None: | ||
super().__init__() | ||
|
||
self.model_name = model_name | ||
|
||
from transformers import AutoModel, AutoTokenizer | ||
|
||
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
self.model = AutoModel.from_pretrained(model_name) | ||
|
||
def mean_pooling(self, emb: Tensor, attention_mask: Tensor) -> Tensor: | ||
mask = attention_mask.unsqueeze(-1).expand(emb.size()).to(emb.dtype) | ||
return (emb * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9) | ||
|
||
def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor: | ||
out = self.model(input_ids=input_ids, attention_mask=attention_mask) | ||
|
||
emb = out[0] # First element contains all token embeddings. | ||
emb = self.mean_pooling(emb, attention_mask) | ||
emb = F.normalize(emb, p=2, dim=1) | ||
return emb | ||
|
||
@property | ||
def device(self) -> torch.device: | ||
return next(iter(self.model.parameters())).device | ||
|
||
@torch.no_grad() | ||
def encode( | ||
self, | ||
text: List[str], | ||
batch_size: Optional[int] = None, | ||
output_device: Optional[torch.device] = None, | ||
) -> Tensor: | ||
batch_size = len(text) if batch_size is None else batch_size | ||
|
||
embs: List[Tensor] = [] | ||
for start in range(0, len(text), batch_size): | ||
token = self.tokenizer( | ||
text[start:start + batch_size], | ||
padding=True, | ||
truncation=True, | ||
return_tensors='pt', | ||
) | ||
|
||
emb = self( | ||
input_ids=token.input_ids.to(self.device), | ||
attention_mask=token.attention_mask.to(self.device), | ||
).to(output_device or 'cpu') | ||
|
||
embs.append(emb) | ||
|
||
return torch.cat(embs, dim=0) if len(embs) > 1 else embs[0] | ||
|
||
def __repr__(self) -> str: | ||
return f'{self.__class__.__name__}(model_name={self.model_name})' |