Skip to content

Commit

Permalink
Merge branch 'fix/search_space_scopes' into 'main'
Browse files Browse the repository at this point in the history
Vastly improve search space creation speed with temporary global dict

See merge request es/ai/hannah/hannah!388
  • Loading branch information
moreib committed May 28, 2024
2 parents 827d8d1 + b502a24 commit a0c9757
Show file tree
Hide file tree
Showing 19 changed files with 1,116 additions and 327 deletions.
26 changes: 20 additions & 6 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@

{
"name": "AgingEvolution",
"type": "python",
"type": "debugpy",
"request": "launch",
"module": "hannah.tools.train",
"justMyCode": false,
Expand All @@ -162,12 +162,26 @@
"trainer.overfit_batches=1",
"nas.n_jobs=1",
"nas.budget=100",
"nas.total_candidates=50",
"nas.num_selected_candidates=20",
"nas.sampler.population_size=20",
// "nas.predictor.model.input_feature_size=31",
"nas.total_candidates=5",
"nas.num_selected_candidates=2",
"nas.sampler.population_size=2",
// "nas.predictor.model.input_feature_size=35",
"module.num_workers=8",
"experiment_id=test_constraint_config",
"experiment_id=test_merge",
"fx_mac_summary=True",
// "~nas.predictor",
"~normalizer"
]
},
{
"name": "ChainedLinear",
"type": "python",
"request": "launch",
"module": "hannah.tools.train",
"justMyCode": false,
"cwd": "${workspaceFolder}/experiments/kws",
"args": [
"+experiment=ae_nas"
"fx_mac_summary=True",
// "~nas.predictor",
"~normalizer"
Expand Down
385 changes: 223 additions & 162 deletions doc/nas/search_spaces.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion hannah/conf/model/embedded_vision_net.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_target_: hannah.models.embedded_vision_net.models.search_space
_target_: hannah.models.embedded_vision_net.models.embedded_vision_net
name: embedded_vision_net
num_classes: 10
max_channels: 256
Expand Down
5 changes: 3 additions & 2 deletions hannah/models/embedded_vision_net/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from hannah.nas.expressions.types import Int
from hannah.nas.expressions.utils import extract_parameter_from_expression
from hannah.nas.functional_operators.executor import BasicExecutor
from hannah.nas.functional_operators.op import Tensor, get_nodes, scope
from hannah.nas.functional_operators.op import Tensor, get_nodes, scope, search_space
from hannah.nas.functional_operators.operators import Conv2d

# from hannah.nas.functional_operators.visualizer import Visualizer
Expand Down Expand Up @@ -106,7 +106,8 @@ def backbone(input, num_classes=10, max_channels=512, max_blocks=9):
return out


def search_space(
@search_space
def embedded_vision_net(
name,
input,
num_classes: int,
Expand Down
61 changes: 34 additions & 27 deletions hannah/nas/constraints/random_walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import numpy as np

from hannah.nas.functional_operators.lazy import lazy
from hannah.nas.functional_operators.op import ChoiceOp
from hannah.nas.parameters.parameters import Parameter
from hannah.nas.parameters.parametrize import set_parametrization
from hannah.nas.search.utils import np_to_primitive

Expand Down Expand Up @@ -70,32 +72,36 @@ def hierarchical_parameter_dict(parameter, include_empty=False, flatten=False):
}


def get_active_parameter(params):
# FIXME: this needs to be generalized
active_params = {}
params = hierarchical_parameter_dict(params)
num_blocks = params["ChoiceOp_0"]["num_blocks"][""].value
active_params["num_blocks"] = num_blocks + 1
for i in range(num_blocks + 1):
current_block = f"block_{i}"
depth = params[current_block]["ChoiceOp_0"]["depth"].value
active_params[params[current_block]["ChoiceOp_0"]["depth"].name] = depth + 1
for j in range(depth + 1):
current_pattern = f"pattern_{j}"
choice = params[current_block][current_pattern]["ChoiceOp_0.choice"].value
for k, v in params[current_block][current_pattern].items():
if k.split(".")[0] == "Conv2d_0":
active_params[v.name] = v.value
elif "expand_reduce" in k and choice == 1:
active_params[v.name] = v.value
elif "reduce_expand" in k and choice == 2:
active_params[v.name] = v.value
elif "pooling" in k and choice == 3:
active_params[v.name] = v.value
elif "ChoiceOp" in k:
active_params[v.name] = CHOICES[v.value]

return active_params
def get_active_parameter(net):
active_param_ids = []
queue = [net]
visited = [net.id]

def extract_parameters(node):
ids = []
for k, p in node._PARAMETERS.items():
if isinstance(p, Parameter):
ids.append(p.id)
return ids

while queue:
current = queue.pop()
if isinstance(current, ChoiceOp):
# handle choices
active_param_ids.append(current.switch.id)
chosen_path = current.options[lazy(current.switch)]
if chosen_path.id not in visited:
queue.append(chosen_path)
visited.append(chosen_path.id)
else:
# handle all other operators & tensors
active_param_ids.extend(extract_parameters(current))
for operand in current.operands:
if operand.id not in visited:
queue.append(operand)
visited.append(operand.id)

return active_param_ids


class RandomWalkConstraintSolver:
Expand Down Expand Up @@ -176,7 +182,8 @@ def solve(self, module, parameters, fix_vars=[]):

ct = 0
while ct < self.max_iterations:
active_params = get_active_parameter(params)
# active_params = get_active_parameter(params)
active_params = get_active_parameter(mod)

param_keys = [p for p in all_param_keys if p in active_params]
current = con.lhs.evaluate()
Expand Down
72 changes: 42 additions & 30 deletions hannah/nas/functional_operators/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,49 +75,55 @@ def get_unique_id():
return _id


_id = 0


def get_unique_id():
global _id
_id += 1
return _id

def get_highest_scope_counter(scope, scope_dict):
if scope in scope_dict:
scope_dict[scope] += 1
else:
scope_dict[scope] = 0
return scope_dict[scope]

# FIXME: Traverses nodes to often -> massively increases time when building
# search spaces
def get_highest_scope_counter(start_nodes, scope):
ct = -1
for start_node in start_nodes:
for n in get_nodes(start_node):
highest_scope = n.id.split(".")[0]
if scope == "_".join(highest_scope.split("_")[:-1]):
ct = max(int(highest_scope.split("_")[-1]), ct)
return ct


# TODO: Make scopes accessible, e.g., as a list
def scope(function):
"""Decorator defining a scope in a search space. The id of every subcomponent (operators or lower-hierarchy scopes)
enclosed in a function decorated with this will be prefixed with the name of the function, creating a
hierarchical scope.
"""
@wraps(function)
def set_scope(*args, **kwargs):
out = function(*args, **kwargs)
name = function.__name__
assert "global_scope_stack" in globals(), "No scope tracking found, did you wrap the search space with @search_space?"

inputs = [a for a in args if isinstance(a, (Op, Tensor))] + [
a for k, a in kwargs.items() if isinstance(a, (Op, Tensor))
]
ct = get_highest_scope_counter(inputs, name) + 1
name = function.__name__
global global_scope_stack
ct = get_highest_scope_counter(name, global_scope_stack[-1])
global_scope_stack.append({})
out = function(*args, **kwargs)
for n in nodes_in_scope(out, inputs):
n.setid(f"{name}_{ct}.{n.id}")
# n.id = f"{name}_{ct}.{n.id}"
# print(n.id)
# for k, p in n._PARAMETERS.items():
# if isinstance(p, Expression):
# p.id = f"{name}.{k}"
global_scope_stack.pop()
return out

return set_scope


def search_space(function):
"""Decorator to define a search space. For correct scoping,
a search space containing functional ops must be enclosed by
a function decorated with @search_space.
"""
@wraps(function)
def search_space_limits(*args, **kwargs):
global global_scope_stack
global_scope_stack = [{}]
out = scope(function)(*args, **kwargs)
del global_scope_stack
return out

return search_space_limits


@parametrize
class Op:
def __init__(self, name, *args, **kwargs) -> None:
Expand All @@ -133,10 +139,14 @@ def __call__(self, *operands) -> Any:
new_op = self # FIXME: Just use self?
for operand in operands:
operand.connect(new_op)
ct = get_highest_scope_counter(operands, self.name) + 1

# Some Ops (ChoiceOp) can be called multiple times and already have a counter
if not len(self.id.split(".")[-1].split("_")) > 1:
assert "global_scope_stack" in globals(), "No scope tracking found, did you wrap the search space with @search_space?"
global global_scope_stack
ct = get_highest_scope_counter(self.name, global_scope_stack[-1])
self.id = f"{self.id}_{ct}"
# self.setid(f"{self.id}_{ct}")
return new_op

def connect(self, node):
Expand Down Expand Up @@ -318,7 +328,9 @@ def _connect_options(self, *operands):
self.options[i] = node_opt(*operands)
if is_parametrized(self.options[i]):
self._PARAMETERS[self.options[i].id] = self.options[i] # FIXME:
ct = get_highest_scope_counter(operands, self.name) + 1
assert "global_scope_stack" in globals(), "No scope tracking found, did you wrap the search space with @search_space?"
global global_scope_stack
ct = get_highest_scope_counter(self.name, global_scope_stack[-1])
self.id = f"{self.id}_{ct}"

def shape_fun(self):
Expand Down
1 change: 0 additions & 1 deletion hannah/nas/performance_prediction/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def predict(
...


@runtime_checkable
class FitablePredictor(Predictor):
def load(self, result_folder: str):
"""Load predefined model from a folder.
Expand Down
3 changes: 2 additions & 1 deletion hannah/nas/test/test_functional_executor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from hannah.nas.functional_operators.executor import BasicExecutor
from hannah.nas.functional_operators.lazy import lazy
from hannah.nas.functional_operators.operators import Conv2d, Linear, Relu
from hannah.nas.functional_operators.op import Tensor
from hannah.nas.functional_operators.op import Tensor, search_space
from hannah.nas.parameters.parameters import CategoricalParameter, IntScalarParameter

from torch.optim import SGD
Expand Down Expand Up @@ -42,6 +42,7 @@ def linear(input, out_features):
return out


@search_space
def network(input):
out = conv_relu(input,
out_channels=IntScalarParameter(32, 64, name='out_channels'),
Expand Down
Loading

0 comments on commit a0c9757

Please sign in to comment.