Unofficial implementation of Denoising Diffusion Probabilistic Models (DDPM) in JAX and Flax.
Denoising Diffusion Implicit Models (DDIM) sampling is used as well.
Real | Generated |
---|---|
Model has 5.46M parameters, trained on Colab (T4) for 100K steps with batch size 128 in 8.5 hours.
Full hyperparameters can be found in configs/mnist.py.
Real | Generated |
---|---|
Model has 9.70M parameters, trained on Kaggle (TPUv3-8) for 40K steps with batch size 128 in 2.5 hours.
Full hyperparameters can be found in configs/fashion_mnist.py.
Real | Generated |
---|---|
Due to compute constraints, the model is only trained for 64 x 64 images.
Model has 72.70M parameters, trained on Kaggle (P100) for 60K steps with batch size 64 in 22 hours.
Full hyperparameters can be found in configs/celeb_a64.py.