Skip to content

Commit

Permalink
Run rustfmt.
Browse files Browse the repository at this point in the history
  • Loading branch information
pmfirestone committed Dec 18, 2024
1 parent 0d8c9cc commit 9e19f50
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 87 deletions.
107 changes: 53 additions & 54 deletions src/dfa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use regex_automata::{
util::{primitives::StateID, start},
Anchored,
};
use std::collections::{HashMap, VecDeque};
use std::hash::{Hash, Hasher};
use std::collections::{VecDeque, HashMap};
use std::rc::Rc;

type DFACache = HashMap<String, Rc<DFA>>;
Expand All @@ -22,31 +22,31 @@ pub struct DFAState {
pub state_id: StateID,
}



/// Construct DFAs with caching. Only one of these should be instantiated in
/// the lifetime of the program.
pub struct DFABuilder {
cache: DFACache
cache: DFACache,
}

impl DFABuilder {
/// Initialize with an empty cache.
pub fn new() -> DFABuilder {
DFABuilder{cache: HashMap::new()}
DFABuilder {
cache: HashMap::new(),
}
}

/// Return a DFAState, either from the cache or building a new one from scratch.
// FIXME: Remove the clones from this function to accelerate it further.
pub fn build_dfa(&mut self, regex: &str) -> DFAState {
match self.cache.get(regex) {
Some(dfa) => return DFAState::new(regex, dfa.clone()),
None => {
let new_dfa = Rc::new(DFA::new(regex).unwrap());
self.cache.insert(String::from(regex), new_dfa.clone());
return DFAState::new(regex, new_dfa);
}
}
match self.cache.get(regex) {
Some(dfa) => return DFAState::new(regex, dfa.clone()),
None => {
let new_dfa = Rc::new(DFA::new(regex).unwrap());
self.cache.insert(String::from(regex), new_dfa.clone());
return DFAState::new(regex, new_dfa);
}
}
}
}

