Skip to content

Commit

Permalink
[Mosaic GPU] Add Python bindings for the Mosaic GPU MLIR dialect.
Browse files Browse the repository at this point in the history
Also start moving the existing C++ tests to Python.

PiperOrigin-RevId: 691729887
  • Loading branch information
bchetioui authored and Google-ML-Automation committed Oct 31, 2024
1 parent 85662f6 commit c708a04
Show file tree
Hide file tree
Showing 16 changed files with 327 additions and 65 deletions.
1 change: 1 addition & 0 deletions jax/_src/lib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ py_library_providing_imports_info(
"//jax:version",
] + if_building_jaxlib([
"//jaxlib",
"//jaxlib/mosaic/python:gpu_dialect",
"//jaxlib/mosaic/python:tpu_dialect",
"//jaxlib:cpu_feature_guard",
"//jaxlib:utils",
Expand Down
9 changes: 8 additions & 1 deletion jax/_src/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,14 @@ def _xla_gc_callback(*args):
import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error # noqa: F401
import jaxlib.gpu_triton as gpu_triton # pytype: disable=import-error # noqa: F401

import jaxlib.mosaic.python.tpu as tpu # pytype: disable=import-error # noqa: F401
try:
import jaxlib.mosaic.python.gpu as mosaic_gpu_dialect # pytype: disable=import-error
except ImportError:
# TODO(bchetioui): Remove this when minimum jaxlib version >= 0.4.36.
# Jaxlib doesn't contain Mosaic GPU dialect bindings.
mosaic_gpu_dialect = None # type: ignore

import jaxlib.mosaic.python.tpu as tpu # pytype: disable=import-error # noqa: F401

# Version number for MLIR:Python APIs, provided by jaxlib.
mlir_api_version = xla_client.mlir_api_version
Expand Down
2 changes: 2 additions & 0 deletions jax/experimental/mosaic/gpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# ==============================================================================

from jax import ShapeDtypeStruct as ShapeDtypeStruct
from jax._src.lib import mosaic_gpu_dialect as dialect # noqa: F401

from .core import (
Barrier as Barrier,
ClusterBarrier as ClusterBarrier,
Expand Down
12 changes: 12 additions & 0 deletions jaxlib/mlir/_mlir_libs/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,18 @@ py_extension(
],
)

py_extension(
name = "_mosaic_gpu_ext",
srcs = ["mosaic_gpu_ext.cc"],
copts = COPTS,
linkopts = LINKOPTS,
deps = [
"//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi",
"@llvm-project//mlir:CAPIIRHeaders",
"@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps",
],
)

# This is here, instead of in jaxlib/mosaic/python, so it's in the same
# directory as libjaxlib_mlir_capi.so (produced by
# :jaxlib_mlir_capi_shared_library). This ensures that the RPATH works correctly
Expand Down
37 changes: 37 additions & 0 deletions jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// clang-format: off
// pybind11 must be included before mlir/Bindings/Python/PybindAdaptors.h,
// otherwise this code will not build on Windows.
#include "pybind11/pybind11.h"
// clang-format: on

#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/PybindAdaptors.h" // IWYU pragma: keep
#include "jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h"

PYBIND11_MODULE(_mosaic_gpu_ext, m, py::mod_gil_not_used()) {
m.def(
"register_dialect",
[](MlirContext context, bool load) {
MlirDialectHandle dialect = mlirGetDialectHandle__mosaic_gpu__();
mlirDialectHandleRegisterDialect(dialect, context);
if (load) {
mlirDialectHandleLoadDialect(dialect, context);
}
},
py::arg("context"), py::arg("load") = true);
}
52 changes: 51 additions & 1 deletion jaxlib/mosaic/dialect/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
load(
"@llvm-project//mlir:tblgen.bzl",
"gentbl_cc_library",
"gentbl_filegroup",
"td_library",
)

package(
default_applicable_licenses = [],
Expand Down Expand Up @@ -143,3 +148,48 @@ cc_test(
"@tsl//tsl/platform:errors",
],
)

gentbl_filegroup(
name = "mosaic_gpu_python_gen_raw",
tbl_outs = [
(
[
"-gen-python-op-bindings",
"-bind-dialect=mosaic_gpu",
],
"_mosaic_gpu_gen_raw.py",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = ":mosaic_gpu.td",
deps = [
":mosaic_gpu_td_files",
"@llvm-project//mlir:OpBaseTdFiles",
],
)

genrule(
name = "mosaic_gpu_python_gen",
srcs = ["_mosaic_gpu_gen_raw.py"],
outs = ["_mosaic_gpu_gen.py"],
cmd = "cat $(location _mosaic_gpu_gen_raw.py) | sed -e 's/^from \\./from jaxlib\\.mlir\\.dialects\\./g' > $@",
)

DIALECT_CAPI_SOURCES = [
":integrations/c/gpu_dialect.cc",
]

DIALECT_CAPI_HEADERS = [
":integrations/c/gpu_dialect.h",
]

cc_library(
name = "gpu_dialect_capi",
srcs = DIALECT_CAPI_SOURCES,
hdrs = DIALECT_CAPI_HEADERS,
deps = [
":mosaic_gpu",
":mosaic_gpu_inc_gen",
"@llvm-project//mlir:CAPIIR",
],
)
25 changes: 25 additions & 0 deletions jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h"

#include "mlir/CAPI/Registration.h"
#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu.h"

extern "C" {

MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(MosaicGPU, mosaic_gpu,
mosaic_gpu::MosaicGPUDialect);
}
33 changes: 33 additions & 0 deletions jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef JAXLIB_MOSAIC_DIALECT_GPU_INTEGRATIONS_C_GPU_DIALECT_H_
#define JAXLIB_MOSAIC_DIALECT_GPU_INTEGRATIONS_C_GPU_DIALECT_H_

#include <stddef.h>

#include "mlir/CAPI/Registration.h"

#ifdef __cplusplus
extern "C" {
#endif

MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(MosaicGPU, mosaic_gpu);

#ifdef __cplusplus
}
#endif

#endif // JAXLIB_MOSAIC_DIALECT_GPU_INTEGRATIONS_C_GPU_DIALECT_H_
44 changes: 22 additions & 22 deletions jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,28 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/TypeSwitch.h" // IWYU pragma: keep
#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/Casting.h"
#include "mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/include/mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/include/mlir/IR/Builders.h"
#include "mlir/include/mlir/IR/BuiltinAttributes.h"
#include "mlir/include/mlir/IR/BuiltinTypes.h"
#include "mlir/include/mlir/IR/Dialect.h"
#include "mlir/include/mlir/IR/DialectImplementation.h" // IWYU pragma: keep
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/include/mlir/IR/Location.h"
#include "mlir/include/mlir/IR/MLIRContext.h"
#include "mlir/include/mlir/IR/TypeRange.h"
#include "mlir/include/mlir/IR/Types.h"
#include "mlir/include/mlir/IR/Value.h"
#include "mlir/include/mlir/IR/ValueRange.h"
#include "mlir/include/mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep
#include "llvm/Support/Casting.h"
#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Support/LLVM.h"
#include "tsl/platform/statusor.h"

// Generated definitions.
Expand Down
12 changes: 6 additions & 6 deletions jaxlib/mosaic/dialect/gpu/mosaic_gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ limitations under the License.

#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/raw_ostream.h"
#include "mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/include/mlir/IR/Builders.h"
#include "mlir/include/mlir/IR/BuiltinTypes.h"
#include "mlir/include/mlir/IR/Value.h"
#include "mlir/include/mlir/Support/LLVM.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"

// Generated definitions.
#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_dialect.h.inc" // IWYU pragma: keep
Expand Down
14 changes: 7 additions & 7 deletions jaxlib/mosaic/dialect/gpu/mosaic_gpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ limitations under the License.
#ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_
#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_

include "third_party/llvm/llvm-project/mlir/include/mlir/IR/AttrTypeBase.td"
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BuiltinTypeInterfaces.td"
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/CommonAttrConstraints.td"
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/CommonTypeConstraints.td"
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/DialectBase.td"
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/EnumAttr.td"
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinTypeInterfaces.td"
include "mlir/IR/CommonAttrConstraints.td"
include "mlir/IR/CommonTypeConstraints.td"
include "mlir/IR/DialectBase.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"

def MosaicGPU_Dialect : Dialect {
let name = "mosaic_gpu";
Expand Down
28 changes: 0 additions & 28 deletions jaxlib/mosaic/dialect/gpu/mosaic_gpu_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,34 +193,6 @@ TEST_F(MosaicGpuTest, RuntimeFunctionsAreRegistered) {
mosaic_gpu::kRuntimeMemcpyAsyncH2DName));
}

TEST_F(MosaicGpuTest, InitializeBarrierOpEnforcesRelevantInvariants) {
auto loc = builder_.getUnknownLoc();
auto f32 = builder_.getF32Type();
auto barrier = BarrierType::get(&context_);

// InitializeBarrierOp requires a memref with type `BarrierType`.
auto initialize_op = builder_.create<InitializeBarrierOp>(
loc, mlir::MemRefType::get({1, 2}, f32), /*arrival_count=*/1);
EXPECT_FALSE(mlir::succeeded(mlir::verify(*module_)));
ExpectLastErrorContains("must be memref of barrier values");
initialize_op->erase();

// InitializeBarrierOp requires a non-negative arrival count.
initialize_op = builder_.create<InitializeBarrierOp>(
loc, mlir::MemRefType::get({1, 2}, barrier), /*arrival_count=*/0);
EXPECT_FALSE(mlir::succeeded(mlir::verify(*module_)));
ExpectLastErrorContains("value is positive");
initialize_op->erase();

// Checks that InitializeBarrierOp prints nicely.
initialize_op = builder_.create<InitializeBarrierOp>(
loc, mlir::MemRefType::get({1, 2}, barrier), /*arrival_count=*/1);
EXPECT_TRUE(mlir::succeeded(mlir::verify(*module_)));
EXPECT_THAT(
MlirToString(initialize_op),
HasSubstr(
"mosaic_gpu.initialize_barrier 1 : memref<1x2x!mosaic_gpu.barrier>"));
}

} // anonymous namespace
} // namespace mosaic_gpu
13 changes: 13 additions & 0 deletions jaxlib/mosaic/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,19 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_filegroup")
load("@rules_python//python:defs.bzl", "py_library")

py_library(
name = "gpu_dialect",
srcs = [
"gpu.py",
"//jaxlib/mosaic/dialect/gpu:_mosaic_gpu_gen.py",
],
visibility = ["//visibility:public"],
deps = [
"//jaxlib/mlir",
"//jaxlib/mlir/_mlir_libs:_mosaic_gpu_ext",
],
)

gentbl_filegroup(
name = "tpu_python_gen_raw",
tbl_outs = [
Expand Down
31 changes: 31 additions & 0 deletions jaxlib/mosaic/python/gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Python bindings for the MLIR Mosaic GPU dialect."""

# ruff: noqa: F401
# ruff: noqa: F403


# pylint: disable=g-bad-import-order
from jaxlib.mosaic.dialect.gpu._mosaic_gpu_gen import * # pylint: disable=wildcard-import # type: ignore[import-not-found]
from jaxlib.mlir._mlir_libs._mosaic_gpu_ext import * # pylint: disable=wildcard-import # type: ignore[import-not-found]

try:
from jaxlib.mlir.dialects._ods_common import _cext
except ImportError:
from mlir.dialects._ods_common import _cext # type: ignore[import-not-found]


_cext.globals.append_dialect_search_prefix("jax.jaxlib.mosaic.python")
Loading

0 comments on commit c708a04

Please sign in to comment.