-
Notifications
You must be signed in to change notification settings - Fork 64
/
language_models.py
143 lines (116 loc) · 5.91 KB
/
language_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import litellm
from config import TOGETHER_MODEL_NAMES, LITELLM_TEMPLATES, API_KEY_NAMES, Model
from loggers import logger
from common import get_api_key
class LanguageModel():
def __init__(self, model_name):
self.model_name = Model(model_name)
def batched_generate(self, prompts_list: list, max_n_tokens: int, temperature: float):
"""
Generates responses for a batch of prompts using a language model.
"""
raise NotImplementedError
class APILiteLLM(LanguageModel):
API_RETRY_SLEEP = 10
API_ERROR_OUTPUT = "ERROR: API CALL FAILED."
API_QUERY_SLEEP = 1
API_MAX_RETRY = 5
API_TIMEOUT = 20
def __init__(self, model_name):
super().__init__(model_name)
self.api_key = get_api_key(self.model_name)
self.litellm_model_name = self.get_litellm_model_name(self.model_name)
litellm.drop_params=True
self.set_eos_tokens(self.model_name)
def get_litellm_model_name(self, model_name):
if model_name in TOGETHER_MODEL_NAMES:
litellm_name = TOGETHER_MODEL_NAMES[model_name]
self.use_open_source_model = True
else:
self.use_open_source_model = False
#if self.use_open_source_model:
# Output warning, there should be a TogetherAI model name
#logger.warning(f"Warning: No TogetherAI model name for {model_name}.")
litellm_name = model_name.value
return litellm_name
def set_eos_tokens(self, model_name):
if self.use_open_source_model:
self.eos_tokens = LITELLM_TEMPLATES[model_name]["eos_tokens"]
else:
self.eos_tokens = []
def _update_prompt_template(self):
# We manually add the post_message later if we want to seed the model response
if self.model_name in LITELLM_TEMPLATES:
litellm.register_prompt_template(
initial_prompt_value=LITELLM_TEMPLATES[self.model_name]["initial_prompt_value"],
model=self.litellm_model_name,
roles=LITELLM_TEMPLATES[self.model_name]["roles"]
)
self.post_message = LITELLM_TEMPLATES[self.model_name]["post_message"]
else:
self.post_message = ""
def batched_generate(self, convs_list: list[list[dict]],
max_n_tokens: int,
temperature: float,
top_p: float,
extra_eos_tokens: list[str] = None) -> list[str]:
eos_tokens = self.eos_tokens
if extra_eos_tokens:
eos_tokens.extend(extra_eos_tokens)
if self.use_open_source_model:
self._update_prompt_template()
outputs = litellm.batch_completion(
model=self.litellm_model_name,
messages=convs_list,
api_key=self.api_key,
temperature=temperature,
top_p=top_p,
max_tokens=max_n_tokens,
num_retries=self.API_MAX_RETRY,
seed=0,
stop=eos_tokens,
)
responses = [output["choices"][0]["message"].content for output in outputs]
return responses
# class LocalvLLM(LanguageModel):
# def __init__(self, model_name: str):
# """Initializes the LLMHuggingFace with the specified model name."""
# super().__init__(model_name)
# if self.model_name not in MODEL_NAMES:
# raise ValueError(f"Invalid model name: {model_name}")
# self.hf_model_name = HF_MODEL_NAMES[Model(model_name)]
# destroy_model_parallel()
# self.model = vllm.LLM(model=self.hf_model_name)
# if self.temperature > 0:
# self.sampling_params = vllm.SamplingParams(
# temperature=self.temperature, top_p=self.top_p, max_tokens=self.max_n_tokens
# )
# else:
# self.sampling_params = vllm.SamplingParams(temperature=0, max_tokens=self.max_n_tokens)
# def _get_responses(self, prompts_list: list[str], max_new_tokens: int | None = None) -> list[str]:
# """Generates responses from the model for the given prompts."""
# full_prompt_list = self._prompt_to_conv(prompts_list)
# outputs = self.model.generate(full_prompt_list, self.sampling_params)
# # Get output from each input, but remove initial space
# outputs_list = [output.outputs[0].text[1:] for output in outputs]
# return outputs_list
# def _prompt_to_conv(self, prompts_list):
# batchsize = len(prompts_list)
# convs_list = [self._init_conv_template() for _ in range(batchsize)]
# full_prompts = []
# for conv, prompt in zip(convs_list, prompts_list):
# conv.append_message(conv.roles[0], prompt)
# conv.append_message(conv.roles[1], None)
# full_prompt = conv.get_prompt()
# # Need this to avoid extraneous space in generation
# if "llama-2-7b-chat-hf" in self.model_name:
# full_prompt += " "
# full_prompts.append(full_prompt)
# return full_prompts
# def _init_conv_template(self):
# template = get_conversation_template(self.hf_model_name)
# if "llama" in self.hf_model_name:
# # Add the system prompt for Llama as FastChat does not include it
# template.system_message = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
# return template