Skip to content

Commit

Permalink
Restriced VectorSize to 1 for joint_matrix
Browse files Browse the repository at this point in the history
* Updated the load/store file with proper vectorized load/store for future use.
  • Loading branch information
muhammad-tanvir-1211 committed Jan 31, 2024
1 parent eb628b6 commit 781f7dc
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 26 deletions.
27 changes: 24 additions & 3 deletions src/operations/blas3/gemm_load_store_joint_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ struct PacketizeJointMatrix {
*dest = round_to_tf32(val);
}
}

/*! @brief Performs a vectorised load using sycl::vec::load when the current
* block is internal. In the case where k < the
* number of elements being loaded then edge loads will be element wise with
Expand Down Expand Up @@ -114,6 +115,7 @@ struct PacketizeJointMatrix {
}
store<trans, ld>(packet, dest);
}

/*! @brief Store a vector packet into local memory when the source is
* transposed. This will untranspose the elements individually when storing so
* the data in local memory is always consistent.
Expand Down Expand Up @@ -156,16 +158,35 @@ struct PacketizeJointMatrix {
address_t::local_space>,
DestPointerType>::value) {
using dtype = cl::sycl::half;
*dest = static_cast<dtype>(packet[0]);
cl::sycl::vec<dtype, vector_size> new_vec{};
for (index_t i = 0; i < packet_size; i++) {
reinterpret_cast<dtype *>(&new_vec)[i] =
static_cast<dtype>(reinterpret_cast<value_t *>(&packet)[i]);
}
new_vec.template store<address_t::local_space>(
0, cl::sycl::multi_ptr<dtype, address_t::local_space>(dest));
} else if constexpr (std::is_same<cl::sycl::multi_ptr<
cl::sycl::ext::oneapi::bfloat16,
address_t::local_space>,
DestPointerType>::value) {
using dtype = cl::sycl::ext::oneapi::bfloat16;
*dest = static_cast<dtype>(packet[0]);
cl::sycl::vec<dtype, vector_size> new_vec;
for (index_t i = 0; i < packet_size; i++) {
reinterpret_cast<dtype *>(&new_vec)[i] =
static_cast<dtype>(reinterpret_cast<value_t *>(&packet)[i]);
}
new_vec.template store<address_t::local_space>(
0, cl::sycl::multi_ptr<dtype, address_t::local_space>(dest));
} else {
using namespace cl::sycl::ext::oneapi::experimental::matrix;
*dest = round_to_tf32(packet[0]);
using dtype = float;
cl::sycl::vec<dtype, vector_size> new_vec;
for (index_t i = 0; i < packet_size; i++) {
reinterpret_cast<dtype *>(&new_vec)[i] =
round_to_tf32(reinterpret_cast<value_t *>(&packet)[i]);
}
new_vec.template store<address_t::local_space>(
0, cl::sycl::multi_ptr<dtype, address_t::local_space>(dest));
}
}
};
Expand Down
41 changes: 18 additions & 23 deletions src/operations/blas3/gemm_local_joint_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
using value_t = element_t;
using index_t = typename std::make_signed<typename input_t::index_t>::type;
using packetize_t = PacketizeJointMatrix<VectorSize, value_t, index_t>;
using vector_t = typename packetize_t::PacketType;
using address_t = cl::sycl::access::address_space;

// enable easier access to tile dimensions
Expand Down Expand Up @@ -156,6 +155,9 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
static_assert(std::is_same<value_t, float>::value,
"This code is only supported for float data type.");

static_assert(VectorSize == 1,
"Vectorization not supported for joint_matrix.");

//! @brief leading dimension of block of A in local
static constexpr index_t ldsa =
(trans_a ? cl_elems : block_rows) +
Expand Down Expand Up @@ -366,7 +368,6 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
batch_size_);
}
} else {
using address_t = cl::sycl::access::address_space;
auto input_scratch = *reinterpret_cast<cl::sycl::multi_ptr<
typename tile_type::jmInpType, address_t::local_space> *>(&scratch);

Expand Down Expand Up @@ -721,25 +722,22 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
index_t item_id, InputPointerType ptr, index_t ld,
ScratchPointerType scratch, RowPredicate in_row, ColPredicate in_col) {
constexpr index_t bs = rows * cols;
constexpr index_t multiplier = internal ? packetize_t::packet_size : 1;
constexpr index_t loop_iterations = (bs - 1) / (wg_size * multiplier) + 1;
constexpr index_t loop_iterations = (bs - 1) / wg_size + 1;
#pragma unroll
for (index_t i = 0; i < loop_iterations; ++i) {
if (!do_check<((bs % (wg_size * multiplier)) != 0)>(
item_id + i * (wg_size * multiplier) < bs))
if (!do_check<((bs % wg_size) != 0)>(item_id + i * wg_size < bs))
continue;
const index_t col_ofs = i * ((wg_size * multiplier) / rows);
const index_t col_ofs = i * (wg_size / rows);
const bool in_range =
do_check<check_row_limit>(
in_row(((item_id * multiplier) % rows), multiplier - 1)) &&
do_check<check_row_limit>(in_row((item_id % rows), 0)) &&
do_check<check_col_limit>(
in_col((item_id * multiplier / rows), col_ofs));
in_col((item_id / rows), col_ofs));

packetize_t::template load<trans, internal, lds>(
in_range, ptr + col_ofs * ld, scratch + col_ofs * lds,
[&](const index_t &ofs) {
return in_row((item_id * multiplier) % rows, ofs) &&
in_col((item_id * multiplier) / rows, col_ofs);
return in_row(item_id % rows, ofs) &&
in_col(item_id / rows, col_ofs);
});
}
}
Expand All @@ -751,24 +749,21 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
index_t item_id, InputPointerType ptr, index_t ld,
ScratchPointerType scratch, RowPredicate in_row, ColPredicate in_col) {
constexpr index_t bs = rows * cols;
constexpr index_t multiplier = internal ? packetize_t::packet_size : 1;
constexpr index_t loop_iterations = (bs - 1) / (wg_size * multiplier) + 1;
constexpr index_t loop_iterations = (bs - 1) / wg_size + 1;
#pragma unroll
for (index_t i = 0; i < loop_iterations; ++i) {
if (!do_check<((bs % (wg_size * multiplier)) != 0)>(
item_id + i * (wg_size * multiplier) < bs))
if (!do_check<((bs % wg_size) != 0)>(item_id + i * wg_size < bs))
continue;
const index_t row_ofs = i * ((wg_size * multiplier) / cols);
const bool in_range = do_check<check_row_limit>(in_row(
(item_id * multiplier) / cols, row_ofs)) &&
do_check<check_col_limit>(in_col(
(item_id * multiplier) % cols, multiplier - 1));
const index_t row_ofs = i * (wg_size / cols);
const bool in_range =
do_check<check_row_limit>(in_row(item_id / cols, row_ofs)) &&
do_check<check_col_limit>(in_col(item_id % cols, 0));

packetize_t::template load<trans, internal, lds>(
in_range, ptr + row_ofs * ld, scratch + row_ofs * lds,
[&](const index_t &ofs) PORTBLAS_ALWAYS_INLINE {
return in_col((item_id * multiplier) % cols, ofs) &&
in_row((item_id * multiplier) / cols, row_ofs);
return in_col(item_id % cols, ofs) &&
in_row(item_id / cols, row_ofs);
});
}
}
Expand Down

0 comments on commit 781f7dc

Please sign in to comment.