Skip to content

Commit

Permalink
Provide test for the Blelloch algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
kif committed Nov 6, 2024
1 parent b47a563 commit a341657
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 15 deletions.
7 changes: 3 additions & 4 deletions src/pyFAI/opencl/test/test_collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
__contact__ = "[email protected]"
__license__ = "MIT"
__copyright__ = "2013 European Synchrotron Radiation Facility, Grenoble, France"
__date__ = "30/10/2024"
__date__ = "06/11/2024"

import logging
import numpy
Expand Down Expand Up @@ -172,7 +172,6 @@ def test_Hillis_Steele(self):
self.assertTrue(good, "Cumsum calculation is correct for WG=%s" % wg)

@unittest.skipUnless(ocl, "pyopencl is missing")
@unittest.skip("Fix me")
def test_Blelloch(self):
"""
tests the Blelloch scan function
Expand All @@ -192,8 +191,8 @@ def test_Blelloch(self):
logger.error("Error %s on WG=%s: Hillis_Steele", error, wg)
break
else:
res = scan_d.get().reshape((-1, wg))
ref = numpy.array([numpy.cumsum(i) for i in self.data.reshape((-1, wg))])
res = scan_d.get().reshape((-1, 2*wg))
ref = numpy.array([numpy.cumsum(i) for i in self.data.reshape((-1, 2*wg))])
good = numpy.allclose(res, ref)
if not good:
print(ref)
Expand Down
21 changes: 10 additions & 11 deletions src/pyFAI/resources/openCL/collective/scan.cl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ kernel void test_cumsum(global float* input,
* Implements Blelloch algorithm
* https://en.wikipedia.org/wiki/Prefix_sum#cite_ref-offman_10-0
*
* One workgroup calculates the cumsum in an array of twice its size!
*/


Expand All @@ -70,34 +71,32 @@ void static inline blelloch_scan_float(local float *shared)
int ws = get_local_size(0);
int lid = get_local_id(0);
int dp = 1;
int w;

for(int s = ws>>1; s > 0; s >>= 1)
for(int s = ws; s > 0; s >>= 1)
{
barrier(CLK_LOCAL_MEM_FENCE);
if(lid < s)
{
int i = dp*(2*lid+1)-1;
int j = dp*(2*lid+2)-1;
int j = i + dp;
shared[j] += shared[i];
}
dp <<= 1;
}

if(lid == 0)
shared[ws-1] = 0;

for(int s = 1; s < ws; s <<= 1)
dp >>= 1;
for(int s = 1; s < ws; s=((s+1)<<1)-1)
{
w = dp;
dp >>= 1;

barrier(CLK_LOCAL_MEM_FENCE);

if(lid < s) {
int i = dp*(2*lid+1)-1;
int j = dp*(2*lid+2)-1;

float t = shared[j];
int i = (lid+1)*w - 1;
int j = i + dp;
shared[j] += shared[i];
shared[i] = t;
}
}
barrier(CLK_LOCAL_MEM_FENCE);
Expand Down

0 comments on commit a341657

Please sign in to comment.