Skip to content

Commit

Permalink
Fix constant folding in CPU constant case
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Nov 27, 2024
1 parent 6ac12b2 commit 92ebbf2
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 0 deletions.
17 changes: 17 additions & 0 deletions nncf/experimental/torch/fx/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,20 @@ def insert_placerholder_values(self, env: Dict[torch.fx.Node, Any]) -> None:
env[n] = self.unknown_value # type: ignore[assignment]


def get_model_device(model: torch.fx.GraphModule) -> torch.device:
"""
Returns device of the first model parameter of torch.device("cpu").
:param model: GraphModule instance.
:return: Device of the first model parameter of torch.device("cpu").
"""
try:
device = next(model.parameters()).device
except StopIteration:
return torch.device("cpu")
return device


def constant_fold(
gm: torch.fx.GraphModule,
constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
Expand All @@ -233,13 +247,16 @@ def constant_fold(
:param constraint_fn: Constraint function which takes a node and returs the constraint:
should the node be constant folded or not.
"""

with torch.utils._python_dispatch._disable_current_modes():
cf = ConstantFolder(gm)
cf.run()

device = get_model_device(gm)
for node, constant in cf.node_replacements.items():
if constraint_fn is not None and not constraint_fn(node):
continue
constant = constant.to(device)
_replace_node_with_constant(gm, node, constant)

erased_params = []
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
strict digraph {
"0 linear_weight" [id=0, type=get_attr];
"1 linear_bias" [id=1, type=get_attr];
"2 lifted_tensor_0" [id=2, type=get_attr];
"3 x" [id=3, type=input];
"4 lift_fresh_copy" [id=4, type=lift_fresh_copy];
"5 detach_" [id=5, type=detach_];
"6 _frozen_param0" [id=6, type=get_attr];
"7 linear" [id=7, type=linear];
"8 add" [id=8, type=add];
"9 output" [id=9, type=output];
"0 linear_weight" -> "7 linear" [label="(3, 3)", style=solid];
"1 linear_bias" -> "7 linear" [label="(3,)", style=solid];
"2 lifted_tensor_0" -> "4 lift_fresh_copy" [label="()", style=solid];
"3 x" -> "7 linear" [label="(1, 3, 3, 3)", style=solid];
"4 lift_fresh_copy" -> "5 detach_" [label="()", style=solid];
"6 _frozen_param0" -> "8 add" [label="()", style=solid];
"7 linear" -> "8 add" [label="(1, 3, 3, 3)", style=solid];
"8 add" -> "9 output" [label="(1, 3, 3, 3)", style=solid];
}
16 changes: 16 additions & 0 deletions tests/torch/fx/test_model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from tests.torch.test_models.synthetic import ConvolutionWithNotTensorBiasModel
from tests.torch.test_models.synthetic import MultiBranchesConnectedModel
from tests.torch.test_models.synthetic import MultiBranchesConnectedModelWithConcat
from tests.torch.test_models.synthetic import ScalarCloneTestModel


@dataclass
Expand Down Expand Up @@ -509,6 +510,21 @@ def test_constant_folding():
check_graph(nncf_graph, "folded_model.dot", TRANSFORMED_GRAPH_DIR_NAME, extended=True)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Cuda is not available")
def test_constant_folding_scalar_clone():
model = ScalarCloneTestModel().cuda()
captured_model = get_torch_fx_model(model, torch.ones(model.INPUT_SIZE))
assert captured_model.lifted_tensor_0.device == torch.device("cpu")

folded_model = deepcopy(captured_model)
constant_fold(folded_model)
ex_input = torch.ones(model.INPUT_SIZE).cuda()
assert torch.allclose(captured_model(ex_input), folded_model(ex_input))

nncf_graph = GraphConverter.create_nncf_graph(folded_model)
check_graph(nncf_graph, "folded_scalar_clone_model.dot", TRANSFORMED_GRAPH_DIR_NAME, extended=True)


def test_constant_folding_with_constraints(is_per_channel):
model = ConstantFoldingTestModel()
model_with_correct_pattern = get_torch_fx_model(model, torch.ones(model.INPUT_SIZE))
Expand Down
12 changes: 12 additions & 0 deletions tests/torch/test_models/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,18 @@ def forward(self, x):
return x + y


class ScalarCloneTestModel(nn.Module):
INPUT_SIZE = (1, 3, 3, 3)

def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 3)

def forward(self, x):
y = torch.clone(torch.tensor(1).cpu())
return self.linear(x) + y


class ShortTransformer(torch.nn.Module):
def __init__(self, in_features, num_embeddings, share_weights=False):
super().__init__()
Expand Down

0 comments on commit 92ebbf2

Please sign in to comment.