diff --git a/arkouda/sorting.py b/arkouda/sorting.py index 267e75fc18..0aa9ef6ae5 100644 --- a/arkouda/sorting.py +++ b/arkouda/sorting.py @@ -6,12 +6,15 @@ from arkouda.pdarraycreation import zeros from arkouda.strings import Strings from arkouda.dtypes import int64, float64 +from enum import Enum numeric_dtypes = {float64,int64} -__all__ = ["argsort", "coargsort", "sort"] +__all__ = ["argsort", "coargsort", "sort", "SortingAlgorithm"] -def argsort(pda : Union[pdarray,Strings,'Categorical']) -> pdarray: # type: ignore +SortingAlgorithm = Enum('SortingAlgorithm', ['RadixSortLSD', 'TwoArrayRadixSort']) + +def argsort(pda : Union[pdarray,Strings,'Categorical'], algorithm : SortingAlgorithm = SortingAlgorithm.RadixSortLSD) -> pdarray: # type: ignore """ Return the permutation that sorts the array. @@ -57,11 +60,11 @@ def argsort(pda : Union[pdarray,Strings,'Categorical']) -> pdarray: # type: igno name = '{}+{}'.format(pda.offsets.name, pda.bytes.name) else: name = pda.name - repMsg = generic_msg(cmd="argsort", args="{} {}".format(pda.objtype, name)) + repMsg = generic_msg(cmd="argsort", args="{} {} {}".format(algorithm.name, pda.objtype, name)) return create_pdarray(cast(str,repMsg)) -def coargsort(arrays: Sequence[Union[Strings, pdarray, 'Categorical']]) -> pdarray: # type: ignore +def coargsort(arrays: Sequence[Union[Strings, pdarray, 'Categorical']], algorithm : SortingAlgorithm = SortingAlgorithm.RadixSortLSD) -> pdarray: # type: ignore """ Return the permutation that groups the rows (left-to-right), if the input arrays are treated as columns. The permutation sorts numeric @@ -129,12 +132,14 @@ def coargsort(arrays: Sequence[Union[Strings, pdarray, 'Categorical']]) -> pdarr raise ValueError("All pdarrays, Strings, or Categoricals must be of the same size") if size == 0: return zeros(0, dtype=int64) - repMsg = generic_msg(cmd="coargsort", args="{:n} {} {}".format(len(arrays), - ' '.join(anames), ' '.join(atypes))) + repMsg = generic_msg(cmd="coargsort", args="{} {:n} {} {}".format(algorithm.name, + len(arrays), + ' '.join(anames), + ' '.join(atypes))) return create_pdarray(cast(str, repMsg)) @typechecked -def sort(pda : pdarray) -> pdarray: +def sort(pda : pdarray, algorithm : SortingAlgorithm = SortingAlgorithm.RadixSortLSD) -> pdarray: """ Return a sorted copy of the array. Only sorts numeric arrays; for Strings, use argsort. @@ -177,5 +182,5 @@ def sort(pda : pdarray) -> pdarray: return zeros(0, dtype=int64) if pda.dtype not in numeric_dtypes: raise ValueError("ak.sort supports float64 or int64, not {}".format(pda.dtype)) - repMsg = generic_msg(cmd="sort", args="{}".format(pda.name)) + repMsg = generic_msg(cmd="sort", args="{} {}".format(algorithm.name, pda.name)) return create_pdarray(cast(str,repMsg)) diff --git a/src/ArgSortMsg.chpl b/src/ArgSortMsg.chpl index ee9dd36445..5853743136 100644 --- a/src/ArgSortMsg.chpl +++ b/src/ArgSortMsg.chpl @@ -10,7 +10,7 @@ module ArgSortMsg use Time only; use Math only; - use Sort only; + private use Sort; use Reflection only; use PrivateDist; @@ -46,6 +46,74 @@ module ArgSortMsg var mBins = 2**25; var lBins = 2**25 * numLocales; + enum SortingAlgorithm { + RadixSortLSD, + TwoArrayRadixSort + }; + config const defaultSortAlgorithm: SortingAlgorithm = SortingAlgorithm.RadixSortLSD; + + // proc DefaultComparator.keyPart(x: _tuple, i:int) where !isHomogeneousTuple(x) && + // (isInt(x(0)) || isUint(x(0)) || isReal(x(0))) { + + import Reflection.canResolveMethod; + record ContrivedComparator { + const dc = new DefaultComparator(); + proc keyPart(a, i: int) { + if canResolveMethod(dc, "keyPart", a, 0) { + return dc.keyPart(a, i); + } else if isTuple(a) { + return tupleKeyPart(a, i); + } else { + compilerError("No keyPart method for eltType ", a.type:string); + } + } + proc tupleKeyPart(x: _tuple, i:int) { + proc makePart(y): uint(64) { + var part: uint(64); + // get the part, ignore the section + const p = dc.keyPart(y, 0)(1); + // assuming result of keyPart is int or uint <= 64 bits + part = p:uint(64); + // If the number is signed, invert the top bit, so that + // the negative numbers sort below the positive numbers + if isInt(p) { + const one:uint(64) = 1; + part = part ^ (one << 63); + } + return part; + } + var part: uint(64); + if isTuple(x[0]) && (x.size == 2) { + for param j in 0.. x[0].size { + return (-1, 0:uint(64)); + } else { + return (0, part); + } + } else { + for param j in 0..= x.size { + return (-1, 0:uint(64)); + } else { + return (0, part); + } + } + } + } + + const myDefaultComparator = new ContrivedComparator(); + /* Perform one step in a multi-step argsort, starting with an initial permutation vector and further permuting it in the manner required to sort an array of keys. @@ -65,7 +133,7 @@ module ArgSortMsg agg.copy(newai, olda[idx]); } // Generate the next incremental permutation - deltaIV = radixSortLSD_ranks(newa); + deltaIV = argsortDefault(newa); } when DType.Float64 { var e = toSymEntry(g, real); @@ -74,7 +142,7 @@ module ArgSortMsg forall (newai, idx) in zip(newa, iv) with (var agg = newSrcAggregator(real)) { agg.copy(newai, olda[idx]); } - deltaIV = radixSortLSD_ranks(newa); + deltaIV = argsortDefault(newa); } otherwise { throw getErrorWithContext( msg="Unsupported DataType: %t".format(dtype2str(g.dtype)), @@ -100,7 +168,7 @@ module ArgSortMsg forall (nh, idx) in zip(newHashes, iv) with (var agg = newSrcAggregator((2*uint))) { agg.copy(nh, hashes[idx]); } - var deltaIV = radixSortLSD_ranks(newHashes); + var deltaIV = argsortDefault(newHashes); // var (newOffsets, newVals) = s[iv]; // var deltaIV = newStr.argGroup(); var newIV: [aD] int; @@ -130,7 +198,21 @@ module ArgSortMsg proc coargsortMsg(cmd: string, payload: string, st: borrowed SymTab): MsgTuple throws { param pn = Reflection.getRoutineName(); var repMsg: string; - var (nstr, rest) = payload.splitMsgToTuple(2); + var (algoName, nstr, rest) = payload.splitMsgToTuple(3); + var algorithm: SortingAlgorithm = defaultSortAlgorithm; + if algoName != "" { + try { + algorithm = algoName: SortingAlgorithm; + } catch { + throw getErrorWithContext( + msg="Unrecognized sorting algorithm: %s".format(algoName), + lineNumber=getLineNumber(), + routineName=getRoutineName(), + moduleName=getModuleName(), + errorClass="NotImplementedError" + ); + } + } var n = nstr:int; // number of arrays to sort var fields = rest.split(); asLogger.debug(getModuleName(),getRoutineName(),getLineNumber(), @@ -263,7 +345,7 @@ module ArgSortMsg } } - var iv = argsortDefault(merged); + var iv = argsortDefault(merged, algorithm=algorithm); st.addEntry(ivname, new shared SymEntry(iv)); var repMsg = "created " + st.attrib(ivname); @@ -305,12 +387,28 @@ module ArgSortMsg return new MsgTuple(repMsg, MsgType.NORMAL); } - proc argsortDefault(A:[?D] ?t):[D] int throws { + proc argsortDefault(A:[?D] ?t, algorithm:SortingAlgorithm=defaultSortAlgorithm):[D] int throws { var t1 = Time.getCurrentTime(); - //var AI = [(a, i) in zip(A, D)] (a, i); - //Sort.TwoArrayRadixSort.twoArrayRadixSort(AI); - //var iv = [(a, i) in AI] i; - var iv = radixSortLSD_ranks(A); + var iv: [D] int; + select algorithm { + when SortingAlgorithm.TwoArrayRadixSort { + var AI = [(a, i) in zip(A, D)] (a, i); + Sort.TwoArrayRadixSort.twoArrayRadixSort(AI, comparator=myDefaultComparator); + iv = [(a, i) in AI] i; + } + when SortingAlgorithm.RadixSortLSD { + iv = radixSortLSD_ranks(A); + } + otherwise { + throw getErrorWithContext( + msg="Unrecognized sorting algorithm: %s".format(algorithm:string), + lineNumber=getLineNumber(), + routineName=getRoutineName(), + moduleName=getModuleName(), + errorClass="NotImplementedError" + ); + } + } try! asLogger.debug(getModuleName(),getRoutineName(),getLineNumber(), "argsort time = %i".format(Time.getCurrentTime() - t1)); return iv; @@ -321,8 +419,21 @@ module ArgSortMsg param pn = Reflection.getRoutineName(); var repMsg: string; // response message // split request into fields - var (objtype, name) = payload.splitMsgToTuple(2); - + var (algoName, objtype, name) = payload.splitMsgToTuple(3); + var algorithm: SortingAlgorithm = defaultSortAlgorithm; + if algoName != "" { + try { + algorithm = algoName: SortingAlgorithm; + } catch { + throw getErrorWithContext( + msg="Unrecognized sorting algorithm: %s".format(algoName), + lineNumber=getLineNumber(), + routineName=getRoutineName(), + moduleName=getModuleName(), + errorClass="NotImplementedError" + ); + } + } // get next symbol name var ivname = st.nextName(); asLogger.debug(getModuleName(),getRoutineName(),getLineNumber(), @@ -338,7 +449,7 @@ module ArgSortMsg select (gEnt.dtype) { when (DType.Int64) { var e = toSymEntry(gEnt,int); - var iv = argsortDefault(e.a); + var iv = argsortDefault(e.a, algorithm=algorithm); st.addEntry(ivname, new shared SymEntry(iv)); } when (DType.Float64) { diff --git a/src/SortMsg.chpl b/src/SortMsg.chpl index 496f017667..13c06ef110 100644 --- a/src/SortMsg.chpl +++ b/src/SortMsg.chpl @@ -14,6 +14,7 @@ module SortMsg use AryUtil; use Logging; use Message; + private use ArgSortMsg; private config const logLevel = ServerConfig.logLevel; const sortLogger = new Logger(logLevel); @@ -29,8 +30,21 @@ module SortMsg proc sortMsg(cmd: string, payload: string, st: borrowed SymTab): MsgTuple throws { param pn = Reflection.getRoutineName(); var repMsg: string; // response message - var (name) = payload.splitMsgToTuple(1); - + var (algoName, name) = payload.splitMsgToTuple(2); + var algorithm: SortingAlgorithm = defaultSortAlgorithm; + if algoName != "" { + try { + algorithm = algoName: SortingAlgorithm; + } catch { + throw getErrorWithContext( + msg="Unrecognized sorting algorithm: %s".format(algoName), + lineNumber=getLineNumber(), + routineName=getRoutineName(), + moduleName=getModuleName(), + errorClass="NotImplementedError" + ); + } + } // get next symbol name var sortedName = st.nextName(); @@ -43,18 +57,39 @@ module SortMsg sortLogger.debug(getModuleName(),getRoutineName(),getLineNumber(), "cmd: %s name: %s sortedName: %s dtype: %t".format( cmd, name, sortedName, gEnt.dtype)); - + + proc doSort(a: [?D] ?t) throws { + select algorithm { + when SortingAlgorithm.TwoArrayRadixSort { + var b: [D] t = a; + Sort.TwoArrayRadixSort.twoArrayRadixSort(b, comparator=myDefaultComparator); + return b; + } + when SortingAlgorithm.RadixSortLSD { + return radixSortLSD_keys(a); + } + otherwise { + throw getErrorWithContext( + msg="Unrecognized sorting algorithm: %s".format(algorithm:string), + lineNumber=getLineNumber(), + routineName=getRoutineName(), + moduleName=getModuleName(), + errorClass="NotImplementedError" + ); + } + } + } // Sort the input pda and create a new symbol entry for // the sorted pda. select (gEnt.dtype) { when (DType.Int64) { var e = toSymEntry(gEnt, int); - var sorted = sort(e.a); + var sorted = doSort(e.a); st.addEntry(sortedName, new shared SymEntry(sorted)); }// end when(DType.Int64) when (DType.Float64) { var e = toSymEntry(gEnt, real); - var sorted = sort(e.a); + var sorted = doSort(e.a); st.addEntry(sortedName, new shared SymEntry(sorted)); }// end when(DType.Float64) otherwise { diff --git a/test/UnitTestArgSortMsg.chpl b/test/UnitTestArgSortMsg.chpl index 4f391fb7c7..48a3a145d8 100644 --- a/test/UnitTestArgSortMsg.chpl +++ b/test/UnitTestArgSortMsg.chpl @@ -22,7 +22,7 @@ prototype module UnitTestArgSort // sort it and return iv in symbol table var cmd = "argsort"; - reqMsg = try! "pdarray %s".format(aname); + reqMsg = try! "%s pdarray %s".format(ArgSortMsg.SortingAlgorithm.RadixSortLSD: string, aname); writeReq(reqMsg); var d: Diags; d.start(); @@ -57,7 +57,7 @@ prototype module UnitTestArgSort // cosort both int and real arrays and return iv in symbol table cmd = "coargsort"; - reqMsg = try! "%i %s %s pdarray pdarray".format(2, aname, fname); + reqMsg = try! "%s %i %s %s pdarray pdarray".format(ArgSortMsg.SortingAlgorithm.RadixSortLSD: string, 2, aname, fname); writeReq(reqMsg); d.start(); repMsg = coargsortMsg(cmd=cmd, payload=reqMsg, st).msg; diff --git a/tests/coargsort_test.py b/tests/coargsort_test.py index f1c78cea0e..d668a6d4dd 100755 --- a/tests/coargsort_test.py +++ b/tests/coargsort_test.py @@ -5,7 +5,7 @@ from context import arkouda as ak from base_test import ArkoudaTest -def check_int(N): +def check_int(N, algo): z = ak.zeros(N, dtype=ak.int64) a2 = ak.randint(0, 2**16, N) @@ -14,22 +14,22 @@ def check_int(N): d2 = ak.randint(0, 2**16, N) n2 = ak.randint(-(2**15), 2**15, N) - perm = ak.coargsort([a2]) + perm = ak.coargsort([a2], algo) assert ak.is_sorted(a2[perm]) - perm = ak.coargsort([n2]) + perm = ak.coargsort([n2], algo) assert ak.is_sorted(n2[perm]) - perm = ak.coargsort([a2, b2, c2, d2]) + perm = ak.coargsort([a2, b2, c2, d2], algo) assert ak.is_sorted(a2[perm]) - perm = ak.coargsort([z, b2, c2, d2]) + perm = ak.coargsort([z, b2, c2, d2], algo) assert ak.is_sorted(b2[perm]) - perm = ak.coargsort([z, z, c2, d2]) + perm = ak.coargsort([z, z, c2, d2], algo) assert ak.is_sorted(c2[perm]) - perm = ak.coargsort([z, z, z, d2]) + perm = ak.coargsort([z, z, z, d2], algo) assert ak.is_sorted(d2[perm]) @@ -37,16 +37,16 @@ def check_int(N): b4 = ak.randint(0, 2**32, N) n4 = ak.randint(-(2**31), 2**31, N) - perm = ak.coargsort([a4]) + perm = ak.coargsort([a4], algo) assert ak.is_sorted(a4[perm]) - perm = ak.coargsort([n4]) + perm = ak.coargsort([n4], algo) assert ak.is_sorted(n4[perm]) - perm = ak.coargsort([a4, b4]) + perm = ak.coargsort([a4, b4], algo) assert ak.is_sorted(a4[perm]) - perm = ak.coargsort([b4, a4]) + perm = ak.coargsort([b4, a4], algo) assert ak.is_sorted(b4[perm]) @@ -55,55 +55,55 @@ def check_int(N): n8 = ak.randint(-(2**63), 2**64, N) - perm = ak.coargsort([a8]) + perm = ak.coargsort([a8], algo) assert ak.is_sorted(a8[perm]) - perm = ak.coargsort([n8]) + perm = ak.coargsort([n8], algo) assert ak.is_sorted(n8[perm]) - perm = ak.coargsort([b8, a8]) + perm = ak.coargsort([b8, a8], algo) assert ak.is_sorted(b8[perm]) from itertools import permutations all_perm = permutations([a2, a4, a8]) for p in all_perm: - perm = ak.coargsort(p) + perm = ak.coargsort(p, algo) assert ak.is_sorted(p[0][perm]) -def check_float(N): +def check_float(N, algo): a = ak.randint(0, 1, N, dtype=ak.float64) n = ak.randint(-1, 1, N, dtype=ak.float64) z = ak.zeros(N, dtype=ak.float64) - perm = ak.coargsort([a]) + perm = ak.coargsort([a], algo) assert ak.is_sorted(a[perm]) - perm = ak.coargsort([a, n]) + perm = ak.coargsort([a, n], algo) assert ak.is_sorted(a[perm]) - perm = ak.coargsort([n, a]) + perm = ak.coargsort([n, a], algo) assert ak.is_sorted(n[perm]) - perm = ak.coargsort([z, a]) + perm = ak.coargsort([z, a], algo) assert ak.is_sorted(a[perm]) - perm = ak.coargsort([z, n]) + perm = ak.coargsort([z, n], algo) assert ak.is_sorted(n[perm]) -def check_int_float(N): +def check_int_float(N, algo): f = ak.randint(0, 2**63, N, dtype=ak.float64) i = ak.randint(0, 2**63, N, dtype=ak.int64) - perm = ak.coargsort([f, i]) + perm = ak.coargsort([f, i], algo) assert ak.is_sorted(f[perm]) - perm = ak.coargsort([i, f]) + perm = ak.coargsort([i, f], algo) assert ak.is_sorted(i[perm]) -def check_large(N): +def check_large(N, algo): l = [ak.randint(0, 2**63, N) for _ in range(10)] - perm = ak.coargsort(l) + perm = ak.coargsort(l, algo) assert ak.is_sorted(l[0][perm]) def check_coargsort(N_per_locale): @@ -112,63 +112,71 @@ def check_coargsort(N_per_locale): N = N_per_locale * cfg["numLocales"] print("numLocales = {}, N = {:,}".format(cfg["numLocales"], N)) - check_int(N) - check_float(N) - check_int_float(N) - check_large(N) + for algo in ak.SortingAlgorithm: + check_int(N, algo) + check_float(N, algo) + check_int_float(N, algo) + check_large(N, algo) class CoargsortTest(ArkoudaTest): def test_int(self): - check_int(10**3) + for algo in ak.SortingAlgorithm: + check_int(10**3, algo) def test_float(self): - check_float(10**3) + for algo in ak.SortingAlgorithm: + check_float(10**3, algo) def test_int_float(self): - check_int_float(10**3) + for algo in ak.SortingAlgorithm: + check_int_float(10**3, algo) def test_large(self): - check_large(10**3) + for algo in ak.SortingAlgorithm: + check_large(10**3, algo) def test_error_handling(self): ones = ak.ones(100) short_ones = ak.ones(10) - - with self.assertRaises(ValueError): - ak.coargsort([ones, short_ones]) - - with self.assertRaises(TypeError): - ak.coargsort([list(range(0,10)), [0]]) + + for algo in ak.SortingAlgorithm: + with self.assertRaises(ValueError): + ak.coargsort([ones, short_ones], algo) + + for algo in ak.SortingAlgorithm: + with self.assertRaises(TypeError): + ak.coargsort([list(range(0,10)), [0]], algo) def test_coargsort_categorical(self): string = ak.array(['a', 'b', 'a', 'b', 'c']) cat = ak.Categorical(string) cat_from_codes = ak.Categorical.from_codes(codes=ak.array([0, 1, 0, 1, 2]), categories=ak.array(['a', 'b', 'c'])) - str_perm = ak.coargsort([string]) - str_sorted = string[str_perm].to_ndarray() - - # coargsort on categorical - cat_perm = ak.coargsort([cat]) - cat_sorted = cat[cat_perm].to_ndarray() - self.assertTrue((str_sorted == cat_sorted).all()) - - # coargsort on categorical.from_codes - # coargsort sorts using codes, the order isn't guaranteed, only grouping - from_codes_perm = ak.coargsort([cat_from_codes]) - from_codes_sorted = cat_from_codes[from_codes_perm].to_ndarray() - self.assertTrue((['a', 'a', 'b', 'b', 'c'] == from_codes_sorted).all()) - - # coargsort on 2 categoricals (one from_codes) - cat_perm = ak.coargsort([cat, cat_from_codes]) - cat_sorted = cat[cat_perm].to_ndarray() - self.assertTrue((str_sorted == cat_sorted).all()) - - # coargsort on mixed strings and categoricals - mixed_perm = ak.coargsort([cat, string, cat_from_codes]) - mixed_sorted = cat_from_codes[mixed_perm].to_ndarray() - self.assertTrue((str_sorted == mixed_sorted).all()) + for algo in ak.SortingAlgorithm: + str_perm = ak.coargsort([string], algo) + str_sorted = string[str_perm].to_ndarray() + + # coargsort on categorical + cat_perm = ak.coargsort([cat], algo) + cat_sorted = cat[cat_perm].to_ndarray() + self.assertTrue((str_sorted == cat_sorted).all()) + + # coargsort on categorical.from_codes + # coargsort sorts using codes, the order isn't guaranteed, only grouping + from_codes_perm = ak.coargsort([cat_from_codes], algo) + from_codes_sorted = cat_from_codes[from_codes_perm].to_ndarray() + self.assertTrue((['a', 'a', 'b', 'b', 'c'] == from_codes_sorted).all()) + + # coargsort on 2 categoricals (one from_codes) + cat_perm = ak.coargsort([cat, cat_from_codes], algo) + cat_sorted = cat[cat_perm].to_ndarray() + self.assertTrue((str_sorted == cat_sorted).all()) + + # coargsort on mixed strings and categoricals + mixed_perm = ak.coargsort([cat, string, cat_from_codes], algo) + mixed_sorted = cat_from_codes[mixed_perm].to_ndarray() + self.assertTrue((str_sorted == mixed_sorted).all()) def create_parser(): diff --git a/tests/sort_test.py b/tests/sort_test.py index bf53ffea74..b508e0e966 100644 --- a/tests/sort_test.py +++ b/tests/sort_test.py @@ -9,9 +9,10 @@ class SortTest(ArkoudaTest): def testSort(self): pda = ak.randint(0,100,100) - spda = ak.sort(pda) - maxIndex = spda.argmax() - self.assertTrue(maxIndex > 0) + for algo in ak.SortingAlgorithm: + spda = ak.sort(pda, algo) + maxIndex = spda.argmax() + self.assertTrue(maxIndex > 0) def testBitBoundaryHardcode(self): @@ -19,17 +20,19 @@ def testBitBoundaryHardcode(self): a = ak.array([1, -1, 32767]) # 16 bit b = ak.array([1, 0, 32768]) # 16 bit c = ak.array([1, -1, 32768]) # 17 bit - assert ak.is_sorted(ak.sort(a)) - assert ak.is_sorted(ak.sort(b)) - assert ak.is_sorted(ak.sort(c)) + for algo in ak.SortingAlgorithm: + assert ak.is_sorted(ak.sort(a, algo)) + assert ak.is_sorted(ak.sort(b, algo)) + assert ak.is_sorted(ak.sort(c, algo)) # test hardcoded 64-bit boundaries with and without negative values d = ak.array([1, -1, 2**63-1]) e = ak.array([1, 0, 2**63-1]) f = ak.array([1, -2**63, 2**63-1]) - assert ak.is_sorted(ak.sort(d)) - assert ak.is_sorted(ak.sort(e)) - assert ak.is_sorted(ak.sort(f)) + for algo in ak.SortingAlgorithm: + assert ak.is_sorted(ak.sort(d, algo)) + assert ak.is_sorted(ak.sort(e, algo)) + assert ak.is_sorted(ak.sort(f, algo)) def testBitBoundary(self): @@ -37,28 +40,30 @@ def testBitBoundary(self): L = -2**15 U = 2**16 a = ak.randint(L, U, 100) - assert ak.is_sorted(ak.sort(a)) + for algo in ak.SortingAlgorithm: + assert ak.is_sorted(ak.sort(a, algo)) def testErrorHandling(self): # Test RuntimeError from bool NotImplementedError akbools = ak.randint(0, 1, 1000, dtype=ak.bool) bools = ak.randint(0, 1, 1000, dtype=bool) - - with self.assertRaises(ValueError) as cm: - ak.sort(akbools) - self.assertEqual('ak.sort supports float64 or int64, not bool', - cm.exception.args[0]) + + for algo in ak.SortingAlgorithm: + with self.assertRaises(ValueError) as cm: + ak.sort(akbools, algo) + self.assertEqual('ak.sort supports float64 or int64, not bool', + cm.exception.args[0]) - with self.assertRaises(ValueError) as cm: - ak.sort(bools) - self.assertEqual('ak.sort supports float64 or int64, not bool', - cm.exception.args[0]) + with self.assertRaises(ValueError) as cm: + ak.sort(bools, algo) + self.assertEqual('ak.sort supports float64 or int64, not bool', + cm.exception.args[0]) - # Test TypeError from sort attempt on non-pdarray - with self.assertRaises(TypeError): - ak.sort(list(range(0,10))) + # Test TypeError from sort attempt on non-pdarray + with self.assertRaises(TypeError): + ak.sort(list(range(0,10)), algo) - # Test attempt to sort Strings object, which is unsupported - with self.assertRaises(TypeError): - ak.sort(ak.array(['String {}'.format(i) for i in range(0,10)])) + # Test attempt to sort Strings object, which is unsupported + with self.assertRaises(TypeError): + ak.sort(ak.array(['String {}'.format(i) for i in range(0,10)]), algo)