Skip to content
/ FMix Public

Official implementation of 'FMix: Enhancing Mixed Sample Data Augmentation'

License

Notifications You must be signed in to change notification settings

ecs-vlc/FMix

Repository files navigation

FMix

This repository contains the official implementation of the paper 'FMix: Enhancing Mixed Sampled Data Augmentation'

PWC PWC

ArXivPapers With CodeAboutExperimentsImplementationsPre-trained Models

Dive in with our example notebook in Colab!

About

FMix is a variant of MixUp, CutMix, etc. introduced in our paper 'FMix: Enhancing Mixed Sampled Data Augmentation'. It uses masks sampled from Fourier space to mix training examples. Take a look at our example notebook in colab which shows how you can generate masks in two dimensions

and in three!

Experiments

Core Experiments

Shell scripts for our core experiments can be found in the experiments folder. For example,

bash cifar_experiment cifar10 resnet fmix ./data

will train a PreAct-ResNet18 on CIFAR-10 with FMix. More information can be found at the start of each of the shell files.

Additional Experiments

All additional classification experiments can be run via trainer.py

Analyses

For Grad-CAM, take a look at the Grad-CAM notebook in colab.

For the other analyses, have a look in the analysis folder.

Implementations

The core implementation of FMix uses numpy and can be found in fmix.py. We provide bindings for this in PyTorch (with Torchbearer or PyTorch-Lightning) and Tensorflow.

Torchbearer

The FMix callback in torchbearer_implementation.py can be added directly to your torchbearer code:

from implementations.torchbearer_implementation import FMix

fmix = FMix()
trial = Trial(model, optimiser, fmix.loss(), callbacks=[fmix])

See an example in test_torchbearer.py.

PyTorch-Lightning

For PyTorch-Lightning, we provide a class, FMix in lightning.py that can be used in your LightningModule:

from implementations.lightning import FMix

class CoolSystem(pl.LightningModule):
    def __init__(self):
        ...
        
        self.fmix = FMix()
    
    def training_step(self, batch, batch_nb):
        x, y = batch
        x = self.fmix(x)

        x = self.forward(x)

        loss = self.fmix.loss(x, y)
        return {'loss': loss}

See an example in test_lightning.py.

Tensorflow

For Tensorflow, we provide a class, FMix in tensorflow_implementation.py that can be used in your tensorflow code:

from implementations.tensorflow_implementation import FMix

fmix = FMix()

def loss(model, x, y, training=True):
    x = fmix(x)
    y_ = model(x, training=training)
    return tf.reduce_mean(fmix.loss(y_, y))

See an example in test_tensorflow.py.

Pre-trained Models

We provide pre-trained models via torch.hub (more coming soon). To use them, run

import torch
model = torch.hub.load('ecs-vlc/FMix:master', ARCHITECTURE, pretrained=True)

where ARCHITECTURE is one of the following:

CIFAR-10

PreAct-ResNet-18

Configuration ARCHITECTURE Accuracy
Baseline 'preact_resnet18_cifar10_baseline' --------
+ MixUp 'preact_resnet18_cifar10_mixup' --------
+ FMix 'preact_resnet18_cifar10_fmix' --------
+ Mixup + FMix 'preact_resnet18_cifar10_fmixplusmixup' --------

PyramidNet-200

Configuration ARCHITECTURE Accuracy
Baseline 'pyramidnet_cifar10_baseline' 98.31
+ MixUp 'pyramidnet_cifar10_mixup' 97.92
+ FMix 'pyramidnet_cifar10_fmix' 98.64

ImageNet

ResNet-101

Configuration ARCHITECTURE Accuracy (Top-1)
Baseline 'renset101_imagenet_baseline' 76.51
+ MixUp 'renset101_imagenet_mixup' 76.27
+ FMix 'renset101_imagenet_fmix' 76.72