Skip to content

PyTorch implementation of the Recurrent Interface Network (RIN), based on the original Tensorflow implementation from the pix2seq repository.

License

Notifications You must be signed in to change notification settings

leon-w/rin-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

rin-pytorch

PyTorch implementation of Recurrent Interface Network (RIN). The codebase is largely a translation of the Tensorflow code of the original authors (google-research/pix2seq) and it should behave almost identically (even transferring weights is possible). Some simplifications have been made, removing unused code and options. In the current state, the code is optimized for class-guided image generation, but it should be easy to adapt it to other tasks.

The model is implemented using Keras Core with the PyTorch backend as this made the translation of the original Tensorflow code relatively straightforward as it was also using the Keras API.

The training logic is adapted from lucidrains/recurrent-interface-network-pytorch.

Usage

import torchvision

from rin_pytorch import Rin, RinDiffusionModel, Trainer

rin = Rin(**rin_config).cuda()
rin.pass_dummy_data(num_classes=10)  # populate lazy model with weights
diffusion_model = RinDiffusionModel(rin=rin, **diffusion_config)

# also create an EMA model
rin_ema = Rin(**rin_config).cuda()
rin_ema.pass_dummy_data(num_classes=10)
ema_diffusion_model = RinDiffusionModel(rin=rin_ema, **diffusion_config)

dataset = ...

trainer = Trainer(
    diffusion_model,
    ema_diffusion_model,
    dataset,
    **trainer_config,
)

trainer.train()

Refer to train_cifar10.py file for an example of how to train a model on CIFAR-10. From that, it should be clear how to adapt the code to other datasets and configs.

Examples

The following samples are generated using the same hyperparameters as the original authors for the CIFAR-10 dataset.

Samples from PyTorch model trained using this codebase (DDIM @ 100 steps):

example_pytorch

Samples from Tensorflow model trained using the pix2seq codebase (DDIM @ 100 steps):

example_tensorflow

About

PyTorch implementation of the Recurrent Interface Network (RIN), based on the original Tensorflow implementation from the pix2seq repository.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages