diff --git a/arkouda/strings.py b/arkouda/strings.py index 3e4b17c66b..30b59fe94b 100755 --- a/arkouda/strings.py +++ b/arkouda/strings.py @@ -1,7 +1,7 @@ from __future__ import annotations import itertools -from typing import cast, Tuple, List, Optional, Union +from typing import cast, Tuple, List, Optional, Union, Dict from typeguard import typechecked from arkouda.client import generic_msg from arkouda.pdarrayclass import pdarray, create_pdarray, parse_single_value, \ @@ -13,7 +13,7 @@ translate_np_dtype import json import re -from arkouda.infoclass import information +from arkouda.infoclass import information, list_symbol_table __all__ = ['Strings'] @@ -39,6 +39,10 @@ class Strings: The sizes of each dimension of the array dtype : dtype The dtype is ak.str + regex_dict: Dict[str, Tuple[pdarray, pdarray, pdarray]] + Dictionary storing information on matches (cache of Strings.find_locations(pattern)) + Keys - regex patterns + Values - tuples of pdarrays (numMatches, matchStarts, matchLens) logger : ArkoudaLogger Used for all logging operations @@ -102,6 +106,7 @@ def __init__(self, offset_attrib : Union[pdarray,str], self.dtype = npstr self.name:Optional[str] = None + self.regex_dict: Dict = dict() self.logger = getArkoudaLogger(name=__class__.__name__) # type: ignore def __iter__(self): @@ -255,6 +260,141 @@ def get_lengths(self) -> pdarray: format(self.objtype, self.offsets.name, self.bytes.name) return create_pdarray(generic_msg(cmd=cmd,args=args)) + def cached_regex_patterns(self): + """ + Returns the regex patterns for which Strings.find_locations(pattern) have been cached + """ + sym_tab = list_symbol_table() + self.regex_dict = {key: val for key, val in self.regex_dict.items() if + all([pda.name in sym_tab for pda in val])} + return self.regex_dict.keys() + + @typechecked + def find_locations(self, pattern: Union[bytes, str_scalars]) -> Tuple[pdarray, pdarray, pdarray]: + """ + Finds pattern matches and returns pdarrays containing the number, start postitions, and lengths of matches + Note: only handles regular expressions supported by re2 (does not support lookaheads/lookbehinds) + + Parameters + ---------- + pattern: str_scalars + The regex pattern used to find matches + + Returns + ------- + pdarray, int64 + For each original string, the number of pattern matches + pdarray, int64 + The start positons of pattern matches + pdarray, int64 + The lengths of pattern matches + + Raises + ------ + TypeError + Raised if the pattern parameter is not bytes or str_scalars + ValueError + Rasied if pattern is not a valid regex + RuntimeError + Raised if there is a server-side error thrown + + See Also + -------- + Strings.findall, Strings.match + + Examples + -------- + >>> strings = ak.array(['{} string {}'.format(i, i) for i in range(1, 6)]) + >>> strings + array(['1 string 1', '2 string 2', '3 string 3', '4 string 4', '5 string 5']) + >>> num_matches, starts, lens = strings.find_locations('\\d') + >>> num_matches + array([2, 2, 2, 2, 2]) + >>> starts + array([0, 9, 11, 20, 22, 31, 33, 42, 44, 53]) + >>> lens + array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])) + """ + if isinstance(pattern, bytes): + pattern = pattern.decode() + sym_tab = list_symbol_table() + if pattern not in self.regex_dict or any([pda.name not in sym_tab for pda in self.regex_dict[pattern]]): + # run find_locations if we don't have the result of find_locations(pattern) cached or any of the references have gone stale + try: + re.compile(pattern) + except Exception as e: + raise ValueError(e) + cmd = "segmentedFindLoc" + args = "{} {} {} {}".format(self.objtype, + self.offsets.name, + self.bytes.name, + json.dumps([pattern])) + repMsg = cast(str, generic_msg(cmd=cmd, args=args)) + arrays = repMsg.split('+', maxsplit=2) + self.regex_dict[pattern] = (create_pdarray(arrays[0]), create_pdarray(arrays[1]), create_pdarray(arrays[2])) + return self.regex_dict[pattern] + + @typechecked + def findall(self, pattern: Union[bytes, str_scalars], return_match_origins: bool = False) -> Union[Strings, Tuple]: + """ + Return all non-overlapping matches of pattern in Strings as a new Strings object + Note: only handles regular expressions supported by re2 (does not support lookaheads/lookbehinds) + + Parameters + ---------- + pattern: str_scalars + The regex pattern used to find matches + return_match_origins: bool + If True, return a pdarray containing the index of the original string each pattern match is from + + Returns + ------- + Strings + Strings object containing only pattern matches + pdarray, int64 (optional) + The index of the original string each pattern match is from + + Raises + ------ + TypeError + Raised if the pattern parameter is not bytes or str_scalars + ValueError + Rasied if pattern is not a valid regex + RuntimeError + Raised if there is a server-side error thrown + + See Also + -------- + Strings.find_locations, Strings.match + + Examples + -------- + >>> strings = ak.array(['{} string {}'.format(i, i) for i in range(1, 6)]) + >>> strings + array(['1 string 1', '2 string 2', '3 string 3', '4 string 4', '5 string 5']) + >>> matches, match_origins = strings.findall('\\d', return_match_origins = True) + >>> matches + array(['1', '1', '2', '2', '3', '3', '4', '4', '5', '5']) + >>> match_origins + array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4]) + """ + num_matches, starts, lens = self.find_locations(pattern) + cmd = "segmentedFindAll" + args = "{} {} {} {} {} {} {}".format(self.objtype, + self.offsets.name, + self.bytes.name, + num_matches.name, + starts.name, + lens.name, + return_match_origins) + repMsg = cast(str, generic_msg(cmd=cmd, args=args)) + if return_match_origins: + arrays = repMsg.split('+', maxsplit=2) + return Strings(arrays[0], arrays[1]), create_pdarray(arrays[2]) + else: + arrays = repMsg.split('+', maxsplit=1) + return Strings(arrays[0], arrays[1]) + @typechecked def contains(self, substr: Union[bytes, str_scalars], regex: bool = False) -> pdarray: """ diff --git a/src/SegmentedArray.chpl b/src/SegmentedArray.chpl index 25885b0693..c86feb8a63 100644 --- a/src/SegmentedArray.chpl +++ b/src/SegmentedArray.chpl @@ -461,7 +461,7 @@ module SegmentedArray { } /* - Returns Regexp.compile if pattern can be compiled without an error + Returns Regexp.compile if pattern can be compiled without an error */ proc checkCompile(const pattern: ?t) throws where t == bytes || t == string { try { @@ -483,6 +483,113 @@ module SegmentedArray { return try! compile(pattern); } + /* + Given a SegString, finds pattern matches and returns pdarrays containing the number, start postitions, and lengths of matches + Note: the regular expression engine used, re2, does not support lookahead/lookbehind + :arg pattern: The regex pattern used to find matches + :type pattern: string + :returns: int64 pdarray – For each original string, the number of pattern matches and int64 pdarray – The start positons of pattern matches and int64 pdarray – The lengths of pattern matches + */ + proc findMatchLocations(const pattern: string) throws { + checkCompile(pattern); + ref origOffsets = this.offsets.a; + ref origVals = this.values.a; + const lengths = this.getLengths(); + + overMemLimit((this.offsets.size * numBytes(int)) + (2 * this.values.size * numBytes(int))); + var numMatches: [this.offsets.aD] int; + var matchStartBool: [this.values.aD] bool = false; + var sparseLens: [this.values.aD] int; + + forall (i, off, len) in zip(this.offsets.aD, origOffsets, lengths) with (var myRegex = _unsafeCompileRegex(pattern), + var lenAgg = newDstAggregator(int), + var startAgg = newDstAggregator(bool), + var matchAgg = newDstAggregator(int)) { + var matches = myRegex.matches(interpretAsString(origVals[off..#len])); + for m in matches { + var match: reMatch = m[0]; + lenAgg.copy(sparseLens[off + match.offset:int], match.size); + startAgg.copy(matchStartBool[off + match.offset:int], true); + } + matchAgg.copy(numMatches[i], matches.size); + } + var totalMatches = + reduce numMatches; + // check there's enough room to create a copy for scan and throw if creating a copy would go over memory limit + overMemLimit(numBytes(int) * matchStartBool.size); + // the matchTransform starts at 0 and increment after hitting a matchStart + // when looping over the origVals domain, matchTransform acts as a function: origVals.domain -> makeDistDom(totalMatches) + var matchTransform = + scan matchStartBool - matchStartBool; + + var matchStarts: [makeDistDom(totalMatches)] int; + var matchLens: [makeDistDom(totalMatches)] int; + [i in this.values.aD] if (matchStartBool[i] == true) { + matchStarts[matchTransform[i]] = i; + matchLens[matchTransform[i]] = sparseLens[i]; + } + return (numMatches, matchStarts, matchLens); + } + + /* + Given a SegString, return a new SegString only containing matches of the regex pattern, + If returnMatchOrig is set to True, return a pdarray containing the index of the original string each pattern match is from + Note: the regular expression engine used, re2, does not support lookahead/lookbehind + :arg numMatchesEntry: For each string in SegString, the number of pattern matches + :type numMatchesEntry: borrowed SymEntry(int) + :arg startsEntry: The starting postions of pattern matches + :type startsEntry: borrowed SymEntry(int) + :arg lensEntry: The lengths of pattern matches + :type lensEntry: borrowed SymEntry(int) + :arg returnMatchOrig: If True, return a pdarray containing the index of the original string each pattern match is from + :type returnMatchOrig: bool + :returns: Strings – Only the portions of Strings which match pattern and (optional) int64 pdarray – For each pattern match, the index of the original string it was in + */ + proc findAllMatches(const numMatchesEntry: borrowed SymEntry(int), const startsEntry: borrowed SymEntry(int), const lensEntry: borrowed SymEntry(int), const returnMatchOrig: bool) throws { + ref origVals = this.values.a; + ref numMatches = numMatchesEntry.a; + ref matchStarts = startsEntry.a; + ref matchLens = lensEntry.a; + + // matchesValsSize is the total length of all matches + the number of matches (to account for null bytes) + var matchesValsSize = (+ reduce matchLens) + matchLens.size; + // check there's enough room to create a copy for scan and to allocate matchesVals/Offsets + overMemLimit((matchesValsSize * numBytes(uint(8))) + (2 * matchLens.size * numBytes(int))); + var matchesVals: [makeDistDom(matchesValsSize)] uint(8); + var matchesOffsets: [makeDistDom(matchLens.size)] int; + // + current index to account for null bytes + var matchesIndicies = + scan matchLens - matchLens + lensEntry.aD; + + forall (i, start, len, matchesInd) in zip(lensEntry.aD, matchStarts, matchLens, matchesIndicies) with (var valAgg = newDstAggregator(uint(8)), var offAgg = newDstAggregator(int)) { + for j in 0..#len { + // copy in match + valAgg.copy(matchesVals[matchesInd + j], origVals[start + j]); + } + // write null byte after each match + valAgg.copy(matchesVals[matchesInd + len], 0:uint(8)); + if i == 0 { + offAgg.copy(matchesOffsets[i], 0); + } + if i != lensEntry.aD.high { + offAgg.copy(matchesOffsets[i+1], matchesInd + len + 1); + } + } + + // build matchOrigins mapping from matchesStrings (pattern matches) to the original Strings they were found in + const matchOriginsDom = if returnMatchOrig then makeDistDom(matchesOffsets.size) else makeDistDom(0); + var matchOrigins: [matchOriginsDom] int; + if returnMatchOrig { + // check there's enough room to create a copy for scan and throw if creating a copy would go over memory limit + overMemLimit(numBytes(int) * numMatches.size); + var matchesIndicies = (+ scan numMatches) - numMatches; + forall (stringInd, matchInd) in zip(this.offsets.aD, matchesIndicies) with (var originAgg = newDstAggregator(int)) { + for k in matchInd..#numMatches[stringInd] { + // Each string has numMatches[stringInd] number of pattern matches, so matchOrigins needs to repeat stringInd for numMatches[stringInd] times + originAgg.copy(matchOrigins[k], stringInd); + } + } + } + return (matchesOffsets, matchesVals, matchOrigins); + } + /* Returns list of bools where index i indicates whether the regular expression, pattern, matched string i of the SegString diff --git a/src/SegmentedMsg.chpl b/src/SegmentedMsg.chpl index ff78a25852..328ab664ea 100644 --- a/src/SegmentedMsg.chpl +++ b/src/SegmentedMsg.chpl @@ -156,6 +156,90 @@ module SegmentedMsg { return new MsgTuple(repMsg, MsgType.NORMAL); } + proc segmentedFindLocMsg(cmd: string, payload: string, st: borrowed SymTab): MsgTuple throws { + var pn = Reflection.getRoutineName(); + var repMsg: string; + var (objtype, segName, valName, patternJson) = payload.splitMsgToTuple(4); + + // check to make sure symbols defined + st.checkTable(segName); + st.checkTable(valName); + + const json = jsonToPdArray(patternJson, 1); + const pattern: string = json[json.domain.low]; + + smLogger.debug(getModuleName(), getRoutineName(), getLineNumber(), + "cmd: %s objtype: %t".format(cmd, objtype)); + + select objtype { + when "str" { + const rNumMatchesName = st.nextName(); + const rStartsName = st.nextName(); + const rLensName = st.nextName(); + var strings = getSegString(segName, valName, st); + var (numMatches, matchStarts, matchLens) = strings.findMatchLocations(pattern); + st.addEntry(rNumMatchesName, new shared SymEntry(numMatches)); + st.addEntry(rStartsName, new shared SymEntry(matchStarts)); + st.addEntry(rLensName, new shared SymEntry(matchLens)); + repMsg = "created %s+created %s+created %s".format(st.attrib(rNumMatchesName), + st.attrib(rStartsName), + st.attrib(rLensName)); + } + otherwise { + var errorMsg = "%s".format(objtype); + smLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); + return new MsgTuple(notImplementedError(pn, errorMsg), MsgType.ERROR); + } + } + smLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg); + return new MsgTuple(repMsg, MsgType.NORMAL); + } + + proc segmentedFindAllMsg(cmd: string, payload: string, st: borrowed SymTab): MsgTuple throws { + var pn = Reflection.getRoutineName(); + var repMsg: string; + var (objtype, segName, valName, numMatchesName, startsName, lensName, returnMatchOrigStr) = payload.splitMsgToTuple(7); + const returnMatchOrig: bool = returnMatchOrigStr.toLower() == "true"; + + // check to make sure symbols defined + st.checkTable(segName); + st.checkTable(valName); + st.checkTable(numMatchesName); + st.checkTable(startsName); + st.checkTable(lensName); + + smLogger.debug(getModuleName(), getRoutineName(), getLineNumber(), + "cmd: %s objtype: %t".format(cmd, objtype)); + + select objtype { + when "str" { + const rSegName = st.nextName(); + const rValName = st.nextName(); + const optName: string = if returnMatchOrig then st.nextName() else ""; + var strings = getSegString(segName, valName, st); + var numMatches = st.lookup(numMatchesName): borrowed SymEntry(int); + var starts = st.lookup(startsName): borrowed SymEntry(int); + var lens = st.lookup(lensName): borrowed SymEntry(int); + + var (off, val, matchOrigins) = strings.findAllMatches(numMatches, starts, lens, returnMatchOrig); + st.addEntry(rSegName, new shared SymEntry(off)); + st.addEntry(rValName, new shared SymEntry(val)); + repMsg = "created %s+created %s".format(st.attrib(rSegName), st.attrib(rValName)); + if returnMatchOrig { + st.addEntry(optName, new shared SymEntry(matchOrigins)); + repMsg += "+created %s".format(st.attrib(optName)); + } + } + otherwise { + var errorMsg = "%s".format(objtype); + smLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg); + return new MsgTuple(notImplementedError(pn, errorMsg), MsgType.ERROR); + } + } + smLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg); + return new MsgTuple(repMsg, MsgType.NORMAL); + } + proc createPeelSymEntries(loname, lo, lvname, lv, roname, ro, rvname, rv, st: borrowed SymTab) throws { st.addEntry(loname, new shared SymEntry(lo)); st.addEntry(lvname, new shared SymEntry(lv)); diff --git a/src/arkouda_server.chpl b/src/arkouda_server.chpl index 1b497a45f3..12a14a9925 100644 --- a/src/arkouda_server.chpl +++ b/src/arkouda_server.chpl @@ -308,6 +308,8 @@ proc main() { when "segmentLengths" {repTuple = segmentLengthsMsg(cmd, args, st);} when "segmentedHash" {repTuple = segmentedHashMsg(cmd, args, st);} when "segmentedEfunc" {repTuple = segmentedEfuncMsg(cmd, args, st);} + when "segmentedFindLoc" {repTuple = segmentedFindLocMsg(cmd, args, st);} + when "segmentedFindAll" {repTuple = segmentedFindAllMsg(cmd, args, st);} when "segmentedPeel" {repTuple = segmentedPeelMsg(cmd, args, st);} when "segmentedIndex" {repTuple = segmentedIndexMsg(cmd, args, st);} when "segmentedBinopvv" {repTuple = segBinopvvMsg(cmd, args, st);} diff --git a/tests/regex_test.py b/tests/regex_test.py index fa777415e5..84be38ec5e 100644 --- a/tests/regex_test.py +++ b/tests/regex_test.py @@ -44,6 +44,48 @@ def test_regex_match(self): self.assertFalse(aaa_strings.match('ing a+').any()) self.assertFalse(aaa_strings.match('a+ str').any()) + def test_regex_find_locations(self): + strings = ak.array(['{} string {}'.format(i, i) for i in range(1, 6)]) + + expected_num_matches = [2, 2, 2, 2, 2] + expected_starts = [0, 9, 11, 20, 22, 31, 33, 42, 44, 53] + expected_lens = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] + actual_num_matches, actual_starts, actual_lens = strings.find_locations('\\d') + self.assertListEqual(expected_num_matches, actual_num_matches.to_ndarray().tolist()) + self.assertListEqual(expected_starts, actual_starts.to_ndarray().tolist()) + self.assertListEqual(expected_lens, actual_lens.to_ndarray().tolist()) + + expected_num_matches = [1, 1, 1, 1, 1] + expected_starts = [2, 13, 24, 35, 46] + expected_lens = [8, 8, 8, 8, 8] + actual_num_matches, actual_starts, actual_lens = strings.find_locations('string \\d') + self.assertListEqual(expected_num_matches, actual_num_matches.to_ndarray().tolist()) + self.assertListEqual(expected_starts, actual_starts.to_ndarray().tolist()) + self.assertListEqual(expected_lens, actual_lens.to_ndarray().tolist()) + + def test_regex_findall(self): + strings = ak.array(['{} string {}'.format(i, i) for i in range(1, 6)]) + expected_matches = ['1', '1', '2', '2', '3', '3', '4', '4', '5', '5'] + expected_match_origins = [0, 0, 1, 1, 2, 2, 3, 3, 4, 4] + actual_matches, actual_match_origins = strings.findall('\\d', return_match_origins=True) + self.assertListEqual(expected_matches, actual_matches.to_ndarray().tolist()) + self.assertListEqual(expected_match_origins, actual_match_origins.to_ndarray().tolist()) + actual_matches = strings.findall('\\d') + self.assertListEqual(expected_matches, actual_matches.to_ndarray().tolist()) + + expected_matches = ['string 1', 'string 2', 'string 3', 'string 4', 'string 5'] + expected_match_origins = [0, 1, 2, 3, 4] + actual_matches, actual_match_origins = strings.findall('string \\d', return_match_origins=True) + self.assertListEqual(expected_matches, actual_matches.to_ndarray().tolist()) + self.assertListEqual(expected_match_origins, actual_match_origins.to_ndarray().tolist()) + + under = ak.array(['', '____', '_1_2', '3___4___', '5']) + expected_matches = ['____', '_', '_', '___', '___'] + expected_match_origins = [1, 2, 2, 3, 3] + actual_matches, actual_match_origins = under.findall('_+', return_match_origins=True) + self.assertListEqual(expected_matches, actual_matches.to_ndarray().tolist()) + self.assertListEqual(expected_match_origins, actual_match_origins.to_ndarray().tolist()) + def test_regex_peel(self): orig = ak.array(['a.b', 'c.d', 'e.f.g']) digit = ak.array(['a1b', 'c1d', 'e1f2g'])