diff --git a/src/iter/extend.rs b/src/iter/extend.rs index 3e5a51efa..3d19d9d80 100644 --- a/src/iter/extend.rs +++ b/src/iter/extend.rs @@ -2,31 +2,83 @@ use super::noop::NoopConsumer; use super::plumbing::{Consumer, Folder, Reducer, UnindexedConsumer}; use super::{IntoParallelIterator, ParallelExtend, ParallelIterator}; +use either::Either; use std::borrow::Cow; use std::collections::LinkedList; use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; use std::collections::{BinaryHeap, VecDeque}; +use std::ffi::{OsStr, OsString}; use std::hash::{BuildHasher, Hash}; /// Performs a generic `par_extend` by collecting to a `LinkedList>` in /// parallel, then extending the collection sequentially. macro_rules! extend { - ($self:ident, $par_iter:ident, $extend:ident) => { - $extend($self, drive_list_vec($par_iter)); + ($self:ident, $par_iter:ident) => { + extend!($self <- fast_collect($par_iter)) }; + ($self:ident <- $vecs:expr) => { + match $vecs { + Either::Left(vec) => $self.extend(vec), + Either::Right(list) => { + for vec in list { + $self.extend(vec); + } + } + } + }; +} +macro_rules! extend_reserved { + ($self:ident, $par_iter:ident, $len:ident) => { + let vecs = fast_collect($par_iter); + $self.reserve($len(&vecs)); + extend!($self <- vecs) + }; + ($self:ident, $par_iter:ident) => { + extend_reserved!($self, $par_iter, len) + }; +} + +/// Computes the total length of a `fast_collect` result. +fn len(vecs: &Either, LinkedList>>) -> usize { + match vecs { + Either::Left(vec) => vec.len(), + Either::Right(list) => list.iter().map(Vec::len).sum(), + } } -/// Computes the total length of a `LinkedList>`. -fn len(list: &LinkedList>) -> usize { - list.iter().map(Vec::len).sum() +/// Computes the total string length of a `fast_collect` result. +fn string_len>(vecs: &Either, LinkedList>>) -> usize { + let strs = match vecs { + Either::Left(vec) => Either::Left(vec.iter()), + Either::Right(list) => Either::Right(list.iter().flatten()), + }; + strs.map(AsRef::as_ref).map(str::len).sum() } -pub(super) fn drive_list_vec(pi: I) -> LinkedList> +/// Computes the total OS-string length of a `fast_collect` result. +fn osstring_len>(vecs: &Either, LinkedList>>) -> usize { + let osstrs = match vecs { + Either::Left(vec) => Either::Left(vec.iter()), + Either::Right(list) => Either::Right(list.iter().flatten()), + }; + osstrs.map(AsRef::as_ref).map(OsStr::len).sum() +} + +pub(super) fn fast_collect(pi: I) -> Either, LinkedList>> where I: IntoParallelIterator, T: Send, { - pi.into_par_iter().drive_unindexed(ListVecConsumer) + let par_iter = pi.into_par_iter(); + match par_iter.opt_len() { + Some(len) => { + // Pseudo-specialization. See impl of ParallelExtend for Vec for more details. + let mut vec = Vec::new(); + super::collect::special_extend(par_iter, len, &mut vec); + Either::Left(vec) + } + None => Either::Right(par_iter.drive_unindexed(ListVecConsumer)), + } } struct ListVecConsumer; @@ -92,16 +144,6 @@ impl Folder for ListVecFolder { } } -fn heap_extend(heap: &mut BinaryHeap, list: LinkedList>) -where - BinaryHeap: Extend, -{ - heap.reserve(len(&list)); - for vec in list { - heap.extend(vec); - } -} - /// Extends a binary heap with items from a parallel iterator. impl ParallelExtend for BinaryHeap where @@ -111,7 +153,7 @@ where where I: IntoParallelIterator, { - extend!(self, par_iter, heap_extend); + extend_reserved!(self, par_iter); } } @@ -124,16 +166,7 @@ where where I: IntoParallelIterator, { - extend!(self, par_iter, heap_extend); - } -} - -fn btree_map_extend(map: &mut BTreeMap, list: LinkedList>) -where - BTreeMap: Extend, -{ - for vec in list { - map.extend(vec); + extend_reserved!(self, par_iter); } } @@ -147,7 +180,7 @@ where where I: IntoParallelIterator, { - extend!(self, par_iter, btree_map_extend); + extend!(self, par_iter); } } @@ -161,16 +194,7 @@ where where I: IntoParallelIterator, { - extend!(self, par_iter, btree_map_extend); - } -} - -fn btree_set_extend(set: &mut BTreeSet, list: LinkedList>) -where - BTreeSet: Extend, -{ - for vec in list { - set.extend(vec); + extend!(self, par_iter); } } @@ -183,7 +207,7 @@ where where I: IntoParallelIterator, { - extend!(self, par_iter, btree_set_extend); + extend!(self, par_iter); } } @@ -196,19 +220,7 @@ where where I: IntoParallelIterator, { - extend!(self, par_iter, btree_set_extend); - } -} - -fn hash_map_extend(map: &mut HashMap, list: LinkedList>) -where - HashMap: Extend, - K: Eq + Hash, - S: BuildHasher, -{ - map.reserve(len(&list)); - for vec in list { - map.extend(vec); + extend!(self, par_iter); } } @@ -224,7 +236,7 @@ where I: IntoParallelIterator, { // See the map_collect benchmarks in rayon-demo for different strategies. - extend!(self, par_iter, hash_map_extend); + extend_reserved!(self, par_iter); } } @@ -239,19 +251,7 @@ where where I: IntoParallelIterator, { - extend!(self, par_iter, hash_map_extend); - } -} - -fn hash_set_extend(set: &mut HashSet, list: LinkedList>) -where - HashSet: Extend, - T: Eq + Hash, - S: BuildHasher, -{ - set.reserve(len(&list)); - for vec in list { - set.extend(vec); + extend_reserved!(self, par_iter); } } @@ -265,7 +265,7 @@ where where I: IntoParallelIterator, { - extend!(self, par_iter, hash_set_extend); + extend_reserved!(self, par_iter); } } @@ -279,7 +279,7 @@ where where I: IntoParallelIterator, { - extend!(self, par_iter, hash_set_extend); + extend_reserved!(self, par_iter); } } @@ -380,9 +380,34 @@ impl Reducer> for ListReducer { } } -fn flat_string_extend(string: &mut String, list: LinkedList) { - string.reserve(list.iter().map(String::len).sum()); - string.extend(list); +/// Extends an OS-string with string slices from a parallel iterator. +impl<'a> ParallelExtend<&'a OsStr> for OsString { + fn par_extend(&mut self, par_iter: I) + where + I: IntoParallelIterator, + { + extend_reserved!(self, par_iter, osstring_len); + } +} + +/// Extends an OS-string with strings from a parallel iterator. +impl ParallelExtend for OsString { + fn par_extend(&mut self, par_iter: I) + where + I: IntoParallelIterator, + { + extend_reserved!(self, par_iter, osstring_len); + } +} + +/// Extends an OS-string with string slices from a parallel iterator. +impl<'a> ParallelExtend> for OsString { + fn par_extend(&mut self, par_iter: I) + where + I: IntoParallelIterator>, + { + extend_reserved!(self, par_iter, osstring_len); + } } /// Extends a string with characters from a parallel iterator. @@ -394,7 +419,8 @@ impl ParallelExtend for String { // This is like `extend`, but `Vec` is less efficient to deal // with than `String`, so instead collect to `LinkedList`. let list = par_iter.into_par_iter().drive_unindexed(ListStringConsumer); - flat_string_extend(self, list); + self.reserve(list.iter().map(String::len).sum()); + self.extend(list); } } @@ -473,25 +499,13 @@ impl Folder for ListStringFolder { } } -fn string_extend(string: &mut String, list: LinkedList>) -where - String: Extend, - Item: AsRef, -{ - let len = list.iter().flatten().map(Item::as_ref).map(str::len).sum(); - string.reserve(len); - for vec in list { - string.extend(vec); - } -} - /// Extends a string with string slices from a parallel iterator. impl<'a> ParallelExtend<&'a str> for String { fn par_extend(&mut self, par_iter: I) where I: IntoParallelIterator, { - extend!(self, par_iter, string_extend); + extend_reserved!(self, par_iter, string_len); } } @@ -501,7 +515,7 @@ impl ParallelExtend for String { where I: IntoParallelIterator, { - extend!(self, par_iter, string_extend); + extend_reserved!(self, par_iter, string_len); } } @@ -511,7 +525,7 @@ impl ParallelExtend> for String { where I: IntoParallelIterator>, { - extend!(self, par_iter, string_extend); + extend_reserved!(self, par_iter, string_len); } } @@ -521,17 +535,7 @@ impl<'a> ParallelExtend> for String { where I: IntoParallelIterator>, { - extend!(self, par_iter, string_extend); - } -} - -fn deque_extend(deque: &mut VecDeque, list: LinkedList>) -where - VecDeque: Extend, -{ - deque.reserve(len(&list)); - for vec in list { - deque.extend(vec); + extend_reserved!(self, par_iter, string_len); } } @@ -544,7 +548,7 @@ where where I: IntoParallelIterator, { - extend!(self, par_iter, deque_extend); + extend_reserved!(self, par_iter); } } @@ -557,14 +561,7 @@ where where I: IntoParallelIterator, { - extend!(self, par_iter, deque_extend); - } -} - -fn vec_append(vec: &mut Vec, list: LinkedList>) { - vec.reserve(len(&list)); - for mut other in list { - vec.append(&mut other); + extend_reserved!(self, par_iter); } } @@ -589,7 +586,10 @@ where None => { // This works like `extend`, but `Vec::append` is more efficient. let list = par_iter.drive_unindexed(ListVecConsumer); - vec_append(self, list); + self.reserve(list.iter().map(Vec::len).sum()); + for mut other in list { + self.append(&mut other); + } } } } diff --git a/src/iter/from_par_iter.rs b/src/iter/from_par_iter.rs index 49afd6cb8..993899b7d 100644 --- a/src/iter/from_par_iter.rs +++ b/src/iter/from_par_iter.rs @@ -5,6 +5,7 @@ use std::borrow::Cow; use std::collections::LinkedList; use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; use std::collections::{BinaryHeap, VecDeque}; +use std::ffi::{OsStr, OsString}; use std::hash::{BuildHasher, Hash}; use std::rc::Rc; use std::sync::Arc; @@ -235,6 +236,36 @@ impl<'a> FromParallelIterator> for String { } } +/// Collects OS-string slices from a parallel iterator into an OS-string. +impl<'a> FromParallelIterator<&'a OsStr> for OsString { + fn from_par_iter(par_iter: I) -> Self + where + I: IntoParallelIterator, + { + collect_extended(par_iter) + } +} + +/// Collects OS-strings from a parallel iterator into one large OS-string. +impl FromParallelIterator for OsString { + fn from_par_iter(par_iter: I) -> Self + where + I: IntoParallelIterator, + { + collect_extended(par_iter) + } +} + +/// Collects OS-string slices from a parallel iterator into an OS-string. +impl<'a> FromParallelIterator> for OsString { + fn from_par_iter(par_iter: I) -> Self + where + I: IntoParallelIterator>, + { + collect_extended(par_iter) + } +} + /// Collects an arbitrary `Cow` collection. /// /// Note, the standard library only has `FromIterator` for `Cow<'a, str>` and diff --git a/src/iter/mod.rs b/src/iter/mod.rs index 4b7289190..fb44b87f5 100644 --- a/src/iter/mod.rs +++ b/src/iter/mod.rs @@ -2376,15 +2376,15 @@ pub trait ParallelIterator: Sized + Send { /// assert_eq!(total_len, 2550); /// ``` fn collect_vec_list(self) -> LinkedList> { - match self.opt_len() { - Some(0) => LinkedList::new(), - Some(len) => { - // Pseudo-specialization. See impl of ParallelExtend for Vec for more details. - let mut v = Vec::new(); - collect::special_extend(self, len, &mut v); - LinkedList::from([v]) + match extend::fast_collect(self) { + Either::Left(vec) => { + let mut list = LinkedList::new(); + if !vec.is_empty() { + list.push_back(vec); + } + list } - None => extend::drive_list_vec(self), + Either::Right(list) => list, } } diff --git a/src/iter/test.rs b/src/iter/test.rs index 09bee1d86..3f9445025 100644 --- a/src/iter/test.rs +++ b/src/iter/test.rs @@ -11,6 +11,7 @@ use std::collections::LinkedList; use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; use std::collections::{BinaryHeap, VecDeque}; use std::f64; +use std::ffi::OsStr; use std::fmt::Debug; use std::sync::mpsc; use std::usize; @@ -1568,13 +1569,27 @@ fn par_iter_collect_cows() { assert_eq!(a, b); // Collects `str` into a `String` - let a: Cow<'_, str> = s.split_whitespace().collect(); - let b: Cow<'_, str> = s.par_split_whitespace().collect(); + let sw = s.split_whitespace(); + let psw = s.par_split_whitespace(); + let a: Cow<'_, str> = sw.clone().collect(); + let b: Cow<'_, str> = psw.clone().collect(); assert_eq!(a, b); // Collects `String` into a `String` - let a: Cow<'_, str> = s.split_whitespace().map(str::to_owned).collect(); - let b: Cow<'_, str> = s.par_split_whitespace().map(str::to_owned).collect(); + let a: Cow<'_, str> = sw.map(str::to_owned).collect(); + let b: Cow<'_, str> = psw.map(str::to_owned).collect(); + assert_eq!(a, b); + + // Collects `OsStr` into a `OsString` + let sw = s.split_whitespace().map(OsStr::new); + let psw = s.par_split_whitespace().map(OsStr::new); + let a: Cow<'_, OsStr> = Cow::Owned(sw.clone().collect()); + let b: Cow<'_, OsStr> = psw.clone().collect(); + assert_eq!(a, b); + + // Collects `OsString` into a `OsString` + let a: Cow<'_, OsStr> = Cow::Owned(sw.map(OsStr::to_owned).collect()); + let b: Cow<'_, OsStr> = psw.map(OsStr::to_owned).collect(); assert_eq!(a, b); }