Skip to content

Commit

Permalink
Fix pycuda.cumath import
Browse files Browse the repository at this point in the history
  • Loading branch information
jfowkes committed Jan 8, 2025
1 parent 2f34caa commit 5186fd9
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions ptypy/accelerate/cuda_pycuda/engines/ML_pycuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import numpy as np
from pycuda import gpuarray
import pycuda.driver as cuda
import pycuda.cumath
import pycuda.cumath as cm

from ptypy.engines import register
from ptypy.accelerate.base.engines.ML_serial import ML_serial, BaseModelSerial
Expand Down Expand Up @@ -279,7 +279,7 @@ def _replace_ob_grad(self):
# Wavefield preconditioner for the object
if self.p.wavefield_precond:
for name, s in new_ob_grad.storages.items():
s.gpu /= sqrt(self.ob_fln.storages[name].gpu
s.gpu /= cm.sqrt(self.ob_fln.storages[name].gpu
+ self.p.wavefield_delta_object)

# Smoothing preconditioner
Expand All @@ -304,7 +304,7 @@ def _replace_pr_grad(self):
# Wavefield preconditioner for the probe
if self.p.wavefield_precond:
for name, s in new_pr_grad.storages.items():
s.gpu /= sqrt(self.pr_fln.storages[name].gpu
s.gpu /= cm.sqrt(self.pr_fln.storages[name].gpu
+ self.p.wavefield_delta_probe)

return self._replace_grad(self.pr_grad , new_pr_grad)
Expand Down

0 comments on commit 5186fd9

Please sign in to comment.