Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Conformance] TorchFX/OV backends Alignment #2996

Merged
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
273 changes: 273 additions & 0 deletions nncf/experimental/torch/fx/constant_folding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
# Copyright (c) 2024 Intel Corporation
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
from typing import Any, Callable, Dict, List, Optional

import torch
import torch.utils._pytree as pytree

aten = torch.ops.aten
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved


def _replace_node_with_constant(
gm: torch.fx.GraphModule,
node: torch.fx.Node,
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved
constant: torch.Tensor,
name: Optional[str] = None,
) -> None:
g = gm.graph

if name:
qualname = name
else:
if not hasattr(gm, "_frozen_param_count"):
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved
gm._frozen_param_count = 0 # type: ignore[assignment]
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved
i = gm._frozen_param_count

while True:
qualname = f"_frozen_param{i}"
if not hasattr(gm, qualname):
break
i += 1

gm._frozen_param_count = i + 1

with g.inserting_before(node):
new_input_node = g.create_node("get_attr", qualname, (), {})
node.replace_all_uses_with(new_input_node)
new_input_node.meta.update(node.meta)
g.erase_node(node)

# needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
gm.register_buffer(qualname, constant)
setattr(gm, qualname, constant)


def _is_const_source(node: torch.fx.Node, lifted_constants: Optional[Dict[str, Any]]) -> bool:
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved
return node.op == "get_attr" or (
node.op == "placeholder" and lifted_constants is not None and node.name in lifted_constants
)


class _ConstantFolder(torch.fx.Interpreter):
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
gm: torch.fx.GraphModule,
skip_constructors: bool = False,
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved
lifted_constants: Optional[Dict[str, torch.Tensor]] = None,
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
super().__init__(gm)
self.node_replacements: Dict[torch.fx.Node, Any] = {}
self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter()
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved
self.unknown_value = object()
self.skip_constructors: bool = skip_constructors

# overwrite this to deallocate env values if their only remaining use
# is the output
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved
self.user_to_last_uses = self.node_to_last_non_output_use()
self.lifted_constants = lifted_constants

def _support_dynamic_shape(self) -> bool:
# ConstantFolder not support dynamic shape now
return False

def _deduce_value(self, node: torch.fx.Node) -> Any:
return super().run_node(node)

def is_impure(self, node: torch.fx.node.Node) -> bool:
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved
def is_woq_int8_pattern(node: torch.fx.node.Node) -> bool:
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved
return (
node.target == torch.ops.prims.convert_element_type.default # type: ignore[return-value]
and isinstance(node.args[0], torch.fx.Node)
and "val" in node.args[0].meta
and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr]
and node.args[1] == torch.bfloat16
)

if (
is_woq_int8_pattern(node)
or (
node.target == torch.ops.aten.permute.default
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved
and len(node.users) == 1
and is_woq_int8_pattern(next(iter(node.users)))
)
) and _is_const_source(
node.args[0], self.lifted_constants # type: ignore[arg-type]
):
# Case 1: int8_weight -> dq -> bf16_weight
# Case 2: int8_weight -> permute -> dq -> bf16_weight
return True

quant_registered = getattr(torch.ops.quantized_decomposed, "dequantize_per_channel", None) is not None
if quant_registered and node.target in [
torch.ops.quantized_decomposed.dequantize_per_channel.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved
]:
# For the pattern fp32_weight -> q -> dq
# We only folding fp32_weight -> q
# int8_weight and leave dq in graph to be fused
return True
return False

def node_to_last_non_output_use(self) -> Dict[torch.fx.Node, List[torch.fx.Node]]:
last_non_output_use = collections.defaultdict(list)
seen_uses = set()
output_node = next(iter(reversed(self.module.graph.nodes)))

for node in reversed(self.module.graph.nodes):
if node.target == "output":
continue

def add_use(inp: torch.fx.Node) -> None:
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved
if inp in seen_uses:
return

seen_uses.add(inp)
last_non_output_use[node].append(inp)
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved

# In-place is fine since we don't mutate
pytree.tree_map_only_(torch.fx.Node, add_use, (node.args, node.kwargs))

# if this node is only used in output, we want to gc it right away
if len(node.users) == 1 and output_node in node.users:
last_non_output_use[node].append(node)

return last_non_output_use

def run_node(self, node: torch.fx.Node) -> Any:
if node.target == "output":
# because we remove nodes from env on last non output use,
# re-define them now or we'll get error in interpreter
def set_env(arg: torch.fx.Node) -> None:
self.env[arg] = self.unknown_value

