Skip to content

Commit

Permalink
Simplifications in e2e matmul tests (iree-org#18889)
Browse files Browse the repository at this point in the history
Two commits:
1. Stop inferring `acc_type`. Require specifying it. Only a few tests
were relying on the inferrence.
2. Stop special-casing narrow float types (only using f32 as ABI type,
generating `arith.truncf` internally). This was only needed when these
narrow float types were not supported in the rest of IREE.

Signed-off-by: Benoit Jacob <[email protected]>
  • Loading branch information
bjacob authored Oct 24, 2024
1 parent 225baf2 commit 8ce8bed
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 67 deletions.
38 changes: 25 additions & 13 deletions tests/e2e/matmul/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -360,16 +360,17 @@ X86_64_AVX512_BF16 = X86_64_AVX512 + [
generator = ":generate_e2e_matmul_tests",
generator_args = [
"--lhs_rhs_type=%s" % lhs_rhs_type,
"--acc_type=%s" % acc_type,
"--shapes=small",
],
target_backends_and_drivers = [
("vmvx", "local-task"),
],
test_runner = "//tools/testing/e2e:iree-e2e-matmul-test",
test_type = "matmul",
) for lhs_rhs_type in [
"i8",
"f32",
) for (lhs_rhs_type, acc_type) in [
("i8", "i32"),
("f32", "f32"),
]]

