From 19fce4f2a777fe8176a76af61ce9efc444eceb80 Mon Sep 17 00:00:00 2001 From: Alexander Dokuchaev Date: Mon, 11 Nov 2024 01:59:33 +0200 Subject: [PATCH] model manager --- .../torch2/function_hook/model_manager.py | 87 +++++++++++++++++ .../function_hook/test_model_manager.py | 97 +++++++++++++++++++ 2 files changed, 184 insertions(+) create mode 100644 nncf/experimental/torch2/function_hook/model_manager.py create mode 100644 tests/torch2/function_hook/test_model_manager.py diff --git a/nncf/experimental/torch2/function_hook/model_manager.py b/nncf/experimental/torch2/function_hook/model_manager.py new file mode 100644 index 00000000000..3b3810eb902 --- /dev/null +++ b/nncf/experimental/torch2/function_hook/model_manager.py @@ -0,0 +1,87 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple + +import torch +from torch import nn + +import nncf + + +def split_const_name(const_name: str) -> Tuple[str, str]: + """ + Splits the constant name into module and attribute names. + + :param const_name: The full name of the constant, including module and attribute. + :return: + - module_name: The name of the module containing the constant. + - weight_attr_name: The name of the constant attribute within the module. + """ + index = const_name.rfind(".") + if index == -1: + return str(), const_name + module_name = const_name[:index] + weight_attr_name = const_name[index + 1 :] + return module_name, weight_attr_name + + +def get_module_by_name(module_name: str, model: torch.nn.Module) -> torch.nn.Module: + """ + Retrieves a module from a PyTorch model by its hierarchical name. + + :param module_name: The name of the module to retrieve (e.g., "module1.submodule2"). + :param model: The PyTorch model. + :return: The retrieved module. + """ + if not module_name: + return model + curr_module = model + for name in module_name.split("."): + for child_name, child_module in curr_module.named_children(): + if child_name == name: + curr_module = child_module + break + else: + raise nncf.ModuleNotFoundError(f"Could not find the {module_name} module in the model.") + return curr_module + + +def get_const_data(model: nn.Module, const_name: str) -> torch.Tensor: + """ + Retrieves a constant tensor associated with a given node. + + :param const_name: The name of const data. + :param model: The PyTorch model. + :return: A torch.Tensor object containing the constant value. + """ + module_name, const_attr_name = split_const_name(const_name) + module = get_module_by_name(module_name, model) + data: torch.Tensor = getattr(module, const_attr_name) + if isinstance(data, torch.nn.Parameter): + return data.data + return data + + +def set_const_data(model: nn.Module, const_name: str, data: torch.Tensor) -> None: + """ + Sets the constant data associated with a specific name of a tensor in a PyTorch model. + + :param model: The PyTorch model. + :param const_name: The name of tensor in the model. + :param data: The constant data tensor to be set. + """ + module_name, const_attr_name = split_const_name(const_name) + module = get_module_by_name(module_name, model) + const = getattr(module, const_attr_name) + if isinstance(const, torch.nn.Parameter): + const.data = data + else: + setattr(module, const_attr_name, data) diff --git a/tests/torch2/function_hook/test_model_manager.py b/tests/torch2/function_hook/test_model_manager.py new file mode 100644 index 00000000000..81e7c6324ae --- /dev/null +++ b/tests/torch2/function_hook/test_model_manager.py @@ -0,0 +1,97 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest +import torch +from torch import nn + +from nncf.experimental.torch2.function_hook.model_manager import get_const_data +from nncf.experimental.torch2.function_hook.model_manager import get_module_by_name +from nncf.experimental.torch2.function_hook.model_manager import set_const_data +from nncf.experimental.torch2.function_hook.model_manager import split_const_name + + +@pytest.mark.parametrize( + "const_name, ref", + ( + ("conv.weight", ("conv", "weight")), + ("module.head.conv.bias", ("module.head.conv", "bias")), + ("param", ("", "param")), + ), +) +def test_split_const_name(const_name, ref): + assert split_const_name(const_name) == ref + + +class ModelToGetModule(nn.Module): + def __init__(self): + super().__init__() + self.bn = nn.BatchNorm1d(1) + self.seq = nn.Sequential(nn.Identity(), nn.ReLU()) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.bn(x) + x = self.seq(x) + return x + + +def test_get_module_by_name(): + model = ModelToGetModule() + assert get_module_by_name("", model) is model + assert get_module_by_name("bn", model) is model.bn + assert get_module_by_name("seq.0", model) is model.seq[0] + assert get_module_by_name("seq.1", model) is model.seq[1] + + +class ModelGetSetConst(nn.Module): + param: torch.nn.Parameter + buffer: torch.Tensor + + def __init__(self): + super().__init__() + self.register_parameter("param", nn.Parameter(torch.tensor([1.0]))) + self.register_buffer("buffer", torch.tensor([2.0])) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + self.param + self.buffer + + +def test_get_const_data(): + model = ModelGetSetConst() + + data = get_const_data(model, "param") + assert isinstance(data, type(model.param.data)) + assert data == model.param.data + + data = get_const_data(model, "buffer") + assert isinstance(data, type(model.buffer)) + assert data == model.buffer + + with pytest.raises(AttributeError): + get_const_data(model, "not_exist") + + +def test_set_const_data(): + model = ModelGetSetConst() + + set_const_data(model, "param", torch.tensor([100.0])) + assert isinstance(model.param, torch.nn.Parameter) + assert model.param.data == torch.tensor([100.0]) + assert list(model.parameters())[0].data == torch.tensor([100.0]) + + set_const_data(model, "buffer", torch.tensor([200.0])) + assert isinstance(model.buffer, torch.Tensor) and not isinstance(model.buffer, torch.nn.Parameter) + assert model.buffer == torch.tensor([200.0]) + assert list(model.buffers())[0] == torch.tensor([200.0]) + + with pytest.raises(AttributeError): + set_const_data(model, "not_exist", None)