Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored blas1 dot & sdsdot operators #471

Merged
Merged
11 changes: 7 additions & 4 deletions benchmark/portblas/blas1/dot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,10 @@ void run(benchmark::State& state, blas::SB_Handle* sb_handle_ptr, index_t size,
scalar_t vr_temp = 0;
{
auto vr_temp_gpu = blas::helper::allocate<mem_alloc, scalar_t>(1, q);
auto copyToD =
blas::helper::copy_to_device<scalar_t>(q, &vr_temp, vr_temp_gpu, 1);
auto dot_event = _dot(sb_handle, size, inx, static_cast<index_t>(1), iny,
static_cast<index_t>(1), vr_temp_gpu);
static_cast<index_t>(1), vr_temp_gpu, {copyToD});
sb_handle.wait(dot_event);
auto copy_output = blas::helper::copy_to_host(q, vr_temp_gpu, &vr_temp, 1);
sb_handle.wait(copy_output);
Expand Down Expand Up @@ -128,8 +130,8 @@ void register_benchmark(blas::SB_Handle* sb_handle_ptr, bool* success,
};

benchmark::RegisterBenchmark(
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(
size, mem_type).c_str(),
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(size, mem_type)
.c_str(),
BM_lambda, sb_handle_ptr, size, success)
->UseRealTime();
}
Expand All @@ -141,7 +143,8 @@ void register_benchmark(blas_benchmark::Args& args,
auto dot_params = blas_benchmark::utils::get_blas1_params(args);

register_benchmark<scalar_t, blas::helper::AllocType::buffer>(
sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER, dot_params);
sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER,
dot_params);
#ifdef SB_ENABLE_USM
register_benchmark<scalar_t, blas::helper::AllocType::usm>(
sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_USM, dot_params);
Expand Down
14 changes: 9 additions & 5 deletions benchmark/portblas/blas1/sdsdot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,11 @@ void run(benchmark::State& state, blas::SB_Handle* sb_handle_ptr, index_t size,
scalar_t vr_temp = 0;
{
auto vr_temp_gpu = blas::helper::allocate<mem_alloc, scalar_t>(1, q);
auto copyToD =
blas::helper::copy_to_device<scalar_t>(q, &vr_temp, vr_temp_gpu, 1);
auto sdsdot_event =
_sdsdot(sb_handle, size, sb, inx, static_cast<index_t>(1), iny,
static_cast<index_t>(1), vr_temp_gpu);
static_cast<index_t>(1), vr_temp_gpu, {copyToD});
sb_handle.wait(sdsdot_event);
auto event = blas::helper::copy_to_host(q, vr_temp_gpu, &vr_temp, 1);
sb_handle.wait(event);
Expand Down Expand Up @@ -126,8 +128,8 @@ void register_benchmark(blas::SB_Handle* sb_handle_ptr, bool* success,
run<scalar_t, mem_alloc>(st, sb_handle_ptr, size, success);
};
benchmark::RegisterBenchmark(
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(
size, mem_type).c_str(),
blas_benchmark::utils::get_name<benchmark_op, scalar_t>(size, mem_type)
.c_str(),
BM_lambda, sb_handle_ptr, size, success)
->UseRealTime();
}
Expand All @@ -139,10 +141,12 @@ void register_benchmark(blas_benchmark::Args& args,
auto sdsdot_params = blas_benchmark::utils::get_blas1_params(args);

register_benchmark<scalar_t, blas::helper::AllocType::buffer>(
sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER, sdsdot_params);
sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER,
sdsdot_params);
#ifdef SB_ENABLE_USM
register_benchmark<scalar_t, blas::helper::AllocType::usm>(
sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_USM, sdsdot_params);
sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_USM,
sdsdot_params);
#endif
}

Expand Down
51 changes: 33 additions & 18 deletions include/interface/blas1_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,19 @@ typename sb_handle_t::event_t _nrm2_impl(
container_1_t _rs, const index_t number_WG,
const typename sb_handle_t::event_t &_dependencies);

/*!
* \brief Prototype for the internal implementation of the Dot operator. See
* documentation in the blas1_interface.hpp file for details.
*/
template <int localSize, int localMemSize, typename sb_handle_t,
typename container_0_t, typename container_1_t,
typename container_2_t, typename index_t, typename increment_t>
typename sb_handle_t::event_t _dot_impl(
sb_handle_t &sb_handle, index_t _N, container_0_t _vx, increment_t _incx,
container_1_t _vy, increment_t _incy, container_2_t _rs,
const index_t _number_wg,
const typename sb_handle_t::event_t &_dependencies);

