-
Notifications
You must be signed in to change notification settings - Fork 97
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2342 from kif/2313_medfilt_ng_opencl
Medfilt in OpenCL
- Loading branch information
Showing
16 changed files
with
1,638 additions
and
57 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,7 +29,7 @@ | |
|
||
__author__ = "Jérôme Kieffer" | ||
__contact__ = "[email protected]" | ||
__date__ = "19/11/2024" | ||
__date__ = "05/12/2024" | ||
__status__ = "stable" | ||
__license__ = "MIT" | ||
|
||
|
@@ -859,7 +859,9 @@ cdef class CsrIntegrator(object): | |
for i in range(start, stop): | ||
former_element = element | ||
element = work[i] | ||
if (qmin<=former_element.s0) and (element.s0 <= qmax): | ||
if ((element.s3!=0) and | ||
(((qmin<=former_element.s0) and (element.s0 <= qmax)) or | ||
((qmin>=former_element.s0) and (element.s0 >= qmax)))): #specific case where qmin==qmax | ||
acc_sig = acc_sig + element.s1 | ||
acc_var = acc_var + element.s2 | ||
acc_norm = acc_norm + element.s3 | ||
|
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,7 +33,7 @@ | |
__contact__ = "[email protected]" | ||
__license__ = "MIT" | ||
__copyright__ = "2013 European Synchrotron Radiation Facility, Grenoble, France" | ||
__date__ = "12/11/2024" | ||
__date__ = "21/11/2024" | ||
|
||
import logging | ||
import numpy | ||
|
@@ -250,7 +250,7 @@ def test_sort(self): | |
data_d = pyopencl.array.to_device(self.queue, data) | ||
# print(ref.shape, (ref.shape[0],min(wg, self.max_valid_wg)), (1, min(wg, self.max_valid_wg)), positions) | ||
try: | ||
evt = self.program.test_combsort_float(self.queue, (ref.shape[0],min(wg, self.max_valid_wg)), (1, min(wg, self.max_valid_wg)), | ||
evt = self.program.test_combsort_float(self.queue, (min(wg, self.max_valid_wg), ref.shape[0]), (min(wg, self.max_valid_wg), 1), | ||
data_d.data, | ||
positions_d.data, | ||
pyopencl.LocalMemory(4*min(wg, self.max_valid_wg))) | ||
|
@@ -290,7 +290,7 @@ def test_sort4(self): | |
data_d = pyopencl.array.to_device(self.queue, data) | ||
# print(ref.shape, (ref.shape[0],min(wg, self.max_valid_wg)), (1, min(wg, self.max_valid_wg)), positions) | ||
try: | ||
evt = self.program.test_combsort_float4(self.queue, (ref.shape[0],min(wg, self.max_valid_wg)), (1, min(wg, self.max_valid_wg)), | ||
evt = self.program.test_combsort_float4(self.queue, (min(wg, self.max_valid_wg), ref.shape[0]), (min(wg, self.max_valid_wg),1), | ||
data_d.data, | ||
positions_d.data, | ||
pyopencl.LocalMemory(4*min(wg, self.max_valid_wg))) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,7 +33,7 @@ | |
__contact__ = "[email protected]" | ||
__license__ = "MIT" | ||
__copyright__ = "2019-2021 European Synchrotron Radiation Facility, Grenoble, France" | ||
__date__ = "28/06/2022" | ||
__date__ = "06/12/2024" | ||
|
||
import logging | ||
import numpy | ||
|
@@ -43,10 +43,9 @@ | |
if ocl: | ||
import pyopencl.array | ||
from ...test.utilstest import UtilsTest | ||
from silx.opencl.common import _measure_workgroup_size | ||
from ...integrator.azimuthal import AzimuthalIntegrator | ||
from ...method_registry import IntegrationMethod | ||
from scipy.ndimage import gaussian_filter1d | ||
from ...containers import ErrorModel | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
|
@@ -81,10 +80,12 @@ def tearDownClass(cls): | |
cls.ai = None | ||
|
||
@unittest.skipUnless(ocl, "pyopencl is missing") | ||
def integrate_ng(self, block_size=None): | ||
def integrate_ng(self, block_size=None, method_called="integrate_ng", extra=None): | ||
""" | ||
tests the 1d histogram kernel, with variable workgroup size | ||
""" | ||
if extra is None: | ||
extra={} | ||
from ..azim_csr import OCL_CSR_Integrator | ||
data = numpy.ones(self.ai.detector.shape) | ||
npt = 500 | ||
|
@@ -95,14 +96,14 @@ def integrate_ng(self, block_size=None): | |
dim=1, default=None, degradable=False) | ||
|
||
# Retrieve the CSR array | ||
cpu_integrate = self.ai._integrate1d_legacy(data, npt, unit=unit, method=csr_method) | ||
cpu_integrate = self.ai._integrate1d_ng(data, npt, unit=unit, method=csr_method) | ||
r_m = cpu_integrate[0] | ||
csr_engine = list(self.ai.engines.values())[0] | ||
csr = csr_engine.engine.lut | ||
ref = self.ai._integrate1d_ng(data, npt, unit=unit, method=method) | ||
integrator = OCL_CSR_Integrator(csr, data.size, block_size=block_size) | ||
ref = self.ai._integrate1d_ng(data, npt, unit=unit, method=method, error_model="poisson") | ||
integrator = OCL_CSR_Integrator(csr, data.size, block_size=block_size, empty=-1) | ||
solidangle = self.ai.solidAngleArray() | ||
res = integrator.integrate_ng(data, solidangle=solidangle) | ||
res = integrator.__getattribute__(method_called)(data, solidangle=solidangle, error_model=ErrorModel.POISSON, **extra) | ||
# for info, res contains: position intensity error signal variance normalization count | ||
|
||
# Start with smth easy: the position | ||
|
@@ -112,23 +113,32 @@ def integrate_ng(self, block_size=None): | |
if "AMD" in integrator.ctx.devices[0].platform.name: | ||
logger.warning("This test is known to be complicated for AMD-GPU, relax the constrains for them") | ||
else: | ||
self.assertLessEqual(delta.max(), 1, "counts are almost the same") | ||
self.assertEqual(delta.sum(), 0, "as much + and -") | ||
if method_called=="integrate_ng": | ||
self.assertLessEqual(delta.max(), 1, "counts are almost the same") | ||
self.assertEqual(delta.sum(), 0, "as much + and -") | ||
elif method_called=="medfilt": | ||
pix = csr[2][1:]-csr[2][:-1] | ||
self.assertTrue(numpy.allclose(res.count, pix), "all pixels have been counted") | ||
|
||
# Intensities are not that different: | ||
delta = ref.intensity - res.intensity | ||
self.assertLessEqual(abs(delta.max()), 1e-5, "intensity is almost the same") | ||
|
||
# histogram of normalization | ||
ref = self.ai._integrate1d_ng(solidangle, npt, unit=unit, method=method).sum_signal | ||
sig = res.normalization | ||
err = abs((sig - ref).max()) | ||
self.assertLess(err, 5e-4, "normalization content is the same: %s<5e-5" % (err)) | ||
# print(ref.sum_normalization) | ||
# print(res.normalization) | ||
err = abs((res.normalization - ref.sum_normalization)) | ||
# print(err) | ||
self.assertLess(err.max(), 5e-4, "normalization content is the same: %s<5e-5" % (err.max)) | ||
|
||
# histogram of signal | ||
ref = self.ai._integrate1d_ng(data, npt, unit=unit, method=method).sum_signal | ||
sig = res.signal | ||
self.assertLess(abs((sig - ref).sum()), 5e-5, "signal content is the same") | ||
self.assertLess(abs((res.signal - ref.sum_signal)).max(), 5e-5, "signal content is the same") | ||
|
||
# histogram of variance | ||
self.assertLess(abs((res.variance - ref.sum_variance)).max(), 5e-5, "signal content is the same") | ||
|
||
# Intensities are not that different: | ||
delta = ref.intensity - res.intensity | ||
# print(delta) | ||
self.assertLessEqual(abs(delta).max(), 1e-5, "intensity is almost the same") | ||
|
||
|
||
@unittest.skipUnless(ocl, "pyopencl is missing") | ||
def test_integrate_ng(self): | ||
|
@@ -144,6 +154,20 @@ def test_integrate_ng_single(self): | |
""" | ||
self.integrate_ng(block_size=1) | ||
|
||
@unittest.skipUnless(ocl, "pyopencl is missing") | ||
def test_sigma_clip(self): | ||
""" | ||
tests the sigma-clipping kernel, default block size | ||
""" | ||
self.integrate_ng(None, "sigma_clip",{"cutoff":100.0, "cycle":0,}) | ||
|
||
@unittest.skipUnless(ocl, "pyopencl is missing") | ||
def test_medfilt(self): | ||
""" | ||
tests the median filtering kernel, default block size | ||
""" | ||
self.integrate_ng(None, "medfilt", {"quant_min":0, "quant_max":1}) | ||
|
||
|
||
def suite(): | ||
loader = unittest.defaultTestLoader.loadTestsFromTestCase | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.