diff --git a/src/lib.rs b/src/lib.rs index bae321c..4bbb0ce 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,9 +29,6 @@ pub use stack::Stack; pub mod tree_index; pub use tree_index::TreeIndex; -/// Re-exports the [`sdd`](https://crates.io/crates/sdd) crate for backward compatibility. -pub use sdd as ebr; - mod exit_guard; mod hash_table; mod wait_queue; @@ -39,7 +36,8 @@ mod wait_queue; #[cfg(not(feature = "equivalent"))] mod equivalent; -pub use equivalent::{Comparable, Equivalent}; +#[cfg(feature = "serde")] +mod serde; #[cfg(feature = "loom")] mod maybe_std { @@ -53,8 +51,32 @@ mod maybe_std { pub(crate) use std::thread::yield_now; } -#[cfg(feature = "serde")] -mod serde; +mod range_helper { + use crate::Comparable; + use std::ops::Bound::{Excluded, Included, Unbounded}; + use std::ops::RangeBounds; + + /// Emulates `RangeBounds::contains`. + pub(crate) fn contains>(range: &R, key: &K) -> bool + where + Q: Comparable + ?Sized, + { + (match range.start_bound() { + Included(start) => start.compare(key).is_le(), + Excluded(start) => start.compare(key).is_lt(), + Unbounded => true, + }) && (match range.end_bound() { + Included(end) => end.compare(key).is_ge(), + Excluded(end) => end.compare(key).is_gt(), + Unbounded => true, + }) + } +} + +/// Re-exports the [`sdd`](https://crates.io/crates/sdd) crate for backward compatibility. +pub use sdd as ebr; + +pub use equivalent::{Comparable, Equivalent}; #[cfg(test)] mod tests; diff --git a/src/tests/correctness.rs b/src/tests/correctness.rs index 68090e8..56bb8be 100644 --- a/src/tests/correctness.rs +++ b/src/tests/correctness.rs @@ -2158,11 +2158,10 @@ mod treeindex_test { } } - #[cfg_attr(miri, ignore)] #[test] fn mixed() { let range = if cfg!(miri) { 64 } else { 4096 }; - let num_threads = if cfg!(miri) { 4 } else { 16 }; + let num_threads = if cfg!(miri) { 2 } else { 16 }; let tree: Arc> = Arc::new(TreeIndex::new()); let barrier = Arc::new(Barrier::new(num_threads)); let mut thread_handles = Vec::with_capacity(num_threads); diff --git a/src/tree_index.rs b/src/tree_index.rs index 0f21d48..9816d73 100644 --- a/src/tree_index.rs +++ b/src/tree_index.rs @@ -508,7 +508,10 @@ where /// assert!(!treeindex.contains(&3)); /// ``` #[inline] - pub fn remove_range>(&self, range: R) { + pub fn remove_range>(&self, range: R) + where + Q: Comparable + ?Sized, + { let start_unbounded = matches!(range.start_bound(), Unbounded); let guard = Guard::new(); @@ -553,7 +556,10 @@ where /// let future_remove_range = treeindex.remove_range_async(3..8); /// ``` #[inline] - pub async fn remove_range_async>(&self, range: R) { + pub async fn remove_range_async>(&self, range: R) + where + Q: Comparable + ?Sized, + { let start_unbounded = matches!(range.start_bound(), Unbounded); loop { diff --git a/src/tree_index/internal_node.rs b/src/tree_index/internal_node.rs index 1b1fee2..5421350 100644 --- a/src/tree_index/internal_node.rs +++ b/src/tree_index/internal_node.rs @@ -471,7 +471,7 @@ where /// Returns the number of remaining children. #[allow(clippy::too_many_lines)] #[inline] - pub(super) fn remove_range<'g, R: RangeBounds, D: DeriveAsyncWait>( + pub(super) fn remove_range<'g, Q, R: RangeBounds, D: DeriveAsyncWait>( &self, range: &R, start_unbounded: bool, @@ -479,7 +479,10 @@ where valid_upper_min_node: Option<&'g Node>, async_wait: &mut D, guard: &'g Guard, - ) -> Result { + ) -> Result + where + Q: Comparable + ?Sized, + { debug_assert!(valid_lower_max_leaf.is_none() || start_unbounded); debug_assert!(valid_lower_max_leaf.is_none() || valid_upper_min_node.is_none()); diff --git a/src/tree_index/leaf.rs b/src/tree_index/leaf.rs index eacb448..c271e6d 100644 --- a/src/tree_index/leaf.rs +++ b/src/tree_index/leaf.rs @@ -1,7 +1,7 @@ use crate::ebr::{AtomicShared, Guard, Shared}; use crate::maybe_std::AtomicUsize; -use crate::Comparable; use crate::LinkedList; +use crate::{range_helper, Comparable}; use std::cell::UnsafeCell; use std::cmp::Ordering; use std::fmt::{self, Debug}; @@ -462,7 +462,10 @@ where /// /// Returns the number of remaining children. #[inline] - pub(super) fn remove_range>(&self, range: &R) { + pub(super) fn remove_range>(&self, range: &R) + where + Q: Comparable + ?Sized, + { let mut mutable_metadata = self.metadata.load(Acquire); for i in 0..DIMENSION.num_entries { if mutable_metadata == 0 { @@ -471,7 +474,7 @@ where let rank = mutable_metadata % (1_usize << DIMENSION.num_bits_per_entry); if rank != Dimension::uninit_rank() && rank != DIMENSION.removed_rank() { let k = self.key_at(i); - if range.contains(k) { + if range_helper::contains(range, k) { self.remove_if(k, &mut |_| true); } } diff --git a/src/tree_index/leaf_node.rs b/src/tree_index/leaf_node.rs index 9ddb4de..be3963c 100644 --- a/src/tree_index/leaf_node.rs +++ b/src/tree_index/leaf_node.rs @@ -5,8 +5,8 @@ use crate::ebr::{AtomicShared, Guard, Ptr, Shared, Tag}; use crate::exit_guard::ExitGuard; use crate::maybe_std::AtomicU8; use crate::wait_queue::{DeriveAsyncWait, WaitQueue}; -use crate::Comparable; use crate::LinkedList; +use crate::{range_helper, Comparable}; use std::borrow::Borrow; use std::cmp::Ordering::{Equal, Greater, Less}; use std::ops::{Bound, RangeBounds}; @@ -479,7 +479,7 @@ where /// /// Returns the number of remaining children. #[inline] - pub(super) fn remove_range<'g, R: RangeBounds, D: DeriveAsyncWait>( + pub(super) fn remove_range<'g, Q, R: RangeBounds, D: DeriveAsyncWait>( &self, range: &R, start_unbounded: bool, @@ -487,7 +487,10 @@ where valid_upper_min_node: Option<&'g Node>, async_wait: &mut D, guard: &'g Guard, - ) -> Result { + ) -> Result + where + Q: Comparable + ?Sized, + { debug_assert!(valid_lower_max_leaf.is_none() || start_unbounded); debug_assert!(valid_lower_max_leaf.is_none() || valid_upper_min_node.is_none()); @@ -1098,13 +1101,16 @@ impl<'n, K, V> Drop for Locker<'n, K, V> { impl RemoveRangeState { /// Returns the next state. - pub(super) fn next>( + pub(super) fn next>( self, key: &K, range: &R, start_unbounded: bool, - ) -> Self { - if range.contains(key) { + ) -> Self + where + Q: Comparable + ?Sized, + { + if range_helper::contains(range, key) { match self { RemoveRangeState::Below => { if start_unbounded { @@ -1121,14 +1127,13 @@ impl RemoveRangeState { } else { match self { RemoveRangeState::Below => match range.start_bound() { - Bound::Included(k) => match key.cmp(k) { - Less => RemoveRangeState::Below, - Equal => unreachable!(), - Greater => RemoveRangeState::MaybeAbove, + Bound::Included(k) => match k.compare(key) { + Less | Equal => RemoveRangeState::MaybeAbove, + Greater => RemoveRangeState::Below, }, - Bound::Excluded(k) => match key.cmp(k) { - Less | Equal => RemoveRangeState::Below, - Greater => RemoveRangeState::MaybeAbove, + Bound::Excluded(k) => match k.compare(key) { + Less => RemoveRangeState::MaybeAbove, + Greater | Equal => RemoveRangeState::Below, }, Bound::Unbounded => RemoveRangeState::MaybeAbove, }, diff --git a/src/tree_index/node.rs b/src/tree_index/node.rs index 8567ce6..718f695 100644 --- a/src/tree_index/node.rs +++ b/src/tree_index/node.rs @@ -158,7 +158,7 @@ where /// /// Returns the number of remaining children. #[inline] - pub(super) fn remove_range<'g, R: RangeBounds, D: DeriveAsyncWait>( + pub(super) fn remove_range<'g, Q, R: RangeBounds, D: DeriveAsyncWait>( &self, range: &R, start_unbounded: bool, @@ -166,7 +166,10 @@ where valid_upper_min_node: Option<&'g Node>, async_wait: &mut D, guard: &'g Guard, - ) -> Result { + ) -> Result + where + Q: Comparable + ?Sized, + { match &self { Self::Internal(internal_node) => internal_node.remove_range( range,