Skip to content

Commit

Permalink
Add torch.set_num_threads for ptq tests (#1995)
Browse files Browse the repository at this point in the history
### Changes

Set number of thread for torch backends for
`torch.set_num_threads(int(cpu_threads_num))` by set `CPU_THREADS_NUM`
env variable.

Example of reducing quantization time for TORCH backend:

Model | Before | After  
--- | --- | --- 
deit3_small_patch16_224 | 2:34:09  | 0:00:46
dla34 | 0:16:38  | 0:00:38

Build: manual/post_training_quantization/122

Renamed tests to `test_ptq_quantization`

### Reason for changes

Quantization time is dramatically slow for TORCH backends on CI.
  • Loading branch information
AlexanderDokuchaev authored Jul 24, 2023
1 parent b3f240d commit d842fa8
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
7 changes: 7 additions & 0 deletions tests/post_training/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
from abc import ABC
from abc import abstractmethod
Expand Down Expand Up @@ -212,6 +213,12 @@ def quantize(self) -> None:
Run quantization of the model and collect time and memory usage information.
"""
print("Quantization...")

if self.backend in [BackendType.TORCH, BackendType.OLD_TORCH]:
cpu_threads_num = os.environ.get("CPU_THREADS_NUM")
if cpu_threads_num is not None:
torch.set_num_threads(int(cpu_threads_num))

start_time = time.perf_counter()
self.run_info.quant_memory_usage = memory_usage(self._quantize, max_usage=True)
self.run_info.time_quantization = time.perf_counter() - start_time
Expand Down
2 changes: 1 addition & 1 deletion tests/post_training/test_quantize_conformance.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def read_reference_data():


@pytest.mark.parametrize("test_case_name", TEST_CASES.keys())
def test_ptq_hf(test_case_name, data, output, result):
def test_ptq_quantization(test_case_name, data, output, result):
pipeline = None
err_msg = None
test_model_param = None
Expand Down

0 comments on commit d842fa8

Please sign in to comment.