Skip to content

Commit

Permalink
fix onnx_model
Browse files Browse the repository at this point in the history
Signed-off-by: Mengni Wang <[email protected]>
  • Loading branch information
mengniwang95 committed Jul 19, 2024
1 parent ccc12b0 commit 902b4fd
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 46 deletions.
50 changes: 22 additions & 28 deletions onnx_neural_compressor/algorithms/layer_wise/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def layer_wise_quant(
split_model = model_to_split.pop(0)
split_node = split_nodes.pop(0)
if require_data_reader:
current_data_reader = lwq_data_reader.pop(0)
complete_data_reader = lwq_data_reader.pop(0)

# if no remaining split nodes, it means this is the last split, and the two split models will be saved.
save_both_split_models = True if len(split_nodes) == 0 else False
Expand All @@ -114,17 +114,19 @@ def layer_wise_quant(
model_to_split.append(split_model_part_2)

logger.info("Quantize split model {}".format(split_idx))

if require_data_reader:
# process data_reader for current split and next split

current_data_reader = _filter_data_reader_for_current_split_model(
split_model_part_1.model, current_data_reader, data_reader
split_model_part_1.model, complete_data_reader
)

# next_data_reader contains split_model_part_1 output data
next_data_reader = _prepare_data_reader_for_next_split_model(
split_model_part_1.model_path, current_data_reader, providers
complete_data_reader = _prepare_data_reader_for_next_split_model(
split_model_part_1.model_path, [i.name for i in split_model_part_2.model.graph.input], complete_data_reader, providers
)
lwq_data_reader.append(next_data_reader)

lwq_data_reader.append(complete_data_reader)

# perform quantization
split_model_part_1_quantized = quant_func(
Expand All @@ -142,7 +144,7 @@ def layer_wise_quant(

# check split model is valid
try:
ort.InferenceSession(split_model_part_1_quantized.model.SerializeToString(), providers=providers)
ort.InferenceSession(split_model_part_1_quantized.model_path, providers=providers)
except Exception as e:
logger.error(
"Layer-wise quantized model {} can't be inferred correctly. "
Expand All @@ -167,7 +169,7 @@ def layer_wise_quant(
# process data_reader for current split
current_data_reader = lwq_data_reader.pop(0)
current_data_reader = _filter_data_reader_for_current_split_model(
split_model_part_2.model, current_data_reader, data_reader
split_model_part_2.model, complete_data_reader
)

# perform quantization
Expand All @@ -186,7 +188,7 @@ def layer_wise_quant(

# check split model is valid
try:
ort.InferenceSession(split_model_part_2_quantized.model.SerializeToString(), providers=providers)
ort.InferenceSession(split_model_part_2_quantized.model_path, providers=providers)
except Exception as e:
logger.error(
"Layer-wise quantized model {} can't be inferred correctly. "
Expand Down Expand Up @@ -225,22 +227,19 @@ def rewind(self):
def _filter_data_reader_for_current_split_model(
model: onnx.ModelProto,
current_data_reader: data_reader.CalibrationDataReader,
data_reader: data_reader.CalibrationDataReader,
):
"""Filter data reader to remove data that is not in model input.
Args:
model (onnx.ModelProto): onnx model.
current_data_reader (data_reader.CalibrationDataReader): data reader of current split model.
data_reader (data_reader.CalibrationDataReader): data reader of the original model.
Returns:
data_reader.CalibrationDataReader: filtered data reader.
"""
filter_inputs = []
input_names = [input.name for input in model.graph.input]
current_data_reader.rewind()
data_reader.rewind()

while True:
inputs = current_data_reader.get_next()
Expand All @@ -251,22 +250,12 @@ def _filter_data_reader_for_current_split_model(
}
filter_inputs.append(filter_input)

idx = 0
while True:
inputs = data_reader.get_next()
if not inputs:
break
filter_input = {
input_name: input_tensor for input_name, input_tensor in inputs.items() if input_name in input_names
}
if len(filter_input) > 0:
filter_inputs[idx].update(filter_input)
idx += 1
return DataReader(filter_inputs)


def _prepare_data_reader_for_next_split_model(
model_path: str,
next_model_input_names: list,
data_reader: data_reader.CalibrationDataReader,
providers: List[str] = ["CPUExecutionProvider"],
):
Expand All @@ -282,16 +271,21 @@ def _prepare_data_reader_for_next_split_model(
Returns:
data_reader.CalibrationDataReader: data reader for next split model.
"""
data_reader = copy.deepcopy(data_reader)

data_reader.rewind()
data_reader_for_next_split_model = []
session = ort.InferenceSession(model_path, providers=providers)
output_names = [output.name for output in session.get_outputs()]
input_names = [input.name for input in session.get_inputs()]
while True:
inputs = data_reader.get_next()
if not inputs:
break
out = session.run(None, inputs)
inputs.update({name: value for name, value in zip(output_names, out)})
data_reader_for_next_split_model.append(inputs)
out = session.run(None, {name: inputs[name] for name in input_names})
filter_input = {
name: value for name, value in zip(output_names, out)
}
for name, value in inputs.items():
if name in next_model_input_names and name not in filter_input:
filter_input[name] = value
data_reader_for_next_split_model.append(filter_input)
return DataReader(data_reader_for_next_split_model)
50 changes: 32 additions & 18 deletions onnx_neural_compressor/onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def is_graph_output(self, name):
def save(self, root):
"""Save ONNX model."""
if os.path.split(root)[0] != "" and not os.path.exists(os.path.split(root)[0]):
raise ValueError('"root" directory does not exists.')
os.mkdir(os.path.split(root)[0])
if self.is_large_model: # pragma: no cover
onnx.external_data_helper.load_external_data_for_model(self.model, os.path.split(self._model_path)[0])
onnx.save_model(
Expand Down Expand Up @@ -897,30 +897,44 @@ def split_model_with_node(self, split_node_name, path_of_model_to_split, save_bo
split_model_part_2.CopyFrom(self.model)
split_model_part_2.graph.ClearField("node")

split_node_output = None
part_idx = 1
split_node = None
nodes = []
for node in self.model.graph.node:
if part_idx == 1:
split_model_part_1.graph.node.append(node)
elif part_idx == 2:
split_model_part_2.graph.node.append(node)
nodes.append(node)

if node.name == split_node_name:
split_node_output = node.output
part_idx = 2
split_node = node
break

assert len(split_node_output) == 1, (
assert len(split_node.output) == 1, (
"Only support split at node with 1 output tensor, while "
"current split node {} has {} output tensors".format(split_node_name, len(split_node_output))
"current split node {} has {} output tensors".format(split_node_name, len(split_node.output))
)
split_tensor_name = split_node_output[0]
split_tensor_name = split_node.output[0]

split_tensor = self._build_input_output_tensor(split_tensor_name, value_info)

split_model_part_1.graph.node.extend(nodes)
split_model_part_1.graph.output.append(split_tensor)
split_model_part_2.graph.input.append(split_tensor)

split_model_part_1 = ONNXModel(split_model_part_1, ignore_warning=True)

# remove isolated graphs which are not related to the split_node
output_name_to_node = split_model_part_1.output_name_to_node()
valid_nodes = [split_node]
while len(valid_nodes) > 0:
node = valid_nodes.pop(0)
for inp in node.input:
if inp in output_name_to_node:
valid_nodes.append(output_name_to_node[inp])
if node in nodes:
nodes.remove(node)
split_model_part_1.remove_nodes(nodes)

for node in self.model.graph.node:
if node not in split_model_part_1.nodes():
split_model_part_2.graph.node.append(node)

split_model_part_2.graph.input.append(split_tensor)
split_model_part_2 = ONNXModel(split_model_part_2, ignore_warning=True)

# remove unused input & output
Expand Down Expand Up @@ -994,14 +1008,14 @@ def _remove_unused_input_output(self):
"""Remove unused input & output for split model."""
remove_outputs = []
remove_inputs = []
if len(self._input_name_to_nodes) == 0:
self._input_name_to_nodes = self.input_name_to_nodes()
input_name_to_nodes = self.input_name_to_nodes()
output_name_to_node = self.output_name_to_node()
for output in self.model.graph.output:
if output.name not in self._output_name_to_node.keys():
if output.name not in output_name_to_node.keys():
remove_outputs.append(output)

for input in self.model.graph.input:
if input.name not in self._input_name_to_nodes.keys():
if input.name not in input_name_to_nodes.keys():
remove_inputs.append(input)

for output in remove_outputs:
Expand Down

0 comments on commit 902b4fd

Please sign in to comment.