Skip to content

Commit

Permalink
[GraphQnt] some cleanup and renaming
Browse files Browse the repository at this point in the history
  • Loading branch information
maltanar committed Feb 5, 2024
1 parent 2feab84 commit 3e132fe
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -180,32 +180,30 @@ def adjust_graph(self, model, input_positions, node_in_focus, quantized_nodes, n
return model


class IntroduceQuantnode(Transformation):
"""This transformation can be used to introduce a Quant node for a specific type of node in the graph.
Users would be able to specify the location of the quant node by providing the input and output indexs
as the parameters.
class QuantizeGraph(Transformation):
"""This transformation can be used to introduce a Quant node for particular nodes in the graph,
determined based on either op_type or node name.
For the particular nodes identified, users can specify the location of the Quant nodes by providing
the input and output indices where Quant nodes are to be inserted.
Assumes the input model is cleaned-up with all intermediate shapes specified and nodes given
unique names already.
1) Expectations:
a) Onnx model in the modelwraper format.
b) Model must be cleaned using cleanup_model qonnx.util.cleanup.cleanup_model()
c) Batchsize to be set.
2) Steps to transform are
Step1: Finding the input for the quant node.
Step2: Finding the consumer of the quant node output.
Step3: Finding the shape for the output tensor of quant node.
Note: The output tensor of the quant node must have the same shape as the
consumer of the input to the quant node.
2) Steps to transform are
Step1: Finding the input for the quant node.
Step2: Finding the consumer of the quant node output.
Step3: Finding the shape for the output tensor of quant node.
Note: The output tensor of the quant node must have the same shape as the
consumer of the input to the quant node.
3) Introduction to quantnodes will be done with precedence to "Name" in comparison to "op_type".
3) Introduction to quantnodes will be done with precedence to "Name" in comparison to "op_type".
4) Assert:
a) The input is a dictionary representing the node names as keys and a list of quant positions
as values.
b) The input dictionary must have atleast one mac node (Conv, gemm, matmul) for the transformation.
4) Assert:
a) The input is a dictionary representing the node names as keys and a list of quant positions
as values.
b) The input dictionary must have atleast one mac node (Conv, gemm, matmul) for the transformation.
5) Return:
Returns a cleaned version of the model.
5) Return:
Returns a cleaned version of the model.
"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,21 @@
import urllib.request

from qonnx.core.modelwrapper import ModelWrapper
from qonnx.transformation.introduce_quantnode import IntroduceQuantnode, graph_util
from qonnx.transformation.quantize_graph import QuantizeGraph, graph_util
from qonnx.util.cleanup import cleanup
from qonnx.util.inference_cost import inference_cost

random.seed(42)

graph_util = graph_util()

a = "https://github.com/onnx/models/raw/main/validated/vision/"
b = "classification/resnet/model/resnet18-v1-7.onnx?download="
download_url = "https://github.com/onnx/models/raw/main/validated/vision/"
download_url += "classification/resnet/model/resnet18-v1-7.onnx?download="

model_details = {
"resnet18-v1-7": {
"description": "Resnet18 Opset version 7.",
"url": (a + b),
"url": download_url,
"test_input": {
"name": {
"Conv_0": [
Expand Down Expand Up @@ -124,7 +124,7 @@ def test_introduce_quantnode(test_model):
model = download_model(test_model, do_cleanup=True, return_modelwrapper=True)
original_model_inf_cost = inference_cost(model, discount_sparsity=False)
nodes_pos = test_details["test_input"]
model = model.transform(IntroduceQuantnode(nodes_pos))
model = model.transform(QuantizeGraph(nodes_pos))
quantnodes_added = len(model.get_nodes_by_op_type("Quant"))
assert quantnodes_added == 10 # 10 positions are specified.
verification = to_verify(model, nodes_pos)
Expand Down

0 comments on commit 3e132fe

Please sign in to comment.