Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][ArgParse] Pass default values to target compiler(#13264) #17014

Merged
merged 4 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/tvm/driver/tvmc/composite_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,34 +51,42 @@
REGISTERED_CODEGEN = {
"compute-library": {
"config_key": None,
"pass_default": False,
"pass_pipeline": partition_for_arm_compute_lib,
},
"cmsis-nn": {
"config_key": "relay.ext.cmsisnn.options",
"pass_default": False,
"pass_pipeline": partition_for_cmsisnn,
},
"ethos-n": {
"config_key": "relay.ext.ethos-n.options",
"pass_default": False,
"pass_pipeline": partition_for_ethosn,
},
"ethos-u": {
"config_key": "relay.ext.ethos-u.options",
"pass_default": False,
"pass_pipeline": partition_for_ethosu,
},
"bnns": {
"config_key": None,
"pass_default": False,
"pass_pipeline": partition_for_bnns,
},
"vitis-ai": {
"config_key": "relay.ext.vitis_ai.options",
"pass_default": False,
"pass_pipeline": partition_for_vitis_ai,
},
"clml": {
"config_key": None,
"pass_default": False,
"pass_pipeline": partition_for_clml,
},
"mrvl": {
"config_key": "relay.ext.mrvl.options",
"pass_default": True,
"pass_pipeline": partition_for_mrvl,
},
}
Expand Down
19 changes: 18 additions & 1 deletion python/tvm/driver/tvmc/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,28 @@ def _generate_codegen_args(parser, codegen_name):
for tvm_type, python_type in INTERNAL_TO_NATIVE_TYPE.items():
if field.type_info.startswith(tvm_type):
target_option = field.name
default_value = None

# Retrieve the default value string from attrs(field) of config node
# Eg: "default=target_cpu_name"
target_option_default_str = field.type_info.split("default=")[1]

# Extract the defalut value based on the tvm type
if target_option_default_str and tvm_type == "runtime.String":
default_value = target_option_default_str
elif target_option_default_str and tvm_type == "IntImm":
# Extract the numeric value from the python Int string, Eg: T.int64(8)
str_slice = target_option_default_str.split("(")[1]
default_value = str_slice.split(")")[0]

if codegen["pass_default"] is False:
default_value = None

target_group.add_argument(
f"--target-{codegen_name}-{target_option}",
type=python_type,
help=field.description,
default=default_value,
)


Expand Down Expand Up @@ -133,7 +151,6 @@ def reconstruct_target_args(args):
codegen_options = _reconstruct_codegen_args(args, codegen_name)
if codegen_options:
reconstructed[codegen_name] = codegen_options

return reconstructed


Expand Down
16 changes: 16 additions & 0 deletions tests/python/driver/tvmc/test_target_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,21 @@ def test_target_to_argparse_for_mrvl_hybrid():
assert parsed.target_mrvl_mcpu == "cnf10kb"


@tvm.testing.requires_mrvl
def test_default_arg_for_mrvl_hybrid():
parser = argparse.ArgumentParser()
generate_target_args(parser)
parsed, _ = parser.parse_known_args(
[
"--target=mrvl, llvm",
]
)
assert parsed.target == "mrvl, llvm"
assert parsed.target_mrvl_mcpu == "cn10ka"
assert parsed.target_mrvl_num_tiles == 8


@tvm.testing.requires_cmsisnn
def test_mapping_target_args():
parser = argparse.ArgumentParser()
generate_target_args(parser)
Expand Down Expand Up @@ -129,6 +144,7 @@ def test_ethosu_compiler_attrs():
}


@tvm.testing.requires_cmsisnn
def test_skip_target_from_codegen():
parser = argparse.ArgumentParser()
generate_target_args(parser)
Expand Down
Loading