Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revision of BYOL module and tests #901

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
70b2829
Remove review tag. Update docstring.
matsumotosan Aug 25, 2022
de54fb8
Test with SimCLR transforms instead of CPC.
matsumotosan Aug 25, 2022
e097b0c
Add method to calculate loss for given online and target view. Update…
matsumotosan Aug 25, 2022
213ba1f
Add initial value of tau as an argument
matsumotosan Aug 25, 2022
815def1
Add encode method to calculate representation without calculating pro…
matsumotosan Aug 25, 2022
d185785
Merge branch 'master' into byol_module
matsumotosan Aug 26, 2022
198fa95
Rename losses to clarify direction
matsumotosan Aug 26, 2022
c26a98a
Rename projector_hidden_size to projector_hidden_dim
matsumotosan Aug 28, 2022
193223e
Merge branch 'master' into byol_module
matsumotosan Sep 12, 2022
54f2630
Merge branch 'master' into byol_module
otaj Sep 15, 2022
ab638ef
Merge branch 'master' into byol_module
matsumotosan Sep 15, 2022
aac6376
Merge branch 'master' into byol_module
otaj Sep 16, 2022
1e3cc4f
Merge branch 'master' into byol_module
Borda Sep 19, 2022
1c013b8
Merge branch 'master' into byol_module
matsumotosan Sep 20, 2022
7b0682e
Merge branch 'master' into byol_module
matsumotosan Sep 21, 2022
3a443f9
Merge branch 'master' into byol_module
matsumotosan Sep 22, 2022
22ed1e5
Merge branch 'master' into byol_module
matsumotosan Sep 23, 2022
bb1ada9
Merge branch 'master' into byol_module
matsumotosan Sep 24, 2022
5414a06
Move logging to shared_step. Remove unused arguments from init.
matsumotosan Sep 20, 2022
2abd592
Use CIFAR10 datamodule add dataset specific args. Fix BYOL module docs.
matsumotosan Sep 23, 2022
1443878
Add typing to parser argument
matsumotosan Sep 24, 2022
cd10474
fix tests
Sep 27, 2022
7e58019
Remove byol test todo. Remove max_steps trainer arg from cli_main.
matsumotosan Sep 27, 2022
efeaec5
Fix max_epoch argparse conflict with trainer argparse method.
matsumotosan Sep 28, 2022
22f8762
Remove data_dir arg from docs and BYOL initialization
matsumotosan Oct 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 84 additions & 99 deletions pl_bolts/models/self_supervised/byol/byol_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,39 @@

import torch
from pytorch_lightning import LightningModule, Trainer, seed_everything
from torch import Tensor
from torch.nn import functional as F
from torch.optim import Adam

from pl_bolts.callbacks.byol_updates import BYOLMAWeightUpdate
from pl_bolts.models.self_supervised.byol.models import SiameseArm
from pl_bolts.models.self_supervised.byol.models import MLP, SiameseArm
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
from pl_bolts.utils.stability import under_review


