From c5e4e2239183c23aca2f65a84c993c05cc6fa417 Mon Sep 17 00:00:00 2001 From: Alexander Dokuchaev Date: Tue, 17 Dec 2024 05:14:52 +0200 Subject: [PATCH] f --- nncf/quantization/algorithms/pipeline.py | 14 +++++++------- .../algorithms/post_training/algorithm.py | 5 +++-- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/nncf/quantization/algorithms/pipeline.py b/nncf/quantization/algorithms/pipeline.py index 1f659e14125..0159955ae0c 100644 --- a/nncf/quantization/algorithms/pipeline.py +++ b/nncf/quantization/algorithms/pipeline.py @@ -115,7 +115,7 @@ def run_step( def run_from_step( self, - model: ModelWrapper, + model_wrapper: ModelWrapper, dataset: Dataset, start_step_index: int = 0, step_index_to_statistics: Optional[Dict[int, StatisticPointsContainer]] = None, @@ -134,23 +134,23 @@ def run_from_step( :return: The updated model after executing the pipeline from the specified pipeline step to the end. """ - pipeline_steps = self._remove_unsupported_algorithms(get_backend(model.model)) + pipeline_steps = self._remove_unsupported_algorithms(model_wrapper.backend) if step_index_to_statistics is None: step_index_to_statistics = {} # The `step_model` and `step_graph` entities are required to execute `step_index`-th pipeline step - step_model = model + step_model_wrapper = model_wrapper for step_index in range(start_step_index, len(pipeline_steps)): # Collect statistics required to run current pipeline step step_statistics = step_index_to_statistics.get(step_index) if step_statistics is None: - statistic_points = self.get_statistic_points_for_step(step_index, step_model) - step_statistics = collect_statistics(statistic_points, step_model, dataset) + statistic_points = self.get_statistic_points_for_step(step_index, step_model_wrapper) + step_statistics = collect_statistics(statistic_points, step_model_wrapper, dataset) # Run current pipeline step - step_model = self.run_step(step_index, step_statistics, step_model) + step_model_wrapper = self.run_step(step_index, step_statistics, step_model_wrapper) - return step_model + return step_model_wrapper def get_statistic_points_for_step(self, step_index: int, model_wrapper: ModelWrapper) -> StatisticPointsContainer: """ diff --git a/nncf/quantization/algorithms/post_training/algorithm.py b/nncf/quantization/algorithms/post_training/algorithm.py index 80879f361c5..dbfbf8cd80b 100644 --- a/nncf/quantization/algorithms/post_training/algorithm.py +++ b/nncf/quantization/algorithms/post_training/algorithm.py @@ -90,11 +90,12 @@ def available_backends(self) -> List[BackendType]: return list(backends) def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer: + assert False return self._pipeline.get_statistic_points_for_step(0, model, graph) def apply( self, - model: ModelWrapper, + model_wrapper: ModelWrapper, *, statistic_points: Optional[StatisticPointsContainer] = None, dataset: Optional[Dataset] = None, @@ -109,4 +110,4 @@ def apply( if statistic_points: step_index_to_statistics = {0: statistic_points} - return self._pipeline.run_from_step(model, dataset, 0, step_index_to_statistics) + return self._pipeline.run_from_step(model_wrapper, dataset, 0, step_index_to_statistics)