diff --git a/experiments/constrained_space_comparison/augmentation/cifar_augment.yaml b/experiments/constrained_space_comparison/augmentation/cifar_augment.yaml new file mode 100644 index 00000000..24d4f70c --- /dev/null +++ b/experiments/constrained_space_comparison/augmentation/cifar_augment.yaml @@ -0,0 +1,20 @@ +batch_augment: + pipeline: null + transforms: + #RandomVerticalFlip: + # p: 0.5 + RandomHorizontalFlip: + p: 0.5 + RandomAffine: + degrees: [-15, 15] + translate: [0.1, 0.1] + scale: [0.9, 1.1] + shear: [-5, 5] + p: 0.5 + RandomCrop: + size: [32,32] + padding: 4 + RandomErasing: + p: 0.5 + #scale: [0.!, 0.3] + #value: [0.4914, 0.4822, 0.4465] diff --git a/experiments/constrained_space_comparison/config.yaml b/experiments/constrained_space_comparison/config.yaml index 646d63d9..692e87ba 100644 --- a/experiments/constrained_space_comparison/config.yaml +++ b/experiments/constrained_space_comparison/config.yaml @@ -23,7 +23,7 @@ defaults: - override features: identity # Feature extractor configuration name (use identity for vision datasets) #- override model: timm_mobilenetv3_small_075 # Neural network name (for now timm_resnet50 or timm_efficientnet_lite1) - override scheduler: 1cycle # learning rate scheduler config name - - override optimizer: adamw # Optimizer config name + - override optimizer: sgd # Optimizer config name - override normalizer: null # Feature normalizer (used for quantized neural networks) - override module: image_classifier # Lightning module config for the training loop (image classifier for image classification tasks) - _self_ @@ -32,12 +32,7 @@ defaults: dataset: data_folder: ${oc.env:HANNAH_DATA_FOLDER,${hydra:runtime.cwd}/../../datasets/} -module: - batch_size: 128 - num_workers: 8 - trainer: - max_epochs: 10 + max_epochs: 100 -scheduler: - max_lr: 0.001 +fx_mac_summary: True diff --git a/experiments/constrained_space_comparison/eval.yaml b/experiments/constrained_space_comparison/eval.yaml new file mode 100644 index 00000000..a6a5131d --- /dev/null +++ b/experiments/constrained_space_comparison/eval.yaml @@ -0,0 +1,77 @@ +## +## Copyright (c) 2022 University of Tübingen. +## +## This file is part of hannah. +## See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. +## +## Licensed under the Apache License, Version 2.0 (the "License"); +## you may not use this file except in compliance with the License. +## You may obtain a copy of the License at +## +## http://www.apache.org/licenses/LICENSE-2.0 +## +## Unless required by applicable law or agreed to in writing, software +## distributed under the License is distributed on an "AS IS" BASIS, +## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +## See the License for the specific language governing permissions and +## limitations under the License. +## +data: + # baseline: trained_models/ae_nas_cifar10_baseline/embedded_vision_net + # weights: trained_models/ae_nas_cifar10_weight/embedded_vision_net + # weights_and_macs: trained_models/ae_nas_cifar10_weight_and_macs/embedded_vision_net + # macs: trained_models/ae_nas_cifar10_macs/embedded_vision_net + # nopred: trained_models/ae_nas_cifar10_weight_and_macs_no_pred/embedded_vision_net + # sortbymacs: trained_models/ae_nas_cifar10_weight_and_macs_sortbymacs/embedded_vision_net + 250k: trained_models/ae_nas_cifar10_weight_250k/embedded_vision_net + 250k_20e: trained_models/ae_nas_cifar10_weight_250k_20epochs/embedded_vision_net + 250k_50m: trained_models/ae_nas_cifar10_weight_250k_macs_50m/embedded_vision_net + + +metrics: + total_act: + name: Activations + total_weights: + name: Weights + weights_m: + name: Weights [M] + derived: data["total_weights"] / 1000 / 1000 + val_accuracy: + name: Accuracy [%] + derived: (1.0 - data["val_error"]) * 100.0 + act_k: + name: Activations [k] + derived: data["total_act"] / 1000 + macs_m: + name: MACS [M] + derived: data["total_macs"] / 1000 / 1000 + +plots: + # Comparison plots 2-3 metrics using y, x and size as visualization points + - type: comparison + name: accuracy_memory + metrics: + - val_accuracy + - weights_m + - act_k + + - type: comparison + name: accuracy_macs + metrics: + - val_accuracy + - macs_m + +extract: + random_nas_cifar10: + bounds: + val_error: 0.20 + total_macs: 100000000 + total_weights: 1000000 + + +experiment: presampling +force: false + +hydra: + run: + dir: ./nas_results/${experiment} diff --git a/experiments/constrained_space_comparison/experiment/ae_nas_cifar10_weights_150-250k_macs_50m.yaml b/experiments/constrained_space_comparison/experiment/ae_nas_cifar10_weights_150-250k_macs_50m.yaml new file mode 100644 index 00000000..55fc4665 --- /dev/null +++ b/experiments/constrained_space_comparison/experiment/ae_nas_cifar10_weights_150-250k_macs_50m.yaml @@ -0,0 +1,23 @@ +# @package _global_ +defaults: + - override /nas: aging_evolution_nas + - override /model: embedded_vision_net + - override /dataset: cifar10 + - override /nas/constraint_model: random_walk + +model: + num_classes: 10 + constraints: + - name: weights + upper: 250000 + # lower: 150000 + - name: macs + lower: 50000000 + +nas: + budget: 600 + n_jobs: 8 + +seed: [1234] + +experiment_id: "ae_nas_cifar10_weight_250k_macs_50m" diff --git a/experiments/constrained_space_comparison/experiment/ae_nas_cifar10_weights_250k.yaml b/experiments/constrained_space_comparison/experiment/ae_nas_cifar10_weights_250k.yaml new file mode 100644 index 00000000..b04c6a5d --- /dev/null +++ b/experiments/constrained_space_comparison/experiment/ae_nas_cifar10_weights_250k.yaml @@ -0,0 +1,20 @@ +# @package _global_ +defaults: + - override /nas: aging_evolution_nas + - override /model: embedded_vision_net + - override /dataset: cifar10 + - override /nas/constraint_model: random_walk + +model: + num_classes: 10 + constraints: + - name: weights + upper: 250000 + +nas: + budget: 600 + n_jobs: 8 + +seed: [1234] + +experiment_id: "ae_nas_cifar10_weight_250k" diff --git a/experiments/constrained_space_comparison/experiment/ae_nas_cifar10_weights_250k_20epochs.yaml b/experiments/constrained_space_comparison/experiment/ae_nas_cifar10_weights_250k_20epochs.yaml new file mode 100644 index 00000000..dd8ef650 --- /dev/null +++ b/experiments/constrained_space_comparison/experiment/ae_nas_cifar10_weights_250k_20epochs.yaml @@ -0,0 +1,29 @@ +# @package _global_ +defaults: + - override /nas: aging_evolution_nas + - override /model: embedded_vision_net + - override /dataset: cifar10 + - override /nas/constraint_model: random_walk + +model: + num_classes: 10 + constraints: + - name: weights + upper: 250000 + +nas: + budget: 1000 + n_jobs: 2 + total_candidates: 100 + num_selected_candidates: 25 + bounds: + val_error: 0.1 + total_macs: 200000000 + total_weights: 250000 + +trainer: + max_epochs: 20 + +seed: [1234] + +experiment_id: "ae_nas_cifar10_weight_250k_20epochs" diff --git a/experiments/constrained_space_comparison/experiment/train_best_model.yaml b/experiments/constrained_space_comparison/experiment/train_best_model.yaml new file mode 100644 index 00000000..246322df --- /dev/null +++ b/experiments/constrained_space_comparison/experiment/train_best_model.yaml @@ -0,0 +1,23 @@ +# @package _global_ +defaults: + - override /model: embedded_vision_net_model + - override /augmentation: cifar_augment + +model: + param_path: /local/reiber/hannah/experiments/constrained_space_comparison/parameters.pkl + task_name: 250k_50m + index: 524 + input_shape: [1, 3, 32, 32] + labels: 10 + +module: + batch_size: 32 + num_workers: 32 + +trainer: + max_epochs: 100 + +scheduler: + max_lr: 0.001 + +experiment_id: best_model_augment_adam0001_b32_100e \ No newline at end of file diff --git a/experiments/ri_evn/config.yaml b/experiments/ri_evn/config.yaml new file mode 100644 index 00000000..d6ba22d3 --- /dev/null +++ b/experiments/ri_evn/config.yaml @@ -0,0 +1,45 @@ +## +## Copyright (c) 2022 University of Tübingen. +## +## This file is part of hannah. +## See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. +## +## Licensed under the Apache License, Version 2.0 (the "License"); +## you may not use this file except in compliance with the License. +## You may obtain a copy of the License at +## +## http://www.apache.org/licenses/LICENSE-2.0 +## +## Unless required by applicable law or agreed to in writing, software +## distributed under the License is distributed on an "AS IS" BASIS, +## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +## See the License for the specific language governing permissions and +## limitations under the License. +## +defaults: + - base_config + - experiment: optional + - override dataset: ri_capsule # Dataset configuration name + - override features: identity # Feature extractor configuration name (use identity for vision datasets) + # - override model: embedded_vision_net_model # Neural network name (for now timm_resnet50 or timm_efficientnet_lite1) + - override scheduler: 1cycle # learning rate scheduler config name + - override optimizer: adamw # Optimizer config name + - override normalizer: null # Feature normalizer (used for quantized neural networks) + - override module: image_classifier # Lightning module config for the training loop (image classifier for image classification tasks) + - _self_ + + +dataset: + data_folder: /data + +module: + batch_size: 32 + num_workers: 8 + +trainer: + max_epochs: 20 + +scheduler: + max_lr: 0.001 + +fx_mac_summary: True \ No newline at end of file diff --git a/experiments/ri_evn/eval.yaml b/experiments/ri_evn/eval.yaml new file mode 100644 index 00000000..eb06c4fd --- /dev/null +++ b/experiments/ri_evn/eval.yaml @@ -0,0 +1,70 @@ +## +## Copyright (c) 2022 University of Tübingen. +## +## This file is part of hannah. +## See https://atreus.informatik.uni-tuebingen.de/ties/ai/hannah/hannah for further info. +## +## Licensed under the Apache License, Version 2.0 (the "License"); +## you may not use this file except in compliance with the License. +## You may obtain a copy of the License at +## +## http://www.apache.org/licenses/LICENSE-2.0 +## +## Unless required by applicable law or agreed to in writing, software +## distributed under the License is distributed on an "AS IS" BASIS, +## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +## See the License for the specific language governing permissions and +## limitations under the License. +## +data: + 250k: trained_models/250k_weights/embedded_vision_net + + + +metrics: + total_act: + name: Activations + total_weights: + name: Weights + weights_m: + name: Weights [M] + derived: data["total_weights"] / 1000 / 1000 + val_accuracy: + name: Accuracy [%] + derived: (1.0 - data["val_error"]) * 100.0 + act_k: + name: Activations [k] + derived: data["total_act"] / 1000 + macs_m: + name: MACS [M] + derived: data["total_macs"] / 1000 / 1000 + +plots: + # Comparison plots 2-3 metrics using y, x and size as visualization points + - type: comparison + name: accuracy_memory + metrics: + - val_accuracy + - weights_m + - act_k + + - type: comparison + name: accuracy_macs + metrics: + - val_accuracy + - macs_m + +extract: + random_nas_cifar10: + bounds: + val_error: 0.20 + total_macs: 100000000 + total_weights: 1000000 + + +experiment: presampling +force: false + +hydra: + run: + dir: ./nas_results/${experiment} diff --git a/experiments/ri_evn/experiment/250k_weights.yaml b/experiments/ri_evn/experiment/250k_weights.yaml new file mode 100644 index 00000000..31d503a7 --- /dev/null +++ b/experiments/ri_evn/experiment/250k_weights.yaml @@ -0,0 +1,27 @@ +# @package _global_ +defaults: + - override /nas: aging_evolution_nas + - override /model: embedded_vision_net + - override /nas/constraint_model: random_walk + +model: + num_classes: 4 + constraints: + - name: weights + upper: 250000 + # - name: macs + # upper: 128000000 + +nas: + budget: 600 + n_jobs: 1 + num_selected_candidates: 20 + total_candidates: 50 + bounds: + val_error: 0.1 + total_macs: 128000000 + total_weights: 250000 + +seed: [1234] + +experiment_id: 250k_weights diff --git a/experiments/ri_evn/experiment/train_model.yaml b/experiments/ri_evn/experiment/train_model.yaml new file mode 100644 index 00000000..bd57ff23 --- /dev/null +++ b/experiments/ri_evn/experiment/train_model.yaml @@ -0,0 +1,17 @@ +# @package _global_ +defaults: + - override /model: embedded_vision_net_model + +model: + param_path: /local/reiber/hannah/experiments/ri_evn/parameters.pkl + task_name: 250k + index: 5 + input_shape: (1, 3, 320, 320) + labels: 4 + +seed: [1234] + +trainer: + max_epochs: 50 + +experiment_id: train_best_model diff --git a/experiments/ri_evn/parameters.pkl b/experiments/ri_evn/parameters.pkl new file mode 100644 index 00000000..d9c5fd3b Binary files /dev/null and b/experiments/ri_evn/parameters.pkl differ diff --git a/hannah/conf/model/embedded_vision_net.yaml b/hannah/conf/model/embedded_vision_net.yaml index 75a78c46..39978a95 100644 --- a/hannah/conf/model/embedded_vision_net.yaml +++ b/hannah/conf/model/embedded_vision_net.yaml @@ -3,11 +3,11 @@ name: embedded_vision_net num_classes: 10 max_channels: 256 -# constraints: -# - name: weights -# upper: 550000 -# # lower: 250000 -# - name: macs -# upper: 128000000 -# # lower: 60000000 +constraints: + - name: weights + upper: 550000 + # lower: 250000 + - name: macs + upper: 128000000 + # lower: 60000000 diff --git a/hannah/models/embedded_vision_net/blocks.py b/hannah/models/embedded_vision_net/blocks.py index 8cda2605..ba5b81ae 100644 --- a/hannah/models/embedded_vision_net/blocks.py +++ b/hannah/models/embedded_vision_net/blocks.py @@ -4,7 +4,6 @@ from hannah.nas.expressions.types import Int from hannah.nas.functional_operators.op import scope from hannah.models.embedded_vision_net.operators import adaptive_avg_pooling, add, conv2d, conv_relu, depthwise_conv2d, dynamic_depth, grouped_conv2d, interleave_channels, pointwise_conv2d, linear, relu, batch_norm, choice, identity, max_pool, avg_pool -# from hannah.nas.functional_operators.visualizer import Visualizer from hannah.nas.parameters.parameters import CategoricalParameter, IntScalarParameter diff --git a/hannah/models/embedded_vision_net/expressions.py b/hannah/models/embedded_vision_net/expressions.py index de8898f0..34f2b615 100644 --- a/hannah/models/embedded_vision_net/expressions.py +++ b/hannah/models/embedded_vision_net/expressions.py @@ -23,6 +23,16 @@ def expr_sum(expressions: list): return res +def expr_and(expressions): + res = None + for expr in expressions: + if res: + res = res + expr + else: + res = expr + return res + + ADD = 1 CHOICE = 2 diff --git a/hannah/models/embedded_vision_net/models.py b/hannah/models/embedded_vision_net/models.py index 9c0ffad8..2ab8a08f 100644 --- a/hannah/models/embedded_vision_net/models.py +++ b/hannah/models/embedded_vision_net/models.py @@ -1,5 +1,6 @@ from functools import partial -from hannah.models.embedded_vision_net.expressions import expr_product, extract_macs_recursive, extract_weights_recursive +from hannah.models.embedded_vision_net.expressions import expr_product, extract_macs_recursive, extract_weights_recursive, expr_and +from hannah.nas.core.parametrized import is_parametrized from hannah.nas.expressions.arithmetic import Ceil from hannah.nas.expressions.logic import And from hannah.nas.expressions.types import Int @@ -7,6 +8,7 @@ from hannah.nas.functional_operators.executor import BasicExecutor from hannah.nas.functional_operators.op import Tensor, get_nodes, scope from hannah.models.embedded_vision_net.operators import adaptive_avg_pooling, add, conv2d, conv_relu, depthwise_conv2d, dynamic_depth, pointwise_conv2d, linear, relu, batch_norm, choice, identity +from hannah.nas.functional_operators.operators import Conv2d # from hannah.nas.functional_operators.visualizer import Visualizer from hannah.nas.parameters.parameters import CategoricalParameter, FloatScalarParameter, IntScalarParameter from hannah.models.embedded_vision_net.blocks import block, cwm_block, classifier_head, stem @@ -54,26 +56,34 @@ def search_space(name, input, num_classes: int, max_channels=512, constraints: l arch.weights = extract_weights_recursive(arch) weight_params = extract_parameter_from_expression(arch.weights) weight_params = [p for p in weight_params if 'stride' not in p.name and 'groups' not in p.name] - if "lower" in con and "upper" in con: - upper = arch.weights < con.upper - lower = arch.weights > con.lower - arch.cond(And(lower, upper), weight_params) - elif "upper" in con: + # weight_params = [p for p in weight_params if "depth" in p.name or "num_blocks" in p.name] + # if "lower" in con and "upper" in con: + # upper = arch.weights < con.upper + # lower = arch.weights > con.lower + # arch.cond(And(lower, upper), weight_params) + if "upper" in con: arch.cond(arch.weights < con.upper, weight_params) - elif "lower" in con: + if "lower" in con: arch.cond(arch.weights > con.lower, weight_params) elif con.name == "macs": arch.macs = extract_macs_recursive(arch) mac_params = extract_parameter_from_expression(arch.macs) mac_params = [p for p in mac_params if 'stride' not in p.name and 'groups' not in p.name] - if "lower" in con and "upper" in con: - upper = arch.macs < con.upper - lower = arch.macs > con.lower - arch.cond(And(lower, upper), mac_params) - elif "upper" in con: + # if "lower" in con and "upper" in con: + # upper = arch.macs < con.upper + # lower = arch.macs > con.lower + # arch.cond(And(lower, upper), mac_params) + if "upper" in con: arch.cond(arch.macs < con.upper, mac_params) - elif "lower" in con: + if "lower" in con: arch.cond(arch.macs > con.lower, mac_params) + # elif con.name == "ofm": + # for node in get_nodes(arch): + # if isinstance(node, Conv2d): + # ofm_vol = node.shape()[1] * node.shape()[2] * node.shape()[3] + # node.operands[1].axis[0] # output channels + # possible_params = [node.stride, node.dilation, node.operands[1].shape()[0]] + # arch.cond(ofm_vol < con.upper, allowed_params=[p for p in possible_params if is_parametrized(p)]) else: raise NotImplementedError(f"Constraint {con.name} not implemented") return arch diff --git a/hannah/models/embedded_vision_net/parameters.py b/hannah/models/embedded_vision_net/parameters.py index 84706eec..431cfbce 100644 --- a/hannah/models/embedded_vision_net/parameters.py +++ b/hannah/models/embedded_vision_net/parameters.py @@ -57,3 +57,5 @@ def set_current(self, x): if x not in possible_values: diff = np.abs(np.array(possible_values)) - x self.current_value = int(possible_values[np.argmin(diff)]) + else: + self.current_value = int(x) diff --git a/hannah/nas/constraints/random_walk.py b/hannah/nas/constraints/random_walk.py index c1fa4c43..45e5dd5a 100644 --- a/hannah/nas/constraints/random_walk.py +++ b/hannah/nas/constraints/random_walk.py @@ -1,11 +1,72 @@ from copy import deepcopy +from dataclasses import dataclass import random from typing import Any import numpy as np +from hannah.nas.functional_operators.lazy import lazy from hannah.nas.parameters.parametrize import set_parametrization from hannah.nas.search.utils import np_to_primitive +@dataclass +class Param: + name: str + value: Any + + +def hierarchical_parameter_dict(parameter, include_empty=False, flatten=False): + hierarchical_params = {} + for key, param in parameter.items(): + key_list = key.split(".") + key_list = key_list[:2] + [".".join(key_list[2:])] + current_param_branch = hierarchical_params + for k in key_list: + try: + index = int(k) + if index not in current_param_branch: + current_param_branch[index] = {} + # current_param_branch = current_param_branch[index] + except Exception: + index = k + if k not in current_param_branch: + current_param_branch[k] = {} + + if k == key_list[-1]: + current_param_branch[index] = Param(name=key, value=param) + else: + current_param_branch = current_param_branch[index] + return hierarchical_params + + +CHOICES = {0: "conv", 1: "expand_reduce", 2: "reduce_expand", 3: "pooling"} + + +def get_active_parameter(params): + active_params = {} + params = hierarchical_parameter_dict(params) + num_blocks = params["ChoiceOp_0"]["num_blocks"][""].value + active_params["num_blocks"] = num_blocks + 1 + for i in range(num_blocks+1): + current_block = f"block_{i}" + depth = params[current_block]["ChoiceOp_0"]["depth"].value + active_params[params[current_block]["ChoiceOp_0"]["depth"].name] = depth + 1 + for j in range(depth + 1): + current_pattern = f"pattern_{j}" + choice = params[current_block][current_pattern]["ChoiceOp_0.choice"].value + for k, v in params[current_block][current_pattern].items(): + if k.split(".")[0] == "Conv2d_0": + active_params[v.name] = v.value + elif "expand_reduce" in k and choice == 1: + active_params[v.name] = v.value + elif "reduce_expand" in k and choice == 2: + active_params[v.name] = v.value + elif "pooling" in k and choice == 3: + active_params[v.name] = v.value + elif "ChoiceOp" in k: + active_params[v.name] = CHOICES[v.value] + return active_params + + class RandomWalkConstraintSolver: def __init__(self, max_iterations=5000) -> None: self.max_iterations = max_iterations @@ -50,8 +111,21 @@ def build_model(self, conditions, fixed_vars=[]): # ct += 1 # print(f"Failed to solve constraint {i}.") + def right_direction(self, current, new, direction): + if direction == ">": + if new > current: + return True + else: + return False + elif direction == "<": + if new < current: + return True + else: + return False + def solve(self, module, parameters, fix_vars=[]): - mod = deepcopy(module) + mod = deepcopy(module) # FIXME copying is inefficient + # mod = module self.solution = deepcopy(parameters) params = deepcopy(parameters) set_parametrization(parameters, mod.parametrization(flatten=True)) @@ -62,32 +136,60 @@ def solve(self, module, parameters, fix_vars=[]): knobs = list(reversed(knobs)) for i, con in enumerate(constraints): - param_keys = list(params.keys()) + all_param_keys = list(params.keys()) if knobs[i] is not None: - param_keys = [p.id for p in knobs[i]] + all_param_keys = [p.id for p in knobs[i]] + + direction = con.symbol ct = 0 while ct < self.max_iterations: + active_params = get_active_parameter(params) + param_keys = [p for p in all_param_keys if p in active_params] + current = con.lhs.evaluate() if con.evaluate(): self.solution.update(params) solved_conditions.append(con) print(f"Solved constraint {i} with {ct} iterations.") break else: + new_target = lazy(con.lhs) key_to_change = random.choice(param_keys) old_val = mod.parametrization(flatten=True)[key_to_change].current_value new_val = mod.parametrization(flatten=True)[key_to_change].sample() + j = 0 + while not self.right_direction(current, new_target, direction): + mod.parametrization(flatten=True)[key_to_change].set_current(old_val) + + key_to_change = random.choice(param_keys) + old_val = mod.parametrization(flatten=True)[key_to_change].current_value + new_val = mod.parametrization(flatten=True)[key_to_change].sample() + new_target = lazy(con.lhs) + param_keys.remove(key_to_change) + j += 1 + if j > 1000: + raise Exception("Timeout: Failed to find improvement.") + + # print(f"Param: {key_to_change}: {old_val} -> {new_val}") valid = True - for c in solved_conditions: - if not c.evaluate(): - print("Solution violated already satisfied constraint") - # reverse modification to satisfy already solved constraints again - mod.parametrization(flatten=True)[key_to_change].set_current(old_val) - valid = False + new_target = con.lhs.evaluate() + if self.right_direction(current, new_target, direction): + for c in solved_conditions: + if not c.evaluate(): + print("Solution violated already satisfied constraint") + # reverse modification to satisfy already solved constraints again + param_keys.remove(key_to_change) + mod.parametrization(flatten=True)[key_to_change].set_current(old_val) + valid = False + else: + print("No improvement") + mod.parametrization(flatten=True)[key_to_change].set_current(old_val) + valid = False if valid: # update proposed solution for this constraint params[key_to_change] = new_val + print(f"Constraint lhs: {lazy(con.lhs)} - rhs: {lazy(con.rhs)}") ct += 1 if ct == self.max_iterations-1: print(f"Failed to solve constraint {i}.") diff --git a/hannah/nas/functional_operators/torch_conversion.py b/hannah/nas/functional_operators/torch_conversion.py new file mode 100644 index 00000000..8a677dd0 --- /dev/null +++ b/hannah/nas/functional_operators/torch_conversion.py @@ -0,0 +1,115 @@ +import torch +from hannah.nas.functional_operators.lazy import lazy +from hannah.nas.functional_operators.op import Op, Tensor +from hannah.nas.functional_operators.operators import Conv2d, Relu, Linear, BatchNorm, AdaptiveAvgPooling, Add +from hannah.nas.functional_operators.executor import BasicExecutor +from hannah.models.embedded_vision_net.models import search_space +import torch.nn as nn + + +class Classifier(nn.Module): + def __init__(self, in_features, out_features) -> None: + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear = nn.Linear(in_features=in_features, out_features=out_features, bias=False) + + def forward(self, x): + out = torch.flatten(x, start_dim=1) + out = self.linear(out) + return out + + +class TorchConverter(BasicExecutor): + def __init__(self, search_space) -> None: + super().__init__(search_space) + self.conversions = {Conv2d: self.convert_conv2d, + Relu: self.convert_relu, + Linear: self.convert_linear, + BatchNorm: self.convert_batch_norm, + AdaptiveAvgPooling: self.convert_adaptive_avg_pool} + self.mods = nn.ModuleDict() + + def convert_conv2d(self, node): + in_channels = lazy(node.in_channels) + out_channels = lazy(node.out_channels) + kernel_size = lazy(node.kernel_size) + stride = lazy(node.stride) + padding = lazy(node.padding) + dilation = lazy(node.dilation) + groups = lazy(node.groups) + + return nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=False) # FIXME: Add Bias to functional convs + + def convert_relu(self, node): + return nn.ReLU() + + def convert_linear(self, node): + in_features = lazy(node.in_features) + out_features = lazy(node.out_features) + return nn.Linear(in_features, out_features) + + def convert_batch_norm(self, node): + num_features = lazy(node.operands[0].shape()[1]) + return nn.BatchNorm2d(num_features=num_features) + + def convert_max_pool(self, node): + pass + + def convert_avg_pool(self, node): + pass + + def convert_adaptive_avg_pool(self, node): + output_size = lazy(node.output_size) + return nn.AdaptiveAvgPool2d(output_size=output_size) + + def convert(self): + self.find_execution_order() + to_remove = [] + for node_name in self.nodes: + node = self.node_dict[node_name] + if type(node) in self.conversions: + op = self.conversions[type(node)](node) + self.mods[node_name.replace(".", "_")] = op + remove_operands = [operand for operand in self.forward_dict[node_name] if not isinstance(self.node_dict[operand], Op) and not operand == "input"] + for operand in remove_operands: + self.forward_dict[node_name].remove(operand) + elif not isinstance(node, Op): + to_remove.append(node_name) + + for node_name in to_remove: + self.nodes.remove(node_name) + del self.forward_dict[node_name] + print() + + def forward(self, x): + # FIXME: Remove obsolete entries in out + out = {'input': x} + for node in self.nodes: + node_name = node.replace(".", "_") + operands = [out[n] for n in self.forward_dict[node]] + + if node_name in self.mods: + if isinstance(self.mods[node_name], nn.Linear): + operands = [torch.flatten(operands[0], start_dim=1)] + out[node] = self.mods[node_name](*operands) + elif isinstance(self.node_dict[node], Add): + out[node] = torch.add(*operands) + return out[node] + + +if __name__ == '__main__': + input = Tensor(name="input", shape=(32, 3, 32, 32), axis=("N", "C", "H", "W")) + space = search_space(name="evn", input=input, num_classes=10) + # space.sample() + converter = TorchConverter(space) + converter.convert() + x = torch.randn(input.shape()) + converter.forward(x) diff --git a/hannah/nas/search/search.py b/hannah/nas/search/search.py index 353c4e4e..8e174137 100644 --- a/hannah/nas/search/search.py +++ b/hannah/nas/search/search.py @@ -220,7 +220,7 @@ def after_search(self): pass # self.extract_best_model() - def sample_candidates(self, num_total, num_candidates=None, sort_key="ff", presample=False): + def sample_candidates(self, num_total, num_candidates=None, sort_key="val_error", presample=False): candidates = [] skip_ct = 0 while len(candidates) < num_total: @@ -319,9 +319,9 @@ def sample(self): parameters ) break - except Exception: - pass - print() + except Exception as e: + print("Error occured while sampling: ") + print(str(e)) else: parameters, keys = self.sampler.next_parameters() return parameters