diff --git a/hannah/conf/nas/aging_evolution_nas.yaml b/hannah/conf/nas/aging_evolution_nas.yaml index 25372389..c4bef1ab 100644 --- a/hannah/conf/nas/aging_evolution_nas.yaml +++ b/hannah/conf/nas/aging_evolution_nas.yaml @@ -31,6 +31,7 @@ n_jobs: 10 presample: False total_candidates: 50 num_selected_candidates: 20 +constrained_sampling_on_search: True bounds: val_error: 0.1 # total_macs: 128000000 diff --git a/hannah/nas/constraints/random_walk.py b/hannah/nas/constraints/random_walk.py index 47fad65f..7afbcad7 100644 --- a/hannah/nas/constraints/random_walk.py +++ b/hannah/nas/constraints/random_walk.py @@ -29,6 +29,7 @@ from hannah.nas.parameters.parameters import Parameter from hannah.nas.parameters.parametrize import set_parametrization from hannah.nas.search.utils import np_to_primitive +from hannah.nas.functional_operators.utils.visit import get_active_parameters logger = logging.getLogger(__name__) @@ -72,36 +73,36 @@ def hierarchical_parameter_dict(parameter, include_empty=False, flatten=False): } -def get_active_parameter(net): - active_param_ids = [] - queue = [net] - visited = [net.id] - - def extract_parameters(node): - ids = [] - for k, p in node._PARAMETERS.items(): - if isinstance(p, Parameter): - ids.append(p.id) - return ids - - while queue: - current = queue.pop() - if isinstance(current, ChoiceOp): - # handle choices - active_param_ids.append(current.switch.id) - chosen_path = current.options[lazy(current.switch)] - if chosen_path.id not in visited: - queue.append(chosen_path) - visited.append(chosen_path.id) - else: - # handle all other operators & tensors - active_param_ids.extend(extract_parameters(current)) - for operand in current.operands: - if operand.id not in visited: - queue.append(operand) - visited.append(operand.id) - - return active_param_ids +# def get_active_parameter(net): +# active_param_ids = [] +# queue = [net] +# visited = [net.id] + +# def extract_parameters(node): +# ids = [] +# for k, p in node._PARAMETERS.items(): +# if isinstance(p, Parameter): +# ids.append(p.id) +# return ids + +# while queue: +# current = queue.pop() +# if isinstance(current, ChoiceOp): +# # handle choices +# active_param_ids.append(current.switch.id) +# chosen_path = current.options[lazy(current.switch)] +# if chosen_path.id not in visited: +# queue.append(chosen_path) +# visited.append(chosen_path.id) +# else: +# # handle all other operators & tensors +# active_param_ids.extend(extract_parameters(current)) +# for operand in current.operands: +# if operand.id not in visited: +# queue.append(operand) +# visited.append(operand.id) + +# return active_param_ids class RandomWalkConstraintSolver: @@ -193,7 +194,7 @@ def solve(self, module, parameters, fix_vars=[]): ct = 0 while ct < self.max_iterations: # active_params = get_active_parameter(params) - active_params = get_active_parameter(mod) + active_params = list(get_active_parameters(mod).keys()) param_keys = [p for p in all_param_keys if p in active_params] current = con.lhs.evaluate() diff --git a/hannah/nas/functional_operators/utils/visit.py b/hannah/nas/functional_operators/utils/visit.py index 35aa17a9..e6b59c74 100644 --- a/hannah/nas/functional_operators/utils/visit.py +++ b/hannah/nas/functional_operators/utils/visit.py @@ -17,6 +17,8 @@ # limitations under the License. # from ..op import BaseNode +from hannah.nas.functional_operators.op import ChoiceOp +from hannah.nas.parameters.parameters import Parameter def post_order(op: BaseNode): @@ -40,3 +42,30 @@ def post_order(op: BaseNode): def reverse_post_order(op: BaseNode): """Visits the operator graph in reverse post order""" return reversed(list(post_order(op))) + + +def get_active_parameters(space, parametrization=None): + if parametrization is None: + parametrization = space.parametrization() + + queue = [space] + visited = [space.id] + active_params = {} + + while queue: + node = queue.pop(0) + for k, p in node._PARAMETERS.items(): + if isinstance(p, Parameter): + active_params[p.id] = parametrization[p.id] + for operand in node.operands: + while isinstance(operand, ChoiceOp): + for k, p in operand._PARAMETERS.items(): + if isinstance(p, Parameter): + active_params[p.id] = parametrization[p.id] + active_op_index = operand.switch.evaluate() + operand = operand.operands[active_op_index] + if operand.id not in visited: + queue.append(operand) + visited.append(operand.id) + return active_params + diff --git a/hannah/nas/search/sampler/aging_evolution.py b/hannah/nas/search/sampler/aging_evolution.py index 70e97804..b2ffb47a 100644 --- a/hannah/nas/search/sampler/aging_evolution.py +++ b/hannah/nas/search/sampler/aging_evolution.py @@ -30,9 +30,9 @@ from hannah.nas.search.sampler.mutator import ParameterMutator from hannah.nas.search.utils import np_to_primitive -from ...parametrization import SearchSpace from ...utils import is_pareto from .base_sampler import Sampler, SearchResult +from hannah.nas.functional_operators.utils.visit import get_active_parameters class FitnessFunction: @@ -59,14 +59,16 @@ class AgingEvolutionSampler(Sampler): def __init__( self, parent_config, + search_space, parametrization: dict, population_size: int = 50, random_state = None, sample_size: int = 10, + mutation_rate: float = 0.01, eps: float = 0.1, output_folder=".", ): - super().__init__(parent_config, output_folder=output_folder) + super().__init__(parent_config, search_space=search_space, output_folder=output_folder) self.bounds = self.parent_config.nas.bounds self.parametrization = parametrization @@ -79,7 +81,7 @@ def __init__( self.population_size = population_size self.sample_size = sample_size self.eps = eps - self.mutator = ParameterMutator(0.1) + self.mutator = ParameterMutator(mutation_rate) self.history = [] self.population = [] @@ -118,8 +120,11 @@ def next_parameters(self): parent = sample[np.argmin(fitness)] parent_parametrization = set_parametrization(parent.parameters, self.parametrization) + parametrization = {key: param.current_value for key, param in parent_parametrization.items()} + active_parameters = get_active_parameters(self.search_space, parent_parametrization) - parametrization, mutated_keys = self.mutator.mutate(parent_parametrization) + mutated_parameters, mutated_keys = self.mutator.mutate(active_parameters) + parametrization.update(mutated_parameters) return parametrization, mutated_keys diff --git a/hannah/nas/search/sampler/base_sampler.py b/hannah/nas/search/sampler/base_sampler.py index 749598e4..b29b4c4b 100644 --- a/hannah/nas/search/sampler/base_sampler.py +++ b/hannah/nas/search/sampler/base_sampler.py @@ -27,8 +27,10 @@ def costs(self): class Sampler(ABC): def __init__(self, parent_config, + search_space, output_folder=".") -> None: self.history = [] + self.search_space = search_space self.output_folder = Path(output_folder) self.parent_config = parent_config diff --git a/hannah/nas/search/sampler/random_sampler.py b/hannah/nas/search/sampler/random_sampler.py index 11671e2a..1ea29d82 100644 --- a/hannah/nas/search/sampler/random_sampler.py +++ b/hannah/nas/search/sampler/random_sampler.py @@ -25,10 +25,11 @@ class RandomSampler(Sampler): def __init__( self, parent_config, + search_space, parametrization, output_folder=".", ) -> None: - super().__init__(parent_config=parent_config, output_folder=output_folder) + super().__init__(parent_config=parent_config, search_space=search_space, output_folder=output_folder) self.parametrization = parametrization if (self.output_folder / "history.yml").exists(): diff --git a/hannah/nas/search/search.py b/hannah/nas/search/search.py index a1f8c011..3605f2ee 100644 --- a/hannah/nas/search/search.py +++ b/hannah/nas/search/search.py @@ -129,6 +129,7 @@ def before_search(self): parametrization = self.search_space.parametrization(flatten=True) self.sampler = instantiate( self.config.nas.sampler, + search_space=self.search_space, parametrization=parametrization, parent_config=self.config, _recursive_=False, @@ -247,7 +248,7 @@ def sample_candidates( num_candidates=None, sort_key="val_error", presample=False, - constrain=False, + constrain=True, ): candidates = [] skip_ct = 0 diff --git a/hannah/nas/test/test_active_parameters.py b/hannah/nas/test/test_active_parameters.py new file mode 100644 index 00000000..7abbe9a1 --- /dev/null +++ b/hannah/nas/test/test_active_parameters.py @@ -0,0 +1,19 @@ +from hannah.nas.functional_operators.op import Tensor +from hannah.models.embedded_vision_net.models import embedded_vision_net +from hannah.nas.functional_operators.utils.visit import get_active_parameters + + +def test_active_parameters(): + input = Tensor(name="input", shape=(1, 3, 32, 32), axis=("N", "C", "H", "W")) + space = embedded_vision_net("space", input, num_classes=10) + space.parametrization()["embedded_vision_net_0.ChoiceOp_0.num_blocks"].set_current(1) + space.parametrization()["embedded_vision_net_0.block_0.pattern_0.ChoiceOp_0.choice"].set_current(4) + space.parametrization()["embedded_vision_net_0.block_0.pattern_0.sandglass_block_0.expansion_0.ChoiceOp_0.choice"].set_current(1) + active_params = get_active_parameters(space) + + space.parametrization() + print() + + +if __name__ == "__main__": + test_active_parameters() \ No newline at end of file diff --git a/hannah/nas/test/test_max78000_backend.py b/hannah/nas/test/test_max78000_backend.py index 060935e9..4d1028a8 100644 --- a/hannah/nas/test/test_max78000_backend.py +++ b/hannah/nas/test/test_max78000_backend.py @@ -96,7 +96,7 @@ def get_graph(seed): warnings.warn("remove this when seedable randomsampling works") print("Init sampler") - sampler = RandomSampler(None, graph.parametrization(flatten=True)) + sampler = RandomSampler(None, graph, graph.parametrization(flatten=True)) print("Init solver") solver = RandomWalkConstraintSolver() diff --git a/hannah/nas/test/test_nn_meter.py b/hannah/nas/test/test_nn_meter.py index aed6364d..cca3bf92 100644 --- a/hannah/nas/test/test_nn_meter.py +++ b/hannah/nas/test/test_nn_meter.py @@ -126,7 +126,7 @@ def test_nn_meter(hardware_name): predictor = NNMeterPredictor(hardware_name) print("Init sampler") - sampler = RandomSampler(None, net.parametrization(flatten=True)) + sampler = RandomSampler(None, net, net.parametrization(flatten=True)) print("Init solver") solver = RandomWalkConstraintSolver() diff --git a/hannah/nas/test/test_onnx_export.py b/hannah/nas/test/test_onnx_export.py index fa76cb22..5778bd55 100644 --- a/hannah/nas/test/test_onnx_export.py +++ b/hannah/nas/test/test_onnx_export.py @@ -106,7 +106,7 @@ def test_export_embedded_vision_net(): print(graph) print("Init sampler") - sampler = RandomSampler(None, graph.parametrization(flatten=True)) + sampler = RandomSampler(None, graph, graph.parametrization(flatten=True)) print("Init solver") solver = RandomWalkConstraintSolver() @@ -141,7 +141,7 @@ def test_export_ai8x_net(): print(graph) print("Init sampler") - sampler = RandomSampler(None, graph.parametrization(flatten=True)) + sampler = RandomSampler(None, graph, graph.parametrization(flatten=True)) print("Init solver") solver = RandomWalkConstraintSolver() diff --git a/hannah/nas/test/test_random_walk_constrainer.py b/hannah/nas/test/test_random_walk_constrainer.py index a4ecc619..a78af75c 100644 --- a/hannah/nas/test/test_random_walk_constrainer.py +++ b/hannah/nas/test/test_random_walk_constrainer.py @@ -1,5 +1,6 @@ from hannah.nas.functional_operators.op import Tensor, scope, search_space -from hannah.nas.constraints.random_walk import get_active_parameter, RandomWalkConstraintSolver +from hannah.nas.functional_operators.utils.visit import get_active_parameters +from hannah.nas.constraints.random_walk import RandomWalkConstraintSolver from hannah.models.embedded_vision_net.operators import adaptive_avg_pooling, add, conv_relu, dynamic_depth, linear from hannah.nas.parameters.parameters import CategoricalParameter, IntScalarParameter @@ -44,17 +45,17 @@ def space(input): def test_get_active_params(): input = Tensor(name='input', shape=(1, 3, 32, 32), axis=('N', 'C', 'H', 'W')) out = space(input) - active_params = get_active_parameter(out) + active_params = list(get_active_parameters(out).keys()) assert len(active_params) == 7 for p in active_params: assert "parallel_blocks_1" not in p and "parallel_blocks_2" not in p out.parametrization()['space_0.ChoiceOp_0.depth'].set_current(1) - active_params = get_active_parameter(out) + active_params = get_active_parameters(out) assert len(active_params) == 10 for p in active_params: assert "parallel_blocks_2" not in p out.parametrization()['space_0.ChoiceOp_0.depth'].set_current(2) - active_params = get_active_parameter(out) + active_params = get_active_parameters(out) assert len(active_params) == 13 diff --git a/hannah/nas/test/test_searchspace_to_graph.py b/hannah/nas/test/test_searchspace_to_graph.py index 1935d701..94a31125 100644 --- a/hannah/nas/test/test_searchspace_to_graph.py +++ b/hannah/nas/test/test_searchspace_to_graph.py @@ -46,7 +46,7 @@ def test_model_conversion(model): model.sample() print("Init sampler") - sampler = RandomSampler(None, model.parametrization(flatten=True)) + sampler = RandomSampler(None, model, model.parametrization(flatten=True)) print("Init solver") solver = RandomWalkConstraintSolver()