Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add wavefield preconditioner #584

Draft
wants to merge 23 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 53 additions & 6 deletions ptypy/accelerate/base/engines/ML_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,13 @@ def _get_smooth_gradient(self, data, sigma):

def _replace_ob_grad(self):
new_ob_grad = self.ob_grad_new

# Wavefield preconditioner for the object
if self.p.wavefield_precond:
for name, s in new_ob_grad.storages.items():
s.data /= np.sqrt(self.ob_fln.storages[name].data
+ self.p.wavefield_delta_object)

# Smoothing preconditioner
if self.smooth_gradient:
self.smooth_gradient.sigma *= (1. - self.p.smooth_gradient_decay)
Expand All @@ -175,14 +182,21 @@ def _replace_ob_grad(self):

def _replace_pr_grad(self):
new_pr_grad = self.pr_grad_new
# probe support

# Probe support
if self.p.probe_update_start <= self.curiter:
# Apply probe support if needed
for name, s in new_pr_grad.storages.items():
self.support_constraint(s)
else:
new_pr_grad.fill(0.)

# Wavefield preconditioner for the probe
if self.p.wavefield_precond:
for name, s in new_pr_grad.storages.items():
s.data /= np.sqrt(self.pr_fln.storages[name].data
+ self.p.wavefield_delta_probe)

norm = Cnorm2(new_pr_grad)
dot = np.real(Cdot(new_pr_grad, self.pr_grad))
self.pr_grad << new_pr_grad
Expand Down Expand Up @@ -240,10 +254,20 @@ def engine_iterate(self, num=1):
self.cn2_pr_grad = cn2_new_pr_grad

dt = self.ptycho.FType

# 3. Next conjugate
self.ob_h *= dt(bt / self.tmin)

# Smoothing preconditioner
# Wavefield preconditioner for the object (with and without smoothing preconditioner)
if self.p.wavefield_precond:
for name, s in self.ob_h.storages.items():
if self.smooth_gradient:
s.data[:] -= self._get_smooth_gradient(self.ob_grad.storages[name].data
/ np.sqrt(self.ob_fln.storages[name].data + self.p.wavefield_delta_object)
, self.smooth_gradient.sigma)
else:
s.data[:] -= (self.ob_grad.storages[name].data
/ np.sqrt(self.ob_fln.storages[name].data + self.p.wavefield_delta_object))
# Smoothing preconditioner for the object
if self.smooth_gradient:
for name, s in self.ob_h.storages.items():
s.data[:] -= self._get_smooth_gradient(self.ob_grad.storages[name].data, self.smooth_gradient.sigma)
Expand All @@ -252,7 +276,13 @@ def engine_iterate(self, num=1):

self.pr_h *= dt(bt / self.tmin)
self.pr_grad *= dt(self.scale_p_o)
self.pr_h -= self.pr_grad
# Wavefield preconditioner for the probe
if self.p.wavefield_precond:
for name, s in self.pr_h.storages.items():
s.data[:] -= (self.pr_grad.storages[name].data
/ np.sqrt(self.pr_fln.storages[name].data + self.p.wavefield_delta_probe))
else:
self.pr_h -= self.pr_grad

# In principle, the way things are now programmed this part
# could be iterated over in a real Newton-Raphson style.
Expand Down Expand Up @@ -416,6 +446,11 @@ def new_grad(self):
pr_grad = self.engine.pr_grad_new
ob_grad << 0.
pr_grad << 0.
if self.engine.p.wavefield_precond:
ob_fln = self.engine.ob_fln
pr_fln = self.engine.pr_fln
ob_fln << 0.
pr_fln << 0.

# We need an array for MPI
LL = np.array([0.])
Expand Down Expand Up @@ -450,6 +485,11 @@ def new_grad(self):
prg = pr_grad.S[pID].data
I = prep.I

# local references for wavefield precond
if self.engine.p.wavefield_precond:
obf = ob_fln.S[oID].data
prf = pr_fln.S[pID].data

# make propagated exit (to buffer)
AWK.build_aux_no_ex(aux, addr, ob, pr, add=False)

Expand All @@ -465,8 +505,12 @@ def new_grad(self):
GDK.error_reduce(addr, err_phot)
aux[:] = BW(aux)

POK.ob_update_ML(addr, obg, pr, aux)
POK.pr_update_ML(addr, prg, ob, aux)
if self.engine.p.wavefield_precond:
POK.ob_update_ML_wavefield(addr, obg, obf, pr, aux)
POK.pr_update_ML_wavefield(addr, prg, prf, ob, aux)
else:
POK.ob_update_ML(addr, obg, pr, aux)
POK.pr_update_ML(addr, prg, ob, aux)

