Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Swav improvements #899

2 changes: 0 additions & 2 deletions pl_bolts/models/self_supervised/ssl_finetuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
from torchmetrics import Accuracy

from pl_bolts.models.self_supervised import SSLEvaluator
from pl_bolts.utils.stability import under_review


@under_review()
class SSLFineTuner(LightningModule):
"""Finetunes a self-supervised learning backbone using the standard evaluation protocol of a singler layer MLP
with 1024 units.
Expand Down
2 changes: 2 additions & 0 deletions pl_bolts/models/self_supervised/swav/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pl_bolts.models.self_supervised.swav.loss import SWAVLoss
from pl_bolts.models.self_supervised.swav.swav_module import SwAV
from pl_bolts.models.self_supervised.swav.swav_resnet import resnet18, resnet50
from pl_bolts.models.self_supervised.swav.transforms import (
Expand All @@ -13,4 +14,5 @@
"SwAVEvalDataTransform",
"SwAVFinetuneTransform",
"SwAVTrainDataTransform",
"SWAVLoss",
]
132 changes: 132 additions & 0 deletions pl_bolts/models/self_supervised/swav/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from typing import List, Optional, Union

import numpy as np
import torch
import torch.nn as nn
from torch import distributed as dist


class SWAVLoss(nn.Module):
def __init__(
self,
temperature: float,
crops_for_assign: tuple,
nmb_crops: tuple,
sinkhorn_iterations: int,
epsilon: float,
gpus: int,
num_nodes: int,
):
"""Implementation for SWAV loss function.

