-
Notifications
You must be signed in to change notification settings - Fork 4
/
make_context_cache_binary.py
66 lines (57 loc) · 3.87 KB
/
make_context_cache_binary.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from utils.htp_devices_config import htp_devices, dump_htp_config, dump_htp_link_config
import argparse, os
from pathlib import Path
def main():
parser = argparse.ArgumentParser(description='Make context cache from model libs')
parser.add_argument('model_lib', type=Path, help='Path to RWKV pth file')
parser.add_argument('output_path', type=Path, help='Path to output folder')
parser.add_argument('platform', type=str, choices=htp_devices.keys(), help='Platform name')
parser.add_argument('--use_optrace', action='store_true', help='Use optrace profiling')
parser.add_argument('--wkv_customop', action='store_true', help='Use wkv custom op')
parser.add_argument('--output_name', type=str, default=None, help='Output name for the binary file')
args = parser.parse_args()
qnn_sdk_root = os.environ["QNN_SDK_ROOT"]
if not qnn_sdk_root:
print("Please set QNN_SDK_ROOT environment variable to the root of the Qualcomm Neural Processing SDK")
exit(1)
QNN_VERSION_MINOR = int(qnn_sdk_root.split('/')[-1].split('.')[1])
old_qnn = True if QNN_VERSION_MINOR < 22 else False
print(f"QNN_VERSION_MINOR: {QNN_VERSION_MINOR}")
if "chunk" in str(args.model_lib):
print("Chunked model detected")
num_chunks = int(str(args.model_lib).split('chunk')[-1].replace('.so', '').split('of')[-1])
print(f"Number of chunks: {num_chunks}")
for i in range(1, num_chunks+1):
model_path = str(args.model_lib).split('chunk')[0] + f"chunk{i}of{num_chunks}.so"
print(f"Processing chunk {model_path}")
model_name = model_path.split('/')[-1].replace('.so', '')
dump_htp_config(args.platform, [model_name], model_path.replace('.so', '_htp_config.json'), old_qnn)
dump_htp_link_config(model_path.replace('.so', '_htp_link.json'), qnn_sdk_root)
convert_cmd = f"{qnn_sdk_root}/bin/x86_64-linux-clang/qnn-context-binary-generator"
convert_cmd += f" --backend {qnn_sdk_root}/lib/x86_64-linux-clang/libQnnHtp.so"
convert_cmd += f" --model {model_path}"
convert_cmd += f" --output_dir {args.output_path}"
convert_cmd += f" --binary_file {model_name.replace('lib', '') if args.output_name is None else args.output_name + f'_chunk{i}of{num_chunks}'}"
convert_cmd += f" --config_file {model_path.replace('.so', '_htp_link.json')}"
if args.use_optrace:
convert_cmd += " --profiling_level detailed --profiling_option optrace"
if args.wkv_customop:
convert_cmd += " --op_packages hexagon/HTP/RwkvWkvOpPackage/build/x86_64-linux-clang/libQnnRwkvWkvOpPackage.so:RwkvWkvOpPackageInterfaceProvider"
os.system(convert_cmd)
else:
model_name = str(args.model_lib).split('/')[-1].replace('.so', '')
dump_htp_config(args.platform, [model_name], str(args.model_lib).replace('.so', '_htp_config.json'), old_qnn)
dump_htp_link_config(str(args.model_lib).replace('.so', '_htp_link.json'), qnn_sdk_root)
convert_cmd = f"{qnn_sdk_root}/bin/x86_64-linux-clang/qnn-context-binary-generator"
convert_cmd += f" --backend {qnn_sdk_root}/lib/x86_64-linux-clang/libQnnHtp.so"
convert_cmd += f" --model {args.model_lib}"
convert_cmd += f" --output_dir {args.output_path}"
convert_cmd += f" --binary_file {model_name.replace('lib', '') if args.output_name is None else args.output_name}"
convert_cmd += f" --config_file {str(args.model_lib).replace('.so', '_htp_link.json')}"
if args.use_optrace:
convert_cmd += " --profiling_level detailed --profiling_option optrace"
if args.wkv_customop:
convert_cmd += " --op_packages hexagon/HTP/RwkvWkvOpPackage/build/x86_64-linux-clang/libQnnRwkvWkvOpPackage.so:RwkvWkvOpPackageInterfaceProvider"
os.system(convert_cmd)
if __name__ == '__main__':
main()