# In-place is fine since we don't mutate
pytree.tree_map_only_(torch.fx.Node, set_env, node.args)
return super().run_node(node)

args, kwargs = self.fetch_args_kwargs_from_env(node)
flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs)

# We need to do this weird thing because in cases where flattened_inputs
# contains a ScriptObject, equality checking results in a type error if
# the types are different.
if any(
type(self.unknown_value) is type(input_) and self.unknown_value == input_ for input_ in flattened_inputs
):
return self.unknown_value

# TODO - fix errors with this
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved
if node.op == "call_function" and node.target == aten._efficientzerotensor.default:
return self.unknown_value

# TODO - constant folding triton kernel returns the inputs -- fix this
if node.op == "call_function" and node.name == "triton_kernel_wrapper_functional_proxy":
return self.unknown_value

# skip constructors, since inductor generates optimal code for them already
# and turning into tensor would result in an additional global memory read
# TODO - more complicated strategy
if (
self.skip_constructors
and not _is_const_source(node, self.lifted_constants)
and not any(isinstance(e, torch.Tensor) for e in flattened_inputs)
):
return self.unknown_value

# All mutations should either be removed or on inputs which we did not make constant
if isinstance(node.target, torch._ops.OpOverload) and torch.Tag.nondeterministic_seeded in node.target.tags:
return self.unknown_value

out = self._deduce_value(node)
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved
if out == self.unknown_value:
return self.unknown_value

if not _is_const_source(node, self.lifted_constants) and isinstance(out, torch.Tensor):
if out.device.type == "meta":
return out

if not self.insertable_tensor_check(out):
return out

if self.is_impure(node):
return self.unknown_value

self.add_node_replacement(node, out)
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved

flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs)
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved

for n in flattened_node_inps:
if not isinstance(n, torch.fx.Node):
continue

self.replaced_uses[n] += 1
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved

for to_delete in self.user_to_last_uses.get(node, []):
if self.replaced_uses[to_delete] == len(to_delete.users):
self.node_replacements.pop(to_delete, None)

return out

def insertable_tensor_check(self, tensor: torch.Tensor) -> bool:
return True
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved

def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None:
self.node_replacements[node] = tensor

def run(self) -> Any: # type: ignore[override]
env: Dict[torch.fx.Node, Any] = {}
self.insert_placerholder_values(env)
return super().run(initial_env=env)

def insert_placerholder_values(self, env: Dict[torch.fx.Node, Any]) -> None:
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved
for n in self.module.graph.find_nodes(op="placeholder"):
if self.lifted_constants is not None and n.name in self.lifted_constants:
env[n] = self.lifted_constants[n.name]
else:
env[n] = self.unknown_value # type: ignore[assignment]


def constant_fold(
gm: torch.fx.GraphModule,
constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
) -> None:
"""
Calcualtes constant subgraphs values and replaces them with a constant node inplace.
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved

:param gm: Given graph model.
:param constraint_fn: Constraint function which takes a node and returs the constraint:
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved
should the node be constant folded or not.
"""
with torch.utils._python_dispatch._disable_current_modes():
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved
cf = _ConstantFolder(gm, skip_constructors=True)
cf.run()

for node, constant in cf.node_replacements.items():
if constraint_fn is not None and not constraint_fn(node):
continue
_replace_node_with_constant(gm, node, constant)
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved

erased_params = []
for node in gm.graph.find_nodes(op="get_attr"):
if len(node.users) == 0:
if hasattr(gm, node.target):
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved
delattr(gm, node.target)
erased_params.append(node)

for node in erased_params:
gm.graph.erase_node(node)
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved

gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()
26 changes: 22 additions & 4 deletions nncf/experimental/torch/fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# limitations under the License.

from copy import copy
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
Expand All @@ -23,6 +22,7 @@
import nncf.torch
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.transformations.commands import TargetType
from nncf.experimental.torch.fx.constant_folding import constant_fold
from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name
from nncf.experimental.torch.fx.node_utils import get_tensor_constant_from_node
from nncf.torch.graph.transformations.commands import PTTargetPoint
Expand Down Expand Up @@ -669,13 +669,17 @@ def _compress_qdq_constant_transformation(model: torch.fx.GraphModule, matches)
for match in matches:
mul_node = match.replacements[0]
sub_node = match.replacements[1]
weight_node, scale_node, zp_node, axis = None, None, None, None
nodes_map = {node.name: match.nodes_map[node] for node in match.nodes_map}
get_const = partial(get_tensor_constant_from_node, model=model)

