Skip to content

Commit

Permalink
model manager
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderDokuchaev committed Nov 11, 2024
1 parent 058dce6 commit 19fce4f
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 0 deletions.
87 changes: 87 additions & 0 deletions nncf/experimental/torch2/function_hook/model_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple

import torch
from torch import nn

import nncf


def split_const_name(const_name: str) -> Tuple[str, str]:
"""
Splits the constant name into module and attribute names.
:param const_name: The full name of the constant, including module and attribute.
:return:
- module_name: The name of the module containing the constant.
- weight_attr_name: The name of the constant attribute within the module.
"""
index = const_name.rfind(".")
if index == -1:
return str(), const_name
module_name = const_name[:index]
weight_attr_name = const_name[index + 1 :]
return module_name, weight_attr_name


def get_module_by_name(module_name: str, model: torch.nn.Module) -> torch.nn.Module:
"""
Retrieves a module from a PyTorch model by its hierarchical name.
:param module_name: The name of the module to retrieve (e.g., "module1.submodule2").
:param model: The PyTorch model.
:return: The retrieved module.
"""
if not module_name:
return model
curr_module = model
for name in module_name.split("."):
for child_name, child_module in curr_module.named_children():
if child_name == name:
curr_module = child_module
break
else:
raise nncf.ModuleNotFoundError(f"Could not find the {module_name} module in the model.")
return curr_module


def get_const_data(model: nn.Module, const_name: str) -> torch.Tensor:
"""
Retrieves a constant tensor associated with a given node.
:param const_name: The name of const data.
:param model: The PyTorch model.
:return: A torch.Tensor object containing the constant value.
"""
module_name, const_attr_name = split_const_name(const_name)
module = get_module_by_name(module_name, model)
data: torch.Tensor = getattr(module, const_attr_name)
if isinstance(data, torch.nn.Parameter):
return data.data
return data


def set_const_data(model: nn.Module, const_name: str, data: torch.Tensor) -> None:
"""
Sets the constant data associated with a specific name of a tensor in a PyTorch model.
:param model: The PyTorch model.
:param const_name: The name of tensor in the model.
:param data: The constant data tensor to be set.
"""
module_name, const_attr_name = split_const_name(const_name)
module = get_module_by_name(module_name, model)
const = getattr(module, const_attr_name)
if isinstance(const, torch.nn.Parameter):
const.data = data
else:
setattr(module, const_attr_name, data)
97 changes: 97 additions & 0 deletions tests/torch2/function_hook/test_model_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import pytest
import torch
from torch import nn

from nncf.experimental.torch2.function_hook.model_manager import get_const_data
from nncf.experimental.torch2.function_hook.model_manager import get_module_by_name
from nncf.experimental.torch2.function_hook.model_manager import set_const_data
from nncf.experimental.torch2.function_hook.model_manager import split_const_name


@pytest.mark.parametrize(
"const_name, ref",
(
("conv.weight", ("conv", "weight")),
("module.head.conv.bias", ("module.head.conv", "bias")),
("param", ("", "param")),
),
)
def test_split_const_name(const_name, ref):
assert split_const_name(const_name) == ref


class ModelToGetModule(nn.Module):
def __init__(self):
super().__init__()
self.bn = nn.BatchNorm1d(1)
self.seq = nn.Sequential(nn.Identity(), nn.ReLU())

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.bn(x)
x = self.seq(x)
return x


def test_get_module_by_name():
model = ModelToGetModule()
assert get_module_by_name("", model) is model
assert get_module_by_name("bn", model) is model.bn
assert get_module_by_name("seq.0", model) is model.seq[0]
assert get_module_by_name("seq.1", model) is model.seq[1]


class ModelGetSetConst(nn.Module):
param: torch.nn.Parameter
buffer: torch.Tensor

def __init__(self):
super().__init__()
self.register_parameter("param", nn.Parameter(torch.tensor([1.0])))
self.register_buffer("buffer", torch.tensor([2.0]))

def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.param + self.buffer


def test_get_const_data():
model = ModelGetSetConst()

data = get_const_data(model, "param")
assert isinstance(data, type(model.param.data))
assert data == model.param.data

data = get_const_data(model, "buffer")
assert isinstance(data, type(model.buffer))
assert data == model.buffer

with pytest.raises(AttributeError):
get_const_data(model, "not_exist")


def test_set_const_data():
model = ModelGetSetConst()

set_const_data(model, "param", torch.tensor([100.0]))
assert isinstance(model.param, torch.nn.Parameter)
assert model.param.data == torch.tensor([100.0])
assert list(model.parameters())[0].data == torch.tensor([100.0])

set_const_data(model, "buffer", torch.tensor([200.0]))
assert isinstance(model.buffer, torch.Tensor) and not isinstance(model.buffer, torch.nn.Parameter)
assert model.buffer == torch.tensor([200.0])
assert list(model.buffers())[0] == torch.tensor([200.0])

with pytest.raises(AttributeError):
set_const_data(model, "not_exist", None)

0 comments on commit 19fce4f

Please sign in to comment.