Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLM-Enhanced Text-Attributed Graph (TAG) Representation Learning #9428

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `EdgeIndex.sparse_resize_` functionality ([#8983](https://github.com/pyg-team/pytorch_geometric/pull/8983))
- Added approximate `faiss`-based KNN-search ([#8952](https://github.com/pyg-team/pytorch_geometric/pull/8952))
- Added documentation on Environment setup on XPU device ([#9407](https://github.com/pyg-team/pytorch_geometric/pull/9407))
- Added an example for large language model (LLM) enhanced text-attributed graph (TAG) representation learning ([#9361](https://github.com/pyg-team/pytorch_geometric/pull/9361))

### Changed

Expand Down
6 changes: 3 additions & 3 deletions examples/llm/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Examples for Co-training LLMs and GNNs

| Example | Description |
| ------- | ----------- |
| | |
| Example | Description |
| --------------- | --------------------------------------------------------------------------------------------------------------------------------------------- |
| [tape](./tape/) | [Harnessing Explanations: LLM-to-LM Interpreter for Enhanced Text-Attributed Graph Representation Learning](https://arxiv.org/abs/2305.19523) |
88 changes: 88 additions & 0 deletions examples/llm/tape/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Harnessing Explanations: LLM-to-LM Interpreter for Enhanced Text Attributed Graph Representation Learning

This repository implements the methodology introduced in the [paper](https://arxiv.org/abs/2305.19523) that leverages large language models (LLMs) to enhance text-attributed graph (TAG) representation learning, boosting graph neural network (GNN) performance on downstream tasks.

## Framework Overview

1. **Node Feature Extraction**

- Prepare prompts containing the article information (title and abstract) for each node.
- Query an LLM with these prompts to generate a ranked label prediction list and explanation.

1. **Node Feature Encoder**

- Fine-tune a language model (LM) on a sequence classification task with the article title and abstract as input.

1. **GNN Trainer**

- Train a GNN model using the following features, with node features updated by the fine-tuned LM encoder:
1. Title & Abstract (TA)
1. Prediction (P) - Using a PyTorch `nn.Embedding` layer for top-k ranked features.
1. Explanation (E)

1. **Model Ensemble**

- Fuse predictions from the trained GNN models on TA, P, and E by averaging them.

> \[!Note\]
> Fine-tuning an LM is optional and not currently supported. Instead, you can use any open-weight fine-tuned embedding model, significantly reducing time and cost while achieving comparable results.

## Usage

### Setup the environment

```bash
# Replace the 'cu118' CUDA version according to your system
pip install torch==2.3.0 --index-url https://download.pytorch.org/whl/cu118
pip install torch_geometric
pip install torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.3.0+cu118.html

# For online LLM inference
$ poetry install
# For offline LLM inference
$ poetry install --extras "llm_offline"
```

### Training

```bash
$ python train.py --config=train_config.yaml
# You can also provide CLI arguments to overwrite values in the `train_config.yaml` file
$ python train.py --help
```

- The [train_config.yaml](./train_config.yaml) utilizes the online LLM engine with the model [huggingface/meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct).
- Predictions generated by this model for the PubMed dataset have been uploaded to [Hugging Face](https://huggingface.co/datasets/devanshamin/PubMedDiabetes-LLM-Predictions), which will be downloaded and used instead of calling the LLM during training.
- This optimization significantly accelerates the training process and demonstrates end-to-end training with tape.
- Instead of fine-tuning an LM on the PubMed dataset, the [train_config.yaml](./train_config.yaml) uses a general-purpose embedding model [avsolatorio/GIST-Embedding-v0](https://huggingface.co/avsolatorio/GIST-Embedding-v0).
- With LLM predictions, you can expect the following run time and accuracy when training the GNN for the PubMed dataset using the feature type `TAPE`:

```markdown
When the LM embeddings cache for the dataset is empty,
Feature_type Test_accuracy
TITLE_ABSTRACT (TA) 0.908722
PREDICTION (P) 0.889959
EXPLANATION (E) 0.914807
TAPE (TAPE) 0.946501
Run time: 11 minutes and 14.59 seconds

When the LM embeddings cache for the dataset is present,
Feature_type Test_accuracy
TITLE_ABSTRACT (TA) 0.915061
PREDICTION (P) 0.889452
EXPLANATION (E) 0.923174
TAPE (TAPE) 0.952333
Run time: 1 minute and 0.31 seconds
```

In summary,

| | Current Implementation | Author Implementation |
| -------------- | -------------------------------------- | --------------------------------------- |
| Dataset | PubMed | PubMed |
| LLM | `meta-llama/Meta-Llama-3-8B-Instruct` | `openai/gpt-3.5-turbo-0301` |
| LM fine-tuning | ✖ | ✔ |
| GNN layer | `SAGEConv` | `SAGEConv` |
| GNN hparams | `layers=4, hidden_dim=64, dropout=0.1` | `layers=3, hidden_dim=256, dropout=0.5` |
| Seed runs | 4 | 4 |
| Accuracy | `0.9573 ± 0.0032` | `0.9618 ± 0.0053` |
32 changes: 32 additions & 0 deletions examples/llm/tape/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
[tool.poetry]
name = "tape"
version = "0.1.0"
description = "LLM-to-LM Interpreter for Enhanced Text-Attributed Graph Representation Learning"
authors = ["Devansh Amin"]
readme = "README.md"

[tool.poetry.dependencies]
python = ">=3.9,<3.12"
pandas = "*"
requests = "*"
tqdm = "*"
python-dotenv = "^1.0.1"
gdown = "^5.2.0"
numpy = "*"
jinja2 = "*"
tenacity = "*"
ogb = "^1.3.6"
jsonargparse = {extras = ["omegaconf"], version = "^4.29.0"} # Combining dataclasses + YAML + CLI
instructor = {extras = ["litellm"], version = "^1.3.2"} # Getting structured outputs from LLM
diskcache = "^5.6.3" # Caching LLM responses
transformers = "^4.41.2" # LM inference
sentence-transformers = "^3.0.1" # LM inference
datasets = "^2.19.2" # Loading LLM predictions from Hugging Face
vllm = { version = "0.5.0.post1", optional = true }

[tool.poetry.extras]
llm_offline = ["vllm"]

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
Empty file.
13 changes: 13 additions & 0 deletions examples/llm/tape/tape/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from enum import Enum


class DatasetName(str, Enum):
PUBMED = 'pubmed'
OGBN_ARXIV = 'ogbn_arxiv'


class FeatureType(str, Enum):
TITLE_ABSTRACT = 'TA'
PREDICTION = 'P'
EXPLANATION = 'E'
TAPE = 'TAPE'
Empty file.
153 changes: 153 additions & 0 deletions examples/llm/tape/tape/dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
from typing import Optional, Union

import torch
from tape.config import DatasetName, FeatureType
from tape.dataset import parser
from tape.dataset.llm.engine import LlmOfflineEngineArgs, LlmOnlineEngineArgs
from tape.dataset.lm_encoder import LmEncoder, LmEncoderArgs

from torch_geometric.data import Data


class GraphDataset:
def __init__(
self, dataset_name: DatasetName, feature_type: FeatureType,
lm_encoder_args: LmEncoderArgs,
llm_online_engine_args: Optional[LlmOnlineEngineArgs] = None,
llm_offline_engine_args: Optional[LlmOfflineEngineArgs] = None,
device: Optional[Union[str, torch.device]] = None,
seed: Optional[int] = 42, cache_dir: str = '.cache') -> None:

self.seed = seed
self.dataset_name = dataset_name
self.feature_type = feature_type
self.llm_online_engine_args = llm_online_engine_args
self.llm_offline_engine_args = llm_offline_engine_args
self.cache_dir = cache_dir

assert llm_online_engine_args or llm_offline_engine_args, (
'LLM online/offline engine arguments cannot be empty!'
'Please provide either one of them.')

lm_encoder_args.device = device
self.lm_encoder = LmEncoder(args=lm_encoder_args)

self._parser = None
self._dataset = None
self._topk = None

@property
def dataset(self) -> Data:
if self._dataset is None:
self.load_dataset()
self.update_node_features()
return self._dataset

@property
def num_classes(self) -> int:
return self._parser.graph.n_classes

@property
def topk(self) -> int:
"""TopK ranked LLM predictions."""
if self._topk is None:
_ = self.dataset
self._topk = min(self._parser.graph.n_classes, 5)
return self._topk

def load_dataset(self) -> None:
if self.dataset_name == DatasetName.PUBMED:
cls = parser.PubmedParser
elif self.dataset_name == DatasetName.OGBN_ARXIV:
cls = parser.OgbnArxivParser
else:
raise ValueError(f'Invalid dataset name "{self.dataset_name}"!')

self._parser = cls(seed=self.seed, cache_dir=self.cache_dir)
self._parser.parse()
self._dataset = self._parser.graph.dataset

def update_node_features(self) -> None:
"""Update original node features with Language Model (LM) features."""
ftype = self.feature_type
print('Generating node features for feature type '
f"'{ftype.name} ({ftype.value})'...")
graph = self._parser.graph
articles = graph.articles

if ftype == FeatureType.TITLE_ABSTRACT:
sentences = [
f'Title: {article.title}\nAbstract: {article.abstract}'
for article in articles
]
features = self.lm_encoder(sentences)
features = torch.stack(features)
self.lm_encoder.save_cache()
else:
responses = self._get_llm_responses()

if ftype == FeatureType.EXPLANATION:
features = self.lm_encoder(
sentences=[resp.reason for resp in responses])
features = torch.stack(features)
self.lm_encoder.save_cache()
else:
# FeatureType.PREDICTION
label2id = {
v['label'] if isinstance(v, dict) else v: k
for k, v in graph.class_id_to_label.items()
}
features = torch.zeros((self._dataset.num_nodes, self.topk))
for i, resp in enumerate(responses):
# Convert the predicted labels (which are strings) to their
# corresponding integer IDs.
preds = [label2id[label] for label in resp.label]

# Assign the converted predictions to the corresponding row
# in the features tensor.
# `preds` can have fewer elements than `topk`, so we only
# fill as many elements as we have in `preds`.
# We add 1 to each ID because the nn.Embedding layer
# typically expects non-zero indices to learn embeddings.
# Zero can be used to represent padding or a non-existent
# class.
features[i][:len(preds)] = torch.tensor(
preds, dtype=torch.long) + 1

# Explanation of why we add 1 to the labels:
# The OGBN-Arxiv dataset contains LLM predictions where
# the labels are fixed topk values.
# In contrast, the PubMed dataset contains LLM predictions
# where the labels can be either a single value or
# multiple values.
# During GNN training, the features tensor is passed to an
# `nn.Embedding` layer.
# If we had topk=3 and preds = [0], initializing the
# features with zeros would make it difficult to
# distinguish between "no prediction" and
# "prediction of class 0". To denote that the class is
# present, we increment the value by 1.

self._dataset.x = features

def _get_llm_responses(self):

graph = self._parser.graph

if self.llm_online_engine_args:
from tape.dataset.llm import online as engine

args = self.llm_online_engine_args
else:
from tape.dataset.llm import offline as engine

args = self.llm_offline_engine_args

if self.dataset_name == DatasetName.PUBMED:
cls = engine.LlmPubmedResponses
elif self.dataset_name == DatasetName.OGBN_ARXIV:
cls = engine.LlmOgbnArxivResponses

llm = cls(args=args, class_id_to_label=graph.class_id_to_label)
responses = llm.get_responses_from_articles(articles=graph.articles)
return responses
Empty file.
52 changes: 52 additions & 0 deletions examples/llm/tape/tape/dataset/llm/engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Union

from dotenv import load_dotenv
from pydantic import BaseModel
from tape.dataset.parser import Article

load_dotenv()


@dataclass
class LlmOnlineEngineArgs:
model: str
max_retries: int = 5
# Arguments for OpenAI's `client.chat.completions.create` method
sampling_kwargs: Optional[Dict] = None
rate_limit_per_minute: Optional[int] = None # Requests per minute (RPM)
cache_dir: str = '.cache'

def __post_init__(self) -> None:
if self.cache_dir:
self.cache_dir = str(Path.cwd() / self.cache_dir)


@dataclass
class LlmOfflineEngineArgs(LlmOnlineEngineArgs):
batch_size: int = 100
# sampling_kwargs ➜ Arguments for `vllm.EngineArgs`
engine_kwargs: Optional[Dict] = None # Arguments for `vllm.EngineArgs`


class LlmResponseModel(BaseModel, ABC):
label: List[str]
reason: str


class LlmEngine(ABC):
def __init__(
self, args: Union[LlmOnlineEngineArgs,
LlmOfflineEngineArgs]) -> None:
self.args = args

@abstractmethod
def __call__(self) -> Optional[LlmResponseModel]:
pass

@abstractmethod
def get_responses_from_articles(
self, articles: List[Article]) -> List[LlmResponseModel]:
pass
4 changes: 4 additions & 0 deletions examples/llm/tape/tape/dataset/llm/offline/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .ogbn_arxiv import LlmOgbnArxivResponses
from .pubmed import LlmPubmedResponses

__all__ = ['LlmOgbnArxivResponses', 'LlmPubmedResponses']
Loading