Skip to content

Commit

Permalink
[IFRT] Add IFRT IR program serialize options and proto to store seria…
Browse files Browse the repository at this point in the history
…lized IFRT IR programs.

PiperOrigin-RevId: 695576399
  • Loading branch information
ICGog authored and Google-ML-Automation committed Nov 12, 2024
1 parent 9e0fe15 commit 66d6f1d
Show file tree
Hide file tree
Showing 10 changed files with 331 additions and 19 deletions.
21 changes: 21 additions & 0 deletions xla/python/ifrt/ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ cc_library(
":ifrt_ir_compile_options_proto_cc",
"//xla/pjrt:pjrt_executable",
"//xla/python/ifrt",
"//xla/python/ifrt:serdes",
"//xla/python/pjrt_ifrt:xla_ifrt",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log:check",
Expand All @@ -199,6 +200,11 @@ cc_library(
],
)

tf_proto_library(
name = "ifrt_ir_program_proto",
srcs = ["ifrt_ir_program.proto"],
)

tf_proto_library(
name = "ifrt_ir_compile_options_proto",
srcs = ["ifrt_ir_compile_options.proto"],
Expand Down Expand Up @@ -229,12 +235,15 @@ cc_library(
visibility = ["//xla/python/ifrt:friends"],
deps = [
":ifrt_ir_program",
":ifrt_ir_program_proto_cc",
":version",
"//xla:status_macros",
"//xla/mlir/utils:error_util",
"//xla/python/ifrt:serdes",
"//xla/python/ifrt/support:module_parsing",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/strings:string_view",
"@llvm-project//llvm:Support",
Expand Down Expand Up @@ -278,3 +287,15 @@ cc_library(
"@com_google_absl//absl/status:statusor",
],
)

cc_library(
name = "version",
srcs = ["version.cc"],
hdrs = ["version.h"],
compatible_with = get_compatible_with_portable(),
deps = [
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
)
4 changes: 4 additions & 0 deletions xla/python/ifrt/ir/ifrt_dialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ def Ifrt_ArrayMappingAttrArrayAttr :
let constBuilderCall = "::mlir::ArrayAttr::get($_builder.getContext(), $0)";
}

def Ifrt_IoAliasesAttr : TypedArrayAttrBase<
ConfinedAttr<DenseI32ArrayAttr, [DenseArrayCount<2>]>,
"Array of pairs of aliased input/output indices">;

//===---------------------------------------------------------------------------
// Types
//===---------------------------------------------------------------------------
Expand Down
1 change: 1 addition & 0 deletions xla/python/ifrt/ir/ifrt_ir_program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ namespace xla {
namespace ifrt {

char IfrtIRProgram::ID = 0;
char SerializeIfrtIRProgramOptions::ID = 0;
char IfrtIRCompileOptions::ID = 0;

absl::StatusOr<std::unique_ptr<IfrtIRCompileOptions>> GetIfrtIRCompileOptions(
Expand Down
18 changes: 18 additions & 0 deletions xla/python/ifrt/ir/ifrt_ir_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ limitations under the License.
#include "xla/python/ifrt/executable.h"
#include "xla/python/ifrt/ir/ifrt_ir_compile_options.pb.h"
#include "xla/python/ifrt/program.h"
#include "xla/python/ifrt/serdes.h"

namespace xla {
namespace ifrt {
Expand All @@ -55,6 +56,23 @@ struct IfrtIRProgram : llvm::RTTIExtends<IfrtIRProgram, Program> {
mlir::OwningOpRef<mlir::ModuleOp> owning_mlir_module;
};

// Options for serializing IFRT IR programs.
struct SerializeIfrtIRProgramOptions
: llvm::RTTIExtends<SerializeIfrtIRProgramOptions, SerializeOptions> {
explicit SerializeIfrtIRProgramOptions(std::string ifrt_version,
std::string atom_program_version)
: ifrt_version(std::move(ifrt_version)),
atom_program_version(std::move(atom_program_version)) {}

static char ID; // NOLINT

// String of the form "major.minor.patch", representing the IFRT IR version.
std::string ifrt_version;
// String of the form "major.minor.patch", representing the atom program
// version (currently VHLO version).
std::string atom_program_version;
};

// CompileOptions for an IFRT IR program.
struct IfrtIRCompileOptions
: llvm::RTTIExtends<IfrtIRCompileOptions, CompileOptions> {
Expand Down
42 changes: 42 additions & 0 deletions xla/python/ifrt/ir/ifrt_ir_program.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

syntax = "proto3";

package xla.ifrt;

message IfrtIrAtomProgramProto {
string name = 1;
// String of the form "major.minor.patch", representing the atom program
// version. Currently, only used to denote VHLO version.
optional string version = 2;
// Serialized atom program. If version is set then the serialized program
// is in the VHLO dialect, otherwise it is in its original dialect.
bytes program = 3;
}

// Proto for storing a serialized IFRT IR program.
message IfrtIrProgramProto {
bytes ifrt_program = 1;

// String of the form "major.minor.patch", representing the IFRT IR version.
// If ifrt version is not set, then the program is not versioned and the
// whole program will be serialized into the `ifrt_program` field.
optional string ifrt_version = 2;

// List of atom programs that are used by the IFRT IR program. It is empty
// if the IFRT IR program is not versioned (i.e., `ifrt_version` is not set).
repeated IfrtIrAtomProgramProto atom_programs = 3;
}
62 changes: 50 additions & 12 deletions xla/python/ifrt/ir/ifrt_ir_program_serdes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "llvm/Support/Casting.h"
Expand All @@ -30,6 +31,8 @@ limitations under the License.
#include "mlir/Support/LLVM.h"
#include "xla/mlir/utils/error_util.h"
#include "xla/python/ifrt/ir/ifrt_ir_program.h"
#include "xla/python/ifrt/ir/ifrt_ir_program.pb.h"
#include "xla/python/ifrt/ir/version.h"
#include "xla/python/ifrt/serdes.h"
#include "xla/python/ifrt/support/module_parsing.h"
#include "xla/status_macros.h"
Expand All @@ -49,31 +52,66 @@ class IfrtIRProgramSerDes
}

absl::StatusOr<std::string> Serialize(
Serializable& serializable, std::unique_ptr<SerializeOptions>) override {
Serializable& serializable,
std::unique_ptr<SerializeOptions> options) override {
const auto* serialize_options =
llvm::cast_or_null<SerializeIfrtIRProgramOptions>(options.get());
const auto& program = llvm::cast<IfrtIRProgram>(serializable);
if (program.mlir_module == nullptr) {
return absl::InvalidArgumentError("Unable to serialize null MLIR module");
}
std::string serialized;
llvm::raw_string_ostream out(serialized);
mlir::BytecodeWriterConfig config;

IfrtIrProgramProto program_proto;
llvm::raw_string_ostream ifrt_ir_program_stream(
*program_proto.mutable_ifrt_program());
mlir::BaseScopedDiagnosticHandler diagnostic_handler(
program.mlir_module->getContext());
if (mlir::failed(
mlir::writeBytecodeToFile(program.mlir_module, out, config))) {
return absl::InvalidArgumentError(
absl::StrFormat("Failed to serialize IFRT IR module string: %s",
diagnostic_handler.ConsumeStatus().message()));

if (serialize_options == nullptr) {
// Serialize to bytecode the whole program if no options are provided.
// This is a fast path for the case where the user does not care about
// stable serialization.
mlir::BytecodeWriterConfig writer_config;
if (mlir::failed(mlir::writeBytecodeToFile(
program.mlir_module, ifrt_ir_program_stream, writer_config))) {
return absl::InvalidArgumentError(
absl::StrFormat("Failed to serialize IFRT IR module string: %s",
diagnostic_handler.ConsumeStatus().message()));
}
} else {
auto fail_or_bytecode_version =
Version::fromString(serialize_options->ifrt_version)
->getBytecodeVersion();
if (mlir::failed(fail_or_bytecode_version)) {
return absl::InvalidArgumentError(absl::StrFormat(
"Failed to get IFRT IR bytecode version for IR version %s",
serialize_options->ifrt_version));
}
mlir::BytecodeWriterConfig writer_config(
absl::StrCat("IFRT_v", serialize_options->ifrt_version));
writer_config.setDesiredBytecodeVersion(*fail_or_bytecode_version);
if (mlir::failed(mlir::writeBytecodeToFile(
program.mlir_module, ifrt_ir_program_stream, writer_config))) {
return absl::InvalidArgumentError(
absl::StrFormat("Failed to serialize IFRT IR module string: %s",
diagnostic_handler.ConsumeStatus().message()));
}
program_proto.set_ifrt_version(serialize_options->ifrt_version);
}
return serialized;
return program_proto.SerializeAsString();
}

absl::StatusOr<std::unique_ptr<Serializable>> Deserialize(
const std::string& serialized,
std::unique_ptr<DeserializeOptions>) override {
IfrtIrProgramProto program_proto;
if (!program_proto.ParseFromString(serialized)) {
return absl::InvalidArgumentError("Failed to parse IfrtIrProgramProto");
}
auto context = std::make_unique<mlir::MLIRContext>();
TF_ASSIGN_OR_RETURN(auto module,
support::ParseMlirModuleString(serialized, *context));
TF_ASSIGN_OR_RETURN(
auto module,
support::ParseMlirModuleString(program_proto.ifrt_program(), *context));
return std::make_unique<IfrtIRProgram>(std::move(context),
std::move(module));
}
Expand Down
2 changes: 1 addition & 1 deletion xla/python/ifrt/ir/ifrt_ir_program_serdes_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ module {

EXPECT_THAT(Deserialize<IfrtIRProgram>(serialized, /*options=*/nullptr),
StatusIs(Not(absl::StatusCode::kOk),
HasSubstr("Failed to parse IFRT IR module string")));
HasSubstr("Failed to parse IfrtIrProgramProto")));
}

} // namespace
Expand Down
8 changes: 2 additions & 6 deletions xla/python/ifrt/ir/ifrt_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,6 @@ def Ifrt_RemapArraysOp
let hasVerifier = 1;
}

def IoAliasesAttr : TypedArrayAttrBase<
ConfinedAttr<DenseI32ArrayAttr, [DenseArrayCount<2>]>,
"Array of pairs of aliased input/output indices">;

def Ifrt_CallOp : Ifrt_Op<"Call",
[AttrSizedOperandSegments,
NestedInIfrtFunc,
Expand Down Expand Up @@ -194,7 +190,7 @@ def Ifrt_CallOp : Ifrt_Op<"Call",
Variadic<Ifrt_ControlType>:$control_inputs,
SymbolRefAttr:$callee,
Ifrt_DevicesAttr:$devices,
DefaultValuedAttr<IoAliasesAttr, "{}">:$io_aliases,
DefaultValuedAttr<Ifrt_IoAliasesAttr, "{}">:$io_aliases,
DefaultValuedAttr<DenseI32ArrayAttr, "{}">:$donated_input_indices);
let results = (outs
Variadic<Ifrt_ArrayType>:$outputs,
Expand Down Expand Up @@ -234,7 +230,7 @@ def Ifrt_CallLoadedExecutableOp : Ifrt_Op<"CallLoadedExecutable",
Variadic<Ifrt_ArrayType>:$inputs,
Variadic<Ifrt_ControlType>:$control_inputs,
SymbolRefAttr:$callee,
DefaultValuedAttr<IoAliasesAttr, "{}">:$io_aliases,
DefaultValuedAttr<Ifrt_IoAliasesAttr, "{}">:$io_aliases,
DefaultValuedAttr<DenseI32ArrayAttr, "{}">:$donated_input_indices);
let results = (outs
Variadic<Ifrt_ArrayType>:$outputs,
Expand Down
93 changes: 93 additions & 0 deletions xla/python/ifrt/ir/version.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/python/ifrt/ir/version.h"

#include <cstdint>

#include "llvm/ADT/StringRef.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/Regex.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Support/LLVM.h"

namespace xla {
namespace ifrt {

namespace {

static int64_t parseNumber(llvm::StringRef num_ref) {
int64_t num;
if (num_ref.getAsInteger(/*radix=*/10, num)) {
llvm::report_fatal_error("failed to parse version number");
}
return num;
}

/// Validate version argument is `#.#.#` (ex: 0.9.0, 0.99.0, 1.2.3)
/// Returns the vector of 3 matches (major, minor, patch) if successful,
/// else returns failure.
static mlir::FailureOr<llvm::SmallVector<int64_t, 3>> extractVersionNumbers(
llvm::StringRef version_ref) {
llvm::Regex versionRegex("^([0-9]+)\\.([0-9]+)\\.([0-9]+)$");
llvm::SmallVector<llvm::StringRef> matches;
if (!versionRegex.match(version_ref, &matches)) {
return mlir::failure();
}
return llvm::SmallVector<int64_t, 3>{parseNumber(matches[1]),
parseNumber(matches[2]),
parseNumber(matches[3])};
}

} // namespace

mlir::FailureOr<Version> Version::fromString(llvm::StringRef version_ref) {
auto failOrVersionArray = extractVersionNumbers(version_ref);
if (mlir::failed(failOrVersionArray)) {
return mlir::failure();
}
auto versionArr = *failOrVersionArray;
return Version(versionArr[0], versionArr[1], versionArr[2]);
}

mlir::FailureOr<int64_t> Version::getBytecodeVersion() const {
if (*this <= getCurrentVersion()) return 0;
return mlir::failure();
}

Version Version::fromCompatibilityRequirement(
CompatibilityRequirement requirement) {
switch (requirement) {
case CompatibilityRequirement::NONE:
return Version::getCurrentVersion();
case CompatibilityRequirement::WEEK_4:
return Version(0, 1, 0); // v0.1.0 - Nov 05, 2024
case CompatibilityRequirement::WEEK_12:
return Version(0, 1, 0); // v0.1.0 - Nov 05, 2024
case CompatibilityRequirement::MAX:
return Version::getMinimumVersion();
}
llvm::report_fatal_error("Unsupported compatibility requirement");
}

mlir::Diagnostic& operator<<(mlir::Diagnostic& diag, const Version& version) {
return diag << version.toString();
}
llvm::raw_ostream& operator<<(llvm::raw_ostream& os, const Version& version) {
return os << version.toString();
}

} // namespace ifrt
} // namespace xla
Loading

0 comments on commit 66d6f1d

Please sign in to comment.