From 3a96b69965189876ff3bccceebb26d991e9bea72 Mon Sep 17 00:00:00 2001 From: Anna Date: Wed, 29 Nov 2023 10:29:07 -0800 Subject: [PATCH] Add script for doing bulk generation against an endpoint (#765) * Add script for doing bulk generation against an endpoint * more logging * warn * fix * format * asdfads * Add warning * updates * folder -> file * remove blank line * Support remote input * prompts -> inputs --- llmfoundry/utils/prompt_files.py | 58 +++++++ scripts/inference/endpoint_generate.py | 223 +++++++++++++++++++++++++ scripts/inference/hf_generate.py | 31 ++-- tests/test_prompt_files.py | 18 ++ 4 files changed, 309 insertions(+), 21 deletions(-) create mode 100644 llmfoundry/utils/prompt_files.py create mode 100644 scripts/inference/endpoint_generate.py create mode 100644 tests/test_prompt_files.py diff --git a/llmfoundry/utils/prompt_files.py b/llmfoundry/utils/prompt_files.py new file mode 100644 index 0000000000..40de19907a --- /dev/null +++ b/llmfoundry/utils/prompt_files.py @@ -0,0 +1,58 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import os +from typing import List, Optional + +PROMPTFILE_PREFIX = 'file::' + + +def load_prompts(prompts: List[str], + prompt_delimiter: Optional[str] = None) -> List[str]: + """Loads a set of prompts, both free text and from file. + + Args: + prompts (List[str]): List of free text prompts and prompt files + prompt_delimiter (Optional str): Delimiter for text file + If not provided, assumes the prompt file is a single prompt (non-delimited) + + Returns: + List of prompt string(s) + """ + prompt_strings = [] + for prompt in prompts: + if prompt.startswith(PROMPTFILE_PREFIX): + prompts = load_prompts_from_file(prompt, prompt_delimiter) + prompt_strings.extend(prompts) + else: + prompt_strings.append(prompt) + return prompt_strings + + +def load_prompts_from_file(prompt_path: str, + prompt_delimiter: Optional[str] = None) -> List[str]: + """Load a set of prompts from a text fie. + + Args: + prompt_path (str): Path for text file + prompt_delimiter (Optional str): Delimiter for text file + If not provided, assumes the prompt file is a single prompt (non-delimited) + + Returns: + List of prompt string(s) + """ + if not prompt_path.startswith(PROMPTFILE_PREFIX): + raise ValueError(f'prompt_path_str must start with {PROMPTFILE_PREFIX}') + + _, prompt_file_path = prompt_path.split(PROMPTFILE_PREFIX, maxsplit=1) + prompt_file_path = os.path.expanduser(prompt_file_path) + if not os.path.isfile(prompt_file_path): + raise FileNotFoundError( + f'{prompt_file_path=} does not match any existing files.') + + with open(prompt_file_path, 'r') as f: + prompt_string = f.read() + + if prompt_delimiter is None: + return [prompt_string] + return [i for i in prompt_string.split(prompt_delimiter) if i] diff --git a/scripts/inference/endpoint_generate.py b/scripts/inference/endpoint_generate.py new file mode 100644 index 0000000000..e78fecf59b --- /dev/null +++ b/scripts/inference/endpoint_generate.py @@ -0,0 +1,223 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Batch generate text completion results from an endpoint. + +Warning: This script is experimental and could change or be removed at any time +""" + +import asyncio +import copy +import logging +import math +import os +import tempfile +import time +from argparse import ArgumentParser, Namespace + +import pandas as pd +import requests +from composer.utils import (get_file, maybe_create_object_store_from_uri, + parse_uri) + +from llmfoundry.utils import prompt_files as utils + +logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') +log = logging.getLogger(__name__) + +ENDPOINT_API_KEY_ENV: str = 'ENDPOINT_API_KEY' +ENDPOINT_URL_ENV: str = 'ENDPOINT_URL' + +PROMPT_DELIMITER = '\n' + + +def parse_args() -> Namespace: + """Parse commandline arguments.""" + parser = ArgumentParser( + description='Call prompts against a text completions endpoint') + + ##### + # Path Parameters + parser.add_argument( + '-i', + '--inputs', + nargs='+', + help=f'List of strings, local datafiles (starting with {utils.PROMPTFILE_PREFIX}),' +\ + ' and/or remote object stores' + ) + parser.add_argument( + '--prompt-delimiter', + default='\n', + help= + 'Prompt delimiter for txt files. By default, a file is a single prompt') + + parser.add_argument('-o', + '--output-folder', + required=True, + help='Remote folder to save the output') + + ##### + # Generation Parameters + parser.add_argument( + '--rate-limit', + type=int, + default=75, + help='Max number of calls to make to the endpoint in a second') + parser.add_argument( + '--batch-size', + type=int, + default=10, + help='Max number of calls to make to the endpoint in a single request') + + ##### + # Endpoint Parameters + parser.add_argument( + '-e', + '--endpoint', + type=str, + help= + f'OpenAI-compatible text completions endpoint to query on. If not set, will read from {ENDPOINT_URL_ENV}' + ) + + parser.add_argument('--max-tokens', type=int, default=100) + parser.add_argument('--temperature', type=float, default=1.0) + parser.add_argument('--top-k', type=int, default=50) + parser.add_argument('--top-p', type=float, default=1.0) + return parser.parse_args() + + +async def main(args: Namespace) -> None: + # This is mildly experimental, so for now imports are not added as part of llm-foundry + try: + import aiohttp + except ImportError as e: + raise ImportError('Please install aiohttp') from e + + try: + from ratelimit import limits, sleep_and_retry + except ImportError as e: + raise ImportError('Please install ratelimit') from e + + if args.batch_size > args.rate_limit: + raise ValueError( + f'Batch size is {args.batch_size} but rate limit is set to {args.rate_limit} / s' + ) + + url = args.endpoint if args.endpoint else os.environ.get(ENDPOINT_URL_ENV) + if not url: + raise ValueError( + f'URL must be provided via --endpoint or {ENDPOINT_URL_ENV}') + + log.info(f'Using endpoint {url}') + + api_key = os.environ.get(ENDPOINT_API_KEY_ENV, '') + if not api_key: + log.warning(f'API key not set in {ENDPOINT_API_KEY_ENV}') + + new_inputs = [] + for prompt in args.inputs: + if prompt.startswith(utils.PROMPTFILE_PREFIX): + new_inputs.append(prompt) + continue + + input_object_store = maybe_create_object_store_from_uri(prompt) + if input_object_store is not None: + local_output_path = tempfile.TemporaryDirectory().name + get_file(prompt, str(local_output_path)) + log.info(f'Downloaded {prompt} to {local_output_path}') + prompt = f'{utils.PROMPTFILE_PREFIX}{local_output_path}' + + new_inputs.append(prompt) + + prompt_strings = utils.load_prompts(new_inputs, args.prompt_delimiter) + + cols = ['batch', 'prompt', 'output'] + param_data = { + 'max_tokens': args.max_tokens, + 'temperature': args.temperature, + 'top_k': args.top_k, + 'top_p': args.top_p, + } + + total_batches = math.ceil(len(prompt_strings) / args.batch_size) + log.info( + f'Generating {len(prompt_strings)} prompts in {total_batches} batches') + + @sleep_and_retry + @limits(calls=total_batches, period=1) # type: ignore + async def generate(session: aiohttp.ClientSession, batch: int, + prompts: list): + data = copy.copy(param_data) + data['prompt'] = prompts + headers = {'Authorization': api_key, 'Content-Type': 'application/json'} + req_start = time.time() + async with session.post(url, headers=headers, json=data) as resp: + if resp.ok: + try: + response = await resp.json() + except requests.JSONDecodeError: + raise Exception( + f'Bad response: {resp.status} {resp.reason}') + else: + raise Exception(f'Bad response: {resp.status} {resp.reason}') + + req_end = time.time() + n_compl = response['usage']['completion_tokens'] + n_prompt = response['usage']['prompt_tokens'] + req_latency = (req_end - req_start) + log.info(f'Completed batch {batch}: {n_compl:,} completion' + + f' tokens using {n_prompt:,} prompt tokens in {req_latency}s') + + res = pd.DataFrame(columns=cols) + + for r in response['choices']: + index = r['index'] + res.loc[len(res)] = [batch, prompts[index], r['text']] + return res + + res = pd.DataFrame(columns=cols) + batch = 0 + + gen_start = time.time() + async with aiohttp.ClientSession() as session: + tasks = [] + + for i in range(total_batches): + prompts = prompt_strings[i * args.batch_size:min( + (i + 1) * args.batch_size, len(prompt_strings))] + + tasks.append(generate(session, batch, prompts)) + batch += 1 + + results = await asyncio.gather(*tasks) + res = pd.concat(results) + + res.reset_index(drop=True, inplace=True) + + gen_end = time.time() + gen_latency = (gen_end - gen_start) + log.info(f'Generated {len(res)} prompts in {gen_latency}s, example data:') + log.info(res.head()) + + with tempfile.TemporaryDirectory() as tmp_dir: + file = 'output.csv' + local_path = os.path.join(tmp_dir, file) + res.to_csv(local_path, index=False) + + output_object_store = maybe_create_object_store_from_uri( + args.output_folder) + if output_object_store is not None: + _, _, output_folder_prefix = parse_uri(args.output_folder) + remote_path = os.path.join(output_folder_prefix, file) + output_object_store.upload_object(remote_path, local_path) + output_object_store.download_object + log.info(f'Uploaded results to {args.output_folder}/{file}') + else: + output_dir, _ = os.path.split(args.output_folder) + os.makedirs(output_dir, exist_ok=True) + os.rename(local_path, args.output_folder) + log.info(f'Saved results to {args.output_folder}') + + +if __name__ == '__main__': + asyncio.run(main(parse_args())) diff --git a/scripts/inference/hf_generate.py b/scripts/inference/hf_generate.py index 45ddc6b63e..6ac645e5b7 100644 --- a/scripts/inference/hf_generate.py +++ b/scripts/inference/hf_generate.py @@ -1,7 +1,6 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 import itertools -import os import random import time import warnings @@ -13,6 +12,8 @@ import torch from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from llmfoundry.utils import prompt_files as utils + def get_dtype(dtype: str): if dtype == 'fp32': @@ -62,9 +63,14 @@ def parse_args() -> Namespace: 'My name is', 'This is an explanation of deep learning to a five year old. Deep learning is', ], - help='Generation prompts. Use syntax "file::/path/to/prompt.txt" to load a ' +\ - 'prompt contained in a txt file.' + help='List of generation prompts or list of delimited files. Use syntax ' +\ + '"file::/path/to/prompt.txt" to load a prompt(s) contained in a txt file.' ) + parser.add_argument( + '--prompt-delimiter', + default=None, + help= + 'Prompt delimiter for txt files. By default, a file is a single prompt') parser.add_argument('--max_seq_len', type=int, default=None) parser.add_argument('--max_new_tokens', type=int, default=100) parser.add_argument('--max_batch_size', type=int, default=None) @@ -125,19 +131,6 @@ def parse_args() -> Namespace: return parser.parse_args() -def load_prompt_string_from_file(prompt_path_str: str): - if not prompt_path_str.startswith('file::'): - raise ValueError('prompt_path_str must start with "file::".') - _, prompt_file_path = prompt_path_str.split('file::', maxsplit=1) - prompt_file_path = os.path.expanduser(prompt_file_path) - if not os.path.isfile(prompt_file_path): - raise FileNotFoundError( - f'{prompt_file_path=} does not match any existing files.') - with open(prompt_file_path, 'r') as f: - prompt_string = ''.join(f.readlines()) - return prompt_string - - def maybe_synchronize(): if torch.cuda.is_available(): torch.cuda.synchronize() @@ -163,11 +156,7 @@ def main(args: Namespace) -> None: print(f'Using {model_dtype=}') # Load prompts - prompt_strings = [] - for prompt in args.prompts: - if prompt.startswith('file::'): - prompt = load_prompt_string_from_file(prompt) - prompt_strings.append(prompt) + prompt_strings = utils.load_prompts(args.prompts, args.prompt_delimiter) # Grab config first print(f'Loading HF Config...') diff --git a/tests/test_prompt_files.py b/tests/test_prompt_files.py new file mode 100644 index 0000000000..12a5d02999 --- /dev/null +++ b/tests/test_prompt_files.py @@ -0,0 +1,18 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path + +from llmfoundry.utils import prompt_files as utils + + +def test_load_prompt_strings(tmp_path: Path): + assert utils.load_prompts(['hello', 'goodbye']) == ['hello', 'goodbye'] + + with open(tmp_path / 'prompts.txt', 'w') as f: + f.write('hello goodbye') + + temp = utils.PROMPTFILE_PREFIX + str(tmp_path / 'prompts.txt') + assert utils.load_prompts( + [temp, temp, 'why'], + ' ') == ['hello', 'goodbye', 'hello', 'goodbye', 'why']