Skip to content
This repository has been archived by the owner on Dec 6, 2023. It is now read-only.

Commit

Permalink
fix & expose --model-id
Browse files Browse the repository at this point in the history
  • Loading branch information
casperdcl committed Oct 23, 2023
1 parent 9bec015 commit 003e668
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
4 changes: 3 additions & 1 deletion cht-petals/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@

load_dotenv()

MODEL_ID = os.getenv("MODEL_ID", "petals-team/StableBeluga2")
MODEL_PATH = os.getenv("MODEL_PATH", "./models")
DHT_PREFIX = os.getenv("DHT_PREFIX", "StableBeluga2")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-id", help="HuggingFace Model", default=MODEL_ID)
parser.add_argument("--model-path", help="Path to Model files directory", default=MODEL_PATH)
parser.add_argument("--dht-prefix", help="DHT prefix to use")
parser.add_argument("--dht-prefix", help="DHT prefix to use", default=DHT_PREFIX)
parser.add_argument("--port", help="Port to run model server on", type=int, default=8000)
args = parser.parse_args()
MODEL_PATH = args.model_path
Expand Down
14 changes: 8 additions & 6 deletions cht-petals/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from abc import ABC, abstractmethod
from typing import List

Expand Down Expand Up @@ -55,13 +54,16 @@ def generate(
return [cls.tokenizer.decode(outputs[0])]

@classmethod
def get_model(cls, model_path: str = "./models", dht_prefix: str = "StableBeluga2"):
def get_model(
cls,
model_path: str = "./models",
dht_prefix: str = "StableBeluga2",
model_id: str = "petals-team/StableBeluga2",
):
if cls.model is None:
Tokenizer = LlamaTokenizer if "llama" in model_path.lower() else AutoTokenizer
cls.tokenizer = Tokenizer.from_pretrained(os.getenv("MODEL_ID", model_path))
cls.tokenizer = Tokenizer.from_pretrained(model_id)
cls.model = AutoDistributedModelForCausalLM.from_pretrained(
os.getenv("MODEL_ID", model_path),
torch_dtype=torch.float32,
dht_prefix=dht_prefix,
model_id, torch_dtype=torch.float32, dht_prefix=dht_prefix
)
return cls.model

0 comments on commit 003e668

Please sign in to comment.