-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add offline inference with vllm backend (#143)
- Loading branch information
1 parent
215bc3a
commit d05ca66
Showing
6 changed files
with
494 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
defaults: | ||
- _self_ | ||
- inference: inference_aquila_7b | ||
|
||
experiment: | ||
exp_name: aquila2 | ||
exp_dir: ./outputs | ||
task: | ||
type: inference | ||
backend: vllm | ||
entrypoint: ./flagscale/inference/inference_aquila.py | ||
runner: | ||
hostfile: /share/project/zhaoyingli/hostfile | ||
envs: | ||
CUDA_VISIBLE_DEVICES: 0,1,2,3,4,5,6,7 | ||
CUDA_DEVICE_MAX_CONNECTIONS: 1 | ||
|
||
action: run | ||
|
||
hydra: | ||
run: | ||
dir: ${experiment.exp_dir}/hydra |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
engine: | ||
model: BAAI/Aquila-7B/ | ||
tokenizer: BAAI/Aquila-7B/ | ||
trust_remote_code: true | ||
tensor_parallel_size: 1 | ||
pipeline_parallel_size: 1 | ||
gpu_memory_utilization: 0.6 | ||
dtype: bfloat16 | ||
seed: 1234 | ||
|
||
data: | ||
prompts: [ | ||
"Hello, my name is", | ||
"The president of the United States is", | ||
"The capital of France is", | ||
"The future of AI is", | ||
] | ||
# prompts_path: null | ||
top_p: 0.95 | ||
top_k: 100 | ||
max_tokens: 7 | ||
temperature: 0.9 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import argparse | ||
from typing import List, Union | ||
|
||
from vllm import EngineArgs | ||
|
||
|
||
def parse_args(ignore_unknown_args=False): | ||
parser = argparse.ArgumentParser(description='vLLM Inference') | ||
_add_additional_args(parser) | ||
_add_vllm_engine_args(parser) | ||
_add_sampling_args(parser) | ||
|
||
if ignore_unknown_args: | ||
args, _ = parser.parse_known_args() | ||
else: | ||
args = parser.parse_args() | ||
|
||
return args | ||
|
||
|
||
def _add_additional_args(parser): | ||
group = parser.add_argument_group(title='vLLM-additional-args') | ||
|
||
group.add_argument("--prompts-path", | ||
type=str, | ||
default=None, | ||
help="the text file contain the prompts") | ||
group.add_argument("--prompts", | ||
nargs='*', | ||
help="A list of prompts to generate completions for.") | ||
|
||
|
||
def _add_vllm_engine_args(parser): | ||
group = parser.add_argument_group(title='vLLM-Engine') | ||
group = EngineArgs.add_cli_args(group) | ||
return parser | ||
|
||
|
||
def _add_sampling_args(parser): | ||
group = parser.add_argument_group(title='vLLM-sampling-params') | ||
|
||
group.add_argument("--n", | ||
type=int, | ||
default=1, | ||
help="Number of output sequences to return for the given prompt.") | ||
group.add_argument("--best_of", | ||
type=int, | ||
default=None, | ||
help="Number of output sequences that are generated from the prompt.") | ||
group.add_argument("--presence-penalty", | ||
type=float, | ||
default=0.0, | ||
help="Float that penalizes new tokens based on whether they appear in the generated text so far.") | ||
group.add_argument("--frequency-penalty", | ||
type=float, | ||
default=0.0, | ||
help="Float that penalizes new tokens based on their frequency in the generated text so far.") | ||
group.add_argument("--repetition-penalty", | ||
type=float, | ||
default=1.0, | ||
help="Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far.") | ||
group.add_argument("--temperature", | ||
type=float, | ||
default=1.0, | ||
help="Float that controls the randomness of the sampling.") | ||
group.add_argument("--top-p", | ||
type=float, | ||
default=1.0, | ||
help="Float that controls the cumulative probability of the top tokens to consider.") | ||
group.add_argument("--top-k", | ||
type=int, | ||
default=-1, | ||
help="Integer that controls the number of top tokens to consider.") | ||
group.add_argument("--min-p", | ||
type=float, | ||
default=0.0, | ||
help="Float that represents the minimum probability for a token to be considered.") | ||
group.add_argument("--use-beam-search", | ||
type=bool, | ||
default=False, | ||
help="Whether to use beam search instead of sampling.") | ||
group.add_argument("--length-penalty", | ||
type=float, | ||
default=1.0, | ||
help="Float that penalizes sequences based on their length.") | ||
group.add_argument("--early-stopping", | ||
type=Union[bool, str], | ||
default=False, | ||
help="Controls the stopping condition for beam search.") | ||
group.add_argument("--stop", | ||
type=Union[str, List[str]], | ||
default=None, | ||
help="List of strings that stop the generation when they are generated.") | ||
group.add_argument("--stop-token-ids", | ||
type=List[int], | ||
default=None, | ||
help="List of tokens that stop the generation when they are generated.") | ||
group.add_argument("--include-stop-str-in-output", | ||
type=bool, | ||
default=False, | ||
help="Whether to include the stop strings in output text.") | ||
group.add_argument("--ignore-eos", | ||
type=bool, | ||
default=False, | ||
help="Whether to ignore the EOS token and continue generating tokens after the EOS token is generated.") | ||
group.add_argument("--max-tokens", | ||
type=int, | ||
default=16, | ||
help="Maximum number of tokens to generate per output sequence.") | ||
group.add_argument("--min-tokens", | ||
type=int, | ||
default=0, | ||
help="Minimum number of tokens to generate per output sequence before EOS or stop_token_ids can be generated.") | ||
group.add_argument("--logprobs", | ||
type=int, | ||
default=None, | ||
help="Number of log probabilities to return per output token.") | ||
group.add_argument("--prompt-logprobs", | ||
type=int, | ||
default=None, | ||
help="Number of log probabilities to return per prompt token.") | ||
group.add_argument("--detokenize", | ||
type=bool, | ||
default=True, | ||
help="Whether to detokenize the output.") | ||
group.add_argument("--skip-special-tokens", | ||
type=bool, | ||
default=True, | ||
help="Whether to skip special tokens in the output.") | ||
group.add_argument("--spaces-between-special-tokens", | ||
type=bool, | ||
default=True, | ||
help="Whether to add spaces between special tokens in the output.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import argparse | ||
from typing import List, Union | ||
|
||
import torch | ||
|
||
from transformers import AutoTokenizer, LlamaForCausalLM, GenerationConfig | ||
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams | ||
|
||
from arguments import parse_args | ||
|
||
|
||
def process_requests(prompts: List[str], | ||
engine: LLMEngine, | ||
sampling_params: SamplingParams): | ||
"""Continuously process a list of prompts and handle the outputs.""" | ||
request_id = 0 | ||
while prompts: | ||
prompt = prompts.pop(0) | ||
engine.add_request(str(request_id), prompt, sampling_params) | ||
request_id += 1 | ||
|
||
outputs: List[Union[RequestOutput]] = [] | ||
while engine.has_unfinished_requests(): | ||
step_outputs = engine.step() | ||
for output in step_outputs: | ||
if output.finished: | ||
outputs.append(output) | ||
|
||
outputs = sorted(outputs, key=lambda x: int(x.request_id)) | ||
return outputs | ||
|
||
|
||
def inference(args: argparse.Namespace, prompts: List[str]): | ||
"""Initialize the LLMEngine""" | ||
engine_args = EngineArgs.from_cli_args(args) | ||
llm_engine = LLMEngine.from_engine_args(engine_args) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True) | ||
llm_engine.tokenizer.tokenizer = tokenizer | ||
|
||
"""Initialize the SamplingParams""" | ||
sampling_params = SamplingParams( | ||
n=args.n, | ||
best_of=args.best_of, | ||
frequency_penalty=args.frequency_penalty, | ||
repetition_penalty=args.repetition_penalty, | ||
temperature=args.temperature, | ||
top_p=args.top_p, | ||
top_k=args.top_k, | ||
min_p=args.min_p, | ||
seed=args.seed, | ||
use_beam_search=args.use_beam_search, | ||
length_penalty=args.length_penalty, | ||
early_stopping=args.early_stopping, | ||
stop=args.stop, | ||
stop_token_ids=args.stop_token_ids, | ||
include_stop_str_in_output=args.include_stop_str_in_output, | ||
ignore_eos=args.ignore_eos, | ||
max_tokens=args.max_tokens, | ||
min_tokens=args.min_tokens, | ||
logprobs=args.logprobs, | ||
prompt_logprobs=args.prompt_logprobs, | ||
detokenize=args.detokenize, | ||
skip_special_tokens=args.skip_special_tokens, | ||
spaces_between_special_tokens=args.spaces_between_special_tokens, | ||
# logits_processors=, | ||
# truncate_prompt_tokens=, | ||
) | ||
|
||
outputs = process_requests(prompts, llm_engine, sampling_params) | ||
for output in outputs: | ||
print("\n") | ||
print("="*50) | ||
print("=> RequestOutput:", output) | ||
token_ids = output.outputs[0].token_ids | ||
print("=> generated text:", tokenizer.decode(token_ids)) | ||
|
||
|
||
def generate(args: argparse.Namespace, prompts: List[str]): | ||
|
||
model = LlamaForCausalLM.from_pretrained( | ||
args.model, | ||
torch_dtype=torch.bfloat16, | ||
attn_implementation="flash_attention_2", | ||
trust_remote_code=True | ||
).to('cuda') | ||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True) | ||
|
||
for prompt in prompts: | ||
print("\n") | ||
print("="*50) | ||
print("=> prompt:", prompt) | ||
tokens = tokenizer.encode_plus(prompt)["input_ids"] | ||
tokens = torch.tensor(tokens)[None,].to(model.device) | ||
input_length = len(tokens[0]) | ||
generation_config = GenerationConfig( | ||
do_sample=True, | ||
eos_token_id=tokenizer.convert_tokens_to_ids('<|extra_204|>'), | ||
pad_token_id=tokenizer.convert_tokens_to_ids('<|endoftext|>'), | ||
max_new_tokens=args.max_tokens, | ||
temperature=args.temperature, | ||
top_k=args.top_k, | ||
top_p=args.top_p, | ||
) | ||
out = model.generate( | ||
tokens, | ||
generation_config, | ||
return_dict_in_generate=True, | ||
output_scores=True, | ||
) | ||
out_ids = out["sequences"][0][input_length:].cpu().numpy() | ||
out_text = tokenizer.decode(out_ids.tolist()) | ||
print("=> generated text:", out_text) | ||
|
||
|
||
if __name__ == '__main__': | ||
args = parse_args() | ||
|
||
prompts = [] | ||
if args.prompts_path is not None: | ||
with open(args.prompts_path, "r") as f: | ||
while True: | ||
prompt = f.readline() | ||
if not prompt: | ||
break | ||
prompts.append(prompt[:-1]) # remove the last '\n' of prompt | ||
elif len(args.prompts) > 1: | ||
prompts = args.prompts | ||
else: | ||
raise ValueError("Pleace set right prompts_path or prompts data.") | ||
|
||
""" | ||
vllm inference | ||
""" | ||
inference(args, prompts) | ||
|
||
""" | ||
transformers inference | ||
""" | ||
# generate(args, prompts) |
Oops, something went wrong.