Skip to content
This repository has been archived by the owner on Jul 12, 2024. It is now read-only.

Commit

Permalink
Add pipeline as kwargs to hf models
Browse files Browse the repository at this point in the history
  • Loading branch information
Akshaj000 committed Oct 22, 2023
1 parent afad25a commit 18e1439
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions genai_stack/model/hf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional, Dict
from langchain.llms import HuggingFacePipeline
from transformers import pipeline

from genai_stack.model.base import BaseModel, BaseModelConfig, BaseModelConfigModel

Expand All @@ -17,6 +18,8 @@ class HuggingFaceModelConfigModel(BaseModelConfigModel):
"""Key word arguments passed to the pipeline."""
task: str = "text-generation"
"""Valid tasks: 'text2text-generation', 'text-generation', 'summarization'"""
pipeline: Optional[pipeline] = None
"""If pipeline is passed, all other configs are ignored."""


class HuggingFaceModelConfig(BaseModelConfig):
Expand All @@ -30,9 +33,14 @@ def _post_init(self, *args, **kwargs):
self.model = self.load()

def load(self):
model = HuggingFacePipeline.from_model_id(
model_id=self.config.model, task=self.config.task, model_kwargs=self.config.model_kwargs
)
if self.config.pipeline is not None:
model = self.config.pipeline
else:
model = HuggingFacePipeline.from_model_id(
model_id=self.config.model,
task=self.config.task,
model_kwargs=self.config.model_kwargs,
)
return model

def predict(self, prompt: str):
Expand Down

0 comments on commit 18e1439

Please sign in to comment.