Skip to content

Commit

Permalink
GNN+LLM example (part 1) (#9350)
Browse files Browse the repository at this point in the history
Splits #9167 into
multiple PRs.

---------

Co-authored-by: puririshi98 <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Rishi Puri <[email protected]>
Co-authored-by: Akihiro Nitta <[email protected]>
  • Loading branch information
5 people authored May 22, 2024
1 parent 9745df0 commit 9dc2302
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Started work on GNN+LLM package ([#9350](https://github.com/pyg-team/pytorch_geometric/pull/9350))
- Added support for negative sampling in `LinkLoader` acccording to source and destination node weights ([#9316](https://github.com/pyg-team/pytorch_geometric/pull/9316))
- Added support for `EdgeIndex.unbind` ([#9298](https://github.com/pyg-team/pytorch_geometric/pull/9298))
- Integrate `torch_geometric.Index` into `torch_geometric.EdgeIndex` ([#9296](https://github.com/pyg-team/pytorch_geometric/pull/9296))
Expand Down
2 changes: 2 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,5 @@ For examples on using `torch.compile`, see the examples under [`examples/compile
For examples on scaling PyG up via multi-GPUs, see the examples under [`examples/multi_gpu`](./multi_gpu).

For examples on working with heterogeneous data, see the examples under [`examples/hetero`](./hetero).

For examples on co-training LLMs with GNNs, see the examples under [`examples/llm`](./llm).
5 changes: 5 additions & 0 deletions examples/llm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Examples for Co-training LLMs and GNNs

| Example | Description |
| ------- | ----------- |
| | |
26 changes: 26 additions & 0 deletions test/nn/nlp/test_sentence_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pytest

from torch_geometric.nn.nlp import SentenceTransformer
from torch_geometric.testing import onlyFullTest, withCUDA


@withCUDA
@onlyFullTest
@pytest.mark.parametrize('batch_size', [None, 1])
def test_sentence_transformer(batch_size, device):
model = SentenceTransformer(model_name='prajjwal1/bert-tiny').to(device)
assert model.device == device
assert str(model) == 'SentenceTransformer(model_name=prajjwal1/bert-tiny)'

text = [
"this is a basic english text",
"PyG is the best open-source GNN library :)",
]

out = model.encode(text, batch_size=batch_size)
assert out.is_cpu
assert out.size() == (2, 128)

out = model.encode(text, batch_size=batch_size, output_device=device)
assert out.device == device
assert out.size() == (2, 128)
5 changes: 5 additions & 0 deletions torch_geometric/nn/nlp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .sentence_transformer import SentenceTransformer

__all__ = classes = [
'SentenceTransformer',
]
63 changes: 63 additions & 0 deletions torch_geometric/nn/nlp/sentence_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import List, Optional

import torch
import torch.nn.functional as F
from torch import Tensor


class SentenceTransformer(torch.nn.Module):
def __init__(self, model_name: str) -> None:
super().__init__()

self.model_name = model_name

from transformers import AutoModel, AutoTokenizer

self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)

def mean_pooling(self, emb: Tensor, attention_mask: Tensor) -> Tensor:
mask = attention_mask.unsqueeze(-1).expand(emb.size()).to(emb.dtype)
return (emb * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)

def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
out = self.model(input_ids=input_ids, attention_mask=attention_mask)

emb = out[0] # First element contains all token embeddings.
emb = self.mean_pooling(emb, attention_mask)
emb = F.normalize(emb, p=2, dim=1)
return emb

@property
def device(self) -> torch.device:
return next(iter(self.model.parameters())).device

@torch.no_grad()
def encode(
self,
text: List[str],
batch_size: Optional[int] = None,
output_device: Optional[torch.device] = None,
) -> Tensor:
batch_size = len(text) if batch_size is None else batch_size

embs: List[Tensor] = []
for start in range(0, len(text), batch_size):
token = self.tokenizer(
text[start:start + batch_size],
padding=True,
truncation=True,
return_tensors='pt',
)

emb = self(
input_ids=token.input_ids.to(self.device),
attention_mask=token.attention_mask.to(self.device),
).to(output_device or 'cpu')

embs.append(emb)

return torch.cat(embs, dim=0) if len(embs) > 1 else embs[0]

def __repr__(self) -> str:
return f'{self.__class__.__name__}(model_name={self.model_name})'

0 comments on commit 9dc2302

Please sign in to comment.