Skip to content

Commit

Permalink
UltraQuery release (#22)
Browse files Browse the repository at this point in the history
Main UltraQuery code
  • Loading branch information
migalkin authored Apr 24, 2024
1 parent c414f83 commit 561f2c7
Show file tree
Hide file tree
Showing 17 changed files with 2,725 additions and 20 deletions.
191 changes: 183 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

[![pytorch](https://img.shields.io/badge/PyTorch_2.1+-ee4c2c?logo=pytorch&logoColor=white)](https://pytorch.org/get-started/locally/)
[![pyg](https://img.shields.io/badge/PyG_2.4+-3C2179?logo=pyg&logoColor=#3C2179)](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html)
[![arxiv](http://img.shields.io/badge/arxiv-2310.04562-yellow.svg)](https://arxiv.org/abs/2310.04562)
[![ULTRA arxiv](http://img.shields.io/badge/arxiv-2310.04562-yellow.svg)](https://arxiv.org/abs/2310.04562)
[![UltraQuery arxiv](http://img.shields.io/badge/arxiv-2404.07198-yellow.svg)](https://arxiv.org/abs/2404.07198)
[![HuggingFace Hub](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-black)](https://huggingface.co/collections/mgalkin/ultra-65699bb28369400a5827669d)
![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)

Expand Down Expand Up @@ -37,6 +38,7 @@ This repository is based on PyTorch 2.1 and PyTorch-Geometric 2.4.
* [Pre-train](#pretraining) ULTRA on your own mixture of graphs.
* Run [evaluation on many datasets](#run-on-many-datasets) sequentially.
* Use the pre-trained checkpoints to run inference and fine-tuning on [your own KGs](#adding-your-own-graph).
* (NEW) Execute complex logical queries on any KG with [UltraQuery](#ultraquery)

Table of contents:
* [Installation](#installation)
Expand All @@ -47,8 +49,10 @@ Table of contents:
* [Pretraining](#pretraining)
* [Datasets](#datasets)
* [Adding custom datasets](#adding-your-own-graph)
* [UltraQuery](#ultraquery)

## Updates
* **Apr 23rd, 2024**: Release of [UltraQuery](#ultraquery) for complex multi-hop logical query answering on _any_ KG (with new checkpoint and 23 datasets).
* **Jan 15th, 2024**: Accepted at [ICLR 2024](https://openreview.net/forum?id=jVEoydFOl9)!
* **Dec 4th, 2023**: Added a new ULTRA checkpoint `ultra_50g` pre-trained on 50 graphs. Averaged over 16 larger transductive graphs, it delivers 0.389 MRR / 0.549 Hits@10 compared to 0.329 MRR / 0.479 Hits@10 of the `ultra_3g` checkpoint. The inductive performance is still as good! Use this checkpoint for inference on larger graphs.
* **Dec 4th, 2023**: Pre-trained ULTRA models (3g, 4g, 50g) are now also available on the [HuggingFace Hub](https://huggingface.co/collections/mgalkin/ultra-65699bb28369400a5827669d)!
Expand Down Expand Up @@ -340,17 +344,188 @@ class CustomDataset(InductiveDataset):
TSV / CSV files are supported by setting a delimiter (eg, `delimiter = "\t"`) in the class definition.
After adding your own dataset, you can immediately run 0-shot inference or fine-tuning of any ULTRA checkpoint.

## UltraQuery ##

You can now run complex logical queries on any KG with UltraQuery, an inductive query answering approach that uses any Ultra checkpoint with non-parametric fuzzy logic operators. Read more in the [new preprint](https://arxiv.org/abs/2404.07198).

Similar to Ultra, UltraQuery transfers to any KG in the zero-shot fashion and sets a few SOTA results on a variety of query answering benchmarks.

### Checkpoint ###

Any existing ULTRA checkpoint is compatible with UltraQuery but we also ship a newly trained `ultraquery.pth` checkpoint in the `ckpts` folder.

* A new `ultraquery.pth` checkpoint trained on complex queries from the `FB15k237LogicalQuery` dataset for 40,000 steps, the config is in `config/ultraquery/pretrain.yaml` - the same ULTRA architecture but tuned for the multi-source propagation needed in complex queries (no need for score thresholding)
* You can use any existing ULTRA checkpoint (`3g` / `4g` / `50g`) for starters - don't forget to set the `--threshold` argument to 0.8 or higher (depending on the dataset). Score thresholding is required because those models were trained on simple one-hop link prediction and there are certain issues (namely, the multi-source propagation issue, read Section 4.1 in the [new preprint](https://arxiv.org/abs/2404.07198) for more details)

### Performance

The numbers reported in the preprint were obtained with a model trained with TorchDrug. In this PyG implementation, we managed to get even better performance across the board with the `ultraquery.pth` checkpoint.

`EPFO` is the averaged performance over 9 queries with relation projection, intersection, and union. `Neg` is the averaged performance over 5 queries with negation.

<table>
<tr>
<th rowspan=2>Model</th>
<th colspan=4>Total Average (23 datasets)</th>
<th colspan=4>Transductive (3 datasets)</th>
<th colspan=4>Inductive (e) (9 graphs)</th>
<th colspan=4>Inductive (e,r) (11 graphs)</th>
</tr>
<tr>
<th>EPFO MRR</th>
<th>EPFO Hits@10</th>
<th>Neg MRR</th>
<th>Neg Hits@10</th>
<th>EPFO MRR</th>
<th>EPFO Hits@10</th>
<th>Neg MRR</th>
<th>Neg Hits@10</th>
<th>EPFO MRR</th>
<th>EPFO Hits@10</th>
<th>Neg MRR</th>
<th>Neg Hits@10</th>
<th>EPFO MRR</th>
<th>EPFO Hits@10</th>
<th>Neg MRR</th>
<th>Neg Hits@10</th>
</tr>
<tr>
<th>UltraQuery Paper</th>
<td align="center">0.301</td>
<td align="center">0.428</td>
<td align="center">0.152</td>
<td align="center">0.264</td>
<td align="center">0.335</td>
<td align="center">0.467</td>
<td align="center">0.132</td>
<td align="center">0.260</td>
<td align="center">0.321</td>
<td align="center">0.479</td>
<td align="center">0.156</td>
<td align="center">0.291</td>
<td align="center">0.275</td>
<td align="center">0.375</td>
<td align="center">0.153</td>
<td align="center">0.242</td>
</tr>
<tr>
<th>UltraQuery PyG</th>
<td align="center">0.309</td>
<td align="center">0.432</td>
<td align="center">0.178</td>
<td align="center">0.286</td>
<td align="center">0.411</td>
<td align="center">0.518</td>
<td align="center">0.240</td>
<td align="center">0.352</td>
<td align="center">0.312</td>
<td align="center">0.468</td>
<td align="center">0.139</td>
<td align="center">0.262</td>
<td align="center">0.280</td>
<td align="center">0.380</td>
<td align="center">0.193</td>
<td align="center">0.288</td>
</tr>
</table>

In particular, we reach SOTA on FB15k queries (0.764 MRR & 0.834 Hits@10 on EPFO; 0.567 MRR & 0.725 Hits@10 on negation) compared to much larger and heavier baselines (such as QTO).

### Run Inference ###

The running format is similar to the KG completion pipeline - use `run_query.py` and `run_query_many` for running a single expriment on one dataset or on a sequence of datasets.
Due to the size of the datasets and query complexity, it is recommended to run inference on a GPU.

An example command for running transductive inference with UltraQuery on FB15k237 queries

```bash
python script/run_query.py -c config/ultraquery/transductive.yaml --dataset FB15k237LogicalQuery --epochs 0 --bpe null --gpus [0] --bs 32 --threshold 0.0 --ultra_ckpt null --qe_ckpt /path/to/ultra/ckpts/ultraquery.pth
```

An example command for running transductive inference with a vanilla Ultra 4g on FB15k237 queries with scores thresholding

```bash
python script/run_query.py -c config/ultraquery/transductive.yaml --dataset FB15k237LogicalQuery --epochs 0 --bpe null --gpus [0] --bs 32 --threshold 0.8 --ultra_ckpt /path/to/ultra/ckpts/ultra_4g.pth --qe_ckpt null
```

An example command for running inductive inference with UltraQuery on `InductiveFB15k237Query:550` queries

```bash
python script/run_query.py -c config/ultraquery/inductive.yaml --dataset InductiveFB15k237Query --version 550 --epochs 0 --bpe null --gpus [0] --bs 32 --threshold 0.0 --ultra_ckpt null --qe_ckpt /path/to/ultra/ckpts/ultraquery.pth
```

New arguments for `_query` scripts:
* `--threshold`: set to 0.0 when using the main UltraQuery checkpoint `ultraquery.pth` or 0.8 (and higher) when using vanilla Ultra checkpoints
* `--qe_ckpt`: path to the UltraQuery checkpoint, set to `null` if you want to run vanilla Ultra checkpoints
* `--ultra_ckpt`: path to the original Ultra checkpoints, set to `null` if you want to run the UltraQuery checkpoint

### Datasets ###

23 new datasets available in `datasets_query.py` that will be automatically downloaded upon the first launch.
All datasets include 14 standard query types (`1p`, `2p`, `3p`, `2i`, `3i`, `ip`, `pi`, `2u-DNF`, `up-DNF`, `2in`, `3in`,`inp`, `pin`, `pni`).

The standard protocol is training on 10 patterns without unions and `ip`,`pi` queries (`1p`, `2p`, `3p`, `2i`, `3i`, `2in`, `3in`,`inp`, `pin`, `pni`) and running evaluation on all 14 patterns including `2u`, `up`, `ip`, `pi`.

<details>
<summary>Transductive query datasets (3)</summary>

All are the [BetaE](https://arxiv.org/abs/2010.11465) versions of the datasets including queries with negation and limiting the max number of answers to 100
* `FB15k237LogicalQuery`, `FB15kLogicalQuery`, `NELL995LogicalQuery`

</details>

<details>
<summary>Inductive (e) query datasets (9)</summary>

9 inductive datasets extracted from FB15k237 - first proposed in [Inductive Logical Query Answering in Knowledge Graphs](https://openreview.net/forum?id=-vXEN5rIABY) (NeurIPS 2022)

`InductiveFB15k237Query` with 9 versions where the number shows the how large is the inference graph compared to the train graph (in the number of nodes):
* `550`, `300`, `217`, `175`, `150`, `134`, `122`, `113`, `106`

In addition, we include the `InductiveFB15k237QueryExtendedEval` dataset with the same versions. Those are supposed to be inference-only datasets that measure the _faithfulness_ of complex query answering approaches. In each split, as validation and test graphs extend the train graphs with more nodes and edges, training queries now have more true answers achievable by simple edge traversal (no missing link prediction required) - the task is to measure how well CLQA models can retrieve new easy answers on training queries but on larger unseen graphs.

</details>

<details>
<summary>Inductive (e,r) query datasets (11)</summary>

11 new inductive query datasets (WikiTopics-CLQA) that we built specifically for testing UltraQuery.
The queries were sampled from the WikiTopics splits proposed in [Double Equivariance for Inductive Link Prediction for Both New Nodes and New Relation Types](https://arxiv.org/abs/2302.01313)

`WikiTopicsQuery` with 11 versions
* `art`, `award`, `edu`, `health`, `infra`, `loc`, `org`, `people`, `sci`, `sport`, `tax`

</details>

### Metrics

New metrics include `auroc`, `spearmanr`, `mape`. We don't support Mean Rank `mr` in complex queries. If you ever see `nan` in one of those metrics, consider reducing the batch size as those metrics are computed with the variadic functions that might be numerically unstable on large batches.

## Citation ##

If you find this codebase useful in your research, please cite the original paper.
If you find this codebase useful in your research, please cite the original papers.

The main ULTRA paper:

```bibtex
@inproceedings{galkin2023ultra,
title={Towards Foundation Models for Knowledge Graph Reasoning},
author={Mikhail Galkin and Xinyu Yuan and Hesham Mostafa and Jian Tang and Zhaocheng Zhu},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
url={https://openreview.net/forum?id=jVEoydFOl9}
}
```

UltraQuery:

```bibtex
@article{galkin2023ultra,
title={Towards Foundation Models for Knowledge Graph Reasoning},
author={Mikhail Galkin and Xinyu Yuan and Hesham Mostafa and Jian Tang and Zhaocheng Zhu},
year={2023},
eprint={2310.04562},
@article{galkin2024ultraquery,
title={Zero-shot Logical Query Reasoning on any Knowledge Graph},,
author={Mikhail Galkin and Jincheng Zhou and Bruno Ribeiro and Jian Tang and Zhaocheng Zhu},
year={2024},
eprint={2404.07198},
archivePrefix={arXiv},
primaryClass={cs.CL}
primaryClass={cs.AI}
}
```
Binary file added ckpts/ultraquery.pth
Binary file not shown.
4 changes: 2 additions & 2 deletions config/inductive/inference.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ dataset:
model:
class: Ultra
relation_model:
class: NBFNet
class: RelNBFNet
input_dim: 64
hidden_dims: [64, 64, 64, 64, 64, 64]
message_func: distmult
aggregate_func: sum
short_cut: yes
layer_norm: yes
entity_model:
class: IndNBFNet
class: EntityNBFNet
input_dim: 64
hidden_dims: [64, 64, 64, 64, 64, 64]
message_func: distmult
Expand Down
4 changes: 2 additions & 2 deletions config/transductive/inference.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@ dataset:
model:
class: Ultra
relation_model:
class: NBFNet
class: RelNBFNet
input_dim: 64
hidden_dims: [64, 64, 64, 64, 64, 64]
message_func: distmult
aggregate_func: sum
short_cut: yes
layer_norm: yes
entity_model:
class: IndNBFNet
class: EntityNBFNet
input_dim: 64
hidden_dims: [64, 64, 64, 64, 64, 64]
message_func: distmult
Expand Down
4 changes: 2 additions & 2 deletions config/transductive/pretrain_3g.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ dataset:
model:
class: Ultra
relation_model:
class: NBFNet
class: RelNBFNet
input_dim: 64
hidden_dims: [64, 64, 64, 64, 64, 64]
message_func: distmult
aggregate_func: sum
short_cut: yes
layer_norm: yes
entity_model:
class: IndNBFNet
class: EntityNBFNet
input_dim: 64
hidden_dims: [64, 64, 64, 64, 64, 64]
message_func: distmult
Expand Down
4 changes: 2 additions & 2 deletions config/transductive/pretrain_4g.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ dataset:
model:
class: Ultra
relation_model:
class: NBFNet
class: RelNBFNet
input_dim: 64
hidden_dims: [64, 64, 64, 64, 64, 64]
message_func: distmult
aggregate_func: sum
short_cut: yes
layer_norm: yes
entity_model:
class: IndNBFNet
class: EntityNBFNet
input_dim: 64
hidden_dims: [64, 64, 64, 64, 64, 64]
message_func: distmult
Expand Down
53 changes: 53 additions & 0 deletions config/ultraquery/inductive.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
output_dir: ~/git/ULTRA/output

dataset:
class: {{ dataset }}
root: ~/git/ULTRA/query-datasets/
version: {{ version }} # specify dataset version here or when running the script

model:
class: UltraQuery
model:
class: Ultra
relation_model:
class: RelNBFNet
input_dim: 64
hidden_dims: [64, 64, 64, 64, 64, 64]
message_func: distmult
aggregate_func: sum
short_cut: yes
layer_norm: yes
entity_model:
class: QueryNBFNet
input_dim: 64
hidden_dims: [64, 64, 64, 64, 64, 64]
message_func: distmult
aggregate_func: sum
short_cut: yes
layer_norm: yes
logic: product
dropout_ratio: 0.5
threshold: {{ threshold }}
more_dropout: 0.0

task:
name: InductiveInference
strict_negative: yes
adversarial_temperature: 0.2
sample_weight: no
metric: [mrr, hits@1, hits@3, hits@10, auroc, spearmanr] # mape is supported as well

optimizer:
class: Adam
lr: 5.0e-4

train:
gpus: {{ gpus }}
batch_size: {{ bs }} # reduce if doesn't fit on a GPU
num_epoch: {{ epochs }} # total number of optimization steps will be num_epochs * batch_per_epoch
batch_per_epoch: {{ bpe }} # number of batches to be considered as "one epoch"
log_interval: 100
fast_test: 1000 # UltraQuery is slower in inference, use this option for a random subsample of valid data

ultra_ckpt: {{ ultra_ckpt }} # Ultra checkpoint pre-trained on simple link prediction
ultraquery_ckpt: {{ qe_ckpt }} # UltraQuery checkpoint trained on complex queries
Loading

0 comments on commit 561f2c7

Please sign in to comment.