From 748f6348360b6d69e2ecdfdbe378e087ebf3c444 Mon Sep 17 00:00:00 2001 From: Joe Caulfield Date: Wed, 9 Oct 2024 17:36:46 +0800 Subject: [PATCH 1/9] init separate prio scheduler crate --- Cargo.lock | 4 ++++ Cargo.toml | 2 ++ prio-graph-scheduler/Cargo.toml | 18 ++++++++++++++++++ prio-graph-scheduler/src/lib.rs | 1 + 4 files changed, 25 insertions(+) create mode 100644 prio-graph-scheduler/Cargo.toml create mode 100644 prio-graph-scheduler/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 62b8278cd7954b..00f39ba058bf60 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7209,6 +7209,10 @@ dependencies = [ "thiserror", ] +[[package]] +name = "solana-prio-graph-scheduler" +version = "2.1.0" + [[package]] name = "solana-program" version = "2.1.0" diff --git a/Cargo.toml b/Cargo.toml index 8e5c541548c2d8..271b30eb263613 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -69,6 +69,7 @@ members = [ "poh", "poh-bench", "poseidon", + "prio-graph-scheduler", "program-runtime", "program-test", "programs/address-lookup-table", @@ -431,6 +432,7 @@ solana-package-metadata-macro = { path = "sdk/package-metadata-macro", version = solana-perf = { path = "perf", version = "=2.1.0" } solana-poh = { path = "poh", version = "=2.1.0" } solana-poseidon = { path = "poseidon", version = "=2.1.0" } +solana-prio-graph-scheduler = { path = "prio-graph-scheduler", version = "=2.1.0" } solana-program = { path = "sdk/program", version = "=2.1.0", default-features = false } solana-program-error = { path = "sdk/program-error", version = "=2.1.0" } solana-program-memory = { path = "sdk/program-memory", version = "=2.1.0" } diff --git a/prio-graph-scheduler/Cargo.toml b/prio-graph-scheduler/Cargo.toml new file mode 100644 index 00000000000000..09c7da9c326a23 --- /dev/null +++ b/prio-graph-scheduler/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "solana-prio-graph-scheduler" +description = "Solana Priority Graph Scheduler" +documentation = "https://docs.rs/solana-prio-graph-scheduler" +version.workspace = true +authors.workspace = true +repository.workspace = true +homepage.workspace = true +license.workspace = true +edition.workspace = true + +[dependencies] + +[package.metadata.docs.rs] +targets = ["x86_64-unknown-linux-gnu"] + +[lints] +workspace = true diff --git a/prio-graph-scheduler/src/lib.rs b/prio-graph-scheduler/src/lib.rs new file mode 100644 index 00000000000000..3554f7eb23845a --- /dev/null +++ b/prio-graph-scheduler/src/lib.rs @@ -0,0 +1 @@ +//! Solana Priority Graph Scheduler. From c7a7c7636b178f227c7c2adf45cfa2549272a49d Mon Sep 17 00:00:00 2001 From: lewis Date: Fri, 11 Oct 2024 18:37:33 +0800 Subject: [PATCH 2/9] feat: extract new prio-graph-scheduler to reuse PrioGraphScheduler --- Cargo.lock | 5 + core/src/banking_stage.rs | 2 +- prio-graph-scheduler/Cargo.toml | 5 + prio-graph-scheduler/src/id_generator.rs | 21 + prio-graph-scheduler/src/in_flight_tracker.rs | 124 +++ prio-graph-scheduler/src/lib.rs | 5 + .../src/scheduler_messages.rs | 78 ++ .../src/thread_aware_account_locks.rs | 742 ++++++++++++++++++ prio-graph-scheduler/src/transaction_state.rs | 359 +++++++++ 9 files changed, 1340 insertions(+), 1 deletion(-) create mode 100644 prio-graph-scheduler/src/id_generator.rs create mode 100644 prio-graph-scheduler/src/in_flight_tracker.rs create mode 100644 prio-graph-scheduler/src/scheduler_messages.rs create mode 100644 prio-graph-scheduler/src/thread_aware_account_locks.rs create mode 100644 prio-graph-scheduler/src/transaction_state.rs diff --git a/Cargo.lock b/Cargo.lock index 00f39ba058bf60..c3e4bcbf4f0d07 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7212,6 +7212,11 @@ dependencies = [ [[package]] name = "solana-prio-graph-scheduler" version = "2.1.0" +dependencies = [ + "ahash 0.8.10", + "solana-core", + "solana-sdk", +] [[package]] name = "solana-program" diff --git a/core/src/banking_stage.rs b/core/src/banking_stage.rs index 32cc3fbe44dda1..78aec0b62b7e2d 100644 --- a/core/src/banking_stage.rs +++ b/core/src/banking_stage.rs @@ -66,7 +66,7 @@ mod consume_worker; mod decision_maker; mod forward_packet_batches_by_accounts; mod forward_worker; -mod immutable_deserialized_packet; +pub mod immutable_deserialized_packet; mod latest_unprocessed_votes; mod leader_slot_timing_metrics; mod multi_iterator_scanner; diff --git a/prio-graph-scheduler/Cargo.toml b/prio-graph-scheduler/Cargo.toml index 09c7da9c326a23..9bc4b097c82d19 100644 --- a/prio-graph-scheduler/Cargo.toml +++ b/prio-graph-scheduler/Cargo.toml @@ -10,6 +10,11 @@ license.workspace = true edition.workspace = true [dependencies] +solana-core = { workspace = true } +solana-sdk = { workspace = true } + +ahash = { workspace = true } + [package.metadata.docs.rs] targets = ["x86_64-unknown-linux-gnu"] diff --git a/prio-graph-scheduler/src/id_generator.rs b/prio-graph-scheduler/src/id_generator.rs new file mode 100644 index 00000000000000..0d6e5ee2098174 --- /dev/null +++ b/prio-graph-scheduler/src/id_generator.rs @@ -0,0 +1,21 @@ +use crate::scheduler_messages::TransactionId; + +/// Simple reverse-sequential ID generator for `TransactionId`s. +/// These IDs uniquely identify transactions during the scheduling process. +pub struct IdGenerator { + next_id: u64, +} + +impl Default for IdGenerator { + fn default() -> Self { + Self { next_id: u64::MAX } + } +} + +impl IdGenerator { + pub fn next>(&mut self) -> T { + let id = self.next_id; + self.next_id = self.next_id.wrapping_sub(1); + T::from(id) + } +} diff --git a/prio-graph-scheduler/src/in_flight_tracker.rs b/prio-graph-scheduler/src/in_flight_tracker.rs new file mode 100644 index 00000000000000..f23a7461cb5c3b --- /dev/null +++ b/prio-graph-scheduler/src/in_flight_tracker.rs @@ -0,0 +1,124 @@ +use { + crate::id_generator::IdGenerator, + crate::thread_aware_account_locks::ThreadId, + crate::scheduler_messages::TransactionBatchId, + std::collections::HashMap, +}; + +/// Tracks the number of transactions that are in flight for each thread. +pub struct InFlightTracker { + num_in_flight_per_thread: Vec, + cus_in_flight_per_thread: Vec, + batches: HashMap, + batch_id_generator: IdGenerator, +} + +struct BatchEntry { + thread_id: ThreadId, + num_transactions: usize, + total_cus: u64, +} + +impl InFlightTracker { + pub fn new(num_threads: usize) -> Self { + Self { + num_in_flight_per_thread: vec![0; num_threads], + cus_in_flight_per_thread: vec![0; num_threads], + batches: HashMap::new(), + batch_id_generator: IdGenerator::default(), + } + } + + /// Returns the number of transactions that are in flight for each thread. + pub fn num_in_flight_per_thread(&self) -> &[usize] { + &self.num_in_flight_per_thread + } + + /// Returns the number of cus that are in flight for each thread. + pub fn cus_in_flight_per_thread(&self) -> &[u64] { + &self.cus_in_flight_per_thread + } + + /// Tracks number of transactions and CUs in-flight for the `thread_id`. + /// Returns a `TransactionBatchId` that can be used to stop tracking the batch + /// when it is complete. + pub fn track_batch( + &mut self, + num_transactions: usize, + total_cus: u64, + thread_id: ThreadId, + ) -> TransactionBatchId { + let batch_id = self.batch_id_generator.next(); + self.num_in_flight_per_thread[thread_id] += num_transactions; + self.cus_in_flight_per_thread[thread_id] += total_cus; + self.batches.insert( + batch_id, + BatchEntry { + thread_id, + num_transactions, + total_cus, + }, + ); + + batch_id + } + + /// Stop tracking the batch with given `batch_id`. + /// Removes the number of transactions for the scheduled thread. + /// Returns the thread id that the batch was scheduled on. + /// + /// # Panics + /// Panics if the batch id does not exist in the tracker. + pub fn complete_batch(&mut self, batch_id: TransactionBatchId) -> ThreadId { + let Some(BatchEntry { + thread_id, + num_transactions, + total_cus, + }) = self.batches.remove(&batch_id) + else { + panic!("batch id {batch_id} is not being tracked"); + }; + self.num_in_flight_per_thread[thread_id] -= num_transactions; + self.cus_in_flight_per_thread[thread_id] -= total_cus; + + thread_id + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + #[should_panic(expected = "is not being tracked")] + fn test_in_flight_tracker_untracked_batch() { + let mut in_flight_tracker = InFlightTracker::new(2); + in_flight_tracker.complete_batch(TransactionBatchId::new(5)); + } + + #[test] + fn test_in_flight_tracker() { + let mut in_flight_tracker = InFlightTracker::new(2); + + // Add a batch with 2 transactions, 10 kCUs to thread 0. + let batch_id_0 = in_flight_tracker.track_batch(2, 10_000, 0); + assert_eq!(in_flight_tracker.num_in_flight_per_thread(), &[2, 0]); + assert_eq!(in_flight_tracker.cus_in_flight_per_thread(), &[10_000, 0]); + + // Add a batch with 1 transaction, 15 kCUs to thread 1. + let batch_id_1 = in_flight_tracker.track_batch(1, 15_000, 1); + assert_eq!(in_flight_tracker.num_in_flight_per_thread(), &[2, 1]); + assert_eq!( + in_flight_tracker.cus_in_flight_per_thread(), + &[10_000, 15_000] + ); + + in_flight_tracker.complete_batch(batch_id_0); + assert_eq!(in_flight_tracker.num_in_flight_per_thread(), &[0, 1]); + assert_eq!(in_flight_tracker.cus_in_flight_per_thread(), &[0, 15_000]); + + in_flight_tracker.complete_batch(batch_id_1); + assert_eq!(in_flight_tracker.num_in_flight_per_thread(), &[0, 0]); + assert_eq!(in_flight_tracker.cus_in_flight_per_thread(), &[0, 0]); + } +} diff --git a/prio-graph-scheduler/src/lib.rs b/prio-graph-scheduler/src/lib.rs index 3554f7eb23845a..927e9d77bbc084 100644 --- a/prio-graph-scheduler/src/lib.rs +++ b/prio-graph-scheduler/src/lib.rs @@ -1 +1,6 @@ //! Solana Priority Graph Scheduler. +pub mod transaction_state; +pub mod scheduler_messages; +pub mod id_generator; +pub mod in_flight_tracker; +pub mod thread_aware_account_locks; \ No newline at end of file diff --git a/prio-graph-scheduler/src/scheduler_messages.rs b/prio-graph-scheduler/src/scheduler_messages.rs new file mode 100644 index 00000000000000..b5e11be6ba9d78 --- /dev/null +++ b/prio-graph-scheduler/src/scheduler_messages.rs @@ -0,0 +1,78 @@ +use { + solana_core::banking_stage::immutable_deserialized_packet::ImmutableDeserializedPacket, + solana_sdk::{clock::Slot, transaction::SanitizedTransaction}, + std::{fmt::Display, sync::Arc}, +}; + +/// A unique identifier for a transaction batch. +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] +pub struct TransactionBatchId(u64); + +impl TransactionBatchId { + pub fn new(index: u64) -> Self { + Self(index) + } +} + +impl Display for TransactionBatchId { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From for TransactionBatchId { + fn from(id: u64) -> Self { + Self(id) + } +} + +/// A unique identifier for a transaction. +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct TransactionId(u64); + +impl TransactionId { + pub fn new(index: u64) -> Self { + Self(index) + } +} + +impl Display for TransactionId { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From for TransactionId { + fn from(id: u64) -> Self { + Self(id) + } +} + +/// Message: [Scheduler -> Worker] +/// Transactions to be consumed (i.e. executed, recorded, and committed) +pub struct ConsumeWork { + pub batch_id: TransactionBatchId, + pub ids: Vec, + pub transactions: Vec, + pub max_age_slots: Vec, +} + +/// Message: [Scheduler -> Worker] +/// Transactions to be forwarded to the next leader(s) +pub struct ForwardWork { + pub packets: Vec>, +} + +/// Message: [Worker -> Scheduler] +/// Processed transactions. +pub struct FinishedConsumeWork { + pub work: ConsumeWork, + pub retryable_indexes: Vec, +} + +/// Message: [Worker -> Scheduler] +/// Forwarded transactions. +pub struct FinishedForwardWork { + pub work: ForwardWork, + pub successful: bool, +} diff --git a/prio-graph-scheduler/src/thread_aware_account_locks.rs b/prio-graph-scheduler/src/thread_aware_account_locks.rs new file mode 100644 index 00000000000000..d5de72547008c3 --- /dev/null +++ b/prio-graph-scheduler/src/thread_aware_account_locks.rs @@ -0,0 +1,742 @@ +use { + ahash::AHashMap, + solana_sdk::pubkey::Pubkey, + std::{ + collections::hash_map::Entry, + fmt::{Debug, Display}, + ops::{BitAnd, BitAndAssign, Sub}, + }, +}; + +pub(crate) const MAX_THREADS: usize = u64::BITS as usize; + +/// Identifier for a thread +pub(crate) type ThreadId = usize; // 0..MAX_THREADS-1 + +type LockCount = u32; + +/// A bit-set of threads an account is scheduled or can be scheduled for. +#[derive(Copy, Clone, PartialEq, Eq)] +pub(crate) struct ThreadSet(u64); + +struct AccountWriteLocks { + thread_id: ThreadId, + lock_count: LockCount, +} + +struct AccountReadLocks { + thread_set: ThreadSet, + lock_counts: [LockCount; MAX_THREADS], +} + +/// Account locks. +/// Write Locks - only one thread can hold a write lock at a time. +/// Contains how many write locks are held by the thread. +/// Read Locks - multiple threads can hold a read lock at a time. +/// Contains thread-set for easily checking which threads are scheduled. +#[derive(Default)] +struct AccountLocks { + pub write_locks: Option, + pub read_locks: Option, +} + +/// Thread-aware account locks which allows for scheduling on threads +/// that already hold locks on the account. This is useful for allowing +/// queued transactions to be scheduled on a thread while the transaction +/// is still being executed on the thread. +pub(crate) struct ThreadAwareAccountLocks { + /// Number of threads. + num_threads: usize, // 0..MAX_THREADS + /// Locks for each account. An account should only have an entry if there + /// is at least one lock. + locks: AHashMap, +} + +impl ThreadAwareAccountLocks { + /// Creates a new `ThreadAwareAccountLocks` with the given number of threads. + pub(crate) fn new(num_threads: usize) -> Self { + assert!(num_threads > 0, "num threads must be > 0"); + assert!( + num_threads <= MAX_THREADS, + "num threads must be <= {MAX_THREADS}" + ); + + Self { + num_threads, + locks: AHashMap::new(), + } + } + + /// Returns the `ThreadId` if the accounts are able to be locked + /// for the given thread, otherwise `None` is returned. + /// `allowed_threads` is a set of threads that the caller restricts locking to. + /// If accounts are schedulable, then they are locked for the thread + /// selected by the `thread_selector` function. + /// `thread_selector` is only called if all accounts are schdulable, meaning + /// that the `thread_set` passed to `thread_selector` is non-empty. + pub(crate) fn try_lock_accounts<'a>( + &mut self, + write_account_locks: impl Iterator + Clone, + read_account_locks: impl Iterator + Clone, + allowed_threads: ThreadSet, + thread_selector: impl FnOnce(ThreadSet) -> ThreadId, + ) -> Option { + let schedulable_threads = self.accounts_schedulable_threads( + write_account_locks.clone(), + read_account_locks.clone(), + )? & allowed_threads; + (!schedulable_threads.is_empty()).then(|| { + let thread_id = thread_selector(schedulable_threads); + self.lock_accounts(write_account_locks, read_account_locks, thread_id); + thread_id + }) + } + + /// Unlocks the accounts for the given thread. + pub(crate) fn unlock_accounts<'a>( + &mut self, + write_account_locks: impl Iterator, + read_account_locks: impl Iterator, + thread_id: ThreadId, + ) { + for account in write_account_locks { + self.write_unlock_account(account, thread_id); + } + + for account in read_account_locks { + self.read_unlock_account(account, thread_id); + } + } + + /// Returns `ThreadSet` that the given accounts can be scheduled on. + fn accounts_schedulable_threads<'a>( + &self, + write_account_locks: impl Iterator, + read_account_locks: impl Iterator, + ) -> Option { + let mut schedulable_threads = ThreadSet::any(self.num_threads); + + for account in write_account_locks { + schedulable_threads &= self.write_schedulable_threads(account); + if schedulable_threads.is_empty() { + return None; + } + } + + for account in read_account_locks { + schedulable_threads &= self.read_schedulable_threads(account); + if schedulable_threads.is_empty() { + return None; + } + } + + Some(schedulable_threads) + } + + /// Returns `ThreadSet` of schedulable threads for the given readable account. + fn read_schedulable_threads(&self, account: &Pubkey) -> ThreadSet { + self.schedulable_threads::(account) + } + + /// Returns `ThreadSet` of schedulable threads for the given writable account. + fn write_schedulable_threads(&self, account: &Pubkey) -> ThreadSet { + self.schedulable_threads::(account) + } + + /// Returns `ThreadSet` of schedulable threads. + /// If there are no locks, then all threads are schedulable. + /// If only write-locked, then only the thread holding the write lock is schedulable. + /// If a mix of locks, then only the write thread is schedulable. + /// If only read-locked, the only write-schedulable thread is if a single thread + /// holds all read locks. Otherwise, no threads are write-schedulable. + /// If only read-locked, all threads are read-schedulable. + fn schedulable_threads(&self, account: &Pubkey) -> ThreadSet { + match self.locks.get(account) { + None => ThreadSet::any(self.num_threads), + Some(AccountLocks { + write_locks: None, + read_locks: Some(read_locks), + }) => { + if WRITE { + read_locks + .thread_set + .only_one_contained() + .map(ThreadSet::only) + .unwrap_or_else(ThreadSet::none) + } else { + ThreadSet::any(self.num_threads) + } + } + Some(AccountLocks { + write_locks: Some(write_locks), + read_locks: None, + }) => ThreadSet::only(write_locks.thread_id), + Some(AccountLocks { + write_locks: Some(write_locks), + read_locks: Some(read_locks), + }) => { + assert_eq!( + read_locks.thread_set.only_one_contained(), + Some(write_locks.thread_id) + ); + read_locks.thread_set + } + Some(AccountLocks { + write_locks: None, + read_locks: None, + }) => unreachable!(), + } + } + + /// Add locks for all writable and readable accounts on `thread_id`. + fn lock_accounts<'a>( + &mut self, + write_account_locks: impl Iterator, + read_account_locks: impl Iterator, + thread_id: ThreadId, + ) { + assert!( + thread_id < self.num_threads, + "thread_id must be < num_threads" + ); + for account in write_account_locks { + self.write_lock_account(account, thread_id); + } + + for account in read_account_locks { + self.read_lock_account(account, thread_id); + } + } + + /// Locks the given `account` for writing on `thread_id`. + /// Panics if the account is already locked for writing on another thread. + fn write_lock_account(&mut self, account: &Pubkey, thread_id: ThreadId) { + let entry = self.locks.entry(*account).or_default(); + + let AccountLocks { + write_locks, + read_locks, + } = entry; + + if let Some(read_locks) = read_locks { + assert_eq!( + read_locks.thread_set.only_one_contained(), + Some(thread_id), + "outstanding read lock must be on same thread" + ); + } + + if let Some(write_locks) = write_locks { + assert_eq!( + write_locks.thread_id, thread_id, + "outstanding write lock must be on same thread" + ); + write_locks.lock_count += 1; + } else { + *write_locks = Some(AccountWriteLocks { + thread_id, + lock_count: 1, + }); + } + } + + /// Unlocks the given `account` for writing on `thread_id`. + /// Panics if the account is not locked for writing on `thread_id`. + fn write_unlock_account(&mut self, account: &Pubkey, thread_id: ThreadId) { + let Entry::Occupied(mut entry) = self.locks.entry(*account) else { + panic!("write lock must exist for account: {account}"); + }; + + let AccountLocks { + write_locks: maybe_write_locks, + read_locks, + } = entry.get_mut(); + + let Some(write_locks) = maybe_write_locks else { + panic!("write lock must exist for account: {account}"); + }; + + assert_eq!( + write_locks.thread_id, thread_id, + "outstanding write lock must be on same thread" + ); + + write_locks.lock_count -= 1; + if write_locks.lock_count == 0 { + *maybe_write_locks = None; + if read_locks.is_none() { + entry.remove(); + } + } + } + + /// Locks the given `account` for reading on `thread_id`. + /// Panics if the account is already locked for writing on another thread. + fn read_lock_account(&mut self, account: &Pubkey, thread_id: ThreadId) { + let AccountLocks { + write_locks, + read_locks, + } = self.locks.entry(*account).or_default(); + + if let Some(write_locks) = write_locks { + assert_eq!( + write_locks.thread_id, thread_id, + "outstanding write lock must be on same thread" + ); + } + + match read_locks { + Some(read_locks) => { + read_locks.thread_set.insert(thread_id); + read_locks.lock_counts[thread_id] += 1; + } + None => { + let mut lock_counts = [0; MAX_THREADS]; + lock_counts[thread_id] = 1; + *read_locks = Some(AccountReadLocks { + thread_set: ThreadSet::only(thread_id), + lock_counts, + }); + } + } + } + + /// Unlocks the given `account` for reading on `thread_id`. + /// Panics if the account is not locked for reading on `thread_id`. + fn read_unlock_account(&mut self, account: &Pubkey, thread_id: ThreadId) { + let Entry::Occupied(mut entry) = self.locks.entry(*account) else { + panic!("read lock must exist for account: {account}"); + }; + + let AccountLocks { + write_locks, + read_locks: maybe_read_locks, + } = entry.get_mut(); + + let Some(read_locks) = maybe_read_locks else { + panic!("read lock must exist for account: {account}"); + }; + + assert!( + read_locks.thread_set.contains(thread_id), + "outstanding read lock must be on same thread" + ); + + read_locks.lock_counts[thread_id] -= 1; + if read_locks.lock_counts[thread_id] == 0 { + read_locks.thread_set.remove(thread_id); + if read_locks.thread_set.is_empty() { + *maybe_read_locks = None; + if write_locks.is_none() { + entry.remove(); + } + } + } + } +} + +impl BitAnd for ThreadSet { + type Output = Self; + + fn bitand(self, rhs: Self) -> Self::Output { + Self(self.0 & rhs.0) + } +} + +impl BitAndAssign for ThreadSet { + fn bitand_assign(&mut self, rhs: Self) { + self.0 &= rhs.0; + } +} + +impl Sub for ThreadSet { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + Self(self.0 & !rhs.0) + } +} + +impl Display for ThreadSet { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "ThreadSet({:#0width$b})", self.0, width = MAX_THREADS) + } +} + +impl Debug for ThreadSet { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Display::fmt(self, f) + } +} + +impl ThreadSet { + #[inline(always)] + pub(crate) const fn none() -> Self { + Self(0b0) + } + + #[inline(always)] + pub(crate) const fn any(num_threads: usize) -> Self { + if num_threads == MAX_THREADS { + Self(u64::MAX) + } else { + Self(Self::as_flag(num_threads) - 1) + } + } + + #[inline(always)] + pub(crate) const fn only(thread_id: ThreadId) -> Self { + Self(Self::as_flag(thread_id)) + } + + #[inline(always)] + pub(crate) fn num_threads(&self) -> u32 { + self.0.count_ones() + } + + #[inline(always)] + pub(crate) fn only_one_contained(&self) -> Option { + (self.num_threads() == 1).then_some(self.0.trailing_zeros() as ThreadId) + } + + #[inline(always)] + pub(crate) fn is_empty(&self) -> bool { + self == &Self::none() + } + + #[inline(always)] + pub(crate) fn contains(&self, thread_id: ThreadId) -> bool { + self.0 & Self::as_flag(thread_id) != 0 + } + + #[inline(always)] + pub(crate) fn insert(&mut self, thread_id: ThreadId) { + self.0 |= Self::as_flag(thread_id); + } + + #[inline(always)] + pub(crate) fn remove(&mut self, thread_id: ThreadId) { + self.0 &= !Self::as_flag(thread_id); + } + + #[inline(always)] + pub(crate) fn contained_threads_iter(self) -> impl Iterator { + (0..MAX_THREADS).filter(move |thread_id| self.contains(*thread_id)) + } + + #[inline(always)] + const fn as_flag(thread_id: ThreadId) -> u64 { + 0b1 << thread_id + } +} + +#[cfg(test)] +mod tests { + use super::*; + + const TEST_NUM_THREADS: usize = 4; + const TEST_ANY_THREADS: ThreadSet = ThreadSet::any(TEST_NUM_THREADS); + + // Simple thread selector to select the first schedulable thread + fn test_thread_selector(thread_set: ThreadSet) -> ThreadId { + thread_set.contained_threads_iter().next().unwrap() + } + + #[test] + #[should_panic(expected = "num threads must be > 0")] + fn test_too_few_num_threads() { + ThreadAwareAccountLocks::new(0); + } + + #[test] + #[should_panic(expected = "num threads must be <=")] + fn test_too_many_num_threads() { + ThreadAwareAccountLocks::new(MAX_THREADS + 1); + } + + #[test] + fn test_try_lock_accounts_none() { + let pk1 = Pubkey::new_unique(); + let pk2 = Pubkey::new_unique(); + let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + locks.read_lock_account(&pk1, 2); + locks.read_lock_account(&pk1, 3); + assert_eq!( + locks.try_lock_accounts( + [&pk1].into_iter(), + [&pk2].into_iter(), + TEST_ANY_THREADS, + test_thread_selector + ), + None + ); + } + + #[test] + fn test_try_lock_accounts_one() { + let pk1 = Pubkey::new_unique(); + let pk2 = Pubkey::new_unique(); + let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + locks.write_lock_account(&pk2, 3); + + assert_eq!( + locks.try_lock_accounts( + [&pk1].into_iter(), + [&pk2].into_iter(), + TEST_ANY_THREADS, + test_thread_selector + ), + Some(3) + ); + } + + #[test] + fn test_try_lock_accounts_multiple() { + let pk1 = Pubkey::new_unique(); + let pk2 = Pubkey::new_unique(); + let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + locks.read_lock_account(&pk2, 0); + locks.read_lock_account(&pk2, 0); + + assert_eq!( + locks.try_lock_accounts( + [&pk1].into_iter(), + [&pk2].into_iter(), + TEST_ANY_THREADS - ThreadSet::only(0), // exclude 0 + test_thread_selector + ), + Some(1) + ); + } + + #[test] + fn test_try_lock_accounts_any() { + let pk1 = Pubkey::new_unique(); + let pk2 = Pubkey::new_unique(); + let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + assert_eq!( + locks.try_lock_accounts( + [&pk1].into_iter(), + [&pk2].into_iter(), + TEST_ANY_THREADS, + test_thread_selector + ), + Some(0) + ); + } + + #[test] + fn test_accounts_schedulable_threads_no_outstanding_locks() { + let pk1 = Pubkey::new_unique(); + let locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + + assert_eq!( + locks.accounts_schedulable_threads([&pk1].into_iter(), std::iter::empty()), + Some(TEST_ANY_THREADS) + ); + assert_eq!( + locks.accounts_schedulable_threads(std::iter::empty(), [&pk1].into_iter()), + Some(TEST_ANY_THREADS) + ); + } + + #[test] + fn test_accounts_schedulable_threads_outstanding_write_only() { + let pk1 = Pubkey::new_unique(); + let pk2 = Pubkey::new_unique(); + let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + + locks.write_lock_account(&pk1, 2); + assert_eq!( + locks.accounts_schedulable_threads([&pk1, &pk2].into_iter(), std::iter::empty()), + Some(ThreadSet::only(2)) + ); + assert_eq!( + locks.accounts_schedulable_threads(std::iter::empty(), [&pk1, &pk2].into_iter()), + Some(ThreadSet::only(2)) + ); + } + + #[test] + fn test_accounts_schedulable_threads_outstanding_read_only() { + let pk1 = Pubkey::new_unique(); + let pk2 = Pubkey::new_unique(); + let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + + locks.read_lock_account(&pk1, 2); + assert_eq!( + locks.accounts_schedulable_threads([&pk1, &pk2].into_iter(), std::iter::empty()), + Some(ThreadSet::only(2)) + ); + assert_eq!( + locks.accounts_schedulable_threads(std::iter::empty(), [&pk1, &pk2].into_iter()), + Some(TEST_ANY_THREADS) + ); + + locks.read_lock_account(&pk1, 0); + assert_eq!( + locks.accounts_schedulable_threads([&pk1, &pk2].into_iter(), std::iter::empty()), + None + ); + assert_eq!( + locks.accounts_schedulable_threads(std::iter::empty(), [&pk1, &pk2].into_iter()), + Some(TEST_ANY_THREADS) + ); + } + + #[test] + fn test_accounts_schedulable_threads_outstanding_mixed() { + let pk1 = Pubkey::new_unique(); + let pk2 = Pubkey::new_unique(); + let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + + locks.read_lock_account(&pk1, 2); + locks.write_lock_account(&pk1, 2); + assert_eq!( + locks.accounts_schedulable_threads([&pk1, &pk2].into_iter(), std::iter::empty()), + Some(ThreadSet::only(2)) + ); + assert_eq!( + locks.accounts_schedulable_threads(std::iter::empty(), [&pk1, &pk2].into_iter()), + Some(ThreadSet::only(2)) + ); + } + + #[test] + #[should_panic(expected = "outstanding write lock must be on same thread")] + fn test_write_lock_account_write_conflict_panic() { + let pk1 = Pubkey::new_unique(); + let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + locks.write_lock_account(&pk1, 0); + locks.write_lock_account(&pk1, 1); + } + + #[test] + #[should_panic(expected = "outstanding read lock must be on same thread")] + fn test_write_lock_account_read_conflict_panic() { + let pk1 = Pubkey::new_unique(); + let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + locks.read_lock_account(&pk1, 0); + locks.write_lock_account(&pk1, 1); + } + + #[test] + #[should_panic(expected = "write lock must exist")] + fn test_write_unlock_account_not_locked() { + let pk1 = Pubkey::new_unique(); + let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + locks.write_unlock_account(&pk1, 0); + } + + #[test] + #[should_panic(expected = "outstanding write lock must be on same thread")] + fn test_write_unlock_account_thread_mismatch() { + let pk1 = Pubkey::new_unique(); + let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + locks.write_lock_account(&pk1, 1); + locks.write_unlock_account(&pk1, 0); + } + + #[test] + #[should_panic(expected = "outstanding write lock must be on same thread")] + fn test_read_lock_account_write_conflict_panic() { + let pk1 = Pubkey::new_unique(); + let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + locks.write_lock_account(&pk1, 0); + locks.read_lock_account(&pk1, 1); + } + + #[test] + #[should_panic(expected = "read lock must exist")] + fn test_read_unlock_account_not_locked() { + let pk1 = Pubkey::new_unique(); + let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + locks.read_unlock_account(&pk1, 1); + } + + #[test] + #[should_panic(expected = "outstanding read lock must be on same thread")] + fn test_read_unlock_account_thread_mismatch() { + let pk1 = Pubkey::new_unique(); + let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + locks.read_lock_account(&pk1, 0); + locks.read_unlock_account(&pk1, 1); + } + + #[test] + fn test_write_locking() { + let pk1 = Pubkey::new_unique(); + let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + locks.write_lock_account(&pk1, 1); + locks.write_lock_account(&pk1, 1); + locks.write_unlock_account(&pk1, 1); + locks.write_unlock_account(&pk1, 1); + assert!(locks.locks.is_empty()); + } + + #[test] + fn test_read_locking() { + let pk1 = Pubkey::new_unique(); + let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + locks.read_lock_account(&pk1, 1); + locks.read_lock_account(&pk1, 1); + locks.read_unlock_account(&pk1, 1); + locks.read_unlock_account(&pk1, 1); + assert!(locks.locks.is_empty()); + } + + #[test] + #[should_panic(expected = "thread_id must be < num_threads")] + fn test_lock_accounts_invalid_thread() { + let pk1 = Pubkey::new_unique(); + let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); + locks.lock_accounts([&pk1].into_iter(), std::iter::empty(), TEST_NUM_THREADS); + } + + #[test] + fn test_thread_set() { + let mut thread_set = ThreadSet::none(); + assert!(thread_set.is_empty()); + assert_eq!(thread_set.num_threads(), 0); + assert_eq!(thread_set.only_one_contained(), None); + for idx in 0..MAX_THREADS { + assert!(!thread_set.contains(idx)); + } + + thread_set.insert(4); + assert!(!thread_set.is_empty()); + assert_eq!(thread_set.num_threads(), 1); + assert_eq!(thread_set.only_one_contained(), Some(4)); + for idx in 0..MAX_THREADS { + assert_eq!(thread_set.contains(idx), idx == 4); + } + + thread_set.insert(2); + assert!(!thread_set.is_empty()); + assert_eq!(thread_set.num_threads(), 2); + assert_eq!(thread_set.only_one_contained(), None); + for idx in 0..MAX_THREADS { + assert_eq!(thread_set.contains(idx), idx == 2 || idx == 4); + } + + thread_set.remove(4); + assert!(!thread_set.is_empty()); + assert_eq!(thread_set.num_threads(), 1); + assert_eq!(thread_set.only_one_contained(), Some(2)); + for idx in 0..MAX_THREADS { + assert_eq!(thread_set.contains(idx), idx == 2); + } + } + + #[test] + fn test_thread_set_any_zero() { + let any_threads = ThreadSet::any(0); + assert_eq!(any_threads.num_threads(), 0); + } + + #[test] + fn test_thread_set_any_max() { + let any_threads = ThreadSet::any(MAX_THREADS); + assert_eq!(any_threads.num_threads(), MAX_THREADS as u32); + } +} diff --git a/prio-graph-scheduler/src/transaction_state.rs b/prio-graph-scheduler/src/transaction_state.rs new file mode 100644 index 00000000000000..9c9d783ab15369 --- /dev/null +++ b/prio-graph-scheduler/src/transaction_state.rs @@ -0,0 +1,359 @@ +use { + solana_core::banking_stage::immutable_deserialized_packet::ImmutableDeserializedPacket, + solana_sdk::{clock::Slot, transaction::SanitizedTransaction}, + std::sync::Arc, +}; + +/// Simple wrapper type to tie a sanitized transaction to max age slot. +pub struct SanitizedTransactionTTL { + pub transaction: SanitizedTransaction, + pub max_age_slot: Slot, +} + +/// TransactionState is used to track the state of a transaction in the transaction scheduler +/// and banking stage as a whole. +/// +/// There are two states a transaction can be in: +/// 1. `Unprocessed` - The transaction is available for scheduling. +/// 2. `Pending` - The transaction is currently scheduled or being processed. +/// +/// Newly received transactions are initially in the `Unprocessed` state. +/// When a transaction is scheduled, it is transitioned to the `Pending` state, +/// using the `transition_to_pending` method. +/// When a transaction finishes processing it may be retryable. If it is retryable, +/// the transaction is transitioned back to the `Unprocessed` state using the +/// `transition_to_unprocessed` method. If it is not retryable, the state should +/// be dropped. +/// +/// For performance, when a transaction is transitioned to the `Pending` state, the +/// internal `SanitizedTransaction` is moved out of the `TransactionState` and sent +/// to the appropriate thread for processing. This is done to avoid cloning the +/// `SanitizedTransaction`. +#[allow(clippy::large_enum_variant)] +pub enum TransactionState { + /// The transaction is available for scheduling. + Unprocessed { + transaction_ttl: SanitizedTransactionTTL, + packet: Arc, + priority: u64, + cost: u64, + should_forward: bool, + }, + /// The transaction is currently scheduled or being processed. + Pending { + packet: Arc, + priority: u64, + cost: u64, + should_forward: bool, + }, + /// Only used during transition. + Transitioning, +} + +impl TransactionState { + /// Creates a new `TransactionState` in the `Unprocessed` state. + pub fn new( + transaction_ttl: SanitizedTransactionTTL, + packet: Arc, + priority: u64, + cost: u64, + ) -> Self { + let should_forward = !packet.original_packet().meta().forwarded() + && packet.original_packet().meta().is_from_staked_node(); + Self::Unprocessed { + transaction_ttl, + packet, + priority, + cost, + should_forward, + } + } + + /// Return the priority of the transaction. + /// This is *not* the same as the `compute_unit_price` of the transaction. + /// The priority is used to order transactions for processing. + pub fn priority(&self) -> u64 { + match self { + Self::Unprocessed { priority, .. } => *priority, + Self::Pending { priority, .. } => *priority, + Self::Transitioning => unreachable!(), + } + } + + /// Return the cost of the transaction. + pub fn cost(&self) -> u64 { + match self { + Self::Unprocessed { cost, .. } => *cost, + Self::Pending { cost, .. } => *cost, + Self::Transitioning => unreachable!(), + } + } + + /// Return whether packet should be attempted to be forwarded. + pub fn should_forward(&self) -> bool { + match self { + Self::Unprocessed { + should_forward: forwarded, + .. + } => *forwarded, + Self::Pending { + should_forward: forwarded, + .. + } => *forwarded, + Self::Transitioning => unreachable!(), + } + } + + /// Mark the packet as forwarded. + /// This is used to prevent the packet from being forwarded multiple times. + pub fn mark_forwarded(&mut self) { + match self { + Self::Unprocessed { should_forward, .. } => *should_forward = false, + Self::Pending { should_forward, .. } => *should_forward = false, + Self::Transitioning => unreachable!(), + } + } + + /// Return the packet of the transaction. + pub fn packet(&self) -> &Arc { + match self { + Self::Unprocessed { packet, .. } => packet, + Self::Pending { packet, .. } => packet, + Self::Transitioning => unreachable!(), + } + } + + /// Intended to be called when a transaction is scheduled. This method will + /// transition the transaction from `Unprocessed` to `Pending` and return the + /// `SanitizedTransactionTTL` for processing. + /// + /// # Panics + /// This method will panic if the transaction is already in the `Pending` state, + /// as this is an invalid state transition. + pub fn transition_to_pending(&mut self) -> SanitizedTransactionTTL { + match self.take() { + TransactionState::Unprocessed { + transaction_ttl, + packet, + priority, + cost, + should_forward: forwarded, + } => { + *self = TransactionState::Pending { + packet, + priority, + cost, + should_forward: forwarded, + }; + transaction_ttl + } + TransactionState::Pending { .. } => { + panic!("transaction already pending"); + } + Self::Transitioning => unreachable!(), + } + } + + /// Intended to be called when a transaction is retried. This method will + /// transition the transaction from `Pending` to `Unprocessed`. + /// + /// # Panics + /// This method will panic if the transaction is already in the `Unprocessed` + /// state, as this is an invalid state transition. + pub fn transition_to_unprocessed(&mut self, transaction_ttl: SanitizedTransactionTTL) { + match self.take() { + TransactionState::Unprocessed { .. } => panic!("already unprocessed"), + TransactionState::Pending { + packet, + priority, + cost, + should_forward: forwarded, + } => { + *self = Self::Unprocessed { + transaction_ttl, + packet, + priority, + cost, + should_forward: forwarded, + } + } + Self::Transitioning => unreachable!(), + } + } + + /// Get a reference to the `SanitizedTransactionTTL` for the transaction. + /// + /// # Panics + /// This method will panic if the transaction is in the `Pending` state. + pub fn transaction_ttl(&self) -> &SanitizedTransactionTTL { + match self { + Self::Unprocessed { + transaction_ttl, .. + } => transaction_ttl, + Self::Pending { .. } => panic!("transaction is pending"), + Self::Transitioning => unreachable!(), + } + } + + /// Internal helper to transitioning between states. + /// Replaces `self` with a dummy state that will immediately be overwritten in transition. + fn take(&mut self) -> Self { + core::mem::replace(self, Self::Transitioning) + } +} + +#[cfg(test)] +mod tests { + use { + super::*, + solana_sdk::{ + compute_budget::ComputeBudgetInstruction, hash::Hash, message::Message, packet::Packet, + signature::Keypair, signer::Signer, system_instruction, transaction::Transaction, + }, + }; + + fn create_transaction_state(compute_unit_price: u64) -> TransactionState { + let from_keypair = Keypair::new(); + let ixs = vec![ + system_instruction::transfer( + &from_keypair.pubkey(), + &solana_sdk::pubkey::new_rand(), + 1, + ), + ComputeBudgetInstruction::set_compute_unit_price(compute_unit_price), + ]; + let message = Message::new(&ixs, Some(&from_keypair.pubkey())); + let tx = Transaction::new(&[&from_keypair], message, Hash::default()); + + let packet = Arc::new( + ImmutableDeserializedPacket::new(Packet::from_data(None, tx.clone()).unwrap()).unwrap(), + ); + let transaction_ttl = SanitizedTransactionTTL { + transaction: SanitizedTransaction::from_transaction_for_tests(tx), + max_age_slot: Slot::MAX, + }; + const TEST_TRANSACTION_COST: u64 = 5000; + TransactionState::new( + transaction_ttl, + packet, + compute_unit_price, + TEST_TRANSACTION_COST, + ) + } + + #[test] + #[should_panic(expected = "already pending")] + fn test_transition_to_pending_panic() { + let mut transaction_state = create_transaction_state(0); + transaction_state.transition_to_pending(); + transaction_state.transition_to_pending(); // invalid transition + } + + #[test] + fn test_transition_to_pending() { + let mut transaction_state = create_transaction_state(0); + assert!(matches!( + transaction_state, + TransactionState::Unprocessed { .. } + )); + let _ = transaction_state.transition_to_pending(); + assert!(matches!( + transaction_state, + TransactionState::Pending { .. } + )); + } + + #[test] + #[should_panic(expected = "already unprocessed")] + fn test_transition_to_unprocessed_panic() { + let mut transaction_state = create_transaction_state(0); + + // Manually clone `SanitizedTransactionTTL` + let SanitizedTransactionTTL { + transaction, + max_age_slot, + } = transaction_state.transaction_ttl(); + let transaction_ttl = SanitizedTransactionTTL { + transaction: transaction.clone(), + max_age_slot: *max_age_slot, + }; + transaction_state.transition_to_unprocessed(transaction_ttl); // invalid transition + } + + #[test] + fn test_transition_to_unprocessed() { + let mut transaction_state = create_transaction_state(0); + assert!(matches!( + transaction_state, + TransactionState::Unprocessed { .. } + )); + let transaction_ttl = transaction_state.transition_to_pending(); + assert!(matches!( + transaction_state, + TransactionState::Pending { .. } + )); + transaction_state.transition_to_unprocessed(transaction_ttl); + assert!(matches!( + transaction_state, + TransactionState::Unprocessed { .. } + )); + } + + #[test] + fn test_priority() { + let priority = 15; + let mut transaction_state = create_transaction_state(priority); + assert_eq!(transaction_state.priority(), priority); + + // ensure compute unit price is not lost through state transitions + let transaction_ttl = transaction_state.transition_to_pending(); + assert_eq!(transaction_state.priority(), priority); + transaction_state.transition_to_unprocessed(transaction_ttl); + assert_eq!(transaction_state.priority(), priority); + } + + #[test] + #[should_panic(expected = "transaction is pending")] + fn test_transaction_ttl_panic() { + let mut transaction_state = create_transaction_state(0); + let transaction_ttl = transaction_state.transaction_ttl(); + assert!(matches!( + transaction_state, + TransactionState::Unprocessed { .. } + )); + assert_eq!(transaction_ttl.max_age_slot, Slot::MAX); + + let _ = transaction_state.transition_to_pending(); + assert!(matches!( + transaction_state, + TransactionState::Pending { .. } + )); + let _ = transaction_state.transaction_ttl(); // pending state, the transaction ttl is not available + } + + #[test] + fn test_transaction_ttl() { + let mut transaction_state = create_transaction_state(0); + let transaction_ttl = transaction_state.transaction_ttl(); + assert!(matches!( + transaction_state, + TransactionState::Unprocessed { .. } + )); + assert_eq!(transaction_ttl.max_age_slot, Slot::MAX); + + // ensure transaction_ttl is not lost through state transitions + let transaction_ttl = transaction_state.transition_to_pending(); + assert!(matches!( + transaction_state, + TransactionState::Pending { .. } + )); + + transaction_state.transition_to_unprocessed(transaction_ttl); + let transaction_ttl = transaction_state.transaction_ttl(); + assert!(matches!( + transaction_state, + TransactionState::Unprocessed { .. } + )); + assert_eq!(transaction_ttl.max_age_slot, Slot::MAX); + } +} From 064d9a15fe14f060bd7b518a2fb4102c2a1beabf Mon Sep 17 00:00:00 2001 From: lewis Date: Fri, 11 Oct 2024 20:40:23 +0800 Subject: [PATCH 3/9] feat: extract new prio-graph-scheduler to reuse PrioGraphScheduler --- Cargo.lock | 15 + core/src/banking_stage.rs | 4 +- prio-graph-scheduler/Cargo.toml | 16 + prio-graph-scheduler/src/lib.rs | 15 +- .../src/prio_graph_scheduler.rs | 907 +++++++++++++ .../src/scheduler_controller.rs | 1161 +++++++++++++++++ prio-graph-scheduler/src/scheduler_error.rs | 9 + prio-graph-scheduler/src/scheduler_metrics.rs | 408 ++++++ .../src/transaction_priority_id.rs | 69 + .../src/transaction_state_container.rs | 259 ++++ 10 files changed, 2860 insertions(+), 3 deletions(-) create mode 100644 prio-graph-scheduler/src/prio_graph_scheduler.rs create mode 100644 prio-graph-scheduler/src/scheduler_controller.rs create mode 100644 prio-graph-scheduler/src/scheduler_error.rs create mode 100644 prio-graph-scheduler/src/scheduler_metrics.rs create mode 100644 prio-graph-scheduler/src/transaction_priority_id.rs create mode 100644 prio-graph-scheduler/src/transaction_state_container.rs diff --git a/Cargo.lock b/Cargo.lock index c3e4bcbf4f0d07..42352a5f7ea4c3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7214,8 +7214,23 @@ name = "solana-prio-graph-scheduler" version = "2.1.0" dependencies = [ "ahash 0.8.10", + "arrayvec", + "assert_matches", + "crossbeam-channel", + "itertools 0.12.1", + "log", + "min-max-heap", + "prio-graph", "solana-core", + "solana-cost-model", + "solana-gossip", + "solana-ledger", + "solana-measure", + "solana-metrics", + "solana-poh", + "solana-runtime", "solana-sdk", + "thiserror", ] [[package]] diff --git a/core/src/banking_stage.rs b/core/src/banking_stage.rs index 78aec0b62b7e2d..41e7d95595dd84 100644 --- a/core/src/banking_stage.rs +++ b/core/src/banking_stage.rs @@ -62,7 +62,7 @@ pub mod qos_service; pub mod unprocessed_packet_batches; pub mod unprocessed_transaction_storage; -mod consume_worker; +pub mod consume_worker; mod decision_maker; mod forward_packet_batches_by_accounts; mod forward_worker; @@ -73,7 +73,7 @@ mod multi_iterator_scanner; mod packet_deserializer; mod packet_filter; mod packet_receiver; -mod read_write_account_set; +pub mod read_write_account_set; #[allow(dead_code)] mod scheduler_messages; mod transaction_scheduler; diff --git a/prio-graph-scheduler/Cargo.toml b/prio-graph-scheduler/Cargo.toml index 9bc4b097c82d19..47706ad7f40b1b 100644 --- a/prio-graph-scheduler/Cargo.toml +++ b/prio-graph-scheduler/Cargo.toml @@ -12,9 +12,25 @@ edition.workspace = true [dependencies] solana-core = { workspace = true } solana-sdk = { workspace = true } +solana-poh = { workspace = true } +solana-metrics = { workspace = true } +solana-ledger = { workspace = true } +solana-runtime = { workspace = true } +solana-gossip = { workspace = true } +solana-cost-model = { workspace = true } +solana-measure = { workspace = true } ahash = { workspace = true } +prio-graph = { workspace = true } +thiserror = { workspace = true } +itertools = { workspace = true } +log = { workspace = true } +crossbeam-channel = { workspace = true } +arrayvec = { workspace = true } +min-max-heap = { workspace = true } +[dev-dependencies] +assert_matches = { workspace = true } [package.metadata.docs.rs] targets = ["x86_64-unknown-linux-gnu"] diff --git a/prio-graph-scheduler/src/lib.rs b/prio-graph-scheduler/src/lib.rs index 927e9d77bbc084..0fa6eac41017b1 100644 --- a/prio-graph-scheduler/src/lib.rs +++ b/prio-graph-scheduler/src/lib.rs @@ -3,4 +3,17 @@ pub mod transaction_state; pub mod scheduler_messages; pub mod id_generator; pub mod in_flight_tracker; -pub mod thread_aware_account_locks; \ No newline at end of file +pub mod thread_aware_account_locks; +pub mod transaction_priority_id; +pub mod scheduler_error; +pub mod scheduler_metrics; +// pub mod scheduler_controller; +pub mod transaction_state_container; +pub mod prio_graph_scheduler; + +#[macro_use] +extern crate solana_metrics; + +#[cfg(test)] +#[macro_use] +extern crate assert_matches; \ No newline at end of file diff --git a/prio-graph-scheduler/src/prio_graph_scheduler.rs b/prio-graph-scheduler/src/prio_graph_scheduler.rs new file mode 100644 index 00000000000000..a2f339d3baeaac --- /dev/null +++ b/prio-graph-scheduler/src/prio_graph_scheduler.rs @@ -0,0 +1,907 @@ +use { + crate::scheduler_messages::{ + ConsumeWork, FinishedConsumeWork, TransactionBatchId, TransactionId, + }, + crate::transaction_priority_id::TransactionPriorityId, + crate::transaction_state::TransactionState, + crate::{ + in_flight_tracker::InFlightTracker, + scheduler_error::SchedulerError, + thread_aware_account_locks::{ThreadAwareAccountLocks, ThreadId, ThreadSet}, + transaction_state::SanitizedTransactionTTL, + transaction_state_container::TransactionStateContainer, + }, + crossbeam_channel::{Receiver, Sender, TryRecvError}, + itertools::izip, + prio_graph::{AccessKind, PrioGraph}, + solana_core::banking_stage::{ + consumer::TARGET_NUM_TRANSACTIONS_PER_BATCH, read_write_account_set::ReadWriteAccountSet, + }, + solana_cost_model::block_cost_limits::MAX_BLOCK_UNITS, + solana_measure::measure_us, + solana_sdk::{ + pubkey::Pubkey, saturating_add_assign, slot_history::Slot, + transaction::SanitizedTransaction, + }, +}; + +pub struct PrioGraphScheduler { + in_flight_tracker: InFlightTracker, + account_locks: ThreadAwareAccountLocks, + consume_work_senders: Vec>, + finished_consume_work_receiver: Receiver, + look_ahead_window_size: usize, +} + +impl PrioGraphScheduler { + pub fn new( + consume_work_senders: Vec>, + finished_consume_work_receiver: Receiver, + ) -> Self { + let num_threads = consume_work_senders.len(); + Self { + in_flight_tracker: InFlightTracker::new(num_threads), + account_locks: ThreadAwareAccountLocks::new(num_threads), + consume_work_senders, + finished_consume_work_receiver, + look_ahead_window_size: 2048, + } + } + + /// Schedule transactions from the given `TransactionStateContainer` to be + /// consumed by the worker threads. Returns summary of scheduling, or an + /// error. + /// `pre_graph_filter` is used to filter out transactions that should be + /// skipped and dropped before insertion to the prio-graph. This fn should + /// set `false` for transactions that should be dropped, and `true` + /// otherwise. + /// `pre_lock_filter` is used to filter out transactions after they have + /// made it to the top of the prio-graph, and immediately before locks are + /// checked and taken. This fn should return `true` for transactions that + /// should be scheduled, and `false` otherwise. + /// + /// Uses a `PrioGraph` to perform look-ahead during the scheduling of transactions. + /// This, combined with internal tracking of threads' in-flight transactions, allows + /// for load-balancing while prioritizing scheduling transactions onto threads that will + /// not cause conflicts in the near future. + pub fn schedule( + &mut self, + container: &mut TransactionStateContainer, + pre_graph_filter: impl Fn(&[&SanitizedTransaction], &mut [bool]), + pre_lock_filter: impl Fn(&SanitizedTransaction) -> bool, + ) -> Result { + let num_threads = self.consume_work_senders.len(); + let max_cu_per_thread = MAX_BLOCK_UNITS / num_threads as u64; + + let mut schedulable_threads = ThreadSet::any(num_threads); + for thread_id in 0..num_threads { + if self.in_flight_tracker.cus_in_flight_per_thread()[thread_id] >= max_cu_per_thread { + schedulable_threads.remove(thread_id); + } + } + if schedulable_threads.is_empty() { + return Ok(SchedulingSummary { + num_scheduled: 0, + num_unschedulable: 0, + num_filtered_out: 0, + filter_time_us: 0, + }); + } + + let mut batches = Batches::new(num_threads); + // Some transactions may be unschedulable due to multi-thread conflicts. + // These transactions cannot be scheduled until some conflicting work is completed. + // However, the scheduler should not allow other transactions that conflict with + // these transactions to be scheduled before them. + let mut unschedulable_ids = Vec::new(); + let mut blocking_locks = ReadWriteAccountSet::default(); + let mut prio_graph = PrioGraph::new(|id: &TransactionPriorityId, _graph_node| *id); + + // Track metrics on filter. + let mut num_filtered_out: usize = 0; + let mut total_filter_time_us: u64 = 0; + + let mut window_budget = self.look_ahead_window_size; + let mut chunked_pops = |container: &mut TransactionStateContainer, + prio_graph: &mut PrioGraph<_, _, _, _>, + window_budget: &mut usize| { + while *window_budget > 0 { + const MAX_FILTER_CHUNK_SIZE: usize = 128; + let mut filter_array = [true; MAX_FILTER_CHUNK_SIZE]; + let mut ids = Vec::with_capacity(MAX_FILTER_CHUNK_SIZE); + let mut txs = Vec::with_capacity(MAX_FILTER_CHUNK_SIZE); + + let chunk_size = (*window_budget).min(MAX_FILTER_CHUNK_SIZE); + for _ in 0..chunk_size { + if let Some(id) = container.pop() { + ids.push(id); + } else { + break; + } + } + *window_budget = window_budget.saturating_sub(chunk_size); + + ids.iter().for_each(|id| { + let transaction = container.get_transaction_ttl(&id.id).unwrap(); + txs.push(&transaction.transaction); + }); + + let (_, filter_us) = + measure_us!(pre_graph_filter(&txs, &mut filter_array[..chunk_size])); + saturating_add_assign!(total_filter_time_us, filter_us); + + for (id, filter_result) in ids.iter().zip(&filter_array[..chunk_size]) { + if *filter_result { + let transaction = container.get_transaction_ttl(&id.id).unwrap(); + prio_graph.insert_transaction( + *id, + Self::get_transaction_account_access(transaction), + ); + } else { + saturating_add_assign!(num_filtered_out, 1); + container.remove_by_id(&id.id); + } + } + + if ids.len() != chunk_size { + break; + } + } + }; + + // Create the initial look-ahead window. + // Check transactions against filter, remove from container if it fails. + chunked_pops(container, &mut prio_graph, &mut window_budget); + + let mut unblock_this_batch = + Vec::with_capacity(self.consume_work_senders.len() * TARGET_NUM_TRANSACTIONS_PER_BATCH); + const MAX_TRANSACTIONS_PER_SCHEDULING_PASS: usize = 100_000; + let mut num_scheduled: usize = 0; + let mut num_sent: usize = 0; + let mut num_unschedulable: usize = 0; + while num_scheduled < MAX_TRANSACTIONS_PER_SCHEDULING_PASS { + // If nothing is in the main-queue of the `PrioGraph` then there's nothing left to schedule. + if prio_graph.is_empty() { + break; + } + + while let Some(id) = prio_graph.pop() { + unblock_this_batch.push(id); + + // Should always be in the container, during initial testing phase panic. + // Later, we can replace with a continue in case this does happen. + let Some(transaction_state) = container.get_mut_transaction_state(&id.id) else { + panic!("transaction state must exist") + }; + + let maybe_schedule_info = try_schedule_transaction( + transaction_state, + &pre_lock_filter, + &mut blocking_locks, + &mut self.account_locks, + num_threads, + |thread_set| { + Self::select_thread( + thread_set, + &batches.total_cus, + self.in_flight_tracker.cus_in_flight_per_thread(), + &batches.transactions, + self.in_flight_tracker.num_in_flight_per_thread(), + ) + }, + ); + + match maybe_schedule_info { + Err(TransactionSchedulingError::Filtered) => { + container.remove_by_id(&id.id); + } + Err(TransactionSchedulingError::UnschedulableConflicts) => { + unschedulable_ids.push(id); + saturating_add_assign!(num_unschedulable, 1); + } + Ok(TransactionSchedulingInfo { + thread_id, + transaction, + max_age_slot, + cost, + }) => { + saturating_add_assign!(num_scheduled, 1); + batches.transactions[thread_id].push(transaction); + batches.ids[thread_id].push(id.id); + batches.max_age_slots[thread_id].push(max_age_slot); + saturating_add_assign!(batches.total_cus[thread_id], cost); + + // If target batch size is reached, send only this batch. + if batches.ids[thread_id].len() >= TARGET_NUM_TRANSACTIONS_PER_BATCH { + saturating_add_assign!( + num_sent, + self.send_batch(&mut batches, thread_id)? + ); + } + + // if the thread is at max_cu_per_thread, remove it from the schedulable threads + // if there are no more schedulable threads, stop scheduling. + if self.in_flight_tracker.cus_in_flight_per_thread()[thread_id] + + batches.total_cus[thread_id] + >= max_cu_per_thread + { + schedulable_threads.remove(thread_id); + if schedulable_threads.is_empty() { + break; + } + } + + if num_scheduled >= MAX_TRANSACTIONS_PER_SCHEDULING_PASS { + break; + } + } + } + } + + // Send all non-empty batches + saturating_add_assign!(num_sent, self.send_batches(&mut batches)?); + + // Refresh window budget and do chunked pops + saturating_add_assign!(window_budget, unblock_this_batch.len()); + chunked_pops(container, &mut prio_graph, &mut window_budget); + + // Unblock all transactions that were blocked by the transactions that were just sent. + for id in unblock_this_batch.drain(..) { + prio_graph.unblock(&id); + } + } + + // Send batches for any remaining transactions + saturating_add_assign!(num_sent, self.send_batches(&mut batches)?); + + // Push unschedulable ids back into the container + for id in unschedulable_ids { + container.push_id_into_queue(id); + } + + // Push remaining transactions back into the container + while let Some((id, _)) = prio_graph.pop_and_unblock() { + container.push_id_into_queue(id); + } + + assert_eq!( + num_scheduled, num_sent, + "number of scheduled and sent transactions must match" + ); + + Ok(SchedulingSummary { + num_scheduled, + num_unschedulable, + num_filtered_out, + filter_time_us: total_filter_time_us, + }) + } + + /// Receive completed batches of transactions without blocking. + /// Returns (num_transactions, num_retryable_transactions) on success. + pub fn receive_completed( + &mut self, + container: &mut TransactionStateContainer, + ) -> Result<(usize, usize), SchedulerError> { + let mut total_num_transactions: usize = 0; + let mut total_num_retryable: usize = 0; + loop { + let (num_transactions, num_retryable) = self.try_receive_completed(container)?; + if num_transactions == 0 { + break; + } + saturating_add_assign!(total_num_transactions, num_transactions); + saturating_add_assign!(total_num_retryable, num_retryable); + } + Ok((total_num_transactions, total_num_retryable)) + } + + /// Receive completed batches of transactions. + /// Returns `Ok((num_transactions, num_retryable))` if a batch was received, `Ok((0, 0))` if no batch was received. + fn try_receive_completed( + &mut self, + container: &mut TransactionStateContainer, + ) -> Result<(usize, usize), SchedulerError> { + match self.finished_consume_work_receiver.try_recv() { + Ok(FinishedConsumeWork { + work: + ConsumeWork { + batch_id, + ids, + transactions, + max_age_slots, + }, + retryable_indexes, + }) => { + let num_transactions = ids.len(); + let num_retryable = retryable_indexes.len(); + + // Free the locks + self.complete_batch(batch_id, &transactions); + + // Retryable transactions should be inserted back into the container + let mut retryable_iter = retryable_indexes.into_iter().peekable(); + for (index, (id, transaction, max_age_slot)) in + izip!(ids, transactions, max_age_slots).enumerate() + { + if let Some(retryable_index) = retryable_iter.peek() { + if *retryable_index == index { + container.retry_transaction( + id, + SanitizedTransactionTTL { + transaction, + max_age_slot, + }, + ); + retryable_iter.next(); + continue; + } + } + container.remove_by_id(&id); + } + + Ok((num_transactions, num_retryable)) + } + Err(TryRecvError::Empty) => Ok((0, 0)), + Err(TryRecvError::Disconnected) => Err(SchedulerError::DisconnectedRecvChannel( + "finished consume work", + )), + } + } + + /// Mark a given `TransactionBatchId` as completed. + /// This will update the internal tracking, including account locks. + fn complete_batch( + &mut self, + batch_id: TransactionBatchId, + transactions: &[SanitizedTransaction], + ) { + let thread_id = self.in_flight_tracker.complete_batch(batch_id); + for transaction in transactions { + let message = transaction.message(); + let account_keys = message.account_keys(); + let write_account_locks = account_keys + .iter() + .enumerate() + .filter_map(|(index, key)| message.is_writable(index).then_some(key)); + let read_account_locks = account_keys + .iter() + .enumerate() + .filter_map(|(index, key)| (!message.is_writable(index)).then_some(key)); + self.account_locks + .unlock_accounts(write_account_locks, read_account_locks, thread_id); + } + } + + /// Send all batches of transactions to the worker threads. + /// Returns the number of transactions sent. + fn send_batches(&mut self, batches: &mut Batches) -> Result { + (0..self.consume_work_senders.len()) + .map(|thread_index| self.send_batch(batches, thread_index)) + .sum() + } + + /// Send a batch of transactions to the given thread's `ConsumeWork` channel. + /// Returns the number of transactions sent. + fn send_batch( + &mut self, + batches: &mut Batches, + thread_index: usize, + ) -> Result { + if batches.ids[thread_index].is_empty() { + return Ok(0); + } + + let (ids, transactions, max_age_slots, total_cus) = batches.take_batch(thread_index); + + let batch_id = self + .in_flight_tracker + .track_batch(ids.len(), total_cus, thread_index); + + let num_scheduled = ids.len(); + let work = ConsumeWork { + batch_id, + ids, + transactions, + max_age_slots, + }; + self.consume_work_senders[thread_index] + .send(work) + .map_err(|_| SchedulerError::DisconnectedSendChannel("consume work sender"))?; + + Ok(num_scheduled) + } + + /// Given the schedulable `thread_set`, select the thread with the least amount + /// of work queued up. + /// Currently, "work" is just defined as the number of transactions. + /// + /// If the `chain_thread` is available, this thread will be selected, regardless of + /// load-balancing. + /// + /// Panics if the `thread_set` is empty. This should never happen, see comment + /// on `ThreadAwareAccountLocks::try_lock_accounts`. + fn select_thread( + thread_set: ThreadSet, + batch_cus_per_thread: &[u64], + in_flight_cus_per_thread: &[u64], + batches_per_thread: &[Vec], + in_flight_per_thread: &[usize], + ) -> ThreadId { + thread_set + .contained_threads_iter() + .map(|thread_id| { + ( + thread_id, + batch_cus_per_thread[thread_id] + in_flight_cus_per_thread[thread_id], + batches_per_thread[thread_id].len() + in_flight_per_thread[thread_id], + ) + }) + .min_by(|a, b| a.1.cmp(&b.1).then_with(|| a.2.cmp(&b.2))) + .map(|(thread_id, _, _)| thread_id) + .unwrap() + } + + /// Gets accessed accounts (resources) for use in `PrioGraph`. + fn get_transaction_account_access( + transaction: &SanitizedTransactionTTL, + ) -> impl Iterator + '_ { + let message = transaction.transaction.message(); + message + .account_keys() + .iter() + .enumerate() + .map(|(index, key)| { + if message.is_writable(index) { + (*key, AccessKind::Write) + } else { + (*key, AccessKind::Read) + } + }) + } +} + +/// Metrics from scheduling transactions. +#[derive(Debug, PartialEq, Eq)] +pub struct SchedulingSummary { + /// Number of transactions scheduled. + pub num_scheduled: usize, + /// Number of transactions that were not scheduled due to conflicts. + pub num_unschedulable: usize, + /// Number of transactions that were dropped due to filter. + pub num_filtered_out: usize, + /// Time spent filtering transactions + pub filter_time_us: u64, +} + +struct Batches { + ids: Vec>, + transactions: Vec>, + max_age_slots: Vec>, + total_cus: Vec, +} + +impl Batches { + fn new(num_threads: usize) -> Self { + Self { + ids: vec![Vec::with_capacity(TARGET_NUM_TRANSACTIONS_PER_BATCH); num_threads], + transactions: vec![Vec::with_capacity(TARGET_NUM_TRANSACTIONS_PER_BATCH); num_threads], + max_age_slots: vec![Vec::with_capacity(TARGET_NUM_TRANSACTIONS_PER_BATCH); num_threads], + total_cus: vec![0; num_threads], + } + } + + fn take_batch( + &mut self, + thread_id: ThreadId, + ) -> ( + Vec, + Vec, + Vec, + u64, + ) { + ( + core::mem::replace( + &mut self.ids[thread_id], + Vec::with_capacity(TARGET_NUM_TRANSACTIONS_PER_BATCH), + ), + core::mem::replace( + &mut self.transactions[thread_id], + Vec::with_capacity(TARGET_NUM_TRANSACTIONS_PER_BATCH), + ), + core::mem::replace( + &mut self.max_age_slots[thread_id], + Vec::with_capacity(TARGET_NUM_TRANSACTIONS_PER_BATCH), + ), + core::mem::replace(&mut self.total_cus[thread_id], 0), + ) + } +} + +/// A transaction has been scheduled to a thread. +struct TransactionSchedulingInfo { + thread_id: ThreadId, + transaction: SanitizedTransaction, + max_age_slot: Slot, + cost: u64, +} + +/// Error type for reasons a transaction could not be scheduled. +enum TransactionSchedulingError { + /// Transaction was filtered out before locking. + Filtered, + /// Transaction cannot be scheduled due to conflicts, or + /// higher priority conflicting transactions are unschedulable. + UnschedulableConflicts, +} + +fn try_schedule_transaction( + transaction_state: &mut TransactionState, + pre_lock_filter: impl Fn(&SanitizedTransaction) -> bool, + blocking_locks: &mut ReadWriteAccountSet, + account_locks: &mut ThreadAwareAccountLocks, + num_threads: usize, + thread_selector: impl Fn(ThreadSet) -> ThreadId, +) -> Result { + let transaction = &transaction_state.transaction_ttl().transaction; + if !pre_lock_filter(transaction) { + return Err(TransactionSchedulingError::Filtered); + } + + // Check if this transaction conflicts with any blocked transactions + let message = transaction.message(); + if !blocking_locks.check_locks(message) { + blocking_locks.take_locks(message); + return Err(TransactionSchedulingError::UnschedulableConflicts); + } + + // Schedule the transaction if it can be. + let message = transaction.message(); + let account_keys = message.account_keys(); + let write_account_locks = account_keys + .iter() + .enumerate() + .filter_map(|(index, key)| message.is_writable(index).then_some(key)); + let read_account_locks = account_keys + .iter() + .enumerate() + .filter_map(|(index, key)| (!message.is_writable(index)).then_some(key)); + + let Some(thread_id) = account_locks.try_lock_accounts( + write_account_locks, + read_account_locks, + ThreadSet::any(num_threads), + thread_selector, + ) else { + blocking_locks.take_locks(message); + return Err(TransactionSchedulingError::UnschedulableConflicts); + }; + + let sanitized_transaction_ttl = transaction_state.transition_to_pending(); + let cost = transaction_state.cost(); + + Ok(TransactionSchedulingInfo { + thread_id, + transaction: sanitized_transaction_ttl.transaction, + max_age_slot: sanitized_transaction_ttl.max_age_slot, + cost, + }) +} + +#[cfg(test)] +mod tests { + use { + super::*, + crossbeam_channel::{unbounded, Receiver}, + itertools::Itertools, + solana_core::banking_stage::{ + consumer::TARGET_NUM_TRANSACTIONS_PER_BATCH, + immutable_deserialized_packet::ImmutableDeserializedPacket, + }, + solana_sdk::{ + compute_budget::ComputeBudgetInstruction, hash::Hash, message::Message, packet::Packet, + pubkey::Pubkey, signature::Keypair, signer::Signer, system_instruction, + transaction::Transaction, + }, + std::{borrow::Borrow, sync::Arc}, + }; + + macro_rules! txid { + ($value:expr) => { + TransactionId::new($value) + }; + } + + macro_rules! txids { + ([$($element:expr),*]) => { + vec![ $(txid!($element)),* ] + }; + } + + fn create_test_frame( + num_threads: usize, + ) -> ( + PrioGraphScheduler, + Vec>, + Sender, + ) { + let (consume_work_senders, consume_work_receivers) = + (0..num_threads).map(|_| unbounded()).unzip(); + let (finished_consume_work_sender, finished_consume_work_receiver) = unbounded(); + let scheduler = + PrioGraphScheduler::new(consume_work_senders, finished_consume_work_receiver); + ( + scheduler, + consume_work_receivers, + finished_consume_work_sender, + ) + } + + fn prioritized_tranfers( + from_keypair: &Keypair, + to_pubkeys: impl IntoIterator>, + lamports: u64, + priority: u64, + ) -> SanitizedTransaction { + let to_pubkeys_lamports = to_pubkeys + .into_iter() + .map(|pubkey| *pubkey.borrow()) + .zip(std::iter::repeat(lamports)) + .collect_vec(); + let mut ixs = + system_instruction::transfer_many(&from_keypair.pubkey(), &to_pubkeys_lamports); + let prioritization = ComputeBudgetInstruction::set_compute_unit_price(priority); + ixs.push(prioritization); + let message = Message::new(&ixs, Some(&from_keypair.pubkey())); + let tx = Transaction::new(&[from_keypair], message, Hash::default()); + SanitizedTransaction::from_transaction_for_tests(tx) + } + + fn create_container( + tx_infos: impl IntoIterator< + Item = ( + impl Borrow, + impl IntoIterator>, + u64, + u64, + ), + >, + ) -> TransactionStateContainer { + let mut container = TransactionStateContainer::with_capacity(10 * 1024); + for (index, (from_keypair, to_pubkeys, lamports, compute_unit_price)) in + tx_infos.into_iter().enumerate() + { + let id = TransactionId::new(index as u64); + let transaction = prioritized_tranfers( + from_keypair.borrow(), + to_pubkeys, + lamports, + compute_unit_price, + ); + let packet = Arc::new( + ImmutableDeserializedPacket::new( + Packet::from_data(None, transaction.to_versioned_transaction()).unwrap(), + ) + .unwrap(), + ); + let transaction_ttl = SanitizedTransactionTTL { + transaction, + max_age_slot: Slot::MAX, + }; + const TEST_TRANSACTION_COST: u64 = 5000; + container.insert_new_transaction( + id, + transaction_ttl, + packet, + compute_unit_price, + TEST_TRANSACTION_COST, + ); + } + + container + } + + fn collect_work( + receiver: &Receiver, + ) -> (Vec, Vec>) { + receiver + .try_iter() + .map(|work| { + let ids = work.ids.clone(); + (work, ids) + }) + .unzip() + } + + fn test_pre_graph_filter(_txs: &[&SanitizedTransaction], results: &mut [bool]) { + results.fill(true); + } + + fn test_pre_lock_filter(_tx: &SanitizedTransaction) -> bool { + true + } + + #[test] + fn test_schedule_disconnected_channel() { + let (mut scheduler, work_receivers, _finished_work_sender) = create_test_frame(1); + let mut container = create_container([(&Keypair::new(), &[Pubkey::new_unique()], 1, 1)]); + + drop(work_receivers); // explicitly drop receivers + assert_matches!( + scheduler.schedule(&mut container, test_pre_graph_filter, test_pre_lock_filter), + Err(SchedulerError::DisconnectedSendChannel(_)) + ); + } + + #[test] + fn test_schedule_single_threaded_no_conflicts() { + let (mut scheduler, work_receivers, _finished_work_sender) = create_test_frame(1); + let mut container = create_container([ + (&Keypair::new(), &[Pubkey::new_unique()], 1, 1), + (&Keypair::new(), &[Pubkey::new_unique()], 2, 2), + ]); + + let scheduling_summary = scheduler + .schedule(&mut container, test_pre_graph_filter, test_pre_lock_filter) + .unwrap(); + assert_eq!(scheduling_summary.num_scheduled, 2); + assert_eq!(scheduling_summary.num_unschedulable, 0); + assert_eq!(collect_work(&work_receivers[0]).1, vec![txids!([1, 0])]); + } + + #[test] + fn test_schedule_single_threaded_conflict() { + let (mut scheduler, work_receivers, _finished_work_sender) = create_test_frame(1); + let pubkey = Pubkey::new_unique(); + let mut container = create_container([ + (&Keypair::new(), &[pubkey], 1, 1), + (&Keypair::new(), &[pubkey], 1, 2), + ]); + + let scheduling_summary = scheduler + .schedule(&mut container, test_pre_graph_filter, test_pre_lock_filter) + .unwrap(); + assert_eq!(scheduling_summary.num_scheduled, 2); + assert_eq!(scheduling_summary.num_unschedulable, 0); + assert_eq!( + collect_work(&work_receivers[0]).1, + vec![txids!([1]), txids!([0])] + ); + } + + #[test] + fn test_schedule_consume_single_threaded_multi_batch() { + let (mut scheduler, work_receivers, _finished_work_sender) = create_test_frame(1); + let mut container = create_container( + (0..4 * TARGET_NUM_TRANSACTIONS_PER_BATCH) + .map(|i| (Keypair::new(), [Pubkey::new_unique()], i as u64, 1)), + ); + + // expect 4 full batches to be scheduled + let scheduling_summary = scheduler + .schedule(&mut container, test_pre_graph_filter, test_pre_lock_filter) + .unwrap(); + assert_eq!( + scheduling_summary.num_scheduled, + 4 * TARGET_NUM_TRANSACTIONS_PER_BATCH + ); + assert_eq!(scheduling_summary.num_unschedulable, 0); + + let thread0_work_counts: Vec<_> = work_receivers[0] + .try_iter() + .map(|work| work.ids.len()) + .collect(); + assert_eq!(thread0_work_counts, [TARGET_NUM_TRANSACTIONS_PER_BATCH; 4]); + } + + #[test] + fn test_schedule_simple_thread_selection() { + let (mut scheduler, work_receivers, _finished_work_sender) = create_test_frame(2); + let mut container = + create_container((0..4).map(|i| (Keypair::new(), [Pubkey::new_unique()], 1, i))); + + let scheduling_summary = scheduler + .schedule(&mut container, test_pre_graph_filter, test_pre_lock_filter) + .unwrap(); + assert_eq!(scheduling_summary.num_scheduled, 4); + assert_eq!(scheduling_summary.num_unschedulable, 0); + assert_eq!(collect_work(&work_receivers[0]).1, [txids!([3, 1])]); + assert_eq!(collect_work(&work_receivers[1]).1, [txids!([2, 0])]); + } + + #[test] + fn test_schedule_priority_guard() { + let (mut scheduler, work_receivers, finished_work_sender) = create_test_frame(2); + // intentionally shorten the look-ahead window to cause unschedulable conflicts + scheduler.look_ahead_window_size = 2; + + let accounts = (0..8).map(|_| Keypair::new()).collect_vec(); + let mut container = create_container([ + (&accounts[0], &[accounts[1].pubkey()], 1, 6), + (&accounts[2], &[accounts[3].pubkey()], 1, 5), + (&accounts[4], &[accounts[5].pubkey()], 1, 4), + (&accounts[6], &[accounts[7].pubkey()], 1, 3), + (&accounts[1], &[accounts[2].pubkey()], 1, 2), + (&accounts[2], &[accounts[3].pubkey()], 1, 1), + ]); + + // The look-ahead window is intentionally shortened, high priority transactions + // [0, 1, 2, 3] do not conflict, and are scheduled onto threads in a + // round-robin fashion. This leads to transaction [4] being unschedulable due + // to conflicts with [0] and [1], which were scheduled to different threads. + // Transaction [5] is technically schedulable, onto thread 1 since it only + // conflicts with transaction [1]. However, [5] will not be scheduled because + // it conflicts with a higher-priority transaction [4] that is unschedulable. + // The full prio-graph can be visualized as: + // [0] \ + // -> [4] -> [5] + // [1] / ------/ + // [2] + // [3] + // Because the look-ahead window is shortened to a size of 4, the scheduler does + // not have knowledge of the joining at transaction [4] until after [0] and [1] + // have been scheduled. + let scheduling_summary = scheduler + .schedule(&mut container, test_pre_graph_filter, test_pre_lock_filter) + .unwrap(); + assert_eq!(scheduling_summary.num_scheduled, 4); + assert_eq!(scheduling_summary.num_unschedulable, 2); + let (thread_0_work, thread_0_ids) = collect_work(&work_receivers[0]); + assert_eq!(thread_0_ids, [txids!([0]), txids!([2])]); + assert_eq!( + collect_work(&work_receivers[1]).1, + [txids!([1]), txids!([3])] + ); + + // Cannot schedule even on next pass because of lock conflicts + let scheduling_summary = scheduler + .schedule(&mut container, test_pre_graph_filter, test_pre_lock_filter) + .unwrap(); + assert_eq!(scheduling_summary.num_scheduled, 0); + assert_eq!(scheduling_summary.num_unschedulable, 2); + + // Complete batch on thread 0. Remaining txs can be scheduled onto thread 1 + finished_work_sender + .send(FinishedConsumeWork { + work: thread_0_work.into_iter().next().unwrap(), + retryable_indexes: vec![], + }) + .unwrap(); + scheduler.receive_completed(&mut container).unwrap(); + let scheduling_summary = scheduler + .schedule(&mut container, test_pre_graph_filter, test_pre_lock_filter) + .unwrap(); + assert_eq!(scheduling_summary.num_scheduled, 2); + assert_eq!(scheduling_summary.num_unschedulable, 0); + + assert_eq!( + collect_work(&work_receivers[1]).1, + [txids!([4]), txids!([5])] + ); + } + + #[test] + fn test_schedule_pre_lock_filter() { + let (mut scheduler, work_receivers, _finished_work_sender) = create_test_frame(1); + let pubkey = Pubkey::new_unique(); + let keypair = Keypair::new(); + let mut container = create_container([ + (&Keypair::new(), &[pubkey], 1, 1), + (&keypair, &[pubkey], 1, 2), + (&Keypair::new(), &[pubkey], 1, 3), + ]); + + // 2nd transaction should be filtered out and dropped before locking. + let pre_lock_filter = + |tx: &SanitizedTransaction| tx.message().fee_payer() != &keypair.pubkey(); + let scheduling_summary = scheduler + .schedule(&mut container, test_pre_graph_filter, pre_lock_filter) + .unwrap(); + assert_eq!(scheduling_summary.num_scheduled, 2); + assert_eq!(scheduling_summary.num_unschedulable, 0); + assert_eq!( + collect_work(&work_receivers[0]).1, + vec![txids!([2]), txids!([0])] + ); + } +} diff --git a/prio-graph-scheduler/src/scheduler_controller.rs b/prio-graph-scheduler/src/scheduler_controller.rs new file mode 100644 index 00000000000000..5a04e1eb39d65b --- /dev/null +++ b/prio-graph-scheduler/src/scheduler_controller.rs @@ -0,0 +1,1161 @@ +//! Control flow for BankingStage's transaction scheduler. +//! + +use { + super::{ + prio_graph_scheduler::PrioGraphScheduler, + scheduler_error::SchedulerError, + scheduler_metrics::{ + SchedulerCountMetrics, SchedulerLeaderDetectionMetrics, SchedulerTimingMetrics, + }, + transaction_id_generator::TransactionIdGenerator, + transaction_state::SanitizedTransactionTTL, + transaction_state_container::TransactionStateContainer, + }, + crate::banking_stage::{ + consume_worker::ConsumeWorkerMetrics, + consumer::Consumer, + decision_maker::{BufferedPacketsDecision, DecisionMaker}, + forwarder::Forwarder, + immutable_deserialized_packet::ImmutableDeserializedPacket, + packet_deserializer::PacketDeserializer, + ForwardOption, LikeClusterInfo, TOTAL_BUFFERED_PACKETS, + }, + arrayvec::ArrayVec, + crossbeam_channel::RecvTimeoutError, + solana_accounts_db::account_locks::validate_account_locks, + solana_cost_model::cost_model::CostModel, + solana_measure::measure_us, + solana_runtime::{bank::Bank, bank_forks::BankForks}, + solana_runtime_transaction::instructions_processor::process_compute_budget_instructions, + solana_sdk::{ + self, + clock::{FORWARD_TRANSACTIONS_TO_LEADER_AT_SLOT_OFFSET, MAX_PROCESSING_AGE}, + fee::FeeBudgetLimits, + saturating_add_assign, + transaction::SanitizedTransaction, + }, + solana_svm::transaction_error_metrics::TransactionErrorMetrics, + solana_svm_transaction::svm_message::SVMMessage, + std::{ + sync::{Arc, RwLock}, + time::{Duration, Instant}, + }, +}; + +/// Controls packet and transaction flow into scheduler, and scheduling execution. +pub(crate) struct SchedulerController { + /// Decision maker for determining what should be done with transactions. + decision_maker: DecisionMaker, + /// Packet/Transaction ingress. + packet_receiver: PacketDeserializer, + bank_forks: Arc>, + /// Generates unique IDs for incoming transactions. + transaction_id_generator: TransactionIdGenerator, + /// Container for transaction state. + /// Shared resource between `packet_receiver` and `scheduler`. + container: TransactionStateContainer, + /// State for scheduling and communicating with worker threads. + scheduler: PrioGraphScheduler, + /// Metrics tracking time for leader bank detection. + leader_detection_metrics: SchedulerLeaderDetectionMetrics, + /// Metrics tracking counts on transactions in different states + /// over an interval and during a leader slot. + count_metrics: SchedulerCountMetrics, + /// Metrics tracking time spent in difference code sections + /// over an interval and during a leader slot. + timing_metrics: SchedulerTimingMetrics, + /// Metric report handles for the worker threads. + worker_metrics: Vec>, + /// State for forwarding packets to the leader, if enabled. + forwarder: Option>, +} + +impl SchedulerController { + pub fn new( + decision_maker: DecisionMaker, + packet_deserializer: PacketDeserializer, + bank_forks: Arc>, + scheduler: PrioGraphScheduler, + worker_metrics: Vec>, + forwarder: Option>, + ) -> Self { + Self { + decision_maker, + packet_receiver: packet_deserializer, + bank_forks, + transaction_id_generator: TransactionIdGenerator::default(), + container: TransactionStateContainer::with_capacity(TOTAL_BUFFERED_PACKETS), + scheduler, + leader_detection_metrics: SchedulerLeaderDetectionMetrics::default(), + count_metrics: SchedulerCountMetrics::default(), + timing_metrics: SchedulerTimingMetrics::default(), + worker_metrics, + forwarder, + } + } + + pub fn run(mut self) -> Result<(), SchedulerError> { + loop { + // BufferedPacketsDecision is shared with legacy BankingStage, which will forward + // packets. Initially, not renaming these decision variants but the actions taken + // are different, since new BankingStage will not forward packets. + // For `Forward` and `ForwardAndHold`, we want to receive packets but will not + // forward them to the next leader. In this case, `ForwardAndHold` is + // indistinguishable from `Hold`. + // + // `Forward` will drop packets from the buffer instead of forwarding. + // During receiving, since packets would be dropped from buffer anyway, we can + // bypass sanitization and buffering and immediately drop the packets. + let (decision, decision_time_us) = + measure_us!(self.decision_maker.make_consume_or_forward_decision()); + self.timing_metrics.update(|timing_metrics| { + saturating_add_assign!(timing_metrics.decision_time_us, decision_time_us); + }); + let new_leader_slot = decision.bank_start().map(|b| b.working_bank.slot()); + self.leader_detection_metrics + .update_and_maybe_report(decision.bank_start()); + self.count_metrics + .maybe_report_and_reset_slot(new_leader_slot); + self.timing_metrics + .maybe_report_and_reset_slot(new_leader_slot); + + self.process_transactions(&decision)?; + self.receive_completed()?; + if !self.receive_and_buffer_packets(&decision) { + break; + } + // Report metrics only if there is data. + // Reset intervals when appropriate, regardless of report. + let should_report = self.count_metrics.interval_has_data(); + let priority_min_max = self.container.get_min_max_priority(); + self.count_metrics.update(|count_metrics| { + count_metrics.update_priority_stats(priority_min_max); + }); + self.count_metrics + .maybe_report_and_reset_interval(should_report); + self.timing_metrics + .maybe_report_and_reset_interval(should_report); + self.worker_metrics + .iter() + .for_each(|metrics| metrics.maybe_report_and_reset()); + } + + Ok(()) + } + + /// Process packets based on decision. + fn process_transactions( + &mut self, + decision: &BufferedPacketsDecision, + ) -> Result<(), SchedulerError> { + let forwarding_enabled = self.forwarder.is_some(); + match decision { + BufferedPacketsDecision::Consume(bank_start) => { + let (scheduling_summary, schedule_time_us) = measure_us!(self.scheduler.schedule( + &mut self.container, + |txs, results| { + Self::pre_graph_filter( + txs, + results, + &bank_start.working_bank, + MAX_PROCESSING_AGE, + ) + }, + |_| true // no pre-lock filter for now + )?); + + self.count_metrics.update(|count_metrics| { + saturating_add_assign!( + count_metrics.num_scheduled, + scheduling_summary.num_scheduled + ); + saturating_add_assign!( + count_metrics.num_unschedulable, + scheduling_summary.num_unschedulable + ); + saturating_add_assign!( + count_metrics.num_schedule_filtered_out, + scheduling_summary.num_filtered_out + ); + }); + + self.timing_metrics.update(|timing_metrics| { + saturating_add_assign!( + timing_metrics.schedule_filter_time_us, + scheduling_summary.filter_time_us + ); + saturating_add_assign!(timing_metrics.schedule_time_us, schedule_time_us); + }); + } + BufferedPacketsDecision::Forward => { + if forwarding_enabled { + let (_, forward_time_us) = measure_us!(self.forward_packets(false)); + self.timing_metrics.update(|timing_metrics| { + saturating_add_assign!(timing_metrics.forward_time_us, forward_time_us); + }); + } else { + let (_, clear_time_us) = measure_us!(self.clear_container()); + self.timing_metrics.update(|timing_metrics| { + saturating_add_assign!(timing_metrics.clear_time_us, clear_time_us); + }); + } + } + BufferedPacketsDecision::ForwardAndHold => { + if forwarding_enabled { + let (_, forward_time_us) = measure_us!(self.forward_packets(true)); + self.timing_metrics.update(|timing_metrics| { + saturating_add_assign!(timing_metrics.forward_time_us, forward_time_us); + }); + } else { + let (_, clean_time_us) = measure_us!(self.clean_queue()); + self.timing_metrics.update(|timing_metrics| { + saturating_add_assign!(timing_metrics.clean_time_us, clean_time_us); + }); + } + } + BufferedPacketsDecision::Hold => {} + } + + Ok(()) + } + + fn pre_graph_filter( + transactions: &[&SanitizedTransaction], + results: &mut [bool], + bank: &Bank, + max_age: usize, + ) { + let lock_results = vec![Ok(()); transactions.len()]; + let mut error_counters = TransactionErrorMetrics::default(); + let check_results = + bank.check_transactions(transactions, &lock_results, max_age, &mut error_counters); + + let fee_check_results: Vec<_> = check_results + .into_iter() + .zip(transactions) + .map(|(result, tx)| { + result?; // if there's already error do nothing + Consumer::check_fee_payer_unlocked(bank, tx.message(), &mut error_counters) + }) + .collect(); + + for (fee_check_result, result) in fee_check_results.into_iter().zip(results.iter_mut()) { + *result = fee_check_result.is_ok(); + } + } + + /// Forward packets to the next leader. + fn forward_packets(&mut self, hold: bool) { + const MAX_FORWARDING_DURATION: Duration = Duration::from_millis(100); + let start = Instant::now(); + let bank = self.bank_forks.read().unwrap().working_bank(); + let feature_set = &bank.feature_set; + let forwarder = self.forwarder.as_mut().expect("forwarder must exist"); + + // Pop from the container in chunks, filter using bank checks, then attempt to forward. + // This doubles as a way to clean the queue as well as forwarding transactions. + const CHUNK_SIZE: usize = 64; + let mut num_forwarded: usize = 0; + let mut ids_to_add_back = Vec::new(); + let mut max_time_reached = false; + while !self.container.is_empty() { + let mut filter_array = [true; CHUNK_SIZE]; + let mut ids = Vec::with_capacity(CHUNK_SIZE); + let mut txs = Vec::with_capacity(CHUNK_SIZE); + + for _ in 0..CHUNK_SIZE { + if let Some(id) = self.container.pop() { + ids.push(id); + } else { + break; + } + } + let chunk_size = ids.len(); + ids.iter().for_each(|id| { + let transaction = self.container.get_transaction_ttl(&id.id).unwrap(); + txs.push(&transaction.transaction); + }); + + // use same filter we use for processing transactions: + // age, already processed, fee-check. + Self::pre_graph_filter( + &txs, + &mut filter_array, + &bank, + MAX_PROCESSING_AGE + .saturating_sub(FORWARD_TRANSACTIONS_TO_LEADER_AT_SLOT_OFFSET as usize), + ); + + for (id, filter_result) in ids.iter().zip(&filter_array[..chunk_size]) { + if !*filter_result { + self.container.remove_by_id(&id.id); + continue; + } + + ids_to_add_back.push(*id); // add back to the queue at end + let state = self.container.get_mut_transaction_state(&id.id).unwrap(); + let sanitized_transaction = &state.transaction_ttl().transaction; + let immutable_packet = state.packet().clone(); + + // If not already forwarded and can be forwarded, add to forwardable packets. + if state.should_forward() + && forwarder.try_add_packet( + sanitized_transaction, + immutable_packet, + feature_set, + ) + { + saturating_add_assign!(num_forwarded, 1); + state.mark_forwarded(); + } + } + + if start.elapsed() >= MAX_FORWARDING_DURATION { + max_time_reached = true; + break; + } + } + + // Forward each batch of transactions + forwarder.forward_batched_packets(&ForwardOption::ForwardTransaction); + forwarder.clear_batches(); + + // If we hit the time limit. Drop everything that was not checked/processed. + // If we cannot run these simple checks in time, then we cannot run them during + // leader slot. + if max_time_reached { + while let Some(id) = self.container.pop() { + self.container.remove_by_id(&id.id); + } + } + + if hold { + for priority_id in ids_to_add_back { + self.container.push_id_into_queue(priority_id); + } + } else { + for priority_id in ids_to_add_back { + self.container.remove_by_id(&priority_id.id); + } + } + + self.count_metrics.update(|count_metrics| { + saturating_add_assign!(count_metrics.num_forwarded, num_forwarded); + }); + } + + /// Clears the transaction state container. + /// This only clears pending transactions, and does **not** clear in-flight transactions. + fn clear_container(&mut self) { + let mut num_dropped_on_clear: usize = 0; + while let Some(id) = self.container.pop() { + self.container.remove_by_id(&id.id); + saturating_add_assign!(num_dropped_on_clear, 1); + } + + self.count_metrics.update(|count_metrics| { + saturating_add_assign!(count_metrics.num_dropped_on_clear, num_dropped_on_clear); + }); + } + + /// Clean unprocessable transactions from the queue. These will be transactions that are + /// expired, already processed, or are no longer sanitizable. + /// This only clears pending transactions, and does **not** clear in-flight transactions. + fn clean_queue(&mut self) { + // Clean up any transactions that have already been processed, are too old, or do not have + // valid nonce accounts. + const MAX_TRANSACTION_CHECKS: usize = 10_000; + let mut transaction_ids = Vec::with_capacity(MAX_TRANSACTION_CHECKS); + + while let Some(id) = self.container.pop() { + transaction_ids.push(id); + } + + let bank = self.bank_forks.read().unwrap().working_bank(); + + const CHUNK_SIZE: usize = 128; + let mut error_counters = TransactionErrorMetrics::default(); + let mut num_dropped_on_age_and_status: usize = 0; + for chunk in transaction_ids.chunks(CHUNK_SIZE) { + let lock_results = vec![Ok(()); chunk.len()]; + let sanitized_txs: Vec<_> = chunk + .iter() + .map(|id| { + &self + .container + .get_transaction_ttl(&id.id) + .expect("transaction must exist") + .transaction + }) + .collect(); + + let check_results = bank.check_transactions( + &sanitized_txs, + &lock_results, + MAX_PROCESSING_AGE, + &mut error_counters, + ); + + for (result, id) in check_results.into_iter().zip(chunk.iter()) { + if result.is_err() { + saturating_add_assign!(num_dropped_on_age_and_status, 1); + self.container.remove_by_id(&id.id); + } else { + self.container.push_id_into_queue(*id); + } + } + } + + self.count_metrics.update(|count_metrics| { + saturating_add_assign!( + count_metrics.num_dropped_on_age_and_status, + num_dropped_on_age_and_status + ); + }); + } + + /// Receives completed transactions from the workers and updates metrics. + fn receive_completed(&mut self) -> Result<(), SchedulerError> { + let ((num_transactions, num_retryable), receive_completed_time_us) = + measure_us!(self.scheduler.receive_completed(&mut self.container)?); + + self.count_metrics.update(|count_metrics| { + saturating_add_assign!(count_metrics.num_finished, num_transactions); + saturating_add_assign!(count_metrics.num_retryable, num_retryable); + }); + self.timing_metrics.update(|timing_metrics| { + saturating_add_assign!( + timing_metrics.receive_completed_time_us, + receive_completed_time_us + ); + }); + + Ok(()) + } + + /// Returns whether the packet receiver is still connected. + fn receive_and_buffer_packets(&mut self, decision: &BufferedPacketsDecision) -> bool { + let remaining_queue_capacity = self.container.remaining_queue_capacity(); + + const MAX_PACKET_RECEIVE_TIME: Duration = Duration::from_millis(10); + let (recv_timeout, should_buffer) = match decision { + BufferedPacketsDecision::Consume(_) => ( + if self.container.is_empty() { + MAX_PACKET_RECEIVE_TIME + } else { + Duration::ZERO + }, + true, + ), + BufferedPacketsDecision::Forward => (MAX_PACKET_RECEIVE_TIME, self.forwarder.is_some()), + BufferedPacketsDecision::ForwardAndHold | BufferedPacketsDecision::Hold => { + (MAX_PACKET_RECEIVE_TIME, true) + } + }; + + let (received_packet_results, receive_time_us) = measure_us!(self + .packet_receiver + .receive_packets(recv_timeout, remaining_queue_capacity, |packet| { + packet.check_excessive_precompiles()?; + Ok(packet) + })); + + self.timing_metrics.update(|timing_metrics| { + saturating_add_assign!(timing_metrics.receive_time_us, receive_time_us); + }); + + match received_packet_results { + Ok(receive_packet_results) => { + let num_received_packets = receive_packet_results.deserialized_packets.len(); + + self.count_metrics.update(|count_metrics| { + saturating_add_assign!(count_metrics.num_received, num_received_packets); + }); + + if should_buffer { + let (_, buffer_time_us) = measure_us!( + self.buffer_packets(receive_packet_results.deserialized_packets) + ); + self.timing_metrics.update(|timing_metrics| { + saturating_add_assign!(timing_metrics.buffer_time_us, buffer_time_us); + }); + } else { + self.count_metrics.update(|count_metrics| { + saturating_add_assign!( + count_metrics.num_dropped_on_receive, + num_received_packets + ); + }); + } + } + Err(RecvTimeoutError::Timeout) => {} + Err(RecvTimeoutError::Disconnected) => return false, + } + + true + } + + fn buffer_packets(&mut self, packets: Vec) { + // Convert to Arcs + let packets: Vec<_> = packets.into_iter().map(Arc::new).collect(); + // Sanitize packets, generate IDs, and insert into the container. + let bank = self.bank_forks.read().unwrap().working_bank(); + let last_slot_in_epoch = bank.epoch_schedule().get_last_slot_in_epoch(bank.epoch()); + let transaction_account_lock_limit = bank.get_transaction_account_lock_limit(); + let vote_only = bank.vote_only_bank(); + + const CHUNK_SIZE: usize = 128; + let lock_results: [_; CHUNK_SIZE] = core::array::from_fn(|_| Ok(())); + + let mut arc_packets = ArrayVec::<_, CHUNK_SIZE>::new(); + let mut transactions = ArrayVec::<_, CHUNK_SIZE>::new(); + let mut fee_budget_limits_vec = ArrayVec::<_, CHUNK_SIZE>::new(); + + let mut error_counts = TransactionErrorMetrics::default(); + for chunk in packets.chunks(CHUNK_SIZE) { + let mut post_sanitization_count: usize = 0; + chunk + .iter() + .filter_map(|packet| { + packet + .build_sanitized_transaction( + vote_only, + bank.as_ref(), + bank.get_reserved_account_keys(), + ) + .map(|tx| (packet.clone(), tx)) + }) + .inspect(|_| saturating_add_assign!(post_sanitization_count, 1)) + .filter(|(_packet, tx)| { + validate_account_locks( + tx.message().account_keys(), + transaction_account_lock_limit, + ) + .is_ok() + }) + .filter_map(|(packet, tx)| { + process_compute_budget_instructions(SVMMessage::program_instructions_iter(&tx)) + .map(|compute_budget| (packet, tx, compute_budget.into())) + .ok() + }) + .for_each(|(packet, tx, fee_budget_limits)| { + arc_packets.push(packet); + transactions.push(tx); + fee_budget_limits_vec.push(fee_budget_limits); + }); + + let check_results = bank.check_transactions( + &transactions, + &lock_results[..transactions.len()], + MAX_PROCESSING_AGE, + &mut error_counts, + ); + let post_lock_validation_count = transactions.len(); + + let mut post_transaction_check_count: usize = 0; + let mut num_dropped_on_capacity: usize = 0; + let mut num_buffered: usize = 0; + for (((packet, transaction), fee_budget_limits), _check_result) in arc_packets + .drain(..) + .zip(transactions.drain(..)) + .zip(fee_budget_limits_vec.drain(..)) + .zip(check_results) + .filter(|(_, check_result)| check_result.is_ok()) + { + saturating_add_assign!(post_transaction_check_count, 1); + let transaction_id = self.transaction_id_generator.next(); + + let (priority, cost) = + Self::calculate_priority_and_cost(&transaction, &fee_budget_limits, &bank); + let transaction_ttl = SanitizedTransactionTTL { + transaction, + max_age_slot: last_slot_in_epoch, + }; + + if self.container.insert_new_transaction( + transaction_id, + transaction_ttl, + packet, + priority, + cost, + ) { + saturating_add_assign!(num_dropped_on_capacity, 1); + } + saturating_add_assign!(num_buffered, 1); + } + + // Update metrics for transactions that were dropped. + let num_dropped_on_sanitization = chunk.len().saturating_sub(post_sanitization_count); + let num_dropped_on_lock_validation = + post_sanitization_count.saturating_sub(post_lock_validation_count); + let num_dropped_on_transaction_checks = + post_lock_validation_count.saturating_sub(post_transaction_check_count); + + self.count_metrics.update(|count_metrics| { + saturating_add_assign!( + count_metrics.num_dropped_on_capacity, + num_dropped_on_capacity + ); + saturating_add_assign!(count_metrics.num_buffered, num_buffered); + saturating_add_assign!( + count_metrics.num_dropped_on_sanitization, + num_dropped_on_sanitization + ); + saturating_add_assign!( + count_metrics.num_dropped_on_validate_locks, + num_dropped_on_lock_validation + ); + saturating_add_assign!( + count_metrics.num_dropped_on_receive_transaction_checks, + num_dropped_on_transaction_checks + ); + }); + } + } + + /// Calculate priority and cost for a transaction: + /// + /// Cost is calculated through the `CostModel`, + /// and priority is calculated through a formula here that attempts to sell + /// blockspace to the highest bidder. + /// + /// The priority is calculated as: + /// P = R / (1 + C) + /// where P is the priority, R is the reward, + /// and C is the cost towards block-limits. + /// + /// Current minimum costs are on the order of several hundred, + /// so the denominator is effectively C, and the +1 is simply + /// to avoid any division by zero due to a bug - these costs + /// are calculated by the cost-model and are not direct + /// from user input. They should never be zero. + /// Any difference in the prioritization is negligible for + /// the current transaction costs. + fn calculate_priority_and_cost( + transaction: &SanitizedTransaction, + fee_budget_limits: &FeeBudgetLimits, + bank: &Bank, + ) -> (u64, u64) { + let cost = CostModel::calculate_cost(transaction, &bank.feature_set).sum(); + let reward = bank.calculate_reward_for_transaction(transaction, fee_budget_limits); + + // We need a multiplier here to avoid rounding down too aggressively. + // For many transactions, the cost will be greater than the fees in terms of raw lamports. + // For the purposes of calculating prioritization, we multiply the fees by a large number so that + // the cost is a small fraction. + // An offset of 1 is used in the denominator to explicitly avoid division by zero. + const MULTIPLIER: u64 = 1_000_000; + ( + reward + .saturating_mul(MULTIPLIER) + .saturating_div(cost.saturating_add(1)), + cost, + ) + } +} + +#[cfg(test)] +mod tests { + use { + super::*, + crate::{ + banking_stage::{ + consumer::TARGET_NUM_TRANSACTIONS_PER_BATCH, + scheduler_messages::{ConsumeWork, FinishedConsumeWork, TransactionBatchId}, + tests::create_slow_genesis_config, + }, + banking_trace::BankingPacketBatch, + sigverify::SigverifyTracerPacketStats, + }, + crossbeam_channel::{unbounded, Receiver, Sender}, + itertools::Itertools, + solana_gossip::cluster_info::ClusterInfo, + solana_ledger::{ + blockstore::Blockstore, genesis_utils::GenesisConfigInfo, + get_tmp_ledger_path_auto_delete, leader_schedule_cache::LeaderScheduleCache, + }, + solana_perf::packet::{to_packet_batches, PacketBatch, NUM_PACKETS}, + solana_poh::poh_recorder::{PohRecorder, Record, WorkingBankEntry}, + solana_runtime::bank::Bank, + solana_sdk::{ + compute_budget::ComputeBudgetInstruction, fee_calculator::FeeRateGovernor, hash::Hash, + message::Message, poh_config::PohConfig, pubkey::Pubkey, signature::Keypair, + signer::Signer, system_instruction, system_transaction, transaction::Transaction, + }, + std::sync::{atomic::AtomicBool, Arc, RwLock}, + tempfile::TempDir, + }; + + fn create_channels(num: usize) -> (Vec>, Vec>) { + (0..num).map(|_| unbounded()).unzip() + } + + // Helper struct to create tests that hold channels, files, etc. + // such that our tests can be more easily set up and run. + struct TestFrame { + bank: Arc, + mint_keypair: Keypair, + _ledger_path: TempDir, + _entry_receiver: Receiver, + _record_receiver: Receiver, + poh_recorder: Arc>, + banking_packet_sender: Sender, Option)>>, + + consume_work_receivers: Vec>, + finished_consume_work_sender: Sender, + } + + fn create_test_frame(num_threads: usize) -> (TestFrame, SchedulerController>) { + let GenesisConfigInfo { + mut genesis_config, + mint_keypair, + .. + } = create_slow_genesis_config(u64::MAX); + genesis_config.fee_rate_governor = FeeRateGovernor::new(5000, 0); + let (bank, bank_forks) = Bank::new_no_wallclock_throttle_for_tests(&genesis_config); + + let ledger_path = get_tmp_ledger_path_auto_delete!(); + let blockstore = Blockstore::open(ledger_path.path()) + .expect("Expected to be able to open database ledger"); + let (poh_recorder, entry_receiver, record_receiver) = PohRecorder::new( + bank.tick_height(), + bank.last_blockhash(), + bank.clone(), + Some((4, 4)), + bank.ticks_per_slot(), + Arc::new(blockstore), + &Arc::new(LeaderScheduleCache::new_from_bank(&bank)), + &PohConfig::default(), + Arc::new(AtomicBool::default()), + ); + let poh_recorder = Arc::new(RwLock::new(poh_recorder)); + let decision_maker = DecisionMaker::new(Pubkey::new_unique(), poh_recorder.clone()); + + let (banking_packet_sender, banking_packet_receiver) = unbounded(); + let packet_deserializer = PacketDeserializer::new(banking_packet_receiver); + + let (consume_work_senders, consume_work_receivers) = create_channels(num_threads); + let (finished_consume_work_sender, finished_consume_work_receiver) = unbounded(); + + let test_frame = TestFrame { + bank, + mint_keypair, + _ledger_path: ledger_path, + _entry_receiver: entry_receiver, + _record_receiver: record_receiver, + poh_recorder, + banking_packet_sender, + consume_work_receivers, + finished_consume_work_sender, + }; + + let scheduler_controller = SchedulerController::new( + decision_maker, + packet_deserializer, + bank_forks, + PrioGraphScheduler::new(consume_work_senders, finished_consume_work_receiver), + vec![], // no actual workers with metrics to report, this can be empty + None, + ); + + (test_frame, scheduler_controller) + } + + fn create_and_fund_prioritized_transfer( + bank: &Bank, + mint_keypair: &Keypair, + from_keypair: &Keypair, + to_pubkey: &Pubkey, + lamports: u64, + compute_unit_price: u64, + recent_blockhash: Hash, + ) -> Transaction { + // Fund the sending key, so that the transaction does not get filtered by the fee-payer check. + { + let transfer = system_transaction::transfer( + mint_keypair, + &from_keypair.pubkey(), + 500_000, // just some amount that will always be enough + bank.last_blockhash(), + ); + bank.process_transaction(&transfer).unwrap(); + } + + let transfer = system_instruction::transfer(&from_keypair.pubkey(), to_pubkey, lamports); + let prioritization = ComputeBudgetInstruction::set_compute_unit_price(compute_unit_price); + let message = Message::new(&[transfer, prioritization], Some(&from_keypair.pubkey())); + Transaction::new(&vec![from_keypair], message, recent_blockhash) + } + + fn to_banking_packet_batch(txs: &[Transaction]) -> BankingPacketBatch { + let packet_batch = to_packet_batches(txs, NUM_PACKETS); + Arc::new((packet_batch, None)) + } + + // Helper function to let test receive and then schedule packets. + // The order of operations here is convenient for testing, but does not + // match the order of operations in the actual scheduler. + // The actual scheduler will process immediately after the decision, + // in order to keep the decision as recent as possible for processing. + // In the tests, the decision will not become stale, so it is more convenient + // to receive first and then schedule. + fn test_receive_then_schedule( + scheduler_controller: &mut SchedulerController>, + ) { + let decision = scheduler_controller + .decision_maker + .make_consume_or_forward_decision(); + assert!(matches!(decision, BufferedPacketsDecision::Consume(_))); + assert!(scheduler_controller.receive_completed().is_ok()); + assert!(scheduler_controller.receive_and_buffer_packets(&decision)); + assert!(scheduler_controller.process_transactions(&decision).is_ok()); + } + + #[test] + #[should_panic(expected = "batch id 0 is not being tracked")] + fn test_unexpected_batch_id() { + let (test_frame, scheduler_controller) = create_test_frame(1); + let TestFrame { + finished_consume_work_sender, + .. + } = &test_frame; + + finished_consume_work_sender + .send(FinishedConsumeWork { + work: ConsumeWork { + batch_id: TransactionBatchId::new(0), + ids: vec![], + transactions: vec![], + max_age_slots: vec![], + }, + retryable_indexes: vec![], + }) + .unwrap(); + + scheduler_controller.run().unwrap(); + } + + #[test] + fn test_schedule_consume_single_threaded_no_conflicts() { + let (test_frame, mut scheduler_controller) = create_test_frame(1); + let TestFrame { + bank, + mint_keypair, + poh_recorder, + banking_packet_sender, + consume_work_receivers, + .. + } = &test_frame; + + poh_recorder + .write() + .unwrap() + .set_bank_for_test(bank.clone()); + + // Send packet batch to the scheduler - should do nothing until we become the leader. + let tx1 = create_and_fund_prioritized_transfer( + bank, + mint_keypair, + &Keypair::new(), + &Pubkey::new_unique(), + 1, + 1000, + bank.last_blockhash(), + ); + let tx2 = create_and_fund_prioritized_transfer( + bank, + mint_keypair, + &Keypair::new(), + &Pubkey::new_unique(), + 1, + 2000, + bank.last_blockhash(), + ); + let tx1_hash = tx1.message().hash(); + let tx2_hash = tx2.message().hash(); + + let txs = vec![tx1, tx2]; + banking_packet_sender + .send(to_banking_packet_batch(&txs)) + .unwrap(); + + test_receive_then_schedule(&mut scheduler_controller); + let consume_work = consume_work_receivers[0].try_recv().unwrap(); + assert_eq!(consume_work.ids.len(), 2); + assert_eq!(consume_work.transactions.len(), 2); + let message_hashes = consume_work + .transactions + .iter() + .map(|tx| tx.message_hash()) + .collect_vec(); + assert_eq!(message_hashes, vec![&tx2_hash, &tx1_hash]); + } + + #[test] + fn test_schedule_consume_single_threaded_conflict() { + let (test_frame, mut scheduler_controller) = create_test_frame(1); + let TestFrame { + bank, + mint_keypair, + poh_recorder, + banking_packet_sender, + consume_work_receivers, + .. + } = &test_frame; + + poh_recorder + .write() + .unwrap() + .set_bank_for_test(bank.clone()); + + let pk = Pubkey::new_unique(); + let tx1 = create_and_fund_prioritized_transfer( + bank, + mint_keypair, + &Keypair::new(), + &pk, + 1, + 1000, + bank.last_blockhash(), + ); + let tx2 = create_and_fund_prioritized_transfer( + bank, + mint_keypair, + &Keypair::new(), + &pk, + 1, + 2000, + bank.last_blockhash(), + ); + let tx1_hash = tx1.message().hash(); + let tx2_hash = tx2.message().hash(); + + let txs = vec![tx1, tx2]; + banking_packet_sender + .send(to_banking_packet_batch(&txs)) + .unwrap(); + + // We expect 2 batches to be scheduled + test_receive_then_schedule(&mut scheduler_controller); + let consume_works = (0..2) + .map(|_| consume_work_receivers[0].try_recv().unwrap()) + .collect_vec(); + + let num_txs_per_batch = consume_works.iter().map(|cw| cw.ids.len()).collect_vec(); + let message_hashes = consume_works + .iter() + .flat_map(|cw| cw.transactions.iter().map(|tx| tx.message_hash())) + .collect_vec(); + assert_eq!(num_txs_per_batch, vec![1; 2]); + assert_eq!(message_hashes, vec![&tx2_hash, &tx1_hash]); + } + + #[test] + fn test_schedule_consume_single_threaded_multi_batch() { + let (test_frame, mut scheduler_controller) = create_test_frame(1); + let TestFrame { + bank, + mint_keypair, + poh_recorder, + banking_packet_sender, + consume_work_receivers, + .. + } = &test_frame; + + poh_recorder + .write() + .unwrap() + .set_bank_for_test(bank.clone()); + + // Send multiple batches - all get scheduled + let txs1 = (0..2 * TARGET_NUM_TRANSACTIONS_PER_BATCH) + .map(|i| { + create_and_fund_prioritized_transfer( + bank, + mint_keypair, + &Keypair::new(), + &Pubkey::new_unique(), + i as u64, + 1, + bank.last_blockhash(), + ) + }) + .collect_vec(); + let txs2 = (0..2 * TARGET_NUM_TRANSACTIONS_PER_BATCH) + .map(|i| { + create_and_fund_prioritized_transfer( + bank, + mint_keypair, + &Keypair::new(), + &Pubkey::new_unique(), + i as u64, + 2, + bank.last_blockhash(), + ) + }) + .collect_vec(); + + banking_packet_sender + .send(to_banking_packet_batch(&txs1)) + .unwrap(); + banking_packet_sender + .send(to_banking_packet_batch(&txs2)) + .unwrap(); + + // We expect 4 batches to be scheduled + test_receive_then_schedule(&mut scheduler_controller); + let consume_works = (0..4) + .map(|_| consume_work_receivers[0].try_recv().unwrap()) + .collect_vec(); + + assert_eq!( + consume_works.iter().map(|cw| cw.ids.len()).collect_vec(), + vec![TARGET_NUM_TRANSACTIONS_PER_BATCH; 4] + ); + } + + #[test] + fn test_schedule_consume_simple_thread_selection() { + let (test_frame, mut scheduler_controller) = create_test_frame(2); + let TestFrame { + bank, + mint_keypair, + poh_recorder, + banking_packet_sender, + consume_work_receivers, + .. + } = &test_frame; + + poh_recorder + .write() + .unwrap() + .set_bank_for_test(bank.clone()); + + // Send 4 transactions w/o conflicts. 2 should be scheduled on each thread + let txs = (0..4) + .map(|i| { + create_and_fund_prioritized_transfer( + bank, + mint_keypair, + &Keypair::new(), + &Pubkey::new_unique(), + 1, + i * 10, + bank.last_blockhash(), + ) + }) + .collect_vec(); + banking_packet_sender + .send(to_banking_packet_batch(&txs)) + .unwrap(); + + // Priority Expectation: + // Thread 0: [3, 1] + // Thread 1: [2, 0] + let t0_expected = [3, 1] + .into_iter() + .map(|i| txs[i].message().hash()) + .collect_vec(); + let t1_expected = [2, 0] + .into_iter() + .map(|i| txs[i].message().hash()) + .collect_vec(); + + test_receive_then_schedule(&mut scheduler_controller); + let t0_actual = consume_work_receivers[0] + .try_recv() + .unwrap() + .transactions + .iter() + .map(|tx| *tx.message_hash()) + .collect_vec(); + let t1_actual = consume_work_receivers[1] + .try_recv() + .unwrap() + .transactions + .iter() + .map(|tx| *tx.message_hash()) + .collect_vec(); + + assert_eq!(t0_actual, t0_expected); + assert_eq!(t1_actual, t1_expected); + } + + #[test] + fn test_schedule_consume_retryable() { + let (test_frame, mut scheduler_controller) = create_test_frame(1); + let TestFrame { + bank, + mint_keypair, + poh_recorder, + banking_packet_sender, + consume_work_receivers, + finished_consume_work_sender, + .. + } = &test_frame; + + poh_recorder + .write() + .unwrap() + .set_bank_for_test(bank.clone()); + + // Send packet batch to the scheduler - should do nothing until we become the leader. + let tx1 = create_and_fund_prioritized_transfer( + bank, + mint_keypair, + &Keypair::new(), + &Pubkey::new_unique(), + 1, + 1000, + bank.last_blockhash(), + ); + let tx2 = create_and_fund_prioritized_transfer( + bank, + mint_keypair, + &Keypair::new(), + &Pubkey::new_unique(), + 1, + 2000, + bank.last_blockhash(), + ); + let tx1_hash = tx1.message().hash(); + let tx2_hash = tx2.message().hash(); + + let txs = vec![tx1, tx2]; + banking_packet_sender + .send(to_banking_packet_batch(&txs)) + .unwrap(); + + test_receive_then_schedule(&mut scheduler_controller); + let consume_work = consume_work_receivers[0].try_recv().unwrap(); + assert_eq!(consume_work.ids.len(), 2); + assert_eq!(consume_work.transactions.len(), 2); + let message_hashes = consume_work + .transactions + .iter() + .map(|tx| tx.message_hash()) + .collect_vec(); + assert_eq!(message_hashes, vec![&tx2_hash, &tx1_hash]); + + // Complete the batch - marking the second transaction as retryable + finished_consume_work_sender + .send(FinishedConsumeWork { + work: consume_work, + retryable_indexes: vec![1], + }) + .unwrap(); + + // Transaction should be rescheduled + test_receive_then_schedule(&mut scheduler_controller); + let consume_work = consume_work_receivers[0].try_recv().unwrap(); + assert_eq!(consume_work.ids.len(), 1); + assert_eq!(consume_work.transactions.len(), 1); + let message_hashes = consume_work + .transactions + .iter() + .map(|tx| tx.message_hash()) + .collect_vec(); + assert_eq!(message_hashes, vec![&tx1_hash]); + } +} diff --git a/prio-graph-scheduler/src/scheduler_error.rs b/prio-graph-scheduler/src/scheduler_error.rs new file mode 100644 index 00000000000000..9b8d4015448e57 --- /dev/null +++ b/prio-graph-scheduler/src/scheduler_error.rs @@ -0,0 +1,9 @@ +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum SchedulerError { + #[error("Sending channel disconnected: {0}")] + DisconnectedSendChannel(&'static str), + #[error("Recv channel disconnected: {0}")] + DisconnectedRecvChannel(&'static str), +} diff --git a/prio-graph-scheduler/src/scheduler_metrics.rs b/prio-graph-scheduler/src/scheduler_metrics.rs new file mode 100644 index 00000000000000..922c105acba8c9 --- /dev/null +++ b/prio-graph-scheduler/src/scheduler_metrics.rs @@ -0,0 +1,408 @@ +use { + itertools::MinMaxResult, + solana_poh::poh_recorder::BankStart, + solana_sdk::{clock::Slot, timing::AtomicInterval}, + std::time::Instant, +}; + +#[derive(Default)] +pub struct SchedulerCountMetrics { + interval: IntervalSchedulerCountMetrics, + slot: SlotSchedulerCountMetrics, +} + +impl SchedulerCountMetrics { + pub fn update(&mut self, update: impl Fn(&mut SchedulerCountMetricsInner)) { + update(&mut self.interval.metrics); + update(&mut self.slot.metrics); + } + + pub fn maybe_report_and_reset_slot(&mut self, slot: Option) { + self.slot.maybe_report_and_reset(slot); + } + + pub fn maybe_report_and_reset_interval(&mut self, should_report: bool) { + self.interval.maybe_report_and_reset(should_report); + } + + pub fn interval_has_data(&self) -> bool { + self.interval.metrics.has_data() + } +} + +#[derive(Default)] +struct IntervalSchedulerCountMetrics { + interval: AtomicInterval, + metrics: SchedulerCountMetricsInner, +} + +#[derive(Default)] +struct SlotSchedulerCountMetrics { + slot: Option, + metrics: SchedulerCountMetricsInner, +} + +#[derive(Default)] +pub struct SchedulerCountMetricsInner { + /// Number of packets received. + pub num_received: usize, + /// Number of packets buffered. + pub num_buffered: usize, + + /// Number of transactions scheduled. + pub num_scheduled: usize, + /// Number of transactions that were unschedulable. + pub num_unschedulable: usize, + /// Number of transactions that were filtered out during scheduling. + pub num_schedule_filtered_out: usize, + /// Number of completed transactions received from workers. + pub num_finished: usize, + /// Number of transactions that were retryable. + pub num_retryable: usize, + /// Number of transactions that were scheduled to be forwarded. + pub num_forwarded: usize, + + /// Number of transactions that were immediately dropped on receive. + pub num_dropped_on_receive: usize, + /// Number of transactions that were dropped due to sanitization failure. + pub num_dropped_on_sanitization: usize, + /// Number of transactions that were dropped due to failed lock validation. + pub num_dropped_on_validate_locks: usize, + /// Number of transactions that were dropped due to failed transaction + /// checks during receive. + pub num_dropped_on_receive_transaction_checks: usize, + /// Number of transactions that were dropped due to clearing. + pub num_dropped_on_clear: usize, + /// Number of transactions that were dropped due to age and status checks. + pub num_dropped_on_age_and_status: usize, + /// Number of transactions that were dropped due to exceeded capacity. + pub num_dropped_on_capacity: usize, + /// Min prioritization fees in the transaction container + pub min_prioritization_fees: u64, + /// Max prioritization fees in the transaction container + pub max_prioritization_fees: u64, +} + +impl IntervalSchedulerCountMetrics { + fn maybe_report_and_reset(&mut self, should_report: bool) { + const REPORT_INTERVAL_MS: u64 = 1000; + if self.interval.should_update(REPORT_INTERVAL_MS) { + if should_report { + self.metrics.report("banking_stage_scheduler_counts", None); + } + self.metrics.reset(); + } + } +} + +impl SlotSchedulerCountMetrics { + fn maybe_report_and_reset(&mut self, slot: Option) { + if self.slot != slot { + // Only report if there was an assigned slot. + if self.slot.is_some() { + self.metrics + .report("banking_stage_scheduler_slot_counts", self.slot); + } + self.metrics.reset(); + self.slot = slot; + } + } +} + +impl SchedulerCountMetricsInner { + fn report(&self, name: &'static str, slot: Option) { + let mut datapoint = create_datapoint!( + @point name, + ("num_received", self.num_received, i64), + ("num_buffered", self.num_buffered, i64), + ("num_scheduled", self.num_scheduled, i64), + ("num_unschedulable", self.num_unschedulable, i64), + ( + "num_schedule_filtered_out", + self.num_schedule_filtered_out, + i64 + ), + ("num_finished", self.num_finished, i64), + ("num_retryable", self.num_retryable, i64), + ("num_forwarded", self.num_forwarded, i64), + ("num_dropped_on_receive", self.num_dropped_on_receive, i64), + ( + "num_dropped_on_sanitization", + self.num_dropped_on_sanitization, + i64 + ), + ( + "num_dropped_on_validate_locks", + self.num_dropped_on_validate_locks, + i64 + ), + ( + "num_dropped_on_receive_transaction_checks", + self.num_dropped_on_receive_transaction_checks, + i64 + ), + ("num_dropped_on_clear", self.num_dropped_on_clear, i64), + ( + "num_dropped_on_age_and_status", + self.num_dropped_on_age_and_status, + i64 + ), + ("num_dropped_on_capacity", self.num_dropped_on_capacity, i64), + ("min_priority", self.get_min_priority(), i64), + ("max_priority", self.get_max_priority(), i64) + ); + if let Some(slot) = slot { + datapoint.add_field_i64("slot", slot as i64); + } + solana_metrics::submit(datapoint, log::Level::Info); + } + + pub fn has_data(&self) -> bool { + self.num_received != 0 + || self.num_buffered != 0 + || self.num_scheduled != 0 + || self.num_unschedulable != 0 + || self.num_schedule_filtered_out != 0 + || self.num_finished != 0 + || self.num_retryable != 0 + || self.num_forwarded != 0 + || self.num_dropped_on_receive != 0 + || self.num_dropped_on_sanitization != 0 + || self.num_dropped_on_validate_locks != 0 + || self.num_dropped_on_receive_transaction_checks != 0 + || self.num_dropped_on_clear != 0 + || self.num_dropped_on_age_and_status != 0 + || self.num_dropped_on_capacity != 0 + } + + fn reset(&mut self) { + self.num_received = 0; + self.num_buffered = 0; + self.num_scheduled = 0; + self.num_unschedulable = 0; + self.num_schedule_filtered_out = 0; + self.num_finished = 0; + self.num_retryable = 0; + self.num_forwarded = 0; + self.num_dropped_on_receive = 0; + self.num_dropped_on_sanitization = 0; + self.num_dropped_on_validate_locks = 0; + self.num_dropped_on_receive_transaction_checks = 0; + self.num_dropped_on_clear = 0; + self.num_dropped_on_age_and_status = 0; + self.num_dropped_on_capacity = 0; + self.min_prioritization_fees = u64::MAX; + self.max_prioritization_fees = 0; + } + + pub fn update_priority_stats(&mut self, min_max_fees: MinMaxResult) { + // update min/max priority + match min_max_fees { + itertools::MinMaxResult::NoElements => { + // do nothing + } + itertools::MinMaxResult::OneElement(e) => { + self.min_prioritization_fees = e; + self.max_prioritization_fees = e; + } + itertools::MinMaxResult::MinMax(min, max) => { + self.min_prioritization_fees = min; + self.max_prioritization_fees = max; + } + } + } + + pub fn get_min_priority(&self) -> u64 { + // to avoid getting u64::max recorded by metrics / in case of edge cases + if self.min_prioritization_fees != u64::MAX { + self.min_prioritization_fees + } else { + 0 + } + } + + pub fn get_max_priority(&self) -> u64 { + self.max_prioritization_fees + } +} + +#[derive(Default)] +pub struct SchedulerTimingMetrics { + interval: IntervalSchedulerTimingMetrics, + slot: SlotSchedulerTimingMetrics, +} + +impl SchedulerTimingMetrics { + pub fn update(&mut self, update: impl Fn(&mut SchedulerTimingMetricsInner)) { + update(&mut self.interval.metrics); + update(&mut self.slot.metrics); + } + + pub fn maybe_report_and_reset_slot(&mut self, slot: Option) { + self.slot.maybe_report_and_reset(slot); + } + + pub fn maybe_report_and_reset_interval(&mut self, should_report: bool) { + self.interval.maybe_report_and_reset(should_report); + } +} + +#[derive(Default)] +struct IntervalSchedulerTimingMetrics { + interval: AtomicInterval, + metrics: SchedulerTimingMetricsInner, +} + +#[derive(Default)] +struct SlotSchedulerTimingMetrics { + slot: Option, + metrics: SchedulerTimingMetricsInner, +} + +#[derive(Default)] +pub struct SchedulerTimingMetricsInner { + /// Time spent making processing decisions. + pub decision_time_us: u64, + /// Time spent receiving packets. + pub receive_time_us: u64, + /// Time spent buffering packets. + pub buffer_time_us: u64, + /// Time spent filtering transactions during scheduling. + pub schedule_filter_time_us: u64, + /// Time spent scheduling transactions. + pub schedule_time_us: u64, + /// Time spent clearing transactions from the container. + pub clear_time_us: u64, + /// Time spent cleaning expired or processed transactions from the container. + pub clean_time_us: u64, + /// Time spent forwarding transactions. + pub forward_time_us: u64, + /// Time spent receiving completed transactions. + pub receive_completed_time_us: u64, +} + +impl IntervalSchedulerTimingMetrics { + fn maybe_report_and_reset(&mut self, should_report: bool) { + const REPORT_INTERVAL_MS: u64 = 1000; + if self.interval.should_update(REPORT_INTERVAL_MS) { + if should_report { + self.metrics.report("banking_stage_scheduler_timing", None); + } + self.metrics.reset(); + } + } +} + +impl SlotSchedulerTimingMetrics { + fn maybe_report_and_reset(&mut self, slot: Option) { + if self.slot != slot { + // Only report if there was an assigned slot. + if self.slot.is_some() { + self.metrics + .report("banking_stage_scheduler_slot_timing", self.slot); + } + self.metrics.reset(); + self.slot = slot; + } + } +} + +impl SchedulerTimingMetricsInner { + fn report(&self, name: &'static str, slot: Option) { + let mut datapoint = create_datapoint!( + @point name, + ("decision_time_us", self.decision_time_us, i64), + ("receive_time_us", self.receive_time_us, i64), + ("buffer_time_us", self.buffer_time_us, i64), + ("schedule_filter_time_us", self.schedule_filter_time_us, i64), + ("schedule_time_us", self.schedule_time_us, i64), + ("clear_time_us", self.clear_time_us, i64), + ("clean_time_us", self.clean_time_us, i64), + ("forward_time_us", self.forward_time_us, i64), + ( + "receive_completed_time_us", + self.receive_completed_time_us, + i64 + ) + ); + if let Some(slot) = slot { + datapoint.add_field_i64("slot", slot as i64); + } + solana_metrics::submit(datapoint, log::Level::Info); + } + + fn reset(&mut self) { + self.decision_time_us = 0; + self.receive_time_us = 0; + self.buffer_time_us = 0; + self.schedule_filter_time_us = 0; + self.schedule_time_us = 0; + self.clear_time_us = 0; + self.clean_time_us = 0; + self.forward_time_us = 0; + self.receive_completed_time_us = 0; + } +} + +#[derive(Default)] +pub struct SchedulerLeaderDetectionMetrics { + inner: Option, +} + +struct SchedulerLeaderDetectionMetricsInner { + slot: Slot, + bank_creation_time: Instant, + bank_detected_time: Instant, +} + +impl SchedulerLeaderDetectionMetrics { + pub fn update_and_maybe_report(&mut self, bank_start: Option<&BankStart>) { + match (&self.inner, bank_start) { + (None, Some(bank_start)) => self.initialize_inner(bank_start), + (Some(_inner), None) => self.report_and_reset(), + (Some(inner), Some(bank_start)) if inner.slot != bank_start.working_bank.slot() => { + self.report_and_reset(); + self.initialize_inner(bank_start); + } + _ => {} + } + } + + fn initialize_inner(&mut self, bank_start: &BankStart) { + let bank_detected_time = Instant::now(); + self.inner = Some(SchedulerLeaderDetectionMetricsInner { + slot: bank_start.working_bank.slot(), + bank_creation_time: *bank_start.bank_creation_time, + bank_detected_time, + }); + } + + fn report_and_reset(&mut self) { + let SchedulerLeaderDetectionMetricsInner { + slot, + bank_creation_time, + bank_detected_time, + } = self.inner.take().expect("inner must be present"); + + let bank_detected_delay_us = bank_detected_time + .duration_since(bank_creation_time) + .as_micros() + .try_into() + .unwrap_or(i64::MAX); + let bank_detected_to_slot_end_detected_us = bank_detected_time + .elapsed() + .as_micros() + .try_into() + .unwrap_or(i64::MAX); + datapoint_info!( + "banking_stage_scheduler_leader_detection", + ("slot", slot, i64), + ("bank_detected_delay_us", bank_detected_delay_us, i64), + ( + "bank_detected_to_slot_end_detected_us", + bank_detected_to_slot_end_detected_us, + i64 + ), + ); + } +} diff --git a/prio-graph-scheduler/src/transaction_priority_id.rs b/prio-graph-scheduler/src/transaction_priority_id.rs new file mode 100644 index 00000000000000..f39927324e355a --- /dev/null +++ b/prio-graph-scheduler/src/transaction_priority_id.rs @@ -0,0 +1,69 @@ +use { + crate::scheduler_messages::TransactionId, + prio_graph::TopLevelId, + std::hash::{Hash, Hasher}, +}; + +/// A unique identifier tied with priority ordering for a transaction/packet: +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub struct TransactionPriorityId { + pub priority: u64, + pub id: TransactionId, +} + +impl TransactionPriorityId { + pub fn new(priority: u64, id: TransactionId) -> Self { + Self { priority, id } + } +} + +impl Hash for TransactionPriorityId { + fn hash(&self, state: &mut H) { + self.id.hash(state) + } +} + +impl TopLevelId for TransactionPriorityId { + fn id(&self) -> Self { + *self + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_transaction_priority_id_ordering() { + // Higher priority first + { + let id1 = TransactionPriorityId::new(1, TransactionId::new(1)); + let id2 = TransactionPriorityId::new(2, TransactionId::new(1)); + assert!(id1 < id2); + assert!(id1 <= id2); + assert!(id2 > id1); + assert!(id2 >= id1); + } + + // Equal priority then compare by id + { + let id1 = TransactionPriorityId::new(1, TransactionId::new(1)); + let id2 = TransactionPriorityId::new(1, TransactionId::new(2)); + assert!(id1 < id2); + assert!(id1 <= id2); + assert!(id2 > id1); + assert!(id2 >= id1); + } + + // Equal priority and id + { + let id1 = TransactionPriorityId::new(1, TransactionId::new(1)); + let id2 = TransactionPriorityId::new(1, TransactionId::new(1)); + assert_eq!(id1, id2); + assert!(id1 >= id2); + assert!(id1 <= id2); + assert!(id2 >= id1); + assert!(id2 <= id1); + } + } +} diff --git a/prio-graph-scheduler/src/transaction_state_container.rs b/prio-graph-scheduler/src/transaction_state_container.rs new file mode 100644 index 00000000000000..8e2a51f5f7bb37 --- /dev/null +++ b/prio-graph-scheduler/src/transaction_state_container.rs @@ -0,0 +1,259 @@ +use { + super::{ + transaction_priority_id::TransactionPriorityId, + transaction_state::{SanitizedTransactionTTL, TransactionState}, + }, + crate::scheduler_messages::TransactionId, + itertools::MinMaxResult, + min_max_heap::MinMaxHeap, + solana_core::banking_stage::immutable_deserialized_packet::ImmutableDeserializedPacket, + std::{collections::HashMap, sync::Arc}, +}; + +/// This structure will hold `TransactionState` for the entirety of a +/// transaction's lifetime in the scheduler and BankingStage as a whole. +/// +/// Transaction Lifetime: +/// 1. Received from `SigVerify` by `BankingStage` +/// 2. Inserted into `TransactionStateContainer` by `BankingStage` +/// 3. Popped in priority-order by scheduler, and transitioned to `Pending` state +/// 4. Processed by `ConsumeWorker` +/// a. If consumed, remove `Pending` state from the `TransactionStateContainer` +/// b. If retryable, transition back to `Unprocessed` state. +/// Re-insert to the queue, and return to step 3. +/// +/// The structure is composed of two main components: +/// 1. A priority queue of wrapped `TransactionId`s, which are used to +/// order transactions by priority for selection by the scheduler. +/// 2. A map of `TransactionId` to `TransactionState`, which is used to +/// track the state of each transaction. +/// +/// When `Pending`, the associated `TransactionId` is not in the queue, but +/// is still in the map. +/// The entry in the map should exist before insertion into the queue, and be +/// be removed only after the id is removed from the queue. +/// +/// The container maintains a fixed capacity. If the queue is full when pushing +/// a new transaction, the lowest priority transaction will be dropped. +pub struct TransactionStateContainer { + priority_queue: MinMaxHeap, + id_to_transaction_state: HashMap, +} + +impl TransactionStateContainer { + pub fn with_capacity(capacity: usize) -> Self { + Self { + priority_queue: MinMaxHeap::with_capacity(capacity), + id_to_transaction_state: HashMap::with_capacity(capacity), + } + } + + /// Returns true if the queue is empty. + pub fn is_empty(&self) -> bool { + self.priority_queue.is_empty() + } + + /// Returns the remaining capacity of the queue + pub fn remaining_queue_capacity(&self) -> usize { + self.priority_queue.capacity() - self.priority_queue.len() + } + + /// Get the top transaction id in the priority queue. + pub fn pop(&mut self) -> Option { + self.priority_queue.pop_max() + } + + /// Get mutable transaction state by id. + pub fn get_mut_transaction_state( + &mut self, + id: &TransactionId, + ) -> Option<&mut TransactionState> { + self.id_to_transaction_state.get_mut(id) + } + + /// Get reference to `SanitizedTransactionTTL` by id. + /// Panics if the transaction does not exist. + pub fn get_transaction_ttl( + &self, + id: &TransactionId, + ) -> Option<&SanitizedTransactionTTL> { + self.id_to_transaction_state + .get(id) + .map(|state| state.transaction_ttl()) + } + + /// Insert a new transaction into the container's queues and maps. + /// Returns `true` if a packet was dropped due to capacity limits. + pub fn insert_new_transaction( + &mut self, + transaction_id: TransactionId, + transaction_ttl: SanitizedTransactionTTL, + packet: Arc, + priority: u64, + cost: u64, + ) -> bool { + let priority_id = TransactionPriorityId::new(priority, transaction_id); + self.id_to_transaction_state.insert( + transaction_id, + TransactionState::new(transaction_ttl, packet, priority, cost), + ); + self.push_id_into_queue(priority_id) + } + + /// Retries a transaction - inserts transaction back into map (but not packet). + /// This transitions the transaction to `Unprocessed` state. + pub fn retry_transaction( + &mut self, + transaction_id: TransactionId, + transaction_ttl: SanitizedTransactionTTL, + ) { + let transaction_state = self + .get_mut_transaction_state(&transaction_id) + .expect("transaction must exist"); + let priority_id = TransactionPriorityId::new(transaction_state.priority(), transaction_id); + transaction_state.transition_to_unprocessed(transaction_ttl); + self.push_id_into_queue(priority_id); + } + + /// Pushes a transaction id into the priority queue. If the queue is full, the lowest priority + /// transaction will be dropped (removed from the queue and map). + /// Returns `true` if a packet was dropped due to capacity limits. + pub fn push_id_into_queue(&mut self, priority_id: TransactionPriorityId) -> bool { + if self.remaining_queue_capacity() == 0 { + let popped_id = self.priority_queue.push_pop_min(priority_id); + self.remove_by_id(&popped_id.id); + true + } else { + self.priority_queue.push(priority_id); + false + } + } + + /// Remove transaction by id. + pub fn remove_by_id(&mut self, id: &TransactionId) { + self.id_to_transaction_state + .remove(id) + .expect("transaction must exist"); + } + + pub fn get_min_max_priority(&self) -> MinMaxResult { + match self.priority_queue.peek_min() { + Some(min) => match self.priority_queue.peek_max() { + Some(max) => MinMaxResult::MinMax(min.priority, max.priority), + None => MinMaxResult::OneElement(min.priority), + }, + None => MinMaxResult::NoElements, + } + } +} + +#[cfg(test)] +mod tests { + use { + super::*, + solana_sdk::{ + compute_budget::ComputeBudgetInstruction, + hash::Hash, + message::Message, + packet::Packet, + signature::Keypair, + signer::Signer, + slot_history::Slot, + system_instruction, + transaction::{SanitizedTransaction, Transaction}, + }, + }; + + /// Returns (transaction_ttl, priority, cost) + fn test_transaction( + priority: u64, + ) -> ( + SanitizedTransactionTTL, + Arc, + u64, + u64, + ) { + let from_keypair = Keypair::new(); + let ixs = vec![ + system_instruction::transfer( + &from_keypair.pubkey(), + &solana_sdk::pubkey::new_rand(), + 1, + ), + ComputeBudgetInstruction::set_compute_unit_price(priority), + ]; + let message = Message::new(&ixs, Some(&from_keypair.pubkey())); + let tx = SanitizedTransaction::from_transaction_for_tests(Transaction::new( + &[&from_keypair], + message, + Hash::default(), + )); + let packet = Arc::new( + ImmutableDeserializedPacket::new( + Packet::from_data(None, tx.to_versioned_transaction()).unwrap(), + ) + .unwrap(), + ); + let transaction_ttl = SanitizedTransactionTTL { + transaction: tx, + max_age_slot: Slot::MAX, + }; + const TEST_TRANSACTION_COST: u64 = 5000; + (transaction_ttl, packet, priority, TEST_TRANSACTION_COST) + } + + fn push_to_container(container: &mut TransactionStateContainer, num: usize) { + for id in 0..num as u64 { + let priority = id; + let (transaction_ttl, packet, priority, cost) = test_transaction(priority); + container.insert_new_transaction( + TransactionId::new(id), + transaction_ttl, + packet, + priority, + cost, + ); + } + } + + #[test] + fn test_is_empty() { + let mut container = TransactionStateContainer::with_capacity(1); + assert!(container.is_empty()); + + push_to_container(&mut container, 1); + assert!(!container.is_empty()); + } + + #[test] + fn test_priority_queue_capacity() { + let mut container = TransactionStateContainer::with_capacity(1); + push_to_container(&mut container, 5); + + assert_eq!(container.priority_queue.len(), 1); + assert_eq!(container.id_to_transaction_state.len(), 1); + assert_eq!( + container + .id_to_transaction_state + .iter() + .map(|ts| ts.1.priority()) + .next() + .unwrap(), + 4 + ); + } + + #[test] + fn test_get_mut_transaction_state() { + let mut container = TransactionStateContainer::with_capacity(5); + push_to_container(&mut container, 5); + + let existing_id = TransactionId::new(3); + let non_existing_id = TransactionId::new(7); + assert!(container.get_mut_transaction_state(&existing_id).is_some()); + assert!(container.get_mut_transaction_state(&existing_id).is_some()); + assert!(container + .get_mut_transaction_state(&non_existing_id) + .is_none()); + } +} From 8e8f1084d93fcc9dd45d88f7706dada840e7d9a3 Mon Sep 17 00:00:00 2001 From: lewis Date: Mon, 14 Oct 2024 14:35:19 +0800 Subject: [PATCH 4/9] feat: add a DeseriablizableTxPacket trait --- .../src/deserializable_packet.rs | 39 +++++++++++++++++++ prio-graph-scheduler/src/id_generator.rs | 2 - prio-graph-scheduler/src/lib.rs | 1 + 3 files changed, 40 insertions(+), 2 deletions(-) create mode 100644 prio-graph-scheduler/src/deserializable_packet.rs diff --git a/prio-graph-scheduler/src/deserializable_packet.rs b/prio-graph-scheduler/src/deserializable_packet.rs new file mode 100644 index 00000000000000..f67984e9a7625e --- /dev/null +++ b/prio-graph-scheduler/src/deserializable_packet.rs @@ -0,0 +1,39 @@ +use ahash::HashSet; +use solana_sdk::hash::Hash; +use solana_sdk::message::AddressLoader; +use solana_sdk::packet::Packet; +use solana_sdk::pubkey::Pubkey; +use solana_sdk::transaction::{SanitizedTransaction, SanitizedVersionedTransaction}; +use std::error::Error; + +/// DeserializablePacket can be deserialized from a Packet. +/// +/// DeserializablePacket will be deserialized as a SanitizedTransaction +/// to be scheduled in transaction stream and scheduler. +pub trait DeserializableTxPacket: PartialEq + PartialOrd + Eq + Sized { + type DeserializeError: Error; + + fn from_packet(packet: Packet) -> Result; + + /// This function deserializes packets into transactions, + /// computes the blake3 hash of transaction messages. + fn build_sanitized_transaction( + &self, + votes_only: bool, + address_loader: impl AddressLoader, + reserved_account_keys: &HashSet, + ) -> Option; + + fn original_packet(&self) -> &Packet; + + /// deserialized into versionedTx, and then to SanitizedTransaction. + fn transaction(&self) -> &SanitizedVersionedTransaction; + + fn message_hash(&self) -> &Hash; + + fn is_simple_vote(&self) -> bool; + + fn compute_unit_price(&self) -> u64; + + fn compute_unit_limit(&self) -> u64; +} diff --git a/prio-graph-scheduler/src/id_generator.rs b/prio-graph-scheduler/src/id_generator.rs index 0d6e5ee2098174..3090e4e044d473 100644 --- a/prio-graph-scheduler/src/id_generator.rs +++ b/prio-graph-scheduler/src/id_generator.rs @@ -1,5 +1,3 @@ -use crate::scheduler_messages::TransactionId; - /// Simple reverse-sequential ID generator for `TransactionId`s. /// These IDs uniquely identify transactions during the scheduling process. pub struct IdGenerator { diff --git a/prio-graph-scheduler/src/lib.rs b/prio-graph-scheduler/src/lib.rs index 0fa6eac41017b1..215c451f5e33a2 100644 --- a/prio-graph-scheduler/src/lib.rs +++ b/prio-graph-scheduler/src/lib.rs @@ -10,6 +10,7 @@ pub mod scheduler_metrics; // pub mod scheduler_controller; pub mod transaction_state_container; pub mod prio_graph_scheduler; +pub mod deserializable_packet; #[macro_use] extern crate solana_metrics; From 32ce8985e7d16f7de24cba498bcc3f6a18ffb269 Mon Sep 17 00:00:00 2001 From: lewis Date: Mon, 14 Oct 2024 15:03:45 +0800 Subject: [PATCH 5/9] feat: use DeserializableTxPacket to standardize crate --- .../immutable_deserialized_packet.rs | 12 +- .../src/deserializable_packet.rs | 46 +- .../src/prio_graph_scheduler.rs | 41 +- .../src/scheduler_controller.rs | 1161 ----------------- .../src/scheduler_messages.rs | 10 +- prio-graph-scheduler/src/transaction_state.rs | 607 ++++----- .../src/transaction_state_container.rs | 28 +- 7 files changed, 396 insertions(+), 1509 deletions(-) delete mode 100644 prio-graph-scheduler/src/scheduler_controller.rs diff --git a/core/src/banking_stage/immutable_deserialized_packet.rs b/core/src/banking_stage/immutable_deserialized_packet.rs index b03f3d5d64d4e8..7bb259494e4c31 100644 --- a/core/src/banking_stage/immutable_deserialized_packet.rs +++ b/core/src/banking_stage/immutable_deserialized_packet.rs @@ -41,12 +41,12 @@ pub enum DeserializedPacketError { #[derive(Debug, Eq)] pub struct ImmutableDeserializedPacket { - original_packet: Packet, - transaction: SanitizedVersionedTransaction, - message_hash: Hash, - is_simple_vote: bool, - compute_unit_price: u64, - compute_unit_limit: u32, + pub original_packet: Packet, + pub transaction: SanitizedVersionedTransaction, + pub message_hash: Hash, + pub is_simple_vote: bool, + pub compute_unit_price: u64, + pub compute_unit_limit: u32, } impl ImmutableDeserializedPacket { diff --git a/prio-graph-scheduler/src/deserializable_packet.rs b/prio-graph-scheduler/src/deserializable_packet.rs index f67984e9a7625e..79c2a834666d41 100644 --- a/prio-graph-scheduler/src/deserializable_packet.rs +++ b/prio-graph-scheduler/src/deserializable_packet.rs @@ -1,4 +1,5 @@ -use ahash::HashSet; +use std::collections::HashSet; +use solana_core::banking_stage::immutable_deserialized_packet::{DeserializedPacketError, ImmutableDeserializedPacket}; use solana_sdk::hash::Hash; use solana_sdk::message::AddressLoader; use solana_sdk::packet::Packet; @@ -37,3 +38,46 @@ pub trait DeserializableTxPacket: PartialEq + PartialOrd + Eq + Sized { fn compute_unit_limit(&self) -> u64; } + + +/// TODO: migrate to solana_core +impl DeserializableTxPacket for ImmutableDeserializedPacket { + type DeserializeError = DeserializedPacketError; + + fn from_packet(packet: Packet) -> Result { + ImmutableDeserializedPacket::new(packet) + } + + fn build_sanitized_transaction( + &self, + votes_only: bool, + address_loader: impl AddressLoader, + reserved_account_keys: &HashSet, + ) -> Option { + self.build_sanitized_transaction(votes_only, address_loader, reserved_account_keys) + } + + fn original_packet(&self) -> &Packet { + &self.original_packet + } + + fn transaction(&self) -> &SanitizedVersionedTransaction { + &self.transaction + } + + fn message_hash(&self) -> &Hash { + &self.message_hash + } + + fn is_simple_vote(&self) -> bool { + self.is_simple_vote + } + + fn compute_unit_price(&self) -> u64 { + self.compute_unit_price + } + + fn compute_unit_limit(&self) -> u64 { + u64::from(self.compute_unit_limit) + } +} \ No newline at end of file diff --git a/prio-graph-scheduler/src/prio_graph_scheduler.rs b/prio-graph-scheduler/src/prio_graph_scheduler.rs index a2f339d3baeaac..4d66a7901d4976 100644 --- a/prio-graph-scheduler/src/prio_graph_scheduler.rs +++ b/prio-graph-scheduler/src/prio_graph_scheduler.rs @@ -1,14 +1,12 @@ use { - crate::scheduler_messages::{ - ConsumeWork, FinishedConsumeWork, TransactionBatchId, TransactionId, - }, - crate::transaction_priority_id::TransactionPriorityId, - crate::transaction_state::TransactionState, crate::{ + deserializable_packet::DeserializableTxPacket, in_flight_tracker::InFlightTracker, scheduler_error::SchedulerError, + scheduler_messages::{ConsumeWork, FinishedConsumeWork, TransactionBatchId, TransactionId}, thread_aware_account_locks::{ThreadAwareAccountLocks, ThreadId, ThreadSet}, - transaction_state::SanitizedTransactionTTL, + transaction_priority_id::TransactionPriorityId, + transaction_state::{SanitizedTransactionTTL, TransactionState}, transaction_state_container::TransactionStateContainer, }, crossbeam_channel::{Receiver, Sender, TryRecvError}, @@ -25,15 +23,16 @@ use { }, }; -pub struct PrioGraphScheduler { +pub struct PrioGraphScheduler { in_flight_tracker: InFlightTracker, account_locks: ThreadAwareAccountLocks, consume_work_senders: Vec>, finished_consume_work_receiver: Receiver, look_ahead_window_size: usize, + phantom: std::marker::PhantomData

, } -impl PrioGraphScheduler { +impl PrioGraphScheduler

{ pub fn new( consume_work_senders: Vec>, finished_consume_work_receiver: Receiver, @@ -45,6 +44,7 @@ impl PrioGraphScheduler { consume_work_senders, finished_consume_work_receiver, look_ahead_window_size: 2048, + phantom: std::marker::PhantomData, } } @@ -66,7 +66,7 @@ impl PrioGraphScheduler { /// not cause conflicts in the near future. pub fn schedule( &mut self, - container: &mut TransactionStateContainer, + container: &mut TransactionStateContainer

, pre_graph_filter: impl Fn(&[&SanitizedTransaction], &mut [bool]), pre_lock_filter: impl Fn(&SanitizedTransaction) -> bool, ) -> Result { @@ -102,7 +102,7 @@ impl PrioGraphScheduler { let mut total_filter_time_us: u64 = 0; let mut window_budget = self.look_ahead_window_size; - let mut chunked_pops = |container: &mut TransactionStateContainer, + let mut chunked_pops = |container: &mut TransactionStateContainer

, prio_graph: &mut PrioGraph<_, _, _, _>, window_budget: &mut usize| { while *window_budget > 0 { @@ -281,7 +281,7 @@ impl PrioGraphScheduler { /// Returns (num_transactions, num_retryable_transactions) on success. pub fn receive_completed( &mut self, - container: &mut TransactionStateContainer, + container: &mut TransactionStateContainer

, ) -> Result<(usize, usize), SchedulerError> { let mut total_num_transactions: usize = 0; let mut total_num_retryable: usize = 0; @@ -300,7 +300,7 @@ impl PrioGraphScheduler { /// Returns `Ok((num_transactions, num_retryable))` if a batch was received, `Ok((0, 0))` if no batch was received. fn try_receive_completed( &mut self, - container: &mut TransactionStateContainer, + container: &mut TransactionStateContainer

, ) -> Result<(usize, usize), SchedulerError> { match self.finished_consume_work_receiver.try_recv() { Ok(FinishedConsumeWork { @@ -535,8 +535,8 @@ enum TransactionSchedulingError { UnschedulableConflicts, } -fn try_schedule_transaction( - transaction_state: &mut TransactionState, +fn try_schedule_transaction( + transaction_state: &mut TransactionState

, pre_lock_filter: impl Fn(&SanitizedTransaction) -> bool, blocking_locks: &mut ReadWriteAccountSet, account_locks: &mut ThreadAwareAccountLocks, @@ -621,15 +621,17 @@ mod tests { fn create_test_frame( num_threads: usize, ) -> ( - PrioGraphScheduler, + PrioGraphScheduler, Vec>, Sender, ) { let (consume_work_senders, consume_work_receivers) = (0..num_threads).map(|_| unbounded()).unzip(); let (finished_consume_work_sender, finished_consume_work_receiver) = unbounded(); - let scheduler = - PrioGraphScheduler::new(consume_work_senders, finished_consume_work_receiver); + let scheduler = PrioGraphScheduler::::new( + consume_work_senders, + finished_consume_work_receiver, + ); ( scheduler, consume_work_receivers, @@ -666,8 +668,9 @@ mod tests { u64, ), >, - ) -> TransactionStateContainer { - let mut container = TransactionStateContainer::with_capacity(10 * 1024); + ) -> TransactionStateContainer { + let mut container = + TransactionStateContainer::::with_capacity(10 * 1024); for (index, (from_keypair, to_pubkeys, lamports, compute_unit_price)) in tx_infos.into_iter().enumerate() { diff --git a/prio-graph-scheduler/src/scheduler_controller.rs b/prio-graph-scheduler/src/scheduler_controller.rs deleted file mode 100644 index 5a04e1eb39d65b..00000000000000 --- a/prio-graph-scheduler/src/scheduler_controller.rs +++ /dev/null @@ -1,1161 +0,0 @@ -//! Control flow for BankingStage's transaction scheduler. -//! - -use { - super::{ - prio_graph_scheduler::PrioGraphScheduler, - scheduler_error::SchedulerError, - scheduler_metrics::{ - SchedulerCountMetrics, SchedulerLeaderDetectionMetrics, SchedulerTimingMetrics, - }, - transaction_id_generator::TransactionIdGenerator, - transaction_state::SanitizedTransactionTTL, - transaction_state_container::TransactionStateContainer, - }, - crate::banking_stage::{ - consume_worker::ConsumeWorkerMetrics, - consumer::Consumer, - decision_maker::{BufferedPacketsDecision, DecisionMaker}, - forwarder::Forwarder, - immutable_deserialized_packet::ImmutableDeserializedPacket, - packet_deserializer::PacketDeserializer, - ForwardOption, LikeClusterInfo, TOTAL_BUFFERED_PACKETS, - }, - arrayvec::ArrayVec, - crossbeam_channel::RecvTimeoutError, - solana_accounts_db::account_locks::validate_account_locks, - solana_cost_model::cost_model::CostModel, - solana_measure::measure_us, - solana_runtime::{bank::Bank, bank_forks::BankForks}, - solana_runtime_transaction::instructions_processor::process_compute_budget_instructions, - solana_sdk::{ - self, - clock::{FORWARD_TRANSACTIONS_TO_LEADER_AT_SLOT_OFFSET, MAX_PROCESSING_AGE}, - fee::FeeBudgetLimits, - saturating_add_assign, - transaction::SanitizedTransaction, - }, - solana_svm::transaction_error_metrics::TransactionErrorMetrics, - solana_svm_transaction::svm_message::SVMMessage, - std::{ - sync::{Arc, RwLock}, - time::{Duration, Instant}, - }, -}; - -/// Controls packet and transaction flow into scheduler, and scheduling execution. -pub(crate) struct SchedulerController { - /// Decision maker for determining what should be done with transactions. - decision_maker: DecisionMaker, - /// Packet/Transaction ingress. - packet_receiver: PacketDeserializer, - bank_forks: Arc>, - /// Generates unique IDs for incoming transactions. - transaction_id_generator: TransactionIdGenerator, - /// Container for transaction state. - /// Shared resource between `packet_receiver` and `scheduler`. - container: TransactionStateContainer, - /// State for scheduling and communicating with worker threads. - scheduler: PrioGraphScheduler, - /// Metrics tracking time for leader bank detection. - leader_detection_metrics: SchedulerLeaderDetectionMetrics, - /// Metrics tracking counts on transactions in different states - /// over an interval and during a leader slot. - count_metrics: SchedulerCountMetrics, - /// Metrics tracking time spent in difference code sections - /// over an interval and during a leader slot. - timing_metrics: SchedulerTimingMetrics, - /// Metric report handles for the worker threads. - worker_metrics: Vec>, - /// State for forwarding packets to the leader, if enabled. - forwarder: Option>, -} - -impl SchedulerController { - pub fn new( - decision_maker: DecisionMaker, - packet_deserializer: PacketDeserializer, - bank_forks: Arc>, - scheduler: PrioGraphScheduler, - worker_metrics: Vec>, - forwarder: Option>, - ) -> Self { - Self { - decision_maker, - packet_receiver: packet_deserializer, - bank_forks, - transaction_id_generator: TransactionIdGenerator::default(), - container: TransactionStateContainer::with_capacity(TOTAL_BUFFERED_PACKETS), - scheduler, - leader_detection_metrics: SchedulerLeaderDetectionMetrics::default(), - count_metrics: SchedulerCountMetrics::default(), - timing_metrics: SchedulerTimingMetrics::default(), - worker_metrics, - forwarder, - } - } - - pub fn run(mut self) -> Result<(), SchedulerError> { - loop { - // BufferedPacketsDecision is shared with legacy BankingStage, which will forward - // packets. Initially, not renaming these decision variants but the actions taken - // are different, since new BankingStage will not forward packets. - // For `Forward` and `ForwardAndHold`, we want to receive packets but will not - // forward them to the next leader. In this case, `ForwardAndHold` is - // indistinguishable from `Hold`. - // - // `Forward` will drop packets from the buffer instead of forwarding. - // During receiving, since packets would be dropped from buffer anyway, we can - // bypass sanitization and buffering and immediately drop the packets. - let (decision, decision_time_us) = - measure_us!(self.decision_maker.make_consume_or_forward_decision()); - self.timing_metrics.update(|timing_metrics| { - saturating_add_assign!(timing_metrics.decision_time_us, decision_time_us); - }); - let new_leader_slot = decision.bank_start().map(|b| b.working_bank.slot()); - self.leader_detection_metrics - .update_and_maybe_report(decision.bank_start()); - self.count_metrics - .maybe_report_and_reset_slot(new_leader_slot); - self.timing_metrics - .maybe_report_and_reset_slot(new_leader_slot); - - self.process_transactions(&decision)?; - self.receive_completed()?; - if !self.receive_and_buffer_packets(&decision) { - break; - } - // Report metrics only if there is data. - // Reset intervals when appropriate, regardless of report. - let should_report = self.count_metrics.interval_has_data(); - let priority_min_max = self.container.get_min_max_priority(); - self.count_metrics.update(|count_metrics| { - count_metrics.update_priority_stats(priority_min_max); - }); - self.count_metrics - .maybe_report_and_reset_interval(should_report); - self.timing_metrics - .maybe_report_and_reset_interval(should_report); - self.worker_metrics - .iter() - .for_each(|metrics| metrics.maybe_report_and_reset()); - } - - Ok(()) - } - - /// Process packets based on decision. - fn process_transactions( - &mut self, - decision: &BufferedPacketsDecision, - ) -> Result<(), SchedulerError> { - let forwarding_enabled = self.forwarder.is_some(); - match decision { - BufferedPacketsDecision::Consume(bank_start) => { - let (scheduling_summary, schedule_time_us) = measure_us!(self.scheduler.schedule( - &mut self.container, - |txs, results| { - Self::pre_graph_filter( - txs, - results, - &bank_start.working_bank, - MAX_PROCESSING_AGE, - ) - }, - |_| true // no pre-lock filter for now - )?); - - self.count_metrics.update(|count_metrics| { - saturating_add_assign!( - count_metrics.num_scheduled, - scheduling_summary.num_scheduled - ); - saturating_add_assign!( - count_metrics.num_unschedulable, - scheduling_summary.num_unschedulable - ); - saturating_add_assign!( - count_metrics.num_schedule_filtered_out, - scheduling_summary.num_filtered_out - ); - }); - - self.timing_metrics.update(|timing_metrics| { - saturating_add_assign!( - timing_metrics.schedule_filter_time_us, - scheduling_summary.filter_time_us - ); - saturating_add_assign!(timing_metrics.schedule_time_us, schedule_time_us); - }); - } - BufferedPacketsDecision::Forward => { - if forwarding_enabled { - let (_, forward_time_us) = measure_us!(self.forward_packets(false)); - self.timing_metrics.update(|timing_metrics| { - saturating_add_assign!(timing_metrics.forward_time_us, forward_time_us); - }); - } else { - let (_, clear_time_us) = measure_us!(self.clear_container()); - self.timing_metrics.update(|timing_metrics| { - saturating_add_assign!(timing_metrics.clear_time_us, clear_time_us); - }); - } - } - BufferedPacketsDecision::ForwardAndHold => { - if forwarding_enabled { - let (_, forward_time_us) = measure_us!(self.forward_packets(true)); - self.timing_metrics.update(|timing_metrics| { - saturating_add_assign!(timing_metrics.forward_time_us, forward_time_us); - }); - } else { - let (_, clean_time_us) = measure_us!(self.clean_queue()); - self.timing_metrics.update(|timing_metrics| { - saturating_add_assign!(timing_metrics.clean_time_us, clean_time_us); - }); - } - } - BufferedPacketsDecision::Hold => {} - } - - Ok(()) - } - - fn pre_graph_filter( - transactions: &[&SanitizedTransaction], - results: &mut [bool], - bank: &Bank, - max_age: usize, - ) { - let lock_results = vec![Ok(()); transactions.len()]; - let mut error_counters = TransactionErrorMetrics::default(); - let check_results = - bank.check_transactions(transactions, &lock_results, max_age, &mut error_counters); - - let fee_check_results: Vec<_> = check_results - .into_iter() - .zip(transactions) - .map(|(result, tx)| { - result?; // if there's already error do nothing - Consumer::check_fee_payer_unlocked(bank, tx.message(), &mut error_counters) - }) - .collect(); - - for (fee_check_result, result) in fee_check_results.into_iter().zip(results.iter_mut()) { - *result = fee_check_result.is_ok(); - } - } - - /// Forward packets to the next leader. - fn forward_packets(&mut self, hold: bool) { - const MAX_FORWARDING_DURATION: Duration = Duration::from_millis(100); - let start = Instant::now(); - let bank = self.bank_forks.read().unwrap().working_bank(); - let feature_set = &bank.feature_set; - let forwarder = self.forwarder.as_mut().expect("forwarder must exist"); - - // Pop from the container in chunks, filter using bank checks, then attempt to forward. - // This doubles as a way to clean the queue as well as forwarding transactions. - const CHUNK_SIZE: usize = 64; - let mut num_forwarded: usize = 0; - let mut ids_to_add_back = Vec::new(); - let mut max_time_reached = false; - while !self.container.is_empty() { - let mut filter_array = [true; CHUNK_SIZE]; - let mut ids = Vec::with_capacity(CHUNK_SIZE); - let mut txs = Vec::with_capacity(CHUNK_SIZE); - - for _ in 0..CHUNK_SIZE { - if let Some(id) = self.container.pop() { - ids.push(id); - } else { - break; - } - } - let chunk_size = ids.len(); - ids.iter().for_each(|id| { - let transaction = self.container.get_transaction_ttl(&id.id).unwrap(); - txs.push(&transaction.transaction); - }); - - // use same filter we use for processing transactions: - // age, already processed, fee-check. - Self::pre_graph_filter( - &txs, - &mut filter_array, - &bank, - MAX_PROCESSING_AGE - .saturating_sub(FORWARD_TRANSACTIONS_TO_LEADER_AT_SLOT_OFFSET as usize), - ); - - for (id, filter_result) in ids.iter().zip(&filter_array[..chunk_size]) { - if !*filter_result { - self.container.remove_by_id(&id.id); - continue; - } - - ids_to_add_back.push(*id); // add back to the queue at end - let state = self.container.get_mut_transaction_state(&id.id).unwrap(); - let sanitized_transaction = &state.transaction_ttl().transaction; - let immutable_packet = state.packet().clone(); - - // If not already forwarded and can be forwarded, add to forwardable packets. - if state.should_forward() - && forwarder.try_add_packet( - sanitized_transaction, - immutable_packet, - feature_set, - ) - { - saturating_add_assign!(num_forwarded, 1); - state.mark_forwarded(); - } - } - - if start.elapsed() >= MAX_FORWARDING_DURATION { - max_time_reached = true; - break; - } - } - - // Forward each batch of transactions - forwarder.forward_batched_packets(&ForwardOption::ForwardTransaction); - forwarder.clear_batches(); - - // If we hit the time limit. Drop everything that was not checked/processed. - // If we cannot run these simple checks in time, then we cannot run them during - // leader slot. - if max_time_reached { - while let Some(id) = self.container.pop() { - self.container.remove_by_id(&id.id); - } - } - - if hold { - for priority_id in ids_to_add_back { - self.container.push_id_into_queue(priority_id); - } - } else { - for priority_id in ids_to_add_back { - self.container.remove_by_id(&priority_id.id); - } - } - - self.count_metrics.update(|count_metrics| { - saturating_add_assign!(count_metrics.num_forwarded, num_forwarded); - }); - } - - /// Clears the transaction state container. - /// This only clears pending transactions, and does **not** clear in-flight transactions. - fn clear_container(&mut self) { - let mut num_dropped_on_clear: usize = 0; - while let Some(id) = self.container.pop() { - self.container.remove_by_id(&id.id); - saturating_add_assign!(num_dropped_on_clear, 1); - } - - self.count_metrics.update(|count_metrics| { - saturating_add_assign!(count_metrics.num_dropped_on_clear, num_dropped_on_clear); - }); - } - - /// Clean unprocessable transactions from the queue. These will be transactions that are - /// expired, already processed, or are no longer sanitizable. - /// This only clears pending transactions, and does **not** clear in-flight transactions. - fn clean_queue(&mut self) { - // Clean up any transactions that have already been processed, are too old, or do not have - // valid nonce accounts. - const MAX_TRANSACTION_CHECKS: usize = 10_000; - let mut transaction_ids = Vec::with_capacity(MAX_TRANSACTION_CHECKS); - - while let Some(id) = self.container.pop() { - transaction_ids.push(id); - } - - let bank = self.bank_forks.read().unwrap().working_bank(); - - const CHUNK_SIZE: usize = 128; - let mut error_counters = TransactionErrorMetrics::default(); - let mut num_dropped_on_age_and_status: usize = 0; - for chunk in transaction_ids.chunks(CHUNK_SIZE) { - let lock_results = vec![Ok(()); chunk.len()]; - let sanitized_txs: Vec<_> = chunk - .iter() - .map(|id| { - &self - .container - .get_transaction_ttl(&id.id) - .expect("transaction must exist") - .transaction - }) - .collect(); - - let check_results = bank.check_transactions( - &sanitized_txs, - &lock_results, - MAX_PROCESSING_AGE, - &mut error_counters, - ); - - for (result, id) in check_results.into_iter().zip(chunk.iter()) { - if result.is_err() { - saturating_add_assign!(num_dropped_on_age_and_status, 1); - self.container.remove_by_id(&id.id); - } else { - self.container.push_id_into_queue(*id); - } - } - } - - self.count_metrics.update(|count_metrics| { - saturating_add_assign!( - count_metrics.num_dropped_on_age_and_status, - num_dropped_on_age_and_status - ); - }); - } - - /// Receives completed transactions from the workers and updates metrics. - fn receive_completed(&mut self) -> Result<(), SchedulerError> { - let ((num_transactions, num_retryable), receive_completed_time_us) = - measure_us!(self.scheduler.receive_completed(&mut self.container)?); - - self.count_metrics.update(|count_metrics| { - saturating_add_assign!(count_metrics.num_finished, num_transactions); - saturating_add_assign!(count_metrics.num_retryable, num_retryable); - }); - self.timing_metrics.update(|timing_metrics| { - saturating_add_assign!( - timing_metrics.receive_completed_time_us, - receive_completed_time_us - ); - }); - - Ok(()) - } - - /// Returns whether the packet receiver is still connected. - fn receive_and_buffer_packets(&mut self, decision: &BufferedPacketsDecision) -> bool { - let remaining_queue_capacity = self.container.remaining_queue_capacity(); - - const MAX_PACKET_RECEIVE_TIME: Duration = Duration::from_millis(10); - let (recv_timeout, should_buffer) = match decision { - BufferedPacketsDecision::Consume(_) => ( - if self.container.is_empty() { - MAX_PACKET_RECEIVE_TIME - } else { - Duration::ZERO - }, - true, - ), - BufferedPacketsDecision::Forward => (MAX_PACKET_RECEIVE_TIME, self.forwarder.is_some()), - BufferedPacketsDecision::ForwardAndHold | BufferedPacketsDecision::Hold => { - (MAX_PACKET_RECEIVE_TIME, true) - } - }; - - let (received_packet_results, receive_time_us) = measure_us!(self - .packet_receiver - .receive_packets(recv_timeout, remaining_queue_capacity, |packet| { - packet.check_excessive_precompiles()?; - Ok(packet) - })); - - self.timing_metrics.update(|timing_metrics| { - saturating_add_assign!(timing_metrics.receive_time_us, receive_time_us); - }); - - match received_packet_results { - Ok(receive_packet_results) => { - let num_received_packets = receive_packet_results.deserialized_packets.len(); - - self.count_metrics.update(|count_metrics| { - saturating_add_assign!(count_metrics.num_received, num_received_packets); - }); - - if should_buffer { - let (_, buffer_time_us) = measure_us!( - self.buffer_packets(receive_packet_results.deserialized_packets) - ); - self.timing_metrics.update(|timing_metrics| { - saturating_add_assign!(timing_metrics.buffer_time_us, buffer_time_us); - }); - } else { - self.count_metrics.update(|count_metrics| { - saturating_add_assign!( - count_metrics.num_dropped_on_receive, - num_received_packets - ); - }); - } - } - Err(RecvTimeoutError::Timeout) => {} - Err(RecvTimeoutError::Disconnected) => return false, - } - - true - } - - fn buffer_packets(&mut self, packets: Vec) { - // Convert to Arcs - let packets: Vec<_> = packets.into_iter().map(Arc::new).collect(); - // Sanitize packets, generate IDs, and insert into the container. - let bank = self.bank_forks.read().unwrap().working_bank(); - let last_slot_in_epoch = bank.epoch_schedule().get_last_slot_in_epoch(bank.epoch()); - let transaction_account_lock_limit = bank.get_transaction_account_lock_limit(); - let vote_only = bank.vote_only_bank(); - - const CHUNK_SIZE: usize = 128; - let lock_results: [_; CHUNK_SIZE] = core::array::from_fn(|_| Ok(())); - - let mut arc_packets = ArrayVec::<_, CHUNK_SIZE>::new(); - let mut transactions = ArrayVec::<_, CHUNK_SIZE>::new(); - let mut fee_budget_limits_vec = ArrayVec::<_, CHUNK_SIZE>::new(); - - let mut error_counts = TransactionErrorMetrics::default(); - for chunk in packets.chunks(CHUNK_SIZE) { - let mut post_sanitization_count: usize = 0; - chunk - .iter() - .filter_map(|packet| { - packet - .build_sanitized_transaction( - vote_only, - bank.as_ref(), - bank.get_reserved_account_keys(), - ) - .map(|tx| (packet.clone(), tx)) - }) - .inspect(|_| saturating_add_assign!(post_sanitization_count, 1)) - .filter(|(_packet, tx)| { - validate_account_locks( - tx.message().account_keys(), - transaction_account_lock_limit, - ) - .is_ok() - }) - .filter_map(|(packet, tx)| { - process_compute_budget_instructions(SVMMessage::program_instructions_iter(&tx)) - .map(|compute_budget| (packet, tx, compute_budget.into())) - .ok() - }) - .for_each(|(packet, tx, fee_budget_limits)| { - arc_packets.push(packet); - transactions.push(tx); - fee_budget_limits_vec.push(fee_budget_limits); - }); - - let check_results = bank.check_transactions( - &transactions, - &lock_results[..transactions.len()], - MAX_PROCESSING_AGE, - &mut error_counts, - ); - let post_lock_validation_count = transactions.len(); - - let mut post_transaction_check_count: usize = 0; - let mut num_dropped_on_capacity: usize = 0; - let mut num_buffered: usize = 0; - for (((packet, transaction), fee_budget_limits), _check_result) in arc_packets - .drain(..) - .zip(transactions.drain(..)) - .zip(fee_budget_limits_vec.drain(..)) - .zip(check_results) - .filter(|(_, check_result)| check_result.is_ok()) - { - saturating_add_assign!(post_transaction_check_count, 1); - let transaction_id = self.transaction_id_generator.next(); - - let (priority, cost) = - Self::calculate_priority_and_cost(&transaction, &fee_budget_limits, &bank); - let transaction_ttl = SanitizedTransactionTTL { - transaction, - max_age_slot: last_slot_in_epoch, - }; - - if self.container.insert_new_transaction( - transaction_id, - transaction_ttl, - packet, - priority, - cost, - ) { - saturating_add_assign!(num_dropped_on_capacity, 1); - } - saturating_add_assign!(num_buffered, 1); - } - - // Update metrics for transactions that were dropped. - let num_dropped_on_sanitization = chunk.len().saturating_sub(post_sanitization_count); - let num_dropped_on_lock_validation = - post_sanitization_count.saturating_sub(post_lock_validation_count); - let num_dropped_on_transaction_checks = - post_lock_validation_count.saturating_sub(post_transaction_check_count); - - self.count_metrics.update(|count_metrics| { - saturating_add_assign!( - count_metrics.num_dropped_on_capacity, - num_dropped_on_capacity - ); - saturating_add_assign!(count_metrics.num_buffered, num_buffered); - saturating_add_assign!( - count_metrics.num_dropped_on_sanitization, - num_dropped_on_sanitization - ); - saturating_add_assign!( - count_metrics.num_dropped_on_validate_locks, - num_dropped_on_lock_validation - ); - saturating_add_assign!( - count_metrics.num_dropped_on_receive_transaction_checks, - num_dropped_on_transaction_checks - ); - }); - } - } - - /// Calculate priority and cost for a transaction: - /// - /// Cost is calculated through the `CostModel`, - /// and priority is calculated through a formula here that attempts to sell - /// blockspace to the highest bidder. - /// - /// The priority is calculated as: - /// P = R / (1 + C) - /// where P is the priority, R is the reward, - /// and C is the cost towards block-limits. - /// - /// Current minimum costs are on the order of several hundred, - /// so the denominator is effectively C, and the +1 is simply - /// to avoid any division by zero due to a bug - these costs - /// are calculated by the cost-model and are not direct - /// from user input. They should never be zero. - /// Any difference in the prioritization is negligible for - /// the current transaction costs. - fn calculate_priority_and_cost( - transaction: &SanitizedTransaction, - fee_budget_limits: &FeeBudgetLimits, - bank: &Bank, - ) -> (u64, u64) { - let cost = CostModel::calculate_cost(transaction, &bank.feature_set).sum(); - let reward = bank.calculate_reward_for_transaction(transaction, fee_budget_limits); - - // We need a multiplier here to avoid rounding down too aggressively. - // For many transactions, the cost will be greater than the fees in terms of raw lamports. - // For the purposes of calculating prioritization, we multiply the fees by a large number so that - // the cost is a small fraction. - // An offset of 1 is used in the denominator to explicitly avoid division by zero. - const MULTIPLIER: u64 = 1_000_000; - ( - reward - .saturating_mul(MULTIPLIER) - .saturating_div(cost.saturating_add(1)), - cost, - ) - } -} - -#[cfg(test)] -mod tests { - use { - super::*, - crate::{ - banking_stage::{ - consumer::TARGET_NUM_TRANSACTIONS_PER_BATCH, - scheduler_messages::{ConsumeWork, FinishedConsumeWork, TransactionBatchId}, - tests::create_slow_genesis_config, - }, - banking_trace::BankingPacketBatch, - sigverify::SigverifyTracerPacketStats, - }, - crossbeam_channel::{unbounded, Receiver, Sender}, - itertools::Itertools, - solana_gossip::cluster_info::ClusterInfo, - solana_ledger::{ - blockstore::Blockstore, genesis_utils::GenesisConfigInfo, - get_tmp_ledger_path_auto_delete, leader_schedule_cache::LeaderScheduleCache, - }, - solana_perf::packet::{to_packet_batches, PacketBatch, NUM_PACKETS}, - solana_poh::poh_recorder::{PohRecorder, Record, WorkingBankEntry}, - solana_runtime::bank::Bank, - solana_sdk::{ - compute_budget::ComputeBudgetInstruction, fee_calculator::FeeRateGovernor, hash::Hash, - message::Message, poh_config::PohConfig, pubkey::Pubkey, signature::Keypair, - signer::Signer, system_instruction, system_transaction, transaction::Transaction, - }, - std::sync::{atomic::AtomicBool, Arc, RwLock}, - tempfile::TempDir, - }; - - fn create_channels(num: usize) -> (Vec>, Vec>) { - (0..num).map(|_| unbounded()).unzip() - } - - // Helper struct to create tests that hold channels, files, etc. - // such that our tests can be more easily set up and run. - struct TestFrame { - bank: Arc, - mint_keypair: Keypair, - _ledger_path: TempDir, - _entry_receiver: Receiver, - _record_receiver: Receiver, - poh_recorder: Arc>, - banking_packet_sender: Sender, Option)>>, - - consume_work_receivers: Vec>, - finished_consume_work_sender: Sender, - } - - fn create_test_frame(num_threads: usize) -> (TestFrame, SchedulerController>) { - let GenesisConfigInfo { - mut genesis_config, - mint_keypair, - .. - } = create_slow_genesis_config(u64::MAX); - genesis_config.fee_rate_governor = FeeRateGovernor::new(5000, 0); - let (bank, bank_forks) = Bank::new_no_wallclock_throttle_for_tests(&genesis_config); - - let ledger_path = get_tmp_ledger_path_auto_delete!(); - let blockstore = Blockstore::open(ledger_path.path()) - .expect("Expected to be able to open database ledger"); - let (poh_recorder, entry_receiver, record_receiver) = PohRecorder::new( - bank.tick_height(), - bank.last_blockhash(), - bank.clone(), - Some((4, 4)), - bank.ticks_per_slot(), - Arc::new(blockstore), - &Arc::new(LeaderScheduleCache::new_from_bank(&bank)), - &PohConfig::default(), - Arc::new(AtomicBool::default()), - ); - let poh_recorder = Arc::new(RwLock::new(poh_recorder)); - let decision_maker = DecisionMaker::new(Pubkey::new_unique(), poh_recorder.clone()); - - let (banking_packet_sender, banking_packet_receiver) = unbounded(); - let packet_deserializer = PacketDeserializer::new(banking_packet_receiver); - - let (consume_work_senders, consume_work_receivers) = create_channels(num_threads); - let (finished_consume_work_sender, finished_consume_work_receiver) = unbounded(); - - let test_frame = TestFrame { - bank, - mint_keypair, - _ledger_path: ledger_path, - _entry_receiver: entry_receiver, - _record_receiver: record_receiver, - poh_recorder, - banking_packet_sender, - consume_work_receivers, - finished_consume_work_sender, - }; - - let scheduler_controller = SchedulerController::new( - decision_maker, - packet_deserializer, - bank_forks, - PrioGraphScheduler::new(consume_work_senders, finished_consume_work_receiver), - vec![], // no actual workers with metrics to report, this can be empty - None, - ); - - (test_frame, scheduler_controller) - } - - fn create_and_fund_prioritized_transfer( - bank: &Bank, - mint_keypair: &Keypair, - from_keypair: &Keypair, - to_pubkey: &Pubkey, - lamports: u64, - compute_unit_price: u64, - recent_blockhash: Hash, - ) -> Transaction { - // Fund the sending key, so that the transaction does not get filtered by the fee-payer check. - { - let transfer = system_transaction::transfer( - mint_keypair, - &from_keypair.pubkey(), - 500_000, // just some amount that will always be enough - bank.last_blockhash(), - ); - bank.process_transaction(&transfer).unwrap(); - } - - let transfer = system_instruction::transfer(&from_keypair.pubkey(), to_pubkey, lamports); - let prioritization = ComputeBudgetInstruction::set_compute_unit_price(compute_unit_price); - let message = Message::new(&[transfer, prioritization], Some(&from_keypair.pubkey())); - Transaction::new(&vec![from_keypair], message, recent_blockhash) - } - - fn to_banking_packet_batch(txs: &[Transaction]) -> BankingPacketBatch { - let packet_batch = to_packet_batches(txs, NUM_PACKETS); - Arc::new((packet_batch, None)) - } - - // Helper function to let test receive and then schedule packets. - // The order of operations here is convenient for testing, but does not - // match the order of operations in the actual scheduler. - // The actual scheduler will process immediately after the decision, - // in order to keep the decision as recent as possible for processing. - // In the tests, the decision will not become stale, so it is more convenient - // to receive first and then schedule. - fn test_receive_then_schedule( - scheduler_controller: &mut SchedulerController>, - ) { - let decision = scheduler_controller - .decision_maker - .make_consume_or_forward_decision(); - assert!(matches!(decision, BufferedPacketsDecision::Consume(_))); - assert!(scheduler_controller.receive_completed().is_ok()); - assert!(scheduler_controller.receive_and_buffer_packets(&decision)); - assert!(scheduler_controller.process_transactions(&decision).is_ok()); - } - - #[test] - #[should_panic(expected = "batch id 0 is not being tracked")] - fn test_unexpected_batch_id() { - let (test_frame, scheduler_controller) = create_test_frame(1); - let TestFrame { - finished_consume_work_sender, - .. - } = &test_frame; - - finished_consume_work_sender - .send(FinishedConsumeWork { - work: ConsumeWork { - batch_id: TransactionBatchId::new(0), - ids: vec![], - transactions: vec![], - max_age_slots: vec![], - }, - retryable_indexes: vec![], - }) - .unwrap(); - - scheduler_controller.run().unwrap(); - } - - #[test] - fn test_schedule_consume_single_threaded_no_conflicts() { - let (test_frame, mut scheduler_controller) = create_test_frame(1); - let TestFrame { - bank, - mint_keypair, - poh_recorder, - banking_packet_sender, - consume_work_receivers, - .. - } = &test_frame; - - poh_recorder - .write() - .unwrap() - .set_bank_for_test(bank.clone()); - - // Send packet batch to the scheduler - should do nothing until we become the leader. - let tx1 = create_and_fund_prioritized_transfer( - bank, - mint_keypair, - &Keypair::new(), - &Pubkey::new_unique(), - 1, - 1000, - bank.last_blockhash(), - ); - let tx2 = create_and_fund_prioritized_transfer( - bank, - mint_keypair, - &Keypair::new(), - &Pubkey::new_unique(), - 1, - 2000, - bank.last_blockhash(), - ); - let tx1_hash = tx1.message().hash(); - let tx2_hash = tx2.message().hash(); - - let txs = vec![tx1, tx2]; - banking_packet_sender - .send(to_banking_packet_batch(&txs)) - .unwrap(); - - test_receive_then_schedule(&mut scheduler_controller); - let consume_work = consume_work_receivers[0].try_recv().unwrap(); - assert_eq!(consume_work.ids.len(), 2); - assert_eq!(consume_work.transactions.len(), 2); - let message_hashes = consume_work - .transactions - .iter() - .map(|tx| tx.message_hash()) - .collect_vec(); - assert_eq!(message_hashes, vec![&tx2_hash, &tx1_hash]); - } - - #[test] - fn test_schedule_consume_single_threaded_conflict() { - let (test_frame, mut scheduler_controller) = create_test_frame(1); - let TestFrame { - bank, - mint_keypair, - poh_recorder, - banking_packet_sender, - consume_work_receivers, - .. - } = &test_frame; - - poh_recorder - .write() - .unwrap() - .set_bank_for_test(bank.clone()); - - let pk = Pubkey::new_unique(); - let tx1 = create_and_fund_prioritized_transfer( - bank, - mint_keypair, - &Keypair::new(), - &pk, - 1, - 1000, - bank.last_blockhash(), - ); - let tx2 = create_and_fund_prioritized_transfer( - bank, - mint_keypair, - &Keypair::new(), - &pk, - 1, - 2000, - bank.last_blockhash(), - ); - let tx1_hash = tx1.message().hash(); - let tx2_hash = tx2.message().hash(); - - let txs = vec![tx1, tx2]; - banking_packet_sender - .send(to_banking_packet_batch(&txs)) - .unwrap(); - - // We expect 2 batches to be scheduled - test_receive_then_schedule(&mut scheduler_controller); - let consume_works = (0..2) - .map(|_| consume_work_receivers[0].try_recv().unwrap()) - .collect_vec(); - - let num_txs_per_batch = consume_works.iter().map(|cw| cw.ids.len()).collect_vec(); - let message_hashes = consume_works - .iter() - .flat_map(|cw| cw.transactions.iter().map(|tx| tx.message_hash())) - .collect_vec(); - assert_eq!(num_txs_per_batch, vec![1; 2]); - assert_eq!(message_hashes, vec![&tx2_hash, &tx1_hash]); - } - - #[test] - fn test_schedule_consume_single_threaded_multi_batch() { - let (test_frame, mut scheduler_controller) = create_test_frame(1); - let TestFrame { - bank, - mint_keypair, - poh_recorder, - banking_packet_sender, - consume_work_receivers, - .. - } = &test_frame; - - poh_recorder - .write() - .unwrap() - .set_bank_for_test(bank.clone()); - - // Send multiple batches - all get scheduled - let txs1 = (0..2 * TARGET_NUM_TRANSACTIONS_PER_BATCH) - .map(|i| { - create_and_fund_prioritized_transfer( - bank, - mint_keypair, - &Keypair::new(), - &Pubkey::new_unique(), - i as u64, - 1, - bank.last_blockhash(), - ) - }) - .collect_vec(); - let txs2 = (0..2 * TARGET_NUM_TRANSACTIONS_PER_BATCH) - .map(|i| { - create_and_fund_prioritized_transfer( - bank, - mint_keypair, - &Keypair::new(), - &Pubkey::new_unique(), - i as u64, - 2, - bank.last_blockhash(), - ) - }) - .collect_vec(); - - banking_packet_sender - .send(to_banking_packet_batch(&txs1)) - .unwrap(); - banking_packet_sender - .send(to_banking_packet_batch(&txs2)) - .unwrap(); - - // We expect 4 batches to be scheduled - test_receive_then_schedule(&mut scheduler_controller); - let consume_works = (0..4) - .map(|_| consume_work_receivers[0].try_recv().unwrap()) - .collect_vec(); - - assert_eq!( - consume_works.iter().map(|cw| cw.ids.len()).collect_vec(), - vec![TARGET_NUM_TRANSACTIONS_PER_BATCH; 4] - ); - } - - #[test] - fn test_schedule_consume_simple_thread_selection() { - let (test_frame, mut scheduler_controller) = create_test_frame(2); - let TestFrame { - bank, - mint_keypair, - poh_recorder, - banking_packet_sender, - consume_work_receivers, - .. - } = &test_frame; - - poh_recorder - .write() - .unwrap() - .set_bank_for_test(bank.clone()); - - // Send 4 transactions w/o conflicts. 2 should be scheduled on each thread - let txs = (0..4) - .map(|i| { - create_and_fund_prioritized_transfer( - bank, - mint_keypair, - &Keypair::new(), - &Pubkey::new_unique(), - 1, - i * 10, - bank.last_blockhash(), - ) - }) - .collect_vec(); - banking_packet_sender - .send(to_banking_packet_batch(&txs)) - .unwrap(); - - // Priority Expectation: - // Thread 0: [3, 1] - // Thread 1: [2, 0] - let t0_expected = [3, 1] - .into_iter() - .map(|i| txs[i].message().hash()) - .collect_vec(); - let t1_expected = [2, 0] - .into_iter() - .map(|i| txs[i].message().hash()) - .collect_vec(); - - test_receive_then_schedule(&mut scheduler_controller); - let t0_actual = consume_work_receivers[0] - .try_recv() - .unwrap() - .transactions - .iter() - .map(|tx| *tx.message_hash()) - .collect_vec(); - let t1_actual = consume_work_receivers[1] - .try_recv() - .unwrap() - .transactions - .iter() - .map(|tx| *tx.message_hash()) - .collect_vec(); - - assert_eq!(t0_actual, t0_expected); - assert_eq!(t1_actual, t1_expected); - } - - #[test] - fn test_schedule_consume_retryable() { - let (test_frame, mut scheduler_controller) = create_test_frame(1); - let TestFrame { - bank, - mint_keypair, - poh_recorder, - banking_packet_sender, - consume_work_receivers, - finished_consume_work_sender, - .. - } = &test_frame; - - poh_recorder - .write() - .unwrap() - .set_bank_for_test(bank.clone()); - - // Send packet batch to the scheduler - should do nothing until we become the leader. - let tx1 = create_and_fund_prioritized_transfer( - bank, - mint_keypair, - &Keypair::new(), - &Pubkey::new_unique(), - 1, - 1000, - bank.last_blockhash(), - ); - let tx2 = create_and_fund_prioritized_transfer( - bank, - mint_keypair, - &Keypair::new(), - &Pubkey::new_unique(), - 1, - 2000, - bank.last_blockhash(), - ); - let tx1_hash = tx1.message().hash(); - let tx2_hash = tx2.message().hash(); - - let txs = vec![tx1, tx2]; - banking_packet_sender - .send(to_banking_packet_batch(&txs)) - .unwrap(); - - test_receive_then_schedule(&mut scheduler_controller); - let consume_work = consume_work_receivers[0].try_recv().unwrap(); - assert_eq!(consume_work.ids.len(), 2); - assert_eq!(consume_work.transactions.len(), 2); - let message_hashes = consume_work - .transactions - .iter() - .map(|tx| tx.message_hash()) - .collect_vec(); - assert_eq!(message_hashes, vec![&tx2_hash, &tx1_hash]); - - // Complete the batch - marking the second transaction as retryable - finished_consume_work_sender - .send(FinishedConsumeWork { - work: consume_work, - retryable_indexes: vec![1], - }) - .unwrap(); - - // Transaction should be rescheduled - test_receive_then_schedule(&mut scheduler_controller); - let consume_work = consume_work_receivers[0].try_recv().unwrap(); - assert_eq!(consume_work.ids.len(), 1); - assert_eq!(consume_work.transactions.len(), 1); - let message_hashes = consume_work - .transactions - .iter() - .map(|tx| tx.message_hash()) - .collect_vec(); - assert_eq!(message_hashes, vec![&tx1_hash]); - } -} diff --git a/prio-graph-scheduler/src/scheduler_messages.rs b/prio-graph-scheduler/src/scheduler_messages.rs index b5e11be6ba9d78..466fce7b1e7f54 100644 --- a/prio-graph-scheduler/src/scheduler_messages.rs +++ b/prio-graph-scheduler/src/scheduler_messages.rs @@ -1,5 +1,5 @@ use { - solana_core::banking_stage::immutable_deserialized_packet::ImmutableDeserializedPacket, + crate::deserializable_packet::DeserializableTxPacket, solana_sdk::{clock::Slot, transaction::SanitizedTransaction}, std::{fmt::Display, sync::Arc}, }; @@ -59,8 +59,8 @@ pub struct ConsumeWork { /// Message: [Scheduler -> Worker] /// Transactions to be forwarded to the next leader(s) -pub struct ForwardWork { - pub packets: Vec>, +pub struct ForwardWork { + pub packets: Vec>, } /// Message: [Worker -> Scheduler] @@ -72,7 +72,7 @@ pub struct FinishedConsumeWork { /// Message: [Worker -> Scheduler] /// Forwarded transactions. -pub struct FinishedForwardWork { - pub work: ForwardWork, +pub struct FinishedForwardWork { + pub work: ForwardWork

, pub successful: bool, } diff --git a/prio-graph-scheduler/src/transaction_state.rs b/prio-graph-scheduler/src/transaction_state.rs index 9c9d783ab15369..56575beeaf79a4 100644 --- a/prio-graph-scheduler/src/transaction_state.rs +++ b/prio-graph-scheduler/src/transaction_state.rs @@ -1,13 +1,13 @@ use { - solana_core::banking_stage::immutable_deserialized_packet::ImmutableDeserializedPacket, - solana_sdk::{clock::Slot, transaction::SanitizedTransaction}, - std::sync::Arc, + crate::deserializable_packet::DeserializableTxPacket, + solana_sdk::{clock::Slot, transaction::SanitizedTransaction}, + std::sync::Arc, }; /// Simple wrapper type to tie a sanitized transaction to max age slot. pub struct SanitizedTransactionTTL { - pub transaction: SanitizedTransaction, - pub max_age_slot: Slot, + pub transaction: SanitizedTransaction, + pub max_age_slot: Slot, } /// TransactionState is used to track the state of a transaction in the transaction scheduler @@ -30,330 +30,333 @@ pub struct SanitizedTransactionTTL { /// to the appropriate thread for processing. This is done to avoid cloning the /// `SanitizedTransaction`. #[allow(clippy::large_enum_variant)] -pub enum TransactionState { - /// The transaction is available for scheduling. - Unprocessed { - transaction_ttl: SanitizedTransactionTTL, - packet: Arc, - priority: u64, - cost: u64, - should_forward: bool, - }, - /// The transaction is currently scheduled or being processed. - Pending { - packet: Arc, - priority: u64, - cost: u64, - should_forward: bool, - }, - /// Only used during transition. - Transitioning, +pub enum TransactionState { + /// The transaction is available for scheduling. + Unprocessed { + transaction_ttl: SanitizedTransactionTTL, + packet: Arc

, + priority: u64, + cost: u64, + should_forward: bool, + }, + /// The transaction is currently scheduled or being processed. + Pending { + packet: Arc

, + priority: u64, + cost: u64, + should_forward: bool, + }, + /// Only used during transition. + Transitioning, } -impl TransactionState { - /// Creates a new `TransactionState` in the `Unprocessed` state. - pub fn new( - transaction_ttl: SanitizedTransactionTTL, - packet: Arc, - priority: u64, - cost: u64, - ) -> Self { - let should_forward = !packet.original_packet().meta().forwarded() - && packet.original_packet().meta().is_from_staked_node(); - Self::Unprocessed { - transaction_ttl, - packet, - priority, - cost, - should_forward, - } - } +impl TransactionState

{ + /// Creates a new `TransactionState` in the `Unprocessed` state. + pub fn new( + transaction_ttl: SanitizedTransactionTTL, + packet: Arc

, + priority: u64, + cost: u64, + ) -> Self { + let should_forward = !packet.original_packet().meta().forwarded() + && packet.original_packet().meta().is_from_staked_node(); + Self::Unprocessed { + transaction_ttl, + packet, + priority, + cost, + should_forward, + } + } - /// Return the priority of the transaction. - /// This is *not* the same as the `compute_unit_price` of the transaction. - /// The priority is used to order transactions for processing. - pub fn priority(&self) -> u64 { - match self { - Self::Unprocessed { priority, .. } => *priority, - Self::Pending { priority, .. } => *priority, - Self::Transitioning => unreachable!(), - } - } + /// Return the priority of the transaction. + /// This is *not* the same as the `compute_unit_price` of the transaction. + /// The priority is used to order transactions for processing. + pub fn priority(&self) -> u64 { + match self { + Self::Unprocessed { priority, .. } => *priority, + Self::Pending { priority, .. } => *priority, + Self::Transitioning => unreachable!(), + } + } - /// Return the cost of the transaction. - pub fn cost(&self) -> u64 { - match self { - Self::Unprocessed { cost, .. } => *cost, - Self::Pending { cost, .. } => *cost, - Self::Transitioning => unreachable!(), - } - } + /// Return the cost of the transaction. + pub fn cost(&self) -> u64 { + match self { + Self::Unprocessed { cost, .. } => *cost, + Self::Pending { cost, .. } => *cost, + Self::Transitioning => unreachable!(), + } + } - /// Return whether packet should be attempted to be forwarded. - pub fn should_forward(&self) -> bool { - match self { - Self::Unprocessed { - should_forward: forwarded, - .. - } => *forwarded, - Self::Pending { - should_forward: forwarded, - .. - } => *forwarded, - Self::Transitioning => unreachable!(), - } - } + /// Return whether packet should be attempted to be forwarded. + pub fn should_forward(&self) -> bool { + match self { + Self::Unprocessed { + should_forward: forwarded, + .. + } => *forwarded, + Self::Pending { + should_forward: forwarded, + .. + } => *forwarded, + Self::Transitioning => unreachable!(), + } + } - /// Mark the packet as forwarded. - /// This is used to prevent the packet from being forwarded multiple times. - pub fn mark_forwarded(&mut self) { - match self { - Self::Unprocessed { should_forward, .. } => *should_forward = false, - Self::Pending { should_forward, .. } => *should_forward = false, - Self::Transitioning => unreachable!(), - } - } + /// Mark the packet as forwarded. + /// This is used to prevent the packet from being forwarded multiple times. + pub fn mark_forwarded(&mut self) { + match self { + Self::Unprocessed { should_forward, .. } => *should_forward = false, + Self::Pending { should_forward, .. } => *should_forward = false, + Self::Transitioning => unreachable!(), + } + } - /// Return the packet of the transaction. - pub fn packet(&self) -> &Arc { - match self { - Self::Unprocessed { packet, .. } => packet, - Self::Pending { packet, .. } => packet, - Self::Transitioning => unreachable!(), - } - } + /// Return the packet of the transaction. + pub fn packet(&self) -> &Arc

{ + match self { + Self::Unprocessed { packet, .. } => packet, + Self::Pending { packet, .. } => packet, + Self::Transitioning => unreachable!(), + } + } - /// Intended to be called when a transaction is scheduled. This method will - /// transition the transaction from `Unprocessed` to `Pending` and return the - /// `SanitizedTransactionTTL` for processing. - /// - /// # Panics - /// This method will panic if the transaction is already in the `Pending` state, - /// as this is an invalid state transition. - pub fn transition_to_pending(&mut self) -> SanitizedTransactionTTL { - match self.take() { - TransactionState::Unprocessed { - transaction_ttl, - packet, - priority, - cost, - should_forward: forwarded, - } => { - *self = TransactionState::Pending { - packet, - priority, - cost, - should_forward: forwarded, - }; - transaction_ttl - } - TransactionState::Pending { .. } => { - panic!("transaction already pending"); - } - Self::Transitioning => unreachable!(), - } - } + /// Intended to be called when a transaction is scheduled. This method will + /// transition the transaction from `Unprocessed` to `Pending` and return the + /// `SanitizedTransactionTTL` for processing. + /// + /// # Panics + /// This method will panic if the transaction is already in the `Pending` state, + /// as this is an invalid state transition. + pub fn transition_to_pending(&mut self) -> SanitizedTransactionTTL { + match self.take() { + TransactionState::Unprocessed { + transaction_ttl, + packet, + priority, + cost, + should_forward: forwarded, + } => { + *self = TransactionState::Pending { + packet, + priority, + cost, + should_forward: forwarded, + }; + transaction_ttl + } + TransactionState::Pending { .. } => { + panic!("transaction already pending"); + } + Self::Transitioning => unreachable!(), + } + } - /// Intended to be called when a transaction is retried. This method will - /// transition the transaction from `Pending` to `Unprocessed`. - /// - /// # Panics - /// This method will panic if the transaction is already in the `Unprocessed` - /// state, as this is an invalid state transition. - pub fn transition_to_unprocessed(&mut self, transaction_ttl: SanitizedTransactionTTL) { - match self.take() { - TransactionState::Unprocessed { .. } => panic!("already unprocessed"), - TransactionState::Pending { - packet, - priority, - cost, - should_forward: forwarded, - } => { - *self = Self::Unprocessed { - transaction_ttl, - packet, - priority, - cost, - should_forward: forwarded, - } - } - Self::Transitioning => unreachable!(), - } - } + /// Intended to be called when a transaction is retried. This method will + /// transition the transaction from `Pending` to `Unprocessed`. + /// + /// # Panics + /// This method will panic if the transaction is already in the `Unprocessed` + /// state, as this is an invalid state transition. + pub fn transition_to_unprocessed(&mut self, transaction_ttl: SanitizedTransactionTTL) { + match self.take() { + TransactionState::Unprocessed { .. } => panic!("already unprocessed"), + TransactionState::Pending { + packet, + priority, + cost, + should_forward: forwarded, + } => { + *self = Self::Unprocessed { + transaction_ttl, + packet, + priority, + cost, + should_forward: forwarded, + } + } + Self::Transitioning => unreachable!(), + } + } - /// Get a reference to the `SanitizedTransactionTTL` for the transaction. - /// - /// # Panics - /// This method will panic if the transaction is in the `Pending` state. - pub fn transaction_ttl(&self) -> &SanitizedTransactionTTL { - match self { - Self::Unprocessed { - transaction_ttl, .. - } => transaction_ttl, - Self::Pending { .. } => panic!("transaction is pending"), - Self::Transitioning => unreachable!(), - } - } + /// Get a reference to the `SanitizedTransactionTTL` for the transaction. + /// + /// # Panics + /// This method will panic if the transaction is in the `Pending` state. + pub fn transaction_ttl(&self) -> &SanitizedTransactionTTL { + match self { + Self::Unprocessed { + transaction_ttl, .. + } => transaction_ttl, + Self::Pending { .. } => panic!("transaction is pending"), + Self::Transitioning => unreachable!(), + } + } - /// Internal helper to transitioning between states. - /// Replaces `self` with a dummy state that will immediately be overwritten in transition. - fn take(&mut self) -> Self { - core::mem::replace(self, Self::Transitioning) - } + /// Internal helper to transitioning between states. + /// Replaces `self` with a dummy state that will immediately be overwritten in transition. + fn take(&mut self) -> Self { + core::mem::replace(self, Self::Transitioning) + } } #[cfg(test)] mod tests { - use { - super::*, - solana_sdk::{ - compute_budget::ComputeBudgetInstruction, hash::Hash, message::Message, packet::Packet, - signature::Keypair, signer::Signer, system_instruction, transaction::Transaction, - }, - }; + use { + super::*, + solana_core::banking_stage::immutable_deserialized_packet::ImmutableDeserializedPacket, + solana_sdk::{ + compute_budget::ComputeBudgetInstruction, hash::Hash, message::Message, packet::Packet, + signature::Keypair, signer::Signer, system_instruction, transaction::Transaction, + }, + }; - fn create_transaction_state(compute_unit_price: u64) -> TransactionState { - let from_keypair = Keypair::new(); - let ixs = vec![ - system_instruction::transfer( - &from_keypair.pubkey(), - &solana_sdk::pubkey::new_rand(), - 1, - ), - ComputeBudgetInstruction::set_compute_unit_price(compute_unit_price), - ]; - let message = Message::new(&ixs, Some(&from_keypair.pubkey())); - let tx = Transaction::new(&[&from_keypair], message, Hash::default()); + fn create_transaction_state( + compute_unit_price: u64, + ) -> TransactionState { + let from_keypair = Keypair::new(); + let ixs = vec![ + system_instruction::transfer( + &from_keypair.pubkey(), + &solana_sdk::pubkey::new_rand(), + 1, + ), + ComputeBudgetInstruction::set_compute_unit_price(compute_unit_price), + ]; + let message = Message::new(&ixs, Some(&from_keypair.pubkey())); + let tx = Transaction::new(&[&from_keypair], message, Hash::default()); - let packet = Arc::new( - ImmutableDeserializedPacket::new(Packet::from_data(None, tx.clone()).unwrap()).unwrap(), - ); - let transaction_ttl = SanitizedTransactionTTL { - transaction: SanitizedTransaction::from_transaction_for_tests(tx), - max_age_slot: Slot::MAX, - }; - const TEST_TRANSACTION_COST: u64 = 5000; - TransactionState::new( - transaction_ttl, - packet, - compute_unit_price, - TEST_TRANSACTION_COST, - ) - } + let packet = Arc::new( + ImmutableDeserializedPacket::new(Packet::from_data(None, tx.clone()).unwrap()).unwrap(), + ); + let transaction_ttl = SanitizedTransactionTTL { + transaction: SanitizedTransaction::from_transaction_for_tests(tx), + max_age_slot: Slot::MAX, + }; + const TEST_TRANSACTION_COST: u64 = 5000; + TransactionState::new( + transaction_ttl, + packet, + compute_unit_price, + TEST_TRANSACTION_COST, + ) + } - #[test] - #[should_panic(expected = "already pending")] - fn test_transition_to_pending_panic() { - let mut transaction_state = create_transaction_state(0); - transaction_state.transition_to_pending(); - transaction_state.transition_to_pending(); // invalid transition - } + #[test] + #[should_panic(expected = "already pending")] + fn test_transition_to_pending_panic() { + let mut transaction_state = create_transaction_state(0); + transaction_state.transition_to_pending(); + transaction_state.transition_to_pending(); // invalid transition + } - #[test] - fn test_transition_to_pending() { - let mut transaction_state = create_transaction_state(0); - assert!(matches!( - transaction_state, - TransactionState::Unprocessed { .. } - )); - let _ = transaction_state.transition_to_pending(); - assert!(matches!( - transaction_state, - TransactionState::Pending { .. } - )); - } + #[test] + fn test_transition_to_pending() { + let mut transaction_state = create_transaction_state(0); + assert!(matches!( + transaction_state, + TransactionState::Unprocessed { .. } + )); + let _ = transaction_state.transition_to_pending(); + assert!(matches!( + transaction_state, + TransactionState::Pending { .. } + )); + } - #[test] - #[should_panic(expected = "already unprocessed")] - fn test_transition_to_unprocessed_panic() { - let mut transaction_state = create_transaction_state(0); + #[test] + #[should_panic(expected = "already unprocessed")] + fn test_transition_to_unprocessed_panic() { + let mut transaction_state = create_transaction_state(0); - // Manually clone `SanitizedTransactionTTL` - let SanitizedTransactionTTL { - transaction, - max_age_slot, - } = transaction_state.transaction_ttl(); - let transaction_ttl = SanitizedTransactionTTL { - transaction: transaction.clone(), - max_age_slot: *max_age_slot, - }; - transaction_state.transition_to_unprocessed(transaction_ttl); // invalid transition - } + // Manually clone `SanitizedTransactionTTL` + let SanitizedTransactionTTL { + transaction, + max_age_slot, + } = transaction_state.transaction_ttl(); + let transaction_ttl = SanitizedTransactionTTL { + transaction: transaction.clone(), + max_age_slot: *max_age_slot, + }; + transaction_state.transition_to_unprocessed(transaction_ttl); // invalid transition + } - #[test] - fn test_transition_to_unprocessed() { - let mut transaction_state = create_transaction_state(0); - assert!(matches!( - transaction_state, - TransactionState::Unprocessed { .. } - )); - let transaction_ttl = transaction_state.transition_to_pending(); - assert!(matches!( - transaction_state, - TransactionState::Pending { .. } - )); - transaction_state.transition_to_unprocessed(transaction_ttl); - assert!(matches!( - transaction_state, - TransactionState::Unprocessed { .. } - )); - } + #[test] + fn test_transition_to_unprocessed() { + let mut transaction_state = create_transaction_state(0); + assert!(matches!( + transaction_state, + TransactionState::Unprocessed { .. } + )); + let transaction_ttl = transaction_state.transition_to_pending(); + assert!(matches!( + transaction_state, + TransactionState::Pending { .. } + )); + transaction_state.transition_to_unprocessed(transaction_ttl); + assert!(matches!( + transaction_state, + TransactionState::Unprocessed { .. } + )); + } - #[test] - fn test_priority() { - let priority = 15; - let mut transaction_state = create_transaction_state(priority); - assert_eq!(transaction_state.priority(), priority); + #[test] + fn test_priority() { + let priority = 15; + let mut transaction_state = create_transaction_state(priority); + assert_eq!(transaction_state.priority(), priority); - // ensure compute unit price is not lost through state transitions - let transaction_ttl = transaction_state.transition_to_pending(); - assert_eq!(transaction_state.priority(), priority); - transaction_state.transition_to_unprocessed(transaction_ttl); - assert_eq!(transaction_state.priority(), priority); - } + // ensure compute unit price is not lost through state transitions + let transaction_ttl = transaction_state.transition_to_pending(); + assert_eq!(transaction_state.priority(), priority); + transaction_state.transition_to_unprocessed(transaction_ttl); + assert_eq!(transaction_state.priority(), priority); + } - #[test] - #[should_panic(expected = "transaction is pending")] - fn test_transaction_ttl_panic() { - let mut transaction_state = create_transaction_state(0); - let transaction_ttl = transaction_state.transaction_ttl(); - assert!(matches!( - transaction_state, - TransactionState::Unprocessed { .. } - )); - assert_eq!(transaction_ttl.max_age_slot, Slot::MAX); + #[test] + #[should_panic(expected = "transaction is pending")] + fn test_transaction_ttl_panic() { + let mut transaction_state = create_transaction_state(0); + let transaction_ttl = transaction_state.transaction_ttl(); + assert!(matches!( + transaction_state, + TransactionState::Unprocessed { .. } + )); + assert_eq!(transaction_ttl.max_age_slot, Slot::MAX); - let _ = transaction_state.transition_to_pending(); - assert!(matches!( - transaction_state, - TransactionState::Pending { .. } - )); - let _ = transaction_state.transaction_ttl(); // pending state, the transaction ttl is not available - } + let _ = transaction_state.transition_to_pending(); + assert!(matches!( + transaction_state, + TransactionState::Pending { .. } + )); + let _ = transaction_state.transaction_ttl(); // pending state, the transaction ttl is not available + } - #[test] - fn test_transaction_ttl() { - let mut transaction_state = create_transaction_state(0); - let transaction_ttl = transaction_state.transaction_ttl(); - assert!(matches!( - transaction_state, - TransactionState::Unprocessed { .. } - )); - assert_eq!(transaction_ttl.max_age_slot, Slot::MAX); + #[test] + fn test_transaction_ttl() { + let mut transaction_state = create_transaction_state(0); + let transaction_ttl = transaction_state.transaction_ttl(); + assert!(matches!( + transaction_state, + TransactionState::Unprocessed { .. } + )); + assert_eq!(transaction_ttl.max_age_slot, Slot::MAX); - // ensure transaction_ttl is not lost through state transitions - let transaction_ttl = transaction_state.transition_to_pending(); - assert!(matches!( - transaction_state, - TransactionState::Pending { .. } - )); + // ensure transaction_ttl is not lost through state transitions + let transaction_ttl = transaction_state.transition_to_pending(); + assert!(matches!( + transaction_state, + TransactionState::Pending { .. } + )); - transaction_state.transition_to_unprocessed(transaction_ttl); - let transaction_ttl = transaction_state.transaction_ttl(); - assert!(matches!( - transaction_state, - TransactionState::Unprocessed { .. } - )); - assert_eq!(transaction_ttl.max_age_slot, Slot::MAX); - } + transaction_state.transition_to_unprocessed(transaction_ttl); + let transaction_ttl = transaction_state.transaction_ttl(); + assert!(matches!( + transaction_state, + TransactionState::Unprocessed { .. } + )); + assert_eq!(transaction_ttl.max_age_slot, Slot::MAX); + } } diff --git a/prio-graph-scheduler/src/transaction_state_container.rs b/prio-graph-scheduler/src/transaction_state_container.rs index 8e2a51f5f7bb37..c162034f6f890d 100644 --- a/prio-graph-scheduler/src/transaction_state_container.rs +++ b/prio-graph-scheduler/src/transaction_state_container.rs @@ -3,10 +3,9 @@ use { transaction_priority_id::TransactionPriorityId, transaction_state::{SanitizedTransactionTTL, TransactionState}, }, - crate::scheduler_messages::TransactionId, + crate::{deserializable_packet::DeserializableTxPacket, scheduler_messages::TransactionId}, itertools::MinMaxResult, min_max_heap::MinMaxHeap, - solana_core::banking_stage::immutable_deserialized_packet::ImmutableDeserializedPacket, std::{collections::HashMap, sync::Arc}, }; @@ -35,12 +34,12 @@ use { /// /// The container maintains a fixed capacity. If the queue is full when pushing /// a new transaction, the lowest priority transaction will be dropped. -pub struct TransactionStateContainer { +pub struct TransactionStateContainer { priority_queue: MinMaxHeap, - id_to_transaction_state: HashMap, + id_to_transaction_state: HashMap>, } -impl TransactionStateContainer { +impl TransactionStateContainer

{ pub fn with_capacity(capacity: usize) -> Self { Self { priority_queue: MinMaxHeap::with_capacity(capacity), @@ -67,16 +66,13 @@ impl TransactionStateContainer { pub fn get_mut_transaction_state( &mut self, id: &TransactionId, - ) -> Option<&mut TransactionState> { + ) -> Option<&mut TransactionState

> { self.id_to_transaction_state.get_mut(id) } /// Get reference to `SanitizedTransactionTTL` by id. /// Panics if the transaction does not exist. - pub fn get_transaction_ttl( - &self, - id: &TransactionId, - ) -> Option<&SanitizedTransactionTTL> { + pub fn get_transaction_ttl(&self, id: &TransactionId) -> Option<&SanitizedTransactionTTL> { self.id_to_transaction_state .get(id) .map(|state| state.transaction_ttl()) @@ -88,7 +84,7 @@ impl TransactionStateContainer { &mut self, transaction_id: TransactionId, transaction_ttl: SanitizedTransactionTTL, - packet: Arc, + packet: Arc

, priority: u64, cost: u64, ) -> bool { @@ -150,8 +146,7 @@ impl TransactionStateContainer { #[cfg(test)] mod tests { use { - super::*, - solana_sdk::{ + super::*, solana_core::banking_stage::immutable_deserialized_packet::ImmutableDeserializedPacket, solana_sdk::{ compute_budget::ComputeBudgetInstruction, hash::Hash, message::Message, @@ -161,7 +156,7 @@ mod tests { slot_history::Slot, system_instruction, transaction::{SanitizedTransaction, Transaction}, - }, + } }; /// Returns (transaction_ttl, priority, cost) @@ -202,7 +197,10 @@ mod tests { (transaction_ttl, packet, priority, TEST_TRANSACTION_COST) } - fn push_to_container(container: &mut TransactionStateContainer, num: usize) { + fn push_to_container( + container: &mut TransactionStateContainer, + num: usize, + ) { for id in 0..num as u64 { let priority = id; let (transaction_ttl, packet, priority, cost) = test_transaction(priority); From d554cec35f3f35a727d24f896e840dc704bbd5ec Mon Sep 17 00:00:00 2001 From: lewis Date: Mon, 14 Oct 2024 16:20:57 +0800 Subject: [PATCH 6/9] feat: prio-graph crate remove dependency of solana-core --- Cargo.lock | 9 +- prio-graph-scheduler/Cargo.toml | 17 +- .../src/deserializable_packet.rs | 44 --- prio-graph-scheduler/src/lib.rs | 186 +++++++++++- .../src/prio_graph_scheduler.rs | 21 +- .../src/read_write_account_set.rs | 287 ++++++++++++++++++ prio-graph-scheduler/src/transaction_state.rs | 10 +- .../src/transaction_state_container.rs | 8 +- 8 files changed, 507 insertions(+), 75 deletions(-) create mode 100644 prio-graph-scheduler/src/read_write_account_set.rs diff --git a/Cargo.lock b/Cargo.lock index cb5ebbf4ef1ea3..10f8bb02973ee8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7241,20 +7241,27 @@ dependencies = [ "ahash 0.8.10", "arrayvec", "assert_matches", + "bincode", "crossbeam-channel", "itertools 0.12.1", "log", "min-max-heap", "prio-graph", - "solana-core", + "solana-compute-budget", "solana-cost-model", "solana-gossip", "solana-ledger", "solana-measure", "solana-metrics", + "solana-perf", "solana-poh", + "solana-prio-graph-scheduler", "solana-runtime", + "solana-runtime-transaction", + "solana-sanitize", "solana-sdk", + "solana-short-vec", + "solana-svm-transaction", "thiserror", ] diff --git a/prio-graph-scheduler/Cargo.toml b/prio-graph-scheduler/Cargo.toml index 47706ad7f40b1b..1316a486bda9cb 100644 --- a/prio-graph-scheduler/Cargo.toml +++ b/prio-graph-scheduler/Cargo.toml @@ -10,7 +10,6 @@ license.workspace = true edition.workspace = true [dependencies] -solana-core = { workspace = true } solana-sdk = { workspace = true } solana-poh = { workspace = true } solana-metrics = { workspace = true } @@ -31,9 +30,25 @@ min-max-heap = { workspace = true } [dev-dependencies] assert_matches = { workspace = true } +solana-compute-budget = { workspace = true } +solana-perf = { workspace = true } +solana-runtime-transaction = { workspace = true } +solana-sanitize = { workspace = true } +solana-short-vec = { workspace = true } +solana-svm-transaction = { workspace = true } +# let dev-context-only-utils works when running this crate. +solana-prio-graph-scheduler = { path = ".", features = [ + "dev-context-only-utils", +] } +solana-sdk = { workspace = true, features = ["dev-context-only-utils"] } + +bincode = { workspace = true } [package.metadata.docs.rs] targets = ["x86_64-unknown-linux-gnu"] [lints] workspace = true + +[features] +dev-context-only-utils = ["solana-runtime/dev-context-only-utils"] diff --git a/prio-graph-scheduler/src/deserializable_packet.rs b/prio-graph-scheduler/src/deserializable_packet.rs index 79c2a834666d41..0f54b0ec6de047 100644 --- a/prio-graph-scheduler/src/deserializable_packet.rs +++ b/prio-graph-scheduler/src/deserializable_packet.rs @@ -1,5 +1,4 @@ use std::collections::HashSet; -use solana_core::banking_stage::immutable_deserialized_packet::{DeserializedPacketError, ImmutableDeserializedPacket}; use solana_sdk::hash::Hash; use solana_sdk::message::AddressLoader; use solana_sdk::packet::Packet; @@ -37,47 +36,4 @@ pub trait DeserializableTxPacket: PartialEq + PartialOrd + Eq + Sized { fn compute_unit_price(&self) -> u64; fn compute_unit_limit(&self) -> u64; -} - - -/// TODO: migrate to solana_core -impl DeserializableTxPacket for ImmutableDeserializedPacket { - type DeserializeError = DeserializedPacketError; - - fn from_packet(packet: Packet) -> Result { - ImmutableDeserializedPacket::new(packet) - } - - fn build_sanitized_transaction( - &self, - votes_only: bool, - address_loader: impl AddressLoader, - reserved_account_keys: &HashSet, - ) -> Option { - self.build_sanitized_transaction(votes_only, address_loader, reserved_account_keys) - } - - fn original_packet(&self) -> &Packet { - &self.original_packet - } - - fn transaction(&self) -> &SanitizedVersionedTransaction { - &self.transaction - } - - fn message_hash(&self) -> &Hash { - &self.message_hash - } - - fn is_simple_vote(&self) -> bool { - self.is_simple_vote - } - - fn compute_unit_price(&self) -> u64 { - self.compute_unit_price - } - - fn compute_unit_limit(&self) -> u64 { - u64::from(self.compute_unit_limit) - } } \ No newline at end of file diff --git a/prio-graph-scheduler/src/lib.rs b/prio-graph-scheduler/src/lib.rs index 215c451f5e33a2..8e4ddb9d76bb45 100644 --- a/prio-graph-scheduler/src/lib.rs +++ b/prio-graph-scheduler/src/lib.rs @@ -1,20 +1,192 @@ //! Solana Priority Graph Scheduler. -pub mod transaction_state; -pub mod scheduler_messages; pub mod id_generator; pub mod in_flight_tracker; -pub mod thread_aware_account_locks; -pub mod transaction_priority_id; pub mod scheduler_error; +pub mod scheduler_messages; pub mod scheduler_metrics; +pub mod thread_aware_account_locks; +pub mod transaction_priority_id; +pub mod transaction_state; // pub mod scheduler_controller; -pub mod transaction_state_container; -pub mod prio_graph_scheduler; pub mod deserializable_packet; +pub mod prio_graph_scheduler; +pub mod transaction_state_container; #[macro_use] extern crate solana_metrics; #[cfg(test)] #[macro_use] -extern crate assert_matches; \ No newline at end of file +extern crate assert_matches; + +/// Consumer will create chunks of transactions from buffer with up to this size. +pub const TARGET_NUM_TRANSACTIONS_PER_BATCH: usize = 64; + +mod read_write_account_set; + +#[cfg(test)] +mod tests { + use { + crate::deserializable_packet::DeserializableTxPacket, + solana_compute_budget::compute_budget_limits::ComputeBudgetLimits, + solana_perf::packet::Packet, + solana_runtime_transaction::instructions_processor::process_compute_budget_instructions, + solana_sanitize::SanitizeError, + solana_sdk::{ + hash::Hash, + message::Message, + pubkey::Pubkey, + signature::Signature, + transaction::{ + AddressLoader, SanitizedTransaction, SanitizedVersionedTransaction, + VersionedTransaction, + }, + }, + solana_short_vec::decode_shortu16_len, + solana_svm_transaction::instruction::SVMInstruction, + std::{cmp::Ordering, collections::HashSet, mem::size_of}, + thiserror::Error, + }; + + #[derive(Debug, Error)] + pub enum MockDeserializedPacketError { + #[error("ShortVec Failed to Deserialize")] + // short_vec::decode_shortu16_len() currently returns () on error + ShortVecError(()), + #[error("Deserialization Error: {0}")] + DeserializationError(#[from] bincode::Error), + #[error("overflowed on signature size {0}")] + SignatureOverflowed(usize), + #[error("packet failed sanitization {0}")] + SanitizeError(#[from] SanitizeError), + #[error("transaction failed prioritization")] + PrioritizationFailure, + } + + #[derive(Debug, Eq)] + pub struct MockImmutableDeserializedPacket { + pub original_packet: Packet, + pub transaction: SanitizedVersionedTransaction, + pub message_hash: Hash, + pub is_simple_vote: bool, + pub compute_unit_price: u64, + pub compute_unit_limit: u32, + } + + impl DeserializableTxPacket for MockImmutableDeserializedPacket { + type DeserializeError = MockDeserializedPacketError; + fn from_packet(packet: Packet) -> Result { + let versioned_transaction: VersionedTransaction = packet.deserialize_slice(..)?; + let sanitized_transaction = + SanitizedVersionedTransaction::try_from(versioned_transaction)?; + let message_bytes = packet_message(&packet)?; + let message_hash = Message::hash_raw_message(message_bytes); + let is_simple_vote = packet.meta().is_simple_vote_tx(); + + // drop transaction if prioritization fails. + let ComputeBudgetLimits { + mut compute_unit_price, + compute_unit_limit, + .. + } = process_compute_budget_instructions( + sanitized_transaction + .get_message() + .program_instructions_iter() + .map(|(pubkey, ix)| (pubkey, SVMInstruction::from(ix))), + ) + .map_err(|_| MockDeserializedPacketError::PrioritizationFailure)?; + + // set compute unit price to zero for vote transactions + if is_simple_vote { + compute_unit_price = 0; + }; + + Ok(Self { + original_packet: packet, + transaction: sanitized_transaction, + message_hash, + is_simple_vote, + compute_unit_price, + compute_unit_limit, + }) + } + + fn original_packet(&self) -> &Packet { + &self.original_packet + } + + fn transaction(&self) -> &SanitizedVersionedTransaction { + &self.transaction + } + + fn message_hash(&self) -> &Hash { + &self.message_hash + } + + fn is_simple_vote(&self) -> bool { + self.is_simple_vote + } + + fn compute_unit_price(&self) -> u64 { + self.compute_unit_price + } + + fn compute_unit_limit(&self) -> u64 { + u64::from(self.compute_unit_limit) + } + + // This function deserializes packets into transactions, computes the blake3 hash of transaction + // messages. + fn build_sanitized_transaction( + &self, + votes_only: bool, + address_loader: impl AddressLoader, + reserved_account_keys: &HashSet, + ) -> Option { + if votes_only && !self.is_simple_vote() { + return None; + } + let tx = SanitizedTransaction::try_new( + self.transaction().clone(), + *self.message_hash(), + self.is_simple_vote(), + address_loader, + reserved_account_keys, + ) + .ok()?; + Some(tx) + } + } + + // PartialEq MUST be consistent with PartialOrd and Ord + impl PartialEq for MockImmutableDeserializedPacket { + fn eq(&self, other: &Self) -> bool { + self.compute_unit_price() == other.compute_unit_price() + } + } + + impl PartialOrd for MockImmutableDeserializedPacket { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } + } + + impl Ord for MockImmutableDeserializedPacket { + fn cmp(&self, other: &Self) -> Ordering { + self.compute_unit_price().cmp(&other.compute_unit_price()) + } + } + + /// Read the transaction message from packet data + fn packet_message(packet: &Packet) -> Result<&[u8], MockDeserializedPacketError> { + let (sig_len, sig_size) = packet + .data(..) + .and_then(|bytes| decode_shortu16_len(bytes).ok()) + .ok_or(MockDeserializedPacketError::ShortVecError(()))?; + sig_len + .checked_mul(size_of::()) + .and_then(|v| v.checked_add(sig_size)) + .and_then(|msg_start| packet.data(msg_start..)) + .ok_or(MockDeserializedPacketError::SignatureOverflowed(sig_size)) + } +} diff --git a/prio-graph-scheduler/src/prio_graph_scheduler.rs b/prio-graph-scheduler/src/prio_graph_scheduler.rs index 4d66a7901d4976..209e0799951d22 100644 --- a/prio-graph-scheduler/src/prio_graph_scheduler.rs +++ b/prio-graph-scheduler/src/prio_graph_scheduler.rs @@ -2,19 +2,18 @@ use { crate::{ deserializable_packet::DeserializableTxPacket, in_flight_tracker::InFlightTracker, + read_write_account_set::ReadWriteAccountSet, scheduler_error::SchedulerError, scheduler_messages::{ConsumeWork, FinishedConsumeWork, TransactionBatchId, TransactionId}, thread_aware_account_locks::{ThreadAwareAccountLocks, ThreadId, ThreadSet}, transaction_priority_id::TransactionPriorityId, transaction_state::{SanitizedTransactionTTL, TransactionState}, transaction_state_container::TransactionStateContainer, + TARGET_NUM_TRANSACTIONS_PER_BATCH, }, crossbeam_channel::{Receiver, Sender, TryRecvError}, itertools::izip, prio_graph::{AccessKind, PrioGraph}, - solana_core::banking_stage::{ - consumer::TARGET_NUM_TRANSACTIONS_PER_BATCH, read_write_account_set::ReadWriteAccountSet, - }, solana_cost_model::block_cost_limits::MAX_BLOCK_UNITS, solana_measure::measure_us, solana_sdk::{ @@ -592,12 +591,10 @@ fn try_schedule_transaction( mod tests { use { super::*, + crate::tests::MockImmutableDeserializedPacket, + crate::TARGET_NUM_TRANSACTIONS_PER_BATCH, crossbeam_channel::{unbounded, Receiver}, itertools::Itertools, - solana_core::banking_stage::{ - consumer::TARGET_NUM_TRANSACTIONS_PER_BATCH, - immutable_deserialized_packet::ImmutableDeserializedPacket, - }, solana_sdk::{ compute_budget::ComputeBudgetInstruction, hash::Hash, message::Message, packet::Packet, pubkey::Pubkey, signature::Keypair, signer::Signer, system_instruction, @@ -621,14 +618,14 @@ mod tests { fn create_test_frame( num_threads: usize, ) -> ( - PrioGraphScheduler, + PrioGraphScheduler, Vec>, Sender, ) { let (consume_work_senders, consume_work_receivers) = (0..num_threads).map(|_| unbounded()).unzip(); let (finished_consume_work_sender, finished_consume_work_receiver) = unbounded(); - let scheduler = PrioGraphScheduler::::new( + let scheduler = PrioGraphScheduler::::new( consume_work_senders, finished_consume_work_receiver, ); @@ -668,9 +665,9 @@ mod tests { u64, ), >, - ) -> TransactionStateContainer { + ) -> TransactionStateContainer { let mut container = - TransactionStateContainer::::with_capacity(10 * 1024); + TransactionStateContainer::::with_capacity(10 * 1024); for (index, (from_keypair, to_pubkeys, lamports, compute_unit_price)) in tx_infos.into_iter().enumerate() { @@ -682,7 +679,7 @@ mod tests { compute_unit_price, ); let packet = Arc::new( - ImmutableDeserializedPacket::new( + MockImmutableDeserializedPacket::from_packet( Packet::from_data(None, transaction.to_versioned_transaction()).unwrap(), ) .unwrap(), diff --git a/prio-graph-scheduler/src/read_write_account_set.rs b/prio-graph-scheduler/src/read_write_account_set.rs new file mode 100644 index 00000000000000..ea5fc0ffa1a13e --- /dev/null +++ b/prio-graph-scheduler/src/read_write_account_set.rs @@ -0,0 +1,287 @@ +use { + ahash::AHashSet, + solana_sdk::{message::SanitizedMessage, pubkey::Pubkey}, +}; + +/// Wrapper struct to accumulate locks for a batch of transactions. +#[derive(Debug, Default)] +pub struct ReadWriteAccountSet { + /// Set of accounts that are locked for read + read_set: AHashSet, + /// Set of accounts that are locked for write + write_set: AHashSet, +} + +impl ReadWriteAccountSet { + /// Returns true if all account locks were available and false otherwise. + pub fn check_locks(&self, message: &SanitizedMessage) -> bool { + message + .account_keys() + .iter() + .enumerate() + .all(|(index, pubkey)| { + if message.is_writable(index) { + self.can_write(pubkey) + } else { + self.can_read(pubkey) + } + }) + } + + /// Add all account locks. + /// Returns true if all account locks were available and false otherwise. + pub fn take_locks(&mut self, message: &SanitizedMessage) -> bool { + message + .account_keys() + .iter() + .enumerate() + .fold(true, |all_available, (index, pubkey)| { + if message.is_writable(index) { + all_available & self.add_write(pubkey) + } else { + all_available & self.add_read(pubkey) + } + }) + } + + /// Clears the read and write sets + #[allow(dead_code)] + pub fn clear(&mut self) { + self.read_set.clear(); + self.write_set.clear(); + } + + /// Check if an account can be read-locked + fn can_read(&self, pubkey: &Pubkey) -> bool { + !self.write_set.contains(pubkey) + } + + /// Check if an account can be write-locked + fn can_write(&self, pubkey: &Pubkey) -> bool { + !self.write_set.contains(pubkey) && !self.read_set.contains(pubkey) + } + + /// Add an account to the read-set. + /// Returns true if the lock was available. + fn add_read(&mut self, pubkey: &Pubkey) -> bool { + let can_read = self.can_read(pubkey); + self.read_set.insert(*pubkey); + + can_read + } + + /// Add an account to the write-set. + /// Returns true if the lock was available. + fn add_write(&mut self, pubkey: &Pubkey) -> bool { + let can_write = self.can_write(pubkey); + self.write_set.insert(*pubkey); + + can_write + } +} + +#[cfg(test)] +mod tests { + use { + super::ReadWriteAccountSet, + solana_ledger::genesis_utils::GenesisConfigInfo, + solana_runtime::{bank::Bank, bank_forks::BankForks, genesis_utils::create_genesis_config}, + solana_sdk::{ + account::AccountSharedData, + address_lookup_table::{ + self, + state::{AddressLookupTable, LookupTableMeta}, + }, + hash::Hash, + message::{ + v0::{self, MessageAddressTableLookup}, + MessageHeader, VersionedMessage, + }, + pubkey::Pubkey, + signature::Keypair, + signer::Signer, + transaction::{MessageHash, SanitizedTransaction, VersionedTransaction}, + }, + std::{ + borrow::Cow, + sync::{Arc, RwLock}, + }, + }; + + fn create_test_versioned_message( + write_keys: &[Pubkey], + read_keys: &[Pubkey], + address_table_lookups: Vec, + ) -> VersionedMessage { + VersionedMessage::V0(v0::Message { + header: MessageHeader { + num_required_signatures: write_keys.len() as u8, + num_readonly_signed_accounts: 0, + num_readonly_unsigned_accounts: read_keys.len() as u8, + }, + recent_blockhash: Hash::default(), + account_keys: write_keys.iter().chain(read_keys.iter()).copied().collect(), + address_table_lookups, + instructions: vec![], + }) + } + + fn create_test_sanitized_transaction( + write_keypair: &Keypair, + read_keys: &[Pubkey], + address_table_lookups: Vec, + bank: &Bank, + ) -> SanitizedTransaction { + let message = create_test_versioned_message( + &[write_keypair.pubkey()], + read_keys, + address_table_lookups, + ); + SanitizedTransaction::try_create( + VersionedTransaction::try_new(message, &[write_keypair]).unwrap(), + MessageHash::Compute, + Some(false), + bank, + bank.get_reserved_account_keys(), + ) + .unwrap() + } + + fn create_test_address_lookup_table( + bank: Arc, + num_addresses: usize, + ) -> (Arc, Pubkey) { + let mut addresses = Vec::with_capacity(num_addresses); + addresses.resize_with(num_addresses, Pubkey::new_unique); + let address_lookup_table = AddressLookupTable { + meta: LookupTableMeta { + authority: None, + ..LookupTableMeta::default() + }, + addresses: Cow::Owned(addresses), + }; + + let address_table_key = Pubkey::new_unique(); + let data = address_lookup_table.serialize_for_tests().unwrap(); + let mut account = + AccountSharedData::new(1, data.len(), &address_lookup_table::program::id()); + account.set_data(data); + bank.store_account(&address_table_key, &account); + + let slot = bank.slot() + 1; + ( + Arc::new(Bank::new_from_parent(bank, &Pubkey::new_unique(), slot)), + address_table_key, + ) + } + + fn create_test_bank() -> (Arc, Arc>) { + let GenesisConfigInfo { genesis_config, .. } = create_genesis_config(10_000); + Bank::new_no_wallclock_throttle_for_tests(&genesis_config) + } + + // Helper function (could potentially use test_case in future). + // conflict_index = 0 means write lock conflict with static key + // conflict_index = 1 means read lock conflict with static key + // conflict_index = 2 means write lock conflict with address table key + // conflict_index = 3 means read lock conflict with address table key + fn test_check_and_take_locks(conflict_index: usize, add_write: bool, expectation: bool) { + let (bank, _bank_forks) = create_test_bank(); + let (bank, table_address) = create_test_address_lookup_table(bank, 2); + let tx = create_test_sanitized_transaction( + &Keypair::new(), + &[Pubkey::new_unique()], + vec![MessageAddressTableLookup { + account_key: table_address, + writable_indexes: vec![0], + readonly_indexes: vec![1], + }], + &bank, + ); + let message = tx.message(); + + let mut account_locks = ReadWriteAccountSet::default(); + + let conflict_key = message.account_keys().get(conflict_index).unwrap(); + if add_write { + account_locks.add_write(conflict_key); + } else { + account_locks.add_read(conflict_key); + } + assert_eq!(expectation, account_locks.check_locks(message)); + assert_eq!(expectation, account_locks.take_locks(message)); + } + + #[test] + fn test_check_and_take_locks_write_write_conflict() { + test_check_and_take_locks(0, true, false); // static key conflict + test_check_and_take_locks(2, true, false); // lookup key conflict + } + + #[test] + fn test_check_and_take_locks_read_write_conflict() { + test_check_and_take_locks(0, false, false); // static key conflict + test_check_and_take_locks(2, false, false); // lookup key conflict + } + + #[test] + fn test_check_and_take_locks_write_read_conflict() { + test_check_and_take_locks(1, true, false); // static key conflict + test_check_and_take_locks(3, true, false); // lookup key conflict + } + + #[test] + fn test_check_and_take_locks_read_read_non_conflict() { + test_check_and_take_locks(1, false, true); // static key conflict + test_check_and_take_locks(3, false, true); // lookup key conflict + } + + #[test] + pub fn test_write_write_conflict() { + let mut account_locks = ReadWriteAccountSet::default(); + let account = Pubkey::new_unique(); + assert!(account_locks.can_write(&account)); + account_locks.add_write(&account); + assert!(!account_locks.can_write(&account)); + } + + #[test] + pub fn test_read_write_conflict() { + let mut account_locks = ReadWriteAccountSet::default(); + let account = Pubkey::new_unique(); + assert!(account_locks.can_read(&account)); + account_locks.add_read(&account); + assert!(!account_locks.can_write(&account)); + assert!(account_locks.can_read(&account)); + } + + #[test] + pub fn test_write_read_conflict() { + let mut account_locks = ReadWriteAccountSet::default(); + let account = Pubkey::new_unique(); + assert!(account_locks.can_write(&account)); + account_locks.add_write(&account); + assert!(!account_locks.can_write(&account)); + assert!(!account_locks.can_read(&account)); + } + + #[test] + pub fn test_read_read_non_conflict() { + let mut account_locks = ReadWriteAccountSet::default(); + let account = Pubkey::new_unique(); + assert!(account_locks.can_read(&account)); + account_locks.add_read(&account); + assert!(account_locks.can_read(&account)); + } + + #[test] + pub fn test_write_write_different_keys() { + let mut account_locks = ReadWriteAccountSet::default(); + let account1 = Pubkey::new_unique(); + let account2 = Pubkey::new_unique(); + assert!(account_locks.can_write(&account1)); + account_locks.add_write(&account1); + assert!(account_locks.can_write(&account2)); + assert!(account_locks.can_read(&account2)); + } +} diff --git a/prio-graph-scheduler/src/transaction_state.rs b/prio-graph-scheduler/src/transaction_state.rs index 56575beeaf79a4..422f4c1f8c6506 100644 --- a/prio-graph-scheduler/src/transaction_state.rs +++ b/prio-graph-scheduler/src/transaction_state.rs @@ -205,17 +205,15 @@ impl TransactionState

{ #[cfg(test)] mod tests { use { - super::*, - solana_core::banking_stage::immutable_deserialized_packet::ImmutableDeserializedPacket, - solana_sdk::{ + super::*, crate::tests::MockImmutableDeserializedPacket, solana_sdk::{ compute_budget::ComputeBudgetInstruction, hash::Hash, message::Message, packet::Packet, signature::Keypair, signer::Signer, system_instruction, transaction::Transaction, - }, + } }; fn create_transaction_state( compute_unit_price: u64, - ) -> TransactionState { + ) -> TransactionState { let from_keypair = Keypair::new(); let ixs = vec![ system_instruction::transfer( @@ -229,7 +227,7 @@ mod tests { let tx = Transaction::new(&[&from_keypair], message, Hash::default()); let packet = Arc::new( - ImmutableDeserializedPacket::new(Packet::from_data(None, tx.clone()).unwrap()).unwrap(), + MockImmutableDeserializedPacket::from_packet(Packet::from_data(None, tx.clone()).unwrap()).unwrap(), ); let transaction_ttl = SanitizedTransactionTTL { transaction: SanitizedTransaction::from_transaction_for_tests(tx), diff --git a/prio-graph-scheduler/src/transaction_state_container.rs b/prio-graph-scheduler/src/transaction_state_container.rs index c162034f6f890d..6e6ad444eea977 100644 --- a/prio-graph-scheduler/src/transaction_state_container.rs +++ b/prio-graph-scheduler/src/transaction_state_container.rs @@ -146,7 +146,7 @@ impl TransactionStateContainer

{ #[cfg(test)] mod tests { use { - super::*, solana_core::banking_stage::immutable_deserialized_packet::ImmutableDeserializedPacket, solana_sdk::{ + super::*, crate::tests::MockImmutableDeserializedPacket, solana_sdk::{ compute_budget::ComputeBudgetInstruction, hash::Hash, message::Message, @@ -164,7 +164,7 @@ mod tests { priority: u64, ) -> ( SanitizedTransactionTTL, - Arc, + Arc, u64, u64, ) { @@ -184,7 +184,7 @@ mod tests { Hash::default(), )); let packet = Arc::new( - ImmutableDeserializedPacket::new( + MockImmutableDeserializedPacket::from_packet( Packet::from_data(None, tx.to_versioned_transaction()).unwrap(), ) .unwrap(), @@ -198,7 +198,7 @@ mod tests { } fn push_to_container( - container: &mut TransactionStateContainer, + container: &mut TransactionStateContainer, num: usize, ) { for id in 0..num as u64 { From 4f6d25d28dd7500a554b20aa13de6cb2bb586b89 Mon Sep 17 00:00:00 2001 From: lewis Date: Mon, 14 Oct 2024 17:55:14 +0800 Subject: [PATCH 7/9] feat: migrate logic in bank_stage to use prio-graph crate --- Cargo.lock | 1 + core/Cargo.toml | 1 + core/src/banking_stage.rs | 15 +- core/src/banking_stage/consume_worker.rs | 4 +- .../forward_packet_batches_by_accounts.rs | 1 + core/src/banking_stage/forward_worker.rs | 27 +- .../immutable_deserialized_packet.rs | 20 +- .../banking_stage/latest_unprocessed_votes.rs | 13 +- core/src/banking_stage/packet_deserializer.rs | 9 +- core/src/banking_stage/packet_filter.rs | 1 + core/src/banking_stage/packet_receiver.rs | 1 + .../scheduler_controller.rs | 42 +- core/src/banking_stage/scheduler_messages.rs | 66 -- .../batch_id_generator.rs | 14 - .../in_flight_tracker.rs | 123 --- .../transaction_scheduler/mod.rs | 11 - .../prio_graph_scheduler.rs | 907 ------------------ .../transaction_scheduler/scheduler_error.rs | 9 - .../scheduler_metrics.rs | 408 -------- .../thread_aware_account_locks.rs | 742 -------------- .../transaction_id_generator.rs | 21 - .../transaction_priority_id.rs | 69 -- .../transaction_state.rs | 359 ------- .../transaction_state_container.rs | 261 ----- .../unprocessed_packet_batches.rs | 8 +- .../unprocessed_transaction_storage.rs | 5 +- .../src/deserializable_packet.rs | 2 +- prio-graph-scheduler/src/lib.rs | 2 +- .../src/prio_graph_scheduler.rs | 2 +- prio-graph-scheduler/src/transaction_state.rs | 2 +- .../src/transaction_state_container.rs | 2 +- 31 files changed, 79 insertions(+), 3069 deletions(-) rename core/src/banking_stage/{transaction_scheduler => }/scheduler_controller.rs (98%) delete mode 100644 core/src/banking_stage/scheduler_messages.rs delete mode 100644 core/src/banking_stage/transaction_scheduler/batch_id_generator.rs delete mode 100644 core/src/banking_stage/transaction_scheduler/in_flight_tracker.rs delete mode 100644 core/src/banking_stage/transaction_scheduler/mod.rs delete mode 100644 core/src/banking_stage/transaction_scheduler/prio_graph_scheduler.rs delete mode 100644 core/src/banking_stage/transaction_scheduler/scheduler_error.rs delete mode 100644 core/src/banking_stage/transaction_scheduler/scheduler_metrics.rs delete mode 100644 core/src/banking_stage/transaction_scheduler/thread_aware_account_locks.rs delete mode 100644 core/src/banking_stage/transaction_scheduler/transaction_id_generator.rs delete mode 100644 core/src/banking_stage/transaction_scheduler/transaction_priority_id.rs delete mode 100644 core/src/banking_stage/transaction_scheduler/transaction_state.rs delete mode 100644 core/src/banking_stage/transaction_scheduler/transaction_state_container.rs diff --git a/Cargo.lock b/Cargo.lock index 10f8bb02973ee8..6d867885203277 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6430,6 +6430,7 @@ dependencies = [ "solana-net-utils", "solana-perf", "solana-poh", + "solana-prio-graph-scheduler", "solana-program-runtime", "solana-quic-client", "solana-rayon-threadlimit", diff --git a/core/Cargo.toml b/core/Cargo.toml index 4d0797908627f5..4f3da3cacf81d1 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -90,6 +90,7 @@ solana-version = { workspace = true } solana-vote = { workspace = true } solana-vote-program = { workspace = true } solana-wen-restart = { workspace = true } +solana-prio-graph-scheduler = { workspace = true } strum = { workspace = true, features = ["derive"] } strum_macros = { workspace = true } sys-info = { workspace = true } diff --git a/core/src/banking_stage.rs b/core/src/banking_stage.rs index 41e7d95595dd84..eeb76fe4d68024 100644 --- a/core/src/banking_stage.rs +++ b/core/src/banking_stage.rs @@ -17,12 +17,8 @@ use { }, crate::{ banking_stage::{ - consume_worker::ConsumeWorker, - packet_deserializer::PacketDeserializer, - transaction_scheduler::{ - prio_graph_scheduler::PrioGraphScheduler, - scheduler_controller::SchedulerController, scheduler_error::SchedulerError, - }, + consume_worker::ConsumeWorker, packet_deserializer::PacketDeserializer, + scheduler_controller::SchedulerController, }, banking_trace::BankingPacketReceiver, tracer_packet_stats::TracerPacketStats, @@ -36,6 +32,9 @@ use { solana_measure::measure_us, solana_perf::{data_budget::DataBudget, packet::PACKETS_PER_BATCH}, solana_poh::poh_recorder::{PohRecorder, TransactionRecorder}, + solana_prio_graph_scheduler::{ + prio_graph_scheduler::PrioGraphScheduler, scheduler_error::SchedulerError, + }, solana_runtime::{ bank_forks::BankForks, prioritization_fee_cache::PrioritizationFeeCache, vote_sender_types::ReplayVoteSender, @@ -74,9 +73,7 @@ mod packet_deserializer; mod packet_filter; mod packet_receiver; pub mod read_write_account_set; -#[allow(dead_code)] -mod scheduler_messages; -mod transaction_scheduler; +mod scheduler_controller; // Fixed thread size seems to be fastest on GCP setup pub const NUM_THREADS: u32 = 6; diff --git a/core/src/banking_stage/consume_worker.rs b/core/src/banking_stage/consume_worker.rs index b676168bb04d4d..b89457bfca7bb6 100644 --- a/core/src/banking_stage/consume_worker.rs +++ b/core/src/banking_stage/consume_worker.rs @@ -2,11 +2,11 @@ use { super::{ consumer::{Consumer, ExecuteAndCommitTransactionsOutput, ProcessTransactionBatchOutput}, leader_slot_timing_metrics::LeaderExecuteAndCommitTimings, - scheduler_messages::{ConsumeWork, FinishedConsumeWork}, }, crossbeam_channel::{Receiver, RecvError, SendError, Sender}, solana_measure::measure_us, solana_poh::leader_bank_notifier::LeaderBankNotifier, + solana_prio_graph_scheduler::scheduler_messages::{ConsumeWork, FinishedConsumeWork}, solana_runtime::bank::Bank, solana_sdk::timing::AtomicInterval, solana_svm::transaction_error_metrics::TransactionErrorMetrics, @@ -694,7 +694,6 @@ mod tests { crate::banking_stage::{ committer::Committer, qos_service::QosService, - scheduler_messages::{TransactionBatchId, TransactionId}, tests::{create_slow_genesis_config, sanitize_transactions, simulate_poh}, }, crossbeam_channel::unbounded, @@ -703,6 +702,7 @@ mod tests { get_tmp_ledger_path_auto_delete, leader_schedule_cache::LeaderScheduleCache, }, solana_poh::poh_recorder::{PohRecorder, WorkingBankEntry}, + solana_prio_graph_scheduler::scheduler_messages::{TransactionBatchId, TransactionId}, solana_runtime::{ bank_forks::BankForks, prioritization_fee_cache::PrioritizationFeeCache, vote_sender_types::ReplayVoteReceiver, diff --git a/core/src/banking_stage/forward_packet_batches_by_accounts.rs b/core/src/banking_stage/forward_packet_batches_by_accounts.rs index 1d86cfb9753b1b..acee6171f6de79 100644 --- a/core/src/banking_stage/forward_packet_batches_by_accounts.rs +++ b/core/src/banking_stage/forward_packet_batches_by_accounts.rs @@ -8,6 +8,7 @@ use { }, solana_feature_set::FeatureSet, solana_perf::packet::Packet, + solana_prio_graph_scheduler::deserializable_packet::DeserializableTxPacket, solana_sdk::transaction::SanitizedTransaction, std::sync::Arc, }; diff --git a/core/src/banking_stage/forward_worker.rs b/core/src/banking_stage/forward_worker.rs index 61cf311f0a8cf8..7d7c9d0d500c4a 100644 --- a/core/src/banking_stage/forward_worker.rs +++ b/core/src/banking_stage/forward_worker.rs @@ -1,36 +1,40 @@ use { super::{ - forwarder::Forwarder, - scheduler_messages::{FinishedForwardWork, ForwardWork}, + forwarder::Forwarder, immutable_deserialized_packet::ImmutableDeserializedPacket, ForwardOption, }, crate::banking_stage::LikeClusterInfo, crossbeam_channel::{Receiver, RecvError, SendError, Sender}, + solana_prio_graph_scheduler::deserializable_packet::DeserializableTxPacket, + solana_prio_graph_scheduler::scheduler_messages::{FinishedForwardWork, ForwardWork}, thiserror::Error, }; +type DefaultForwardWork = ForwardWork; +type DefaultFinishedForwardWork = FinishedForwardWork; + #[derive(Debug, Error)] pub enum ForwardWorkerError { #[error("Failed to receive work from scheduler: {0}")] Recv(#[from] RecvError), #[error("Failed to send finalized forward work to scheduler: {0}")] - Send(#[from] SendError), + Send(#[from] SendError), } pub(crate) struct ForwardWorker { - forward_receiver: Receiver, + forward_receiver: Receiver, forward_option: ForwardOption, forwarder: Forwarder, - forwarded_sender: Sender, + forwarded_sender: Sender, } #[allow(dead_code)] impl ForwardWorker { pub fn new( - forward_receiver: Receiver, + forward_receiver: Receiver, forward_option: ForwardOption, forwarder: Forwarder, - forwarded_sender: Sender, + forwarded_sender: Sender, ) -> Self { Self { forward_receiver, @@ -47,7 +51,7 @@ impl ForwardWorker { } } - fn forward_loop(&self, work: ForwardWork) -> Result<(), ForwardWorkerError> { + fn forward_loop(&self, work: DefaultForwardWork) -> Result<(), ForwardWorkerError> { for work in try_drain_iter(work, &self.forward_receiver) { let (res, _num_packets, _forward_us, _leader_pubkey) = self.forwarder.forward_packets( &self.forward_option, @@ -64,7 +68,7 @@ impl ForwardWorker { Ok(()) } - fn failed_forward_drain(&self, work: ForwardWork) -> Result<(), ForwardWorkerError> { + fn failed_forward_drain(&self, work: DefaultForwardWork) -> Result<(), ForwardWorkerError> { for work in try_drain_iter(work, &self.forward_receiver) { self.forwarded_sender.send(FinishedForwardWork { work, @@ -98,6 +102,7 @@ mod tests { }, solana_perf::packet::to_packet_batches, solana_poh::poh_recorder::{PohRecorder, WorkingBankEntry}, + solana_prio_graph_scheduler::deserializable_packet::DeserializableTxPacket, solana_runtime::bank::Bank, solana_sdk::{ genesis_config::GenesisConfig, poh_config::PohConfig, pubkey::Pubkey, @@ -119,8 +124,8 @@ mod tests { _entry_receiver: Receiver, _poh_simulator: JoinHandle<()>, - forward_sender: Sender, - forwarded_receiver: Receiver, + forward_sender: Sender, + forwarded_receiver: Receiver, } fn setup_test_frame() -> (TestFrame, ForwardWorker>) { diff --git a/core/src/banking_stage/immutable_deserialized_packet.rs b/core/src/banking_stage/immutable_deserialized_packet.rs index 7bb259494e4c31..2a29e6f3b10ac6 100644 --- a/core/src/banking_stage/immutable_deserialized_packet.rs +++ b/core/src/banking_stage/immutable_deserialized_packet.rs @@ -2,6 +2,7 @@ use { super::packet_filter::PacketFilterFailure, solana_compute_budget::compute_budget_limits::ComputeBudgetLimits, solana_perf::packet::Packet, + solana_prio_graph_scheduler::deserializable_packet::DeserializableTxPacket, solana_runtime_transaction::instructions_processor::process_compute_budget_instructions, solana_sanitize::SanitizeError, solana_sdk::{ @@ -49,8 +50,9 @@ pub struct ImmutableDeserializedPacket { pub compute_unit_limit: u32, } -impl ImmutableDeserializedPacket { - pub fn new(packet: Packet) -> Result { +impl DeserializableTxPacket for ImmutableDeserializedPacket { + type DeserializeError = DeserializedPacketError; + fn new(packet: Packet) -> Result { let versioned_transaction: VersionedTransaction = packet.deserialize_slice(..)?; let sanitized_transaction = SanitizedVersionedTransaction::try_from(versioned_transaction)?; let message_bytes = packet_message(&packet)?; @@ -85,33 +87,33 @@ impl ImmutableDeserializedPacket { }) } - pub fn original_packet(&self) -> &Packet { + fn original_packet(&self) -> &Packet { &self.original_packet } - pub fn transaction(&self) -> &SanitizedVersionedTransaction { + fn transaction(&self) -> &SanitizedVersionedTransaction { &self.transaction } - pub fn message_hash(&self) -> &Hash { + fn message_hash(&self) -> &Hash { &self.message_hash } - pub fn is_simple_vote(&self) -> bool { + fn is_simple_vote(&self) -> bool { self.is_simple_vote } - pub fn compute_unit_price(&self) -> u64 { + fn compute_unit_price(&self) -> u64 { self.compute_unit_price } - pub fn compute_unit_limit(&self) -> u64 { + fn compute_unit_limit(&self) -> u64 { u64::from(self.compute_unit_limit) } // This function deserializes packets into transactions, computes the blake3 hash of transaction // messages. - pub fn build_sanitized_transaction( + fn build_sanitized_transaction( &self, votes_only: bool, address_loader: impl AddressLoader, diff --git a/core/src/banking_stage/latest_unprocessed_votes.rs b/core/src/banking_stage/latest_unprocessed_votes.rs index bb97142bda5e81..63e3893893ad95 100644 --- a/core/src/banking_stage/latest_unprocessed_votes.rs +++ b/core/src/banking_stage/latest_unprocessed_votes.rs @@ -2,12 +2,7 @@ use { super::{ forward_packet_batches_by_accounts::ForwardPacketBatchesByAccounts, immutable_deserialized_packet::{DeserializedPacketError, ImmutableDeserializedPacket}, - }, - itertools::Itertools, - rand::{thread_rng, Rng}, - solana_perf::packet::Packet, - solana_runtime::{bank::Bank, epoch_stakes::EpochStakes}, - solana_sdk::{ + }, itertools::Itertools, rand::{thread_rng, Rng}, solana_perf::packet::Packet, solana_prio_graph_scheduler::deserializable_packet::DeserializableTxPacket, solana_runtime::{bank::Bank, epoch_stakes::EpochStakes}, solana_sdk::{ account::from_account, clock::{Slot, UnixTimestamp}, feature_set::{self}, @@ -16,9 +11,7 @@ use { pubkey::Pubkey, slot_hashes::SlotHashes, sysvar, - }, - solana_vote_program::vote_instruction::VoteInstruction, - std::{ + }, solana_vote_program::vote_instruction::VoteInstruction, std::{ cmp, collections::HashMap, ops::DerefMut, @@ -26,7 +19,7 @@ use { atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}, Arc, RwLock, }, - }, + } }; #[derive(PartialEq, Eq, Debug, Copy, Clone)] diff --git a/core/src/banking_stage/packet_deserializer.rs b/core/src/banking_stage/packet_deserializer.rs index 78fab3718252f4..f7acc13f56bf47 100644 --- a/core/src/banking_stage/packet_deserializer.rs +++ b/core/src/banking_stage/packet_deserializer.rs @@ -4,15 +4,10 @@ use { super::{ immutable_deserialized_packet::{DeserializedPacketError, ImmutableDeserializedPacket}, packet_filter::PacketFilterFailure, - }, - crate::{ + }, crate::{ banking_trace::{BankingPacketBatch, BankingPacketReceiver}, sigverify::SigverifyTracerPacketStats, - }, - crossbeam_channel::RecvTimeoutError, - solana_perf::packet::PacketBatch, - solana_sdk::saturating_add_assign, - std::time::{Duration, Instant}, + }, crossbeam_channel::RecvTimeoutError, solana_perf::packet::PacketBatch, solana_prio_graph_scheduler::deserializable_packet::DeserializableTxPacket, solana_sdk::saturating_add_assign, std::time::{Duration, Instant} }; /// Results from deserializing packet batches. diff --git a/core/src/banking_stage/packet_filter.rs b/core/src/banking_stage/packet_filter.rs index 4c38d70762e35e..7b480f3997ba56 100644 --- a/core/src/banking_stage/packet_filter.rs +++ b/core/src/banking_stage/packet_filter.rs @@ -3,6 +3,7 @@ use { solana_builtins_default_costs::BUILTIN_INSTRUCTION_COSTS, solana_sdk::{ed25519_program, saturating_add_assign, secp256k1_program}, thiserror::Error, + solana_prio_graph_scheduler::deserializable_packet::DeserializableTxPacket, }; #[derive(Debug, Error, PartialEq)] diff --git a/core/src/banking_stage/packet_receiver.rs b/core/src/banking_stage/packet_receiver.rs index 6b77d103c69670..9288d1e4ed4b17 100644 --- a/core/src/banking_stage/packet_receiver.rs +++ b/core/src/banking_stage/packet_receiver.rs @@ -9,6 +9,7 @@ use { crate::{banking_trace::BankingPacketReceiver, tracer_packet_stats::TracerPacketStats}, crossbeam_channel::RecvTimeoutError, solana_measure::{measure::Measure, measure_us}, + solana_prio_graph_scheduler::deserializable_packet::DeserializableTxPacket, solana_sdk::{saturating_add_assign, timing::timestamp}, std::{sync::atomic::Ordering, time::Duration}, }; diff --git a/core/src/banking_stage/transaction_scheduler/scheduler_controller.rs b/core/src/banking_stage/scheduler_controller.rs similarity index 98% rename from core/src/banking_stage/transaction_scheduler/scheduler_controller.rs rename to core/src/banking_stage/scheduler_controller.rs index 995b1a5782702b..cd5651d069fc23 100644 --- a/core/src/banking_stage/transaction_scheduler/scheduler_controller.rs +++ b/core/src/banking_stage/scheduler_controller.rs @@ -2,16 +2,6 @@ //! use { - super::{ - prio_graph_scheduler::PrioGraphScheduler, - scheduler_error::SchedulerError, - scheduler_metrics::{ - SchedulerCountMetrics, SchedulerLeaderDetectionMetrics, SchedulerTimingMetrics, - }, - transaction_id_generator::TransactionIdGenerator, - transaction_state::SanitizedTransactionTTL, - transaction_state_container::TransactionStateContainer, - }, crate::banking_stage::{ consume_worker::ConsumeWorkerMetrics, consumer::Consumer, @@ -26,6 +16,17 @@ use { solana_accounts_db::account_locks::validate_account_locks, solana_cost_model::cost_model::CostModel, solana_measure::measure_us, + solana_prio_graph_scheduler::deserializable_packet::DeserializableTxPacket, + solana_prio_graph_scheduler::{ + id_generator::IdGenerator, + prio_graph_scheduler::PrioGraphScheduler, + scheduler_error::SchedulerError, + scheduler_metrics::{ + SchedulerCountMetrics, SchedulerLeaderDetectionMetrics, SchedulerTimingMetrics, + }, + transaction_state::SanitizedTransactionTTL, + transaction_state_container::TransactionStateContainer, + }, solana_runtime::{bank::Bank, bank_forks::BankForks}, solana_runtime_transaction::instructions_processor::process_compute_budget_instructions, solana_sdk::{ @@ -51,12 +52,12 @@ pub(crate) struct SchedulerController { packet_receiver: PacketDeserializer, bank_forks: Arc>, /// Generates unique IDs for incoming transactions. - transaction_id_generator: TransactionIdGenerator, + transaction_id_generator: IdGenerator, /// Container for transaction state. /// Shared resource between `packet_receiver` and `scheduler`. - container: TransactionStateContainer, + container: TransactionStateContainer, /// State for scheduling and communicating with worker threads. - scheduler: PrioGraphScheduler, + scheduler: PrioGraphScheduler, /// Metrics tracking time for leader bank detection. leader_detection_metrics: SchedulerLeaderDetectionMetrics, /// Metrics tracking counts on transactions in different states @@ -76,7 +77,7 @@ impl SchedulerController { decision_maker: DecisionMaker, packet_deserializer: PacketDeserializer, bank_forks: Arc>, - scheduler: PrioGraphScheduler, + scheduler: PrioGraphScheduler, worker_metrics: Vec>, forwarder: Option>, ) -> Self { @@ -84,8 +85,10 @@ impl SchedulerController { decision_maker, packet_receiver: packet_deserializer, bank_forks, - transaction_id_generator: TransactionIdGenerator::default(), - container: TransactionStateContainer::with_capacity(TOTAL_BUFFERED_PACKETS), + transaction_id_generator: IdGenerator::default(), + container: TransactionStateContainer::::with_capacity( + TOTAL_BUFFERED_PACKETS, + ), scheduler, leader_detection_metrics: SchedulerLeaderDetectionMetrics::default(), count_metrics: SchedulerCountMetrics::default(), @@ -661,9 +664,7 @@ mod tests { super::*, crate::{ banking_stage::{ - consumer::TARGET_NUM_TRANSACTIONS_PER_BATCH, - scheduler_messages::{ConsumeWork, FinishedConsumeWork, TransactionBatchId}, - tests::create_slow_genesis_config, + consumer::TARGET_NUM_TRANSACTIONS_PER_BATCH, tests::create_slow_genesis_config, }, banking_trace::BankingPacketBatch, sigverify::SigverifyTracerPacketStats, @@ -677,6 +678,9 @@ mod tests { }, solana_perf::packet::{to_packet_batches, PacketBatch, NUM_PACKETS}, solana_poh::poh_recorder::{PohRecorder, Record, WorkingBankEntry}, + solana_prio_graph_scheduler::scheduler_messages::{ + ConsumeWork, FinishedConsumeWork, TransactionBatchId, + }, solana_runtime::bank::Bank, solana_sdk::{ compute_budget::ComputeBudgetInstruction, fee_calculator::FeeRateGovernor, hash::Hash, diff --git a/core/src/banking_stage/scheduler_messages.rs b/core/src/banking_stage/scheduler_messages.rs deleted file mode 100644 index ee5c4ebeef9738..00000000000000 --- a/core/src/banking_stage/scheduler_messages.rs +++ /dev/null @@ -1,66 +0,0 @@ -use { - super::immutable_deserialized_packet::ImmutableDeserializedPacket, - solana_sdk::{clock::Slot, transaction::SanitizedTransaction}, - std::{fmt::Display, sync::Arc}, -}; - -/// A unique identifier for a transaction batch. -#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] -pub struct TransactionBatchId(u64); - -impl TransactionBatchId { - pub fn new(index: u64) -> Self { - Self(index) - } -} - -impl Display for TransactionBatchId { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} - -/// A unique identifier for a transaction. -#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] -pub struct TransactionId(u64); - -impl TransactionId { - pub fn new(index: u64) -> Self { - Self(index) - } -} - -impl Display for TransactionId { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} - -/// Message: [Scheduler -> Worker] -/// Transactions to be consumed (i.e. executed, recorded, and committed) -pub struct ConsumeWork { - pub batch_id: TransactionBatchId, - pub ids: Vec, - pub transactions: Vec, - pub max_age_slots: Vec, -} - -/// Message: [Scheduler -> Worker] -/// Transactions to be forwarded to the next leader(s) -pub struct ForwardWork { - pub packets: Vec>, -} - -/// Message: [Worker -> Scheduler] -/// Processed transactions. -pub struct FinishedConsumeWork { - pub work: ConsumeWork, - pub retryable_indexes: Vec, -} - -/// Message: [Worker -> Scheduler] -/// Forwarded transactions. -pub struct FinishedForwardWork { - pub work: ForwardWork, - pub successful: bool, -} diff --git a/core/src/banking_stage/transaction_scheduler/batch_id_generator.rs b/core/src/banking_stage/transaction_scheduler/batch_id_generator.rs deleted file mode 100644 index 6effc80f8537b4..00000000000000 --- a/core/src/banking_stage/transaction_scheduler/batch_id_generator.rs +++ /dev/null @@ -1,14 +0,0 @@ -use crate::banking_stage::scheduler_messages::TransactionBatchId; - -#[derive(Default)] -pub struct BatchIdGenerator { - next_id: u64, -} - -impl BatchIdGenerator { - pub fn next(&mut self) -> TransactionBatchId { - let id = self.next_id; - self.next_id = self.next_id.wrapping_sub(1); - TransactionBatchId::new(id) - } -} diff --git a/core/src/banking_stage/transaction_scheduler/in_flight_tracker.rs b/core/src/banking_stage/transaction_scheduler/in_flight_tracker.rs deleted file mode 100644 index 243f14c66920a0..00000000000000 --- a/core/src/banking_stage/transaction_scheduler/in_flight_tracker.rs +++ /dev/null @@ -1,123 +0,0 @@ -use { - super::{batch_id_generator::BatchIdGenerator, thread_aware_account_locks::ThreadId}, - crate::banking_stage::scheduler_messages::TransactionBatchId, - std::collections::HashMap, -}; - -/// Tracks the number of transactions that are in flight for each thread. -pub struct InFlightTracker { - num_in_flight_per_thread: Vec, - cus_in_flight_per_thread: Vec, - batches: HashMap, - batch_id_generator: BatchIdGenerator, -} - -struct BatchEntry { - thread_id: ThreadId, - num_transactions: usize, - total_cus: u64, -} - -impl InFlightTracker { - pub fn new(num_threads: usize) -> Self { - Self { - num_in_flight_per_thread: vec![0; num_threads], - cus_in_flight_per_thread: vec![0; num_threads], - batches: HashMap::new(), - batch_id_generator: BatchIdGenerator::default(), - } - } - - /// Returns the number of transactions that are in flight for each thread. - pub fn num_in_flight_per_thread(&self) -> &[usize] { - &self.num_in_flight_per_thread - } - - /// Returns the number of cus that are in flight for each thread. - pub fn cus_in_flight_per_thread(&self) -> &[u64] { - &self.cus_in_flight_per_thread - } - - /// Tracks number of transactions and CUs in-flight for the `thread_id`. - /// Returns a `TransactionBatchId` that can be used to stop tracking the batch - /// when it is complete. - pub fn track_batch( - &mut self, - num_transactions: usize, - total_cus: u64, - thread_id: ThreadId, - ) -> TransactionBatchId { - let batch_id = self.batch_id_generator.next(); - self.num_in_flight_per_thread[thread_id] += num_transactions; - self.cus_in_flight_per_thread[thread_id] += total_cus; - self.batches.insert( - batch_id, - BatchEntry { - thread_id, - num_transactions, - total_cus, - }, - ); - - batch_id - } - - /// Stop tracking the batch with given `batch_id`. - /// Removes the number of transactions for the scheduled thread. - /// Returns the thread id that the batch was scheduled on. - /// - /// # Panics - /// Panics if the batch id does not exist in the tracker. - pub fn complete_batch(&mut self, batch_id: TransactionBatchId) -> ThreadId { - let Some(BatchEntry { - thread_id, - num_transactions, - total_cus, - }) = self.batches.remove(&batch_id) - else { - panic!("batch id {batch_id} is not being tracked"); - }; - self.num_in_flight_per_thread[thread_id] -= num_transactions; - self.cus_in_flight_per_thread[thread_id] -= total_cus; - - thread_id - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - #[should_panic(expected = "is not being tracked")] - fn test_in_flight_tracker_untracked_batch() { - let mut in_flight_tracker = InFlightTracker::new(2); - in_flight_tracker.complete_batch(TransactionBatchId::new(5)); - } - - #[test] - fn test_in_flight_tracker() { - let mut in_flight_tracker = InFlightTracker::new(2); - - // Add a batch with 2 transactions, 10 kCUs to thread 0. - let batch_id_0 = in_flight_tracker.track_batch(2, 10_000, 0); - assert_eq!(in_flight_tracker.num_in_flight_per_thread(), &[2, 0]); - assert_eq!(in_flight_tracker.cus_in_flight_per_thread(), &[10_000, 0]); - - // Add a batch with 1 transaction, 15 kCUs to thread 1. - let batch_id_1 = in_flight_tracker.track_batch(1, 15_000, 1); - assert_eq!(in_flight_tracker.num_in_flight_per_thread(), &[2, 1]); - assert_eq!( - in_flight_tracker.cus_in_flight_per_thread(), - &[10_000, 15_000] - ); - - in_flight_tracker.complete_batch(batch_id_0); - assert_eq!(in_flight_tracker.num_in_flight_per_thread(), &[0, 1]); - assert_eq!(in_flight_tracker.cus_in_flight_per_thread(), &[0, 15_000]); - - in_flight_tracker.complete_batch(batch_id_1); - assert_eq!(in_flight_tracker.num_in_flight_per_thread(), &[0, 0]); - assert_eq!(in_flight_tracker.cus_in_flight_per_thread(), &[0, 0]); - } -} diff --git a/core/src/banking_stage/transaction_scheduler/mod.rs b/core/src/banking_stage/transaction_scheduler/mod.rs deleted file mode 100644 index 17991b762eb104..00000000000000 --- a/core/src/banking_stage/transaction_scheduler/mod.rs +++ /dev/null @@ -1,11 +0,0 @@ -mod batch_id_generator; -mod in_flight_tracker; -pub(crate) mod prio_graph_scheduler; -pub(crate) mod scheduler_controller; -pub(crate) mod scheduler_error; -mod scheduler_metrics; -mod thread_aware_account_locks; -mod transaction_id_generator; -mod transaction_priority_id; -mod transaction_state; -mod transaction_state_container; diff --git a/core/src/banking_stage/transaction_scheduler/prio_graph_scheduler.rs b/core/src/banking_stage/transaction_scheduler/prio_graph_scheduler.rs deleted file mode 100644 index 59ce92173ed26e..00000000000000 --- a/core/src/banking_stage/transaction_scheduler/prio_graph_scheduler.rs +++ /dev/null @@ -1,907 +0,0 @@ -use { - super::{ - in_flight_tracker::InFlightTracker, - scheduler_error::SchedulerError, - thread_aware_account_locks::{ThreadAwareAccountLocks, ThreadId, ThreadSet}, - transaction_state::SanitizedTransactionTTL, - transaction_state_container::TransactionStateContainer, - }, - crate::banking_stage::{ - consumer::TARGET_NUM_TRANSACTIONS_PER_BATCH, - read_write_account_set::ReadWriteAccountSet, - scheduler_messages::{ConsumeWork, FinishedConsumeWork, TransactionBatchId, TransactionId}, - transaction_scheduler::{ - transaction_priority_id::TransactionPriorityId, transaction_state::TransactionState, - }, - }, - crossbeam_channel::{Receiver, Sender, TryRecvError}, - itertools::izip, - prio_graph::{AccessKind, PrioGraph}, - solana_cost_model::block_cost_limits::MAX_BLOCK_UNITS, - solana_measure::measure_us, - solana_sdk::{ - pubkey::Pubkey, saturating_add_assign, slot_history::Slot, - transaction::SanitizedTransaction, - }, -}; - -pub(crate) struct PrioGraphScheduler { - in_flight_tracker: InFlightTracker, - account_locks: ThreadAwareAccountLocks, - consume_work_senders: Vec>, - finished_consume_work_receiver: Receiver, - look_ahead_window_size: usize, -} - -impl PrioGraphScheduler { - pub(crate) fn new( - consume_work_senders: Vec>, - finished_consume_work_receiver: Receiver, - ) -> Self { - let num_threads = consume_work_senders.len(); - Self { - in_flight_tracker: InFlightTracker::new(num_threads), - account_locks: ThreadAwareAccountLocks::new(num_threads), - consume_work_senders, - finished_consume_work_receiver, - look_ahead_window_size: 2048, - } - } - - /// Schedule transactions from the given `TransactionStateContainer` to be - /// consumed by the worker threads. Returns summary of scheduling, or an - /// error. - /// `pre_graph_filter` is used to filter out transactions that should be - /// skipped and dropped before insertion to the prio-graph. This fn should - /// set `false` for transactions that should be dropped, and `true` - /// otherwise. - /// `pre_lock_filter` is used to filter out transactions after they have - /// made it to the top of the prio-graph, and immediately before locks are - /// checked and taken. This fn should return `true` for transactions that - /// should be scheduled, and `false` otherwise. - /// - /// Uses a `PrioGraph` to perform look-ahead during the scheduling of transactions. - /// This, combined with internal tracking of threads' in-flight transactions, allows - /// for load-balancing while prioritizing scheduling transactions onto threads that will - /// not cause conflicts in the near future. - pub(crate) fn schedule( - &mut self, - container: &mut TransactionStateContainer, - pre_graph_filter: impl Fn(&[&SanitizedTransaction], &mut [bool]), - pre_lock_filter: impl Fn(&SanitizedTransaction) -> bool, - ) -> Result { - let num_threads = self.consume_work_senders.len(); - let max_cu_per_thread = MAX_BLOCK_UNITS / num_threads as u64; - - let mut schedulable_threads = ThreadSet::any(num_threads); - for thread_id in 0..num_threads { - if self.in_flight_tracker.cus_in_flight_per_thread()[thread_id] >= max_cu_per_thread { - schedulable_threads.remove(thread_id); - } - } - if schedulable_threads.is_empty() { - return Ok(SchedulingSummary { - num_scheduled: 0, - num_unschedulable: 0, - num_filtered_out: 0, - filter_time_us: 0, - }); - } - - let mut batches = Batches::new(num_threads); - // Some transactions may be unschedulable due to multi-thread conflicts. - // These transactions cannot be scheduled until some conflicting work is completed. - // However, the scheduler should not allow other transactions that conflict with - // these transactions to be scheduled before them. - let mut unschedulable_ids = Vec::new(); - let mut blocking_locks = ReadWriteAccountSet::default(); - let mut prio_graph = PrioGraph::new(|id: &TransactionPriorityId, _graph_node| *id); - - // Track metrics on filter. - let mut num_filtered_out: usize = 0; - let mut total_filter_time_us: u64 = 0; - - let mut window_budget = self.look_ahead_window_size; - let mut chunked_pops = |container: &mut TransactionStateContainer, - prio_graph: &mut PrioGraph<_, _, _, _>, - window_budget: &mut usize| { - while *window_budget > 0 { - const MAX_FILTER_CHUNK_SIZE: usize = 128; - let mut filter_array = [true; MAX_FILTER_CHUNK_SIZE]; - let mut ids = Vec::with_capacity(MAX_FILTER_CHUNK_SIZE); - let mut txs = Vec::with_capacity(MAX_FILTER_CHUNK_SIZE); - - let chunk_size = (*window_budget).min(MAX_FILTER_CHUNK_SIZE); - for _ in 0..chunk_size { - if let Some(id) = container.pop() { - ids.push(id); - } else { - break; - } - } - *window_budget = window_budget.saturating_sub(chunk_size); - - ids.iter().for_each(|id| { - let transaction = container.get_transaction_ttl(&id.id).unwrap(); - txs.push(&transaction.transaction); - }); - - let (_, filter_us) = - measure_us!(pre_graph_filter(&txs, &mut filter_array[..chunk_size])); - saturating_add_assign!(total_filter_time_us, filter_us); - - for (id, filter_result) in ids.iter().zip(&filter_array[..chunk_size]) { - if *filter_result { - let transaction = container.get_transaction_ttl(&id.id).unwrap(); - prio_graph.insert_transaction( - *id, - Self::get_transaction_account_access(transaction), - ); - } else { - saturating_add_assign!(num_filtered_out, 1); - container.remove_by_id(&id.id); - } - } - - if ids.len() != chunk_size { - break; - } - } - }; - - // Create the initial look-ahead window. - // Check transactions against filter, remove from container if it fails. - chunked_pops(container, &mut prio_graph, &mut window_budget); - - let mut unblock_this_batch = - Vec::with_capacity(self.consume_work_senders.len() * TARGET_NUM_TRANSACTIONS_PER_BATCH); - const MAX_TRANSACTIONS_PER_SCHEDULING_PASS: usize = 100_000; - let mut num_scheduled: usize = 0; - let mut num_sent: usize = 0; - let mut num_unschedulable: usize = 0; - while num_scheduled < MAX_TRANSACTIONS_PER_SCHEDULING_PASS { - // If nothing is in the main-queue of the `PrioGraph` then there's nothing left to schedule. - if prio_graph.is_empty() { - break; - } - - while let Some(id) = prio_graph.pop() { - unblock_this_batch.push(id); - - // Should always be in the container, during initial testing phase panic. - // Later, we can replace with a continue in case this does happen. - let Some(transaction_state) = container.get_mut_transaction_state(&id.id) else { - panic!("transaction state must exist") - }; - - let maybe_schedule_info = try_schedule_transaction( - transaction_state, - &pre_lock_filter, - &mut blocking_locks, - &mut self.account_locks, - num_threads, - |thread_set| { - Self::select_thread( - thread_set, - &batches.total_cus, - self.in_flight_tracker.cus_in_flight_per_thread(), - &batches.transactions, - self.in_flight_tracker.num_in_flight_per_thread(), - ) - }, - ); - - match maybe_schedule_info { - Err(TransactionSchedulingError::Filtered) => { - container.remove_by_id(&id.id); - } - Err(TransactionSchedulingError::UnschedulableConflicts) => { - unschedulable_ids.push(id); - saturating_add_assign!(num_unschedulable, 1); - } - Ok(TransactionSchedulingInfo { - thread_id, - transaction, - max_age_slot, - cost, - }) => { - saturating_add_assign!(num_scheduled, 1); - batches.transactions[thread_id].push(transaction); - batches.ids[thread_id].push(id.id); - batches.max_age_slots[thread_id].push(max_age_slot); - saturating_add_assign!(batches.total_cus[thread_id], cost); - - // If target batch size is reached, send only this batch. - if batches.ids[thread_id].len() >= TARGET_NUM_TRANSACTIONS_PER_BATCH { - saturating_add_assign!( - num_sent, - self.send_batch(&mut batches, thread_id)? - ); - } - - // if the thread is at max_cu_per_thread, remove it from the schedulable threads - // if there are no more schedulable threads, stop scheduling. - if self.in_flight_tracker.cus_in_flight_per_thread()[thread_id] - + batches.total_cus[thread_id] - >= max_cu_per_thread - { - schedulable_threads.remove(thread_id); - if schedulable_threads.is_empty() { - break; - } - } - - if num_scheduled >= MAX_TRANSACTIONS_PER_SCHEDULING_PASS { - break; - } - } - } - } - - // Send all non-empty batches - saturating_add_assign!(num_sent, self.send_batches(&mut batches)?); - - // Refresh window budget and do chunked pops - saturating_add_assign!(window_budget, unblock_this_batch.len()); - chunked_pops(container, &mut prio_graph, &mut window_budget); - - // Unblock all transactions that were blocked by the transactions that were just sent. - for id in unblock_this_batch.drain(..) { - prio_graph.unblock(&id); - } - } - - // Send batches for any remaining transactions - saturating_add_assign!(num_sent, self.send_batches(&mut batches)?); - - // Push unschedulable ids back into the container - for id in unschedulable_ids { - container.push_id_into_queue(id); - } - - // Push remaining transactions back into the container - while let Some((id, _)) = prio_graph.pop_and_unblock() { - container.push_id_into_queue(id); - } - - assert_eq!( - num_scheduled, num_sent, - "number of scheduled and sent transactions must match" - ); - - Ok(SchedulingSummary { - num_scheduled, - num_unschedulable, - num_filtered_out, - filter_time_us: total_filter_time_us, - }) - } - - /// Receive completed batches of transactions without blocking. - /// Returns (num_transactions, num_retryable_transactions) on success. - pub fn receive_completed( - &mut self, - container: &mut TransactionStateContainer, - ) -> Result<(usize, usize), SchedulerError> { - let mut total_num_transactions: usize = 0; - let mut total_num_retryable: usize = 0; - loop { - let (num_transactions, num_retryable) = self.try_receive_completed(container)?; - if num_transactions == 0 { - break; - } - saturating_add_assign!(total_num_transactions, num_transactions); - saturating_add_assign!(total_num_retryable, num_retryable); - } - Ok((total_num_transactions, total_num_retryable)) - } - - /// Receive completed batches of transactions. - /// Returns `Ok((num_transactions, num_retryable))` if a batch was received, `Ok((0, 0))` if no batch was received. - fn try_receive_completed( - &mut self, - container: &mut TransactionStateContainer, - ) -> Result<(usize, usize), SchedulerError> { - match self.finished_consume_work_receiver.try_recv() { - Ok(FinishedConsumeWork { - work: - ConsumeWork { - batch_id, - ids, - transactions, - max_age_slots, - }, - retryable_indexes, - }) => { - let num_transactions = ids.len(); - let num_retryable = retryable_indexes.len(); - - // Free the locks - self.complete_batch(batch_id, &transactions); - - // Retryable transactions should be inserted back into the container - let mut retryable_iter = retryable_indexes.into_iter().peekable(); - for (index, (id, transaction, max_age_slot)) in - izip!(ids, transactions, max_age_slots).enumerate() - { - if let Some(retryable_index) = retryable_iter.peek() { - if *retryable_index == index { - container.retry_transaction( - id, - SanitizedTransactionTTL { - transaction, - max_age_slot, - }, - ); - retryable_iter.next(); - continue; - } - } - container.remove_by_id(&id); - } - - Ok((num_transactions, num_retryable)) - } - Err(TryRecvError::Empty) => Ok((0, 0)), - Err(TryRecvError::Disconnected) => Err(SchedulerError::DisconnectedRecvChannel( - "finished consume work", - )), - } - } - - /// Mark a given `TransactionBatchId` as completed. - /// This will update the internal tracking, including account locks. - fn complete_batch( - &mut self, - batch_id: TransactionBatchId, - transactions: &[SanitizedTransaction], - ) { - let thread_id = self.in_flight_tracker.complete_batch(batch_id); - for transaction in transactions { - let message = transaction.message(); - let account_keys = message.account_keys(); - let write_account_locks = account_keys - .iter() - .enumerate() - .filter_map(|(index, key)| message.is_writable(index).then_some(key)); - let read_account_locks = account_keys - .iter() - .enumerate() - .filter_map(|(index, key)| (!message.is_writable(index)).then_some(key)); - self.account_locks - .unlock_accounts(write_account_locks, read_account_locks, thread_id); - } - } - - /// Send all batches of transactions to the worker threads. - /// Returns the number of transactions sent. - fn send_batches(&mut self, batches: &mut Batches) -> Result { - (0..self.consume_work_senders.len()) - .map(|thread_index| self.send_batch(batches, thread_index)) - .sum() - } - - /// Send a batch of transactions to the given thread's `ConsumeWork` channel. - /// Returns the number of transactions sent. - fn send_batch( - &mut self, - batches: &mut Batches, - thread_index: usize, - ) -> Result { - if batches.ids[thread_index].is_empty() { - return Ok(0); - } - - let (ids, transactions, max_age_slots, total_cus) = batches.take_batch(thread_index); - - let batch_id = self - .in_flight_tracker - .track_batch(ids.len(), total_cus, thread_index); - - let num_scheduled = ids.len(); - let work = ConsumeWork { - batch_id, - ids, - transactions, - max_age_slots, - }; - self.consume_work_senders[thread_index] - .send(work) - .map_err(|_| SchedulerError::DisconnectedSendChannel("consume work sender"))?; - - Ok(num_scheduled) - } - - /// Given the schedulable `thread_set`, select the thread with the least amount - /// of work queued up. - /// Currently, "work" is just defined as the number of transactions. - /// - /// If the `chain_thread` is available, this thread will be selected, regardless of - /// load-balancing. - /// - /// Panics if the `thread_set` is empty. This should never happen, see comment - /// on `ThreadAwareAccountLocks::try_lock_accounts`. - fn select_thread( - thread_set: ThreadSet, - batch_cus_per_thread: &[u64], - in_flight_cus_per_thread: &[u64], - batches_per_thread: &[Vec], - in_flight_per_thread: &[usize], - ) -> ThreadId { - thread_set - .contained_threads_iter() - .map(|thread_id| { - ( - thread_id, - batch_cus_per_thread[thread_id] + in_flight_cus_per_thread[thread_id], - batches_per_thread[thread_id].len() + in_flight_per_thread[thread_id], - ) - }) - .min_by(|a, b| a.1.cmp(&b.1).then_with(|| a.2.cmp(&b.2))) - .map(|(thread_id, _, _)| thread_id) - .unwrap() - } - - /// Gets accessed accounts (resources) for use in `PrioGraph`. - fn get_transaction_account_access( - transaction: &SanitizedTransactionTTL, - ) -> impl Iterator + '_ { - let message = transaction.transaction.message(); - message - .account_keys() - .iter() - .enumerate() - .map(|(index, key)| { - if message.is_writable(index) { - (*key, AccessKind::Write) - } else { - (*key, AccessKind::Read) - } - }) - } -} - -/// Metrics from scheduling transactions. -#[derive(Debug, PartialEq, Eq)] -pub(crate) struct SchedulingSummary { - /// Number of transactions scheduled. - pub num_scheduled: usize, - /// Number of transactions that were not scheduled due to conflicts. - pub num_unschedulable: usize, - /// Number of transactions that were dropped due to filter. - pub num_filtered_out: usize, - /// Time spent filtering transactions - pub filter_time_us: u64, -} - -struct Batches { - ids: Vec>, - transactions: Vec>, - max_age_slots: Vec>, - total_cus: Vec, -} - -impl Batches { - fn new(num_threads: usize) -> Self { - Self { - ids: vec![Vec::with_capacity(TARGET_NUM_TRANSACTIONS_PER_BATCH); num_threads], - transactions: vec![Vec::with_capacity(TARGET_NUM_TRANSACTIONS_PER_BATCH); num_threads], - max_age_slots: vec![Vec::with_capacity(TARGET_NUM_TRANSACTIONS_PER_BATCH); num_threads], - total_cus: vec![0; num_threads], - } - } - - fn take_batch( - &mut self, - thread_id: ThreadId, - ) -> ( - Vec, - Vec, - Vec, - u64, - ) { - ( - core::mem::replace( - &mut self.ids[thread_id], - Vec::with_capacity(TARGET_NUM_TRANSACTIONS_PER_BATCH), - ), - core::mem::replace( - &mut self.transactions[thread_id], - Vec::with_capacity(TARGET_NUM_TRANSACTIONS_PER_BATCH), - ), - core::mem::replace( - &mut self.max_age_slots[thread_id], - Vec::with_capacity(TARGET_NUM_TRANSACTIONS_PER_BATCH), - ), - core::mem::replace(&mut self.total_cus[thread_id], 0), - ) - } -} - -/// A transaction has been scheduled to a thread. -struct TransactionSchedulingInfo { - thread_id: ThreadId, - transaction: SanitizedTransaction, - max_age_slot: Slot, - cost: u64, -} - -/// Error type for reasons a transaction could not be scheduled. -enum TransactionSchedulingError { - /// Transaction was filtered out before locking. - Filtered, - /// Transaction cannot be scheduled due to conflicts, or - /// higher priority conflicting transactions are unschedulable. - UnschedulableConflicts, -} - -fn try_schedule_transaction( - transaction_state: &mut TransactionState, - pre_lock_filter: impl Fn(&SanitizedTransaction) -> bool, - blocking_locks: &mut ReadWriteAccountSet, - account_locks: &mut ThreadAwareAccountLocks, - num_threads: usize, - thread_selector: impl Fn(ThreadSet) -> ThreadId, -) -> Result { - let transaction = &transaction_state.transaction_ttl().transaction; - if !pre_lock_filter(transaction) { - return Err(TransactionSchedulingError::Filtered); - } - - // Check if this transaction conflicts with any blocked transactions - let message = transaction.message(); - if !blocking_locks.check_locks(message) { - blocking_locks.take_locks(message); - return Err(TransactionSchedulingError::UnschedulableConflicts); - } - - // Schedule the transaction if it can be. - let message = transaction.message(); - let account_keys = message.account_keys(); - let write_account_locks = account_keys - .iter() - .enumerate() - .filter_map(|(index, key)| message.is_writable(index).then_some(key)); - let read_account_locks = account_keys - .iter() - .enumerate() - .filter_map(|(index, key)| (!message.is_writable(index)).then_some(key)); - - let Some(thread_id) = account_locks.try_lock_accounts( - write_account_locks, - read_account_locks, - ThreadSet::any(num_threads), - thread_selector, - ) else { - blocking_locks.take_locks(message); - return Err(TransactionSchedulingError::UnschedulableConflicts); - }; - - let sanitized_transaction_ttl = transaction_state.transition_to_pending(); - let cost = transaction_state.cost(); - - Ok(TransactionSchedulingInfo { - thread_id, - transaction: sanitized_transaction_ttl.transaction, - max_age_slot: sanitized_transaction_ttl.max_age_slot, - cost, - }) -} - -#[cfg(test)] -mod tests { - use { - super::*, - crate::banking_stage::{ - consumer::TARGET_NUM_TRANSACTIONS_PER_BATCH, - immutable_deserialized_packet::ImmutableDeserializedPacket, - }, - crossbeam_channel::{unbounded, Receiver}, - itertools::Itertools, - solana_sdk::{ - compute_budget::ComputeBudgetInstruction, hash::Hash, message::Message, packet::Packet, - pubkey::Pubkey, signature::Keypair, signer::Signer, system_instruction, - transaction::Transaction, - }, - std::{borrow::Borrow, sync::Arc}, - }; - - macro_rules! txid { - ($value:expr) => { - TransactionId::new($value) - }; - } - - macro_rules! txids { - ([$($element:expr),*]) => { - vec![ $(txid!($element)),* ] - }; - } - - fn create_test_frame( - num_threads: usize, - ) -> ( - PrioGraphScheduler, - Vec>, - Sender, - ) { - let (consume_work_senders, consume_work_receivers) = - (0..num_threads).map(|_| unbounded()).unzip(); - let (finished_consume_work_sender, finished_consume_work_receiver) = unbounded(); - let scheduler = - PrioGraphScheduler::new(consume_work_senders, finished_consume_work_receiver); - ( - scheduler, - consume_work_receivers, - finished_consume_work_sender, - ) - } - - fn prioritized_tranfers( - from_keypair: &Keypair, - to_pubkeys: impl IntoIterator>, - lamports: u64, - priority: u64, - ) -> SanitizedTransaction { - let to_pubkeys_lamports = to_pubkeys - .into_iter() - .map(|pubkey| *pubkey.borrow()) - .zip(std::iter::repeat(lamports)) - .collect_vec(); - let mut ixs = - system_instruction::transfer_many(&from_keypair.pubkey(), &to_pubkeys_lamports); - let prioritization = ComputeBudgetInstruction::set_compute_unit_price(priority); - ixs.push(prioritization); - let message = Message::new(&ixs, Some(&from_keypair.pubkey())); - let tx = Transaction::new(&[from_keypair], message, Hash::default()); - SanitizedTransaction::from_transaction_for_tests(tx) - } - - fn create_container( - tx_infos: impl IntoIterator< - Item = ( - impl Borrow, - impl IntoIterator>, - u64, - u64, - ), - >, - ) -> TransactionStateContainer { - let mut container = TransactionStateContainer::with_capacity(10 * 1024); - for (index, (from_keypair, to_pubkeys, lamports, compute_unit_price)) in - tx_infos.into_iter().enumerate() - { - let id = TransactionId::new(index as u64); - let transaction = prioritized_tranfers( - from_keypair.borrow(), - to_pubkeys, - lamports, - compute_unit_price, - ); - let packet = Arc::new( - ImmutableDeserializedPacket::new( - Packet::from_data(None, transaction.to_versioned_transaction()).unwrap(), - ) - .unwrap(), - ); - let transaction_ttl = SanitizedTransactionTTL { - transaction, - max_age_slot: Slot::MAX, - }; - const TEST_TRANSACTION_COST: u64 = 5000; - container.insert_new_transaction( - id, - transaction_ttl, - packet, - compute_unit_price, - TEST_TRANSACTION_COST, - ); - } - - container - } - - fn collect_work( - receiver: &Receiver, - ) -> (Vec, Vec>) { - receiver - .try_iter() - .map(|work| { - let ids = work.ids.clone(); - (work, ids) - }) - .unzip() - } - - fn test_pre_graph_filter(_txs: &[&SanitizedTransaction], results: &mut [bool]) { - results.fill(true); - } - - fn test_pre_lock_filter(_tx: &SanitizedTransaction) -> bool { - true - } - - #[test] - fn test_schedule_disconnected_channel() { - let (mut scheduler, work_receivers, _finished_work_sender) = create_test_frame(1); - let mut container = create_container([(&Keypair::new(), &[Pubkey::new_unique()], 1, 1)]); - - drop(work_receivers); // explicitly drop receivers - assert_matches!( - scheduler.schedule(&mut container, test_pre_graph_filter, test_pre_lock_filter), - Err(SchedulerError::DisconnectedSendChannel(_)) - ); - } - - #[test] - fn test_schedule_single_threaded_no_conflicts() { - let (mut scheduler, work_receivers, _finished_work_sender) = create_test_frame(1); - let mut container = create_container([ - (&Keypair::new(), &[Pubkey::new_unique()], 1, 1), - (&Keypair::new(), &[Pubkey::new_unique()], 2, 2), - ]); - - let scheduling_summary = scheduler - .schedule(&mut container, test_pre_graph_filter, test_pre_lock_filter) - .unwrap(); - assert_eq!(scheduling_summary.num_scheduled, 2); - assert_eq!(scheduling_summary.num_unschedulable, 0); - assert_eq!(collect_work(&work_receivers[0]).1, vec![txids!([1, 0])]); - } - - #[test] - fn test_schedule_single_threaded_conflict() { - let (mut scheduler, work_receivers, _finished_work_sender) = create_test_frame(1); - let pubkey = Pubkey::new_unique(); - let mut container = create_container([ - (&Keypair::new(), &[pubkey], 1, 1), - (&Keypair::new(), &[pubkey], 1, 2), - ]); - - let scheduling_summary = scheduler - .schedule(&mut container, test_pre_graph_filter, test_pre_lock_filter) - .unwrap(); - assert_eq!(scheduling_summary.num_scheduled, 2); - assert_eq!(scheduling_summary.num_unschedulable, 0); - assert_eq!( - collect_work(&work_receivers[0]).1, - vec![txids!([1]), txids!([0])] - ); - } - - #[test] - fn test_schedule_consume_single_threaded_multi_batch() { - let (mut scheduler, work_receivers, _finished_work_sender) = create_test_frame(1); - let mut container = create_container( - (0..4 * TARGET_NUM_TRANSACTIONS_PER_BATCH) - .map(|i| (Keypair::new(), [Pubkey::new_unique()], i as u64, 1)), - ); - - // expect 4 full batches to be scheduled - let scheduling_summary = scheduler - .schedule(&mut container, test_pre_graph_filter, test_pre_lock_filter) - .unwrap(); - assert_eq!( - scheduling_summary.num_scheduled, - 4 * TARGET_NUM_TRANSACTIONS_PER_BATCH - ); - assert_eq!(scheduling_summary.num_unschedulable, 0); - - let thread0_work_counts: Vec<_> = work_receivers[0] - .try_iter() - .map(|work| work.ids.len()) - .collect(); - assert_eq!(thread0_work_counts, [TARGET_NUM_TRANSACTIONS_PER_BATCH; 4]); - } - - #[test] - fn test_schedule_simple_thread_selection() { - let (mut scheduler, work_receivers, _finished_work_sender) = create_test_frame(2); - let mut container = - create_container((0..4).map(|i| (Keypair::new(), [Pubkey::new_unique()], 1, i))); - - let scheduling_summary = scheduler - .schedule(&mut container, test_pre_graph_filter, test_pre_lock_filter) - .unwrap(); - assert_eq!(scheduling_summary.num_scheduled, 4); - assert_eq!(scheduling_summary.num_unschedulable, 0); - assert_eq!(collect_work(&work_receivers[0]).1, [txids!([3, 1])]); - assert_eq!(collect_work(&work_receivers[1]).1, [txids!([2, 0])]); - } - - #[test] - fn test_schedule_priority_guard() { - let (mut scheduler, work_receivers, finished_work_sender) = create_test_frame(2); - // intentionally shorten the look-ahead window to cause unschedulable conflicts - scheduler.look_ahead_window_size = 2; - - let accounts = (0..8).map(|_| Keypair::new()).collect_vec(); - let mut container = create_container([ - (&accounts[0], &[accounts[1].pubkey()], 1, 6), - (&accounts[2], &[accounts[3].pubkey()], 1, 5), - (&accounts[4], &[accounts[5].pubkey()], 1, 4), - (&accounts[6], &[accounts[7].pubkey()], 1, 3), - (&accounts[1], &[accounts[2].pubkey()], 1, 2), - (&accounts[2], &[accounts[3].pubkey()], 1, 1), - ]); - - // The look-ahead window is intentionally shortened, high priority transactions - // [0, 1, 2, 3] do not conflict, and are scheduled onto threads in a - // round-robin fashion. This leads to transaction [4] being unschedulable due - // to conflicts with [0] and [1], which were scheduled to different threads. - // Transaction [5] is technically schedulable, onto thread 1 since it only - // conflicts with transaction [1]. However, [5] will not be scheduled because - // it conflicts with a higher-priority transaction [4] that is unschedulable. - // The full prio-graph can be visualized as: - // [0] \ - // -> [4] -> [5] - // [1] / ------/ - // [2] - // [3] - // Because the look-ahead window is shortened to a size of 4, the scheduler does - // not have knowledge of the joining at transaction [4] until after [0] and [1] - // have been scheduled. - let scheduling_summary = scheduler - .schedule(&mut container, test_pre_graph_filter, test_pre_lock_filter) - .unwrap(); - assert_eq!(scheduling_summary.num_scheduled, 4); - assert_eq!(scheduling_summary.num_unschedulable, 2); - let (thread_0_work, thread_0_ids) = collect_work(&work_receivers[0]); - assert_eq!(thread_0_ids, [txids!([0]), txids!([2])]); - assert_eq!( - collect_work(&work_receivers[1]).1, - [txids!([1]), txids!([3])] - ); - - // Cannot schedule even on next pass because of lock conflicts - let scheduling_summary = scheduler - .schedule(&mut container, test_pre_graph_filter, test_pre_lock_filter) - .unwrap(); - assert_eq!(scheduling_summary.num_scheduled, 0); - assert_eq!(scheduling_summary.num_unschedulable, 2); - - // Complete batch on thread 0. Remaining txs can be scheduled onto thread 1 - finished_work_sender - .send(FinishedConsumeWork { - work: thread_0_work.into_iter().next().unwrap(), - retryable_indexes: vec![], - }) - .unwrap(); - scheduler.receive_completed(&mut container).unwrap(); - let scheduling_summary = scheduler - .schedule(&mut container, test_pre_graph_filter, test_pre_lock_filter) - .unwrap(); - assert_eq!(scheduling_summary.num_scheduled, 2); - assert_eq!(scheduling_summary.num_unschedulable, 0); - - assert_eq!( - collect_work(&work_receivers[1]).1, - [txids!([4]), txids!([5])] - ); - } - - #[test] - fn test_schedule_pre_lock_filter() { - let (mut scheduler, work_receivers, _finished_work_sender) = create_test_frame(1); - let pubkey = Pubkey::new_unique(); - let keypair = Keypair::new(); - let mut container = create_container([ - (&Keypair::new(), &[pubkey], 1, 1), - (&keypair, &[pubkey], 1, 2), - (&Keypair::new(), &[pubkey], 1, 3), - ]); - - // 2nd transaction should be filtered out and dropped before locking. - let pre_lock_filter = - |tx: &SanitizedTransaction| tx.message().fee_payer() != &keypair.pubkey(); - let scheduling_summary = scheduler - .schedule(&mut container, test_pre_graph_filter, pre_lock_filter) - .unwrap(); - assert_eq!(scheduling_summary.num_scheduled, 2); - assert_eq!(scheduling_summary.num_unschedulable, 0); - assert_eq!( - collect_work(&work_receivers[0]).1, - vec![txids!([2]), txids!([0])] - ); - } -} diff --git a/core/src/banking_stage/transaction_scheduler/scheduler_error.rs b/core/src/banking_stage/transaction_scheduler/scheduler_error.rs deleted file mode 100644 index 9b8d4015448e57..00000000000000 --- a/core/src/banking_stage/transaction_scheduler/scheduler_error.rs +++ /dev/null @@ -1,9 +0,0 @@ -use thiserror::Error; - -#[derive(Debug, Error)] -pub enum SchedulerError { - #[error("Sending channel disconnected: {0}")] - DisconnectedSendChannel(&'static str), - #[error("Recv channel disconnected: {0}")] - DisconnectedRecvChannel(&'static str), -} diff --git a/core/src/banking_stage/transaction_scheduler/scheduler_metrics.rs b/core/src/banking_stage/transaction_scheduler/scheduler_metrics.rs deleted file mode 100644 index bb8cbbe617396a..00000000000000 --- a/core/src/banking_stage/transaction_scheduler/scheduler_metrics.rs +++ /dev/null @@ -1,408 +0,0 @@ -use { - itertools::MinMaxResult, - solana_poh::poh_recorder::BankStart, - solana_sdk::{clock::Slot, timing::AtomicInterval}, - std::time::Instant, -}; - -#[derive(Default)] -pub struct SchedulerCountMetrics { - interval: IntervalSchedulerCountMetrics, - slot: SlotSchedulerCountMetrics, -} - -impl SchedulerCountMetrics { - pub fn update(&mut self, update: impl Fn(&mut SchedulerCountMetricsInner)) { - update(&mut self.interval.metrics); - update(&mut self.slot.metrics); - } - - pub fn maybe_report_and_reset_slot(&mut self, slot: Option) { - self.slot.maybe_report_and_reset(slot); - } - - pub fn maybe_report_and_reset_interval(&mut self, should_report: bool) { - self.interval.maybe_report_and_reset(should_report); - } - - pub fn interval_has_data(&self) -> bool { - self.interval.metrics.has_data() - } -} - -#[derive(Default)] -struct IntervalSchedulerCountMetrics { - interval: AtomicInterval, - metrics: SchedulerCountMetricsInner, -} - -#[derive(Default)] -struct SlotSchedulerCountMetrics { - slot: Option, - metrics: SchedulerCountMetricsInner, -} - -#[derive(Default)] -pub struct SchedulerCountMetricsInner { - /// Number of packets received. - pub num_received: usize, - /// Number of packets buffered. - pub num_buffered: usize, - - /// Number of transactions scheduled. - pub num_scheduled: usize, - /// Number of transactions that were unschedulable. - pub num_unschedulable: usize, - /// Number of transactions that were filtered out during scheduling. - pub num_schedule_filtered_out: usize, - /// Number of completed transactions received from workers. - pub num_finished: usize, - /// Number of transactions that were retryable. - pub num_retryable: usize, - /// Number of transactions that were scheduled to be forwarded. - pub num_forwarded: usize, - - /// Number of transactions that were immediately dropped on receive. - pub num_dropped_on_receive: usize, - /// Number of transactions that were dropped due to sanitization failure. - pub num_dropped_on_sanitization: usize, - /// Number of transactions that were dropped due to failed lock validation. - pub num_dropped_on_validate_locks: usize, - /// Number of transactions that were dropped due to failed transaction - /// checks during receive. - pub num_dropped_on_receive_transaction_checks: usize, - /// Number of transactions that were dropped due to clearing. - pub num_dropped_on_clear: usize, - /// Number of transactions that were dropped due to age and status checks. - pub num_dropped_on_age_and_status: usize, - /// Number of transactions that were dropped due to exceeded capacity. - pub num_dropped_on_capacity: usize, - /// Min prioritization fees in the transaction container - pub min_prioritization_fees: u64, - /// Max prioritization fees in the transaction container - pub max_prioritization_fees: u64, -} - -impl IntervalSchedulerCountMetrics { - fn maybe_report_and_reset(&mut self, should_report: bool) { - const REPORT_INTERVAL_MS: u64 = 1000; - if self.interval.should_update(REPORT_INTERVAL_MS) { - if should_report { - self.metrics.report("banking_stage_scheduler_counts", None); - } - self.metrics.reset(); - } - } -} - -impl SlotSchedulerCountMetrics { - fn maybe_report_and_reset(&mut self, slot: Option) { - if self.slot != slot { - // Only report if there was an assigned slot. - if self.slot.is_some() { - self.metrics - .report("banking_stage_scheduler_slot_counts", self.slot); - } - self.metrics.reset(); - self.slot = slot; - } - } -} - -impl SchedulerCountMetricsInner { - fn report(&self, name: &'static str, slot: Option) { - let mut datapoint = create_datapoint!( - @point name, - ("num_received", self.num_received, i64), - ("num_buffered", self.num_buffered, i64), - ("num_scheduled", self.num_scheduled, i64), - ("num_unschedulable", self.num_unschedulable, i64), - ( - "num_schedule_filtered_out", - self.num_schedule_filtered_out, - i64 - ), - ("num_finished", self.num_finished, i64), - ("num_retryable", self.num_retryable, i64), - ("num_forwarded", self.num_forwarded, i64), - ("num_dropped_on_receive", self.num_dropped_on_receive, i64), - ( - "num_dropped_on_sanitization", - self.num_dropped_on_sanitization, - i64 - ), - ( - "num_dropped_on_validate_locks", - self.num_dropped_on_validate_locks, - i64 - ), - ( - "num_dropped_on_receive_transaction_checks", - self.num_dropped_on_receive_transaction_checks, - i64 - ), - ("num_dropped_on_clear", self.num_dropped_on_clear, i64), - ( - "num_dropped_on_age_and_status", - self.num_dropped_on_age_and_status, - i64 - ), - ("num_dropped_on_capacity", self.num_dropped_on_capacity, i64), - ("min_priority", self.get_min_priority(), i64), - ("max_priority", self.get_max_priority(), i64) - ); - if let Some(slot) = slot { - datapoint.add_field_i64("slot", slot as i64); - } - solana_metrics::submit(datapoint, log::Level::Info); - } - - pub fn has_data(&self) -> bool { - self.num_received != 0 - || self.num_buffered != 0 - || self.num_scheduled != 0 - || self.num_unschedulable != 0 - || self.num_schedule_filtered_out != 0 - || self.num_finished != 0 - || self.num_retryable != 0 - || self.num_forwarded != 0 - || self.num_dropped_on_receive != 0 - || self.num_dropped_on_sanitization != 0 - || self.num_dropped_on_validate_locks != 0 - || self.num_dropped_on_receive_transaction_checks != 0 - || self.num_dropped_on_clear != 0 - || self.num_dropped_on_age_and_status != 0 - || self.num_dropped_on_capacity != 0 - } - - fn reset(&mut self) { - self.num_received = 0; - self.num_buffered = 0; - self.num_scheduled = 0; - self.num_unschedulable = 0; - self.num_schedule_filtered_out = 0; - self.num_finished = 0; - self.num_retryable = 0; - self.num_forwarded = 0; - self.num_dropped_on_receive = 0; - self.num_dropped_on_sanitization = 0; - self.num_dropped_on_validate_locks = 0; - self.num_dropped_on_receive_transaction_checks = 0; - self.num_dropped_on_clear = 0; - self.num_dropped_on_age_and_status = 0; - self.num_dropped_on_capacity = 0; - self.min_prioritization_fees = u64::MAX; - self.max_prioritization_fees = 0; - } - - pub fn update_priority_stats(&mut self, min_max_fees: MinMaxResult) { - // update min/max priority - match min_max_fees { - itertools::MinMaxResult::NoElements => { - // do nothing - } - itertools::MinMaxResult::OneElement(e) => { - self.min_prioritization_fees = e; - self.max_prioritization_fees = e; - } - itertools::MinMaxResult::MinMax(min, max) => { - self.min_prioritization_fees = min; - self.max_prioritization_fees = max; - } - } - } - - pub fn get_min_priority(&self) -> u64 { - // to avoid getting u64::max recorded by metrics / in case of edge cases - if self.min_prioritization_fees != u64::MAX { - self.min_prioritization_fees - } else { - 0 - } - } - - pub fn get_max_priority(&self) -> u64 { - self.max_prioritization_fees - } -} - -#[derive(Default)] -pub struct SchedulerTimingMetrics { - interval: IntervalSchedulerTimingMetrics, - slot: SlotSchedulerTimingMetrics, -} - -impl SchedulerTimingMetrics { - pub fn update(&mut self, update: impl Fn(&mut SchedulerTimingMetricsInner)) { - update(&mut self.interval.metrics); - update(&mut self.slot.metrics); - } - - pub fn maybe_report_and_reset_slot(&mut self, slot: Option) { - self.slot.maybe_report_and_reset(slot); - } - - pub fn maybe_report_and_reset_interval(&mut self, should_report: bool) { - self.interval.maybe_report_and_reset(should_report); - } -} - -#[derive(Default)] -struct IntervalSchedulerTimingMetrics { - interval: AtomicInterval, - metrics: SchedulerTimingMetricsInner, -} - -#[derive(Default)] -struct SlotSchedulerTimingMetrics { - slot: Option, - metrics: SchedulerTimingMetricsInner, -} - -#[derive(Default)] -pub struct SchedulerTimingMetricsInner { - /// Time spent making processing decisions. - pub decision_time_us: u64, - /// Time spent receiving packets. - pub receive_time_us: u64, - /// Time spent buffering packets. - pub buffer_time_us: u64, - /// Time spent filtering transactions during scheduling. - pub schedule_filter_time_us: u64, - /// Time spent scheduling transactions. - pub schedule_time_us: u64, - /// Time spent clearing transactions from the container. - pub clear_time_us: u64, - /// Time spent cleaning expired or processed transactions from the container. - pub clean_time_us: u64, - /// Time spent forwarding transactions. - pub forward_time_us: u64, - /// Time spent receiving completed transactions. - pub receive_completed_time_us: u64, -} - -impl IntervalSchedulerTimingMetrics { - fn maybe_report_and_reset(&mut self, should_report: bool) { - const REPORT_INTERVAL_MS: u64 = 1000; - if self.interval.should_update(REPORT_INTERVAL_MS) { - if should_report { - self.metrics.report("banking_stage_scheduler_timing", None); - } - self.metrics.reset(); - } - } -} - -impl SlotSchedulerTimingMetrics { - fn maybe_report_and_reset(&mut self, slot: Option) { - if self.slot != slot { - // Only report if there was an assigned slot. - if self.slot.is_some() { - self.metrics - .report("banking_stage_scheduler_slot_timing", self.slot); - } - self.metrics.reset(); - self.slot = slot; - } - } -} - -impl SchedulerTimingMetricsInner { - fn report(&self, name: &'static str, slot: Option) { - let mut datapoint = create_datapoint!( - @point name, - ("decision_time_us", self.decision_time_us, i64), - ("receive_time_us", self.receive_time_us, i64), - ("buffer_time_us", self.buffer_time_us, i64), - ("schedule_filter_time_us", self.schedule_filter_time_us, i64), - ("schedule_time_us", self.schedule_time_us, i64), - ("clear_time_us", self.clear_time_us, i64), - ("clean_time_us", self.clean_time_us, i64), - ("forward_time_us", self.forward_time_us, i64), - ( - "receive_completed_time_us", - self.receive_completed_time_us, - i64 - ) - ); - if let Some(slot) = slot { - datapoint.add_field_i64("slot", slot as i64); - } - solana_metrics::submit(datapoint, log::Level::Info); - } - - fn reset(&mut self) { - self.decision_time_us = 0; - self.receive_time_us = 0; - self.buffer_time_us = 0; - self.schedule_filter_time_us = 0; - self.schedule_time_us = 0; - self.clear_time_us = 0; - self.clean_time_us = 0; - self.forward_time_us = 0; - self.receive_completed_time_us = 0; - } -} - -#[derive(Default)] -pub struct SchedulerLeaderDetectionMetrics { - inner: Option, -} - -struct SchedulerLeaderDetectionMetricsInner { - slot: Slot, - bank_creation_time: Instant, - bank_detected_time: Instant, -} - -impl SchedulerLeaderDetectionMetrics { - pub fn update_and_maybe_report(&mut self, bank_start: Option<&BankStart>) { - match (&self.inner, bank_start) { - (None, Some(bank_start)) => self.initialize_inner(bank_start), - (Some(_inner), None) => self.report_and_reset(), - (Some(inner), Some(bank_start)) if inner.slot != bank_start.working_bank.slot() => { - self.report_and_reset(); - self.initialize_inner(bank_start); - } - _ => {} - } - } - - fn initialize_inner(&mut self, bank_start: &BankStart) { - let bank_detected_time = Instant::now(); - self.inner = Some(SchedulerLeaderDetectionMetricsInner { - slot: bank_start.working_bank.slot(), - bank_creation_time: *bank_start.bank_creation_time, - bank_detected_time, - }); - } - - fn report_and_reset(&mut self) { - let SchedulerLeaderDetectionMetricsInner { - slot, - bank_creation_time, - bank_detected_time, - } = self.inner.take().expect("inner must be present"); - - let bank_detected_delay_us = bank_detected_time - .duration_since(bank_creation_time) - .as_micros() - .try_into() - .unwrap_or(i64::MAX); - let bank_detected_to_slot_end_detected_us = bank_detected_time - .elapsed() - .as_micros() - .try_into() - .unwrap_or(i64::MAX); - datapoint_info!( - "banking_stage_scheduler_leader_detection", - ("slot", slot, i64), - ("bank_detected_delay_us", bank_detected_delay_us, i64), - ( - "bank_detected_to_slot_end_detected_us", - bank_detected_to_slot_end_detected_us, - i64 - ), - ); - } -} diff --git a/core/src/banking_stage/transaction_scheduler/thread_aware_account_locks.rs b/core/src/banking_stage/transaction_scheduler/thread_aware_account_locks.rs deleted file mode 100644 index b279102756eed4..00000000000000 --- a/core/src/banking_stage/transaction_scheduler/thread_aware_account_locks.rs +++ /dev/null @@ -1,742 +0,0 @@ -use { - ahash::AHashMap, - solana_sdk::pubkey::Pubkey, - std::{ - collections::hash_map::Entry, - fmt::{Debug, Display}, - ops::{BitAnd, BitAndAssign, Sub}, - }, -}; - -pub(crate) const MAX_THREADS: usize = u64::BITS as usize; - -/// Identifier for a thread -pub(crate) type ThreadId = usize; // 0..MAX_THREADS-1 - -type LockCount = u32; - -/// A bit-set of threads an account is scheduled or can be scheduled for. -#[derive(Copy, Clone, PartialEq, Eq)] -pub(crate) struct ThreadSet(u64); - -struct AccountWriteLocks { - thread_id: ThreadId, - lock_count: LockCount, -} - -struct AccountReadLocks { - thread_set: ThreadSet, - lock_counts: [LockCount; MAX_THREADS], -} - -/// Account locks. -/// Write Locks - only one thread can hold a write lock at a time. -/// Contains how many write locks are held by the thread. -/// Read Locks - multiple threads can hold a read lock at a time. -/// Contains thread-set for easily checking which threads are scheduled. -#[derive(Default)] -struct AccountLocks { - pub write_locks: Option, - pub read_locks: Option, -} - -/// Thread-aware account locks which allows for scheduling on threads -/// that already hold locks on the account. This is useful for allowing -/// queued transactions to be scheduled on a thread while the transaction -/// is still being executed on the thread. -pub(crate) struct ThreadAwareAccountLocks { - /// Number of threads. - num_threads: usize, // 0..MAX_THREADS - /// Locks for each account. An account should only have an entry if there - /// is at least one lock. - locks: AHashMap, -} - -impl ThreadAwareAccountLocks { - /// Creates a new `ThreadAwareAccountLocks` with the given number of threads. - pub(crate) fn new(num_threads: usize) -> Self { - assert!(num_threads > 0, "num threads must be > 0"); - assert!( - num_threads <= MAX_THREADS, - "num threads must be <= {MAX_THREADS}" - ); - - Self { - num_threads, - locks: AHashMap::new(), - } - } - - /// Returns the `ThreadId` if the accounts are able to be locked - /// for the given thread, otherwise `None` is returned. - /// `allowed_threads` is a set of threads that the caller restricts locking to. - /// If accounts are schedulable, then they are locked for the thread - /// selected by the `thread_selector` function. - /// `thread_selector` is only called if all accounts are schdulable, meaning - /// that the `thread_set` passed to `thread_selector` is non-empty. - pub(crate) fn try_lock_accounts<'a>( - &mut self, - write_account_locks: impl Iterator + Clone, - read_account_locks: impl Iterator + Clone, - allowed_threads: ThreadSet, - thread_selector: impl FnOnce(ThreadSet) -> ThreadId, - ) -> Option { - let schedulable_threads = self.accounts_schedulable_threads( - write_account_locks.clone(), - read_account_locks.clone(), - )? & allowed_threads; - (!schedulable_threads.is_empty()).then(|| { - let thread_id = thread_selector(schedulable_threads); - self.lock_accounts(write_account_locks, read_account_locks, thread_id); - thread_id - }) - } - - /// Unlocks the accounts for the given thread. - pub(crate) fn unlock_accounts<'a>( - &mut self, - write_account_locks: impl Iterator, - read_account_locks: impl Iterator, - thread_id: ThreadId, - ) { - for account in write_account_locks { - self.write_unlock_account(account, thread_id); - } - - for account in read_account_locks { - self.read_unlock_account(account, thread_id); - } - } - - /// Returns `ThreadSet` that the given accounts can be scheduled on. - fn accounts_schedulable_threads<'a>( - &self, - write_account_locks: impl Iterator, - read_account_locks: impl Iterator, - ) -> Option { - let mut schedulable_threads = ThreadSet::any(self.num_threads); - - for account in write_account_locks { - schedulable_threads &= self.write_schedulable_threads(account); - if schedulable_threads.is_empty() { - return None; - } - } - - for account in read_account_locks { - schedulable_threads &= self.read_schedulable_threads(account); - if schedulable_threads.is_empty() { - return None; - } - } - - Some(schedulable_threads) - } - - /// Returns `ThreadSet` of schedulable threads for the given readable account. - fn read_schedulable_threads(&self, account: &Pubkey) -> ThreadSet { - self.schedulable_threads::(account) - } - - /// Returns `ThreadSet` of schedulable threads for the given writable account. - fn write_schedulable_threads(&self, account: &Pubkey) -> ThreadSet { - self.schedulable_threads::(account) - } - - /// Returns `ThreadSet` of schedulable threads. - /// If there are no locks, then all threads are schedulable. - /// If only write-locked, then only the thread holding the write lock is schedulable. - /// If a mix of locks, then only the write thread is schedulable. - /// If only read-locked, the only write-schedulable thread is if a single thread - /// holds all read locks. Otherwise, no threads are write-schedulable. - /// If only read-locked, all threads are read-schedulable. - fn schedulable_threads(&self, account: &Pubkey) -> ThreadSet { - match self.locks.get(account) { - None => ThreadSet::any(self.num_threads), - Some(AccountLocks { - write_locks: None, - read_locks: Some(read_locks), - }) => { - if WRITE { - read_locks - .thread_set - .only_one_contained() - .map(ThreadSet::only) - .unwrap_or_else(ThreadSet::none) - } else { - ThreadSet::any(self.num_threads) - } - } - Some(AccountLocks { - write_locks: Some(write_locks), - read_locks: None, - }) => ThreadSet::only(write_locks.thread_id), - Some(AccountLocks { - write_locks: Some(write_locks), - read_locks: Some(read_locks), - }) => { - assert_eq!( - read_locks.thread_set.only_one_contained(), - Some(write_locks.thread_id) - ); - read_locks.thread_set - } - Some(AccountLocks { - write_locks: None, - read_locks: None, - }) => unreachable!(), - } - } - - /// Add locks for all writable and readable accounts on `thread_id`. - fn lock_accounts<'a>( - &mut self, - write_account_locks: impl Iterator, - read_account_locks: impl Iterator, - thread_id: ThreadId, - ) { - assert!( - thread_id < self.num_threads, - "thread_id must be < num_threads" - ); - for account in write_account_locks { - self.write_lock_account(account, thread_id); - } - - for account in read_account_locks { - self.read_lock_account(account, thread_id); - } - } - - /// Locks the given `account` for writing on `thread_id`. - /// Panics if the account is already locked for writing on another thread. - fn write_lock_account(&mut self, account: &Pubkey, thread_id: ThreadId) { - let entry = self.locks.entry(*account).or_default(); - - let AccountLocks { - write_locks, - read_locks, - } = entry; - - if let Some(read_locks) = read_locks { - assert_eq!( - read_locks.thread_set.only_one_contained(), - Some(thread_id), - "outstanding read lock must be on same thread" - ); - } - - if let Some(write_locks) = write_locks { - assert_eq!( - write_locks.thread_id, thread_id, - "outstanding write lock must be on same thread" - ); - write_locks.lock_count += 1; - } else { - *write_locks = Some(AccountWriteLocks { - thread_id, - lock_count: 1, - }); - } - } - - /// Unlocks the given `account` for writing on `thread_id`. - /// Panics if the account is not locked for writing on `thread_id`. - fn write_unlock_account(&mut self, account: &Pubkey, thread_id: ThreadId) { - let Entry::Occupied(mut entry) = self.locks.entry(*account) else { - panic!("write lock must exist for account: {account}"); - }; - - let AccountLocks { - write_locks: maybe_write_locks, - read_locks, - } = entry.get_mut(); - - let Some(write_locks) = maybe_write_locks else { - panic!("write lock must exist for account: {account}"); - }; - - assert_eq!( - write_locks.thread_id, thread_id, - "outstanding write lock must be on same thread" - ); - - write_locks.lock_count -= 1; - if write_locks.lock_count == 0 { - *maybe_write_locks = None; - if read_locks.is_none() { - entry.remove(); - } - } - } - - /// Locks the given `account` for reading on `thread_id`. - /// Panics if the account is already locked for writing on another thread. - fn read_lock_account(&mut self, account: &Pubkey, thread_id: ThreadId) { - let AccountLocks { - write_locks, - read_locks, - } = self.locks.entry(*account).or_default(); - - if let Some(write_locks) = write_locks { - assert_eq!( - write_locks.thread_id, thread_id, - "outstanding write lock must be on same thread" - ); - } - - match read_locks { - Some(read_locks) => { - read_locks.thread_set.insert(thread_id); - read_locks.lock_counts[thread_id] += 1; - } - None => { - let mut lock_counts = [0; MAX_THREADS]; - lock_counts[thread_id] = 1; - *read_locks = Some(AccountReadLocks { - thread_set: ThreadSet::only(thread_id), - lock_counts, - }); - } - } - } - - /// Unlocks the given `account` for reading on `thread_id`. - /// Panics if the account is not locked for reading on `thread_id`. - fn read_unlock_account(&mut self, account: &Pubkey, thread_id: ThreadId) { - let Entry::Occupied(mut entry) = self.locks.entry(*account) else { - panic!("read lock must exist for account: {account}"); - }; - - let AccountLocks { - write_locks, - read_locks: maybe_read_locks, - } = entry.get_mut(); - - let Some(read_locks) = maybe_read_locks else { - panic!("read lock must exist for account: {account}"); - }; - - assert!( - read_locks.thread_set.contains(thread_id), - "outstanding read lock must be on same thread" - ); - - read_locks.lock_counts[thread_id] -= 1; - if read_locks.lock_counts[thread_id] == 0 { - read_locks.thread_set.remove(thread_id); - if read_locks.thread_set.is_empty() { - *maybe_read_locks = None; - if write_locks.is_none() { - entry.remove(); - } - } - } - } -} - -impl BitAnd for ThreadSet { - type Output = Self; - - fn bitand(self, rhs: Self) -> Self::Output { - Self(self.0 & rhs.0) - } -} - -impl BitAndAssign for ThreadSet { - fn bitand_assign(&mut self, rhs: Self) { - self.0 &= rhs.0; - } -} - -impl Sub for ThreadSet { - type Output = Self; - - fn sub(self, rhs: Self) -> Self::Output { - Self(self.0 & !rhs.0) - } -} - -impl Display for ThreadSet { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "ThreadSet({:#0width$b})", self.0, width = MAX_THREADS) - } -} - -impl Debug for ThreadSet { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - Display::fmt(self, f) - } -} - -impl ThreadSet { - #[inline(always)] - pub(crate) const fn none() -> Self { - Self(0b0) - } - - #[inline(always)] - pub(crate) const fn any(num_threads: usize) -> Self { - if num_threads == MAX_THREADS { - Self(u64::MAX) - } else { - Self(Self::as_flag(num_threads) - 1) - } - } - - #[inline(always)] - pub(crate) const fn only(thread_id: ThreadId) -> Self { - Self(Self::as_flag(thread_id)) - } - - #[inline(always)] - pub(crate) fn num_threads(&self) -> u32 { - self.0.count_ones() - } - - #[inline(always)] - pub(crate) fn only_one_contained(&self) -> Option { - (self.num_threads() == 1).then_some(self.0.trailing_zeros() as ThreadId) - } - - #[inline(always)] - pub(crate) fn is_empty(&self) -> bool { - self == &Self::none() - } - - #[inline(always)] - pub(crate) fn contains(&self, thread_id: ThreadId) -> bool { - self.0 & Self::as_flag(thread_id) != 0 - } - - #[inline(always)] - pub(crate) fn insert(&mut self, thread_id: ThreadId) { - self.0 |= Self::as_flag(thread_id); - } - - #[inline(always)] - pub(crate) fn remove(&mut self, thread_id: ThreadId) { - self.0 &= !Self::as_flag(thread_id); - } - - #[inline(always)] - pub(crate) fn contained_threads_iter(self) -> impl Iterator { - (0..MAX_THREADS).filter(move |thread_id| self.contains(*thread_id)) - } - - #[inline(always)] - const fn as_flag(thread_id: ThreadId) -> u64 { - 0b1 << thread_id - } -} - -#[cfg(test)] -mod tests { - use super::*; - - const TEST_NUM_THREADS: usize = 4; - const TEST_ANY_THREADS: ThreadSet = ThreadSet::any(TEST_NUM_THREADS); - - // Simple thread selector to select the first schedulable thread - fn test_thread_selector(thread_set: ThreadSet) -> ThreadId { - thread_set.contained_threads_iter().next().unwrap() - } - - #[test] - #[should_panic(expected = "num threads must be > 0")] - fn test_too_few_num_threads() { - ThreadAwareAccountLocks::new(0); - } - - #[test] - #[should_panic(expected = "num threads must be <=")] - fn test_too_many_num_threads() { - ThreadAwareAccountLocks::new(MAX_THREADS + 1); - } - - #[test] - fn test_try_lock_accounts_none() { - let pk1 = Pubkey::new_unique(); - let pk2 = Pubkey::new_unique(); - let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - locks.read_lock_account(&pk1, 2); - locks.read_lock_account(&pk1, 3); - assert_eq!( - locks.try_lock_accounts( - [&pk1].into_iter(), - [&pk2].into_iter(), - TEST_ANY_THREADS, - test_thread_selector - ), - None - ); - } - - #[test] - fn test_try_lock_accounts_one() { - let pk1 = Pubkey::new_unique(); - let pk2 = Pubkey::new_unique(); - let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - locks.write_lock_account(&pk2, 3); - - assert_eq!( - locks.try_lock_accounts( - [&pk1].into_iter(), - [&pk2].into_iter(), - TEST_ANY_THREADS, - test_thread_selector - ), - Some(3) - ); - } - - #[test] - fn test_try_lock_accounts_multiple() { - let pk1 = Pubkey::new_unique(); - let pk2 = Pubkey::new_unique(); - let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - locks.read_lock_account(&pk2, 0); - locks.read_lock_account(&pk2, 0); - - assert_eq!( - locks.try_lock_accounts( - [&pk1].into_iter(), - [&pk2].into_iter(), - TEST_ANY_THREADS - ThreadSet::only(0), // exclude 0 - test_thread_selector - ), - Some(1) - ); - } - - #[test] - fn test_try_lock_accounts_any() { - let pk1 = Pubkey::new_unique(); - let pk2 = Pubkey::new_unique(); - let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - assert_eq!( - locks.try_lock_accounts( - [&pk1].into_iter(), - [&pk2].into_iter(), - TEST_ANY_THREADS, - test_thread_selector - ), - Some(0) - ); - } - - #[test] - fn test_accounts_schedulable_threads_no_outstanding_locks() { - let pk1 = Pubkey::new_unique(); - let locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - - assert_eq!( - locks.accounts_schedulable_threads([&pk1].into_iter(), std::iter::empty()), - Some(TEST_ANY_THREADS) - ); - assert_eq!( - locks.accounts_schedulable_threads(std::iter::empty(), [&pk1].into_iter()), - Some(TEST_ANY_THREADS) - ); - } - - #[test] - fn test_accounts_schedulable_threads_outstanding_write_only() { - let pk1 = Pubkey::new_unique(); - let pk2 = Pubkey::new_unique(); - let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - - locks.write_lock_account(&pk1, 2); - assert_eq!( - locks.accounts_schedulable_threads([&pk1, &pk2].into_iter(), std::iter::empty()), - Some(ThreadSet::only(2)) - ); - assert_eq!( - locks.accounts_schedulable_threads(std::iter::empty(), [&pk1, &pk2].into_iter()), - Some(ThreadSet::only(2)) - ); - } - - #[test] - fn test_accounts_schedulable_threads_outstanding_read_only() { - let pk1 = Pubkey::new_unique(); - let pk2 = Pubkey::new_unique(); - let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - - locks.read_lock_account(&pk1, 2); - assert_eq!( - locks.accounts_schedulable_threads([&pk1, &pk2].into_iter(), std::iter::empty()), - Some(ThreadSet::only(2)) - ); - assert_eq!( - locks.accounts_schedulable_threads(std::iter::empty(), [&pk1, &pk2].into_iter()), - Some(TEST_ANY_THREADS) - ); - - locks.read_lock_account(&pk1, 0); - assert_eq!( - locks.accounts_schedulable_threads([&pk1, &pk2].into_iter(), std::iter::empty()), - None - ); - assert_eq!( - locks.accounts_schedulable_threads(std::iter::empty(), [&pk1, &pk2].into_iter()), - Some(TEST_ANY_THREADS) - ); - } - - #[test] - fn test_accounts_schedulable_threads_outstanding_mixed() { - let pk1 = Pubkey::new_unique(); - let pk2 = Pubkey::new_unique(); - let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - - locks.read_lock_account(&pk1, 2); - locks.write_lock_account(&pk1, 2); - assert_eq!( - locks.accounts_schedulable_threads([&pk1, &pk2].into_iter(), std::iter::empty()), - Some(ThreadSet::only(2)) - ); - assert_eq!( - locks.accounts_schedulable_threads(std::iter::empty(), [&pk1, &pk2].into_iter()), - Some(ThreadSet::only(2)) - ); - } - - #[test] - #[should_panic(expected = "outstanding write lock must be on same thread")] - fn test_write_lock_account_write_conflict_panic() { - let pk1 = Pubkey::new_unique(); - let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - locks.write_lock_account(&pk1, 0); - locks.write_lock_account(&pk1, 1); - } - - #[test] - #[should_panic(expected = "outstanding read lock must be on same thread")] - fn test_write_lock_account_read_conflict_panic() { - let pk1 = Pubkey::new_unique(); - let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - locks.read_lock_account(&pk1, 0); - locks.write_lock_account(&pk1, 1); - } - - #[test] - #[should_panic(expected = "write lock must exist")] - fn test_write_unlock_account_not_locked() { - let pk1 = Pubkey::new_unique(); - let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - locks.write_unlock_account(&pk1, 0); - } - - #[test] - #[should_panic(expected = "outstanding write lock must be on same thread")] - fn test_write_unlock_account_thread_mismatch() { - let pk1 = Pubkey::new_unique(); - let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - locks.write_lock_account(&pk1, 1); - locks.write_unlock_account(&pk1, 0); - } - - #[test] - #[should_panic(expected = "outstanding write lock must be on same thread")] - fn test_read_lock_account_write_conflict_panic() { - let pk1 = Pubkey::new_unique(); - let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - locks.write_lock_account(&pk1, 0); - locks.read_lock_account(&pk1, 1); - } - - #[test] - #[should_panic(expected = "read lock must exist")] - fn test_read_unlock_account_not_locked() { - let pk1 = Pubkey::new_unique(); - let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - locks.read_unlock_account(&pk1, 1); - } - - #[test] - #[should_panic(expected = "outstanding read lock must be on same thread")] - fn test_read_unlock_account_thread_mismatch() { - let pk1 = Pubkey::new_unique(); - let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - locks.read_lock_account(&pk1, 0); - locks.read_unlock_account(&pk1, 1); - } - - #[test] - fn test_write_locking() { - let pk1 = Pubkey::new_unique(); - let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - locks.write_lock_account(&pk1, 1); - locks.write_lock_account(&pk1, 1); - locks.write_unlock_account(&pk1, 1); - locks.write_unlock_account(&pk1, 1); - assert!(locks.locks.is_empty()); - } - - #[test] - fn test_read_locking() { - let pk1 = Pubkey::new_unique(); - let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - locks.read_lock_account(&pk1, 1); - locks.read_lock_account(&pk1, 1); - locks.read_unlock_account(&pk1, 1); - locks.read_unlock_account(&pk1, 1); - assert!(locks.locks.is_empty()); - } - - #[test] - #[should_panic(expected = "thread_id must be < num_threads")] - fn test_lock_accounts_invalid_thread() { - let pk1 = Pubkey::new_unique(); - let mut locks = ThreadAwareAccountLocks::new(TEST_NUM_THREADS); - locks.lock_accounts([&pk1].into_iter(), std::iter::empty(), TEST_NUM_THREADS); - } - - #[test] - fn test_thread_set() { - let mut thread_set = ThreadSet::none(); - assert!(thread_set.is_empty()); - assert_eq!(thread_set.num_threads(), 0); - assert_eq!(thread_set.only_one_contained(), None); - for idx in 0..MAX_THREADS { - assert!(!thread_set.contains(idx)); - } - - thread_set.insert(4); - assert!(!thread_set.is_empty()); - assert_eq!(thread_set.num_threads(), 1); - assert_eq!(thread_set.only_one_contained(), Some(4)); - for idx in 0..MAX_THREADS { - assert_eq!(thread_set.contains(idx), idx == 4); - } - - thread_set.insert(2); - assert!(!thread_set.is_empty()); - assert_eq!(thread_set.num_threads(), 2); - assert_eq!(thread_set.only_one_contained(), None); - for idx in 0..MAX_THREADS { - assert_eq!(thread_set.contains(idx), idx == 2 || idx == 4); - } - - thread_set.remove(4); - assert!(!thread_set.is_empty()); - assert_eq!(thread_set.num_threads(), 1); - assert_eq!(thread_set.only_one_contained(), Some(2)); - for idx in 0..MAX_THREADS { - assert_eq!(thread_set.contains(idx), idx == 2); - } - } - - #[test] - fn test_thread_set_any_zero() { - let any_threads = ThreadSet::any(0); - assert_eq!(any_threads.num_threads(), 0); - } - - #[test] - fn test_thread_set_any_max() { - let any_threads = ThreadSet::any(MAX_THREADS); - assert_eq!(any_threads.num_threads(), MAX_THREADS as u32); - } -} diff --git a/core/src/banking_stage/transaction_scheduler/transaction_id_generator.rs b/core/src/banking_stage/transaction_scheduler/transaction_id_generator.rs deleted file mode 100644 index f54523890f9caf..00000000000000 --- a/core/src/banking_stage/transaction_scheduler/transaction_id_generator.rs +++ /dev/null @@ -1,21 +0,0 @@ -use crate::banking_stage::scheduler_messages::TransactionId; - -/// Simple reverse-sequential ID generator for `TransactionId`s. -/// These IDs uniquely identify transactions during the scheduling process. -pub struct TransactionIdGenerator { - next_id: u64, -} - -impl Default for TransactionIdGenerator { - fn default() -> Self { - Self { next_id: u64::MAX } - } -} - -impl TransactionIdGenerator { - pub fn next(&mut self) -> TransactionId { - let id = self.next_id; - self.next_id = self.next_id.wrapping_sub(1); - TransactionId::new(id) - } -} diff --git a/core/src/banking_stage/transaction_scheduler/transaction_priority_id.rs b/core/src/banking_stage/transaction_scheduler/transaction_priority_id.rs deleted file mode 100644 index 9857a689519502..00000000000000 --- a/core/src/banking_stage/transaction_scheduler/transaction_priority_id.rs +++ /dev/null @@ -1,69 +0,0 @@ -use { - crate::banking_stage::scheduler_messages::TransactionId, - prio_graph::TopLevelId, - std::hash::{Hash, Hasher}, -}; - -/// A unique identifier tied with priority ordering for a transaction/packet: -#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub(crate) struct TransactionPriorityId { - pub(crate) priority: u64, - pub(crate) id: TransactionId, -} - -impl TransactionPriorityId { - pub(crate) fn new(priority: u64, id: TransactionId) -> Self { - Self { priority, id } - } -} - -impl Hash for TransactionPriorityId { - fn hash(&self, state: &mut H) { - self.id.hash(state) - } -} - -impl TopLevelId for TransactionPriorityId { - fn id(&self) -> Self { - *self - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_transaction_priority_id_ordering() { - // Higher priority first - { - let id1 = TransactionPriorityId::new(1, TransactionId::new(1)); - let id2 = TransactionPriorityId::new(2, TransactionId::new(1)); - assert!(id1 < id2); - assert!(id1 <= id2); - assert!(id2 > id1); - assert!(id2 >= id1); - } - - // Equal priority then compare by id - { - let id1 = TransactionPriorityId::new(1, TransactionId::new(1)); - let id2 = TransactionPriorityId::new(1, TransactionId::new(2)); - assert!(id1 < id2); - assert!(id1 <= id2); - assert!(id2 > id1); - assert!(id2 >= id1); - } - - // Equal priority and id - { - let id1 = TransactionPriorityId::new(1, TransactionId::new(1)); - let id2 = TransactionPriorityId::new(1, TransactionId::new(1)); - assert_eq!(id1, id2); - assert!(id1 >= id2); - assert!(id1 <= id2); - assert!(id2 >= id1); - assert!(id2 <= id1); - } - } -} diff --git a/core/src/banking_stage/transaction_scheduler/transaction_state.rs b/core/src/banking_stage/transaction_scheduler/transaction_state.rs deleted file mode 100644 index 85af8217309e93..00000000000000 --- a/core/src/banking_stage/transaction_scheduler/transaction_state.rs +++ /dev/null @@ -1,359 +0,0 @@ -use { - crate::banking_stage::immutable_deserialized_packet::ImmutableDeserializedPacket, - solana_sdk::{clock::Slot, transaction::SanitizedTransaction}, - std::sync::Arc, -}; - -/// Simple wrapper type to tie a sanitized transaction to max age slot. -pub(crate) struct SanitizedTransactionTTL { - pub(crate) transaction: SanitizedTransaction, - pub(crate) max_age_slot: Slot, -} - -/// TransactionState is used to track the state of a transaction in the transaction scheduler -/// and banking stage as a whole. -/// -/// There are two states a transaction can be in: -/// 1. `Unprocessed` - The transaction is available for scheduling. -/// 2. `Pending` - The transaction is currently scheduled or being processed. -/// -/// Newly received transactions are initially in the `Unprocessed` state. -/// When a transaction is scheduled, it is transitioned to the `Pending` state, -/// using the `transition_to_pending` method. -/// When a transaction finishes processing it may be retryable. If it is retryable, -/// the transaction is transitioned back to the `Unprocessed` state using the -/// `transition_to_unprocessed` method. If it is not retryable, the state should -/// be dropped. -/// -/// For performance, when a transaction is transitioned to the `Pending` state, the -/// internal `SanitizedTransaction` is moved out of the `TransactionState` and sent -/// to the appropriate thread for processing. This is done to avoid cloning the -/// `SanitizedTransaction`. -#[allow(clippy::large_enum_variant)] -pub(crate) enum TransactionState { - /// The transaction is available for scheduling. - Unprocessed { - transaction_ttl: SanitizedTransactionTTL, - packet: Arc, - priority: u64, - cost: u64, - should_forward: bool, - }, - /// The transaction is currently scheduled or being processed. - Pending { - packet: Arc, - priority: u64, - cost: u64, - should_forward: bool, - }, - /// Only used during transition. - Transitioning, -} - -impl TransactionState { - /// Creates a new `TransactionState` in the `Unprocessed` state. - pub(crate) fn new( - transaction_ttl: SanitizedTransactionTTL, - packet: Arc, - priority: u64, - cost: u64, - ) -> Self { - let should_forward = !packet.original_packet().meta().forwarded() - && packet.original_packet().meta().is_from_staked_node(); - Self::Unprocessed { - transaction_ttl, - packet, - priority, - cost, - should_forward, - } - } - - /// Return the priority of the transaction. - /// This is *not* the same as the `compute_unit_price` of the transaction. - /// The priority is used to order transactions for processing. - pub(crate) fn priority(&self) -> u64 { - match self { - Self::Unprocessed { priority, .. } => *priority, - Self::Pending { priority, .. } => *priority, - Self::Transitioning => unreachable!(), - } - } - - /// Return the cost of the transaction. - pub(crate) fn cost(&self) -> u64 { - match self { - Self::Unprocessed { cost, .. } => *cost, - Self::Pending { cost, .. } => *cost, - Self::Transitioning => unreachable!(), - } - } - - /// Return whether packet should be attempted to be forwarded. - pub(crate) fn should_forward(&self) -> bool { - match self { - Self::Unprocessed { - should_forward: forwarded, - .. - } => *forwarded, - Self::Pending { - should_forward: forwarded, - .. - } => *forwarded, - Self::Transitioning => unreachable!(), - } - } - - /// Mark the packet as forwarded. - /// This is used to prevent the packet from being forwarded multiple times. - pub(crate) fn mark_forwarded(&mut self) { - match self { - Self::Unprocessed { should_forward, .. } => *should_forward = false, - Self::Pending { should_forward, .. } => *should_forward = false, - Self::Transitioning => unreachable!(), - } - } - - /// Return the packet of the transaction. - pub(crate) fn packet(&self) -> &Arc { - match self { - Self::Unprocessed { packet, .. } => packet, - Self::Pending { packet, .. } => packet, - Self::Transitioning => unreachable!(), - } - } - - /// Intended to be called when a transaction is scheduled. This method will - /// transition the transaction from `Unprocessed` to `Pending` and return the - /// `SanitizedTransactionTTL` for processing. - /// - /// # Panics - /// This method will panic if the transaction is already in the `Pending` state, - /// as this is an invalid state transition. - pub(crate) fn transition_to_pending(&mut self) -> SanitizedTransactionTTL { - match self.take() { - TransactionState::Unprocessed { - transaction_ttl, - packet, - priority, - cost, - should_forward: forwarded, - } => { - *self = TransactionState::Pending { - packet, - priority, - cost, - should_forward: forwarded, - }; - transaction_ttl - } - TransactionState::Pending { .. } => { - panic!("transaction already pending"); - } - Self::Transitioning => unreachable!(), - } - } - - /// Intended to be called when a transaction is retried. This method will - /// transition the transaction from `Pending` to `Unprocessed`. - /// - /// # Panics - /// This method will panic if the transaction is already in the `Unprocessed` - /// state, as this is an invalid state transition. - pub(crate) fn transition_to_unprocessed(&mut self, transaction_ttl: SanitizedTransactionTTL) { - match self.take() { - TransactionState::Unprocessed { .. } => panic!("already unprocessed"), - TransactionState::Pending { - packet, - priority, - cost, - should_forward: forwarded, - } => { - *self = Self::Unprocessed { - transaction_ttl, - packet, - priority, - cost, - should_forward: forwarded, - } - } - Self::Transitioning => unreachable!(), - } - } - - /// Get a reference to the `SanitizedTransactionTTL` for the transaction. - /// - /// # Panics - /// This method will panic if the transaction is in the `Pending` state. - pub(crate) fn transaction_ttl(&self) -> &SanitizedTransactionTTL { - match self { - Self::Unprocessed { - transaction_ttl, .. - } => transaction_ttl, - Self::Pending { .. } => panic!("transaction is pending"), - Self::Transitioning => unreachable!(), - } - } - - /// Internal helper to transitioning between states. - /// Replaces `self` with a dummy state that will immediately be overwritten in transition. - fn take(&mut self) -> Self { - core::mem::replace(self, Self::Transitioning) - } -} - -#[cfg(test)] -mod tests { - use { - super::*, - solana_sdk::{ - compute_budget::ComputeBudgetInstruction, hash::Hash, message::Message, packet::Packet, - signature::Keypair, signer::Signer, system_instruction, transaction::Transaction, - }, - }; - - fn create_transaction_state(compute_unit_price: u64) -> TransactionState { - let from_keypair = Keypair::new(); - let ixs = vec![ - system_instruction::transfer( - &from_keypair.pubkey(), - &solana_sdk::pubkey::new_rand(), - 1, - ), - ComputeBudgetInstruction::set_compute_unit_price(compute_unit_price), - ]; - let message = Message::new(&ixs, Some(&from_keypair.pubkey())); - let tx = Transaction::new(&[&from_keypair], message, Hash::default()); - - let packet = Arc::new( - ImmutableDeserializedPacket::new(Packet::from_data(None, tx.clone()).unwrap()).unwrap(), - ); - let transaction_ttl = SanitizedTransactionTTL { - transaction: SanitizedTransaction::from_transaction_for_tests(tx), - max_age_slot: Slot::MAX, - }; - const TEST_TRANSACTION_COST: u64 = 5000; - TransactionState::new( - transaction_ttl, - packet, - compute_unit_price, - TEST_TRANSACTION_COST, - ) - } - - #[test] - #[should_panic(expected = "already pending")] - fn test_transition_to_pending_panic() { - let mut transaction_state = create_transaction_state(0); - transaction_state.transition_to_pending(); - transaction_state.transition_to_pending(); // invalid transition - } - - #[test] - fn test_transition_to_pending() { - let mut transaction_state = create_transaction_state(0); - assert!(matches!( - transaction_state, - TransactionState::Unprocessed { .. } - )); - let _ = transaction_state.transition_to_pending(); - assert!(matches!( - transaction_state, - TransactionState::Pending { .. } - )); - } - - #[test] - #[should_panic(expected = "already unprocessed")] - fn test_transition_to_unprocessed_panic() { - let mut transaction_state = create_transaction_state(0); - - // Manually clone `SanitizedTransactionTTL` - let SanitizedTransactionTTL { - transaction, - max_age_slot, - } = transaction_state.transaction_ttl(); - let transaction_ttl = SanitizedTransactionTTL { - transaction: transaction.clone(), - max_age_slot: *max_age_slot, - }; - transaction_state.transition_to_unprocessed(transaction_ttl); // invalid transition - } - - #[test] - fn test_transition_to_unprocessed() { - let mut transaction_state = create_transaction_state(0); - assert!(matches!( - transaction_state, - TransactionState::Unprocessed { .. } - )); - let transaction_ttl = transaction_state.transition_to_pending(); - assert!(matches!( - transaction_state, - TransactionState::Pending { .. } - )); - transaction_state.transition_to_unprocessed(transaction_ttl); - assert!(matches!( - transaction_state, - TransactionState::Unprocessed { .. } - )); - } - - #[test] - fn test_priority() { - let priority = 15; - let mut transaction_state = create_transaction_state(priority); - assert_eq!(transaction_state.priority(), priority); - - // ensure compute unit price is not lost through state transitions - let transaction_ttl = transaction_state.transition_to_pending(); - assert_eq!(transaction_state.priority(), priority); - transaction_state.transition_to_unprocessed(transaction_ttl); - assert_eq!(transaction_state.priority(), priority); - } - - #[test] - #[should_panic(expected = "transaction is pending")] - fn test_transaction_ttl_panic() { - let mut transaction_state = create_transaction_state(0); - let transaction_ttl = transaction_state.transaction_ttl(); - assert!(matches!( - transaction_state, - TransactionState::Unprocessed { .. } - )); - assert_eq!(transaction_ttl.max_age_slot, Slot::MAX); - - let _ = transaction_state.transition_to_pending(); - assert!(matches!( - transaction_state, - TransactionState::Pending { .. } - )); - let _ = transaction_state.transaction_ttl(); // pending state, the transaction ttl is not available - } - - #[test] - fn test_transaction_ttl() { - let mut transaction_state = create_transaction_state(0); - let transaction_ttl = transaction_state.transaction_ttl(); - assert!(matches!( - transaction_state, - TransactionState::Unprocessed { .. } - )); - assert_eq!(transaction_ttl.max_age_slot, Slot::MAX); - - // ensure transaction_ttl is not lost through state transitions - let transaction_ttl = transaction_state.transition_to_pending(); - assert!(matches!( - transaction_state, - TransactionState::Pending { .. } - )); - - transaction_state.transition_to_unprocessed(transaction_ttl); - let transaction_ttl = transaction_state.transaction_ttl(); - assert!(matches!( - transaction_state, - TransactionState::Unprocessed { .. } - )); - assert_eq!(transaction_ttl.max_age_slot, Slot::MAX); - } -} diff --git a/core/src/banking_stage/transaction_scheduler/transaction_state_container.rs b/core/src/banking_stage/transaction_scheduler/transaction_state_container.rs deleted file mode 100644 index ed78b41983fa2a..00000000000000 --- a/core/src/banking_stage/transaction_scheduler/transaction_state_container.rs +++ /dev/null @@ -1,261 +0,0 @@ -use { - super::{ - transaction_priority_id::TransactionPriorityId, - transaction_state::{SanitizedTransactionTTL, TransactionState}, - }, - crate::banking_stage::{ - immutable_deserialized_packet::ImmutableDeserializedPacket, - scheduler_messages::TransactionId, - }, - itertools::MinMaxResult, - min_max_heap::MinMaxHeap, - std::{collections::HashMap, sync::Arc}, -}; - -/// This structure will hold `TransactionState` for the entirety of a -/// transaction's lifetime in the scheduler and BankingStage as a whole. -/// -/// Transaction Lifetime: -/// 1. Received from `SigVerify` by `BankingStage` -/// 2. Inserted into `TransactionStateContainer` by `BankingStage` -/// 3. Popped in priority-order by scheduler, and transitioned to `Pending` state -/// 4. Processed by `ConsumeWorker` -/// a. If consumed, remove `Pending` state from the `TransactionStateContainer` -/// b. If retryable, transition back to `Unprocessed` state. -/// Re-insert to the queue, and return to step 3. -/// -/// The structure is composed of two main components: -/// 1. A priority queue of wrapped `TransactionId`s, which are used to -/// order transactions by priority for selection by the scheduler. -/// 2. A map of `TransactionId` to `TransactionState`, which is used to -/// track the state of each transaction. -/// -/// When `Pending`, the associated `TransactionId` is not in the queue, but -/// is still in the map. -/// The entry in the map should exist before insertion into the queue, and be -/// be removed only after the id is removed from the queue. -/// -/// The container maintains a fixed capacity. If the queue is full when pushing -/// a new transaction, the lowest priority transaction will be dropped. -pub(crate) struct TransactionStateContainer { - priority_queue: MinMaxHeap, - id_to_transaction_state: HashMap, -} - -impl TransactionStateContainer { - pub(crate) fn with_capacity(capacity: usize) -> Self { - Self { - priority_queue: MinMaxHeap::with_capacity(capacity), - id_to_transaction_state: HashMap::with_capacity(capacity), - } - } - - /// Returns true if the queue is empty. - pub(crate) fn is_empty(&self) -> bool { - self.priority_queue.is_empty() - } - - /// Returns the remaining capacity of the queue - pub(crate) fn remaining_queue_capacity(&self) -> usize { - self.priority_queue.capacity() - self.priority_queue.len() - } - - /// Get the top transaction id in the priority queue. - pub(crate) fn pop(&mut self) -> Option { - self.priority_queue.pop_max() - } - - /// Get mutable transaction state by id. - pub(crate) fn get_mut_transaction_state( - &mut self, - id: &TransactionId, - ) -> Option<&mut TransactionState> { - self.id_to_transaction_state.get_mut(id) - } - - /// Get reference to `SanitizedTransactionTTL` by id. - /// Panics if the transaction does not exist. - pub(crate) fn get_transaction_ttl( - &self, - id: &TransactionId, - ) -> Option<&SanitizedTransactionTTL> { - self.id_to_transaction_state - .get(id) - .map(|state| state.transaction_ttl()) - } - - /// Insert a new transaction into the container's queues and maps. - /// Returns `true` if a packet was dropped due to capacity limits. - pub(crate) fn insert_new_transaction( - &mut self, - transaction_id: TransactionId, - transaction_ttl: SanitizedTransactionTTL, - packet: Arc, - priority: u64, - cost: u64, - ) -> bool { - let priority_id = TransactionPriorityId::new(priority, transaction_id); - self.id_to_transaction_state.insert( - transaction_id, - TransactionState::new(transaction_ttl, packet, priority, cost), - ); - self.push_id_into_queue(priority_id) - } - - /// Retries a transaction - inserts transaction back into map (but not packet). - /// This transitions the transaction to `Unprocessed` state. - pub(crate) fn retry_transaction( - &mut self, - transaction_id: TransactionId, - transaction_ttl: SanitizedTransactionTTL, - ) { - let transaction_state = self - .get_mut_transaction_state(&transaction_id) - .expect("transaction must exist"); - let priority_id = TransactionPriorityId::new(transaction_state.priority(), transaction_id); - transaction_state.transition_to_unprocessed(transaction_ttl); - self.push_id_into_queue(priority_id); - } - - /// Pushes a transaction id into the priority queue. If the queue is full, the lowest priority - /// transaction will be dropped (removed from the queue and map). - /// Returns `true` if a packet was dropped due to capacity limits. - pub(crate) fn push_id_into_queue(&mut self, priority_id: TransactionPriorityId) -> bool { - if self.remaining_queue_capacity() == 0 { - let popped_id = self.priority_queue.push_pop_min(priority_id); - self.remove_by_id(&popped_id.id); - true - } else { - self.priority_queue.push(priority_id); - false - } - } - - /// Remove transaction by id. - pub(crate) fn remove_by_id(&mut self, id: &TransactionId) { - self.id_to_transaction_state - .remove(id) - .expect("transaction must exist"); - } - - pub(crate) fn get_min_max_priority(&self) -> MinMaxResult { - match self.priority_queue.peek_min() { - Some(min) => match self.priority_queue.peek_max() { - Some(max) => MinMaxResult::MinMax(min.priority, max.priority), - None => MinMaxResult::OneElement(min.priority), - }, - None => MinMaxResult::NoElements, - } - } -} - -#[cfg(test)] -mod tests { - use { - super::*, - solana_sdk::{ - compute_budget::ComputeBudgetInstruction, - hash::Hash, - message::Message, - packet::Packet, - signature::Keypair, - signer::Signer, - slot_history::Slot, - system_instruction, - transaction::{SanitizedTransaction, Transaction}, - }, - }; - - /// Returns (transaction_ttl, priority, cost) - fn test_transaction( - priority: u64, - ) -> ( - SanitizedTransactionTTL, - Arc, - u64, - u64, - ) { - let from_keypair = Keypair::new(); - let ixs = vec![ - system_instruction::transfer( - &from_keypair.pubkey(), - &solana_sdk::pubkey::new_rand(), - 1, - ), - ComputeBudgetInstruction::set_compute_unit_price(priority), - ]; - let message = Message::new(&ixs, Some(&from_keypair.pubkey())); - let tx = SanitizedTransaction::from_transaction_for_tests(Transaction::new( - &[&from_keypair], - message, - Hash::default(), - )); - let packet = Arc::new( - ImmutableDeserializedPacket::new( - Packet::from_data(None, tx.to_versioned_transaction()).unwrap(), - ) - .unwrap(), - ); - let transaction_ttl = SanitizedTransactionTTL { - transaction: tx, - max_age_slot: Slot::MAX, - }; - const TEST_TRANSACTION_COST: u64 = 5000; - (transaction_ttl, packet, priority, TEST_TRANSACTION_COST) - } - - fn push_to_container(container: &mut TransactionStateContainer, num: usize) { - for id in 0..num as u64 { - let priority = id; - let (transaction_ttl, packet, priority, cost) = test_transaction(priority); - container.insert_new_transaction( - TransactionId::new(id), - transaction_ttl, - packet, - priority, - cost, - ); - } - } - - #[test] - fn test_is_empty() { - let mut container = TransactionStateContainer::with_capacity(1); - assert!(container.is_empty()); - - push_to_container(&mut container, 1); - assert!(!container.is_empty()); - } - - #[test] - fn test_priority_queue_capacity() { - let mut container = TransactionStateContainer::with_capacity(1); - push_to_container(&mut container, 5); - - assert_eq!(container.priority_queue.len(), 1); - assert_eq!(container.id_to_transaction_state.len(), 1); - assert_eq!( - container - .id_to_transaction_state - .iter() - .map(|ts| ts.1.priority()) - .next() - .unwrap(), - 4 - ); - } - - #[test] - fn test_get_mut_transaction_state() { - let mut container = TransactionStateContainer::with_capacity(5); - push_to_container(&mut container, 5); - - let existing_id = TransactionId::new(3); - let non_existing_id = TransactionId::new(7); - assert!(container.get_mut_transaction_state(&existing_id).is_some()); - assert!(container.get_mut_transaction_state(&existing_id).is_some()); - assert!(container - .get_mut_transaction_state(&non_existing_id) - .is_none()); - } -} diff --git a/core/src/banking_stage/unprocessed_packet_batches.rs b/core/src/banking_stage/unprocessed_packet_batches.rs index f92eeb09c57b54..c2616867f3f0be 100644 --- a/core/src/banking_stage/unprocessed_packet_batches.rs +++ b/core/src/banking_stage/unprocessed_packet_batches.rs @@ -1,13 +1,9 @@ use { - super::immutable_deserialized_packet::{DeserializedPacketError, ImmutableDeserializedPacket}, - min_max_heap::MinMaxHeap, - solana_perf::packet::Packet, - solana_sdk::hash::Hash, - std::{ + super::immutable_deserialized_packet::{DeserializedPacketError, ImmutableDeserializedPacket}, min_max_heap::MinMaxHeap, solana_perf::packet::Packet, solana_prio_graph_scheduler::deserializable_packet::DeserializableTxPacket, solana_sdk::hash::Hash, std::{ cmp::Ordering, collections::{hash_map::Entry, HashMap}, sync::Arc, - }, + } }; /// Holds deserialized messages, as well as computed message_hash and other things needed to create diff --git a/core/src/banking_stage/unprocessed_transaction_storage.rs b/core/src/banking_stage/unprocessed_transaction_storage.rs index f612f5eaf08b11..5eeaea5d082e62 100644 --- a/core/src/banking_stage/unprocessed_transaction_storage.rs +++ b/core/src/banking_stage/unprocessed_transaction_storage.rs @@ -20,6 +20,7 @@ use { solana_accounts_db::account_locks::validate_account_locks, solana_feature_set::FeatureSet, solana_measure::measure_us, + solana_prio_graph_scheduler::deserializable_packet::DeserializableTxPacket, solana_runtime::bank::Bank, solana_sdk::{ clock::FORWARD_TRANSACTIONS_TO_LEADER_AT_SLOT_OFFSET, hash::Hash, saturating_add_assign, @@ -1326,7 +1327,9 @@ mod tests { VoteSource::Tpu, ); - transaction_storage.insert_batch(vec![ImmutableDeserializedPacket::new(vote.clone())?]); + transaction_storage.insert_batch(vec![ImmutableDeserializedPacket::new( + vote.clone(), + )?]); assert_eq!(1, transaction_storage.len()); // When processing packets, return all packets as retryable so that they diff --git a/prio-graph-scheduler/src/deserializable_packet.rs b/prio-graph-scheduler/src/deserializable_packet.rs index 0f54b0ec6de047..d814351652c61b 100644 --- a/prio-graph-scheduler/src/deserializable_packet.rs +++ b/prio-graph-scheduler/src/deserializable_packet.rs @@ -13,7 +13,7 @@ use std::error::Error; pub trait DeserializableTxPacket: PartialEq + PartialOrd + Eq + Sized { type DeserializeError: Error; - fn from_packet(packet: Packet) -> Result; + fn new(packet: Packet) -> Result; /// This function deserializes packets into transactions, /// computes the blake3 hash of transaction messages. diff --git a/prio-graph-scheduler/src/lib.rs b/prio-graph-scheduler/src/lib.rs index 8e4ddb9d76bb45..9dbb4027c26995 100644 --- a/prio-graph-scheduler/src/lib.rs +++ b/prio-graph-scheduler/src/lib.rs @@ -75,7 +75,7 @@ mod tests { impl DeserializableTxPacket for MockImmutableDeserializedPacket { type DeserializeError = MockDeserializedPacketError; - fn from_packet(packet: Packet) -> Result { + fn new(packet: Packet) -> Result { let versioned_transaction: VersionedTransaction = packet.deserialize_slice(..)?; let sanitized_transaction = SanitizedVersionedTransaction::try_from(versioned_transaction)?; diff --git a/prio-graph-scheduler/src/prio_graph_scheduler.rs b/prio-graph-scheduler/src/prio_graph_scheduler.rs index 209e0799951d22..bf391919ed7605 100644 --- a/prio-graph-scheduler/src/prio_graph_scheduler.rs +++ b/prio-graph-scheduler/src/prio_graph_scheduler.rs @@ -679,7 +679,7 @@ mod tests { compute_unit_price, ); let packet = Arc::new( - MockImmutableDeserializedPacket::from_packet( + MockImmutableDeserializedPacket::new( Packet::from_data(None, transaction.to_versioned_transaction()).unwrap(), ) .unwrap(), diff --git a/prio-graph-scheduler/src/transaction_state.rs b/prio-graph-scheduler/src/transaction_state.rs index 422f4c1f8c6506..1e7b7a645622d9 100644 --- a/prio-graph-scheduler/src/transaction_state.rs +++ b/prio-graph-scheduler/src/transaction_state.rs @@ -227,7 +227,7 @@ mod tests { let tx = Transaction::new(&[&from_keypair], message, Hash::default()); let packet = Arc::new( - MockImmutableDeserializedPacket::from_packet(Packet::from_data(None, tx.clone()).unwrap()).unwrap(), + MockImmutableDeserializedPacket::new(Packet::from_data(None, tx.clone()).unwrap()).unwrap(), ); let transaction_ttl = SanitizedTransactionTTL { transaction: SanitizedTransaction::from_transaction_for_tests(tx), diff --git a/prio-graph-scheduler/src/transaction_state_container.rs b/prio-graph-scheduler/src/transaction_state_container.rs index 6e6ad444eea977..4e1e998ba6f751 100644 --- a/prio-graph-scheduler/src/transaction_state_container.rs +++ b/prio-graph-scheduler/src/transaction_state_container.rs @@ -184,7 +184,7 @@ mod tests { Hash::default(), )); let packet = Arc::new( - MockImmutableDeserializedPacket::from_packet( + MockImmutableDeserializedPacket::new( Packet::from_data(None, tx.to_versioned_transaction()).unwrap(), ) .unwrap(), From 7f6ad5642b6f1e85ed0883a353c3e3faf4ea0b9b Mon Sep 17 00:00:00 2001 From: lewis Date: Tue, 15 Oct 2024 10:58:26 +0800 Subject: [PATCH 8/9] chore: remove useless pub --- core/src/banking_stage.rs | 6 +++--- .../banking_stage/immutable_deserialized_packet.rs | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/core/src/banking_stage.rs b/core/src/banking_stage.rs index eeb76fe4d68024..17a9b9fb5856ed 100644 --- a/core/src/banking_stage.rs +++ b/core/src/banking_stage.rs @@ -61,18 +61,18 @@ pub mod qos_service; pub mod unprocessed_packet_batches; pub mod unprocessed_transaction_storage; -pub mod consume_worker; +mod consume_worker; mod decision_maker; mod forward_packet_batches_by_accounts; mod forward_worker; -pub mod immutable_deserialized_packet; +mod immutable_deserialized_packet; mod latest_unprocessed_votes; mod leader_slot_timing_metrics; mod multi_iterator_scanner; mod packet_deserializer; mod packet_filter; mod packet_receiver; -pub mod read_write_account_set; +mod read_write_account_set; mod scheduler_controller; // Fixed thread size seems to be fastest on GCP setup diff --git a/core/src/banking_stage/immutable_deserialized_packet.rs b/core/src/banking_stage/immutable_deserialized_packet.rs index 2a29e6f3b10ac6..fe0f27e5aa9599 100644 --- a/core/src/banking_stage/immutable_deserialized_packet.rs +++ b/core/src/banking_stage/immutable_deserialized_packet.rs @@ -42,12 +42,12 @@ pub enum DeserializedPacketError { #[derive(Debug, Eq)] pub struct ImmutableDeserializedPacket { - pub original_packet: Packet, - pub transaction: SanitizedVersionedTransaction, - pub message_hash: Hash, - pub is_simple_vote: bool, - pub compute_unit_price: u64, - pub compute_unit_limit: u32, + original_packet: Packet, + transaction: SanitizedVersionedTransaction, + message_hash: Hash, + is_simple_vote: bool, + compute_unit_price: u64, + compute_unit_limit: u32, } impl DeserializableTxPacket for ImmutableDeserializedPacket { From 8d5bfe156250bc67ff6eb191051f4b586ec04628 Mon Sep 17 00:00:00 2001 From: lewis Date: Thu, 17 Oct 2024 14:55:18 +0800 Subject: [PATCH 9/9] chore: public some api --- .../src/thread_aware_account_locks.rs | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/prio-graph-scheduler/src/thread_aware_account_locks.rs b/prio-graph-scheduler/src/thread_aware_account_locks.rs index d5de72547008c3..f5b134d4f4f296 100644 --- a/prio-graph-scheduler/src/thread_aware_account_locks.rs +++ b/prio-graph-scheduler/src/thread_aware_account_locks.rs @@ -8,16 +8,16 @@ use { }, }; -pub(crate) const MAX_THREADS: usize = u64::BITS as usize; +pub const MAX_THREADS: usize = u64::BITS as usize; /// Identifier for a thread -pub(crate) type ThreadId = usize; // 0..MAX_THREADS-1 +pub type ThreadId = usize; // 0..MAX_THREADS-1 type LockCount = u32; /// A bit-set of threads an account is scheduled or can be scheduled for. #[derive(Copy, Clone, PartialEq, Eq)] -pub(crate) struct ThreadSet(u64); +pub struct ThreadSet(u64); struct AccountWriteLocks { thread_id: ThreadId, @@ -44,7 +44,7 @@ struct AccountLocks { /// that already hold locks on the account. This is useful for allowing /// queued transactions to be scheduled on a thread while the transaction /// is still being executed on the thread. -pub(crate) struct ThreadAwareAccountLocks { +pub struct ThreadAwareAccountLocks { /// Number of threads. num_threads: usize, // 0..MAX_THREADS /// Locks for each account. An account should only have an entry if there @@ -54,7 +54,7 @@ pub(crate) struct ThreadAwareAccountLocks { impl ThreadAwareAccountLocks { /// Creates a new `ThreadAwareAccountLocks` with the given number of threads. - pub(crate) fn new(num_threads: usize) -> Self { + pub fn new(num_threads: usize) -> Self { assert!(num_threads > 0, "num threads must be > 0"); assert!( num_threads <= MAX_THREADS, @@ -74,7 +74,7 @@ impl ThreadAwareAccountLocks { /// selected by the `thread_selector` function. /// `thread_selector` is only called if all accounts are schdulable, meaning /// that the `thread_set` passed to `thread_selector` is non-empty. - pub(crate) fn try_lock_accounts<'a>( + pub fn try_lock_accounts<'a>( &mut self, write_account_locks: impl Iterator + Clone, read_account_locks: impl Iterator + Clone, @@ -93,7 +93,7 @@ impl ThreadAwareAccountLocks { } /// Unlocks the accounts for the given thread. - pub(crate) fn unlock_accounts<'a>( + pub fn unlock_accounts<'a>( &mut self, write_account_locks: impl Iterator, read_account_locks: impl Iterator, @@ -371,12 +371,12 @@ impl Debug for ThreadSet { impl ThreadSet { #[inline(always)] - pub(crate) const fn none() -> Self { + pub const fn none() -> Self { Self(0b0) } #[inline(always)] - pub(crate) const fn any(num_threads: usize) -> Self { + pub const fn any(num_threads: usize) -> Self { if num_threads == MAX_THREADS { Self(u64::MAX) } else { @@ -385,42 +385,42 @@ impl ThreadSet { } #[inline(always)] - pub(crate) const fn only(thread_id: ThreadId) -> Self { + pub const fn only(thread_id: ThreadId) -> Self { Self(Self::as_flag(thread_id)) } #[inline(always)] - pub(crate) fn num_threads(&self) -> u32 { + pub fn num_threads(&self) -> u32 { self.0.count_ones() } #[inline(always)] - pub(crate) fn only_one_contained(&self) -> Option { + pub fn only_one_contained(&self) -> Option { (self.num_threads() == 1).then_some(self.0.trailing_zeros() as ThreadId) } #[inline(always)] - pub(crate) fn is_empty(&self) -> bool { + pub fn is_empty(&self) -> bool { self == &Self::none() } #[inline(always)] - pub(crate) fn contains(&self, thread_id: ThreadId) -> bool { + pub fn contains(&self, thread_id: ThreadId) -> bool { self.0 & Self::as_flag(thread_id) != 0 } #[inline(always)] - pub(crate) fn insert(&mut self, thread_id: ThreadId) { + pub fn insert(&mut self, thread_id: ThreadId) { self.0 |= Self::as_flag(thread_id); } #[inline(always)] - pub(crate) fn remove(&mut self, thread_id: ThreadId) { + pub fn remove(&mut self, thread_id: ThreadId) { self.0 &= !Self::as_flag(thread_id); } #[inline(always)] - pub(crate) fn contained_threads_iter(self) -> impl Iterator { + pub fn contained_threads_iter(self) -> impl Iterator { (0..MAX_THREADS).filter(move |thread_id| self.contains(*thread_id)) }