Skip to content

Commit

Permalink
Merge pull request #113 from microsoft/dev/add-error-info
Browse files Browse the repository at this point in the history
Refine api and fix typo
  • Loading branch information
Lynazhang authored May 5, 2023
2 parents d0c9cef + d7bbd25 commit 8006ed6
Show file tree
Hide file tree
Showing 11 changed files with 151 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def load_config(self):
self.eps = self.config['EMP_ALPHA']

def test(self):
secondary_op_lat = min(lat for op, lat in self.latency.items() if op != 'block' or op != self.false_case)
secondary_op_lat = min(lat for op, lat in self.latency.items() if op != 'block' and op != self.false_case)
return self.latency[self.false_case].avg - self.latency['block'].avg > self.eps * secondary_op_lat.avg

def load_latency(self, testcase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"dwconv-bn-relu": ["DwConvBnRelu", "DwConvSampler"],
"dwconv-bn-relu6": ["DwConvBnRelu6", "DwConvSampler"],
"dwconv-block": ["DwConvBlock", "DwConvSampler"],
"dwconv-bn-hswish": ["ConvBnHswish", "DwConvSampler"],
"dwconv-bn-hswish": ["DwConvBnHswish", "DwConvSampler"],
# others
"maxpool": ["MaxPoolBlock", "PoolingSampler"],
"avgpool": ["AvgPoolBlock", "PoolingSampler"],
Expand All @@ -39,6 +39,7 @@
"bnrelu": ["BnRelu", "HwCinSampler"],
"bn": ["BnBlock", "HwCinSampler"],
"hswish": ["HswishBlock", "HwCinSampler"],
"swish": ["SwishBlock", "HwCinSampler"],
"relu": ["ReluBlock", "HwCinSampler"],
"addrelu": ["AddRelu", "HwCinSampler"],
"add": ["AddBlock", "HwCinSampler"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"conv-hswish": ["HW", "CIN", "COUT", "KERNEL_SIZE", "STRIDES"],
"conv-block": ["HW", "CIN", "COUT", "KERNEL_SIZE", "STRIDES"],
"conv-bn-hswish": ["HW", "CIN", "COUT", "KERNEL_SIZE", "STRIDES"],
"conv-swish": ["HW", "CIN", "COUT", "KERNEL_SIZE", "STRIDES"],
# dwconv ("COUT" will always be the same as "CIN")
"dwconv-bn": ["HW", "CIN", "COUT", "KERNEL_SIZE", "STRIDES"],
"dwconv-relu": ["HW", "CIN", "COUT", "KERNEL_SIZE", "STRIDES"],
Expand All @@ -28,6 +29,7 @@
"dwconv-bn-relu6": ["HW", "CIN", "COUT", "KERNEL_SIZE", "STRIDES"],
"dwconv-block": ["HW", "CIN", "COUT", "KERNEL_SIZE", "STRIDES"],
"dwconv-bn-hswish": ["HW", "CIN", "COUT", "KERNEL_SIZE", "STRIDES"],
"dwconv-swish": ["HW", "CIN", "COUT", "KERNEL_SIZE", "STRIDES"],
# pooling ("COUT" will always be the same as "CIN")
"maxpool": ["HW", "CIN", "COUT", "KERNEL_SIZE", "POOL_STRIDES"],
"avgpool": ["HW", "CIN", "COUT", "KERNEL_SIZE", "POOL_STRIDES"],
Expand All @@ -41,6 +43,7 @@
"bnrelu": ["HW", "CIN"],
"bn": ["HW", "CIN"],
"hswish": ["HW", "CIN"],
"swish": ["HW", "CIN"],
"relu": ["HW", "CIN"],
# In "addrelu" block and "add" block, the second feature "CIN" will always be the same as
# the third feature
Expand Down Expand Up @@ -166,7 +169,7 @@ def get_data_by_profiled_results(kernel_type, feature_parser, cfgs_path, labs_pa
paths, features, labs = [], [], []
for id in labs_dict.keys():
try:
path = cfgs_dict[id]["model"]
path = cfgs_dict[id]["converted_model"] if "converted_model" in cfgs_dict[id] else cfgs_dict[id]["model"]
configs = cfgs_dict[id]["config"]
feature = feature_parser.get_feature_by_config(configs)
if predict_label == "latency":
Expand Down
32 changes: 23 additions & 9 deletions nn_meter/builder/nn_meter_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import time
import signal
import logging
import subprocess
from . import builder_config
from .utils import save_profiled_results, merge_info, handle_timeout
from nn_meter.builder.backends import connect_backend
logging = logging.getLogger("nn-Meter")


def convert_models(backend, models, mode = 'predbuild', broken_point_mode = False):
def convert_models(backend, models, mode = 'predbuild', broken_point_mode = False, model_save_path = None):
""" convert the model to the needed format by backend, in order to increase efficiency when profiling on device.
@params:
Expand All @@ -24,6 +25,8 @@ def convert_models(backend, models, mode = 'predbuild', broken_point_mode = Fals
broken_point_mode (boolean): broken_point_mode will skip all models have attributes "converted_model"
model_save_path (str or None): path to save converted models, if not set, the converted model will be placed in
the same directory as the original model.
"""
if isinstance(models, str):
save_name = os.path.basename(models)
Expand All @@ -33,8 +36,6 @@ def convert_models(backend, models, mode = 'predbuild', broken_point_mode = Fals
save_name = "converted_results.json"

workspace_path = builder_config.get('WORKSPACE', mode)
model_save_path = os.path.join(workspace_path, 'testcases' if mode == 'ruletest' else 'kernels')
os.makedirs(model_save_path, exist_ok=True)
res_save_path = os.path.join(workspace_path, "results")
os.makedirs(res_save_path, exist_ok=True)

Expand All @@ -48,7 +49,7 @@ def convert_models(backend, models, mode = 'predbuild', broken_point_mode = Fals
continue
try:
model_path = model['model']
converted_model = backend.convert_model(model_path, model_save_path, input_shape=model['shapes'])
converted_model = backend.convert_model(model_path, model_save_path or os.path.dirname(model_path), input_shape=model['shapes'])
model['converted_model'] = converted_model
count += 1
except Exception as e:
Expand All @@ -69,8 +70,8 @@ def convert_models(backend, models, mode = 'predbuild', broken_point_mode = Fals
return models


def profile_models(backend, models, mode = 'ruletest', metrics = ["latency"], save_name = "profiled_results.json",
have_converted = False, log_frequency = 50, broken_point_mode = False, time_threshold = 300, **kwargs):
def profile_models(backend, models, mode = 'ruletest', metrics = ["latency"], save_name = "profiled_results.json", have_converted = False,
log_frequency = 50, broken_point_mode = False, time_threshold = 300, is_pixel6 = None, model_save_path = None, **kwargs):
""" run models with given backend and return latency of testcase models
@params:
Expand All @@ -95,6 +96,9 @@ def profile_models(backend, models, mode = 'ruletest', metrics = ["latency"], sa
time_threshold (int): the time threshold for profiling one single model. If the total profiling time of a model is longger than the
`time_threshold` (second), nn-Meter will log a profiling timeout error for this model and step to profile the next model.
model_save_path (str or None): path to save converted models, if not set, the converted model will be placed in
the same directory as the original model.
**kwargs: arguments for profiler, such as `taskset` and `close_xnnpack` in TFLite profiler
"""
signal.signal(signal.SIGALRM, handle_timeout)
Expand All @@ -103,8 +107,6 @@ def profile_models(backend, models, mode = 'ruletest', metrics = ["latency"], sa
models = json.load(fp)

workspace_path = builder_config.get('WORKSPACE', mode)
model_save_path = os.path.join(workspace_path, 'testcases' if mode == 'ruletest' else 'kernels')
os.makedirs(model_save_path, exist_ok=True)
res_save_path = os.path.join(workspace_path, "results")
os.makedirs(res_save_path, exist_ok=True)
info_save_path = os.path.join(res_save_path, save_name)
Expand Down Expand Up @@ -148,7 +150,7 @@ def profile_models(backend, models, mode = 'ruletest', metrics = ["latency"], sa
try:
model_path = model['model']
signal.alarm(time_threshold)
profiled_res = backend.profile_model_file(model_path, model_save_path, input_shape=model['shapes'], metrics=metrics, **kwargs)
profiled_res = backend.profile_model_file(model_path, model_save_path or os.path.dirname(model_path), input_shape=model['shapes'], metrics=metrics, **kwargs)
signal.alarm(0)
for metric in metrics:
model[metric] = profiled_res[metric]
Expand All @@ -159,8 +161,20 @@ def profile_models(backend, models, mode = 'ruletest', metrics = ["latency"], sa

# save information to json file for per 50 models
if count > 0 and count % log_frequency == 0:
freq = None
save_profiled_results(models, info_save_path, detail, metrics)
logging.keyinfo(f"{count} models complete. Still profiling... Save the intermediate results to {info_save_path} ")
if is_pixel6 != None:
freq = subprocess.check_output(
["adb", "-s", "1B261FDF6009KS", "shell", "cat", "/sys/devices/system/cpu/cpu6/cpufreq/scaling_cur_freq"])
# import pdb; pdb.set_trace()
loop = 0
while freq != is_pixel6 and loop < 100:
time.sleep(2)
freq = subprocess.check_output(
["adb", "-s", "1B261FDF6009KS", "shell", "cat", "/sys/devices/system/cpu/cpu6/cpufreq/scaling_cur_freq"])
loop += 1
print(f"[freq: {freq}] loop {loop}")

# save information to json file
save_profiled_results(models, info_save_path, detail, metrics)
Expand Down
21 changes: 21 additions & 0 deletions nn_meter/builder/nn_modules/tf_networks/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,27 @@ def call(self, inputs):
return model


class SwishBlock(TFBlock):
def __init__(self, config, batch_size = 1):
super().__init__(config, batch_size)

swish_op = Swish(self.input_shape, config)
self.swish_op = swish_op.get_model()

def get_model(self):
class Model(keras.Model):
def __init__(self, swish_op):
super().__init__()
self.swish = swish_op

def call(self, inputs):
return self.swish(inputs)

model = Model(self.swish_op)
model(get_inputs_by_shapes(self.input_tensor_shape, self.batch_size))
return model


class ReluBlock(TFBlock):
def __init__(self, config, batch_size = 1):
super().__init__(config, batch_size)
Expand Down
7 changes: 7 additions & 0 deletions nn_meter/builder/nn_modules/tf_networks/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,13 @@ def func(inputs):
return inputs * relu6(inputs + 3.) * (1. / 6.)
return func


class Swish(BaseOperator):
def get_model(self):
def func(inputs):
return tf.keras.activations.swish(inputs)
return func

#---------------------- basic operation ----------------------#

class Reshape(BaseOperator):
Expand Down
57 changes: 56 additions & 1 deletion nn_meter/builder/nn_modules/torch_networks/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,39 @@ def get_model(self):
return self.build_model([self.conv_op, self.bn_op, self.relu_op])



class ConvSwish(TorchBlock):
def __init__(self, config, batch_size):
super().__init__(config, batch_size)
conv_op = Conv(self.input_shape, config)
self.conv_op, out_shape = conv_op.get_model(), conv_op.get_output_shape()

bn_op = BN(out_shape, config)
self.bn_op, out_shape = bn_op.get_model(), bn_op.get_output_shape()

swish_op = Swish(out_shape, config)
self.swish_op = swish_op.get_model()

def get_model(self):
return self.build_model([self.conv_op, self.bn_op, self.swish_op])


class DWConvSwish(nn.Module):
def __init__(self, config, batch_size):
super().__init__(config, batch_size)
dwconv_op = DwConv(self.input_shape, config)
self.dwconv_op, out_shape = dwconv_op.get_model(), dwconv_op.get_output_shape()

bn_op = BN(out_shape, config)
self.bn_op, out_shape = bn_op.get_model(), bn_op.get_output_shape()

swish_op = Swish(out_shape, config)
self.swish_op = swish_op.get_model()

def get_model(self):
return self.build_model([self.dwconv_op, self.bn_op, self.swish_op])


class ConvBnRelu6(TorchBlock):
def __init__(self, config, batch_size = 1):
super().__init__(config, batch_size)
Expand Down Expand Up @@ -281,7 +314,7 @@ def get_model(self):
return self.build_model([self.dwconv_op])


class ConvBnHswish(TorchBlock):
class DwConvBnHswish(TorchBlock):
def __init__(self, config, batch_size = 1):
super().__init__(config, batch_size)

Expand Down Expand Up @@ -407,6 +440,17 @@ def get_model(self):
return self.build_model([self.se_op])


class SESwishBlock(TorchBlock):
def __init__(self, config, batch_size = 1):
super().__init__(config, batch_size)

se_op = SE_swish(self.input_shape, config)
self.se_op = se_op.get_model()

def get_model(self):
return self.build_model([self.se_op])


class GlobalAvgPoolBlock(TorchBlock):
def __init__(self, config, batch_size = 1):
self.config = config
Expand Down Expand Up @@ -457,6 +501,17 @@ def get_model(self):
return self.build_model([self.hswish_op])


class SwishBlock(TorchBlock):
def __init__(self, config, batch_size = 1):
super().__init__(config, batch_size)

swish_op = Swish(self.input_shape, config)
self.swish_op = swish_op.get_model()

def get_model(self):
return self.build_model([self.swish_op])


class ReluBlock(TorchBlock):
def __init__(self, config, batch_size = 1):
super().__init__(config, batch_size)
Expand Down
34 changes: 34 additions & 0 deletions nn_meter/builder/nn_modules/torch_networks/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,31 @@ def forward(self, x):
return SE(self.input_shape[0])


class SE_swish(BaseOperator):
def get_model(self):
class SE_swish(nn.Module):
def __init__(self, num_channels, se_ratio=0.25):
super().__init__()
mid_channels = int(num_channels * se_ratio)
self.squeeze = nn.Conv2d(num_channels, mid_channels, kernel_size=1, padding=0)
self.relu = nn.ReLU()
self.excite = nn.Conv2d(mid_channels, num_channels, kernel_size=1, padding=0)
self.swish = nn.SiLU()

def _scale(self, x):
x = x.mean(3, keepdim=True).mean(2, keepdim=True)
x = self.squeeze(x)
x = self.relu(x)
x = self.excite(x)
x = self.swish(x)
return x

def forward(self, x):
scale = self._scale(x)
return scale * x
return SE_swish(self.input_shape[0])


class FC(BaseOperator):
def get_model(self):
cin = self.input_shape[-1]
Expand Down Expand Up @@ -157,6 +182,15 @@ class Hswish(BaseOperator):
def get_model(self):
return nn.Hardswish()


class Swish(BaseOperator):
def get_model(self):
class Swish(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
return Swish()


#---------------------- basic operation ----------------------#

class Reshape(BaseOperator):
Expand Down
4 changes: 2 additions & 2 deletions nn_meter/builder/nn_modules/torch_networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def get_inputs_by_shapes(shapes, batch_size = 1):
def get_padding(ks, s, hw):
""" choose padding value to make sure:
if s = 1, out_hw = in_hw;
if s = 2, out_hw = in_hw // 2;
if s = 4, out_hw = in_hw // 4;
if s = 2, out_hw = ceil(in_hw / 2);
if s = 4, out_hw = ceil(in_hw / 4);
"""
if hw % s == 0:
pad = max(ks - s, 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def get_padding_shape(input_shape, cout, k_size, strides, padding):
(ph.get_h(input_shape) - ph.get_h(k_size) + 1) / ph.get_h(strides)
)
outw = math.ceil(
(ph.get_h(input_shape) - ph.get_h(k_size) + 1) / ph.get_w(strides)
(ph.get_w(input_shape) - ph.get_w(k_size) + 1) / ph.get_w(strides)
)

pad_size = [0, 0, 0, 0]
Expand Down
4 changes: 0 additions & 4 deletions tests/integration_test/test_latency_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,6 @@ def integration_test(model_type, url, ppath, output_name = "tests/integration_te
# logging.error(f"Meets ERROR when checking --{model_type} {ppath} --predictor {pred_name} --predictor-version {pred_version}")

latency_list = parse_latency_info(result.decode('utf-8'))
print(model_type)
print(latency_list)
print("-----")
os.system("cat tests/integration_test/test_result.txt")
for model, latency in latency_list:
item = f'{model}, {model_type}, {pred_name}, {pred_version}, {round(float(latency), 4)}\n'
# print(item)
Expand Down

0 comments on commit 8006ed6

Please sign in to comment.