/**
* @brief _rot constructor given plane rotation
* @param sb_handle SB_Handle
Expand Down Expand Up @@ -306,12 +319,12 @@ typename sb_handle_t::event_t _rotm(
* @tparam container_3_t Buffer Iterator or USM pointer
* @tparam container_4_t Buffer Iterator or USM pointer
* @param sb_handle SB_Handle
* @param _d1[in,out] On entry, memory object holding the scaling factor for the
* x-coordinate. On exit, the re-scaled _d1.
* @param _d2[in,out] On entry, memory object holding the scaling factor for the
* y-coordinate. On exit, the re-scaled _d2.
* @param _x1[in,out] On entry, memory object holding the x-coordinate. On exit,
* the re-scaled _x1
* @param _d1[in,out] On entry, memory object holding the scaling factor for
* the x-coordinate. On exit, the re-scaled _d1.
* @param _d2[in,out] On entry, memory object holding the scaling factor for
* the y-coordinate. On exit, the re-scaled _d2.
* @param _x1[in,out] On entry, memory object holding the x-coordinate. On
* exit, the re-scaled _x1
* @param _y1[in] Memory object holding the y-coordinate of the point.
* @param _param[out] Buffer with the following layout: [flag, h11, h21, h12,
* h22].
Expand Down Expand Up @@ -359,8 +372,10 @@ typename sb_handle_t::event_t _rotg(
* @tparam sb_handle_t SB_Handle type
* @tparam scalar_t Scalar type
* @param sb_handle SB_Handle
* @param a[in, out] On entry, x-coordinate of the point. On exit, the scalar z.
* @param b[in, out] On entry, y-coordinate of the point. On exit, the scalar r.
* @param a[in, out] On entry, x-coordinate of the point. On exit, the scalar
* z.
* @param b[in, out] On entry, y-coordinate of the point. On exit, the scalar
* r.
* @param c[out] scalar representing the output c.
* @param s[out] scalar representing the output s.
* @param _dependencies Vector of events
Expand All @@ -377,7 +392,6 @@ void _rotg(sb_handle_t &sb_handle, scalar_t &a, scalar_t &b, scalar_t &c,
* @tparam sb_handle_t SB_Handle type
* @tparam container_0_t Buffer Iterator or USM pointer
* @tparam container_1_t Buffer Iterator or USM pointer
* @tparam container_2_t Buffer Iterator or USM pointer
* @tparam index_t Index type
* @tparam increment_t Increment type
* @param sb_handle SB_Handle
Expand All @@ -404,7 +418,6 @@ typename ValueType<container_0_t>::type _dot(
* @tparam sb_handle_t SB_Handle type
* @tparam container_0_t Buffer Iterator or USM pointer
* @tparam container_1_t Buffer Iterator or USM pointer
* @tparam container_2_t Buffer Iterator or USM pointer
* @tparam index_t Index type
* @tparam increment_t Increment type
* @param sb_handle SB_Handle
Expand Down Expand Up @@ -754,12 +767,12 @@ typename sb_handle_t::event_t _rotm(
* @tparam container_3_t Buffer Iterator or USM pointer
* @tparam container_4_t Buffer Iterator or USM pointer
* @param sb_handle SB_Handle
* @param _d1[in,out] On entry, memory object holding the scaling factor for the
* x-coordinate. On exit, the re-scaled _d1.
* @param _d2[in,out] On entry, memory object holding the scaling factor for the
* y-coordinate. On exit, the re-scaled _d2.
* @param _x1[in,out] On entry, memory object holding the x-coordinate. On exit,
* the re-scaled _x1
* @param _d1[in,out] On entry, memory object holding the scaling factor for
* the x-coordinate. On exit, the re-scaled _d1.
* @param _d2[in,out] On entry, memory object holding the scaling factor for
* the y-coordinate. On exit, the re-scaled _d2.
* @param _x1[in,out] On entry, memory object holding the x-coordinate. On
* exit, the re-scaled _x1
* @param _y1[in] Memory object holding the y-coordinate of the point.
* @param _param[out] Buffer with the following layout: [flag, h11, h21, h12,
* h22].
Expand Down Expand Up @@ -811,8 +824,10 @@ typename sb_handle_t::event_t _rotg(
* @tparam sb_handle_t SB_Handle type
* @tparam scalar_t Scalar type
* @param sb_handle SB_Handle
* @param a[in, out] On entry, x-coordinate of the point. On exit, the scalar z.
* @param b[in, out] On entry, y-coordinate of the point. On exit, the scalar r.
* @param a[in, out] On entry, x-coordinate of the point. On exit, the scalar
* z.
* @param b[in, out] On entry, y-coordinate of the point. On exit, the scalar
* r.
* @param c[out] scalar representing the output c.
* @param s[out] scalar representing the output s.
* @param _dependencies Vector of events
Expand Down
18 changes: 18 additions & 0 deletions include/operations/blas1_trees.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,24 @@ struct BinaryOp {
void adjust_access_displacement();
};

/*! BinaryOpConst.
* @brief Implements a const Binary Operation (x OP z) with x and z vectors.
*/
template <typename operator_t, typename lhs_t, typename rhs_t>
struct BinaryOpConst {
using index_t = typename rhs_t::index_t;
using value_t = typename ResolveReturnType<operator_t, rhs_t>::type::value_t;
lhs_t lhs_;
rhs_t rhs_;
BinaryOpConst(lhs_t &_l, rhs_t &_r);
index_t get_size() const;
bool valid_thread(cl::sycl::nd_item<1> ndItem) const;
value_t eval(index_t i) const;
value_t eval(cl::sycl::nd_item<1> ndItem) const;
void bind(cl::sycl::handler &h);
void adjust_access_displacement();
};

