From a341657c00b9d3372b9b47f446158b122a2a5b71 Mon Sep 17 00:00:00 2001 From: Jerome Kieffer Date: Wed, 6 Nov 2024 17:45:07 +0100 Subject: [PATCH] Provide test for the Blelloch algorithm --- src/pyFAI/opencl/test/test_collective.py | 7 +++---- src/pyFAI/resources/openCL/collective/scan.cl | 21 +++++++++---------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/src/pyFAI/opencl/test/test_collective.py b/src/pyFAI/opencl/test/test_collective.py index d9b74a2be..688f48d7e 100644 --- a/src/pyFAI/opencl/test/test_collective.py +++ b/src/pyFAI/opencl/test/test_collective.py @@ -33,7 +33,7 @@ __contact__ = "jerome.kieffer@esrf.eu" __license__ = "MIT" __copyright__ = "2013 European Synchrotron Radiation Facility, Grenoble, France" -__date__ = "30/10/2024" +__date__ = "06/11/2024" import logging import numpy @@ -172,7 +172,6 @@ def test_Hillis_Steele(self): self.assertTrue(good, "Cumsum calculation is correct for WG=%s" % wg) @unittest.skipUnless(ocl, "pyopencl is missing") - @unittest.skip("Fix me") def test_Blelloch(self): """ tests the Blelloch scan function @@ -192,8 +191,8 @@ def test_Blelloch(self): logger.error("Error %s on WG=%s: Hillis_Steele", error, wg) break else: - res = scan_d.get().reshape((-1, wg)) - ref = numpy.array([numpy.cumsum(i) for i in self.data.reshape((-1, wg))]) + res = scan_d.get().reshape((-1, 2*wg)) + ref = numpy.array([numpy.cumsum(i) for i in self.data.reshape((-1, 2*wg))]) good = numpy.allclose(res, ref) if not good: print(ref) diff --git a/src/pyFAI/resources/openCL/collective/scan.cl b/src/pyFAI/resources/openCL/collective/scan.cl index 74cb48343..74b626d7d 100644 --- a/src/pyFAI/resources/openCL/collective/scan.cl +++ b/src/pyFAI/resources/openCL/collective/scan.cl @@ -62,6 +62,7 @@ kernel void test_cumsum(global float* input, * Implements Blelloch algorithm * https://en.wikipedia.org/wiki/Prefix_sum#cite_ref-offman_10-0 * + * One workgroup calculates the cumsum in an array of twice its size! */ @@ -70,34 +71,32 @@ void static inline blelloch_scan_float(local float *shared) int ws = get_local_size(0); int lid = get_local_id(0); int dp = 1; + int w; - for(int s = ws>>1; s > 0; s >>= 1) + for(int s = ws; s > 0; s >>= 1) { barrier(CLK_LOCAL_MEM_FENCE); if(lid < s) { int i = dp*(2*lid+1)-1; - int j = dp*(2*lid+2)-1; + int j = i + dp; shared[j] += shared[i]; } dp <<= 1; } - if(lid == 0) - shared[ws-1] = 0; - - for(int s = 1; s < ws; s <<= 1) + dp >>= 1; + for(int s = 1; s < ws; s=((s+1)<<1)-1) { + w = dp; dp >>= 1; + barrier(CLK_LOCAL_MEM_FENCE); if(lid < s) { - int i = dp*(2*lid+1)-1; - int j = dp*(2*lid+2)-1; - - float t = shared[j]; + int i = (lid+1)*w - 1; + int j = i + dp; shared[j] += shared[i]; - shared[i] = t; } } barrier(CLK_LOCAL_MEM_FENCE);