Skip to content

Commit

Permalink
[CUBLAS][FP8] Enable R.matmul + R.multiply offloading (#16974)
Browse files Browse the repository at this point in the history
This commit enables offloading of the next pattern to cuBLAS:
  mm = R.linear(data, weights)
  scale = R.multiply(a_scale, w_scale)
  out = R.multiply(mm, scale)
  out = R.cast(out, dtype)
  • Loading branch information
ibsidorenko authored May 8, 2024
1 parent 02c4c55 commit c0a47ed
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 12 deletions.
11 changes: 10 additions & 1 deletion python/tvm/relax/backend/contrib/cublas.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
from tvm.relax.transform import PatternCheckContext

from ..pattern_registry import get_patterns_with_prefix, register_patterns
from ..patterns import make_matmul_pattern, make_matmul_dequantize_pattern
from ..patterns import (
make_matmul_pattern,
make_matmul_dequantize_pattern,
make_matmul_multiply_pattern,
)
from ..utils import has_leaking_intermediate_variables


Expand Down Expand Up @@ -202,6 +206,11 @@ def _check_matmul(context: PatternCheckContext) -> bool:
*make_matmul_dequantize_pattern(transposed_rhs=True),
_check_matmul,
),
(
"cublas.matmul_transposed_multiply",
*make_matmul_multiply_pattern(transposed_rhs=True),
_check_matmul,
),
]
)

Expand Down
38 changes: 38 additions & 0 deletions python/tvm/relax/backend/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,44 @@ def make_matmul_dequantize_pattern(
return out, annotations


def make_matmul_multiply_pattern(
transposed_rhs: bool = False,
) -> Tuple[DFPattern, Mapping[str, DFPattern]]:
"""
Create pattern for matrix multiplication and multiply operation.
Parameters
----------
transposed_rhs: bool
Whether the right hand side of multiplication is transposed.
Returns
-------
pattern: DFPattern
The resulting pattern describing a matrix multiplication.
annotations: Mapping[str, DFPattern]
A mapping from name to sub pattern. It can be used to extract important expressions from
match result, to power the partition check function and codegen.
"""

lhs = wildcard()
rhs = wildcard()
scaleA = wildcard()
scaleB = wildcard()
annotations = {"lhs": lhs, "rhs": rhs, "scaleA": scaleA, "scaleB": scaleB}

if transposed_rhs:
rhs = is_op("relax.permute_dims")(rhs)
out = is_op("relax.matmul")(lhs, rhs)
annotations["root"] = out
scale = is_op("relax.multiply")(scaleA.has_shape((1,)), scaleB.has_shape((1,)))
out = is_op("relax.multiply")(out, scale)
out = is_op("relax.astype")(out)

return out, annotations


def make_attention_rewrite_pattern(
qkv_layout: str, out_layout: str, with_bias: bool, with_cast: bool, with_kv_repeat: bool = False
):
Expand Down
5 changes: 4 additions & 1 deletion src/relax/backend/contrib/cublas/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,17 @@ class CublasJSONSerializer : public JSONSerializer {
inputs_tmp.insert(inputs_tmp.end(), res.begin(), res.end());
}

ICHECK(inputs_tmp.size() <= 3);
ICHECK(inputs_tmp.size() <= 4);
NodeEntries inputs(inputs_tmp.size());

auto arg_idx = backend::ExtractArgIdx(composite_name, fn);
inputs[0] = inputs_tmp[arg_idx["lhs"]->value];
inputs[1] = inputs_tmp[arg_idx["rhs"]->value];
if (inputs_tmp.size() == 3) {
inputs[2] = inputs_tmp[arg_idx["bias"]->value];
} else if (inputs_tmp.size() == 4) {
inputs[2] = inputs_tmp[arg_idx["scaleA"]->value];
inputs[3] = inputs_tmp[arg_idx["scaleB"]->value];
}

auto node = std::make_shared<JSONGraphNode>(composite_name, /* name_ */
Expand Down
14 changes: 12 additions & 2 deletions src/runtime/contrib/cublas/cublas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,9 @@ int roundoff(int v, int d) { return (v + d - 1) / d * d; }

void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
cublasLtMatmulPreference_t matmul_pref_desc, const DLTensor* A, const DLTensor* B,
const DLTensor* bias, const DLTensor* C, bool transa, bool transb,
void* workspace_ptr, size_t workspace_size, cublasLtEpilogue_t epilogue,
const DLTensor* bias, const DLTensor* scaleA, const DLTensor* scaleB,
const DLTensor* C, bool transa, bool transb, void* workspace_ptr,
size_t workspace_size, cublasLtEpilogue_t epilogue,
std::optional<float> dq_scale) {
ICHECK(TypeEqual(A->dtype, B->dtype));
// Reversed strides indicates an in-place transpose operation.
Expand Down Expand Up @@ -193,6 +194,15 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
&bias->data, sizeof(float*)));
}