###########################################################################
Expand All @@ -383,6 +384,7 @@ iree_generated_e2e_runner_test(
generator = ":generate_e2e_matmul_tests",
generator_args = [
"--lhs_rhs_type=f32",
"--acc_type=f32",
"--shapes=easy_large_static",
"--compilation_info=LLVMGPUMatmulSimt",
],
Expand Down Expand Up @@ -411,6 +413,7 @@ iree_generated_e2e_runner_test(
generator = ":generate_e2e_matmul_tests",
generator_args = [
"--lhs_rhs_type=f32",
"--acc_type=f32",
"--shapes=easy_large_static",
"--compilation_info=LLVMGPUMatmulTensorCore",
],
Expand All @@ -437,6 +440,7 @@ iree_generated_e2e_runner_test(
generator = ":generate_e2e_matmul_tests",
generator_args = [
"--lhs_rhs_type=f32",
"--acc_type=f32",
],
tags = [
# CUDA cuInit fails with sanitizer on.
Expand All @@ -461,6 +465,7 @@ iree_generated_e2e_runner_test(
generator = ":generate_e2e_matmul_tests",
generator_args = [
"--lhs_rhs_type=f16",
"--acc_type=f32",
],
tags = [
# CUDA cuInit fails with sanitizer on.
Expand All @@ -486,6 +491,7 @@ iree_generated_e2e_runner_test(
generator = ":generate_e2e_matmul_tests",
generator_args = [
"--lhs_rhs_type=f32",
"--acc_type=f32",
"--shapes=easy_large_static",
"--compilation_info=LLVMGPUMatmulTensorCoreMmaSync",
],
Expand Down Expand Up @@ -513,6 +519,7 @@ iree_generated_e2e_runner_test(
generator = ":generate_e2e_matmul_tests",
generator_args = [
"--lhs_rhs_type=f16",
"--acc_type=f32",
"--shapes=easy_large_static",
"--compilation_info=LLVMGPUMatmulTensorCore",
],
Expand Down Expand Up @@ -540,6 +547,7 @@ iree_generated_e2e_runner_test(
generator = ":generate_e2e_matmul_tests",
generator_args = [
"--lhs_rhs_type=f16",
"--acc_type=f32",
"--shapes=easy_large_static",
"--compilation_info=LLVMGPUMatmulTensorCoreMmaSync",
],
Expand All @@ -566,6 +574,7 @@ iree_generated_e2e_runner_test(
generator = ":generate_e2e_matmul_tests",
generator_args = [
"--lhs_rhs_type=%s" % lhs_rhs_type,
"--acc_type=%s" % acc_type,
],
tags = [
# CUDA cuInit fails with sanitizer on.
Expand All @@ -580,8 +589,8 @@ iree_generated_e2e_runner_test(
],
test_runner = "//tools/testing/e2e:iree-e2e-matmul-test",
test_type = "matmul",
) for lhs_rhs_type in [
"f32",
) for (lhs_rhs_type, acc_type) in [
("f32", "f32"),
]]

###########################################################################
Expand All @@ -598,6 +607,7 @@ iree_generated_e2e_runner_test(
generator = ":generate_e2e_matmul_tests",
generator_args = [
"--lhs_rhs_type=%s" % lhs_rhs_type,
"--acc_type=%s" % acc_type,
"--shapes=easy_large_static",
"--compilation_info=SPIRVVectorizeMali",
],
Expand All @@ -611,10 +621,10 @@ iree_generated_e2e_runner_test(
],
test_runner = "//tools/testing/e2e:iree-e2e-matmul-test",
test_type = "matmul",
) for lhs_rhs_type in [
"i8",
"f16",
"f32",
) for (lhs_rhs_type, acc_type) in [
("i8", "i32"),
("f16", "f32"),
("f32", "f32"),
]]

[iree_generated_e2e_runner_test(
Expand All @@ -625,6 +635,7 @@ iree_generated_e2e_runner_test(
generator = ":generate_e2e_matmul_tests",
generator_args = [
"--lhs_rhs_type=%s" % lhs_rhs_type,
"--acc_type=%s" % acc_type,
"--shapes=easy_large_static",
"--compilation_info=SPIRVVectorizeNVIDIA",
],
Expand All @@ -637,10 +648,10 @@ iree_generated_e2e_runner_test(
],
test_runner = "//tools/testing/e2e:iree-e2e-matmul-test",
test_type = "matmul",
) for lhs_rhs_type in [
"i8",
"f16",
"f32",
) for (lhs_rhs_type, acc_type) in [
("i8", "i32"),
("f16", "f32"),
("f32", "f32"),
]]

iree_generated_e2e_runner_test(
Expand All @@ -651,6 +662,7 @@ iree_generated_e2e_runner_test(
generator = ":generate_e2e_matmul_tests",
generator_args = [
"--lhs_rhs_type=f16",
"--acc_type=f32",
"--shapes=easy_large_static",
"--compilation_info=SPIRVCooperativeMatrixVectorize",
],
Expand Down
17 changes: 17 additions & 0 deletions tests/e2e/matmul/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,7 @@ iree_generated_e2e_runner_test(
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=i8"
"--acc_type=i32"
"--shapes=small"
TEST_RUNNER
iree_tools_testing_e2e_iree-e2e-matmul-test
Expand All @@ -948,6 +949,7 @@ iree_generated_e2e_runner_test(
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f32"
"--acc_type=f32"
"--shapes=small"
TEST_RUNNER
iree_tools_testing_e2e_iree-e2e-matmul-test
Expand All @@ -969,6 +971,7 @@ iree_generated_e2e_runner_test(
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f32"
"--acc_type=f32"
"--shapes=easy_large_static"
"--compilation_info=LLVMGPUMatmulSimt"
TEST_RUNNER
Expand All @@ -994,6 +997,7 @@ iree_generated_e2e_runner_test(
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f32"
"--acc_type=f32"
"--shapes=easy_large_static"
"--compilation_info=LLVMGPUMatmulTensorCore"
TEST_RUNNER
Expand Down Expand Up @@ -1021,6 +1025,7 @@ iree_generated_e2e_runner_test(
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f32"
"--acc_type=f32"
TEST_RUNNER
iree_tools_testing_e2e_iree-e2e-matmul-test
TARGET_BACKENDS
Expand All @@ -1046,6 +1051,7 @@ iree_generated_e2e_runner_test(
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f16"
"--acc_type=f32"
TEST_RUNNER
iree_tools_testing_e2e_iree-e2e-matmul-test
TARGET_BACKENDS
Expand All @@ -1071,6 +1077,7 @@ iree_generated_e2e_runner_test(
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f32"
"--acc_type=f32"
"--shapes=easy_large_static"
"--compilation_info=LLVMGPUMatmulTensorCoreMmaSync"
TEST_RUNNER
Expand Down Expand Up @@ -1098,6 +1105,7 @@ iree_generated_e2e_runner_test(
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f16"
"--acc_type=f32"
"--shapes=easy_large_static"
"--compilation_info=LLVMGPUMatmulTensorCore"
TEST_RUNNER
Expand Down Expand Up @@ -1125,6 +1133,7 @@ iree_generated_e2e_runner_test(
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f16"
"--acc_type=f32"
"--shapes=easy_large_static"
"--compilation_info=LLVMGPUMatmulTensorCoreMmaSync"
TEST_RUNNER
Expand Down Expand Up @@ -1152,6 +1161,7 @@ iree_generated_e2e_runner_test(
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f32"
"--acc_type=f32"
TEST_RUNNER
iree_tools_testing_e2e_iree-e2e-matmul-test
TARGET_BACKENDS
Expand All @@ -1177,6 +1187,7 @@ iree_generated_e2e_runner_test(
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=i8"
"--acc_type=i32"
"--shapes=easy_large_static"
"--compilation_info=SPIRVVectorizeMali"
TEST_RUNNER
Expand All @@ -1201,6 +1212,7 @@ iree_generated_e2e_runner_test(
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f16"
"--acc_type=f32"
"--shapes=easy_large_static"
"--compilation_info=SPIRVVectorizeMali"
TEST_RUNNER
Expand All @@ -1225,6 +1237,7 @@ iree_generated_e2e_runner_test(
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f32"
"--acc_type=f32"
"--shapes=easy_large_static"
"--compilation_info=SPIRVVectorizeMali"
TEST_RUNNER
Expand All @@ -1249,6 +1262,7 @@ iree_generated_e2e_runner_test(
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=i8"
"--acc_type=i32"
"--shapes=easy_large_static"
"--compilation_info=SPIRVVectorizeNVIDIA"
TEST_RUNNER
Expand All @@ -1273,6 +1287,7 @@ iree_generated_e2e_runner_test(
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f16"
"--acc_type=f32"
"--shapes=easy_large_static"
"--compilation_info=SPIRVVectorizeNVIDIA"
TEST_RUNNER
Expand All @@ -1297,6 +1312,7 @@ iree_generated_e2e_runner_test(
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f32"
"--acc_type=f32"
"--shapes=easy_large_static"
"--compilation_info=SPIRVVectorizeNVIDIA"
TEST_RUNNER
Expand All @@ -1321,6 +1337,7 @@ iree_generated_e2e_runner_test(
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f16"
"--acc_type=f32"
"--shapes=easy_large_static"
"--compilation_info=SPIRVCooperativeMatrixVectorize"
TEST_RUNNER
Expand Down
61 changes: 7 additions & 54 deletions tests/e2e/matmul/generate_e2e_matmul_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,20 +545,6 @@ def int_or_DYN(s: DimSize):
return s.value or "DYN"


# Gets friendlier form/type that we can use as arg types which we can cast into the target_type.
def cast_argtype_if_required(target_type: MatrixElemTypeId):
if target_type == MatrixElemTypeId.F8E4M3FNUZ:
return MatrixElemTypeId.F32
return target_type


# Gets the op needed to cast/convert from the friendly form/type into the target_type.
def get_castback_from_arg_op(target_type: MatrixElemTypeId):
if target_type == MatrixElemTypeId.F8E4M3FNUZ:
return "arith.truncf"
return ValueError(f"Unhandled castback type of {target_type}")


# Describes the fully resolved shape dimensions of all 3 input matrices,
# LHS, RHS, and Accumulator, in a testcase.
# Each value is a string, which may either represent a positive integer such as "123",
Expand Down Expand Up @@ -659,9 +645,8 @@ def generate_function(
acc_r = int_or_question_mark(shapes.acc_rows)
acc_c = int_or_question_mark(shapes.acc_cols)

casted_lhs_rhs_type = cast_argtype_if_required(lhs_rhs_type)
lhs_tensor_type = f"tensor<{lhs_r}x{lhs_c}x{casted_lhs_rhs_type.value}>"
rhs_tensor_type = f"tensor<{rhs_r}x{rhs_c}x{casted_lhs_rhs_type.value}>"
lhs_tensor_type = f"tensor<{lhs_r}x{lhs_c}x{lhs_rhs_type.value}>"
rhs_tensor_type = f"tensor<{rhs_r}x{rhs_c}x{lhs_rhs_type.value}>"
acc_tensor_type = f"tensor<{acc_r}x{acc_c}x{acc_type.value}>"

if transpose_rhs:
Expand All @@ -680,15 +665,6 @@ def generate_function(
func_definition = func_definition + compilation_info_string
generate_function.compilation_index += 1
compute = f" %result = {op_name} {compilation_info_attr}ins(%lhs, %rhs: {lhs_tensor_type}, {rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}\n"
if casted_lhs_rhs_type != lhs_rhs_type:
castback_op = get_castback_from_arg_op(lhs_rhs_type)
compute_lhs_tensor_type = f"tensor<{lhs_r}x{lhs_c}x{lhs_rhs_type.value}>"
compute_rhs_tensor_type = f"tensor<{rhs_r}x{rhs_c}x{lhs_rhs_type.value}>"
compute = (
f" %lhs_casted = {castback_op} %lhs: {lhs_tensor_type} to {compute_lhs_tensor_type}\n"
f" %rhs_casted = {castback_op} %rhs: {rhs_tensor_type} to {compute_rhs_tensor_type}\n"
f" %result = {op_name} {compilation_info_attr}ins(%lhs_casted, %rhs_casted: {compute_lhs_tensor_type}, {compute_rhs_tensor_type}) outs(%acc: {acc_tensor_type}) -> {acc_tensor_type}"
)
if shape.accumulate:
signature = f"({lhs_tensor_type}, {rhs_tensor_type}, {acc_tensor_type}) -> {acc_tensor_type}"
import_declaration = f"func.func private @module.{func_name}(%lhs: !hal.buffer_view, %rhs: !hal.buffer_view, %acc: !hal.buffer_view) -> !hal.buffer_view"
Expand Down Expand Up @@ -818,9 +794,8 @@ def generate_call(
rhs_shape = [shape.k, shape.n]
transpose_rhs = 0

casted_lhs_rhs_type = cast_argtype_if_required(lhs_rhs_type)
op = op + generate_random_matrix("lhs", lhs_shape, casted_lhs_rhs_type)
op = op + generate_random_matrix("rhs", rhs_shape, casted_lhs_rhs_type)
op = op + generate_random_matrix("lhs", lhs_shape, lhs_rhs_type)
op = op + generate_random_matrix("rhs", rhs_shape, lhs_rhs_type)
if shape.accumulate:
op = op + generate_random_matrix("acc", [shape.m, shape.n], acc_type)
# TODO(#16168): there's a bug with in-place input->output aliasing and
Expand Down Expand Up @@ -919,16 +894,15 @@ def parse_arguments():
"f8E5M2FNUZ",
"f8E4M3FNUZ",
],
help="Numeric type of input matrices",
help="Numeric type of input LHS and RHS matrices",
required=True,
)
parser.add_argument(
"--acc_type",
type=str,
choices=["i32", "f32", "f16", "bf16"],
help="Numeric type of input matrices",
default="",
required=False,
help="Numeric type of the accumulator and result matrices",
required=True,
)
parser.add_argument(
"--shapes",
Expand Down Expand Up @@ -1005,30 +979,9 @@ def write_calls_file(functions, calls, filename, requirements):
file.write(module_definition)


# For now, the accumulator type can always be inferred from the input LHS/RHS
# type, so we do that. That is temporary: eventually there will be cases
# where the same input types are used with different accumulator types, e.g.
# f16 inputs with both f16 and f32 accumulator.
def infer_acc_type(lhs_rhs_type: MatrixElemTypeId, acc_type: MatrixElemTypeId):
if acc_type != MatrixElemTypeId.NONE:
return acc_type
if lhs_rhs_type == MatrixElemTypeId.F8E5M2:
return MatrixElemTypeId.F32
if lhs_rhs_type == MatrixElemTypeId.F8E4M3:
return MatrixElemTypeId.F32
if lhs_rhs_type == MatrixElemTypeId.F8E5M2FNUZ:
return MatrixElemTypeId.F32
if lhs_rhs_type == MatrixElemTypeId.F8E4M3FNUZ:
return MatrixElemTypeId.F32
if lhs_rhs_type == MatrixElemTypeId.I8:
return MatrixElemTypeId.I32
return lhs_rhs_type


def main(args):
lhs_rhs_type = MatrixElemTypeId(args.lhs_rhs_type)
acc_type = MatrixElemTypeId(args.acc_type)
acc_type = infer_acc_type(lhs_rhs_type, acc_type)
shapes_id = ShapesId(args.shapes)
compilation_info_id = CompilationInfoId(args.compilation_info)

Expand Down

0 comments on commit 8ce8bed

Please sign in to comment.