Skip to content

Commit

Permalink
Merge pull request #1055 from Bears-R-Us/sum-precision
Browse files Browse the repository at this point in the history
#964 Fix sum precision for real this time
  • Loading branch information
reuster986 authored Feb 1, 2022
2 parents 5736813 + a93e814 commit 4b59849
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 28 deletions.
119 changes: 91 additions & 28 deletions src/ReductionMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -411,37 +411,100 @@ module ReductionMsg
and then reduce over each chunk using the operator <Op>. The return array
of reduced values is the same size as <segments>.
*/
proc segSum(values:[] ?intype, segments:[?D] int, skipNan=false) throws {

proc segSum(values:[?vD] ?intype, segments:[?D] int, skipNan=false) throws {
type t = if intype == bool then int else intype;
var res: [D] t;
if (D.size == 0) { return res; }
var cumsum;
if (isFloatType(t) && skipNan) {
var arrCopy = [elem in values] if isnan(elem) then 0.0 else elem;
// check there's enough room to create a copy for scan and throw if creating a copy would go over memory limit
overMemLimit(numBytes(t) * arrCopy.size);
cumsum = + scan arrCopy;
// Set reset flag at segment boundaries
var flagvalues: [vD] (bool, t); // = [v in values] (false, v);
if isFloatType(t) && skipNan {
forall (fv, val) in zip(flagvalues, values) {
fv = if isnan(val) then (false, 0.0) else (false, val);
}
} else {
forall (fv, val) in zip(flagvalues, values) {
fv = (false, val:t);
}
}
else {
// check there's enough room to create a copy for scan and throw if creating a copy would go over memory limit
overMemLimit(numBytes(t) * values.size);
cumsum = + scan values;
forall s in segments with (var agg = newDstAggregator(bool)) {
agg.copy(flagvalues[s][0], true);
}
// Iterate over segments
var rightvals: [D] t;
forall (i, r) in zip(D, rightvals) with (var agg = newSrcAggregator(t)) {
// Find the segment boundaries
if (i == D.high) {
agg.copy(r, cumsum[values.domain.high]);
} else {
agg.copy(r, cumsum[segments[i+1] - 1]);
}
// check there's enough room to create a copy for scan and throw if creating a copy would go over memory limit
overMemLimit((numBytes(t)+1) * flagvalues.size);
// Scan with custom operator, which resets the bitwise AND
// at segment boundaries.
const scanresult = ResettingPlusScanOp scan flagvalues;
// Read the results from the last element of each segment
forall (r, s) in zip(res[..D.high-1], segments[D.low+1..]) with (var agg = newSrcAggregator(t)) {
agg.copy(r, scanresult[s-1](1));
}
res[D.low] = rightvals[D.low];
res[D.low+1..] = rightvals[D.low+1..] - rightvals[..D.high-1];
res[D.high] = scanresult[vD.high](1);
return res;
}

/* Performs a bitwise sum scan, controlled by a reset flag. While
* the reset flag is false, the accumulation of values proceeds as
* normal. When a true is encountered, the state resets to the
* identity. */
class ResettingPlusScanOp: ReduceScanOp {
type eltType;
/* value is a tuple comprising a flag and the actual result of
segmented sum.
The meaning of the flag depends on whether it belongs to an
array element yet to be scanned or to an element that has
already been scanned (including the internal state of a class
instance doing the scanning). For elements yet to be scanned,
the flag means "reset to the identity here". For elements that
have already been scanned, or for internal state, the flag means
"there has already been a reset in the computation of this value".
*/
var value = if eltType == (bool, real) then (false, 0.0) else (false, 0);

proc identity return if eltType == (bool, real) then (false, 0.0) else (false, 0);

proc accumulate(x) {
// Assume x is an element that has not yet been scanned, and
// that it comes after the current state.
const (reset, other) = x;
const (hasReset, v) = value;
// x's reset flag controls whether value gets replaced or combined
// also update this instance's "hasReset" flag with x's reset flag
value = (hasReset | reset, if reset then other else (v + other));
}

proc accumulateOntoState(ref state, x) {
// Assume state is an element that has already been scanned,
// and x is an update from a previous boundary.
const (prevReset, other) = x;
const (hasReset, v) = state;
// absorb reset history
// If state has already encountered a reset, then it should
// ignore x's value
state = (hasReset | prevReset, if hasReset then v else (v + other));
}

proc combine(x) {
// Assume x is an instance that scanned a prior chunk.
const (xHasReset, other) = x.value;
const (hasReset, v) = value;
// Since current instance is absorbing x's history,
// xHasReset flag should be ORed in.
// But if current instance has already encountered a reset,
// then it should ignore x's value.
value = (hasReset | xHasReset, if hasReset then v else (v + other));
}

proc generate() {
return value;
}

proc clone() {
return new unmanaged ResettingPlusScanOp(eltType=eltType);
}
}

proc segProduct(values:[] ?t, segments:[?D] int, skipNan=false): [D] real throws {
/* Compute the product of values in each segment. The logic here
is to convert the product into a sum in the log-domain. To
Expand Down Expand Up @@ -714,12 +777,12 @@ module ReductionMsg
proc accumulateOntoState(ref state, x) {
// Assume state is an element that has already been scanned,
// and x is an update from a previous boundary.
const (_, other) = x;
const (prevReset, other) = x;
const (hasReset, v) = state;
// x's hasReset flag does not matter
// absorb reset history
// If state has already encountered a reset, then it should
// ignore x's value
state = (hasReset, if hasReset then v else (v | other));
state = (hasReset | prevReset, if hasReset then v else (v | other));
}

proc combine(x) {
Expand Down Expand Up @@ -796,12 +859,12 @@ module ReductionMsg
proc accumulateOntoState(ref state, x) {
// Assume state is an element that has already been scanned,
// and x is an update from a previous boundary.
const (_, other) = x;
const (prevReset, other) = x;
const (hasReset, v) = state;
// x's hasReset flag does not matter
// absorb reset history
// If state has already encountered a reset, then it should
// ignore x's value
state = (hasReset, if hasReset then v else (v & other));
state = (hasReset | prevReset, if hasReset then v else (v & other));
}

proc combine(x) {
Expand Down
68 changes: 68 additions & 0 deletions test/ScanOpTest.chpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* The purpose of this test was to reproduce a bug that only occurs
* with multiple locales and a large number of cores per locale, so
* it was not caught by the CI.
* To trigger those conditions on a single locale, chapel must be
* built in multi-locale mode (e.g. CHPL_COMM=gasnet) and this
* program must be run in an oversubscribed configuration, e.g.:
*
* CHPL_RT_OVERSUBSCRIBED=yes CHPL_RT_NUM_THREADS_PER_LOCALE=8 test-bin/ScanOpTest -nl 4 --SIZE=32
*/

use TestBase;
use CommAggregation;
use ReductionMsg;

config const SIZE = numLocales * here.maxTaskPar;
config const GROUPS = min(SIZE, 8);
config const offset = 0;
config const DEBUG = false;

proc makeArrays() {
const sD = makeDistDom(GROUPS);
const D = makeDistDom(SIZE);
var keys: [D] int;
var segs: [sD] int;
forall (i, k) in zip(D, keys) {
var key = (i - offset) / (SIZE / GROUPS);
if key < 0 {
k = 0;
} else if key >= GROUPS {
k = GROUPS - 1;
} else {
k = key;
if ((i - offset) % (SIZE / GROUPS)) == 0 {
segs[key] = i;
}
}
}
segs[0] = 0;
var ones: [D] int = 1;
var ans: [sD] int;
for g in sD {
ans[g] = + reduce (keys == g);
}
return (keys, segs, ones, ans);
}

proc writeCols(names: string, a:[?D] int, b: [D] int, c: [D] int, d: [D] int) {
writeln(names);
for i in D {
var line = "%2i %3i %3i %3i %3i".format(i, a[i], b[i], c[i], d[i]);
writeln(line);
}
}

proc main() {
const (keys, segments, values, answers) = makeArrays();
var res = segSum(values, segments);
if DEBUG {
var diff = res - answers;
writeCols("grp st size res diff", segments, answers, res, diff);
}

if !(&& reduce (res == answers)) {
writeln(">>> Incorrect result <<<");
}
}
Empty file added test/ScanOpTest.good
Empty file.
16 changes: 16 additions & 0 deletions tests/numeric_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,19 @@ def test_isnan(self):
ark_s_int64 = ak.array(np.array([1, 2, 3, 4], dtype="int64"))
with self.assertRaises(RuntimeError, msg="Currently isnan on int64 is not supported"):
ak.isnan(ark_s_int64)

def testPrecision(self):
# See https://github.com/Bears-R-Us/arkouda/issues/964
# Grouped sum was exacerbating floating point errors
# This test verifies the fix
N = 10**6
G = N // 10
ub = 2**63 // N
groupnum = ak.randint(0, G, N, seed=1)
intval = ak.randint(0, ub, N, seed=2)
floatval = ak.cast(intval, ak.float64)
g = ak.GroupBy(groupnum)
_, intmean = g.mean(intval)
_, floatmean = g.mean(floatval)
ak_mse = ak.mean((intmean - floatmean)**2)
self.assertTrue(np.isclose(ak_mse, 0.0))

0 comments on commit 4b59849

Please sign in to comment.