@under_review()
class BYOL(LightningModule):
"""PyTorch Lightning implementation of Bootstrap Your Own Latent (BYOL_)_

Paper authors: Jean-Bastien Grill, Florian Strub, Florent Altché, Corentin Tallec, Pierre H. Richemond, \
Elena Buchatskaya, Carl Doersch, Bernardo Avila Pires, Zhaohan Daniel Guo, Mohammad Gheshlaghi Azar, \
Bilal Piot, Koray Kavukcuoglu, Rémi Munos, Michal Valko.

Args:
learning_rate (float, optional): optimizer learning rate. Defaults to 0.2.
weight_decay (float, optional): optimizer weight decay. Defaults to 1.5e-6.
warmup_epochs (int, optional): number of epochs for scheduler warmup. Defaults to 10.
max_epochs (int, optional): maximum number of epochs for scheduler. Defaults to 1000.
base_encoder (Union[str, torch.nn.Module], optional): base encoder architecture. Defaults to "resnet50".
encoder_out_dim (int, optional): base encoder output dimension. Defaults to 2048.
projector_hidden_dim (int, optional): projector MLP hidden dimension. Defaults to 4096.
projector_out_dim (int, optional): projector MLP output dimension. Defaults to 256.
initial_tau (float, optional): initial value of target decay rate used. Defaults to 0.996.

Model implemented by:
- `Annika Brundyn <https://github.com/annikabrundyn>`_

.. warning:: Work in progress. This implementation is still being verified.

TODOs:
- verify on CIFAR-10
- verify on STL-10
- pre-train on imagenet

Example::

model = BYOL(num_classes=10)
model = BYOL()

dm = CIFAR10DataModule(num_workers=0)
dm.train_transforms = SimCLRTrainDataTransform(32)
Expand All @@ -42,11 +45,6 @@ class BYOL(LightningModule):
trainer = pl.Trainer()
trainer.fit(model, datamodule=dm)

Train::

trainer = Trainer()
trainer.fit(model)

CLI command::

# cifar10
Expand All @@ -56,7 +54,6 @@ class BYOL(LightningModule):
python byol_module.py
--gpus 8
--dataset imagenet2012
--data_dir /path/to/imagenet/
--meta_dir /path/to/folder/with/meta.bin/
--batch_size 32

Expand All @@ -65,87 +62,82 @@ class BYOL(LightningModule):

def __init__(
self,
num_classes,
learning_rate: float = 0.2,
weight_decay: float = 1.5e-6,
input_height: int = 32,
batch_size: int = 32,
num_workers: int = 0,
warmup_epochs: int = 10,
max_epochs: int = 1000,
base_encoder: Union[str, torch.nn.Module] = "resnet50",
encoder_out_dim: int = 2048,
projector_hidden_size: int = 4096,
projector_hidden_dim: int = 4096,
projector_out_dim: int = 256,
**kwargs
):
"""
Args:
datamodule: The datamodule
learning_rate: the learning rate
weight_decay: optimizer weight decay
input_height: image input height
batch_size: the batch size
num_workers: number of workers
warmup_epochs: num of epochs for scheduler warm up
max_epochs: max epochs for scheduler
base_encoder: the base encoder module or resnet name
encoder_out_dim: output dimension of base_encoder
projector_hidden_size: hidden layer size of projector MLP
projector_out_dim: output size of projector MLP
"""
initial_tau: float = 0.996,
**kwargs: Any,
) -> None:

super().__init__()
self.save_hyperparameters(ignore="base_encoder")

self.online_network = SiameseArm(base_encoder, encoder_out_dim, projector_hidden_size, projector_out_dim)
self.online_network = SiameseArm(base_encoder, encoder_out_dim, projector_hidden_dim, projector_out_dim)
self.target_network = deepcopy(self.online_network)
self.weight_callback = BYOLMAWeightUpdate()
self.predictor = MLP(projector_out_dim, projector_hidden_dim, projector_out_dim)

def on_train_batch_end(self, outputs, batch: Any, batch_idx: int) -> None:
# Add callback for user automatically since it's key to BYOL weight update
self.weight_callback = BYOLMAWeightUpdate(initial_tau=initial_tau)

def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int) -> None:
"""Add callback to perform exponential moving average weight update on target network."""
self.weight_callback.on_train_batch_end(self.trainer, self, outputs, batch, batch_idx)

def forward(self, x):
y, _, _ = self.online_network(x)
return y
def forward(self, x: Tensor) -> Tensor:
"""Returns the encoded representation of a view.

def shared_step(self, batch, batch_idx):
imgs, y = batch
img_1, img_2 = imgs[:2]
Args:
x (Tensor): sample to be encoded
"""
return self.online_network.encode(x)

# Image 1 to image 2 loss
y1, z1, h1 = self.online_network(img_1)
with torch.no_grad():
y2, z2, h2 = self.target_network(img_2)
loss_a = -2 * F.cosine_similarity(h1, z2).mean()
def training_step(self, batch: Any, batch_idx: int) -> Tensor:
"""Complete training loop."""
return self._shared_step(batch, batch_idx, "train")

# Image 2 to image 1 loss
y1, z1, h1 = self.online_network(img_2)
with torch.no_grad():
y2, z2, h2 = self.target_network(img_1)
# L2 normalize
loss_b = -2 * F.cosine_similarity(h1, z2).mean()
def validation_step(self, batch: Any, batch_idx: int) -> Tensor:
"""Complete validation loop."""
return self._shared_step(batch, batch_idx, "val")

# Final loss
total_loss = loss_a + loss_b
def _shared_step(self, batch: Any, batch_idx: int, step: str) -> Tensor:
"""Shared evaluation step for training and validation loop."""
imgs, _ = batch
img1, img2 = imgs[:2]

return loss_a, loss_b, total_loss
# Calculate similarity loss in each direction
loss_12 = self.calculate_loss(img1, img2)
loss_21 = self.calculate_loss(img2, img1)

def training_step(self, batch, batch_idx):
loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx)
# Calculate total loss
total_loss = loss_12 + loss_21

# log results
self.log_dict({"1_2_loss": loss_a, "2_1_loss": loss_b, "train_loss": total_loss})
# Log losses
if step == "train":
self.log_dict({"train_loss_12": loss_12, "train_loss_21": loss_21, "train_loss": total_loss})
elif step == "val":
self.log_dict({"val_loss_12": loss_12, "val_loss_21": loss_21, "val_loss": total_loss})
else:
raise ValueError(f"Step '{step}' is invalid. Must be 'train' or 'val'.")

return total_loss

def validation_step(self, batch, batch_idx):
loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx)
def calculate_loss(self, v_online: Tensor, v_target: Tensor) -> Tensor:
"""Calculates similarity loss between the online network prediction of target network projection.

