This repo provides ResNet example for CIFAR-10 using Google's JAX. I aim to provide a simple baseline code for deep learning researchers who want to quickly get started with JAX. For those who are not famlilar with JAX, it is Autograd + XLA.
I built upon Deepmind's Haiku and Optax for high-level neural net API. I used PyTorch and Torchvision for data loading pipeline. My ResNet implementation is based on this repo.
Updates:
- Support for mixed precision training using JMP.
- Support for multi-GPU training:
train_multigpu.py
- JAX
- Haiku
- Optax
- dm-tree
- PyTorch
- Torchvision
python train.py
python train_mp.py
Model | Size | Test Acc |
---|---|---|
ResNet20 | 0.27 M | 91.5 % |
ResNet32 | 0.46 M | 92.5 % |
ResNet44 | 0.66 M | 93.1 % |
ResNet56 | 0.85 M | 93.2 % |