diff --git a/benches/json/Cargo.toml b/benches/json/Cargo.toml index c72e9d6..84b3803 100644 --- a/benches/json/Cargo.toml +++ b/benches/json/Cargo.toml @@ -17,7 +17,7 @@ lalrpop = "0.20.0" [dev-dependencies] criterion = { version = "0.4", features = ["html_reports"] } snmalloc-rs = { version = "0.3", features = ["build_cc"] } -pest = { version = "2.5.7", features = [ "std", "memchr" ] } +pest = { version = "2.5.7", features = ["std", "memchr"] } pest_derive = "2.5.7" lalrpop-util = { version = "0.20.0", features = ["lexer", "unicode"] } logos = "0.13.0" diff --git a/pag-lexer/src/lookahead.rs b/pag-lexer/src/lookahead.rs index 8f6c825..bc120d3 100644 --- a/pag-lexer/src/lookahead.rs +++ b/pag-lexer/src/lookahead.rs @@ -31,36 +31,27 @@ fn generate_lut_routine(index: usize) -> TokenStream { } } -fn byte_simd(byte: u8) -> TokenStream { - let byte = byte_char(byte); - quote! { data.simd_eq(u8x16::splat(#byte)) } -} - -fn range_simd(min: u8, max: u8) -> TokenStream { - let min = byte_char(min); - let max = byte_char(max); - quote! { data.simd_ge(u8x16::splat(#min)) & data.simd_le(u8x16::splat(#max)) } -} - +#[cfg(not(target_arch = "aarch64"))] fn generate_lookahead_routine(intervals: &Intervals, kind: Kind) -> TokenStream { - let count_act = match kind { - Kind::Positive => quote! { trailing_ones }, - Kind::Negative => quote! { trailing_zeros }, - }; - let idx_offset = intervals + let mask = intervals .iter() .map(|&Interval(l, r)| match l == r { - true => byte_simd(l), - false => range_simd(l, r), + true => { + let l = byte_char(l); + quote! { data.simd_eq(u8x16::splat(#l)) } + } + false => { + let l = byte_char(l); + let r = byte_char(r); + quote! { data.simd_ge(u8x16::splat(#l)) & data.simd_le(u8x16::splat(#r)) } + } }) .reduce(|acc, x| quote! { #acc | #x }) - .map(|x| { - if cfg!(target_arch = "aarch64") { - quote! { unsafe { core::mem::transmute::<_, u128>(#x).#count_act() / 8 } } - } else { - quote! { (#x).to_bitmask().#count_act() } - } - }); + .unwrap(); + let count_act = match kind { + Kind::Positive => quote! { trailing_ones }, + Kind::Negative => quote! { trailing_zeros }, + }; let tail_match = match kind { Kind::Positive => quote! { matches!(input.get(idx), Some(#intervals)) }, Kind::Negative => quote! { !matches!(input.get(idx), Some(#intervals) | None) }, @@ -70,7 +61,8 @@ fn generate_lookahead_routine(intervals: &Intervals, kind: Kind) -> TokenStream for chunk in input[idx..].chunks_exact(16) { use core::simd::*; let data = u8x16::from_slice(chunk); - let idx_offset = #idx_offset; + let mask = #mask; + let idx_offset = mask.to_bitmask().#count_act(); idx += idx_offset as usize; if idx_offset != 16 { break 'lookahead; @@ -83,10 +75,46 @@ fn generate_lookahead_routine(intervals: &Intervals, kind: Kind) -> TokenStream } } +#[cfg(target_arch = "aarch64")] +fn generate_lookahead_routine(intervals: &Intervals, kind: Kind) -> TokenStream { + let mask = intervals + .iter() + .map(|&Interval(l, r)| match l == r { + true => { + let l = byte_char(l); + quote! { data.simd_eq(u8x16::splat(#l)) } + } + false => { + let l = byte_char(l); + let r = byte_char(r); + quote! { data.simd_ge(u8x16::splat(#l)) & data.simd_le(u8x16::splat(#r)) } + } + }) + .reduce(|acc, x| quote! { #acc | #x }) + .unwrap(); + let count_act = match kind { + Kind::Positive => quote! { trailing_ones }, + Kind::Negative => quote! { trailing_zeros }, + }; + quote! { + for chunk in input[idx..].chunks_exact(16) { + use core::simd::*; + let data = u8x16::from_slice(chunk); + let mask = #mask; + let mask = unsafe { core::mem::transmute::<_, u128>(mask) }; + let idx_offset = mask.#count_act() / 8; + idx += idx_offset as usize; + if idx_offset != 16 { + break; + } + } + } +} + fn estimated_cost(intervals: &Intervals) -> u32 { intervals .iter() - .map(|Interval(l, r)| if l == r { 1 } else { 2 }) + .map(|Interval(l, r)| 1 + (l != r) as u32) .sum() } @@ -139,7 +167,7 @@ impl LoopOptimizer { } pub fn generate_lookahead(&mut self, dfa: &DfaTable, state: &DfaState) -> Option { - let limit = 8; + let limit = 4; let positives = direct_self_loops(dfa, state)?; let negatives = positives.complement()?; diff --git a/pag-lexer/src/vector.rs b/pag-lexer/src/vector.rs index 61488ae..e23ee50 100644 --- a/pag-lexer/src/vector.rs +++ b/pag-lexer/src/vector.rs @@ -159,6 +159,7 @@ impl Vector { return quote! { Some(#interval) => { cursor = idx + 1; #on_success }, }; } let target_id = dfa[target].state_id; + #[cfg(not(target_arch = "aarch64"))] if lookahead.is_some() && info.state_id == target_id { return quote! {}; }