diff --git a/python/tvm/driver/tvmc/target.py b/python/tvm/driver/tvmc/target.py index ec8215184ee3..368a8e986fd1 100644 --- a/python/tvm/driver/tvmc/target.py +++ b/python/tvm/driver/tvmc/target.py @@ -69,10 +69,25 @@ def _generate_codegen_args(parser, codegen_name): for tvm_type, python_type in INTERNAL_TO_NATIVE_TYPE.items(): if field.type_info.startswith(tvm_type): target_option = field.name + default_value = None + + # Retrieve the default value string from attrs(field) of config node + # Eg: "default=target_cpu_name" + target_option_default_str = field.type_info.split("default=")[1] + + # Extract the defalut value based on the tvm type + if target_option_default_str and tvm_type == "runtime.String": + default_value = target_option_default_str + elif target_option_default_str and tvm_type == "IntImm": + # Extract the numeric value from the python Int string, Eg: T.int64(8) + str_slice = target_option_default_str.split("(")[1] + default_value = str_slice.split(")")[0] + target_group.add_argument( f"--target-{codegen_name}-{target_option}", type=python_type, help=field.description, + default=default_value, ) @@ -133,7 +148,6 @@ def reconstruct_target_args(args): codegen_options = _reconstruct_codegen_args(args, codegen_name) if codegen_options: reconstructed[codegen_name] = codegen_options - return reconstructed diff --git a/tests/python/driver/tvmc/test_target_options.py b/tests/python/driver/tvmc/test_target_options.py index 194047e7a628..d98a8d588e22 100644 --- a/tests/python/driver/tvmc/test_target_options.py +++ b/tests/python/driver/tvmc/test_target_options.py @@ -72,6 +72,21 @@ def test_target_to_argparse_for_mrvl_hybrid(): assert parsed.target_mrvl_mcpu == "cnf10kb" +@tvm.testing.requires_mrvl +def test_default_arg_for_mrvl_hybrid(): + parser = argparse.ArgumentParser() + generate_target_args(parser) + parsed, _ = parser.parse_known_args( + [ + "--target=mrvl, llvm", + ] + ) + assert parsed.target == "mrvl, llvm" + assert parsed.target_mrvl_mcpu == "cn10ka" + assert parsed.target_mrvl_num_tiles == 8 + + +@tvm.testing.requires_cmsisnn def test_mapping_target_args(): parser = argparse.ArgumentParser() generate_target_args(parser) @@ -129,6 +144,7 @@ def test_ethosu_compiler_attrs(): } +@tvm.testing.requires_cmsisnn def test_skip_target_from_codegen(): parser = argparse.ArgumentParser() generate_target_args(parser)