This repository provides code for training on BridgeData V2.
We provide implementations for the following subset of methods described in the paper:
- Goal-conditioned BC
- Goal-conditioned BC with a diffusion policy
- Goal-condtioned IQL
- Goal-conditioned contrastive RL
The code for the language-conditioned BC method may be released soon.
The official implementations and papers for all the methods can be found here:
- IQDL (IQL + diffusion policy) [Hansen-Estruch et al.] and Diffusion Policy [Chi et al.]
- IQL [Kostrikov et al.]
- Contrastive RL [Zheng et al., Eysenbach et al.]
- RT-1 [Brohan et al.]
- ACT [Zhao et al.]
Please open a GitHub issue if you encounter problems with this code.
The raw dataset (comprised of JPEGs, PNGs, and pkl files) can be downloaded from the website. For training, the raw data needs to be converted into a TFRecord format that is compatible with the data loader. First, use data_processing/bridgedata_raw_to_numpy.py
to convert the raw data into numpy files. Then, use data_processing/bridgedata_numpy_to_tfrecord.py
to convert the numpy files into TFRecord files.
To start training run the command below. Replace METHOD
with one of gc_bc
, gc_ddpm_bc
, gc_iql
, or contrastive_rl_td
, and replace NAME
with a name for the run.
python experiments/train.py \
--config experiments/configs/train_config.py:METHOD \
--bridgedata_config experiments/configs/data_config.py:all \
--name NAME
Training hyperparameters can be modified in experiments/configs/data_config.py
and data parameters (e.g. subsets to include/exclude) can be modified in experiments/configs/train_config.py
.
First, set up the robot hardware according to our guide. Install our WidowX robot controller stack from this repo. Then, run the command:
python experiments/eval.py \
--num_timesteps NUM_TIMESTEPS \
--video_save_path VIDEO_DIR \
--checkpoint_path CHECKPOINT_PATH \
--wandb_run_name WANDB_RUN_NAME \
--blocking
The script loads some information about the checkpoint from its corresponding WandB run.
Checkpoints for GCBC, D-GCBC, GCIQL, CRL, and RT-1 are available here. Each checkpoint (except RT-1) has an associated JSON file with its configuration information. To evaluate these checkpoints with the above evaluation script, modify the references to the wandb run configuration to use the dictionary provided in the JSON file instead.
An evaluation script for the RT-1 checkpoint is available in this separate repo (TODO).
We don't currently have checkpoints for ACT or LCBC available but may release them soon.
The dependencies for this codebase can be installed in a conda environment:
conda create -n jaxrl python=3.10
conda activate jaxrl
pip install -e .
pip install -r requirements.txt
For GPU:
pip install --upgrade "jax[cuda11_pip]==0.4.13" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
For TPU
pip install --upgrade "jax[tpu]==0.4.13" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
See the Jax Github page for more details on installing Jax.
This code is based on dibyaghosh/jaxrl_m.
If you use this code and/or BridgeData V2 in your work, please cite the paper with:
@inproceedings{walke2023bridgedata,
title={BridgeData V2: A Dataset for Robot Learning at Scale},
author={Walke, Homer and Black, Kevin and Lee, Abraham and Kim, Moo Jin and Du, Max and Zheng, Chongyi and Zhao, Tony and Hansen-Estruch, Philippe and Vuong, Quan and He, Andre and Myers, Vivek and Fang, Kuan and Finn, Chelsea and Levine, Sergey},
booktitle={Conference on Robot Learning (CoRL)},
year={2023}
}