From 6c1ff8f361addcf85442fb2021dce3e27d35a6bf Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Mon, 22 Jul 2024 15:08:56 -0700 Subject: [PATCH] Parallelize remote file downloads within each iree_tests directory. (#297) Progress on https://github.com/nod-ai/SHARK-TestSuite/issues/285 This is a simple improvement over serial processing, but it could still be improved further. Looks like this shaves ~10 seconds off runs in this repo: * Before 45s: https://github.com/nod-ai/SHARK-TestSuite/actions/runs/9996869262/job/27632145027#step:6:15 * After 35s: https://github.com/nod-ai/SHARK-TestSuite/actions/runs/9997455742/job/27634060037?pr=297#step:6:15 I saw 2m+ runs in IREE, hopefully this helps there too. Should be able to get the total time down to 10-20 seconds. --- iree_tests/download_remote_files.py | 47 ++++++++++++++++++++++------- 1 file changed, 36 insertions(+), 11 deletions(-) diff --git a/iree_tests/download_remote_files.py b/iree_tests/download_remote_files.py index ec0de6053..c78894baf 100644 --- a/iree_tests/download_remote_files.py +++ b/iree_tests/download_remote_files.py @@ -5,7 +5,9 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from azure.storage.blob import BlobClient, BlobProperties +from functools import partial from huggingface_hub import hf_hub_download +from multiprocessing import Pool from pathlib import Path from typing import Optional import argparse @@ -211,29 +213,44 @@ def download_generic_remote_file( raise NotImplementedError("generic remote file downloads not implemented yet") +def download_file(remote_file: str, test_dir: Path, cache_dir: Optional[Path]): + """ + Downloads a file from URL into test_dir, if the URL schema is supported. + + If cache_dir is set, downloads there instead, creating a symlink from + test_dir/file_name to cache_dir/file_name. + """ + if "blob.core.windows.net" in remote_file: + download_azure_remote_file(remote_file, test_dir, cache_dir) + elif "huggingface" in remote_file: + download_huggingface_remote_file(remote_file, test_dir, cache_dir) + else: + download_generic_remote_file(remote_file, test_dir, cache_dir) + + def download_files_for_test_case( - test_case_json: dict, test_dir: Path, cache_dir: Optional[Path] + test_case_json: dict, test_dir: Path, jobs: int, cache_dir: Optional[Path] ): if "remote_files" not in test_case_json: return - # This is naive (greedy, serial) for now. We could batch downloads that - # share a source: + # This is naive for now. We could further optimize with batching: # * Iterate over all files (across all included paths), building a list # of files to download (checking hashes / local references before # adding to the list) # * (Optionally) Determine disk space needed/available and ask before # continuing # * Group files based on source (e.g. Azure container) - # * Start batched/parallel downloads - for remote_file in test_case_json["remote_files"]: - if "blob.core.windows.net" in remote_file: - download_azure_remote_file(remote_file, test_dir, cache_dir) - elif "huggingface" in remote_file: - download_huggingface_remote_file(remote_file, test_dir, cache_dir) - else: - download_generic_remote_file(remote_file, test_dir, cache_dir) + with Pool(jobs) as pool: + pool.map( + partial( + download_file, + test_dir=test_dir, + cache_dir=cache_dir, + ), + test_case_json["remote_files"], + ) if __name__ == "__main__": @@ -249,6 +266,13 @@ def download_files_for_test_case( help="Local cache directory to download into. If set, symlinks will be created pointing to " "this location", ) + parser.add_argument( + "-j", + "--jobs", + type=int, + default=8, + help="Number of parallel processes to use when downloading files", + ) args = parser.parse_args() # Adjust logging levels. @@ -287,5 +311,6 @@ def download_files_for_test_case( download_files_for_test_case( test_case_json=test_case_json, test_dir=test_dir, + jobs=args.jobs, cache_dir=cache_dir_for_test, )