Skip to content

Commit

Permalink
Cleanup and style fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
semi-h committed Mar 1, 2024
1 parent b982e52 commit 8e1f39b
Showing 1 changed file with 13 additions and 16 deletions.
29 changes: 13 additions & 16 deletions src/cuda/kernels/complex.f90
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ attributes(global) subroutine processfftdiv( &
div_c = aimag(div(i, j, b))/(nx*ny*nz)

! get the indices for x, y, z directions
ix = j; iy = i + (b-1)/(nz/2+1)*SZ; iz = mod(b-1, nz/2+1) + 1
ix = j; iy = i + (b - 1)/(nz/2 + 1)*SZ; iz = mod(b - 1, nz/2 + 1) + 1

! post-process forward
! post-process in z
Expand All @@ -46,16 +46,16 @@ attributes(global) subroutine processfftdiv( &
tmp_c = div_c
div_r = tmp_r*by(iy) + tmp_c*ay(iy)
div_c = tmp_c*by(iy) - tmp_r*ay(iy)
if ( iy > ny/2 + 1 ) div_r = -div_r
if ( iy > ny/2 + 1 ) div_c = -div_c
if (iy > ny/2 + 1) div_r = -div_r
if (iy > ny/2 + 1) div_c = -div_c

! post-process in x
tmp_r = div_r
tmp_c = div_c
div_r = tmp_r*bx(ix) + tmp_c*ax(ix)
div_c = tmp_c*bx(ix) - tmp_r*ax(ix)
if ( ix > nx/2 + 1 ) div_r = -div_r
if ( ix > nx/2 + 1 ) div_c = -div_c
if (ix > nx/2 + 1) div_r = -div_r
if (ix > nx/2 + 1) div_c = -div_c

! Solve Poisson
tmp_r = real(waves(i, j, b), kind=dp)
Expand All @@ -79,16 +79,16 @@ attributes(global) subroutine processfftdiv( &
tmp_c = div_c
div_r = tmp_r*by(iy) + tmp_c*ay(iy)
div_c = tmp_c*by(iy) - tmp_r*ay(iy)
if ( iy > ny/2 + 1 ) div_r = -div_r
if ( iy > ny/2 + 1 ) div_c = -div_c
if (iy > ny/2 + 1) div_r = -div_r
if (iy > ny/2 + 1) div_c = -div_c

! post-process in x
tmp_r = div_r
tmp_c = div_c
div_r = tmp_r*bx(ix) + tmp_c*ax(ix)
div_c =-tmp_c*bx(ix) + tmp_r*ax(ix)
if ( ix > nx/2 + 1 ) div_r = -div_r
if ( ix > nx/2 + 1 ) div_c = -div_c
if (ix > nx/2 + 1) div_r = -div_r
if (ix > nx/2 + 1) div_c = -div_c

! update the entry
div(i, j, b) = cmplx(div_r, div_c, kind=dp)
Expand All @@ -111,7 +111,7 @@ attributes(global) subroutine reorder_cmplx_x2y_T(u_y, u_x, nz)
b_i = blockIdx%x; b_j = blockIdx%y; b_k = blockIdx%z

! copy into shared
tile(i, j) = u_x((b_i - 1)*SZ + j, i, b_k + nz*(b_j-1))
tile(i, j) = u_x((b_i - 1)*SZ + j, i, b_k + nz*(b_j - 1))

call syncthreads()

Expand Down Expand Up @@ -166,7 +166,7 @@ attributes(global) subroutine reorder_cmplx_y2z_T(u_z, u_y, nx, nz)
! copy into shared
if ( j + (b_z - 1)*SZ <= nz ) &
tile(i, j) = u_y(i + (b_y - 1)*SZ, mod(b_x - 1, SZ) + 1, &
j + (b_z - 1)*SZ + ((b_x-1)/SZ)*nz)
j + (b_z - 1)*SZ + ((b_x - 1)/SZ)*nz)

call syncthreads()

Expand All @@ -185,14 +185,11 @@ attributes(global) subroutine reorder_cmplx_z2y_T(u_y, u_z, nx, nz)

complex(dp), shared :: tile(SZ, SZ)

integer :: i, j, k, b_i, b_j, b_k, b_x, b_y, b_z
integer :: i, j, k, b_x, b_y, b_z

i = threadIdx%x
j = threadIdx%y
k = threadIdx%z
b_i = blockIdx%x
b_j = blockIdx%y
b_k = blockIdx%z

b_x = blockIdx%z
b_y = blockIdx%y
Expand All @@ -206,7 +203,7 @@ attributes(global) subroutine reorder_cmplx_z2y_T(u_y, u_z, nx, nz)

! copy into output array from shared
if ( j + (b_z - 1)*SZ <= nz ) &
u_y(i + (b_y - 1)*SZ, mod(b_x - 1, SZ)+1, &
u_y(i + (b_y - 1)*SZ, mod(b_x - 1, SZ) + 1, &
j + (b_z - 1)*SZ + ((b_x - 1)/SZ)*nz) = tile(j, i)

end subroutine reorder_cmplx_z2y_T
Expand Down

0 comments on commit 8e1f39b

Please sign in to comment.