def get_const(arg: Union[torch.fx.Node, float, int]):
if isinstance(arg, torch.fx.Node):
return get_tensor_constant_from_node(arg, model)
return arg

weight_node = get_const(nodes_map["weight"])
scale_node = get_const(nodes_map["scale"])
zp_node = get_const(nodes_map["zero_point"])
axis = nodes_map["axis"]
axis = nodes_map.get("axis")
Copy link
Contributor

@anzr299 anzr299 Oct 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can use the same function for axis too.

Suggested change
axis = nodes_map.get("axis")
axis = get_const(nodes_map.get("axis"))

Copy link
Collaborator Author

@daniil-lyakhov daniil-lyakhov Oct 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, per tensor case has no "axis" key, [] raises key error. Intended axis value for per tensor case - None. .get returns None https://docs.python.org/3/library/stdtypes.html#dict.get

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry about that, I have updated the suggestion. I meant to say that we can pass this also to the get_const funciton to keep it the same as others

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

port_id = 0
if axis is not None:
result = torch.ops.quantized_decomposed.quantize_per_channel.default(
Expand Down Expand Up @@ -788,12 +792,26 @@ def apply_quantization_transformations(model: torch.fx.GraphModule) -> None:
# to make it easier for algorithms to work
# with the target graph BatchNorm operations
# are being fused
fold_constant_except_qdq(model)
fuse_conv_bn(model)
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved
separate_conv_and_bias(model)
separate_linear_and_bias(model)
shared_constants_unification_transformation(model)


def fold_constant_except_qdq(model: torch.fx.GraphModule):
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved
"""
Performs constant folding avoiding quantize-dequantize pattern.

:param model: Model to perform constant folding on.
"""

def constraint_fn(node: torch.fx.Node):
return node.op != "call_function" or node.target not in QUANTIZE_NODE_TARGETS + DEQUANTIZE_NODE_TARGETS

constant_fold(model, constraint_fn=constraint_fn)


def revert_quantization_transformations(model: torch.fx.GraphModule) -> None:
"""
Reverts quantization transformations from the model.
Expand Down
3 changes: 3 additions & 0 deletions tests/post_training/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,9 @@ def save_compressed_model(self) -> None:
ov.serialize(ov_model, self.path_compressed_ir)
elif self.backend in OV_BACKENDS:
self.path_compressed_ir = self.output_model_dir / "model.xml"
from openvino._offline_transformations import apply_moc_transformations

apply_moc_transformations(self.compressed_model, cf=True)
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved
ov.serialize(self.compressed_model, str(self.path_compressed_ir))

def get_num_compressed(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def _export_graph_module(model: torch.nn.Module, args: Tuple[Any, ...]) -> torch
class VisionModelParams:
weights: models.WeightsEnum
export_fn: Callable[[torch.nn.Module, Tuple[Any, ...]], torch.fx.GraphModule]
export_torch_before_ov_convert: bool = False


class ImageClassificationTorchvision(ImageClassificationBase):
Expand All @@ -47,8 +48,12 @@ class ImageClassificationTorchvision(ImageClassificationBase):
models.mobilenet_v3_small: VisionModelParams(
models.MobileNet_V3_Small_Weights.DEFAULT, _capture_pre_autograd_module
),
models.vit_b_16: VisionModelParams(models.ViT_B_16_Weights.DEFAULT, _export_graph_module),
models.swin_v2_s: VisionModelParams(models.Swin_V2_S_Weights.DEFAULT, _export_graph_module),
models.vit_b_16: VisionModelParams(
models.ViT_B_16_Weights.DEFAULT, _export_graph_module, export_torch_before_ov_convert=True
),
models.swin_v2_s: VisionModelParams(
models.Swin_V2_S_Weights.DEFAULT, _export_graph_module, export_torch_before_ov_convert=True
),
}

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -92,9 +97,10 @@ def prepare_model(self) -> None:

elif self.backend in [BackendType.OV, BackendType.FP32]:
with torch.no_grad():
with disable_patching():
m = torch.export.export(model, args=(self.dummy_tensor,))
self.model = ov.convert_model(m, example_input=self.dummy_tensor, input=self.input_size)
if self.model_params.export_torch_before_ov_convert:
with disable_patching():
model = torch.export.export(model, (self.dummy_tensor,))
self.model = ov.convert_model(model, example_input=self.dummy_tensor, input=self.input_size)
self.input_name = list(inp.get_any_name() for inp in self.model.inputs)[0]

self._dump_model_fp32()
Expand Down
Loading
Loading