Skip to content

Commit

Permalink
Activate Triangular Solve to XLA's FFI
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705029286
  • Loading branch information
Paweł Paruzel authored and Google-ML-Automation committed Dec 11, 2024
1 parent 3d9c720 commit 1256153
Show file tree
Hide file tree
Showing 9 changed files with 352 additions and 47 deletions.
1 change: 1 addition & 0 deletions jax/_src/export/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,7 @@ def _check_lowering(lowering) -> None:
"lapack_ssytrd_ffi", "lapack_dsytrd_ffi", "lapack_chetrd_ffi", "lapack_zhetrd_ffi",
"lapack_sgehrd_ffi", "lapack_dgehrd_ffi", "lapack_cgehrd_ffi", "lapack_zgehrd_ffi",
"lapack_sgees_ffi", "lapack_dgees_ffi", "lapack_cgees_ffi", "lapack_zgees_ffi",
"lapack_strsm_ffi", "lapack_dtrsm_ffi", "lapack_ctrsm_ffi", "lapack_ztrsm_ffi",
]
# These are the JAX custom call target names that are guaranteed to be stable.
# Their backwards compatibility is tested by back_compat_test.py.
Expand Down

Large diffs are not rendered by default.

11 changes: 7 additions & 4 deletions jax/_src/lax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1329,7 +1329,6 @@ def _triangular_solve_lowering(
ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal),
hlo.TransposeAttr.get(transpose))]

mlir.register_lowering(triangular_solve_p, _triangular_solve_lowering)

def _triangular_solve_cpu_lower(
ctx, a, b, *, left_side, lower, transpose_a,
Expand All @@ -1342,10 +1341,12 @@ def _triangular_solve_cpu_lower(
if len(a_aval.shape) == 2 and np.dtype(a_aval.dtype) in _cpu_lapack_types:
alpha = mlir.ir_constant(np.array(1, dtype=a_aval.dtype))
b_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, b_aval.shape)
# TODO(b/344892332): Remove the conditional after the compatibility period.
ctx_args = (ctx,) if jaxlib_version >= (0, 4, 37) else ()
return lapack.trsm_hlo(
a_aval.dtype, alpha,
a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal,
b_shape_vals=b_shape_vals)
*ctx_args, a_aval.dtype, alpha,
a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal,
b_shape_vals=b_shape_vals)
else:
# Fall back to the HLO implementation for unsupported types or batching.
# TODO: Consider swapping XLA for LAPACK in batched case
Expand All @@ -1358,6 +1359,8 @@ def _triangular_solve_cpu_lower(
ir.BoolAttr.get(unit_diagonal),
hlo.TransposeAttr.get(transpose))]


mlir.register_lowering(triangular_solve_p, _triangular_solve_lowering)
mlir.register_lowering(triangular_solve_p, _triangular_solve_cpu_lower,
platform='cpu')

