Skip to content

Commit

Permalink
Add offline inference with vllm backend (#143)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaoyinglia authored Jun 12, 2024
1 parent 215bc3a commit d05ca66
Show file tree
Hide file tree
Showing 6 changed files with 494 additions and 45 deletions.
6 changes: 3 additions & 3 deletions examples/aquila/conf/config_auto_tuner.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
defaults:
- train: demo
- train: demo
- _self_

experiment:
Expand Down Expand Up @@ -36,8 +36,8 @@ experiment:
train_iters: 5
max_time: 600

action: run
action: auto_tune

hydra:
run:
dir: ${experiment.exp_dir}/hydra
dir: ${experiment.exp_dir}/hydra
22 changes: 22 additions & 0 deletions examples/aquila/conf/config_infer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
defaults:
- _self_
- inference: inference_aquila_7b

experiment:
exp_name: aquila2
exp_dir: ./outputs
task:
type: inference
backend: vllm
entrypoint: ./flagscale/inference/inference_aquila.py
runner:
hostfile: /share/project/zhaoyingli/hostfile
envs:
CUDA_VISIBLE_DEVICES: 0,1,2,3,4,5,6,7
CUDA_DEVICE_MAX_CONNECTIONS: 1

action: run

hydra:
run:
dir: ${experiment.exp_dir}/hydra
22 changes: 22 additions & 0 deletions examples/aquila/conf/inference/inference_aquila_7b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
engine:
model: BAAI/Aquila-7B/
tokenizer: BAAI/Aquila-7B/
trust_remote_code: true
tensor_parallel_size: 1
pipeline_parallel_size: 1
gpu_memory_utilization: 0.6
dtype: bfloat16
seed: 1234

data:
prompts: [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# prompts_path: null
top_p: 0.95
top_k: 100
max_tokens: 7
temperature: 0.9
133 changes: 133 additions & 0 deletions flagscale/inference/arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import argparse
from typing import List, Union

from vllm import EngineArgs


def parse_args(ignore_unknown_args=False):
parser = argparse.ArgumentParser(description='vLLM Inference')
_add_additional_args(parser)
_add_vllm_engine_args(parser)
_add_sampling_args(parser)

if ignore_unknown_args:
args, _ = parser.parse_known_args()
else:
args = parser.parse_args()

return args


def _add_additional_args(parser):
group = parser.add_argument_group(title='vLLM-additional-args')

group.add_argument("--prompts-path",
type=str,
default=None,
help="the text file contain the prompts")
group.add_argument("--prompts",
nargs='*',
help="A list of prompts to generate completions for.")


def _add_vllm_engine_args(parser):
group = parser.add_argument_group(title='vLLM-Engine')
group = EngineArgs.add_cli_args(group)
return parser


def _add_sampling_args(parser):
group = parser.add_argument_group(title='vLLM-sampling-params')

group.add_argument("--n",
type=int,
default=1,
help="Number of output sequences to return for the given prompt.")
group.add_argument("--best_of",
type=int,
default=None,
help="Number of output sequences that are generated from the prompt.")
group.add_argument("--presence-penalty",
type=float,
default=0.0,
help="Float that penalizes new tokens based on whether they appear in the generated text so far.")
group.add_argument("--frequency-penalty",
type=float,
default=0.0,
help="Float that penalizes new tokens based on their frequency in the generated text so far.")
group.add_argument("--repetition-penalty",
type=float,
default=1.0,
help="Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far.")
group.add_argument("--temperature",
type=float,
default=1.0,
help="Float that controls the randomness of the sampling.")
group.add_argument("--top-p",
type=float,
default=1.0,
help="Float that controls the cumulative probability of the top tokens to consider.")
group.add_argument("--top-k",
type=int,
default=-1,
help="Integer that controls the number of top tokens to consider.")
group.add_argument("--min-p",
type=float,
default=0.0,
help="Float that represents the minimum probability for a token to be considered.")
group.add_argument("--use-beam-search",
type=bool,
default=False,
help="Whether to use beam search instead of sampling.")
group.add_argument("--length-penalty",
type=float,
default=1.0,
help="Float that penalizes sequences based on their length.")
group.add_argument("--early-stopping",
type=Union[bool, str],
default=False,
help="Controls the stopping condition for beam search.")
group.add_argument("--stop",
type=Union[str, List[str]],
default=None,
help="List of strings that stop the generation when they are generated.")
group.add_argument("--stop-token-ids",
type=List[int],
default=None,
help="List of tokens that stop the generation when they are generated.")
group.add_argument("--include-stop-str-in-output",
type=bool,
default=False,
help="Whether to include the stop strings in output text.")
group.add_argument("--ignore-eos",
type=bool,
default=False,
help="Whether to ignore the EOS token and continue generating tokens after the EOS token is generated.")
group.add_argument("--max-tokens",
type=int,
default=16,
help="Maximum number of tokens to generate per output sequence.")
group.add_argument("--min-tokens",
type=int,
default=0,
help="Minimum number of tokens to generate per output sequence before EOS or stop_token_ids can be generated.")
group.add_argument("--logprobs",
type=int,
default=None,
help="Number of log probabilities to return per output token.")
group.add_argument("--prompt-logprobs",
type=int,
default=None,
help="Number of log probabilities to return per prompt token.")
group.add_argument("--detokenize",
type=bool,
default=True,
help="Whether to detokenize the output.")
group.add_argument("--skip-special-tokens",
type=bool,
default=True,
help="Whether to skip special tokens in the output.")
group.add_argument("--spaces-between-special-tokens",
type=bool,
default=True,
help="Whether to add spaces between special tokens in the output.")
140 changes: 140 additions & 0 deletions flagscale/inference/inference_aquila.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import argparse
from typing import List, Union

import torch

from transformers import AutoTokenizer, LlamaForCausalLM, GenerationConfig
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams

from arguments import parse_args


def process_requests(prompts: List[str],
engine: LLMEngine,
sampling_params: SamplingParams):
"""Continuously process a list of prompts and handle the outputs."""
request_id = 0
while prompts:
prompt = prompts.pop(0)
engine.add_request(str(request_id), prompt, sampling_params)
request_id += 1

outputs: List[Union[RequestOutput]] = []
while engine.has_unfinished_requests():
step_outputs = engine.step()
for output in step_outputs:
if output.finished:
outputs.append(output)

outputs = sorted(outputs, key=lambda x: int(x.request_id))
return outputs


def inference(args: argparse.Namespace, prompts: List[str]):
"""Initialize the LLMEngine"""
engine_args = EngineArgs.from_cli_args(args)
llm_engine = LLMEngine.from_engine_args(engine_args)

tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True)
llm_engine.tokenizer.tokenizer = tokenizer

"""Initialize the SamplingParams"""
sampling_params = SamplingParams(
n=args.n,
best_of=args.best_of,
frequency_penalty=args.frequency_penalty,
repetition_penalty=args.repetition_penalty,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
min_p=args.min_p,
seed=args.seed,
use_beam_search=args.use_beam_search,
length_penalty=args.length_penalty,
early_stopping=args.early_stopping,
stop=args.stop,
stop_token_ids=args.stop_token_ids,
include_stop_str_in_output=args.include_stop_str_in_output,
ignore_eos=args.ignore_eos,
max_tokens=args.max_tokens,
min_tokens=args.min_tokens,
logprobs=args.logprobs,
prompt_logprobs=args.prompt_logprobs,
detokenize=args.detokenize,
skip_special_tokens=args.skip_special_tokens,
spaces_between_special_tokens=args.spaces_between_special_tokens,
# logits_processors=,
# truncate_prompt_tokens=,
)

outputs = process_requests(prompts, llm_engine, sampling_params)
for output in outputs:
print("\n")
print("="*50)
print("=> RequestOutput:", output)
token_ids = output.outputs[0].token_ids
print("=> generated text:", tokenizer.decode(token_ids))


def generate(args: argparse.Namespace, prompts: List[str]):

model = LlamaForCausalLM.from_pretrained(
args.model,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
trust_remote_code=True
).to('cuda')
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True)

