In this paper, we propose Importance-aware Co-teaching for Offline Model-based Optimization (ICT), which consists of the pseudo-label-driven co-teaching and meta-learning-based sample reweighting.
The environment of ICT can be installed as:
conda create --name ICT --file requirements.txt
conda activate ICT
If you have never installed Mujoco
before, please also make sure you have .mujoco
directory under your home directory and include mujoco200
inside.
For task TF Bind 8
, you can use the following command to pre-train several proxies:
python3 grad.py --task=TFBind8-Exact-v0 --store_path=new_proxies/ > tfbind8.txt
You can define a specific directory to store the pre-trained proxies and also redirect the output to anywhere you want.
For your most convenience, you can pre-train proxies for all the seven tasks used in the paper by
./train_proxies.sh
To be noticed, all pre-trained proxies are trained for 200 epochs.
For the TF Bind 8
task, we can only run the pseudo-label-driven co-teaching by
python3 ICT.py --alpha=1e-3 --if_coteach=True --if_reweight=False --task=TFBind8-Exact-v0 --Tmax=100 --interval=100 --num_coteaching=8 >> ./result/tfbind8_exp.txt
num_coteaching
is the hyper-parameter which decides how many instances are selected in the co-teaching process, which is denoted as K
in the paper.
The meta-learning-based sample reweighting can be run as:
python3 ICT.py --alpha=1e-3 --beta=3e-1 --if_coteach=False --if_reweight=True --task=TFBind8-Exact-v0 --Tmax=100 --interval=100 --reweight_mode=full >> ./result/tfbind8_exp.txt
or
python3 ICT.py --alpha=1e-3 --beta=3e-1 --if_coteach=True --if_reweight=True --task=TFBind8-Exact-v0 --Tmax=100 --interval=100 --reweight_mode=full --num_coteaching=128 >> ./result/tfbind8_exp.txt
To be noticed, the second choice is that we keep the number of instances selected by the co-teaching process num_coteaching
(K
) the same as the number of data we augment around the current optimization point, namely the co-teaching process will have no effect here.
Moreover, there are three choices provided for the reweighting mode, namely 'top128'
, 'half'
, and 'full'
. 'top128'
means we use the top 128 designs in the offline dataset as the supervision signal for the meta-learning process. Similarly, we use the top half designs or all designs in the offline dataset as the supervision signal for the meta-learning process for 'half'
or 'full'
, respectively.
However, the reported results shown in the paper have only used the 'full'
mode.
Run our proposed ICT as:
python3 ICT.py --alpha=1e-3 --beta=3e-1 --if_coteach=True --if_reweight=True --task=TFBind8-Exact-v0 --Tmax=100 --interval=100 --reweight_mode=full --num_coteaching=8 >> ./result/tfbind8_exp.txt
For other tasks, we only need to change the name of the tasks and modify the hyper-parameter according to the paper.
Superconductor: Superconductor-RandomForest-v0
Ant Morphology: AntMorphology-Exact-v0
D'Kitty Morphology: DKittyMorphology-Exact-v0
Hopper Controller: HopperController-Exact-v0
TF Bind 8: TFBind8-Exact-v0
TF Bind 10: TFBind10-Exact-v0
NAS: CIFARNAS-Exact-v0
@inproceedings{yuan2023importanceaware,
author = {Yuan, Ye and Chen, Can (Sam) and Liu, Zixuan and Neiswanger, Willie and Liu, Xue (Steve)},
booktitle = {Advances in Neural Information Processing Systems},
editor = {A. Oh and T. Naumann and A. Globerson and K. Saenko and M. Hardt and S. Levine},
pages = {55718--55733},
publisher = {Curran Associates, Inc.},
title = {Importance-aware Co-teaching for Offline Model-based Optimization},
url = {https://proceedings.neurips.cc/paper_files/paper/2023/file/ae8b0b5838ba510daff1198474e7b984-Paper-Conference.pdf},
volume = {36},
year = {2023}
}