Expand Down
8 changes: 4 additions & 4 deletions jaxlib/cpu/cpu_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,10 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(

// FFI Kernels

JAX_CPU_REGISTER_HANDLER(blas_strsm_ffi);
JAX_CPU_REGISTER_HANDLER(blas_dtrsm_ffi);
JAX_CPU_REGISTER_HANDLER(blas_ctrsm_ffi);
JAX_CPU_REGISTER_HANDLER(blas_ztrsm_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_strsm_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_dtrsm_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_ctrsm_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_ztrsm_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_sgetrf_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_dgetrf_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_cgetrf_ffi);
Expand Down
8 changes: 4 additions & 4 deletions jaxlib/cpu/lapack.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,10 @@ nb::dict Registrations() {
dict["lapack_zhetrd"] =
EncapsulateFunction(Sytrd<std::complex<double>>::Kernel);

dict["blas_strsm_ffi"] = EncapsulateFunction(blas_strsm_ffi);
dict["blas_dtrsm_ffi"] = EncapsulateFunction(blas_dtrsm_ffi);
dict["blas_ctrsm_ffi"] = EncapsulateFunction(blas_ctrsm_ffi);
dict["blas_ztrsm_ffi"] = EncapsulateFunction(blas_ztrsm_ffi);
dict["lapack_strsm_ffi"] = EncapsulateFunction(lapack_strsm_ffi);
dict["lapack_dtrsm_ffi"] = EncapsulateFunction(lapack_dtrsm_ffi);
dict["lapack_ctrsm_ffi"] = EncapsulateFunction(lapack_ctrsm_ffi);
dict["lapack_ztrsm_ffi"] = EncapsulateFunction(lapack_ztrsm_ffi);
dict["lapack_sgetrf_ffi"] = EncapsulateFunction(lapack_sgetrf_ffi);
dict["lapack_dgetrf_ffi"] = EncapsulateFunction(lapack_dgetrf_ffi);
dict["lapack_cgetrf_ffi"] = EncapsulateFunction(lapack_cgetrf_ffi);
Expand Down
8 changes: 4 additions & 4 deletions jaxlib/cpu/lapack_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2128,10 +2128,10 @@ template struct TridiagonalReduction<ffi::DataType::C128>;

// FFI Handlers

JAX_CPU_DEFINE_TRSM(blas_strsm_ffi, ::xla::ffi::DataType::F32);
JAX_CPU_DEFINE_TRSM(blas_dtrsm_ffi, ::xla::ffi::DataType::F64);
JAX_CPU_DEFINE_TRSM(blas_ctrsm_ffi, ::xla::ffi::DataType::C64);
JAX_CPU_DEFINE_TRSM(blas_ztrsm_ffi, ::xla::ffi::DataType::C128);
JAX_CPU_DEFINE_TRSM(lapack_strsm_ffi, ::xla::ffi::DataType::F32);
JAX_CPU_DEFINE_TRSM(lapack_dtrsm_ffi, ::xla::ffi::DataType::F64);
JAX_CPU_DEFINE_TRSM(lapack_ctrsm_ffi, ::xla::ffi::DataType::C64);
JAX_CPU_DEFINE_TRSM(lapack_ztrsm_ffi, ::xla::ffi::DataType::C128);

JAX_CPU_DEFINE_GETRF(lapack_sgetrf_ffi, ::xla::ffi::DataType::F32);
JAX_CPU_DEFINE_GETRF(lapack_dgetrf_ffi, ::xla::ffi::DataType::F64);
Expand Down
8 changes: 4 additions & 4 deletions jaxlib/cpu/lapack_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -741,10 +741,10 @@ struct TridiagonalReduction {
};

// Declare all the handler symbols
XLA_FFI_DECLARE_HANDLER_SYMBOL(blas_strsm_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(blas_dtrsm_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(blas_ctrsm_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(blas_ztrsm_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_strsm_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dtrsm_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_ctrsm_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_ztrsm_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgetrf_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgetrf_ffi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgetrf_ffi);
Expand Down
68 changes: 41 additions & 27 deletions jaxlib/lapack.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,52 +118,66 @@ def build_lapack_fn_target(fn_base: str, dtype) -> str:

# ?trsm(left_side, lower, trans_a, diag, m, n, alpha, a, b):
# triangular solve
def trsm_hlo(dtype, alpha, a, b,
def trsm_hlo(ctx, dtype, alpha, a, b,
left_side=False, lower=False, trans_a=False,
conj_a=False, diag=False, *,
b_shape_vals: tuple[DimensionSize, ...]):
_lapack.initialize()
if conj_a and not trans_a:
raise NotImplementedError("Conjugation without transposition not supported")
fn_base = prepare_lapack_call(fn_base="trsm", dtype=dtype)
b_type = ir.RankedTensorType(b.type)

m, n = b_shape_vals[-2:]
batch_dims_vals = b_shape_vals[:-2]
num_bd = len(batch_dims_vals)
batch_size_val = hlo_s32(1)
for b_v in batch_dims_vals:
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))

if dtype == np.float32:
fn = "blas_strsm"
elif dtype == np.float64:
fn = "blas_dtrsm"
elif dtype == np.complex64:
fn = "blas_ctrsm"
elif dtype == np.complex128:
fn = "blas_ztrsm"
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")

if conj_a and not trans_a:
raise NotImplementedError("Conjugation without transposition not supported")
scalar_layout = []
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
result_types, result_shapes = mk_result_types_and_shapes(
[(b_shape_vals, b_type.element_type)])

if ctx.is_forward_compat():
# The old TRSM kernel name is prefixed with "blas"
fn = fn_base.replace("lapack", "blas", 1)
m, n = b_shape_vals[-2:]
batch_size_val = hlo_s32(1)
for b_v in batch_dims_vals:
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
result_types, result_shapes = mk_result_types_and_shapes(
[(b_shape_vals, b_type.element_type)]
)
return custom_call(
fn,
result_types=result_types,
operands=[hlo_s32(int(left_side)), hlo_s32(int(lower)),
hlo_s32((2 if conj_a else 1) if trans_a else 0), hlo_s32(int(diag)),
ensure_hlo_s32(m), ensure_hlo_s32(n), batch_size_val,
alpha, a, b],
operand_layouts=[scalar_layout] * 8 + [layout] * 2,
result_layouts=[layout],
operand_output_aliases={9: 0},
result_shapes=result_shapes,
).results

fn = fn_base + "_ffi"
return custom_call(
fn,
result_types=result_types,
operands=[hlo_s32(int(left_side)), hlo_s32(int(lower)),
hlo_s32((2 if conj_a else 1) if trans_a else 0), hlo_s32(int(diag)),
ensure_hlo_s32(m), ensure_hlo_s32(n), batch_size_val,
alpha, a, b],
operand_layouts=[scalar_layout] * 8 + [layout] * 2,
operands=[a, b, alpha],
operand_layouts=[layout] * 2 + [scalar_layout],
result_layouts=[layout],
operand_output_aliases={9: 0},
operand_output_aliases={1: 0},
result_shapes=result_shapes,
backend_config={
"side": _matrix_side_attr(left_side=left_side),
"uplo": _matrix_uplo_attr(lower=lower),
"trans_x": _matrix_transpose_attr(
transpose=trans_a, conjugate=conj_a
),
"diag": _matrix_diagonal_attr(unit_diag=diag),
},
api_version=4,
).results



# ?potrf: Cholesky decomposition

def potrf_hlo(ctx, dtype, a: ir.Value, *, lower=False,
Expand Down
9 changes: 9 additions & 0 deletions tests/export_back_compat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def test_custom_call_coverage(self):
cpu_eigh_lapack_syev.data_2024_08_19,
cpu_lu_lapack_getrf.data_2024_05_31,
cpu_schur_lapack_gees.data_2024_11_29,
cpu_triangular_solve_blas_trsm.data_2024_12_02,
cpu_svd_lapack_gesdd.data_2024_08_13,
cpu_hessenberg_lapack_gehrd.data_2024_08_31,
cpu_tridiagonal_lapack_sytrd_hetrd.data_2024_12_01,
Expand Down Expand Up @@ -741,6 +742,14 @@ def check_triangular_solve_results(res_run, res_expected, *, rtol, atol):

self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=check_triangular_solve_results)
# TODO(b/344892332): Remove the check after the compatibility period.
has_xla_ffi_support = jaxlib_version >= (0, 4, 37)
if has_xla_ffi_support:
with config.export_ignore_forward_compatibility(True):
# FFI Kernel test
data = self.load_testdata(cpu_triangular_solve_blas_trsm.data_2024_12_02[dtype_name])
self.run_one_test(func, data, rtol=rtol, atol=atol,
check_results=check_triangular_solve_results)

@parameterized.named_parameters(
dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name)
Expand Down

0 comments on commit 1256153

Please sign in to comment.