From d842fa8c85402ff71b1287078d96d332fe7788ee Mon Sep 17 00:00:00 2001 From: Alexander Dokuchaev Date: Mon, 24 Jul 2023 12:55:21 +0300 Subject: [PATCH] Add torch.set_num_threads for ptq tests (#1995) ### Changes Set number of thread for torch backends for `torch.set_num_threads(int(cpu_threads_num))` by set `CPU_THREADS_NUM` env variable. Example of reducing quantization time for TORCH backend: Model | Before | After --- | --- | --- deit3_small_patch16_224 | 2:34:09 | 0:00:46 dla34 | 0:16:38 | 0:00:38 Build: manual/post_training_quantization/122 Renamed tests to `test_ptq_quantization` ### Reason for changes Quantization time is dramatically slow for TORCH backends on CI. --- tests/post_training/pipelines/base.py | 7 +++++++ tests/post_training/test_quantize_conformance.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/post_training/pipelines/base.py b/tests/post_training/pipelines/base.py index eb1f92bd5e5..1c377f98445 100644 --- a/tests/post_training/pipelines/base.py +++ b/tests/post_training/pipelines/base.py @@ -8,6 +8,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import time from abc import ABC from abc import abstractmethod @@ -212,6 +213,12 @@ def quantize(self) -> None: Run quantization of the model and collect time and memory usage information. """ print("Quantization...") + + if self.backend in [BackendType.TORCH, BackendType.OLD_TORCH]: + cpu_threads_num = os.environ.get("CPU_THREADS_NUM") + if cpu_threads_num is not None: + torch.set_num_threads(int(cpu_threads_num)) + start_time = time.perf_counter() self.run_info.quant_memory_usage = memory_usage(self._quantize, max_usage=True) self.run_info.time_quantization = time.perf_counter() - start_time diff --git a/tests/post_training/test_quantize_conformance.py b/tests/post_training/test_quantize_conformance.py index 89e9bc8d52f..2f4620a2718 100644 --- a/tests/post_training/test_quantize_conformance.py +++ b/tests/post_training/test_quantize_conformance.py @@ -48,7 +48,7 @@ def read_reference_data(): @pytest.mark.parametrize("test_case_name", TEST_CASES.keys()) -def test_ptq_hf(test_case_name, data, output, result): +def test_ptq_quantization(test_case_name, data, output, result): pipeline = None err_msg = None test_model_param = None