diff --git a/fx2ait/fx2ait/ait_splitter.py b/fx2ait/fx2ait/ait_splitter.py index 9679b75fc..b0f1a02d9 100644 --- a/fx2ait/fx2ait/ait_splitter.py +++ b/fx2ait/fx2ait/ait_splitter.py @@ -218,6 +218,7 @@ def _lower_model_to_backend( torch.float16, torch.float, 1, # num_runtimes + False, ), interpreter_result, ) diff --git a/fx2ait/fx2ait/csrc/AITModel.cpp b/fx2ait/fx2ait/csrc/AITModel.cpp index 510bbb534..d4a49e8aa 100644 --- a/fx2ait/fx2ait/csrc/AITModel.cpp +++ b/fx2ait/fx2ait/csrc/AITModel.cpp @@ -64,7 +64,8 @@ static auto registerAITModel = std::vector, std::optional, std::optional, - int64_t>()) + int64_t, + bool>()) .def("forward", &AITModel::forward) .def("profile", &AITModel::profile) .def("get_library_path", &AITModel::libraryPath) diff --git a/fx2ait/fx2ait/csrc/AITModelImpl.cpp b/fx2ait/fx2ait/csrc/AITModelImpl.cpp index 74586cb6e..8530ae706 100644 --- a/fx2ait/fx2ait/csrc/AITModelImpl.cpp +++ b/fx2ait/fx2ait/csrc/AITModelImpl.cpp @@ -145,7 +145,9 @@ AITModelImpl::AITModelImpl( floating_point_input_dtype_(input_dtype), floating_point_output_dtype_(output_dtype), use_cuda_graph_(use_cuda_graph) { - LOG(INFO) << "Loading .so lib " << model_path; + LOG(INFO) << "AITModelImpl: loading .so lib " << model_path; + LOG(INFO) << "AITModelImpl: num_runtimes: " << num_runtimes + << ",use_cuda_graph: " << use_cuda_graph; TORCH_CHECK(handle_, "could not dlopen ", model_path, ": ", dlerror()); TORCH_CHECK(num_runtimes > 0, "num_runtimes must be positive"); diff --git a/fx2ait/fx2ait/lower/lower.py b/fx2ait/fx2ait/lower/lower.py index ecd43e5e1..2a7e4ab7b 100644 --- a/fx2ait/fx2ait/lower/lower.py +++ b/fx2ait/fx2ait/lower/lower.py @@ -129,6 +129,7 @@ def lower_pass( _precision_to_torch_type(lower_settings.precision), _precision_to_torch_type(lower_settings.output_precision), 1, # num_runtimes + False, ), interp_res, lower_settings.trace_ait_module, diff --git a/fx2ait/fx2ait/test/test_fx2ait.py b/fx2ait/fx2ait/test/test_fx2ait.py index 1c986412f..ec2474ce1 100644 --- a/fx2ait/fx2ait/test/test_fx2ait.py +++ b/fx2ait/fx2ait/test/test_fx2ait.py @@ -68,6 +68,7 @@ def _test_fx2ait_impl(self, test_serialization=False, test_cuda_graph=False): torch.float16, torch.float16, 1, # num_runtimes + False, ) ) ait_mod.engine.use_cuda_graph = test_cuda_graph @@ -140,6 +141,7 @@ def forward(self, a, b, c, d): torch.float16, torch.float16, 1, # num_runtimes + False, ), interp_result, ) diff --git a/fx2ait/fx2ait/tools/ait_minimizer.py b/fx2ait/fx2ait/tools/ait_minimizer.py index 9c914cd9c..8a22988f3 100644 --- a/fx2ait/fx2ait/tools/ait_minimizer.py +++ b/fx2ait/fx2ait/tools/ait_minimizer.py @@ -42,6 +42,7 @@ def lower_mod_default( torch.float16, torch.float16, 1, # num_runtimes + False, ), interpreter_result, ) diff --git a/fx2ait/fx2ait/tools/common_aten2ait.py b/fx2ait/fx2ait/tools/common_aten2ait.py index 60a4cf4f6..4d3d38d68 100644 --- a/fx2ait/fx2ait/tools/common_aten2ait.py +++ b/fx2ait/fx2ait/tools/common_aten2ait.py @@ -163,6 +163,7 @@ def run_test( torch.float16, torch.float, 1, # num_runtimes + False, ), interp_result, ) @@ -256,6 +257,7 @@ def run_test_with_dynamic_shape( torch.float16, torch.float, 1, # num_runtimes + False, ), interp_result, ) @@ -375,6 +377,7 @@ def benchmark(f, args): torch.float16, torch.float, 1, # num_runtimes + False, ), interp_result, ) diff --git a/fx2ait/fx2ait/tools/common_fx2ait.py b/fx2ait/fx2ait/tools/common_fx2ait.py index bb44e2c3a..ac2e4be58 100644 --- a/fx2ait/fx2ait/tools/common_fx2ait.py +++ b/fx2ait/fx2ait/tools/common_fx2ait.py @@ -187,6 +187,7 @@ def run_test( torch_dtype, torch.float, 1, # num_runtimes + False, ), interp_result, ) @@ -199,6 +200,7 @@ def run_test( torch_dtype, torch.float, 1, # num_runtimes + False, ), interp_result, ) @@ -317,6 +319,7 @@ def run_test_with_dynamic_shape( torch.float16, torch.float, 1, # num_runtimes + False, ), interp_result, ) @@ -329,6 +332,7 @@ def run_test_with_dynamic_shape( torch.float16, torch.float, 1, # num_runtimes + False, ), interp_result, ) @@ -467,6 +471,7 @@ def benchmark(f, args): torch.float16, torch.float, 1, # num_runtimes + False, ), interp_result, )