Skip to content

Commit

Permalink
Fixed race condition
Browse files Browse the repository at this point in the history
* Removed extra store function
  • Loading branch information
muhammad-tanvir-1211 committed Jan 31, 2024
1 parent 781f7dc commit 6cc4c33
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 55 deletions.
83 changes: 32 additions & 51 deletions src/operations/blas3/gemm_load_store_joint_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,16 @@ struct PacketizeJointMatrix {

/*! @brief Performs a coalesced non-vectorized load when the current block is
* not internal.
* @tparam trans Whether the source matrix is transposed or not.
* @tparam internal True if the current block is internal and no bounds
* checking is required.
* @tparam ld The leading dimension of the destination memory.
*/

template <bool trans, bool internal, int ld, typename SrcPointerType,
typename DestPointerType, typename EdgePredicate>
template <bool internal, typename SrcPointerType, typename DestPointerType,
typename EdgePredicate>
static PORTBLAS_INLINE typename std::enable_if<!internal>::type load(
const bool in_range, SrcPointerType src, DestPointerType dest,
EdgePredicate) {
value_t val = in_range ? *(src) : value_t{0};
value_t val = in_range ? *src : value_t{0};
using address_t = cl::sycl::access::address_space;
if constexpr (std::is_same<cl::sycl::multi_ptr<cl::sycl::half,
address_t::local_space>,
Expand All @@ -91,68 +89,51 @@ struct PacketizeJointMatrix {
* block is internal. In the case where k < the
* number of elements being loaded then edge loads will be element wise with
* additional bounds checking.
* @tparam trans Whether the source matrix is transposed or not.
* @tparam internal True if the current block is internal and no bounds
* checking is required.
* @tparam ld The leading dimension of the destination memory. */
template <bool trans, bool internal, index_t ld, typename SrcPointerType,
typename DestPointerType, typename EdgePredicate>
*/
template <bool internal, typename SrcPointerType, typename DestPointerType,
typename EdgePredicate>
static PORTBLAS_INLINE typename std::enable_if<internal>::type load(
const bool in_range, SrcPointerType src, DestPointerType dest,
EdgePredicate edge_in_range) {
PacketType packet{};

using address_t = cl::sycl::access::address_space;
if (in_range) {
using address_t = cl::sycl::access::address_space;
packet.template load<address_t::global_space>(
0, cl::sycl::multi_ptr<const value_t, address_t::global_space>(src));
store(packet, dest);
} else {
// avoid writing to variable, instead directly write to
// shared local memory to avoid race condition experienced
// with release compiler.
#pragma unroll
for (index_t i = 0; i < packet_size; i++) {
reinterpret_cast<value_t *>(&packet)[i] =
edge_in_range(i) ? *(src + i) : value_t{0};
}
}
store<trans, ld>(packet, dest);
}

/*! @brief Store a vector packet into local memory when the source is
* transposed. This will untranspose the elements individually when storing so
* the data in local memory is always consistent.
* @tparam trans Whether the source matrix is transposed or not.
* @tparam ld The leading dimension of the destination memory.*/
template <bool trans, index_t ld, typename DestPointerType>
static PORTBLAS_INLINE typename std::enable_if<trans>::type store(
PacketType &packet, DestPointerType dest) {
using address_t = cl::sycl::access::address_space;
#pragma unroll
for (index_t i = 0; i < packet_size; i++) {
value_t val = reinterpret_cast<value_t *>(&packet)[i];
if constexpr (std::is_same<cl::sycl::multi_ptr<cl::sycl::half,
address_t::local_space>,
DestPointerType>::value) {
using dtype = cl::sycl::half;
*(dest + ld * i) = static_cast<dtype>(val);
} else if constexpr (std::is_same<cl::sycl::multi_ptr<
cl::sycl::ext::oneapi::bfloat16,
address_t::local_space>,
DestPointerType>::value) {
using dtype = cl::sycl::ext::oneapi::bfloat16;
*(dest + ld * i) = static_cast<dtype>(val);
} else {
using namespace cl::sycl::ext::oneapi::experimental::matrix;
*(dest + ld * i) = round_to_tf32(val);
for (index_t i = 0; i < packet_size; i++, dest++, src++) {
if constexpr (std::is_same<cl::sycl::multi_ptr<cl::sycl::half,
address_t::local_space>,
DestPointerType>::value) {
using dtype = cl::sycl::half;
*dest = static_cast<dtype>(edge_in_range(i) ? *src : 0);
} else if constexpr (std::is_same<cl::sycl::multi_ptr<
cl::sycl::ext::oneapi::bfloat16,
address_t::local_space>,
DestPointerType>::value) {
using dtype = cl::sycl::ext::oneapi::bfloat16;
*dest = static_cast<dtype>(edge_in_range(i) ? *src : 0);
} else {
using namespace cl::sycl::ext::oneapi::experimental::matrix;
*dest = edge_in_range(i) ? round_to_tf32(*src) : 0.f;
}
}
}
}

/*! @brief Store a vector packet into local memory when the source is not
* transposed. This will use sycl::vec::store function.
* @tparam trans Whether the source matrix is transposed or not.
* @tparam ld The leading dimension of the destination memory.*/
template <bool trans, int ld, typename DestPointerType>
static PORTBLAS_INLINE typename std::enable_if<!trans>::type store(
PacketType &packet, DestPointerType dest) {
/*! @brief Store a vector packet into local memory. This will use
* sycl::vec::store function.
*/
template <typename DestPointerType>
static PORTBLAS_INLINE void store(PacketType &packet, DestPointerType dest) {
using address_t = cl::sycl::access::address_space;
if constexpr (std::is_same<cl::sycl::multi_ptr<cl::sycl::half,
address_t::local_space>,
Expand Down
7 changes: 3 additions & 4 deletions src/operations/blas3/gemm_local_joint_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -730,10 +730,9 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
const index_t col_ofs = i * (wg_size / rows);
const bool in_range =
do_check<check_row_limit>(in_row((item_id % rows), 0)) &&
do_check<check_col_limit>(
in_col((item_id / rows), col_ofs));
do_check<check_col_limit>(in_col((item_id / rows), col_ofs));

packetize_t::template load<trans, internal, lds>(
packetize_t::template load<internal>(
in_range, ptr + col_ofs * ld, scratch + col_ofs * lds,
[&](const index_t &ofs) {
return in_row(item_id % rows, ofs) &&
Expand All @@ -759,7 +758,7 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
do_check<check_row_limit>(in_row(item_id / cols, row_ofs)) &&
do_check<check_col_limit>(in_col(item_id % cols, 0));

packetize_t::template load<trans, internal, lds>(
packetize_t::template load<internal>(
in_range, ptr + row_ofs * ld, scratch + row_ofs * lds,
[&](const index_t &ofs) PORTBLAS_ALWAYS_INLINE {
return in_col(item_id % cols, ofs) &&
Expand Down

0 comments on commit 6cc4c33

Please sign in to comment.