/*! TupleOp.
* @brief Implements a Tuple Operation (map (\x -> [i, x]) vector).
*/
Expand Down
23 changes: 23 additions & 0 deletions src/interface/blas1/backend/amd_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,29 @@ typename sb_handle_t::event_t _nrm2(
}
} // namespace backend
} // namespace nrm2

namespace dot {
namespace backend {
template <typename sb_handle_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename index_t, typename increment_t>
typename sb_handle_t::event_t _dot(
sb_handle_t& sb_handle, index_t _N, container_0_t _vx, increment_t _incx,
container_1_t _vy, increment_t _incy, container_2_t _rs,
const typename sb_handle_t::event_t& _dependencies) {
if (_N < (1 << 18)) {
constexpr index_t localSize = 1024;
const index_t number_WG = (_N + localSize - 1) / localSize;
return blas::internal::_dot_impl<static_cast<int>(localSize), 32>(
sb_handle, _N, _vx, _incx, _vy, _incy, _rs, number_WG, _dependencies);
} else {
constexpr int localSize = 512;
constexpr index_t number_WG = 512;
return blas::internal::_dot_impl<localSize, 32>(
sb_handle, _N, _vx, _incx, _vy, _incy, _rs, number_WG, _dependencies);
}
}
} // namespace backend
} // namespace dot
} // namespace blas

#endif
16 changes: 16 additions & 0 deletions src/interface/blas1/backend/default_cpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,22 @@ typename sb_handle_t::event_t _nrm2(
}
} // namespace backend
} // namespace nrm2

namespace dot {
namespace backend {
template <typename sb_handle_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename index_t, typename increment_t>
typename sb_handle_t::event_t _dot(
sb_handle_t& sb_handle, index_t _N, container_0_t _vx, increment_t _incx,
container_1_t _vy, increment_t _incy, container_2_t _rs,
const typename sb_handle_t::event_t& _dependencies) {
constexpr int localSize = 8;
constexpr index_t number_WG = 16;
return blas::internal::_dot_impl<localSize, 0>(
sb_handle, _N, _vx, _incx, _vy, _incy, _rs, number_WG, _dependencies);
}
} // namespace backend
} // namespace dot
} // namespace blas

#endif
17 changes: 17 additions & 0 deletions src/interface/blas1/backend/intel_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,23 @@ typename sb_handle_t::event_t _nrm2(
} // namespace backend
} // namespace nrm2

namespace dot {
namespace backend {
template <typename sb_handle_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename index_t, typename increment_t>
typename sb_handle_t::event_t _dot(
sb_handle_t& sb_handle, index_t _N, container_0_t _vx, increment_t _incx,
container_1_t _vy, increment_t _incy, container_2_t _rs,
const typename sb_handle_t::event_t& _dependencies) {
constexpr index_t localSize = 128;
const index_t number_WG =
std::min((_N + localSize - 1) / localSize, static_cast<index_t>(512));
return blas::internal::_dot_impl<static_cast<int>(localSize), 32>(
sb_handle, _N, _vx, _incx, _vy, _incy, _rs, number_WG, _dependencies);
}
} // namespace backend
} // namespace dot

} // namespace blas

#endif
26 changes: 26 additions & 0 deletions src/interface/blas1/backend/nvidia_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,32 @@ typename sb_handle_t::event_t _nrm2(
} // namespace backend
} // namespace nrm2

namespace dot {
namespace backend {
template <typename sb_handle_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename index_t, typename increment_t>
typename sb_handle_t::event_t _dot(
sb_handle_t& sb_handle, index_t _N, container_0_t _vx, increment_t _incx,
container_1_t _vy, increment_t _incy, container_2_t _rs,
const typename sb_handle_t::event_t& _dependencies) {
if (_N < (1 << 23)) {
constexpr index_t localSize = 512;
const index_t number_WG = (_N < (1 << 18))
? (_N + localSize - 1) / localSize
: static_cast<index_t>(256);

return blas::internal::_dot_impl<static_cast<int>(localSize), 32>(
sb_handle, _N, _vx, _incx, _vy, _incy, _rs, number_WG, _dependencies);
} else {
constexpr int localSize = 512;
constexpr index_t number_WG = 1024;
return blas::internal::_dot_impl<localSize, 32>(
sb_handle, _N, _vx, _incx, _vy, _incy, _rs, number_WG, _dependencies);
}
}
} // namespace backend
} // namespace dot

} // namespace blas

