diff --git a/auto_round/__main__.py b/auto_round/__main__.py index 6bcebcde..537db971 100644 --- a/auto_round/__main__.py +++ b/auto_round/__main__.py @@ -13,12 +13,18 @@ # limitations under the License. import sys +def run_eval(): + from auto_round.script.llm import setup_eval_parser, eval + args = setup_eval_parser() + eval(args) + def run(): - from auto_round.script.llm import setup_parser, tune, eval - args = setup_parser() - if args.eval: - eval(args) + if "--eval" in sys.argv: + sys.argv.remove("--eval") + run_eval() else: + from auto_round.script.llm import setup_parser, tune + args = setup_parser() tune(args) def run_best(): diff --git a/auto_round/script/llm.py b/auto_round/script/llm.py index 8c0a37d1..d3991dfb 100644 --- a/auto_round/script/llm.py +++ b/auto_round/script/llm.py @@ -157,6 +157,26 @@ def __init__(self, *args, **kwargs): self.add_argument("--disable_act_dynamic", action='store_true', help="activation static quantization") +class EvalArgumentParser(argparse.ArgumentParser): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.add_argument("--model", "--model_name", "--model_name_or_path", default="facebook/opt-125m", + help="model name or path") + self.add_argument("--device", "--devices", default="auto", type=str, + help="the device to be used for tuning. " + "Currently, device settings support CPU, GPU, and HPU." + "The default is set to cuda:0," + "allowing for automatic detection and switch to HPU or CPU." + "set --device 0,1,2 to use multiple cards.") + self.add_argument("--tasks", + default="lambada_openai,hellaswag,winogrande,piqa,mmlu,wikitext,truthfulqa_mc1," \ + "truthfulqa_mc2,openbookqa,boolq,rte,arc_easy,arc_challenge", + help="lm-eval tasks") + self.add_argument("--disable_trust_remote_code", action='store_true', + help="whether to disable trust_remote_code") + self.add_argument("--eval_bs", default=None, type=int, + help="batch size in evaluation") + def setup_parser(): parser = BasicArgumentParser() @@ -227,6 +247,11 @@ def setup_fast_parser(): return args +def setup_eval_parser(): + parser = EvalArgumentParser() + args = parser.parse_args() + return args + def tune(args): tasks = args.tasks if args.format is None: @@ -485,9 +510,6 @@ def eval(args): devices = args.device.replace(" ", "").split(',') parallelism = False - if "CUDA_VISIBLE_DEVICES" in os.environ: - args.device = "auto" - if all(s.isdigit() for s in devices): if "CUDA_VISIBLE_DEVICES" in os.environ: current_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"] diff --git a/setup.cfg b/setup.cfg index 1afe7e8c..504cf3b7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,6 +7,8 @@ license_files = console_scripts = auto_round = auto_round.__main__:run auto-round = auto_round.__main__:run + auto_round_eval = auto_round.__main__:run_eval + auto-round-eval = auto_round.__main__:run_eval auto_round_mllm = auto_round.__main__:run_mllm auto-round-mllm = auto_round.__main__:run_mllm auto-round-fast = auto_round.__main__:run_fast