Args:
temperature: loss temperature
crops_for_assign: list of crop ids for computing assignment
nmb_crops: number of global and local crops, ex: [2, 6]
sinkhorn_iterations: iterations for sinkhorn normalization
epsilon: epsilon val for swav assignments
gpus: number of gpus per node used in training, passed to SwAV module
to manage the queue and select distributed sinkhorn
num_nodes: num_nodes: number of nodes to train on
"""
super().__init__()
self.temperature = temperature
self.crops_for_assign = crops_for_assign
self.softmax = nn.Softmax(dim=1)
self.sinkhorn_iterations = sinkhorn_iterations
self.epsilon = epsilon
self.nmb_crops = nmb_crops
self.gpus = gpus
self.num_nodes = num_nodes
if self.gpus * self.num_nodes > 1:
self.assignment_fn = self.distributed_sinkhorn
else:
self.assignment_fn = self.sinkhorn

def forward(
self,
output: torch.Tensor,
embedding: torch.Tensor,
prototype_weights: torch.Tensor,
batch_size: int,
queue: Optional[torch.Tensor] = None,
use_queue: bool = False,
) -> List[Union[torch.Tensor, torch.Tensor, bool]]:
loss = 0
for i, crop_id in enumerate(self.crops_for_assign):
with torch.no_grad():
out = output[batch_size * crop_id : batch_size * (crop_id + 1)]

# Time to use the queue
if queue is not None:
if use_queue or not torch.all(queue[i, -1, :] == 0):
use_queue = True
out = torch.cat((torch.mm(queue[i], prototype_weights.t()), out))
# fill the queue
queue[i, batch_size:] = self.queue[i, :-batch_size].clone()
queue[i, :batch_size] = embedding[crop_id * batch_size : (crop_id + 1) * batch_size]

# get assignments
q = torch.exp(out / self.epsilon).t()
q = self.assignment_fn(q, self.sinkhorn_iterations)[-batch_size:]

# cluster assignment prediction
subloss = 0
for v in np.delete(np.arange(np.sum(self.nmb_crops)), crop_id):
p = self.softmax(output[batch_size * v : batch_size * (v + 1)] / self.temperature)
subloss -= torch.mean(torch.sum(q * torch.log(p), dim=1))
loss += subloss / (np.sum(self.nmb_crops) - 1)
loss /= len(self.crops_for_assign)
return loss, queue, use_queue

def sinkhorn(self, Q, nmb_iters):
"""Implementation of Sinkhorn clustering."""
with torch.no_grad():
sum_Q = torch.sum(Q)
Q /= sum_Q

K, B = Q.shape

if self.gpus > 0:
u = torch.zeros(K).cuda()
r = torch.ones(K).cuda() / K
c = torch.ones(B).cuda() / B
else:
u = torch.zeros(K)
r = torch.ones(K) / K
c = torch.ones(B) / B

for _ in range(nmb_iters):
u = torch.sum(Q, dim=1)

Q *= (r / u).unsqueeze(1)
Q *= (c / torch.sum(Q, dim=0)).unsqueeze(0)

return (Q / torch.sum(Q, dim=0, keepdim=True)).t().float()

def distributed_sinkhorn(self, Q, nmb_iters):
"""Implementation of Distributed Sinkhorn."""
with torch.no_grad():
sum_Q = torch.sum(Q)
dist.all_reduce(sum_Q)
Q /= sum_Q

if self.gpus > 0:
u = torch.zeros(Q.shape[0]).cuda(non_blocking=True)
r = torch.ones(Q.shape[0]).cuda(non_blocking=True) / Q.shape[0]
c = torch.ones(Q.shape[1]).cuda(non_blocking=True) / (self.gpus * Q.shape[1])
else:
u = torch.zeros(Q.shape[0])
r = torch.ones(Q.shape[0]) / Q.shape[0]
c = torch.ones(Q.shape[1]) / (self.gpus * Q.shape[1])

curr_sum = torch.sum(Q, dim=1)
dist.all_reduce(curr_sum)

for _ in range(nmb_iters):
u = curr_sum
Q *= (r / u).unsqueeze(1)
Q *= (c / torch.sum(Q, dim=0)).unsqueeze(0)
curr_sum = torch.sum(Q, dim=1)
dist.all_reduce(curr_sum)
return (Q / torch.sum(Q, dim=0, keepdim=True)).t().float()
2 changes: 0 additions & 2 deletions pl_bolts/models/self_supervised/swav/swav_finetuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@
from pl_bolts.models.self_supervised.swav.swav_module import SwAV
from pl_bolts.models.self_supervised.swav.transforms import SwAVFinetuneTransform
from pl_bolts.transforms.dataset_normalizations import imagenet_normalization, stl10_normalization
from pl_bolts.utils.stability import under_review


@under_review()
def cli_main(): # pragma: no cover
from pl_bolts.datamodules import ImagenetDataModule, STL10DataModule

Expand Down
110 changes: 22 additions & 88 deletions pl_bolts/models/self_supervised/swav/swav_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
import os
from argparse import ArgumentParser

import numpy as np
import torch
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from torch import distributed as dist
from torch import nn

from pl_bolts.models.self_supervised.swav.loss import SWAVLoss
from pl_bolts.models.self_supervised.swav.swav_resnet import resnet18, resnet50
from pl_bolts.optimizers.lars import LARS
from pl_bolts.optimizers.lr_scheduler import linear_warmup_decay
Expand All @@ -17,10 +16,8 @@
imagenet_normalization,
stl10_normalization,
)
from pl_bolts.utils.stability import under_review


@under_review()
class SwAV(LightningModule):
def __init__(
self,
Expand Down Expand Up @@ -129,19 +126,23 @@ def __init__(
self.warmup_epochs = warmup_epochs
self.max_epochs = max_epochs

if self.gpus * self.num_nodes > 1:
self.get_assignments = self.distributed_sinkhorn
else:
self.get_assignments = self.sinkhorn

self.model = self.init_model()
self.criterion = SWAVLoss(
gpus=self.gpus,
num_nodes=self.num_nodes,
temperature=self.temperature,
crops_for_assign=self.crops_for_assign,
nmb_crops=self.nmb_crops,
sinkhorn_iterations=self.sinkhorn_iterations,
epsilon=self.epsilon,
)

# compute iters per epoch
global_batch_size = self.num_nodes * self.gpus * self.batch_size if self.gpus > 0 else self.batch_size
self.train_iters_per_epoch = self.num_samples // global_batch_size

self.queue = None
self.softmax = nn.Softmax(dim=1)
# self.softmax = nn.Softmax(dim=1)

def setup(self, stage):
if self.queue_length > 0:
Expand Down Expand Up @@ -216,33 +217,17 @@ def shared_step(self, batch):
embedding = embedding.detach()
bs = inputs[0].size(0)

# 3. swav loss computation
loss = 0
for i, crop_id in enumerate(self.crops_for_assign):
with torch.no_grad():
out = output[bs * crop_id : bs * (crop_id + 1)]

# 4. time to use the queue
if self.queue is not None:
if self.use_the_queue or not torch.all(self.queue[i, -1, :] == 0):
self.use_the_queue = True
out = torch.cat((torch.mm(self.queue[i], self.model.prototypes.weight.t()), out))
# fill the queue
self.queue[i, bs:] = self.queue[i, :-bs].clone()
self.queue[i, :bs] = embedding[crop_id * bs : (crop_id + 1) * bs]

# 5. get assignments
q = torch.exp(out / self.epsilon).t()
q = self.get_assignments(q, self.sinkhorn_iterations)[-bs:]

# cluster assignment prediction
subloss = 0
for v in np.delete(np.arange(np.sum(self.nmb_crops)), crop_id):
p = self.softmax(output[bs * v : bs * (v + 1)] / self.temperature)
subloss -= torch.mean(torch.sum(q * torch.log(p), dim=1))
loss += subloss / (np.sum(self.nmb_crops) - 1)
loss /= len(self.crops_for_assign)

# SWAV loss computation
loss, queue, use_queue = self.criterion(
output=output,
embedding=embedding,
prototype_weights=self.model.prototypes.weight,
batch_size=bs,
queue=self.queue,
use_queue=self.use_the_queue,
)
self.queue = queue
self.use_the_queue = use_queue
return loss

def training_step(self, batch, batch_idx):
Expand Down Expand Up @@ -302,56 +287,6 @@ def configure_optimizers(self):

return [optimizer], [scheduler]

def sinkhorn(self, Q, nmb_iters):
with torch.no_grad():
sum_Q = torch.sum(Q)
Q /= sum_Q

K, B = Q.shape

if self.gpus > 0:
u = torch.zeros(K).cuda()
r = torch.ones(K).cuda() / K
c = torch.ones(B).cuda() / B
else:
u = torch.zeros(K)
r = torch.ones(K) / K
c = torch.ones(B) / B

for _ in range(nmb_iters):
u = torch.sum(Q, dim=1)

Q *= (r / u).unsqueeze(1)
Q *= (c / torch.sum(Q, dim=0)).unsqueeze(0)

return (Q / torch.sum(Q, dim=0, keepdim=True)).t().float()

def distributed_sinkhorn(self, Q, nmb_iters):
with torch.no_grad():
sum_Q = torch.sum(Q)
dist.all_reduce(sum_Q)
Q /= sum_Q

if self.gpus > 0:
u = torch.zeros(Q.shape[0]).cuda(non_blocking=True)
r = torch.ones(Q.shape[0]).cuda(non_blocking=True) / Q.shape[0]
c = torch.ones(Q.shape[1]).cuda(non_blocking=True) / (self.gpus * Q.shape[1])
else:
u = torch.zeros(Q.shape[0])
r = torch.ones(Q.shape[0]) / Q.shape[0]
c = torch.ones(Q.shape[1]) / (self.gpus * Q.shape[1])

curr_sum = torch.sum(Q, dim=1)
dist.all_reduce(curr_sum)

for it in range(nmb_iters):
u = curr_sum
Q *= (r / u).unsqueeze(1)
Q *= (c / torch.sum(Q, dim=0)).unsqueeze(0)
curr_sum = torch.sum(Q, dim=1)
dist.all_reduce(curr_sum)
return (Q / torch.sum(Q, dim=0, keepdim=True)).t().float()

@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
Expand Down Expand Up @@ -446,7 +381,6 @@ def add_model_specific_args(parent_parser):
return parser


@under_review()
def cli_main():
from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule
Expand Down
Loading