Our constrained neurosymbolic models outperform the Vanilla and Augmented Lagrangian methods with guarantees on conformance to natural constraints in three case studies
- CARLA -- Conformance of a vehicle model to unicycle dynamics with emphasis on no drift at-rest.
- Artificial Pancreas (AP) -- Conformance of AP models to ARMAX model that encodes increasing insulin-decreases glucose constraint.
- PyBullet Drones -- Conformance of drone models to quadrotor dynamics with emphasis on hover. See our paper for the complete set of results.
The processed data and trained models can be found inside each case study's directory at this drive folder.
The raw data is also available inside each case study's directory at this drive folder.
Alternatively, the instructions to collect all data and train the ARMAX constraint model specifically for AP can be found at README_for_Data_Collection.md.
Setup conda env:
conda create -n DL_env python=3.8
conda activate DL_env
conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
pip install neupy scipy pybullet
Train vanilla and constrained neural network dynamics models as follows where NAME_OF_ENV can be one of {Carla, Drones, AP, Quadrupeds}
cd scripts
bash run_Vanillas.sh
bash run_{NAME_OF_ENV}.sh
AP delta monotonicity analysis
bash run_delta_monotonicity.sh
Carla prediction drift analysis
bash run_test_at_rest.sh
Some helpful information
- scripts/train.py is the main file. All the bash files above call this file with different arguments. If loading data and models from the drive link above, set --only_eval.
- scripts/dataset.py contains all dataset classes to both load labelled D datasets and create unlabelled \Omega datasets. These classes also create memories (via neupy's neural gas) in the input space. And compute all lower and upper bounds for each voronoi cell in the input space (where each cell corresponds to one memory).
- scripts/model.py contains the models and trainers. The trainers contain the symbolic wrapper around the neural network. The loss function can be augmented lagrangian or simple MSE.