Skip to content

Commit

Permalink
Merge pull request #927 from pierce314159/911_regex_find_slice_Strings
Browse files Browse the repository at this point in the history
Closes #911: Implement `find_locations` and `findall` on regexes for `Strings`
  • Loading branch information
reuster986 authored Sep 28, 2021
2 parents a18b753 + 3902c23 commit dc41935
Show file tree
Hide file tree
Showing 5 changed files with 378 additions and 3 deletions.
144 changes: 142 additions & 2 deletions arkouda/strings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import itertools
from typing import cast, Tuple, List, Optional, Union
from typing import cast, Tuple, List, Optional, Union, Dict
from typeguard import typechecked
from arkouda.client import generic_msg
from arkouda.pdarrayclass import pdarray, create_pdarray, parse_single_value, \
Expand All @@ -13,7 +13,7 @@
translate_np_dtype
import json
import re
from arkouda.infoclass import information
from arkouda.infoclass import information, list_symbol_table

__all__ = ['Strings']

Expand All @@ -39,6 +39,10 @@ class Strings:
The sizes of each dimension of the array
dtype : dtype
The dtype is ak.str
regex_dict: Dict[str, Tuple[pdarray, pdarray, pdarray]]
Dictionary storing information on matches (cache of Strings.find_locations(pattern))
Keys - regex patterns
Values - tuples of pdarrays (numMatches, matchStarts, matchLens)
logger : ArkoudaLogger
Used for all logging operations
Expand Down Expand Up @@ -102,6 +106,7 @@ def __init__(self, offset_attrib : Union[pdarray,str],

self.dtype = npstr
self.name:Optional[str] = None
self.regex_dict: Dict = dict()
self.logger = getArkoudaLogger(name=__class__.__name__) # type: ignore

def __iter__(self):
Expand Down Expand Up @@ -255,6 +260,141 @@ def get_lengths(self) -> pdarray:
format(self.objtype, self.offsets.name, self.bytes.name)
return create_pdarray(generic_msg(cmd=cmd,args=args))

def cached_regex_patterns(self):
"""
Returns the regex patterns for which Strings.find_locations(pattern) have been cached
"""
sym_tab = list_symbol_table()
self.regex_dict = {key: val for key, val in self.regex_dict.items() if
all([pda.name in sym_tab for pda in val])}
return self.regex_dict.keys()

@typechecked
def find_locations(self, pattern: Union[bytes, str_scalars]) -> Tuple[pdarray, pdarray, pdarray]:
"""
Finds pattern matches and returns pdarrays containing the number, start postitions, and lengths of matches
Note: only handles regular expressions supported by re2 (does not support lookaheads/lookbehinds)
Parameters
----------
pattern: str_scalars
The regex pattern used to find matches
Returns
-------
pdarray, int64
For each original string, the number of pattern matches
pdarray, int64
The start positons of pattern matches
pdarray, int64
The lengths of pattern matches
Raises
------
TypeError
Raised if the pattern parameter is not bytes or str_scalars
ValueError
Rasied if pattern is not a valid regex
RuntimeError
Raised if there is a server-side error thrown
See Also
--------
Strings.findall, Strings.match
Examples
--------
>>> strings = ak.array(['{} string {}'.format(i, i) for i in range(1, 6)])
>>> strings
array(['1 string 1', '2 string 2', '3 string 3', '4 string 4', '5 string 5'])
>>> num_matches, starts, lens = strings.find_locations('\\d')
>>> num_matches
array([2, 2, 2, 2, 2])
>>> starts
array([0, 9, 11, 20, 22, 31, 33, 42, 44, 53])
>>> lens
array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]))
"""
if isinstance(pattern, bytes):
pattern = pattern.decode()
sym_tab = list_symbol_table()
if pattern not in self.regex_dict or any([pda.name not in sym_tab for pda in self.regex_dict[pattern]]):
# run find_locations if we don't have the result of find_locations(pattern) cached or any of the references have gone stale
try:
re.compile(pattern)
except Exception as e:
raise ValueError(e)
cmd = "segmentedFindLoc"
args = "{} {} {} {}".format(self.objtype,
self.offsets.name,
self.bytes.name,
json.dumps([pattern]))
repMsg = cast(str, generic_msg(cmd=cmd, args=args))
arrays = repMsg.split('+', maxsplit=2)
self.regex_dict[pattern] = (create_pdarray(arrays[0]), create_pdarray(arrays[1]), create_pdarray(arrays[2]))
return self.regex_dict[pattern]

@typechecked
def findall(self, pattern: Union[bytes, str_scalars], return_match_origins: bool = False) -> Union[Strings, Tuple]:
"""
Return all non-overlapping matches of pattern in Strings as a new Strings object
Note: only handles regular expressions supported by re2 (does not support lookaheads/lookbehinds)
Parameters
----------
pattern: str_scalars
The regex pattern used to find matches
return_match_origins: bool
If True, return a pdarray containing the index of the original string each pattern match is from
Returns
-------
Strings
Strings object containing only pattern matches
pdarray, int64 (optional)
The index of the original string each pattern match is from
Raises
------
TypeError
Raised if the pattern parameter is not bytes or str_scalars
ValueError
Rasied if pattern is not a valid regex
RuntimeError
Raised if there is a server-side error thrown
See Also
--------
Strings.find_locations, Strings.match
Examples
--------
>>> strings = ak.array(['{} string {}'.format(i, i) for i in range(1, 6)])
>>> strings
array(['1 string 1', '2 string 2', '3 string 3', '4 string 4', '5 string 5'])
>>> matches, match_origins = strings.findall('\\d', return_match_origins = True)
>>> matches
array(['1', '1', '2', '2', '3', '3', '4', '4', '5', '5'])
>>> match_origins
array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4])
"""
num_matches, starts, lens = self.find_locations(pattern)
cmd = "segmentedFindAll"
args = "{} {} {} {} {} {} {}".format(self.objtype,
self.offsets.name,
self.bytes.name,
num_matches.name,
starts.name,
lens.name,
return_match_origins)
repMsg = cast(str, generic_msg(cmd=cmd, args=args))
if return_match_origins:
arrays = repMsg.split('+', maxsplit=2)
return Strings(arrays[0], arrays[1]), create_pdarray(arrays[2])
else:
arrays = repMsg.split('+', maxsplit=1)
return Strings(arrays[0], arrays[1])

@typechecked
def contains(self, substr: Union[bytes, str_scalars], regex: bool = False) -> pdarray:
"""
Expand Down
109 changes: 108 additions & 1 deletion src/SegmentedArray.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ module SegmentedArray {
}

/*
Returns Regexp.compile if pattern can be compiled without an error
Returns Regexp.compile if pattern can be compiled without an error
*/
proc checkCompile(const pattern: ?t) throws where t == bytes || t == string {
try {
Expand All @@ -483,6 +483,113 @@ module SegmentedArray {
return try! compile(pattern);
}

/*
Given a SegString, finds pattern matches and returns pdarrays containing the number, start postitions, and lengths of matches
Note: the regular expression engine used, re2, does not support lookahead/lookbehind
:arg pattern: The regex pattern used to find matches
:type pattern: string
:returns: int64 pdarray – For each original string, the number of pattern matches and int64 pdarray – The start positons of pattern matches and int64 pdarray – The lengths of pattern matches
*/
proc findMatchLocations(const pattern: string) throws {
checkCompile(pattern);
ref origOffsets = this.offsets.a;
ref origVals = this.values.a;
const lengths = this.getLengths();

overMemLimit((this.offsets.size * numBytes(int)) + (2 * this.values.size * numBytes(int)));
var numMatches: [this.offsets.aD] int;
var matchStartBool: [this.values.aD] bool = false;
var sparseLens: [this.values.aD] int;

forall (i, off, len) in zip(this.offsets.aD, origOffsets, lengths) with (var myRegex = _unsafeCompileRegex(pattern),
var lenAgg = newDstAggregator(int),
var startAgg = newDstAggregator(bool),
var matchAgg = newDstAggregator(int)) {
var matches = myRegex.matches(interpretAsString(origVals[off..#len]));
for m in matches {
var match: reMatch = m[0];
lenAgg.copy(sparseLens[off + match.offset:int], match.size);
startAgg.copy(matchStartBool[off + match.offset:int], true);
}
matchAgg.copy(numMatches[i], matches.size);
}
var totalMatches = + reduce numMatches;
// check there's enough room to create a copy for scan and throw if creating a copy would go over memory limit
overMemLimit(numBytes(int) * matchStartBool.size);
// the matchTransform starts at 0 and increment after hitting a matchStart
// when looping over the origVals domain, matchTransform acts as a function: origVals.domain -> makeDistDom(totalMatches)
var matchTransform = + scan matchStartBool - matchStartBool;

var matchStarts: [makeDistDom(totalMatches)] int;
var matchLens: [makeDistDom(totalMatches)] int;
[i in this.values.aD] if (matchStartBool[i] == true) {
matchStarts[matchTransform[i]] = i;
matchLens[matchTransform[i]] = sparseLens[i];
}
return (numMatches, matchStarts, matchLens);
}

/*
Given a SegString, return a new SegString only containing matches of the regex pattern,
If returnMatchOrig is set to True, return a pdarray containing the index of the original string each pattern match is from
Note: the regular expression engine used, re2, does not support lookahead/lookbehind
:arg numMatchesEntry: For each string in SegString, the number of pattern matches
:type numMatchesEntry: borrowed SymEntry(int)
:arg startsEntry: The starting postions of pattern matches
:type startsEntry: borrowed SymEntry(int)
:arg lensEntry: The lengths of pattern matches
:type lensEntry: borrowed SymEntry(int)
:arg returnMatchOrig: If True, return a pdarray containing the index of the original string each pattern match is from
:type returnMatchOrig: bool
:returns: Strings – Only the portions of Strings which match pattern and (optional) int64 pdarray – For each pattern match, the index of the original string it was in
*/
proc findAllMatches(const numMatchesEntry: borrowed SymEntry(int), const startsEntry: borrowed SymEntry(int), const lensEntry: borrowed SymEntry(int), const returnMatchOrig: bool) throws {
ref origVals = this.values.a;
ref numMatches = numMatchesEntry.a;
ref matchStarts = startsEntry.a;
ref matchLens = lensEntry.a;

// matchesValsSize is the total length of all matches + the number of matches (to account for null bytes)
var matchesValsSize = (+ reduce matchLens) + matchLens.size;
// check there's enough room to create a copy for scan and to allocate matchesVals/Offsets
overMemLimit((matchesValsSize * numBytes(uint(8))) + (2 * matchLens.size * numBytes(int)));
var matchesVals: [makeDistDom(matchesValsSize)] uint(8);
var matchesOffsets: [makeDistDom(matchLens.size)] int;
// + current index to account for null bytes
var matchesIndicies = + scan matchLens - matchLens + lensEntry.aD;

forall (i, start, len, matchesInd) in zip(lensEntry.aD, matchStarts, matchLens, matchesIndicies) with (var valAgg = newDstAggregator(uint(8)), var offAgg = newDstAggregator(int)) {
for j in 0..#len {
// copy in match
valAgg.copy(matchesVals[matchesInd + j], origVals[start + j]);
}
// write null byte after each match
valAgg.copy(matchesVals[matchesInd + len], 0:uint(8));
if i == 0 {
offAgg.copy(matchesOffsets[i], 0);
}
if i != lensEntry.aD.high {
offAgg.copy(matchesOffsets[i+1], matchesInd + len + 1);
}
}

// build matchOrigins mapping from matchesStrings (pattern matches) to the original Strings they were found in
const matchOriginsDom = if returnMatchOrig then makeDistDom(matchesOffsets.size) else makeDistDom(0);
var matchOrigins: [matchOriginsDom] int;
if returnMatchOrig {
// check there's enough room to create a copy for scan and throw if creating a copy would go over memory limit
overMemLimit(numBytes(int) * numMatches.size);
var matchesIndicies = (+ scan numMatches) - numMatches;
forall (stringInd, matchInd) in zip(this.offsets.aD, matchesIndicies) with (var originAgg = newDstAggregator(int)) {
for k in matchInd..#numMatches[stringInd] {
// Each string has numMatches[stringInd] number of pattern matches, so matchOrigins needs to repeat stringInd for numMatches[stringInd] times
originAgg.copy(matchOrigins[k], stringInd);
}
}
}
return (matchesOffsets, matchesVals, matchOrigins);
}

/*
Returns list of bools where index i indicates whether the regular expression, pattern, matched string i of the SegString
Expand Down
84 changes: 84 additions & 0 deletions src/SegmentedMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,90 @@ module SegmentedMsg {
return new MsgTuple(repMsg, MsgType.NORMAL);
}

proc segmentedFindLocMsg(cmd: string, payload: string, st: borrowed SymTab): MsgTuple throws {
var pn = Reflection.getRoutineName();
var repMsg: string;
var (objtype, segName, valName, patternJson) = payload.splitMsgToTuple(4);

// check to make sure symbols defined
st.checkTable(segName);
st.checkTable(valName);

const json = jsonToPdArray(patternJson, 1);
const pattern: string = json[json.domain.low];

smLogger.debug(getModuleName(), getRoutineName(), getLineNumber(),
"cmd: %s objtype: %t".format(cmd, objtype));

select objtype {
when "str" {
const rNumMatchesName = st.nextName();
const rStartsName = st.nextName();
const rLensName = st.nextName();
var strings = getSegString(segName, valName, st);
var (numMatches, matchStarts, matchLens) = strings.findMatchLocations(pattern);
st.addEntry(rNumMatchesName, new shared SymEntry(numMatches));
st.addEntry(rStartsName, new shared SymEntry(matchStarts));
st.addEntry(rLensName, new shared SymEntry(matchLens));
repMsg = "created %s+created %s+created %s".format(st.attrib(rNumMatchesName),
st.attrib(rStartsName),
st.attrib(rLensName));
}
otherwise {
var errorMsg = "%s".format(objtype);
smLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg);
return new MsgTuple(notImplementedError(pn, errorMsg), MsgType.ERROR);
}
}
smLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}

proc segmentedFindAllMsg(cmd: string, payload: string, st: borrowed SymTab): MsgTuple throws {
var pn = Reflection.getRoutineName();
var repMsg: string;
var (objtype, segName, valName, numMatchesName, startsName, lensName, returnMatchOrigStr) = payload.splitMsgToTuple(7);
const returnMatchOrig: bool = returnMatchOrigStr.toLower() == "true";

// check to make sure symbols defined
st.checkTable(segName);
st.checkTable(valName);
st.checkTable(numMatchesName);
st.checkTable(startsName);
st.checkTable(lensName);

smLogger.debug(getModuleName(), getRoutineName(), getLineNumber(),
"cmd: %s objtype: %t".format(cmd, objtype));

select objtype {
when "str" {
const rSegName = st.nextName();
const rValName = st.nextName();
const optName: string = if returnMatchOrig then st.nextName() else "";
var strings = getSegString(segName, valName, st);
var numMatches = st.lookup(numMatchesName): borrowed SymEntry(int);
var starts = st.lookup(startsName): borrowed SymEntry(int);
var lens = st.lookup(lensName): borrowed SymEntry(int);

var (off, val, matchOrigins) = strings.findAllMatches(numMatches, starts, lens, returnMatchOrig);
st.addEntry(rSegName, new shared SymEntry(off));
st.addEntry(rValName, new shared SymEntry(val));
repMsg = "created %s+created %s".format(st.attrib(rSegName), st.attrib(rValName));
if returnMatchOrig {
st.addEntry(optName, new shared SymEntry(matchOrigins));
repMsg += "+created %s".format(st.attrib(optName));
}
}
otherwise {
var errorMsg = "%s".format(objtype);
smLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg);
return new MsgTuple(notImplementedError(pn, errorMsg), MsgType.ERROR);
}
}
smLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}

proc createPeelSymEntries(loname, lo, lvname, lv, roname, ro, rvname, rv, st: borrowed SymTab) throws {
st.addEntry(loname, new shared SymEntry(lo));
st.addEntry(lvname, new shared SymEntry(lv));
Expand Down
Loading

0 comments on commit dc41935

Please sign in to comment.