Skip to content

Commit

Permalink
Fixes & formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
OuadiElfarouki committed Jul 17, 2023
1 parent ef6879d commit e865638
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 22 deletions.
45 changes: 32 additions & 13 deletions include/interface/extension_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,25 +280,21 @@ typename sb_handle_t::event_t _transpose(sb_handle_t& sb_handle, index_t m,
}

/**
* \brief COPY batch of matrices from in_matrix to out_matrix with scaling
* factor of alpha
* \brief COPY batch of matrices inplace with scaling factor of alpha
*
* @tparam sb_handle_t SB_Handle type
* @tparam element_t Scaling factor type
* @tparam index_t Index type
* @tparam in_t Buffer Iterator
* @tparam out_t Buffer Iterator
* @tparam in_out_t input/output type
* @param sb_handle SB_Handle
* @param trans compute matrix transpose or not.
* @param trans compute matrix transpose or not
* @param m rows of matrix
* @param n cols of matrix
* @param alpha Scaling factor
* @param in_memory BufferIterator of input
* @param ld_in leading dimension of in_matrices
* @param stride_in stride distance between matrices inside batch
* @param matrix_out BufferIterator of output
* @param ld_out leading dimention of out_matrix
* @param stride_out stride distance between matrices inside batch
* @param memory container of input & output matrices
* @param ld_in leading dimension at input
* @param ld_out leading dimention at output
* @param stride stride distance between matrices inside batch
* @param batch_size number of matrices to compute
*/
template <typename sb_handle_t, typename element_t, typename index_t,
Expand All @@ -309,10 +305,33 @@ typename sb_handle_t::event_t _imatcopy_batch(sb_handle_t& sb_handle,
index_t ld_in, index_t ld_out,
index_t stride,
index_t batch_size) {
return internal::_matcopy_batch<true>(sb_handle, trans, m, n, alpha, memory, ld_in,
stride, memory, ld_out, stride, batch_size);
return internal::_matcopy_batch<true>(sb_handle, trans, m, n, alpha, memory,
ld_in, stride, memory, ld_out, stride,
batch_size);
}

/**
* \brief COPY batch of matrices outplace from in_memory to out_memory with
* scaling factor of alpha
*
* @tparam sb_handle_t SB_Handle type
* @tparam element_t Scaling factor type
* @tparam index_t Index type
* @tparam in_t container input type
* @tparam out_t container output type
* @param sb_handle SB_Handle
* @param trans compute matrix transpose or not
* @param m rows of matrix
* @param n cols of matrix
* @param alpha Scaling factor
* @param in_memory input matrix container
* @param ld_in leading dimension of input
* @param stride_in stride distance between matrices inside batch
* @param out_memory output matrix container
* @param ld_out leading dimention of output
* @param stride_out stride distance between matrices inside batch
* @param batch_size number of matrices to compute
*/
template <typename sb_handle_t, typename element_t, typename index_t,
typename in_t, typename out_t>
typename sb_handle_t::event_t _omatcopy_batch(
Expand Down
1 change: 0 additions & 1 deletion src/interface/extension/backend/default_cpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#ifndef SYCL_BLAS_TRANSPOSE_DEFAULT_CPU_BACKEND_HPP
#define SYCL_BLAS_TRANSPOSE_DEFAULT_CPU_BACKEND_HPP
#include "interface/extension_interface.h"
#include "interface/transpose_launcher.h"

namespace blas {
namespace extension {
Expand Down
1 change: 1 addition & 0 deletions src/interface/extension/matcopy_batch.cpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "sb_handle/kernel_constructor.hpp"
#include "sb_handle/sycl_blas_handle.hpp"
#include "operations/extension/matcopy_batch.hpp"
#include "operations/extension/transpose.hpp"

namespace blas {
namespace extension {
Expand Down
10 changes: 4 additions & 6 deletions src/interface/extension_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,10 @@ typename sb_handle_t::event_t _matcopy_batch_impl(
in_t in_memory, index_t ld_in, index_t in_stride, out_t out_memory,
index_t ld_out, index_t out_stride, index_t batch_size) {
auto in_view = make_matrix_view<col_major>(in_memory, m, n, ld_in);
auto out_view =
make_matrix_view<col_major>(out_memory, m, n, ld_out);
auto copy_batch_tree =
make_matcopy_batch<false, TileSize, TilePerWG>(
out_view, in_view, in_view, alpha, 0, m, n, ld_out, ld_in, 1,
out_stride, in_stride, 1, batch_size);
auto out_view = make_matrix_view<col_major>(out_memory, m, n, ld_out);
auto copy_batch_tree = make_matcopy_batch<false, TileSize, TilePerWG>(
out_view, in_view, in_view, alpha, 0, m, n, ld_out, ld_in, 1, out_stride,
in_stride, 1, batch_size);
constexpr index_t local_size = TileSize * TilePerWG;
const index_t tile_per_matrix =
(((m - 1) / TileSize) + 1) * (((n - 1) / TileSize) + 1);
Expand Down
4 changes: 2 additions & 2 deletions test/unittest/extension/extension_reference.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ namespace reference_blas {
*/
template <typename index_t, typename scalar_t>
void ext_omatcopy(char trans, const index_t m, const index_t n,
const scalar_t alpha, scalar_t* A,
const index_t lda, scalar_t* B, index_t ldb) {
const scalar_t alpha, scalar_t* A, const index_t lda,
scalar_t* B, index_t ldb) {
if (trans != 't') {
for (index_t j = 0; j < n; j++) {
for (index_t i = 0; i < m; i++) {
Expand Down

0 comments on commit e865638

Please sign in to comment.