Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: n1ck-guo <[email protected]>
  • Loading branch information
n1ck-guo committed Dec 12, 2024
1 parent 689c6eb commit f6dec89
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 7 deletions.
14 changes: 10 additions & 4 deletions auto_round/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,18 @@
# limitations under the License.
import sys

def run_eval():
from auto_round.script.llm import setup_eval_parser, eval
args = setup_eval_parser()
eval(args)

def run():
from auto_round.script.llm import setup_parser, tune, eval
args = setup_parser()
if args.eval:
eval(args)
if "--eval" in sys.argv:
sys.argv.remove("--eval")
run_eval()
else:
from auto_round.script.llm import setup_parser, tune
args = setup_parser()
tune(args)

def run_best():
Expand Down
28 changes: 25 additions & 3 deletions auto_round/script/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,26 @@ def __init__(self, *args, **kwargs):
self.add_argument("--disable_act_dynamic", action='store_true',
help="activation static quantization")

class EvalArgumentParser(argparse.ArgumentParser):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.add_argument("--model", "--model_name", "--model_name_or_path", default="facebook/opt-125m",
help="model name or path")
self.add_argument("--device", "--devices", default="auto", type=str,
help="the device to be used for tuning. "
"Currently, device settings support CPU, GPU, and HPU."
"The default is set to cuda:0,"
"allowing for automatic detection and switch to HPU or CPU."
"set --device 0,1,2 to use multiple cards.")
self.add_argument("--tasks",
default="lambada_openai,hellaswag,winogrande,piqa,mmlu,wikitext,truthfulqa_mc1," \
"truthfulqa_mc2,openbookqa,boolq,rte,arc_easy,arc_challenge",
help="lm-eval tasks")
self.add_argument("--disable_trust_remote_code", action='store_true',
help="whether to disable trust_remote_code")
self.add_argument("--eval_bs", default=None, type=int,
help="batch size in evaluation")


def setup_parser():
parser = BasicArgumentParser()
Expand Down Expand Up @@ -227,6 +247,11 @@ def setup_fast_parser():
return args


def setup_eval_parser():
parser = EvalArgumentParser()
args = parser.parse_args()
return args

def tune(args):
tasks = args.tasks
if args.format is None:
Expand Down Expand Up @@ -485,9 +510,6 @@ def eval(args):
devices = args.device.replace(" ", "").split(',')
parallelism = False

if "CUDA_VISIBLE_DEVICES" in os.environ:
args.device = "auto"

if all(s.isdigit() for s in devices):
if "CUDA_VISIBLE_DEVICES" in os.environ:
current_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"]
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ license_files =
console_scripts =
auto_round = auto_round.__main__:run
auto-round = auto_round.__main__:run
auto_round_eval = auto_round.__main__:run_eval
auto-round-eval = auto_round.__main__:run_eval
auto_round_mllm = auto_round.__main__:run_mllm
auto-round-mllm = auto_round.__main__:run_mllm
auto-round-fast = auto_round.__main__:run_fast
Expand Down

0 comments on commit f6dec89

Please sign in to comment.