Skip to content

Commit

Permalink
Start serializing wavefield precond
Browse files Browse the repository at this point in the history
  • Loading branch information
jfowkes committed Nov 25, 2024
1 parent f6b3437 commit 4a74199
Showing 1 changed file with 44 additions and 5 deletions.
49 changes: 44 additions & 5 deletions ptypy/accelerate/base/engines/ML_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,13 @@ def _get_smooth_gradient(self, data, sigma):

def _replace_ob_grad(self):
new_ob_grad = self.ob_grad_new

# Wavefield preconditioner for the object
if self.p.wavefield_precond:
self.ob_fln += self.p.wavefield_delta_object
for name, s in new_ob_grad.storages.items():
new_ob_grad.storages[name].data /= np.sqrt(self.ob_fln.storages[name].data)

# Smoothing preconditioner
if self.smooth_gradient:
self.smooth_gradient.sigma *= (1. - self.p.smooth_gradient_decay)
Expand All @@ -175,14 +182,21 @@ def _replace_ob_grad(self):

def _replace_pr_grad(self):
new_pr_grad = self.pr_grad_new
# probe support

# Probe support
if self.p.probe_update_start <= self.curiter:
# Apply probe support if needed
for name, s in new_pr_grad.storages.items():
self.support_constraint(s)
else:
new_pr_grad.fill(0.)

# Wavefield preconditioner for the probe
if self.p.wavefield_precond:
self.pr_fln += self.p.wavefield_delta_probe
for name, s in new_pr_grad.storages.items():
new_pr_grad.storages[name].data /= np.sqrt(self.pr_fln.storages[name].data)

norm = Cnorm2(new_pr_grad)
dot = np.real(Cdot(new_pr_grad, self.pr_grad))
self.pr_grad << new_pr_grad
Expand Down Expand Up @@ -240,19 +254,30 @@ def engine_iterate(self, num=1):
self.cn2_pr_grad = cn2_new_pr_grad

dt = self.ptycho.FType

# 3. Next conjugate
self.ob_h *= dt(bt / self.tmin)

# Smoothing preconditioner
if self.smooth_gradient:
# Smoothing and wavefield preconditioners for the object
if self.smooth_gradient and self.p.wavefield_precond:
for name, s in self.ob_h.storages.items():
s.data[:] -= self._get_smooth_gradient(self.ob_grad.storages[name].data / np.sqrt(self.ob_fln.storages[name].data), self.smooth_gradient.sigma)
elif self.p.wavefield_precond:
for name, s in self.ob_h.storages.items():
s.data[:] -= self.ob_grad.storages[name].data / np.sqrt(self.ob_fln.storages[name].data)
elif self.smooth_gradient:
for name, s in self.ob_h.storages.items():
s.data[:] -= self._get_smooth_gradient(self.ob_grad.storages[name].data, self.smooth_gradient.sigma)
else:
self.ob_h -= self.ob_grad

self.pr_h *= dt(bt / self.tmin)
self.pr_grad *= dt(self.scale_p_o)
self.pr_h -= self.pr_grad
# Wavefield preconditioner for the probe
if self.p.wavefield_precond:
for name, s in self.pr_h.storages.items():
s.data[:] -= self.pr_grad.storages[name].data / np.sqrt(self.pr_fln.storages[name].data)
else:
self.pr_h -= self.pr_grad

# In principle, the way things are now programmed this part
# could be iterated over in a real Newton-Raphson style.
Expand Down Expand Up @@ -416,6 +441,11 @@ def new_grad(self):
pr_grad = self.engine.pr_grad_new
ob_grad << 0.
pr_grad << 0.
if self.engine.p.wavefield_precond:
ob_fln = self.engine.ob_fln
pr_fln = self.engine.pr_fln
ob_fln << 0.
pr_fln << 0.

# We need an array for MPI
LL = np.array([0.])
Expand Down Expand Up @@ -465,6 +495,12 @@ def new_grad(self):
GDK.error_reduce(addr, err_phot)
aux[:] = BW(aux)

#FIXME: need modified POK kernel to compute fluence maps
# if self.engine.p.wavefield_precond:
# ob_fln[pod.ob_view] += u.abs2(pod.probe)
# pr_fln[pod.pr_view] += u.abs2(pod.object)
#END FIXME

POK.ob_update_ML(addr, obg, pr, aux)
POK.pr_update_ML(addr, prg, ob, aux)

Expand All @@ -480,6 +516,9 @@ def new_grad(self):
# MPI reduction of gradients
ob_grad.allreduce()
pr_grad.allreduce()
if self.engine.p.wavefield_precond:
ob_fln.allreduce()
pr_fln.allreduce()
parallel.allreduce(LL)

# Object regularizer
Expand Down

0 comments on commit 4a74199

Please sign in to comment.