Skip to content

Commit

Permalink
[GPU] Enable eltwise primitive fusion to RMS (openvinotoolkit#28435)
Browse files Browse the repository at this point in the history
### Details:
 - This PR enables primitive fusion to RMS.
  • Loading branch information
e-ddykim authored Jan 15, 2025
1 parent 28473b9 commit 7260cc0
Show file tree
Hide file tree
Showing 8 changed files with 262 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "gemm_inst.h"
#include "lrn_inst.h"
#include "mvn_inst.h"
#include "rms_inst.h"
#include "pooling_inst.h"
#include "normalize_inst.h"
#include "permute_inst.h"
Expand Down Expand Up @@ -764,6 +765,8 @@ void prepare_primitive_fusing::fuse_simple_primitives(program &p) {

should_fuse |= input.is_type<mvn>();

should_fuse |= input.is_type<rms>();

should_fuse |= input.is_type<group_normalization>();

should_fuse |= input.is_type<normalize>() && data_type_traits::is_i8_u8(input.get_input_layout(0).data_type);
Expand Down Expand Up @@ -964,6 +967,7 @@ void prepare_primitive_fusing::fuse_simple_primitives(program &p) {
(parents[i].first->is_type<mvn>() &&
mvn_supports_fusings(parents[i].first->as<mvn>())) ||
(parents[i].first->is_type<group_normalization>()) ||
(parents[i].first->is_type<rms>()) ||
(parents[i].first->is_type<deconvolution>()) ||
(parents[i].first->is_type<permute>()) ||
(parents[i].first->is_type<resample>()) ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ KERNEL(rms_gpu_bfyx_opt)(
OPTIONAL_SHAPE_INFO_ARG
const __global INPUT0_TYPE* input,
const __global INPUT1_TYPE* gamma,
__global OUTPUT_TYPE* output)
__global OUTPUT_TYPE* output
#if HAS_FUSED_OPS_DECLS
, FUSED_OPS_DECLS
#endif
)
{
const uint data_idx = get_global_id(1);
const uint in_data_idx = get_global_id(0);
Expand Down Expand Up @@ -100,18 +104,53 @@ KERNEL(rms_gpu_bfyx_opt)(

rms = slm_buf[0];

#if HAS_FUSED_OPS
uint b, f, z, y, x;
#if INPUT_RANK == 1
f = z = y = x = 1;
#elif INPUT_RANK == 2
z = y = x = 1;
b = data_idx;
#elif INPUT_RANK == 3
x = 1;
f = data_idx % OUTPUT_FEATURE_NUM;
b = data_idx / OUTPUT_FEATURE_NUM;
#else
x = data_idx;
y = x % OUTPUT_SIZE_Y; x = x / OUTPUT_SIZE_Y;
z = x % OUTPUT_SIZE_Z; x = x / OUTPUT_SIZE_Z;
f = x % OUTPUT_FEATURE_NUM; x = x / OUTPUT_FEATURE_NUM;
b = x % OUTPUT_BATCH_NUM; x = x / OUTPUT_BATCH_NUM;
#endif
#endif

i = 0;
if ((workers_per_data > SUB_GROUP_SIZE) && USE_BLOCK_WRITE)
{
for (; i < items_num - (items_num % SUBGROUP_BLOCK_SIZE); i += SUBGROUP_BLOCK_SIZE)
{
ACC_TYPE vec_gamma = TO_ACC_TYPE(BLOCK_READ(gamma, subgroup_offset + i * get_sub_group_size()));
OUTPUT_VEC_TYPE vec_tmp;
#if HAS_FUSED_OPS
LAST_DIM = subgroup_offset + i * get_sub_group_size() + get_sub_group_local_id();
#endif
#if SUBGROUP_BLOCK_SIZE == 1
vec_tmp = TO_OUTPUT_TYPE(rms * data[i] * vec_gamma);
OUTPUT_TYPE normalized = TO_OUTPUT_TYPE(rms * data[i] * vec_gamma);
#if HAS_FUSED_OPS
FUSED_OPS;
normalized = FUSED_OPS_RESULT;
#endif
vec_tmp = normalized;
#else
unroll_for (int j = 0; j < SUBGROUP_BLOCK_SIZE; j++)
vec_tmp[j] = TO_OUTPUT_TYPE(rms * data[i + j] * vec_gamma[j]);
unroll_for (int j = 0; j < SUBGROUP_BLOCK_SIZE; j++) {
OUTPUT_TYPE normalized = TO_OUTPUT_TYPE(rms * data[i + j] * vec_gamma[j]);
#if HAS_FUSED_OPS
LAST_DIM += j * get_sub_group_size();
FUSED_OPS;
normalized = FUSED_OPS_RESULT;
#endif
vec_tmp[j] = normalized;
}
#endif
BLOCK_WRITE(output, data_offset + subgroup_offset + i * get_sub_group_size(), vec_tmp);
}
Expand All @@ -120,13 +159,25 @@ KERNEL(rms_gpu_bfyx_opt)(
for (; i < items_num; i++)
{
ACCUMULATOR_TYPE temp = TO_ACCUMULATOR_TYPE(gamma[subgroup_offset + get_sub_group_local_id() + i * get_sub_group_size()]);
output[data_offset + subgroup_offset + get_sub_group_local_id() + i * get_sub_group_size()] = TO_OUTPUT_TYPE(rms * data[i] * temp);
OUTPUT_TYPE normalized = TO_OUTPUT_TYPE(rms * data[i] * temp);
#if HAS_FUSED_OPS
LAST_DIM = subgroup_offset + get_sub_group_local_id() + i * get_sub_group_size();
FUSED_OPS;
normalized = FUSED_OPS_RESULT;
#endif
output[data_offset + subgroup_offset + get_sub_group_local_id() + i * get_sub_group_size()] = normalized;
}

if (in_data_idx < leftovers)
{
ACCUMULATOR_TYPE temp = TO_ACCUMULATOR_TYPE(gamma[workers_per_data * items_num + in_data_idx]);
output[data_offset + workers_per_data * items_num + in_data_idx] = TO_OUTPUT_TYPE(rms * data[items_num] * temp);
OUTPUT_TYPE normalized = TO_OUTPUT_TYPE(rms * data[items_num] * temp);
#if HAS_FUSED_OPS
LAST_DIM = workers_per_data * items_num + in_data_idx;
FUSED_OPS;
normalized = FUSED_OPS_RESULT;
#endif
output[data_offset + workers_per_data * items_num + in_data_idx] = normalized;
}
}
#undef USE_BLOCK_WRITE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ KERNEL(rms_gpu_ref)(
OPTIONAL_SHAPE_INFO_ARG
const __global INPUT0_TYPE* input,
const __global INPUT1_TYPE* gamma,
__global OUTPUT_TYPE* output)
__global OUTPUT_TYPE* output
#if HAS_FUSED_OPS_DECLS
, FUSED_OPS_DECLS
#endif
)
{
const uint b = get_global_id(0);
const uint f = get_global_id(1);
Expand Down Expand Up @@ -38,6 +42,10 @@ KERNEL(rms_gpu_ref)(
const uint gamma_idx = z;
#endif
OUTPUT_TYPE result = TO_OUTPUT_TYPE(rms) * TO_OUTPUT_TYPE(input[input_idx]) * TO_OUTPUT_TYPE(gamma[gamma_idx]);
#if HAS_FUSED_OPS
FUSED_OPS;
result = FUSED_OPS_RESULT;
#endif
output[output_idx] = result;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,35 @@ JitConstants RMSKernelBfyxOpt::GetJitConstants(const rms_params& params, Dispatc
}
jit.AddConstant(MakeJitConstant("SUB_GROUP_SIZE", subgroup_size));
jit.AddConstant(MakeJitConstant("SUBGROUP_BLOCK_SIZE", dispatchData.subgroupBlockSize));
if (!params.fused_ops.empty()) {
jit.AddConstant(MakeJitConstant("INPUT_RANK", params.ov_input_rank));
switch (params.ov_input_rank) {
case 1 :
jit.AddConstant(MakeJitConstant("LAST_DIM", "b"));
break;
case 2 :
jit.AddConstant(MakeJitConstant("LAST_DIM", "f"));
break;
case 3 :
jit.AddConstant(MakeJitConstant("LAST_DIM", "y"));
break;
default:
jit.AddConstant(MakeJitConstant("LAST_DIM", "x"));
break;
}

std::vector<std::string> idx_order;
if (params.inputs[0].GetDims().size() == 5) {
idx_order = { "(b)", "(f)", "(z)", "(y)", "(x)" };
} else if (params.inputs[0].GetDims().size() <= 4) {
idx_order = { "(b)", "(f)", "(y)", "(x)" };
} else {
OPENVINO_THROW("rms_bfyx_opt doesn't support 5D or higher dims.");
}

auto conf = FusedOpsConfiguration("", idx_order, "normalized", params.outputs[0].GetDType(), 1);
jit.Merge(MakeFusedOpsJitConstants(params, { conf }));
}

return jit;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ class RMSKernelBfyxOpt : public RMSKernelBase {
ParamsKey GetSupportedKey() const override;

protected:
std::vector<FusedOpType> GetSupportedFusedOps() const override {
return {
FusedOpType::ACTIVATION,
FusedOpType::QUANTIZE,
FusedOpType::ELTWISE
};
}
bool Validate(const Params&) const override;
DispatchData SetDefault(const rms_params& params) const override;
JitConstants GetJitConstants(const rms_params& params, DispatchData dispatchData) const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,26 @@ ParamsKey RMSKernelRef::GetSupportedKey() const {
return k;
}

JitConstants RMSKernelRef::GetJitConstants(const rms_params& params, DispatchData dispatchData) const {
auto jit = Parent::GetJitConstants(params, dispatchData);

if (!params.fused_ops.empty()) {
std::vector<std::string> idx_order;
if (params.inputs[0].GetDims().size() == 5) {
idx_order = { "(b)", "(f)", "(z)", "(y)", "(x)" };
} else if (params.inputs[0].GetDims().size() <= 4) {
idx_order = { "(b)", "(f)", "(y)", "(x)" };
} else {
OPENVINO_THROW("rms_ref doesn't support 5D or higher dims.");
}

auto conf = FusedOpsConfiguration("", idx_order, "result", params.outputs[0].GetDType(), 1);
jit.Merge(MakeFusedOpsJitConstants(params, { conf }));
}

return jit;
}

KernelsData RMSKernelRef::GetKernelsData(const Params& params) const {
return GetCommonKernelsData(params);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,15 @@ class RMSKernelRef : public RMSKernelBase {
KernelsData GetKernelsData(const Params& params) const override;
KernelsPriority GetKernelsPriority(const Params& params) const override;
ParamsKey GetSupportedKey() const override;

protected:
std::vector<FusedOpType> GetSupportedFusedOps() const override {
return {
FusedOpType::ACTIVATION,
FusedOpType::QUANTIZE,
FusedOpType::ELTWISE
};
}
JitConstants GetJitConstants(const rms_params& params, DispatchData dispatchData) const override;
};
} // namespace kernel_selector
126 changes: 126 additions & 0 deletions src/plugins/intel_gpu/tests/unit/fusions/rms_fusion_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "test_utils.h"
#include "fusion_test_common.hpp"

#include <intel_gpu/primitives/input_layout.hpp>
#include <intel_gpu/primitives/quantize.hpp>
#include <intel_gpu/primitives/eltwise.hpp>
#include <intel_gpu/primitives/data.hpp>
#include <intel_gpu/primitives/rms.hpp>
#include <intel_gpu/primitives/reorder.hpp>

#include <cmath>

using namespace cldnn;
using namespace ::tests;

namespace {
struct rms_test_params {
tensor input_size;
tensor gamma_size;
tensor elwise_size;
data_types input_type;
format input_format;
size_t expected_fused_primitives;
size_t expected_fused_primitives_onednn;
size_t expected_not_fused_primitives;
};

class RMSFusingTest : public ::BaseFusingTest<rms_test_params> {
public:
void execute(rms_test_params& p) {
if (engine.get_device_info().supports_immad)
p.expected_fused_primitives = p.expected_fused_primitives_onednn;
auto input_prim = get_mem(get_input_layout(p));
auto gamma_prim = get_mem(get_gamma_layout(p));

network network_not_fused(this->engine, this->topology_non_fused, cfg_not_fused);
network network_fused(this->engine, this->topology_fused, cfg_fused);

network_fused.set_input_data("input", input_prim);
network_fused.set_input_data("gamma", gamma_prim);
network_not_fused.set_input_data("input", input_prim);
network_not_fused.set_input_data("gamma", gamma_prim);

compare(network_not_fused, network_fused, p);
}

layout get_input_layout(rms_test_params& p) {
return layout{ p.input_type, p.input_format, p.input_size };
}

layout get_gamma_layout(rms_test_params& p) {
return layout{ p.input_type, p.input_format, p.gamma_size };
}
};
} // namespace


/* ----------------------------------------------------------------------------------------------------- */
/* --------------------------------------- RMS cases --------------------------------------------------- */
/* ----------------------------------------------------------------------------------------------------- */

#define CASE_RMS_F32_1 { 1, 16, 8, 8 }, { 1, 1, 1, 8 }, { 1, 16, 8, 8 }, data_types::f32, format::bfyx
#define CASE_RMS_F32_2 { 2, 16, 8, 8 }, { 1, 1, 1, 8 }, { 2, 16, 8, 8 }, data_types::f32, format::bfyx
#define CASE_RMS_3D_F32_1 { 1, 16, 8, 8, 8 }, { 1, 1, 1, 1, 8 }, { 1, 16, 8, 8, 8 }, data_types::f32, format::bfzyx
#define CASE_RMS_3D_F32_2 { 2, 16, 8, 8, 8 }, { 1, 1, 1, 1, 8 }, { 2, 16, 8, 8, 8 }, data_types::f32, format::bfzyx
#define CASE_RMS_F16_1 { 1, 16, 8, 8 }, { 1, 1, 1, 8 }, { 1, 16, 8, 8 }, data_types::f16, format::bfyx
#define CASE_RMS_F16_2 { 2, 16, 8, 8 }, { 1, 1, 1, 8 }, { 2, 16, 8, 8 }, data_types::f16, format::bfyx
#define CASE_RMS_3D_F16_1 { 1, 16, 8, 8, 8 }, { 1, 1, 1, 1, 8 }, { 1, 16, 8, 8, 8 }, data_types::f16, format::bfzyx
#define CASE_RMS_3D_F16_2 { 2, 16, 8, 8, 8 }, { 1, 1, 1, 1, 8 }, { 2, 16, 8, 8, 8 }, data_types::f16, format::bfzyx

class rms_activation : public RMSFusingTest {};
TEST_P(rms_activation, basic) {
auto p = GetParam();
create_topologies(
input_layout("input", get_input_layout(p)),
input_layout("gamma", get_gamma_layout(p)),
rms("rms", input_info("input"), input_info("gamma"), 1e-10f),
activation("act", input_info("rms"), activation_func::relu),
reorder("reorder_bfyx", input_info("act"), format::bfyx, data_types::f32)
);

tolerance = (p.input_type == data_types::f32) ? 1e-5f : 0.1f;
execute(p);
}

INSTANTIATE_TEST_SUITE_P(fusings_gpu, rms_activation, ::testing::ValuesIn(std::vector<rms_test_params>{
rms_test_params{ CASE_RMS_F32_1, 3, 3, 4 },
rms_test_params{ CASE_RMS_F32_2, 3, 3, 4 },
rms_test_params{ CASE_RMS_3D_F32_1, 3, 3, 4 },
rms_test_params{ CASE_RMS_3D_F32_2, 3, 3, 4 },
rms_test_params{ CASE_RMS_F16_1, 3, 3, 4 },
rms_test_params{ CASE_RMS_F16_2, 3, 3, 4 },
rms_test_params{ CASE_RMS_3D_F16_1, 3, 3, 4 },
rms_test_params{ CASE_RMS_3D_F16_2, 3, 3, 4 },
}));

class rms_eltwise : public RMSFusingTest {};
TEST_P(rms_eltwise, basic) {
auto p = GetParam();
create_topologies(
input_layout("input", layout{ p.input_type, p.input_format, p.input_size }),
input_layout("gamma", layout{ p.input_type, p.input_format, p.gamma_size }),
rms("rms", input_info("input"), input_info("gamma"), 1e-10f),
data("eltw_data", get_mem(layout{ p.input_type, p.input_format, p.elwise_size })),
eltwise("eltw", { input_info("rms"), input_info("eltw_data") }, eltwise_mode::sum, p.input_type),
reorder("reorder_bfyx", input_info("eltw"), p.input_format, data_types::f32)
);

tolerance = (p.input_type == data_types::f32) ? 1e-5f : 0.1f;
execute(p);
}

INSTANTIATE_TEST_SUITE_P(fusings_gpu, rms_eltwise, ::testing::ValuesIn(std::vector<rms_test_params>{
rms_test_params{ CASE_RMS_F32_1, 3, 3, 4 },
rms_test_params{ CASE_RMS_F32_2, 3, 3, 4 },
rms_test_params{ CASE_RMS_3D_F32_1, 3, 3, 4 },
rms_test_params{ CASE_RMS_3D_F32_2, 3, 3, 4 },
rms_test_params{ CASE_RMS_F16_1, 3, 3, 4 },
rms_test_params{ CASE_RMS_F16_2, 3, 3, 4 },
rms_test_params{ CASE_RMS_3D_F16_1, 3, 3, 4 },
rms_test_params{ CASE_RMS_3D_F16_2, 3, 3, 4 },
}));

0 comments on commit 7260cc0

Please sign in to comment.