Skip to content

Commit

Permalink
Refactor BnCKFwdInference::GetSolution for NHWC (#3120)
Browse files Browse the repository at this point in the history
  • Loading branch information
xinlipn authored Aug 5, 2024
1 parent a54b21c commit cfabfbb
Showing 1 changed file with 107 additions and 79 deletions.
186 changes: 107 additions & 79 deletions src/solver/batchnorm/forward_inference_ck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,18 @@ using F32 = float;
using F64 = double;
using BF16 = ushort;

template <typename XDataType,
typename YDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType>
using DeviceOpBnFwdInfPtrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<XDataType, MeanVarDataType, MeanVarDataType, ScaleDataType, BiasDataType>,
ck::Tuple<YDataType>,
Normalize,
Rank>>;

struct CKArgsBNormFwd
{
CKArgsBNormFwd(const miopen::batchnorm::ProblemDescription& problem)
Expand Down Expand Up @@ -79,6 +91,25 @@ struct CKArgsBNormFwd
std::array<index_t, Rank> aligned_scaleBiasMeanVarStrides{3};

std::array<int, NumBatchNormReduceDim> reduceDims{0, 1, 2};

template <typename InvokerPtr, typename InvokerParams>
auto MakeArgPtr(const InvokerPtr& invoker_ptr, const InvokerParams& data_ctx) const
{
return invoker_ptr->MakeArgumentPointer(xyLengths,
{xyStrides,
aligned_scaleBiasMeanVarStrides,
aligned_scaleBiasMeanVarStrides,
aligned_scaleBiasMeanVarStrides,
aligned_scaleBiasMeanVarStrides},
{xyStrides},
{data_ctx.x,
data_ctx.estimatedMean,
data_ctx.estimatedVariance,
data_ctx.bnScale,
data_ctx.bnBias},
{data_ctx.y},
Normalize{data_ctx.epsilon});
}
};

template <typename XDataType,
Expand All @@ -90,13 +121,9 @@ template <typename XDataType,
static int CheckCKApplicability(const miopen::batchnorm::ProblemDescription& problem)
{
const auto& args = CKArgsBNormFwd{problem};
using DeviceOp = ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<XDataType, MeanVarDataType, MeanVarDataType, ScaleDataType, BiasDataType>,
ck::Tuple<YDataType>,
Normalize,
Rank>;
const auto bn_fwd_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
const auto bn_fwd_ptrs =
DeviceOpBnFwdInfPtrs<XDataType, YDataType, ScaleDataType, BiasDataType, MeanVarDataType>::
GetInstances();
assert(!bn_fwd_ptrs.empty());
int count = 0;
for(const auto& it : bn_fwd_ptrs)
Expand Down Expand Up @@ -126,52 +153,44 @@ template <typename XDataType,
typename ScaleDataType,
typename BiasDataType,
typename MeanVarDataType>
static void RunCKSolution(const Handle& handle,
const AnyInvokeParams& primitive_parameters,
const miopen::batchnorm::ProblemDescription& problem)
ConvSolution InvokerFactoryMakerNHWC(const miopen::batchnorm::ProblemDescription& bn_problem)
{
const auto& args = CKArgsBNormFwd{problem};
ConvSolution result;
const auto kernel_index = CheckCKApplicability<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType>(bn_problem);
auto bn_fwd_ptrs =
DeviceOpBnFwdInfPtrs<XDataType, YDataType, ScaleDataType, BiasDataType, MeanVarDataType>::
GetInstances();

using DeviceOp = ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<XDataType, MeanVarDataType, MeanVarDataType, ScaleDataType, BiasDataType>,
ck::Tuple<YDataType>,
Normalize,
Rank>;
const auto bn_fwd_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
assert(kernel_index >= 0 && !bn_fwd_ptrs.empty() && kernel_index < bn_fwd_ptrs.size());
auto bn_ptr = std::move(bn_fwd_ptrs.at(kernel_index));

int kernel_index = CheckCKApplicability<XDataType,
YDataType,
AccDataType,
ScaleDataType,
BiasDataType,
MeanVarDataType>(problem);
assert(kernel_index >= 0 && kernel_index < bn_fwd_ptrs.size());
auto& bn_ptr = bn_fwd_ptrs.at(kernel_index);
const auto& params = primitive_parameters.CastTo<miopen::batchnorm::InfInvokeParams>();
result.invoker_factory = [args = CKArgsBNormFwd{bn_problem},
sh_bn_ptr = std::shared_ptr{std::move(bn_ptr)}](
const std::vector<Kernel>& /*kernels*/) mutable {
return [args = std::move(args), sh_bn_ptr = std::move(sh_bn_ptr)](
const Handle& handle, const AnyInvokeParams& primitive_parameters) {
const auto& params = primitive_parameters.CastTo<miopen::batchnorm::InfInvokeParams>();

auto argument_ptr = bn_ptr->MakeArgumentPointer(
args.xyLengths,
{args.xyStrides,
args.aligned_scaleBiasMeanVarStrides,
args.aligned_scaleBiasMeanVarStrides,
args.aligned_scaleBiasMeanVarStrides,
args.aligned_scaleBiasMeanVarStrides},
{args.xyStrides},
{params.x, params.estimatedMean, params.estimatedVariance, params.bnScale, params.bnBias},
{params.y},
Normalize{params.epsilon});
auto argument_ptr = args.MakeArgPtr(sh_bn_ptr, params);

auto invoker_ptr = bn_ptr->MakeInvokerPointer();
const auto enable_profiling = handle.IsProfilingEnabled();
auto invoker_ptr = sh_bn_ptr->MakeInvokerPointer();
const auto enable_profiling = handle.IsProfilingEnabled();

float elapsed_time =
invoker_ptr->Run(argument_ptr.get(), {handle.GetStream(), enable_profiling});
if(enable_profiling)
{
handle.ResetKernelTime();
handle.AccumKernelTime(elapsed_time);
}
float elapsed_time =
invoker_ptr->Run(argument_ptr.get(), {handle.GetStream(), enable_profiling});
if(enable_profiling)
{
handle.ResetKernelTime();
handle.AccumKernelTime(elapsed_time);
}
};
};
return result;
}
#endif

Expand Down Expand Up @@ -209,43 +228,52 @@ bool BnCKFwdInference::IsApplicable(
return false;
}

template <typename InvokerFactoryMakerNHWC>
ConvSolution MakeAnyInvokerFactory(const miopen::batchnorm::ProblemDescription& problem,
InvokerFactoryMakerNHWC&& invoker_factory_maker_nhwc)
{
#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL
if(problem.IsLayoutNHWC())
{
switch(problem.GetXDesc().GetType())
{
case miopenFloat: return invoker_factory_maker_nhwc(F32{});
case miopenDouble: return invoker_factory_maker_nhwc(F64{});
case miopenHalf: return invoker_factory_maker_nhwc(F16{});
case miopenBFloat16: return invoker_factory_maker_nhwc(BF16{});
default:
MIOPEN_THROW(miopenStatusInternalError,
"BnCKFwdInference operation does not support this data type");
}
}
// Todo: problem.IsLayoutDefault()
else
{
MIOPEN_THROW(miopenStatusInternalError,
"BnCKFwdInference operation does not support this data layout");
}
#else
return {};
#endif
}

ConvSolution BnCKFwdInference::GetSolution(
[[maybe_unused]] const ExecutionContext& context,
[[maybe_unused]] const miopen::batchnorm::ProblemDescription& bn_problem) const
{
#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL
ConvSolution result;
result.invoker_factory = [=](const std::vector<Kernel>& kernels) {
std::ignore = kernels;
return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) {
switch(bn_problem.GetXDesc().GetType())
{
case miopenHalf:
RunCKSolution<F16, F16, F32, F16, F16, F32>(
handle, primitive_parameters, bn_problem);
break;
case miopenFloat:
RunCKSolution<F32, F32, F32, F32, F32, F32>(
handle, primitive_parameters, bn_problem);
break;
case miopenDouble:
RunCKSolution<F64, F64, F64, F64, F64, F64>(
handle, primitive_parameters, bn_problem);
break;
case miopenBFloat16:
RunCKSolution<BF16, BF16, F32, BF16, BF16, F32>(
handle, primitive_parameters, bn_problem);
break;
case miopenInt8:
case miopenInt32:
case miopenInt64:
case miopenFloat8:
case miopenBFloat8:
default: MIOPEN_THROW("Unsupported datatype");
}
};
};
return result;
return MakeAnyInvokerFactory(
bn_problem,
[&](auto data_type_val) {
using T = decltype(data_type_val);

using AccTy = std::conditional_t<std::is_same_v<T, F64>,
T, // T==F64
F32>; // T==F32
return InvokerFactoryMakerNHWC<T, T, AccTy, T, T, AccTy>(bn_problem);
}
// Todo: InvokerFactoryMakerNCHW
);
#else
return {};
#endif
Expand Down

0 comments on commit cfabfbb

Please sign in to comment.