diff --git a/src/cuda/kernels/complex.f90 b/src/cuda/kernels/complex.f90 index 9007625d..0992586d 100644 --- a/src/cuda/kernels/complex.f90 +++ b/src/cuda/kernels/complex.f90 @@ -208,4 +208,88 @@ attributes(global) subroutine reorder_cmplx_z2y_T(u_y, u_z, nx, nz) end subroutine reorder_cmplx_z2y_T + attributes(global) subroutine reshapeDSF(uout, uin) + implicit none + + real(dp), device, intent(out), dimension(:, :, :) :: uout + real(dp), device, intent(in), dimension(:, :, :) :: uin + + real(dp), shared :: tile(SZ + 1, SZ) + + integer :: i, j, b_i, b + + i = threadIdx%x; j = threadIdx%y + b_i = blockIdx%x; b = blockIdx%y + + tile(i, j) = uin(i, j + (b_i - 1)*SZ, b) + + call syncthreads() + + uout(i + (b_i - 1)*SZ, j, b) = tile(j, i) + + end subroutine reshapeDSF + + attributes(global) subroutine reshapeDSB(uout, uin) + implicit none + + real(dp), device, intent(out), dimension(:, :, :) :: uout + real(dp), device, intent(in), dimension(:, :, :) :: uin + + real(dp), shared :: tile(SZ + 1, SZ) + + integer :: i, j, b_i, b + + i = threadIdx%x; j = threadIdx%y + b_i = blockIdx%x; b = blockIdx%y + + tile(i, j) = uin(i + (b_i - 1)*SZ, j, b) + + call syncthreads() + + uout(i, j + (b_i - 1)*SZ, b) = tile(j, i) + + end subroutine reshapeDSB + + attributes(global) subroutine reshapeCDSF(uout, uin) + implicit none + + complex(dp), device, intent(out), dimension(:, :, :) :: uout + complex(dp), device, intent(in), dimension(:, :, :) :: uin + + complex(dp), shared :: tile(SZ + 1, SZ) + + integer :: i, j, b_i, b + + i = threadIdx%x; j = threadIdx%y + b_i = blockIdx%x; b = blockIdx%y + + tile(i, j) = uin(i, j + (b_i - 1)*SZ, b) + + call syncthreads() + + uout(i + (b_i - 1)*SZ, j, b) = tile(j, i) + + end subroutine reshapeCDSF + + attributes(global) subroutine reshapeCDSB(uout, uin) + implicit none + + complex(dp), device, intent(out), dimension(:, :, :) :: uout + complex(dp), device, intent(in), dimension(:, :, :) :: uin + + complex(dp), shared :: tile(SZ + 1, SZ) + + integer :: i, j, b_i, b + + i = threadIdx%x; j = threadIdx%y + b_i = blockIdx%x; b = blockIdx%y + + tile(i, j) = uin(i + (b_i - 1)*SZ, j, b) + + call syncthreads() + + uout(i, j + (b_i - 1)*SZ, b) = tile(j, i) + + end subroutine reshapeCDSB + end module m_cuda_complex