Skip to content

Commit

Permalink
feat: concurrent metadata fetching
Browse files Browse the repository at this point in the history
  • Loading branch information
aochagavia committed Feb 2, 2024
1 parent b05a137 commit fe6a7fa
Show file tree
Hide file tree
Showing 7 changed files with 427 additions and 236 deletions.
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@ tracing = "0.1.37"
elsa = "1.9.0"
bitvec = "1.0.1"
serde = { version = "1.0", features = ["derive"], optional = true }
futures = { version = "0.3.30", default-features = false, features = ["alloc"] }
tokio = { version = "1.35.1", features = ["rt", "sync"] }

[dev-dependencies]
insta = "1.31.0"
indexmap = "2.0.0"
proptest = "1.2.0"
tracing-test = { version = "0.2.4", features = ["no-env-filter"] }
tokio = { version = "1.35.1", features = ["time"] }
27 changes: 22 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use std::{
any::Any,
fmt::{Debug, Display},
hash::Hash,
rc::Rc,
};

/// The solver is based around the fact that for for every package name we are trying to find a
Expand Down Expand Up @@ -61,21 +62,37 @@ pub trait VersionSet: Debug + Display + Clone + Eq + Hash {
/// packages that are available in the system.
pub trait DependencyProvider<VS: VersionSet, N: PackageName = String>: Sized {
/// Returns the [`Pool`] that is used to allocate the Ids returned from this instance
fn pool(&self) -> &Pool<VS, N>;
fn pool(&self) -> Rc<Pool<VS, N>>;

/// Sort the specified solvables based on which solvable to try first. The solver will
/// iteratively try to select the highest version. If a conflict is found with the highest
/// version the next version is tried. This continues until a solution is found.
fn sort_candidates(&self, solver: &SolverCache<VS, N, Self>, solvables: &mut [SolvableId]);

/// Returns a list of solvables that should be considered when a package with the given name is
/// Obtains a list of solvables that should be considered when a package with the given name is
/// requested.
///
/// Returns `None` if no such package exist.
fn get_candidates(&self, name: NameId) -> Option<Candidates>;
/// # Async
///
/// The returned future will be awaited by a tokio runtime blocking the main thread. You are
/// free to use other runtimes in your implementation, as long as the runtime-specific code runs
/// in threads controlled by that runtime (and _not_ in the main thread). For instance, you can
/// use `async_std::task::spawn` to spawn a new task, use `async_std::io` inside the task to
/// retrieve necessary information from the network, and `await` the returned task handle.
#[allow(async_fn_in_trait)]
async fn get_candidates(&self, name: NameId) -> Option<Candidates>;

/// Returns the dependencies for the specified solvable.
fn get_dependencies(&self, solvable: SolvableId) -> Dependencies;
///
/// # Async
///
/// The returned future will be awaited by a tokio runtime blocking the main thread. You are
/// free to use other runtimes in your implementation, as long as the runtime-specific code runs
/// in threads controlled by that runtime (and _not_ in the main thread). For instance, you can
/// use `async_std::task::spawn` to spawn a new task, use `async_std::io` inside the task to
/// retrieve necessary information from the network, and `await` the returned task handle.
#[allow(async_fn_in_trait)]
async fn get_dependencies(&self, solvable: SolvableId) -> Dependencies;

/// Whether the solver should stop the dependency resolution algorithm.
///
Expand Down
22 changes: 11 additions & 11 deletions src/problem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use std::collections::{HashMap, HashSet};
use std::fmt;
use std::fmt::{Display, Formatter};
use std::hash::Hash;

use std::rc::Rc;

use itertools::Itertools;
Expand Down Expand Up @@ -52,7 +51,7 @@ impl Problem {
let unresolved_node = graph.add_node(ProblemNode::UnresolvedDependency);

for clause_id in &self.clauses {
let clause = &solver.clauses[*clause_id].kind;
let clause = &solver.clauses.borrow()[*clause_id].kind;
match clause {
Clause::InstallRoot => (),
Clause::Excluded(solvable, reason) => {
Expand All @@ -65,7 +64,7 @@ impl Problem {
&Clause::Requires(package_id, version_set_id) => {
let package_node = Self::add_node(&mut graph, &mut nodes, package_id);

let candidates = solver.cache.get_or_cache_sorted_candidates(version_set_id).unwrap_or_else(|_| {
let candidates = solver.async_runtime.block_on(solver.cache.get_or_cache_sorted_candidates(version_set_id)).unwrap_or_else(|_| {
unreachable!("The version set was used in the solver, so it must have been cached. Therefore cancellation is impossible here and we cannot get an `Err(...)`")
});
if candidates.is_empty() {
Expand Down Expand Up @@ -162,10 +161,11 @@ impl Problem {
>(
&self,
solver: &'a Solver<VS, N, D>,
pool: Rc<Pool<VS, N>>,
merged_solvable_display: &'a M,
) -> DisplayUnsat<'a, VS, N, M> {
let graph = self.graph(solver);
DisplayUnsat::new(graph, solver.pool(), merged_solvable_display)
DisplayUnsat::new(graph, pool, merged_solvable_display)
}
}

Expand Down Expand Up @@ -512,7 +512,7 @@ pub struct DisplayUnsat<'pool, VS: VersionSet, N: PackageName + Display, M: Solv
merged_candidates: HashMap<SolvableId, Rc<MergedProblemNode>>,
installable_set: HashSet<NodeIndex>,
missing_set: HashSet<NodeIndex>,
pool: &'pool Pool<VS, N>,
pool: Rc<Pool<VS, N>>,
merged_solvable_display: &'pool M,
}

Expand All @@ -521,10 +521,10 @@ impl<'pool, VS: VersionSet, N: PackageName + Display, M: SolvableDisplay<VS, N>>
{
pub(crate) fn new(
graph: ProblemGraph,
pool: &'pool Pool<VS, N>,
pool: Rc<Pool<VS, N>>,
merged_solvable_display: &'pool M,
) -> Self {
let merged_candidates = graph.simplify(pool);
let merged_candidates = graph.simplify(&pool);
let installable_set = graph.get_installable_set();
let missing_set = graph.get_missing_set();

Expand Down Expand Up @@ -666,10 +666,10 @@ impl<'pool, VS: VersionSet, N: PackageName + Display, M: SolvableDisplay<VS, N>>
let version = if let Some(merged) = self.merged_candidates.get(&solvable_id) {
reported.extend(merged.ids.iter().cloned());
self.merged_solvable_display
.display_candidates(self.pool, &merged.ids)
.display_candidates(&self.pool, &merged.ids)
} else {
self.merged_solvable_display
.display_candidates(self.pool, &[solvable_id])
.display_candidates(&self.pool, &[solvable_id])
};

let excluded = graph
Expand Down Expand Up @@ -790,9 +790,9 @@ impl<VS: VersionSet, N: PackageName + Display, M: SolvableDisplay<VS, N>> fmt::D
writeln!(
f,
"{indent}{} {} is locked, but another version is required as reported above",
locked.name.display(self.pool),
locked.name.display(&self.pool),
self.merged_solvable_display
.display_candidates(self.pool, &[solvable_id])
.display_candidates(&self.pool, &[solvable_id])
)?;
}
ConflictCause::Excluded(_, _) => continue,
Expand Down
40 changes: 23 additions & 17 deletions src/solver/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use elsa::FrozenMap;
use std::any::Any;
use std::cell::RefCell;
use std::marker::PhantomData;
use std::rc::Rc;

/// Keeps a cache of previously computed and/or requested information about solvables and version
/// sets.
Expand Down Expand Up @@ -65,7 +66,7 @@ impl<VS: VersionSet, N: PackageName, D: DependencyProvider<VS, N>> SolverCache<V
}

/// Returns a reference to the pool used by the solver
pub fn pool(&self) -> &Pool<VS, N> {
pub fn pool(&self) -> Rc<Pool<VS, N>> {
self.provider.pool()
}

Expand All @@ -74,7 +75,7 @@ impl<VS: VersionSet, N: PackageName, D: DependencyProvider<VS, N>> SolverCache<V
///
/// If the provider has requested the solving process to be cancelled, the cancellation value
/// will be returned as an `Err(...)`.
pub fn get_or_cache_candidates(
pub async fn get_or_cache_candidates(
&self,
package_name: NameId,
) -> Result<&Candidates, Box<dyn Any>> {
Expand All @@ -93,6 +94,7 @@ impl<VS: VersionSet, N: PackageName, D: DependencyProvider<VS, N>> SolverCache<V
let candidates = self
.provider
.get_candidates(package_name)
.await
.unwrap_or_default();

// Store information about which solvables dependency information is easy to
Expand Down Expand Up @@ -126,23 +128,24 @@ impl<VS: VersionSet, N: PackageName, D: DependencyProvider<VS, N>> SolverCache<V
///
/// If the provider has requested the solving process to be cancelled, the cancellation value
/// will be returned as an `Err(...)`.
pub fn get_or_cache_matching_candidates(
pub async fn get_or_cache_matching_candidates(
&self,
version_set_id: VersionSetId,
) -> Result<&[SolvableId], Box<dyn Any>> {
match self.version_set_candidates.get(&version_set_id) {
Some(candidates) => Ok(candidates),
None => {
let package_name = self.pool().resolve_version_set_package_name(version_set_id);
let version_set = self.pool().resolve_version_set(version_set_id);
let candidates = self.get_or_cache_candidates(package_name)?;
let pool = self.pool();
let package_name = pool.resolve_version_set_package_name(version_set_id);
let version_set = pool.resolve_version_set(version_set_id);
let candidates = self.get_or_cache_candidates(package_name).await?;

let matching_candidates = candidates
.candidates
.iter()
.copied()
.filter(|&p| {
let version = self.pool().resolve_internal_solvable(p).solvable().inner();
let version = pool.resolve_internal_solvable(p).solvable().inner();
version_set.contains(version)
})
.collect();
Expand All @@ -158,23 +161,24 @@ impl<VS: VersionSet, N: PackageName, D: DependencyProvider<VS, N>> SolverCache<V
///
/// If the provider has requested the solving process to be cancelled, the cancellation value
/// will be returned as an `Err(...)`.
pub fn get_or_cache_non_matching_candidates(
pub async fn get_or_cache_non_matching_candidates(
&self,
version_set_id: VersionSetId,
) -> Result<&[SolvableId], Box<dyn Any>> {
match self.version_set_inverse_candidates.get(&version_set_id) {
Some(candidates) => Ok(candidates),
None => {
let package_name = self.pool().resolve_version_set_package_name(version_set_id);
let version_set = self.pool().resolve_version_set(version_set_id);
let candidates = self.get_or_cache_candidates(package_name)?;
let pool = self.pool();
let package_name = pool.resolve_version_set_package_name(version_set_id);
let version_set = pool.resolve_version_set(version_set_id);
let candidates = self.get_or_cache_candidates(package_name).await?;

let matching_candidates = candidates
.candidates
.iter()
.copied()
.filter(|&p| {
let version = self.pool().resolve_internal_solvable(p).solvable().inner();
let version = pool.resolve_internal_solvable(p).solvable().inner();
!version_set.contains(version)
})
.collect();
Expand All @@ -191,16 +195,18 @@ impl<VS: VersionSet, N: PackageName, D: DependencyProvider<VS, N>> SolverCache<V
///
/// If the provider has requested the solving process to be cancelled, the cancellation value
/// will be returned as an `Err(...)`.
pub fn get_or_cache_sorted_candidates(
pub async fn get_or_cache_sorted_candidates(
&self,
version_set_id: VersionSetId,
) -> Result<&[SolvableId], Box<dyn Any>> {
match self.version_set_to_sorted_candidates.get(&version_set_id) {
Some(candidates) => Ok(candidates),
None => {
let package_name = self.pool().resolve_version_set_package_name(version_set_id);
let matching_candidates = self.get_or_cache_matching_candidates(version_set_id)?;
let candidates = self.get_or_cache_candidates(package_name)?;
let matching_candidates = self
.get_or_cache_matching_candidates(version_set_id)
.await?;
let candidates = self.get_or_cache_candidates(package_name).await?;

// Sort all the candidates in order in which they should be tried by the solver.
let mut sorted_candidates = Vec::new();
Expand Down Expand Up @@ -228,7 +234,7 @@ impl<VS: VersionSet, N: PackageName, D: DependencyProvider<VS, N>> SolverCache<V
///
/// If the provider has requested the solving process to be cancelled, the cancellation value
/// will be returned as an `Err(...)`.
pub fn get_or_cache_dependencies(
pub async fn get_or_cache_dependencies(
&self,
solvable_id: SolvableId,
) -> Result<&Dependencies, Box<dyn Any>> {
Expand All @@ -242,7 +248,7 @@ impl<VS: VersionSet, N: PackageName, D: DependencyProvider<VS, N>> SolverCache<V
return Err(value);
}

let dependencies = self.provider.get_dependencies(solvable_id);
let dependencies = self.provider.get_dependencies(solvable_id).await;
let dependencies_id = self.solvable_dependencies.alloc(dependencies);
self.solvable_to_dependencies
.insert_copy(solvable_id, dependencies_id);
Expand Down
Loading

0 comments on commit fe6a7fa

Please sign in to comment.