From f0d7dec85587e97964299187c8b77e05d3801430 Mon Sep 17 00:00:00 2001 From: Christoph Gerum Date: Thu, 30 Nov 2023 13:32:46 +0100 Subject: [PATCH] Update kakao resnet --- experiments/rhode_island/config.yaml | 3 ++- experiments/rhode_island/dvc.yaml | 7 ------- hannah/callbacks/backends.py | 16 +++++++++++++--- hannah/models/kakao_resnet.py | 23 +++++++++++++++++++++-- 4 files changed, 36 insertions(+), 13 deletions(-) delete mode 100644 experiments/rhode_island/dvc.yaml diff --git a/experiments/rhode_island/config.yaml b/experiments/rhode_island/config.yaml index 112e67c1..acb3d525 100644 --- a/experiments/rhode_island/config.yaml +++ b/experiments/rhode_island/config.yaml @@ -22,6 +22,7 @@ defaults: - override features: identity # Feature extractor configuration name (use identity for vision datasets) - override model: timm_mobilenetv3_small_075 # Neural network name (for now timm_resnet50 or timm_efficientnet_lite1) - override scheduler: 1cycle # learning rate scheduler config name + - override augmentation: ri_augment - override optimizer: sgd # Optimizer config name - override normalizer: null # Feature normalizer (used for quantized neural networks) - override module: image_classifier # Lightning module config for the training loop (image classifier for image classification tasks) @@ -31,7 +32,7 @@ dataset: data_folder: ${oc.env:HANNAH_DATA_FOLDER,${hydra:runtime.cwd}/../../datasets/} module: - batch_size: 128 + batch_size: 512 trainer: max_epochs: 15 diff --git a/experiments/rhode_island/dvc.yaml b/experiments/rhode_island/dvc.yaml deleted file mode 100644 index 28a9d60a..00000000 --- a/experiments/rhode_island/dvc.yaml +++ /dev/null @@ -1,7 +0,0 @@ -stages: - train_baseline: - cmd: ./scripts/train_all_baselines.sh - deps: - - scripts/train_all_baselines.sh - outs: - - trained_models/baseline diff --git a/hannah/callbacks/backends.py b/hannah/callbacks/backends.py index 18ec6f50..91f5864a 100644 --- a/hannah/callbacks/backends.py +++ b/hannah/callbacks/backends.py @@ -42,7 +42,9 @@ except ModuleNotFoundError: onnxrt_backend = None -from ..models.factory.qat import QAT_MODULE_MAPPINGS +from typing import Mapping + +from ..nn.qat import QAT_MODULE_MAPPINGS logger = logging.getLogger(__name__) @@ -172,6 +174,8 @@ def on_test_epoch_start(self, trainer, pl_module): Returns: """ + logger.info("Exporting module") + pl_module = self.quantize(pl_module) self.prepare(pl_module) self.export() @@ -222,8 +226,14 @@ def on_test_batch_end( """ if batch_idx < self.test_batches: - result = self.run_batch(inputs=batch[0]) - target = pl_module(batch[0].to(pl_module.device)) + # decode batches from target device + if isinstance(batch, Mapping) or isinstance(batch, dict): + inputs = batch["data"] + else: + inputs = batch[0] + + result = self.run_batch(inputs=inputs) + target = pl_module(inputs.to(pl_module.device)) target = target[: result.shape[0]] mse = torch.nn.functional.mse_loss( diff --git a/hannah/models/kakao_resnet.py b/hannah/models/kakao_resnet.py index cb5e6bfe..81c2f512 100644 --- a/hannah/models/kakao_resnet.py +++ b/hannah/models/kakao_resnet.py @@ -79,9 +79,28 @@ def resnet8(*args, **kwargs): def resnet8_025(*args, **kwargs): - num_class = 10 + num_class = 4 + model = nn.Sequential( + conv_bn(3, 16, kernel_size=8, stride=8, padding=0), + conv_bn(16, 32, kernel_size=5, stride=2, padding=2), + Residual(nn.Sequential(conv_bn(32, 32), conv_bn(32, 32))), + conv_bn(32, 64, kernel_size=3, stride=1, padding=1), + nn.MaxPool2d(2), + Residual(nn.Sequential(conv_bn(64, 64), conv_bn(64, 64))), + conv_bn(64, 32, kernel_size=3, stride=1, padding=0), + nn.AdaptiveMaxPool2d((1, 1)), + Flatten(), + nn.Linear(32, num_class, bias=False), + Mul(0.2), + ) + + return model + + +def resnet8_012(*args, **kwargs): + num_class = 4 model = nn.Sequential( - conv_bn(3, 16, kernel_size=4, stride=2, padding=0), + conv_bn(3, 16, kernel_size=16, stride=16, padding=0), conv_bn(16, 32, kernel_size=5, stride=2, padding=2), Residual(nn.Sequential(conv_bn(32, 32), conv_bn(32, 32))), conv_bn(32, 64, kernel_size=3, stride=1, padding=1),