diff --git a/genai_stack/model/hf.py b/genai_stack/model/hf.py index 95253494..48c4299a 100644 --- a/genai_stack/model/hf.py +++ b/genai_stack/model/hf.py @@ -1,5 +1,6 @@ from typing import Optional, Dict from langchain.llms import HuggingFacePipeline +from transformers import pipeline from genai_stack.model.base import BaseModel, BaseModelConfig, BaseModelConfigModel @@ -17,6 +18,8 @@ class HuggingFaceModelConfigModel(BaseModelConfigModel): """Key word arguments passed to the pipeline.""" task: str = "text-generation" """Valid tasks: 'text2text-generation', 'text-generation', 'summarization'""" + pipeline: Optional[pipeline] = None + """If pipeline is passed, all other configs are ignored.""" class HuggingFaceModelConfig(BaseModelConfig): @@ -30,9 +33,14 @@ def _post_init(self, *args, **kwargs): self.model = self.load() def load(self): - model = HuggingFacePipeline.from_model_id( - model_id=self.config.model, task=self.config.task, model_kwargs=self.config.model_kwargs - ) + if self.config.pipeline is not None: + model = self.config.pipeline + else: + model = HuggingFacePipeline.from_model_id( + model_id=self.config.model, + task=self.config.task, + model_kwargs=self.config.model_kwargs, + ) return model def predict(self, prompt: str):