Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]8205-Add qat support #8209

Draft
wants to merge 5 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/engines/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class Trainer(Workflow):

"""

def run(self) -> None: # type: ignore[override]
def run(self, *args) -> None: # type: ignore[override]
"""
Execute training based on Ignite Engine.
If call this function multiple times, it will continuously run from the previous state.
Expand Down
2 changes: 2 additions & 0 deletions monai/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from .metrics_reloaded_handler import MetricsReloadedBinaryHandler, MetricsReloadedCategoricalHandler
from .metrics_saver import MetricsSaver
from .mlflow_handler import MLFlowHandler
from .model_calibrator import ModelCalibrater
from .model_quantizer import ModelQuantizer
from .nvtx_handlers import MarkHandler, RangeHandler, RangePopHandler, RangePushHandler
from .panoptic_quality import PanopticQuality
from .parameter_scheduler import ParamSchedulerHandler
Expand Down
62 changes: 62 additions & 0 deletions monai/handlers/model_calibrator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING

import modelopt.torch.quantization as mtq
import torch

from monai.utils import IgniteInfo, min_version, optional_import

Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
Checkpoint, _ = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint")
if TYPE_CHECKING:
from ignite.engine import Engine
else:
Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")


class ModelCalibrater:
"""
Model quantizer is for model quantization. It takes a model as input and convert it to a quantized
model.

Args:
model: the model to be quantized.
example_inputs: the example inputs for the model quantization. examples::
(torch.randn(256,256,256),)
config: the calibration config.

"""

def __init__(self, model: torch.nn.Module, export_path: str, config: dict = mtq.INT8_SMOOTHQUANT_CFG) -> None:
self.model = model
self.export_path = export_path
self.config = config

def attach(self, engine: Engine) -> None:
"""
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
engine.add_event_handler(Events.STARTED, self)

@staticmethod
def _model_wrapper(engine, model):
engine.run()

def __call__(self, engine) -> None:
quant_fun = partial(self._model_wrapper, engine)
model = mtq.quantize(self.model, self.config, quant_fun)
torch.save(model.state_dict(), self.export_path)
71 changes: 71 additions & 0 deletions monai/handlers/model_quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Sequence
from types import MethodType
from typing import TYPE_CHECKING

import torch
from torch.ao.quantization.quantize_pt2e import prepare_qat_pt2e
from torch.ao.quantization.quantizer import Quantizer
from torch.ao.quantization.quantizer.xnnpack_quantizer import XNNPACKQuantizer, get_symmetric_quantization_config

from monai.utils import IgniteInfo, min_version, optional_import

Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
Checkpoint, _ = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint")
if TYPE_CHECKING:
from ignite.engine import Engine
else:
Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")


class ModelQuantizer:
"""
Model quantizer is for model quantization. It takes a model as input and convert it to a quantized
model.

Args:
model: the model to be quantized.
example_inputs: the example inputs for the model quantization. examples::
(torch.randn(256,256,256),)
quantizer: quantizer for the quantization job.

"""

def __init__(
self, model: torch.nn.Module, example_inputs: Sequence, export_path: str, quantizer: Quantizer | None = None
) -> None:
self.model = model
self.example_inputs = example_inputs
self.export_path = export_path
self.quantizer = (
XNNPACKQuantizer().set_global(get_symmetric_quantization_config()) if quantizer is None else quantizer
)

def attach(self, engine: Engine) -> None:
"""
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
engine.add_event_handler(Events.STARTED, self.start)
engine.add_event_handler(Events.ITERATION_COMPLETED, self.epoch)

def start(self) -> None:
self.model = torch.export.export_for_training(self.model, self.example_inputs).module()
self.model = prepare_qat_pt2e(self.model, self.quantizer)
self.model.train = MethodType(torch.ao.quantization.move_exported_model_to_train, self.model)
self.model.eval = MethodType(torch.ao.quantization.move_exported_model_to_eval, self.model)

def epoch(self) -> None:
torch.save(self.model.state_dict(), self.export_path)
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,4 @@ pyamg>=5.0.0
git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588
onnx_graphsurgeon
polygraphy
nvidia-modelopt>=0.19.0
Loading