Skip to content

Commit

Permalink
Remove the deprecated quantization tool
Browse files Browse the repository at this point in the history
Signed-off-by: Jeremy Fowers <[email protected]>
  • Loading branch information
jeremyfowers committed Dec 4, 2023
1 parent 3ceda8a commit a6c7be7
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 244 deletions.
1 change: 0 additions & 1 deletion docs/coverage.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ Name Stmts Miss Branch BrPart Cover Mi
--------------------------------------------------------------------------------------------------------
\turnkeyml\build\__init__.py 0 0 0 0 100%
\turnkeyml\build\onnx_helpers.py 70 34 28 2 45% 15-21, 28-87, 92, 95-100
\turnkeyml\build\quantization_helpers.py 29 20 18 0 19% 13-30, 35, 50-78
\turnkeyml\build\sequences.py 15 1 8 2 87% 62->61, 65
\turnkeyml\build\tensor_helpers.py 47 26 34 4 41% 17-44, 57, 61, 63-74, 78
\turnkeyml\build_api.py 31 9 8 3 64% 68-71, 120-125, 140-147
Expand Down
63 changes: 1 addition & 62 deletions src/turnkeyml/build/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import turnkeyml.common.build as build
import turnkeyml.build.tensor_helpers as tensor_helpers
import turnkeyml.build.onnx_helpers as onnx_helpers
import turnkeyml.build.quantization_helpers as quant_helpers
import turnkeyml.common.filesystem as fs


Expand Down Expand Up @@ -77,13 +76,6 @@ def converted_onnx_file(state: build.State):
)


def quantized_onnx_file(state: build.State):
return os.path.join(
onnx_dir(state),
f"{state.config.build_name}-op{state.config.onnx_opset}-opt-quantized_int8.onnx",
)