for dID, prep in self.engine.diff_info.items():
err_phot = prep.err_phot
Expand All @@ -480,6 +524,9 @@ def new_grad(self):
# MPI reduction of gradients
ob_grad.allreduce()
pr_grad.allreduce()
if self.engine.p.wavefield_precond:
ob_fln.allreduce()
pr_fln.allreduce()
parallel.allreduce(LL)

# Object regularizer
Expand Down
26 changes: 26 additions & 0 deletions ptypy/accelerate/base/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,32 @@ def pr_update_ML(self, addr, pr, ob, ex, fac=2.0):
ex[exc[0], exc[1]:exc[1] + rows, exc[2]:exc[2] + cols] * fac
return

def ob_update_ML_wavefield(self, addr, ob, obf, pr, ex, fac=2.0):

sh = addr.shape
flat_addr = addr.reshape(sh[0] * sh[1], sh[2], sh[3])
rows, cols = ex.shape[-2:]
for ind, (prc, obc, exc, mac, dic) in enumerate(flat_addr):
ob[obc[0], obc[1]:obc[1] + rows, obc[2]:obc[2] + cols] += \
pr[prc[0], prc[1]:prc[1] + rows, prc[2]:prc[2] + cols].conj() * \
ex[exc[0], exc[1]:exc[1] + rows, exc[2]:exc[2] + cols] * fac
obf[obc[0], obc[1]:obc[1] + rows, obc[2]:obc[2] + cols] += \
abs2(pr[prc[0], prc[1]:prc[1] + rows, prc[2]:prc[2] + cols])
return

def pr_update_ML_wavefield(self, addr, pr, prf, ob, ex, fac=2.0):

sh = addr.shape
flat_addr = addr.reshape(sh[0] * sh[1], sh[2], sh[3])
rows, cols = ex.shape[-2:]
for ind, (prc, obc, exc, mac, dic) in enumerate(flat_addr):
pr[prc[0], prc[1]:prc[1] + rows, prc[2]:prc[2] + cols] += \
ob[obc[0], obc[1]:obc[1] + rows, obc[2]:obc[2] + cols].conj() * \
ex[exc[0], exc[1]:exc[1] + rows, exc[2]:exc[2] + cols] * fac
prf[prc[0], prc[1]:prc[1] + rows, prc[2]:prc[2] + cols] += \
abs2(ob[obc[0], obc[1]:obc[1] + rows, obc[2]:obc[2] + cols])
return

def ob_update_local(self, addr, ob, pr, ex, aux, prn, a=0., b=1.):
sh = addr.shape
flat_addr = addr.reshape(sh[0] * sh[1], sh[2], sh[3])
Expand Down
123 changes: 123 additions & 0 deletions ptypy/accelerate/cuda_common/ob_update2_ML_wavefield.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/** ob_update2_ML_wavefield.
*
* Data types:
* - IN_TYPE: the data type for the inputs (float or double)
* - OUT_TYPE: the data type for the outputs (float or double)
* - MATH_TYPE: the data type used for computation
* - ACC_TYPE: accumulator for the ob field
*
* NOTE: This version of ob_update goes over all tiles that need to be accumulated
* in a single thread block to avoid global atomic additions (as in ob_update_ML_wavefield.cu).
* This requires a local array of NUM_MODES size to store the local updates.
* GPU registers per thread are limited (255 32bit registers on V100),
* and at some point the registers will spill into shared or global memory
* and the kernel will get considerably slower.
*/

#include "common.cuh"

#define pr_dlayer(k) addr[(k)]
#define ex_dlayer(k) addr[6 * num_pods + (k)]
#define obj_dlayer(k) addr[3 * num_pods + (k)]
#define obj_roi_row(k) addr[4 * num_pods + (k)]
#define obj_roi_column(k) addr[5 * num_pods + (k)]