# log results
self.log_dict({"1_2_loss": loss_a, "2_1_loss": loss_b, "val_loss": total_loss})

return total_loss
Args:
v_online (Tensor): Online network view
v_target (Tensor): Target network view
"""
_, z1 = self.online_network(v_online)
h1 = self.predictor(z1)
with torch.no_grad():
_, z2 = self.target_network(v_target)
loss = -2 * F.cosine_similarity(h1, z2).mean()
return loss

def configure_optimizers(self):
optimizer = Adam(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)
Expand All @@ -155,30 +147,23 @@ def configure_optimizers(self):
return [optimizer], [scheduler]

@staticmethod
def add_model_specific_args(parent_parser):
def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--online_ft", action="store_true", help="run online finetuner")
parser.add_argument("--dataset", type=str, default="cifar10", choices=["cifar10", "imagenet2012", "stl10"])

(args, _) = parser.parse_known_args()
args = parser.parse_args([])

# Data
parser.add_argument("--data_dir", type=str, default=".")
parser.add_argument("--num_workers", default=8, type=int)
if "max_epochs" in args:
parser.set_defaults(max_epochs=1000)
else:
parser.add_argument("--max_epochs", type=int, default=1000)

# optim
parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--learning_rate", type=float, default=1e-3)
parser.add_argument("--learning_rate", type=float, default=0.2)
parser.add_argument("--weight_decay", type=float, default=1.5e-6)
parser.add_argument("--warmup_epochs", type=float, default=10)

# Model
parser.add_argument("--warmup_epochs", type=int, default=10)
parser.add_argument("--meta_dir", default=".", type=str, help="path to meta.bin for imagenet")

return parser


@under_review()
def cli_main():
from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule
Expand All @@ -188,23 +173,19 @@ def cli_main():

parser = ArgumentParser()

# trainer args
parser = Trainer.add_argparse_args(parser)

# model args
parser = BYOL.add_model_specific_args(parser)
args = parser.parse_args()
parser = CIFAR10DataModule.add_dataset_specific_args(parser)
parser.add_argument("--dataset", type=str, default="cifar10", choices=["cifar10", "imagenet2012", "stl10"])

# pick data
dm = None
args = parser.parse_args()

# init default datamodule
# Initialize datamodule
if args.dataset == "cifar10":
dm = CIFAR10DataModule.from_argparse_args(args)
dm.train_transforms = SimCLRTrainDataTransform(32)
dm.val_transforms = SimCLREvalDataTransform(32)
args.num_classes = dm.num_classes

elif args.dataset == "stl10":
dm = STL10DataModule.from_argparse_args(args)
dm.train_dataloader = dm.train_dataloader_mixed
Expand All @@ -214,20 +195,24 @@ def cli_main():
dm.train_transforms = SimCLRTrainDataTransform(h)
dm.val_transforms = SimCLREvalDataTransform(h)
args.num_classes = dm.num_classes

elif args.dataset == "imagenet2012":
dm = ImagenetDataModule.from_argparse_args(args, image_size=196)
(c, h, w) = dm.dims
dm.train_transforms = SimCLRTrainDataTransform(h)
dm.val_transforms = SimCLREvalDataTransform(h)
args.num_classes = dm.num_classes
else:
raise ValueError(
f"{args.dataset} is not a valid dataset. Dataset must be 'cifar10', 'stl10', or 'imagenet2012'."
)

model = BYOL(**args.__dict__)
# Initialize BYOL module
model = BYOL(**vars(args))

# finetune in real-time
online_eval = SSLOnlineEvaluator(dataset=args.dataset, z_dim=2048, num_classes=dm.num_classes)

trainer = Trainer.from_argparse_args(args, max_steps=300000, callbacks=[online_eval])
trainer = Trainer.from_argparse_args(args, callbacks=[online_eval])

trainer.fit(model, datamodule=dm)

Expand Down
Loading