diff --git a/nncf/experimental/torch/fx/constant_folding.py b/nncf/experimental/torch/fx/constant_folding.py index 945f79b497c..2d312c38077 100644 --- a/nncf/experimental/torch/fx/constant_folding.py +++ b/nncf/experimental/torch/fx/constant_folding.py @@ -222,6 +222,20 @@ def insert_placerholder_values(self, env: Dict[torch.fx.Node, Any]) -> None: env[n] = self.unknown_value # type: ignore[assignment] +def get_model_device(model: torch.fx.GraphModule) -> torch.device: + """ + Returns device of the first model parameter of torch.device("cpu"). + + :param model: GraphModule instance. + :return: Device of the first model parameter of torch.device("cpu"). + """ + try: + device = next(model.parameters()).device + except StopIteration: + return torch.device("cpu") + return device + + def constant_fold( gm: torch.fx.GraphModule, constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None, @@ -233,13 +247,16 @@ def constant_fold( :param constraint_fn: Constraint function which takes a node and returs the constraint: should the node be constant folded or not. """ + with torch.utils._python_dispatch._disable_current_modes(): cf = ConstantFolder(gm) cf.run() + device = get_model_device(gm) for node, constant in cf.node_replacements.items(): if constraint_fn is not None and not constraint_fn(node): continue + constant = constant.to(device) _replace_node_with_constant(gm, node, constant) erased_params = [] diff --git a/tests/torch/data/reference_graphs/fx/transformed/folded_scalar_clone_model.dot b/tests/torch/data/reference_graphs/fx/transformed/folded_scalar_clone_model.dot new file mode 100644 index 00000000000..a0e60be77c6 --- /dev/null +++ b/tests/torch/data/reference_graphs/fx/transformed/folded_scalar_clone_model.dot @@ -0,0 +1,20 @@ +strict digraph { +"0 linear_weight" [id=0, type=get_attr]; +"1 linear_bias" [id=1, type=get_attr]; +"2 lifted_tensor_0" [id=2, type=get_attr]; +"3 x" [id=3, type=input]; +"4 lift_fresh_copy" [id=4, type=lift_fresh_copy]; +"5 detach_" [id=5, type=detach_]; +"6 _frozen_param0" [id=6, type=get_attr]; +"7 linear" [id=7, type=linear]; +"8 add" [id=8, type=add]; +"9 output" [id=9, type=output]; +"0 linear_weight" -> "7 linear" [label="(3, 3)", style=solid]; +"1 linear_bias" -> "7 linear" [label="(3,)", style=solid]; +"2 lifted_tensor_0" -> "4 lift_fresh_copy" [label="()", style=solid]; +"3 x" -> "7 linear" [label="(1, 3, 3, 3)", style=solid]; +"4 lift_fresh_copy" -> "5 detach_" [label="()", style=solid]; +"6 _frozen_param0" -> "8 add" [label="()", style=solid]; +"7 linear" -> "8 add" [label="(1, 3, 3, 3)", style=solid]; +"8 add" -> "9 output" [label="(1, 3, 3, 3)", style=solid]; +} diff --git a/tests/torch/fx/test_model_transformer.py b/tests/torch/fx/test_model_transformer.py index 23039ee99ef..b458fce0b74 100644 --- a/tests/torch/fx/test_model_transformer.py +++ b/tests/torch/fx/test_model_transformer.py @@ -56,6 +56,7 @@ from tests.torch.test_models.synthetic import ConvolutionWithNotTensorBiasModel from tests.torch.test_models.synthetic import MultiBranchesConnectedModel from tests.torch.test_models.synthetic import MultiBranchesConnectedModelWithConcat +from tests.torch.test_models.synthetic import ScalarCloneTestModel @dataclass @@ -509,6 +510,21 @@ def test_constant_folding(): check_graph(nncf_graph, "folded_model.dot", TRANSFORMED_GRAPH_DIR_NAME, extended=True) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Cuda is not available") +def test_constant_folding_scalar_clone(): + model = ScalarCloneTestModel().cuda() + captured_model = get_torch_fx_model(model, torch.ones(model.INPUT_SIZE)) + assert captured_model.lifted_tensor_0.device == torch.device("cpu") + + folded_model = deepcopy(captured_model) + constant_fold(folded_model) + ex_input = torch.ones(model.INPUT_SIZE).cuda() + assert torch.allclose(captured_model(ex_input), folded_model(ex_input)) + + nncf_graph = GraphConverter.create_nncf_graph(folded_model) + check_graph(nncf_graph, "folded_scalar_clone_model.dot", TRANSFORMED_GRAPH_DIR_NAME, extended=True) + + def test_constant_folding_with_constraints(is_per_channel): model = ConstantFoldingTestModel() model_with_correct_pattern = get_torch_fx_model(model, torch.ones(model.INPUT_SIZE)) diff --git a/tests/torch/test_models/synthetic.py b/tests/torch/test_models/synthetic.py index c999deedfce..472743698b4 100644 --- a/tests/torch/test_models/synthetic.py +++ b/tests/torch/test_models/synthetic.py @@ -614,6 +614,18 @@ def forward(self, x): return x + y +class ScalarCloneTestModel(nn.Module): + INPUT_SIZE = (1, 3, 3, 3) + + def __init__(self): + super().__init__() + self.linear = nn.Linear(3, 3) + + def forward(self, x): + y = torch.clone(torch.tensor(1).cpu()) + return self.linear(x) + y + + class ShortTransformer(torch.nn.Module): def __init__(self, in_features, num_embeddings, share_weights=False): super().__init__()