Skip to content

Commit

Permalink
improve benchmark tput by moving prompt preparation outside of loop (r…
Browse files Browse the repository at this point in the history
  • Loading branch information
gracehonv authored Jun 19, 2024
1 parent afd137a commit 03872a4
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 15 deletions.
6 changes: 2 additions & 4 deletions src/llmperf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def randomly_sample_sonnet_lines_prompt(
prompt_tokens_mean: int = 550,
prompt_tokens_stddev: int = 250,
expect_output_tokens: int = 150,
tokenizer = LlamaTokenizerFast.from_pretrained(
"hf-internal-testing/llama-tokenizer")
) -> Tuple[str, int]:
"""Generate a prompt that randomly samples lines from a the shakespeare sonnet at sonnet.txt.
Expand All @@ -80,10 +82,6 @@ def randomly_sample_sonnet_lines_prompt(
A tuple of the prompt and the length of the prompt.
"""

tokenizer = LlamaTokenizerFast.from_pretrained(
"hf-internal-testing/llama-tokenizer"
)

get_token_length = lambda text: len(tokenizer.encode(text))

prompt = (
Expand Down
28 changes: 17 additions & 11 deletions token_benchmark_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,21 @@ def get_token_throughput_latencies(
req_launcher = RequestsLauncher(clients)
completed_requests = []
num_completed_requests = 0
# make up prompts outside of send loop for faster benchmarking loop
num_output_tokens_list = []
prompts = []
for i in range(max_num_completed_requests):
num_output_tokens = (sample_random_positive_int(
mean_output_tokens, stddev_output_tokens
))
num_output_tokens_list.append(num_output_tokens)

prompts.append(randomly_sample_sonnet_lines_prompt(
prompt_tokens_mean=mean_input_tokens,
prompt_tokens_stddev=stddev_input_tokens,
expect_output_tokens=num_output_tokens,
tokenizer=tokenizer
))
start_time = time.monotonic()
iter = 0
pbar = tqdm(total=max_num_completed_requests)
Expand All @@ -79,21 +94,12 @@ def get_token_throughput_latencies(
and len(completed_requests) < max_num_completed_requests
):
iter += 1
num_output_tokens = sample_random_positive_int(
mean_output_tokens, stddev_output_tokens
)

prompt = randomly_sample_sonnet_lines_prompt(
prompt_tokens_mean=mean_input_tokens,
prompt_tokens_stddev=stddev_input_tokens,
expect_output_tokens=num_output_tokens,
)

default_sampling_params = {"max_tokens": num_output_tokens}
default_sampling_params = {"max_tokens": num_output_tokens_list.pop()}
default_sampling_params.update(additional_sampling_params)
request_config = RequestConfig(
model=model,
prompt=prompt,
prompt=prompts.pop(),
sampling_params=default_sampling_params,
llm_api=llm_api,
)
Expand Down

0 comments on commit 03872a4

Please sign in to comment.