extern "C" __global__ void ob_update2_ML_wavefield(int pr_sh,
int ob_modes,
int num_pods,
int ob_sh_rows,
int ob_sh_cols,
int pr_modes,
complex<OUT_TYPE>* ob_g,
const complex<IN_TYPE>* __restrict__ pr_g,
const complex<IN_TYPE>* __restrict__ ex_g,
OUT_TYPE* ob_f,
const int* addr,
IN_TYPE fac_)
{
int y = blockIdx.y * BDIM_Y + threadIdx.y;
int dy = ob_sh_rows;
int z = blockIdx.x * BDIM_X + threadIdx.x;
int dz = ob_sh_cols;
MATH_TYPE fac = fac_;
complex<ACC_TYPE> ob[NUM_MODES];
ACC_TYPE of[NUM_MODES];

int txy = threadIdx.y * BDIM_X + threadIdx.x;
assert(ob_modes <= NUM_MODES);

if (y < dy && z < dz)
{
#pragma unroll
for (int i = 0; i < NUM_MODES; ++i)
{
auto idx = i * dy * dz + y * dz + z;
assert(idx < ob_modes * ob_sh_rows * ob_sh_cols);
ob[i] = ob_g[idx];
of[i] = ob_f[idx];
}
}

__shared__ int addresses[BDIM_X * BDIM_Y * 5];

for (int p = 0; p < num_pods; p += BDIM_X * BDIM_Y)
{
int mi = BDIM_X * BDIM_Y;
if (mi > num_pods - p)
mi = num_pods - p;

if (p > 0)
__syncthreads();

if (txy < mi)
{
assert(p + txy < num_pods);
assert(txy < BDIM_X * BDIM_Y);
addresses[txy * 5 + 0] = pr_dlayer(p + txy);
addresses[txy * 5 + 1] = ex_dlayer(p + txy);
addresses[txy * 5 + 2] = obj_dlayer(p + txy);
assert(obj_dlayer(p + txy) < NUM_MODES);
assert(addresses[txy * 5 + 2] < NUM_MODES);
addresses[txy * 5 + 3] = obj_roi_row(p + txy);
addresses[txy * 5 + 4] = obj_roi_column(p + txy);
}

__syncthreads();

if (y >= dy || z >= dz)
continue;

#pragma unroll 4
for (int i = 0; i < mi; ++i)
{
int* ad = addresses + i * 5;
int v1 = y - ad[3];
int v2 = z - ad[4];
if (v1 >= 0 && v1 < pr_sh && v2 >= 0 && v2 < pr_sh)
{
auto pridx = ad[0] * pr_sh * pr_sh + v1 * pr_sh + v2;
assert(pridx < pr_modes * pr_sh * pr_sh);
complex<MATH_TYPE> pr = pr_g[pridx];
int idx = ad[2];
assert(idx < NUM_MODES);
auto cpr = conj(pr);
auto exidx = ad[1] * pr_sh * pr_sh + v1 * pr_sh + v2;
complex<MATH_TYPE> ex_val = ex_g[exidx];
complex<ACC_TYPE> add_val = cpr * ex_val * fac;
ob[idx] += add_val;
complex<MATH_TYPE> abs2_val = cpr * pr;
ACC_TYPE add_val2 = abs2_val.real();
of[idx] += add_val2;
}
}
}

if (y < dy && z < dz)
{
for (int i = 0; i < NUM_MODES; ++i)
{
ob_g[i * dy * dz + y * dz + z] = ob[i];
ob_f[i * dy * dz + y * dz + z] = of[i];
}
}
}
72 changes: 72 additions & 0 deletions ptypy/accelerate/cuda_common/ob_update_ML_wavefield.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/** ob_update_ML_wavefield.
*
* Data types:
* - IN_TYPE: the data type for the inputs (float or double)
* - OUT_TYPE: the data type for the outputs (float or double)
* - MATH_TYPE: the data type used for computation
*/

#include "common.cuh"

template <class T>
__device__ inline void atomicAdd(complex<T>* x, const complex<T>& y)
{
auto xf = reinterpret_cast<T*>(x);
atomicAdd(xf, y.real());
atomicAdd(xf + 1, y.imag());
}

extern "C"
{
__global__ void ob_update_ML_wavefield(const complex<IN_TYPE>* __restrict__ exit_wave,
int A,
int B,
int C,
const complex<IN_TYPE>* __restrict__ probe,
int D,
int E,
int F,
complex<OUT_TYPE>* obj,
int G,
int H,
int I,
OUT_TYPE* obj_fln,
const int* __restrict__ addr,
IN_TYPE fac_)
{
const int bid = blockIdx.x;
const int tx = threadIdx.x;
const int ty = threadIdx.y;
const int addr_stride = 15;
MATH_TYPE fac = fac_;

const int* oa = addr + 3 + bid * addr_stride;
const int* pa = addr + bid * addr_stride;
const int* ea = addr + 6 + bid * addr_stride;

probe += pa[0] * E * F + pa[1] * F + pa[2];
obj += oa[0] * H * I + oa[1] * I + oa[2];

assert(oa[0] * H * I + oa[1] * I + oa[2] + (B - 1) * I + C - 1 < G * H * I);

exit_wave += ea[0] * B * C;

for (int b = ty; b < B; b += blockDim.y)
{
for (int c = tx; c < C; c += blockDim.x)
{
complex<MATH_TYPE> probe_val = probe[b * F + c];
complex<MATH_TYPE> exit_val = exit_wave[b * C + c];
complex<MATH_TYPE> add_val_m = conj(probe_val) * exit_val * fac;
complex<OUT_TYPE> add_val = add_val_m;

atomicAdd(&obj[b * I + c], add_val);

complex<MATH_TYPE> abs2_val = conj(probe_val) * probe_val;
OUT_TYPE add_val2 = abs2_val.real();

atomicAdd(&obj_fln[b * I + c], add_val2);
}
}
}
}
Loading
Loading