diff --git a/src/main.rs b/src/main.rs index c120139..97c888d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -17,6 +17,7 @@ use params::*; use vec::Vector; use hashbrown::{hash_set::Entry, HashMap, HashSet}; +use rayon::prelude::*; use seq_macro::seq; use std::time::Instant; @@ -167,6 +168,23 @@ fn find_binary_expressions_left( }); } +fn find_binary_expressions_left_multithread_dfs( + cache: &Cache, + hashset_cache: &HashSetCache, + n: usize, + k: usize, + er: &Expr, +) { + seq!(op_len in 1..=5 { + if n <= k + op_len { + return; + }; + cache[n - k - op_len].par_iter().for_each(|el| { + find_binary_operators(&mut CacheLevel::new(), cache, hashset_cache, n, el, er, op_len); + }); + }); +} + fn find_binary_expressions_right( cn: &mut CacheLevel, cache: &Cache, @@ -314,8 +332,6 @@ fn find_expressions_multithread( mut_hashset_cache: &mut HashSetCache, n: usize, ) { - use rayon::prelude::*; - let cache = &mut_cache; let hashset_cache = &mut_hashset_cache; @@ -324,26 +340,24 @@ fn find_expressions_multithread( .flat_map(|k| { cache[k].par_iter().map(move |r| { let mut cn = CacheLevel::new(); - find_binary_expressions_left(&mut cn, cache, hashset_cache, n, k, r); + if k == 1 && n > MAX_CACHE_LENGTH && n + 1 < MAX_LENGTH && r.var_mask != 0 { + find_binary_expressions_left_multithread_dfs(cache, hashset_cache, n, k, r); + } else { + find_binary_expressions_left(&mut cn, cache, hashset_cache, n, k, r); + } cn }) }) - .chain( - std::iter::once_with(|| { - let mut cn = CacheLevel::new(); - find_parens_expressions(&mut cn, cache, hashset_cache, n); - cn - }) - .par_bridge(), - ) - .chain( - std::iter::once_with(|| { - let mut cn = CacheLevel::new(); - find_unary_expressions(&mut cn, cache, hashset_cache, n); - cn - }) - .par_bridge(), - ) + .chain(rayon::iter::once(()).map(|()| { + let mut cn = CacheLevel::new(); + find_parens_expressions(&mut cn, cache, hashset_cache, n); + cn + })) + .chain(rayon::iter::once(()).map(|()| { + let mut cn = CacheLevel::new(); + find_unary_expressions(&mut cn, cache, hashset_cache, n); + cn + })) .flatten_iter() .collect();