Skip to content

Commit

Permalink
Fused solver for Fwd Convolution with Residual add, Bias add and then…
Browse files Browse the repository at this point in the history
… activation function (#2517)
  • Loading branch information
amberhassaan authored Dec 20, 2023
1 parent 45991db commit 7a7d288
Show file tree
Hide file tree
Showing 12 changed files with 859 additions and 14 deletions.
5 changes: 4 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,7 @@ nlohmann/[email protected] -DJSON_MultipleHeaders=ON -DJSON_BuildTests=Off
ROCmSoftwarePlatform/[email protected]
ROCmSoftwarePlatform/[email protected]
ROCmSoftwarePlatform/frugally-deep@9683d557eb672ee2304f80f6682c51242d748a50
ROCmSoftwarePlatform/composable_kernel@0dacd895d5ba9c9eeb99588ec7f7df1da82f7fa9 -DCMAKE_BUILD_TYPE=Release -DINSTANCES_ONLY=ON
ROCmSoftwarePlatform/composable_kernel@55a89c746eb6cf7973c47fb9b2635e0f73bd2fc2 -DCMAKE_BUILD_TYPE=Release -DINSTANCES_ONLY=ON



1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ set( MIOpen_Source
solver/conv_bin_winoRxS_fused.cpp
solver/conv_ck_igemm_fwd_v6r1_dlops_nchw.cpp
solver/conv_ck_igemm_fwd_bias_activ_fused.cpp
solver/conv_ck_igemm_fwd_bias_res_add_activ_fused.cpp
solver/conv_direct_naive_conv.cpp
solver/conv_direct_naive_conv_bwd.cpp
solver/conv_direct_naive_conv_fwd.cpp
Expand Down
87 changes: 75 additions & 12 deletions src/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ miopenStatus_t ConvBiasActivFusion(Handle& handle,
assert(workspaceSizeInBytes == 0);
std::ignore = workspace;
std::ignore = workspaceSizeInBytes;
/// \todo: add workspace support in fusion

/*
if(alpha1 != nullptr)
{
const auto falpha1 = *(static_cast<const float*>(alpha1));
Expand All @@ -92,29 +95,46 @@ miopenStatus_t ConvBiasActivFusion(Handle& handle,
if(falpha2 != 1.0f)
MIOPEN_THROW(miopenStatusNotImplemented, "alpha2 can only be 1.0");
}
if(z != nullptr || zDesc.GetSize() != 0)
MIOPEN_THROW(miopenStatusNotImplemented, "The addition of z vector is not yet supported");
*/

// TODO: The type of these pointers depends on the ConvolutionDescriptor's data
// type
float falpha1 = alpha1 != nullptr ? *(static_cast<const float*>(alpha1)) : 1.0f;
float falpha2 = alpha2 != nullptr ? *(static_cast<const float*>(alpha2)) : 1.0f;

// if(z != nullptr || zDesc.GetSize() != 0)
// MIOPEN_THROW(miopenStatusNotImplemented, "The addition of z vector is not yet supported");
FusionPlanDescriptor fusePlanDesc{miopenVerticalFusion, xDesc};
OperatorArgs fusionArgs;
auto convoOp = std::make_shared<ConvForwardOpDescriptor>(conv_desc, wDesc);
auto convOp = std::make_shared<ConvForwardOpDescriptor>(conv_desc, wDesc);
auto zOp = std::make_shared<TensorScaleAddOpDescriptor>(zDesc);
auto biasOp = std::make_shared<BiasFusionOpDescriptor>(biasDesc);
auto activOp = std::make_shared<ActivFwdFusionOpDescriptor>(activationDesc.GetMode());
MIOPEN_CHECK(fusePlanDesc.AddOp(convoOp));

if(activationDesc.GetMode() != miopenActivationRELU)
{
MIOPEN_THROW(miopenStatusNotImplemented,
"only Activation Mode == miopenActivationRELU is supported");
}

MIOPEN_CHECK(fusePlanDesc.AddOp(convOp));
MIOPEN_CHECK(fusePlanDesc.SetConvAlgo(algo));
MIOPEN_CHECK(fusePlanDesc.AddOp(zOp));
MIOPEN_CHECK(fusePlanDesc.AddOp(biasOp));
MIOPEN_CHECK(fusePlanDesc.AddOp(activOp));

MIOPEN_CHECK(fusePlanDesc.Compile(handle));
float alpha = static_cast<float>(1.0);
float beta = static_cast<float>(0);
float alpha = 1.0f;
float beta = 0.0f;
float activ_alpha = activationDesc.GetAlpha();
float activ_beta = activationDesc.GetBeta();
float activ_gamma = activationDesc.GetGamma();

// Set the Args
MIOPEN_CHECK(convoOp->SetArgs(fusionArgs, &alpha, &beta, w));
MIOPEN_CHECK(activOp->SetArgs(fusionArgs, &alpha, &beta, activ_alpha, activ_beta, activ_gamma));
MIOPEN_CHECK(convOp->SetArgs(fusionArgs, &falpha1, &beta, w));
MIOPEN_CHECK(zOp->SetArgs(fusionArgs, falpha2, z));
MIOPEN_CHECK(biasOp->SetArgs(fusionArgs, &alpha, &beta, bias));
MIOPEN_CHECK(activOp->SetArgs(fusionArgs, &alpha, &beta, activ_alpha, activ_beta, activ_gamma));
MIOPEN_CHECK(fusePlanDesc.Execute(handle, xDesc, x, yDesc, y, fusionArgs));
return miopenStatusSuccess;
}
Expand All @@ -140,6 +160,8 @@ AllocateBuffersAndMakeFusionInvokeParams(Handle& handle,
const auto bn_inf_id = solver::fusion::GetOpIdx(plan.op_map, miopenFusionOpBatchNormInference);
const auto bn_fwd_id = solver::fusion::GetOpIdx(plan.op_map, miopenFusionOpBatchNormFwdTrain);
const auto bn_bwd_id = solver::fusion::GetOpIdx(plan.op_map, miopenFusionOpBatchNormBwdTrain);
const auto tensor_add_op_id =
solver::fusion::GetOpIdx(plan.op_map, miopenFusionOpTensorScaleAdd);

const auto any_activ = activ_fwd_id != -1 || activ_bwd_id != -1;
const auto any_bn = bn_inf_id != -1 || bn_fwd_id != -1 || bn_bwd_id != -1;
Expand Down Expand Up @@ -198,6 +220,20 @@ AllocateBuffersAndMakeFusionInvokeParams(Handle& handle,
}
}

if(tensor_add_op_id != -1)
{
const auto& tensor_add_op =
dynamic_cast<const TensorScaleAddOpDescriptor&>(*plan.op_map[tensor_add_op_id]);
assert(&tensor_add_op);

float alpha = 1.0f;
const auto space = tensor_add_op.tensor_desc.GetNumBytes();
auto ptr = allocate_buffer(space);

params.SetArg(tensor_add_op_id,
std::make_unique<miopen::fusion::TensorScaleAddOpInvokeParam>(alpha, ptr));
}

if(any_bn)
{
const auto epsilon = 0.00001;
Expand Down Expand Up @@ -512,12 +548,24 @@ miopenStatus_t ConvForwardOpDescriptor::GetOutputDesc(TensorDescriptor& output_d
[&]() { output_desc = base_desc.GetForwardOutputTensor(input_desc, filter_desc); });
}

/*
miopenStatus_t
ConvForwardOpDescriptor::SetArgs(OperatorArgs& args, float alpha, float beta, ConstData_t w)
{
auto op_args = std::make_unique<fusion::ConvolutionOpInvokeParam>(alpha, beta, w);
args.SetArg(GetIdx(), std::move(op_args));
return miopenStatusSuccess;
}
*/

miopenStatus_t ConvForwardOpDescriptor::SetArgs(OperatorArgs& args,
const void* /*alpha*/,
const void* /*beta*/,
const void* alpha,
const void* beta,
ConstData_t w)
{
auto op_args = std::make_unique<fusion::ConvolutionOpInvokeParam>(w);
float falpha = alpha != nullptr ? *reinterpret_cast<const float*>(alpha) : 1.0f;
float fbeta = beta != nullptr ? *reinterpret_cast<const float*>(beta) : 0.0f;
auto op_args = std::make_unique<fusion::ConvolutionOpInvokeParam>(falpha, fbeta, w);
args.SetArg(GetIdx(), std::move(op_args));
return miopenStatusSuccess;
}
Expand Down Expand Up @@ -672,6 +720,20 @@ miopenStatus_t BiasFusionOpDescriptor::SetArgs(OperatorArgs& args,
return miopenStatusSuccess;
}

miopenStatus_t TensorScaleAddOpDescriptor::GetOutputDesc(TensorDescriptor& output_desc) const
{
output_desc = this->tensor_desc;
return miopenStatusSuccess;
}

miopenStatus_t
TensorScaleAddOpDescriptor::SetArgs(OperatorArgs& args, float alpha, ConstData_t tensor_ptr)
{
auto op_args = std::make_unique<fusion::TensorScaleAddOpInvokeParam>(alpha, tensor_ptr);
args.SetArg(GetIdx(), std::move(op_args));
return miopenStatusSuccess;
}

std::string FusionPlanDescriptor::GetAlgorithmName(const Handle& /*handle*/)
{
if(conv_fwd_algo)
Expand All @@ -698,7 +760,8 @@ static auto GetFusedDirectSolvers()

static auto GetFusedIGemmSolvers()
{
return solver::SolverContainer<solver::fusion::ConvCKIgemmFwdBiasActivFused>{};
return solver::SolverContainer<solver::fusion::ConvCKIgemmFwdBiasActivFused,
solver::fusion::ConvCKIgemmFwdBiasResAddActivFused>{};
}

static auto GetFusedWinogradSolvers()
Expand Down
11 changes: 11 additions & 0 deletions src/include/miopen/fusion.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,16 @@ struct BiasFusionOpDescriptor : FusionOpDescriptor
TensorDescriptor base_desc;
};

struct TensorScaleAddOpDescriptor : public FusionOpDescriptor
{
TensorScaleAddOpDescriptor(const TensorDescriptor& desc) : tensor_desc(desc) {}
miopenStatus_t GetOutputDesc(TensorDescriptor& output_desc) const override;
miopenStatus_t GetNetworkConfig(std::ostringstream& network_config) override;
miopenStatus_t SetArgs(OperatorArgs& args, float alpha, ConstData_t tensor_ptr);
miopenFusionOp_t kind() const override { return miopenFusionOpTensorScaleAdd; };
TensorDescriptor tensor_desc;
};

struct ActivFwdFusionOpDescriptor : FusionOpDescriptor
{
ActivFwdFusionOpDescriptor(miopenActivationMode_t mode) : activMode(mode) {}
Expand Down Expand Up @@ -215,6 +225,7 @@ struct ConvForwardOpDescriptor : FusionOpDescriptor
conv_compiler_options(""){};
miopenStatus_t GetOutputDesc(TensorDescriptor& output_desc) const override;
miopenStatus_t SetArgs(OperatorArgs& args, const void* alpha, const void* beta, ConstData_t w);
// miopenStatus_t SetArgs(OperatorArgs& args, float alpha, float beta, ConstData_t w);
miopenStatus_t GetNetworkConfig(std::ostringstream& network_config) override;
bool isASMApplicable(Handle& handle);
miopenFusionOp_t kind() const override { return miopenFusionOpConvForward; };
Expand Down
13 changes: 13 additions & 0 deletions src/include/miopen/fusion/fusion_invoke_params.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ struct FusionOpInvokeParamBase
struct ConvolutionOpInvokeParam : FusionOpInvokeParamBase
{
ConvolutionOpInvokeParam(ConstData_t w) : weights(w) {}
ConvolutionOpInvokeParam(float _alpha, float _beta, ConstData_t w)
: alpha(_alpha), beta(_beta), weights(w)
{
}
float alpha = 1.0f; // scales new result of convolution
float beta = 0.0f; // scales old val of convolution output tensor
ConstData_t weights = nullptr;
};

Expand All @@ -50,6 +56,13 @@ struct BiasOpInvokeParam : FusionOpInvokeParamBase
ConstData_t bdata = nullptr;
};

struct TensorScaleAddOpInvokeParam : public FusionOpInvokeParamBase
{
TensorScaleAddOpInvokeParam(float a, ConstData_t tp) : alpha(a), tensor_ptr(tp) {}
float alpha = 1.0f;
ConstData_t tensor_ptr = nullptr;
};

struct ActivationOpInvokeParam : FusionOpInvokeParamBase
{
ActivationOpInvokeParam(double alpha, double beta, double gamma)
Expand Down
67 changes: 67 additions & 0 deletions src/include/miopen/fusion/solvers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,73 @@ struct ConvCKIgemmFwdBiasActivFused final
bool CheckCKApplicability(const miopen::conv::ProblemDescription&) const;
};

struct PerfConfigConvCKIgemmFwdBiasResAddActivFused
: PerfConfigBase<PerfConfigConvCKIgemmFwdBiasResAddActivFused>
{
int index;
std::string kernel_id;
std::vector<std::string> valid_kernels;
PerfConfigConvCKIgemmFwdBiasResAddActivFused(int idx, std::string kernl_id)
: index(idx), kernel_id(kernl_id)
{
}
PerfConfigConvCKIgemmFwdBiasResAddActivFused()
: PerfConfigConvCKIgemmFwdBiasResAddActivFused(0, "")
{
}
PerfConfigConvCKIgemmFwdBiasResAddActivFused(bool)
: PerfConfigConvCKIgemmFwdBiasResAddActivFused(0, "")
{
}
void HeuristicInit(const FusionDescription& fdesc_problem);
bool SetNextValue(const FusionDescription& fdesc_problem);
bool IsValidValue() const;
bool IsValid(const FusionContext&, const FusionDescription& fdesc_problem) const;

template <typename Self, typename F>
static void Visit(Self&& s, F f)
{
f(s.kernel_id, "kernel_id");
}
bool operator==(const PerfConfigConvCKIgemmFwdBiasResAddActivFused& other) const;

private:
template <typename DataType, typename AccumDataType = DataType>
void Init(const miopen::conv::ProblemDescription&);
template <typename DataType, typename AccumDataType = DataType>
bool CheckIsSupportCKArgs(const miopen::conv::ProblemDescription&) const;
};

struct ConvCKIgemmFwdBiasResAddActivFused final
: FusionTunableSolver<PerfConfigConvCKIgemmFwdBiasResAddActivFused>
{
const std::string& SolverDbId() const override
{
return GetSolverDbId<ConvCKIgemmFwdBiasResAddActivFused>();
}

PerfConfigConvCKIgemmFwdBiasResAddActivFused
GetDefaultPerformanceConfig(const FusionContext& ctx,
const FusionDescription& fdesc_problem) const override;
bool IsValidPerformanceConfig(
const FusionContext& ctx,
const FusionDescription& fdesc_problem,
const PerfConfigConvCKIgemmFwdBiasResAddActivFused& config) const override;
PerfConfigConvCKIgemmFwdBiasResAddActivFused
Search(const FusionContext& ctx,
const FusionDescription& fdesc_problem,
const AnyInvokeParams& invoke_ctx) const override;
bool IsApplicable(const FusionContext& ctx,
const FusionDescription& fdesc_problem) const override;
ConvSolution
GetSolution(const FusionContext& ctx,
const FusionDescription& fdesc_problem,
const PerfConfigConvCKIgemmFwdBiasResAddActivFused& config) const override;

private:
template <typename DataType, typename AccumDataType = DataType>
bool CheckCKApplicability(const miopen::conv::ProblemDescription&) const;
};
struct ConvBinWinogradRxSFused final : FusionSolverBase
{
const std::string& SolverDbId() const override
Expand Down
1 change: 1 addition & 0 deletions src/include/miopen/fusion_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ enum miopenFusionOp_t
miopenFusionOpBatchNormFwdTrain = 4,
miopenFusionOpBatchNormBwdTrain = 5,
miopenFusionOpActivBackward = 6,
miopenFusionOpTensorScaleAdd = 7,
};

enum MDGraph_op_t
Expand Down
6 changes: 6 additions & 0 deletions src/ocl/fusionopbiasbnactivocl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ miopenStatus_t BiasFusionOpDescriptor::GetNetworkConfig(std::ostringstream& netw
return miopenStatusSuccess;
}

miopenStatus_t TensorScaleAddOpDescriptor::GetNetworkConfig(std::ostringstream& network_config)
{
network_config << "tensorScaleAdd"; // for bias
return miopenStatusSuccess;
}

miopenStatus_t ActivFwdFusionOpDescriptor::GetNetworkConfig(std::ostringstream& network_config)
{
network_config << "ActivFwd" << std::to_string(activMode);
Expand Down
5 changes: 5 additions & 0 deletions src/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,11 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry)
++id,
conv::ConvHipImplicitGemmF16F8F16WrwXdlops{},
miopenConvolutionAlgoImplicitGEMM);
Register(registry,
++id,
Primitive::Fusion,
fusion::ConvCKIgemmFwdBiasResAddActivFused{}.SolverDbId(),
miopenConvolutionAlgoImplicitGEMM);

// IMPORTANT: New solvers should be added to the end of the function!
}
Expand Down
Loading

0 comments on commit 7a7d288

Please sign in to comment.