Skip to content

Commit

Permalink
Update rotmg interface to handle issue in OpenCL CPU support (#532)
Browse files Browse the repository at this point in the history
* Update rotmg implementation

* Add conditional check on mem_type in use

* Add dependencies to copy_y1 operation

* Add new interface to library and add related tests.
  • Loading branch information
s-Nick authored Aug 23, 2024
1 parent 18a32a8 commit a9378fb
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 124 deletions.
11 changes: 11 additions & 0 deletions src/interface/blas1/rotmg.cpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,22 @@ template typename SB_Handle::event_t _rotmg(
BufferIterator<${DATA_TYPE}> _y1, BufferIterator<${DATA_TYPE}> _param,
const typename SB_Handle::event_t& dependencies);

template typename SB_Handle::event_t _rotmg(
SB_Handle& sb_handle, BufferIterator<${DATA_TYPE}> _d1,
BufferIterator<${DATA_TYPE}> _d2, BufferIterator<${DATA_TYPE}> _x1,
${DATA_TYPE} _y1, BufferIterator<${DATA_TYPE}> _param,
const typename SB_Handle::event_t& dependencies);

#ifdef SB_ENABLE_USM
template typename SB_Handle::event_t _rotmg(
SB_Handle& sb_handle, ${DATA_TYPE} * _d1, ${DATA_TYPE} * _d2,
${DATA_TYPE} * _x1, ${DATA_TYPE} * _y1, ${DATA_TYPE} * _param,
const typename SB_Handle::event_t& dependencies);

template typename SB_Handle::event_t _rotmg(
SB_Handle& sb_handle, ${DATA_TYPE} * _d1, ${DATA_TYPE} * _d2,
${DATA_TYPE} * _x1, ${DATA_TYPE} _y1, ${DATA_TYPE} * _param,
const typename SB_Handle::event_t& dependencies);
#endif

} // namespace internal
Expand Down
33 changes: 28 additions & 5 deletions src/interface/blas1_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -811,14 +811,37 @@ typename sb_handle_t::event_t _rotmg(
auto d1_view = make_vector_view(_d1, inc, vector_size);
auto d2_view = make_vector_view(_d2, inc, vector_size);
auto x1_view = make_vector_view(_x1, inc, vector_size);
auto y1_view = make_vector_view(_y1, inc, vector_size);
auto param_view = make_vector_view(_param, inc, param_size);

auto operation =
Rotmg<decltype(d1_view)>(d1_view, d2_view, x1_view, y1_view, param_view);
auto ret = sb_handle.execute(operation, _dependencies);
if constexpr (std::is_arithmetic_v<container_3_t>) {
constexpr helper::AllocType mem_type = std::is_pointer_v<container_0_t>
? helper::AllocType::usm
: helper::AllocType::buffer;
auto _y1_tmp = blas::helper::allocate<mem_type, container_3_t>(
1, sb_handle.get_queue());

return ret;
auto copy_y1 = blas::helper::copy_to_device(sb_handle.get_queue(), &_y1,
_y1_tmp, 1, _dependencies);

auto y1_view = make_vector_view(_y1_tmp, inc, vector_size);
auto operation = Rotmg<decltype(d1_view)>(d1_view, d2_view, x1_view,
y1_view, param_view);

auto operator_event =
sb_handle.execute(operation, typename sb_handle_t::event_t{copy_y1});
if constexpr (mem_type != helper::AllocType::buffer) {
// This wait is necessary to free the temporary memory created above and
// avoiding the host_task
operator_event[0].wait();
sycl::free(_y1_tmp, sb_handle.get_queue());
}
return operator_event;
} else {
auto y1_view = make_vector_view(_y1, inc, vector_size);
auto operation = Rotmg<decltype(d1_view)>(d1_view, d2_view, x1_view,
y1_view, param_view);
return sb_handle.execute(operation, _dependencies);
}
}

/**
Expand Down
Loading

0 comments on commit a9378fb

Please sign in to comment.