diff --git a/Cargo.toml b/Cargo.toml index db72e3a..c69b739 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "resolvo" version = "0.3.0" -authors = ["Adolfo Ochagavía ", "Bas Zalmstra ", "Tim de Jager " ] +authors = ["Adolfo Ochagavía ", "Bas Zalmstra ", "Tim de Jager "] description = "Fast package resolver written in Rust (CDCL based SAT solving)" keywords = ["dependency", "solver", "version"] categories = ["algorithms"] @@ -10,17 +10,25 @@ repository = "https://github.com/mamba-org/resolvo" license = "BSD-3-Clause" edition = "2021" readme = "README.md" +resolver = "2" [dependencies] -itertools = "0.11.0" +itertools = "0.12.1" petgraph = "0.6.4" tracing = "0.1.37" elsa = "1.9.0" bitvec = "1.0.1" serde = { version = "1.0", features = ["derive"], optional = true } +futures = { version = "0.3.30", default-features = false, features = ["alloc"] } +event-listener = "5.0.0" + +tokio = { version = "1.35.1", features = ["rt"], optional = true } +async-std = { version = "1.12.0", default-features = false, features = ["alloc", "default"], optional = true } [dev-dependencies] insta = "1.31.0" indexmap = "2.0.0" proptest = "1.2.0" tracing-test = { version = "0.2.4", features = ["no-env-filter"] } +tokio = { version = "1.35.1", features = ["time", "rt"] } +resolvo = { path = ".", features = ["tokio"] } \ No newline at end of file diff --git a/rust-toolchain b/rust-toolchain index 0834888..7c7053a 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -1.72.0 +1.75.0 diff --git a/src/lib.rs b/src/lib.rs index 5985145..e520eda 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,6 +14,7 @@ pub(crate) mod internal; mod pool; pub mod problem; pub mod range; +pub mod runtime; mod solvable; mod solver; @@ -30,9 +31,10 @@ use std::{ any::Any, fmt::{Debug, Display}, hash::Hash, + rc::Rc, }; -/// The solver is based around the fact that for for every package name we are trying to find a +/// The solver is based around the fact that for every package name we are trying to find a /// single variant. Variants are grouped by their respective package name. A package name is /// anything that we can compare and hash for uniqueness checks. /// @@ -44,7 +46,7 @@ pub trait PackageName: Eq + Hash {} impl PackageName for N {} -/// A [`VersionSet`] is describes a set of "versions". The trait defines whether a given version +/// A [`VersionSet`] describes a set of "versions". The trait defines whether a given version /// is part of the set or not. /// /// One could implement [`VersionSet`] for [`std::ops::Range`] where the implementation @@ -61,21 +63,26 @@ pub trait VersionSet: Debug + Display + Clone + Eq + Hash { /// packages that are available in the system. pub trait DependencyProvider: Sized { /// Returns the [`Pool`] that is used to allocate the Ids returned from this instance - fn pool(&self) -> &Pool; + fn pool(&self) -> Rc>; /// Sort the specified solvables based on which solvable to try first. The solver will /// iteratively try to select the highest version. If a conflict is found with the highest /// version the next version is tried. This continues until a solution is found. - fn sort_candidates(&self, solver: &SolverCache, solvables: &mut [SolvableId]); + #[allow(async_fn_in_trait)] + async fn sort_candidates( + &self, + solver: &SolverCache, + solvables: &mut [SolvableId], + ); - /// Returns a list of solvables that should be considered when a package with the given name is + /// Obtains a list of solvables that should be considered when a package with the given name is /// requested. - /// - /// Returns `None` if no such package exist. - fn get_candidates(&self, name: NameId) -> Option; + #[allow(async_fn_in_trait)] + async fn get_candidates(&self, name: NameId) -> Option; /// Returns the dependencies for the specified solvable. - fn get_dependencies(&self, solvable: SolvableId) -> Dependencies; + #[allow(async_fn_in_trait)] + async fn get_dependencies(&self, solvable: SolvableId) -> Dependencies; /// Whether the solver should stop the dependency resolution algorithm. /// @@ -126,6 +133,7 @@ pub struct Candidates { } /// Holds information about the dependencies of a package. +#[derive(Debug, Clone)] pub enum Dependencies { /// The dependencies are known. Known(KnownDependencies), diff --git a/src/problem.rs b/src/problem.rs index 060e61a..caffb10 100644 --- a/src/problem.rs +++ b/src/problem.rs @@ -4,7 +4,6 @@ use std::collections::{HashMap, HashSet}; use std::fmt; use std::fmt::{Display, Formatter}; use std::hash::Hash; - use std::rc::Rc; use itertools::Itertools; @@ -12,10 +11,10 @@ use petgraph::graph::{DiGraph, EdgeIndex, EdgeReference, NodeIndex}; use petgraph::visit::{Bfs, DfsPostOrder, EdgeRef}; use petgraph::Direction; -use crate::internal::id::StringId; use crate::{ - internal::id::{ClauseId, SolvableId, VersionSetId}, + internal::id::{ClauseId, SolvableId, StringId, VersionSetId}, pool::Pool, + runtime::AsyncRuntime, solver::{clause::Clause, Solver}, DependencyProvider, PackageName, SolvableDisplay, VersionSet, }; @@ -41,9 +40,9 @@ impl Problem { } /// Generates a graph representation of the problem (see [`ProblemGraph`] for details) - pub fn graph>( + pub fn graph, RT: AsyncRuntime>( &self, - solver: &Solver, + solver: &Solver, ) -> ProblemGraph { let mut graph = DiGraph::::default(); let mut nodes: HashMap = HashMap::default(); @@ -53,7 +52,7 @@ impl Problem { let unresolved_node = graph.add_node(ProblemNode::UnresolvedDependency); for clause_id in &self.clauses { - let clause = &solver.clauses[*clause_id].kind; + let clause = &solver.clauses.borrow()[*clause_id].kind; match clause { Clause::InstallRoot => (), Clause::Excluded(solvable, reason) => { @@ -73,7 +72,7 @@ impl Problem { &Clause::Requires(package_id, version_set_id) => { let package_node = Self::add_node(&mut graph, &mut nodes, package_id); - let candidates = solver.cache.get_or_cache_sorted_candidates(version_set_id).unwrap_or_else(|_| { + let candidates = solver.async_runtime.block_on(solver.cache.get_or_cache_sorted_candidates(version_set_id)).unwrap_or_else(|_| { unreachable!("The version set was used in the solver, so it must have been cached. Therefore cancellation is impossible here and we cannot get an `Err(...)`") }); if candidates.is_empty() { @@ -167,13 +166,15 @@ impl Problem { N: PackageName + Display, D: DependencyProvider, M: SolvableDisplay, + RT: AsyncRuntime, >( &self, - solver: &'a Solver, + solver: &'a Solver, + pool: Rc>, merged_solvable_display: &'a M, ) -> DisplayUnsat<'a, VS, N, M> { let graph = self.graph(solver); - DisplayUnsat::new(graph, solver.pool(), merged_solvable_display) + DisplayUnsat::new(graph, pool, merged_solvable_display) } } @@ -515,7 +516,7 @@ pub struct DisplayUnsat<'pool, VS: VersionSet, N: PackageName + Display, M: Solv merged_candidates: HashMap>, installable_set: HashSet, missing_set: HashSet, - pool: &'pool Pool, + pool: Rc>, merged_solvable_display: &'pool M, } @@ -524,10 +525,10 @@ impl<'pool, VS: VersionSet, N: PackageName + Display, M: SolvableDisplay> { pub(crate) fn new( graph: ProblemGraph, - pool: &'pool Pool, + pool: Rc>, merged_solvable_display: &'pool M, ) -> Self { - let merged_candidates = graph.simplify(pool); + let merged_candidates = graph.simplify(&pool); let installable_set = graph.get_installable_set(); let missing_set = graph.get_missing_set(); @@ -669,10 +670,10 @@ impl<'pool, VS: VersionSet, N: PackageName + Display, M: SolvableDisplay> let version = if let Some(merged) = self.merged_candidates.get(&solvable_id) { reported.extend(merged.ids.iter().cloned()); self.merged_solvable_display - .display_candidates(self.pool, &merged.ids) + .display_candidates(&self.pool, &merged.ids) } else { self.merged_solvable_display - .display_candidates(self.pool, &[solvable_id]) + .display_candidates(&self.pool, &[solvable_id]) }; let excluded = graph @@ -796,9 +797,9 @@ impl> fmt::D writeln!( f, "{indent}{} {} is locked, but another version is required as reported above", - locked.name.display(self.pool), + locked.name.display(&self.pool), self.merged_solvable_display - .display_candidates(self.pool, &[solvable_id]) + .display_candidates(&self.pool, &[solvable_id]) )?; } ConflictCause::Excluded => continue, diff --git a/src/range.rs b/src/range.rs index a08e1ad..ce7cf07 100644 --- a/src/range.rs +++ b/src/range.rs @@ -409,7 +409,7 @@ pub mod tests { segments.push((start_bound, Unbounded)); } - return Range { segments }.check_invariants(); + Range { segments }.check_invariants() }) } diff --git a/src/runtime.rs b/src/runtime.rs new file mode 100644 index 0000000..800330b --- /dev/null +++ b/src/runtime.rs @@ -0,0 +1,78 @@ +//! Solving in resolvo is a compute heavy operation. However, while computing the solver will +//! request additional information from the [`crate::DependencyProvider`] and a dependency provider +//! might want to perform multiple requests concurrently. To that end the +//! [`crate::DependencyProvider`]s methods are async. The implementer can implement the async +//! operations in any way they choose including with any runtime they choose. +//! However, the solver itself is completely single threaded, but it still has to await the calls to +//! the dependency provider. Using the [`AsyncRuntime`] allows the caller of the solver to choose +//! how to await the futures. +//! +//! By default, the solver uses the [`NowOrNeverRuntime`] runtime which polls any future once. If +//! the future yields (thus requiring an additional poll) the runtime panics. If the methods of +//! [`crate::DependencyProvider`] do not yield (e.g. do not `.await`) this will suffice. +//! +//! Only if the [`crate::DependencyProvider`] implementation yields you will need to provide a +//! [`AsyncRuntime`] to the solver. +//! +//! ## `tokio` +//! +//! The [`AsyncRuntime`] trait is implemented both for [`tokio::runtime::Handle`] and for +//! [`tokio::runtime::Runtime`]. +//! +//! ## `async-std` +//! +//! Use the [`AsyncStdRuntime`] struct to block on async methods from the +//! [`crate::DependencyProvider`] using the `async-std` executor. + +use futures::FutureExt; +use std::future::Future; + +/// A trait to wrap an async runtime. +pub trait AsyncRuntime { + /// Runs the given future on the current thread, blocking until it is complete, and yielding its + /// resolved result. + fn block_on(&self, f: F) -> F::Output; +} + +/// The simplest runtime possible evaluates and consumes the future, returning the resulting +/// output if the future is ready after the first call to [`Future::poll`]. If the future does +/// yield the runtime panics. +/// +/// This assumes that the passed in future never yields. For purely blocking computations this +/// is the preferred method since it also incurs very little overhead and doesn't require the +/// inclusion of a heavy-weight runtime. +#[derive(Default, Copy, Clone)] +pub struct NowOrNeverRuntime; + +impl AsyncRuntime for NowOrNeverRuntime { + fn block_on(&self, f: F) -> F::Output { + f.now_or_never() + .expect("can only use non-yielding futures with the NowOrNeverRuntime") + } +} + +#[cfg(feature = "tokio")] +impl AsyncRuntime for tokio::runtime::Handle { + fn block_on(&self, f: F) -> F::Output { + self.block_on(f) + } +} + +#[cfg(feature = "tokio")] +impl AsyncRuntime for tokio::runtime::Runtime { + fn block_on(&self, f: F) -> F::Output { + self.block_on(f) + } +} + +/// An implementation of [`AsyncRuntime`] that spawns and awaits any passed future on the current +/// thread. +#[cfg(feature = "async-std")] +pub struct AsyncStdRuntime; + +#[cfg(feature = "async-std")] +impl AsyncRuntime for AsyncStdRuntime { + fn block_on(&self, f: F) -> F::Output { + async_std::task::block_on(f) + } +} diff --git a/src/solver/cache.rs b/src/solver/cache.rs index 8d25d3d..eed36db 100644 --- a/src/solver/cache.rs +++ b/src/solver/cache.rs @@ -10,9 +10,8 @@ use crate::{ }; use bitvec::vec::BitVec; use elsa::FrozenMap; -use std::any::Any; -use std::cell::RefCell; -use std::marker::PhantomData; +use event_listener::Event; +use std::{any::Any, cell::RefCell, collections::HashMap, marker::PhantomData, rc::Rc}; /// Keeps a cache of previously computed and/or requested information about solvables and version /// sets. @@ -22,6 +21,7 @@ pub struct SolverCache, package_name_to_candidates: FrozenCopyMap, + package_name_to_candidates_in_flight: RefCell>>, /// A mapping of `VersionSetId` to the candidates that match that set. version_set_candidates: FrozenMap>, @@ -53,6 +53,7 @@ impl> SolverCache> SolverCache &Pool { + pub fn pool(&self) -> Rc> { self.provider.pool() } @@ -74,7 +75,7 @@ impl> SolverCache Result<&Candidates, Box> { @@ -89,32 +90,63 @@ impl> SolverCache { + // Found an in-flight request, wait for that request to finish and return the computed result. + in_flight.listen().await; + self.package_name_to_candidates + .get_copy(&package_name) + .expect("after waiting for a request the result should be available") } - } + None => { + // Prepare an in-flight notifier for other requests coming in. + self.package_name_to_candidates_in_flight + .borrow_mut() + .insert(package_name, Rc::new(Event::new())); + + // Otherwise we have to get them from the DependencyProvider + let candidates = self + .provider + .get_candidates(package_name) + .await + .unwrap_or_default(); + + // Store information about which solvables dependency information is easy to + // retrieve. + { + let mut hint_dependencies_available = + self.hint_dependencies_available.borrow_mut(); + for hint_candidate in candidates.hint_dependencies_available.iter() { + let idx = hint_candidate.to_usize(); + if hint_dependencies_available.len() <= idx { + hint_dependencies_available.resize(idx + 1, false); + } + hint_dependencies_available.set(idx, true) + } + } + + // Allocate an ID so we can refer to the candidates from everywhere + let candidates_id = self.candidates.alloc(candidates); + self.package_name_to_candidates + .insert_copy(package_name, candidates_id); - // Allocate an ID so we can refer to the candidates from everywhere - let candidates_id = self.candidates.alloc(candidates); - self.package_name_to_candidates - .insert_copy(package_name, candidates_id); + // Remove the in-flight request now that we inserted the result and notify any waiters + let notifier = self + .package_name_to_candidates_in_flight + .borrow_mut() + .remove(&package_name) + .expect("notifier should be there"); + notifier.notify(usize::MAX); - candidates_id + candidates_id + } + } } }; @@ -126,23 +158,24 @@ impl> SolverCache Result<&[SolvableId], Box> { match self.version_set_candidates.get(&version_set_id) { Some(candidates) => Ok(candidates), None => { - let package_name = self.pool().resolve_version_set_package_name(version_set_id); - let version_set = self.pool().resolve_version_set(version_set_id); - let candidates = self.get_or_cache_candidates(package_name)?; + let pool = self.pool(); + let package_name = pool.resolve_version_set_package_name(version_set_id); + let version_set = pool.resolve_version_set(version_set_id); + let candidates = self.get_or_cache_candidates(package_name).await?; let matching_candidates = candidates .candidates .iter() .copied() .filter(|&p| { - let version = self.pool().resolve_internal_solvable(p).solvable().inner(); + let version = pool.resolve_internal_solvable(p).solvable().inner(); version_set.contains(version) }) .collect(); @@ -158,23 +191,24 @@ impl> SolverCache Result<&[SolvableId], Box> { match self.version_set_inverse_candidates.get(&version_set_id) { Some(candidates) => Ok(candidates), None => { - let package_name = self.pool().resolve_version_set_package_name(version_set_id); - let version_set = self.pool().resolve_version_set(version_set_id); - let candidates = self.get_or_cache_candidates(package_name)?; + let pool = self.pool(); + let package_name = pool.resolve_version_set_package_name(version_set_id); + let version_set = pool.resolve_version_set(version_set_id); + let candidates = self.get_or_cache_candidates(package_name).await?; let matching_candidates = candidates .candidates .iter() .copied() .filter(|&p| { - let version = self.pool().resolve_internal_solvable(p).solvable().inner(); + let version = pool.resolve_internal_solvable(p).solvable().inner(); !version_set.contains(version) }) .collect(); @@ -191,7 +225,7 @@ impl> SolverCache Result<&[SolvableId], Box> { @@ -199,13 +233,17 @@ impl> SolverCache Ok(candidates), None => { let package_name = self.pool().resolve_version_set_package_name(version_set_id); - let matching_candidates = self.get_or_cache_matching_candidates(version_set_id)?; - let candidates = self.get_or_cache_candidates(package_name)?; + let matching_candidates = self + .get_or_cache_matching_candidates(version_set_id) + .await?; + let candidates = self.get_or_cache_candidates(package_name).await?; // Sort all the candidates in order in which they should be tried by the solver. let mut sorted_candidates = Vec::new(); sorted_candidates.extend_from_slice(matching_candidates); - self.provider.sort_candidates(self, &mut sorted_candidates); + self.provider + .sort_candidates(self, &mut sorted_candidates) + .await; // If we have a solvable that we favor, we sort that to the front. This ensures // that the version that is favored is picked first. @@ -228,7 +266,7 @@ impl> SolverCache Result<&Dependencies, Box> { @@ -242,7 +280,7 @@ impl> SolverCache, + conflicting_clauses: Vec, + negative_assertions: Vec<(SolvableId, ClauseId)>, + clauses_to_watch: Vec, +} + /// Drives the SAT solving process -pub struct Solver> { +pub struct Solver< + VS: VersionSet, + N: PackageName, + D: DependencyProvider, + RT: AsyncRuntime = NowOrNeverRuntime, +> { + /// The [Pool] used by the solver + pub pool: Rc>, + pub(crate) async_runtime: RT, pub(crate) cache: SolverCache, - pub(crate) clauses: Arena, + pub(crate) clauses: RefCell>, requires_clauses: Vec<(SolvableId, VersionSetId, ClauseId)>, watches: WatchMap, @@ -43,8 +66,8 @@ pub struct Solver> learnt_why: Mapping>, learnt_clause_ids: Vec, - clauses_added_for_package: HashSet, - clauses_added_for_solvable: HashSet, + clauses_added_for_package: RefCell>, + clauses_added_for_solvable: RefCell>, decision_tracker: DecisionTracker, @@ -52,15 +75,20 @@ pub struct Solver> root_requirements: Vec, } -impl> Solver { - /// Create a solver, using the provided pool +impl> + Solver +{ + /// Create a solver, using the provided pool and async runtime. pub fn new(provider: D) -> Self { + let pool = provider.pool(); Self { cache: SolverCache::new(provider), - clauses: Arena::new(), + pool, + async_runtime: NowOrNeverRuntime, + clauses: RefCell::new(Arena::new()), requires_clauses: Default::default(), watches: WatchMap::new(), - negative_assertions: Vec::new(), + negative_assertions: Default::default(), learnt_clauses: Arena::new(), learnt_why: Mapping::new(), learnt_clause_ids: Vec::new(), @@ -70,11 +98,6 @@ impl> Solver &Pool { - self.cache.pool() - } } /// The root cause of a solver error. @@ -104,7 +127,29 @@ pub(crate) enum PropagationError { Cancelled(Box), } -impl> Solver { +impl, RT: AsyncRuntime> + Solver +{ + /// Set the runtime of the solver to `runtime`. + pub fn with_runtime(self, runtime: RT2) -> Solver { + Solver { + pool: self.pool, + async_runtime: runtime, + cache: self.cache, + clauses: self.clauses, + requires_clauses: self.requires_clauses, + watches: self.watches, + negative_assertions: self.negative_assertions, + learnt_clauses: self.learnt_clauses, + learnt_why: self.learnt_why, + learnt_clause_ids: self.learnt_clause_ids, + clauses_added_for_package: self.clauses_added_for_package, + clauses_added_for_solvable: self.clauses_added_for_solvable, + decision_tracker: self.decision_tracker, + root_requirements: self.root_requirements, + } + } + /// Solves the provided `jobs` and returns a transaction from the found solution /// /// Returns a [`Problem`] if no solution was found, which provides ways to inspect the causes @@ -123,7 +168,7 @@ impl> Sol // The first clause will always be the install root clause. Here we verify that this is // indeed the case. - let root_clause = self.clauses.alloc(ClauseState::root()); + let root_clause = self.clauses.borrow_mut().alloc(ClauseState::root()); assert_eq!(root_clause, ClauseId::install_root()); // Run SAT @@ -145,26 +190,6 @@ impl> Sol Ok(steps) } - /// Adds a clause to the solver and immediately starts watching its literals. - fn add_and_watch_clause(&mut self, clause: ClauseState) -> ClauseId { - let clause_id = self.clauses.alloc(clause); - let clause = &self.clauses[clause_id]; - - // Add in requires clause lookup - if let &Clause::Requires(solvable_id, version_set_id) = &clause.kind { - self.requires_clauses - .push((solvable_id, version_set_id, clause_id)); - } - - // Start watching the literals of the clause - let clause = &mut self.clauses[clause_id]; - if clause.has_watches() { - self.watches.start_watching(clause, clause_id); - } - - clause_id - } - /// Adds clauses for a solvable. These clauses include requirements and constrains on other /// solvables. /// @@ -172,217 +197,334 @@ impl> Sol /// /// If the provider has requested the solving process to be cancelled, the cancellation value /// will be returned as an `Err(...)`. - fn add_clauses_for_solvable( - &mut self, - solvable_id: SolvableId, - ) -> Result<(Vec, Vec), Box> { - if self.clauses_added_for_solvable.contains(&solvable_id) { - return Ok((Vec::new(), Vec::new())); + async fn add_clauses_for_solvables( + &self, + solvable_ids: impl IntoIterator, + ) -> Result> { + let mut output = AddClauseOutput::default(); + + pub enum TaskResult<'i> { + Dependencies { + solvable_id: SolvableId, + dependencies: Dependencies, + }, + SortedCandidates { + solvable_id: SolvableId, + version_set_id: VersionSetId, + candidates: &'i [SolvableId], + }, + NonMatchingCandidates { + solvable_id: SolvableId, + version_set_id: VersionSetId, + non_matching_candidates: &'i [SolvableId], + }, + Candidates { + name_id: NameId, + package_candidates: &'i Candidates, + }, } - let mut new_clauses = Vec::new(); - let mut conflicting_clauses = Vec::new(); - let mut queue = vec![solvable_id]; - let mut seen = HashSet::new(); - seen.insert(solvable_id); + // Mark the initial seen solvables as seen + let mut pending_solvables = vec![]; + { + let mut clauses_added_for_solvable = self.clauses_added_for_solvable.borrow_mut(); + for solvable_id in solvable_ids { + if clauses_added_for_solvable.insert(solvable_id) { + pending_solvables.push(solvable_id); + } + } + } - while let Some(solvable_id) = queue.pop() { - let solvable = self.pool().resolve_internal_solvable(solvable_id); - tracing::trace!( - "┝━ adding clauses for dependencies of {}", - solvable.display(self.pool()) - ); + let mut seen = pending_solvables.iter().copied().collect::>(); + let mut pending_futures = FuturesUnordered::new(); + loop { + // Iterate over all pending solvables and request their dependencies. + for solvable_id in pending_solvables.drain(..) { + // Get the solvable information and request its requirements and constraints + let solvable = self.pool.resolve_internal_solvable(solvable_id); + tracing::trace!( + "┝━ adding clauses for dependencies of {}", + solvable.display(&self.pool) + ); - // Determine the dependencies of the current solvable. There are two cases here: - // 1. The solvable is the root solvable which only provides required dependencies. - // 2. The solvable is a package candidate in which case we request the corresponding - // dependencies from the `DependencyProvider`. - let (requirements, constrains) = match solvable.inner { - SolvableInner::Root => (self.root_requirements.clone(), Vec::new()), - SolvableInner::Package(_) => { - let deps = self.cache.get_or_cache_dependencies(solvable_id)?; - match deps { - Dependencies::Known(deps) => { - (deps.requirements.clone(), deps.constrains.clone()) - } + let get_dependencies_fut = match solvable.inner { + SolvableInner::Root => ready(Ok(TaskResult::Dependencies { + solvable_id, + dependencies: Dependencies::Known(KnownDependencies { + requirements: self.root_requirements.clone(), + constrains: vec![], + }), + })) + .left_future(), + SolvableInner::Package(_) => async move { + let deps = self.cache.get_or_cache_dependencies(solvable_id).await?; + Ok(TaskResult::Dependencies { + solvable_id, + dependencies: deps.clone(), + }) + } + .right_future(), + }; + + pending_futures.push(get_dependencies_fut.boxed_local()); + } + + let Some(result) = pending_futures.next().await else { + // No more pending results + break; + }; + + let mut clauses_added_for_solvable = self.clauses_added_for_solvable.borrow_mut(); + let mut clauses_added_for_package = self.clauses_added_for_package.borrow_mut(); + + match result? { + TaskResult::Dependencies { + solvable_id, + dependencies, + } => { + // Get the solvable information and request its requirements and constraints + let solvable = self.pool.resolve_internal_solvable(solvable_id); + tracing::trace!( + "dependencies available for {}", + solvable.display(&self.pool) + ); + + let (requirements, constrains) = match dependencies { + Dependencies::Known(deps) => (deps.requirements, deps.constrains), Dependencies::Unknown(reason) => { // There is no information about the solvable's dependencies, so we add // an exclusion clause for it let clause_id = self .clauses - .alloc(ClauseState::exclude(solvable_id, *reason)); + .borrow_mut() + .alloc(ClauseState::exclude(solvable_id, reason)); // Exclusions are negative assertions, tracked outside of the watcher system - self.negative_assertions.push((solvable_id, clause_id)); - - new_clauses.push(clause_id); + output.negative_assertions.push((solvable_id, clause_id)); + // There might be a conflict now if self.decision_tracker.assigned_value(solvable_id) == Some(true) { - conflicting_clauses.push(clause_id); + output.conflicting_clauses.push(clause_id); } continue; } - } - } - }; + }; - // Add clauses for the requirements - for version_set_id in requirements { - let dependency_name = self.pool().resolve_version_set_package_name(version_set_id); - self.add_clauses_for_package(dependency_name)?; - - // Find all the solvables that match for the given version set - let candidates = self.cache.get_or_cache_sorted_candidates(version_set_id)?; - - // Queue requesting the dependencies of the candidates as well if they are cheaply - // available from the dependency provider. - for &candidate in candidates { - if seen.insert(candidate) - && self.cache.are_dependencies_available_for(candidate) - && !self.clauses_added_for_solvable.contains(&candidate) - { - queue.push(candidate); + for version_set_id in chain(requirements.iter(), constrains.iter()).copied() { + let dependency_name = + self.pool.resolve_version_set_package_name(version_set_id); + + if clauses_added_for_package.insert(dependency_name) { + tracing::trace!( + "┝━ adding clauses for package '{}'", + self.pool.resolve_package_name(dependency_name) + ); + + pending_futures.push( + async move { + let package_candidates = + self.cache.get_or_cache_candidates(dependency_name).await?; + Ok(TaskResult::Candidates { + name_id: dependency_name, + package_candidates, + }) + } + .boxed_local(), + ); + } } - } - // Add the requires clause - let no_candidates = candidates.is_empty(); - let (clause, conflict) = ClauseState::requires( - solvable_id, - version_set_id, - candidates, - &self.decision_tracker, - ); - - let clause_id = self.add_and_watch_clause(clause); + for version_set_id in requirements { + // Find all the solvable that match for the given version set + pending_futures.push( + async move { + let candidates = self + .cache + .get_or_cache_sorted_candidates(version_set_id) + .await?; + Ok(TaskResult::SortedCandidates { + solvable_id, + version_set_id, + candidates, + }) + } + .boxed_local(), + ); + } - if conflict { - conflicting_clauses.push(clause_id); - } else if no_candidates { - // Add assertions for unit clauses (i.e. those with no matching candidates) - self.negative_assertions.push((solvable_id, clause_id)); + for version_set_id in constrains { + // Find all the solvables that match for the given version set + pending_futures.push( + async move { + let non_matching_candidates = self + .cache + .get_or_cache_non_matching_candidates(version_set_id) + .await?; + Ok(TaskResult::NonMatchingCandidates { + solvable_id, + version_set_id, + non_matching_candidates, + }) + } + .boxed_local(), + ) + } } + TaskResult::Candidates { + name_id, + package_candidates, + } => { + // Get the solvable information and request its requirements and constraints + let solvable = self.pool.resolve_package_name(name_id); + tracing::trace!("package candidates available for {}", solvable); + + let locked_solvable_id = package_candidates.locked; + let candidates = &package_candidates.candidates; + + // Check the assumption that no decision has been made about any of the solvables. + for &candidate in candidates { + debug_assert!( + self.decision_tracker.assigned_value(candidate).is_none(), + "a decision has been made about a candidate of a package that was not properly added yet." + ); + } - new_clauses.push(clause_id); - } + // Each candidate gets a clause to disallow other candidates. + for (i, &candidate) in candidates.iter().enumerate() { + for &other_candidate in &candidates[i + 1..] { + let clause_id = self + .clauses + .borrow_mut() + .alloc(ClauseState::forbid_multiple(candidate, other_candidate)); - // Add clauses for the constraints - for version_set_id in constrains { - let dependency_name = self.pool().resolve_version_set_package_name(version_set_id); - self.add_clauses_for_package(dependency_name)?; + debug_assert!(self.clauses.borrow_mut()[clause_id].has_watches()); + output.clauses_to_watch.push(clause_id); + } + } - // Find all the solvables that match for the given version set - let constrained_candidates = self - .cache - .get_or_cache_non_matching_candidates(version_set_id)?; + // If there is a locked solvable, forbid other solvables. + if let Some(locked_solvable_id) = locked_solvable_id { + for &other_candidate in candidates { + if other_candidate != locked_solvable_id { + let clause_id = self + .clauses + .borrow_mut() + .alloc(ClauseState::lock(locked_solvable_id, other_candidate)); + + debug_assert!(self.clauses.borrow_mut()[clause_id].has_watches()); + output.clauses_to_watch.push(clause_id); + } + } + } - // Add forbidden clauses for the candidates - for forbidden_candidate in constrained_candidates.iter().copied().collect_vec() { - let (clause, conflict) = ClauseState::constrains( - solvable_id, - forbidden_candidate, - version_set_id, - &self.decision_tracker, - ); + // Add a clause for solvables that are externally excluded. + for (solvable, reason) in package_candidates.excluded.iter().copied() { + let clause_id = self + .clauses + .borrow_mut() + .alloc(ClauseState::exclude(solvable, reason)); - let clause_id = self.add_and_watch_clause(clause); + // Exclusions are negative assertions, tracked outside of the watcher system + output.negative_assertions.push((solvable, clause_id)); - if conflict { - conflicting_clauses.push(clause_id); + // Conflicts should be impossible here + debug_assert!(self.decision_tracker.assigned_value(solvable) != Some(true)); } - - new_clauses.push(clause_id) } - } - - // Start by stating the clauses have been added. - self.clauses_added_for_solvable.insert(solvable_id); - } + TaskResult::SortedCandidates { + solvable_id, + version_set_id, + candidates, + } => { + let version_set_name = self.pool.resolve_package_name( + self.pool.resolve_version_set_package_name(version_set_id), + ); + let version_set = self.pool.resolve_version_set(version_set_id); + tracing::trace!( + "sorted candidates available for {} {}", + version_set_name, + version_set + ); - Ok((new_clauses, conflicting_clauses)) - } + // Queue requesting the dependencies of the candidates as well if they are cheaply + // available from the dependency provider. + for &candidate in candidates { + if seen.insert(candidate) + && self.cache.are_dependencies_available_for(candidate) + && clauses_added_for_solvable.insert(candidate) + { + pending_solvables.push(candidate); + } + } - /// Adds all clauses for a specific package name. - /// - /// These clauses include: - /// - /// 1. making sure that only a single candidate for the package is selected (forbid multiple) - /// 2. if there is a locked candidate then that candidate is the only selectable candidate. - /// - /// If this function is called with the same package name twice, the clauses will only be added - /// once. - /// - /// There is no need to propagate after adding these clauses because none of the clauses are - /// assertions (only a single literal) and we assume that no decision has been made about any - /// of the solvables involved. This assumption is checked when debug_assertions are enabled. - /// - /// If the provider has requested the solving process to be cancelled, the cancellation value - /// will be returned as an `Err(...)`. - fn add_clauses_for_package(&mut self, package_name: NameId) -> Result<(), Box> { - if self.clauses_added_for_package.contains(&package_name) { - return Ok(()); - } + // Add the requirements clause + let no_candidates = candidates.is_empty(); + let (clause, conflict) = ClauseState::requires( + solvable_id, + version_set_id, + candidates, + &self.decision_tracker, + ); - tracing::trace!( - "┝━ adding clauses for package '{}'", - self.pool().resolve_package_name(package_name) - ); + let clause_id = self.clauses.borrow_mut().alloc(clause); + let clause = &self.clauses.borrow()[clause_id]; - let package_candidates = self.cache.get_or_cache_candidates(package_name)?; - let locked_solvable_id = package_candidates.locked; - let candidates = &package_candidates.candidates; + let &Clause::Requires(solvable_id, version_set_id) = &clause.kind else { + unreachable!(); + }; - // Check the assumption that no decision has been made about any of the solvables. - for &candidate in candidates { - debug_assert!( - self.decision_tracker.assigned_value(candidate).is_none(), - "a decision has been made about a candidate of a package that was not properly added yet." - ); - } + if clause.has_watches() { + output.clauses_to_watch.push(clause_id); + } - // Each candidate gets a clause to disallow other candidates. - for (i, &candidate) in candidates.iter().enumerate() { - for &other_candidate in &candidates[i + 1..] { - let clause_id = self - .clauses - .alloc(ClauseState::forbid_multiple(candidate, other_candidate)); + output + .new_requires_clauses + .push((solvable_id, version_set_id, clause_id)); - let clause = &mut self.clauses[clause_id]; - debug_assert!(clause.has_watches()); - self.watches.start_watching(clause, clause_id); - } - } + if conflict { + output.conflicting_clauses.push(clause_id); + } else if no_candidates { + // Add assertions for unit clauses (i.e. those with no matching candidates) + output.negative_assertions.push((solvable_id, clause_id)); + } + } + TaskResult::NonMatchingCandidates { + solvable_id, + version_set_id, + non_matching_candidates, + } => { + let version_set_name = self.pool.resolve_package_name( + self.pool.resolve_version_set_package_name(version_set_id), + ); + let version_set = self.pool.resolve_version_set(version_set_id); + tracing::trace!( + "non matching candidates available for {} {}", + version_set_name, + version_set + ); - // If there is a locked solvable, forbid other solvables. - if let Some(locked_solvable_id) = locked_solvable_id { - for &other_candidate in candidates { - if other_candidate != locked_solvable_id { - let clause_id = self - .clauses - .alloc(ClauseState::lock(locked_solvable_id, other_candidate)); + // Add forbidden clauses for the candidates + for &forbidden_candidate in non_matching_candidates { + let (clause, conflict) = ClauseState::constrains( + solvable_id, + forbidden_candidate, + version_set_id, + &self.decision_tracker, + ); - let clause = &mut self.clauses[clause_id]; + let clause_id = self.clauses.borrow_mut().alloc(clause); + output.clauses_to_watch.push(clause_id); - debug_assert!(clause.has_watches()); - self.watches.start_watching(clause, clause_id); + if conflict { + output.conflicting_clauses.push(clause_id); + } + } } } } - // Add a clause for solvables that are externally excluded. - for (solvable, reason) in package_candidates.excluded.iter().copied() { - let clause_id = self.clauses.alloc(ClauseState::exclude(solvable, reason)); - - // Exclusions are negative assertions, tracked outside of the watcher system - self.negative_assertions.push((solvable, clause_id)); - - // Conflicts should be impossible here - debug_assert!(self.decision_tracker.assigned_value(solvable) != Some(true)); - } - - self.clauses_added_for_package.insert(package_name); - Ok(()) + Ok(output) } /// Run the CDCL algorithm to solve the SAT problem @@ -410,7 +552,6 @@ impl> Sol assert!(self.decision_tracker.is_empty()); let mut level = 0; - let mut new_clauses = Vec::new(); loop { // A level of 0 means the decision loop has been completely reset because a partial // solution was invalidated by newly added clauses. @@ -424,7 +565,7 @@ impl> Sol // solution that satisfies the user requirements. tracing::info!( "╤══ install {} at level {level}", - SolvableId::root().display(self.pool()) + SolvableId::root().display(&self.pool) ); self.decision_tracker .try_add_decision( @@ -434,14 +575,14 @@ impl> Sol .expect("already decided"); // Add the clauses for the root solvable. - let (mut clauses, conflicting_clauses) = - self.add_clauses_for_solvable(SolvableId::root())?; - if let Some(clause_id) = conflicting_clauses.into_iter().next() { + let output = self + .async_runtime + .block_on(self.add_clauses_for_solvables(vec![SolvableId::root()]))?; + if let Err(clause_id) = self.process_add_clause_output(output) { return Err(UnsolvableOrCancelled::Unsolvable( self.analyze_unsolvable(clause_id), )); } - new_clauses.append(&mut clauses); } // Propagate decisions from assignments above @@ -459,7 +600,7 @@ impl> Sol // The conflict was caused because new clauses have been added dynamically. // We need to start over. tracing::debug!("├─ added clause {clause:?} introduces a conflict which invalidates the partial solution", - clause=self.clauses[clause_id].debug(self.pool())); + clause=self.clauses.borrow()[clause_id].debug(&self.pool)); level = 0; self.decision_tracker.clear(); continue; @@ -486,7 +627,12 @@ impl> Sol // Filter only decisions that led to a positive assignment .filter(|d| d.value) // Select solvables for which we do not yet have dependencies - .filter(|d| !self.clauses_added_for_solvable.contains(&d.solvable_id)) + .filter(|d| { + !self + .clauses_added_for_solvable + .borrow() + .contains(&d.solvable_id) + }) .map(|d| (d.solvable_id, d.derived_from)) .collect(); @@ -502,31 +648,52 @@ impl> Sol .copied() .format_with("\n- ", |(id, derived_from), f| f(&format_args!( "{} (derived from {:?})", - id.display(self.pool()), - self.clauses[derived_from].debug(self.pool()), + id.display(&self.pool), + self.clauses.borrow()[derived_from].debug(&self.pool), ))) ); - for (solvable, _) in new_solvables { - // Add the clauses for this particular solvable. - let (mut clauses_for_solvable, conflicting_causes) = - self.add_clauses_for_solvable(solvable)?; - new_clauses.append(&mut clauses_for_solvable); + // Concurrently get the solvable's clauses + let output = self.async_runtime.block_on(self.add_clauses_for_solvables( + new_solvables.iter().map(|(solvable_id, _)| *solvable_id), + ))?; - for &clause_id in &conflicting_causes { - // Backtrack in the case of conflicts - tracing::debug!("├─ added clause {clause:?} introduces a conflict which invalidates the partial solution", - clause=self.clauses[clause_id].debug(self.pool())); - } + // Serially process the outputs, to reduce the need for synchronization + for &clause_id in &output.conflicting_clauses { + tracing::debug!("├─ added clause {clause:?} introduces a conflict which invalidates the partial solution", + clause=self.clauses.borrow()[clause_id].debug(&self.pool)); + } - if !conflicting_causes.is_empty() { - self.decision_tracker.clear(); - level = 0; - } + if let Err(_first_conflicting_clause_id) = self.process_add_clause_output(output) { + self.decision_tracker.clear(); + level = 0; } } } + fn process_add_clause_output(&mut self, mut output: AddClauseOutput) -> Result<(), ClauseId> { + let mut clauses = self.clauses.borrow_mut(); + for clause_id in output.clauses_to_watch { + debug_assert!( + clauses[clause_id].has_watches(), + "attempting to watch a clause without watches!" + ); + self.watches + .start_watching(&mut clauses[clause_id], clause_id); + } + + self.requires_clauses + .append(&mut output.new_requires_clauses); + self.negative_assertions + .append(&mut output.negative_assertions); + + if let Some(&clause_id) = output.conflicting_clauses.first() { + return Err(clause_id); + } + + Ok(()) + } + /// Resolves all dependencies /// /// Repeatedly chooses the next variable to assign, and calls [`Solver::set_propagate_learn`] to @@ -557,7 +724,7 @@ impl> Sol /// ensures that if there are conflicts they are delt with as early as possible. fn decide(&mut self) -> Option<(SolvableId, SolvableId, ClauseId)> { let mut best_decision = None; - for &(solvable_id, deps, clause_id) in self.requires_clauses.iter() { + for &(solvable_id, deps, clause_id) in &self.requires_clauses { // Consider only clauses in which we have decided to install the solvable if self.decision_tracker.assigned_value(solvable_id) != Some(true) { continue; @@ -605,8 +772,8 @@ impl> Sol if let Some((count, (candidate, _solvable_id, clause_id))) = best_decision { tracing::info!( "deciding to assign {}, ({:?}, {} possible candidates)", - candidate.display(self.pool()), - self.clauses[clause_id].debug(self.pool()), + candidate.display(&self.pool), + self.clauses.borrow()[clause_id].debug(&self.pool), count, ); } @@ -637,8 +804,8 @@ impl> Sol tracing::info!( "╤══ Install {} at level {level} (required by {})", - solvable.display(self.pool()), - required_by.display(self.pool()), + solvable.display(&self.pool), + required_by.display(&self.pool), ); // Add the decision to the tracker @@ -688,28 +855,28 @@ impl> Sol { tracing::info!( "├─ Propagation conflicted: could not set {solvable} to {attempted_value}", - solvable = conflicting_solvable.display(self.pool()) + solvable = conflicting_solvable.display(&self.pool) ); tracing::info!( "│ During unit propagation for clause: {:?}", - self.clauses[conflicting_clause].debug(self.pool()) + self.clauses.borrow()[conflicting_clause].debug(&self.pool) ); tracing::info!( "│ Previously decided value: {}. Derived from: {:?}", !attempted_value, - self.clauses[self + self.clauses.borrow()[self .decision_tracker .find_clause_for_assignment(conflicting_solvable) .unwrap()] - .debug(self.pool()), + .debug(&self.pool), ); } if level == 1 { tracing::info!("╘══ UNSOLVABLE"); for decision in self.decision_tracker.stack() { - let clause = &self.clauses[decision.derived_from]; + let clause = &self.clauses.borrow()[decision.derived_from]; let level = self.decision_tracker.level(decision.solvable_id); let action = if decision.value { "install" } else { "forbid" }; @@ -720,8 +887,8 @@ impl> Sol tracing::info!( "* ({level}) {action} {}. Reason: {:?}", - decision.solvable_id.display(self.pool()), - clause.debug(self.pool()), + decision.solvable_id.display(&self.pool), + clause.debug(&self.pool), ); } @@ -744,7 +911,7 @@ impl> Sol .expect("bug: solvable was already decided!"); tracing::debug!( "├─ Propagate after learn: {} = {decision}", - literal.solvable_id.display(self.pool()) + literal.solvable_id.display(&self.pool) ); Ok(level) @@ -774,7 +941,7 @@ impl> Sol if decided { tracing::trace!( "├─ Propagate assertion {} = {}", - solvable_id.display(self.pool()), + solvable_id.display(&self.pool), value ); } @@ -783,7 +950,7 @@ impl> Sol // Assertions derived from learnt rules for learn_clause_idx in 0..self.learnt_clause_ids.len() { let clause_id = self.learnt_clause_ids[learn_clause_idx]; - let clause = &self.clauses[clause_id]; + let clause = &self.clauses.borrow()[clause_id]; let Clause::Learnt(learnt_index) = clause.kind else { unreachable!(); }; @@ -811,7 +978,7 @@ impl> Sol if decided { tracing::trace!( "├─ Propagate assertion {} = {}", - literal.solvable_id.display(self.pool()), + literal.solvable_id.display(&self.pool), decision ); } @@ -831,13 +998,14 @@ impl> Sol } // Get mutable access to both clauses. + let mut clauses = self.clauses.borrow_mut(); let (predecessor_clause, clause) = if let Some(prev_clause_id) = predecessor_clause_id { let (predecessor_clause, clause) = - self.clauses.get_two_mut(prev_clause_id, clause_id); + clauses.get_two_mut(prev_clause_id, clause_id); (Some(predecessor_clause), clause) } else { - (None, &mut self.clauses[clause_id]) + (None, &mut clauses[clause_id]) }; // Update the prev_clause_id for the next run @@ -909,9 +1077,9 @@ impl> Sol _ => { tracing::debug!( "├─ Propagate {} = {}. {:?}", - remaining_watch.solvable_id.display(self.cache.pool()), + remaining_watch.solvable_id.display(&self.cache.pool()), remaining_watch.satisfying_value(), - clause.debug(self.cache.pool()), + clause.debug(&self.cache.pool()), ); } } @@ -964,7 +1132,7 @@ impl> Sol tracing::info!("=== ANALYZE UNSOLVABLE"); let mut involved = HashSet::new(); - self.clauses[clause_id].kind.visit_literals( + self.clauses.borrow()[clause_id].kind.visit_literals( &self.learnt_clauses, &self.cache.version_set_to_sorted_candidates, |literal| { @@ -974,7 +1142,7 @@ impl> Sol let mut seen = HashSet::new(); Self::analyze_unsolvable_clause( - &self.clauses, + &self.clauses.borrow(), &self.learnt_why, clause_id, &mut problem, @@ -995,14 +1163,14 @@ impl> Sol assert_ne!(why, ClauseId::install_root()); Self::analyze_unsolvable_clause( - &self.clauses, + &self.clauses.borrow(), &self.learnt_why, why, &mut problem, &mut seen, ); - self.clauses[why].kind.visit_literals( + self.clauses.borrow()[why].kind.visit_literals( &self.learnt_clauses, &self.cache.version_set_to_sorted_candidates, |literal| { @@ -1046,7 +1214,7 @@ impl> Sol loop { learnt_why.push(clause_id); - self.clauses[clause_id].kind.visit_literals( + self.clauses.borrow()[clause_id].kind.visit_literals( &self.learnt_clauses, &self.cache.version_set_to_sorted_candidates, |literal| { @@ -1115,10 +1283,13 @@ impl> Sol let learnt_id = self.learnt_clauses.alloc(learnt.clone()); self.learnt_why.insert(learnt_id, learnt_why); - let clause_id = self.clauses.alloc(ClauseState::learnt(learnt_id, &learnt)); + let clause_id = self + .clauses + .borrow_mut() + .alloc(ClauseState::learnt(learnt_id, &learnt)); self.learnt_clause_ids.push(clause_id); - let clause = &mut self.clauses[clause_id]; + let clause = &mut self.clauses.borrow_mut()[clause_id]; if clause.has_watches() { self.watches.start_watching(clause, clause_id); } @@ -1128,7 +1299,7 @@ impl> Sol tracing::debug!( "│ - {}{}", if lit.negate { "NOT " } else { "" }, - lit.solvable_id.display(self.pool()) + lit.solvable_id.display(&self.pool) ); } diff --git a/tests/snapshots/solver__resolve_with_concurrent_metadata_fetching.snap b/tests/snapshots/solver__resolve_with_concurrent_metadata_fetching.snap new file mode 100644 index 0000000..5365a68 --- /dev/null +++ b/tests/snapshots/solver__resolve_with_concurrent_metadata_fetching.snap @@ -0,0 +1,8 @@ +--- +source: tests/solver.rs +expression: result +--- +child1=3 +child2=2 +parent=4 + diff --git a/tests/solver.rs b/tests/solver.rs index 42b9ebc..a88b874 100644 --- a/tests/solver.rs +++ b/tests/solver.rs @@ -5,6 +5,12 @@ use resolvo::{ KnownDependencies, NameId, Pool, SolvableId, Solver, SolverCache, UnsolvableOrCancelled, VersionSet, VersionSetId, }; +use std::cell::RefCell; +use std::collections::HashSet; +use std::rc::Rc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::time::Duration; use std::{ any::Any, cell::Cell, @@ -58,7 +64,7 @@ impl Pack { } fn offset(&self, version_offset: i32) -> Pack { - let mut pack = self.clone(); + let mut pack = *self; pack.version = pack.version.wrapping_add_signed(version_offset); pack } @@ -139,12 +145,20 @@ impl FromStr for Spec { /// This provides sorting functionality for our `BundleBox` packaging system #[derive(Default)] struct BundleBoxProvider { - pool: Pool>, + pool: Rc>>, packages: IndexMap>, favored: HashMap, locked: HashMap, excluded: HashMap>, cancel_solving: Cell, + // TODO: simplify? + concurrent_requests: Arc, + concurrent_requests_max: Rc>, + sleep_before_return: bool, + + // A mapping of packages that we have requested candidates for. This way we can keep track of duplicate requests. + requested_candidates: RefCell>, + requested_dependencies: RefCell>, } struct BundleBoxPackageDependencies { @@ -159,7 +173,7 @@ impl BundleBoxProvider { pub fn requirements(&self, requirements: &[&str]) -> Vec { requirements - .into_iter() + .iter() .map(|dep| Spec::from_str(dep).unwrap()) .map(|spec| { let dep_name = self.pool.intern_package_name(&spec.name); @@ -202,13 +216,13 @@ impl BundleBoxProvider { constrains: &[&str], ) { let dependencies = dependencies - .into_iter() + .iter() .map(|dep| Spec::from_str(dep)) .collect::, _>>() .unwrap(); let constrains = constrains - .into_iter() + .iter() .map(|dep| Spec::from_str(dep)) .collect::, _>>() .unwrap(); @@ -224,14 +238,26 @@ impl BundleBoxProvider { }, ); } + + // Sends a value from the dependency provider to the solver, introducing a minimal delay to force + // concurrency to be used (unless there is no async runtime available) + async fn maybe_delay(&self, value: T) -> T { + if self.sleep_before_return { + tokio::time::sleep(Duration::from_millis(10)).await; + self.concurrent_requests.fetch_sub(1, Ordering::SeqCst); + value + } else { + value + } + } } impl DependencyProvider> for BundleBoxProvider { - fn pool(&self) -> &Pool> { - &self.pool + fn pool(&self) -> Rc>> { + self.pool.clone() } - fn sort_candidates( + async fn sort_candidates( &self, _solver: &SolverCache, String, Self>, solvables: &mut [SolvableId], @@ -244,9 +270,23 @@ impl DependencyProvider> for BundleBoxProvider { }); } - fn get_candidates(&self, name: NameId) -> Option { + async fn get_candidates(&self, name: NameId) -> Option { + let concurrent_requests = self.concurrent_requests.fetch_add(1, Ordering::SeqCst); + self.concurrent_requests_max.set( + self.concurrent_requests_max + .get() + .max(concurrent_requests + 1), + ); + + assert!( + self.requested_candidates.borrow_mut().insert(name), + "duplicate get_candidates request" + ); + let package_name = self.pool.resolve_package_name(name); - let package = self.packages.get(package_name)?; + let Some(package) = self.packages.get(package_name) else { + return self.maybe_delay(None).await; + }; let mut candidates = Candidates { candidates: Vec::with_capacity(package.len()), @@ -271,10 +311,30 @@ impl DependencyProvider> for BundleBoxProvider { } } - Some(candidates) + self.maybe_delay(Some(candidates)).await } - fn get_dependencies(&self, solvable: SolvableId) -> Dependencies { + async fn get_dependencies(&self, solvable: SolvableId) -> Dependencies { + tracing::info!( + "get dependencies for {}", + self.pool + .resolve_solvable(solvable) + .name_id() + .display(&self.pool) + ); + + let concurrent_requests = self.concurrent_requests.fetch_add(1, Ordering::SeqCst); + self.concurrent_requests_max.set( + self.concurrent_requests_max + .get() + .max(concurrent_requests + 1), + ); + + assert!( + self.requested_dependencies.borrow_mut().insert(solvable), + "duplicate get_dependencies request" + ); + let candidate = self.pool.resolve_solvable(solvable); let package_name = self.pool.resolve_package_name(candidate.name_id()); let pack = candidate.inner(); @@ -282,16 +342,18 @@ impl DependencyProvider> for BundleBoxProvider { if pack.cancel_during_get_dependencies { self.cancel_solving.set(true); let reason = self.pool.intern_string("cancelled"); - return Dependencies::Unknown(reason); + return self.maybe_delay(Dependencies::Unknown(reason)).await; } if pack.unknown_deps { let reason = self.pool.intern_string("could not retrieve deps"); - return Dependencies::Unknown(reason); + return self.maybe_delay(Dependencies::Unknown(reason)).await; } let Some(deps) = self.packages.get(package_name).and_then(|v| v.get(pack)) else { - return Dependencies::Known(Default::default()); + return self + .maybe_delay(Dependencies::Known(Default::default())) + .await; }; let mut result = KnownDependencies { @@ -310,7 +372,7 @@ impl DependencyProvider> for BundleBoxProvider { result.constrains.push(dep_spec); } - Dependencies::Known(result) + self.maybe_delay(Dependencies::Known(result)).await } fn should_cancel_with_value(&self) -> Option> { @@ -341,6 +403,7 @@ fn transaction_to_string(pool: &Pool, solvables: &Vec String { let requirements = provider.requirements(specs); + let pool = provider.pool(); let mut solver = Solver::new(provider); match solver.solve(requirements) { Ok(_) => panic!("expected unsat, but a solution was found"), @@ -349,12 +412,12 @@ fn solve_unsat(provider: BundleBoxProvider, specs: &[&str]) -> String { let graph = problem.graph(&solver); let mut output = stderr(); writeln!(output, "UNSOLVABLE:").unwrap(); - graph.graphviz(&mut output, solver.pool(), true).unwrap(); + graph.graphviz(&mut output, &pool, true).unwrap(); writeln!(output, "\n").unwrap(); // Format a user friendly error message problem - .display_user_friendly(&solver, &DefaultSolvableDisplay) + .display_user_friendly(&solver, pool, &DefaultSolvableDisplay) .to_string() } Err(UnsolvableOrCancelled::Cancelled(reason)) => *reason.downcast().unwrap(), @@ -362,22 +425,31 @@ fn solve_unsat(provider: BundleBoxProvider, specs: &[&str]) -> String { } /// Solve the problem and returns either a solution represented as a string or an error string. -fn solve_snapshot(provider: BundleBoxProvider, specs: &[&str]) -> String { +fn solve_snapshot(mut provider: BundleBoxProvider, specs: &[&str]) -> String { + // The test dependency provider requires time support for sleeping + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_time() + .build() + .unwrap(); + + provider.sleep_before_return = true; + let requirements = provider.requirements(specs); - let mut solver = Solver::new(provider); + let pool = provider.pool(); + let mut solver = Solver::new(provider).with_runtime(runtime); match solver.solve(requirements) { - Ok(solvables) => transaction_to_string(solver.pool(), &solvables), + Ok(solvables) => transaction_to_string(&pool, &solvables), Err(UnsolvableOrCancelled::Unsolvable(problem)) => { // Write the problem graphviz to stderr let graph = problem.graph(&solver); let mut output = stderr(); writeln!(output, "UNSOLVABLE:").unwrap(); - graph.graphviz(&mut output, solver.pool(), true).unwrap(); + graph.graphviz(&mut output, &pool, true).unwrap(); writeln!(output, "\n").unwrap(); // Format a user friendly error message problem - .display_user_friendly(&solver, &DefaultSolvableDisplay) + .display_user_friendly(&solver, pool, &DefaultSolvableDisplay) .to_string() } Err(UnsolvableOrCancelled::Cancelled(reason)) => *reason.downcast().unwrap(), @@ -389,16 +461,14 @@ fn solve_snapshot(provider: BundleBoxProvider, specs: &[&str]) -> String { fn test_unit_propagation_1() { let provider = BundleBoxProvider::from_packages(&[("asdf", 1, vec![])]); let root_requirements = provider.requirements(&["asdf"]); + let pool = provider.pool(); let mut solver = Solver::new(provider); let solved = solver.solve(root_requirements).unwrap(); assert_eq!(solved.len(), 1); - let solvable = solver.pool().resolve_solvable(solved[0]); + let solvable = pool.resolve_solvable(solved[0]); - assert_eq!( - solver.pool().resolve_package_name(solvable.name_id()), - "asdf" - ); + assert_eq!(pool.resolve_package_name(solvable.name_id()), "asdf"); assert_eq!(solvable.inner().version, 1); } @@ -411,25 +481,20 @@ fn test_unit_propagation_nested() { ("dummy", 6u32, vec![]), ]); let requirements = provider.requirements(&["asdf"]); + let pool = provider.pool(); let mut solver = Solver::new(provider); let solved = solver.solve(requirements).unwrap(); assert_eq!(solved.len(), 2); - let solvable = solver.pool().resolve_solvable(solved[0]); + let solvable = pool.resolve_solvable(solved[0]); - assert_eq!( - solver.pool().resolve_package_name(solvable.name_id()), - "asdf" - ); + assert_eq!(pool.resolve_package_name(solvable.name_id()), "asdf"); assert_eq!(solvable.inner().version, 1); - let solvable = solver.pool().resolve_solvable(solved[1]); + let solvable = pool.resolve_solvable(solved[1]); - assert_eq!( - solver.pool().resolve_package_name(solvable.name_id()), - "efgh" - ); + assert_eq!(pool.resolve_package_name(solvable.name_id()), "efgh"); assert_eq!(solvable.inner().version, 4); } @@ -443,28 +508,39 @@ fn test_resolve_multiple() { ("efgh", 5, vec![]), ]); let requirements = provider.requirements(&["asdf", "efgh"]); + let pool = provider.pool(); let mut solver = Solver::new(provider); let solved = solver.solve(requirements).unwrap(); assert_eq!(solved.len(), 2); - let solvable = solver.pool().resolve_solvable(solved[0]); + let solvable = pool.resolve_solvable(solved[0]); - assert_eq!( - solver.pool().resolve_package_name(solvable.name_id()), - "asdf" - ); + assert_eq!(pool.resolve_package_name(solvable.name_id()), "asdf"); assert_eq!(solvable.inner().version, 2); - let solvable = solver.pool().resolve_solvable(solved[1]); + let solvable = pool.resolve_solvable(solved[1]); - assert_eq!( - solver.pool().resolve_package_name(solvable.name_id()), - "efgh" - ); + assert_eq!(pool.resolve_package_name(solvable.name_id()), "efgh"); assert_eq!(solvable.inner().version, 5); } +#[test] +fn test_resolve_with_concurrent_metadata_fetching() { + let provider = BundleBoxProvider::from_packages(&[ + ("parent", 4, vec!["child1", "child2"]), + ("child1", 3, vec![]), + ("child2", 2, vec![]), + ]); + + let max_concurrent_requests = provider.concurrent_requests_max.clone(); + + let result = solve_snapshot(provider, &["parent"]); + insta::assert_snapshot!(result); + + assert_eq!(2, max_concurrent_requests.get()); +} + /// In case of a conflict the version should not be selected with the conflict #[test] fn test_resolve_with_conflict() { @@ -490,17 +566,15 @@ fn test_resolve_with_nonexisting() { ("b", 1, vec!["idontexist"]), ]); let requirements = provider.requirements(&["asdf"]); + let pool = provider.pool(); let mut solver = Solver::new(provider); let solved = solver.solve(requirements).unwrap(); assert_eq!(solved.len(), 1); - let solvable = solver.pool().resolve_solvable(solved[0]); + let solvable = pool.resolve_solvable(solved[0]); - assert_eq!( - solver.pool().resolve_package_name(solvable.name_id()), - "asdf" - ); + assert_eq!(pool.resolve_package_name(solvable.name_id()), "asdf"); assert_eq!(solvable.inner().version, 3); } @@ -526,15 +600,16 @@ fn test_resolve_with_nested_deps() { ("opentelemetry-grpc", 1, vec!["opentelemetry-api 1"]), ]); let requirements = provider.requirements(&["apache-airflow"]); + let pool = provider.pool(); let mut solver = Solver::new(provider); let solved = solver.solve(requirements).unwrap(); assert_eq!(solved.len(), 1); - let solvable = solver.pool().resolve_solvable(solved[0]); + let solvable = pool.resolve_solvable(solved[0]); assert_eq!( - solver.pool().resolve_package_name(solvable.name_id()), + pool.resolve_package_name(solvable.name_id()), "apache-airflow" ); assert_eq!(solvable.inner().version, 1); @@ -552,15 +627,16 @@ fn test_resolve_with_unknown_deps() { ); provider.add_package("opentelemetry-api", Pack::new(2), &[], &[]); let requirements = provider.requirements(&["opentelemetry-api"]); + let pool = provider.pool(); let mut solver = Solver::new(provider); let solved = solver.solve(requirements).unwrap(); assert_eq!(solved.len(), 1); - let solvable = solver.pool().resolve_solvable(solved[0]); + let solvable = pool.resolve_solvable(solved[0]); assert_eq!( - solver.pool().resolve_package_name(solvable.name_id()), + pool.resolve_package_name(solvable.name_id()), "opentelemetry-api" ); assert_eq!(solvable.inner().version, 2); @@ -596,15 +672,13 @@ fn test_resolve_locked_top_level() { let requirements = provider.requirements(&["asdf"]); + let pool = provider.pool(); let mut solver = Solver::new(provider); let solved = solver.solve(requirements).unwrap(); assert_eq!(solved.len(), 1); let solvable_id = solved[0]; - assert_eq!( - solver.pool().resolve_solvable(solvable_id).inner().version, - 3 - ); + assert_eq!(pool.resolve_solvable(solvable_id).inner().version, 3); } /// Should ignore lock when it is not a top level package and a newer version exists without it @@ -619,16 +693,14 @@ fn test_resolve_ignored_locked_top_level() { provider.set_locked("fgh", 1); let requirements = provider.requirements(&["asdf"]); + let pool = provider.pool(); let mut solver = Solver::new(provider); let solved = solver.solve(requirements).unwrap(); assert_eq!(solved.len(), 1); - let solvable = solver.pool().resolve_solvable(solved[0]); + let solvable = pool.resolve_solvable(solved[0]); - assert_eq!( - solver.pool().resolve_package_name(solvable.name_id()), - "asdf" - ); + assert_eq!(pool.resolve_package_name(solvable.name_id()), "asdf"); assert_eq!(solvable.inner().version, 4); } @@ -679,10 +751,11 @@ fn test_resolve_cyclic() { let provider = BundleBoxProvider::from_packages(&[("a", 2, vec!["b 0..10"]), ("b", 5, vec!["a 2..4"])]); let requirements = provider.requirements(&["a 0..100"]); + let pool = provider.pool(); let mut solver = Solver::new(provider); let solved = solver.solve(requirements).unwrap(); - let result = transaction_to_string(&solver.pool(), &solved); + let result = transaction_to_string(&pool, &solved); insta::assert_snapshot!(result, @r###" a=2 b=5 @@ -823,8 +896,8 @@ fn test_unsat_constrains() { ("b", 42, vec![]), ]); - provider.add_package("c", 10.into(), &vec![], &vec!["b 0..50"]); - provider.add_package("c", 8.into(), &vec![], &vec!["b 0..50"]); + provider.add_package("c", 10.into(), &[], &["b 0..50"]); + provider.add_package("c", 8.into(), &[], &["b 0..50"]); let error = solve_unsat(provider, &["a", "c"]); insta::assert_snapshot!(error); } @@ -839,8 +912,8 @@ fn test_unsat_constrains_2() { ("b", 2, vec!["c 2"]), ]); - provider.add_package("c", 1.into(), &vec![], &vec!["a 3"]); - provider.add_package("c", 2.into(), &vec![], &vec!["a 3"]); + provider.add_package("c", 1.into(), &[], &["a 3"]); + provider.add_package("c", 2.into(), &[], &["a 3"]); let error = solve_unsat(provider, &["a"]); insta::assert_snapshot!(error); }