Skip to content

Commit

Permalink
Remove StepwisePipeline class
Browse files Browse the repository at this point in the history
  • Loading branch information
andrey-churkin committed Sep 21, 2023
1 parent 1c8d511 commit 96e8d76
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 208 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from nncf.quantization.algorithms.fast_bias_correction.algorithm import FastBiasCorrection
from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization
from nncf.quantization.algorithms.smooth_quant.algorithm import SmoothQuant
from nncf.quantization.pipelines.stepwise_pipeline import StepwisePipeline
from nncf.quantization.pipelines.pipeline import Pipeline
from nncf.quantization.range_estimator import AggregatorType
from nncf.quantization.range_estimator import RangeEstimatorParameters
from nncf.quantization.range_estimator import StatisticsCollectorParameters
Expand Down Expand Up @@ -89,7 +89,7 @@ def _get_bias_correction_param_grid() -> ParamGrid:
return {"fast_bias_correction": [True, False]}


def get_quantization_param_grids(pipeline: StepwisePipeline) -> List[ParamGrid]:
def get_quantization_param_grids(pipeline: Pipeline) -> List[ParamGrid]:
"""
Returns params grid for post-training quantization algorithm.
"""
Expand Down
15 changes: 7 additions & 8 deletions nncf/quantization/pipelines/hyperparameter_tuner/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,10 @@
from nncf.quantization.algorithms.accuracy_control.rank_functions import create_normalized_mse_func
from nncf.quantization.algorithms.accuracy_control.subset_selection import select_subset
from nncf.quantization.pipelines.pipeline import Pipeline
from nncf.quantization.pipelines.stepwise_pipeline import StepwisePipeline
from nncf.quantization.pipelines.stepwise_pipeline import collect_statistics
from nncf.quantization.pipelines.stepwise_pipeline import get_statistic_points
from nncf.quantization.pipelines.stepwise_pipeline import run_pipeline_from_step
from nncf.quantization.pipelines.stepwise_pipeline import run_pipeline_step
from nncf.quantization.pipelines.pipeline import collect_statistics
from nncf.quantization.pipelines.pipeline import get_statistic_points
from nncf.quantization.pipelines.pipeline import run_pipeline_from_step
from nncf.quantization.pipelines.pipeline import run_pipeline_step

