Skip to content
This repository has been archived by the owner on Jul 12, 2024. It is now read-only.

Commit

Permalink
Merge branch 'main' into prompt-engine-api
Browse files Browse the repository at this point in the history
  • Loading branch information
Akshaj000 authored Oct 5, 2023
2 parents 852f754 + 9220660 commit 1caf502
Show file tree
Hide file tree
Showing 18 changed files with 122 additions and 195 deletions.
2 changes: 1 addition & 1 deletion assets/config_custom_chromadb.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
},
"vectordb": {
"name": "chromadb",
"class_name": "llm_stack",
"class_name": "genai_stack",
"embedding": {
"name": "HuggingFaceEmbeddings",
"fields": {
Expand Down
2 changes: 1 addition & 1 deletion assets/etl.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@
},
"vectordb": {
"name": "chromadb",
"class_name": "llm_stack"
"class_name": "genai_stack"
}
}
4 changes: 2 additions & 2 deletions genai_stack/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Top-level package for genai_stack."""

__author__ = """AIM by DPhi"""
__author__ = """AI Planet Tech Team"""
__email__ = "[email protected]"
__version__ = "0.2.2"
__version__ = "0.2.3"

import os

Expand Down
1 change: 1 addition & 0 deletions genai_stack/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
VECTORDB = "/vectordb"
ETL = "/etl"
PROMPT_ENGINE = "/prompt-engine"
MODEL = "/model"
6 changes: 3 additions & 3 deletions genai_stack/etl/run.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from genai_stack.constants.etl.etl import PREBUILT_ETL_LOADERS, ETL_MODULE
from genai_stack.constants.etl.etl import AVAILABLE_ETL_LOADERS, ETL_MODULE
from genai_stack.utils.importing import import_class

from genai_stack.core import ConfigLoader


def list_etl_loaders():
return PREBUILT_ETL_LOADERS.keys()
return AVAILABLE_ETL_LOADERS.keys()


def run_etl_loader(config_file: str, vectordb):
config_cls = ConfigLoader(name="EtlLoader", config=config_file)
etl_cls = import_class(
f"{ETL_MODULE}.{PREBUILT_ETL_LOADERS.get(config_cls.config.get('etl'))}".replace(
f"{ETL_MODULE}.{AVAILABLE_ETL_LOADERS.get(config_cls.config.get('etl'))}".replace(
"/",
".",
)
Expand Down
13 changes: 13 additions & 0 deletions genai_stack/genai_server/models/model_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from pydantic import BaseModel


class ModelBaseModel(BaseModel):
pass


class ModelRequestModel(ModelBaseModel):
prompt: str


class ModelResponseModel(ModelBaseModel):
output: str
15 changes: 15 additions & 0 deletions genai_stack/genai_server/routers/model_routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from fastapi import APIRouter

from genai_stack.constant import API, MODEL
from genai_stack.genai_server.settings.settings import settings
from genai_stack.genai_server.services.model_service import ModelService
from genai_stack.genai_server.models.model_models import ModelResponseModel, ModelRequestModel

service = ModelService(store=settings.STORE)

router = APIRouter(prefix=API + MODEL, tags=["model"])


@router.post("/predict")
def predict(data: ModelRequestModel) -> ModelResponseModel:
return service.predict(data=data)
11 changes: 9 additions & 2 deletions genai_stack/genai_server/server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from fastapi import FastAPI

from genai_stack.genai_server.routers import session_routes, retriever_routes, vectordb_routes, etl_routes, prompt_engine_routes
from genai_stack.genai_server.routers import (
session_routes,
retriever_routes,
vectordb_routes,
etl_routes,
prompt_engine_routes,
model_routes,
)


def get_genai_server_app():
Expand All @@ -21,5 +27,6 @@ def get_genai_server_app():
app.include_router(vectordb_routes.router)
app.include_router(etl_routes.router)
app.include_router(prompt_engine_routes.router)
app.include_router(model_routes.router)

return app
13 changes: 13 additions & 0 deletions genai_stack/genai_server/services/model_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from genai_stack.genai_platform.services.base_service import BaseService
from genai_stack.genai_server.models.model_models import ModelRequestModel, ModelResponseModel
from genai_stack.genai_server.utils import get_current_stack
from genai_stack.genai_server.settings.config import stack_config


class ModelService(BaseService):
def predict(self, data: ModelRequestModel) -> ModelResponseModel:
stack = get_current_stack(config=stack_config)
response = stack.model.predict(data.prompt)
return ModelResponseModel(
output=response["output"],
)
File renamed without changes.
1 change: 0 additions & 1 deletion genai_stack/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .server import HttpServer
from .base import BaseModel
from .gpt3_5 import OpenAIGpt35Model
from .run import list_supported_models, get_model_class, AVAILABLE_MODEL_MAPS, run_custom_model
Expand Down
54 changes: 42 additions & 12 deletions genai_stack/model/gpt3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,68 @@

class OpenAIGpt35Parameters(BaseModelConfigModel):
model_name: str = Field(default="gpt-3.5-turbo-16k", alias="model")
"""Model name to use."""
"""
Model name to use.
"""
temperature: float = 0
"""What sampling temperature to use."""
"""
What sampling temperature to use.
"""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
"""
Holds any model parameters valid for `create` call not explicitly specified.
"""
openai_api_key: str
"""Base URL path for API requests,
leave blank if not using a proxy or service emulator."""
"""
Base URL path for API requests,
leave blank if not using a proxy or service emulator.
"""
openai_api_base: Optional[str] = None
openai_organization: Optional[str] = None
# to support explicit proxy for OpenAI
openai_proxy: Optional[str] = None
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
"""
Timeout for requests to OpenAI completion API. Default is 600 seconds.
"""
max_retries: int = 6
"""Maximum number of retries to make when generating."""
"""
Maximum number of retries to make when generating.
"""
streaming: bool = False
"""Whether to stream the results or not."""
"""
Whether to stream the results or not.
"""
n: int = 1
"""Number of chat completions to generate for each prompt."""
"""
Number of chat completions to generate for each prompt.
"""
max_tokens: Optional[int] = None
"""Maximum number of tokens to generate."""
"""
Maximum number of tokens to generate.
"""
tiktoken_model_name: Optional[str] = None
"""The model name to pass to tiktoken when using this class.
"""
The model name to pass to tiktoken when using this class.
Tiktoken is used to count the number of tokens in documents to constrain
them to be under a certain limit. By default, when set to None, this will
be the same as the embedding model name. However, there are some cases
where you may want to use this Embedding class with a model name not
supported by tiktoken. This can include when using Azure embeddings or
when using one of the many model providers that expose an OpenAI-like
API but with different models. In those cases, in order to avoid erroring
when tiktoken is called, you can specify a model name to use here."""
when tiktoken is called, you can specify a model name to use here.
"""


class OpenAIGpt35ModelConfigModel(BaseModelConfigModel):
Expand Down
14 changes: 4 additions & 10 deletions genai_stack/model/hf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from typing import Optional, Dict

import torch
from langchain.llms import HuggingFacePipeline

from genai_stack.model.base import BaseModel, BaseModelConfig, BaseModelConfigModel
Expand Down Expand Up @@ -31,18 +29,14 @@ class HuggingFaceModel(BaseModel):
def _post_init(self, *args, **kwargs):
self.model = self.load()

def get_device(self):
return torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

def load(self):
model = HuggingFacePipeline.from_model_id(
model_id=self.config.model,
task=self.config.task,
model_kwargs=self.config.model_kwargs,
device=self.get_device(),
model_id=self.config.model, task=self.config.task, model_kwargs=self.config.model_kwargs
)
return model

def predict(self, prompt: str):
response = self.model(prompt)
return {"output": response[0]["generated_text"]}
# Note: Huggingface model response format is different for different model
# so user should extract the info which is required.
return {"output": response}
86 changes: 0 additions & 86 deletions genai_stack/model/server.py

This file was deleted.

40 changes: 6 additions & 34 deletions genai_stack/retriever/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class BaseRetrieverConfigModel(BaseModel):
"""
Data Model for the configs
"""

pass


Expand All @@ -16,55 +17,26 @@ class BaseRetrieverConfig(StackComponentConfig):
class BaseRetriever(StackComponent):
config_class = BaseRetrieverConfig

def get_prompt(self, query:str):
def get_prompt(self, query: str):
"""
This method returns the prompt template from the prompt engine component
"""
return self.mediator.get_prompt_template(query)

def retrieve(self, query:str) -> dict:
def retrieve(self, query: str) -> dict:
"""
This method returns the model response for the prompt template.
"""
raise NotImplementedError()
def get_context(self, query:str):

def get_context(self, query: str):
"""
This method returns the relevant documents returned by the similarity search from a vectordb based on the query
"""
raise NotImplementedError()

def get_chat_history(self) -> str:
"""
This method returns the chat conversation history
"""
return self.mediator.get_chat_history()


# from typing import Any

# from genai_stack.core import BaseComponent
# from genai_stack.constants.retriever import RETRIEVER_CONFIG_KEY
# from genai_stack.vectordb.base import BaseVectordb

# class BaseRetriever(BaseComponent):
# module_name = "BaseRetriever"
# config_key = RETRIEVER_CONFIG_KEY

# def __init__(self, config: str, vectordb: BaseVectordb = None):
# super().__init__(self.module_name, config)
# self.parse_config(self.config_key, self.required_fields)
# self.vectordb = vectordb

# def retrieve(self, query: Any):
# raise NotImplementedError()

# def get_langchain_retriever(self):
# return self.vectordb.get_langchain_client().as_retriever()

# def get_langchain_memory_retriever(self):
# return self.vectordb.get_langchain_memory_client().as_retriever()

# @classmethod
# def from_config(cls, config):
# raise NotImplementedError
Loading

0 comments on commit 1caf502

Please sign in to comment.