Skip to content

Commit

Permalink
Take sqrt of the sum
Browse files Browse the repository at this point in the history
  • Loading branch information
jfowkes committed Sep 18, 2024
1 parent 8701960 commit f6b3437
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions ptypy/engines/ML.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,9 @@ def engine_iterate(self, num=1):
if self.p.wavefield_precond:
self.ob_fln += self.p.wavefield_delta_object
self.pr_fln += self.p.wavefield_delta_probe
new_ob_grad /= self.ob_fln
new_pr_grad /= self.pr_fln
for name, s in new_ob_grad.storages.items():
new_ob_grad.storages[name].data /= np.sqrt(self.ob_fln.storages[name].data)
new_pr_grad.storages[name].data /= np.sqrt(self.pr_fln.storages[name].data)

# Smoothing preconditioner
if self.smooth_gradient:
Expand Down Expand Up @@ -314,10 +315,10 @@ def engine_iterate(self, num=1):
# Smoothing and wavefield preconditioners for the object
if self.smooth_gradient and self.p.wavefield_precond:
for name, s in self.ob_h.storages.items():
s.data[:] -= self.smooth_gradient(self.ob_grad.storages[name].data / self.ob_fln.storages[name].data)
s.data[:] -= self.smooth_gradient(self.ob_grad.storages[name].data / np.sqrt(self.ob_fln.storages[name].data))
elif self.p.wavefield_precond:
for name, s in self.ob_h.storages.items():
s.data[:] -= self.ob_grad.storages[name].data / self.ob_fln.storages[name].data
s.data[:] -= self.ob_grad.storages[name].data / np.sqrt(self.ob_fln.storages[name].data)
elif self.smooth_gradient:
for name, s in self.ob_h.storages.items():
s.data[:] -= self.smooth_gradient(self.ob_grad.storages[name].data)
Expand All @@ -329,7 +330,7 @@ def engine_iterate(self, num=1):
# Wavefield preconditioner for the probe
if self.p.wavefield_precond:
for name, s in self.pr_h.storages.items():
s.data[:] -= self.pr_grad.storages[name].data / self.pr_fln.storages[name].data
s.data[:] -= self.pr_grad.storages[name].data / np.sqrt(self.pr_fln.storages[name].data)
else:
self.pr_h -= self.pr_grad

Expand Down Expand Up @@ -587,8 +588,8 @@ def new_grad(self):

# Compute fluence maps for object and probe
if self.p.wavefield_precond:
self.ob_fln[pod.ob_view] += np.sqrt(u.abs2(pod.probe))
self.pr_fln[pod.pr_view] += np.sqrt(u.abs2(pod.object))
self.ob_fln[pod.ob_view] += u.abs2(pod.probe)
self.pr_fln[pod.pr_view] += u.abs2(pod.object)

diff_view.error = LLL
error_dct[dname] = np.array([0, LLL / np.prod(DI.shape), 0])
Expand Down Expand Up @@ -841,8 +842,8 @@ def new_grad(self):

# Compute fluence maps for object and probe
if self.p.wavefield_precond:
self.ob_fln[pod.ob_view] += np.sqrt(u.abs2(pod.probe))
self.pr_fln[pod.pr_view] += np.sqrt(u.abs2(pod.object))
self.ob_fln[pod.ob_view] += u.abs2(pod.probe)
self.pr_fln[pod.pr_view] += u.abs2(pod.object)

diff_view.error = LLL
error_dct[dname] = np.array([0, LLL / np.prod(DI.shape), 0])
Expand Down Expand Up @@ -1108,8 +1109,8 @@ def new_grad(self):

# Compute fluence maps for object and probe
if self.p.wavefield_precond:
self.ob_fln[pod.ob_view] += np.sqrt(u.abs2(pod.probe))
self.pr_fln[pod.pr_view] += np.sqrt(u.abs2(pod.object))
self.ob_fln[pod.ob_view] += u.abs2(pod.probe)
self.pr_fln[pod.pr_view] += u.abs2(pod.object)

diff_view.error = LLL
error_dct[dname] = np.array([0, LLL / np.prod(DA.shape), 0])
Expand Down

0 comments on commit f6b3437

Please sign in to comment.