Skip to content

Commit

Permalink
Merge pull request #1129 from cuviper/extend
Browse files Browse the repository at this point in the history
Extend the ParallelExtend implementations
  • Loading branch information
cuviper authored Feb 9, 2024
2 parents a9f676b + 5eac9aa commit d530ebb
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 121 deletions.
218 changes: 109 additions & 109 deletions src/iter/extend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,83 @@ use super::noop::NoopConsumer;
use super::plumbing::{Consumer, Folder, Reducer, UnindexedConsumer};
use super::{IntoParallelIterator, ParallelExtend, ParallelIterator};

use either::Either;
use std::borrow::Cow;
use std::collections::LinkedList;
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use std::collections::{BinaryHeap, VecDeque};
use std::ffi::{OsStr, OsString};
use std::hash::{BuildHasher, Hash};

/// Performs a generic `par_extend` by collecting to a `LinkedList<Vec<_>>` in
/// parallel, then extending the collection sequentially.
macro_rules! extend {
($self:ident, $par_iter:ident, $extend:ident) => {
$extend($self, drive_list_vec($par_iter));
($self:ident, $par_iter:ident) => {
extend!($self <- fast_collect($par_iter))
};
($self:ident <- $vecs:expr) => {
match $vecs {
Either::Left(vec) => $self.extend(vec),
Either::Right(list) => {
for vec in list {
$self.extend(vec);
}
}
}
};
}
macro_rules! extend_reserved {
($self:ident, $par_iter:ident, $len:ident) => {
let vecs = fast_collect($par_iter);
$self.reserve($len(&vecs));
extend!($self <- vecs)
};
($self:ident, $par_iter:ident) => {
extend_reserved!($self, $par_iter, len)
};
}

/// Computes the total length of a `fast_collect` result.
fn len<T>(vecs: &Either<Vec<T>, LinkedList<Vec<T>>>) -> usize {
match vecs {
Either::Left(vec) => vec.len(),
Either::Right(list) => list.iter().map(Vec::len).sum(),
}
}

/// Computes the total length of a `LinkedList<Vec<_>>`.
fn len<T>(list: &LinkedList<Vec<T>>) -> usize {
list.iter().map(Vec::len).sum()
/// Computes the total string length of a `fast_collect` result.
fn string_len<T: AsRef<str>>(vecs: &Either<Vec<T>, LinkedList<Vec<T>>>) -> usize {
let strs = match vecs {
Either::Left(vec) => Either::Left(vec.iter()),
Either::Right(list) => Either::Right(list.iter().flatten()),
};
strs.map(AsRef::as_ref).map(str::len).sum()
}

pub(super) fn drive_list_vec<I, T>(pi: I) -> LinkedList<Vec<T>>
/// Computes the total OS-string length of a `fast_collect` result.
fn osstring_len<T: AsRef<OsStr>>(vecs: &Either<Vec<T>, LinkedList<Vec<T>>>) -> usize {
let osstrs = match vecs {
Either::Left(vec) => Either::Left(vec.iter()),
Either::Right(list) => Either::Right(list.iter().flatten()),
};
osstrs.map(AsRef::as_ref).map(OsStr::len).sum()
}

pub(super) fn fast_collect<I, T>(pi: I) -> Either<Vec<T>, LinkedList<Vec<T>>>
where
I: IntoParallelIterator<Item = T>,
T: Send,
{
pi.into_par_iter().drive_unindexed(ListVecConsumer)
let par_iter = pi.into_par_iter();
match par_iter.opt_len() {
Some(len) => {
// Pseudo-specialization. See impl of ParallelExtend for Vec for more details.
let mut vec = Vec::new();
super::collect::special_extend(par_iter, len, &mut vec);
Either::Left(vec)
}
None => Either::Right(par_iter.drive_unindexed(ListVecConsumer)),
}
}

struct ListVecConsumer;
Expand Down Expand Up @@ -92,16 +144,6 @@ impl<T> Folder<T> for ListVecFolder<T> {
}
}

fn heap_extend<T, Item>(heap: &mut BinaryHeap<T>, list: LinkedList<Vec<Item>>)
where
BinaryHeap<T>: Extend<Item>,
{
heap.reserve(len(&list));
for vec in list {
heap.extend(vec);
}
}

/// Extends a binary heap with items from a parallel iterator.
impl<T> ParallelExtend<T> for BinaryHeap<T>
where
Expand All @@ -111,7 +153,7 @@ where
where
I: IntoParallelIterator<Item = T>,
{
extend!(self, par_iter, heap_extend);
extend_reserved!(self, par_iter);
}
}

