Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Dec 5, 2023
1 parent 655a2ac commit 0e35a0c
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 40 deletions.
5 changes: 1 addition & 4 deletions nncf/experimental/tensor/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,7 @@ def stack(x: List[Tensor], axis: int = 0) -> Tensor:
:return: Stacked Tensor.
"""
if isinstance(x, List):
unwrapped_x = [i.data for i in x]
# singledispatch cannot dispatch function by element in a list
res = stack.dispatch(type(unwrapped_x[0]))(unwrapped_x, axis=axis)
return Tensor(res)
return Tensor(_dispatch_list(stack, x, axis=axis))
raise NotImplementedError(f"Function `stack` is not implemented for {type(x)}")


Expand Down
3 changes: 0 additions & 3 deletions nncf/torch/nncf_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,6 @@ def get_modules_in_nncf_modules_by_type(self, class_names: List[str]) -> Dict[Sc
retval[nncf_module_scope + relative_scope] = target_module
return retval

def update_model_ref(self, model: torch.nn.Module) -> None:
object.__setattr__(self, "__model_ref", model)

def temporary_insert_at_point(self, point: PTInsertionPoint, fn_list: List[Callable]):
hook_addresses = self.insert_at_point(point, fn_list)
self._temprorary_hooks_adresses.append(hook_addresses)
Expand Down
33 changes: 0 additions & 33 deletions nncf/torch/statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy
from typing import Dict

import numpy as np
Expand All @@ -27,41 +26,9 @@
from nncf.torch.tensor_statistics.algo import create_register_input_hook


class ModelView:
def __init__(self, model: NNCFNetwork):
self.model = model
self.nncf_module_additions = self.model.nncf.save_nncf_module_additions()

def __enter__(self):
# Model ref removed to prevent copying
self.model.nncf.update_model_ref(None)

# nncf_replaced_models removed to prevent copying
replaced_modules = self.model.nncf._nncf_replaced_modules
self.model.nncf._nncf_replaced_modules = None

self.nncf_interface = deepcopy(self.model.nncf)

# Model ref is recovering
self.model.nncf.update_model_ref(self.model)
self.nncf_interface.update_model_ref(self.model)

# nncf_replaced_models is recovering
self.model.nncf._nncf_replaced_modules = replaced_modules
self.nncf_interface._nncf_replaced_modules = replaced_modules
return self.model

def __exit__(self, exc_type, exc_val, exc_tb):
self.model._nncf = self.nncf_interface
self.model.nncf.reset_nncf_modules()
self.model.nncf.load_nncf_module_additions(self.nncf_module_additions)


class PTStatisticsAggregator(StatisticsAggregator):
def collect_statistics(self, model: NNCFNetwork, graph: NNCFGraph) -> None:
with torch.no_grad():
# with ModelView(model) as intermediate_model:
# super().collect_statistics(intermediate_model, graph)
super().collect_statistics(model, graph)
model.nncf.remove_temporary_ops()

Expand Down

0 comments on commit 0e35a0c

Please sign in to comment.