diff --git a/src/operations/blas3/gemm_load_store_joint_matrix.hpp b/src/operations/blas3/gemm_load_store_joint_matrix.hpp index 35e43d1f8..c8e28f864 100644 --- a/src/operations/blas3/gemm_load_store_joint_matrix.hpp +++ b/src/operations/blas3/gemm_load_store_joint_matrix.hpp @@ -57,18 +57,16 @@ struct PacketizeJointMatrix { /*! @brief Performs a coalesced non-vectorized load when the current block is * not internal. - * @tparam trans Whether the source matrix is transposed or not. * @tparam internal True if the current block is internal and no bounds * checking is required. - * @tparam ld The leading dimension of the destination memory. */ - template + template static PORTBLAS_INLINE typename std::enable_if::type load( const bool in_range, SrcPointerType src, DestPointerType dest, EdgePredicate) { - value_t val = in_range ? *(src) : value_t{0}; + value_t val = in_range ? *src : value_t{0}; using address_t = cl::sycl::access::address_space; if constexpr (std::is_same, @@ -91,68 +89,51 @@ struct PacketizeJointMatrix { * block is internal. In the case where k < the * number of elements being loaded then edge loads will be element wise with * additional bounds checking. - * @tparam trans Whether the source matrix is transposed or not. * @tparam internal True if the current block is internal and no bounds * checking is required. - * @tparam ld The leading dimension of the destination memory. */ - template + */ + template static PORTBLAS_INLINE typename std::enable_if::type load( const bool in_range, SrcPointerType src, DestPointerType dest, EdgePredicate edge_in_range) { PacketType packet{}; + using address_t = cl::sycl::access::address_space; if (in_range) { - using address_t = cl::sycl::access::address_space; packet.template load( 0, cl::sycl::multi_ptr(src)); + store(packet, dest); } else { + // avoid writing to variable, instead directly write to + // shared local memory to avoid race condition experienced + // with release compiler. #pragma unroll - for (index_t i = 0; i < packet_size; i++) { - reinterpret_cast(&packet)[i] = - edge_in_range(i) ? *(src + i) : value_t{0}; - } - } - store(packet, dest); - } - - /*! @brief Store a vector packet into local memory when the source is - * transposed. This will untranspose the elements individually when storing so - * the data in local memory is always consistent. - * @tparam trans Whether the source matrix is transposed or not. - * @tparam ld The leading dimension of the destination memory.*/ - template - static PORTBLAS_INLINE typename std::enable_if::type store( - PacketType &packet, DestPointerType dest) { - using address_t = cl::sycl::access::address_space; -#pragma unroll - for (index_t i = 0; i < packet_size; i++) { - value_t val = reinterpret_cast(&packet)[i]; - if constexpr (std::is_same, - DestPointerType>::value) { - using dtype = cl::sycl::half; - *(dest + ld * i) = static_cast(val); - } else if constexpr (std::is_same, - DestPointerType>::value) { - using dtype = cl::sycl::ext::oneapi::bfloat16; - *(dest + ld * i) = static_cast(val); - } else { - using namespace cl::sycl::ext::oneapi::experimental::matrix; - *(dest + ld * i) = round_to_tf32(val); + for (index_t i = 0; i < packet_size; i++, dest++, src++) { + if constexpr (std::is_same, + DestPointerType>::value) { + using dtype = cl::sycl::half; + *dest = static_cast(edge_in_range(i) ? *src : 0); + } else if constexpr (std::is_same, + DestPointerType>::value) { + using dtype = cl::sycl::ext::oneapi::bfloat16; + *dest = static_cast(edge_in_range(i) ? *src : 0); + } else { + using namespace cl::sycl::ext::oneapi::experimental::matrix; + *dest = edge_in_range(i) ? round_to_tf32(*src) : 0.f; + } } } } - /*! @brief Store a vector packet into local memory when the source is not - * transposed. This will use sycl::vec::store function. - * @tparam trans Whether the source matrix is transposed or not. - * @tparam ld The leading dimension of the destination memory.*/ - template - static PORTBLAS_INLINE typename std::enable_if::type store( - PacketType &packet, DestPointerType dest) { + /*! @brief Store a vector packet into local memory. This will use + * sycl::vec::store function. + */ + template + static PORTBLAS_INLINE void store(PacketType &packet, DestPointerType dest) { using address_t = cl::sycl::access::address_space; if constexpr (std::is_same, diff --git a/src/operations/blas3/gemm_local_joint_matrix.hpp b/src/operations/blas3/gemm_local_joint_matrix.hpp index 1a0b7524a..3298203f2 100644 --- a/src/operations/blas3/gemm_local_joint_matrix.hpp +++ b/src/operations/blas3/gemm_local_joint_matrix.hpp @@ -730,10 +730,9 @@ class Gemm(in_row((item_id % rows), 0)) && - do_check( - in_col((item_id / rows), col_ofs)); + do_check(in_col((item_id / rows), col_ofs)); - packetize_t::template load( + packetize_t::template load( in_range, ptr + col_ofs * ld, scratch + col_ofs * lds, [&](const index_t &ofs) { return in_row(item_id % rows, ofs) && @@ -759,7 +758,7 @@ class Gemm(in_row(item_id / cols, row_ofs)) && do_check(in_col(item_id % cols, 0)); - packetize_t::template load( + packetize_t::template load( in_range, ptr + row_ofs * ld, scratch + row_ofs * lds, [&](const index_t &ofs) PORTBLAS_ALWAYS_INLINE { return in_col(item_id % cols, ofs) &&