Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #14

Merged
merged 6 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 67 additions & 66 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,6 @@ pub enum State {
/// The responder sent a response.
Responded = 4,

#[doc(hidden)]
CancelingRequested = 10,
#[doc(hidden)]
CancelingBuildingResponse = 11,
/// The requester canceled the request. Responder needs to acknowledge to return to `Idle`
/// state.
Canceled = 12,
}

Expand All @@ -171,11 +165,7 @@ impl From<u8> for State {
2 => State::Requested,
3 => State::BuildingResponse,
4 => State::Responded,

10 => State::CancelingRequested,
11 => State::CancelingBuildingResponse,
12 => State::Canceled,

_ => State::Idle,
}
}
Expand Down Expand Up @@ -401,6 +391,12 @@ impl<Rq, Rp> Channel<Rq, Rp> {
pub fn split(&self) -> Option<(Requester<'_, Rq, Rp>, Responder<'_, Rq, Rp>)> {
Some((self.requester()?, self.responder()?))
}

fn transition(&self, from: State, to: State) -> bool {
self.state
.compare_exchange(from as u8, to as u8, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
}
}

impl<Rq, Rp> Default for Channel<Rq, Rp> {
Expand All @@ -426,12 +422,8 @@ impl<'i, Rq, Rp> Drop for Requester<'i, Rq, Rp> {
}

impl<'i, Rq, Rp> Requester<'i, Rq, Rp> {
#[inline]
fn transition(&self, from: State, to: State) -> bool {
pub fn channel(&self) -> &'i Channel<Rq, Rp> {
self.channel
.state
.compare_exchange(from as u8, to as u8, Ordering::SeqCst, Ordering::SeqCst)
.is_ok()
}

#[cfg(not(loom))]
Expand Down Expand Up @@ -505,23 +497,17 @@ impl<'i, Rq, Rp> Requester<'i, Rq, Rp> {
///
/// In other cases (`Idle` or `Reponsed`) there is nothing to cancel and we fail.
pub fn cancel(&mut self) -> Result<Option<Rq>, Error> {
// we canceled before the responder was even aware of the request.
if self.transition(State::Requested, State::CancelingRequested) {
self.channel
.state
.store(State::Idle as u8, Ordering::Release);
return Ok(Some(unsafe { self.with_data_mut(|i| i.take_rq()) }));
if self
.channel
.transition(State::BuildingResponse, State::Canceled)
{
// we canceled after the responder took the request, but before they answered.
return Ok(None);
}

// we canceled after the responder took the request, but before they answered.
if self.transition(State::BuildingResponse, State::CancelingRequested) {
// this may not yet be None in case the responder switched state to
// BuildingResponse but did not take out the request yet.
// assert!(self.data.is_none());
self.channel
.state
.store(State::Canceled as u8, Ordering::Release);
return Ok(None);
if self.channel.transition(State::Requested, State::Idle) {
// we canceled before the responder was even aware of the request.
return Ok(Some(unsafe { self.with_data_mut(|i| i.take_rq()) }));
}

Err(Error)
Expand All @@ -534,7 +520,7 @@ impl<'i, Rq, Rp> Requester<'i, Rq, Rp> {
// this is likely correct
#[cfg(not(loom))]
pub fn response(&self) -> Result<&Rp, Error> {
if self.transition(State::Responded, State::Responded) {
if self.channel.transition(State::Responded, State::Responded) {
Ok(unsafe { self.data().rp_ref() })
} else {
Err(Error)
Expand All @@ -545,7 +531,7 @@ impl<'i, Rq, Rp> Requester<'i, Rq, Rp> {
///
/// This may be called multiple times.
pub fn with_response<R>(&self, f: impl FnOnce(&Rp) -> R) -> Result<R, Error> {
if self.transition(State::Responded, State::Responded) {
if self.channel.transition(State::Responded, State::Responded) {
Ok(unsafe { self.with_data(|i| f(i.rp_ref())) })
} else {
Err(Error)
Expand All @@ -560,7 +546,7 @@ impl<'i, Rq, Rp> Requester<'i, Rq, Rp> {
// It is a logic error to call this method if we're Idle or Canceled, but
// it seems unnecessary to model this.
pub fn take_response(&mut self) -> Option<Rp> {
if self.transition(State::Responded, State::Idle) {
if self.channel.transition(State::Responded, State::Idle) {
Some(unsafe { self.with_data_mut(|i| i.take_rp()) })
} else {
None
Expand All @@ -576,8 +562,10 @@ where
///
/// This is usefull to build large structures in-place
pub fn with_request_mut<R>(&mut self, f: impl FnOnce(&mut Rq) -> R) -> Result<R, Error> {
if self.transition(State::Idle, State::BuildingRequest)
|| self.transition(State::BuildingRequest, State::BuildingRequest)
if self.channel.transition(State::Idle, State::BuildingRequest)
|| self
.channel
.transition(State::BuildingRequest, State::BuildingRequest)
{
let res = unsafe {
self.with_data_mut(|i| {
Expand All @@ -600,8 +588,10 @@ where
// this is likely correct
#[cfg(not(loom))]
pub fn request_mut(&mut self) -> Result<&mut Rq, Error> {
if self.transition(State::Idle, State::BuildingRequest)
|| self.transition(State::BuildingRequest, State::BuildingRequest)
if self.channel.transition(State::Idle, State::BuildingRequest)
|| self
.channel
.transition(State::BuildingRequest, State::BuildingRequest)
{
unsafe {
self.with_data_mut(|i| {
Expand All @@ -620,7 +610,9 @@ where
/// `with_request_mut`.
pub fn send_request(&mut self) -> Result<(), Error> {
if State::BuildingRequest == self.channel.state.load(Ordering::Acquire)
&& self.transition(State::BuildingRequest, State::Requested)
&& self
.channel
.transition(State::BuildingRequest, State::Requested)
{
Ok(())
} else {
Expand All @@ -647,12 +639,8 @@ impl<'i, Rq, Rp> Drop for Responder<'i, Rq, Rp> {
}

impl<'i, Rq, Rp> Responder<'i, Rq, Rp> {
#[inline]
fn transition(&self, from: State, to: State) -> bool {
pub fn channel(&self) -> &'i Channel<Rq, Rp> {
self.channel
.state
.compare_exchange(from as u8, to as u8, Ordering::SeqCst, Ordering::SeqCst)
.is_ok()
}

#[cfg(not(loom))]
Expand Down Expand Up @@ -701,7 +689,10 @@ impl<'i, Rq, Rp> Responder<'i, Rq, Rp> {
/// This may be called only once as it move the state to BuildingResponse.
/// If you need copies, use `take_request`
pub fn with_request<R>(&self, f: impl FnOnce(&Rq) -> R) -> Result<R, Error> {
if self.transition(State::Requested, State::BuildingResponse) {
if self
.channel
.transition(State::Requested, State::BuildingResponse)
{
Ok(unsafe { self.with_data(|i| f(i.rq_ref())) })
} else {
Err(Error)
Expand All @@ -715,7 +706,10 @@ impl<'i, Rq, Rp> Responder<'i, Rq, Rp> {
// this is likely correct
#[cfg(not(loom))]
pub fn request(&self) -> Result<&Rq, Error> {
if self.transition(State::Requested, State::BuildingResponse) {
if self
.channel
.transition(State::Requested, State::BuildingResponse)
{
Ok(unsafe { self.data().rq_ref() })
} else {
Err(Error)
Expand All @@ -727,7 +721,10 @@ impl<'i, Rq, Rp> Responder<'i, Rq, Rp> {
/// This may be called only once as it move the state to BuildingResponse.
/// If you need copies, clone the request.
pub fn take_request(&mut self) -> Option<Rq> {
if self.transition(State::Requested, State::BuildingResponse) {
if self
.channel
.transition(State::Requested, State::BuildingResponse)
{
Some(unsafe { self.with_data_mut(|i| i.take_rq()) })
} else {
None
Expand All @@ -743,17 +740,7 @@ impl<'i, Rq, Rp> Responder<'i, Rq, Rp> {
//
// It is a logic error to call this method if there is no pending cancellation.
pub fn acknowledge_cancel(&self) -> Result<(), Error> {
if self
.channel
.state
.compare_exchange(
State::Canceled as u8,
State::Idle as u8,
Ordering::SeqCst,
Ordering::SeqCst,
)
.is_ok()
{
if self.channel.transition(State::Canceled, State::Idle) {
Ok(())
} else {
Err(Error)
Expand All @@ -770,10 +757,14 @@ impl<'i, Rq, Rp> Responder<'i, Rq, Rp> {
unsafe {
self.with_data_mut(|i| *i = Message::from_rp(response));
}
self.channel
.state
.store(State::Responded as u8, Ordering::Release);
Ok(())
if self
.channel
.transition(State::BuildingResponse, State::Responded)
{
Ok(())
} else {
Err(Error)
}
} else {
Err(Error)
}
Expand All @@ -788,8 +779,12 @@ where
///
/// This is usefull to build large structures in-place
pub fn with_response_mut<R>(&mut self, f: impl FnOnce(&mut Rp) -> R) -> Result<R, Error> {
if self.transition(State::Requested, State::BuildingResponse)
|| self.transition(State::BuildingResponse, State::BuildingResponse)
if self
.channel
.transition(State::Requested, State::BuildingResponse)
|| self
.channel
.transition(State::BuildingResponse, State::BuildingResponse)
{
let res = unsafe {
self.with_data_mut(|i| {
Expand All @@ -812,8 +807,12 @@ where
// this is likely correct
#[cfg(not(loom))]
pub fn response_mut(&mut self) -> Result<&mut Rp, Error> {
if self.transition(State::Requested, State::BuildingResponse)
|| self.transition(State::BuildingResponse, State::BuildingResponse)
if self
.channel
.transition(State::Requested, State::BuildingResponse)
|| self
.channel
.transition(State::BuildingResponse, State::BuildingResponse)
{
unsafe {
self.with_data_mut(|i| {
Expand All @@ -832,7 +831,9 @@ where
/// `with_response_mut`.
pub fn send_response(&mut self) -> Result<(), Error> {
if State::BuildingResponse == self.channel.state.load(Ordering::Acquire)
&& self.transition(State::BuildingResponse, State::Responded)
&& self
.channel
.transition(State::BuildingResponse, State::Responded)
{
Ok(())
} else {
Expand Down
57 changes: 44 additions & 13 deletions tests/loom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ use interchange::{Channel, Requester, Responder};
use std::sync::atomic::Ordering::Acquire;
use std::sync::atomic::{AtomicBool, Ordering::Release};

static BRANCHES_USED: [AtomicBool; 6] = {
static BRANCHES_USED: [AtomicBool; 14] = {
#[allow(clippy::declare_interior_mutable_const)]
const ATOMIC_BOOL_INIT: AtomicBool = AtomicBool::new(false);
[ATOMIC_BOOL_INIT; 6]
[ATOMIC_BOOL_INIT; 14]
};

#[cfg(loom)]
Expand Down Expand Up @@ -51,42 +51,73 @@ fn test_function() {

fn requester_thread(mut requester: Requester<'static, u64, u64>) -> Option<()> {
requester.request(53).unwrap();
requester.with_response(|r| assert_eq!(*r, 63)).ok()?;
requester.with_response(|r| assert_eq!(*r, 63)).ok()?;
match requester.cancel() {
Ok(Some(53) | None) => {
BRANCHES_USED[0].store(true, Release);
return None;
}
Ok(_) => panic!("Invalid state"),
Err(_) => {
BRANCHES_USED[1].store(true, Release);
}
}
requester
.with_response(|r| {
BRANCHES_USED[2].store(true, Release);
assert_eq!(*r, 63)
})
.ok()
.or_else(|| {
BRANCHES_USED[3].store(true, Release);
None
})?;
requester.with_response(|r| assert_eq!(*r, 63)).unwrap();
requester.take_response().unwrap();
requester.with_request_mut(|r| *r = 51).unwrap();
requester.send_request().unwrap();
thread::yield_now();
match requester.cancel() {
Ok(Some(51) | None) => BRANCHES_USED[0].store(true, Release),
Ok(Some(51) | None) => BRANCHES_USED[4].store(true, Release),
Ok(_) => panic!("Invalid state"),
Err(_) => {
BRANCHES_USED[1].store(true, Release);
assert_eq!(requester.take_response().unwrap(), 79);
BRANCHES_USED[5].store(true, Release);
match requester.take_response() {
Some(i) => {
assert_eq!(i, 79);
BRANCHES_USED[6].store(true, Release);
}
None => BRANCHES_USED[7].store(true, Release),
}
}
}
BRANCHES_USED[4].store(true, Release);
BRANCHES_USED[8].store(true, Release);
None
}

fn responder_thread(mut responder: Responder<'static, u64, u64>) -> Option<()> {
let req = responder.take_request()?;
let req = responder.take_request().or_else(|| {
BRANCHES_USED[9].store(true, Release);
None
})?;
assert_eq!(req, 53);
responder.respond(req + 10).unwrap();
responder.respond(req + 10).ok().or_else(|| {
BRANCHES_USED[10].store(true, Release);
None
})?;
thread::yield_now();
responder
.with_request(|r| {
BRANCHES_USED[2].store(true, Release);
BRANCHES_USED[11].store(true, Release);
assert_eq!(*r, 51)
})
.map(|_| assert!(responder.with_request(|_| {}).is_err()))
.or_else(|_| {
BRANCHES_USED[3].store(true, Release);
BRANCHES_USED[12].store(true, Release);
responder.acknowledge_cancel()
})
.ok()?;
responder.with_response_mut(|r| *r = 79).ok();
responder.send_response().ok();
BRANCHES_USED[5].store(true, Release);
BRANCHES_USED[13].store(true, Release);
None
}
Loading