diff --git a/src/targets/gpu/gemm_impl.cpp b/src/targets/gpu/gemm_impl.cpp index 3e3e4ad6359..a257fdba0e2 100644 --- a/src/targets/gpu/gemm_impl.cpp +++ b/src/targets/gpu/gemm_impl.cpp @@ -263,7 +263,8 @@ struct gemm_impl { if(strided_batched) { - auto common_args = create_strided_batched_args_common(ctx, compute_type, input_args); + auto common_args = + create_strided_batched_args_common(ctx, compute_type, input_args); rocblas_invoke(&rocblas_gemm_strided_batched_ex3, common_args, rocblas_gemm_algo_standard, @@ -285,7 +286,8 @@ struct gemm_impl { if(strided_batched) { - auto common_args = create_strided_batched_args_common(ctx, compute_type, input_args); + auto common_args = + create_strided_batched_args_common(ctx, compute_type, input_args); rocblas_invoke(&rocblas_gemm_strided_batched_ex, common_args, rocblas_gemm_algo_solution_index, @@ -369,7 +371,9 @@ struct gemm_impl * A and args[0] as B in calling the rocblas_gemm. * */ - auto create_strided_batched_args_common(context& ctx, rb_compute_type rbcompute_type, const std::vector& args) const + auto create_strided_batched_args_common(context& ctx, + rb_compute_type rbcompute_type, + const std::vector& args) const { return pack(ctx.get_stream().get_rocblas(), transb ? rocblas_operation_transpose : rocblas_operation_none, @@ -408,7 +412,9 @@ struct gemm_impl * A and args[0] as B in calling the rocblas_gemm. * * */ - auto create_gemm_ex_args_common(context& ctx, rb_compute_type rbcompute_type, const std::vector& args) const + auto create_gemm_ex_args_common(context& ctx, + rb_compute_type rbcompute_type, + const std::vector& args) const { return pack(ctx.get_stream().get_rocblas(), transb ? rocblas_operation_transpose : rocblas_operation_none, @@ -438,7 +444,7 @@ struct gemm_impl * Find best rocBLAS solution: Get list of solutions and try them all, returning the index * of the fastest one. */ - int tune(context& ctx, const std::vector& input_shapes) const + int tune(context& ctx, const std::vector& input_shapes) const { // tuning meta parameters const int hot_calls = 40; @@ -456,8 +462,10 @@ struct gemm_impl rocblas_int list_size = 0; std::vector solution_indices; rb_compute_type rbcompute_type = compute_type; - // rocblas_gemm_get_solutions() API requires compute_type as rocblas_datatype. Convert manually for FP8 - if(arg_type == rocblas_datatype_f8_r) { + // rocblas_gemm_get_solutions() API requires compute_type as rocblas_datatype. Convert + // manually for FP8 + if(arg_type == rocblas_datatype_f8_r) + { rbcompute_type = rocblas_datatype_f32_r; } if(strided_batched) @@ -471,7 +479,8 @@ struct gemm_impl &list_size); solution_indices.resize(list_size); - auto common_sol_args = create_strided_batched_args_common(ctx, rbcompute_type, input_args); + auto common_sol_args = + create_strided_batched_args_common(ctx, rbcompute_type, input_args); rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions, common_sol_args, rocblas_gemm_algo_solution_index,