if (scaleA != nullptr && scaleB != nullptr) {
auto scaleA_data = static_cast<char*>(scaleA->data) + scaleA->byte_offset;
auto scaleB_data = static_cast<char*>(scaleB->data) + scaleB->byte_offset;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&scaleA_data, sizeof(float*)));
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&scaleB_data, sizeof(float*)));
}

if (epilogue != CUBLASLT_EPILOGUE_DEFAULT) {
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue, sizeof(epilogue)));
Expand Down
15 changes: 10 additions & 5 deletions src/runtime/contrib/cublas/cublas_json_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,15 @@ class CublasJSONRuntime : public JSONRuntimeBase {
return dl_tensors[eid];
};

auto get_inputs = [=](const JSONGraphNode& node, bool has_bias) {
const DLTensor* bias = nullptr;
auto get_inputs = [=](const JSONGraphNode& node, bool has_bias, bool has_scale) {
const DLTensor *bias = nullptr, *scaleA = nullptr, *scaleB = nullptr;
if (has_bias) {
bias = get_input(node, 2);
} else if (has_scale) {
scaleA = get_input(node, 2);
scaleB = get_input(node, 3);
}
return std::make_tuple(get_input(node, 0), get_input(node, 1), bias);
return std::make_tuple(get_input(node, 0), get_input(node, 1), bias, scaleA, scaleB);
};

for (size_t i = 0; i < nodes_.size(); ++i) {
Expand All @@ -127,15 +130,17 @@ class CublasJSONRuntime : public JSONRuntimeBase {
epilogue = CUBLASLT_EPILOGUE_BIAS;
}

auto [a_ptr, b_ptr, bias_ptr] = get_inputs(node, epilogue != CUBLASLT_EPILOGUE_DEFAULT);
bool has_scale = op_name.find("multiply") != std::string::npos;
auto [a_ptr, b_ptr, bias_ptr, scaleA_ptr, scaleB_ptr] =
get_inputs(node, epilogue != CUBLASLT_EPILOGUE_DEFAULT, has_scale);

std::optional<float> dq_scale = std::nullopt;
if (op_name.find("dequantize") != std::string::npos) {
dq_scale = std::stof(node.GetAttr<std::vector<std::string>>("dq_scale")[0]);
}

tvm::contrib::CallCublasLt(entry_ptr->handle, stream, entry_ptr->matmul_pref_desc, a_ptr,
b_ptr, bias_ptr, out_ptr, transa, transb,
b_ptr, bias_ptr, scaleA_ptr, scaleB_ptr, out_ptr, transa, transb,
entry_ptr->workspace_ptr, entry_ptr->workspace_size, epilogue,
dq_scale);
}
Expand Down
6 changes: 3 additions & 3 deletions src/runtime/contrib/cublas/cublas_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ inline cudaDataType_t GetCudaDataType(DLDataType type) {
/*! \brief Execute matrix multiply followed by the specified epilogue, using cuBLASLt. */
void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
cublasLtMatmulPreference_t matmul_pref_desc, const DLTensor* A, const DLTensor* B,
const DLTensor* bias, const DLTensor* C, bool transa, bool transb,
void* workspace_ptr, size_t workspace_size,
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT,
const DLTensor* bias, const DLTensor* scaleA, const DLTensor* scaleB,
const DLTensor* C, bool transa, bool transb, void* workspace_ptr,
size_t workspace_size, cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT,
std::optional<float> dq_scale = std::nullopt);

} // namespace contrib
Expand Down
79 changes: 79 additions & 0 deletions tests/python/relax/test_codegen_cublas.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,40 @@ def get_relax_matmul_dequantize_module(
return tvm.IRModule({"main": func})


def get_relax_matmul_multiply_module(
x_shape,
y_shape,
z_shape,
in_dtype,
acc_dtype,
out_dtype,
transposed_y=False,
):
"""Create a matmul op followd by multiply operations."""
with IRBuilder() as builder:
with relax_builder.function():
R.func_name("main")
x = R.arg("x", R.Tensor(x_shape, in_dtype))
y = R.arg("y", R.Tensor(y_shape, in_dtype))
scaleA = R.arg("scaleA", R.Tensor(z_shape, acc_dtype))
scaleB = R.arg("scaleB", R.Tensor(z_shape, acc_dtype))

with R.dataflow() as frame:
if transposed_y:
axes = list(range(len(y_shape) - 2)) + [-1, -2]
y = R.emit(R.permute_dims(y, axes=axes))
result = R.emit(R.matmul(x, y, out_dtype=acc_dtype))
z = R.emit(R.multiply(scaleA, scaleB))
result = R.emit(R.multiply(result, z))
if acc_dtype != out_dtype:
result = R.emit(R.astype(result, out_dtype))
R.output(result)
R.func_ret_value(frame.output_vars[0])

func = builder.get()
return tvm.IRModule({"main": func})


@pytest.mark.parametrize(
"x_shape, y_shape, transpose_y, epilogue",
[
Expand Down Expand Up @@ -327,6 +361,36 @@ def test_matmul_fp8_dequantize_offload():
tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3)


@tvm.testing.requires_cuda_compute_version(9)
@pytest.mark.skipif(ml_dtypes is None, reason="requires ml_dtypes to be installed")
def test_matmul_fp8_multiply_offload():
x_shape = (10, 32)
y_shape = (64, 32)
z_shape = (1,)
in_dtype, acc_dtype = ("e4m3_float8", "float32")

mod = get_relax_matmul_multiply_module(
x_shape,
y_shape,
z_shape,
in_dtype,
acc_dtype,
"float16",
transposed_y=True,
)

numpytype = "float8_e4m3fn"
x = np.random.uniform(low=0, high=5, size=x_shape).astype(numpytype)
y = np.random.uniform(low=0, high=5, size=y_shape).astype(numpytype)
scaleA = np.random.uniform(low=0, high=5, size=z_shape).astype(acc_dtype)
scaleB = np.random.uniform(low=0, high=5, size=z_shape).astype(acc_dtype)
args = (x, y, scaleA, scaleB)

out = get_result_with_relax_cublas_offload(mod, args)
ref = build_and_run(mod, args, "llvm", legalize=True)
tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3)


@pytest.mark.parametrize(
"M, N, K, out_dtype, transposed_y, partition_done",
[
Expand Down Expand Up @@ -371,6 +435,21 @@ def test_cublas_partition_fp8_matmul_dequantize(M, N, K, scale, zp, num_bindings
assert len(mod["main"].body.blocks[0].bindings) == num_bindings


def test_cublas_partition_fp8_matmul_multiply():
M, N, K = (32, 64, 128)
mod = get_relax_matmul_multiply_module(
(M, K),
(N, K),
(1,),
"e4m3_float8",
"float32",
"float16",
transposed_y=True,
)
mod = partition_for_cublas(mod)
assert len(mod["main"].body.blocks[0].bindings) == 1


def test_cublas_partition_matmul_without_bias():
# cuBLAS does not handle 2D bias (residual input)
mod = get_relax_matmul_module((16, 32), (32, 32), "float16", "float16", bias_shape=(16, 32))
Expand Down

0 comments on commit c0a47ed

Please sign in to comment.