diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index c4cbd803092..04dcbf526ed 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1318,6 +1318,10 @@ void InitXlaModuleBindings(py::module m) { [](const std::vector& tensors) -> std::string { return GetTensorsHloGraph(tensors, EmitMode::kHloReadable); }); + m.def("_get_xla_tensors_hlo_proto", + [](const std::vector& tensors) -> py::bytes { + return py::bytes(GetTensorsHloGraph(tensors, EmitMode::kHloProto)); + }); m.def("_get_xla_tensor_debug_info", [](const at::Tensor& tensor) -> std::string { return GetXLATensorDebugInfo(tensor); diff --git a/torch_xla/csrc/ir_dump_util.cpp b/torch_xla/csrc/ir_dump_util.cpp index 448cbf63d27..63a3f17e6cb 100644 --- a/torch_xla/csrc/ir_dump_util.cpp +++ b/torch_xla/csrc/ir_dump_util.cpp @@ -287,6 +287,9 @@ std::string DumpUtil::ToHlo(c10::ArrayRef values, switch (mode) { case EmitMode::kHloReadable: return ConsumeValue(runtime::util::GetComputationHloText(computation)); + case EmitMode::kHloProto: + return ConsumeValue(runtime::util::GetDeterministicSerializedModuleProto( + computation.proto())); case EmitMode::kStableHloReadable: return hloToStablehlo(&computation.proto(), /* emit_bytecode = */ false); diff --git a/torch_xla/csrc/ir_dump_util.h b/torch_xla/csrc/ir_dump_util.h index 8c9124a3c44..f47c93f82a1 100644 --- a/torch_xla/csrc/ir_dump_util.h +++ b/torch_xla/csrc/ir_dump_util.h @@ -12,6 +12,7 @@ namespace torch_xla { enum class EmitMode { kHloReadable, + kHloProto, kStableHloReadable, kStableHloBytecode, };