From 07be227aecbb5a8b5fd1f6faffa0d77143a692b2 Mon Sep 17 00:00:00 2001 From: Mikhael Djajapermana Date: Mon, 27 Nov 2023 14:38:10 +0000 Subject: [PATCH] Recreate ResNet search space --- hannah/conf/config_resnet.yaml | 61 +++++++ hannah/conf/model/resnet.yaml | 3 + hannah/models/resnet/blocks.py | 55 ++++++ hannah/models/resnet/expressions.py | 24 +++ hannah/models/resnet/models.py | 248 +++------------------------- hannah/models/resnet/models_lazy.py | 229 +++++++++++++++++++++++++ hannah/models/resnet/operators.py | 68 ++++++++ test/test_lazy_resnet.py | 2 +- 8 files changed, 468 insertions(+), 222 deletions(-) create mode 100644 hannah/conf/config_resnet.yaml create mode 100644 hannah/conf/model/resnet.yaml create mode 100644 hannah/models/resnet/blocks.py create mode 100644 hannah/models/resnet/expressions.py create mode 100644 hannah/models/resnet/models_lazy.py create mode 100644 hannah/models/resnet/operators.py diff --git a/hannah/conf/config_resnet.yaml b/hannah/conf/config_resnet.yaml new file mode 100644 index 00000000..611aaed9 --- /dev/null +++ b/hannah/conf/config_resnet.yaml @@ -0,0 +1,61 @@ +## +## Copyright (c) 2022 University of Tübingen. +## +## This file is part of hannah. +## See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. +## +## Licensed under the Apache License, Version 2.0 (the "License"); +## you may not use this file except in compliance with the License. +## You may obtain a copy of the License at +## +## http://www.apache.org/licenses/LICENSE-2.0 +## +## Unless required by applicable law or agreed to in writing, software +## distributed under the License is distributed on an "AS IS" BASIS, +## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +## See the License for the specific language governing permissions and +## limitations under the License. +## +defaults: + - base_config + - override dataset: cifar10 # Dataset configuration name + - override features: identity # Feature extractor configuration name (use identity for vision datasets) + - override model: resnet # Neural network name + - override scheduler: 1cycle # learning rate scheduler config name + - override optimizer: adamw # Optimizer config name + - override normalizer: null # Feature normalizer (used for quantized neural networks) + - override module: image_classifier # Lightning module config for the training loop (image classifier for image classification tasks) + - override nas: aging_evolution_nas + - _self_ + + +# dataset: +# data_folder: ${oc.env:HANNAH_DATA_FOLDER,${hydra:runtime.cwd}/../../datasets/} + +experiment_id: "resnet_nas" + +seed: [1234] + +model: + num_classes: 10 + +nas: + budget: 500 + n_jobs: 4 + total_candidates: 100 + num_selected_candidates: 10 + sampler: + population_size: 50 + sample_size: 10 + +module: + batch_size: 64 + num_workers: 4 + +trainer: + max_epochs: 10 + +scheduler: + max_lr: 0.001 + +fx_mac_summary: True diff --git a/hannah/conf/model/resnet.yaml b/hannah/conf/model/resnet.yaml new file mode 100644 index 00000000..088ea3ba --- /dev/null +++ b/hannah/conf/model/resnet.yaml @@ -0,0 +1,3 @@ +_target_: hannah.models.resnet.models.search_space +name: resnet +num_classes: 10 \ No newline at end of file diff --git a/hannah/models/resnet/blocks.py b/hannah/models/resnet/blocks.py new file mode 100644 index 00000000..78b6c184 --- /dev/null +++ b/hannah/models/resnet/blocks.py @@ -0,0 +1,55 @@ +from functools import partial +from hannah.models.embedded_vision_net.expressions import expr_product +from hannah.nas.expressions.arithmetic import Ceil +from hannah.nas.expressions.types import Int +from hannah.nas.functional_operators.op import scope +from hannah.models.embedded_vision_net.operators import adaptive_avg_pooling, add, conv2d, conv_relu, depthwise_conv2d, dynamic_depth, pointwise_conv2d, linear, relu, batch_norm, choice, identity +from hannah.nas.parameters.parameters import CategoricalParameter, IntScalarParameter + + +@scope +def conv_relu_bn(input, out_channels, kernel_size, stride): + out = conv2d(input, out_channels, kernel_size, stride) + out = batch_norm(out) + out = relu(out) + return out + + +@scope +def residual(input, main_branch_output_shape): + input_shape = input.shape() + in_fmap = input_shape[2] + out_channels = main_branch_output_shape[1] + out_fmap = main_branch_output_shape[2] + stride = Int(Ceil(in_fmap / out_fmap)) + + out = conv2d(input, out_channels=out_channels, kernel_size=1, stride=stride, padding=0) + out = batch_norm(out) + out = relu(out) + return out + + +@scope +def block(input, depth, out_channels, kernel_size, stride): + assert isinstance(depth, IntScalarParameter), "block depth must be of type IntScalarParameter" + out = input + exits = [] + for i in range(depth.max+1): + out = conv_relu_bn(out, + out_channels=out_channels.new(), + kernel_size=kernel_size.new(), + stride=stride.new() if i == 0 else 1) + exits.append(out) + + out = dynamic_depth(*exits, switch=depth) + res = residual(input, out.shape()) + out = add(out, res) + + return out + + +@scope +def classifier_head(input, num_classes): + out = choice(input, adaptive_avg_pooling) + out = linear(out, num_classes) + return out diff --git a/hannah/models/resnet/expressions.py b/hannah/models/resnet/expressions.py new file mode 100644 index 00000000..8364cc7b --- /dev/null +++ b/hannah/models/resnet/expressions.py @@ -0,0 +1,24 @@ +from hannah.nas.expressions.logic import And, If +from hannah.nas.expressions.arithmetic import Ceil + + +def padding_expression(kernel_size, stride, dilation=1): + """Symbolically calculate padding such that for a given kernel_size, stride and dilation + the padding is such that the output dimension is kept the same(stride=1) or halved(stride=2). + Note: If the input dimension is 1 and stride = 2, the calculated padding will result in + an output with also dimension 1. + + Parameters + ---------- + kernel_size : Union[int, Expression] + stride : Union[int, Expression] + dilation : Union[int, Expression], optional + _description_, by default 1 + + Returns + ------- + Expression + """ + # r = 1 - (kernel_size % 2) + p = (dilation * (kernel_size - 1) - stride + 1) / 2 + return Ceil(p) diff --git a/hannah/models/resnet/models.py b/hannah/models/resnet/models.py index 9b41f9fd..22e562a4 100644 --- a/hannah/models/resnet/models.py +++ b/hannah/models/resnet/models.py @@ -1,229 +1,35 @@ -import torch -import torch.nn as nn -from hannah.nas.expressions.shapes import conv2d_shape, identity_shape -from hannah.nas.parameters.lazy import Lazy -from hannah.nas.parameters.parametrize import parametrize -from hannah.nas.parameters.iterators import RangeIterator -from hannah.nas.parameters.parameters import IntScalarParameter, CategoricalParameter -from hannah.nas.expressions.arithmetic import Ceil -from hannah.nas.expressions.choice import SymbolicAttr, Choice -from hannah.nas.expressions.types import Int +from hannah.models.embedded_vision_net.expressions import expr_product +from hannah.nas.parameters.parameters import CategoricalParameter, IntScalarParameter +from hannah.models.resnet.operators import dynamic_depth +from hannah.models.resnet.blocks import block, conv_relu_bn, classifier_head -conv2d = Lazy(nn.Conv2d, shape_func=conv2d_shape) -linear = Lazy(nn.Linear) -batch_norm = Lazy(nn.BatchNorm2d, shape_func=identity_shape) -relu = Lazy(nn.ReLU) -tensor = Lazy(torch.Tensor, shape_func=identity_shape) +def search_space(name, input, num_classes=10): + out_channels = IntScalarParameter(16, 64, step_size=4, name='out_channels') + kernel_size = CategoricalParameter([3, 5, 7, 9], name='kernel_size') + stride = CategoricalParameter([1, 2], name='stride') -def padding_expression(kernel_size, stride, dilation = 1): - """Symbolically calculate padding such that for a given kernel_size, stride and dilation - the padding is such that the output dimension is kept the same(stride=1) or halved(stride=2). - Note: If the input dimension is 1 and stride = 2, the calculated padding will result in - an output with also dimension 1. + depth = IntScalarParameter(0, 2, name='depth') + num_blocks = IntScalarParameter(0, 6, name='num_blocks') - Parameters - ---------- - kernel_size : Union[int, Expression] - stride : Union[int, Expression] - dilation : Union[int, Expression], optional - _description_, by default 1 + stem_kernel_size = CategoricalParameter([3, 5], name="kernel_size") + stem_channels = IntScalarParameter(min=16, max=32, step_size=4, name="out_channels") + out = conv_relu_bn(input, stem_channels, stem_kernel_size, stride.new()) - Returns - ------- - Expression - """ - p = (dilation * (kernel_size - 1) - stride + 1) / 2 - return Ceil(p) + exits = [] + for i in range(num_blocks.max+1): + out = block(out, + depth=depth.new(), + out_channels=out_channels.new(), + kernel_size=kernel_size.new(), + stride=stride.new()) + exits.append(out) -def stride_product(expressions: list): - res = None - for expr in expressions: - if res: - res = res * expr - else: - res = expr - return res + out = dynamic_depth(*exits, switch=num_blocks) + out = classifier_head(out, num_classes=num_classes) + strides = [v for k, v in out.parametrization(flatten=True).items() if k.split('.')[-1] == 'stride'] + total_stride = expr_product(strides) + out.cond(input.shape()[2] / total_stride > 1) -@parametrize -class ConvReluBn(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride, id, inputs) -> None: - super().__init__() - self.id = id - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.stride = stride - - self.conv = conv2d(self.id + ".conv", - inputs=inputs, - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding_expression(kernel_size, stride)) - - self.shape = self.conv.shape - self.bn = batch_norm(self.id + ".bn", num_features=out_channels) - self.relu = relu(self.id + ".relu") - - def initialize(self): - self.tconv = self.conv.instantiate() - self.tbn = self.bn.instantiate() - self.trelu = self.relu.instantiate() - - def forward(self, x): - out = self.tconv(x) - out = self.tbn(out) - out = self.trelu(out) - return out - -@parametrize -class ResidualBlock(nn.Module): - def __init__(self, in_channels, out_channels, in_fmap_size, out_fmap_size, id, inputs) -> None: - super().__init__() - self.id = id - self.in_channels = in_channels - self.out_channels = out_channels - self.in_fmap = in_fmap_size - self.out_fmap = out_fmap_size - self.stride = Int(Ceil(in_fmap_size / out_fmap_size)) - self.conv = conv2d(id=self.id + ".residual_conv", - inputs=inputs, - in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - stride=self.stride, - padding=0) - self.activation = relu(self.id + '.relu') - - def initialize(self): - self.tconv = self.conv.instantiate() - self.tact = self.activation.instantiate() - - def forward(self, x): - out = self.tconv(x) - out = self.tact(out) - return out - - -@parametrize -class ConvReluBlock(nn.Module): - def __init__(self, params, input_shape, id, depth) -> None: - super().__init__() - self.input_shape = input_shape - self.depth = self.add_param(f'{id}.depth', depth) - self.mods = nn.ModuleList() - self.id = id - self.depth = depth - self.params = params - - self.strides = [] - - previous = input_shape - for d in RangeIterator(self.depth, instance=False): - in_channels = self.input_shape[1] if d == 0 else previous.out_channels - out_channels = self.add_param(f'{self.id}.conv{d}.out_channels', IntScalarParameter(self.params.conv.out_channels.min, - self.params.conv.out_channels.max, - self.params.conv.out_channels.step)) - kernel_size = self.add_param(f'{self.id}.conv{d}.kernel_size', CategoricalParameter(self.params.conv.kernel_size.choices)) - stride = self.add_param(f'{self.id}.conv{d}.stride', CategoricalParameter(self.params.conv.stride.choices)) - - self.strides.append(stride) - - layer = ConvReluBn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, id=f'{self.id}.{d}', inputs=[previous]) - self.mods.append(layer) - previous = layer - - self.last_layer = Choice(self.mods, self.depth - 1) - self.cond(stride_product(self.strides) <= self.input_shape[2]) - - def initialize(self): - for d in RangeIterator(self.depth, instance=False): - self.mods[d].initialize() - - def forward(self, x): - out = x - for d in RangeIterator(self.depth, instance=True): - out = self.mods[d](out) - return out - -@parametrize -class ClassifierHead(nn.Module): - def __init__(self, input, labels) -> None: - super().__init__() - self.labels = labels - in_features = input.get('shape')[1] * input.get('shape')[2] * input.get('shape')[3] - self._linear = self.add_param('linear', - linear("linear", - inputs=[input], - in_features=in_features, - out_features=self.labels)) - - def initialize(self): - self.linear = self._linear.instantiate() - - def forward(self, x): - out = x.view(x.shape[0], -1) - out = self.linear(out) - return out - - -@parametrize -class ResNet(nn.Module): - def __init__(self, name, params, input_shape, labels) -> None: - super().__init__() - self.input_shape = input_shape - self.labels = labels - self.depth = IntScalarParameter(params.depth.min, params.depth.max) - self.num_blocks = IntScalarParameter(params.num_blocks.min, params.num_blocks.max) - self.conv_blocks = nn.ModuleList() - self.residual_blocks = nn.ModuleList() - - next_input = self.input_shape - for n in RangeIterator(self.num_blocks, instance=False): - block = self.add_param(f"conv_block_{n}", - ConvReluBlock(params, - next_input, - f"conv_block_{n}", - self.depth)) - # last = Choice(block.mods, self.depth - 1) - residual_block = self.add_param(f"residual_block_{n}", - ResidualBlock(next_input[1], - block.last_layer.get('shape')[1], - next_input[2], - block.last_layer.get('shape')[2], - f"residual_block_{n}", - next_input)) - next_input = [block.last_layer.get("shape")[0], block.last_layer.get("shape")[1], block.last_layer.get("shape")[2], block.last_layer.get("shape")[3]] - self.conv_blocks.append(block) - self.residual_blocks.append(residual_block) - - last_block = Choice(self.conv_blocks, self.num_blocks - 1) - self.classifier = ClassifierHead(last_block.get("last_layer"), self.labels) - - - def initialize(self): - for n in RangeIterator(self.num_blocks, instance=False): - self.conv_blocks[n].initialize() - self.residual_blocks[n].initialize() - self.classifier.initialize() - - def forward(self, x): - out = x - for n in RangeIterator(self.num_blocks, instance=True): - block_out = self.conv_blocks[n](out) - res_out = self.residual_blocks[n](out) - out = torch.add(block_out, res_out) - out = block_out - - out = self.classifier(out) - return out - - def get_hparams(self): - params = {} - for key, param in self.parametrization(flatten=True).items(): - params[key] = param.current_value.item() - - return params \ No newline at end of file + return out diff --git a/hannah/models/resnet/models_lazy.py b/hannah/models/resnet/models_lazy.py new file mode 100644 index 00000000..9b41f9fd --- /dev/null +++ b/hannah/models/resnet/models_lazy.py @@ -0,0 +1,229 @@ +import torch +import torch.nn as nn +from hannah.nas.expressions.shapes import conv2d_shape, identity_shape +from hannah.nas.parameters.lazy import Lazy +from hannah.nas.parameters.parametrize import parametrize +from hannah.nas.parameters.iterators import RangeIterator +from hannah.nas.parameters.parameters import IntScalarParameter, CategoricalParameter +from hannah.nas.expressions.arithmetic import Ceil +from hannah.nas.expressions.choice import SymbolicAttr, Choice +from hannah.nas.expressions.types import Int + +conv2d = Lazy(nn.Conv2d, shape_func=conv2d_shape) +linear = Lazy(nn.Linear) +batch_norm = Lazy(nn.BatchNorm2d, shape_func=identity_shape) +relu = Lazy(nn.ReLU) +tensor = Lazy(torch.Tensor, shape_func=identity_shape) + + +def padding_expression(kernel_size, stride, dilation = 1): + """Symbolically calculate padding such that for a given kernel_size, stride and dilation + the padding is such that the output dimension is kept the same(stride=1) or halved(stride=2). + Note: If the input dimension is 1 and stride = 2, the calculated padding will result in + an output with also dimension 1. + + Parameters + ---------- + kernel_size : Union[int, Expression] + stride : Union[int, Expression] + dilation : Union[int, Expression], optional + _description_, by default 1 + + Returns + ------- + Expression + """ + p = (dilation * (kernel_size - 1) - stride + 1) / 2 + return Ceil(p) + +def stride_product(expressions: list): + res = None + for expr in expressions: + if res: + res = res * expr + else: + res = expr + return res + + +@parametrize +class ConvReluBn(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride, id, inputs) -> None: + super().__init__() + self.id = id + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + + self.conv = conv2d(self.id + ".conv", + inputs=inputs, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding_expression(kernel_size, stride)) + + self.shape = self.conv.shape + self.bn = batch_norm(self.id + ".bn", num_features=out_channels) + self.relu = relu(self.id + ".relu") + + def initialize(self): + self.tconv = self.conv.instantiate() + self.tbn = self.bn.instantiate() + self.trelu = self.relu.instantiate() + + def forward(self, x): + out = self.tconv(x) + out = self.tbn(out) + out = self.trelu(out) + return out + +@parametrize +class ResidualBlock(nn.Module): + def __init__(self, in_channels, out_channels, in_fmap_size, out_fmap_size, id, inputs) -> None: + super().__init__() + self.id = id + self.in_channels = in_channels + self.out_channels = out_channels + self.in_fmap = in_fmap_size + self.out_fmap = out_fmap_size + self.stride = Int(Ceil(in_fmap_size / out_fmap_size)) + self.conv = conv2d(id=self.id + ".residual_conv", + inputs=inputs, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=self.stride, + padding=0) + self.activation = relu(self.id + '.relu') + + def initialize(self): + self.tconv = self.conv.instantiate() + self.tact = self.activation.instantiate() + + def forward(self, x): + out = self.tconv(x) + out = self.tact(out) + return out + + +@parametrize +class ConvReluBlock(nn.Module): + def __init__(self, params, input_shape, id, depth) -> None: + super().__init__() + self.input_shape = input_shape + self.depth = self.add_param(f'{id}.depth', depth) + self.mods = nn.ModuleList() + self.id = id + self.depth = depth + self.params = params + + self.strides = [] + + previous = input_shape + for d in RangeIterator(self.depth, instance=False): + in_channels = self.input_shape[1] if d == 0 else previous.out_channels + out_channels = self.add_param(f'{self.id}.conv{d}.out_channels', IntScalarParameter(self.params.conv.out_channels.min, + self.params.conv.out_channels.max, + self.params.conv.out_channels.step)) + kernel_size = self.add_param(f'{self.id}.conv{d}.kernel_size', CategoricalParameter(self.params.conv.kernel_size.choices)) + stride = self.add_param(f'{self.id}.conv{d}.stride', CategoricalParameter(self.params.conv.stride.choices)) + + self.strides.append(stride) + + layer = ConvReluBn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, id=f'{self.id}.{d}', inputs=[previous]) + self.mods.append(layer) + previous = layer + + self.last_layer = Choice(self.mods, self.depth - 1) + self.cond(stride_product(self.strides) <= self.input_shape[2]) + + def initialize(self): + for d in RangeIterator(self.depth, instance=False): + self.mods[d].initialize() + + def forward(self, x): + out = x + for d in RangeIterator(self.depth, instance=True): + out = self.mods[d](out) + return out + +@parametrize +class ClassifierHead(nn.Module): + def __init__(self, input, labels) -> None: + super().__init__() + self.labels = labels + in_features = input.get('shape')[1] * input.get('shape')[2] * input.get('shape')[3] + self._linear = self.add_param('linear', + linear("linear", + inputs=[input], + in_features=in_features, + out_features=self.labels)) + + def initialize(self): + self.linear = self._linear.instantiate() + + def forward(self, x): + out = x.view(x.shape[0], -1) + out = self.linear(out) + return out + + +@parametrize +class ResNet(nn.Module): + def __init__(self, name, params, input_shape, labels) -> None: + super().__init__() + self.input_shape = input_shape + self.labels = labels + self.depth = IntScalarParameter(params.depth.min, params.depth.max) + self.num_blocks = IntScalarParameter(params.num_blocks.min, params.num_blocks.max) + self.conv_blocks = nn.ModuleList() + self.residual_blocks = nn.ModuleList() + + next_input = self.input_shape + for n in RangeIterator(self.num_blocks, instance=False): + block = self.add_param(f"conv_block_{n}", + ConvReluBlock(params, + next_input, + f"conv_block_{n}", + self.depth)) + # last = Choice(block.mods, self.depth - 1) + residual_block = self.add_param(f"residual_block_{n}", + ResidualBlock(next_input[1], + block.last_layer.get('shape')[1], + next_input[2], + block.last_layer.get('shape')[2], + f"residual_block_{n}", + next_input)) + next_input = [block.last_layer.get("shape")[0], block.last_layer.get("shape")[1], block.last_layer.get("shape")[2], block.last_layer.get("shape")[3]] + self.conv_blocks.append(block) + self.residual_blocks.append(residual_block) + + last_block = Choice(self.conv_blocks, self.num_blocks - 1) + self.classifier = ClassifierHead(last_block.get("last_layer"), self.labels) + + + def initialize(self): + for n in RangeIterator(self.num_blocks, instance=False): + self.conv_blocks[n].initialize() + self.residual_blocks[n].initialize() + self.classifier.initialize() + + def forward(self, x): + out = x + for n in RangeIterator(self.num_blocks, instance=True): + block_out = self.conv_blocks[n](out) + res_out = self.residual_blocks[n](out) + out = torch.add(block_out, res_out) + out = block_out + + out = self.classifier(out) + return out + + def get_hparams(self): + params = {} + for key, param in self.parametrization(flatten=True).items(): + params[key] = param.current_value.item() + + return params \ No newline at end of file diff --git a/hannah/models/resnet/operators.py b/hannah/models/resnet/operators.py new file mode 100644 index 00000000..678403ce --- /dev/null +++ b/hannah/models/resnet/operators.py @@ -0,0 +1,68 @@ +from hannah.models.resnet.expressions import padding_expression +from hannah.nas.functional_operators.op import Tensor, scope, ChoiceOp +from hannah.nas.functional_operators.operators import AdaptiveAvgPooling, Add, BatchNorm, Conv2d, Linear, Relu, Identity +from hannah.nas.expressions.types import Int + + +def conv2d(input, out_channels, kernel_size=1, stride=1, dilation=1, groups=1, padding=None): + in_channels = input.shape()[1] + weight = Tensor(name='weight', + shape=(out_channels, in_channels, kernel_size, kernel_size), + axis=('O', 'I', 'kH', 'kW'), + grad=True) + + conv = Conv2d(stride=stride, dilation=dilation, groups=groups, padding=padding)(input, weight) + return conv + + +def linear(input, out_features): + input_shape = input.shape() + in_features = input_shape[1] * input_shape[2] * input_shape[3] + weight = Tensor(name='weight', + shape=(in_features, out_features), + axis=('in_features', 'out_features'), + grad=True) + + out = Linear()(input, weight) + return out + + +def add(input, other): + return Add()(input, other) + + +def identity(input): + return Identity()(input) + + +def adaptive_avg_pooling(input): + return AdaptiveAvgPooling()(input) + + +@scope +def batch_norm(input): + # https://stackoverflow.com/questions/44887446/pytorch-nn-functional-batch-norm-for-2d-input + n_chans = input.shape()[1] + running_mu = Tensor(name='running_mean', shape=(n_chans,), axis=('c',)) + running_std = Tensor(name='running_std', shape=(n_chans,), axis=('c',)) + # running_mu.data = torch.zeros(n_chans) # zeros are fine for first training iter + # running_std = torch.ones(n_chans) # ones are fine for first training iter + return BatchNorm()(input, running_mu, running_std) + + +def relu(input): + return Relu()(input) + + +def conv_relu(input, out_channels, kernel_size, stride): + out = conv2d(input, out_channels=out_channels, stride=stride, kernel_size=kernel_size) + out = relu(out) + return out + + +def choice(input, *choices, switch=None): + return ChoiceOp(*choices, switch=switch)(input) + + +def dynamic_depth(*exits, switch): + return ChoiceOp(*exits, switch=switch)() diff --git a/test/test_lazy_resnet.py b/test/test_lazy_resnet.py index 341fbd20..e165f677 100644 --- a/test/test_lazy_resnet.py +++ b/test/test_lazy_resnet.py @@ -4,7 +4,7 @@ import yaml import os -from hannah.models.resnet.models import ResNet +from hannah.models.resnet.models_lazy import ResNet def test_lazy_resnet_init():