TModel = TypeVar("TModel")
TTensor = TypeVar("TTensor")
Expand Down Expand Up @@ -180,7 +179,7 @@ def find_best_combination(
return best_combination_key


class HyperparameterTuner(Pipeline):
class HyperparameterTuner:
"""
This algorithm is used to find a best combination of parameters from `param_grid`.
Expand Down Expand Up @@ -219,7 +218,7 @@ class HyperparameterTuner(Pipeline):

def __init__(
self,
pipeline_cls: Type[StepwisePipeline],
pipeline_cls: Type[Pipeline],
init_params: Dict[str, Any],
param_grids: List[Dict[str, List[Any]]],
calibration_dataset: Dataset,
Expand Down Expand Up @@ -255,7 +254,7 @@ def __init__(
self._error_fn = None

# Will be initialized inside `_prepare_pipeline_step()` method
self._pipelines: Dict[CombinationKey, StepwisePipeline] = {}
self._pipelines: Dict[CombinationKey, Pipeline] = {}
self._step_index_to_statistics: Dict[int, StatisticPointsContainer] = {}

self._calculated_scores: Dict[CombinationKey, float] = {}
Expand Down
169 changes: 158 additions & 11 deletions nncf/quantization/pipelines/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,177 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC
from abc import abstractmethod
from typing import TypeVar
from typing import Dict, List, Optional, TypeVar, Union

from nncf.common.factory import NNCFGraphFactory
from nncf.common.factory import StatisticsAggregatorFactory
from nncf.common.graph.graph import NNCFGraph
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.data.dataset import Dataset
from nncf.quantization.algorithms.algorithm import Algorithm

TModel = TypeVar("TModel")
PipelineStep = List[Algorithm]


class Pipeline(ABC):
def get_statistic_points(pipeline_step: PipelineStep, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer:
"""
A base class for creating pipelines that apply algorithms to a model.
TODO
This abstract class serves as an interface for creating custom model
processing pipelines that encapsulate a series of algorithms to be
applied to a model using a provided dataset.
:param pipeline_step:
:param model:
:param graph:
:return:
"""
container = StatisticPointsContainer()
for algorithm in pipeline_step:
for statistic_points in algorithm.get_statistic_points(model, graph).values():
for statistic_point in statistic_points:
container.add_statistic_point(statistic_point)

return container


def collect_statistics(
containers: Union[StatisticPointsContainer, List[StatisticPointsContainer]],
model: TModel,
graph: NNCFGraph,
dataset: Dataset,
) -> StatisticPointsContainer:
"""
TODO:
:param statistic_points:
:param model:
:param graph:
:param dataset:
:return:
"""
if not isinstance(containers, list):
containers = [containers]

statistics_aggregator = StatisticsAggregatorFactory.create(model, dataset)
for container in containers:
statistics_aggregator.register_statistic_points(container)
statistics_aggregator.collect_statistics(model, graph)

return statistics_aggregator.statistic_points


class Pipeline:
"""
A class for creating pipelines that apply algorithms to a model.
This class is used for creating custom model processing pipelines
that encapsulate a series of algorithms to be applied to a model
using a provided dataset.
A pipeline consists of pipeline steps. Each pipeline step is a
sequence of Algorithm class instances whose statistic points are
combined and collected using the model obtained after the previous
pipeline step. The collected statistic points are used for all
algorithms in this step.
"""

def __init__(self, pipeline_steps: List[PipelineStep]):
"""
:param pipeline_steps: A sequence of pipeline steps to be executed in order.
"""
self._pipeline_steps = pipeline_steps

@property
def pipeline_steps(self) -> List[PipelineStep]:
"""
Property that defines the sequence of distinct pipeline steps to
be executed in order.
:return: A sequence of pipeline steps to be executed in order.
"""
return self._pipeline_steps

@abstractmethod
def run(self, model: TModel, dataset: Dataset) -> TModel:
"""
Abstract method that defines the sequence of algorithms to be
applied to the provided model using the provided dataset.
Executes the pipeline on the provided model.
:param model: A model to which pipeline will be applied.
:param dataset: A dataset that holds the data items for algorithms.
:return: The updated model after executing the entire pipeline.
"""
return run_pipeline_from_step(self, model, dataset)


def run_pipeline_step(
pipeline_step: PipelineStep,
pipeline_step_statistics: StatisticPointsContainer,
model: TModel,
graph: NNCFGraph,
) -> TModel:
"""
Executes a provided pipeline step on the provided model.
:param pipeline_step: A sequence of algorithms representing a pipeline step.
:param pipeline_step_statistics: Statistics required to execute a pipeline step.
:param model: A model to which a pipeline step will be applied.
:param graph: A graph assosiated with a model.
:return: The updated model after executing the pipeline step.
"""
current_model = model
current_graph = graph

for algorithm in pipeline_step[:-1]:
current_model = algorithm.apply(current_model, current_graph, pipeline_step_statistics)
current_graph = NNCFGraphFactory.create(current_model)
current_model = pipeline_step[-1].apply(current_model, current_graph, pipeline_step_statistics)

return current_model


def run_pipeline_from_step(
pipeline: Pipeline,
model: TModel,
dataset: Dataset,
graph: Optional[NNCFGraph] = None,
start_step_index: int = 0,
step_index_to_statistics: Optional[Dict[int, StatisticPointsContainer]] = None,
) -> TModel:
"""
Execute the pipeline from the specified pipeline step to the end.
:param pipeline: A pipeline part of which should be executed.
:param model: This is the model after the (start_step_index - 1)-th pipeline
step, or the initial model if start_step_index is 0.
:param dataset: A dataset that holds the data items for pipeline steps.
:param graph: A graph assosiated with a model.
:param start_step_index: Zero-based pipeline step index from which the pipeline
should be executed.
:param step_index_to_statistics: A mapping from pipeline step index to statistics
required to execute pipeline step.
:return: The updated model after executing the pipeline from the specified pipeline
step to the end.
"""
if step_index_to_statistics is None:
step_index_to_statistics = {}

# The `step_model` and `step_graph` entities are required to execute `step_index`-th pipeline step
step_model = model
step_graph = graph
step_index = start_step_index

for pipeline_step in pipeline.pipeline_steps[start_step_index:]:
# Create graph required to run current pipeline step
if step_graph is None:
step_graph = NNCFGraphFactory.create(step_model)

# Collect statistics required to run current pipeline step
step_statistics = step_index_to_statistics.get(step_index)
if step_statistics is None:
statistic_points = get_statistic_points(pipeline_step, step_model, step_graph)
step_statistics = collect_statistics(statistic_points, step_model, step_graph, dataset)

# Run current pipeline step
step_model = run_pipeline_step(pipeline_step, step_statistics, step_model, step_graph)

step_graph = None # We should rebuild the graph for the next pipeline step
step_index += 1

return step_model
6 changes: 3 additions & 3 deletions nncf/quantization/pipelines/post_training/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from nncf.quantization.algorithms.fast_bias_correction.algorithm import FastBiasCorrection
from nncf.quantization.algorithms.min_max.algorithm import MinMaxQuantization
from nncf.quantization.algorithms.smooth_quant.algorithm import SmoothQuant
from nncf.quantization.pipelines.stepwise_pipeline import StepwisePipeline
from nncf.quantization.pipelines.pipeline import Pipeline
from nncf.scopes import IgnoredScope

TModel = TypeVar("TModel")
Expand All @@ -37,7 +37,7 @@ def create_ptq_pipeline(
model_type: Optional[ModelType] = None,
ignored_scope: Optional[IgnoredScope] = None,
advanced_parameters: Optional[AdvancedQuantizationParameters] = None,
) -> StepwisePipeline:
) -> Pipeline:
"""
Creates a post-training quantization pipeline.
Expand Down Expand Up @@ -136,4 +136,4 @@ def create_ptq_pipeline(
)
)

return StepwisePipeline(pipeline_steps)
return Pipeline(pipeline_steps)
Loading

0 comments on commit 96e8d76

Please sign in to comment.