Skip to content
This repository has been archived by the owner on Dec 6, 2023. It is now read-only.

Added petals models #104

Merged
merged 4 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions cht-petals/.dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
.editorconfig
.gitattributes
.github
.gitignore
.gitlab-ci.yml
.idea
.pre-commit-config.yaml
.readthedocs.yml
.travis.yml
venv
.git
./ml/models/
.bin
10 changes: 10 additions & 0 deletions cht-petals/build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash
set -e
export VERSION=1.0.0
source "$(dirname "${BASH_SOURCE[0]}")/../utils.sh"

# TODO: support linux/amd64
BUILDX_PLATFORM=linux/arm64 TESTS_SKIP_CPU=1 \
build_cpu ghcr.io/premai-io/chat-stable-beluga-2-cpu petals-team/StableBeluga2 ${@:1}
BUILDX_PLATFORM=linux/arm64 TESTS_SKIP_CPU=1 \
build_cpu ghcr.io/premai-io/chat-codellama-34b-cpu premai-io/CodeLlama-34b-Instruct-hf ${@:1}
22 changes: 22 additions & 0 deletions cht-petals/docker/cpu/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
FROM python:3.10-slim-bullseye

ARG MODEL_ID

RUN apt update && apt install -y libopenblas-dev ninja-build build-essential wget git
RUN python -m pip install --upgrade pip pytest cmake scikit-build setuptools

WORKDIR /usr/src/app/

COPY requirements.txt ./

RUN pip install --no-cache-dir -r ./requirements.txt --upgrade pip

COPY download.py .

RUN python3 download.py --model $MODEL_ID

COPY . .

ENV MODEL_ID=$MODEL_ID

CMD python main.py
23 changes: 23 additions & 0 deletions cht-petals/download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import argparse

from petals import AutoDistributedModelForCausalLM
from tenacity import retry, stop_after_attempt, wait_fixed
from transformers import AutoTokenizer, LlamaTokenizer

parser = argparse.ArgumentParser()
parser.add_argument("--model", help="Model to download")
args = parser.parse_args()

print(f"Downloading model {args.model}")


@retry(stop=stop_after_attempt(3), wait=wait_fixed(5))
def download_model() -> None:
if "llama" in args.model.lower():
_ = LlamaTokenizer.from_pretrained(args.model)
else:
_ = AutoTokenizer.from_pretrained(args.model)
_ = AutoDistributedModelForCausalLM.from_pretrained(args.model)


download_model()
45 changes: 45 additions & 0 deletions cht-petals/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import logging

import uvicorn
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from routes import router as api_router

load_dotenv()

logging.basicConfig(
format="%(asctime)s %(levelname)-8s %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)


def create_start_app_handler(app: FastAPI):
def start_app() -> None:
from models import PetalsBasedModel

PetalsBasedModel.get_model()

return start_app


def get_application() -> FastAPI:
application = FastAPI(title="prem-chat", debug=True, version="0.0.1")
application.include_router(api_router, prefix="/v1")
application.add_event_handler("startup", create_start_app_handler(application))
application.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
return application


app = get_application()


if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=8000)
67 changes: 67 additions & 0 deletions cht-petals/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os
from abc import ABC, abstractmethod
from typing import List

from petals import AutoDistributedModelForCausalLM
from transformers import AutoTokenizer, LlamaTokenizer, logging

logging.set_verbosity_error()


class ChatModel(ABC):
@abstractmethod
def get_model(cls):
pass

@abstractmethod
def generate(
cls,
messages: list,
temperature: float = 0.9,
top_p: float = 0.9,
n: int = 1,
stream: bool = False,
max_tokens: int = 128,
stop: str = "",
**kwargs,
):
pass

@abstractmethod
def embeddings(cls, text):
pass


class PetalsBasedModel(ChatModel):
model = None
tokenizer = None

@classmethod
def generate(
cls,
messages: list,
temperature: float = 0.9,
top_p: float = 0.9,
n: int = 1,
stream: bool = False,
max_tokens: int = 128,
stop: str = "",
**kwargs,
) -> List:
message = messages[-1]["content"]
inputs = cls.tokenizer(message, return_tensors="pt")["input_ids"]
outputs = cls.model.generate(inputs, max_new_tokens=5)
print(cls.tokenizer.decode(outputs[0]))
return [cls.tokenizer.decode(outputs[0])]

