Skip to content

Commit

Permalink
[TorchFX] SmoothQuant algorithm implementation (#2875)
Browse files Browse the repository at this point in the history
### Changes

TorchFX SmoothQuant backend implementation
*  module_insertion_transformation_builder is introduced
* Transformation requires names for new modules and nodes
* vit_b_16 is introduced in the conformance tests
### Reason for changes

To improve metrics of quantized models: swin_v2_s and vit_b_16
* To insert SQ multiply nodes to the graph
* To make node names human-readable and consistent
* To check sq algorithm E2E

### Related tickets

#2766

### Tests

* Smooth quant test template is implemented for TorchfX backed
* Conformance test: post_training_quantization/446/ is successfull
* Test models check SQ multiplies for swin_v2_s and vit_b_16 models
  • Loading branch information
daniil-lyakhov authored Aug 16, 2024
1 parent 9a0b5d2 commit 7744ebf
Show file tree
Hide file tree
Showing 20 changed files with 9,465 additions and 8,605 deletions.
13 changes: 7 additions & 6 deletions nncf/experimental/torch/fx/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph:
output_port_id=output_port_id,
dtype=Dtype.FLOAT,
)

return nncf_graph

@staticmethod
Expand All @@ -121,22 +120,24 @@ def get_edge_params(
edge tensor shape.
"""
output_port_id = 0
tensor_shape = None
if source_node.op in ("get_attr",):
tensor_shape = tuple(getattr(model, source_node.target).shape)
elif "val" in source_node.meta:
if source_nncf_node.metatype is om.PTBatchNormMetatype:
tensor = source_node.meta["val"][0]
elif source_nncf_node.metatype is om.PTSplitMetatype:
elif source_nncf_node.metatype in [om.PTSplitMetatype, om.PTMaxMetatype, om.PTMinMetatype]:
tensor = source_node.meta["val"][output_idx]
# Assume every split outputs corresponds to an unique output_port_id
# Assume every outputs corresponds to an unique output_port_id
output_port_id = output_idx
else:
tensor = source_node.meta["val"]
tensor_shape = tuple(tensor.shape)
else:
if isinstance(tensor, torch.Tensor):
tensor_shape = tuple(tensor.shape)

if tensor_shape is None:
# TODO(dlyakhov): Refactor algorithms to always have knowns edges shapes.
nncf_logger.debug(f"Edge shape between {source_node.name} and {dist_node.name} is unknown.")
tensor_shape = None

input_port_id = dist_node.all_input_nodes.index(source_node)
return input_port_id, output_port_id, tensor_shape
23 changes: 22 additions & 1 deletion nncf/experimental/torch/fx/node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.fx


# TODO(dlyakhov): Use torch.fx.graph.find_nodes method instead after
Expand All @@ -28,3 +28,24 @@ def get_graph_node_by_name(graph: torch.fx.Graph, name: str) -> torch.fx.Node:
if node.name == name:
return node
raise RuntimeError(f"Node with name {name} is not found")


def get_tensor_constant_from_node(constant_node: torch.fx.Node, model: torch.fx.GraphModule) -> torch.nn.Parameter:
"""
Retrieves tensor from the given constant node.
:param constant_node: Given constant node.
:param model: Given model.
:return: Torch tensor referenced by the given constant node.
"""
if constant_node is None:
return None
if constant_node.op != "get_attr":
raise RuntimeError(f"Given node op == {constant_node.op}, but get_attr is expected.")
target_atoms = constant_node.target.split(".")
attr_itr = model
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
attr_itr = getattr(attr_itr, atom)
return attr_itr
24 changes: 23 additions & 1 deletion nncf/experimental/torch/fx/statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand
from nncf.experimental.torch.fx.transformations import leaf_module_insertion_transformation_builder
from nncf.tensor import Tensor
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.return_types import maybe_get_values_from_torch_return_type

Expand Down Expand Up @@ -65,6 +66,24 @@ def collect_statistics(self, model: NNCFNetwork, graph: NNCFGraph) -> None:
def _register_statistics(self, outputs: Dict[str, Tensor], statistic_points: StatisticPointsContainer) -> None:
return

@staticmethod
def _get_statistic_collector_name(tp: PTTargetPoint, module_to_insert: torch.nn.Module) -> str:
"""
Compouses unique statistic collector name according to given target point and module.
:param tp: Given target point.
:param module_to_insert: Given statistic collection module.
:return: Unique statistic collector name according to given target point and module.
"""
return "_".join(
[
tp.target_node_name,
str(tp.input_port_id),
str(tp.target_type.value),
str(id(module_to_insert)),
]
)

def _get_transformation_layout_extra_outputs(
self, statistic_points: StatisticPointsContainer
) -> TransformationLayout:
Expand All @@ -75,8 +94,11 @@ def _get_transformation_layout_extra_outputs(
for _statistic_point in _statistic_points:
for collectors in _statistic_point.algorithm_to_tensor_collectors.values():
for collector in collectors:
tp = _statistic_point.target_point
module_to_insert = TensorCollectorModule(collector)
target_module_name = self._get_statistic_collector_name(tp, module_to_insert)
transformation = leaf_module_insertion_transformation_builder(
TensorCollectorModule(collector), [_statistic_point.target_point]
module_to_insert, [tp], target_module_name
)
transformation_commands.append(
FXApplyTransformationCommand(
Expand Down
162 changes: 122 additions & 40 deletions nncf/experimental/torch/fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,49 +14,108 @@
import torch
import torch.fx
from torch.ao.quantization.fx.utils import create_getattr_from_value
from torch.ao.quantization.pt2e.utils import _get_tensor_constant_from_node
from torch.ao.quantization.pt2e.utils import fold_bn_weights_into_conv_node
from torch.quantization.fake_quantize import FakeQuantize

import nncf
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.transformations.commands import TargetType
from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name
from nncf.experimental.torch.fx.node_utils import get_tensor_constant_from_node
from nncf.torch.graph.transformations.commands import PTTargetPoint

TransformationFNType = Callable[[torch.fx.GraphModule], None]


def _set_new_node_meta(new_node: torch.fx.Node, prev_node: torch.fx.Node, target_module: torch.nn.Module):
"""
Sets correct meta \"val\" value to the new node.
:param new_node: The new node.
:param prev_node: Input node of the new node.
New node expected to have only one input node.
:param target_module: Module which is being called by the new node.
"""
val = prev_node.meta["val"]
val = val if isinstance(val, tuple) else (val,)
retval = []
for t in val:
retval.append(torch.ones(t.shape))

with torch.no_grad():
new_node.meta["val"] = target_module(*val)


def module_insertion_transformation_builder(
module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint], target_module_name: str
) -> TransformationFNType:
"""
Returns transformation which inserts given module to a target model
and calls given module after each target points replacing inputs/outputs
of the target node.
:param module_to_insert: Given torch.nn.Module to insert.
:param target_points: Target points to insert the target module.
:param target_module_name: Target model attribute name for the module_to_insert.
:returns: Transformation which which inserts given module to a target model
and calls given module after each target points.
"""

def module_insertion_transformation(model: torch.fx.GraphModule):
module_attr_name = _set_module_to_the_graph_module(model, module_to_insert, target_module_name)
# Insert call_module nodes to the model
graph = model.graph
for idx, target_point in enumerate(target_points):
new_node = _insert_call_module(graph, target_point, module_attr_name, f"{module_attr_name}_{idx}")
target_node = get_graph_node_by_name(graph, target_point.target_node_name)

if target_point.target_type == TargetType.OPERATOR_POST_HOOK:
_set_new_node_meta(new_node, target_node, module_to_insert)
with graph.inserting_after(target_node):
for user in target_node.users:
if user is new_node:
continue
user.replace_input_with(target_node, new_node)

else:
prev_node = target_node.args[target_point.input_port_id]
_set_new_node_meta(new_node, prev_node, module_to_insert)
target_node.replace_input_with(prev_node, new_node)

return module_insertion_transformation


def leaf_module_insertion_transformation_builder(
module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint]
module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint], target_module_name: str
) -> TransformationFNType:
"""
Returns transformation which inserts given module to a target model
and calls given module after each target points.
:param module_to_insert: Given torch.nn.Module to insert.
:param target_points: Target points to insert the target module.
:param target_module_name: Target model attribute name for the module_to_insert.
:returns: Transformation which which inserts given module to a target model
and calls given module after each target points.
"""

def leaf_module_insertion_transformation(model: torch.fx.GraphModule):
module_attr_name = _set_module_to_the_graph_module(model, module_to_insert, target_points)
module_attr_name = _set_module_to_the_graph_module(model, module_to_insert, target_module_name)
# Insert call_module nodes to the model
graph = model.graph
for target_point in target_points:
_insert_call_module(graph, target_point, module_attr_name)
for idx, target_point in enumerate(target_points):
_insert_call_module(graph, target_point, module_attr_name, f"{module_attr_name}_{idx}")

return leaf_module_insertion_transformation


def bias_update_transformation_builder(node: NNCFNode, value: torch.Tensor) -> TransformationFNType:
"""
Return transformation which updates constant of the given bias node to the given value.
Return transformation which updates constant of the given node with bias to the given value.
:param node: Bias node which requires bias constant update.
:param node: Node with bias which requires bias constant update.
:param value: New value to use as the bias constant.
:return: Transformation which updates constant of the given bias node to the given value.
:return: Transformation which updates constant of the given node with bias to the given value.
"""

def bias_update_transformation(model: torch.fx.GraphModule):
Expand All @@ -67,18 +126,51 @@ def bias_update_transformation(model: torch.fx.GraphModule):
raise nncf.InternalError(f"Node with bias have {len(graph_node.users)} users, 1 expected.")

bias_node = next(iter(graph_node.users))
with graph.inserting_before(bias_node):
new_constant = create_getattr_from_value(model, graph, target_node_name + "_shifted_bias", value)

args = list(bias_node.args)
# A bias node suppose to have constant on the second input port.
args[1] = new_constant
bias_node.args = tuple(args)
graph.eliminate_dead_code()
constant_update_fn(model, bias_node, value, input_port_id=1)

return bias_update_transformation


def constant_update_transformation_builder(node: NNCFNode, value: torch.Tensor) -> TransformationFNType:
"""
Return transformation which updates constant of the given node to the given value.
:param node: Node which requires bias constant update.
:param value: New value to use as the node constant.
:return: Transformation which updates constant of the given node to the given value.
"""

def constant_update_transformation(model: torch.fx.GraphModule):
constant_update_fn(model, get_graph_node_by_name(model.graph, node.node_name), value, input_port_id=1)

return constant_update_transformation


def constant_update_fn(model: torch.fx.GraphModule, node: torch.fx.Node, value: torch.Tensor, input_port_id: int = 1):
"""
Updates constant of given node on the given input port id with given value.
:param model: Target torch GraphModule.
:param node: Given graph node.
:param value: New value to use as the node constant.
:param input_port_id: Target constant input port id.
"""
graph = model.graph
with graph.inserting_before(node):
new_constant = create_getattr_from_value(model, graph, node.name + "_updated_constant", value)

args = list(node.args)
# A bias node suppose to have constant on the second input port.
if args[input_port_id].op != "get_attr":
raise nncf.InternalError(
f"Constant on input port {input_port_id} for {node} is expected,"
f" but node {args[input_port_id]} is present."
)
args[input_port_id] = new_constant
node.args = tuple(args)
graph.eliminate_dead_code()


def qdq_insertion_transformation_builder(
quantizer: FakeQuantize, target_points: List[PTTargetPoint]
) -> TransformationFNType:
Expand Down Expand Up @@ -200,25 +292,23 @@ def insert_one_qdq(model: torch.fx.GraphModule, target_point: PTTargetPoint, qua
raise nncf.InternalError(f"Unexpected target type: {target_point.target_type}")


def _insert_call_module(graph: torch.fx.Graph, target_point: PTTargetPoint, module_attr_name: str):
def _insert_call_module(
graph: torch.fx.Graph, target_point: PTTargetPoint, module_attr_name: str, graph_node_name: str
):
"""
Inserts module call node to the graph after the target node.
:param graph: Graph to insert module call node.
:param target_node: Target node, module call node is being iserted just after the target node.
:param module_attr_name: The name of the graph attribute which keeps the target module.
:param graph_node_name: Target name for module call node.
:return: Inserted module call node.
"""
target_node = get_graph_node_by_name(graph, target_point.target_node_name)
input_node = get_input_node(target_point, target_node)
ctx_manager = get_ctx_manager(graph, target_point)
with ctx_manager(target_node):
return graph.create_node(
"call_module",
module_attr_name,
(input_node,),
{},
name=f"{module_attr_name}_{str(target_point.target_type)}_graph_node",
)
return graph.create_node("call_module", module_attr_name, (input_node,), {}, name=graph_node_name)


def get_input_node(target_point: PTTargetPoint, target_node: torch.fx.Node) -> torch.fx.Node:
Expand Down Expand Up @@ -264,26 +354,18 @@ def get_ctx_manager(graph: torch.fx.Graph, target_point: PTTargetPoint) -> Calla


def _set_module_to_the_graph_module(
model: torch.fx.GraphModule, module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint]
model: torch.fx.GraphModule,
module_to_insert: torch.nn.Module,
module_name_in_model: str,
) -> str:
"""
Sets given module to the given torch.fx.GraphModule with unique name.
:param graph: Target torch.fx.Graph.
:param module_to_insert: Module to insert to the target graph.
:param target_points: Target points which will be used to insert target module
to the graph.
:param module_name_in_model: Target model attribute name for the module_to_insert.
:return: A graph module attribute name which keep given module.
"""
module_to_insert = module_to_insert
# TODO(dlyakhov) Make module name human readable.
module_name_in_model = (
"__".join(
"_".join((tp.target_node_name, str(tp.input_port_id), str(tp.target_type.value))) for tp in target_points
)
+ "_"
+ str(id(module_to_insert))
)
assert not hasattr(model, module_name_in_model)
setattr(model, module_name_in_model, module_to_insert)
return module_name_in_model
Expand Down Expand Up @@ -397,7 +479,7 @@ def separate_linear_and_bias(model: torch.fx.GraphModule):
while linear_bias_node.op != "get_attr":
# Assume zero argument is on a path to the constant
linear_bias_node = linear_bias_node.args[0]
linear_bias_value = _get_tensor_constant_from_node(linear_bias_node, model)
linear_bias_value = get_tensor_constant_from_node(linear_bias_node, model)
args = list(n.args)
args[2] = None
linear_node.args = tuple(args)
Expand Down Expand Up @@ -436,9 +518,9 @@ def separate_conv_and_bias(model: torch.fx.GraphModule):
if len(n.args) < 3 or n.args[2] is None:
continue
conv_node = n
dims = len(_get_tensor_constant_from_node(conv_node.args[1], model).shape)
dims = len(get_tensor_constant_from_node(conv_node.args[1], model).shape)
conv_bias_node = conv_node.args[2]
conv_bias_value = _get_tensor_constant_from_node(conv_bias_node, model)
conv_bias_value = get_tensor_constant_from_node(conv_bias_node, model)
args = list(n.args)
args[2] = None
conv_node.args = tuple(args)
Expand Down Expand Up @@ -502,7 +584,7 @@ def _merge_node_and_bias(model: torch.fx.GraphModule, is_target_node: Callable[[
const_node = node
break
assert const_node is not None
bias_value = _get_tensor_constant_from_node(const_node, model).squeeze()
bias_value = get_tensor_constant_from_node(const_node, model).squeeze()
with model.graph.inserting_before(conv_node):
new_bias_node = create_getattr_from_value(model, model.graph, const_node.name + "_", bias_value)
args = list(conv_node.args)
Expand Down
Loading

0 comments on commit 7744ebf

Please sign in to comment.