Skip to content

Commit

Permalink
inference cost breakdown
Browse files Browse the repository at this point in the history
  • Loading branch information
Harish committed Feb 15, 2024
1 parent 39442cb commit 7608e7c
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 29 deletions.
25 changes: 18 additions & 7 deletions src/qonnx/analysis/inference_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,10 @@ def inference_cost_upsample(model, node, discount_sparsity):
return ret


def inference_cost(model, discount_sparsity=True):
def inference_cost(model, discount_sparsity=True, cost_breakdown=False):
"Ensure all nodes have unique names prior to calling this analysis pass."

node_costs = {}
ret, node_costs, nodes_per_optype = {}, {}, {}
zero_cost_ops = [
"MaxPool",
"AveragePool",
Expand Down Expand Up @@ -240,13 +240,24 @@ def inference_cost(model, discount_sparsity=True):
if node.op_type in inference_cost_fxn_map.keys():
node_cost = inference_cost_fxn_map[node.op_type](model, node, discount_sparsity)
node_costs[node.name] = node_cost
if node.op_type not in nodes_per_optype.keys():
new_optype = {}
new_optype[node.name] = node_cost
nodes_per_optype[node.op_type] = new_optype
else:
nodes_per_optype[node.op_type][node.name] = node_cost
elif node.op_type in zero_cost_ops:
continue
else:
unsupported_ops.add(node.op_type)

ret = aggregate_dict_keys(node_costs)
ret["unsupported"] = unsupported_ops
ret["discount_sparsity"] = discount_sparsity

total = aggregate_dict_keys(node_costs)
total["unsupported"] = unsupported_ops
total["discount_sparsity"] = discount_sparsity
ret["total_cost"] = total
if cost_breakdown:
optype_cost = {}
for optype, resources in nodes_per_optype.items():
optype_cost[optype] = aggregate_dict_keys(resources)
ret["optype_cost"] = optype_cost
ret["node_cost"] = node_costs
return ret
76 changes: 54 additions & 22 deletions src/qonnx/util/inference_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,13 @@ def compute_mem_bits_and_elems(inf_cost_dict, filter_string="mem_w"):


def inference_cost(
model_filename_or_wrapper, *, output_json=None, output_onnx=None, preprocess=True, discount_sparsity=True
model_filename_or_wrapper,
*,
output_json=None,
output_onnx=None,
preprocess=True,
discount_sparsity=True,
cost_breakdown=False
):
"""Return the inference cost estimate metric for given ONNX model.
Supports the Quant op for weight/activation quantization.
Expand All @@ -83,8 +89,8 @@ def inference_cost(
:param preprocess: If set, run preprocessing steps such as shape inference,
datatype inference and constant folding. Strongly recommended.
:param discount_sparsity: If set, will discount op cost of MAC ops with a
constant zero weight, and the mem cost of constant zero weights.
"""
constant zero weight, and the mem cost of constant zero weights."""
combined_results = {}
if isinstance(model_filename_or_wrapper, ModelWrapper):
model = model_filename_or_wrapper
else:
Expand All @@ -104,25 +110,51 @@ def inference_cost(
model = model.transform(GiveReadableTensorNames())
if output_onnx is not None:
model.save(output_onnx)
ret = model.analysis(lambda x: infca.inference_cost(x, discount_sparsity))
bops, macs = compute_bops_and_macs(ret)
mem_w_bits, mem_w_elems = compute_mem_bits_and_elems(ret, "mem_w")
mem_o_bits, mem_o_elems = compute_mem_bits_and_elems(ret, "mem_o")
ret["total_bops"] = bops
ret["total_macs"] = macs
ret["total_mem_w_bits"] = mem_w_bits
ret["total_mem_w_elems"] = mem_w_elems
ret["total_mem_o_bits"] = mem_o_bits
ret["total_mem_o_elems"] = mem_o_elems

if "unsupported" in ret:
ret["unsupported"] = str(ret["unsupported"])

if output_json is not None:
with open(output_json, "w") as f:
json.dump(ret, f, sort_keys=True, indent=2)

return ret
ret = model.analysis(lambda x: infca.inference_cost(x, discount_sparsity, cost_breakdown))
for i, res in ret.items():
if i == "total_cost":
bops, macs = compute_bops_and_macs(res)
mem_w_bits, mem_w_elems = compute_mem_bits_and_elems(res, "mem_w")
mem_o_bits, mem_o_elems = compute_mem_bits_and_elems(res, "mem_o")
res["total_bops"] = bops
res["total_macs"] = macs
res["total_mem_w_bits"] = mem_w_bits
res["total_mem_w_elems"] = mem_w_elems
res["total_mem_o_bits"] = mem_o_bits
res["total_mem_o_elems"] = mem_o_elems
if "unsupported" in res:
res["unsupported"] = str(res["unsupported"])
if output_json is not None:
with open(output_json, "w") as f:
json.dump(res, f, sort_keys=True, indent=2)
combined_results[i] = res
elif i == "optype_cost":
per_optype_breakdown = {}
for optype, op_res in res.items():
bops, macs = compute_bops_and_macs(op_res)
mem_w_bits, mem_w_elems = compute_mem_bits_and_elems(op_res, "mem_w")
mem_o_bits, mem_o_elems = compute_mem_bits_and_elems(op_res, "mem_o")
op_res["total_bops"] = bops
op_res["total_macs"] = macs
op_res["total_mem_w_bits"] = mem_w_bits
op_res["total_mem_w_elems"] = mem_w_elems
op_res["total_mem_o_bits"] = mem_o_bits
op_res["total_mem_o_elems"] = mem_o_elems
per_optype_breakdown[optype] = op_res
combined_results[i] = per_optype_breakdown
else:
per_node_breakdown = {}
for node_name in res.keys():
node_cost = res[node_name]
mem_w_bits, mem_w_elems = compute_mem_bits_and_elems(node_cost, "mem_w")
mem_o_bits, mem_o_elems = compute_mem_bits_and_elems(node_cost, "mem_o")
node_cost["total_mem_w_bits"] = mem_w_bits
node_cost["total_mem_w_elems"] = mem_w_elems
node_cost["total_mem_o_bits"] = mem_o_bits
node_cost["total_mem_o_elems"] = mem_o_elems
per_node_breakdown[node_name] = node_cost
combined_results[i] = per_node_breakdown
return combined_results


def main():
Expand Down
88 changes: 88 additions & 0 deletions tests/analysis/test_inference_cost_breakdown.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) 2024 Advanced Micro Devices, Inc.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of qonnx nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import pytest

import os
import urllib.request

from qonnx.analysis.inference_cost import aggregate_dict_keys
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.util.cleanup import cleanup
from qonnx.util.inference_cost import inference_cost as infca

download_url = "https://github.com/onnx/models/raw/main/validated/vision/"
download_url += "classification/resnet/model/resnet18-v1-7.onnx?download="

model_details = {
"resnet18-v1-7": {
"description": "Resnet18 Opset version 7.",
"url": download_url,
"enc": {
"a": "op_mac_FLOAT32_FLOAT32",
"b": "total_mem_w_bits",
"c": "total_mem_w_elems",
"d": "total_mem_o_bits",
"e": "total_mem_o_elems",
},
},
}


def download_model(test_model, do_cleanup=False, return_modelwrapper=False):
qonnx_url = model_details[test_model]["url"]
# download test data
dl_dir = "/tmp"
dl_file = dl_dir + f"/{test_model}.onnx"
ret = dl_file
if not os.path.isfile(dl_file):
urllib.request.urlretrieve(qonnx_url, dl_file)
if do_cleanup:
out_file = dl_dir + f"/{test_model}_clean.onnx"
cleanup(dl_file, out_file=out_file, override_inpsize=1)
ret = out_file
if return_modelwrapper:
ret = ModelWrapper(ret)
return ret


@pytest.mark.parametrize("test_model", model_details.keys())
def test_inference_cost_breakdown(test_model):
test_details = model_details[test_model]
model = download_model(test_model, do_cleanup=True, return_modelwrapper=True)
inf_cost = infca(model, discount_sparsity=False, cost_breakdown=True)
print(inf_cost.keys())
t_cost = inf_cost["total_cost"] # total cost
op_cost = aggregate_dict_keys(inf_cost["optype_cost"]) # cost per optype
n_cost = aggregate_dict_keys(inf_cost["node_cost"]) # cost per node.
enc = test_details["enc"]
assert t_cost[enc["a"]] == op_cost[enc["a"]] == n_cost[enc["a"]], "inf discrepancy"
assert t_cost[enc["b"]] == op_cost[enc["b"]] == n_cost[enc["b"]], "inf discrepancy"
assert t_cost[enc["c"]] == op_cost[enc["c"]] == n_cost[enc["c"]], "inf discrepancy"
assert t_cost[enc["d"]] == op_cost[enc["d"]] == n_cost[enc["d"]], "inf discrepancy"
assert t_cost[enc["e"]] == op_cost[enc["e"]] == n_cost[enc["e"]], "inf discrepancy"

0 comments on commit 7608e7c

Please sign in to comment.