This repo contains original PyTorch implementation of:
- Making Pre-trained Language Models Great on Tabular Prediction (ICLR 2024 Spotlight)
The following key features are proposed in this paper:
-
Relative magnitude tokenization (RMT): a distributed representation method for continuous values to enhance LM's numerical perception capability.
-
Intra-feature attention (IFA): a mechanism to pre-fuse feature-wise information for reasonable tabular feature contexts & model acceleration.
-
TP-BERTa: a resulting LM pre-trained from RoBERTa with the above features for tabular prediction.
The repo structure and module functions are as follows:
├─bin ---- // Implementation of tabular models
│ ├─tpberta_modeling.py ---- // TP-BERTa base class
│ └─xxx.py ----------------- // Other non-LM DNN baselines
├─lib ---- // Utilities
│ ├─aux.py --------------- // Auxiliary Loss: Magnitude-aware Triplet Loss
│ ├─feature_encoder.py --- // Numerical Value Binner (C4.5 discretization)
│ ├─optim.py ------------- // Utilities for optimizer & trainer
│ ├─env.py --------------- // Environment Variables configs
│ ├─data.py -------------- // Dataset & Data Transformation class
│ ├─data_utils.py -------- // Data Config & Multi-task Loader class
│ └─xxx.py --------------- // Other standard utils
│
├─checkpoints --- // Pre-trained model weights & configs (RoBERTa, TP-BERTa)
├─configs --- // Model & Training configs for non-LM baselines
│ ├─default --- // default configs
│ └─tuned ----- // tuned configs (generated with hyperparameter tuning scripts)
├─scripts --- // Experiment codes
│ ├─data --- // feature_names.json and csv file path for pre-training & fine-tuning
│ │ ├─pretrain-bin
│ │ ├─pretrain-reg
│ │ ├─finetune-bin
│ │ ├─finetune-reg
│ │ └─finetune-mul
│ ├─examples --- // Example shell scripts for main experiments
│ ├─pretrain --- // Codes for TP-BERTa pre-training
│ ├─finetune --- // Codes for baseline fine-tuning & hyperparameter tuning
│ └─clean_feat_names.py --- // Text clean for table feature names (making for feature_names.json)
All necessary dependencies for TP-BERTa are included in requirement.txt
. To conduct the packaged baselines, uncomment the corresponding lines.
In experiment we saved weights and configs of RoBERTa-base in the local checkpoints/roberta-base
folder (network unavailable) and conducted pre-training with scripts/pretrain/pretrain_tpberta.py
. You can use online HuggingFace APIs by assigning the argument --base_model_dir
with "FacebookAI/roberta-base".
Pretraining Example
cd ~/tp-berta/scripts
python clean_feat_names.py
python pretrain/pretrain_tpberta.py --task "binclass" --batch_size 512 --max_epochs 30 --base_model_dir="FacebookAI/roberta-base"
The TP-BERTa is designed for standard supervised tabular data prediction, it requires fine-tuning on downstream datasets, and a larger training round (in experiment we uniformly used 200 max training epochs with an early stop of 50 epochs, codes here) is preferred compared to the small tabular deep models (e.g., FT-Transformer) since slightly fine-tuning the BERT-sized model is tend to be underfit in our practice.
-
Upload pre-trained TP-BERTa checkpoints.
- Download TP-BERTa checkpoints pre-trained on single task type or both task types.
- Move the
*.tar.gz
file to thecheckpoints
folder (create one if not exists) - Unzip the file and run TP-BERTa according to the scripts in
scripts/examples/finetune
.
-
Sort and update experiment datasets.
- We have acquired permission on distributing the used data subset from TabPertNet (OpenTabs currently) datasets.
- Download datasets for pre-training (202 datasets) and fine-tuning (145 datasets).
- Unzip the
*.tar.gz
file to thedata
folder (create one if not exists).
-
Integrate TP-BERTa to HuggingFace🤗 community.
If you find this useful for your research, please cite the following paper:
@article{yan2024making,
title={Making Pre-trained Language Models Great on Tabular Prediction},
author={Yan, Jiahuan and Zheng, Bo and Xu, Hongxia and Zhu, Yiheng and Chen, Danny and Sun, Jimeng and Wu, Jian and Chen, Jintai},
journal={arXiv preprint arXiv:2403.01841},
year={2024}
}
Our codes are influenced by the following repos: