diff --git a/CHANGELOG.md b/CHANGELOG.md index 7b210f63b67b..40b9fd912178 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `EdgeIndex.sparse_resize_` functionality ([#8983](https://github.com/pyg-team/pytorch_geometric/pull/8983)) - Added approximate `faiss`-based KNN-search ([#8952](https://github.com/pyg-team/pytorch_geometric/pull/8952)) - Added documentation on Environment setup on XPU device ([#9407](https://github.com/pyg-team/pytorch_geometric/pull/9407)) +- Added an example for large language model (LLM) enhanced text-attributed graph (TAG) representation learning ([#9361](https://github.com/pyg-team/pytorch_geometric/pull/9361)) ### Changed diff --git a/examples/llm/README.md b/examples/llm/README.md index 9899d2ba4869..13c52ae8e260 100644 --- a/examples/llm/README.md +++ b/examples/llm/README.md @@ -1,5 +1,5 @@ # Examples for Co-training LLMs and GNNs -| Example | Description | -| ------- | ----------- | -| | | +| Example | Description | +| --------------- | --------------------------------------------------------------------------------------------------------------------------------------------- | +| [tape](./tape/) | [Harnessing Explanations: LLM-to-LM Interpreter for Enhanced Text-Attributed Graph Representation Learning](https://arxiv.org/abs/2305.19523) | diff --git a/examples/llm/tape/README.md b/examples/llm/tape/README.md new file mode 100644 index 000000000000..32fd245c9fc2 --- /dev/null +++ b/examples/llm/tape/README.md @@ -0,0 +1,88 @@ +# Harnessing Explanations: LLM-to-LM Interpreter for Enhanced Text Attributed Graph Representation Learning + +This repository implements the methodology introduced in the [paper](https://arxiv.org/abs/2305.19523) that leverages large language models (LLMs) to enhance text-attributed graph (TAG) representation learning, boosting graph neural network (GNN) performance on downstream tasks. + +## Framework Overview + +1. **Node Feature Extraction** + + - Prepare prompts containing the article information (title and abstract) for each node. + - Query an LLM with these prompts to generate a ranked label prediction list and explanation. + +1. **Node Feature Encoder** + + - Fine-tune a language model (LM) on a sequence classification task with the article title and abstract as input. + +1. **GNN Trainer** + + - Train a GNN model using the following features, with node features updated by the fine-tuned LM encoder: + 1. Title & Abstract (TA) + 1. Prediction (P) - Using a PyTorch `nn.Embedding` layer for top-k ranked features. + 1. Explanation (E) + +1. **Model Ensemble** + + - Fuse predictions from the trained GNN models on TA, P, and E by averaging them. + +> \[!Note\] +> Fine-tuning an LM is optional and not currently supported. Instead, you can use any open-weight fine-tuned embedding model, significantly reducing time and cost while achieving comparable results. + +## Usage + +### Setup the environment + +```bash +# Replace the 'cu118' CUDA version according to your system +pip install torch==2.3.0 --index-url https://download.pytorch.org/whl/cu118 +pip install torch_geometric +pip install torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.3.0+cu118.html + +# For online LLM inference +$ poetry install +# For offline LLM inference +$ poetry install --extras "llm_offline" +``` + +### Training + +```bash +$ python train.py --config=train_config.yaml +# You can also provide CLI arguments to overwrite values in the `train_config.yaml` file +$ python train.py --help +``` + +- The [train_config.yaml](./train_config.yaml) utilizes the online LLM engine with the model [huggingface/meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct). +- Predictions generated by this model for the PubMed dataset have been uploaded to [Hugging Face](https://huggingface.co/datasets/devanshamin/PubMedDiabetes-LLM-Predictions), which will be downloaded and used instead of calling the LLM during training. +- This optimization significantly accelerates the training process and demonstrates end-to-end training with tape. +- Instead of fine-tuning an LM on the PubMed dataset, the [train_config.yaml](./train_config.yaml) uses a general-purpose embedding model [avsolatorio/GIST-Embedding-v0](https://huggingface.co/avsolatorio/GIST-Embedding-v0). +- With LLM predictions, you can expect the following run time and accuracy when training the GNN for the PubMed dataset using the feature type `TAPE`: + +```markdown +When the LM embeddings cache for the dataset is empty, +Feature_type Test_accuracy +TITLE_ABSTRACT (TA) 0.908722 +PREDICTION (P) 0.889959 +EXPLANATION (E) 0.914807 +TAPE (TAPE) 0.946501 +Run time: 11 minutes and 14.59 seconds + +When the LM embeddings cache for the dataset is present, +Feature_type Test_accuracy +TITLE_ABSTRACT (TA) 0.915061 +PREDICTION (P) 0.889452 +EXPLANATION (E) 0.923174 +TAPE (TAPE) 0.952333 +Run time: 1 minute and 0.31 seconds +``` + +In summary, + +| | Current Implementation | Author Implementation | +| -------------- | -------------------------------------- | --------------------------------------- | +| Dataset | PubMed | PubMed | +| LLM | `meta-llama/Meta-Llama-3-8B-Instruct` | `openai/gpt-3.5-turbo-0301` | +| LM fine-tuning | ✖ | ✔ | +| GNN layer | `SAGEConv` | `SAGEConv` | +| GNN hparams | `layers=4, hidden_dim=64, dropout=0.1` | `layers=3, hidden_dim=256, dropout=0.5` | +| Seed runs | 4 | 4 | +| Accuracy | `0.9573 ± 0.0032` | `0.9618 ± 0.0053` | diff --git a/examples/llm/tape/pyproject.toml b/examples/llm/tape/pyproject.toml new file mode 100644 index 000000000000..d5eb6f92fef4 --- /dev/null +++ b/examples/llm/tape/pyproject.toml @@ -0,0 +1,32 @@ +[tool.poetry] +name = "tape" +version = "0.1.0" +description = "LLM-to-LM Interpreter for Enhanced Text-Attributed Graph Representation Learning" +authors = ["Devansh Amin"] +readme = "README.md" + +[tool.poetry.dependencies] +python = ">=3.9,<3.12" +pandas = "*" +requests = "*" +tqdm = "*" +python-dotenv = "^1.0.1" +gdown = "^5.2.0" +numpy = "*" +jinja2 = "*" +tenacity = "*" +ogb = "^1.3.6" +jsonargparse = {extras = ["omegaconf"], version = "^4.29.0"} # Combining dataclasses + YAML + CLI +instructor = {extras = ["litellm"], version = "^1.3.2"} # Getting structured outputs from LLM +diskcache = "^5.6.3" # Caching LLM responses +transformers = "^4.41.2" # LM inference +sentence-transformers = "^3.0.1" # LM inference +datasets = "^2.19.2" # Loading LLM predictions from Hugging Face +vllm = { version = "0.5.0.post1", optional = true } + +[tool.poetry.extras] +llm_offline = ["vllm"] + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" diff --git a/examples/llm/tape/tape/__init__.py b/examples/llm/tape/tape/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/examples/llm/tape/tape/config.py b/examples/llm/tape/tape/config.py new file mode 100644 index 000000000000..3c64af40c54a --- /dev/null +++ b/examples/llm/tape/tape/config.py @@ -0,0 +1,13 @@ +from enum import Enum + + +class DatasetName(str, Enum): + PUBMED = 'pubmed' + OGBN_ARXIV = 'ogbn_arxiv' + + +class FeatureType(str, Enum): + TITLE_ABSTRACT = 'TA' + PREDICTION = 'P' + EXPLANATION = 'E' + TAPE = 'TAPE' diff --git a/examples/llm/tape/tape/dataset/__init__.py b/examples/llm/tape/tape/dataset/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/examples/llm/tape/tape/dataset/dataset.py b/examples/llm/tape/tape/dataset/dataset.py new file mode 100644 index 000000000000..5c8565866400 --- /dev/null +++ b/examples/llm/tape/tape/dataset/dataset.py @@ -0,0 +1,153 @@ +from typing import Optional, Union + +import torch +from tape.config import DatasetName, FeatureType +from tape.dataset import parser +from tape.dataset.llm.engine import LlmOfflineEngineArgs, LlmOnlineEngineArgs +from tape.dataset.lm_encoder import LmEncoder, LmEncoderArgs + +from torch_geometric.data import Data + + +class GraphDataset: + def __init__( + self, dataset_name: DatasetName, feature_type: FeatureType, + lm_encoder_args: LmEncoderArgs, + llm_online_engine_args: Optional[LlmOnlineEngineArgs] = None, + llm_offline_engine_args: Optional[LlmOfflineEngineArgs] = None, + device: Optional[Union[str, torch.device]] = None, + seed: Optional[int] = 42, cache_dir: str = '.cache') -> None: + + self.seed = seed + self.dataset_name = dataset_name + self.feature_type = feature_type + self.llm_online_engine_args = llm_online_engine_args + self.llm_offline_engine_args = llm_offline_engine_args + self.cache_dir = cache_dir + + assert llm_online_engine_args or llm_offline_engine_args, ( + 'LLM online/offline engine arguments cannot be empty!' + 'Please provide either one of them.') + + lm_encoder_args.device = device + self.lm_encoder = LmEncoder(args=lm_encoder_args) + + self._parser = None + self._dataset = None + self._topk = None + + @property + def dataset(self) -> Data: + if self._dataset is None: + self.load_dataset() + self.update_node_features() + return self._dataset + + @property + def num_classes(self) -> int: + return self._parser.graph.n_classes + + @property + def topk(self) -> int: + """TopK ranked LLM predictions.""" + if self._topk is None: + _ = self.dataset + self._topk = min(self._parser.graph.n_classes, 5) + return self._topk + + def load_dataset(self) -> None: + if self.dataset_name == DatasetName.PUBMED: + cls = parser.PubmedParser + elif self.dataset_name == DatasetName.OGBN_ARXIV: + cls = parser.OgbnArxivParser + else: + raise ValueError(f'Invalid dataset name "{self.dataset_name}"!') + + self._parser = cls(seed=self.seed, cache_dir=self.cache_dir) + self._parser.parse() + self._dataset = self._parser.graph.dataset + + def update_node_features(self) -> None: + """Update original node features with Language Model (LM) features.""" + ftype = self.feature_type + print('Generating node features for feature type ' + f"'{ftype.name} ({ftype.value})'...") + graph = self._parser.graph + articles = graph.articles + + if ftype == FeatureType.TITLE_ABSTRACT: + sentences = [ + f'Title: {article.title}\nAbstract: {article.abstract}' + for article in articles + ] + features = self.lm_encoder(sentences) + features = torch.stack(features) + self.lm_encoder.save_cache() + else: + responses = self._get_llm_responses() + + if ftype == FeatureType.EXPLANATION: + features = self.lm_encoder( + sentences=[resp.reason for resp in responses]) + features = torch.stack(features) + self.lm_encoder.save_cache() + else: + # FeatureType.PREDICTION + label2id = { + v['label'] if isinstance(v, dict) else v: k + for k, v in graph.class_id_to_label.items() + } + features = torch.zeros((self._dataset.num_nodes, self.topk)) + for i, resp in enumerate(responses): + # Convert the predicted labels (which are strings) to their + # corresponding integer IDs. + preds = [label2id[label] for label in resp.label] + + # Assign the converted predictions to the corresponding row + # in the features tensor. + # `preds` can have fewer elements than `topk`, so we only + # fill as many elements as we have in `preds`. + # We add 1 to each ID because the nn.Embedding layer + # typically expects non-zero indices to learn embeddings. + # Zero can be used to represent padding or a non-existent + # class. + features[i][:len(preds)] = torch.tensor( + preds, dtype=torch.long) + 1 + + # Explanation of why we add 1 to the labels: + # The OGBN-Arxiv dataset contains LLM predictions where + # the labels are fixed topk values. + # In contrast, the PubMed dataset contains LLM predictions + # where the labels can be either a single value or + # multiple values. + # During GNN training, the features tensor is passed to an + # `nn.Embedding` layer. + # If we had topk=3 and preds = [0], initializing the + # features with zeros would make it difficult to + # distinguish between "no prediction" and + # "prediction of class 0". To denote that the class is + # present, we increment the value by 1. + + self._dataset.x = features + + def _get_llm_responses(self): + + graph = self._parser.graph + + if self.llm_online_engine_args: + from tape.dataset.llm import online as engine + + args = self.llm_online_engine_args + else: + from tape.dataset.llm import offline as engine + + args = self.llm_offline_engine_args + + if self.dataset_name == DatasetName.PUBMED: + cls = engine.LlmPubmedResponses + elif self.dataset_name == DatasetName.OGBN_ARXIV: + cls = engine.LlmOgbnArxivResponses + + llm = cls(args=args, class_id_to_label=graph.class_id_to_label) + responses = llm.get_responses_from_articles(articles=graph.articles) + return responses diff --git a/examples/llm/tape/tape/dataset/llm/__init__.py b/examples/llm/tape/tape/dataset/llm/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/examples/llm/tape/tape/dataset/llm/engine.py b/examples/llm/tape/tape/dataset/llm/engine.py new file mode 100644 index 000000000000..e3698ef7cde1 --- /dev/null +++ b/examples/llm/tape/tape/dataset/llm/engine.py @@ -0,0 +1,52 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Union + +from dotenv import load_dotenv +from pydantic import BaseModel +from tape.dataset.parser import Article + +load_dotenv() + + +@dataclass +class LlmOnlineEngineArgs: + model: str + max_retries: int = 5 + # Arguments for OpenAI's `client.chat.completions.create` method + sampling_kwargs: Optional[Dict] = None + rate_limit_per_minute: Optional[int] = None # Requests per minute (RPM) + cache_dir: str = '.cache' + + def __post_init__(self) -> None: + if self.cache_dir: + self.cache_dir = str(Path.cwd() / self.cache_dir) + + +@dataclass +class LlmOfflineEngineArgs(LlmOnlineEngineArgs): + batch_size: int = 100 + # sampling_kwargs ➜ Arguments for `vllm.EngineArgs` + engine_kwargs: Optional[Dict] = None # Arguments for `vllm.EngineArgs` + + +class LlmResponseModel(BaseModel, ABC): + label: List[str] + reason: str + + +class LlmEngine(ABC): + def __init__( + self, args: Union[LlmOnlineEngineArgs, + LlmOfflineEngineArgs]) -> None: + self.args = args + + @abstractmethod + def __call__(self) -> Optional[LlmResponseModel]: + pass + + @abstractmethod + def get_responses_from_articles( + self, articles: List[Article]) -> List[LlmResponseModel]: + pass diff --git a/examples/llm/tape/tape/dataset/llm/offline/__init__.py b/examples/llm/tape/tape/dataset/llm/offline/__init__.py new file mode 100644 index 000000000000..7f73a44ab97d --- /dev/null +++ b/examples/llm/tape/tape/dataset/llm/offline/__init__.py @@ -0,0 +1,4 @@ +from .ogbn_arxiv import LlmOgbnArxivResponses +from .pubmed import LlmPubmedResponses + +__all__ = ['LlmOgbnArxivResponses', 'LlmPubmedResponses'] diff --git a/examples/llm/tape/tape/dataset/llm/offline/base.py b/examples/llm/tape/tape/dataset/llm/offline/base.py new file mode 100644 index 000000000000..9c51f0c6cd53 --- /dev/null +++ b/examples/llm/tape/tape/dataset/llm/offline/base.py @@ -0,0 +1,119 @@ +import json +import os +from abc import abstractmethod +from functools import partial +from pathlib import Path +from typing import Dict, List, Union + +from jinja2 import Environment, FileSystemLoader +from tape.dataset.llm.engine import ( + LlmEngine, + LlmOfflineEngineArgs, + LlmResponseModel, +) +from tape.dataset.parser import Article +from tqdm import tqdm +from transformers import AutoTokenizer +from vllm import LLM, SamplingParams + + +class LlmOfflineEngine(LlmEngine): + def __init__(self, args: LlmOfflineEngineArgs): + super().__init__(args) + # Update `huggingface_hub` default cache dir + os.environ['HF_HOME'] = args.cache_dir + self.tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=args.model, cache_dir=args.cache_dir) + self.llm = LLM(model=args.model, **(args.engine_kwargs or {})) + self.sampling_params = SamplingParams(**(args.sampling_kwargs or {})) + self._system_prompt = None + + @property + def system_prompt(self) -> str: + if self._system_prompt is None: + self._system_prompt = self.get_system_prompt() + return self._system_prompt + + @abstractmethod + def get_system_prompt(self) -> str: + pass + + def _prepare_conversation( + self, articles: List[Article]) -> List[List[Dict[str, str]]]: + conversation = [] + for article in articles: + prompt = f'Title: {article.title}\nAbstract: {article.abstract}' + messages = [ + { + "role": "system", + "content": self.system_prompt + }, + { + "role": "user", + "content": prompt + }, + ] + conversation.append(messages) + return conversation + + def __call__( + self, + articles: Union[Article, List[Article]], + return_prompt: bool = False, + strict: bool = False, # Whether to use strict json parsing + max_retries: int = 3, + ) -> Union[Dict[str, Union[str, Dict]], List[Dict[str, Union[str, Dict]]]]: + + single_article = isinstance(articles, Article) + if single_article: + articles = [articles] + conversation = self._prepare_conversation(articles) + + apply_chat_template = partial(self.tokenizer.apply_chat_template, + add_generation_prompt=True, + tokenize=not return_prompt) + kwargs = dict(use_tqdm=False, sampling_params=self.sampling_params) + prompt_key = 'prompts' if return_prompt else 'prompt_token_ids' + kwargs[prompt_key] = apply_chat_template(conversation) + + outputs = self.llm.generate(**kwargs) + max_retries = self.args.max_retries or max_retries + for i in range(len(outputs)): + output = outputs[i].outputs[0].text + if strict: + retries = 0 + json_output = None + while retries < max_retries: + try: + json_output = json.loads(output) + break + except Exception: + retries += 1 + print(f'Retry {retries}/{max_retries} ' + 'after exception: {e}') + kwargs[prompt_key] = apply_chat_template( + conversation[i:i + 1]) + output = self.llm.generate(**kwargs)[0].outputs[0].text + output = json_output + outputs[i] = dict(input=outputs[i].prompt, output=output) + + return outputs[0] if single_article else outputs + + def get_responses_from_articles( + self, articles: List[Article]) -> List[LlmResponseModel]: + responses = [None] * len(articles) + batch_size = self.args.batch_size + for start in tqdm(range(0, len(articles), batch_size), + total=len(articles) // batch_size): + results = self(articles[start:start + batch_size], strict=True) + for idx, result in zip(range(start, start + batch_size), results): + if result and result['output']: + responses[idx] = LlmResponseModel(**result['output']) + return responses + + def load_system_prompt_from_template(self, **kwargs) -> str: + file_loader = FileSystemLoader(Path(__file__, '..').resolve()) + env = Environment(loader=file_loader) + template = env.get_template('prompt.jinja') + prompt = template.render(**kwargs) + return prompt diff --git a/examples/llm/tape/tape/dataset/llm/offline/ogbn_arxiv.py b/examples/llm/tape/tape/dataset/llm/offline/ogbn_arxiv.py new file mode 100644 index 000000000000..f824b279197e --- /dev/null +++ b/examples/llm/tape/tape/dataset/llm/offline/ogbn_arxiv.py @@ -0,0 +1,27 @@ +from typing import Dict + +from tape.dataset.llm.engine import LlmOfflineEngineArgs +from tape.dataset.llm.offline.base import LlmOfflineEngine + + +class LlmOgbnArxivResponses(LlmOfflineEngine): + def __init__(self, args: LlmOfflineEngineArgs, + class_id_to_label: Dict) -> None: + super().__init__(args) + self.class_id_to_label = class_id_to_label + + def get_system_prompt(self) -> str: + topk = 5 + categories = [] + for v in self.class_id_to_label.values(): + category = v['category'].replace('-', ' ').replace(',', '') + categories.append(f"{v['label']} // {category}") + kwargs = dict( + role="You're an experienced computer scientist.", + categories=categories, + label_description=( + f'Contains {topk} arXiv CS sub-categories ordered ' + 'from most to least likely.', ), + ) + prompt = self.load_system_prompt_from_template(**kwargs) + return prompt diff --git a/examples/llm/tape/tape/dataset/llm/offline/prompt.jinja b/examples/llm/tape/tape/dataset/llm/offline/prompt.jinja new file mode 100644 index 000000000000..8d1fbb79e8db --- /dev/null +++ b/examples/llm/tape/tape/dataset/llm/offline/prompt.jinja @@ -0,0 +1,15 @@ +{{ role }} Given an article with title and abstract, classify the article into the following categories: +{%- for category in categories %} +{{ loop.index }}. {{ category }} +{%- endfor %} + +Ensure that the "label" is a list of str selected from the categories and that both "label" and "reason" are non-empty. + +Return the output in a JSON format as following: +```json +{ + "label": [...], # {{ label_description }} + "reason": "..." # A detailed explanation with quotes from the article explaining why the article is related to the chosen label based on the ranking. +} +``` +Do not return anything else except the JSON string. diff --git a/examples/llm/tape/tape/dataset/llm/offline/pubmed.py b/examples/llm/tape/tape/dataset/llm/offline/pubmed.py new file mode 100644 index 000000000000..44169b5def1b --- /dev/null +++ b/examples/llm/tape/tape/dataset/llm/offline/pubmed.py @@ -0,0 +1,22 @@ +from typing import Dict + +from tape.dataset.llm.engine import LlmOfflineEngineArgs +from tape.dataset.llm.offline.base import LlmOfflineEngine + + +class LlmPubmedResponses(LlmOfflineEngine): + def __init__(self, args: LlmOfflineEngineArgs, + class_id_to_label: Dict) -> None: + super().__init__(args) + self.class_id_to_label = class_id_to_label + + def get_system_prompt(self) -> str: + kwargs = dict( + role="You're an experienced medical doctor.", + categories=[v['label'] for v in self.class_id_to_label.values()], + label_description=( + 'Contains the category (or categories if multiple options ' + 'apply) ordered from most to least likely.'), + ) + prompt = self.load_system_prompt_from_template(**kwargs) + return prompt diff --git a/examples/llm/tape/tape/dataset/llm/online/__init__.py b/examples/llm/tape/tape/dataset/llm/online/__init__.py new file mode 100644 index 000000000000..7f73a44ab97d --- /dev/null +++ b/examples/llm/tape/tape/dataset/llm/online/__init__.py @@ -0,0 +1,4 @@ +from .ogbn_arxiv import LlmOgbnArxivResponses +from .pubmed import LlmPubmedResponses + +__all__ = ['LlmOgbnArxivResponses', 'LlmPubmedResponses'] diff --git a/examples/llm/tape/tape/dataset/llm/online/base.py b/examples/llm/tape/tape/dataset/llm/online/base.py new file mode 100644 index 000000000000..aa950b1f3d1e --- /dev/null +++ b/examples/llm/tape/tape/dataset/llm/online/base.py @@ -0,0 +1,89 @@ +import random +from abc import abstractmethod +from pathlib import Path +from typing import List, Optional + +import instructor +from litellm import completion +from tape.config import DatasetName +from tape.dataset.llm.engine import ( + LlmEngine, + LlmOnlineEngineArgs, + LlmResponseModel, +) +from tape.dataset.llm.online.cache import llm_responses_cache, setup_cache +from tape.dataset.parser import Article +from tenacity import retry, stop_after_attempt, wait_random_exponential +from tqdm import tqdm + +from torch_geometric.template import module_from_template + + +class LlmOnlineEngine(LlmEngine): + def __init__(self, args: LlmOnlineEngineArgs, + dataset_name: DatasetName) -> None: + super().__init__(args) + self.dataset_name = dataset_name.value + self.client = instructor.from_litellm(completion) + setup_cache(cache_dir=Path(args.cache_dir) / + f'tape_llm_responses/{dataset_name.value}') + self._response_model = None + + @abstractmethod + def get_response_model(self) -> LlmResponseModel: + pass + + @property + def response_model(self) -> LlmResponseModel: + if self._response_model is None: + self._response_model = self.get_response_model() + return self._response_model + + def __call__(self, article: Article) -> Optional[LlmResponseModel]: + messages = [ + dict(role='system', content=self.system_message), + dict( + role='user', content='Title: {}\nAbstract: {}'.format( + article.title, article.abstract)) + ] + response = None + rpm = self.args.rate_limit_per_minute + try: + response = self._completion_with_backoff( + model=self.args.model, + messages=messages, + response_model=self.response_model, + delay=60.0 / rpm if rpm else + 0, # Adding delay to a request to avoid hitting the rate limit + **self.args.sampling_kwargs) + except Exception as e: + print('Max retries reached. Failed to get a successful response. ' + f'Error: {e}') + + return response + + @retry(wait=wait_random_exponential(min=1, max=60), + stop=stop_after_attempt(6)) + @llm_responses_cache + def _completion_with_backoff(self, **kwargs): + return self.client.chat.completions.create_with_completion(**kwargs) + + def get_responses_from_articles( + self, articles: List[Article]) -> List[LlmResponseModel]: + responses = [] + for article in tqdm(articles, total=len(articles), + desc='Fetching LLM responses'): + if not (response := self(article)): + raise ValueError('LLM response cannot be empty!') + response.label = response.label.value # Convert Enum to str + responses.append(response) + return responses + + def load_response_model_from_template(self, **kwargs) -> LlmResponseModel: + uid = '%06x' % random.randrange(16**6) + path = Path(__file__, '..').resolve() + module = module_from_template( + module_name=f'response_model-{self.dataset_name}-{uid}', + template_path=path / 'response_model.jinja', + tmp_dirname='response_model', **kwargs) + return module.Classification diff --git a/examples/llm/tape/tape/dataset/llm/online/cache.py b/examples/llm/tape/tape/dataset/llm/online/cache.py new file mode 100644 index 000000000000..40797509e61a --- /dev/null +++ b/examples/llm/tape/tape/dataset/llm/online/cache.py @@ -0,0 +1,58 @@ +import functools +import json +import time +from pathlib import Path +from typing import Union + +import diskcache +from litellm import ModelResponse +from pydantic import BaseModel +from tape.utils import generate_string_hash + +CACHE = None + + +def setup_cache(cache_dir: Union[str, Path]) -> None: + global CACHE + CACHE = diskcache.Cache(cache_dir) + + +def llm_responses_cache(func): + """Cache a function that returns a Pydantic model.""" + @functools.wraps(func) + def wrapper(*args, **kwargs): + assert CACHE is not None, ( + 'Cache is not set! Please call `set_cache(...)` ' + 'before calling the function.') + delay = kwargs.pop('delay', 0) + response_model = kwargs['response_model'] + key = f'{func.__name__}-{make_key(args, kwargs)}' + if (cached := CACHE.get(key)) is not None: + # Deserialize from JSON based on the return type + usage = ModelResponse(**json.loads(cached)) + data = usage.choices[0].message.tool_calls[0].function.arguments + response = response_model.model_validate_json(data) + else: + time.sleep(delay) + # Call the function and cache its result + response, usage = func(*args, **kwargs) + serialize_usage = usage.model_dump_json() + CACHE.set(key, serialize_usage) + return response + + return wrapper + + +def make_key(args, kwargs): + def convert(v): + if isinstance(v, BaseModel): + return str(v.model_json_schema()) + return str(v) + + data = '' + for arg in args: + data += convert(arg) + for k, v in kwargs.items(): + data += k + convert(v) + input_hash = generate_string_hash(data) + return input_hash diff --git a/examples/llm/tape/tape/dataset/llm/online/ogbn_arxiv.py b/examples/llm/tape/tape/dataset/llm/online/ogbn_arxiv.py new file mode 100644 index 000000000000..24167c7e089a --- /dev/null +++ b/examples/llm/tape/tape/dataset/llm/online/ogbn_arxiv.py @@ -0,0 +1,44 @@ +from typing import Dict + +from tape.config import DatasetName +from tape.dataset.llm.engine import LlmOnlineEngineArgs, LlmResponseModel +from tape.dataset.llm.online.base import LlmOnlineEngine + + +class LlmOgbnArxivResponses(LlmOnlineEngine): + def __init__(self, args: LlmOnlineEngineArgs, + class_id_to_label: Dict) -> None: + super().__init__(args=args, dataset_name=DatasetName.OGBN_ARXIV) + self.class_id_to_label = class_id_to_label + self.system_message = ('Which arXiv CS sub-category does this ' + 'paper belong to?') + + def get_response_model(self) -> LlmResponseModel: + topk = 5 + class_labels = { + v['category'].replace('-', ' ').replace(',', ''): v['label'] + for v in self.class_id_to_label.values() + } + labels_list = list(class_labels.values()) + kwargs = dict( + class_labels=class_labels, label_description=( + f'Provide {topk} likely arXiv CS sub-categories ordered ' + 'from most to least likely.'), + label_examples=[labels_list[:topk], labels_list[topk:topk * 2]], + reason_examples=[ + ('The paper is about a new dataset for scene text ' + 'detection and recognition, which is a topic ' + 'related to computer vision (cs.CV). ' + 'The paper also mentions the use of deep learning ' + 'techniques such as DeconvNet, which falls under ' + 'the sub-category of artificial intelligence (cs.AI).' + ' The dataset is annotated and involves text ' + 'recognition, which could also fall under the ' + 'sub-categories of information retrieval (cs.IR) and' + ' natural language processing (cs.CL). Finally, the ' + 'paper discusses the effectiveness of different ' + 'solutions, which could be evaluated using machine ' + 'learning techniques, falling under the sub-category' + ' of machine learning (cs.LG).'), + ]) + return self.load_response_model_from_template(**kwargs) diff --git a/examples/llm/tape/tape/dataset/llm/online/pubmed.py b/examples/llm/tape/tape/dataset/llm/online/pubmed.py new file mode 100644 index 000000000000..521289a08952 --- /dev/null +++ b/examples/llm/tape/tape/dataset/llm/online/pubmed.py @@ -0,0 +1,59 @@ +from typing import Dict, List + +from datasets import load_dataset +from tape.config import DatasetName +from tape.dataset.llm.engine import LlmOnlineEngineArgs, LlmResponseModel +from tape.dataset.llm.online.base import LlmOnlineEngine +from tape.dataset.parser import Article + + +class LlmPubmedResponses(LlmOnlineEngine): + def __init__(self, args: LlmOnlineEngineArgs, + class_id_to_label: Dict) -> None: + super().__init__(args=args, dataset_name=DatasetName.PUBMED) + self.class_id_to_label = class_id_to_label + self.system_message = ( + 'Classify a scientific publication (containing title' + ' and abstract) into provided categories.') + + def get_response_model(self) -> LlmResponseModel: + class_labels = { + v['label']: v['label'] + for v in self.class_id_to_label.values() + } + labels_list = list(class_labels.values()) + kwargs = dict( + class_labels=class_labels, + label_description=( + 'Provide the most likely category (or categories ' + 'if multiple options apply) ordered ' + 'from most to least likely.'), + label_examples=[labels_list[:1], labels_list[:2]], + reason_examples=[ + # Example containing multiple paper categories + # ➜ Type 1 Diabetes & Experimental Diabetes + ('Type 1 diabetes is present in the abstract as the study ' + 'was conducted on cardiac mitochondria from type-I diabetic ' + 'rats. Experimentally induced diabetes is also present in ' + 'the abstract as the study involved inducing diabetes in ' + 'rats and comparing the mitochondrial function of these ' + 'rats to control rats.'), + ]) + return self.load_response_model_from_template(**kwargs) + + def get_responses_from_articles( + self, articles: List[Article]) -> List[LlmResponseModel]: + model = 'huggingface/meta-llama/Meta-Llama-3-8B-Instruct' + if self.args.model == model: + dataset = load_dataset( + "devanshamin/PubMedDiabetes-LLM-Predictions", + cache_dir=self.args.cache_dir, split='train') + responses = [] + for sample in dataset: + response = self.response_model( + label=sample['predicted_ranked_labels'].split('; '), + reason=sample['explanation']) + responses.append(response) + else: + responses = super().get_responses_from_articles(articles) + return responses diff --git a/examples/llm/tape/tape/dataset/llm/online/response_model.jinja b/examples/llm/tape/tape/dataset/llm/online/response_model.jinja new file mode 100644 index 000000000000..e8c0bc1910aa --- /dev/null +++ b/examples/llm/tape/tape/dataset/llm/online/response_model.jinja @@ -0,0 +1,28 @@ +from enum import Enum +from typing import List + +from pydantic import Field + +from tape.dataset.llm.engine import LlmResponseModel + + +class PaperCategory(str, Enum): +{%- for key, value in class_labels.items() %} + {{ key.replace(' ', '_')|upper }} = "{{ value }}" +{%- endfor %} + + +class Classification(LlmResponseModel): + label: List[PaperCategory] = Field( + ..., + description="{{ label_description }}", + examples={{ label_examples }} + ) + reason: str = Field( + ..., + description=( + 'Give a detailed explanation with quotes from the abstract explaining why ' + 'the paper is related to the chosen label based on the ranking.' + ), + examples={{ reason_examples }} + ) diff --git a/examples/llm/tape/tape/dataset/lm_encoder.py b/examples/llm/tape/tape/dataset/lm_encoder.py new file mode 100644 index 000000000000..f898090d08a3 --- /dev/null +++ b/examples/llm/tape/tape/dataset/lm_encoder.py @@ -0,0 +1,145 @@ +import warnings +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import List, Literal, Optional + +import torch +from safetensors import safe_open +from safetensors.torch import save_file +from tape.config import DatasetName, FeatureType +from tape.utils import generate_string_hash +from tqdm import tqdm + +warnings.filterwarnings('ignore') # Ignore HuggingFace libraries warnings + + +@dataclass +class TransformersTokenizerArgs: + batch_size: int = 32 + truncation: bool = True + padding: bool = True + max_length: int = 512 + + +@dataclass +class SentenceTransformerArgs: + batch_size: int = 32 + show_progress_bar: bool = True + precision: Literal['float32', 'int8', 'uint8', 'binary', + 'ubinary'] = 'float32' + + +@dataclass +class LmEncoderArgs: + dataset_name: DatasetName # Used for creating file name to save embeddings + feature_type: FeatureType # Used for creating file name to save embeddings + model_name_or_path: str + model_library: Literal['transformers', 'sentence_transformer'] + transformers_encoder_args: Optional[TransformersTokenizerArgs] = None + sentence_transformer_encoder_args: Optional[SentenceTransformerArgs] = None + device: Optional[str] = None + cache_dir: str = '.cache' + + +class LmEncoder: + """Language model article encoder.""" + def __init__(self, args: LmEncoderArgs) -> None: + self.args = args + self.device = args.device or ('cuda' + if torch.cuda.is_available() else 'cpu') + self.model = None + self.tokenizer = None + cache_dir = Path.cwd() / args.cache_dir + + embd_cache_dir = cache_dir / 'embeddings' + embd_cache_dir.mkdir(exist_ok=True, parents=True) + file_name = ( + f'{args.feature_type.value}_{args.dataset_name.value}' + f'_{args.model_name_or_path.replace("/", "--")}.safetensors') + self.embd_cache_path = embd_cache_dir / file_name + self._sent_hash_to_embedding = self._load_cache() + + if args.model_library == 'transformers': + from transformers import AutoModel, AutoTokenizer + + self.model = AutoModel.from_pretrained( + args.model_name_or_path, cache_dir=cache_dir).to(self.device) + self.tokenizer = AutoTokenizer.from_pretrained( + args.model_name_or_path, cache_dir=cache_dir) + elif args.model_library == 'sentence_transformer': + from sentence_transformers import SentenceTransformer + + self.model = SentenceTransformer(args.model_name_or_path, + device=self.device, + cache_folder=cache_dir) + else: + raise Exception('Invalid model library!') + + def _load_cache(self): + input_hash_to_embedding = {} + if self.embd_cache_path.exists(): + print('Loading cached embeddings...') + with safe_open(str(self.embd_cache_path), framework="pt", + device=self.device) as f: + for k in f.keys(): + input_hash_to_embedding[k] = f.get_tensor(k) + return input_hash_to_embedding + + def save_cache(self) -> None: + save_file(self._sent_hash_to_embedding, str(self.embd_cache_path)) + print(f'Saved embedding file to "{self.embd_cache_path}"') + + @torch.inference_mode() + def _hf_encoder(self, sentences: List[str], **kwargs): + encoded_sentences = self.tokenizer( + sentences, + truncation=kwargs.get('truncation', True), + padding=kwargs.get('padding', True), + return_tensors='pt', + max_length=kwargs.get('max_length', 512), + ).to(self.device) + # Encode the queries (use the [CLS] last hidden states + # as the representations) + embeddings = self.model(**encoded_sentences).last_hidden_state[:, 0, :] + torch.cuda.empty_cache() + return embeddings + + def _get_embeddings(self, sentences: List[str], **kwargs): + if self.args.model_library == 'transformers': + _kwargs = asdict(self.args.transformers_encoder_args) + _kwargs.update(kwargs) # kwargs overrides the default config + batch_size = _kwargs['batch_size'] + embeddings = [] + for step in tqdm(range(0, len(sentences), batch_size), + total=len(sentences) // batch_size, + desc='Batches'): + embeddings.append( + self._hf_encoder( + sentences=sentences[step:step + batch_size], + **_kwargs)) + embeddings = torch.cat(embeddings) + elif self.args.model_library == 'sentence_transformer': + _kwargs = asdict(self.args.sentence_transformer_encoder_args) + _kwargs.update(kwargs) # kwargs overrides the default config + _kwargs.pop('convert_to_tensor', None) + embeddings = self.model.encode(sentences, convert_to_tensor=True, + **_kwargs) + return embeddings + + def __call__(self, sentences: List[str], **kwargs) -> torch.Tensor: + missing_sentences_idxs = [] + embeddings = [] + for idx, sent_hash in enumerate(map(generate_string_hash, sentences)): + if (embd := self._sent_hash_to_embedding.get(sent_hash)) is None: + missing_sentences_idxs.append((idx, sent_hash)) + else: + sentences[idx] = None + embeddings.append(embd) + if missing_sentences_idxs: + missing_embeddings = self._get_embeddings( + sentences=list(filter(None, sentences)), **kwargs) + for (idx, sent_hash), embedding in zip(missing_sentences_idxs, + missing_embeddings): + embeddings[idx] = embedding + self._sent_hash_to_embedding[sent_hash] = embedding + return embeddings diff --git a/examples/llm/tape/tape/dataset/parser/__init__.py b/examples/llm/tape/tape/dataset/parser/__init__.py new file mode 100644 index 000000000000..abe574399964 --- /dev/null +++ b/examples/llm/tape/tape/dataset/parser/__init__.py @@ -0,0 +1,5 @@ +from .base import Article +from .pubmed import PubmedParser +from .ogbn_arxiv import OgbnArxivParser + +__all__ = ['Article', 'PubmedParser', 'OgbnArxivParser'] diff --git a/examples/llm/tape/tape/dataset/parser/base.py b/examples/llm/tape/tape/dataset/parser/base.py new file mode 100644 index 000000000000..9147b46bd45d --- /dev/null +++ b/examples/llm/tape/tape/dataset/parser/base.py @@ -0,0 +1,24 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from pathlib import Path + + +@dataclass +class Article: + paper_id: str + title: str + abstract: str + + +class Parser(ABC): + def __init__(self, seed: int = 42, cache_dir: str = '.cache') -> None: + self.seed = seed + self.cache_dir = Path.cwd() / cache_dir + + @abstractmethod + def parse(self): + pass + + @abstractmethod + def download_data(self): + pass diff --git a/examples/llm/tape/tape/dataset/parser/ogbn_arxiv.py b/examples/llm/tape/tape/dataset/parser/ogbn_arxiv.py new file mode 100644 index 000000000000..d423ba66f66d --- /dev/null +++ b/examples/llm/tape/tape/dataset/parser/ogbn_arxiv.py @@ -0,0 +1,165 @@ +import re +import zipfile +from pathlib import Path +from typing import Dict, List, Optional + +import gdown +import pandas as pd +import requests +import torch +from ogb.nodeproppred import PygNodePropPredDataset +from tape.dataset.parser.base import Article, Parser + +import torch_geometric.transforms as T +from torch_geometric.data import Data + + +class OgbnArxivParser(Parser): + """Parser for [OGB arXiv](https://ogb.stanford.edu/docs/nodeprop/#ogbn-arxiv) dataset.""" # noqa + + urls = { + 'original': + 'https://snap.stanford.edu/ogb/data/misc/ogbn_arxiv/titleabs.tsv.gz', # noqa + 'llm_responses': + 'https://drive.google.com/file/d/1A6mZSFzDIhJU795497R6mAAM2Y9qutI5/view?usp=sharing', # noqa + } + + def __init__(self, seed: int = 0, cache_dir: str = '.cache') -> None: + super().__init__(seed, cache_dir) + self._dtype_to_path = self.download_data() + self.graph = OgbnArxivGraph(dir_path=self._dtype_to_path['original'], + cache_dir=self.cache_dir) + self.split = None + + def parse(self) -> None: + self.graph.load() + + def download_data(self) -> Dict[str, Path]: + dtype_to_path = {} + for dtype, url in OgbnArxivParser.urls.items(): + save_dir = self.cache_dir / dtype + save_dir.mkdir(exist_ok=True, parents=True) + dtype_to_path[dtype] = save_dir / ( + 'ogbn-arxiv' + ('_orig' if dtype == 'original' else '')) + if url.endswith('.tsv.gz'): + file_name = url.split('/')[-1] + dtype_to_path[dtype] /= file_name + + if not dtype_to_path[dtype].exists(): + if 'drive.google.com' in url: + zip_file_path = save_dir / 'ogbn-arxiv.zip' + file_id = url.split('/d/')[1].split('/')[0] + download_url = f'https://drive.google.com/uc?export=download&id={file_id}' # noqa + gdown.download(download_url, str(zip_file_path), + quiet=False) + with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: + zip_ref.extractall(str(save_dir)) + zip_file_path.unlink() + else: + response = requests.get(url, stream=True) + dtype_to_path[dtype].parent.mkdir(exist_ok=True) + with open(dtype_to_path[dtype], 'wb') as f: + for chunk in response.iter_content(32_768): + if chunk: + f.write(chunk) + + return dtype_to_path + + +class OgbnArxivGraph: + def __init__(self, dir_path: Path, cache_dir: Path) -> None: + + self.dir_path = dir_path + self.cache_dir = cache_dir + + self.dataset: Optional[Data] = None + self.n_classes = 40 + self.n_nodes = 169_343 + self.n_features = 128 + self.class_id_to_label: Optional[Dict] = None + self.articles: Optional[List[Article]] = None + # Split containing train/val/test node ids + self.split: Optional[Dict] = None + + def load(self): + self._load_ogb_dataset() + self._load_articles() + self.class_id_to_label = self._load_class_label_mapping() + + def _load_ogb_dataset(self): + print('Loading OGB dataset...') + + dataset = PygNodePropPredDataset(name='ogbn-arxiv', + root=self.cache_dir, + transform=T.ToSparseTensor()) + self.split = dataset.get_idx_split() + + data = dataset[0] + + train_mask = torch.zeros(data.num_nodes).bool() + train_mask[self.split['train']] = True + data.train_mask = train_mask + + val_mask = torch.zeros(data.num_nodes).bool() + val_mask[self.split['valid']] = True + data.val_mask = val_mask + + test_mask = torch.zeros(data.num_nodes).bool() + test_mask[self.split['test']] = True + data.test_mask = test_mask + + data.edge_index = data.adj_t.to_symmetric() + + self.dataset = data + + def _load_articles(self): + + mapping_df = pd.read_csv( + self.cache_dir / 'ogbn_arxiv/mapping/nodeidx2paperid.csv.gz', + skiprows=1, names=['node_idx', 'paper_id'], compression='gzip') + title_abstract_df = pd.read_table( + self.dir_path, header=None, + names=['paper_id', 'title', 'abstract'], compression='gzip') + df = mapping_df.astype(dict(paper_id=str)).join( + title_abstract_df.set_index('paper_id'), on='paper_id') + self.articles = [] + for row in df.itertuples(index=False): + self.articles.append( + Article(paper_id=row.paper_id, title=row.title, + abstract=row.abstract)) + + def _load_class_label_mapping(self): + mapping_df = pd.read_csv( + self.cache_dir / + 'ogbn_arxiv/mapping/labelidx2arxivcategeory.csv.gz', skiprows=1, + names=['label_id', 'label'], compression='gzip') + class_id_to_label = {} + categories = OgbnArxivGraph.fetch_arxiv_category_taxonomy() + df = pd.DataFrame(categories) + for row in mapping_df.itertuples(index=False): + label = row.label.replace('arxiv cs ', + 'cs.').strip().upper().replace( + 'CS', 'cs') + class_id_to_label[row.label_id] = df[df.label == + label].iloc[0].to_dict() + return class_id_to_label + + @staticmethod + def fetch_arxiv_category_taxonomy( + category: str = 'cs') -> List[Dict[str, str]]: + text = requests.get( + 'https://r.jina.ai/https://arxiv.org/category_taxonomy').text + sections = re.split(r'#### ', text)[1:] + data_list = [] + for section in sections: + match = re.match(rf'({category}\.\w+) \(([^)]+)\)\n\n', section) + if match: + label = match.group(1) + category_name = match.group(2) + description = section[match.end():].strip() + data_list.append({ + 'label': label, + 'category': category_name, + 'description': description + }) + return data_list diff --git a/examples/llm/tape/tape/dataset/parser/pubmed.py b/examples/llm/tape/tape/dataset/parser/pubmed.py new file mode 100644 index 000000000000..33dc72a48747 --- /dev/null +++ b/examples/llm/tape/tape/dataset/parser/pubmed.py @@ -0,0 +1,232 @@ +import json +import zipfile +from pathlib import Path +from typing import Dict, List, Optional + +import gdown +import torch +from tape.dataset.parser.base import Article, Parser + +from torch_geometric.data import Data +from torch_geometric.datasets import Planetoid + + +class PubmedParser(Parser): + """Parser for [PubMed Diabetes](https://linqs.org/datasets/#pubmed-diabetes) dataset.""" # noqa + + urls = { + 'original': + 'https://drive.google.com/file/d/1sYZX-jP6H8OkopVa9cp8-KXdEti5ki_W/view?usp=sharing', # noqa + 'llm_responses': + 'https://drive.google.com/file/d/166waPAjUwu7EWEvMJ0heflfp0-4EvrZS/view?usp=sharing', # noqa + } + + def __init__(self, seed: int = 0, cache_dir: str = '.cache') -> None: + super().__init__(seed, cache_dir) + self._dtype_to_path = self.download_data() + self.graph = PubmedGraph(dir_path=self._dtype_to_path['original'], + cache_dir=self.cache_dir) + + def parse(self) -> None: + self.graph.load() + + def download_data(self) -> Dict[str, Path]: + dtype_to_path = {} + for dtype, url in PubmedParser.urls.items(): + save_dir = self.cache_dir / dtype + save_dir.mkdir(exist_ok=True, parents=True) + zip_file_path = save_dir / 'PubMed.zip' + dtype_to_path[dtype] = save_dir / ( + 'PubMed' + ('_orig' if dtype == 'original' else '')) + + if not dtype_to_path[dtype].exists(): + file_id = url.split('/d/')[1].split('/')[0] + download_url = f'https://drive.google.com/uc?export=download&id={file_id}' # noqa + gdown.download(download_url, str(zip_file_path), quiet=False) + + with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: + zip_ref.extractall(str(save_dir)) + zip_file_path.unlink() + + return dtype_to_path + + +class PubmedGraph: + def __init__(self, dir_path: Path, cache_dir: Path) -> None: + + self.dir_path = dir_path + self.cache_dir = cache_dir + + self.dataset: Optional[Data] = None + self.n_classes = 3 + self.class_id_to_label = { + 0: + dict( + label='Experimental Diabetes', + description=('Studies investigating diabetes in controlled ' + 'experimental settings.'), + ), + 1: + dict( + label='Type 1 Diabetes', + description=( + 'An autoimmune disease where the body attacks and ' + 'destroys insulin-producing cells in the pancreas.'), + ), + 2: + dict( + label='Type 2 Diabetes', description=( + 'A metabolic disorder characterized by high blood ' + "sugar levels due to the body's inability to " + 'effectively use insulin.')), + } + self.n_nodes = 19_717 + self.n_features = 500 + # PubMed Articles + self.articles: Optional[List[Article]] = None + self.pubmed_id_to_node_id = {} + # Nodes + self.node_features: Optional[torch.tensor] = None + self.node_labels: Optional[List] = None + self.node_feature_to_idx: Optional[Dict] = None + # Edges + self.edge_index: Optional[torch.tensor] = None + self.adj_matrix: Optional[torch.tensor] = None + # Split containing train/val/test node ids + self.split: Optional[Dict] = None + + def load(self) -> None: + + self._load_articles() + self._load_nodes() + self._load_edges() + self._load_pyg_dataset() + + def _load_pyg_dataset(self): + + print('Loading PyG dataset...') + + self.dataset = Planetoid(self.cache_dir, 'PubMed')[0] + # Replace dataset matrices with the PubMed-Diabetes data, + # for which we have the original PubMed IDs + self.dataset.x = self.node_features + self.dataset.y = self.node_labels + self.dataset.edge_index = self.edge_index + + # Split dataset nodes into train/val/test and update + # the train/val/test masks + n_nodes = self.dataset.num_nodes + node_ids = torch.randperm(n_nodes) + self.split = {} + for split_name in ('train', 'val', 'test'): + if split_name == 'train': + subset = slice(0, int(n_nodes * 0.6)) + elif split_name == 'val': + subset = slice(int(n_nodes * 0.6), int(n_nodes * 0.8)) + else: + subset = slice(int(n_nodes * 0.8), n_nodes) + + ids = node_ids[subset].sort()[0] + setattr(self.dataset, f'{split_name}_id', ids) + mask = torch.zeros(n_nodes, dtype=bool) + mask[ids] = True + setattr(self.dataset, f'{split_name}_mask', mask) + self.split[split_name] = ids.tolist() + + def _load_articles(self): + + print('Loading articles...') + + self.articles = [] + path = self.dir_path / 'pubmed.json' + data = json.loads(path.read_text()) + node_id = 0 + for article in data: + if (pubmed_id := article.get('PMID')) and ( + title := article.get('TI')) and (abstract := + article.get('AB')): + self.articles.append( + Article(paper_id=pubmed_id, title=title, + abstract=abstract)) + self.pubmed_id_to_node_id[pubmed_id] = node_id + node_id += 1 + else: + print(f'Ignoring PubMed article with node id "{node_id}" ' + 'due to missing PMID/Abstract/Title.') + + print('No. of PubMed articles with title and ' + f'abstract: {len(self.articles):,}') + print(f'Updating no. of nodes from {self.n_nodes:,} ' + f'to {len(self.articles):,}') + self.n_nodes = len(self.articles) + + def _load_nodes(self): + + print('Loading nodes...') + + self.node_features = torch.zeros((self.n_nodes, self.n_features), + dtype=torch.float32) + self.node_labels = torch.empty(self.n_nodes, dtype=torch.long) + self.node_feature_to_idx = {} + + with open(self.dir_path / 'data/Pubmed-Diabetes.NODE.paper.tab', + 'r') as node_file: + node_file.readline() # Ignore header + node_file.readline() # Ignore header + k = 0 + + for line in node_file.readlines(): + items = line.strip().split('\t') + pubmed_id = items[0] + if (node_id := + self.pubmed_id_to_node_id.get(pubmed_id)) is None: + print(f'Ignoring PubMed article "{pubmed_id}" due to ' + 'missing PMID/Abstract/Title.') + continue + + label = int(items[1].split('=')[-1]) - 1 + self.node_labels[node_id] = label + features = items[2:-1] + for feature in features: + parts = feature.split('=') + fname = parts[0] + fvalue = float(parts[1]) + if fname not in self.node_feature_to_idx: + self.node_feature_to_idx[fname] = k + k += 1 + self.node_features[ + node_id, self.node_feature_to_idx[fname]] = fvalue + + def _load_edges(self): + + print('Loading edges...') + + edges = [] + self.adj_matrix = torch.zeros((self.n_nodes, self.n_nodes), + dtype=torch.float32) + + with open(self.dir_path / 'data/Pubmed-Diabetes.DIRECTED.cites.tab', + 'r') as edge_file: + edge_file.readline() # Ignore header + edge_file.readline() # Ignore header + + for line in edge_file.readlines(): + items = line.strip().split('\t') + tail = items[1].split(':')[-1] + head = items[3].split(':')[-1] + if ((head_node_id := + self.pubmed_id_to_node_id.get(head)) is None + or (tail_node_id := + self.pubmed_id_to_node_id.get(tail)) is None): + print(f'Ignoring edge ({head}, {tail}) due to either of ' + 'the PubMed articles being discarded.') + continue + + self.adj_matrix[tail_node_id, head_node_id] = 1.0 + self.adj_matrix[head_node_id, tail_node_id] = 1.0 + if head != tail: + edges.append((head_node_id, tail_node_id)) + edges.append((tail_node_id, head_node_id)) + + edges = torch.tensor(edges, dtype=torch.long) + self.edge_index = torch.unique(edges, dim=0).T diff --git a/examples/llm/tape/tape/gnn_model.py b/examples/llm/tape/tape/gnn_model.py new file mode 100644 index 000000000000..45cdd27827da --- /dev/null +++ b/examples/llm/tape/tape/gnn_model.py @@ -0,0 +1,67 @@ +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch_geometric.nn import conv as conv_layers + + +@dataclass +class NodeClassifierArgs: + conv_layer: str + hidden_channels: int + num_layers: int + dropout: Optional[float] = 0.0 + in_channels: Optional[int] = None # Inferred from the dataset + out_channels: Optional[int] = None # Inferred from the dataset + use_predictions: Optional[bool] = None # Inferred from the dataset + + +class NodeClassifier(torch.nn.Module): + def __init__(self, args: NodeClassifierArgs) -> None: + + super().__init__() + self.use_predictions = args.use_predictions + if self.use_predictions: + # Embedding lookup for each class (out_channels == num_classes) + self.encoder = nn.Embedding(args.out_channels + 1, + args.hidden_channels) + + self.convs = nn.ModuleList() + self.batch_norm = nn.ModuleList() + assert (conv_cls := getattr(conv_layers, args.conv_layer, None)) + self.convs.append(conv_cls(args.in_channels, args.hidden_channels)) + for _ in range(args.num_layers - 2): + self.convs.append( + conv_cls(args.hidden_channels, args.hidden_channels)) + self.batch_norm.append(nn.BatchNorm1d(args.hidden_channels)) + self.convs.append(conv_cls(args.hidden_channels, args.out_channels)) + self.batch_norm.append(nn.BatchNorm1d(args.hidden_channels)) + + self.dropout = args.dropout + + def reset_parameters(self) -> None: + + for conv in self.convs: + conv.reset_parameters() + for bn in self.batch_norm: + bn.reset_parameters() + + def forward(self, x: torch.Tensor, + edge_index: torch.Tensor) -> torch.Tensor: + + if self.use_predictions: + x = self.encoder(x) + x = torch.flatten(x, start_dim=1) + + for i, conv in enumerate(self.convs[:-1]): + x = conv(x, edge_index) + x = self.batch_norm[i](x) + x = F.relu(x) + x = F.dropout(x, p=self.dropout, training=self.training) + + x = self.convs[-1](x, edge_index) + + return x diff --git a/examples/llm/tape/tape/trainer/__init__.py b/examples/llm/tape/tape/trainer/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/examples/llm/tape/tape/trainer/gnn_trainer.py b/examples/llm/tape/tape/trainer/gnn_trainer.py new file mode 100644 index 000000000000..6de4a109cc58 --- /dev/null +++ b/examples/llm/tape/tape/trainer/gnn_trainer.py @@ -0,0 +1,111 @@ +from dataclasses import dataclass +from typing import Literal, Optional + +import torch +from tape.dataset.dataset import GraphDataset +from tape.gnn_model import NodeClassifier, NodeClassifierArgs + +from torch_geometric.data import Data + + +@dataclass +class GnnTrainerArgs: + epochs: int + lr: float + weight_decay: float = 0.0 + early_stopping_patience: int = 50 + device: Optional[str] = None + + +@dataclass +class GnnTrainerOutput: + loss: float + accuracy: float + logits: torch.Tensor + + +class GnnTrainer: + def __init__(self, trainer_args: GnnTrainerArgs, + graph_dataset: GraphDataset, + model_args: NodeClassifierArgs) -> None: + + self.trainer_args = trainer_args + self.dataset: Data = graph_dataset.dataset + self.device = trainer_args.device or ( + 'cuda' if torch.cuda.is_available() else 'cpu') + + use_predictions = graph_dataset.feature_type == 'prediction' + if use_predictions: + # The node features will be the `topk` classes + # ➜ Shape: (num_nodes, topk) + # It will get passed to an embedding lookup layer + # ➜ Shape: (num_nodes, topk, hidden_dim) + # And the last two dims will be flattened + # ➜ Shape: (num_nodes, topk * hidden_dim) + model_args.in_channels = graph_dataset.topk * \ + model_args.hidden_channels + else: + model_args.in_channels = self.dataset.num_node_features + model_args.out_channels = graph_dataset.num_classes + model_args.use_predictions = use_predictions + self.model = NodeClassifier(model_args).to(self.device) + self.optimizer = torch.optim.Adam( + self.model.parameters(), lr=trainer_args.lr, + weight_decay=trainer_args.weight_decay) + self.criterion = torch.nn.CrossEntropyLoss() + + def train(self) -> GnnTrainerOutput: + patience = self.trainer_args.early_stopping_patience + best_val_loss = float('inf') + epochs_without_improvement = 0 + + for epoch in range(1, self.trainer_args.epochs + 1): + train_output = self._train_eval(self.dataset, stage='train') + val_output = self._train_eval(self.dataset, stage='val') + print(f'Epoch: {epoch:03d} | Train loss: {train_output.loss:.4f}, ' + f'Val loss: {val_output.loss:.4f}, ' + f'Train accuracy: {train_output.accuracy:.4f}, ' + f'Val accuracy: {val_output.accuracy:.4f}') + if val_output.loss < best_val_loss: + best_val_loss = val_output.loss + epochs_without_improvement = 0 + else: + epochs_without_improvement += 1 + + if epochs_without_improvement >= patience: + print(f'Early stopping on epoch {epoch} due to no improvement' + f' in validation loss for {patience} epochs.') + break + + output = self._train_eval(self.dataset, stage='test') + return output + + def _train_eval(self, data: Data, stage: Literal['train', 'val', 'test']): + if stage == 'train': + self.model.train() + else: + self.model.eval() + + data = data.to(self.device) + mask = getattr(data, f'{stage}_mask') + if stage == 'train': + self.optimizer.zero_grad() + logits = self.model(data.x, data.edge_index) + loss = self.criterion(logits[mask], data.y[mask].flatten()) + loss.backward() + self.optimizer.step() + else: + with torch.inference_mode(): + logits = self.model(data.x, data.edge_index) + loss = self.criterion(logits[mask], data.y[mask].flatten()) + + accuracy = GnnTrainer.compute_accuracy(logits, data.y, mask) + return GnnTrainerOutput(loss=float(loss), accuracy=accuracy, + logits=logits) + + @staticmethod + def compute_accuracy(logits: torch.Tensor, y_true: torch.Tensor, + mask: torch.Tensor) -> float: + y_pred = logits.argmax(dim=1) + correct = y_pred[mask] == y_true[mask] + return int(correct.sum()) / int(mask.sum()) diff --git a/examples/llm/tape/tape/utils.py b/examples/llm/tape/tape/utils.py new file mode 100644 index 000000000000..c2cfd169e1f6 --- /dev/null +++ b/examples/llm/tape/tape/utils.py @@ -0,0 +1,24 @@ +import hashlib +import time +from typing import Callable + + +def profile_execution(func: Callable) -> Callable: + def wrapper(*args, **kwargs): + start_time = time.perf_counter() + result = func(*args, **kwargs) + execution_time = time.perf_counter() - start_time + minutes = int(execution_time // 60) + seconds = execution_time % 60 + print(f"Function '{func.__name__}' executed in " + f'{minutes} minutes and {seconds:.2f} seconds.\n') + return result + + return wrapper + + +def generate_string_hash(input_string: str, algorithm: str = 'sha256'): + input_bytes = input_string.encode('utf-8') + hash_obj = hashlib.new(algorithm) + hash_obj.update(input_bytes) + return hash_obj.hexdigest() diff --git a/examples/llm/tape/train.py b/examples/llm/tape/train.py new file mode 100644 index 000000000000..0fa0e3dc3965 --- /dev/null +++ b/examples/llm/tape/train.py @@ -0,0 +1,130 @@ +import copy +from dataclasses import is_dataclass +from typing import Optional + +import numpy as np +import pandas as pd +import torch +from jsonargparse import ActionConfigFile, ArgumentParser +from tape.config import DatasetName, FeatureType +from tape.dataset.dataset import GraphDataset +from tape.dataset.llm.engine import LlmOfflineEngineArgs, LlmOnlineEngineArgs +from tape.dataset.lm_encoder import LmEncoderArgs +from tape.gnn_model import NodeClassifierArgs +from tape.trainer.gnn_trainer import GnnTrainer, GnnTrainerArgs +from tape.utils import profile_execution + + +def get_parser() -> ArgumentParser: + # `omegaconf` for variable interpolation + parser = ArgumentParser(parser_mode='omegaconf') + parser.add_argument('--config', action=ActionConfigFile) + parser.add_argument('--dataset', type=DatasetName) + parser.add_argument('--feature_type', type=FeatureType) + parser.add_argument('--cache_dir', type=str) + parser.add_argument('--seed', type=int, default=42) + parser.add_argument('--seed_runs', type=Optional[int], default=None) + parser.add_argument('--device', type=str, required=False) + parser.add_argument('--lm_encoder', type=LmEncoderArgs) + parser.add_argument('--llm_online_engine', + type=Optional[LlmOnlineEngineArgs], default=None) + parser.add_argument('--llm_offline_engine', + type=Optional[LlmOfflineEngineArgs], default=None) + parser.add_argument('--gnn_model', type=NodeClassifierArgs) + parser.add_argument('--gnn_trainer', type=GnnTrainerArgs) + return parser + + +def update_feature_type(args, feature_type: FeatureType): + field_name = 'feature_type' + args_copy = copy.deepcopy(args) + for attr, attribute_value in vars(args_copy).items(): + if (is_dataclass(attribute_value) + and hasattr(attribute_value, field_name)): + field_value = getattr(attribute_value, field_name) + if (isinstance(field_value, FeatureType) + and (field_value == FeatureType.TAPE)): + setattr(attribute_value, field_name, feature_type) + elif attr == field_name: + if (isinstance(attribute_value, FeatureType) + and (attribute_value == FeatureType.TAPE)): + setattr(args_copy, attr, feature_type) + return args_copy + + +def _train(args): + graph_dataset = GraphDataset( + dataset_name=args.dataset, + feature_type=args.feature_type, + lm_encoder_args=args.lm_encoder, + llm_online_engine_args=args.llm_online_engine, + llm_offline_engine_args=args.llm_offline_engine, + device=args.device, + cache_dir=args.cache_dir, + seed=args.seed, + ) + trainer = GnnTrainer( + trainer_args=args.gnn_trainer, + graph_dataset=graph_dataset, + model_args=args.gnn_model, + ) + test_output = trainer.train() + return graph_dataset, test_output + + +@profile_execution +def _run(args): + if args.feature_type == FeatureType.TAPE: + logits = [] + pred_rows = [] + for value in ('TA', 'P', 'E'): + ftype = FeatureType._value2member_map_[value] + _args = update_feature_type(args, feature_type=ftype) + graph_dataset, test_output = _train(_args) + logits.append(test_output.logits) + ftype_str = f'{ftype.name} ({ftype.value})' + print(f'[Feature type: {ftype_str}] Test accuracy: ' + f'{test_output.accuracy:.4f}') + pred_rows.append( + dict(Feature_type=ftype_str, + Test_accuracy=test_output.accuracy)) + + # Fuse predictions of features (TA, P, E) by taking an average + logits = torch.stack(logits).mean(dim=0) + y_true = graph_dataset.dataset.y + mask = graph_dataset.dataset.test_mask + test_acc = GnnTrainer.compute_accuracy(logits=logits, y_true=y_true, + mask=mask) + ftype_str = f'{args.feature_type.name} ({args.feature_type.value})' + pred_rows.append(dict(Feature_type=ftype_str, + Test_accuracy=test_acc), ) + + print() + print(pd.DataFrame(pred_rows)) + else: + # Make sure the feature type used across config is consistent + _args = update_feature_type(args, feature_type=args.feature_type) + _, test_output = _train(_args) + test_acc = test_output.accuracy + ftype_str = f'{args.feature_type.name} ({args.feature_type.value})' + print(f'[Feature type: {ftype_str}] ' + f'Test accuracy: {test_acc:.4f}') + return test_acc + + +if __name__ == '__main__': + parser = get_parser() + args = parser.parse_args() + args = parser.instantiate_classes(args) + + if args.seed_runs is None: + _run(args) + else: + test_accs = [] + for seed in range(args.seed_runs): + args.seed = seed + test_acc = _run(args) + test_accs.append(test_acc) + ftype_str = f'{args.feature_type.name} ({args.feature_type.value})' + acc_str = f'{np.mean(test_accs):.4f} ± {np.std(test_accs):.4f}' + print(f'[Feature type: {ftype_str}] Test accuracy: {acc_str}') diff --git a/examples/llm/tape/train_config.yaml b/examples/llm/tape/train_config.yaml new file mode 100644 index 000000000000..b48e868ef6d7 --- /dev/null +++ b/examples/llm/tape/train_config.yaml @@ -0,0 +1,49 @@ +dataset: PUBMED +feature_type: TAPE +cache_dir: .cache +seed: 42 + +lm_encoder: + dataset_name: ${dataset} + feature_type: ${feature_type} + model_name_or_path: avsolatorio/GIST-Embedding-v0 + model_library: sentence_transformer + sentence_transformer_encoder_args: + batch_size: 100 + show_progress_bar: true + precision: float32 + cache_dir: ${cache_dir} + +llm_online_engine: + cache_dir: ${cache_dir} + sampling_kwargs: + max_tokens: 500 # LLM completion tokens + # Pick any provider model from https://litellm.vercel.app/docs/providers + # model: anthropic/claude-3-haiku-20240307 + # rate_limit_per_minute: 5 # https://docs.anthropic.com/en/api/rate-limits#rate-limits + model: huggingface/meta-llama/Meta-Llama-3-8B-Instruct + +# llm_offline_engine: +# cache_dir: ${cache_dir} +# sampling_kwargs: +# max_tokens: 500 +# n: 1 +# temperature: 0.6 +# top_p: 0.9 +# model: meta-llama/Meta-Llama-3-8B-Instruct +# batch_size: 100 +# engine_kwargs: +# seed: ${seed} +# max_model_len: 8192 + +gnn_model: + conv_layer: SAGEConv # `torch_geometric.nn.conv` layer + hidden_channels: 64 + num_layers: 4 + dropout: 0.1 + +gnn_trainer: + epochs: 500 + early_stopping_patience: 50 + lr: 0.0031622776601683794 # 10**-2.5 + weight_decay: 0.00001 # 10**-5