Skip to content


Merge pull request #927 from pierce314159/911_regex_find_slice_Strings
Browse files Browse the repository at this point in the history
Closes #911: Implement `find_locations` and `findall` on regexes for `Strings`
  • Loading branch information
reuster986 authored Sep 28, 2021
2 parents a18b753 + 3902c23 commit dc41935
Show file tree
Hide file tree
Showing 5 changed files with 378 additions and 3 deletions.
144 changes: 142 additions & 2 deletions arkouda/
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import itertools
from typing import cast, Tuple, List, Optional, Union
from typing import cast, Tuple, List, Optional, Union, Dict
from typeguard import typechecked
from arkouda.client import generic_msg
from arkouda.pdarrayclass import pdarray, create_pdarray, parse_single_value, \
Expand All @@ -13,7 +13,7 @@
import json
import re
from arkouda.infoclass import information
from arkouda.infoclass import information, list_symbol_table

__all__ = ['Strings']

Expand All @@ -39,6 +39,10 @@ class Strings:
The sizes of each dimension of the array
dtype : dtype
The dtype is ak.str
regex_dict: Dict[str, Tuple[pdarray, pdarray, pdarray]]
Dictionary storing information on matches (cache of Strings.find_locations(pattern))
Keys - regex patterns
Values - tuples of pdarrays (numMatches, matchStarts, matchLens)
logger : ArkoudaLogger
Used for all logging operations
Expand Down Expand Up @@ -102,6 +106,7 @@ def __init__(self, offset_attrib : Union[pdarray,str],

self.dtype = npstr[str] = None
self.regex_dict: Dict = dict()
self.logger = getArkoudaLogger(name=__class__.__name__) # type: ignore

def __iter__(self):
Expand Down Expand Up @@ -255,6 +260,141 @@ def get_lengths(self) -> pdarray:
return create_pdarray(generic_msg(cmd=cmd,args=args))

def cached_regex_patterns(self):
Returns the regex patterns for which Strings.find_locations(pattern) have been cached
sym_tab = list_symbol_table()
self.regex_dict = {key: val for key, val in self.regex_dict.items() if
all([ in sym_tab for pda in val])}
return self.regex_dict.keys()

def find_locations(self, pattern: Union[bytes, str_scalars]) -> Tuple[pdarray, pdarray, pdarray]:
Finds pattern matches and returns pdarrays containing the number, start postitions, and lengths of matches
Note: only handles regular expressions supported by re2 (does not support lookaheads/lookbehinds)
pattern: str_scalars
The regex pattern used to find matches
pdarray, int64
For each original string, the number of pattern matches
pdarray, int64
The start positons of pattern matches
pdarray, int64
The lengths of pattern matches
Raised if the pattern parameter is not bytes or str_scalars
Rasied if pattern is not a valid regex
Raised if there is a server-side error thrown
See Also
Strings.findall, Strings.match
>>> strings = ak.array(['{} string {}'.format(i, i) for i in range(1, 6)])
>>> strings
array(['1 string 1', '2 string 2', '3 string 3', '4 string 4', '5 string 5'])
>>> num_matches, starts, lens = strings.find_locations('\\d')
>>> num_matches
array([2, 2, 2, 2, 2])
>>> starts
array([0, 9, 11, 20, 22, 31, 33, 42, 44, 53])
>>> lens
array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]))
if isinstance(pattern, bytes):
pattern = pattern.decode()
sym_tab = list_symbol_table()
if pattern not in self.regex_dict or any([ not in sym_tab for pda in self.regex_dict[pattern]]):
# run find_locations if we don't have the result of find_locations(pattern) cached or any of the references have gone stale
except Exception as e:
raise ValueError(e)
cmd = "segmentedFindLoc"
args = "{} {} {} {}".format(self.objtype,,,
repMsg = cast(str, generic_msg(cmd=cmd, args=args))
arrays = repMsg.split('+', maxsplit=2)
self.regex_dict[pattern] = (create_pdarray(arrays[0]), create_pdarray(arrays[1]), create_pdarray(arrays[2]))
return self.regex_dict[pattern]

def findall(self, pattern: Union[bytes, str_scalars], return_match_origins: bool = False) -> Union[Strings, Tuple]:
Return all non-overlapping matches of pattern in Strings as a new Strings object
Note: only handles regular expressions supported by re2 (does not support lookaheads/lookbehinds)
pattern: str_scalars
The regex pattern used to find matches
return_match_origins: bool
If True, return a pdarray containing the index of the original string each pattern match is from
Strings object containing only pattern matches
pdarray, int64 (optional)
The index of the original string each pattern match is from
Raised if the pattern parameter is not bytes or str_scalars
Rasied if pattern is not a valid regex
Raised if there is a server-side error thrown
See Also
Strings.find_locations, Strings.match
>>> strings = ak.array(['{} string {}'.format(i, i) for i in range(1, 6)])
>>> strings
array(['1 string 1', '2 string 2', '3 string 3', '4 string 4', '5 string 5'])
>>> matches, match_origins = strings.findall('\\d', return_match_origins = True)
>>> matches
array(['1', '1', '2', '2', '3', '3', '4', '4', '5', '5'])
>>> match_origins
array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4])
num_matches, starts, lens = self.find_locations(pattern)
cmd = "segmentedFindAll"
args = "{} {} {} {} {} {} {}".format(self.objtype,,,,,,
repMsg = cast(str, generic_msg(cmd=cmd, args=args))
if return_match_origins:
arrays = repMsg.split('+', maxsplit=2)
return Strings(arrays[0], arrays[1]), create_pdarray(arrays[2])
arrays = repMsg.split('+', maxsplit=1)
return Strings(arrays[0], arrays[1])

def contains(self, substr: Union[bytes, str_scalars], regex: bool = False) -> pdarray:
Expand Down
109 changes: 108 additions & 1 deletion src/SegmentedArray.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ module SegmentedArray {

Returns Regexp.compile if pattern can be compiled without an error
Returns Regexp.compile if pattern can be compiled without an error
proc checkCompile(const pattern: ?t) throws where t == bytes || t == string {
try {
Expand All @@ -483,6 +483,113 @@ module SegmentedArray {
return try! compile(pattern);

Given a SegString, finds pattern matches and returns pdarrays containing the number, start postitions, and lengths of matches
Note: the regular expression engine used, re2, does not support lookahead/lookbehind
:arg pattern: The regex pattern used to find matches
:type pattern: string
:returns: int64 pdarray – For each original string, the number of pattern matches and int64 pdarray – The start positons of pattern matches and int64 pdarray – The lengths of pattern matches
proc findMatchLocations(const pattern: string) throws {
ref origOffsets = this.offsets.a;
ref origVals = this.values.a;
const lengths = this.getLengths();

overMemLimit((this.offsets.size * numBytes(int)) + (2 * this.values.size * numBytes(int)));
var numMatches: [this.offsets.aD] int;
var matchStartBool: [this.values.aD] bool = false;
var sparseLens: [this.values.aD] int;

forall (i, off, len) in zip(this.offsets.aD, origOffsets, lengths) with (var myRegex = _unsafeCompileRegex(pattern),
var lenAgg = newDstAggregator(int),
var startAgg = newDstAggregator(bool),
var matchAgg = newDstAggregator(int)) {
var matches = myRegex.matches(interpretAsString(origVals[off..#len]));
for m in matches {
var match: reMatch = m[0];
lenAgg.copy(sparseLens[off + match.offset:int], match.size);
startAgg.copy(matchStartBool[off + match.offset:int], true);
matchAgg.copy(numMatches[i], matches.size);
var totalMatches = + reduce numMatches;
// check there's enough room to create a copy for scan and throw if creating a copy would go over memory limit
overMemLimit(numBytes(int) * matchStartBool.size);
// the matchTransform starts at 0 and increment after hitting a matchStart
// when looping over the origVals domain, matchTransform acts as a function: origVals.domain -> makeDistDom(totalMatches)
var matchTransform = + scan matchStartBool - matchStartBool;

var matchStarts: [makeDistDom(totalMatches)] int;
var matchLens: [makeDistDom(totalMatches)] int;
[i in this.values.aD] if (matchStartBool[i] == true) {
matchStarts[matchTransform[i]] = i;
matchLens[matchTransform[i]] = sparseLens[i];
return (numMatches, matchStarts, matchLens);

Given a SegString, return a new SegString only containing matches of the regex pattern,
If returnMatchOrig is set to True, return a pdarray containing the index of the original string each pattern match is from
Note: the regular expression engine used, re2, does not support lookahead/lookbehind
:arg numMatchesEntry: For each string in SegString, the number of pattern matches
:type numMatchesEntry: borrowed SymEntry(int)
:arg startsEntry: The starting postions of pattern matches
:type startsEntry: borrowed SymEntry(int)
:arg lensEntry: The lengths of pattern matches
:type lensEntry: borrowed SymEntry(int)
:arg returnMatchOrig: If True, return a pdarray containing the index of the original string each pattern match is from
:type returnMatchOrig: bool
:returns: Strings – Only the portions of Strings which match pattern and (optional) int64 pdarray – For each pattern match, the index of the original string it was in
proc findAllMatches(const numMatchesEntry: borrowed SymEntry(int), const startsEntry: borrowed SymEntry(int), const lensEntry: borrowed SymEntry(int), const returnMatchOrig: bool) throws {
ref origVals = this.values.a;
ref numMatches = numMatchesEntry.a;
ref matchStarts = startsEntry.a;
ref matchLens = lensEntry.a;

// matchesValsSize is the total length of all matches + the number of matches (to account for null bytes)
var matchesValsSize = (+ reduce matchLens) + matchLens.size;
// check there's enough room to create a copy for scan and to allocate matchesVals/Offsets
overMemLimit((matchesValsSize * numBytes(uint(8))) + (2 * matchLens.size * numBytes(int)));
var matchesVals: [makeDistDom(matchesValsSize)] uint(8);
var matchesOffsets: [makeDistDom(matchLens.size)] int;
// + current index to account for null bytes
var matchesIndicies = + scan matchLens - matchLens + lensEntry.aD;

forall (i, start, len, matchesInd) in zip(lensEntry.aD, matchStarts, matchLens, matchesIndicies) with (var valAgg = newDstAggregator(uint(8)), var offAgg = newDstAggregator(int)) {
for j in 0..#len {
// copy in match
valAgg.copy(matchesVals[matchesInd + j], origVals[start + j]);
// write null byte after each match
valAgg.copy(matchesVals[matchesInd + len], 0:uint(8));
if i == 0 {
offAgg.copy(matchesOffsets[i], 0);
if i != lensEntry.aD.high {
offAgg.copy(matchesOffsets[i+1], matchesInd + len + 1);

// build matchOrigins mapping from matchesStrings (pattern matches) to the original Strings they were found in
const matchOriginsDom = if returnMatchOrig then makeDistDom(matchesOffsets.size) else makeDistDom(0);
var matchOrigins: [matchOriginsDom] int;
if returnMatchOrig {
// check there's enough room to create a copy for scan and throw if creating a copy would go over memory limit
overMemLimit(numBytes(int) * numMatches.size);
var matchesIndicies = (+ scan numMatches) - numMatches;
forall (stringInd, matchInd) in zip(this.offsets.aD, matchesIndicies) with (var originAgg = newDstAggregator(int)) {
for k in matchInd..#numMatches[stringInd] {
// Each string has numMatches[stringInd] number of pattern matches, so matchOrigins needs to repeat stringInd for numMatches[stringInd] times
originAgg.copy(matchOrigins[k], stringInd);
return (matchesOffsets, matchesVals, matchOrigins);

Returns list of bools where index i indicates whether the regular expression, pattern, matched string i of the SegString
Expand Down
84 changes: 84 additions & 0 deletions src/SegmentedMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,90 @@ module SegmentedMsg {
return new MsgTuple(repMsg, MsgType.NORMAL);

proc segmentedFindLocMsg(cmd: string, payload: string, st: borrowed SymTab): MsgTuple throws {
var pn = Reflection.getRoutineName();
var repMsg: string;
var (objtype, segName, valName, patternJson) = payload.splitMsgToTuple(4);

// check to make sure symbols defined

const json = jsonToPdArray(patternJson, 1);
const pattern: string = json[json.domain.low];

smLogger.debug(getModuleName(), getRoutineName(), getLineNumber(),
"cmd: %s objtype: %t".format(cmd, objtype));

select objtype {
when "str" {
const rNumMatchesName = st.nextName();
const rStartsName = st.nextName();
const rLensName = st.nextName();
var strings = getSegString(segName, valName, st);
var (numMatches, matchStarts, matchLens) = strings.findMatchLocations(pattern);
st.addEntry(rNumMatchesName, new shared SymEntry(numMatches));
st.addEntry(rStartsName, new shared SymEntry(matchStarts));
st.addEntry(rLensName, new shared SymEntry(matchLens));
repMsg = "created %s+created %s+created %s".format(st.attrib(rNumMatchesName),
otherwise {
var errorMsg = "%s".format(objtype);
return new MsgTuple(notImplementedError(pn, errorMsg), MsgType.ERROR);
return new MsgTuple(repMsg, MsgType.NORMAL);

proc segmentedFindAllMsg(cmd: string, payload: string, st: borrowed SymTab): MsgTuple throws {
var pn = Reflection.getRoutineName();
var repMsg: string;
var (objtype, segName, valName, numMatchesName, startsName, lensName, returnMatchOrigStr) = payload.splitMsgToTuple(7);
const returnMatchOrig: bool = returnMatchOrigStr.toLower() == "true";

// check to make sure symbols defined

smLogger.debug(getModuleName(), getRoutineName(), getLineNumber(),
"cmd: %s objtype: %t".format(cmd, objtype));

select objtype {
when "str" {
const rSegName = st.nextName();
const rValName = st.nextName();
const optName: string = if returnMatchOrig then st.nextName() else "";
var strings = getSegString(segName, valName, st);
var numMatches = st.lookup(numMatchesName): borrowed SymEntry(int);
var starts = st.lookup(startsName): borrowed SymEntry(int);
var lens = st.lookup(lensName): borrowed SymEntry(int);

var (off, val, matchOrigins) = strings.findAllMatches(numMatches, starts, lens, returnMatchOrig);
st.addEntry(rSegName, new shared SymEntry(off));
st.addEntry(rValName, new shared SymEntry(val));
repMsg = "created %s+created %s".format(st.attrib(rSegName), st.attrib(rValName));
if returnMatchOrig {
st.addEntry(optName, new shared SymEntry(matchOrigins));
repMsg += "+created %s".format(st.attrib(optName));
otherwise {
var errorMsg = "%s".format(objtype);
return new MsgTuple(notImplementedError(pn, errorMsg), MsgType.ERROR);
return new MsgTuple(repMsg, MsgType.NORMAL);

proc createPeelSymEntries(loname, lo, lvname, lv, roname, ro, rvname, rv, st: borrowed SymTab) throws {
st.addEntry(loname, new shared SymEntry(lo));
st.addEntry(lvname, new shared SymEntry(lv));
Expand Down

0 comments on commit dc41935

Please sign in to comment.