Skip to content

Commit

Permalink
Simplify pipelines logic, requiring adapter weights to be merged in
Browse files Browse the repository at this point in the history
  • Loading branch information
katalinic-gc committed Sep 1, 2023
1 parent 60fa56c commit 7427191
Showing 1 changed file with 14 additions and 20 deletions.
34 changes: 14 additions & 20 deletions optimum/graphcore/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,19 +212,13 @@ def list_tasks() -> List[str]:
return sorted([*{*SUPPORTED_TASKS, *TASK_ALIASES}])


def is_generation_model(model):
if isinstance(model, PeftModel):
model = model.get_base_model()
return isinstance(model, IPUGenerationMixin) or isinstance(model, WhisperForConditionalGeneration)


def get_poplar_executor(
task: str,
model: Union[PreTrainedModel, PeftModel],
model: PreTrainedModel,
ipu_config: Union[IPUConfig, str, dict] = None,
fp16: bool = True,
for_generation: bool = False,
) -> Union[PreTrainedModel, PeftModel]:
) -> PreTrainedModel:
ipu_config_arg = ipu_config

if isinstance(ipu_config, str):
Expand All @@ -247,7 +241,7 @@ def get_poplar_executor(
ipu_config.enable_half_partials = False
try:
model = to_pipelined(model, ipu_config, force=False)
if model.config.is_encoder_decoder and is_generation_model(model):
if model.config.is_encoder_decoder and isinstance(model, IPUGenerationMixin):
if "use_cache" not in parallelize_kwargs and model.__class__ in MODELS_SUPPORTING_KV_CACHE:
parallelize_kwargs["use_cache"] = True
model.parallelize(for_generation=for_generation, **parallelize_kwargs)
Expand Down Expand Up @@ -290,13 +284,8 @@ def check_model_type(self, supported_models: Union[List[str], dict]):
supported_models_names.append(model.__name__)
supported_models = supported_models_names

if isinstance(self.model, PeftModel):
model_class_name = self.model.get_base_model().__class__.__name__
elif isinstance(self.model, poptorch.PoplarExecutor):
model = self.model._user_model
if isinstance(model, PeftModel):
model = model.get_base_model()
model_class_name = model.__class__.__bases__[0].__name__
if isinstance(self.model, poptorch.PoplarExecutor):
model_class_name = self.model._user_model.__class__.__bases__[0].__name__
elif isinstance(self.model, IPUGenerationMixin):
model_class_name = self.model.__class__.__bases__[0].__name__
else:
Expand Down Expand Up @@ -387,21 +376,26 @@ def pipeline(
break
except ValueError:
continue
elif isinstance(model, (PreTrainedModel, PeftModel)):
elif isinstance(model, PeftModel):
raise TypeError(
"Instead of providing `model` as an instance of `PeftModel`, please call `merge_and_unload()` if LoRA "
"or equivalent to obtain the original `PreTrainedModel` back with adapter weights merged in."
)
elif isinstance(model, PreTrainedModel):
if tokenizer is None and load_tokenizer:
raise ValueError("If you pass a model as a PreTrainedModel, you must pass a tokenizer as well")
if feature_extractor is None and load_feature_extractor:
raise ValueError("If you pass a model as a PreTrainedModel, you must pass a feature extractor as well")

for_generation = targeted_task in SUPPORTED_GENERATION_TASKS
if isinstance(model, (PreTrainedModel, PeftModel)):
if isinstance(model, PreTrainedModel):
if ipu_config is None:
ipu_config = SUPPORTED_TASKS[targeted_task]["default"]["ipu_config"]

parallelize_kwargs = parallelize_kwargs or {}
# Task of automatic speech recognition is a bit of an edge case where it separates into CTC (not generation) and seq2seq (generation).
# This check will do for now.
for_generation |= is_generation_model(model)
for_generation |= isinstance(model, WhisperForConditionalGeneration)
model = get_poplar_executor(
targeted_task, model, ipu_config=ipu_config, fp16=fp16, for_generation=for_generation, **parallelize_kwargs
)
Expand Down Expand Up @@ -453,7 +447,7 @@ def new_forward(self, model_inputs, *args, **kwargs):
if compiled_bs != input_bs:
poplar_executor.destroy()

if isinstance(self.model, poptorch.PoplarExecutor) or is_generation_model(self.model):
if isinstance(self.model, poptorch.PoplarExecutor) or isinstance(self.model, IPUGenerationMixin):
if fp16:
# Support fp16
for key, input in model_inputs.items():
Expand Down

0 comments on commit 7427191

Please sign in to comment.