Skip to content

Commit

Permalink
Introduce HLO graph bindings (pytorch#8551)
Browse files Browse the repository at this point in the history
  • Loading branch information
rpsilva-aws authored Jan 11, 2025
1 parent 6963e19 commit 3784b26
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 0 deletions.
4 changes: 4 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1318,6 +1318,10 @@ void InitXlaModuleBindings(py::module m) {
[](const std::vector<at::Tensor>& tensors) -> std::string {
return GetTensorsHloGraph(tensors, EmitMode::kHloReadable);
});
m.def("_get_xla_tensors_hlo_proto",
[](const std::vector<at::Tensor>& tensors) -> py::bytes {
return py::bytes(GetTensorsHloGraph(tensors, EmitMode::kHloProto));
});
m.def("_get_xla_tensor_debug_info",
[](const at::Tensor& tensor) -> std::string {
return GetXLATensorDebugInfo(tensor);
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/ir_dump_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,9 @@ std::string DumpUtil::ToHlo(c10::ArrayRef<torch::lazy::Value> values,
switch (mode) {
case EmitMode::kHloReadable:
return ConsumeValue(runtime::util::GetComputationHloText(computation));
case EmitMode::kHloProto:
return ConsumeValue(runtime::util::GetDeterministicSerializedModuleProto(
computation.proto()));
case EmitMode::kStableHloReadable:
return hloToStablehlo(&computation.proto(),
/* emit_bytecode = */ false);
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/ir_dump_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace torch_xla {

enum class EmitMode {
kHloReadable,
kHloProto,
kStableHloReadable,
kStableHloBytecode,
};
Expand Down

0 comments on commit 3784b26

Please sign in to comment.