Skip to content

Commit

Permalink
bring back functionalities that were lost in v2 during rebasing
Browse files Browse the repository at this point in the history
  • Loading branch information
dbogunowicz committed Dec 6, 2023
1 parent 39be9a0 commit dcab3f9
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 23 deletions.
31 changes: 10 additions & 21 deletions src/deepsparse/transformers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@
from onnx import ModelProto

from deepsparse.log import get_main_logger
from deepsparse.utils.onnx import MODEL_ONNX_NAME, truncate_onnx_model
from sparsezoo import Model
from deepsparse.utils.onnx import MODEL_ONNX_NAME, model_to_path, truncate_onnx_model
from sparsezoo.utils import save_onnx


__all__ = [
"get_deployment_path",
"setup_transformers_pipeline",
"overwrite_transformer_onnx_model_inputs",
"fix_numpy_types",
Expand All @@ -62,12 +62,12 @@ def setup_transformers_pipeline(
:param sequence_length: The sequence length to use for the model
:param tokenizer_padding_side: The side to pad on for the tokenizer,
either "left" or "right"
:param engine_kwargs: The kwargs to pass to the engine
:param onnx_model_name: The name of the onnx model to be loaded.
If not specified, defaults are used (see setup_onnx_file_path)
:param engine_kwargs: The kwargs to pass to the engine
:return The model path, config, tokenizer, and engine kwargs
"""
model_path, config, tokenizer = setup_onnx_file_path(
model_path, config, tokenizer = fetch_onnx_file_path(
model_path, sequence_length, onnx_model_name
)

Expand All @@ -87,7 +87,7 @@ def setup_transformers_pipeline(
return model_path, config, tokenizer, engine_kwargs


def setup_onnx_file_path(
def fetch_onnx_file_path(
model_path: str,
sequence_length: int,
onnx_model_name: Optional[str] = None,
Expand All @@ -102,6 +102,7 @@ def setup_onnx_file_path(
:param onnx_model_name: optionally, the precise name of the ONNX model
of interest may be specified. If not specified, the default ONNX model
name will be used (refer to `get_deployment_path` for details)
:param task: task to use for the config. Defaults to None
:return: file path to the processed ONNX file for the engine to compile
"""
deployment_path, onnx_path = get_deployment_path(model_path, onnx_model_name)
Expand Down Expand Up @@ -148,6 +149,7 @@ def get_deployment_path(
the deployment directory
"""
onnx_model_name = onnx_model_name or MODEL_ONNX_NAME

if os.path.isfile(model_path):
# return the parent directory of the ONNX file
return os.path.dirname(model_path), model_path
Expand All @@ -163,22 +165,9 @@ def get_deployment_path(
)
return model_path, os.path.join(model_path, onnx_model_name)

elif model_path.startswith("zoo:"):
zoo_model = Model(model_path)
deployment_path = zoo_model.deployment_directory_path
return deployment_path, os.path.join(deployment_path, onnx_model_name)
elif model_path.startswith("hf:"):
from huggingface_hub import snapshot_download

deployment_path = snapshot_download(repo_id=model_path.replace("hf:", "", 1))
onnx_path = os.path.join(deployment_path, onnx_model_name)
if not os.path.isfile(onnx_path):
raise ValueError(
f"{onnx_model_name} not found in transformers model directory "
f"{deployment_path}. Be sure that an export of the model is written to "
f"{onnx_path}"
)
return deployment_path, onnx_path
elif model_path.startswith("zoo:") or model_path.startswith("hf:"):
onnx_model_path = model_to_path(model_path)
return os.path.dirname(onnx_model_path), onnx_model_path
else:
raise ValueError(
f"model_path {model_path} is not a valid file, directory, or zoo stub"
Expand Down
19 changes: 17 additions & 2 deletions src/deepsparse/utils/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def model_to_path(model: Union[str, Model, File]) -> str:

if Model is not object and isinstance(model, Model):
# trigger download and unzipping of deployment directory if not cached
model.deployment_directory_path
model.deployment.path

# default to the main onnx file for the model
model = model.deployment.get_file(MODEL_ONNX_NAME).path
Expand All @@ -138,6 +138,21 @@ def model_to_path(model: Union[str, Model, File]) -> str:
# get the downloaded_path -- will auto download if not on local system
model = model.path

if isinstance(model, str) and model.startswith("hf:"):
# load Hugging Face model from stub
from huggingface_hub import snapshot_download

deployment_path = snapshot_download(repo_id=model.replace("hf:", "", 1))
onnx_path = os.path.join(deployment_path, MODEL_ONNX_NAME)
if not os.path.isfile(onnx_path):
raise ValueError(
f"Could not find the ONNX model file '{MODEL_ONNX_NAME}' in the "
f"Hugging Face Hub repository located at {deployment_path}. Please "
f"ensure the model has been correctly exported to ONNX format and "
f"exists in the repository."
)
return onnx_path

if not isinstance(model, str):
raise ValueError("unsupported type for model: {}".format(type(model)))

Expand Down Expand Up @@ -549,7 +564,7 @@ def overwrite_onnx_model_inputs_for_kv_cache_models(
else:
raise ValueError(f"Unexpected external input name: {external_input.name}")

_LOGGER.info(
_LOGGER.debug(
"Overwriting in-place the input shapes "
f"of the transformer model at {onnx_file_path}"
)
Expand Down

0 comments on commit dcab3f9

Please sign in to comment.