Hierarchical State Space Models for Continuous Sequence-to-Sequence Modeling
Raunaq Bhirangi, Chenyu Wang, Venkatesh Pattabiraman, Carmel Majidi, Abhinav Gupta, Tess Hellebrekers and Lerrel Pinto
Paper: https://arxiv.org/abs/2402.10211
Website: https://hiss-csp.github.io/
HiSS is a simple technique that stacks deep state space models like S4 and Mamba to reason over continuous sequences of sensory data over mutiple temporal hierarchies. We also release CSP-Bench: a benchmark for sequence-to-sequence prediction from sensory data.
-
Clone the repository
-
Create a conde environment from the provided
env.yml
file:conda env create -f env.yml
-
Install Mamba based on the official instructions.
Note: If you run into CUDA issues while installing Mamba, run export CUDA_HOME=$CONDA_PREFIX
, and try again. If you still have problems, install both causal_conv1d
and mamba-ssm
from source.
-
Refer to data_processing/README to download and extract the required dataset.
-
Set the
DATA_DIR
variable in thehiss/utils/__init__.py
file. This is the path to the parent directory which contains folders corresponding to every dataset. -
Process the datasets into format compatible with training
Marker Writing:python data_processing/process_reskin_data.py -dd marker_writing_<hiss/full>_dataset
Intrinsic Slip:python data_processing/process_reskin_data.py -dd intrinsic_slip_<hiss/full>_dataset
Joystick Control:python data_processing/process_xela_data.py -dd joystick_control_<hiss/full>_dataset
RoNIN:python data_processing/process_ronin_data.py
VECtor:python data_processing/process_vector_data.py
TotalCapture:python data_processing/process_total_capture_data.py
-
Run
create_dataset.py
for the respective dataset to preprocess data and resample it at the desired frequencies.
Marker Writing:python create_dataset.py --config-name marker_writing_config
Intrinsic Slip:python create_dataset.py --config-name intrinsic_slip_config
Joystick Control:python create_dataset.py --config-name joystick_control_config
RoNIN:
python create_dataset.py --config-name ronin_train_config
python create_dataset.py --config-name ronin_test_config
VECtor:python create_dataset.py --config-name vector_config
TotalCapture:
python create_dataset.py --config-name total_capture_train_config
python create_dataset.py --config-name total_capture_test_config
To train HiSS models for sequential prediction, use the train.py
file. For each dataset, we provide a <dataset_name>_hiss_config.yaml
file in the conf/
directory, containing model parameters corresponding to the best-performing HiSS model for the respective dataset. To train the model, simply run
python train.py --config-name <dataset_name>_hiss_config
New datasets can be added by creating a corresponding Task
object in line with tasks defined in vt_state/tasks
, and creating a config file in conf/data_env/<data_env_name>
.