Skip to content

Commit

Permalink
[IFRT] Ensure that VIFRT td file structure matches that of IFRT dialect.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696750835
  • Loading branch information
ICGog authored and Google-ML-Automation committed Nov 15, 2024
1 parent dd9a463 commit e73c1b4
Show file tree
Hide file tree
Showing 17 changed files with 428 additions and 604 deletions.
191 changes: 74 additions & 117 deletions xla/python/ifrt/ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,23 @@ cc_library(
],
)

td_library(
name = "vifrt_td",
srcs = [
"vifrt_dialect.td",
"vifrt_interfaces.td",
"vifrt_ops.td",
],
compatible_with = get_compatible_with_portable(),
deps = [
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:ShapeOpsTdFiles",
],
)

gentbl_cc_library(
name = "vifrt_attr_interfaces_inc_gen",
name = "vifrt_interfaces_inc_gen",
compatible_with = get_compatible_with_portable(),
tbl_outs = [
(
Expand All @@ -327,17 +342,61 @@ gentbl_cc_library(
["-gen-attr-interface-defs"],
"vifrt_attr_interfaces.cc.inc",
),
(
["-gen-type-interface-decls"],
"vifrt_type_interfaces.h.inc",
),
(
["-gen-type-interface-defs"],
"vifrt_type_interfaces.cc.inc",
),
(
["-gen-op-interface-decls"],
"vifrt_op_interfaces.h.inc",
),
(
["-gen-op-interface-defs"],
"vifrt_op_interfaces.cc.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "vifrt_attrs.td",
td_file = "vifrt_interfaces.td",
test = True,
deps = [":vifrt_ops_td_files"],
deps = [":vifrt_td"],
)

gentbl_cc_library(
name = "vifrt_attrs_inc_gen",
name = "vifrt_dialect_inc_gen",
compatible_with = get_compatible_with_portable(),
tbl_outs = [
(
[
"-gen-dialect-decls",
"-dialect=vifrt",
],
"vifrt_dialect.h.inc",
),
(
[
"-gen-dialect-defs",
"-dialect=vifrt",
],
"vifrt_dialect.cc.inc",
),
(
[
"-gen-typedef-decls",
"--typedefs-dialect=vifrt",
],
"vifrt_types.h.inc",
),
(
[
"-gen-typedef-defs",
"--typedefs-dialect=vifrt",
],
"vifrt_types.cc.inc",
),
(
[
"-gen-attrdef-decls",
Expand All @@ -354,50 +413,9 @@ gentbl_cc_library(
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "vifrt_ops.td",
td_file = "vifrt_dialect.td",
test = True,
deps = [":vifrt_ops_td_files"],
)

gentbl_cc_library(
name = "vifrt_op_interfaces_inc_gen",
compatible_with = get_compatible_with_portable(),
tbl_outs = [
(
["-gen-op-interface-decls"],
"vifrt_op_interfaces.h.inc",
),
(
["-gen-op-interface-defs"],
"vifrt_op_interfaces.cc.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "vifrt_ops.td",
deps = [":vifrt_ops_td_files"],
)

cc_library(
name = "vifrt_ops",
srcs = [
"vifrt_ops.cc",
],
hdrs = [
"vifrt_ops.h",
],
compatible_with = get_compatible_with_portable(),
deps = [
":version",
":vifrt_attr_interfaces_inc_gen",
":vifrt_attrs_inc_gen",
":vifrt_op_interfaces_inc_gen",
":vifrt_ops_inc_gen",
":vifrt_types",
":vifrt_types_inc_gen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
deps = [":vifrt_td"],
)

gentbl_cc_library(
Expand All @@ -416,85 +434,24 @@ gentbl_cc_library(
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "vifrt_ops.td",
test = True,
deps = [":vifrt_ops_td_files"],
)

td_library(
name = "vifrt_ops_td_files",
srcs = [
"vifrt_attrs.td",
"vifrt_base.td",
"vifrt_dialect.td",
"vifrt_ops.td",
"vifrt_types.td",
],
compatible_with = get_compatible_with_portable(),
deps = [
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:ShapeOpsTdFiles",
],
deps = [":vifrt_td"],
)

cc_library(
name = "vifrt_types",
srcs = ["vifrt_types.cc"],
hdrs = ["vifrt_types.h"],
name = "vifrt",
srcs = ["vifrt_dialect.cc"],
hdrs = ["vifrt_dialect.h"],
compatible_with = get_compatible_with_portable(),
visibility = ["//xla/python/ifrt:friends"],
deps = [
":sharding_param",
":version",
":vifrt_type_interfaces_inc_gen",
":vifrt_types_inc_gen",
":vifrt_dialect_inc_gen",
":vifrt_interfaces_inc_gen",
":vifrt_ops_inc_gen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@stablehlo//:stablehlo_assembly_format",
],
)

gentbl_cc_library(
name = "vifrt_type_interfaces_inc_gen",
compatible_with = get_compatible_with_portable(),
tbl_outs = [
(
["-gen-type-interface-decls"],
"vifrt_type_interfaces.h.inc",
),
(
["-gen-type-interface-defs"],
"vifrt_type_interfaces.cc.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "vifrt_types.td",
deps = [
":vifrt_ops_td_files",
],
)

gentbl_cc_library(
name = "vifrt_types_inc_gen",
compatible_with = get_compatible_with_portable(),
tbl_outs = [
(
[
"-gen-typedef-decls",
"--typedefs-dialect=vifrt",
],
"vifrt_type_defs.h.inc",
),
(
[
"-gen-typedef-defs",
"--typedefs-dialect=vifrt",
],
"vifrt_type_defs.cc.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "vifrt_ops.td",
deps = [
":vifrt_ops_td_files",
],
)
31 changes: 24 additions & 7 deletions xla/python/ifrt/ir/sharding_param.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@ void PopulateDevices(llvm::ArrayRef<int> permutation,
}
}

void PrintInternalV1(llvm::raw_ostream& os, const ShardingParam& sharding) {
PrintDims(os, sharding.dim_shards());
os << " to [";
llvm::interleaveComma(
llvm::ArrayRef<int>(sharding.minor_to_major().permutation), os);
os << "] on ";
PrintDims<int>(os, sharding.minor_to_major().axis_sizes);
}

} // namespace

absl::Status ShardingParam::MinorToMajor::verify() const {
Expand Down Expand Up @@ -124,6 +133,12 @@ void ShardingParam::MinorToMajor::ToDeviceList(

mlir::FailureOr<ShardingParam> ShardingParam::Parse(
mlir::AsmParser& ods_parser) {
// V1 is the current ShardingParam format.
return ParseV1(ods_parser);
}

mlir::FailureOr<ShardingParam> ShardingParam::ParseV1(
mlir::AsmParser& ods_parser) {
MinorToMajor minor_to_major;

auto parseIntoPermutation = [&]() -> mlir::ParseResult {
Expand Down Expand Up @@ -159,6 +174,11 @@ mlir::FailureOr<ShardingParam> ShardingParam::Parse(
std::move(minor_to_major));
}

void ShardingParam::PrintV1(mlir::AsmPrinter& ods_printer,
const ShardingParam& sharding) {
PrintInternalV1(ods_printer.getStream(), sharding);
}

absl::Status ShardingParam::verify() const {
TF_RETURN_IF_ERROR(minor_to_major().verify());
int dim_index = 0;
Expand Down Expand Up @@ -275,17 +295,14 @@ llvm::hash_code hash_value(ShardingParam sharding) {
}

mlir::AsmPrinter& operator<<(mlir::AsmPrinter& os, ShardingParam sharding) {
os.getStream() << sharding;
// V1 if the current ShardingParam version.
PrintInternalV1(os.getStream(), sharding);
return os;
}

llvm::raw_ostream& operator<<(llvm::raw_ostream& os, ShardingParam sharding) {
PrintDims(os, sharding.dim_shards());
os << " to [";
llvm::interleaveComma(
llvm::ArrayRef<int>(sharding.minor_to_major().permutation), os);
os << "] on ";
PrintDims<int>(os, sharding.minor_to_major().axis_sizes);
// V1 if the current ShardingParam version.
PrintInternalV1(os, sharding);
return os;
}

Expand Down
10 changes: 10 additions & 0 deletions xla/python/ifrt/ir/sharding_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,16 @@ class ShardingParam {
minor_to_major_(std::move(minor_to_major)) {}

static mlir::FailureOr<ShardingParam> Parse(mlir::AsmParser& ods_parser);

// Parses V1 of ShardingParam. This method is meant to be used in the VIFRT
// dialect to parse versioned ShardingParams.
static mlir::FailureOr<ShardingParam> ParseV1(mlir::AsmParser& ods_parser);

// Prints V1 of ShardingParam. This method is meant to be used in the VIFRT
// dialect to print versioned ShardingParams.
static void PrintV1(mlir::AsmPrinter& ods_printer,
const ShardingParam& sharding);

absl::Status verify() const;
mlir::LogicalResult verify(
llvm::function_ref<mlir::InFlightDiagnostic()> emit_error) const;
Expand Down
2 changes: 1 addition & 1 deletion xla/python/ifrt/ir/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ xla_cc_binary(
"//xla/python/ifrt/ir:ifrt_ir_program",
"//xla/python/ifrt/ir:ifrt_ir_program_serdes", # build_cleaner: keep
"//xla/python/ifrt/ir:version",
"//xla/python/ifrt/ir:vifrt_ops",
"//xla/python/ifrt/ir:vifrt",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
Expand Down
2 changes: 1 addition & 1 deletion xla/python/ifrt/ir/tests/ifrt-translate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ limitations under the License.
#include "xla/python/ifrt/ir/ifrt_dialect.h"
#include "xla/python/ifrt/ir/ifrt_ir_program.h"
#include "xla/python/ifrt/ir/version.h"
#include "xla/python/ifrt/ir/vifrt_ops.h"
#include "xla/python/ifrt/ir/vifrt_dialect.h"
#include "xla/python/ifrt/serdes.h"

namespace xla {
Expand Down
Loading

0 comments on commit e73c1b4

Please sign in to comment.