Skip to content

Commit

Permalink
Do not download the onnx data file if it exists
Browse files Browse the repository at this point in the history
Signed-off-by: Tung D. Le <[email protected]>
  • Loading branch information
tungld committed Nov 1, 2024
1 parent d21f2db commit 1d8dd01
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions utils/python/transformers/run_gpt2_from_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,32 +44,38 @@
decoder_with_past_model_name = "decoder_with_past_model.onnx"
config_json_name = "config.json"
decoder_url = f"https://huggingface.co/openai-community/{model_name_or_path}/resolve/main/onnx/{decoder_model_name}"
decoder_data_url = f"https://huggingface.co/openai-community/{model_name_or_path}/resolve/main/onnx/{decoder_model_name}_data"
decoder_with_past_url = f"https://huggingface.co/openai-community/{model_name_or_path}/resolve/main/onnx/{decoder_with_past_model_name}"
decoder_with_past_data_url = f"https://huggingface.co/openai-community/{model_name_or_path}/resolve/main/onnx/{decoder_with_past_model_name}_data"
config_json_url = f"https://huggingface.co/openai-community/{model_name_or_path}/resolve/main/onnx/{config_json_name}"

# Local directories for caching the model.
cache_dir = "./"
decoder_model_path = f"{cache_dir}/{decoder_model_name}"
decoder_data_path = f"{cache_dir}/{decoder_model_name}_data"
decoder_with_past_model_path = f"{cache_dir}/{decoder_with_past_model_name}"
decoder_with_past_data_path = f"{cache_dir}/{decoder_with_past_model_name}_data"
config_json_path = f"{cache_dir}/{config_json_name}"

# Download the model to a local dir.
if not os.path.exists(decoder_model_path):
print(f"Downloading {decoder_url}")
urlretrieve(decoder_url, decoder_model_path)
print("Done")
if req.head(f"{decoder_url}_data", allow_redirects=True).status_code == 200:
print(f"Downloading {decoder_url}_data")
urlretrieve(decoder_url + "_data", decoder_model_path + "_data")
print("Done")
if not os.path.exists(decoder_data_path):
if req.head(decoder_data_url, allow_redirects=True).status_code == 200:
print(f"Downloading {decoder_data_url}")
urlretrieve(decoder_data_url, decoder_data_path)
print("Done")
if not os.path.exists(decoder_with_past_model_path):
print(f"Downloading {decoder_with_past_url}")
urlretrieve(decoder_with_past_url, decoder_with_past_model_path)
print("Done")
if req.head(f"{decoder_with_past_url}_data", allow_redirects=True).status_code == 200:
print(f"Downloading {decoder_with_past_url}_data")
urlretrieve(decoder_with_past_url + "_data", decoder_with_past_model_path + "_data")
print("Done")
if not os.path.exists(decoder_with_past_data_path):
if req.head(decoder_with_past_data_url, allow_redirects=True).status_code == 200:
print(f"Downloading {decoder_with_past_data_url}")
urlretrieve(decoder_with_past_data_url, decoder_with_past_data_path)
print("Done")
if not os.path.exists(config_json_path):
print(f"Downloading the config json file {config_json_url}")
urlretrieve(config_json_url, config_json_path)
Expand Down

0 comments on commit 1d8dd01

Please sign in to comment.