From 2d1ebee425b12ad6a784f8ed3087355678c07350 Mon Sep 17 00:00:00 2001 From: Christoph Gerum Date: Tue, 28 Nov 2023 17:09:14 +0100 Subject: [PATCH] CIFAR10: extend experiment --- .../cifar10/augmentation/cifar_augment.yaml | 14 +++++++++++++- experiments/cifar10/config.yaml | 10 +++++----- experiments/cifar10/experiment/sweep_lr.yaml | 2 +- external/hannah-tvm | 2 +- hannah/models/timm.py | 9 +++++++-- 5 files changed, 27 insertions(+), 10 deletions(-) diff --git a/experiments/cifar10/augmentation/cifar_augment.yaml b/experiments/cifar10/augmentation/cifar_augment.yaml index c6365cd1..24d4f70c 100644 --- a/experiments/cifar10/augmentation/cifar_augment.yaml +++ b/experiments/cifar10/augmentation/cifar_augment.yaml @@ -1,8 +1,20 @@ batch_augment: pipeline: null transforms: - RandomVerticalFlip: + #RandomVerticalFlip: + # p: 0.5 + RandomHorizontalFlip: + p: 0.5 + RandomAffine: + degrees: [-15, 15] + translate: [0.1, 0.1] + scale: [0.9, 1.1] + shear: [-5, 5] p: 0.5 RandomCrop: size: [32,32] padding: 4 + RandomErasing: + p: 0.5 + #scale: [0.!, 0.3] + #value: [0.4914, 0.4822, 0.4465] diff --git a/experiments/cifar10/config.yaml b/experiments/cifar10/config.yaml index c4749fa3..62640ef5 100644 --- a/experiments/cifar10/config.yaml +++ b/experiments/cifar10/config.yaml @@ -12,14 +12,14 @@ defaults: monitor: - metric: val_f1_micro + metric: val_accuracy direction: maximize module: - batch_size: 64 + batch_size: 512 trainer: - max_epochs: 50 + max_epochs: 30 -scheduler: - max_lr: 0.1 +optimizer: + lr: 0.3 diff --git a/experiments/cifar10/experiment/sweep_lr.yaml b/experiments/cifar10/experiment/sweep_lr.yaml index d3424177..f9995d03 100644 --- a/experiments/cifar10/experiment/sweep_lr.yaml +++ b/experiments/cifar10/experiment/sweep_lr.yaml @@ -6,4 +6,4 @@ hydra: subdir: lr=${scheduler.max_lr} sweeper: params: - scheduler.max_lr: 0.0001,0.001,0.01,0.1 + scheduler.max_lr: 0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9 diff --git a/external/hannah-tvm b/external/hannah-tvm index 0ede7e36..d2410e8d 160000 --- a/external/hannah-tvm +++ b/external/hannah-tvm @@ -1 +1 @@ -Subproject commit 0ede7e36f0cb295e322031dd6e9adad860acca7f +Subproject commit d2410e8d098a3054bcabe269853188ab13bfb55f diff --git a/hannah/models/timm.py b/hannah/models/timm.py index 71b39b3f..280aa98a 100644 --- a/hannah/models/timm.py +++ b/hannah/models/timm.py @@ -228,12 +228,16 @@ def __init__( if hasattr(self.encoder, "conv1"): input_conv = self.encoder.conv1 out_channels = input_conv.out_channels - new_conv = torch.nn.Conv2d(input_channels, out_channels, 3, 1) + new_conv = torch.nn.Conv2d( + input_channels, out_channels, 3, 1, padding=1 + ) self.encoder.conv1 = new_conv elif hasattr(self.encoder, "conv_stem"): input_conv = self.encoder.conv_stem out_channels = input_conv.out_channels - new_conv = torch.nn.Conv2d(input_channels, out_channels, 3, 1) + new_conv = torch.nn.Conv2d( + input_channels, out_channels, 3, 1, padding=1 + ) self.encoder.conv_stem = new_conv else: logger.critical( @@ -242,6 +246,7 @@ def __init__( if hasattr(self.encoder, "maxpool"): self.encoder.maxpool = torch.nn.Identity() + elif stem == "default": logger.info("""Using default stem for pulp model""") else: