From 563867b5b5e31de818d62a1297ab3234f6431ddb Mon Sep 17 00:00:00 2001 From: pierce314159 <48131946+pierce314159@users.noreply.github.com> Date: Fri, 15 Apr 2022 09:49:41 -0400 Subject: [PATCH] Resolves #1279 - Add `uint64` support for `broadcast` (#1283) * Resolves #1279 - Add `uint64` support for `broadcast` This PR (Resolves #1279): - Fixes bug where `ak.broadcast` and `ak.GroupBy.broadcast` throw a TypeError with values of type `uint64` * Updated test to use `assertListEqual` * Update test to compare against signed int broadcast Co-authored-by: Pierce Hayes --- src/BroadcastMsg.chpl | 10 ++++++++++ tests/groupby_test.py | 36 ++++++++++++++++++++++++++++++++++-- 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/src/BroadcastMsg.chpl b/src/BroadcastMsg.chpl index 41455645b8..696b2d08b7 100644 --- a/src/BroadcastMsg.chpl +++ b/src/BroadcastMsg.chpl @@ -61,6 +61,11 @@ module BroadcastMsg { var res = st.addEntry(rname, size, int); res.a = broadcast(perm.a, segs.a, vals.a); } + when DType.UInt64 { + const vals = toSymEntry(gv, uint); + var res = st.addEntry(rname, size, uint); + res.a = broadcast(perm.a, segs.a, vals.a); + } when DType.Float64 { const vals = toSymEntry(gv, real); var res = st.addEntry(rname, size, real); @@ -87,6 +92,11 @@ module BroadcastMsg { var res = st.addEntry(rname, size, int); res.a = broadcast(segs.a, vals.a, size); } + when DType.UInt64 { + const vals = toSymEntry(gv, uint); + var res = st.addEntry(rname, size, uint); + res.a = broadcast(segs.a, vals.a, size); + } when DType.Float64 { const vals = toSymEntry(gv, real); var res = st.addEntry(rname, size, real); diff --git a/tests/groupby_test.py b/tests/groupby_test.py index 794c603ea4..5e31c09e75 100755 --- a/tests/groupby_test.py +++ b/tests/groupby_test.py @@ -129,7 +129,9 @@ def setUp(self): self.bvalues = ak.randint(0,1,10,dtype=bool) self.fvalues = ak.randint(0,1,10,dtype=float) self.ivalues = ak.array([4, 1, 3, 2, 2, 2, 5, 5, 2, 3]) + self.uvalues = ak.cast(self.ivalues, ak.uint64) self.igb = ak.GroupBy(self.ivalues) + self.ugb = ak.GroupBy(self.uvalues) def test_groupby_on_one_level(self): ''' @@ -195,8 +197,38 @@ def test_broadcast_ints(self): self.assertTrue((np.array([0,0,0,0,0,1,1,0,1,1]),results.to_ndarray())) results = self.igb.broadcast(1*(counts < 4)) - self.assertTrue((np.array([1,0,0,0,0,1,1,1,1,1]),results.to_ndarray())) - + self.assertTrue((np.array([1,0,0,0,0,1,1,1,1,1]),results.to_ndarray())) + + def test_broadcast_uints(self): + keys, counts = self.ugb.count() + self.assertTrue((np.array([1, 4, 2, 1, 2]) == counts.to_ndarray()).all()) + self.assertTrue((np.array([1, 2, 3, 4, 5]) == keys.to_ndarray()).all()) + + u_results = self.ugb.broadcast(1 * (counts > 2)) + i_results = self.igb.broadcast(1 * (counts > 2)) + self.assertTrue((i_results == u_results).all()) + + u_results = self.ugb.broadcast(1 * (counts == 2)) + i_results = self.igb.broadcast(1 * (counts == 2)) + self.assertTrue((i_results == u_results).all()) + + u_results = self.ugb.broadcast(1 * (counts < 4)) + i_results = self.igb.broadcast(1 * (counts < 4)) + self.assertTrue((i_results == u_results).all()) + + # test uint Groupby.broadcast with and without permute + u_results = self.ugb.broadcast(ak.array([1, 2, 6, 8, 9], dtype=ak.uint64), permute=False) + i_results = self.igb.broadcast(ak.array([1, 2, 6, 8, 9], dtype=ak.uint64), permute=False) + self.assertTrue((i_results == u_results).all()) + u_results = self.ugb.broadcast(ak.array([1, 2, 6, 8, 9], dtype=ak.uint64)) + i_results = self.igb.broadcast(ak.array([1, 2, 6, 8, 9], dtype=ak.uint64)) + self.assertTrue((i_results == u_results).all()) + + # test uint broadcast + u_results = ak.broadcast(ak.array([0]), ak.array([1], dtype=ak.uint64), 1) + i_results = ak.broadcast(ak.array([0]), ak.array([1]), 1) + self.assertTrue((i_results == u_results).all()) + def test_broadcast_booleans(self): keys,counts = self.igb.count()