Skip to content

Commit

Permalink
Add sandglass block to embedded_vision_net
Browse files Browse the repository at this point in the history
  • Loading branch information
moreib committed Mar 27, 2024
1 parent e834af7 commit a3e3c70
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 17 deletions.
5 changes: 2 additions & 3 deletions hannah/callbacks/summaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,6 @@ def __init__(self, module: torch.nn.Module):
def run_node(self, n: torch.fx.Node):
try:
out = super().run_node(n)
print(out.shape, n)
except Exception as e:
print(str(e))
if n.op == "call_function":
Expand All @@ -535,15 +534,15 @@ def run_node(self, n: torch.fx.Node):
self.data["MACs"] += [int(macs)]
except Exception as e:
msglogger.warning("Summary of node %s failed: %s", n.name, str(e))
print(traceback.format_exc())
print(traceback.format_exc())
return out


class FxMACSummaryCallback(MacSummaryCallback):
def _do_summary(self, pl_module, input=None, print_log=True):
interpreter = MACSummaryInterpreter(pl_module.model)
dummy_input = input

if dummy_input is None:
dummy_input = pl_module.example_feature_array
dummy_input = dummy_input.to(pl_module.device)
Expand Down
37 changes: 27 additions & 10 deletions hannah/models/embedded_vision_net/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ def grouped_pointwise(input, out_channels):

@scope
def expansion(input, expanded_channels):
#pw = partial(pointwise_conv2d, out_channels=expanded_channels)
#grouped_pw = partial(grouped_pointwise, out_channels=expanded_channels)
#return choice(input, pw, grouped_pw)
return pointwise_conv2d(input, out_channels=expanded_channels)
pw = partial(pointwise_conv2d, out_channels=expanded_channels)
grouped_pw = partial(grouped_pointwise, out_channels=expanded_channels)
return choice(input, pw, grouped_pw)
# return pointwise_conv2d(input, out_channels=expanded_channels)


@scope
Expand All @@ -40,10 +40,10 @@ def spatial_correlation(input, out_channels, kernel_size, stride=1):

@scope
def reduction(input, out_channels):
#pw = partial(pointwise_conv2d, out_channels=out_channels)
#grouped_pw = partial(grouped_pointwise, out_channels=out_channels)
#return choice(input, pw, grouped_pw)
return pointwise_conv2d(input, out_channels=out_channels)
pw = partial(pointwise_conv2d, out_channels=out_channels)
grouped_pw = partial(grouped_pointwise, out_channels=out_channels)
return choice(input, pw, grouped_pw)
# return pointwise_conv2d(input, out_channels=out_channels)


@scope
Expand Down Expand Up @@ -80,14 +80,31 @@ def expand_reduce(input, out_channels, expand_ratio, kernel_size, stride):
return out


# FIXME: integrate this into reduce_expand?
@scope
def sandglass_block(input, out_channels, reduce_ratio, kernel_size, stride):
in_channels = input.shape()[1]
reduced_channels = Int(in_channels / reduce_ratio)
out = depthwise_conv2d(input, out_channels=in_channels, kernel_size=kernel_size, stride=1)
out = batch_norm(out)
out = relu(out)
out = reduction(out, out_channels=reduced_channels)
out = expansion(out, expanded_channels=out_channels)
out = relu(out)
out = batch_norm(out)
out = depthwise_conv2d(out, out_channels=out_channels, kernel_size=kernel_size, stride=stride)
return out


@scope
def pattern(input, stride, out_channels, kernel_size, expand_ratio, reduce_ratio):
convolution = partial(conv_relu, stride=stride, kernel_size=kernel_size, out_channels=out_channels)
red_exp = partial(reduce_expand, out_channels=out_channels, reduce_ratio=reduce_ratio, kernel_size=kernel_size, stride=stride)
exp_red = partial(expand_reduce, out_channels=out_channels, expand_ratio=expand_ratio, kernel_size=kernel_size, stride=stride)
#pool = partial(pooling, kernel_size=kernel_size, stride=stride)
pool = partial(pooling, kernel_size=kernel_size, stride=stride)
sandglass = partial(sandglass_block, out_channels=out_channels, reduce_ratio=reduce_ratio, kernel_size=kernel_size, stride=stride)

out = choice(input, convolution, exp_red, red_exp)
out = choice(input, convolution, exp_red, red_exp, pool, sandglass)
return out


Expand Down
2 changes: 2 additions & 0 deletions hannah/models/embedded_vision_net/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def grouped_conv2d(input, out_channels, kernel_size, stride, dilation=1, padding
in_channels = input.shape()[1]
if groups is None:
groups = Groups(in_channels=in_channels, out_channels=out_channels, name="groups")
else:
groups = Int(groups)

grouped_channels = Int(in_channels / groups)
weight = Tensor(name='weight',
Expand Down
5 changes: 1 addition & 4 deletions hannah/nas/search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(
self.random_state = np.random.RandomState()
else:
self.random_state = random_state

self.example_input_array = None
if input_shape is not None:
self.example_input_array = torch.rand([1] + list(input_shape))
Expand Down Expand Up @@ -256,12 +256,9 @@ def build_model(self, parameters):
return module

def build_search_space(self):


input = Tensor(
"input", shape=self.example_input_array.shape, axis=("N", "C", "H", "W")
)

search_space = instantiate(self.config.model, input=input, _recursive_=True)
return search_space

Expand Down

0 comments on commit a3e3c70

Please sign in to comment.