Skip to content

Commit

Permalink
Define ChainedChannel{Sender, Receiver} wrappers
Browse files Browse the repository at this point in the history
  • Loading branch information
ryoqun committed Jan 9, 2024
1 parent 3abe938 commit 36632b4
Showing 1 changed file with 136 additions and 62 deletions.
198 changes: 136 additions & 62 deletions unified-scheduler-pool/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

use {
assert_matches::assert_matches,
crossbeam_channel::{select, unbounded, Receiver, Sender},
crossbeam_channel::{select, unbounded, Receiver, SendError, Sender},
log::*,
solana_ledger::blockstore_processor::{
execute_batch, TransactionBatchWithIndexes, TransactionStatusSender,
Expand Down Expand Up @@ -247,36 +247,132 @@ type ExecutedTaskPayload = SubchanneledPayload<Box<ExecutedTask>, ()>;
// minimum at the cost of a single heap allocation per switching for the sake of Box-ing the Self
// type to avoid infinite mem::size_of() due to the recursive type structure. Needless to say, such
// an allocation can be amortized to be negligible.
enum ChainedChannel<P1, P2> {
Payload(P1),
PayloadAndChannel(Box<dyn WithChannelAndPayload<P1, P2>>),
}
mod chained_channel {
use super::*;

trait WithChannelAndPayload<P1, P2>: Send + Sync {
fn payload_and_channel(self: Box<Self>) -> PayloadAndChannelInner<P1, P2>;
}
pub(super) enum ChainedChannel<P, C> {
Payload(P),
ContextAndChannel(Box<dyn WithContextAndPayload<P, C>>),
}

type PayloadAndChannelInner<P1, P2> = (P2, Receiver<ChainedChannel<P1, P2>>);
pub(super) trait WithContextAndPayload<P, C>: Send + Sync {
fn context_and_channel(self: Box<Self>) -> ContextAndChannelInner<P, C>;
}

struct PayloadAndChannelWrapper<P1, P2>(PayloadAndChannelInner<P1, P2>);
type ContextAndChannelInner<P, C> = (C, Receiver<ChainedChannel<P, C>>);

impl<P1, P2> WithChannelAndPayload<P1, P2> for PayloadAndChannelWrapper<P1, P2>
where
P1: Send + Sync,
P2: Send + Sync,
{
fn payload_and_channel(self: Box<Self>) -> PayloadAndChannelInner<P1, P2> {
self.0
struct ContextAndChannelWrapper<P, C>(ContextAndChannelInner<P, C>);

impl<P, C> WithContextAndPayload<P, C> for ContextAndChannelWrapper<P, C>
where
P: Send + Sync,
C: Send + Sync,
{
fn context_and_channel(self: Box<Self>) -> ContextAndChannelInner<P, C> {
self.0
}
}
}

impl<P1, P2> ChainedChannel<P1, P2>
where
P1: Send + Sync + 'static,
P2: Send + Sync + 'static,
{
fn chain_to_new_channel(payload: P2, receiver: Receiver<Self>) -> Self {
Self::PayloadAndChannel(Box::new(PayloadAndChannelWrapper((payload, receiver))))
impl<P, C> ChainedChannel<P, C>
where
P: Send + Sync + 'static,
C: Send + Sync + 'static,
{
fn chain_to_new_channel(context: C, receiver: Receiver<Self>) -> Self {
Self::ContextAndChannel(Box::new(ContextAndChannelWrapper((context, receiver))))
}
}

pub(super) struct ChainedChannelSender<P, C> {
sender: Sender<ChainedChannel<P, C>>,
}

impl<P, C> ChainedChannelSender<P, C>
where
P: Send + Sync + 'static,
C: Send + Sync + 'static + Clone,
{
fn new(sender: Sender<ChainedChannel<P, C>>) -> Self {
Self { sender }
}

pub(super) fn send_payload(
&self,
payload: P,
) -> std::result::Result<(), SendError<ChainedChannel<P, C>>> {
self.sender.send(ChainedChannel::Payload(payload))
}

pub(super) fn send_chained_channel(
&mut self,
context: C,
count: usize,
) -> std::result::Result<(), SendError<ChainedChannel<P, C>>> {
let (chained_sender, chained_receiver) = crossbeam_channel::unbounded();
for _ in 0..count {
self.sender.send(ChainedChannel::chain_to_new_channel(
context.clone(),
chained_receiver.clone(),
))?
}
self.sender = chained_sender;
Ok(())
}
}

pub(super) struct ChainedChannelReceiver<P, C: Clone> {
receiver: Receiver<ChainedChannel<P, C>>,
context: C,
}

impl<P, C: Clone> Clone for ChainedChannelReceiver<P, C> {
fn clone(&self) -> Self {
Self {
receiver: self.receiver.clone(),
context: self.context.clone(),
}
}
}

impl<P, C: Clone> ChainedChannelReceiver<P, C> {
fn new(receiver: Receiver<ChainedChannel<P, C>>, initial_context: C) -> Self {
Self {
receiver,
context: initial_context,
}
}

pub(super) fn context(&self) -> &C {
&self.context
}

pub(super) fn for_select(&self) -> &Receiver<ChainedChannel<P, C>> {
&self.receiver
}

pub(super) fn after_select(&mut self, message: ChainedChannel<P, C>) -> Option<P> {
match message {
ChainedChannel::Payload(payload) => Some(payload),
ChainedChannel::ContextAndChannel(new_context_and_channel) => {
(self.context, self.receiver) = new_context_and_channel.context_and_channel();
None
}
}
}
}

pub(super) fn unbounded<P, C>(
initial_context: C,
) -> (ChainedChannelSender<P, C>, ChainedChannelReceiver<P, C>)
where
P: Send + Sync + 'static,
C: Send + Sync + 'static + Clone,
{
let (sender, receiver) = crossbeam_channel::unbounded();
(
ChainedChannelSender::new(sender),
ChainedChannelReceiver::new(receiver, initial_context),
)
}
}

Expand Down Expand Up @@ -369,23 +465,6 @@ impl<S: SpawnableScheduler<TH>, TH: TaskHandler> ThreadManager<S, TH> {
);
}

fn propagate_context_to_handler_threads(
runnable_task_sender: &mut Sender<ChainedChannel<Task, SchedulingContext>>,
context: SchedulingContext,
handler_count: usize,
) {
let (next_sessioned_task_sender, runnable_task_receiver) = unbounded();
for _ in 0..handler_count {
runnable_task_sender
.send(ChainedChannel::chain_to_new_channel(
context.clone(),
runnable_task_receiver.clone(),
))
.unwrap();
}
*runnable_task_sender = next_sessioned_task_sender;
}

fn take_session_result_with_timings(&mut self) -> ResultWithTimings {
self.session_result_with_timings.take().unwrap()
}
Expand All @@ -399,8 +478,8 @@ impl<S: SpawnableScheduler<TH>, TH: TaskHandler> ThreadManager<S, TH> {
}

fn start_threads(&mut self, context: &SchedulingContext) {
let (runnable_task_sender, runnable_task_receiver) =
unbounded::<ChainedChannel<Task, SchedulingContext>>();
let (mut runnable_task_sender, runnable_task_receiver) =
chained_channel::unbounded::<Task, SchedulingContext>(context.clone());
let (executed_task_sender, executed_task_receiver) = unbounded::<ExecutedTaskPayload>();
let (finished_task_sender, finished_task_receiver) = unbounded::<Box<ExecutedTask>>();
let (accumulated_result_sender, accumulated_result_receiver) =
Expand All @@ -422,7 +501,6 @@ impl<S: SpawnableScheduler<TH>, TH: TaskHandler> ThreadManager<S, TH> {
let handler_count = self.handler_count;
let session_result_sender = self.session_result_sender.clone();
let new_task_receiver = self.new_task_receiver.clone();
let mut runnable_task_sender = runnable_task_sender.clone();

let mut session_ending = false;
let mut active_task_count: usize = 0;
Expand Down Expand Up @@ -481,17 +559,16 @@ impl<S: SpawnableScheduler<TH>, TH: TaskHandler> ThreadManager<S, TH> {
// be resolved in the case of single-threaded FIFO like this.
active_task_count = active_task_count.checked_add(1).unwrap();
runnable_task_sender
.send(ChainedChannel::Payload(task))
.send_payload(task)
.unwrap();
}
NewTaskPayload::OpenSubchannel(context) => {
// signal about new SchedulingContext to both handler and
// accumulator threads
Self::propagate_context_to_handler_threads(
&mut runnable_task_sender,
runnable_task_sender.send_chained_channel(
context,
handler_count
);
).unwrap();
executed_task_sender
.send(ExecutedTaskPayload::OpenSubchannel(()))
.unwrap();
Expand Down Expand Up @@ -528,28 +605,25 @@ impl<S: SpawnableScheduler<TH>, TH: TaskHandler> ThreadManager<S, TH> {

let handler_main_loop = || {
let pool = self.pool.clone();
let mut bank = context.bank().clone();
let mut runnable_task_receiver = runnable_task_receiver.clone();
let finished_task_sender = finished_task_sender.clone();

move || loop {
let (task, sender) = select! {
recv(runnable_task_receiver) -> message => {
match message.unwrap() {
ChainedChannel::Payload(task) => {
(task, &finished_task_sender)
}
ChainedChannel::PayloadAndChannel(new_channel) => {
let new_context;
(new_context, runnable_task_receiver) = new_channel.payload_and_channel();
bank = new_context.bank().clone();
continue;
}
recv(runnable_task_receiver.for_select()) -> message => {
if let Some(task) = runnable_task_receiver.after_select(message.unwrap()) {
(task, &finished_task_sender)
} else {
continue;
}
},
};
let mut task = ExecutedTask::new_boxed(task);
Self::execute_task_with_handler(&bank, &mut task, &pool.handler_context);
Self::execute_task_with_handler(
runnable_task_receiver.context().bank(),
&mut task,
&pool.handler_context,
);
sender.send(task).unwrap();
}
};
Expand Down

0 comments on commit 36632b4

Please sign in to comment.