Skip to content

Commit

Permalink
Merge pull request #7 from iaalm/dev
Browse files Browse the repository at this point in the history
Add support for pyllama to embedding
  • Loading branch information
iaalm authored Apr 16, 2023
2 parents 70622e8 + c51b8b2 commit 1130f76
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 15 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,14 @@ models:
ckpt_dir: /absolute/path/to/your/7B/
tokenizer_path: /absolute/path/to/your/tokenizer.model
# keep to 1 instance to speed up loading of model
embeddings:
text-embedding-davinci-002:
type: pyllama_quant
params:
path: /absolute/path/to/your/pyllama-7B4b.pt
min_instance: 1
max_instance: 1
embeddings:
idle_timeout: 3600
text-embedding-ada-002:
type: llama_cpp
params:
Expand Down
14 changes: 9 additions & 5 deletions llama_api_server/model_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from functools import cache
from threading import Lock
from llama_api_server.models.llama_cpp import LlamaCppCompletion, LlamaCppEmbedding
from llama_api_server.models.pyllama import PyLlamaCompletion
from llama_api_server.models.pyllama_quant import PyLlamaQuantCompletion
from llama_api_server.models.pyllama import PyLlama
from llama_api_server.models.pyllama_quant import PyLlamaQuant
from .config import get_config

# Eventhrough python is not good at multi-threading, but must work is done by backend,
Expand All @@ -16,11 +16,15 @@
_lock = Lock()

MODEL_TYPE_MAPPING = {
"embeddings": {"llama_cpp": LlamaCppEmbedding},
"embeddings": {
"llama_cpp": LlamaCppEmbedding,
"pyllama": PyLlama,
"pyllama_quant": PyLlamaQuant,
},
"completions": {
"llama_cpp": LlamaCppCompletion,
"pyllama": PyLlamaCompletion,
"pyllama_quant": PyLlamaQuantCompletion,
"pyllama": PyLlama,
"pyllama_quant": PyLlamaQuant,
},
}

Expand Down
11 changes: 5 additions & 6 deletions llama_api_server/models/llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def completions(self, args):
repeat_penalty = 1.3

prompt = args["prompt"]
if isinstanceof(prompt, list):
prompt = prompt[0]
prompt_tokens = self.model.str_to_token(prompt, True).tolist()
n_past = _eval_token(
self.model, prompt_tokens[:-1], 0, self.n_batch, self.n_thread
Expand Down Expand Up @@ -93,11 +95,8 @@ def __init__(self, params):

def embeddings(self, args):
inputs = args["input"]
if inputs is str:
if isinstance(inputs, str):
inputs = [inputs]
is_array = False
else:
is_array = True
embeds = []

for i in inputs:
Expand All @@ -109,13 +108,13 @@ def embeddings(self, args):
embed = unpack_cfloat_array(self.model.get_embeddings())
embeds.append(embed)

if not is_array:
if len(embeds) == 1:
embeds = embeds[0]

c_prompt_tokens = len(prompt_tokens)
return {
"object": "list",
"data": [{"object": "embedding", "embedding": embed, "index": 0}],
"data": [{"object": "embedding", "embedding": embeds, "index": 0}],
"model": args["model"],
"usage": {
"prompt_tokens": c_prompt_tokens,
Expand Down
34 changes: 32 additions & 2 deletions llama_api_server/models/pyllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from llama_api_server.utils import get_uuid, get_timestamp


class PyLlamaCompletion:
class PyLlama:
def __init__(self, params):
try:
import llama
Expand Down Expand Up @@ -39,7 +39,7 @@ def __init__(self, params):
if device.startswith("cuda"):
torch.set_default_tensor_type(torch.cuda.HalfTensor)
else:
os.environ['KV_CAHCHE_IN_GPU'] = "0"
os.environ["KV_CAHCHE_IN_GPU"] = "0"
torch.set_default_tensor_type(torch.FloatTensor)
model = llama.Transformer(model_args)
torch.set_default_tensor_type(torch.FloatTensor)
Expand All @@ -48,6 +48,8 @@ def __init__(self, params):

def completions(self, args):
prompt = args["prompt"]
if isinstanceof(prompt, list):
prompt = prompt[0]
top_p = args["top_p"]
suffix = args["suffix"]
echo = args["echo"]
Expand Down Expand Up @@ -78,4 +80,32 @@ def completions(self, args):
},
}

def embeddings(self, args):
import torch

inputs = args["input"]
if isinstance(inputs, str):
inputs = [inputs]

input_ids = self.tokenizer.encode(inputs, return_tensors="pt").to(self.dev)

with torch.no_grad():
hidden_states = self.model(
input_ids, output_hidden_states=True
).hidden_states
# [0] for embedding layers
embeds = torch.squeeze(torch.mean(hidden_states[0], 1), 1).tolist()

if len(embeds) == 1:
embeds = embeds[0]

c_prompt_tokens = sum([len(i) for i in input_ids])
return {
"object": "list",
"data": [{"object": "embedding", "embedding": embeds, "index": 0}],
"model": args["model"],
"usage": {
"prompt_tokens": c_prompt_tokens,
"total_tokens": c_prompt_tokens,
},
}
34 changes: 33 additions & 1 deletion llama_api_server/models/pyllama_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from llama_api_server.utils import get_uuid, get_timestamp


class PyLlamaQuantCompletion:
class PyLlamaQuant:
def __init__(self, params):
try:
import llama
Expand Down Expand Up @@ -31,6 +31,8 @@ def completions(self, args):
import torch

prompt = args["prompt"]
if isinstanceof(prompt, list):
prompt = prompt[0]
top_p = args["top_p"]
suffix = args["suffix"]
echo = args["echo"]
Expand Down Expand Up @@ -71,3 +73,33 @@ def completions(self, args):
"total_tokens": c_prompt_tokens + c_completion_tokens,
},
}

def embeddings(self, args):
import torch

inputs = args["input"]
if isinstance(inputs, str):
inputs = [inputs]

input_ids = self.tokenizer.encode(inputs, return_tensors="pt").to(self.dev)

with torch.no_grad():
hidden_states = self.model(
input_ids, output_hidden_states=True
).hidden_states
# [0] for embedding layers
embeds = torch.squeeze(torch.mean(hidden_states[0], 1), 1).tolist()

if len(embeds) == 1:
embeds = embeds[0]

c_prompt_tokens = sum([len(i) for i in input_ids])
return {
"object": "list",
"data": [{"object": "embedding", "embedding": embeds, "index": 0}],
"model": args["model"],
"usage": {
"prompt_tokens": c_prompt_tokens,
"total_tokens": c_prompt_tokens,
},
}

0 comments on commit 1130f76

Please sign in to comment.