diff --git a/cht-petals/download.py b/cht-petals/download.py index bddad6b..64c07bf 100644 --- a/cht-petals/download.py +++ b/cht-petals/download.py @@ -1,5 +1,4 @@ import argparse -from platform import machine import torch from petals import AutoDistributedModelForCausalLM @@ -8,20 +7,18 @@ parser = argparse.ArgumentParser() parser.add_argument("--model", help="Model to download") +parser.add_argument("--dht-prefix", help="DHT prefix to use") args = parser.parse_args() - -print(f"Downloading model {args.model}") +print(f"Downloading model {args.model} with DHT prefix {args.dht_prefix}") @retry(stop=stop_after_attempt(3), wait=wait_fixed(5)) def download_model() -> None: Tokenizer = LlamaTokenizer if "llama" in args.model.lower() else AutoTokenizer _ = Tokenizer.from_pretrained(args.model) - - kwargs = {} - if "x86_64" in machine(): - kwargs["torch_dtype"] = torch.float32 - _ = AutoDistributedModelForCausalLM.from_pretrained(args.model, **kwargs) + _ = AutoDistributedModelForCausalLM.from_pretrained( + args.model, torch_dtype=torch.float32, dht_prefix=args.dht_prefix + ) download_model() diff --git a/cht-petals/main.py b/cht-petals/main.py index 0b416bb..795ddd8 100644 --- a/cht-petals/main.py +++ b/cht-petals/main.py @@ -1,4 +1,6 @@ +import argparse import logging +import os import uvicorn from dotenv import load_dotenv @@ -8,6 +10,19 @@ load_dotenv() +MODEL_ID = os.getenv("MODEL_ID", "petals-team/StableBeluga2") +MODEL_PATH = os.getenv("MODEL_PATH", "./models") +DHT_PREFIX = os.getenv("DHT_PREFIX", "StableBeluga2") +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-id", help="HuggingFace Model", default=MODEL_ID) + parser.add_argument("--model-path", help="Path to Model files directory", default=MODEL_PATH) + parser.add_argument("--dht-prefix", help="DHT prefix to use", default=DHT_PREFIX) + parser.add_argument("--port", help="Port to run model server on", type=int, default=8000) + args = parser.parse_args() + MODEL_PATH = args.model_path + DHT_PREFIX = args.dht_prefix + logging.basicConfig( format="%(asctime)s %(levelname)-8s %(message)s", level=logging.INFO, @@ -19,7 +34,8 @@ def create_start_app_handler(app: FastAPI): def start_app() -> None: from models import PetalsBasedModel - PetalsBasedModel.get_model() + print(f"Using {MODEL_PATH=} and {DHT_PREFIX=}") + PetalsBasedModel.get_model(MODEL_PATH, DHT_PREFIX) return start_app @@ -40,6 +56,5 @@ def get_application() -> FastAPI: app = get_application() - if __name__ == "__main__": - uvicorn.run("main:app", host="0.0.0.0", port=8000) + uvicorn.run(app, host="0.0.0.0", port=args.port) diff --git a/cht-petals/models.py b/cht-petals/models.py index 8e83a15..7b54bad 100644 --- a/cht-petals/models.py +++ b/cht-petals/models.py @@ -1,6 +1,4 @@ -import os from abc import ABC, abstractmethod -from platform import machine from typing import List import torch @@ -56,13 +54,16 @@ def generate( return [cls.tokenizer.decode(outputs[0])] @classmethod - def get_model(cls): + def get_model( + cls, + model_path: str = "./models", + dht_prefix: str = "StableBeluga2", + model_id: str = "petals-team/StableBeluga2", + ): if cls.model is None: - Tokenizer = LlamaTokenizer if "llama" in os.getenv("MODEL_ID").lower() else AutoTokenizer - cls.tokenizer = Tokenizer.from_pretrained(os.getenv("MODEL_ID")) - - kwargs = {} - if "x86_64" in machine(): - kwargs["torch_dtype"] = torch.float32 - cls.model = AutoDistributedModelForCausalLM.from_pretrained(os.getenv("MODEL_ID"), **kwargs) + Tokenizer = LlamaTokenizer if "llama" in model_path.lower() else AutoTokenizer + cls.tokenizer = Tokenizer.from_pretrained(model_id) + cls.model = AutoDistributedModelForCausalLM.from_pretrained( + model_id, torch_dtype=torch.float32, dht_prefix=dht_prefix + ) return cls.model diff --git a/cht-petals/requirements.txt b/cht-petals/requirements.txt index 205b6a1..0863be0 100644 --- a/cht-petals/requirements.txt +++ b/cht-petals/requirements.txt @@ -1,9 +1,9 @@ -fastapi==0.95.0 -uvicorn==0.21.1 -pytest==7.2.2 -requests==2.28.2 -tqdm==4.65.0 -httpx==0.23.3 -python-dotenv==1.0.0 -tenacity==8.2.2 -petals==2.2.0 +fastapi +uvicorn +pytest==7.* +requests==2.* +tqdm==4.* +httpx +python-dotenv==1.* +tenacity==8.* +petals==2.* diff --git a/cht-petals/setup-petals.sh b/cht-petals/setup-petals.sh new file mode 100755 index 0000000..f62aabb --- /dev/null +++ b/cht-petals/setup-petals.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash +# Usage: setup-petals.sh [--model-path=] [--dht-prefix=] [--port=] +set -eEuo pipefail + +tmpdir="${PREM_APPDIR:-.}/petals-$(uuid)" + +cleanup(){ + for i in $(jobs -p); do + kill -n 9 $i || : + done + rm -rf "$tmpdir" + exit 0 +} + +trap "cleanup" SIGTERM +trap "cleanup" SIGINT +trap "cleanup" ERR + +# clone source +git clone -n --depth=1 --filter=tree:0 https://github.com/premAI-io/prem-services.git "$tmpdir" +git -C "$tmpdir" sparse-checkout set --no-cone cht-petals +git -C "$tmpdir" checkout +# install deps +"${PREM_PYTHON:-python}" -m pip install -r "$tmpdir/cht-petals/requirements.txt" +# run server +PYTHONPATH="$tmpdir/cht-petals" "${PREM_PYTHON:-python}" "$tmpdir/cht-petals/main.py" "$@" & + +wait