for prompt in prompts:
print("\n")
print("="*50)
print("=> prompt:", prompt)
tokens = tokenizer.encode_plus(prompt)["input_ids"]
tokens = torch.tensor(tokens)[None,].to(model.device)
input_length = len(tokens[0])
generation_config = GenerationConfig(
do_sample=True,
eos_token_id=tokenizer.convert_tokens_to_ids('<|extra_204|>'),
pad_token_id=tokenizer.convert_tokens_to_ids('<|endoftext|>'),
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
)
out = model.generate(
tokens,
generation_config,
return_dict_in_generate=True,
output_scores=True,
)
out_ids = out["sequences"][0][input_length:].cpu().numpy()
out_text = tokenizer.decode(out_ids.tolist())
print("=> generated text:", out_text)


if __name__ == '__main__':
args = parse_args()

prompts = []
if args.prompts_path is not None:
with open(args.prompts_path, "r") as f:
while True:
prompt = f.readline()
if not prompt:
break
prompts.append(prompt[:-1]) # remove the last '\n' of prompt
elif len(args.prompts) > 1:
prompts = args.prompts
else:
raise ValueError("Pleace set right prompts_path or prompts data.")

"""
vllm inference
"""
inference(args, prompts)

"""
transformers inference
"""
# generate(args, prompts)
Loading

0 comments on commit d05ca66

Please sign in to comment.