Skip to content

Commit

Permalink
Merge pull request #988 from Bears-R-Us/chapel-sort
Browse files Browse the repository at this point in the history
Closes #984 Add `twoArrayRadixSort` as a runtime choice
  • Loading branch information
glitch authored Dec 2, 2021
2 parents 639faef + b7ca2f2 commit 20906e9
Show file tree
Hide file tree
Showing 6 changed files with 281 additions and 117 deletions.
21 changes: 13 additions & 8 deletions arkouda/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
from arkouda.pdarraycreation import zeros
from arkouda.strings import Strings
from arkouda.dtypes import int64, float64
from enum import Enum

numeric_dtypes = {float64,int64}

__all__ = ["argsort", "coargsort", "sort"]
__all__ = ["argsort", "coargsort", "sort", "SortingAlgorithm"]

def argsort(pda : Union[pdarray,Strings,'Categorical']) -> pdarray: # type: ignore
SortingAlgorithm = Enum('SortingAlgorithm', ['RadixSortLSD', 'TwoArrayRadixSort'])

def argsort(pda : Union[pdarray,Strings,'Categorical'], algorithm : SortingAlgorithm = SortingAlgorithm.RadixSortLSD) -> pdarray: # type: ignore
"""
Return the permutation that sorts the array.
Expand Down Expand Up @@ -57,11 +60,11 @@ def argsort(pda : Union[pdarray,Strings,'Categorical']) -> pdarray: # type: igno
name = '{}+{}'.format(pda.offsets.name, pda.bytes.name)
else:
name = pda.name
repMsg = generic_msg(cmd="argsort", args="{} {}".format(pda.objtype, name))
repMsg = generic_msg(cmd="argsort", args="{} {} {}".format(algorithm.name, pda.objtype, name))
return create_pdarray(cast(str,repMsg))


def coargsort(arrays: Sequence[Union[Strings, pdarray, 'Categorical']]) -> pdarray: # type: ignore
def coargsort(arrays: Sequence[Union[Strings, pdarray, 'Categorical']], algorithm : SortingAlgorithm = SortingAlgorithm.RadixSortLSD) -> pdarray: # type: ignore
"""
Return the permutation that groups the rows (left-to-right), if the
input arrays are treated as columns. The permutation sorts numeric
Expand Down Expand Up @@ -129,12 +132,14 @@ def coargsort(arrays: Sequence[Union[Strings, pdarray, 'Categorical']]) -> pdarr
raise ValueError("All pdarrays, Strings, or Categoricals must be of the same size")
if size == 0:
return zeros(0, dtype=int64)
repMsg = generic_msg(cmd="coargsort", args="{:n} {} {}".format(len(arrays),
' '.join(anames), ' '.join(atypes)))
repMsg = generic_msg(cmd="coargsort", args="{} {:n} {} {}".format(algorithm.name,
len(arrays),
' '.join(anames),
' '.join(atypes)))
return create_pdarray(cast(str, repMsg))

@typechecked
def sort(pda : pdarray) -> pdarray:
def sort(pda : pdarray, algorithm : SortingAlgorithm = SortingAlgorithm.RadixSortLSD) -> pdarray:
"""
Return a sorted copy of the array. Only sorts numeric arrays;
for Strings, use argsort.
Expand Down Expand Up @@ -177,5 +182,5 @@ def sort(pda : pdarray) -> pdarray:
return zeros(0, dtype=int64)
if pda.dtype not in numeric_dtypes:
raise ValueError("ak.sort supports float64 or int64, not {}".format(pda.dtype))
repMsg = generic_msg(cmd="sort", args="{}".format(pda.name))
repMsg = generic_msg(cmd="sort", args="{} {}".format(algorithm.name, pda.name))
return create_pdarray(cast(str,repMsg))
139 changes: 125 additions & 14 deletions src/ArgSortMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ module ArgSortMsg

use Time only;
use Math only;
use Sort only;
private use Sort;
use Reflection only;

use PrivateDist;
Expand Down Expand Up @@ -46,6 +46,74 @@ module ArgSortMsg
var mBins = 2**25;
var lBins = 2**25 * numLocales;

enum SortingAlgorithm {
RadixSortLSD,
TwoArrayRadixSort
};
config const defaultSortAlgorithm: SortingAlgorithm = SortingAlgorithm.RadixSortLSD;

// proc DefaultComparator.keyPart(x: _tuple, i:int) where !isHomogeneousTuple(x) &&
// (isInt(x(0)) || isUint(x(0)) || isReal(x(0))) {

import Reflection.canResolveMethod;
record ContrivedComparator {
const dc = new DefaultComparator();
proc keyPart(a, i: int) {
if canResolveMethod(dc, "keyPart", a, 0) {
return dc.keyPart(a, i);
} else if isTuple(a) {
return tupleKeyPart(a, i);
} else {
compilerError("No keyPart method for eltType ", a.type:string);
}
}
proc tupleKeyPart(x: _tuple, i:int) {
proc makePart(y): uint(64) {
var part: uint(64);
// get the part, ignore the section
const p = dc.keyPart(y, 0)(1);
// assuming result of keyPart is int or uint <= 64 bits
part = p:uint(64);
// If the number is signed, invert the top bit, so that
// the negative numbers sort below the positive numbers
if isInt(p) {
const one:uint(64) = 1;
part = part ^ (one << 63);
}
return part;
}
var part: uint(64);
if isTuple(x[0]) && (x.size == 2) {
for param j in 0..<x[0].size {
if i == j {
part = makePart(x[0][j]);
}
}
if i == x[0].size {
part = makePart(x[1]);
}
if i > x[0].size {
return (-1, 0:uint(64));
} else {
return (0, part);
}
} else {
for param j in 0..<x.size {
if i == j {
part = makePart(x[j]);
}
}
if i >= x.size {
return (-1, 0:uint(64));
} else {
return (0, part);
}
}
}
}

const myDefaultComparator = new ContrivedComparator();

/* Perform one step in a multi-step argsort, starting with an initial
permutation vector and further permuting it in the manner required
to sort an array of keys.
Expand All @@ -65,7 +133,7 @@ module ArgSortMsg
agg.copy(newai, olda[idx]);
}
// Generate the next incremental permutation
deltaIV = radixSortLSD_ranks(newa);
deltaIV = argsortDefault(newa);
}
when DType.Float64 {
var e = toSymEntry(g, real);
Expand All @@ -74,7 +142,7 @@ module ArgSortMsg
forall (newai, idx) in zip(newa, iv) with (var agg = newSrcAggregator(real)) {
agg.copy(newai, olda[idx]);
}
deltaIV = radixSortLSD_ranks(newa);
deltaIV = argsortDefault(newa);
}
otherwise { throw getErrorWithContext(
msg="Unsupported DataType: %t".format(dtype2str(g.dtype)),
Expand All @@ -100,7 +168,7 @@ module ArgSortMsg
forall (nh, idx) in zip(newHashes, iv) with (var agg = newSrcAggregator((2*uint))) {
agg.copy(nh, hashes[idx]);
}
var deltaIV = radixSortLSD_ranks(newHashes);
var deltaIV = argsortDefault(newHashes);
// var (newOffsets, newVals) = s[iv];
// var deltaIV = newStr.argGroup();
var newIV: [aD] int;
Expand Down Expand Up @@ -130,7 +198,21 @@ module ArgSortMsg
proc coargsortMsg(cmd: string, payload: string, st: borrowed SymTab): MsgTuple throws {
param pn = Reflection.getRoutineName();
var repMsg: string;
var (nstr, rest) = payload.splitMsgToTuple(2);
var (algoName, nstr, rest) = payload.splitMsgToTuple(3);
var algorithm: SortingAlgorithm = defaultSortAlgorithm;
if algoName != "" {
try {
algorithm = algoName: SortingAlgorithm;
} catch {
throw getErrorWithContext(
msg="Unrecognized sorting algorithm: %s".format(algoName),
lineNumber=getLineNumber(),
routineName=getRoutineName(),
moduleName=getModuleName(),
errorClass="NotImplementedError"
);
}
}
var n = nstr:int; // number of arrays to sort
var fields = rest.split();
asLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),
Expand Down Expand Up @@ -263,7 +345,7 @@ module ArgSortMsg
}
}

var iv = argsortDefault(merged);
var iv = argsortDefault(merged, algorithm=algorithm);
st.addEntry(ivname, new shared SymEntry(iv));

var repMsg = "created " + st.attrib(ivname);
Expand Down Expand Up @@ -305,12 +387,28 @@ module ArgSortMsg
return new MsgTuple(repMsg, MsgType.NORMAL);
}

proc argsortDefault(A:[?D] ?t):[D] int throws {
proc argsortDefault(A:[?D] ?t, algorithm:SortingAlgorithm=defaultSortAlgorithm):[D] int throws {
var t1 = Time.getCurrentTime();
//var AI = [(a, i) in zip(A, D)] (a, i);
//Sort.TwoArrayRadixSort.twoArrayRadixSort(AI);
//var iv = [(a, i) in AI] i;
var iv = radixSortLSD_ranks(A);
var iv: [D] int;
select algorithm {
when SortingAlgorithm.TwoArrayRadixSort {
var AI = [(a, i) in zip(A, D)] (a, i);
Sort.TwoArrayRadixSort.twoArrayRadixSort(AI, comparator=myDefaultComparator);
iv = [(a, i) in AI] i;
}
when SortingAlgorithm.RadixSortLSD {
iv = radixSortLSD_ranks(A);
}
otherwise {
throw getErrorWithContext(
msg="Unrecognized sorting algorithm: %s".format(algorithm:string),
lineNumber=getLineNumber(),
routineName=getRoutineName(),
moduleName=getModuleName(),
errorClass="NotImplementedError"
);
}
}
try! asLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),
"argsort time = %i".format(Time.getCurrentTime() - t1));
return iv;
Expand All @@ -321,8 +419,21 @@ module ArgSortMsg
param pn = Reflection.getRoutineName();
var repMsg: string; // response message
// split request into fields
var (objtype, name) = payload.splitMsgToTuple(2);

var (algoName, objtype, name) = payload.splitMsgToTuple(3);
var algorithm: SortingAlgorithm = defaultSortAlgorithm;
if algoName != "" {
try {
algorithm = algoName: SortingAlgorithm;
} catch {
throw getErrorWithContext(
msg="Unrecognized sorting algorithm: %s".format(algoName),
lineNumber=getLineNumber(),
routineName=getRoutineName(),
moduleName=getModuleName(),
errorClass="NotImplementedError"
);
}
}
// get next symbol name
var ivname = st.nextName();
asLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),
Expand All @@ -338,7 +449,7 @@ module ArgSortMsg
select (gEnt.dtype) {
when (DType.Int64) {
var e = toSymEntry(gEnt,int);
var iv = argsortDefault(e.a);
var iv = argsortDefault(e.a, algorithm=algorithm);
st.addEntry(ivname, new shared SymEntry(iv));
}
when (DType.Float64) {
Expand Down
45 changes: 40 additions & 5 deletions src/SortMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ module SortMsg
use AryUtil;
use Logging;
use Message;
private use ArgSortMsg;

private config const logLevel = ServerConfig.logLevel;
const sortLogger = new Logger(logLevel);
Expand All @@ -29,8 +30,21 @@ module SortMsg
proc sortMsg(cmd: string, payload: string, st: borrowed SymTab): MsgTuple throws {
param pn = Reflection.getRoutineName();
var repMsg: string; // response message
var (name) = payload.splitMsgToTuple(1);

var (algoName, name) = payload.splitMsgToTuple(2);
var algorithm: SortingAlgorithm = defaultSortAlgorithm;
if algoName != "" {
try {
algorithm = algoName: SortingAlgorithm;
} catch {
throw getErrorWithContext(
msg="Unrecognized sorting algorithm: %s".format(algoName),
lineNumber=getLineNumber(),
routineName=getRoutineName(),
moduleName=getModuleName(),
errorClass="NotImplementedError"
);
}
}
// get next symbol name
var sortedName = st.nextName();

Expand All @@ -43,18 +57,39 @@ module SortMsg
sortLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),
"cmd: %s name: %s sortedName: %s dtype: %t".format(
cmd, name, sortedName, gEnt.dtype));


proc doSort(a: [?D] ?t) throws {
select algorithm {
when SortingAlgorithm.TwoArrayRadixSort {
var b: [D] t = a;
Sort.TwoArrayRadixSort.twoArrayRadixSort(b, comparator=myDefaultComparator);
return b;
}
when SortingAlgorithm.RadixSortLSD {
return radixSortLSD_keys(a);
}
otherwise {
throw getErrorWithContext(
msg="Unrecognized sorting algorithm: %s".format(algorithm:string),
lineNumber=getLineNumber(),
routineName=getRoutineName(),
moduleName=getModuleName(),
errorClass="NotImplementedError"
);
}
}
}
// Sort the input pda and create a new symbol entry for
// the sorted pda.
select (gEnt.dtype) {
when (DType.Int64) {
var e = toSymEntry(gEnt, int);
var sorted = sort(e.a);
var sorted = doSort(e.a);
st.addEntry(sortedName, new shared SymEntry(sorted));
}// end when(DType.Int64)
when (DType.Float64) {
var e = toSymEntry(gEnt, real);
var sorted = sort(e.a);
var sorted = doSort(e.a);
st.addEntry(sortedName, new shared SymEntry(sorted));
}// end when(DType.Float64)
otherwise {
Expand Down
4 changes: 2 additions & 2 deletions test/UnitTestArgSortMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ prototype module UnitTestArgSort

// sort it and return iv in symbol table
var cmd = "argsort";
reqMsg = try! "pdarray %s".format(aname);
reqMsg = try! "%s pdarray %s".format(ArgSortMsg.SortingAlgorithm.RadixSortLSD: string, aname);
writeReq(reqMsg);
var d: Diags;
d.start();
Expand Down Expand Up @@ -57,7 +57,7 @@ prototype module UnitTestArgSort

// cosort both int and real arrays and return iv in symbol table
cmd = "coargsort";
reqMsg = try! "%i %s %s pdarray pdarray".format(2, aname, fname);
reqMsg = try! "%s %i %s %s pdarray pdarray".format(ArgSortMsg.SortingAlgorithm.RadixSortLSD: string, 2, aname, fname);
writeReq(reqMsg);
d.start();
repMsg = coargsortMsg(cmd=cmd, payload=reqMsg, st).msg;
Expand Down
Loading

0 comments on commit 20906e9

Please sign in to comment.