Expand All @@ -124,16 +166,7 @@ where
where
I: IntoParallelIterator<Item = &'a T>,
{
extend!(self, par_iter, heap_extend);
}
}

fn btree_map_extend<K, V, Item>(map: &mut BTreeMap<K, V>, list: LinkedList<Vec<Item>>)
where
BTreeMap<K, V>: Extend<Item>,
{
for vec in list {
map.extend(vec);
extend_reserved!(self, par_iter);
}
}

Expand All @@ -147,7 +180,7 @@ where
where
I: IntoParallelIterator<Item = (K, V)>,
{
extend!(self, par_iter, btree_map_extend);
extend!(self, par_iter);
}
}

Expand All @@ -161,16 +194,7 @@ where
where
I: IntoParallelIterator<Item = (&'a K, &'a V)>,
{
extend!(self, par_iter, btree_map_extend);
}
}

fn btree_set_extend<T, Item>(set: &mut BTreeSet<T>, list: LinkedList<Vec<Item>>)
where
BTreeSet<T>: Extend<Item>,
{
for vec in list {
set.extend(vec);
extend!(self, par_iter);
}
}

Expand All @@ -183,7 +207,7 @@ where
where
I: IntoParallelIterator<Item = T>,
{
extend!(self, par_iter, btree_set_extend);
extend!(self, par_iter);
}
}

Expand All @@ -196,19 +220,7 @@ where
where
I: IntoParallelIterator<Item = &'a T>,
{
extend!(self, par_iter, btree_set_extend);
}
}

fn hash_map_extend<K, V, S, Item>(map: &mut HashMap<K, V, S>, list: LinkedList<Vec<Item>>)
where
HashMap<K, V, S>: Extend<Item>,
K: Eq + Hash,
S: BuildHasher,
{
map.reserve(len(&list));
for vec in list {
map.extend(vec);
extend!(self, par_iter);
}
}

Expand All @@ -224,7 +236,7 @@ where
I: IntoParallelIterator<Item = (K, V)>,
{
// See the map_collect benchmarks in rayon-demo for different strategies.
extend!(self, par_iter, hash_map_extend);
extend_reserved!(self, par_iter);
}
}

Expand All @@ -239,19 +251,7 @@ where
where
I: IntoParallelIterator<Item = (&'a K, &'a V)>,
{
extend!(self, par_iter, hash_map_extend);
}
}

fn hash_set_extend<T, S, Item>(set: &mut HashSet<T, S>, list: LinkedList<Vec<Item>>)
where
HashSet<T, S>: Extend<Item>,
T: Eq + Hash,
S: BuildHasher,
{
set.reserve(len(&list));
for vec in list {
set.extend(vec);
extend_reserved!(self, par_iter);
}
}

Expand All @@ -265,7 +265,7 @@ where
where
I: IntoParallelIterator<Item = T>,
{
extend!(self, par_iter, hash_set_extend);
extend_reserved!(self, par_iter);
}
}

Expand All @@ -279,7 +279,7 @@ where
where
I: IntoParallelIterator<Item = &'a T>,
{
extend!(self, par_iter, hash_set_extend);
extend_reserved!(self, par_iter);
}
}

Expand Down Expand Up @@ -380,9 +380,34 @@ impl<T> Reducer<LinkedList<T>> for ListReducer {
}
}

