Skip to content

Commit

Permalink
Follow up the diffusers task refactoring (#1999)
Browse files Browse the repository at this point in the history
* fix

* fix style
  • Loading branch information
JingyaHuang authored Aug 30, 2024
1 parent 23f8574 commit 3b55875
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1937,12 +1937,6 @@ def standardize_model_attributes(cls, model: Union["PreTrainedModel", "TFPreTrai
if inferred_model_type is not None:
break

if inferred_model_type is None:
raise ValueError(
f"The export of a DiffusionPipeline model with the class name {model.__class__.__name__} is currently not supported in Optimum. "
"Please open an issue or submit a PR to add the support."
)

# `model_type` is a class attribute in Transformers, let's avoid modifying it.
model.config.export_model_type = inferred_model_type

Expand Down Expand Up @@ -2068,9 +2062,16 @@ def get_model_from_task(
if original_task == "auto" and config.architectures is not None:
model_class_name = config.architectures[0]

model_class = TasksManager.get_model_class_for_task(
task, framework, model_type=model_type, model_class_name=model_class_name, library=library_name
)
if library_name == "diffusers":
config = DiffusionPipeline.load_config(model_name_or_path, **kwargs)
class_name = config.get("_class_name", None)
loaded_library = importlib.import_module(library_name)
model_class = getattr(loaded_library, class_name)
else:
model_class = TasksManager.get_model_class_for_task(
task, framework, model_type=model_type, model_class_name=model_class_name, library=library_name
)

if library_name == "timm":
model = model_class(f"hf_hub:{model_name_or_path}", pretrained=True, exportable=True)
model = model.to(torch_dtype).to(device)
Expand Down

0 comments on commit 3b55875

Please sign in to comment.