diff --git a/hannah/callbacks/summaries.py b/hannah/callbacks/summaries.py index 5aafb1b0..7e13a5e5 100644 --- a/hannah/callbacks/summaries.py +++ b/hannah/callbacks/summaries.py @@ -17,29 +17,24 @@ # limitations under the License. # import logging -from collections import OrderedDict import sys import traceback +from collections import OrderedDict import pandas as pd import torch +import torch.fx as fx from pytorch_lightning.callbacks import Callback from pytorch_lightning.utilities.rank_zero import rank_zero_only from tabulate import tabulate from torch.fx.graph_module import GraphModule -from hannah.models.ofa.submodules.elasticBase import ElasticBase1d +from hannah.nas.functional_operators.operators import add, conv2d, linear +from hannah.nas.graph_conversion import GraphConversionTracer from ..models.factory import qat -from ..models.ofa import OFAModel -from ..models.ofa.submodules.elastickernelconv import ConvBn1d, ConvBnReLu1d, ConvRelu1d -from ..models.ofa.type_utils import elastic_conv_type, elastic_Linear_type from ..models.sinc import SincNet -import torch.fx as fx -from hannah.nas.graph_conversion import GraphConversionTracer -from hannah.nas.functional_operators.operators import conv2d, linear, add - msglogger = logging.getLogger(__name__) @@ -142,11 +137,6 @@ def get_extra(module, volume_ofm, output): """ classes = { - elastic_conv_type: get_elastic_conv, - elastic_Linear_type: get_elastic_linear, - ConvBn1d: get_conv, - ConvRelu1d: get_conv, - ConvBnReLu1d: get_conv, torch.nn.Conv1d: get_conv, torch.nn.Conv2d: get_conv, qat.Conv1d: get_conv, @@ -330,16 +320,9 @@ def _do_summary(self, pl_module, input=None, print_log=True): total_weights = 0.0 estimated_acts = 0.0 model = pl_module.model - ofamodel = isinstance(model, OFAModel) - if ofamodel: - if model.validation_model is None: - model.build_validation_model() - model = model.validation_model try: df = walk_model(model, dummy_input) - if ofamodel: - pl_module.model.reset_validation_model() t = tabulate(df, headers="keys", tablefmt="psql", floatfmt=".5f") total_macs = df["MACs"].sum() total_acts = df["IFM volume"][0] + df["OFM volume"].sum() @@ -354,8 +337,6 @@ def _do_summary(self, pl_module, input=None, print_log=True): "Estimated Activations: " + "{:,}".format(estimated_acts) ) except RuntimeError as e: - if ofamodel: - pl_module.model.reset_validation_model() msglogger.warning("Could not create performance summary: %s", str(e)) return OrderedDict() @@ -477,15 +458,13 @@ def get_conv(node, output, args, kwargs): out_channels = weight.shape[0] in_channels = weight.shape[1] kernel_size = weight.shape[2] - num_weights = out_channels * in_channels / kwargs['groups'] * kernel_size**2 - macs = volume_ofm * in_channels / kwargs['groups'] * kernel_size + num_weights = out_channels * in_channels / kwargs["groups"] * kernel_size**2 + macs = volume_ofm * in_channels / kwargs["groups"] * kernel_size attrs = "k=" + "(%d, %d)" % (kernel_size, kernel_size) - attrs += ", s=" + "(%d, %d)" % (kwargs['stride'], kwargs['stride']) - attrs += ", g=(%d)" % kwargs['groups'] - attrs += ", dsc=(%s)" % str( - in_channels == out_channels == kwargs['groups'] - ) - attrs += ", d=" + "(%d, %d)" % (kwargs['dilation'], kwargs['dilation']) + attrs += ", s=" + "(%d, %d)" % (kwargs["stride"], kwargs["stride"]) + attrs += ", g=(%d)" % kwargs["groups"] + attrs += ", dsc=(%s)" % str(in_channels == out_channels == kwargs["groups"]) + attrs += ", d=" + "(%d, %d)" % (kwargs["dilation"], kwargs["dilation"]) return num_weights, macs, attrs @@ -500,7 +479,7 @@ def get_linear(node, output, args, kwargs): def get_type(node): try: - return node.name.split('_')[-2] + return node.name.split("_")[-2] except Exception as e: pass return node.name @@ -531,24 +510,26 @@ def __init__(self, module: torch.nn.Module): "MACs": [], } - def run_node(self, n : torch.fx.Node): + def run_node(self, n: torch.fx.Node): try: out = super().run_node(n) except Exception as e: print(str(e)) - if n.op == 'call_function': + if n.op == "call_function": try: args, kwargs = self.fetch_args_kwargs_from_env(n) - num_weights, macs, attrs = self.count_function.get(n.target, get_zero_op)(n, out, args, kwargs) - self.data['Name'] += [n.name] - self.data['Type'] += [get_type(n)] - self.data['Attrs'] += [attrs] - self.data['IFM'] += [tuple(args[0].shape)] - self.data['IFM volume'] += [prod(args[0].shape)] - self.data['OFM'] += [tuple(out.shape)] - self.data['OFM volume'] += [prod(out.shape)] - self.data['Weights volume'] += [int(num_weights)] - self.data['MACs'] += [int(macs)] + num_weights, macs, attrs = self.count_function.get( + n.target, get_zero_op + )(n, out, args, kwargs) + self.data["Name"] += [n.name] + self.data["Type"] += [get_type(n)] + self.data["Attrs"] += [attrs] + self.data["IFM"] += [tuple(args[0].shape)] + self.data["IFM volume"] += [prod(args[0].shape)] + self.data["OFM"] += [tuple(out.shape)] + self.data["OFM volume"] += [prod(out.shape)] + self.data["Weights volume"] += [int(num_weights)] + self.data["MACs"] += [int(macs)] except Exception as e: msglogger.warning("Summary of node %s failed: %s", n.name, str(e)) return out diff --git a/hannah/conf/config_ofa.yaml b/hannah/conf/config_ofa.yaml deleted file mode 100644 index 442cea2d..00000000 --- a/hannah/conf/config_ofa.yaml +++ /dev/null @@ -1,24 +0,0 @@ -## -## Copyright (c) 2022 University of Tübingen. -## -## This file is part of hannah. -## See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -## -## Licensed under the Apache License, Version 2.0 (the "License"); -## you may not use this file except in compliance with the License. -## You may obtain a copy of the License at -## -## http://www.apache.org/licenses/LICENSE-2.0 -## -## Unless required by applicable law or agreed to in writing, software -## distributed under the License is distributed on an "AS IS" BASIS, -## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -## See the License for the specific language governing permissions and -## limitations under the License. -## -defaults: - - config - - override dataset: kws - - override features: mfcc - - override model: ofa - - _self_ diff --git a/hannah/conf/model/ofa.yaml b/hannah/conf/model/ofa.yaml deleted file mode 100644 index d7ebd501..00000000 --- a/hannah/conf/model/ofa.yaml +++ /dev/null @@ -1,191 +0,0 @@ -## -## Copyright (c) 2022 University of Tübingen. -## -## This file is part of hannah. -## See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -## -## Licensed under the Apache License, Version 2.0 (the "License"); -## you may not use this file except in compliance with the License. -## You may obtain a copy of the License at -## -## http://www.apache.org/licenses/LICENSE-2.0 -## -## Unless required by applicable law or agreed to in writing, software -## distributed under the License is distributed on an "AS IS" BASIS, -## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -## See the License for the specific language governing permissions and -## limitations under the License. -## -_target_: hannah.models.ofa.models.create -name: ofa -skew_sampling_distribution: false -min_depth: 3 -norm_before_act: true -dropout: 0.5 -# to add elastic channel count following a module which is not elastic_conv1d -# an "elastic_channel_helper" module with only a channel_count list may be added after -# the maximum channel count of the elastic helper must match the out_channels of the previous module - -conv: - - target: forward - stride: 1 - blocks: - - target: conv1d - kernel_sizes: 3 - dilation_sizes: 1 - act: false - norm: true - quant: false - out_channels: 16 - - target: residual1d - stride: 2 - quant_skip: false - blocks: - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - act: true - norm: true - quant: false - out_channels: - - 24 - - 16 - - 8 - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - act: false - norm: false - quant: false - out_channels: - - 24 - - 16 - - 8 - - target: residual1d - stride: 2 - quant_skip: false - blocks: - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - act: true - norm: true - quant: false - out_channels: - - 32 - - 24 - - 16 - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - act: false - norm: false - quant: false - out_channels: - - 32 - - 24 - - 16 - - target: residual1d - stride: 2 - quant_skip: false - blocks: - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - act: true - norm: true - quant: false - out_channels: - - 48 - - 32 - - 24 - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - act: false - norm: false - quant: false - out_channels: - - 48 - - 32 - - 24 - - target: residual1d - stride: 2 - quant_skip: false - blocks: - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - act: true - norm: true - quant: false - out_channels: - - 64 - - 48 - - 32 - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - act: false - norm: false - quant: false - out_channels: - - 64 - - 48 - - 32 diff --git a/hannah/conf/model/ofa_dsc.yaml b/hannah/conf/model/ofa_dsc.yaml deleted file mode 100644 index e6fcc61d..00000000 --- a/hannah/conf/model/ofa_dsc.yaml +++ /dev/null @@ -1,233 +0,0 @@ -## -## Copyright (c) 2022 University of Tübingen. -## -## This file is part of hannah. -## See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -## -## Licensed under the Apache License, Version 2.0 (the "License"); -## you may not use this file except in compliance with the License. -## You may obtain a copy of the License at -## -## http://www.apache.org/licenses/LICENSE-2.0 -## -## Unless required by applicable law or agreed to in writing, software -## distributed under the License is distributed on an "AS IS" BASIS, -## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -## See the License for the specific language governing permissions and -## limitations under the License. -## -_target_: hannah.models.ofa.models.create -name: ofa -skew_sampling_distribution: false -min_depth: 3 -norm_before_act: true -dropout: 0.5 -# to add elastic channel count following a module which is not elastic_conv1d -# an "elastic_channel_helper" module with only a channel_count list may be added after -# the maximum channel count of the elastic helper must match the out_channels of the previous module - -conv: - - target: forward - stride: 1 - blocks: - - target: conv1d - kernel_sizes: 3 - dilation_sizes: 1 - grouping_sizes: 1 - act: false - norm: true - quant: false - dsc: [false] - out_channels: 16 - - target: residual1d - stride: 2 - quant_skip: false - blocks: - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - grouping_sizes: - - 1 - - 2 - - 4 - act: true - norm: true - quant: false - dsc: [false, true] - out_channels: - - 24 - - 16 - - 8 - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - grouping_sizes: - - 1 - - 2 - - 4 - act: false - norm: false - quant: false - dsc: [false, true] - out_channels: - - 24 - - 16 - - 8 - - target: residual1d - stride: 2 - quant_skip: false - blocks: - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - grouping_sizes: - - 1 - - 2 - - 4 - act: true - norm: true - quant: false - dsc: [false, true] - out_channels: - - 32 - - 24 - - 16 - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - grouping_sizes: - - 1 - - 2 - - 4 - act: false - norm: false - quant: false - dsc: [false, true] - out_channels: - - 32 - - 24 - - 16 - - target: residual1d - stride: 2 - quant_skip: false - blocks: - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - grouping_sizes: - - 1 - - 2 - - 4 - act: true - norm: true - quant: false - dsc: [false, true] - out_channels: - - 48 - - 32 - - 24 - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - grouping_sizes: - - 1 - - 2 - - 4 - act: false - norm: false - quant: false - dsc: [false, true] - out_channels: - - 48 - - 32 - - 24 - - target: residual1d - stride: 2 - quant_skip: false - blocks: - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - grouping_sizes: - - 1 - - 2 - - 4 - act: true - norm: true - quant: false - dsc: [false, true] - out_channels: - - 64 - - 48 - - 32 - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - grouping_sizes: - - 1 - - 2 - - 4 - act: false - norm: false - quant: false - dsc: [false, true] - out_channels: - - 64 - - 48 - - 32 diff --git a/hannah/conf/model/ofa_dsc_quant.yaml b/hannah/conf/model/ofa_dsc_quant.yaml deleted file mode 100644 index b9b1d360..00000000 --- a/hannah/conf/model/ofa_dsc_quant.yaml +++ /dev/null @@ -1,241 +0,0 @@ -## -## Copyright (c) 2022 University of Tübingen. -## -## This file is part of hannah. -## See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -## -## Licensed under the Apache License, Version 2.0 (the "License"); -## you may not use this file except in compliance with the License. -## You may obtain a copy of the License at -## -## http://www.apache.org/licenses/LICENSE-2.0 -## -## Unless required by applicable law or agreed to in writing, software -## distributed under the License is distributed on an "AS IS" BASIS, -## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -## See the License for the specific language governing permissions and -## limitations under the License. -## -_target_: hannah.models.ofa.models.create -name: ofa -skew_sampling_distribution: false -min_depth: 3 -norm_before_act: true -dropout: 0.5 -# to add elastic channel count following a module which is not elastic_conv1d -# an "elastic_channel_helper" module with only a channel_count list may be added after -# the maximum channel count of the elastic helper must match the out_channels of the previous module -qconfig: - _target_: hannah.models.factory.qconfig.get_trax_qat_qconfig - config: - bw_b: 8 - bw_w: 6 - bw_f: 8 - power_of_2: false # Use power of two quantization for weights - noise_prob: 0.7 # Probability of quantizing a value during training - -conv: - - target: forward - stride: 1 - blocks: - - target: conv1d - kernel_sizes: 3 - dilation_sizes: 1 - grouping_sizes: 1 - act: false - norm: true - quant: true - dsc: [false] - out_channels: 16 - - target: residual1d - stride: 2 - quant_skip: true - blocks: - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - grouping_sizes: - - 1 - - 2 - - 4 - act: true - norm: true - quant: true - dsc: [false, true] - out_channels: - - 24 - - 16 - - 8 - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - grouping_sizes: - - 1 - - 2 - - 4 - act: false - norm: false - quant: true - dsc: [false, true] - out_channels: - - 24 - - 16 - - 8 - - target: residual1d - stride: 2 - quant_skip: true - blocks: - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - grouping_sizes: - - 1 - - 2 - - 4 - act: true - norm: true - quant: true - dsc: [false, true] - out_channels: - - 32 - - 24 - - 16 - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - grouping_sizes: - - 1 - - 2 - - 4 - act: false - norm: false - quant: true - dsc: [false, true] - out_channels: - - 32 - - 24 - - 16 - - target: residual1d - stride: 2 - quant_skip: true - blocks: - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - grouping_sizes: - - 1 - - 2 - - 4 - act: true - norm: true - quant: true - dsc: [false, true] - out_channels: - - 48 - - 32 - - 24 - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - grouping_sizes: - - 1 - - 2 - - 4 - act: false - norm: false - quant: true - dsc: [false, true] - out_channels: - - 48 - - 32 - - 24 - - target: residual1d - stride: 2 - quant_skip: true - blocks: - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - grouping_sizes: - - 1 - - 2 - - 4 - act: true - norm: true - quant: true - dsc: [false, true] - out_channels: - - 64 - - 48 - - 32 - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - grouping_sizes: - - 1 - - 2 - - 4 - act: false - norm: false - quant: true - dsc: [false, true] - out_channels: - - 64 - - 48 - - 32 diff --git a/hannah/conf/model/ofa_group.yaml b/hannah/conf/model/ofa_group.yaml deleted file mode 100644 index e0aafc49..00000000 --- a/hannah/conf/model/ofa_group.yaml +++ /dev/null @@ -1,224 +0,0 @@ -## -## Copyright (c) 2022 University of Tübingen. -## -## This file is part of hannah. -## See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -## -## Licensed under the Apache License, Version 2.0 (the "License"); -## you may not use this file except in compliance with the License. -## You may obtain a copy of the License at -## -## http://www.apache.org/licenses/LICENSE-2.0 -## -## Unless required by applicable law or agreed to in writing, software -## distributed under the License is distributed on an "AS IS" BASIS, -## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -## See the License for the specific language governing permissions and -## limitations under the License. -## -_target_: hannah.models.ofa.models.create -name: ofa -skew_sampling_distribution: false -min_depth: 3 -norm_before_act: true -dropout: 0.5 -# to add elastic channel count following a module which is not elastic_conv1d -# an "elastic_channel_helper" module with only a channel_count list may be added after -# the maximum channel count of the elastic helper must match the out_channels of the previous module - -conv: - - target: forward - stride: 1 - blocks: - - target: conv1d - kernel_sizes: 3 - dilation_sizes: 1 - grouping_sizes: 1 - act: false - norm: true - quant: false - out_channels: 16 - - target: residual1d - stride: 2 - quant_skip: false - blocks: - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - grouping_sizes: - - 1 - - 2 - - 4 - act: true - norm: true - quant: false - out_channels: - - 24 - - 16 - - 8 - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - grouping_sizes: - - 1 - - 2 - - 4 - act: false - norm: false - quant: false - out_channels: - - 24 - - 16 - - 8 - - target: residual1d - stride: 2 - quant_skip: false - blocks: - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - grouping_sizes: - - 1 - - 2 - - 4 - act: true - norm: true - quant: false - out_channels: - - 32 - - 24 - - 16 - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - grouping_sizes: - - 1 - - 2 - - 4 - act: false - norm: false - quant: false - out_channels: - - 32 - - 24 - - 16 - - target: residual1d - stride: 2 - quant_skip: false - blocks: - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - grouping_sizes: - - 1 - - 2 - - 4 - act: true - norm: true - quant: false - out_channels: - - 48 - - 32 - - 24 - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - grouping_sizes: - - 1 - - 2 - - 4 - act: false - norm: false - quant: false - out_channels: - - 48 - - 32 - - 24 - - target: residual1d - stride: 2 - quant_skip: false - blocks: - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - grouping_sizes: - - 1 - - 2 - - 4 - act: true - norm: true - quant: false - out_channels: - - 64 - - 48 - - 32 - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - grouping_sizes: - - 1 - - 2 - - 4 - act: false - norm: false - quant: false - out_channels: - - 64 - - 48 - - 32 diff --git a/hannah/conf/model/ofa_group_quant.yaml b/hannah/conf/model/ofa_group_quant.yaml deleted file mode 100644 index 90d59f67..00000000 --- a/hannah/conf/model/ofa_group_quant.yaml +++ /dev/null @@ -1,232 +0,0 @@ -## -## Copyright (c) 2022 University of Tübingen. -## -## This file is part of hannah. -## See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -## -## Licensed under the Apache License, Version 2.0 (the "License"); -## you may not use this file except in compliance with the License. -## You may obtain a copy of the License at -## -## http://www.apache.org/licenses/LICENSE-2.0 -## -## Unless required by applicable law or agreed to in writing, software -## distributed under the License is distributed on an "AS IS" BASIS, -## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -## See the License for the specific language governing permissions and -## limitations under the License. -## -_target_: hannah.models.ofa.models.create -name: ofa -skew_sampling_distribution: false -min_depth: 3 -norm_before_act: true -dropout: 0.5 -# to add elastic channel count following a module which is not elastic_conv1d -# an "elastic_channel_helper" module with only a channel_count list may be added after -# the maximum channel count of the elastic helper must match the out_channels of the previous module -qconfig: - _target_: hannah.models.factory.qconfig.get_trax_qat_qconfig - config: - bw_b: 8 - bw_w: 6 - bw_f: 8 - power_of_2: false # Use power of two quantization for weights - noise_prob: 0.7 # Probability of quantizing a value during training - -conv: - - target: forward - stride: 1 - blocks: - - target: conv1d - kernel_sizes: 3 - dilation_sizes: 1 - grouping_sizes: 1 - act: false - norm: true - quant: true - out_channels: 16 - - target: residual1d - stride: 2 - quant_skip: true - blocks: - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - grouping_sizes: - - 1 - - 2 - - 4 - act: true - norm: true - quant: true - out_channels: - - 24 - - 16 - - 8 - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - grouping_sizes: - - 1 - - 2 - - 4 - act: false - norm: false - quant: true - out_channels: - - 24 - - 16 - - 8 - - target: residual1d - stride: 2 - quant_skip: true - blocks: - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - grouping_sizes: - - 1 - - 2 - - 4 - act: true - norm: true - quant: true - out_channels: - - 32 - - 24 - - 16 - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - grouping_sizes: - - 1 - - 2 - - 4 - act: false - norm: false - quant: true - out_channels: - - 32 - - 24 - - 16 - - target: residual1d - stride: 2 - quant_skip: true - blocks: - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - grouping_sizes: - - 1 - - 2 - - 4 - act: true - norm: true - quant: true - out_channels: - - 48 - - 32 - - 24 - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - grouping_sizes: - - 1 - - 2 - - 4 - act: false - norm: false - quant: true - out_channels: - - 48 - - 32 - - 24 - - target: residual1d - stride: 2 - quant_skip: true - blocks: - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - grouping_sizes: - - 1 - - 2 - - 4 - act: true - norm: true - quant: true - out_channels: - - 64 - - 48 - - 32 - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - grouping_sizes: - - 1 - - 2 - - 4 - act: false - norm: false - quant: true - out_channels: - - 64 - - 48 - - 32 diff --git a/hannah/conf/model/ofa_kernel.yaml b/hannah/conf/model/ofa_kernel.yaml deleted file mode 100644 index 53c25482..00000000 --- a/hannah/conf/model/ofa_kernel.yaml +++ /dev/null @@ -1,129 +0,0 @@ -## -## Copyright (c) 2022 University of Tübingen. -## -## This file is part of hannah. -## See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -## -## Licensed under the Apache License, Version 2.0 (the "License"); -## you may not use this file except in compliance with the License. -## You may obtain a copy of the License at -## -## http://www.apache.org/licenses/LICENSE-2.0 -## -## Unless required by applicable law or agreed to in writing, software -## distributed under the License is distributed on an "AS IS" BASIS, -## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -## See the License for the specific language governing permissions and -## limitations under the License. -## -_target_: hannah.models.ofa.models.create -name: ofa_kernel -skew_sampling_distribution: false -min_depth: 5 -norm_before_act: true -dropout: 0.5 -validate_on_extracted: false -# to add elastic channel count following a module which is not elastic_conv1d -# an "elastic_channel_helper" module with only a channel_count list may be added after -# the maximum channel count of the elastic helper must match the out_channels of the previous module -conv: - - target: forward - stride: 1 - blocks: - - target: conv1d - kernel_size: 3 - act: false - norm: true - out_channels: 16 - - target: residual1d - stride: 2 - blocks: - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - act: true - norm: true - out_channels: - - 24 - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - act: false - norm: false - out_channels: - - 24 - - target: residual1d - stride: 2 - blocks: - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - act: true - norm: true - out_channels: - - 32 - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - act: false - norm: false - out_channels: - - 32 - - target: residual1d - stride: 2 - blocks: - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - act: true - norm: true - out_channels: - - 48 - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - act: false - norm: false - out_channels: - - 48 - - target: residual1d - stride: 2 - blocks: - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - act: true - norm: true - out_channels: - - 64 - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - act: false - norm: false - out_channels: - - 64 diff --git a/hannah/conf/model/ofa_no_channels.yaml b/hannah/conf/model/ofa_no_channels.yaml deleted file mode 100644 index d339d5ce..00000000 --- a/hannah/conf/model/ofa_no_channels.yaml +++ /dev/null @@ -1,128 +0,0 @@ -## -## Copyright (c) 2022 University of Tübingen. -## -## This file is part of hannah. -## See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -## -## Licensed under the Apache License, Version 2.0 (the "License"); -## you may not use this file except in compliance with the License. -## You may obtain a copy of the License at -## -## http://www.apache.org/licenses/LICENSE-2.0 -## -## Unless required by applicable law or agreed to in writing, software -## distributed under the License is distributed on an "AS IS" BASIS, -## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -## See the License for the specific language governing permissions and -## limitations under the License. -## -_target_: hannah.models.ofa.models.create -name: ofa_no_channels -skew_sampling_distribution: false -min_depth: 3 -norm_before_act: true -dropout: 0.5 -# to add elastic channel count following a module which is not elastic_conv1d -# an "elastic_channel_helper" module with only a channel_count list may be added after -# the maximum channel count of the elastic helper must match the out_channels of the previous module -conv: - - target: forward - stride: 1 - blocks: - - target: conv1d - kernel_size: 3 - act: false - norm: true - out_channels: 16 - - target: residual1d - stride: 2 - blocks: - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - act: true - norm: true - out_channels: - - 24 - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - act: false - norm: false - out_channels: - - 24 - - target: residual1d - stride: 2 - blocks: - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - act: true - norm: true - out_channels: - - 32 - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - act: false - norm: false - out_channels: - - 32 - - target: residual1d - stride: 2 - blocks: - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - act: true - norm: true - out_channels: - - 48 - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - act: false - norm: false - out_channels: - - 48 - - target: residual1d - stride: 2 - blocks: - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - act: true - norm: true - out_channels: - - 64 - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - act: false - norm: false - out_channels: - - 64 diff --git a/hannah/conf/model/ofa_quant.yaml b/hannah/conf/model/ofa_quant.yaml deleted file mode 100644 index 3bf0aafd..00000000 --- a/hannah/conf/model/ofa_quant.yaml +++ /dev/null @@ -1,199 +0,0 @@ -## -## Copyright (c) 2022 University of Tübingen. -## -## This file is part of hannah. -## See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -## -## Licensed under the Apache License, Version 2.0 (the "License"); -## you may not use this file except in compliance with the License. -## You may obtain a copy of the License at -## -## http://www.apache.org/licenses/LICENSE-2.0 -## -## Unless required by applicable law or agreed to in writing, software -## distributed under the License is distributed on an "AS IS" BASIS, -## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -## See the License for the specific language governing permissions and -## limitations under the License. -## -_target_: hannah.models.ofa.models.create -name: ofa -skew_sampling_distribution: false -min_depth: 3 -norm_before_act: true -dropout: 0.5 -# to add elastic channel count following a module which is not elastic_conv1d -# an "elastic_channel_helper" module with only a channel_count list may be added after -# the maximum channel count of the elastic helper must match the out_channels of the previous module -qconfig: - _target_: hannah.models.factory.qconfig.get_trax_qat_qconfig - config: - bw_b: 8 - bw_w: 6 - bw_f: 8 - power_of_2: false # Use power of two quantization for weights - noise_prob: 0.7 # Probability of quantizing a value during training - -conv: - - target: forward - stride: 1 - blocks: - - target: conv1d - kernel_sizes: 3 - dilation_sizes: 1 - act: false - norm: true - quant: true - out_channels: 16 - - target: residual1d - stride: 2 - quant_skip: true - blocks: - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - act: true - norm: true - quant: true - out_channels: - - 24 - - 16 - - 8 - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - act: false - norm: false - quant: true - out_channels: - - 24 - - 16 - - 8 - - target: residual1d - stride: 2 - quant_skip: true - blocks: - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - act: true - norm: true - quant: true - out_channels: - - 32 - - 24 - - 16 - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - act: false - norm: false - quant: true - out_channels: - - 32 - - 24 - - 16 - - target: residual1d - stride: 2 - quant_skip: true - blocks: - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - act: true - norm: true - quant: true - out_channels: - - 48 - - 32 - - 24 - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - act: false - norm: false - quant: true - out_channels: - - 48 - - 32 - - 24 - - target: residual1d - stride: 2 - quant_skip: true - blocks: - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - act: true - norm: true - quant: true - out_channels: - - 64 - - 48 - - 32 - - target: elastic_conv1d - kernel_sizes: - - 9 - - 7 - - 5 - - 3 - dilation_sizes: - - 9 - - 3 - - 1 - act: false - norm: false - quant: true - out_channels: - - 64 - - 48 - - 32 diff --git a/hannah/conf/nas/ofa_nas.yaml b/hannah/conf/nas/ofa_nas.yaml deleted file mode 100644 index 15620370..00000000 --- a/hannah/conf/nas/ofa_nas.yaml +++ /dev/null @@ -1,26 +0,0 @@ -## -## Copyright (c) 2022 University of Tübingen. -## -## This file is part of hannah. -## See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -## -## Licensed under the Apache License, Version 2.0 (the "License"); -## you may not use this file except in compliance with the License. -## You may obtain a copy of the License at -## -## http://www.apache.org/licenses/LICENSE-2.0 -## -## Unless required by applicable law or agreed to in writing, software -## distributed under the License is distributed on an "AS IS" BASIS, -## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -## See the License for the specific language governing permissions and -## limitations under the License. -## - -defaults: - - model_trainer: progressive_shrinking_legacy - - sampler: random - -_target_: hannah.nas.search.search.WeightSharingNAS -budget: 2000 -n_jobs: 2 diff --git a/hannah/conf/nas/ofa_nas_all_dsc_off.yaml b/hannah/conf/nas/ofa_nas_all_dsc_off.yaml deleted file mode 100644 index c0b0e07d..00000000 --- a/hannah/conf/nas/ofa_nas_all_dsc_off.yaml +++ /dev/null @@ -1,37 +0,0 @@ -## -## Copyright (c) 2022 University of Tübingen. -## -## This file is part of hannah. -## See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -## -## Licensed under the Apache License, Version 2.0 (the "License"); -## you may not use this file except in compliance with the License. -## You may obtain a copy of the License at -## -## http://www.apache.org/licenses/LICENSE-2.0 -## -## Unless required by applicable law or agreed to in writing, software -## distributed under the License is distributed on an "AS IS" BASIS, -## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -## See the License for the specific language governing permissions and -## limitations under the License. -## -_target_: hannah.nas.OFANasTrainer -epochs_warmup: 35 -epochs_kernel_step: 35 -epochs_depth_step: 35 -epochs_width_step: 35 -epochs_dilation_step: 35 -epochs_grouping_step: 35 -epochs_dsc_step: 35 -elastic_kernels_allowed: True -elastic_depth_allowed: True -elastic_dilation_allowed: False -# with width enabled, the epochs must be 65 for good results -elastic_width_allowed: True -elastic_grouping_allowed: True -#depthwise_separable_convolution -elastic_dsc_allowed: False -evaluate: True -random_evaluate: True -random_eval_number: 100 diff --git a/hannah/conf/nas/ofa_nas_dsc_all.yaml b/hannah/conf/nas/ofa_nas_dsc_all.yaml deleted file mode 100644 index b800c64d..00000000 --- a/hannah/conf/nas/ofa_nas_dsc_all.yaml +++ /dev/null @@ -1,19 +0,0 @@ -_target_: hannah.nas.OFANasTrainer -epochs_warmup: 65 -epochs_kernel_step: 65 -epochs_depth_step: 65 -epochs_width_step: 65 -epochs_dilation_step: 65 -epochs_grouping_step: 65 -epochs_dsc_step: 65 -elastic_kernels_allowed: True -elastic_depth_allowed: True -elastic_dilation_allowed: False -# with width enabled, the epochs must be 65 for good results -elastic_width_allowed: True -elastic_grouping_allowed: True -#depthwise_separable_convolution -elastic_dsc_allowed: True -evaluate: True -random_evaluate: True -random_eval_number: 100 diff --git a/hannah/conf/nas/ofa_nas_dsc_good_combination.yaml b/hannah/conf/nas/ofa_nas_dsc_good_combination.yaml deleted file mode 100644 index 58b18fdf..00000000 --- a/hannah/conf/nas/ofa_nas_dsc_good_combination.yaml +++ /dev/null @@ -1,36 +0,0 @@ -## -## Copyright (c) 2022 University of Tübingen. -## -## This file is part of hannah. -## See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -## -## Licensed under the Apache License, Version 2.0 (the "License"); -## you may not use this file except in compliance with the License. -## You may obtain a copy of the License at -## -## http://www.apache.org/licenses/LICENSE-2.0 -## -## Unless required by applicable law or agreed to in writing, software -## distributed under the License is distributed on an "AS IS" BASIS, -## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -## See the License for the specific language governing permissions and -## limitations under the License. -## -_target_: hannah.nas.OFANasTrainer -epochs_warmup: 35 -epochs_kernel_step: 35 -epochs_depth_step: 35 -epochs_width_step: 35 -epochs_dilation_step: 35 -epochs_grouping_step: 35 -epochs_dsc_step: 35 -elastic_kernels_allowed: True -elastic_depth_allowed: False -elastic_dilation_allowed: False -elastic_width_allowed: False -elastic_grouping_allowed: False -#depthwise_separable_convolution -elastic_dsc_allowed: True -evaluate: True -random_evaluate: True -random_eval_number: 100 diff --git a/hannah/conf/nas/ofa_nas_dsc_long.yaml b/hannah/conf/nas/ofa_nas_dsc_long.yaml deleted file mode 100644 index c72b3203..00000000 --- a/hannah/conf/nas/ofa_nas_dsc_long.yaml +++ /dev/null @@ -1,37 +0,0 @@ -## -## Copyright (c) 2022 University of Tübingen. -## -## This file is part of hannah. -## See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -## -## Licensed under the Apache License, Version 2.0 (the "License"); -## you may not use this file except in compliance with the License. -## You may obtain a copy of the License at -## -## http://www.apache.org/licenses/LICENSE-2.0 -## -## Unless required by applicable law or agreed to in writing, software -## distributed under the License is distributed on an "AS IS" BASIS, -## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -## See the License for the specific language governing permissions and -## limitations under the License. -## -_target_: hannah.nas.OFANasTrainer -epochs_warmup: 35 -epochs_kernel_step: 35 -epochs_depth_step: 35 -epochs_width_step: 35 -epochs_dilation_step: 35 -epochs_grouping_step: 35 -epochs_dsc_step: 35 -elastic_kernels_allowed: True -elastic_depth_allowed: True -elastic_dilation_allowed: False -# with width enabled, the epochs must be 65 for good results -elastic_width_allowed: False -elastic_grouping_allowed: True -#depthwise_separable_convolution -elastic_dsc_allowed: True -evaluate: True -random_evaluate: True -random_eval_number: 100 diff --git a/hannah/conf/nas_ofa.yaml b/hannah/conf/nas_ofa.yaml deleted file mode 100644 index 892c8a0c..00000000 --- a/hannah/conf/nas_ofa.yaml +++ /dev/null @@ -1,34 +0,0 @@ -## -## Copyright (c) 2022 University of Tübingen. -## -## This file is part of hannah. -## See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -## -## Licensed under the Apache License, Version 2.0 (the "License"); -## you may not use this file except in compliance with the License. -## You may obtain a copy of the License at -## -## http://www.apache.org/licenses/LICENSE-2.0 -## -## Unless required by applicable law or agreed to in writing, software -## distributed under the License is distributed on an "AS IS" BASIS, -## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -## See the License for the specific language governing permissions and -## limitations under the License. -## -defaults: - - base_config - - override dataset: kws - - override features: mfcc - - override model: ofa - - override scheduler: 1cycle - - override optimizer: sgd - - override normalizer: fixedpoint - - override module: stream_classifier - - override trainer: default - - override nas: ofa_nas - - _self_ - -experiment_id: nas_ofa -module: - shuffle_all_dataloaders: True diff --git a/hannah/models/ofa/__init__.py b/hannah/models/ofa/__init__.py deleted file mode 100644 index ec6e4559..00000000 --- a/hannah/models/ofa/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# -# Copyright (c) 2022 University of Tübingen. -# -# This file is part of hannah. -# See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from .models import * diff --git a/hannah/models/ofa/models.py b/hannah/models/ofa/models.py deleted file mode 100644 index 78a4c04f..00000000 --- a/hannah/models/ofa/models.py +++ /dev/null @@ -1,1405 +0,0 @@ -# -# Copyright (c) 2022 University of Tübingen. -# -# This file is part of hannah. -# See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import copy -import logging -from typing import List, Tuple - -import numpy as np -import torch.nn as nn -import yaml -from hydra.utils import instantiate -from omegaconf import DictConfig, ListConfig, OmegaConf - -from .submodules.elasticchannelhelper import ElasticChannelHelper -from .submodules.elasticLinear import ElasticQuantWidthLinear, ElasticWidthLinear -from .submodules.resblock import ResBlock1d, ResBlockBase -from .type_utils import ( - elasic_conv_classes, - elastic_all_type, - elastic_conv_type, - elastic_Linear_type, -) -from .utilities import ( - call_function_from_deep_nested, - flatten_module_list, - get_instances_from_deep_nested, - module_list_to_module, -) - - -def create( - name: str, - labels: int, - input_shape, - conv=[], - min_depth: int = 1, - norm_before_act=True, - skew_sampling_distribution: bool = False, - dropout: int = 0.5, - validate_on_extracted=True, - qconfig=None, -) -> nn.Module: - """The function creates a ofa Model with the given name, - labels, input shape, convolutional layers, and other parameters - - Args: - name(str): The name of the model - labels(int): The number of classes in the dataset - input_shape: the shape of the input tensor - conv: a list of MajorBlockConfig objects (Default value = []) - min_depth(int (optional): The minimum depth of the model, defaults to 1 - norm_before_act: If True, the normalization is performed before the - activation function, defaults to True (optional) - skew_sampling_distribution(bool (optional): If True, the model will use a skewed sampling - distribution to sample the number of minor blocks in each major block, defaults - to False - dropout(int): float, default 0.5 - validate_on_extracted: If True, the model will be validated on the - extracted data, defaults to True (optional) - qconfig: the quantization configuration to use (Default value = None) - name: str: - labels: int: - min_depth: int: (Default value = 1) - skew_sampling_distribution: bool: (Default value = False) - dropout: int: (Default value = 0.5) - name: str: - labels: int: - min_depth: int: (Default value = 1) - skew_sampling_distribution: bool: (Default value = False) - dropout: int: (Default value = 0.5) - name: str: - labels: int: - min_depth: int: (Default value = 1) - skew_sampling_distribution: bool: (Default value = False) - dropout: int: (Default value = 0.5) - - Returns: - : A model object. - - """ - - # if no orders for the norm operator are specified, fall back to default - default_qconfig = instantiate(qconfig) if qconfig else None - flatten_n = input_shape[0] - in_channels = input_shape[1] - pool_n = input_shape[2] - # the final output channel count is given by the last minor block of the last major block - final_out_channels = conv[-1].blocks[-1].out_channels - if hasattr(final_out_channels, "__iter__"): - # if the output channel count is a list, get the highest value - final_out_channels = max(final_out_channels) - conv_layers = nn.ModuleList([]) - next_in_channels = in_channels - - for block_config in conv: - if block_config.target == "forward": - major_block = create_minor_block_sequence( - block_config.blocks, - next_in_channels, - stride=block_config.stride, - norm_before_act=norm_before_act, - qconfig=default_qconfig, - ) - - elif block_config.target == "residual1d": - major_block = create_residual_block_1d( - blocks=block_config.blocks, - in_channels=next_in_channels, - stride=block_config.stride, - norm_before_act=norm_before_act, - qconfig=default_qconfig, - quant_skip=block_config.quant_skip - # sources=previous_sources, - ) - - else: - raise Exception( - f"Undefined target selected for major block: {block_config.target}" - ) - # output channel count of the last minor block will be the input channel count of the next major block - next_in_channels = block_config.blocks[-1].out_channels - if hasattr(next_in_channels, "__iter__"): - # if the channel count is a list, get the highest value - next_in_channels = max(next_in_channels) - conv_layers.append(major_block) - - # get the max depth from the count of major blocks - model = OFAModel( - conv_layers=conv_layers, - max_depth=len(conv_layers), - labels=labels, - pool_kernel=pool_n, - flatten_dims=flatten_n, - out_channels=final_out_channels, - min_depth=min_depth, - block_config=conv, - skew_sampling_distribution=skew_sampling_distribution, - dropout=dropout, - validate_on_extracted=validate_on_extracted, - qconfig=default_qconfig, - ) - - # store the name onto the model - setattr(model, "creation_name", name) - - # acquire step counts for OFA progressive shrinking - ofa_steps_depth = len(model.linears) - ofa_steps_kernel = 1 - ofa_steps_width = 1 - ofa_steps_dilation = 1 - ofa_steps_grouping = 1 - ofa_steps_dsc = 1 - for major_block in conv: - for block in major_block.blocks: - if block.target == "elastic_conv1d": - this_block_kernel_steps = len(block.kernel_sizes) - this_block_dilation_steps = len(block.dilation_sizes) - this_block_grouping_steps = len(block.grouping_sizes) - this_block_dsc_steps = len(block.dsc) - this_block_width_steps = len(block.out_channels) - ofa_steps_width = max(ofa_steps_width, this_block_width_steps) - ofa_steps_kernel = max(ofa_steps_kernel, this_block_kernel_steps) - ofa_steps_dilation = max(ofa_steps_dilation, this_block_dilation_steps) - ofa_steps_grouping = max(ofa_steps_grouping, this_block_grouping_steps) - ofa_steps_dsc = max(ofa_steps_dsc, this_block_dsc_steps) - elif block.target == "elastic_channel_helper": - this_block_width_steps = len(block.out_channels) - ofa_steps_width = max(ofa_steps_width, this_block_width_steps) - logging.info( - f"OFA steps are {ofa_steps_kernel} kernel sizes, {ofa_steps_depth} depths, {ofa_steps_width} widths, {ofa_steps_grouping} groups, {ofa_steps_dsc} dscs." - ) - model.ofa_steps_kernel = ofa_steps_kernel - model.ofa_steps_depth = ofa_steps_depth - model.ofa_steps_width = ofa_steps_width - model.ofa_steps_dilation = ofa_steps_dilation - model.ofa_steps_grouping = ofa_steps_grouping - model.ofa_steps_dsc = ofa_steps_dsc - - model.perform_sequence_discovery() - - return model - - -# build a sequence from a list of minor block configurations -def create_minor_block_sequence( - blocks, - in_channels, - stride=1, - norm_before_act=True, - qconfig=None, - # sources: List[nn.Module] = [nn.ModuleList([])], -): - """ - - Args: - blocks: - in_channels: - stride: (Default value = 1) - norm_before_act: (Default value = True) - qconfig: (Default value = None) - # sources: List[nn.Module]: (Default value = [nn.ModuleList([])]) - # sources: List[nn.Module]: (Default value = [nn.ModuleList([])]) - # sources: List[nn.Module]: (Default value = [nn.ModuleList([])]) - - Returns: - - """ - next_in_channels = in_channels - minor_block_sequence = nn.ModuleList([]) - is_first_minor_block = True - - for block_config in blocks: - # set stride on the first minor block in the sequence - if is_first_minor_block: - next_stride = stride - is_first_minor_block = False - else: - next_stride = 1 - - # Set a Default Grouping Sizes (if it is not set in the config) - block_config.grouping_sizes = getattr(block_config, "grouping_sizes", [1]) - # Depthwise Separable Convolution can only be true or false - block_config.dsc = getattr(block_config, "dsc", [False]) - - minor_block, next_in_channels = create_minor_block( - block_config=block_config, - in_channels=next_in_channels, - stride=next_stride, - norm_before_act=norm_before_act, - qconfig=qconfig, - # sources=sources, - ) - - minor_block_sequence.append(minor_block) - - return module_list_to_module(minor_block_sequence) - - -# build a single minor block from its config. return the number of output channels with the block -def create_minor_block( - block_config, - in_channels: int, - stride: int = 1, - norm_before_act=True, - sources: List[nn.ModuleList] = [nn.ModuleList([])], - qconfig=None, -) -> Tuple[nn.Module, int]: - """ - - Args: - block_config: - in_channels: int: - stride: int: (Default value = 1) - norm_before_act: (Default value = True) - sources: List[nn.ModuleList]: (Default value = [nn.ModuleList([])]) - qconfig: (Default value = None) - in_channels: int: - stride: int: (Default value = 1) - sources: List[nn.ModuleList]: (Default value = [nn.ModuleList([])]) - in_channels: int: - stride: int: (Default value = 1) - sources: List[nn.ModuleList]: (Default value = [nn.ModuleList([])]) - - Returns: - - """ - new_block = None - # the output channel count is usually stored in block_config.out_channels - # use it as the default value if available, otherwise it must be set by the specific code handling the target type - new_block_out_channels = getattr(block_config, "out_channels", 1) - - if "conv1d" in block_config.target: - out_channels = block_config.out_channels - if not isinstance(out_channels, ListConfig): - out_channels = [out_channels] - out_channels.sort(reverse=True) - # the maximum available width is the initial output channel count - out_channels_full = out_channels[0] - - kernel_sizes = block_config.kernel_sizes - if not isinstance(kernel_sizes, ListConfig): - kernel_sizes = [kernel_sizes] - - dilation_sizes = block_config.dilation_sizes - if not isinstance(dilation_sizes, ListConfig): - dilation_sizes = [dilation_sizes] - - # grouping_sizes = getattr(block_config, "grouping_sizes", 1) - grouping_sizes = block_config.grouping_sizes - if not isinstance(grouping_sizes, ListConfig): - grouping_sizes = [grouping_sizes] - - dsc = block_config.dsc - if not isinstance(dsc, ListConfig): - dsc = [dsc] - - minor_block_internal_sequence = nn.ModuleList([]) - key = "" - parameter = { - "kernel_sizes": kernel_sizes, - "in_channels": in_channels, - "out_channels": out_channels_full, - "stride": stride, - "dilation_sizes": dilation_sizes, - # new entry edit to group_sizes - "groups": grouping_sizes, - "dscs": dsc, - "out_channel_sizes": out_channels, - } - - if block_config.get("norm", False): - key += "norm" - if block_config.get("act", False): - key += "act" - if block_config.get("quant", False): - key += "quant" - parameter["qconfig"] = qconfig - if key == "": - key = "none" - - if key in elasic_conv_classes.keys(): - new_minor_block = elasic_conv_classes[key](**parameter) - else: - raise Exception( - f"Undefined target selected in minor block sequence: {block_config.target}" - ) - - minor_block_internal_sequence.append(new_minor_block) - - new_block = module_list_to_module( - flatten_module_list(minor_block_internal_sequence) - ) - # the input channel count of the next minor block is the output channel count of the previous block - # output channel count is specified by the elastic conv - new_block_out_channels = new_minor_block.out_channels - elif block_config.target == "elastic_channel_helper": - # if the module is a standalone elastic channel helper, pass the previous block as it's sources - out_channels_list = block_config.out_channels - out_channels_list.sort(reverse=True) - out_channels_full = out_channels_list[0] - new_block = ElasticChannelHelper(out_channels_list) - - if out_channels_full != in_channels: - logging.error( - f"standalone ElasticChannelHelper input width {in_channels} does not match max output channel width {out_channels_full} in list {out_channels_list}" - ) - new_block_out_channels = out_channels_full - # if an unknown target is selected for a minor block, throw an exception. - else: - raise Exception( - f"Undefined target selected in minor block sequence: {block_config.target}" - ) - - # return the new block and its output channel count - return new_block, new_block_out_channels - - -# build a residual major block -def create_residual_block_1d( - blocks, - in_channels, - stride=1, - norm_before_act=None, - qconfig=None, - quant_skip=None, - # sources: List[nn.ModuleList] = [nn.ModuleList([])], -) -> ResBlock1d: - """ - - Args: - blocks: - in_channels: - stride: (Default value = 1) - norm_before_act: (Default value = None) - qconfig: (Default value = None) - quant_skip: (Default value = None) - # sources: List[nn.ModuleList]: (Default value = [nn.ModuleList([])]) - # sources: List[nn.ModuleList]: (Default value = [nn.ModuleList([])]) - # sources: List[nn.ModuleList]: (Default value = [nn.ModuleList([])]) - - Returns: - - """ - minor_blocks = create_minor_block_sequence( - blocks, - in_channels, - stride=stride, - norm_before_act=norm_before_act, - qconfig=qconfig, - ) - # the output channel count of the residual major block is the output channel count of the last minor block - out_channels = blocks[-1].out_channels - if hasattr(out_channels, "__iter__"): - # if the out_channels count is a list, get the highest value - out_channels = max(out_channels) - residual_block = ResBlock1d( - in_channels=in_channels, - out_channels=out_channels, - minor_blocks=minor_blocks, - stride=stride, - norm_before_act=norm_before_act, - quant_skip=quant_skip, - qconfig=qconfig, - ) - return residual_block - - -class OFAModel(nn.Module): - """ """ - - def __init__( - self, - conv_layers: nn.ModuleList([]), - max_depth: int, - labels: int, - pool_kernel: int, - flatten_dims: int, - out_channels: int, - min_depth: int = 1, - block_config=[], - skew_sampling_distribution=False, - dropout=0.5, - validate_on_extracted=False, - qconfig=None, - ): - super().__init__() - self.validate_on_extracted = validate_on_extracted - self.conv_layers = conv_layers - self.max_depth = max_depth - self.active_depth = self.max_depth - self.labels = labels - self.pool_kernel = pool_kernel - self.flatten_dims = flatten_dims - self.out_channels = out_channels - self.block_config = block_config - self.min_depth = min_depth - self.current_step = 0 - self.current_kernel_step = 0 - self.current_channel_step = 0 - self.current_width_step = 0 - self.current_group_step = 0 - self.current_dsc_step = 0 - self.sampling_max_kernel_step = 0 - self.sampling_max_depth_step = 0 - self.sampling_max_width_step = 0 - self.sampling_max_dilation_step = 0 - self.sampling_max_grouping_step = 0 - self.sampling_max_dsc_step = 0 - self.eval_mode = False - self.last_input = None - self.skew_sampling_distribution = skew_sampling_distribution - self.validation_model = None - self.elastic_kernels_allowed = True - self.elastic_depth_allowed = True - self.elastic_width_allowed = True - self.elastic_dilation_allowed = True - self.elastic_grouping_allowed = True - self.elastic_dsc_allowed = True - - self.dropout = nn.Dropout(dropout) - self.pool = nn.AdaptiveAvgPool1d(1) - self.flatten = nn.Flatten(flatten_dims) - self.qconfig = qconfig - - # one linear exit layer for each possible depth level - self.linears = nn.ModuleList([]) - # for every possible depth level (from min_depth to including max_depth) - for i in range(self.min_depth, self.max_depth + 1): - self.active_depth = i - self.update_output_channel_count() - # create the linear output layer for this depth - if self.qconfig is None: - new_output_linear = ElasticWidthLinear(self.out_channels, self.labels) - else: - new_output_linear = ElasticQuantWidthLinear( - self.out_channels, self.labels, qconfig=self.qconfig - ) - self.linears.append(new_output_linear) - - # should now be redundant, as the loop will exit with the active depth being max_depth - self.active_depth = self.max_depth - # ofa step counts will be set by the create function. - self.ofa_steps_kernel = 1 - self.ofa_steps_depth = 1 - self.ofa_steps_width = 1 - self.ofa_steps_dilation = 1 - self.ofa_steps_grouping = 1 - self.ofa_steps_dsc = 1 - - # create a list of every elastic kernel conv, for sampling - all_elastic_kernel_convs = get_instances_from_deep_nested( - input=self.conv_layers, - type_selection=elastic_conv_type, - ) - self.elastic_kernel_convs = [] - for item in all_elastic_kernel_convs: - if item.get_available_kernel_steps() > 1: - # ignore convs with only one available kernel size, they do not need to be stored - self.elastic_kernel_convs.append(item) - logging.info( - f"OFA model accumulated {len(self.elastic_kernel_convs)} elastic kernel convolutions for sampling." - ) - - # create a list of every elastic width helper, for sampling - self.elastic_channel_helpers = get_instances_from_deep_nested( - input=self.conv_layers, type_selection=ElasticChannelHelper - ) - logging.info( - f"OFA model accumulated {len(self.elastic_channel_helpers)} elastic width connections for sampling." - ) - self.block_config = block_config - self.full_config = None - - def extract_conv(self, conv, general_config, parallel=True, bp=False): - """ - - Args: - conv: - general_config: - parallel: (Default value = True) - bp: (Default value = False) - - Returns: - - """ - deletkeys = ["dilation_sizes", "kernel_sizes", "quant"] - for key in deletkeys: - if general_config.get(key) is not None: - del general_config[key] - - general_config["kernel_size"] = conv.kernel_size - general_config["out_channels"] = conv.out_channel_filter.count(True) - general_config["act"] = conv.act - general_config["norm"] = conv.norm - general_config["dilation"] = conv.get_dilation_size() - if bp: - general_config["padding"] = False - general_config["bias"] = False - - if parallel: - general_config["parallel"] = True - - if ( - general_config.get("target") == "elastic_conv1d" - or general_config.get("target", None) is None - ): - general_config["target"] = "conv1d" - - return general_config - - def extract_config(self, conv, general_config, bp=False): - """set all conv attributes in config - - Args: - conv: - general_config: - bp: (Default value = False) - - Returns: - - """ - if isinstance(conv, ResBlockBase): - general_config["target"] = "residual" - if general_config.get("quant_skip") is not None: - del general_config["quant_skip"] - if not isinstance(conv.blocks, nn.Sequential): - config = general_config["blocks"][0] - else: - config = general_config["blocks"] - block = self.extract_config(conv.blocks, config) - block.append(self.extract_conv(conv.skip[0], dict())) - general_config["blocks"] = block - print(block) - elif isinstance(conv, nn.Sequential): - config = general_config - if ( - isinstance(general_config, dict) - and general_config.get("blocks", None) is not None - ): - config = general_config["blocks"] - for tmpconv, layer in zip( - get_instances_from_deep_nested(conv, elastic_conv_type), config - ): - layer = self.extract_config(tmpconv, layer, bp=bp) - elif ( - isinstance(conv, elastic_conv_type) - and general_config.get("blocks", None) is None - ): - general_config = self.extract_conv(conv, general_config, bp=bp) - elif ( - isinstance(conv, elastic_conv_type) - and general_config.get("blocks", None) is not None - ): - general_config["blocks"][0] = self.extract_conv( - conv, general_config["blocks"][0], bp=bp - ) - return general_config - - def print_config(self, filename): - """ - - Args: - filename: - - Returns: - - """ - cfg = copy.copy(self.full_config) - cfg = OmegaConf.to_container(cfg) - - removelist = [ - "_target_", - "name", - "skew_sampling_distribution", - "min_depth", - "norm_before_act", - "dropout", - ] - - for element in removelist: - cfg.pop(element) - first = True - for conv, gen in zip(self.conv_layers, cfg["conv"]): - if first: - gen = self.extract_config(conv, gen, bp=True) - first = False - gen = self.extract_config(conv, gen) - cfg["conv"] = cfg["conv"][0 : self.active_depth] - - d = OmegaConf.to_yaml(cfg) - with open(filename + ".yaml", "w") as f: - f.write("\n") - f.write("_target_: hannah.models.factory.factory.create_cnn\n") - f.write("name: conv_net_trax\n") - f.write("norm:\n") - f.write(" target: bn\n") - f.write("act:\n") - f.write(" target: relu\n") - f.write(d) - - def forward(self, x): - """ - - Args: - x: - - Returns: - - """ - self.last_input = x - self.current_step = self.current_step + 1 - - # in eval mode, run the forward on the extracted validation model. - if self.eval_mode and self.validate_on_extracted: - if self.validation_model is None: - self.build_validation_model() - return self.validation_model.forward(x) - - # if the network is currently being evaluated, don't sample a subnetwork! - if ( - self.sampling_max_depth_step > 0 - or self.sampling_max_kernel_step > 0 - or self.sampling_max_width_step > 0 - or self.sampling_max_grouping_step > 0 - or self.sampling_max_dsc_step > 0 - or self.sampling_max_dilation_step > 0 - ) and not self.eval_mode: - self.sample_subnetwork() - for layer in self.conv_layers[: self.active_depth]: - x = layer(x) - - result = x - result = self.pool(result) - result = self.flatten(result) - result = self.dropout(result) - result = self.get_output_linear_layer(self.active_depth)(result) - - return result - - def perform_sequence_discovery(self): - """ """ - logging.info("Performing model sequence discovery.") - # establisch outer channel Helper - for i in range(len(self.conv_layers) - 1): - pre_block = self.conv_layers[i] - post_block = self.conv_layers[i + 1] - pre_conv = self.get_pre_conv(pre_block) - post_conv = self.get_post_conv(post_block) - - if isinstance(pre_conv, elastic_conv_type): - tmpconv = pre_conv - else: - tmpconv = pre_conv[0] - - if len(tmpconv.out_channel_sizes) > 1: - ech = ElasticChannelHelper(tmpconv.out_channel_sizes) - ech.add_sources(pre_conv) - ech.add_targets(post_conv) - - if i in range(self.min_depth - 1, self.max_depth - 1): - idx = i - (self.min_depth - 1) - ech.add_targets(self.linears[idx]) - self.elastic_channel_helpers.append(ech) - - if isinstance(self.conv_layers[i], ResBlock1d): - chl = self.conv_layers[i].create_internal_channelhelper() - self.elastic_channel_helpers.append(chl) - - if len(self.conv_layers) > 0: - pre_conv = self.get_pre_conv(self.conv_layers[-1]) - - if hasattr(pre_conv, "__iter__"): - out_channels = pre_conv[0].out_channel_sizes - else: - out_channels = pre_conv.out_channel_sizes - - ech = ElasticChannelHelper(out_channels) - ech.add_sources(pre_conv) - ech.add_targets(self.linears[-1]) - self.elastic_channel_helpers.append(ech) - - self.elastic_channel_helpers = flatten_module_list(self.elastic_channel_helpers) - - def get_post_conv(self, post_block): - """ - - Args: - post_block: - - Returns: - - """ - post_conv = None - if isinstance(post_block, ResBlock1d): - post_conv = post_block.get_input_layer() - elif isinstance(post_block, nn.Sequential): - post_conv = flatten_module_list(post_block)[0] - elif isinstance(post_block, elastic_conv_type): - post_conv = post_block - return post_conv - - def get_pre_conv(self, pre_block): - """ - - Args: - pre_block: - - Returns: - - """ - pre_conv = None - if isinstance(pre_block, ResBlock1d): - pre_conv = pre_block.get_output_layer() - elif isinstance(pre_block, nn.Sequential): - pre_conv = flatten_module_list(pre_block)[-1] - elif isinstance(pre_block, elastic_conv_type): - pre_conv = pre_block - return pre_conv - - # pick a random subnetwork, return the settings used - def sample_subnetwork(self): - """ """ - state = { - "depth_step": 0, - "kernel_steps": [], - "dilation_steps": [], - "width_steps": [], - "grouping_steps": [], - "dsc_steps": [], - } - if self.elastic_depth_allowed: - new_depth_step = self.get_random_step(self.sampling_max_depth_step + 1) - self.active_depth = self.max_depth - new_depth_step - state["depth_step"] = new_depth_step - - if self.elastic_kernels_allowed: - for conv in self.elastic_kernel_convs: - # pick an available kernel index for every elastic kernel conv, independently. - max_available_sampling_step = min( - self.sampling_max_kernel_step + 1, conv.get_available_kernel_steps() - ) - new_kernel_step = self.get_random_step(max_available_sampling_step) - conv.pick_kernel_index(new_kernel_step) - state["kernel_steps"].append(new_kernel_step) - - if self.elastic_dilation_allowed: - for conv in self.elastic_kernel_convs: - # pick an available kernel index for every elastic kernel conv, independently. - max_available_sampling_step = min( - self.sampling_max_dilation_step + 1, - conv.get_available_dilation_steps(), - ) - new_dilation_step = self.get_random_step(max_available_sampling_step) - conv.pick_dilation_index(new_dilation_step) - state["dilation_steps"].append(new_dilation_step) - if self.elastic_grouping_allowed: - for conv in self.elastic_kernel_convs: - max_available_sampling_step = min( - self.sampling_max_grouping_step + 1, - conv.get_available_grouping_steps(), # zero index array - ) - new_grouping_step = self.get_random_step(max_available_sampling_step) - conv.pick_group_index(new_grouping_step) - state["grouping_steps"].append(new_grouping_step) - if self.elastic_dsc_allowed: - for conv in self.elastic_kernel_convs: - max_available_sampling_step = min( - self.sampling_max_dsc_step + 1, - conv.get_available_dsc_steps(), # zero index array - ) - new_dsc_step = self.get_random_step(max_available_sampling_step) - conv.pick_dsc_index(new_dsc_step) - state["dsc_steps"].append(new_dsc_step) - - if self.elastic_width_allowed: - for helper in self.elastic_channel_helpers: - # pick an available width step for every elastic channel helper, independently. - max_available_sampling_step = min( - self.sampling_max_width_step + 1, helper.get_available_width_steps() - ) - new_width_step = self.get_random_step(max_available_sampling_step) - helper.set_channel_step(new_width_step) - state["width_steps"].append(new_width_step) - - return state - - # get a step, with distribution biased towards taking less steps, if skew distribution is enabled. - # currently a sort-of pseudo-geometric distribution, may be replaced with better RNG - def get_random_step(self, upper_bound: int) -> int: - """ - - Args: - upper_bound: int: - upper_bound: int: - upper_bound: int: - - Returns: - - """ - if upper_bound <= 0: - logging.warn("requested impossible random step <= 0. defaulting to 0.") - return 0 - if (not self.skew_sampling_distribution) or self.eval_mode: - # during random submodel evaluation, use uniform distribution - return np.random.randint(upper_bound) - else: - acc = 0 - while np.random.randint(2) and acc < upper_bound: - # continue incrementing with a 1/2 chance per additional increment - acc += 1 - if acc == upper_bound: - # if the bound was reached, go back below the bound. - # due to this, the distribution of probability toward the last element - # is not consistent with the distribution gradient across other elements - acc -= 1 - return acc - - # return max available step values - def get_max_submodel_steps(self): - """ """ - max_depth_step = self.sampling_max_depth_step - kernel_steps = [] - width_steps = [] - dilation_steps = [] - grouping_steps = [] - dsc_steps = [] - - for conv in self.elastic_kernel_convs: - kernel_steps.append(conv.get_available_kernel_steps()) - - for conv in self.elastic_kernel_convs: - # MR -> Any FIXME das kann nicht für dilation stimmen, dass alle convs an den kernel step appended werden, oder? - kernel_steps.append(conv.get_available_dilation_steps()) - - # for grouping - for conv in self.elastic_kernel_convs: - grouping_steps.append(conv.get_available_grouping_steps) - for conv in self.elastic_kernel_convs: - dsc_steps.append(conv.get_available_dsc_steps) - - for helper in self.elastic_channel_helpers: - width_steps.append(helper.get_available_width_steps()) - - state = { - "depth_step": max_depth_step, - "kernel_steps": kernel_steps, - "width_steps": width_steps, - "dilation_steps": dilation_steps, - "grouping_steps": grouping_steps, - "dsc_steps": dsc_steps, - } - return state - - # accept a state dict like the one returned in get_max_submodel_steps, return extracted submodel. - # also sets main model state to this submodel. - def get_submodel(self, state: dict): - """ - - Args: - state: dict: - state: dict: - state: dict: - - Returns: - - """ - if not self.set_submodel(state): - return None - else: - return self.extract_elastic_depth_sequence(self.active_depth) - - def on_warmup_end(self): - """ """ - for element in self.elastic_kernel_convs: - if hasattr(element, "on_warmup_end"): - element.on_warmup_end() - - # accept a state dict like the one returned in get_max_submodel_steps, sets model state. - def set_submodel(self, state: dict): - """ - - Args: - state: dict: - state: dict: - state: dict: - - Returns: - - """ - try: - depth_step = state["depth_step"] - kernel_steps = state["kernel_steps"] - width_steps = state["width_steps"] - dilation_steps = state["dilation_steps"] - grouping_steps = state["grouping_steps"] - dsc_steps = state["dsc_steps"] - except KeyError: - logging.error( - "Invalid state dict passed to get_submodel! Keys should be 'depth_step', 'kernel_steps', 'width_steps'!" - ) - return False - - if len(kernel_steps) != len(self.elastic_kernel_convs): - print( - f"State dict provides invalid amount of kernel steps: model has {len(self.elastic_kernel_convs)}, {len(kernel_steps)} provided." - ) - return False - if len(grouping_steps) != len(self.elastic_kernel_convs): - print( - f"State dict provides invalid amount of grouping steps: model has {len(self.elastic_kernel_convs)}, {len(grouping_steps)} provided." - ) - return False - if len(width_steps) != len(self.elastic_channel_helpers): - print( - f"State dict provides invalid amount of width steps: model has {len(self.elastic_channel_helpers)}, {len(width_steps)} provided." - ) - return False - - # FIXME channelhelper must be wrong for dilation step - if len(dilation_steps) != len(self.elastic_channel_helpers): - print( - f"State dict provides invalid amount of width steps: model has {len(self.elastic_channel_helpers)}, {len(dilation_steps)} provided." - ) - return False - - self.active_depth = self.max_depth - depth_step - for i in range(len(kernel_steps)): - self.elastic_kernel_convs[i].pick_kernel_index(kernel_steps[i]) - for i in range(len(dilation_steps)): - self.elastic_kernel_convs[i].pick_dilation_index(dilation_steps[i]) - # grouping - for i in range(len(grouping_steps)): - self.elastic_kernel_convs[i].pick_group_index(grouping_steps[i]) - for i in range(len(dsc_steps)): - self.elastic_kernel_convs[i].pick_dsc_index(dsc_steps[i]) - for i in range(len(width_steps)): - self.elastic_channel_helpers[i].set_channel_step(width_steps[i]) - - return True - - def build_validation_model(self): - """ """ - self.validation_model = self.extract_elastic_depth_sequence(self.active_depth) - return self.validation_model - - def reset_validation_model(self): - """ """ - self.validation_model = None - - def get_validation_model_weight_count(self): - """ """ - val_not_exist = self.validation_model is None - - if val_not_exist: - self.build_validation_model() - # create a dict of the pointer of each parameter to the item count within that parameter - # using a dict with pointers as keys ensures that no parameter is counted twice - parameter_pointers_dict = dict( - (p.data_ptr(), p.numel()) for p in self.validation_model.parameters() - ) - # sum up the values of each dict item, yielding the total element count across params - if val_not_exist: - self.reset_validation_model() - return sum(parameter_pointers_dict.values()) - - # return an extracted module sequence for a given depth - def extract_elastic_depth_sequence( - self, target_depth, quantized=False, clone_mode=False - ): - """ - - Args: - target_depth: - quantized: (Default value = False) - clone_mode: (Default value = False) - - Returns: - - """ - if target_depth < self.min_depth or target_depth > self.max_depth: - raise Exception( - f"attempted to extract submodel for depth {target_depth} where min: {self.min_depth} and max: {self.max_depth}" - ) - extracted_module_list = nn.ModuleList([]) - - if clone_mode: - for layer in self.conv_layers[:target_depth]: - extracted_module_list.append(layer) - else: - rebuild_output = rebuild_extracted_blocks(self.conv_layers[:target_depth]) - extracted_module_list.append(module_list_to_module(rebuild_output)) - - extracted_module_list.append(self.pool) - extracted_module_list.append(self.flatten) - extracted_module_list.append(self.dropout) - output_linear = self.get_output_linear_layer(target_depth) - if isinstance(output_linear, elastic_Linear_type): - output_linear = output_linear.assemble_basic_module() - extracted_module_list.append(output_linear) - - return copy.deepcopy(nn.Sequential(*extracted_module_list)) - - # return extracted module for a given progressive shrinking depth step - def extract_module_from_depth_step(self, depth_step) -> nn.Module: - """ - - Args: - depth_step: - - Returns: - - """ - torch_module = self.extract_elastic_depth_sequence(self.max_depth - depth_step) - return torch_module - - def get_elastic_depth_output(self, target_depth=None, quantized=False): - """ - - Args: - target_depth: (Default value = None) - quantized: (Default value = False) - - Returns: - - """ - if target_depth is None: - target_depth = self.max_depth - if self.last_input is None: - return None - submodel = self.extract_elastic_depth_sequence( - target_depth, quantized=quantized - ) - output = submodel(self.last_input) - return output - - # step all input widths within the model down by one, if possible - def step_down_all_channels(self): - """ """ - return call_function_from_deep_nested( - input=self.conv_layers, - function="step_down_channels", - type_selection=ElasticChannelHelper, - ) - - def reset_active_depth(self): - """ """ - self.active_depth = self.max_depth - - # resume: return to the elastic values from before a reset - def resume_active_elastic_values(self): - """ """ - self.resume_kernel_sizes_from_step() - - # set the output channel count value based on the current active depth - def update_output_channel_count(self): - """ """ - # the new out channel count is given by the last minor block of the last active major block - last_active_major_block = self.block_config[: self.active_depth][-1].blocks[-1] - self.out_channels = last_active_major_block.out_channels - - # the code below this is probably no longer doing anything, TBD. - if hasattr(self.out_channels, "__iter__"): - # if the out_channels count is a list, get the highest value - self.out_channels = max(self.out_channels) - # get the very last module of the last active layer. It must be an elastic channel helper, as the channel count is a list. - last_active_item = self.conv_layers[: self.active_depth][-1] - if isinstance(last_active_item, ResBlock1d): - # incase of a residual layer being at the end, the helper will be at the end of its blocks - last_active_item = last_active_item.blocks - if hasattr(last_active_item, "__iter__"): - # flatten the iterable list of items to actually access the last module, and not some nested Sequential - last_active_item = flatten_module_list(last_active_item) - # a layer ususally contains multiple modules and is iterable. - # Pick the last module within the layer. - last_active_item = last_active_item[-1] - - # return the linear layer which processes the output for the current elastic depth - def get_output_linear_layer(self, target_depth): - """ - - Args: - target_depth: - - Returns: - - """ - return self.linears[target_depth - self.min_depth] - - # step all elastic kernels within the model down by one, if possible - def step_down_all_kernels(self): - """ """ - return call_function_from_deep_nested( - input=self.conv_layers, - function="step_down_kernel_size", - type_selection=elastic_conv_type, - ) - - # reset all kernel sizes to their max value - def reset_all_kernel_sizes(self): - """ """ - return call_function_from_deep_nested( - input=self.conv_layers, - function="reset_kernel_size", - type_selection=elastic_conv_type, - ) - - # reset all kernel sizes to their max value - def reset_all_dilation_sizes(self): - """ """ - return call_function_from_deep_nested( - input=self.conv_layers, - function="reset_dilation_size", - type_selection=elastic_conv_type, - ) - - # reset all group sizes to their max value - def reset_all_group_sizes(self): - """ """ - return call_function_from_deep_nested( - input=self.conv_layers, - function="reset_group_size", - type_selection=elastic_conv_type, - ) - - def reset_all_dsc(self): - """ """ - return call_function_from_deep_nested( - input=self.conv_layers, - function="reset_dscs", - type_selection=elastic_conv_type, - ) - - # step all elastic kernels within the model down by one, if possible - def step_down_all_dilations(self): - """ """ - return call_function_from_deep_nested( - input=self.conv_layers, - function="step_down_dilation_size", - type_selection=elastic_conv_type, - ) - - # step all elastic groups within the model down by one, if possible - def step_down_all_groups(self): - """ """ - return call_function_from_deep_nested( - input=self.conv_layers, - function="step_down_group_size", # In ChannelHelper implementieren - type_selection=elastic_conv_type, - ) - - def step_down_all_dsc(self): - """ """ - return call_function_from_deep_nested( - input=self.conv_layers, - function="step_down_dsc", # In ChannelHelper implementieren - type_selection=elastic_conv_type, - ) - - # go to a specific kernel step - def go_to_kernel_step(self, step: int): - """ - - Args: - step: int: - step: int: - step: int: - - Returns: - - """ - self.current_kernel_step = step - self.resume_kernel_sizes_from_step() - - # go back to the kernel sizes specified by the current step - # call after reset_all_kernel_sizes to resume - def resume_kernel_sizes_from_step(self): - """ """ - # save the current step, resetting may also reset the value for some future implementations - step = self.current_kernel_step - # reset kernel sizes to start from a known point - self.reset_all_kernel_sizes() - self.current_kernel_step = step - for _ in range(self.current_kernel_step): - # perform one step down call for each current kernel step - if not self.step_down_all_kernels(): - # if this iteration of stepping down kernel size returned false, - # there were no kernels to step down. Further iterations are not necessary - break - - # reset all kernel sizes to their max value - def reset_all_widths(self): - """ """ - return call_function_from_deep_nested( - input=self.conv_layers, - function="reset_channel_step", - type_selection=ElasticChannelHelper, - ) - - def progressive_shrinking_add_kernel(self): - """ """ - self.sampling_max_kernel_step += 1 - if self.sampling_max_kernel_step >= self.ofa_steps_kernel: - self.sampling_max_kernel_step -= 1 - logging.warn( - f"excessive OFA kernel stepping! Attempting to add a kernel step when max ({self.ofa_steps_kernel}) already reached" - ) - - def progressive_shrinking_add_dilation(self): - """ """ - self.sampling_max_dilation_step += 1 - if self.sampling_max_dilation_step >= self.ofa_steps_dilation: - self.sampling_max_dilation_step -= 1 - logging.warn( - f"excessive OFA kernel stepping! Attempting to add a kernel step when max ({self.ofa_steps_dilation}) already reached" - ) - - def progressive_shrinking_add_depth(self): - """ """ - self.sampling_max_depth_step += 1 - if self.sampling_max_depth_step >= self.ofa_steps_depth: - self.sampling_max_depth_step -= 1 - logging.warn( - f"excessive OFA depth stepping! Attempting to add a depth step when max ({self.ofa_steps_depth}) already reached" - ) - - def progressive_shrinking_add_group(self): - """ """ - self.sampling_max_grouping_step += 1 - if self.sampling_max_grouping_step >= self.ofa_steps_grouping: - self.sampling_max_grouping_step -= 1 - logging.warn( - f"excessive OFA group stepping! Attempting to add a grouping step when max ({self.ofa_steps_grouping}) already reached" - ) - - def progressive_shrinking_add_dsc(self): - """ """ - self.sampling_max_dsc_step += 1 - if self.sampling_max_dsc_step >= self.ofa_steps_dsc: - self.sampling_max_dsc_step -= 1 - logging.warn( - f"excessive OFA group stepping! Attempting to add a dsc step when max ({self.ofa_steps_dsc}) already reached" - ) - - def progressive_shrinking_compute_channel_priorities(self): - """ """ - call_function_from_deep_nested( - input=self.conv_layers, - function="compute_channel_priorities", - type_selection=ElasticChannelHelper, - ) - - def progressive_shrinking_add_width(self): - """ """ - self.sampling_max_width_step += 1 - if self.sampling_max_width_step >= self.ofa_steps_width: - self.sampling_max_width_step -= 1 - logging.warn( - f"excessive OFA depth stepping! Attempting to add a width step when max ({self.ofa_steps_width}) already reached" - ) - - def progressive_shrinking_disable_sampling(self): - """ """ - self.sampling_max_kernel_step = 0 - self.sampling_max_depth_step = 0 - self.sampling_max_width_step = 0 - self.sampling_max_grouping_step = 0 - self.sampling_max_dsc_step = 0 - - def reset_shrinking(self): - """ """ - self.reset_validation_model() - self.reset_all_widths() - self.reset_all_kernel_sizes() - self.reset_all_dilation_sizes() - self.reset_all_group_sizes() - self.reset_all_dsc() - self.reset_active_depth() - - -def rebuild_extracted_blocks(blocks): - """ - - Args: - blocks: - - Returns: - - """ - out_modules = nn.ModuleList([]) - - if blocks is None: - raise ValueError("input blocks are None value") - - # if the input is not iterable, encase it in a moduleList - elif not hasattr(blocks, "__iter__"): - if not isinstance(blocks, nn.Module): - raise TypeError("Input blocks are neither iterable nor Module") - blocks = nn.ModuleList([blocks]) - - if isinstance(blocks, (nn.Sequential, nn.ModuleList)): - modules = nn.ModuleList([]) - for item in blocks: - modules.append(item) - - modules = flatten_module_list(modules) - - input_modules_flat_length = len(modules) - - # if the module is an elastic module, it is replaced by an equivalent basic module for its current state - for i in range(len(modules)): - module = modules[i] - reassembled_module = None - if isinstance(module, elastic_all_type): - reassembled_module = module.assemble_basic_module() - elif isinstance(module, ResBlockBase): - # reassemble both the subblocks and the skip layer separately, then put them into a new ResBlock - reassembled_subblocks = module_list_to_module( - rebuild_extracted_blocks(module.blocks) - ) - reassembled_skip = module_list_to_module( - rebuild_extracted_blocks(module.skip) - ) - reassembled_module = ResBlockBase( - module.in_channels, module.out_channels - ) - reassembled_module.blocks = reassembled_subblocks - reassembled_module.skip = reassembled_skip - act = module.act - if isinstance(act, elastic_all_type): - act = act.assemble_basic_module() - reassembled_module.do_act = module.do_act - reassembled_module.act = act - - elif isinstance(module, ElasticChannelHelper): - # elastic channel helper modules are not extracted in a rebuild. - # The active filter will be applied to each module. - # to ensure that the length validation still works, reduce input module count by one. - input_modules_flat_length -= 1 - else: - logging.warn( - f"unknown module found during extract/rebuild '{type(module)}'. Ignoring." - ) - - if reassembled_module is not None: - out_modules.append(reassembled_module) - - out_modules = flatten_module_list(out_modules) - if input_modules_flat_length != len(out_modules): - logging.info("Reassembly changed length of module list") - return out_modules diff --git a/hannah/models/ofa/submodules/__init__.py b/hannah/models/ofa/submodules/__init__.py deleted file mode 100644 index bac36f20..00000000 --- a/hannah/models/ofa/submodules/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# -# Copyright (c) 2022 University of Tübingen. -# -# This file is part of hannah. -# See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# diff --git a/hannah/models/ofa/submodules/elasticBase.py b/hannah/models/ofa/submodules/elasticBase.py deleted file mode 100644 index ca03a340..00000000 --- a/hannah/models/ofa/submodules/elasticBase.py +++ /dev/null @@ -1,960 +0,0 @@ -# -# Copyright (c) 2022 University of Tübingen. -# -# This file is part of hannah. -# See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import copy -import logging -from pyclbr import Function -from typing import List - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as nnf - -from hannah.models.factory import qat - -from ..utilities import ( - adjust_weights_for_grouping, - conv1d_get_padding, - filter_primary_module_weights, - filter_single_dimensional_weights, - get_kernel_for_dsc, - prepare_kernel_for_depthwise_separable_convolution, - prepare_kernel_for_pointwise_convolution, - sub_filter_start_end, -) - - -# It's a wrapper for a convolutional layer that allows for the number of input and -# output channels to be changed -class _Elastic: - """ """ - - def __init__(self, in_channel_filter, out_channel_filter, out_channel_sizes=None): - self.in_channel_filter: int = in_channel_filter - self.out_channel_filter: int = out_channel_filter - self.out_channel_sizes: List[int] = out_channel_sizes - - # return a normal conv1d equivalent to this module in the current state - def get_basic_module(self) -> nn.Module: - """ """ - return None - - def get_out_channel_sizes(self): - """ """ - return self.out_channel_sizes - - # return a safe copy of a conv1d equivalent to this module in the current state - def assemble_basic_module( - self, - ) -> nn.Module: # Module, so that Sequentials are possible, like DSC - return copy.deepcopy(self.get_basic_module()) - - def set_out_channel_filter(self, out_channel_filter): - """ - - Args: - ) -> nn.Module: # Module: - so that Sequentials are possible: - like DSCreturn copy.deepcopy(self.get_basic_module(): - out_channel_filter: - ) -> nn.Module: # Module: - like DSCreturn copy.deepcopy(self.get_basic_module(): - ) -> nn.Module: # Module: - like DSCreturn copy.deepcopy(self.get_basic_module(): - ) -> nn.Module: # Module: - like DSCreturn copy.deepcopy(self.get_basic_module())set_out_channel_filter(self: - - Returns: - - """ - if out_channel_filter is not None: - self.out_channel_filter = out_channel_filter - if hasattr(self, "bn") and hasattr(self.bn, "__iter__"): - for element in self.bn: - element.channel_filter = out_channel_filter - elif hasattr(self, "bn") and not hasattr(self.bn, "__iter__"): - self.bn.channel_filter = out_channel_filter - - def set_in_channel_filter(self, in_channel_filter): - """ - - Args: - in_channel_filter: - - Returns: - - """ - if in_channel_filter is not None: - self.in_channel_filter = in_channel_filter - - -# It's a 1D convolutional layer that can change its kernel size and dilation size -class ElasticBase1d(nn.Conv1d, _Elastic): - """ """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_sizes: List[int], - dilation_sizes: List[int], - groups: List[int], - dscs: List[bool], - stride: int = 1, - padding: int = 0, - bias: bool = False, - padding_mode: str = "zeros", - out_channel_sizes=None, - ): - # sort available kernel sizes from largest to smallest (descending order) - kernel_sizes.sort(reverse=True) - # make sure 0 is not set as kernel size. Must be at least 1 - if 0 in kernel_sizes: - kernel_sizes.remove(0) - self.kernel_sizes: List[int] = kernel_sizes - # after sorting kernel sizes, the maximum and minimum size available are the first and last element - self.max_kernel_size: int = kernel_sizes[0] - self.min_kernel_size: int = kernel_sizes[-1] - # initially, the target size is the full kernel - self.target_kernel_index: int = 0 - - # sort available dilation sizes from largest to smallest (descending order) - dilation_sizes.sort(reverse=False) - # make sure 0 is not set as dilation size. Must be at least 1 - if 0 in dilation_sizes: - dilation_sizes.remove(0) - self.dilation_sizes: List[int] = dilation_sizes - # after sorting dilation sizes, the maximum and minimum size available are the first and last element - self.max_dilation_size: int = dilation_sizes[-1] - self.min_dilation_size: int = dilation_sizes[0] - # initially, the target size is the smallest dilation (1) - self.target_dilation_index: int = 0 - - self.in_channels: int = in_channels - self.out_channels: int = out_channels - - # dynamic width changes the in and out_channels - # hence we save then here - self.initial_in_channels: int = in_channels - self.initial_out_channels: int = out_channels - - # sort available grouping sizes from largest to smallest (descending order) - groups.sort(reverse=False) - # make sure 0 is not set as grouping size. Must be at least 1 - if 0 in groups: - groups.remove(0) - - self.group_sizes: List[int] = groups - - self.max_group_size: int = self.group_sizes[-1] - self.min_group_size: int = self.group_sizes[0] - self.target_group_index: int = 0 - self.last_grouping_param = self.get_group_size() - - # set the groups value in the model - self.groups = self.get_group_size() - - # sort available grouping sizes from largest to smallest (descending order) - dscs.sort(reverse=False) - # make sure 0 is not set as grouping size. Must be at least 1 - self.dscs: List[bool] = dscs - - self.max_dsc: bool = self.dscs[-1] - self.min_dsc: bool = self.dscs[0] - self.target_dsc_index: int = 0 - - # store first grouping param - - # needed for speedup of check if weight needs to be adjusted for grouping - self.last_dsc_param = self.get_dsc() - - # set the groups value in the model - self.dsc_on = self.get_dsc() - - self.padding = conv1d_get_padding( - self.kernel_sizes[self.target_kernel_index], - self.dilation_sizes[self.target_dilation_index], - ) - nn.Conv1d.__init__( - self, - in_channels=self.in_channels, - out_channels=self.out_channels, - kernel_size=self.max_kernel_size, - stride=stride, - padding=self.padding, - dilation=self.dilation_sizes[self.target_dilation_index], - groups=self.group_sizes[self.target_group_index], - bias=bias, - ) - - _Elastic.__init__( - self, [True] * in_channels, [True] * out_channels, out_channel_sizes - ) - - # the list of kernel transforms will have one element less than the list of kernel sizes. - # between every two sequential kernel sizes, there will be a kernel transform - # the subsequent kernel is determined by applying the same-size center of the previous kernel to the transform - self.kernel_transforms = nn.ModuleList([]) - for i in range(len(kernel_sizes) - 1): - # the target size of the kernel transform is the next kernel size in the sequence - new_kernel_size = kernel_sizes[i + 1] - # kernel transform is kept minimal by being shared between channels. - # It is simply a linear transformation from the center of the previous kernel to the new kernel - # directly applying the kernel to the transform is possible: nn.Linear accepts - # multi-dimensional input in a way where the last input dim is transformed - # from in_channels to out_channels for the last output dim - new_transform_module = nn.Linear( - new_kernel_size, new_kernel_size, bias=False - ) - # initialise the transform as the identity matrix to start training - # from the center of the larger kernel - new_transform_module.weight.data.copy_(torch.eye(new_kernel_size)) - # transform weights are initially frozen - new_transform_module.weight.requires_grad = True - self.kernel_transforms.append(new_transform_module) - self.set_kernel_size(self.max_kernel_size) - """ - self.dilation_transforms = nn.ModuleList() - for k in range(len(kernel_sizes)): - self.dilation_transforms.append(nn.ModuleList()) - for i in range(len(dilation_sizes) - 1): - new_transform_module = nn.Linear( - self.kernel_sizes[k], self.kernel_sizes[k], bias=False - ) - # initialise the transform as the identity matrix to start training - # from the center of the larger kernel - new_transform_module.weight.data.copy_(torch.eye(self.kernel_sizes[k])) - # transform weights are initially frozen - new_transform_module.weight.requires_grad = True - self.dilation_transforms[k].append(new_transform_module) - """ - self.update_padding() - - def set_bn_parameter(self, conv: nn.Conv1d, tmp_bn, num_tracked): - """Caller for BatchNorm Parameters - This unifies the call in the different methods, especially in dsc / not dsc forward - And assigns the attributes in tmp_bn to the param conv - - Args: - conv: nn.Conv1d: - tmp_bn: - num_tracked: - conv: nn.Conv1d: - conv: nn.Conv1d: - conv: nn.Conv1d: - - Returns: - - """ - conv.bn.num_features = tmp_bn.num_features - conv.bn.weight = tmp_bn.weight - conv.bn.bias = tmp_bn.bias - conv.bn.running_var = tmp_bn.running_var - conv.bn.running_mean = tmp_bn.running_mean - conv.bn.num_batches_tracked = num_tracked - return conv - - def prepare_dsc_for_validation_model( - self, - conv_class: nn.Module, - full_kernel, - full_bias, - in_channels, - out_channels, - grouping, - stride, - padding, - dilation, - # for quant - qconfig=None, - out_quant=None, - bn_eps=None, - bn_momentum=None, - bn_caller: tuple = None, - ): - """This method creates the necessary validation models for DSC. - It creates the validation model as torch.Sequence of standard pytorch convolution models. - The structure is analog to the DSC method do_dsc. - This method can also handle quantization models. - - Args: - conv_class: nn.Module: - full_kernel: - full_bias: - in_channels: - out_channels: - grouping: - stride: - padding: - dilation: - # for quantqconfig: (Default value = None) - out_quant: (Default value = None) - bn_eps: (Default value = None) - bn_momentum: (Default value = None) - bn_caller: tuple: (Default value = None) - conv_class: nn.Module: - bn_caller: tuple: (Default value = None) - conv_class: nn.Module: - bn_caller: tuple: (Default value = None) - conv_class: nn.Module: - bn_caller: tuple: (Default value = None) - - Returns: - - """ - - # use qconfig and out_quant parameters - is_quant = qconfig is not None and out_quant is not None - uses_batch_norm = bn_eps is not None and bn_momentum is not None - - # depthwise - ( - filtered_kernel_depth, - bias, - ) = prepare_kernel_for_depthwise_separable_convolution( - self, kernel=full_kernel, bias=full_bias, in_channels=in_channels - ) - in_channel_depth = in_channels - - param_depthwise_conv: dict = { - "in_channels": in_channel_depth, - "out_channels": filtered_kernel_depth.size(0), - "kernel_size": filtered_kernel_depth.size(2), - "bias": bias, - "groups": in_channel_depth, - "padding": padding, - } - - # is either quant Conv1d or normal Conv1d - depthwise_separable = nn.Conv1d(**param_depthwise_conv) - depthwise_separable.weight.data = filtered_kernel_depth - - # pointwise convolution - kernel, bias = self.get_kernel() - filtered_kernel_point = prepare_kernel_for_pointwise_convolution( - kernel=kernel, grouping=grouping - ) - - param_point_conv: dict = { - "in_channels": in_channels, - "out_channels": out_channels, - "kernel_size": filtered_kernel_point.size(2), - "bias": bias, - "groups": grouping, - "stride": stride, - "dilation": dilation, - } - if is_quant: - param_point_conv["qconfig"] = qconfig - param_point_conv["out_quant"] = out_quant - - if is_quant and uses_batch_norm: - param_point_conv["eps"] = bn_eps - param_point_conv["momentum"] = bn_momentum - - # create a convolution model, from the kwargs params above - pointwise = conv_class(**param_point_conv) - pointwise.weight.data = filtered_kernel_point - - if bn_caller: - # if batchnorm should be applied call bn_caller - bn_function, bn_norm, num_tracked = bn_caller - pointwise = bn_function(pointwise, bn_norm, num_tracked) - - depthwise_separable_conv = nn.Sequential(depthwise_separable, pointwise) - - return depthwise_separable_conv - - def do_dsc( - self, - input, - full_kernel, - full_bias, - grouping, - stride, - padding, - dilation, - quant_weight=None, - quant_bias=None, - quant_weight_function: Function = None, # for quantization - quant_bias_function: Function = None, # for quantization - ): - - """This method will perform the DSC(=Depthwise Separable Convolution). - This method can also handle quantized models. - DSC is done in two steps: - 1. Depthwise Separable: Set Group = In_Channels, Output = k*In_Channels - 2. Pointwise Convolution, with Grouping = Grouping-Param und Out_Channel = Out_Channel-Param - - The Params above are used for quantized models - - Args: - input: - full_kernel: - full_bias: - grouping: - stride: - padding: - dilation: - quant_weight: (Default value = None) - quant_bias: (Default value = None) - quant_weight_function: Function: (Default value = None) - # for quantizationquant_bias_function: Function: (Default value = None) - # for quantization: - quant_weight_function: Function: (Default value = None) - # for quantizationquant_bias_function: Function: (Default value = None) - quant_weight_function: Function: (Default value = None) - # for quantizationquant_bias_function: Function: (Default value = None) - quant_weight_function: Function: (Default value = None) - # for quantizationquant_bias_function: Function: (Default value = None) - - Returns: - - """ - use_fake_weight = quant_weight_function is not None - use_fake_bias = quant_bias_function is not None and full_bias is not None - - # get the **actual** count of in_channels - in_channels = self.in_channel_filter.count(True) - - filtered_kernel, bias = prepare_kernel_for_depthwise_separable_convolution( - self, kernel=full_kernel, bias=full_bias, in_channels=in_channels - ) - # params for depthwise separable convolution - param_depthwise_conv: dict = { - "input": input, - "weight": filtered_kernel - if use_fake_weight is False - else quant_weight_function(filtered_kernel), - "bias": bias if use_fake_bias is False else quant_bias_function(bias), - "groups": in_channels, - "padding": padding, - } - # do depthwise - res_depthwise = nnf.conv1d( - **param_depthwise_conv - # Important!! Kein Stride, keine Dilation, da sonst der Effekt von Depthwise daneben geht. - # Dies kann dann im nächsten Step nachgeholt werden. Sonst stimmt der Output nicht. - ) - use_quant = torch.is_tensor(quant_weight) and torch.is_tensor(quant_bias) - - if use_quant: - kernel, bias = quant_weight, quant_bias - filtered_kernel = get_kernel_for_dsc(kernel) - else: - kernel, bias = self.get_kernel() - filtered_kernel = prepare_kernel_for_pointwise_convolution( - kernel=kernel, grouping=grouping - ) - - # pointwise convolution - param_point_conv: dict = { - "input": res_depthwise, - "weight": filtered_kernel - if use_fake_weight is False - else quant_weight_function(filtered_kernel), - "bias": bias if use_fake_bias is False else quant_bias_function(bias), - "groups": grouping, - "stride": stride, - "dilation": dilation, - } - - res_pointwise = nnf.conv1d(**param_point_conv) - return res_pointwise - - def set_in_and_out_channel(self, kernel, filtered: bool = True): - """This method uses the kernel for setting the input and outputchannel - if dynamic width is activated (channelfilters), the amount of channels is reduced, - hence we can't use the initial values (self.(in/out)_channel) of the constructor - - This method sets the self.(in/out)_channel value to the right amount of channels - extracted from the kernel that will be used. - - if filtered is False, the self.initial_(in/out)_channels will be used. - - The previous values will be stored in the attribute prev_in_channels and prev_out_channels. - - Args: - kernel: - filtered: bool: (Default value = True) - filtered: bool: (Default value = True) - filtered: bool: (Default value = True) - filtered: bool: (Default value = True) - - Returns: - - """ - self.prev_in_channels = self.in_channels - self.prev_out_channels = self.out_channels - self.in_channels = kernel.size(1) if filtered else self.initial_in_channels - self.out_channels = kernel.size(0) if filtered else self.initial_out_channels - - def reset_in_and_out_channel_to_previous(self): - """Analog to set_in_and_out_channels: - Resets the in and out_channels - - Args: - - Returns: - - """ - self.in_channels = self.prev_in_channels - self.out_channels = self.prev_out_channels - - def set_kernel_size(self, new_kernel_size): - """If the requested kernel size is outside of the min/max range, clamp it to - the min/max range. If the requested kernel size is not an available kernel - size, default to the max kernel size - - Args: - new_kernel_size: int): the size of the kernel you want to use - - Returns: - - """ - # previous_kernel_size = self.kernel_sizes[self.target_kernel_index] - if ( - new_kernel_size < self.min_kernel_size - or new_kernel_size > self.max_kernel_size - ): - logging.warn( - f"requested elastic kernel size ({new_kernel_size}) outside of min/max range: ({self.max_kernel_size}, {self.min_kernel_size}). clamping." - ) - if new_kernel_size <= self.min_kernel_size: - new_kernel_size = self.min_kernel_size - else: - new_kernel_size = self.max_kernel_size - - self.target_kernel_index = 0 - try: - index = self.kernel_sizes.index(new_kernel_size) - self.target_kernel_index = index - self.kernel_size = new_kernel_size - self.update_padding() - - except ValueError: - logging.warn( - f"requested elastic kernel size {new_kernel_size} is not an available kernel size. Defaulting to full size ({self.max_kernel_size})" - ) - - # if self.kernel_sizes[self.target_kernel_index] != previous_kernel_size: - # print(f"\nkernel size was changed: {previous_kernel_size} -> {self.kernel_sizes[self.target_kernel_index]}") - - # the initial kernel size is the first element of the list of available sizes - # set the kernel back to its initial size - def reset_kernel_size(self): - """ """ - self.set_kernel_size(self.kernel_sizes[0]) - - # step current kernel size down by one index, if possible. - # return True if the size limit was not reached - def step_down_kernel_size(self): - """ """ - next_kernel_index = self.target_kernel_index + 1 - if next_kernel_index < len(self.kernel_sizes): - self.set_kernel_size(self.kernel_sizes[next_kernel_index]) - # print(f"stepped down kernel size of a module! Index is now {self.target_kernel_index}") - return True - else: - logging.debug( - f"unable to step down kernel size, no available index after current: {self.target_kernel_index} with size: {self.kernel_sizes[self.target_kernel_index]}" - ) - return False - - def pick_kernel_index(self, target_kernel_index: int): - """ - - Args: - target_kernel_index: int: - target_kernel_index: int: - target_kernel_index: int: - target_kernel_index: int: - - Returns: - - """ - if (target_kernel_index < 0) or (target_kernel_index >= len(self.kernel_sizes)): - logging.warn( - f"selected kernel index {target_kernel_index} is out of range: 0 .. {len(self.kernel_sizes)}. Setting to last index." - ) - target_kernel_index = len(self.kernel_sizes) - 1 - self.set_kernel_size(self.kernel_sizes[target_kernel_index]) - - def get_available_kernel_steps(self): - """ """ - return len(self.kernel_sizes) - - def get_full_width_kernel(self): - """It applies the kernel transformations to the kernel until the target kernel - index is reached - :return: The found target kernel. - - Args: - - Returns: - - """ - current_kernel_index = 0 - current_dilation_index = 1 - current_kernel = self.weight - - logging.debug("Target kernel index: %s", str(self.target_kernel_index)) - - # step through kernels until the target index is reached. - while current_kernel_index < self.target_kernel_index: - if current_kernel_index >= len(self.kernel_sizes): - logging.warn( - f"kernel size index {current_kernel_index} is out of range. Elastic kernel acquisition stopping at last available kernel" - ) - break - # find start, end pos of the kernel center for the given next kernel size - start, end = sub_filter_start_end( - self.kernel_sizes[current_kernel_index], - self.kernel_sizes[current_kernel_index + 1], - ) - # extract the kernel center of the correct size - kernel_center = current_kernel[:, :, start:end] - # apply the kernel transformation to the next kernel. the n-th transformation - # is applied to the n-th kernel, yielding the (n+1)-th kernel - next_kernel = self.kernel_transforms[current_kernel_index](kernel_center) - # the kernel has now advanced through the available sizes by one - current_kernel = next_kernel - current_kernel_index += 1 - - # step through dilation until the target index is reached. - """ - while current_dilation_index < self.target_dilation_index: - if current_dilation_index >= len(self.dilation_sizes): - logging.warn( - f"kernel size index {current_kernel_index} is out of range. Elastic kernel acquisition stopping at last available kernel" - ) - break - # apply the kernel transformation to the next kernel. the n-th transformation - # is applied to the n-th kernel, yielding the (n+1)-th kernel - next_kernel = self.dilation_transforms[self.target_kernel_index][ - current_dilation_index - ](current_kernel) - # the kernel has now advanced through the available sizes by one - current_kernel = next_kernel - current_dilation_index += 1 - """ - - return current_kernel - - def get_kernel(self): - """If the input and output channels are not filtered, the full kernel is - - Args: - - Returns: - : return: The new kernel and bias. - - """ - full_kernel = self.get_full_width_kernel() - new_kernel = None - if all(self.in_channel_filter) and all(self.out_channel_filter): - # if no channel filtering is required, the full kernel can be kept - new_kernel = full_kernel - else: - # if channels need to be filtered, apply filters to the kernel - new_kernel = filter_primary_module_weights( - full_kernel, self.in_channel_filter, self.out_channel_filter - ) - # if the module has a bias parameter, also apply the output filtering to it. - if self.bias is None: - return new_kernel, None - else: - new_bias = filter_single_dimensional_weights( - self.bias, self.out_channel_filter - ) - return new_kernel, new_bias - - def set_dilation_size(self, new_dilation_size): - """ - - Args: - new_dilation_size: - - Returns: - - """ - if ( - new_dilation_size < self.min_dilation_size - or new_dilation_size > self.max_dilation_size - ): - logging.warn( - f"requested elastic dilation size ({new_dilation_size}) outside of min/max range: ({self.max_dilation_size}, {self.min_dilation_size}). clamping." - ) - if new_dilation_size < self.min_dilation_size: - new_dilation_size = self.min_dilation_size - else: - new_dilation_size = self.max_dilation_size - - self.target_dilation_index = 0 - try: - index = self.dilation_sizes.index(new_dilation_size) - self.target_dilation_index = index - self.dilation = self.dilation_sizes[self.target_dilation_index] - self.update_padding() - - except ValueError: - logging.warn( - f"requested elastic dilation size {new_dilation_size} is not an available dilation size. Defaulting to full size ({self.max_dilation_size})" - ) - - def update_padding(self): - """ """ - self.padding = conv1d_get_padding(self.kernel_size, self.dilation) - - # the initial dilation size is the first element of the list of available sizes - # set the dilation back to its initial size - def reset_dilation_size(self): - """ """ - self.set_dilation_size(self.dilation_sizes[0]) - - # step current kernel size down by one index, if possible. - # return True if the size limit was not reached - def step_down_dilation_size(self): - """ """ - next_dilation_index = self.target_dilation_index + 1 - if next_dilation_index < len(self.dilation_sizes): - self.set_dilation_size(self.dilation_sizes[next_dilation_index]) - return True - else: - logging.debug( - f"unable to step down dilation size, no available index after current: {self.target_dilation_index} with size: {self.dilation_sizes[self.target_dilation_index]}" - ) - return False - - def pick_dilation_index(self, target_dilation_index: int): - """ - - Args: - target_dilation_index: int: - target_dilation_index: int: - target_dilation_index: int: - target_dilation_index: int: - - Returns: - - """ - if (target_dilation_index < 0) or ( - target_dilation_index >= len(self.dilation_sizes) - ): - # MR-Optional Change (can be done in master) - logging.warn( - f"selected dilation index {target_dilation_index} is out of range: 0 .. {len(self.dilation_sizes)}. Setting to last index." - ) - target_dilation_index = len(self.dilation_sizes) - 1 - self.set_dilation_size(self.dilation_sizes[target_dilation_index]) - - def get_available_dilation_steps(self): - """ """ - return len(self.dilation_sizes) - - def get_available_grouping_steps(self): - """ """ - return len(self.group_sizes) - - def get_available_dsc_steps(self): - """ """ - return len(self.dscs) - - def get_dilation_size(self): - """ """ - return self.dilation_sizes[self.target_dilation_index] - - def pick_group_index(self, target_group_index: int): - """ - - Args: - target_group_index: int: - target_group_index: int: - target_group_index: int: - target_group_index: int: - - Returns: - - """ - if (target_group_index < 0) or (target_group_index >= len(self.group_sizes)): - logging.warn( - f"selected group index {target_group_index} is out of range: 0 .. {len(self.group_sizes)}. Setting to last index." - ) - target_group_index = len(self.group_sizes) - 1 - self.set_group_size(self.group_sizes[target_group_index]) - - def pick_dsc_index(self, target_dsc_index: int): - """ - - Args: - target_dsc_index: int: - target_dsc_index: int: - target_dsc_index: int: - target_dsc_index: int: - - Returns: - - """ - if (target_dsc_index < 0) or (target_dsc_index >= len(self.dscs)): - logging.warn( - f"selected dsc index {target_dsc_index} is out of range: 0 .. {len(self.dscs)}. Setting to last index." - ) - target_dsc_index = len(self.dscs) - 1 - self.set_dsc(self.dscs[target_dsc_index]) - - # the initial group size is the first element of the list of available sizes - # resets the group size back to its initial size - def reset_group_size(self): - """ """ - self.set_group_size(self.group_sizes[0]) - - def reset_dscs(self): - """ """ - self.set_dsc(self.dscs[0]) - - def get_group_size(self): - """ """ - return self.group_sizes[self.target_group_index] - - def get_dsc(self): - """ """ - return self.dscs[self.target_dsc_index] - - def set_group_size(self, new_group_size): - """ - - Args: - new_group_size: - - Returns: - - """ - if new_group_size < self.min_group_size or new_group_size > self.max_group_size: - logging.warn( - f"requested elastic group size ({new_group_size}) outside of min/max range: ({self.max_group_size}, {self.min_group_size}). clamping." - ) - if new_group_size < self.min_group_size: - new_group_size = self.min_group_size - else: - new_group_size = self.max_group_size - - self.target_group_index = 0 - try: - index = self.group_sizes.index(new_group_size) - self.target_group_index = index - # if hasattr(self, 'from_skipping') and self.from_skipping is True: - # logging.warn(f"setting groupsizes from skipping is: {self.from_skipping}") - # else: - # self.groups = self.group_sizes[index] - - except ValueError: - logging.warn( - f"requested elastic group size {new_group_size} is not an available group size. Defaulting to full size ({self.max_group_size})" - ) - - def set_dsc(self, new_dsc): - """ - - Args: - new_dsc: - - Returns: - - """ - if new_dsc < self.min_dsc or new_dsc > self.max_dsc: - logging.warn( - f"requested elastic dsc ({new_dsc}) outside of min/max range: ({self.max_dsc}, {self.min_dsc}). clamping." - ) - if new_dsc < self.min_dsc: - new_dsc = self.min_dsc - else: - new_dsc = self.max_dsc - - self.target_dsc_index = 0 - try: - index = self.dscs.index(new_dsc) - self.target_dsc_index = index - # if hasattr(self, 'from_skipping') and self.from_skipping is True: - # logging.warn(f"setting groupsizes from skipping is: {self.from_skipping}") - # else: - # self.groups = self.group_sizes[index] - - except ValueError: - logging.warn( - f"requested elastic dsc {new_dsc} is not an available group size. Defaulting to full size ({self.max_dsc})" - ) - - # step current kernel size down by one index, if possible. - # return True if the size limit was not reached - def step_down_group_size(self): - """ """ - next_group_index = self.target_group_index + 1 - if next_group_index < len(self.group_sizes): - self.set_group_size(self.group_sizes[next_group_index]) - # print(f"stepped down group size of a module! Index is now {self.target_group_index}") - return True - else: - logging.debug( - f"unable to step down group size, no available index after current: {self.target_group_index} with size: {self.group_sizes[self.target_group_index]}" - ) - return False - - def step_down_dsc(self): - """ """ - next_dsc_index = self.target_dsc_index + 1 - if next_dsc_index < len(self.dscs): - self.set_dsc(self.dscs[next_dsc_index]) - # print(f"stepped down group size of a module! Index is now {self.target_group_index}") - return True - else: - logging.debug( - f"unable to step down dsc, no available index after current: {self.target_dsc_index} with size: {self.dscs[self.target_dsc_index]}" - ) - return False - - # Wrapper Class - def adjust_weights_for_grouping(self, weights, input_divided_by): - """ - - Args: - weights: - input_divided_by: - - Returns: - - """ - return adjust_weights_for_grouping(weights, input_divided_by) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - """ - - Args: - input: torch.Tensor: - input: torch.Tensor: - input: torch.Tensor: - input: torch.Tensor: - - Returns: - - """ - pass - - def extra_repr(self): - """ """ - pass - # return super(ElasticBase1d, self).extra_repr() diff --git a/hannah/models/ofa/submodules/elasticBatchnorm.py b/hannah/models/ofa/submodules/elasticBatchnorm.py deleted file mode 100644 index 62c5adf2..00000000 --- a/hannah/models/ofa/submodules/elasticBatchnorm.py +++ /dev/null @@ -1,120 +0,0 @@ -# -# Copyright (c) 2022 University of Tübingen. -# -# This file is part of hannah. -# See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import copy -import logging - -import torch -import torch.nn as nn -import torch.nn.functional as nnf - -from ..utilities import filter_single_dimensional_weights, make_parameter - - -class ElasticWidthBatchnorm1d(nn.BatchNorm1d): - """ """ - - def __init__( - self, - num_features, - track_running_stats=False, - affine=True, - momentum=0.1, - eps=1e-5, - ): - - super().__init__( - num_features=num_features, - eps=eps, - momentum=momentum, - affine=affine, - track_running_stats=track_running_stats, - ) - self.channel_filter = [True] * num_features - - def forward(self, input: torch.Tensor) -> torch.Tensor: - """if self.track_running_stats: - logging.warn( - "ElasticWidthBatchnorm with tracked running stats currently not fully implemented!" - ) - # num_batches_tracked and exponential averaging are currently not implemented. - - Args: - input: torch.Tensor: - input: torch.Tensor: - input: torch.Tensor: - input: torch.Tensor: - - Returns: - - """ - running_mean = filter_single_dimensional_weights( - self.running_mean, self.channel_filter - ) - running_var = filter_single_dimensional_weights( - self.running_var, self.channel_filter - ) - weight = filter_single_dimensional_weights(self.weight, self.channel_filter) - bias = filter_single_dimensional_weights(self.bias, self.channel_filter) - training = self.training - momentum = self.momentum - eps = self.eps - - return nnf.batch_norm( - input=input, - running_mean=running_mean, - running_var=running_var, - weight=weight, - bias=bias, - training=training or not self.track_running_stats, - momentum=momentum, - eps=eps, - ) - - def get_basic_batchnorm1d(self): - """ """ - # filter_single_dimensional_weights checks for None-input, no need to do it here. - running_mean = filter_single_dimensional_weights( - self.running_mean, self.channel_filter - ) - running_var = filter_single_dimensional_weights( - self.running_var, self.channel_filter - ) - weight = make_parameter( - filter_single_dimensional_weights(self.weight, self.channel_filter) - ) - bias = make_parameter( - filter_single_dimensional_weights(self.bias, self.channel_filter) - ) - new_bn = nn.BatchNorm1d( - num_features=self.num_features, - eps=self.eps, - momentum=self.momentum, - affine=self.affine, - track_running_stats=self.track_running_stats, - ) - new_bn.running_mean = running_mean - new_bn.running_var = running_var - new_bn.weight = weight - new_bn.bias = bias - new_bn.training = self.training - return new_bn - - def assemble_basic_module(self) -> nn.BatchNorm1d: - """ """ - return copy.deepcopy(self.get_basic_batchnorm1d()) diff --git a/hannah/models/ofa/submodules/elasticLinear.py b/hannah/models/ofa/submodules/elasticLinear.py deleted file mode 100644 index 7a1d732f..00000000 --- a/hannah/models/ofa/submodules/elasticLinear.py +++ /dev/null @@ -1,256 +0,0 @@ -# -# Copyright (c) 2022 University of Tübingen. -# -# This file is part of hannah. -# See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import copy -import logging - -import torch -import torch.nn as nn -import torch.nn.functional as nnf - -from ...factory import qat -from ..utilities import ( - filter_primary_module_weights, - filter_single_dimensional_weights, - make_parameter, -) -from .elasticBase import _Elastic - - -class ElasticWidthLinear(nn.Linear, _Elastic): - """ """ - - def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: - nn.Linear.__init__(self, in_features, out_features, bias=bias) - _Elastic.__init__(self, [True] * in_features, [True] * out_features) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - """ - - Args: - input: torch.Tensor: - input: torch.Tensor: - input: torch.Tensor: - input: torch.Tensor: - - Returns: - - """ - if all(self.in_channel_filter) and all(self.out_channel_filter): - # if no channel filtering is required, simply use the full linear - return nnf.linear(input, self.weight, self.bias) - else: - # if channels need to be filtered, apply filters. - new_weight = filter_primary_module_weights( - self.weight, self.in_channel_filter, self.out_channel_filter - ) - # if the module has a bias parameter, also apply the output filtering to it. - # filter_single_dimensional_weights checks for None-input, so no check is done here. - new_bias = filter_single_dimensional_weights( - self.bias, self.out_channel_filter - ) - return nnf.linear(input, new_weight, new_bias) - - def get_basic_module(self): - """ """ - weight = self.weight - bias = self.bias - # weight and bias of this linear will be overwritten - new_linear = nn.Linear( - in_features=self.in_features, - out_features=self.out_features, - ) - if all(self.in_channel_filter) and all(self.out_channel_filter): - new_linear.weight = weight - new_linear.bias = bias - return new_linear - else: - new_weight = make_parameter( - filter_primary_module_weights( - self.weight, self.in_channel_filter, self.out_channel_filter - ) - ) - new_bias = make_parameter( - filter_single_dimensional_weights(self.bias, self.out_channel_filter) - ) - new_linear.weight = new_weight - new_linear.bias = new_bias - return new_linear - - -class ElasticQuantWidthLinear(nn.Linear, _Elastic): - """ """ - - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - out_quant=True, - qconfig=None, - ) -> None: - - super().__init__(in_features, out_features, bias=bias) - _Elastic.__init__(self, [True] * in_features, [True] * out_features) - - assert qconfig, "qconfig must be provided for QAT module" - self.out_quant = out_quant - self.qconfig = qconfig - self.weight_fake_quant = qconfig.weight() - if hasattr(qconfig, "bias"): - self.bias_fake_quant = qconfig.bias() - else: - self.bias_fake_quant = qconfig.activation() - - self.activation_post_process = ( - qconfig.activation() if out_quant else nn.Identity() - ) - - @property - def filtered_weight(self): - """ """ - if all(self.in_channel_filter) and all(self.out_channel_filter): - return self.weight - else: - - return filter_primary_module_weights( - self.weight, self.in_channel_filter, self.out_channel_filter - ) - - @property - def filtered_bias(self): - """ """ - return filter_single_dimensional_weights(self.bias, self.out_channel_filter) - - @property - def scaled_weight(self): - """ """ - return self.weight_fake_quant(self.filtered_weight) - - @property - def scaled_bias(self): - """ """ - return ( - self.bias_fake_quant(self.filtered_bias) - if self.bias is not None - else self.bias - ) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - """ - - Args: - input: torch.Tensor: - input: torch.Tensor: - input: torch.Tensor: - input: torch.Tensor: - - Returns: - - """ - return self.activation_post_process( - nnf.linear(input, self.scaled_weight, self.scaled_bias) - ) - - def get_basic_module(self): - """ """ - weight = self.weight - bias = self.bias - # weight and bias of this linear will be overwritten - new_linear = qat.Linear( - in_features=self.in_features, - out_features=self.out_features, - bias=self.bias is not None, - out_quant=self.out_quant, - qconfig=self.qconfig, - ) - if all(self.in_channel_filter) and all(self.out_channel_filter): - new_linear.weight = weight - new_linear.bias = bias - return new_linear - else: - new_weight = make_parameter( - filter_primary_module_weights( - self.weight, self.in_channel_filter, self.out_channel_filter - ) - ) - new_bias = make_parameter( - filter_single_dimensional_weights(self.bias, self.out_channel_filter) - ) - new_linear.weight = new_weight - new_linear.bias = new_bias - return new_linear - - @classmethod - def from_float(cls, mod): - """Create a qat module from a float module or qparams_dict - - Args: `mod` a float module, either produced by torch.quantization utilities - or directly from user - - Args: - mod: - - Returns: - - """ - assert type(mod) == cls._FLOAT_MODULE, ( - " qat." - + cls.__name__ - + ".from_float only works for " - + cls._FLOAT_MODULE.__name__ - ) - assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" - assert mod.qconfig, "Input float module must have a valid qconfig" - # TODO kombi noch einfügen - # if type(mod) == LinearReLU: - # mod = mod[0] - - qconfig = mod.qconfig - qat_linear = cls( - mod.in_features, - mod.out_features, - bias=mod.bias is not None, - qconfig=qconfig, - ) - qat_linear.weight = mod.weight - qat_linear.bias = mod.bias - return qat_linear - - -# just a ReLu, which can forward a SequenceDiscovery -class ElasticPermissiveReLU(nn.ReLU): - """ """ - - def __init__(self): - super().__init__() - - def forward(self, x): - """ - - Args: - x: - - Returns: - - """ - return super().forward(x) - - def assemble_basic_module(self): - """ """ - return nn.ReLU() diff --git a/hannah/models/ofa/submodules/elasticchannelhelper.py b/hannah/models/ofa/submodules/elasticchannelhelper.py deleted file mode 100644 index ff4dc0f1..00000000 --- a/hannah/models/ofa/submodules/elasticchannelhelper.py +++ /dev/null @@ -1,465 +0,0 @@ -# -# Copyright (c) 2022 University of Tübingen. -# -# This file is part of hannah. -# See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import logging -from typing import List - -import numpy as np -import torch -import torch.nn as nn - -from ..type_utils import elastic_forward_type -from ..utilities import flatten_module_list - -# imports are located at the bottom to circumvent circular dependency import issues -from .elasticBatchnorm import ElasticWidthBatchnorm1d - - -# helper module, deployed in an elastic width connection -# can zero out input channels to train elastic channels without weight modification -# must know the module it passes inputs to to compute channel priorities -# must know source modules to remove output channels in extraction -# This can be previous linears/convs, the skip layer of a previous residual block, -# and batchnorms placed in-between -class ElasticChannelHelper(nn.Module): - """ """ - - def __init__( - self, - channel_counts: List[int], - # sources: nn.ModuleList, - # target: nn.Module, - # additional_targets: nn.ModuleList = nn.ModuleList([]), - ): - super().__init__() - # sort channel counts: largest -> smallest - self.channel_counts = channel_counts - self.channel_counts.sort(reverse=True) - self.sources = nn.ModuleList([]) - self.target = None - # additional target modules will not be used to compute channel priorities, - # but will have to be known for reducing input channels - # this may contain additional exits, the skip layer of a following residual block - self.additional_targets = nn.ModuleList([]) - # initialize filter for channel reduction in training, channel priority list - self.channel_pass_filter: List[int] = [] - # the first channel index in this list is least important, the last channel index ist most important - self.channels_by_priority: List[int] = [] - # initially, all channels are used. - self.max_channels: int = self.channel_counts[0] - self.channel_step: int = 0 - self.current_channels: int = self.channel_counts[self.channel_step] - - # initialize the filter and the channel priority list - for i in range(self.max_channels): - self.channel_pass_filter.append(True) - # to init with technically valid values, simply set starting priority based on index - self.channels_by_priority.append(i) - - # compute channel priorities based on the l1 norm of the weights of whichever - # target module follows this elastic channel section - def compute_channel_priorities(self): - """ """ - target = self.target - channel_norms = [] - - # this will also include the elastic kernel convolutions - # for elastic kernel convolutions, the priorities will then also be - # computed on the base module (full kernel) - if isinstance(target, nn.Conv1d): - weights = target.weight.data - norms_per_kernel_index = torch.linalg.norm(weights, ord=1, dim=0) - channel_norms = torch.linalg.norm(norms_per_kernel_index, ord=1, dim=1) - # the channel priorities for linears need to also be computable: - # especially for the exit connections, a linear may follow after an elastic width - elif isinstance(target, nn.Linear): - weights = target.weight.data - channel_norms = torch.linalg.norm(weights, ord=1, dim=0) - else: - # the channel priorities will keep their previous / default value in - # this case. Reduction will probably occur by channel order - logging.warning( - f"Unable to compute channel priorities! Unsupported target module after elastic channels: {type(target)}" - ) - - # contains the indices of the channels, sorted from channel with smallest - # norm to channel with largest norm - # the least important channel index is at the beginning of the list, - # the most important channel index is at the end - self.channels_by_priority = np.argsort(channel_norms) - - # set the channel filter list based on the channel priorities and the current channel count - def set_channel_filter(self): - """ """ - # get the amount of channels to be removed from the max and current channel counts - channel_reduction_amount: int = self.max_channels - self.current_channels - # start with an empty filter, where every channel passes through, then remove channels by priority - self.channel_pass_filter = [True] * len(self.channel_pass_filter) - - # filter the least important n channels, specified by the reduction amount - for i in range(channel_reduction_amount): - # priority list of channels contains channel indices from least important to most important - # the first n channel indices specified in this list will be filtered out - filtered_channel_index = self.channels_by_priority[i] - self.channel_pass_filter[filtered_channel_index] = False - - if isinstance( - self.target, - elastic_forward_type, - ): - self.apply_filter_to_module(self.target, is_target=True) - else: - logging.warn( - f"Elastic channel helper has no defined behavior for primary target type: {type(self.target)}" - ) - for item in self.additional_targets: - self.apply_filter_to_module(item, is_target=True) - for item in self.sources: - self.apply_filter_to_module(item, is_target=False) - - # if is_target is set to true, the module is a target module (filter its input). - # false -> source module -> filter its output - def apply_filter_to_module(self, module, is_target: bool): - """ - - Args: - module: - is_target: bool: - is_target: bool: - is_target: bool: - is_target: bool: - - Returns: - - """ - if isinstance( - module, - elastic_forward_type, - ): - if is_target: - # target module -> set module input filter - if len(module.in_channel_filter) != len(self.channel_pass_filter): - logging.error( - f"Elastic channel helper filter length {len(self.channel_pass_filter)} does not match filter length {len(module.in_channel_filter)} of {type(module)}! " - ) - return - module.set_in_channel_filter(self.channel_pass_filter) - else: - # source module -> set module output filter - if len(module.out_channel_filter) != len(self.channel_pass_filter): - logging.error( - f"Elastic channel helper filter length {len(self.channel_pass_filter)} does not match filter length {len(module.out_channel_filter)} of {type(module)}! " - ) - return - module.set_out_channel_filter(self.channel_pass_filter) - - elif isinstance(module, ElasticWidthBatchnorm1d): - # this is normal for residual blocks with a norm after applying residual output to blocks output - # if is_target: - # logging.warn("Batchnorm found in Elastic channel helper targets, it should usually be located in-front of the helper module.") - if len(module.channel_filter) != len(self.channel_pass_filter): - logging.error( - f"Elastic channel helper filter length {len(self.channel_pass_filter)} does not match filter length {len(module.channel_filter)} of {type(module)}!" - ) - return - module.channel_filter = self.channel_pass_filter - else: - logging.error( - f"Elastic channel helper could not apply filter to module of unknown type: {type(module)}" - ) - - # step down channel count by one channel step - def step_down_channels(self): - """ """ - if self.channel_step + 1 in range(len(self.channel_counts)): - # if there is still channel steps available, step forward by one. Set new active channel count. - self.channel_step += 1 - self.current_channels = self.channel_counts[self.channel_step] - # after stepping down channels by one, set new channel filter. - self.set_channel_filter() - return True - else: - # if the last channel step is already reached, no additional step-down operation can be performed - return False - - def set_channel_step(self, step: int): - """ - - Args: - step: int: - step: int: - step: int: - step: int: - - Returns: - - """ - if step not in range(len(self.channel_counts)): - logging.warn( - f"Elastic channel helper step target {step} out of range for length {len(self.channel_counts)}. Defaulting to 0." - ) - step = 0 - if step == self.channel_step: - # only re-apply filters if there is actually going to be a change. - return - self.channel_step = step - self.current_channels = self.channel_counts[self.channel_step] - self.set_channel_filter() - - def reset_channel_step(self): - """ """ - self.set_channel_step(0) - - # set the primary target from an input module. For iterable inputs, extract additional secondary targets - def set_primary_target(self, target: nn.Module): - """ - - Args: - target: nn.Module: - target: nn.Module: - target: nn.Module: - target: nn.Module: - - Returns: - - """ - if hasattr(target, "__iter__"): - # first, flatten the target, if it is iterable - target = flatten_module_list(target) - # the primary target is the first linear/conv in the sequence - for item in target: - if self.is_valid_primary_target(item): - self.target = item - # if the primary target was found in the sequence, any trailing - # modules must be ignored, as they are unaffected. - break - else: - # if the module item is not a primary target, process it as a secondary target. - self.add_secondary_targets(item) - # this will check for other, invalid ElasticChannelHelper - # modules in targets and throw an error - else: - # if the input is not iterable, and is just a simple module, it is the target - if not self.is_valid_primary_target(target): - # if the standalone module is not actually a valid primary target, something went wrong! - logging.warn( - f"ElasticChannelHelper target module is an invalid module: '{type(target)}'. Target reset to None." - ) - self.target = None - # if the input is valid as a target module, set it as the target - self.target = target - - # check if a module is valid as a primary target (to compute channel priorities from) - def is_valid_primary_target(self, module: nn.Module) -> bool: - """ - - Args: - module: nn.Module: - module: nn.Module: - module: nn.Module: - module: nn.Module: - - Returns: - - """ - # legacy function - return ElasticChannelHelper.is_primary_target(module) - - # check if a module is valid as a primary target (to compute channel priorities from) - def is_primary_target(module: nn.Module) -> bool: - """ - - Args: - module: nn.Module: - module: nn.Module: - module: nn.Module: - module: nn.Module: - - Returns: - - """ - return isinstance( - module, - elastic_forward_type, - ) - - # add additional target(s) which must also have their inputs adjusted when - # stepping down channels - def add_secondary_targets(self, target: nn.Module): - """ - - Args: - target: nn.Module: - target: nn.Module: - target: nn.Module: - target: nn.Module: - - Returns: - - """ - if hasattr(target, "__iter__"): - # if the input target is iterable, check every item - target_flat = flatten_module_list(target) - for item in target_flat: - if isinstance(item, ElasticChannelHelper): - logging.error( - "ElasticChannelHelper target accumulation reached another ElasticChannelHelper, with no primary target in-between!" - ) - self.add_secondary_target_item(item) - if self.is_valid_primary_target(item): - # if a valid primary target is found reached, the modules - # trailing it must not be affected by width changes - # only modules before a trailing linear/conv will be affected - break - else: - self.add_secondary_target_item(target) - - # TODO: logic for adding secondary items to target/source is pretty much a copy - could be cleaned up - # check a module, add it as a secondary target if its weights would need modification when channel width changes - def add_secondary_target_item(self, target: nn.Module): - """ - - Args: - target: nn.Module: - target: nn.Module: - target: nn.Module: - target: nn.Module: - - Returns: - - """ - if self.is_valid_primary_target(target): - self.additional_targets.append(target) - elif isinstance(target, ElasticWidthBatchnorm1d): - # trailing batchnorms between the channel helper and the next 'real' - # module will also need to have their channels adjusted - # this is normal for residual blocks with a norm after applying residual output to blocks output - # logging.warn( - # "found loose BatchNorm1d module trailing an elastic channel helper. These should be located in-front of the helper" - # ) - self.additional_targets.append(target) - elif isinstance(target, nn.ReLU): - logging.warn( - "found loose ReLu module trailing an elastic channel helper. These should be located in-front of the helper" - ) - else: - logging.warn( - f"module with undefined behavior found in ElasticChannelHelper targets: '{type(target)}'. Ignoring." - ) - - # add additional source(s) which must have their outputs adjusted if the channel width changes - def add_sources(self, source: nn.Module): - """ - - Args: - source: nn.Module: - source: nn.Module: - source: nn.Module: - source: nn.Module: - - Returns: - - """ - if hasattr(source, "__iter__"): - # if the input source is iterable, check every item - source_flat = flatten_module_list(source) - for item in source_flat: - # ascend the list of sources from the back - if isinstance(item, ElasticChannelHelper): - logging.exception( - "ElasticChannelHelper source accumulation found another ElasticChannelHelper!" - ) - self.add_source_item(item) - else: - self.add_source_item(self, source) - - # add additional source(s) which must have their outputs adjusted if the channel width changes - def add_targets(self, target: nn.Module): - """ - - Args: - target: nn.Module: - target: nn.Module: - target: nn.Module: - target: nn.Module: - - Returns: - - """ - if hasattr(target, "__iter__"): - # if the input source is iterable, check every item - target_flat = flatten_module_list(target) - - for item in target_flat: - # ascend the list of sources from the back - if isinstance(item, ElasticChannelHelper): - logging.exception( - "ElasticChannelHelper source accumulation found another ElasticChannelHelper!" - ) - self.discover_target(item) - else: - self.discover_target(target) - - # check a module, add it as a source if its weights would need modification when channel width changes - def add_source_item(self, source: nn.Module): - """ - - Args: - source: nn.Module: - source: nn.Module: - source: nn.Module: - source: nn.Module: - - Returns: - - """ - if self.is_valid_primary_target(source): - # modules which are valid primary targets (Convs, Linears) are also valid sources - self.sources.append(source) - elif isinstance(source, ElasticWidthBatchnorm1d): - # batchnorms before the channel helper will need to be adjusted if channels are removed - self.sources.append(source) - elif isinstance(source, nn.ReLU): - # ReLu preceding the channel helper can be ignored. It does not need adjustment. - pass - else: - logging.warn( - f"module with undefined behavior found in ElasticChannelHelper sources: '{type(source)}'. Ignoring." - ) - - def discover_target(self, new_target: nn.Module): - """ - - Args: - new_target: nn.Module: - new_target: nn.Module: - new_target: nn.Module: - new_target: nn.Module: - - Returns: - - """ - # if no target is set yet, take this module as the primary target - if self.is_valid_primary_target(new_target) and self.target is None: - self.set_primary_target(new_target) - else: - self.add_secondary_targets(new_target) - - def get_available_width_steps(self): - """ """ - return len(self.channel_counts) diff --git a/hannah/models/ofa/submodules/elastickernelconv.py b/hannah/models/ofa/submodules/elastickernelconv.py deleted file mode 100644 index e26f6a84..00000000 --- a/hannah/models/ofa/submodules/elastickernelconv.py +++ /dev/null @@ -1,724 +0,0 @@ -# -# Copyright (c) 2022 University of Tübingen. -# -# This file is part of hannah. -# See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import copy -import logging -import math -import random -from inspect import Parameter -from tokenize import group -from typing import List - -import torch -from torch import nn -from torch.nn import functional as nnf - -from ..utilities import adjust_weight_if_needed, conv1d_get_padding -from .elasticBase import ElasticBase1d -from .elasticBatchnorm import ElasticWidthBatchnorm1d -from .elasticLinear import ElasticPermissiveReLU - - -class ElasticConv1d(ElasticBase1d): - """ """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_sizes: List[int], - dilation_sizes: List[int], - groups: List[int], - dscs: List[bool], - stride: int = 1, - padding: int = 0, - bias: bool = False, - out_channel_sizes=None, - ): - ElasticBase1d.__init__( - self, - in_channels=in_channels, - out_channels=out_channels, - kernel_sizes=kernel_sizes, - stride=stride, - padding=padding, - dilation_sizes=dilation_sizes, - groups=groups, - dscs=dscs, - bias=bias, - out_channel_sizes=out_channel_sizes, - ) - self.norm = False - self.act = False - # TODO es wäre auch möglich das ganze als Flag einzubauen wie norm und act, aber hier wäre die Frage wie man es mit dem trainieren macht ? - # So wäre es statisch und nicht wirklich sinnvoll - - def forward(self, input: torch.Tensor) -> torch.Tensor: - """ - - Args: - input: torch.Tensor: - input: torch.Tensor: - input: torch.Tensor: - input: torch.Tensor: - - Returns: - - """ - # get the kernel for the current index - kernel, bias = self.get_kernel() - # First get the correct count of in and outchannels - self.set_in_and_out_channel(kernel) - - dilation = self.get_dilation_size() - # get padding for the size of the kernel - padding = conv1d_get_padding( - self.kernel_sizes[self.target_kernel_index], dilation - ) - grouping = self.get_group_size() - # Hier muss dann wenn dsc_on on ist, die Logik implementiert werden dass DSC komplett greift - dsc_on = self.get_dsc() - - if dsc_on is False: - kernel, _ = adjust_weight_if_needed( - module=self, kernel=kernel, groups=grouping - ) - output = nnf.conv1d( - input, kernel, bias, self.stride, padding, dilation, grouping - ) - else: - # we use the full kernel here, because if the input_channel_size is greater than the output_channel_size - # we have to increase the output_channel_size for dsc, hence we need the full kernel, because, the filtered kernel - # is in that particular case to small. - kernel, bias = self.get_full_width_kernel(), self.bias - - # sanity check - if self.in_channels > kernel.size(0): - logging.warning( - f"In Channels vs Maximum of Outchannels : {self.in_channels} vs {kernel.size(0)}, full_kernel:({kernel.shape}), kernel:({self.get_kernel()[0].shape})" - ) - - output = self.do_dsc( - input, - full_kernel=kernel, - full_bias=bias, - grouping=grouping, - stride=self.stride, - padding=padding, - dilation=dilation, - ) - - self.reset_in_and_out_channel_to_previous() - return output - - # return a normal conv1d equivalent to this module in the current state - def get_basic_module(self) -> nn.Module: - """ """ - kernel, bias = self.get_kernel() - kernel_size = self.kernel_sizes[self.target_kernel_index] - - self.set_in_and_out_channel(kernel) - - dilation = self.get_dilation_size() - grouping = self.get_group_size() - padding = conv1d_get_padding(kernel_size, dilation) - dsc_on = self.get_dsc() - - if dsc_on: - dsc_sequence = self.prepare_dsc_for_validation_model( - conv_class=nn.Conv1d, - full_kernel=self.get_full_width_kernel(), - full_bias=self.bias, - in_channels=self.in_channels, - out_channels=self.out_channels, - grouping=grouping, - stride=self.stride, - padding=padding, - dilation=dilation, - ) - self.reset_in_and_out_channel_to_previous() - return dsc_sequence - else: - new_conv = nn.Conv1d( - in_channels=self.in_channels, - out_channels=self.out_channels, - kernel_size=kernel_size, - stride=self.stride, - padding=padding, - dilation=dilation, - bias=False, - groups=grouping, - ) - new_conv.last_grouping_param = self.groups - - # for ana purposes handy - set a unique id so we can track this specific convolution - if not hasattr(new_conv, "id"): - new_conv.id = "ElasticConv1d-" + str(random.randint(0, 1000) * 2000) - logging.debug( - f"Validation id created: {new_conv.id} ; g={grouping}, w_before={kernel.shape}, ic={self.in_channels}" - ) - else: - logging.debug("Validation id already present: {new_conv.id}") - - kernel, _ = adjust_weight_if_needed( - module=new_conv, kernel=kernel, groups=new_conv.groups - ) - new_conv.weight.data = kernel - if bias is not None: - new_conv.bias = bias - - logging.debug( - f"=====> id: {new_conv.id} ; g={grouping}, w_after={kernel.shape}, ic={self.in_channels}" - ) - - self.reset_in_and_out_channel_to_previous() - return new_conv - - -class ElasticConvReLu1d(ElasticBase1d): - """ """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_sizes: List[int], - dilation_sizes: List[int], - groups: List[int], - dscs: List[bool], - stride: int = 1, - padding: int = 0, - bias: bool = False, - out_channel_sizes=None, - ): - ElasticBase1d.__init__( - self, - in_channels=in_channels, - out_channels=out_channels, - kernel_sizes=kernel_sizes, - stride=stride, - padding=padding, - dilation_sizes=dilation_sizes, - groups=groups, - dscs=dscs, - bias=bias, - out_channel_sizes=out_channel_sizes, - ) - self.relu = ElasticPermissiveReLU() - self.norm = False - self.act = True - - def forward(self, input: torch.Tensor) -> torch.Tensor: - """ - - Args: - input: torch.Tensor: - input: torch.Tensor: - input: torch.Tensor: - input: torch.Tensor: - - Returns: - - """ - # return self.get_basic_conv1d().forward(input) # for validaing assembled module - # get the kernel for the current index - kernel, bias = self.get_kernel() - # First get the correct count of in and outchannels - # given by the kernel (after setting the kernel correctly, with the help of input-/output_filters) - self.set_in_and_out_channel(kernel) - - dilation = self.get_dilation_size() - # get padding for the size of the kernel - padding = conv1d_get_padding( - self.kernel_sizes[self.target_kernel_index], dilation - ) - - grouping = self.get_group_size() - dsc_on = self.get_dsc() - - if dsc_on is False: - kernel, _ = adjust_weight_if_needed( - module=self, kernel=kernel, groups=grouping - ) - output = nnf.conv1d( - input, kernel, bias, self.stride, padding, dilation, grouping - ) - else: - # we use the full kernel here, because if the input_channel_size is greater than the output_channel_size - # we have to increase the output_channel_size for dsc, hence we need the full kernel, because, the filtered kernel - # is in that particular case to small. - kernel, bias = ( - self.get_full_width_kernel(), - self.bias, - ) # if self.in_channels > self.out_channels else (kernel, bias) - if self.in_channels > kernel.size(0): - logging.warning( - f"In Channels vs Maximum of Outchannels : {self.in_channels} vs {kernel.size(0)}, full_kernel:({kernel.shape}), kernel:({self.get_kernel()[0].shape})" - ) - output = self.do_dsc( - input, - full_kernel=kernel, - full_bias=bias, - grouping=grouping, - stride=self.stride, - padding=padding, - dilation=dilation, - ) - - self.reset_in_and_out_channel_to_previous() - return self.relu(output) - - # return a normal conv1d equivalent to this module in the current state - def get_basic_module(self) -> nn.Module: - """ """ - kernel, bias = self.get_kernel() - self.set_in_and_out_channel(kernel) - - kernel_size = self.kernel_sizes[self.target_kernel_index] - dilation = self.get_dilation_size() - grouping = self.get_group_size() - padding = conv1d_get_padding(kernel_size, dilation) - dsc_on = self.get_dsc() - - if dsc_on: - dsc_sequence = self.prepare_dsc_for_validation_model( - conv_class=ConvRelu1d, - full_kernel=self.get_full_width_kernel(), - full_bias=self.bias, - in_channels=self.in_channels, - out_channels=self.out_channels, - grouping=grouping, - stride=self.stride, - padding=padding, - dilation=dilation, - ) - self.reset_in_and_out_channel_to_previous() - return dsc_sequence - else: - new_conv = ConvRelu1d( - in_channels=self.in_channels, - out_channels=self.out_channels, - kernel_size=kernel_size, - stride=self.stride, - padding=padding, - dilation_sizes=dilation, - bias=False, - groups=grouping, - ) - - # for ana purposes handy - set a unique id so we can track this specific convolution - new_conv.last_grouping_param = self.groups - if not hasattr(new_conv, "id"): - new_conv.id = "ConvRelu1d-" + str(random.randint(0, 1000) * 2000) - logging.debug( - f"Validation id created: {new_conv.id} ; g={grouping}, w_before={kernel.shape}, ic={self.in_channels}" - ) - else: - logging.debug("Validation id already present: {new_conv.id}") - - kernel, _ = adjust_weight_if_needed( - module=new_conv, kernel=kernel, groups=new_conv.groups - ) - logging.debug( - f"=====> id: {new_conv.id} ; g={grouping}, w_after={kernel.shape}, ic={self.in_channels}" - ) - new_conv.weight.data = kernel - if bias is not None: - new_conv.bias = bias - - # print("\nassembled a basic conv from elastic kernel!") - self.reset_in_and_out_channel_to_previous() - return new_conv - - -class ElasticConvBn1d(ElasticConv1d): - """ """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_sizes: List[int], - dilation_sizes: List[int], - groups: List[int], - dscs: List[bool], - stride: int = 1, - padding: int = 0, - bias: bool = False, - track_running_stats=False, - out_channel_sizes=None, - ): - ElasticBase1d.__init__( - self, - in_channels=in_channels, - out_channels=out_channels, - kernel_sizes=kernel_sizes, - stride=stride, - padding=padding, - dilation_sizes=dilation_sizes, - groups=groups, - dscs=dscs, - bias=bias, - out_channel_sizes=out_channel_sizes, - ) - self.bn = ElasticWidthBatchnorm1d(out_channels, track_running_stats) - self.norm = True - self.act = False - - def forward(self, input: torch.Tensor) -> torch.Tensor: - """ - - Args: - input: torch.Tensor: - input: torch.Tensor: - input: torch.Tensor: - input: torch.Tensor: - - Returns: - - """ - # return self.get_basic_conv1d().forward(input) # for validaing assembled module - dilation = self.get_dilation_size() - # get padding for the size of the kernel - self.padding = conv1d_get_padding( - self.kernel_sizes[self.target_kernel_index], dilation - ) - - return self.bn(super(ElasticConvBn1d, self).forward(input)) - - # return a normal conv1d equivalent to this module in the current state - def get_basic_module(self) -> nn.Module: - """ """ - kernel, bias = self.get_kernel() - self.set_in_and_out_channel(kernel) - - kernel_size = self.kernel_sizes[self.target_kernel_index] - dilation = self.get_dilation_size() - grouping = self.get_group_size() - padding = conv1d_get_padding(kernel_size, dilation) - dsc_on = self.get_dsc() - - if dsc_on: - tmp_bn = self.bn.get_basic_batchnorm1d() - dsc_sequence: nn.Sequential = self.prepare_dsc_for_validation_model( - conv_class=ConvBn1d, - full_kernel=self.get_full_width_kernel(), - full_bias=self.bias, - in_channels=self.in_channels, - out_channels=self.out_channels, - grouping=grouping, - stride=self.stride, - padding=padding, - dilation=dilation, - bn_caller=(self.set_bn_parameter, tmp_bn, self.bn.num_batches_tracked), - ) - self.reset_in_and_out_channel_to_previous() - return dsc_sequence - else: - new_conv = ConvBn1d( - in_channels=self.in_channels, - out_channels=self.out_channels, - kernel_size=kernel_size, - stride=self.stride, - padding=padding, - dilation=dilation, - bias=False, - groups=grouping, - ) - tmp_bn = self.bn.get_basic_batchnorm1d() - - # for ana purposes handy - set a unique id so we can track this specific convolution - new_conv.last_grouping_param = self.groups - if not hasattr(new_conv, "id"): - new_conv.id = "ElasticConvBn1d-" + str(random.randint(0, 1000) * 2000) - logging.debug( - f"Validation id created: {new_conv.id} ; g={grouping}, w_before={kernel.shape}, ic={self.in_channels}" - ) - else: - logging.debug("id already present: {new_conv.id}") - kernel, _ = adjust_weight_if_needed( - module=new_conv, kernel=kernel, groups=new_conv.groups - ) - logging.debug( - f"=====> id: {new_conv.id} ; g={grouping}, w_after={kernel.shape}, ic={self.in_channels}" - ) - - new_conv.weight.data = kernel - new_conv.bias = bias - - new_conv = self.set_bn_parameter( - new_conv, tmp_bn=tmp_bn, num_tracked=self.bn.num_batches_tracked - ) - - # print("\nassembled a basic conv from elastic kernel!") - self.reset_in_and_out_channel_to_previous() - return new_conv - - -class ElasticConvBnReLu1d(ElasticConvBn1d): - """ """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_sizes: List[int], - dilation_sizes: List[int], - groups: List[int], - dscs: List[bool], - stride: int = 1, - padding: int = 0, - bias: bool = False, - track_running_stats=False, - out_channel_sizes=None, - from_skipping=False, - ): - ElasticConvBn1d.__init__( - self, - in_channels=in_channels, - out_channels=out_channels, - kernel_sizes=kernel_sizes, - stride=stride, - padding=padding, - dilation_sizes=dilation_sizes, - groups=groups, - dscs=dscs, - bias=bias, - out_channel_sizes=out_channel_sizes, - ) - - self.relu = ElasticPermissiveReLU() - self.norm = True - self.act = True - self.from_skipping = from_skipping - - def forward(self, input: torch.Tensor) -> torch.Tensor: - """ - - Args: - input: torch.Tensor: - input: torch.Tensor: - input: torch.Tensor: - input: torch.Tensor: - - Returns: - - """ - return self.relu(super(ElasticConvBnReLu1d, self).forward(input)) - - # return a normal conv1d equivalent to this module in the current state - def get_basic_module(self) -> nn.Module: - """ """ - kernel, bias = self.get_kernel() - self.set_in_and_out_channel(kernel) - - kernel_size = self.kernel_sizes[self.target_kernel_index] - dilation = self.get_dilation_size() - grouping = self.get_group_size() - padding = conv1d_get_padding(kernel_size, dilation) - - dsc_on = self.get_dsc() - if dsc_on: - tmp_bn = self.bn.get_basic_batchnorm1d() - dsc_sequence: nn.Sequential = self.prepare_dsc_for_validation_model( - conv_class=ConvBnReLu1d, - full_kernel=self.get_full_width_kernel(), - full_bias=self.bias, - in_channels=self.in_channels, - out_channels=self.out_channels, - grouping=grouping, - stride=self.stride, - padding=padding, - dilation=dilation, - bn_caller=(self.set_bn_parameter, tmp_bn, self.bn.num_batches_tracked), - ) - self.reset_in_and_out_channel_to_previous() - return dsc_sequence - else: - new_conv = ConvBnReLu1d( - in_channels=self.in_channels, - out_channels=self.out_channels, - kernel_size=kernel_size, - stride=self.stride, - padding=padding, - dilation=dilation, - bias=False, - groups=grouping, - ) - tmp_bn = self.bn.get_basic_batchnorm1d() - - # for ana purposes handy - set a unique id so we can track this specific convolution - new_conv.last_grouping_param = self.groups - if not hasattr(new_conv, "id"): - new_conv.id = "ElasticConvBnReLu1d-" + str( - random.randint(0, 1000) * 2000 - ) - logging.debug( - f"Validation id created: {new_conv.id} ; g={grouping}, w_before={kernel.shape}, ic={self.in_channels}" - ) - else: - logging.debug("id already present: {new_conv.id}") - kernel, _ = adjust_weight_if_needed( - module=new_conv, kernel=kernel, groups=new_conv.groups - ) - logging.debug( - f"=====> id: {new_conv.id} ; g={grouping}, w_after={kernel.shape}, ic={self.in_channels}, fromSkipping={self.from_skipping}" - ) - - new_conv.weight.data = kernel - new_conv.bias = bias - - new_conv = self.set_bn_parameter( - new_conv, tmp_bn=tmp_bn, num_tracked=self.bn.num_batches_tracked - ) - # print("\nassembled a basic conv from elastic kernel!") - self.reset_in_and_out_channel_to_previous() - return new_conv - - -class ConvRelu1d(nn.Conv1d): - """ """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - dilation_sizes: List[int], - stride: int = 1, - padding: int = 0, - groups: int = 1, - bias: bool = False, - track_running_stats=False, - ): - super().__init__( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation_sizes, - groups=groups, - bias=bias, - ) - self.relu = nn.ReLU() - self.norm = False - self.act = True - - def forward(self, input: torch.Tensor) -> torch.Tensor: - """ - - Args: - input: torch.Tensor: - input: torch.Tensor: - input: torch.Tensor: - input: torch.Tensor: - - Returns: - - """ - return self.relu(super(ConvRelu1d, self).forward(input)) - - -class ConvBn1d(nn.Conv1d): - """ """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - stride: int = 1, - padding: int = 0, - dilation: int = 1, - groups: int = 1, - bias: bool = False, - track_running_stats=False, - ): - super().__init__( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - bias=bias, - ) - self.bn = nn.BatchNorm1d(out_channels, track_running_stats=track_running_stats) - self.norm = True - self.act = False - - def forward(self, input: torch.Tensor) -> torch.Tensor: - """ - - Args: - input: torch.Tensor: - input: torch.Tensor: - input: torch.Tensor: - input: torch.Tensor: - - Returns: - - """ - return self.bn(super(ConvBn1d, self).forward(input)) - - -class ConvBnReLu1d(ConvBn1d): - """ """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - stride: int = 1, - padding: int = 0, - dilation: int = 1, - groups: int = 1, - bias: bool = False, - track_running_stats=False, - ): - super().__init__( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - bias=bias, - ) - self.bn = nn.BatchNorm1d(out_channels, track_running_stats=track_running_stats) - self.relu = nn.ReLU() - self.norm = True - self.act = True - - def forward(self, input: torch.Tensor) -> torch.Tensor: - """ - - Args: - input: torch.Tensor: - input: torch.Tensor: - input: torch.Tensor: - input: torch.Tensor: - - Returns: - - """ - return self.relu(super(ConvBnReLu1d, self).forward(input)) diff --git a/hannah/models/ofa/submodules/elasticquantkernelconv.py b/hannah/models/ofa/submodules/elasticquantkernelconv.py deleted file mode 100644 index 1fb3454d..00000000 --- a/hannah/models/ofa/submodules/elasticquantkernelconv.py +++ /dev/null @@ -1,1131 +0,0 @@ -# -# Copyright (c) 2022 University of Tübingen. -# -# This file is part of hannah. -# See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import copy -import math -from typing import List - -import torch -import torch.nn as nn -from torch.nn import init - -from hannah.nn import qat - -from ..utilities import ( - adjust_weight_if_needed, - conv1d_get_padding, - filter_single_dimensional_weights, -) -from .elasticBase import ElasticBase1d -from .elasticBatchnorm import ElasticWidthBatchnorm1d -from .elasticLinear import ElasticPermissiveReLU - - -class QuadDataHelper: - """Data Container so that _forward and _dsc has the same data.""" - - bias_shape = None - kernelsize = None - dilation = None - grouping = None - padding = None - scale_factor = None - scaled_weight = None - zero_bias = None - - def __init__( - self, - bias_shape, - kernelsize, - dilation, - grouping, - padding, - scale_factor, - scaled_weight, - zero_bias, - ): - self.bias_shape = bias_shape - self.kernelsize = kernelsize - self.dilation = dilation - self.grouping = grouping - self.padding = padding - self.scale_factor = scale_factor - self.scaled_weight = scaled_weight - self.zero_bias = zero_bias - - -# Adapted base Class used for the Quantization -# pytype: enable=attribute-error -class _ElasticConvBnNd( - ElasticBase1d, qat._ConvForwardMixin -): # pytype: disable=module-attr - - _version = 2 - - def __init__( - self, - # ConvNd args - in_channels, - out_channels, - kernel_sizes, - dilation_sizes, - stride=1, - padding=0, - transposed=False, - output_padding=0, - groups: List[int] = [1], - dscs: List[bool] = [False], - bias=False, - padding_mode="zeros", - # BatchNormNd args - eps=1e-05, - momentum=0.1, - freeze_bn=False, - qconfig=None, - dim=1, - out_quant=True, - track_running_stats=True, - out_channel_sizes=None, - fuse_bn=True, - ): - ElasticBase1d.__init__( - self, - in_channels=in_channels, - out_channels=out_channels, - kernel_sizes=kernel_sizes, - stride=stride, - padding=padding, - dilation_sizes=dilation_sizes, - groups=groups, - dscs=dscs, - bias=bias, - padding_mode=padding_mode, - out_channel_sizes=out_channel_sizes, - ) - assert qconfig, "qconfig must be provided for QAT module" - self.qconfig = qconfig - self.freeze_bn = freeze_bn if self.training else True - self.fuse_bn = fuse_bn - self.out_quant = out_quant - self.bn = nn.ModuleList() - self.bn.append( - ElasticWidthBatchnorm1d( - out_channels, - eps=eps, - momentum=momentum, - track_running_stats=track_running_stats, - ) - ) - - self.weight_fake_quant = self.qconfig.weight() - self.activation_post_process = ( - self.qconfig.activation() if out_quant else nn.Identity - ) - self.dim = dim - - if hasattr(self.qconfig, "bias"): - self.bias_fake_quant = self.qconfig.bias() - else: - self.bias_fake_quant = self.qconfig.activation() - - if bias: - self.bias = nn.Parameter(torch.Tensor(out_channels)) - else: - self.bias = None - - self.reset_bn_parameters() - - # this needs to be called after reset_bn_parameters, - # as they modify the same state - if self.training: - if freeze_bn: - self.freeze_bn_stats() - else: - self.update_bn_stats() - else: - self.freeze_bn_stats() - - def on_warmup_end(self): - """ """ - for i in range(len(self.kernel_sizes) - 1): - self.bn.append(copy.deepcopy(self.bn[0])) - - def reset_running_stats(self): - """ """ - for idx in range(len(self.bn)): - self.bn[idx].reset_running_stats() - - def reset_bn_parameters(self): - """ """ - for idx in range(len(self.bn)): - self.bn[idx].reset_running_stats() - init.uniform_(self.bn[idx].weight) - init.zeros_(self.bn[idx].bias) - # note: below is actully for conv, not BN - if self.bias is not None: - fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) - bound = 1 / math.sqrt(fan_in) - init.uniform_(self.bias, -bound, bound) - - def reset_parameters(self): - """ """ - super(_ElasticConvBnNd, self).reset_parameters() - - def update_bn_stats(self): - """ """ - self.freeze_bn = False - for idx in range(len(self.bn) - 1): - self.bn[idx].training = True - return self - - def freeze_bn_stats(self): - """ """ - self.freeze_bn = True - for idx in range(len(self.bn) - 1): - self.bn[idx].training = False - return self - - @property - def scale_factor(self): - """ """ - if self.fuse_bn: - running_std = torch.sqrt( - self.bn[self.target_kernel_index].running_var - + self.bn[self.target_kernel_index].eps - ) - - scale_factor = self.bn[self.target_kernel_index].weight / running_std - else: - scale_factor = torch.ones( - (self.weight.shape[0],), device=self.weight.device - ) - - return filter_single_dimensional_weights(scale_factor, self.out_channel_filter) - - @property - def full_scale_factor(self): - """does the same as scale_factor but uses the whole kernel. Used for dsc""" - if self.fuse_bn: - running_std = torch.sqrt( - self.bn[self.target_kernel_index].running_var - + self.bn[self.target_kernel_index].eps - ) - - scale_factor = self.bn[self.target_kernel_index].weight / running_std - else: - scale_factor = torch.ones( - (self.weight.shape[0],), device=self.weight.device - ) - - return scale_factor - - @property - def scaled_weight(self): - """ """ - scale_factor = self.scale_factor - weight, bias = self.get_kernel() - weight_shape = [1] * len(weight.shape) - weight_shape[0] = -1 - bias_shape = [1] * len(weight.shape) - bias_shape[1] = -1 - - # if we get the scaled weight we need to shape it according to the grouping - grouping = self.get_group_size() - if grouping > 1: - weight, _ = adjust_weight_if_needed( - module=self, kernel=weight, groups=grouping - ) - - scaled_weight = self.weight_fake_quant( - weight * scale_factor.reshape(weight_shape) - ) - - return scaled_weight - - def get_full_kernel_bias(self): - """Gets the full kernel and bias. Used for dsc""" - scale_factor = self.full_scale_factor - weight = self.get_full_width_kernel() - weight_shape = [1] * len(weight.shape) - weight_shape[0] = -1 - bias_shape = [1] * len(weight.shape) - bias_shape[1] = -1 - - scaled_weight = self.weight_fake_quant( - weight * scale_factor.reshape(weight_shape) - ) - - if self.bias is not None: - full_bias = torch.zeros_like(self.bias) - else: - full_bias = torch.zeros(self.out_channels, device=scaled_weight.device) - - return scaled_weight, full_bias - - def _get_params(self) -> QuadDataHelper: - """unifies the param procedure for _forward and _dsc""" - bias_shape = [1] * len(self.weight.shape) - bias_shape[1] = -1 - kernelsize = self.kernel_sizes[self.target_kernel_index] - dilation = self.get_dilation_size() - grouping = self.get_group_size() - self.padding = conv1d_get_padding(kernelsize, dilation) - - scale_factor = self.scale_factor - # if scaled weight is called, the grouping adjusts the weights if needed - scaled_weight = self.scaled_weight - # using zero bias here since the bias for original conv - # will be added later - if self.bias is not None: - zero_bias = torch.zeros_like(self.bias) - else: - zero_bias = torch.zeros(self.out_channels, device=scaled_weight.device) - zero_bias = filter_single_dimensional_weights( - zero_bias, self.out_channel_filter - ) - return QuadDataHelper( - bias_shape, - kernelsize, - dilation, - grouping, - self.padding, - scale_factor, - scaled_weight, - zero_bias, - ) - - def _after_forward_function(self, conv, quad_params: QuadDataHelper): - """unifies the after forward procedure for _forward and _dsc - - Args: - conv: - quad_params: QuadDataHelper: - quad_params: QuadDataHelper: - quad_params: QuadDataHelper: - quad_params: QuadDataHelper: - - Returns: - - """ - scale_factor = quad_params.scale_factor - bias_shape = quad_params.bias_shape - zero_bias = quad_params.zero_bias - - if self.training or not self.fuse_bn: - conv_orig = conv / scale_factor.reshape(bias_shape) - - if self.bias is not None: - bias = filter_single_dimensional_weights( - self.bias, self.out_channel_filter - ) - conv_orig = conv_orig + bias.reshape(bias_shape) - - conv = self.bn[self.target_kernel_index](conv_orig) - # copied from previous _forward (commented code line): - # conv = conv - (self.bn.bias - self.bn.running_mean).reshape(bias_shape) - else: - bias = zero_bias - if self.bias is not None: - _, bias = self.get_kernel() - bias = filter_single_dimensional_weights(bias, self.out_channel_filter) - - bn_rmean = self.bn[self.target_kernel_index].running_mean - bn_bias = self.bn[self.target_kernel_index].bias - - bn_rmean = filter_single_dimensional_weights( - bn_rmean, self.out_channel_filter - ) - bn_bias = filter_single_dimensional_weights( - bn_bias, self.out_channel_filter - ) - - bias = self.bias_fake_quant( - (bias - bn_rmean) * scale_factor + bn_bias - ).reshape(bias_shape) - conv = conv + bias - - return conv - - def _dsc(self, input): - """this method is used for dsc. - it is called as an alternative of _forward - - Args: - input: - - Returns: - - """ - - tmp_quad_helper = self._get_params() - # expand to variables - dilation = tmp_quad_helper.dilation - grouping = tmp_quad_helper.grouping - padding = tmp_quad_helper.padding - scaled_weight = tmp_quad_helper.scaled_weight - zero_bias = tmp_quad_helper.zero_bias - - full_kernel, full_bias = self.get_full_kernel_bias() - dsc_sequence_output = self.do_dsc( - input=input, - full_kernel=full_kernel, - full_bias=full_bias, - grouping=grouping, - stride=self.stride, - padding=padding, - dilation=dilation, - quant_weight=scaled_weight, - quant_bias=zero_bias, - ) - - conv_output = self._after_forward_function( - dsc_sequence_output, quad_params=tmp_quad_helper - ) - return conv_output - - def _forward(self, input): - """ - - Args: - input: - - Returns: - - """ - tmp_quad_helper: QuadDataHelper = self._get_params() - grouping = tmp_quad_helper.grouping - scaled_weight = tmp_quad_helper.scaled_weight - zero_bias = tmp_quad_helper.zero_bias - - conv = self._real_conv_forward(input, scaled_weight, zero_bias, grouping) - conv = self._after_forward_function(conv, quad_params=tmp_quad_helper) - - return conv - - def extra_repr(self): - """ """ - # TODO(jerryzh): extend - return super(_ElasticConvBnNd, self).extra_repr() - - def forward(self, input): - """ - - Args: - input: - - Returns: - - """ - dsc_on = self.get_dsc() - - if not dsc_on: - y = self._forward(input) - else: - y = self._dsc(input) - return y - - def train(self, mode=True): - """Batchnorm's training behavior is using the self.training flag. Prevent - changing it if BN is frozen. This makes sure that calling `model.train()` - on a model with a frozen BN will behave properly. - - Args: - mode: (Default value = True) - - Returns: - - """ - self.training = mode - if not self.freeze_bn: - for module in self.children(): - module.train(mode) - return self - - # ===== Serialization version history ===== - # - # Version 1/None - # self - # |--- weight : Tensor - # |--- bias : Tensor - # |--- gamma : Tensor - # |--- beta : Tensor - # |--- running_mean : Tensor - # |--- running_var : Tensor - # |--- num_batches_tracked : Tensor - # - # Version 2 - # self - # |--- weight : Tensor - # |--- bias : Tensor - # |--- bn : Module - # |--- weight : Tensor (moved from v1.self.gamma) - # |--- bias : Tensor (moved from v1.self.beta) - # |--- running_mean : Tensor (moved from v1.self.running_mean) - # |--- running_var : Tensor (moved from v1.self.running_var) - # |--- num_batches_tracked : Tensor (moved from v1.self.num_batches_tracked) - def _load_from_state_dict( - self, - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ): - """ - - Args: - state_dict: - prefix: - local_metadata: - strict: - missing_keys: - unexpected_keys: - error_msgs: - - Returns: - - """ - version = local_metadata.get("version", None) - if version is None or version == 1: - # BN related parameters and buffers were moved into the BN module for v2 - v2_to_v1_names = { - "bn.weight": "gamma", - "bn.bias": "beta", - "bn.running_mean": "running_mean", - "bn.running_var": "running_var", - "bn.num_batches_tracked": "num_batches_tracked", - } - for v2_name, v1_name in v2_to_v1_names.items(): - if prefix + v1_name in state_dict: - state_dict[prefix + v2_name] = state_dict[prefix + v1_name] - state_dict.pop(prefix + v1_name) - elif prefix + v2_name in state_dict: - # there was a brief period where forward compatibility - # for this module was broken (between - # https://github.com/pytorch/pytorch/pull/38478 - # and https://github.com/pytorch/pytorch/pull/38820) - # and modules emitted the v2 state_dict format while - # specifying that version == 1. This patches the forward - # compatibility issue by allowing the v2 style entries to - # be used. - pass - elif strict: - missing_keys.append(prefix + v2_name) - - super(_ElasticConvBnNd, self)._load_from_state_dict( - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ) - - @classmethod - def from_float(cls, mod): - """Create a qat module from a float module or qparams_dict - Args: `mod` a float module, either produced by torch.quantization utilities - or directly from user - - Args: - mod: - - Returns: - - """ - assert type(mod) == cls._FLOAT_MODULE, ( - "qat." - + cls.__name__ - + ".from_float only works for " - + cls._FLOAT_MODULE.__name__ - ) - assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" - assert mod.qconfig, "Input float module must have a valid qconfig" - qconfig = mod.qconfig - conv, bn = mod[0], mod[1] - qat_convbn = cls( - conv.in_channels, - conv.out_channels, - conv.kernel_size, - conv.stride, - conv.padding, - conv.dilation, - conv.groups, - conv.bias is not None, - conv.padding_mode, - bn.eps, - bn.momentum, - False, - qconfig, - ) - qat_convbn.weight = conv.weight - qat_convbn.bias = conv.bias - qat_convbn.bn.weight = bn.weight - qat_convbn.bn.bias = bn.bias - qat_convbn.bn.running_mean = bn.running_mean - qat_convbn.bn.running_var = bn.running_var - qat_convbn.bn.num_batches_tracked = bn.num_batches_tracked - return qat_convbn - - -class ElasticQuantConv1d(ElasticBase1d, qat._ConvForwardMixin): - """ """ - - _FLOAT_MODULE = nn.Conv1d - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_sizes: List[int], - dilation_sizes: List[int], - stride: int = 1, - padding: int = 0, - groups: List[int] = [1], - dscs: List[bool] = [False], - bias: bool = False, - padding_mode="zeros", - qconfig=None, - out_quant=True, - out_channel_sizes=None, - ): - ElasticBase1d.__init__( - self, - in_channels=in_channels, - out_channels=out_channels, - kernel_sizes=kernel_sizes, - stride=stride, - padding=padding, - dilation_sizes=dilation_sizes, - groups=groups, - dscs=dscs, - bias=bias, - out_channel_sizes=out_channel_sizes, - padding_mode=padding_mode, - ) - assert qconfig, "qconfig must be provided for QAT module" - self.qconfig = qconfig - self.out_quant = out_quant - self.weight_fake_quant = self.qconfig.weight() - self.activation_post_process = ( - self.qconfig.activation() if out_quant else nn.Identity() - ) - if hasattr(qconfig, "bias"): - self.bias_fake_quant = self.qconfig.bias() - else: - self.bias_fake_quant = self.qconfig.activation() - self.dim = 1 - self.norm = False - self.act = False - - def forward(self, input: torch.Tensor) -> torch.Tensor: - """ - - Args: - input: torch.Tensor: - input: torch.Tensor: - input: torch.Tensor: - input: torch.Tensor: - - Returns: - - """ - # get the kernel for the current index - weight, bias = self.get_kernel() - grouping = self.get_group_size() - if grouping > 1: - weight, _ = adjust_weight_if_needed( - module=self, kernel=weight, groups=grouping - ) - - dsc_on = self.get_dsc() - - if not dsc_on: - y = self.activation_post_process( - self._real_conv_forward( - input, - self.weight_fake_quant(weight), - self.bias_fake_quant(bias) if self.bias is not None else None, - grouping, - ) - ) - else: - full_kernel, full_bias = self.get_full_width_kernel(), self.bias - y = self.activation_post_process( - self.do_dsc( - input=input, - full_kernel=full_kernel, - full_bias=full_bias, - grouping=grouping, - stride=self.stride, - padding=self.padding, - dilation=self.dilation, - quant_weight_function=self.weight_fake_quant, - quant_bias_function=self.bias_fake_quant, - ) - ) - return y - - # return a normal conv1d equivalent to this module in the current state - def get_basic_module(self) -> nn.Module: - """ """ - kernel, bias = self.get_kernel() - self.set_in_and_out_channel(kernel) - - kernel_size = self.kernel_size - dilation = self.get_dilation_size() - grouping = self.get_group_size() - dsc_on = self.get_dsc() - padding = conv1d_get_padding(kernel_size, dilation) - - if dsc_on: - dsc_sequence: nn.Sequential = self.prepare_dsc_for_validation_model( - conv_class=qat.Conv1d, - full_kernel=self.get_full_width_kernel(), - full_bias=self.bias, - in_channels=self.in_channels, - out_channels=self.out_channels, - grouping=grouping, - stride=self.stride, - padding=padding, - dilation=dilation, - qconfig=self.qconfig, - out_quant=self.out_quant, - ) - self.reset_in_and_out_channel_to_previous() - return dsc_sequence - else: - new_conv = qat.Conv1d( - self.in_channels, - self.out_channels, - kernel_size, - self.stride, - padding, - dilation, - grouping, - bias, - qconfig=self.qconfig, - out_quant=self.out_quant, - ) - kernel, _ = adjust_weight_if_needed( - module=self, kernel=kernel, groups=grouping - ) - new_conv.weight.data = kernel - if bias is not None: - new_conv.bias = bias - - self.reset_in_and_out_channel_to_previous() - # print("\nassembled a basic conv from elastic kernel!") - return new_conv - - -class ElasticQuantConvReLu1d(ElasticBase1d, qat._ConvForwardMixin): - """ """ - - _FLOAT_MODULE = nn.Conv1d - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_sizes: List[int], - dilation_sizes: List[int], - stride: int = 1, - padding: int = 0, - groups: List[int] = [1], - dscs: List[bool] = [False], - bias: bool = False, - padding_mode="zeros", - qconfig=None, - out_quant=True, - out_channel_sizes=None, - ): - - ElasticBase1d.__init__( - self, - in_channels=in_channels, - out_channels=out_channels, - kernel_sizes=kernel_sizes, - stride=stride, - padding=padding, - dilation_sizes=dilation_sizes, - groups=groups, - dscs=dscs, - bias=bias, - out_channel_sizes=out_channel_sizes, - ) - - assert qconfig, "qconfig must be provided for QAT module" - self.relu = ElasticPermissiveReLU() - self.qconfig = qconfig - self.out_quant = out_quant - self.weight_fake_quant = self.qconfig.weight() - self.activation_post_process = ( - self.qconfig.activation() if out_quant else nn.Identity() - ) - if hasattr(qconfig, "bias"): - self.bias_fake_quant = self.qconfig.bias() - else: - self.bias_fake_quant = self.qconfig.activation() - self.dim = 1 - self.norm = False - self.act = True - - def forward(self, input: torch.Tensor) -> torch.Tensor: - """ - - Args: - input: torch.Tensor: - input: torch.Tensor: - input: torch.Tensor: - input: torch.Tensor: - - Returns: - - """ - # get the kernel for the current index - weight, bias = self.get_kernel() - grouping = self.get_group_size() - if grouping > 1: - weight, _ = adjust_weight_if_needed( - module=self, kernel=weight, groups=grouping - ) - - dsc_on = self.get_dsc() - - if not dsc_on: - y = self.activation_post_process( - self.relu( - self._real_conv_forward( - input, - self.weight_fake_quant(weight), - self.bias_fake_quant(bias) if self.bias is not None else None, - grouping, - ) - ) - ) - else: - full_kernel, full_bias = self.get_full_width_kernel(), self.bias - y = self.activation_post_process( - self.do_dsc( - input=input, - full_kernel=full_kernel, - full_bias=full_bias, - grouping=grouping, - stride=self.stride, - padding=self.padding, - dilation=self.dilation, - quant_weight_function=self.weight_fake_quant, - quant_bias_function=self.bias_fake_quant, - ) - ) - # self.reset_in_and_out_channel_to_previous() - return y - - # return a normal conv1d equivalent to this module in the current state - def get_basic_module(self) -> nn.Module: - """ """ - kernel, bias = self.get_kernel() - kernel_size = self.kernel_sizes[self.target_kernel_index] - dilation = self.get_dilation_size() - padding = conv1d_get_padding(kernel_size, dilation) - grouping = self.get_group_size() - dsc_on = self.get_dsc() - self.set_in_and_out_channel(kernel) - - if dsc_on: - dsc_sequence: nn.Sequential = self.prepare_dsc_for_validation_model( - conv_class=qat.ConvReLU1d, - full_kernel=self.get_full_width_kernel(), - full_bias=self.bias, - in_channels=self.in_channels, - out_channels=self.out_channels, - grouping=grouping, - stride=self.stride, - padding=padding, - dilation=dilation, - qconfig=self.qconfig, - out_quant=self.out_quant, - ) - self.reset_in_and_out_channel_to_previous() - return dsc_sequence - else: - new_conv = qat.ConvReLU1d( - self.in_channels, - self.out_channels, - kernel_size, - self.stride, - padding, - dilation, - grouping, - bias, - qconfig=self.qconfig, - out_quant=self.out_quant, - ) - kernel, _ = adjust_weight_if_needed( - module=self, kernel=kernel, groups=grouping - ) - new_conv.weight.data = kernel - if bias is not None: - new_conv.bias = bias - - self.reset_in_and_out_channel_to_previous() - # print("\nassembled a basic conv from elastic kernel!") - return new_conv - - -class ElasticQuantConvBn1d(_ElasticConvBnNd): - """ """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_sizes: List[int], - dilation_sizes: List[int], - stride: int = 1, - padding: int = 0, - groups: List[int] = [1], - dscs: List[bool] = [False], - bias: bool = False, - track_running_stats=True, - qconfig=None, - out_quant=True, - out_channel_sizes=None, - ): - _ElasticConvBnNd.__init__( - self, - in_channels=in_channels, - out_channels=out_channels, - kernel_sizes=kernel_sizes, - stride=stride, - padding=padding, - dilation_sizes=dilation_sizes, - groups=groups, - dscs=dscs, - bias=bias, - qconfig=qconfig, - out_channel_sizes=out_channel_sizes, - ) - self.out_quant = out_quant - self.norm = True - self.act = False - - def forward(self, input: torch.Tensor) -> torch.Tensor: - """ - - Args: - input: torch.Tensor: - input: torch.Tensor: - input: torch.Tensor: - input: torch.Tensor: - - Returns: - - """ - # get padding for the size of the kernel - dilation = self.get_dilation_size() - self.padding = conv1d_get_padding( - self.kernel_sizes[self.target_kernel_index], dilation - ) - y = super(ElasticQuantConvBn1d, self).forward(input) - # self.reset_in_and_out_channel_to_previous() - return self.activation_post_process(y) - - # return a normal conv1d equivalent to this module in the current state - def get_basic_module(self) -> nn.Module: - """ """ - kernel, bias = self.get_kernel() - grouping = self.get_group_size() - dsc_on = self.get_dsc() - self.set_in_and_out_channel(kernel) - - if dsc_on: - tmp_bn = self.bn[self.target_kernel_index].get_basic_batchnorm1d() - dsc_sequence: nn.Sequential = self.prepare_dsc_for_validation_model( - conv_class=qat.ConvReLU1d, - full_kernel=self.get_full_width_kernel(), - full_bias=self.bias, - in_channels=self.in_channels, - out_channels=self.out_channels, - grouping=grouping, - stride=self.stride, - padding=self.padding, - dilation=self.dilation, - bn_eps=self.bn[self.target_kernel_index].eps, - bn_momentum=self.bn[self.target_kernel_index].momentum, - qconfig=self.qconfig, - out_quant=self.out_quant, - bn_caller=(self.set_bn_parameter, tmp_bn, tmp_bn.num_batches_tracked), - ) - self.reset_in_and_out_channel_to_previous() - return dsc_sequence - else: - new_conv = qat.ConvBn1d( - kernel.shape[1], - kernel.shape[0], - self.kernel_size, - self.stride, - self.padding, - self.dilation, - grouping, - bias, - eps=self.bn[self.target_kernel_index].eps, - momentum=self.bn[self.target_kernel_index].momentum, - qconfig=self.qconfig, - out_quant=self.out_quant, - ) - kernel, _ = adjust_weight_if_needed( - module=self, kernel=kernel, groups=grouping - ) - new_conv.weight.data = kernel - new_conv.bias = bias - tmp_bn = self.bn[self.target_kernel_index].get_basic_batchnorm1d() - - new_conv = self.set_bn_parameter( - new_conv, tmp_bn=tmp_bn, num_tracked=tmp_bn.num_batches_tracked - ) - # print("\nassembled a basic conv from elastic kernel!") - self.reset_in_and_out_channel_to_previous() - return new_conv - - -class ElasticQuantConvBnReLu1d(ElasticQuantConvBn1d): - """ """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_sizes: List[int], - dilation_sizes: List[int], - stride: int = 1, - padding: int = 0, - groups: List[int] = [1], - dscs: List[bool] = [False], - bias: bool = False, - track_running_stats=True, - qconfig=None, - out_quant=True, - out_channel_sizes=None, - ): - ElasticQuantConvBn1d.__init__( - self, - in_channels=in_channels, - out_channels=out_channels, - kernel_sizes=kernel_sizes, - stride=stride, - padding=padding, - dilation_sizes=dilation_sizes, - groups=groups, - dscs=dscs, - bias=bias, - qconfig=qconfig, - out_channel_sizes=out_channel_sizes, - out_quant=out_quant, - ) - - self.relu = ElasticPermissiveReLU() - self.norm = True - self.act = True - - def forward(self, input: torch.Tensor) -> torch.Tensor: - """ - - Args: - input: torch.Tensor: - input: torch.Tensor: - input: torch.Tensor: - input: torch.Tensor: - - Returns: - - """ - dilation = self.get_dilation_size() - self.padding = conv1d_get_padding( - self.kernel_sizes[self.target_kernel_index], dilation - ) - dsc_on = self.get_dsc() - - if not dsc_on: - y = super(ElasticQuantConvBnReLu1d, self)._forward(input) - else: - y = super(ElasticQuantConvBnReLu1d, self)._dsc(input) - return self.activation_post_process(self.relu(y)) - - # return a normal conv1d equivalent to this module in the current state - def get_basic_module(self) -> nn.Module: - """ """ - kernel, bias = self.get_kernel() - self.set_in_and_out_channel(kernel) - - grouping = self.get_group_size() - dsc_on = self.get_dsc() - - if dsc_on: - tmp_bn = self.bn[self.target_kernel_index].get_basic_batchnorm1d() - dsc_sequence: nn.Sequential = self.prepare_dsc_for_validation_model( - conv_class=qat.ConvBnReLU1d, - full_kernel=self.get_full_width_kernel(), - full_bias=self.bias, - in_channels=self.in_channels, - out_channels=self.out_channels, - grouping=grouping, - stride=self.stride, - padding=self.padding, - dilation=self.dilation, - bn_eps=self.bn[self.target_kernel_index].eps, - bn_momentum=self.bn[self.target_kernel_index].momentum, - qconfig=self.qconfig, - out_quant=self.out_quant, - bn_caller=(self.set_bn_parameter, tmp_bn, tmp_bn.num_batches_tracked), - ) - self.reset_in_and_out_channel_to_previous() - return dsc_sequence - else: - - new_conv = qat.ConvBnReLU1d( - kernel.shape[1], - kernel.shape[0], - self.kernel_size, - self.stride, - self.padding, - self.dilation, - grouping, - bias, - eps=self.bn[self.target_kernel_index].eps, - momentum=self.bn[self.target_kernel_index].momentum, - qconfig=self.qconfig, - out_quant=self.out_quant, - ) - kernel, _ = adjust_weight_if_needed( - module=self, kernel=kernel, groups=grouping - ) - new_conv.weight.data = kernel - new_conv.bias = bias - tmp_bn = self.bn[self.target_kernel_index].get_basic_batchnorm1d() - - new_conv = self.set_bn_parameter( - new_conv, tmp_bn=tmp_bn, num_tracked=tmp_bn.num_batches_tracked - ) - self.reset_in_and_out_channel_to_previous() - # print("\nassembled a basic conv from elastic kernel!") - return new_conv diff --git a/hannah/models/ofa/submodules/resblock.py b/hannah/models/ofa/submodules/resblock.py deleted file mode 100644 index 931271aa..00000000 --- a/hannah/models/ofa/submodules/resblock.py +++ /dev/null @@ -1,201 +0,0 @@ -# -# Copyright (c) 2022 University of Tübingen. -# -# This file is part of hannah. -# See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import logging - -import torch.nn as nn - -from ..utilities import flatten_module_list - -# base construct of a residual block -from .elasticBase import ElasticBase1d -from .elasticBatchnorm import ElasticWidthBatchnorm1d -from .elasticchannelhelper import ElasticChannelHelper -from .elastickernelconv import ElasticConvBnReLu1d -from .elasticquantkernelconv import ElasticQuantConvBnReLu1d - - -class ResBlockBase(nn.Module): - """ """ - - def __init__( - self, - in_channels, - out_channels, - act_after_res=True, - ): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.do_act = act_after_res - # if the input channel count does not match the output channel count, - # apply skip to residual values - self.apply_skip = self.in_channels != self.out_channels - # placeholders: - self.act = nn.Identity() - self.blocks = nn.Identity() - self.skip = nn.Identity() - - def forward(self, x): - """ - - Args: - x: - - Returns: - - """ - residual = x - # do not use self.apply_skip for this: a skip connection may still be - # added to support elastic width - # by default, skip is an Identity. It may be needlessly applied if the - # residual block implementation does not replace it with a skip or None - if self.skip is not None: - residual = self.skip(residual) - try: - x = self.blocks(x) - except RuntimeError as r: - logging.warn(r) - for _, actualModel in self.blocks._modules.items(): - logging.info(f"DEBUG Module List: {actualModel}") - logging.info( - f"DEBUG Settings: oc={actualModel.out_channels}, ic={actualModel.in_channels}, weights={actualModel.weight.shape}, k={actualModel.kernel_size}, s={actualModel.stride}, g={actualModel.groups}" - ) - x += residual - if self.do_act: - x = self.act(x) - return x - - def get_nested_modules(self): - """ """ - return nn.ModuleList([self.blocks, self.skip, self.act]) - - -# residual block with a 1d skip connection -class ResBlock1d(ResBlockBase): - """ """ - - def __init__( - self, - in_channels, - out_channels, - minor_blocks, - act_after_res=True, - stride=1, - norm_before_act=True, - quant_skip=False, - qconfig=None, - out_quant=True, - ): - super().__init__( - in_channels=in_channels, - out_channels=out_channels, - act_after_res=act_after_res, - ) - # set the minor block sequence if specified in construction - # if minor_blocks is not None: - self.blocks = minor_blocks - - self.norm = ElasticWidthBatchnorm1d(out_channels) - self.act = nn.ReLU() - self.qconfig = qconfig - self.quant_skip = quant_skip - # if applying skip to the residual values is required, create skip as a minimal conv1d - # stride is also applied to the skip layer (if specified, default is 1) - if not quant_skip: - self.skip = nn.Sequential( - ElasticConvBnReLu1d( - self.in_channels, - out_channels, - kernel_sizes=[1], - dilation_sizes=[1], - groups=[1], - dscs=[False], - stride=stride, - bias=False, - out_channel_sizes=flatten_module_list(self.blocks)[ - -1 - ].out_channel_sizes, - # TODO to delete after ana - from_skipping=True, - ), - ) - else: - self.skip = nn.Sequential( - ElasticQuantConvBnReLu1d( - self.in_channels, - out_channels, - kernel_sizes=[1], - dilation_sizes=[1], - stride=stride, - groups=[1], - dscs=[False], - bias=False, - qconfig=qconfig, - out_channel_sizes=flatten_module_list(self.blocks)[ - -1 - ].out_channel_sizes, - ), - ) # if self.apply_skip else None - if self.qconfig is not None: - self.activation_post_process = ( - self.qconfig.activation() if out_quant else nn.Identity() - ) - # as this does not know if an elastic width section may follow, - # the skip connection is required! it will be needed if the width is modified later - - def forward(self, x): - """ - - Args: - x: - - Returns: - - """ - output = super().forward(x) - if self.qconfig is not None: - return self.activation_post_process(output) - return output - - def get_input_layer(self): - """ """ - input = nn.ModuleList() - input.append(flatten_module_list(self.skip)[0]) - input.append(flatten_module_list(self.blocks)[0]) - return input - - def get_output_layer(self): - """ """ - output = nn.ModuleList() - output.append(flatten_module_list(self.skip)[-1]) - output.append(flatten_module_list(self.blocks)[-1]) - return output - - def create_internal_channelhelper(self): - """ """ - output = nn.ModuleList() - - for idx in range(len(self.blocks) - 1): - if len(self.blocks[idx].out_channel_sizes) > 1: - ech = ElasticChannelHelper(self.blocks[idx].out_channel_sizes) - ech.add_source_item(self.blocks[idx]) - ech.add_targets(self.blocks[idx + 1]) - output.append(ech) - - return output diff --git a/hannah/models/ofa/type_utils.py b/hannah/models/ofa/type_utils.py deleted file mode 100644 index 01d70613..00000000 --- a/hannah/models/ofa/type_utils.py +++ /dev/null @@ -1,95 +0,0 @@ -# -# Copyright (c) 2022 University of Tübingen. -# -# This file is part of hannah. -# See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from hannah.models.ofa.submodules.elasticBatchnorm import ElasticWidthBatchnorm1d -from hannah.models.ofa.submodules.elastickernelconv import ( - ElasticConv1d, - ElasticConvBn1d, - ElasticConvBnReLu1d, - ElasticConvReLu1d, -) -from hannah.models.ofa.submodules.elasticLinear import ( - ElasticPermissiveReLU, - ElasticQuantWidthLinear, - ElasticWidthLinear, -) -from hannah.models.ofa.submodules.elasticquantkernelconv import ( - ElasticQuantConv1d, - ElasticQuantConvBn1d, - ElasticQuantConvBnReLu1d, - ElasticQuantConvReLu1d, -) - -# A dictionary that maps the combination string of the convolution type to the class that -# implements it. -elasic_conv_classes = { - "none": ElasticConv1d, - "quant": ElasticQuantConv1d, - "act": ElasticConvReLu1d, - "actquant": ElasticQuantConvReLu1d, - "norm": ElasticConvBn1d, - "normquant": ElasticQuantConvBn1d, - "normact": ElasticConvBnReLu1d, - "normactquant": ElasticQuantConvBnReLu1d, -} - - -# A tuple of all the classes that are subclasses of `ElasticBaseConv`. -elastic_conv_type = ( - ElasticConv1d, - ElasticConvReLu1d, - ElasticConvBn1d, - ElasticConvBnReLu1d, - ElasticQuantConv1d, - ElasticQuantConvReLu1d, - ElasticQuantConvBn1d, - ElasticQuantConvBnReLu1d, -) - -elastic_forward_type = ( - ElasticConv1d, - ElasticConvReLu1d, - ElasticConvBn1d, - ElasticConvBnReLu1d, - ElasticQuantConv1d, - ElasticQuantConvReLu1d, - ElasticQuantConvBn1d, - ElasticQuantConvBnReLu1d, - ElasticWidthLinear, - ElasticQuantWidthLinear, -) - -elastic_Linear_type = ( - ElasticWidthLinear, - ElasticQuantWidthLinear, -) - -elastic_all_type = ( - ElasticConv1d, - ElasticConvReLu1d, - ElasticConvBn1d, - ElasticConvBnReLu1d, - ElasticQuantConv1d, - ElasticQuantConvReLu1d, - ElasticQuantConvBn1d, - ElasticQuantConvBnReLu1d, - ElasticWidthBatchnorm1d, - ElasticWidthLinear, - ElasticQuantWidthLinear, - ElasticPermissiveReLU, -) diff --git a/hannah/models/ofa/utilities.py b/hannah/models/ofa/utilities.py deleted file mode 100644 index 7ca571bc..00000000 --- a/hannah/models/ofa/utilities.py +++ /dev/null @@ -1,629 +0,0 @@ -# -# Copyright (c) 2022 University of Tübingen. -# -# This file is part of hannah. -# See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# import logging -import logging - -import torch -import torch.nn as nn - - -# Conv1d with automatic padding for the set kernel size -def conv1d_auto_padding(conv1d: nn.Conv1d): - """ - - Args: - conv1d: nn.Conv1d: - conv1d: nn.Conv1d: - conv1d: nn.Conv1d: - conv1d: nn.Conv1d: - - Returns: - - """ - conv1d.padding = conv1d_get_padding(conv1d.kernel_size[0]) - return conv1d - - -def conv1d_get_padding(kernel_size, dilation=1): - """ - - Args: - kernel_size: - dilation: (Default value = 1) - - Returns: - - """ - # check type of kernel_size - if isinstance(kernel_size, tuple): - kernel_size = kernel_size[0] - - # check type of dilation - if isinstance(dilation, tuple): - dilation = dilation[0] - - dil = (kernel_size - 1) * (dilation - 1) - new_kernel_size = kernel_size + dil - padding = new_kernel_size // 2 - return padding - - -# from ofa/utils/common_tools -def sub_filter_start_end(kernel_size, sub_kernel_size): - """ - - Args: - kernel_size: - sub_kernel_size: - - Returns: - - """ - center = kernel_size // 2 - dev = sub_kernel_size // 2 - start, end = center - dev, center + dev + 1 - assert end - start == sub_kernel_size - return start, end - - -# flatten nested iterable modules, usually over a ModuleList. nn.Sequential is -# also an iterable module and a valid input. -def flatten_module_list(modules: nn.Module) -> nn.Module: - """ - - Args: - modules: nn.Module: - modules: nn.Module: - modules: nn.Module: - modules: nn.Module: - - Returns: - - """ - if not hasattr(modules, "__iter__"): - if isinstance(modules, nn.Module): - # if the input is non-iterable and is already a module, it can be returned as a list of one element - return nn.ModuleList([modules]) - - else: - # flatten any nested Sequential or ModuleList - contains_nested = (isinstance(x, nn.Sequential) for x in modules) or ( - isinstance(x, nn.ModuleList) for x in modules - ) - # repeat until the cycle no longer finds nested modules - while contains_nested: - # print(f"Nested? {type(modules)} {len(modules)}") - contains_nested = False - new_module_list = nn.ModuleList([]) - for old_item in modules: - if hasattr(old_item, "__iter__"): - contains_nested = True - for old_subitem in old_item: - new_module_list.append(old_subitem) - else: - new_module_list.append(old_item) - modules = new_module_list - - return modules - - -# return a single module from an input moduleList -def module_list_to_module(module_list): - """ - - Args: - module_list: - - Returns: - - """ - # if the input is a Sequential module it will be iterable, but can be returned as is. - if isinstance(module_list, nn.Sequential): - return module_list - # if the input is not already a module, it must be iterable - if not hasattr(module_list, "__iter__"): - if isinstance(module_list, nn.Module): - return module_list - raise TypeError("input is neither iterable nor module") - if len(module_list) == 1: - module = module_list[0] - assert isinstance( - module, nn.Module - ), "Iterable single-length input does not contain module" - return module - else: - return nn.Sequential(*module_list) - - -# recurse through any iterable (sub)structures. Attempt to call the specified -# function from any discovered objects if it is available. -# return true if any of the calls returned true -# for modules: both ModuleList and Sequential are iterable, so this should be -# able to descend into any module substructures -def call_function_from_deep_nested(input, function, type_selection: type = None): - """ - - Args: - input: - function: - type_selection: type: (Default value = None) - type_selection: type: (Default value = None) - type_selection: type: (Default value = None) - type_selection: type: (Default value = None) - - Returns: - - """ - if input is None: - return False - # print(".") - call_return_value = False - # if a type is specified, only check matching objects - if type_selection is None or isinstance(input, type_selection): - # print(type(input)) - maybe_function = getattr(input, function, None) - if callable(maybe_function): - call_return_value = maybe_function() - # print("deep call!") - - # if the input is iterable, recursively check any nested objects - if hasattr(input, "__iter__"): - for item in input: - new_return_value = call_function_from_deep_nested( - item, function, type_selection - ) - call_return_value = call_return_value or new_return_value - - # if the object has a function to return nested modules, also check them. - if callable(getattr(input, "get_nested_modules", None)): - nested_modules = getattr(input, "get_nested_modules", None)() - new_return_value = call_function_from_deep_nested( - nested_modules, function, type_selection - ) - call_return_value = call_return_value or new_return_value - - return call_return_value - - -# recurse like call_function_from_deep_nested; -# return a list of every found object of -def get_instances_from_deep_nested(input, type_selection: type = None): - """ - - Args: - input: - type_selection: type: (Default value = None) - type_selection: type: (Default value = None) - type_selection: type: (Default value = None) - type_selection: type: (Default value = None) - - Returns: - - """ - results = [] - if input is None: - return results - if type_selection is None or isinstance(input, type_selection): - results.append(input) - # if the input is iterable, recursively check any nested objects - if hasattr(input, "__iter__"): - for item in input: - additional_results = get_instances_from_deep_nested(item, type_selection) - # concatenate the lists - results += additional_results - - # if the object has a function to return nested modules, also check them. - if callable(getattr(input, "get_nested_modules", None)): - nested_modules = getattr(input, "get_nested_modules", None)() - additional_results = get_instances_from_deep_nested( - nested_modules, type_selection - ) - results += additional_results - - return results - - -def filter_primary_module_weights(weights, in_channel_filter, out_channel_filter): - """ - - Args: - weights: - in_channel_filter: - out_channel_filter: - - Returns: - - """ - # out_channel count will be length in dim 0 - out_channel_count = len(weights) - # in_channel count will be length in second dim - in_channel_count = len(weights[0]) - if len(in_channel_filter) != in_channel_count: - logging.error( - f"Unable to filter primary module weights: in_channel count {in_channel_count} does not match filter length {len(in_channel_filter)}" - ) - if len(out_channel_filter) != out_channel_count: - logging.error( - f"Unable to filter primary module weights: out_channel count {out_channel_count} does not match filter length {len(out_channel_filter)}" - ) - - return (weights[out_channel_filter])[:, in_channel_filter] - - -def filter_single_dimensional_weights(weights, channel_filter): - """ - - Args: - weights: - channel_filter: - - Returns: - - """ - if weights is None: - return None - if all(channel_filter): - return weights - channel_count = len(weights) - if len(channel_filter) != channel_count: - logging.error( - f"Unable to filter weights: channel count {channel_count} does not match filter length {len(channel_filter)}" - ) - new_weights = None - # channels where the filter is true are kept. - for i in range(channel_count): - if channel_filter[i]: - if new_weights is None: - new_weights = weights[i : i + 1] - else: - new_weights = torch.cat((new_weights, weights[i : i + 1]), dim=0) - return new_weights - - -def make_parameter(t: torch.Tensor) -> nn.Parameter: - """ - - Args: - t: torch.Tensor: - t: torch.Tensor: - t: torch.Tensor: - t: torch.Tensor: - - Returns: - - """ - if t is None: - return t - if isinstance(t, nn.Parameter): - return t - elif isinstance(t, torch.Tensor): - return nn.parameter.Parameter(t) - else: - logging.error(f"Could not create parameter from input of type '{type(t)}'.") - return None - - -def adjust_weight_if_needed(module, kernel=None, groups=None): - """Adjust the weight if the adjustment is needded. This means, if the kernel does not have the size of - (out_channel, in_channel / group, kernel). - - Args: - kernel: the kernel that should be checked and adjusted if needed. If None module.weight.data will be used (Default value = None) - grouping: value of the conv, if None module.groups will be used - module: the conv - :throws: RuntimeError if there is no last_grouping_param for comporing current group value to past group value - returns (kernel, is adjusted) (adjusted if needed) otherwise throws a RuntimeError - groups: (Default value = None) - - Returns: - - """ - if kernel is None: - kernel = module.weigth.data - if groups is None: - groups = module.groups - - if not hasattr(module, "last_grouping_param"): - raise RuntimeError - - in_channels = kernel.size(1) - - is_adjusted = False - - grouping_changed = groups != module.last_grouping_param - if grouping_changed and groups > 1: - weight_adjustment_needed = is_weight_adjusting_needed( - kernel, in_channels, groups - ) - if weight_adjustment_needed: - is_adjusted = True - kernel = adjust_weights_for_grouping(kernel, groups) - else: - target = get_target_weight(kernel, in_channels, groups) - if hasattr(module, "id"): - logging.debug(f"ID: {module.id}") - - return (kernel, is_adjusted) - - -def is_weight_adjusting_needed(weights, input_channels, groups): - """Checks if a weight adjustment is needed - Requirement: weight.shape[1] must be input_channels/groups - true: weight adjustment is needed - - Args: - weights: the weights that needs to be checked - input_channels: Input Channels of the Convolution Module - groups: Grouping Param of the Convolution Module - - Returns: - - """ - current_weight_dimension = weights.shape[1] - target_weight_dimension = input_channels // groups - return target_weight_dimension != current_weight_dimension - - -def get_target_weight(weights, input_channels, groups): - """Gives the targeted weight shape (out_channel, in_channel // groups, kernel) - - Args: - weights: the weights that needs to be checked - input_channels: Input Channels of the Convolution Module - groups: Grouping Param of the Convolution Module - - Returns: - - """ - target_shape = list(weights.shape) - target_shape[1] = input_channels // groups - return target_shape - - -def prepare_kernel_for_depthwise_separable_convolution( - model, kernel, bias, in_channels -): - """Prepares the kernel for depthwise separable convolution (step 1 of DSC). - This means setting groups = inchannels and outchannels = k * inchannels. - - Args: - model: - kernel: - bias: - in_channels: - - Returns: - : kernel, bias) Tuple - - """ - # Create Filters for Depthwise Separable Convolution of input and output channels - depthwise_output_filter = create_channel_filter( - model, - kernel, - current_channel=kernel.size(0), - reduced_target_channel_size=in_channels, - is_output_filter=True, - ) - depthwise_input_filter = create_channel_filter( - model, - kernel, - current_channel=kernel.size(1), - reduced_target_channel_size=in_channels, - is_output_filter=False, - ) - - # outchannel is adapted - new_kernel = filter_primary_module_weights( - kernel, depthwise_input_filter, depthwise_output_filter - ) - # grouping = in_channel_count - new_kernel = adjust_weights_for_grouping(new_kernel, in_channels) - - if bias is None: - return new_kernel, None - else: - new_bias = filter_single_dimensional_weights(bias, depthwise_output_filter) - return new_kernel, new_bias - - -def prepare_kernel_for_pointwise_convolution(kernel, grouping): - """Prepares the kernel for pointwise convolution (step 2 of DSC). - This means setting the kernel window to 1x1. - So a kernel with output_channel, input_channel / groups, kernel will be set to (_,_,1) - - Args: - kernel: - grouping: - - Returns: - - """ - # use 1x1 kernel - new_kernel = kernel - if grouping > 1: - new_kernel = adjust_weights_for_grouping(kernel, grouping) - - new_kernel = get_kernel_for_dsc(new_kernel) - - return new_kernel - - -def adjust_weights_for_grouping(weights, input_divided_by): - """Adjusts the Weights for the Forward of the Convulution - Shape(outchannels, inchannels / group, kW) - weight – filters of shape (out_channels , in_channels / groups , kW) - input_divided_by - - Args: - weights: - input_divided_by: - - Returns: - - """ - channels_per_group = weights.shape[1] // input_divided_by - - splitted_weights = torch.tensor_split(weights, input_divided_by) - result_weights = [] - - # for current_group in range(groups): - for current_group, current_weight in enumerate(splitted_weights): - input_start = current_group * channels_per_group - input_end = input_start + channels_per_group - current_result_weight = current_weight[:, input_start:input_end, :] - result_weights.append(current_result_weight) - - full_kernel = torch.concat(result_weights) - - return full_kernel - - -def get_kernel_for_dsc(kernel): - """Part of DSC (Step 2, pointwise convolution) - kernel with output_channel, input_channel / groups, kernel will be set to (_,_,1) - - Args: - kernel: - - Returns: - - """ - return kernel[:, :, 0:1] - - -# copied and adapted from elasticchannelhelper.py -# set the channel filter list based on the channel priorities and the reduced_target_channel count -def get_channel_filter( - current_channel_size, reduced_target_channel_size, channel_priority_list -): - """ - - Args: - current_channel_size: - reduced_target_channel_size: - channel_priority_list: - - Returns: - - """ - # get the amount of channels to be removed from the max and current channel counts - channel_reduction_amount: int = current_channel_size - reduced_target_channel_size - # start with an empty filter, where every channel passes through, then remove channels by priority - channel_pass_filter = [True] * current_channel_size - - # filter the least important n channels, specified by the reduction amount - for i in range(channel_reduction_amount): - # priority list of channels contains channel indices from least important to most important - # the first n channel indices specified in this list will be filtered out - filtered_channel_index = channel_priority_list[i] - channel_pass_filter[filtered_channel_index] = False - - return channel_pass_filter - - -def create_channel_filter( - module: nn.Module, - kernel, - current_channel, - reduced_target_channel_size, - is_output_filter: bool = True, -): - """ - - Args: - module: nn.Module: - kernel: - current_channel: - reduced_target_channel_size: - is_output_filter: bool: (Default value = True) - module: nn.Module: - is_output_filter: bool: (Default value = True) - module: nn.Module: - is_output_filter: bool: (Default value = True) - module: nn.Module: - is_output_filter: bool: (Default value = True) - - Returns: - - """ - # create one channel filter - channel_index = 1 if is_output_filter else 0 - channel_filter_priorities = compute_channel_priorities( - module, kernel, channel_index - ) - return get_channel_filter( - current_channel, reduced_target_channel_size, channel_filter_priorities - ) - - -# copied and adapted from elasticchannelhelper.py -# compute channel priorities based on the l1 norm of the weights of whichever -# target module follows this elastic channel section -def compute_channel_priorities(module: nn.Module, kernel, channel_index: int = 0): - """ - - Args: - module: nn.Module: - kernel: - channel_index: int: (Default value = 0) - module: nn.Module: - channel_index: int: (Default value = 0) - module: nn.Module: - channel_index: int: (Default value = 0) - module: nn.Module: - channel_index: int: (Default value = 0) - - Returns: - - """ - channel_norms = [] - - if kernel is None: - logging.warning( - f"Unable to compute channel priorities! Kernel is None: {kernel}" - ) - return None - # this will also include the elastic kernel convolutions - # for elastic kernel convolutions, the priorities will then also be - # computed on the base module (full kernel) - if isinstance(module, nn.Conv1d): - weights = kernel - norms_per_kernel_index = torch.linalg.norm(weights, ord=1, dim=channel_index) - channel_norms = torch.linalg.norm(norms_per_kernel_index, ord=1, dim=1) - # the channel priorities for linears need to also be computable: - # especially for the exit connections, a linear may follow after an elastic width - elif isinstance(module, nn.Linear): - weights = kernel - channel_norms = torch.linalg.norm(weights, ord=1, dim=0) - else: - # the channel priorities will keep their previous / default value in - # this case. Reduction will probably occur by channel order - logging.warning( - f"Unable to compute channel priorities! Unsupported target module after elastic channels: {type(module)}" - ) - - # contains the indices of the channels, sorted from channel with smallest - # norm to channel with largest norm - # the least important channel index is at the beginning of the list, - # the most important channel index is at the end - # np -> torch.argsort() - channels_by_priority = torch.argsort(channel_norms) - - return channels_by_priority diff --git a/hannah/nas/search/model_trainer/progressive_shrinking.py b/hannah/nas/search/model_trainer/progressive_shrinking.py deleted file mode 100644 index b449ad63..00000000 --- a/hannah/nas/search/model_trainer/progressive_shrinking.py +++ /dev/null @@ -1,853 +0,0 @@ -import logging -from pytorch_lightning import seed_everything -from omegaconf import OmegaConf -import omegaconf -from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger -from hydra.utils import instantiate -import torch -from hannah.callbacks.optimization import HydraOptCallback -from hannah.utils.utils import common_callbacks -from pytorch_lightning import Trainer - - -msglogger = logging.getLogger(__name__) - - -class ProgressiveShrinkingModelTrainer: - def __init__(self, - parent_config=None, - epochs_warmup=10, - epochs_kernel_step=10, - epochs_depth_step=10, - epochs_width_step=10, - epochs_dilation_step=10, - epochs_grouping_step=10, - epochs_dsc_step=10, - epochs_tuning_step=0, - elastic_kernels_allowed=False, - elastic_depth_allowed=False, - elastic_width_allowed=False, - elastic_dilation_allowed=False, - elastic_grouping_allowed=False, - elastic_dsc_allowed=False, - evaluate=True, - random_evaluate=True, - random_eval_number=100, - extract_model_config=False, - warmup_model_path="",) -> None: - self.config = parent_config - self.epochs_warmup = epochs_warmup - self.epochs_kernel_step = epochs_kernel_step - self.epochs_depth_step = epochs_depth_step - self.epochs_width_step = epochs_width_step - self.epochs_dilation_step = epochs_dilation_step - self.epochs_grouping_step = epochs_grouping_step - self.epochs_dsc_step = epochs_dsc_step - self.epochs_tuning_step = epochs_tuning_step - self.elastic_kernels_allowed = elastic_kernels_allowed - self.elastic_depth_allowed = elastic_depth_allowed - self.elastic_width_allowed = elastic_width_allowed - self.elastic_dilation_allowed = elastic_dilation_allowed - self.elastic_grouping_allowed = elastic_grouping_allowed - self.elastic_dsc_allowed = elastic_dsc_allowed - - self.evaluate = evaluate - self.random_evaluate = random_evaluate - self.random_eval_number = random_eval_number - self.warmup_model_path = warmup_model_path - self.extract_model_config = extract_model_config - - def build_model(self): - config = OmegaConf.create(self.config) - # logger = TensorBoardLogger(".") - - seed = config.get("seed", 1234) - if isinstance(seed, list) or isinstance(seed, omegaconf.ListConfig): - seed = seed[0] - seed_everything(seed, workers=True) - - if not torch.cuda.is_available(): - config.trainer.gpus = None - - callbacks = common_callbacks(config) - opt_monitor = config.get("monitor", ["val_error"]) - opt_callback = HydraOptCallback(monitor=opt_monitor) - callbacks.append(opt_callback) - checkpoint_callback = instantiate(config.checkpoint) - callbacks.append(checkpoint_callback) - self.config = config - # trainer will be initialized by rebuild_trainer - self.trainer = None - model = instantiate( - config.module, - dataset=config.dataset, - model=config.model, - optimizer=config.optimizer, - features=config.features, - scheduler=config.get("scheduler", None), - normalizer=config.get("normalizer", None), - _recursive_=False, - ) - model.setup("fit") - return model - - def run_training(self, model): - ofa_model = model.model - - self.kernel_step_count = ofa_model.ofa_steps_kernel - self.depth_step_count = ofa_model.ofa_steps_depth - self.width_step_count = ofa_model.ofa_steps_width - self.dilation_step_count = ofa_model.ofa_steps_dilation - self.grouping_step_count = ofa_model.ofa_steps_grouping - self.dsc_step_count = ofa_model.ofa_steps_dsc - ofa_model.elastic_kernels_allowed = self.elastic_kernels_allowed - ofa_model.elastic_depth_allowed = self.elastic_depth_allowed - ofa_model.elastic_width_allowed = self.elastic_width_allowed - ofa_model.elastic_dilation_allowed = self.elastic_dilation_allowed - ofa_model.elastic_grouping_allowed = self.elastic_grouping_allowed - ofa_model.elastic_dsc_allowed = self.elastic_dsc_allowed - ofa_model.full_config = self.config["model"] - - logging.info("Kernel Steps: %d", self.kernel_step_count) - logging.info("Depth Steps: %d", self.depth_step_count) - logging.info("Width Steps: %d", self.width_step_count) - logging.info("Grouping Steps: %d", self.grouping_step_count) - logging.info("DSC Steps: %d", self.dsc_step_count) - # logging.info("dsc: %d", self.grouping_step_count) - - self.submodel_metrics_csv = "" - self.random_metrics_csv = "" - - if self.elastic_width_allowed: - self.submodel_metrics_csv += "width, " - self.random_metrics_csv += "width_steps, " - - if self.elastic_kernels_allowed: - self.submodel_metrics_csv += "kernel, " - self.random_metrics_csv += "kernel_steps, " - - if self.elastic_dilation_allowed: - self.submodel_metrics_csv += "dilation, " - self.random_metrics_csv += "dilation_steps, " - - if self.elastic_depth_allowed: - self.submodel_metrics_csv += "depth, " - self.random_metrics_csv += "depth, " - - if self.elastic_grouping_allowed: - self.submodel_metrics_csv += "grouping, " - self.random_metrics_csv += "group_steps, " - - if self.elastic_dsc_allowed: - self.submodel_metrics_csv += "dsc, " - self.random_metrics_csv += "dsc, " - - if ( - self.elastic_width_allowed - | self.elastic_kernels_allowed - | self.elastic_dilation_allowed - | self.elastic_depth_allowed - | self.elastic_grouping_allowed - | self.elastic_dsc_allowed - ): - self.submodel_metrics_csv += ( - "acc, total_macs, total_weights, torch_params\n" - ) - self.random_metrics_csv += "acc, total_macs, total_weights, torch_params\n" - - # self.random_metrics_csv = "width_steps, depth, kernel_steps, acc, total_macs, total_weights, torch_params\n" - - logging.info("Once for all Model:\n %s", str(ofa_model)) - # TODO Warmup DSC on or off? - self.warmup(model, ofa_model) - ofa_model.reset_shrinking() - - self.train_elastic_kernel(model, ofa_model) - ofa_model.reset_shrinking() - self.train_elastic_dilation(model, ofa_model) - ofa_model.reset_shrinking() - self.train_elastic_depth(model, ofa_model) - ofa_model.reset_shrinking() - self.train_elastic_width(model, ofa_model) - ofa_model.reset_shrinking() - self.train_elastic_grouping(model, ofa_model) - ofa_model.reset_shrinking() - self.train_elastic_dsc(model, ofa_model) - ofa_model.reset_shrinking() - - if self.evaluate: - self.eval_model(model, ofa_model) - - if self.random_evaluate: - # save random metrics - msglogger.info("\n%s", self.random_metrics_csv) - with open("OFA_random_sample_metrics.csv", "w") as f: - f.write(self.random_metrics_csv) - # save self.submodel_metrics_csv - msglogger.info("\n%s", str(self.submodel_metrics_csv)) - with open("OFA_elastic_metrics.csv", "w") as f: - f.write(self.submodel_metrics_csv) - - - - def warmup(self, model, ofa_model): - """ - > The function rebuilds the trainer with the warmup epochs, fits the model, - validates the model, and then calls the on_warmup_end() function to - change some internal variables - - :param model: the model to be trained - :param ofa_model: the model that we want to train - """ - # warm-up. - self.rebuild_trainer("warmup", self.epochs_warmup) - if self.epochs_warmup > 0 and self.warmup_model_path == "": - self.trainer.fit(model) - ckpt_path = "best" - elif self.warmup_model_path != "": - ckpt_path = self.warmup_model_path - self.trainer.validate(ckpt_path=ckpt_path, model=model, verbose=True) - ofa_model.on_warmup_end() - ofa_model.reset_validation_model() - msglogger.info("OFA completed warm-up.") - - def train_elastic_width(self, model, ofa_model): - """ - > The function trains the model for a number of epochs, then adds a width - step, then trains the model for a number of epochs, then adds a width step, - and so on - - :param model: the model to train - :param ofa_model: the model that will be trained - """ - if self.elastic_width_allowed: - # train elastic width - # first, run channel priority computation - ofa_model.progressive_shrinking_compute_channel_priorities() - for current_width_step in range(1, self.width_step_count): - # add a width step - ofa_model.progressive_shrinking_add_width() - if self.epochs_width_step > 0: - self.rebuild_trainer( - f"width_{current_width_step}", self.epochs_width_step - ) - self.trainer.fit(model) - ckpt_path = "best" - self.trainer.validate(ckpt_path=ckpt_path, verbose=True) - msglogger.info("OFA completed width steps.") - - def train_elastic_depth(self, model, ofa_model): - """ - > The function trains the model for a number of epochs, then progressively - shrinks the depth of the model, and trains the model for a number of epochs - again - - :param model: the model to train - :param ofa_model: the model to be trained - """ - if self.elastic_depth_allowed: - # train elastic depth - for current_depth_step in range(1, self.depth_step_count): - # add a depth reduction step - ofa_model.progressive_shrinking_add_depth() - if self.epochs_depth_step > 0: - self.rebuild_trainer( - f"depth_{current_depth_step}", self.epochs_depth_step - ) - self.trainer.fit(model) - ckpt_path = "best" - self.trainer.validate(ckpt_path=ckpt_path, verbose=True) - msglogger.info("OFA completed depth steps.") - - def train_elastic_kernel(self, model, ofa_model): - """ - > The function trains the elastic kernels by progressively shrinking the - model and training the model for a number of epochs and repeats this process - until the number of kernel steps is reached - - :param model: the model to train - :param ofa_model: the model that will be trained - """ - if self.elastic_kernels_allowed: - # train elastic kernels - for current_kernel_step in range(1, self.kernel_step_count): - # add a kernel step - ofa_model.progressive_shrinking_add_kernel() - if self.epochs_kernel_step > 0: - self.rebuild_trainer( - f"kernel_{current_kernel_step}", self.epochs_kernel_step - ) - self.trainer.fit(model) - ckpt_path = "best" - self.trainer.validate(ckpt_path=ckpt_path, verbose=True) - msglogger.info("OFA completed kernel matrices.") - - def train_elastic_dilation(self, model, ofa_model): - """ - > The function trains the model for a number of epochs, then adds a dilation - step, and trains the model for a number of epochs, and repeats this process - until the number of dilation steps is reached - - :param model: the model to be trained - :param ofa_model: the model that will be trained - """ - if self.elastic_dilation_allowed: - # train elastic kernels - for current_dilation_step in range(1, self.dilation_step_count): - # add a kernel step - ofa_model.progressive_shrinking_add_dilation() - if self.epochs_dilation_step > 0: - self.rebuild_trainer( - f"kernel_{current_dilation_step}", self.epochs_dilation_step - ) - self.trainer.fit(model) - ckpt_path = "best" - self.trainer.validate(ckpt_path=ckpt_path, verbose=True) - msglogger.info("OFA completed dilation matrices.") - - def train_elastic_grouping(self, model, ofa_model): - """ - > The function trains the model for a number of epochs, then adds a group - step, and trains the model for a number of epochs, and repeats this process - until the number of group steps is reached - - :param model: the model to be trained - :param ofa_model: the model that will be trained - """ - if self.elastic_grouping_allowed: - # train elastic groups - for current_grouping_step in range(1, self.grouping_step_count): - # add a group step - ofa_model.progressive_shrinking_add_group() - if self.epochs_grouping_step > 0: - self.rebuild_trainer( - f"group_{current_grouping_step}", self.epochs_grouping_step - ) - self.trainer.fit(model) - ckpt_path = "best" - self.trainer.validate(ckpt_path=ckpt_path, verbose=True) - msglogger.info("OFA completed grouping matrices.") - - def train_elastic_dsc(self, model, ofa_model): - """ - > The function trains the model for a number of epochs, then adds a dsc - step (turns Depthwise Separable Convolution on and off), and trains the model for a number of epochs, and repeats this process - until the number of dsc steps is reached - - :param model: the model to be trained - :param ofa_model: the model that will be trained - """ - if self.elastic_dsc_allowed is True: - # train elastic groups - for current_dsc_step in range(1, self.dsc_step_count): - # add a group step - ofa_model.progressive_shrinking_add_dsc() - if self.epochs_dsc_step > 0: - self.rebuild_trainer( - f"dsc_{current_dsc_step}", self.epochs_dsc_step - ) - self.trainer.fit(model) - ckpt_path = "best" - self.trainer.validate(ckpt_path=ckpt_path, verbose=True) - msglogger.info("OFA completed dsc matrices.") - - def eval_elastic_width( - self, - method_stack, - method_index, - lightning_model, - model, - trainer_path, - loginfo_output, - metrics_output, - metrics_csv, - ): - """ - > This function steps down the width of the model, and then calls the next - method in the stack - - :param method_stack: a list of methods that will be called in order - :param method_index: The index of the current method in the method stack - :param lightning_model: the lightning model to be trained - :param model: the model to be trained - :param trainer_path: The path to the trainer - :param loginfo_output: This is the string that will be printed to the - console - :param metrics_output: a string that will be written to the metrics csv file - :param metrics_csv: a string that contains the metrics for the current model - :return: The metrics_csv is being returned. - """ - model.reset_all_widths() - method = method_stack[method_index] - - for current_width_step in range(self.width_step_count): - if current_width_step > 0: - # iteration 0 is the full model with no stepping - model.step_down_all_channels() - - trainer_path_tmp = trainer_path + f"W {current_width_step}, " - loginfo_output_tmp = loginfo_output + f"Width {current_width_step}, " - metrics_output_tmp = metrics_output + f"{current_width_step}, " - - metrics_csv = method( - method_stack, - method_index + 1, - lightning_model, - model, - trainer_path_tmp, - loginfo_output_tmp, - metrics_output_tmp, - metrics_csv, - ) - - return metrics_csv - - def eval_elastic_kernel( - self, - method_stack, - method_index, - lightning_model, - model, - trainer_path, - loginfo_output, - metrics_output, - metrics_csv, - ): - """ - > This function steps down the kernel size of the model, and then calls the - next method in the stack - - :param method_stack: The list of methods to be called - :param method_index: The index of the current method in the method stack - :param lightning_model: the lightning model to be trained - :param model: the model to be trained - :param trainer_path: The path to the trainer - :param loginfo_output: This is the string that will be printed to the - console - :param metrics_output: This is the string that will be printed to the - console - :param metrics_csv: a string that contains the metrics for the current model - :return: The metrics_csv is being returned. - """ - model.reset_all_kernel_sizes() - method = method_stack[method_index] - - for current_kernel_step in range(self.kernel_step_count): - if current_kernel_step > 0: - # iteration 0 is the full model with no stepping - model.step_down_all_kernels() - - trainer_path_tmp = trainer_path + f"K {current_kernel_step}, " - loginfo_output_tmp = loginfo_output + f"Kernel {current_kernel_step}, " - metrics_output_tmp = metrics_output + f"{current_kernel_step}, " - - metrics_csv = method( - method_stack, - method_index + 1, - lightning_model, - model, - trainer_path_tmp, - loginfo_output_tmp, - metrics_output_tmp, - metrics_csv, - ) - - return metrics_csv - - def eval_elastic_dilation( - self, - method_stack, - method_index, - lightning_model, - model, - trainer_path, - loginfo_output, - metrics_output, - metrics_csv, - ): - """ - > This function evaluates the model with a different dilation size for each - layer - - :param method_stack: The list of methods to be called - :param method_index: The index of the method in the method stack - :param lightning_model: the lightning model to be trained - :param model: the model to be evaluated - :param trainer_path: The path to the trainer - :param loginfo_output: This is the string that will be printed to the - console - :param metrics_output: a string that will be written to the metrics csv file - :param metrics_csv: a string that contains the csv data for the metrics - :return: The metrics_csv is being returned. - """ - model.reset_all_dilation_sizes() - method = method_stack[method_index] - - for current_dilation_step in range(self.dilation_step_count): - if current_dilation_step > 0: - # iteration 0 is the full model with no stepping - model.step_down_all_dilations() - - trainer_path_tmp = trainer_path + f"K {current_dilation_step}, " - loginfo_output_tmp = loginfo_output + f"Dilation {current_dilation_step}, " - metrics_output_tmp = metrics_output + f"{current_dilation_step}, " - - metrics_csv = method( - method_stack, - method_index + 1, - lightning_model, - model, - trainer_path_tmp, - loginfo_output_tmp, - metrics_output_tmp, - metrics_csv, - ) - - return metrics_csv - - def eval_elastic_depth( - self, - method_stack, - method_index, - lightning_model, - model, - trainer_path, - loginfo_output, - metrics_output, - metrics_csv, - ): - """ - > This function will run the next method in the stack for each depth step, - and then return the metrics_csv - - :param method_stack: The list of methods to be called - :param method_index: The index of the current method in the method stack - :param lightning_model: the lightning model to be trained - :param model: The model to be trained - :param trainer_path: The path to the trainer, which is used to save the - model - :param loginfo_output: This is the string that will be printed to the - console - :param metrics_output: This is the string that will be printed to the - console - :param metrics_csv: This is the CSV file that we're writing to - :return: The metrics_csv is being returned. - """ - model.reset_active_depth() - method = method_stack[method_index] - - for current_depth_step in range(self.depth_step_count): - if current_depth_step > 0: - # iteration 0 is the full model with no stepping - model.active_depth -= 1 - - trainer_path_tmp = trainer_path + f"D {current_depth_step}, " - loginfo_output_tmp = loginfo_output + f"Depth {current_depth_step}, " - metrics_output_tmp = metrics_output + f"{current_depth_step}, " - - metrics_csv = method( - method_stack, - method_index + 1, - lightning_model, - model, - trainer_path_tmp, - loginfo_output_tmp, - metrics_output_tmp, - metrics_csv, - ) - - return metrics_csv - - def eval_elastic_grouping( - self, - method_stack, - method_index, - lightning_model, - model, - trainer_path, - loginfo_output, - metrics_output, - metrics_csv, - ): - """ - > This function evaluates the model with a different group size for each - layer - - :param method_stack: The list of methods to be called - :param method_index: The index of the method in the method stack - :param lightning_model: the lightning model to be trained - :param model: the model to be evaluated - :param trainer_path: The path to the trainer - :param loginfo_output: This is the string that will be printed to the - console - :param metrics_output: a string that will be written to the metrics csv file - :param metrics_csv: a string that contains the csv data for the metrics - :return: The metrics_csv is being returned. - """ - model.reset_all_group_sizes() - method = method_stack[method_index] - for current_group_step in range(self.grouping_step_count): - if current_group_step > 0: - # iteration 0 is the full model with no stepping - model.step_down_all_groups() - - trainer_path_tmp = trainer_path + f"G {current_group_step}, " - loginfo_output_tmp = loginfo_output + f"Group {current_group_step}, " - metrics_output_tmp = metrics_output + f"{current_group_step}, " - - metrics_csv = method( - method_stack, - method_index + 1, - lightning_model, - model, - trainer_path_tmp, - loginfo_output_tmp, - metrics_output_tmp, - metrics_csv, - ) - - return metrics_csv - - def eval_elastic_dsc( - self, - method_stack, - method_index, - lightning_model, - model, - trainer_path, - loginfo_output, - metrics_output, - metrics_csv, - ): - """ - > This function evaluates the model with a different dsc for each - layer - - :param method_stack: The list of methods to be called - :param method_index: The index of the method in the method stack - :param lightning_model: the lightning model to be trained - :param model: the model to be evaluated - :param trainer_path: The path to the trainer - :param loginfo_output: This is the string that will be printed to the - console - :param metrics_output: a string that will be written to the metrics csv file - :param metrics_csv: a string that contains the csv data for the metrics - :return: The metrics_csv is being returned. - """ - model.reset_all_dsc() - method = method_stack[method_index] - for current_dsc_step in range(self.dsc_step_count): - if current_dsc_step > 0: - # iteration 0 is the full model with no stepping - model.step_down_all_dsc() - - trainer_path_tmp = trainer_path + f"DSC {current_dsc_step}, " - loginfo_output_tmp = loginfo_output + f"DSC {current_dsc_step}, " - metrics_output_tmp = metrics_output + f"{current_dsc_step}, " - - metrics_csv = method( - method_stack, - method_index + 1, - lightning_model, - model, - trainer_path_tmp, - loginfo_output_tmp, - metrics_output_tmp, - metrics_csv, - ) - - return metrics_csv - - def eval_single_model( - self, - method_stack, - method_index, - lightning_model, - model, - trainer_path, - loginfo_output, - metrics_output, - metrics_csv, - ): - """ - > This function takes in a model, a trainer, and a bunch of other stuff, - evaluates the model and tracks the results in der in the given strings and - returns a string of metrics - - :param method_stack: The list of methods that we're evaluating - :param method_index: The index of the method in the method stack - :param lightning_model: the lightning model that we want to evaluate - :param model: The model to be evaluated - :param trainer_path: the path to the trainer object - :param loginfo_output: This is the string that will be printed to the - console when the model is being evaluated - :param metrics_output: This is the string that will be written to the - metrics file. It contains the method name, the method index, and the method - stack - :param metrics_csv: a string that will be written to a csv file - :return: The metrics_csv is being returned. - """ - self.rebuild_trainer(trainer_path, self.epochs_tuning_step, tensorboard=False) - msglogger.info(loginfo_output) - - validation_model = model.build_validation_model() - - lightning_model.model = validation_model - assert model.eval_mode is True - - if self.epochs_tuning_step > 0: - self.trainer.fit(lightning_model) - - validation_results = self.trainer.validate( - lightning_model, ckpt_path=None, verbose=True - ) - - lightning_model.model = model - - metrics_csv += metrics_output - results = validation_results[0] - torch_params = model.get_validation_model_weight_count() - metrics_csv += f"{results['val_accuracy']}, {results['total_macs']}, {results['total_weights']}, {torch_params}" - metrics_csv += "\n" - return metrics_csv - - def eval_model(self, lightning_model, model): - """ - First the method stack for the evaluation ist build and then it is according to this evaluated - - :param lightning_model: the lightning model - :param model: the model to be evaluated - """ - # disable sampling in forward during evaluation. - model.eval_mode = True - - eval_methods = [] - - if self.elastic_width_allowed: - eval_methods.append(self.eval_elastic_width) - - if self.elastic_kernels_allowed: - eval_methods.append(self.eval_elastic_kernel) - - if self.elastic_dilation_allowed: - eval_methods.append(self.eval_elastic_dilation) - - if self.elastic_depth_allowed: - eval_methods.append(self.eval_elastic_depth) - - if self.elastic_grouping_allowed: - eval_methods.append(self.eval_elastic_grouping) - - if self.elastic_dsc_allowed: - eval_methods.append(self.eval_elastic_dsc) - - if len(eval_methods) > 0: - eval_methods.append(self.eval_single_model) - self.submodel_metrics_csv = eval_methods[0]( - eval_methods, - 1, - lightning_model, - model, - "Eval ", - "OFA validating ", - "", - self.submodel_metrics_csv, - ) - - if self.random_evaluate: - self.eval_random_combination(lightning_model, model) - - model.eval_mode = False - - def eval_random_combination(self, lightning_model, model): - # sample a few random combinations - - random_eval_number = self.random_eval_number - prev_max_kernel = model.sampling_max_kernel_step - prev_max_depth = model.sampling_max_depth_step - prev_max_width = model.sampling_max_width_step - prev_max_dilation = model.sampling_max_dilation_step - prev_max_grouping = model.sampling_max_grouping_step - prev_max_dsc = model.sampling_max_dsc_step - model.sampling_max_kernel_step = model.ofa_steps_kernel - 1 - model.sampling_max_dilation_step = model.ofa_steps_dilation - 1 - model.sampling_max_depth_step = model.ofa_steps_depth - 1 - model.sampling_max_width_step = model.ofa_steps_width - 1 - model.sampling_max_grouping_step = model.ofa_steps_grouping - 1 - model.sampling_max_dsc_step = model.ofa_steps_dsc - 1 - assert model.eval_mode is True - for i in range(random_eval_number): - model.reset_validation_model() - random_state = model.sample_subnetwork() - - loginfo_output = f"OFA validating random sample:\n{random_state}" - trainer_path = "Eval random sample: " - metrics_output = "" - - if self.elastic_width_allowed: - selected_widths = random_state["width_steps"] - selected_widths_string = str(selected_widths).replace(",", ";") - metrics_output += f"{selected_widths_string}, " - trainer_path += f"Ws {selected_widths}, " - - if self.elastic_kernels_allowed: - selected_kernels = random_state["kernel_steps"] - selected_kernels_string = str(selected_kernels).replace(",", ";") - metrics_output += f" {selected_kernels_string}, " - trainer_path += f"Ks {selected_kernels}, " - - if self.elastic_dilation_allowed: - selected_dilations = random_state["dilation_steps"] - selected_dilations_string = str(selected_dilations).replace(",", ";") - metrics_output += f" {selected_dilations_string}, " - trainer_path += f"Dils {selected_dilations}, " - - if self.elastic_grouping_allowed: - selected_groups = random_state["grouping_steps"] - selected_groups_string = str(selected_groups).replace(",", ";") - metrics_output += f" {selected_groups_string}, " - trainer_path += f"Gs {selected_groups_string}, " - - if self.elastic_dsc_allowed: - selected_dscs = random_state["dsc_steps"] - selected_dscs_string = str(selected_dscs).replace(",", ";") - metrics_output += f" {selected_dscs_string}, " - trainer_path += f"DSCs {selected_dscs_string}, " - - if self.elastic_depth_allowed: - selected_depth = random_state["depth_step"] - trainer_path += f"D {selected_depth}, " - metrics_output += f"{selected_depth}, " - if self.extract_model_config: - model.print_config("r" + str(i)) - - self.random_metrics_csv = self.eval_single_model( - None, - None, - lightning_model, - model, - trainer_path, - loginfo_output, - metrics_output, - self.random_metrics_csv, - ) - - # revert to normal operation after eval. - model.sampling_max_kernel_step = prev_max_kernel - model.sampling_max_dilation_step = prev_max_dilation - model.sampling_max_depth_step = prev_max_depth - model.sampling_max_width_step = prev_max_width - model.sampling_max_grouping_step = prev_max_grouping - model.sampling_max_dsc_step = prev_max_dsc - - def rebuild_trainer( - self, step_name: str, epochs: int = 1, tensorboard: bool = True - ) -> Trainer: - if tensorboard: - logger = TensorBoardLogger(".", version=step_name) - else: - logger = CSVLogger(".", version=step_name) - callbacks = common_callbacks(self.config) - self.trainer = instantiate( - self.config.trainer, callbacks=callbacks, logger=logger, max_epochs=epochs - ) diff --git a/hannah/nas/search/search_old.py b/hannah/nas/search/search_old.py index c41410cd..58038f37 100644 --- a/hannah/nas/search/search_old.py +++ b/hannah/nas/search/search_old.py @@ -567,843 +567,3 @@ def run(self): metrics[k] = float(v) self.optimizer.tell_result(parameters, metrics) - - -class OFANasTrainer(NASTrainerBase): - def __init__( - self, - parent_config=None, - epochs_warmup=10, - epochs_kernel_step=10, - epochs_depth_step=10, - epochs_width_step=10, - epochs_dilation_step=10, - epochs_grouping_step=10, - epochs_dsc_step=10, - epochs_tuning_step=0, - elastic_kernels_allowed=False, - elastic_depth_allowed=False, - elastic_width_allowed=False, - elastic_dilation_allowed=False, - elastic_grouping_allowed=False, - elastic_dsc_allowed=False, - evaluate=True, - random_evaluate=True, - random_eval_number=100, - extract_model_config=False, - warmup_model_path="", - *args, - **kwargs, - ): - super().__init__(*args, parent_config=parent_config, **kwargs) - # currently no backend config for OFA - self.epochs_warmup = epochs_warmup - self.epochs_kernel_step = epochs_kernel_step - self.epochs_depth_step = epochs_depth_step - self.epochs_width_step = epochs_width_step - self.epochs_dilation_step = epochs_dilation_step - self.epochs_grouping_step = epochs_grouping_step - self.epochs_dsc_step = epochs_dsc_step - self.epochs_tuning_step = epochs_tuning_step - self.elastic_kernels_allowed = elastic_kernels_allowed - self.elastic_depth_allowed = elastic_depth_allowed - self.elastic_width_allowed = elastic_width_allowed - self.elastic_dilation_allowed = elastic_dilation_allowed - self.elastic_grouping_allowed = elastic_grouping_allowed - self.elastic_dsc_allowed = elastic_dsc_allowed - - self.evaluate = evaluate - self.random_evaluate = random_evaluate - self.random_eval_number = random_eval_number - self.warmup_model_path = warmup_model_path - self.extract_model_config = extract_model_config - - def run(self): - config = OmegaConf.create(self.config) - # logger = TensorBoardLogger(".") - - seed = config.get("seed", 1234) - if isinstance(seed, list) or isinstance(seed, omegaconf.ListConfig): - seed = seed[0] - seed_everything(seed, workers=True) - - if not torch.cuda.is_available(): - config.trainer.gpus = None - - callbacks = common_callbacks(config) - opt_monitor = config.get("monitor", ["val_error"]) - opt_callback = HydraOptCallback(monitor=opt_monitor) - callbacks.append(opt_callback) - checkpoint_callback = instantiate(config.checkpoint) - callbacks.append(checkpoint_callback) - self.config = config - # trainer will be initialized by rebuild_trainer - self.trainer = None - model = instantiate( - config.module, - dataset=config.dataset, - model=config.model, - optimizer=config.optimizer, - features=config.features, - scheduler=config.get("scheduler", None), - normalizer=config.get("normalizer", None), - _recursive_=False, - ) - model.setup("fit") - ofa_model = model.model - - self.kernel_step_count = ofa_model.ofa_steps_kernel - self.depth_step_count = ofa_model.ofa_steps_depth - self.width_step_count = ofa_model.ofa_steps_width - self.dilation_step_count = ofa_model.ofa_steps_dilation - self.grouping_step_count = ofa_model.ofa_steps_grouping - self.dsc_step_count = ofa_model.ofa_steps_dsc - ofa_model.elastic_kernels_allowed = self.elastic_kernels_allowed - ofa_model.elastic_depth_allowed = self.elastic_depth_allowed - ofa_model.elastic_width_allowed = self.elastic_width_allowed - ofa_model.elastic_dilation_allowed = self.elastic_dilation_allowed - ofa_model.elastic_grouping_allowed = self.elastic_grouping_allowed - ofa_model.elastic_dsc_allowed = self.elastic_dsc_allowed - ofa_model.full_config = self.config["model"] - - logging.info("Kernel Steps: %d", self.kernel_step_count) - logging.info("Depth Steps: %d", self.depth_step_count) - logging.info("Width Steps: %d", self.width_step_count) - logging.info("Grouping Steps: %d", self.grouping_step_count) - logging.info("DSC Steps: %d", self.dsc_step_count) - # logging.info("dsc: %d", self.grouping_step_count) - - self.submodel_metrics_csv = "" - self.random_metrics_csv = "" - - if self.elastic_width_allowed: - self.submodel_metrics_csv += "width, " - self.random_metrics_csv += "width_steps, " - - if self.elastic_kernels_allowed: - self.submodel_metrics_csv += "kernel, " - self.random_metrics_csv += "kernel_steps, " - - if self.elastic_dilation_allowed: - self.submodel_metrics_csv += "dilation, " - self.random_metrics_csv += "dilation_steps, " - - if self.elastic_depth_allowed: - self.submodel_metrics_csv += "depth, " - self.random_metrics_csv += "depth, " - - if self.elastic_grouping_allowed: - self.submodel_metrics_csv += "grouping, " - self.random_metrics_csv += "group_steps, " - - if self.elastic_dsc_allowed: - self.submodel_metrics_csv += "dsc, " - self.random_metrics_csv += "dsc, " - - if ( - self.elastic_width_allowed - | self.elastic_kernels_allowed - | self.elastic_dilation_allowed - | self.elastic_depth_allowed - | self.elastic_grouping_allowed - | self.elastic_dsc_allowed - ): - self.submodel_metrics_csv += ( - "acc, total_macs, total_weights, torch_params\n" - ) - self.random_metrics_csv += "acc, total_macs, total_weights, torch_params\n" - - # self.random_metrics_csv = "width_steps, depth, kernel_steps, acc, total_macs, total_weights, torch_params\n" - - logging.info("Once for all Model:\n %s", str(ofa_model)) - # TODO Warmup DSC on or off? - self.warmup(model, ofa_model) - ofa_model.reset_shrinking() - - self.train_elastic_kernel(model, ofa_model) - ofa_model.reset_shrinking() - self.train_elastic_dilation(model, ofa_model) - ofa_model.reset_shrinking() - self.train_elastic_depth(model, ofa_model) - ofa_model.reset_shrinking() - self.train_elastic_width(model, ofa_model) - ofa_model.reset_shrinking() - self.train_elastic_grouping(model, ofa_model) - ofa_model.reset_shrinking() - self.train_elastic_dsc(model, ofa_model) - ofa_model.reset_shrinking() - - if self.evaluate: - self.eval_model(model, ofa_model) - - if self.random_evaluate: - # save random metrics - msglogger.info("\n%s", self.random_metrics_csv) - with open("OFA_random_sample_metrics.csv", "w") as f: - f.write(self.random_metrics_csv) - # save self.submodel_metrics_csv - msglogger.info("\n%s", str(self.submodel_metrics_csv)) - with open("OFA_elastic_metrics.csv", "w") as f: - f.write(self.submodel_metrics_csv) - - def warmup(self, model, ofa_model): - """ - > The function rebuilds the trainer with the warmup epochs, fits the model, - validates the model, and then calls the on_warmup_end() function to - change some internal variables - - :param model: the model to be trained - :param ofa_model: the model that we want to train - """ - # warm-up. - self.rebuild_trainer("warmup", self.epochs_warmup) - if self.epochs_warmup > 0 and self.warmup_model_path == "": - self.trainer.fit(model) - ckpt_path = "best" - elif self.warmup_model_path != "": - ckpt_path = self.warmup_model_path - self.trainer.validate(ckpt_path=ckpt_path, model=model, verbose=True) - ofa_model.on_warmup_end() - ofa_model.reset_validation_model() - msglogger.info("OFA completed warm-up.") - - def train_elastic_width(self, model, ofa_model): - """ - > The function trains the model for a number of epochs, then adds a width - step, then trains the model for a number of epochs, then adds a width step, - and so on - - :param model: the model to train - :param ofa_model: the model that will be trained - """ - if self.elastic_width_allowed: - # train elastic width - # first, run channel priority computation - ofa_model.progressive_shrinking_compute_channel_priorities() - for current_width_step in range(1, self.width_step_count): - # add a width step - ofa_model.progressive_shrinking_add_width() - if self.epochs_width_step > 0: - self.rebuild_trainer( - f"width_{current_width_step}", self.epochs_width_step - ) - self.trainer.fit(model) - ckpt_path = "best" - self.trainer.validate(ckpt_path=ckpt_path, verbose=True) - msglogger.info("OFA completed width steps.") - - def train_elastic_depth(self, model, ofa_model): - """ - > The function trains the model for a number of epochs, then progressively - shrinks the depth of the model, and trains the model for a number of epochs - again - - :param model: the model to train - :param ofa_model: the model to be trained - """ - if self.elastic_depth_allowed: - # train elastic depth - for current_depth_step in range(1, self.depth_step_count): - # add a depth reduction step - ofa_model.progressive_shrinking_add_depth() - if self.epochs_depth_step > 0: - self.rebuild_trainer( - f"depth_{current_depth_step}", self.epochs_depth_step - ) - self.trainer.fit(model) - ckpt_path = "best" - self.trainer.validate(ckpt_path=ckpt_path, verbose=True) - msglogger.info("OFA completed depth steps.") - - def train_elastic_kernel(self, model, ofa_model): - """ - > The function trains the elastic kernels by progressively shrinking the - model and training the model for a number of epochs and repeats this process - until the number of kernel steps is reached - - :param model: the model to train - :param ofa_model: the model that will be trained - """ - if self.elastic_kernels_allowed: - # train elastic kernels - for current_kernel_step in range(1, self.kernel_step_count): - # add a kernel step - ofa_model.progressive_shrinking_add_kernel() - if self.epochs_kernel_step > 0: - self.rebuild_trainer( - f"kernel_{current_kernel_step}", self.epochs_kernel_step - ) - self.trainer.fit(model) - ckpt_path = "best" - self.trainer.validate(ckpt_path=ckpt_path, verbose=True) - msglogger.info("OFA completed kernel matrices.") - - def train_elastic_dilation(self, model, ofa_model): - """ - > The function trains the model for a number of epochs, then adds a dilation - step, and trains the model for a number of epochs, and repeats this process - until the number of dilation steps is reached - - :param model: the model to be trained - :param ofa_model: the model that will be trained - """ - if self.elastic_dilation_allowed: - # train elastic kernels - for current_dilation_step in range(1, self.dilation_step_count): - # add a kernel step - ofa_model.progressive_shrinking_add_dilation() - if self.epochs_dilation_step > 0: - self.rebuild_trainer( - f"kernel_{current_dilation_step}", self.epochs_dilation_step - ) - self.trainer.fit(model) - ckpt_path = "best" - self.trainer.validate(ckpt_path=ckpt_path, verbose=True) - msglogger.info("OFA completed dilation matrices.") - - def train_elastic_grouping(self, model, ofa_model): - """ - > The function trains the model for a number of epochs, then adds a group - step, and trains the model for a number of epochs, and repeats this process - until the number of group steps is reached - - :param model: the model to be trained - :param ofa_model: the model that will be trained - """ - if self.elastic_grouping_allowed: - # train elastic groups - for current_grouping_step in range(1, self.grouping_step_count): - # add a group step - ofa_model.progressive_shrinking_add_group() - if self.epochs_grouping_step > 0: - self.rebuild_trainer( - f"group_{current_grouping_step}", self.epochs_grouping_step - ) - self.trainer.fit(model) - ckpt_path = "best" - self.trainer.validate(ckpt_path=ckpt_path, verbose=True) - msglogger.info("OFA completed grouping matrices.") - - def train_elastic_dsc(self, model, ofa_model): - """ - > The function trains the model for a number of epochs, then adds a dsc - step (turns Depthwise Separable Convolution on and off), and trains the model for a number of epochs, and repeats this process - until the number of dsc steps is reached - - :param model: the model to be trained - :param ofa_model: the model that will be trained - """ - if self.elastic_dsc_allowed is True: - # train elastic groups - for current_dsc_step in range(1, self.dsc_step_count): - # add a group step - ofa_model.progressive_shrinking_add_dsc() - if self.epochs_dsc_step > 0: - self.rebuild_trainer( - f"dsc_{current_dsc_step}", self.epochs_dsc_step - ) - self.trainer.fit(model) - ckpt_path = "best" - self.trainer.validate(ckpt_path=ckpt_path, verbose=True) - msglogger.info("OFA completed dsc matrices.") - - def eval_elastic_width( - self, - method_stack, - method_index, - lightning_model, - model, - trainer_path, - loginfo_output, - metrics_output, - metrics_csv, - ): - """ - > This function steps down the width of the model, and then calls the next - method in the stack - - :param method_stack: a list of methods that will be called in order - :param method_index: The index of the current method in the method stack - :param lightning_model: the lightning model to be trained - :param model: the model to be trained - :param trainer_path: The path to the trainer - :param loginfo_output: This is the string that will be printed to the - console - :param metrics_output: a string that will be written to the metrics csv file - :param metrics_csv: a string that contains the metrics for the current model - :return: The metrics_csv is being returned. - """ - model.reset_all_widths() - method = method_stack[method_index] - - for current_width_step in range(self.width_step_count): - if current_width_step > 0: - # iteration 0 is the full model with no stepping - model.step_down_all_channels() - - trainer_path_tmp = trainer_path + f"W {current_width_step}, " - loginfo_output_tmp = loginfo_output + f"Width {current_width_step}, " - metrics_output_tmp = metrics_output + f"{current_width_step}, " - - metrics_csv = method( - method_stack, - method_index + 1, - lightning_model, - model, - trainer_path_tmp, - loginfo_output_tmp, - metrics_output_tmp, - metrics_csv, - ) - - return metrics_csv - - def eval_elastic_kernel( - self, - method_stack, - method_index, - lightning_model, - model, - trainer_path, - loginfo_output, - metrics_output, - metrics_csv, - ): - """ - > This function steps down the kernel size of the model, and then calls the - next method in the stack - - :param method_stack: The list of methods to be called - :param method_index: The index of the current method in the method stack - :param lightning_model: the lightning model to be trained - :param model: the model to be trained - :param trainer_path: The path to the trainer - :param loginfo_output: This is the string that will be printed to the - console - :param metrics_output: This is the string that will be printed to the - console - :param metrics_csv: a string that contains the metrics for the current model - :return: The metrics_csv is being returned. - """ - model.reset_all_kernel_sizes() - method = method_stack[method_index] - - for current_kernel_step in range(self.kernel_step_count): - if current_kernel_step > 0: - # iteration 0 is the full model with no stepping - model.step_down_all_kernels() - - trainer_path_tmp = trainer_path + f"K {current_kernel_step}, " - loginfo_output_tmp = loginfo_output + f"Kernel {current_kernel_step}, " - metrics_output_tmp = metrics_output + f"{current_kernel_step}, " - - metrics_csv = method( - method_stack, - method_index + 1, - lightning_model, - model, - trainer_path_tmp, - loginfo_output_tmp, - metrics_output_tmp, - metrics_csv, - ) - - return metrics_csv - - def eval_elastic_dilation( - self, - method_stack, - method_index, - lightning_model, - model, - trainer_path, - loginfo_output, - metrics_output, - metrics_csv, - ): - """ - > This function evaluates the model with a different dilation size for each - layer - - :param method_stack: The list of methods to be called - :param method_index: The index of the method in the method stack - :param lightning_model: the lightning model to be trained - :param model: the model to be evaluated - :param trainer_path: The path to the trainer - :param loginfo_output: This is the string that will be printed to the - console - :param metrics_output: a string that will be written to the metrics csv file - :param metrics_csv: a string that contains the csv data for the metrics - :return: The metrics_csv is being returned. - """ - model.reset_all_dilation_sizes() - method = method_stack[method_index] - - for current_dilation_step in range(self.dilation_step_count): - if current_dilation_step > 0: - # iteration 0 is the full model with no stepping - model.step_down_all_dilations() - - trainer_path_tmp = trainer_path + f"K {current_dilation_step}, " - loginfo_output_tmp = loginfo_output + f"Dilation {current_dilation_step}, " - metrics_output_tmp = metrics_output + f"{current_dilation_step}, " - - metrics_csv = method( - method_stack, - method_index + 1, - lightning_model, - model, - trainer_path_tmp, - loginfo_output_tmp, - metrics_output_tmp, - metrics_csv, - ) - - return metrics_csv - - def eval_elastic_depth( - self, - method_stack, - method_index, - lightning_model, - model, - trainer_path, - loginfo_output, - metrics_output, - metrics_csv, - ): - """ - > This function will run the next method in the stack for each depth step, - and then return the metrics_csv - - :param method_stack: The list of methods to be called - :param method_index: The index of the current method in the method stack - :param lightning_model: the lightning model to be trained - :param model: The model to be trained - :param trainer_path: The path to the trainer, which is used to save the - model - :param loginfo_output: This is the string that will be printed to the - console - :param metrics_output: This is the string that will be printed to the - console - :param metrics_csv: This is the CSV file that we're writing to - :return: The metrics_csv is being returned. - """ - model.reset_active_depth() - method = method_stack[method_index] - - for current_depth_step in range(self.depth_step_count): - if current_depth_step > 0: - # iteration 0 is the full model with no stepping - model.active_depth -= 1 - - trainer_path_tmp = trainer_path + f"D {current_depth_step}, " - loginfo_output_tmp = loginfo_output + f"Depth {current_depth_step}, " - metrics_output_tmp = metrics_output + f"{current_depth_step}, " - - metrics_csv = method( - method_stack, - method_index + 1, - lightning_model, - model, - trainer_path_tmp, - loginfo_output_tmp, - metrics_output_tmp, - metrics_csv, - ) - - return metrics_csv - - def eval_elastic_grouping( - self, - method_stack, - method_index, - lightning_model, - model, - trainer_path, - loginfo_output, - metrics_output, - metrics_csv, - ): - """ - > This function evaluates the model with a different group size for each - layer - - :param method_stack: The list of methods to be called - :param method_index: The index of the method in the method stack - :param lightning_model: the lightning model to be trained - :param model: the model to be evaluated - :param trainer_path: The path to the trainer - :param loginfo_output: This is the string that will be printed to the - console - :param metrics_output: a string that will be written to the metrics csv file - :param metrics_csv: a string that contains the csv data for the metrics - :return: The metrics_csv is being returned. - """ - model.reset_all_group_sizes() - method = method_stack[method_index] - for current_group_step in range(self.grouping_step_count): - if current_group_step > 0: - # iteration 0 is the full model with no stepping - model.step_down_all_groups() - - trainer_path_tmp = trainer_path + f"G {current_group_step}, " - loginfo_output_tmp = loginfo_output + f"Group {current_group_step}, " - metrics_output_tmp = metrics_output + f"{current_group_step}, " - - metrics_csv = method( - method_stack, - method_index + 1, - lightning_model, - model, - trainer_path_tmp, - loginfo_output_tmp, - metrics_output_tmp, - metrics_csv, - ) - - return metrics_csv - - def eval_elastic_dsc( - self, - method_stack, - method_index, - lightning_model, - model, - trainer_path, - loginfo_output, - metrics_output, - metrics_csv, - ): - """ - > This function evaluates the model with a different dsc for each - layer - - :param method_stack: The list of methods to be called - :param method_index: The index of the method in the method stack - :param lightning_model: the lightning model to be trained - :param model: the model to be evaluated - :param trainer_path: The path to the trainer - :param loginfo_output: This is the string that will be printed to the - console - :param metrics_output: a string that will be written to the metrics csv file - :param metrics_csv: a string that contains the csv data for the metrics - :return: The metrics_csv is being returned. - """ - model.reset_all_dsc() - method = method_stack[method_index] - for current_dsc_step in range(self.dsc_step_count): - if current_dsc_step > 0: - # iteration 0 is the full model with no stepping - model.step_down_all_dsc() - - trainer_path_tmp = trainer_path + f"DSC {current_dsc_step}, " - loginfo_output_tmp = loginfo_output + f"DSC {current_dsc_step}, " - metrics_output_tmp = metrics_output + f"{current_dsc_step}, " - - metrics_csv = method( - method_stack, - method_index + 1, - lightning_model, - model, - trainer_path_tmp, - loginfo_output_tmp, - metrics_output_tmp, - metrics_csv, - ) - - return metrics_csv - - def eval_single_model( - self, - method_stack, - method_index, - lightning_model, - model, - trainer_path, - loginfo_output, - metrics_output, - metrics_csv, - ): - """ - > This function takes in a model, a trainer, and a bunch of other stuff, - evaluates the model and tracks the results in der in the given strings and - returns a string of metrics - - :param method_stack: The list of methods that we're evaluating - :param method_index: The index of the method in the method stack - :param lightning_model: the lightning model that we want to evaluate - :param model: The model to be evaluated - :param trainer_path: the path to the trainer object - :param loginfo_output: This is the string that will be printed to the - console when the model is being evaluated - :param metrics_output: This is the string that will be written to the - metrics file. It contains the method name, the method index, and the method - stack - :param metrics_csv: a string that will be written to a csv file - :return: The metrics_csv is being returned. - """ - self.rebuild_trainer(trainer_path, self.epochs_tuning_step, tensorboard=False) - msglogger.info(loginfo_output) - - validation_model = model.build_validation_model() - - lightning_model.model = validation_model - assert model.eval_mode is True - - if self.epochs_tuning_step > 0: - self.trainer.fit(lightning_model) - - validation_results = self.trainer.validate( - lightning_model, ckpt_path=None, verbose=True - ) - - lightning_model.model = model - - metrics_csv += metrics_output - results = validation_results[0] - torch_params = model.get_validation_model_weight_count() - metrics_csv += f"{results['val_accuracy']}, {results['total_macs']}, {results['total_weights']}, {torch_params}" - metrics_csv += "\n" - return metrics_csv - - def eval_model(self, lightning_model, model): - """ - First the method stack for the evaluation ist build and then it is according to this evaluated - - :param lightning_model: the lightning model - :param model: the model to be evaluated - """ - # disable sampling in forward during evaluation. - model.eval_mode = True - - eval_methods = [] - - if self.elastic_width_allowed: - eval_methods.append(self.eval_elastic_width) - - if self.elastic_kernels_allowed: - eval_methods.append(self.eval_elastic_kernel) - - if self.elastic_dilation_allowed: - eval_methods.append(self.eval_elastic_dilation) - - if self.elastic_depth_allowed: - eval_methods.append(self.eval_elastic_depth) - - if self.elastic_grouping_allowed: - eval_methods.append(self.eval_elastic_grouping) - - if self.elastic_dsc_allowed: - eval_methods.append(self.eval_elastic_dsc) - - if len(eval_methods) > 0: - eval_methods.append(self.eval_single_model) - self.submodel_metrics_csv = eval_methods[0]( - eval_methods, - 1, - lightning_model, - model, - "Eval ", - "OFA validating ", - "", - self.submodel_metrics_csv, - ) - - if self.random_evaluate: - self.eval_random_combination(lightning_model, model) - - model.eval_mode = False - - def eval_random_combination(self, lightning_model, model): - # sample a few random combinations - - random_eval_number = self.random_eval_number - prev_max_kernel = model.sampling_max_kernel_step - prev_max_depth = model.sampling_max_depth_step - prev_max_width = model.sampling_max_width_step - prev_max_dilation = model.sampling_max_dilation_step - prev_max_grouping = model.sampling_max_grouping_step - prev_max_dsc = model.sampling_max_dsc_step - model.sampling_max_kernel_step = model.ofa_steps_kernel - 1 - model.sampling_max_dilation_step = model.ofa_steps_dilation - 1 - model.sampling_max_depth_step = model.ofa_steps_depth - 1 - model.sampling_max_width_step = model.ofa_steps_width - 1 - model.sampling_max_grouping_step = model.ofa_steps_grouping - 1 - model.sampling_max_dsc_step = model.ofa_steps_dsc - 1 - assert model.eval_mode is True - for i in range(random_eval_number): - model.reset_validation_model() - random_state = model.sample_subnetwork() - - loginfo_output = f"OFA validating random sample:\n{random_state}" - trainer_path = "Eval random sample: " - metrics_output = "" - - if self.elastic_width_allowed: - selected_widths = random_state["width_steps"] - selected_widths_string = str(selected_widths).replace(",", ";") - metrics_output += f"{selected_widths_string}, " - trainer_path += f"Ws {selected_widths}, " - - if self.elastic_kernels_allowed: - selected_kernels = random_state["kernel_steps"] - selected_kernels_string = str(selected_kernels).replace(",", ";") - metrics_output += f" {selected_kernels_string}, " - trainer_path += f"Ks {selected_kernels}, " - - if self.elastic_dilation_allowed: - selected_dilations = random_state["dilation_steps"] - selected_dilations_string = str(selected_dilations).replace(",", ";") - metrics_output += f" {selected_dilations_string}, " - trainer_path += f"Dils {selected_dilations}, " - - if self.elastic_grouping_allowed: - selected_groups = random_state["grouping_steps"] - selected_groups_string = str(selected_groups).replace(",", ";") - metrics_output += f" {selected_groups_string}, " - trainer_path += f"Gs {selected_groups_string}, " - - if self.elastic_dsc_allowed: - selected_dscs = random_state["dsc_steps"] - selected_dscs_string = str(selected_dscs).replace(",", ";") - metrics_output += f" {selected_dscs_string}, " - trainer_path += f"DSCs {selected_dscs_string}, " - - if self.elastic_depth_allowed: - selected_depth = random_state["depth_step"] - trainer_path += f"D {selected_depth}, " - metrics_output += f"{selected_depth}, " - if self.extract_model_config: - model.print_config("r" + str(i)) - - self.random_metrics_csv = self.eval_single_model( - None, - None, - lightning_model, - model, - trainer_path, - loginfo_output, - metrics_output, - self.random_metrics_csv, - ) - - # revert to normal operation after eval. - model.sampling_max_kernel_step = prev_max_kernel - model.sampling_max_dilation_step = prev_max_dilation - model.sampling_max_depth_step = prev_max_depth - model.sampling_max_width_step = prev_max_width - model.sampling_max_grouping_step = prev_max_grouping - model.sampling_max_dsc_step = prev_max_dsc - - def rebuild_trainer( - self, step_name: str, epochs: int = 1, tensorboard: bool = True - ) -> Trainer: - if tensorboard: - logger = TensorBoardLogger(".", version=step_name) - else: - logger = CSVLogger(".", version=step_name) - callbacks = common_callbacks(self.config) - self.trainer = instantiate( - self.config.trainer, callbacks=callbacks, logger=logger, max_epochs=epochs - ) diff --git a/test/test_dsc.py b/test/test_dsc.py deleted file mode 100644 index c121d594..00000000 --- a/test/test_dsc.py +++ /dev/null @@ -1,128 +0,0 @@ -from tokenize import group -import unittest -import torch -import pytest - -import torch.nn as nn -import torch.nn.functional as F - -from hannah.models.ofa.utilities import ( - create_channel_filter, - prepare_kernel_for_depthwise_separable_convolution, - prepare_kernel_for_pointwise_convolution, -) - - -class Test_DSC(unittest.TestCase): - - def test_dsc(self): - - input = torch.randn(10, 16, 100) # batch, input, c_out - - in_channel = 16 - out_channel = 32 - kernel_size = 3 - groups = 4 - compare = nn.Conv1d(in_channels=in_channel, out_channels=out_channel, kernel_size=kernel_size, groups=groups) - - print(input.shape) - t = nn.Conv1d( - in_channels=in_channel, - out_channels=out_channel, - kernel_size=kernel_size, - groups=1, - ) - full_kernel = t.weight.data - new_kernel, bias = prepare_kernel_for_depthwise_separable_convolution( - t, - kernel=full_kernel, - bias=None, - in_channels=in_channel - ) - print( - f"kernel:{new_kernel.shape} bias: {bias.shape if bias is not None else bias}" - ) - # perform depthwise separable convolution - res_depthwise = F.conv1d(input, new_kernel, bias, groups=in_channel) - # point_conv = Conv2d(in_channels=10, out_channels=32, kernel_size=1) - - assert res_depthwise.shape[1] == in_channel - print(res_depthwise) - print(res_depthwise.shape) - # get new kernel size - # use full kernel - # grouping = in_channel_count - new_kernel = prepare_kernel_for_pointwise_convolution( - kernel=full_kernel, - grouping=groups, - ) - res_pointwise = F.conv1d(res_depthwise, new_kernel, bias, groups=groups) - assert res_pointwise.shape[1] == out_channel - assert new_kernel.shape[2] == 1 - print(res_pointwise) - print(res_pointwise.shape) - - print("Comparing with normal conv") - compare_output = compare(input) - print(f"dsc:{res_pointwise.shape}, compare:{compare_output.shape}") - assert(compare_output.shape == res_pointwise.shape) - - @pytest.mark.skip(reason="only for looking stuff up") - def test_kernel_dimensions(self): - t = nn.Conv1d(in_channels=2, out_channels=2, kernel_size=3) - a = nn.Conv1d(in_channels=2, out_channels=2, kernel_size=1) - print(t.weight.shape) - print(t.weight.size(0)) - print(t.weight.size(1)) - print(t.weight.size(2)) - print("=================") - print(a.weight.shape) - print(a.weight.size(0)) - print(a.weight.size(1)) - print(a.weight.size(2)) - print("=================") - print(t.weight.shape[0] * t.weight.shape[1] * t.weight.shape[2]) - print(a.weight.shape[0] * a.weight.shape[1] * a.weight.shape[2]) - print("=================") - print(t.weight) - print(a.weight) - print(a.weight[:, :, 0]) - print(t.weight[:, :, 0]) - # Bedeutet wenn ich kernel[a][b] mache, dann nehme ich den ersten wert von dem paket - # bedeutet nimm alle elemente und von der letzten dimension slice immer nur das erste - # kernel[:,:,0].shape sliced alles so, dass es von allen dims, bei der letzten nur das nullte nimmt - - @pytest.mark.skip(reason="only for looking stuff up") - def test_kernel_reducement(self): - input = torch.randn(10, 2, 100) # batch, input, c_out - t = nn.Conv1d(in_channels=2, out_channels=2, kernel_size=3) - a = nn.Conv1d(in_channels=2, out_channels=2, kernel_size=1) - t.weight.data = torch.ones(2, 2, 3) - a.weight.data = torch.ones(2, 2, 1) - kernel = t.weight.data - output_1 = F.conv1d(input, kernel) - kernel_reduced = kernel[:, :, 0:1] - print(f"kernel_size reduced: {kernel_reduced.shape} before: {kernel} target: {a.weight.shape}") - assert(kernel_reduced.shape == a.weight.shape) - output_2 = F.conv1d(input, kernel_reduced) - output_compare = F.conv1d(input, a.weight.data) - assert(output_2.shape == output_compare.shape) - assert(torch.equal(output_2, output_compare)) - print("finished testing") - - @pytest.mark.skip(reason="only for looking stuff up") - def test_check_weigths(self): - t = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, groups=1) - print(f"weight: {t.weight.data.shape}") - # weight is defined out_channels, in_channels, kernel_size - print([t.out_channels, t.in_channels, t.kernel_size[0]]) - - -def printValues(number: int, cnn): - print( - f"C{number} Parameters: in_channels={cnn.in_channels}, out_channels={cnn.out_channels}, k={cnn.kernel_size}, g={cnn.groups}" - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/test_ofa.py b/test/test_ofa.py deleted file mode 100644 index a5035170..00000000 --- a/test/test_ofa.py +++ /dev/null @@ -1,85 +0,0 @@ -# -# Copyright (c) 2023 Hannah contributors. -# -# This file is part of hannah. -# See https://github.com/ekut-es/hannah for further info. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import numpy as np -import torch -import torch.nn as nn -from flaky import flaky - -from hannah.models.ofa.submodules.elastickernelconv import ElasticConv1d - - -# FIXME: should be fixable by initializing seeds -@flaky(20, 3) -def test_elastic_conv1d_quant(): - kernel_sizes = [3, 5, 7] - input_length = 10 - input_channels = 2 - output_channels = 2 - batch_size = 2 - dilation_sizes = [1] - - input = torch.ones((batch_size, input_channels, input_length)) - output = torch.zeros((batch_size, output_channels, input_length)) - - conv = ElasticConv1d( - input_channels, - output_channels, - kernel_sizes, - dilation_sizes=dilation_sizes, - groups=[1], - dscs=[1], - ) - loss_func = nn.MSELoss() - optimizer = torch.optim.SGD(conv.parameters(), lr=0.01) - - res = conv(input) - orig_loss = loss_func(res, output) - print("orig_loss:", orig_loss) - - assert res.shape == output.shape - - for _i in range(10): - optimizer.zero_grad() - res = conv(input) - loss = loss_func(res, output) - loss.backward() - optimizer.step() - - # Sample convolution size - for _i in range(30): - kernel_size = np.random.choice(kernel_sizes) - conv.set_kernel_size(kernel_size) - optimizer.zero_grad() - res = conv(input) - loss = loss_func(res, output) - loss.backward() - optimizer.step() - - print("orig_loss:", orig_loss) - for kernel_size in kernel_sizes: - conv.set_kernel_size(kernel_size) - res = conv(input) - loss = loss_func(res, output) - print("kernel_size:", kernel_size, "loss:", loss) - - assert loss < orig_loss - - -if __name__ == "__main__": - test_elastic_conv1d_quant() diff --git a/test/test_ofa_groups.py b/test/test_ofa_groups.py deleted file mode 100644 index 5cc33d1f..00000000 --- a/test/test_ofa_groups.py +++ /dev/null @@ -1,107 +0,0 @@ -# -# Copyright (c) 2023 Hannah contributors. -# -# This file is part of hannah. -# See https://github.com/ekut-es/hannah for further info. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import unittest - -import numpy as np -import torch -import torch.nn as nn - -from hannah.models.ofa.submodules.elastickernelconv import ElasticConv1d - -# , ElasticConvReLu1d - - -class OFAGroupTestCase(unittest.TestCase): - def test_grouping(self): - kernel_sizes = [3] - input_length = 5 - input_channels = 8 - output_channels = 8 - batch_size = 2 - dilation_sizes = [1] - group_sizes = [1, 2, 4, 8] - - # check calls of set_group_size - - input = torch.ones((batch_size, input_channels, input_length)) - output = torch.zeros((batch_size, output_channels, input_length)) - - conv = ElasticConv1d( - input_channels, - output_channels, - kernel_sizes, - dilation_sizes=dilation_sizes, - groups=group_sizes, - dscs=[0, 1, 2, 3], - ) - loss_func = nn.MSELoss() - optimizer = torch.optim.SGD(conv.parameters(), lr=0.1) - - res = conv(input) - orig_loss = loss_func(res, output) - print("orig_loss:", orig_loss) - - assert res.shape == output.shape - - loss = 1 - # warmup - for i in range(5): - optimizer.zero_grad() - res = conv(input) - loss = loss_func(res, output) - loss.backward() - optimizer.step() - - print("after warmup:", loss) - group_val = {} - for group_size in group_sizes: - print("group_size:", group_size) - conv.set_group_size(group_size) - for i in range(5): - optimizer.zero_grad() - res = conv(input) - loss = loss_func(res, output) - loss.backward() - optimizer.step() - print("loss:", loss) - - # Validation - validation_loss = [] - for i in range(10): - res = conv(input) - val_loss = loss_func(res, output) - print("val_loss:", loss) - validation_loss.append(val_loss.item()) - mean = np.mean(validation_loss) - group_val[group_size] = mean - - print("Values:") - best_pair_g = 1 - best_pair_v = 1 - - for k, v in group_val.items(): - print(f"Groups {k} Accuracy {v}") - if v < best_pair_v: - best_pair_v = v - best_pair_g = k - print(f"Best: G {best_pair_g} Accuracy {best_pair_v}") - - -if __name__ == "__main__": - unittest.main()