Skip to content
This repository has been archived by the owner on Dec 6, 2023. It is now read-only.

Commit

Permalink
Merge pull request #119 from biswaroop1547/main
Browse files Browse the repository at this point in the history
  • Loading branch information
casperdcl authored Oct 24, 2023
2 parents f96cf9b + fc96cf8 commit a2d4661
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 30 deletions.
13 changes: 5 additions & 8 deletions cht-petals/download.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
from platform import machine

import torch
from petals import AutoDistributedModelForCausalLM
Expand All @@ -8,20 +7,18 @@

parser = argparse.ArgumentParser()
parser.add_argument("--model", help="Model to download")
parser.add_argument("--dht-prefix", help="DHT prefix to use")
args = parser.parse_args()

print(f"Downloading model {args.model}")
print(f"Downloading model {args.model} with DHT prefix {args.dht_prefix}")


@retry(stop=stop_after_attempt(3), wait=wait_fixed(5))
def download_model() -> None:
Tokenizer = LlamaTokenizer if "llama" in args.model.lower() else AutoTokenizer
_ = Tokenizer.from_pretrained(args.model)

kwargs = {}
if "x86_64" in machine():
kwargs["torch_dtype"] = torch.float32
_ = AutoDistributedModelForCausalLM.from_pretrained(args.model, **kwargs)
_ = AutoDistributedModelForCausalLM.from_pretrained(
args.model, torch_dtype=torch.float32, dht_prefix=args.dht_prefix
)


download_model()
21 changes: 18 additions & 3 deletions cht-petals/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import argparse
import logging
import os

import uvicorn
from dotenv import load_dotenv
Expand All @@ -8,6 +10,19 @@

load_dotenv()

MODEL_ID = os.getenv("MODEL_ID", "petals-team/StableBeluga2")
MODEL_PATH = os.getenv("MODEL_PATH", "./models")
DHT_PREFIX = os.getenv("DHT_PREFIX", "StableBeluga2")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-id", help="HuggingFace Model", default=MODEL_ID)
parser.add_argument("--model-path", help="Path to Model files directory", default=MODEL_PATH)
parser.add_argument("--dht-prefix", help="DHT prefix to use", default=DHT_PREFIX)
parser.add_argument("--port", help="Port to run model server on", type=int, default=8000)
args = parser.parse_args()
MODEL_PATH = args.model_path
DHT_PREFIX = args.dht_prefix

logging.basicConfig(
format="%(asctime)s %(levelname)-8s %(message)s",
level=logging.INFO,
Expand All @@ -19,7 +34,8 @@ def create_start_app_handler(app: FastAPI):
def start_app() -> None:
from models import PetalsBasedModel

PetalsBasedModel.get_model()
print(f"Using {MODEL_PATH=} and {DHT_PREFIX=}")
PetalsBasedModel.get_model(MODEL_PATH, DHT_PREFIX)

return start_app

Expand All @@ -40,6 +56,5 @@ def get_application() -> FastAPI:

app = get_application()


if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=8000)
uvicorn.run(app, host="0.0.0.0", port=args.port)
21 changes: 11 additions & 10 deletions cht-petals/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import os
from abc import ABC, abstractmethod
from platform import machine
from typing import List

import torch
Expand Down Expand Up @@ -56,13 +54,16 @@ def generate(
return [cls.tokenizer.decode(outputs[0])]

@classmethod
def get_model(cls):
def get_model(
cls,
model_path: str = "./models",
dht_prefix: str = "StableBeluga2",
model_id: str = "petals-team/StableBeluga2",
):
if cls.model is None:
Tokenizer = LlamaTokenizer if "llama" in os.getenv("MODEL_ID").lower() else AutoTokenizer
cls.tokenizer = Tokenizer.from_pretrained(os.getenv("MODEL_ID"))

kwargs = {}
if "x86_64" in machine():
kwargs["torch_dtype"] = torch.float32
cls.model = AutoDistributedModelForCausalLM.from_pretrained(os.getenv("MODEL_ID"), **kwargs)
Tokenizer = LlamaTokenizer if "llama" in model_path.lower() else AutoTokenizer
cls.tokenizer = Tokenizer.from_pretrained(model_id)
cls.model = AutoDistributedModelForCausalLM.from_pretrained(
model_id, torch_dtype=torch.float32, dht_prefix=dht_prefix
)
return cls.model
18 changes: 9 additions & 9 deletions cht-petals/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
fastapi==0.95.0
uvicorn==0.21.1
pytest==7.2.2
requests==2.28.2
tqdm==4.65.0
httpx==0.23.3
python-dotenv==1.0.0
tenacity==8.2.2
petals==2.2.0
fastapi
uvicorn
pytest==7.*
requests==2.*
tqdm==4.*
httpx
python-dotenv==1.*
tenacity==8.*
petals==2.*
28 changes: 28 additions & 0 deletions cht-petals/setup-petals.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/usr/bin/env bash
# Usage: setup-petals.sh [--model-path=<DIR>] [--dht-prefix=<PREFIX>] [--port=<INT>]
set -eEuo pipefail

tmpdir="${PREM_APPDIR:-.}/petals-$(uuid)"

cleanup(){
for i in $(jobs -p); do
kill -n 9 $i || :
done
rm -rf "$tmpdir"
exit 0
}

trap "cleanup" SIGTERM
trap "cleanup" SIGINT
trap "cleanup" ERR

# clone source
git clone -n --depth=1 --filter=tree:0 https://github.com/premAI-io/prem-services.git "$tmpdir"
git -C "$tmpdir" sparse-checkout set --no-cone cht-petals
git -C "$tmpdir" checkout
# install deps
"${PREM_PYTHON:-python}" -m pip install -r "$tmpdir/cht-petals/requirements.txt"
# run server
PYTHONPATH="$tmpdir/cht-petals" "${PREM_PYTHON:-python}" "$tmpdir/cht-petals/main.py" "$@" &

wait

0 comments on commit a2d4661

Please sign in to comment.