From cc1797aa42cf24a3018f2017a7fcfb3be7a0634f Mon Sep 17 00:00:00 2001 From: ajpotts Date: Tue, 3 Dec 2024 17:54:04 -0500 Subject: [PATCH] Closes #3870: bug in reshape for bigint type (#3907) Co-authored-by: Amanda Potts --- src/AryUtil.chpl | 54 +++++++++++++++++++++++++++++++------- tests/pdarrayclass_test.py | 39 ++++++++++++++++++--------- 2 files changed, 71 insertions(+), 22 deletions(-) diff --git a/src/AryUtil.chpl b/src/AryUtil.chpl index 929d9fc1cd..365d10cd3e 100644 --- a/src/AryUtil.chpl +++ b/src/AryUtil.chpl @@ -15,6 +15,7 @@ module AryUtil use List; use CommAggregation; use CommPrimitives; + use BigInteger; param bitsPerDigit = RSLSD_bitsPerDigit; @@ -905,7 +906,8 @@ module AryUtil /* unflatten a 1D array into a multi-dimensional array of the given shape */ - proc unflatten(const ref a: [?d] ?t, shape: ?N*int): [] t throws { + proc unflatten(const ref a: [?d] ?t, shape: ?N*int): [] t throws + where t!=bigint { var unflat = makeDistArray((...shape), t); if N == 1 { @@ -952,7 +954,6 @@ module AryUtil // flat region is spread across multiple locales, do a get for each source locale for locInID in locInStart..locInStop { const flatSubSlice = flatSlice[flatLocRanges[locInID]]; - get( c_ptrTo(unflat[dufc.orderToIndex(flatSubSlice.low)]), getAddr(a[flatSubSlice.low]), @@ -967,11 +968,30 @@ module AryUtil return unflat; } + proc unflatten(const ref a: [?d] ?t, shape: ?N*int): [] t throws + where t==bigint { + var unflat = makeDistArray((...shape), t); + + if N == 1 { + unflat = a; + return unflat; + } + + coforall loc in Locales with (ref unflat) do on loc { + forall idx in a.localSubdomain() with (var agg = newDstAggregator(t)) { + agg.copy(unflat[unflat.domain.orderToIndex(idx)], a[idx]); + } + } + + return unflat; + } + /* flatten a multi-dimensional array into a 1D array */ - @arkouda.registerCommand - proc flatten(const ref a: [?d] ?t): [] t throws { + @arkouda.registerCommand(ignoreWhereClause=true) + proc flatten(const ref a: [?d] ?t): [] t throws + where t!=bigint { if a.rank == 1 then return a; var flat = makeDistArray(d.size, t); @@ -1030,6 +1050,22 @@ module AryUtil return flat; } + + proc flatten(const ref a: [?d] ?t): [] t throws + where t==bigint { + if a.rank == 1 then return a; + + var flat = makeDistArray(d.size, t); + + coforall loc in Locales with (ref flat) do on loc { + forall idx in flat.localSubdomain() with (var agg = newSrcAggregator(t)) { + agg.copy(flat[idx], a[a.domain.orderToIndex(idx)]); + } + } + + return flat; + } + // helper for computing an array element's index from its order record orderer { param rank: int; @@ -1044,10 +1080,10 @@ module AryUtil // index -> order for the input array's indices // e.g., order = k + (nz * j) + (nz * ny * i) inline proc indexToOrder(idx: rank*?t): t - where (t==int) || (t==uint(64)) { - var order : t = 0; - for param i in 0..