This is the official implementation of the paper Tree Cross Attention.
Tree Cross Attention (TCA) is a module based on Cross Attention that only retrieves information from a logarithmic O(log(N)) number of tokens for performing inference. TCA organizes the data in a tree structure and performs a tree search at inference time to retrieve the relevant tokens for prediction. Leveraging TCA, we introduce ReTreever, a flexible architecture for token-efficient inference.
Create and activate a conda environment. Install the dependencies as listed in requirements.txt
:
conda create --name tca python=3.9
conda activate tca
pip install -r requirements.txt
The default hyperparameters are saved in configs/
. Model weights and logs are saved in results/{setting}/{expid}
. Note that when running experiments, the {expid}
must match between training and evaluation since the model will load weights from results/{setting}/{expid}
when evaluating. If training for the first time, evaluation data will be generated and saved in evalsets/{setting}}
.
To evaluate Transformer + Cross Attention
, add --decoder_type ca
to the script.
Training:
python copy_task.py --mode train --expid tca --sequence_length 256
Evaluation:
python copy_task.py --mode eval --expid tca --sequence_length 256
Training:
python gp.py --mode train --expid tca
Evaluation:
python gp.py --mode eval --expid tca --eval_kernel rbf
python gp.py --mode eval --expid tca --eval_kernel matern
Download the files from the official CelebA google drive. Specifically, download list_eval_partitions.txt, identity_CelebA.txt, and img_align_celeba.zip and unzip, placing the downloaded files in datasets/celeba
folder.
Alternatively, follow the instructions from the official website to download the aforementioned necessary files.
After downloading the data, preprocess it by running python data/celeba.py
.
Training:
python celeba.py --mode train --expid tca
Evaluation:
python celeba.py --mode eval --expid tca
Training:
If training for the first time, EMNIST training data will automatically download and save in datasets/emnist
.
python emnist.py --mode train --expid tca
Evaluation:
python emnist.py --mode eval --expid tca --class_range 0 10
python emnist.py --mode eval --expid tca --class_range 10 47
For technical details, please check the conference version of our paper.
@inproceedings{
feng2024tree,
title={Tree Cross Attention},
author={Leo Feng and Frederick Tung and Hossein Hajimirsadeghi and Yoshua Bengio and Mohamed Osama Ahmed},
booktitle={International Conference on Learning Representations},
year={2024},
url={https://openreview.net/forum?id=Vw24wtSddM}
}
This code uses parts from the codebases of Transformer Neural Processes, Perceiver, and Pytorch.