Skip to content

Commit

Permalink
Resolves #1279 - Add uint64 support for broadcast (#1283)
Browse files Browse the repository at this point in the history
* Resolves #1279 - Add `uint64` support for `broadcast`

This PR (Resolves #1279):
- Fixes bug where `ak.broadcast` and `ak.GroupBy.broadcast` throw a TypeError with values of type `uint64`

* Updated test to use `assertListEqual`

* Update test to compare against signed int broadcast

Co-authored-by: Pierce Hayes <[email protected]>
  • Loading branch information
stress-tess and Pierce Hayes authored Apr 15, 2022
1 parent 61357c4 commit 563867b
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 2 deletions.
10 changes: 10 additions & 0 deletions src/BroadcastMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ module BroadcastMsg {
var res = st.addEntry(rname, size, int);
res.a = broadcast(perm.a, segs.a, vals.a);
}
when DType.UInt64 {
const vals = toSymEntry(gv, uint);
var res = st.addEntry(rname, size, uint);
res.a = broadcast(perm.a, segs.a, vals.a);
}
when DType.Float64 {
const vals = toSymEntry(gv, real);
var res = st.addEntry(rname, size, real);
Expand All @@ -87,6 +92,11 @@ module BroadcastMsg {
var res = st.addEntry(rname, size, int);
res.a = broadcast(segs.a, vals.a, size);
}
when DType.UInt64 {
const vals = toSymEntry(gv, uint);
var res = st.addEntry(rname, size, uint);
res.a = broadcast(segs.a, vals.a, size);
}
when DType.Float64 {
const vals = toSymEntry(gv, real);
var res = st.addEntry(rname, size, real);
Expand Down
36 changes: 34 additions & 2 deletions tests/groupby_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ def setUp(self):
self.bvalues = ak.randint(0,1,10,dtype=bool)
self.fvalues = ak.randint(0,1,10,dtype=float)
self.ivalues = ak.array([4, 1, 3, 2, 2, 2, 5, 5, 2, 3])
self.uvalues = ak.cast(self.ivalues, ak.uint64)
self.igb = ak.GroupBy(self.ivalues)
self.ugb = ak.GroupBy(self.uvalues)

def test_groupby_on_one_level(self):
'''
Expand Down Expand Up @@ -195,8 +197,38 @@ def test_broadcast_ints(self):
self.assertTrue((np.array([0,0,0,0,0,1,1,0,1,1]),results.to_ndarray()))

results = self.igb.broadcast(1*(counts < 4))
self.assertTrue((np.array([1,0,0,0,0,1,1,1,1,1]),results.to_ndarray()))

self.assertTrue((np.array([1,0,0,0,0,1,1,1,1,1]),results.to_ndarray()))

def test_broadcast_uints(self):
keys, counts = self.ugb.count()
self.assertTrue((np.array([1, 4, 2, 1, 2]) == counts.to_ndarray()).all())
self.assertTrue((np.array([1, 2, 3, 4, 5]) == keys.to_ndarray()).all())

u_results = self.ugb.broadcast(1 * (counts > 2))
i_results = self.igb.broadcast(1 * (counts > 2))
self.assertTrue((i_results == u_results).all())

u_results = self.ugb.broadcast(1 * (counts == 2))
i_results = self.igb.broadcast(1 * (counts == 2))
self.assertTrue((i_results == u_results).all())

u_results = self.ugb.broadcast(1 * (counts < 4))
i_results = self.igb.broadcast(1 * (counts < 4))
self.assertTrue((i_results == u_results).all())

# test uint Groupby.broadcast with and without permute
u_results = self.ugb.broadcast(ak.array([1, 2, 6, 8, 9], dtype=ak.uint64), permute=False)
i_results = self.igb.broadcast(ak.array([1, 2, 6, 8, 9], dtype=ak.uint64), permute=False)
self.assertTrue((i_results == u_results).all())
u_results = self.ugb.broadcast(ak.array([1, 2, 6, 8, 9], dtype=ak.uint64))
i_results = self.igb.broadcast(ak.array([1, 2, 6, 8, 9], dtype=ak.uint64))
self.assertTrue((i_results == u_results).all())

# test uint broadcast
u_results = ak.broadcast(ak.array([0]), ak.array([1], dtype=ak.uint64), 1)
i_results = ak.broadcast(ak.array([0]), ak.array([1]), 1)
self.assertTrue((i_results == u_results).all())

def test_broadcast_booleans(self):
keys,counts = self.igb.count()

Expand Down

0 comments on commit 563867b

Please sign in to comment.