Skip to content

Commit

Permalink
Add reshape subrotines for transposing pencil groups.
Browse files Browse the repository at this point in the history
  • Loading branch information
semi-h committed Mar 1, 2024
1 parent 8e1f39b commit 03a8a34
Showing 1 changed file with 84 additions and 0 deletions.
84 changes: 84 additions & 0 deletions src/cuda/kernels/complex.f90
Original file line number Diff line number Diff line change
Expand Up @@ -208,4 +208,88 @@ attributes(global) subroutine reorder_cmplx_z2y_T(u_y, u_z, nx, nz)

end subroutine reorder_cmplx_z2y_T

attributes(global) subroutine reshapeDSF(uout, uin)
implicit none

real(dp), device, intent(out), dimension(:, :, :) :: uout
real(dp), device, intent(in), dimension(:, :, :) :: uin

real(dp), shared :: tile(SZ + 1, SZ)

integer :: i, j, b_i, b

i = threadIdx%x; j = threadIdx%y
b_i = blockIdx%x; b = blockIdx%y

tile(i, j) = uin(i, j + (b_i - 1)*SZ, b)

call syncthreads()

uout(i + (b_i - 1)*SZ, j, b) = tile(j, i)

end subroutine reshapeDSF

attributes(global) subroutine reshapeDSB(uout, uin)
implicit none

real(dp), device, intent(out), dimension(:, :, :) :: uout
real(dp), device, intent(in), dimension(:, :, :) :: uin

real(dp), shared :: tile(SZ + 1, SZ)

integer :: i, j, b_i, b

i = threadIdx%x; j = threadIdx%y
b_i = blockIdx%x; b = blockIdx%y

tile(i, j) = uin(i + (b_i - 1)*SZ, j, b)

call syncthreads()

uout(i, j + (b_i - 1)*SZ, b) = tile(j, i)

end subroutine reshapeDSB

attributes(global) subroutine reshapeCDSF(uout, uin)
implicit none

complex(dp), device, intent(out), dimension(:, :, :) :: uout
complex(dp), device, intent(in), dimension(:, :, :) :: uin

complex(dp), shared :: tile(SZ + 1, SZ)

integer :: i, j, b_i, b

i = threadIdx%x; j = threadIdx%y
b_i = blockIdx%x; b = blockIdx%y

tile(i, j) = uin(i, j + (b_i - 1)*SZ, b)

call syncthreads()

uout(i + (b_i - 1)*SZ, j, b) = tile(j, i)

end subroutine reshapeCDSF

attributes(global) subroutine reshapeCDSB(uout, uin)
implicit none

complex(dp), device, intent(out), dimension(:, :, :) :: uout
complex(dp), device, intent(in), dimension(:, :, :) :: uin

complex(dp), shared :: tile(SZ + 1, SZ)

integer :: i, j, b_i, b

i = threadIdx%x; j = threadIdx%y
b_i = blockIdx%x; b = blockIdx%y

tile(i, j) = uin(i + (b_i - 1)*SZ, j, b)

call syncthreads()

uout(i, j + (b_i - 1)*SZ, b) = tile(j, i)

end subroutine reshapeCDSB

end module m_cuda_complex

0 comments on commit 03a8a34

Please sign in to comment.