Skip to content

Commit

Permalink
Fixed compilation error with bfloat16 type
Browse files Browse the repository at this point in the history
  • Loading branch information
muhammad-tanvir-1211 committed Jan 31, 2024
1 parent 6cc4c33 commit e7f1a43
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
21 changes: 10 additions & 11 deletions src/operations/blas3/gemm_load_store_joint_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ struct PacketizeJointMatrix {
cl::sycl::ext::oneapi::bfloat16,
address_t::local_space>,
DestPointerType>::value) {
using dtype = cl::sycl::ext::oneapi::bfloat16;
*dest = static_cast<dtype>(val);
using namespace cl::sycl::ext::oneapi;
*dest = bfloat16(val);
} else {
using namespace cl::sycl::ext::oneapi::experimental::matrix;
*dest = round_to_tf32(val);
Expand Down Expand Up @@ -119,8 +119,8 @@ struct PacketizeJointMatrix {
cl::sycl::ext::oneapi::bfloat16,
address_t::local_space>,
DestPointerType>::value) {
using dtype = cl::sycl::ext::oneapi::bfloat16;
*dest = static_cast<dtype>(edge_in_range(i) ? *src : 0);
using namespace cl::sycl::ext::oneapi;
*dest = bfloat16(edge_in_range(i) ? *src : 0.f);
} else {
using namespace cl::sycl::ext::oneapi::experimental::matrix;
*dest = edge_in_range(i) ? round_to_tf32(*src) : 0.f;
Expand Down Expand Up @@ -150,14 +150,13 @@ struct PacketizeJointMatrix {
cl::sycl::ext::oneapi::bfloat16,
address_t::local_space>,
DestPointerType>::value) {
using dtype = cl::sycl::ext::oneapi::bfloat16;
cl::sycl::vec<dtype, vector_size> new_vec;
for (index_t i = 0; i < packet_size; i++) {
reinterpret_cast<dtype *>(&new_vec)[i] =
static_cast<dtype>(reinterpret_cast<value_t *>(&packet)[i]);
// sycl::vec doesn't accept bfloat16 as a valid input type
// so we need to write the packet elements individually to
// the shared memory.
using namespace cl::sycl::ext::oneapi;
for (index_t i = 0; i < packet_size; i++, dest++) {
*dest = bfloat16(reinterpret_cast<value_t *>(&packet)[i]);
}
new_vec.template store<address_t::local_space>(
0, cl::sycl::multi_ptr<dtype, address_t::local_space>(dest));
} else {
using namespace cl::sycl::ext::oneapi::experimental::matrix;
using dtype = float;
Expand Down
3 changes: 2 additions & 1 deletion test/blas_test.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,8 @@ static inline void fill_trsm_matrix(std::vector<scalar_t> &A, size_t k,
* @param val input/output float value.
* @param nbits number of last bit set to zero. It is set by default to 13 since
* this is the difference of the number of bits of the mantissa between floats
* (23) and FP16 / NVIDIA TF32 (10).
* (23) and FP16 / NVIDIA TF32 (10). For bfloat16, this value needs to be set to
* 16 to get correct result.
*/
static inline void set_to_zero_last_nbits(float &val, int32_t nbits = 13) {
int32_t *int_pntr = reinterpret_cast<int32_t *>(&val);
Expand Down

0 comments on commit e7f1a43

Please sign in to comment.