@classmethod
def get_model(cls):
if cls.model is None:
if "llama" in os.getenv("MODEL_ID").lower():
cls.tokenizer = LlamaTokenizer.from_pretrained(os.getenv("MODEL_ID"))
else:
cls.tokenizer = AutoTokenizer.from_pretrained(os.getenv("MODEL_ID"))
cls.model = AutoDistributedModelForCausalLM.from_pretrained(
os.getenv("MODEL_ID")
)
return cls.model
9 changes: 9 additions & 0 deletions cht-petals/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
fastapi==0.95.0
uvicorn==0.21.1
pytest==7.2.2
requests==2.28.2
tqdm==4.65.0
httpx==0.23.3
python-dotenv==1.0.0
tenacity==8.2.2
petals==2.2.0
107 changes: 107 additions & 0 deletions cht-petals/routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import json
import uuid
from datetime import datetime as dt
from typing import List, Optional, Union

from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from models import PetalsBasedModel as model
from pydantic import BaseModel


class ChatCompletionInput(BaseModel):
model: str
messages: List[dict]
temperature: float = 1.0
top_p: float = 1.0
n: int = 1
stream: bool = False
stop: Optional[Union[str, List[str]]] = ""
max_tokens: int = 7
presence_penalty: float = 0.0
frequence_penalty: float = 0.0
logit_bias: Optional[dict] = {}
user: str = ""


class ChatCompletionResponse(BaseModel):
id: str = uuid.uuid4()
model: str
object: str = "chat.completion"
created: int = int(dt.now().timestamp())
choices: List[dict]
usage: dict = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}


class HealthResponse(BaseModel):
status: bool


router = APIRouter()


@router.get("/", response_model=HealthResponse)
async def health():
return HealthResponse(status=True)


async def generate_chunk_based_response(body, text):
yield "event: completion\ndata: " + json.dumps(
{
"id": str(uuid.uuid4()),
"model": body.model,
"object": "chat.completion",
"choices": [
{
"role": "assistant",
"index": 1,
"delta": {"role": "assistant", "content": text},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
}
) + "\n\n"
yield "event: done\ndata: [DONE]\n\n"


@router.post("/chat/completions", response_model=ChatCompletionResponse)
async def chat_completions(body: ChatCompletionInput):
try:
predictions = model.generate(
messages=body.messages,
temperature=body.temperature,
top_p=body.top_p,
n=body.n,
stream=body.stream,
max_tokens=body.max_tokens,
stop=body.stop,
presence_penalty=body.presence_penalty,
frequence_penalty=body.frequence_penalty,
logit_bias=body.logit_bias,
)
if body.stream:
return StreamingResponse(
generate_chunk_based_response(body, predictions[0]),
media_type="text/event-stream",
)
return ChatCompletionResponse(
id=str(uuid.uuid4()),
model=body.model,
object="chat.completion",
choices=[
{
"role": "assistant",
"index": idx,
"message": {"role": "assistant", "content": text},
"finish_reason": "stop",
}
for idx, text in enumerate(predictions)
],
usage={"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
)
except ValueError as error:
raise HTTPException(
status_code=400,
detail={"message": str(error)},
)
Empty file added cht-petals/tests/__init__.py
Empty file.
26 changes: 26 additions & 0 deletions cht-petals/tests/test_views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from fastapi.testclient import TestClient
from main import get_application


def test_chat_llama_cpp() -> None:
app = get_application()
with TestClient(app) as client:
response = client.post(
"/v1/chat/completions",
json={
"model": "stable-beluga",
"messages": [{"role": "user", "content": "Hello!"}],
"n_threads": 10,
},
)
assert response.status_code == 200

response = client.post(
"/v1/chat/completions",
json={
"stream": True,
"model": "stable-beluga",
"messages": [{"role": "user", "content": "Hello!"}],
},
)
assert response.status_code == 200