fn flat_string_extend(string: &mut String, list: LinkedList<String>) {
string.reserve(list.iter().map(String::len).sum());
string.extend(list);
/// Extends an OS-string with string slices from a parallel iterator.
impl<'a> ParallelExtend<&'a OsStr> for OsString {
fn par_extend<I>(&mut self, par_iter: I)
where
I: IntoParallelIterator<Item = &'a OsStr>,
{
extend_reserved!(self, par_iter, osstring_len);
}
}

/// Extends an OS-string with strings from a parallel iterator.
impl ParallelExtend<OsString> for OsString {
fn par_extend<I>(&mut self, par_iter: I)
where
I: IntoParallelIterator<Item = OsString>,
{
extend_reserved!(self, par_iter, osstring_len);
}
}

/// Extends an OS-string with string slices from a parallel iterator.
impl<'a> ParallelExtend<Cow<'a, OsStr>> for OsString {
fn par_extend<I>(&mut self, par_iter: I)
where
I: IntoParallelIterator<Item = Cow<'a, OsStr>>,
{
extend_reserved!(self, par_iter, osstring_len);
}
}

/// Extends a string with characters from a parallel iterator.
Expand All @@ -394,7 +419,8 @@ impl ParallelExtend<char> for String {
// This is like `extend`, but `Vec<char>` is less efficient to deal
// with than `String`, so instead collect to `LinkedList<String>`.
let list = par_iter.into_par_iter().drive_unindexed(ListStringConsumer);
flat_string_extend(self, list);
self.reserve(list.iter().map(String::len).sum());
self.extend(list);
}
}

Expand Down Expand Up @@ -473,25 +499,13 @@ impl Folder<char> for ListStringFolder {
}
}

fn string_extend<Item>(string: &mut String, list: LinkedList<Vec<Item>>)
where
String: Extend<Item>,
Item: AsRef<str>,
{
let len = list.iter().flatten().map(Item::as_ref).map(str::len).sum();
string.reserve(len);
for vec in list {
string.extend(vec);
}
}

/// Extends a string with string slices from a parallel iterator.
impl<'a> ParallelExtend<&'a str> for String {
fn par_extend<I>(&mut self, par_iter: I)
where
I: IntoParallelIterator<Item = &'a str>,
{
extend!(self, par_iter, string_extend);
extend_reserved!(self, par_iter, string_len);
}
}

Expand All @@ -501,7 +515,7 @@ impl ParallelExtend<String> for String {
where
I: IntoParallelIterator<Item = String>,
{
extend!(self, par_iter, string_extend);
extend_reserved!(self, par_iter, string_len);
}
}

Expand All @@ -511,7 +525,7 @@ impl ParallelExtend<Box<str>> for String {
where
I: IntoParallelIterator<Item = Box<str>>,
{
extend!(self, par_iter, string_extend);
extend_reserved!(self, par_iter, string_len);
}
}

Expand All @@ -521,17 +535,7 @@ impl<'a> ParallelExtend<Cow<'a, str>> for String {
where
I: IntoParallelIterator<Item = Cow<'a, str>>,
{
extend!(self, par_iter, string_extend);
}
}

fn deque_extend<T, Item>(deque: &mut VecDeque<T>, list: LinkedList<Vec<Item>>)
where
VecDeque<T>: Extend<Item>,
{
deque.reserve(len(&list));
for vec in list {
deque.extend(vec);
extend_reserved!(self, par_iter, string_len);
}
}

Expand All @@ -544,7 +548,7 @@ where
where
I: IntoParallelIterator<Item = T>,
{
extend!(self, par_iter, deque_extend);
extend_reserved!(self, par_iter);
}
}

Expand All @@ -557,14 +561,7 @@ where
where
I: IntoParallelIterator<Item = &'a T>,
{
extend!(self, par_iter, deque_extend);
}
}

fn vec_append<T>(vec: &mut Vec<T>, list: LinkedList<Vec<T>>) {
vec.reserve(len(&list));
for mut other in list {
vec.append(&mut other);
extend_reserved!(self, par_iter);
}
}

Expand All @@ -589,7 +586,10 @@ where
None => {
// This works like `extend`, but `Vec::append` is more efficient.
let list = par_iter.drive_unindexed(ListVecConsumer);
vec_append(self, list);
self.reserve(list.iter().map(Vec::len).sum());
for mut other in list {
self.append(&mut other);
}
}
}
}
Expand Down
Loading

0 comments on commit d530ebb

Please sign in to comment.