From a3e3c70eca6369f91559ea945e5d883d27056709 Mon Sep 17 00:00:00 2001 From: Moritz Reiber Date: Wed, 27 Mar 2024 12:55:12 +0000 Subject: [PATCH] Add sandglass block to embedded_vision_net --- hannah/callbacks/summaries.py | 5 +-- hannah/models/embedded_vision_net/blocks.py | 37 ++++++++++++++----- .../models/embedded_vision_net/operators.py | 2 + hannah/nas/search/search.py | 5 +-- 4 files changed, 32 insertions(+), 17 deletions(-) diff --git a/hannah/callbacks/summaries.py b/hannah/callbacks/summaries.py index 80c1ded8..d96ad209 100644 --- a/hannah/callbacks/summaries.py +++ b/hannah/callbacks/summaries.py @@ -515,7 +515,6 @@ def __init__(self, module: torch.nn.Module): def run_node(self, n: torch.fx.Node): try: out = super().run_node(n) - print(out.shape, n) except Exception as e: print(str(e)) if n.op == "call_function": @@ -535,7 +534,7 @@ def run_node(self, n: torch.fx.Node): self.data["MACs"] += [int(macs)] except Exception as e: msglogger.warning("Summary of node %s failed: %s", n.name, str(e)) - print(traceback.format_exc()) + print(traceback.format_exc()) return out @@ -543,7 +542,7 @@ class FxMACSummaryCallback(MacSummaryCallback): def _do_summary(self, pl_module, input=None, print_log=True): interpreter = MACSummaryInterpreter(pl_module.model) dummy_input = input - + if dummy_input is None: dummy_input = pl_module.example_feature_array dummy_input = dummy_input.to(pl_module.device) diff --git a/hannah/models/embedded_vision_net/blocks.py b/hannah/models/embedded_vision_net/blocks.py index ba5b81ae..280401a5 100644 --- a/hannah/models/embedded_vision_net/blocks.py +++ b/hannah/models/embedded_vision_net/blocks.py @@ -27,10 +27,10 @@ def grouped_pointwise(input, out_channels): @scope def expansion(input, expanded_channels): - #pw = partial(pointwise_conv2d, out_channels=expanded_channels) - #grouped_pw = partial(grouped_pointwise, out_channels=expanded_channels) - #return choice(input, pw, grouped_pw) - return pointwise_conv2d(input, out_channels=expanded_channels) + pw = partial(pointwise_conv2d, out_channels=expanded_channels) + grouped_pw = partial(grouped_pointwise, out_channels=expanded_channels) + return choice(input, pw, grouped_pw) + # return pointwise_conv2d(input, out_channels=expanded_channels) @scope @@ -40,10 +40,10 @@ def spatial_correlation(input, out_channels, kernel_size, stride=1): @scope def reduction(input, out_channels): - #pw = partial(pointwise_conv2d, out_channels=out_channels) - #grouped_pw = partial(grouped_pointwise, out_channels=out_channels) - #return choice(input, pw, grouped_pw) - return pointwise_conv2d(input, out_channels=out_channels) + pw = partial(pointwise_conv2d, out_channels=out_channels) + grouped_pw = partial(grouped_pointwise, out_channels=out_channels) + return choice(input, pw, grouped_pw) + # return pointwise_conv2d(input, out_channels=out_channels) @scope @@ -80,14 +80,31 @@ def expand_reduce(input, out_channels, expand_ratio, kernel_size, stride): return out +# FIXME: integrate this into reduce_expand? +@scope +def sandglass_block(input, out_channels, reduce_ratio, kernel_size, stride): + in_channels = input.shape()[1] + reduced_channels = Int(in_channels / reduce_ratio) + out = depthwise_conv2d(input, out_channels=in_channels, kernel_size=kernel_size, stride=1) + out = batch_norm(out) + out = relu(out) + out = reduction(out, out_channels=reduced_channels) + out = expansion(out, expanded_channels=out_channels) + out = relu(out) + out = batch_norm(out) + out = depthwise_conv2d(out, out_channels=out_channels, kernel_size=kernel_size, stride=stride) + return out + + @scope def pattern(input, stride, out_channels, kernel_size, expand_ratio, reduce_ratio): convolution = partial(conv_relu, stride=stride, kernel_size=kernel_size, out_channels=out_channels) red_exp = partial(reduce_expand, out_channels=out_channels, reduce_ratio=reduce_ratio, kernel_size=kernel_size, stride=stride) exp_red = partial(expand_reduce, out_channels=out_channels, expand_ratio=expand_ratio, kernel_size=kernel_size, stride=stride) - #pool = partial(pooling, kernel_size=kernel_size, stride=stride) + pool = partial(pooling, kernel_size=kernel_size, stride=stride) + sandglass = partial(sandglass_block, out_channels=out_channels, reduce_ratio=reduce_ratio, kernel_size=kernel_size, stride=stride) - out = choice(input, convolution, exp_red, red_exp) + out = choice(input, convolution, exp_red, red_exp, pool, sandglass) return out diff --git a/hannah/models/embedded_vision_net/operators.py b/hannah/models/embedded_vision_net/operators.py index a67021c1..85894731 100644 --- a/hannah/models/embedded_vision_net/operators.py +++ b/hannah/models/embedded_vision_net/operators.py @@ -34,6 +34,8 @@ def grouped_conv2d(input, out_channels, kernel_size, stride, dilation=1, padding in_channels = input.shape()[1] if groups is None: groups = Groups(in_channels=in_channels, out_channels=out_channels, name="groups") + else: + groups = Int(groups) grouped_channels = Int(in_channels / groups) weight = Tensor(name='weight', diff --git a/hannah/nas/search/search.py b/hannah/nas/search/search.py index 8e174137..c5e8dbcb 100644 --- a/hannah/nas/search/search.py +++ b/hannah/nas/search/search.py @@ -71,7 +71,7 @@ def __init__( self.random_state = np.random.RandomState() else: self.random_state = random_state - + self.example_input_array = None if input_shape is not None: self.example_input_array = torch.rand([1] + list(input_shape)) @@ -256,12 +256,9 @@ def build_model(self, parameters): return module def build_search_space(self): - - input = Tensor( "input", shape=self.example_input_array.shape, axis=("N", "C", "H", "W") ) - search_space = instantiate(self.config.model, input=input, _recursive_=True) return search_space