class ExportPlaceholder(stage.Stage):
"""
Placeholder Stage that should be replaced by a framework-specific export stage,
Expand Down Expand Up @@ -571,9 +563,8 @@ def fire(self, state: build.State):
inputs_file = state.original_inputs_file
if os.path.isfile(inputs_file):
inputs = np.load(inputs_file, allow_pickle=True)
to_downcast = False if state.quantization_samples else True
inputs_converted = tensor_helpers.save_inputs(
inputs, inputs_file, downcast=to_downcast
inputs, inputs_file, downcast=True
)
else:
raise exp.StageError(
Expand Down Expand Up @@ -621,58 +612,6 @@ def fire(self, state: build.State):
return state


class QuantizeONNXModel(stage.Stage):
"""
Stage that takes an ONNX model and a dataset of quantization samples as inputs,
and performs static post-training quantization to the model to int8 precision.
Expected inputs:
- state.model is a path to the ONNX model
- state.quantization_dataset is a dataset that is used for static quantization
Outputs:
- A *_quantized.onnx file => the quantized onnx model.
"""

def __init__(self):
super().__init__(
unique_name="quantize_onnx",
monitor_message="Quantizing ONNX model",
)

def fire(self, state: build.State):
input_path = state.intermediate_results[0]
output_path = quantized_onnx_file(state)

quant_helpers.quantize(
input_file=input_path,
data=state.quantization_samples,
output_file=output_path,
)

# Check that the converted model is still valid
success_msg = "\tSuccess quantizing ONNX model to int8"
fail_msg = "\tFailed quantizing ONNX model to int8"

if check_model(output_path, success_msg, fail_msg):
state.intermediate_results = [output_path]

stats = fs.Stats(state.cache_dir, state.config.build_name, state.stats_id)
stats.add_build_stat(
fs.Keys.ONNX_FILE,
output_path,
)
else:
msg = f"""
Attempted to use {state.quantization_dataset} to statically quantize
model to int8 datatype, however this operation was not successful.
More information may be available in the log file at **{self.logfile_path}**
"""
raise exp.StageError(msg)

return state


class SuccessStage(stage.Stage):
"""
Stage that sets state.build_status = build.Status.SUCCESSFUL_BUILD,
Expand Down
64 changes: 5 additions & 59 deletions src/turnkeyml/build/ignition.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,6 @@ def load_or_make_state(
monitor: bool,
model: build.UnionValidModelInstanceTypes = None,
inputs: Optional[Dict[str, Any]] = None,
quantization_samples: Optional[Collection] = None,
state_type: Type = build.State,
cache_validation_func: Callable = validate_cached_model,
extra_state_args: Optional[Dict] = None,
Expand All @@ -280,7 +279,6 @@ def load_or_make_state(
"cache_dir": cache_dir,
"config": config,
"model_type": model_type,
"quantization_samples": quantization_samples,
}

# Ensure that `rebuild` has a valid value
Expand All @@ -306,50 +304,6 @@ def load_or_make_state(
state_type=state_type,
)

# if the previous build is using quantization while the current is not
# or vice versa
if state.quantization_samples and quantization_samples is None:
if rebuild == "never":
msg = (
f"Model {config.build_name} was built in a previous call to "
"build_model() with post-training quantization sample enabled."
"However, post-training quantization is not enabled in the "
"current build. Rebuild is necessary but currently the rebuild"
"policy is set to 'never'. "
)
raise exp.CacheError(msg)

msg = (
f"Model {config.build_name} was built in a previous call to "
"build_model() with post-training quantization sample enabled."
"However, post-training quantization is not enabled in the "
"current build. Starting a fresh build."
)

printing.log_info(msg)
return _begin_fresh_build(state_args, state_type)

if not state.quantization_samples and quantization_samples is not None:
if rebuild == "never":
msg = (
f"Model {config.build_name} was built in a previous call to "
"build_model() with post-training quantization sample disabled."
"However, post-training quantization is enabled in the "
"current build. Rebuild is necessary but currently the rebuild"
"policy is set to 'never'. "
)
raise exp.CacheError(msg)

msg = (
f"Model {config.build_name} was built in a previous call to "
"build_model() with post-training quantization sample disabled."
"However, post-training quantization is enabled in the "
"current build. Starting a fresh build."
)

printing.log_info(msg)
return _begin_fresh_build(state_args, state_type)

except exp.StateError as e:
problem = (
"- build_model() failed to load "
Expand Down Expand Up @@ -500,7 +454,6 @@ def model_intake(
user_model,
user_inputs,
user_sequence: Optional[stage.Sequence],
user_quantization_samples: Optional[Collection] = None,
) -> Tuple[Any, Any, stage.Sequence, build.ModelType, str]:
# Model intake structure options:
# user_model
Expand Down Expand Up @@ -550,18 +503,11 @@ def model_intake(

sequence = copy.deepcopy(user_sequence)
if sequence is None:
if user_quantization_samples:
if model_type != build.ModelType.PYTORCH:
raise exp.IntakeError(
"Currently, post training quantization only supports Pytorch models."
)
sequence = sequences.pytorch_with_quantization
else:
sequence = stage.Sequence(
"top_level_sequence",
"Top Level Sequence",
[sequences.onnx_fp32],
)
sequence = stage.Sequence(
"top_level_sequence",
"Top Level Sequence",
[sequences.onnx_fp32],
)

# If there is an ExportPlaceholder Stage in the sequence, replace it with
# a framework-specific export Stage.
Expand Down
78 changes: 0 additions & 78 deletions src/turnkeyml/build/quantization_helpers.py

This file was deleted.

12 changes: 0 additions & 12 deletions src/turnkeyml/build/sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,6 @@
enable_model_validation=True,
)

pytorch_with_quantization = stage.Sequence(
"pytorch_export_sequence_with_quantization",
"Exporting PyTorch Model and Quantizing Exported ONNX",
[
export.ExportPytorchModel(),
export.OptimizeOnnxModel(),
export.QuantizeONNXModel(),
export.SuccessStage(),
],
enable_model_validation=True,
)

# Plugin interface for sequences
discovered_plugins = plugins.discover()

Expand Down
10 changes: 0 additions & 10 deletions src/turnkeyml/build_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def build_model(
monitor: Optional[bool] = None,
rebuild: Optional[str] = None,
sequence: Optional[List[stage.Stage]] = None,
quantization_samples: Collection = None,
onnx_opset: Optional[int] = None,
device: Optional[str] = None,
) -> build.State:
Expand Down Expand Up @@ -48,11 +47,6 @@ def build_model(
- None: Falls back to default
sequence: Override the default sequence of build stages. Power
users only.
quantization_samples: If set, performs post-training quantization
on the ONNX model using the provided samplesIf the previous build used samples
that are different to the samples used in current build, the "rebuild"
argument needs to be manually set to "always" in the current build
in order to create a new ONNX file.
onnx_opset: ONNX opset to use during ONNX export.
device: Specific device target to take into account during the build sequence.
Use the format "device_family", "device_family::part", or
Expand Down Expand Up @@ -96,7 +90,6 @@ def build_model(
model,
inputs,
sequence,
user_quantization_samples=quantization_samples,
)

# Get the state of the model from the cache if a valid build is available
Expand All @@ -109,7 +102,6 @@ def build_model(
monitor=monitor_setting,
model=model_locked,
inputs=inputs_locked,
quantization_samples=quantization_samples,
)

# Return a cached build if possible, otherwise prepare the model State for
Expand All @@ -124,8 +116,6 @@ def build_model(

return state

state.quantization_samples = quantization_samples

sequence_locked.show_monitor(config, state.monitor)
state = sequence_locked.launch(state)

Expand Down
Loading

0 comments on commit a6c7be7

Please sign in to comment.