#endif
6 changes: 5 additions & 1 deletion src/interface/blas1/dot.cpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ namespace internal {
* @tparam sb_handle_t SB_Handle type
* @tparam container_0_t Buffer Iterator or USM Pointer
* @tparam container_1_t Buffer Iterator or USM Pointer
* @tparam container_2_t Buffer Iterator or USM Pointer
* @tparam index_t Index type
* @tparam increment_t Increment type
* @param sb_handle SB_Handle
Expand All @@ -62,6 +61,11 @@ template typename SB_Handle::event_t _dot(
SB_Handle& sb_handle, ${INDEX_TYPE} _N, ${DATA_TYPE} * _vx,
${INCREMENT_TYPE} _incx, ${DATA_TYPE} * _vy, ${INCREMENT_TYPE} _incy,
${DATA_TYPE} * _rs, const typename SB_Handle::event_t& dependencies);

template typename SB_Handle::event_t _dot(
SB_Handle& sb_handle, ${INDEX_TYPE} _N, const ${DATA_TYPE} * _vx,
${INCREMENT_TYPE} _incx, const ${DATA_TYPE} * _vy, ${INCREMENT_TYPE} _incy,
${DATA_TYPE} * _rs, const typename SB_Handle::event_t& dependencies);
#endif

} // namespace internal
Expand Down
6 changes: 5 additions & 1 deletion src/interface/blas1/dot_return.cpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ namespace internal {
* @tparam sb_handle_t SB_Handle type
* @tparam container_0_t Buffer Iterator or USM Pointer
* @tparam container_1_t Buffer Iterator or USM Pointer
* @tparam container_2_t Buffer Iterator or USM Pointer
* @tparam index_t Index type
* @tparam increment_t Increment type
* @param sb_handle SB_Handle
Expand All @@ -61,6 +60,11 @@ template typename ValueType<${DATA_TYPE}>::type _dot(
SB_Handle& sb_handle, ${INDEX_TYPE} _N, ${DATA_TYPE} * _vx,
${INCREMENT_TYPE} _incx, ${DATA_TYPE} * _vy, ${INCREMENT_TYPE} _incy,
const typename SB_Handle::event_t& dependencies);

template typename ValueType<${DATA_TYPE}>::type _dot(
SB_Handle& sb_handle, ${INDEX_TYPE} _N, const ${DATA_TYPE} * _vx,
${INCREMENT_TYPE} _incx, const ${DATA_TYPE} * _vy, ${INCREMENT_TYPE} _incy,
const typename SB_Handle::event_t& dependencies);
#endif

} // namespace internal
Expand Down
5 changes: 5 additions & 0 deletions src/interface/blas1/sdsdot.cpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ template typename SB_Handle::event_t _sdsdot(
SB_Handle& sb_handle, ${INDEX_TYPE} _N, float sb, ${DATA_TYPE} * _vx,
${INCREMENT_TYPE} _incx, ${DATA_TYPE} * _vy, ${INCREMENT_TYPE} _incy,
${DATA_TYPE} * _rs, const typename SB_Handle::event_t& dependencies);

template typename SB_Handle::event_t _sdsdot(
SB_Handle& sb_handle, ${INDEX_TYPE} _N, float sb, const ${DATA_TYPE} * _vx,
${INCREMENT_TYPE} _incx, const ${DATA_TYPE} * _vy, ${INCREMENT_TYPE} _incy,
${DATA_TYPE} * _rs, const typename SB_Handle::event_t& dependencies);
#endif

} // namespace internal
Expand Down
5 changes: 5 additions & 0 deletions src/interface/blas1/sdsdot_return.cpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ template typename ValueType<${DATA_TYPE}>::type _sdsdot(
SB_Handle& sb_handle, ${INDEX_TYPE} _N, float sb, ${DATA_TYPE} * _vx,
${INCREMENT_TYPE} _incx, ${DATA_TYPE} * _vy, ${INCREMENT_TYPE} _incy,
const typename SB_Handle::event_t& dependencies);

template typename ValueType<${DATA_TYPE}>::type _sdsdot(
SB_Handle& sb_handle, ${INDEX_TYPE} _N, float sb, const ${DATA_TYPE} * _vx,
${INCREMENT_TYPE} _incx, const ${DATA_TYPE} * _vy, ${INCREMENT_TYPE} _incy,
const typename SB_Handle::event_t& dependencies);
#endif

} // namespace internal
Expand Down
Loading