Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
umangyadav committed Jul 3, 2024
1 parent 22cebd8 commit 0df395e
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions src/targets/gpu/gemm_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ struct gemm_impl
{
if(strided_batched)
{
auto common_args = create_strided_batched_args_common(ctx, compute_type, input_args);
auto common_args =
create_strided_batched_args_common(ctx, compute_type, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex3,
common_args,
rocblas_gemm_algo_standard,
Expand All @@ -285,7 +286,8 @@ struct gemm_impl
{
if(strided_batched)
{
auto common_args = create_strided_batched_args_common(ctx, compute_type, input_args);
auto common_args =
create_strided_batched_args_common(ctx, compute_type, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex,
common_args,
rocblas_gemm_algo_solution_index,
Expand Down Expand Up @@ -369,7 +371,9 @@ struct gemm_impl
* A and args[0] as B in calling the rocblas_gemm.
*
*/
auto create_strided_batched_args_common(context& ctx, rb_compute_type rbcompute_type, const std::vector<argument>& args) const
auto create_strided_batched_args_common(context& ctx,
rb_compute_type rbcompute_type,
const std::vector<argument>& args) const
{
return pack(ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
Expand Down Expand Up @@ -408,7 +412,9 @@ struct gemm_impl
* A and args[0] as B in calling the rocblas_gemm.
*
* */
auto create_gemm_ex_args_common(context& ctx, rb_compute_type rbcompute_type, const std::vector<argument>& args) const
auto create_gemm_ex_args_common(context& ctx,
rb_compute_type rbcompute_type,
const std::vector<argument>& args) const
{
return pack(ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
Expand Down Expand Up @@ -438,7 +444,7 @@ struct gemm_impl
* Find best rocBLAS solution: Get list of solutions and try them all, returning the index
* of the fastest one.
*/
int tune(context& ctx, const std::vector<shape>& input_shapes) const
int tune(context& ctx, const std::vector<shape>& input_shapes) const
{
// tuning meta parameters
const int hot_calls = 40;
Expand All @@ -456,8 +462,10 @@ struct gemm_impl
rocblas_int list_size = 0;
std::vector<rocblas_int> solution_indices;
rb_compute_type rbcompute_type = compute_type;
// rocblas_gemm_get_solutions() API requires compute_type as rocblas_datatype. Convert manually for FP8
if(arg_type == rocblas_datatype_f8_r) {
// rocblas_gemm_get_solutions() API requires compute_type as rocblas_datatype. Convert
// manually for FP8
if(arg_type == rocblas_datatype_f8_r)
{
rbcompute_type = rocblas_datatype_f32_r;
}
if(strided_batched)
Expand All @@ -471,7 +479,8 @@ struct gemm_impl
&list_size);
solution_indices.resize(list_size);

auto common_sol_args = create_strided_batched_args_common(ctx, rbcompute_type, input_args);
auto common_sol_args =
create_strided_batched_args_common(ctx, rbcompute_type, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions,
common_sol_args,
rocblas_gemm_algo_solution_index,
Expand Down

0 comments on commit 0df395e

Please sign in to comment.