From 9af8be6491755ffcb21aadc412d224e7582fbcc3 Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Thu, 11 Jul 2024 21:51:06 -0700 Subject: [PATCH] add cuda_graph for mts_gpu_benchmark (#1012) Summary: Pull Request resolved: https://github.com/facebookincubator/AITemplate/pull/1012 Add cuda_graph enablement for AIT. Also leave the hook for AOTI GPU trace after enabling cudagraph: https://fburl.com/perfdoctor/yybf0z60 log information for verification I0710 17:58:44.425707 2013974 AITModelImpl.cpp:148] AITModelImpl: loading .so lib /tmp/benchmark_529602.1720659401/_run_on_acc_0/_run_on_acc_0-ait_engine.so I0710 17:58:44.425731 2013974 AITModelImpl.cpp:149] AITModelImpl: num_runtimes: 1,use_cuda_graph: 1 Reviewed By: guowentian Differential Revision: D59617284 --- fx2ait/fx2ait/ait_splitter.py | 1 + fx2ait/fx2ait/csrc/AITModel.cpp | 3 ++- fx2ait/fx2ait/csrc/AITModelImpl.cpp | 4 +++- fx2ait/fx2ait/test/test_fx2ait.py | 2 ++ fx2ait/fx2ait/tools/ait_minimizer.py | 1 + fx2ait/fx2ait/tools/common_aten2ait.py | 3 +++ fx2ait/fx2ait/tools/common_fx2ait.py | 3 +++ 7 files changed, 15 insertions(+), 2 deletions(-) diff --git a/fx2ait/fx2ait/ait_splitter.py b/fx2ait/fx2ait/ait_splitter.py index 9679b75fc..b0f1a02d9 100644 --- a/fx2ait/fx2ait/ait_splitter.py +++ b/fx2ait/fx2ait/ait_splitter.py @@ -218,6 +218,7 @@ def _lower_model_to_backend( torch.float16, torch.float, 1, # num_runtimes + False, ), interpreter_result, ) diff --git a/fx2ait/fx2ait/csrc/AITModel.cpp b/fx2ait/fx2ait/csrc/AITModel.cpp index 510bbb534..d4a49e8aa 100644 --- a/fx2ait/fx2ait/csrc/AITModel.cpp +++ b/fx2ait/fx2ait/csrc/AITModel.cpp @@ -64,7 +64,8 @@ static auto registerAITModel = std::vector, std::optional, std::optional, - int64_t>()) + int64_t, + bool>()) .def("forward", &AITModel::forward) .def("profile", &AITModel::profile) .def("get_library_path", &AITModel::libraryPath) diff --git a/fx2ait/fx2ait/csrc/AITModelImpl.cpp b/fx2ait/fx2ait/csrc/AITModelImpl.cpp index 74586cb6e..8530ae706 100644 --- a/fx2ait/fx2ait/csrc/AITModelImpl.cpp +++ b/fx2ait/fx2ait/csrc/AITModelImpl.cpp @@ -145,7 +145,9 @@ AITModelImpl::AITModelImpl( floating_point_input_dtype_(input_dtype), floating_point_output_dtype_(output_dtype), use_cuda_graph_(use_cuda_graph) { - LOG(INFO) << "Loading .so lib " << model_path; + LOG(INFO) << "AITModelImpl: loading .so lib " << model_path; + LOG(INFO) << "AITModelImpl: num_runtimes: " << num_runtimes + << ",use_cuda_graph: " << use_cuda_graph; TORCH_CHECK(handle_, "could not dlopen ", model_path, ": ", dlerror()); TORCH_CHECK(num_runtimes > 0, "num_runtimes must be positive"); diff --git a/fx2ait/fx2ait/test/test_fx2ait.py b/fx2ait/fx2ait/test/test_fx2ait.py index 1c986412f..ec2474ce1 100644 --- a/fx2ait/fx2ait/test/test_fx2ait.py +++ b/fx2ait/fx2ait/test/test_fx2ait.py @@ -68,6 +68,7 @@ def _test_fx2ait_impl(self, test_serialization=False, test_cuda_graph=False): torch.float16, torch.float16, 1, # num_runtimes + False, ) ) ait_mod.engine.use_cuda_graph = test_cuda_graph @@ -140,6 +141,7 @@ def forward(self, a, b, c, d): torch.float16, torch.float16, 1, # num_runtimes + False, ), interp_result, ) diff --git a/fx2ait/fx2ait/tools/ait_minimizer.py b/fx2ait/fx2ait/tools/ait_minimizer.py index 9c914cd9c..8a22988f3 100644 --- a/fx2ait/fx2ait/tools/ait_minimizer.py +++ b/fx2ait/fx2ait/tools/ait_minimizer.py @@ -42,6 +42,7 @@ def lower_mod_default( torch.float16, torch.float16, 1, # num_runtimes + False, ), interpreter_result, ) diff --git a/fx2ait/fx2ait/tools/common_aten2ait.py b/fx2ait/fx2ait/tools/common_aten2ait.py index 60a4cf4f6..4d3d38d68 100644 --- a/fx2ait/fx2ait/tools/common_aten2ait.py +++ b/fx2ait/fx2ait/tools/common_aten2ait.py @@ -163,6 +163,7 @@ def run_test( torch.float16, torch.float, 1, # num_runtimes + False, ), interp_result, ) @@ -256,6 +257,7 @@ def run_test_with_dynamic_shape( torch.float16, torch.float, 1, # num_runtimes + False, ), interp_result, ) @@ -375,6 +377,7 @@ def benchmark(f, args): torch.float16, torch.float, 1, # num_runtimes + False, ), interp_result, ) diff --git a/fx2ait/fx2ait/tools/common_fx2ait.py b/fx2ait/fx2ait/tools/common_fx2ait.py index bb44e2c3a..f06b65298 100644 --- a/fx2ait/fx2ait/tools/common_fx2ait.py +++ b/fx2ait/fx2ait/tools/common_fx2ait.py @@ -199,6 +199,7 @@ def run_test( torch_dtype, torch.float, 1, # num_runtimes + False, ), interp_result, ) @@ -329,6 +330,7 @@ def run_test_with_dynamic_shape( torch.float16, torch.float, 1, # num_runtimes + False, ), interp_result, ) @@ -467,6 +469,7 @@ def benchmark(f, args): torch.float16, torch.float, 1, # num_runtimes + False, ), interp_result, )