Skip to content

Commit

Permalink
Add a retry mechanism to the model downloader (#5943)
Browse files Browse the repository at this point in the history
  • Loading branch information
oobabooga authored Apr 27, 2024
1 parent dfdb6fe commit 5770e06
Showing 1 changed file with 59 additions and 42 deletions.
101 changes: 59 additions & 42 deletions download-model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
import re
import sys
from pathlib import Path
from time import sleep

import requests
import tqdm
from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError, RequestException, Timeout
from tqdm.contrib.concurrent import thread_map

base = os.environ.get("HF_ENDPOINT") or "https://huggingface.co"
Expand Down Expand Up @@ -177,50 +179,65 @@ def get_output_folder(self, model, branch, is_lora, is_llamacpp=False):
return output_folder

def get_single_file(self, url, output_folder, start_from_scratch=False):
session = self.get_session()
filename = Path(url.rsplit('/', 1)[1])
output_path = output_folder / filename
headers = {}
mode = 'wb'
if output_path.exists() and not start_from_scratch:

# Check if the file has already been downloaded completely
r = session.get(url, stream=True, timeout=10)
total_size = int(r.headers.get('content-length', 0))
if output_path.stat().st_size >= total_size:
return

# Otherwise, resume the download from where it left off
headers = {'Range': f'bytes={output_path.stat().st_size}-'}
mode = 'ab'

with session.get(url, stream=True, headers=headers, timeout=10) as r:
r.raise_for_status() # Do not continue the download if the request was unsuccessful
total_size = int(r.headers.get('content-length', 0))
block_size = 1024 * 1024 # 1MB

tqdm_kwargs = {
'total': total_size,
'unit': 'iB',
'unit_scale': True,
'bar_format': '{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}'
}

if 'COLAB_GPU' in os.environ:
tqdm_kwargs.update({
'position': 0,
'leave': True
})

with open(output_path, mode) as f:
with tqdm.tqdm(**tqdm_kwargs) as t:
count = 0
for data in r.iter_content(block_size):
t.update(len(data))
f.write(data)
if total_size != 0 and self.progress_bar is not None:
count += len(data)
self.progress_bar(float(count) / float(total_size), f"{filename}")

max_retries = 7
attempt = 0
while attempt < max_retries:
attempt += 1
session = self.get_session()
headers = {}
mode = 'wb'

if output_path.exists() and not start_from_scratch:
# Resume download
r = session.get(url, stream=True, timeout=20)
total_size = int(r.headers.get('content-length', 0))
if output_path.stat().st_size >= total_size:
return

headers = {'Range': f'bytes={output_path.stat().st_size}-'}
mode = 'ab'

try:
with session.get(url, stream=True, headers=headers, timeout=30) as r:
r.raise_for_status() # If status is not 2xx, raise an error
total_size = int(r.headers.get('content-length', 0))
block_size = 1024 * 1024 # 1MB

tqdm_kwargs = {
'total': total_size,
'unit': 'iB',
'unit_scale': True,
'bar_format': '{l_bar}{bar}| {n_fmt}/{total_fmt} {rate_fmt}'
}

if 'COLAB_GPU' in os.environ:
tqdm_kwargs.update({
'position': 0,
'leave': True
})

with open(output_path, mode) as f:
with tqdm.tqdm(**tqdm_kwargs) as t:
count = 0
for data in r.iter_content(block_size):
f.write(data)
t.update(len(data))
if total_size != 0 and self.progress_bar is not None:
count += len(data)
self.progress_bar(float(count) / float(total_size), f"{filename}")

break # Exit loop if successful
except (RequestException, ConnectionError, Timeout) as e:
print(f"Error downloading {filename}: {e}.")
print(f"That was attempt {attempt}/{max_retries}.", end=' ')
if attempt < max_retries:
print(f"Retry begins in {2 ** attempt} seconds.")
sleep(2 ** attempt)
else:
print("Failed to download after the maximum number of attempts.")

def start_download_threads(self, file_list, output_folder, start_from_scratch=False, threads=4):
thread_map(lambda url: self.get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True)
Expand Down

0 comments on commit 5770e06

Please sign in to comment.