Expand All @@ -66,10 +66,10 @@ impl DFAState {

/// Convenience function to set the state how we want it.
pub fn advance(&mut self, input: &str) -> StateID {
for c in input.chars() {
self.consume_character(c);
}
self.state_id
for c in input.chars() {
self.consume_character(c);
}
self.state_id
}

/// Consume a character, starting at the current state, setting and
Expand All @@ -79,23 +79,23 @@ impl DFAState {
/// number of bytes long, and the underlying DFA has bytes as its input
/// alphabet.
pub fn consume_character(&mut self, c: char) -> StateID {
let char_len = c.len_utf8();
// Buffer to store character as bytes. UFT-8 characters are at most 4
// bytes long, so allocate a buffer big enough to store the whole
// character regardless of how long it turns out to be.
let mut buf = [0; 4];
c.encode_utf8(&mut buf);
for (i, &b) in buf.iter().enumerate() {
// The number of bytes per character is variable: we only need to
// feed the number of bytes that the character actually is into the
// DFA; any more would be incorrect. Break the loop once we've gone
// past the end of the character.
if i >= char_len {
break;
}
self.state_id = self.dfa.next_state(self.state_id, b);
}
self.state_id
let char_len = c.len_utf8();
// Buffer to store character as bytes. UFT-8 characters are at most 4
// bytes long, so allocate a buffer big enough to store the whole
// character regardless of how long it turns out to be.
let mut buf = [0; 4];
c.encode_utf8(&mut buf);
for (i, &b) in buf.iter().enumerate() {
// The number of bytes per character is variable: we only need to
// feed the number of bytes that the character actually is into the
// DFA; any more would be incorrect. Break the loop once we've gone
// past the end of the character.
if i >= char_len {
break;
}
self.state_id = self.dfa.next_state(self.state_id, b);
}
self.state_id
}

/// Return all states of a dfa by breadth-first search. There exists a private
Expand Down Expand Up @@ -167,47 +167,46 @@ pub fn all_dfa_states(terminals: &Vec<&str>) -> Vec<DFAState> {
res
}


#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_consume_character_match() {
let mut dfa_state = DFABuilder::new().build_dfa("a");
let mut state = dfa_state.consume_character('a');
state = dfa_state.dfa.next_eoi_state(state);
assert!(dfa_state.dfa.is_match_state(state));
let mut dfa_state = DFABuilder::new().build_dfa("a");
let mut state = dfa_state.consume_character('a');
state = dfa_state.dfa.next_eoi_state(state);
assert!(dfa_state.dfa.is_match_state(state));
}

#[test]
fn test_consume_character_fails_to_match() {
let mut dfa_state = DFABuilder::new().build_dfa("a");
let mut state = dfa_state.consume_character('b');
state = dfa_state.dfa.next_eoi_state(state);
assert!(!dfa_state.dfa.is_match_state(state));
let mut dfa_state = DFABuilder::new().build_dfa("a");
let mut state = dfa_state.consume_character('b');
state = dfa_state.dfa.next_eoi_state(state);
assert!(!dfa_state.dfa.is_match_state(state));
}

#[test]
fn test_advance_match() {
let mut dfa_state = DFABuilder::new().build_dfa("[ab¥]*");
let mut state = dfa_state.advance("aabb¥aab");
state = dfa_state.dfa.next_eoi_state(state);
assert!(dfa_state.dfa.is_match_state(state));
let mut dfa_state = DFABuilder::new().build_dfa("[ab¥]*");
let mut state = dfa_state.advance("aabb¥aab");
state = dfa_state.dfa.next_eoi_state(state);
assert!(dfa_state.dfa.is_match_state(state));
}

#[test]
fn test_advance_fails_to_match() {
let mut dfa_state = DFABuilder::new().build_dfa("[ab]*");
let mut state = dfa_state.advance("aabba¥ab");
state = dfa_state.dfa.next_eoi_state(state);
assert!(!dfa_state.dfa.is_match_state(state));
let mut dfa_state = DFABuilder::new().build_dfa("[ab]*");
let mut state = dfa_state.advance("aabba¥ab");
state = dfa_state.dfa.next_eoi_state(state);
assert!(!dfa_state.dfa.is_match_state(state));
}

#[test]
fn test_advance() {
let mut dfa_state = DFABuilder::new().build_dfa(r"[a-zA-Z_]*");
let state = dfa_state.advance("indeed");
assert!(dfa_state.dfa.is_match_state(state));
let mut dfa_state = DFABuilder::new().build_dfa(r"[a-zA-Z_]*");
let state = dfa_state.advance("indeed");
assert!(dfa_state.dfa.is_match_state(state));
}
}
68 changes: 35 additions & 33 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ impl Masker {
starting_state: &mut DFAState,
sequence_of_terminals: Vec<&str>,
) -> bool {
// println!("{} {}", string, starting_state.regex);
// println!("{} {}", string, starting_state.regex);

// We'll need this later.
// We'll need this later.
let initial_state = starting_state.state_id.clone();
let mut state: StateID;

Expand All @@ -45,37 +45,37 @@ impl Masker {
// grammars respect the maximum munch principle, so w1 is the maximal
// matching prefix.
starting_state.state_id = initial_state; // Reset to initial state.
let mut index_reached: usize = 0;
let mut index_reached: usize = 0;
for (i, c) in string.char_indices() {
state = starting_state.consume_character(c);
if starting_state.dfa.is_dead_state(state) | starting_state.dfa.is_quit_state(state) {
break;
}
if starting_state.dfa.is_dead_state(state) | starting_state.dfa.is_quit_state(state) {
break;
}

if starting_state.dfa.is_match_state(state) {
index_reached = i;
}
}

if starting_state.dfa.is_match_state(state) {
index_reached = i;
}
if index_reached > 0 && sequence_of_terminals.is_empty() {
return true;
}

if index_reached > 0 && sequence_of_terminals.is_empty() {
return true;
}

// Case 3: A prefix of the string is successfully consumed by the DFA, and
// dmatch is true starting at the next member of sequence_of_terminals.
starting_state.state_id = initial_state;
for (i, c) in string.char_indices() {
state = starting_state.consume_character(c);

if !starting_state.dfa.is_dead_state(state) {
// Keep munching as long as we're alive.
continue;
}
if !starting_state.dfa.is_dead_state(state) {
// Keep munching as long as we're alive.
continue;
}

if starting_state.dfa.is_dead_state(state) && i == 0 {
// We failed on the first character.
break;
}
if starting_state.dfa.is_dead_state(state) && i == 0 {
// We failed on the first character.
break;
}

// Handle case where we consume one character too many by slicing
// the string before the character we just saw, but only if we
Expand All @@ -84,7 +84,7 @@ impl Masker {
{
let mut new_dfa = self.dfa_builder.build_dfa(sequence_of_terminals[0]);
return self.dmatch(
&string.chars().skip(i-1).collect::<String>(),
&string.chars().skip(i - 1).collect::<String>(),
&mut new_dfa,
sequence_of_terminals[1..].to_vec(),
);
Expand All @@ -108,9 +108,9 @@ impl Masker {
) -> Vec<bool> {
let mut mask: Vec<bool> = Vec::new();
for token in vocabulary {
// Since the state is mutated by dmatch (potentially bad API design
// on my part), make a new one each time we try to match a token.
let mut starting_state = state.clone();
// Since the state is mutated by dmatch (potentially bad API design
// on my part), make a new one each time we try to match a token.
let mut starting_state = state.clone();
mask.push(self.dmatch(token, &mut starting_state, terminal_sequence.clone()));
}
mask
Expand Down Expand Up @@ -161,7 +161,9 @@ impl Masker {
}

fn new() -> Masker {
Masker{dfa_builder: DFABuilder::new()}
Masker {
dfa_builder: DFABuilder::new(),
}
}

// /// Implement algorithm 2 from the paper.
Expand Down Expand Up @@ -269,10 +271,10 @@ mod tests {

#[test]
fn test_dmatch_ugly_unicode_thing() {
// This is a nasty token from an actual LLM. They've played us for fools.
let mut masker = Masker::new();
let mut starting_state = masker.dfa_builder.build_dfa(r"(?i:0|[1-9]\d*)");
assert!(!masker.dmatch("ĠĠ", &mut starting_state, vec![]));
// This is a nasty token from an actual LLM. They've played us for fools.
let mut masker = Masker::new();
let mut starting_state = masker.dfa_builder.build_dfa(r"(?i:0|[1-9]\d*)");
assert!(!masker.dmatch("ĠĠ", &mut starting_state, vec![]));
}

#[test]
Expand Down Expand Up @@ -314,13 +316,13 @@ mod tests {

#[test]
fn test_dmatch_accepts_matching_input() {
let candidate_string = "indeed";
let accept_sequence = vec![r"\(", r"\)"];
let candidate_string = "indeed";
let accept_sequence = vec![r"\(", r"\)"];
let mut matcher = Masker {
dfa_builder: DFABuilder::new(),
};
let mut starting_state = matcher.dfa_builder.build_dfa(r"[a-zA-Z_]*");
assert!(matcher.dmatch(candidate_string, &mut starting_state, accept_sequence));
assert!(matcher.dmatch(candidate_string, &mut starting_state, accept_sequence));
}

#[test]
Expand Down

0 comments on commit 9e19f50

Please sign in to comment.