From 76028d7f6004296005cadbd85033ee94a77842f2 Mon Sep 17 00:00:00 2001 From: Tim de Jager Date: Fri, 4 Oct 2024 13:32:07 +0200 Subject: [PATCH] nwip: total ordering for the dependency provider --- crates/rattler_conda_types/src/channel/mod.rs | 12 +- crates/rattler_conda_types/src/version/mod.rs | 100 +++++++- .../rattler_solve/src/resolvo/conda_util.rs | 232 +++++++++++++++++- crates/rattler_solve/src/resolvo/mod.rs | 9 +- rust-toolchain | 2 +- 5 files changed, 341 insertions(+), 14 deletions(-) diff --git a/crates/rattler_conda_types/src/channel/mod.rs b/crates/rattler_conda_types/src/channel/mod.rs index b9050ec04..802a455c3 100644 --- a/crates/rattler_conda_types/src/channel/mod.rs +++ b/crates/rattler_conda_types/src/channel/mod.rs @@ -78,7 +78,7 @@ impl ChannelConfig { /// Represents a channel description as either a name (e.g. `conda-forge`) or a /// base url. -#[derive(Debug, Clone, Eq, PartialEq, Hash)] +#[derive(Debug, Clone, Eq, PartialEq)] pub enum NamedChannelOrUrl { /// A named channel Name(String), @@ -90,6 +90,16 @@ pub enum NamedChannelOrUrl { Path(Utf8TypedPathBuf), } +impl std::hash::Hash for NamedChannelOrUrl { + fn hash(&self, state: &mut H) { + match self { + NamedChannelOrUrl::Name(name) => name.hash(state), + NamedChannelOrUrl::Url(url) => url.as_str().hash(state), + NamedChannelOrUrl::Path(path) => path.as_str().hash(state), + } + } +} + impl NamedChannelOrUrl { /// Returns the string representation of the channel. /// diff --git a/crates/rattler_conda_types/src/version/mod.rs b/crates/rattler_conda_types/src/version/mod.rs index 52ab8daa8..0651ab399 100644 --- a/crates/rattler_conda_types/src/version/mod.rs +++ b/crates/rattler_conda_types/src/version/mod.rs @@ -1037,7 +1037,7 @@ mod test { use crate::version::StrictVersion; - use super::Version; + use super::{Component, Version}; // Tests are inspired by: https://github.com/conda/conda/blob/33a142c16530fcdada6c377486f1c1a385738a96/tests/models/test_version.py @@ -1049,7 +1049,7 @@ mod test { Restart, } - let versions = [ + let versions_str = [ " 0.4", "== 0.4.0", " < 0.4.1.rc", @@ -1079,18 +1079,21 @@ mod test { " < 2!0.4.1", // epoch increased again ]; - let ops = versions.iter().map(|&v| { - let (op, version) = if let Some((op, version)) = v.trim().split_once(' ') { + let mut versions = Vec::new(); + + let ops = versions_str.iter().map(|&v| { + let (op, version_str) = if let Some((op, version)) = v.trim().split_once(' ') { (op, version.trim()) } else { ("", v.trim()) }; - let version: Version = version.parse().unwrap(); + let version: Version = version_str.parse().unwrap(); let op = match op { "<" => CmpOp::Less, "==" => CmpOp::Equal, _ => CmpOp::Restart, }; + versions.push(version.clone()); (op, version) }); @@ -1127,6 +1130,10 @@ mod test { } previous = Some(version); } + + // Try to see if the sort works + let mut cloned_versions = versions.clone(); + cloned_versions.sort(); } #[test] @@ -1397,4 +1404,87 @@ mod test { expected ); } + + #[test] + fn test_component_total_order() { + // Create instances of each variant + let components = vec![ + Component::Dev, + Component::UnderscoreOrDash { is_dash: false }, + Component::Iden(Box::from("alpha")), + Component::Iden(Box::from("beta")), + Component::Numeral(1), + Component::Numeral(2), + Component::Post, + ]; + + // Check that each component equals itself + for a in &components { + assert_eq!(a.cmp(a), Ordering::Equal); + } + + for (i, a) in components.iter().enumerate() { + for b in components[i + 1..].iter() { + let ord = a.cmp(b); + assert_eq!( + ord, + Ordering::Less, + "Expected {:?} < {:?}, but found {:?}", + a, + b, + ord + ); + } + // Check the reverse ordering as well + // I think this should automatically check transitivity + // If a <= b and b <= c, then a <= c + for b in components[..i].iter() { + let ord = a.cmp(b); + assert_eq!( + ord, + Ordering::Greater, + "Expected {:?} > {:?}, but found {:?}", + a, + b, + ord + ); + } + } + + // Check antisymmetry: If a <= b and b <= a, then a == b + // for a in &components { + // for b in &components { + // let ord_ab = a.cmp(b); + // let ord_ba = b.cmp(a); + // if ord_ab != Ordering::Greater && ord_ba != Ordering::Greater { + // assert_eq!( + // ord_ab, + // ord_ba.reverse(), + // "Antisymmetry violated between {:?} and {:?}", + // a, + // b + // ); + // } + // } + // } + + // for a in &components { + // for b in &components { + // for c in &components { + // let ord_ab = a.cmp(b); + // let ord_bc = b.cmp(c); + // let ord_ac = a.cmp(c); + // if ord_ab != Ordering::Greater && ord_bc != Ordering::Greater { + // assert!( + // ord_ac != Ordering::Greater, + // "Transitivity violated between {:?}, {:?}, and {:?}", + // a, + // b, + // c + // ); + // } + // } + // } + // } + } } diff --git a/crates/rattler_solve/src/resolvo/conda_util.rs b/crates/rattler_solve/src/resolvo/conda_util.rs index d10a1f529..6d87d0da5 100644 --- a/crates/rattler_solve/src/resolvo/conda_util.rs +++ b/crates/rattler_solve/src/resolvo/conda_util.rs @@ -1,17 +1,243 @@ -use std::{cmp::Ordering, collections::HashMap}; +use std::{ + cmp::Ordering, + collections::{HashMap, HashSet}, + ops::Deref, +}; use futures::future::FutureExt; +use itertools::Itertools; use rattler_conda_types::Version; -use resolvo::{Dependencies, Requirement, SolvableId, SolverCache, VersionSetId}; +use resolvo::{Dependencies, NameId, Requirement, SolvableId, SolverCache, VersionSetId}; use crate::resolvo::CondaDependencyProvider; +use super::SolverPackageRecord; + #[derive(Copy, Clone, Debug, Eq, PartialEq)] pub(super) enum CompareStrategy { Default, LowestVersion, } +/// Sorts the candidates based on the strategy. +/// and some different rules +pub struct SolvableSorter<'a, 'repo> { + solver: &'a SolverCache>, + strategy: CompareStrategy, +} + +impl<'a, 'repo> SolvableSorter<'a, 'repo> { + pub fn new( + solver: &'a SolverCache>, + strategy: CompareStrategy, + ) -> Self { + Self { solver, strategy } + } + + fn solvable_record(&self, id: SolvableId) -> SolverPackageRecord<'repo> { + let pool = &self.solver.provider().pool; + let solvable = pool.resolve_solvable(a); + solvable.record + } + + /// This function can be used for the initial sorting of the candidates. + pub fn sort_by_name_version_build(&self, solvables: &mut [SolvableId]) { + solvables.sort_by(|a, b| self.initial_sort(*a, *b)); + } + + /// Sort the candidates based on: + /// 1. Whether the package has tracked features + /// 2. The version of the package + /// 3. The build number of the package + fn initial_sort(&self, a: SolvableId, b: SolvableId) -> Ordering { + let a_record = &self.solvable_record(a); + let b_record = &self.solvable_record(b); + + // First compare by "tracked_features". If one of the packages has a tracked + // feature it is sorted below the one that doesn't have the tracked feature. + let a_has_tracked_features = !a_record.track_features().is_empty(); + let b_has_tracked_features = !b_record.track_features().is_empty(); + match a_has_tracked_features.cmp(&b_has_tracked_features) { + Ordering::Less => return Ordering::Less, + Ordering::Greater => return Ordering::Greater, + Ordering::Equal => {} + }; + + // Otherwise, select the variant with the highest version + match (self.strategy, a_record.version().cmp(b_record.version())) { + (CompareStrategy::Default, Ordering::Greater) + | (CompareStrategy::LowestVersion, Ordering::Less) => return Ordering::Less, + (CompareStrategy::Default, Ordering::Less) + | (CompareStrategy::LowestVersion, Ordering::Greater) => return Ordering::Greater, + (_, Ordering::Equal) => {} + }; + + // Otherwise, select the variant with the highest build number + match a_record.build_number().cmp(&b_record.build_number()) { + Ordering::Less => return Ordering::Greater, + Ordering::Greater => return Ordering::Less, + Ordering::Equal => return Ordering::Equal, + }; + } + + fn find_first_unsorted(&self, solvables: &[SolvableId]) -> Option { + // Find the first solvable record pair that have the same, name, version and build number + // and return its index, this assumes that solvables have been sorted by name, version and build number + for (i, solvable) in solvables.iter().enumerate() { + if i + 1 < solvables.len() { + let next_solvable = solvables[i + 1]; + let solvable_record = self.solvable_record(*solvable); + let next_solvable_record = self.solvable_record(next_solvable); + + if solvable_record.name() == next_solvable_record.name() + && solvable_record.version() == next_solvable_record.version() + && solvable_record.build_number() == next_solvable_record.build_number() + { + return Some(i); + } + } + } + None + } + + /// Sorts the solvables by the highest version of the dependencies shared by the solvables. + /// what this function does is: + /// 1. Find the first unsorted solvable in the list + /// 2. Get the dependencies for each solvable + /// 3. Get the known dependencies for each solvable, filter out the unknown dependencies + /// 4. Retain the dependencies that are shared by all the solvables + /// 5. Create a max vector which is the maximum version of each of the shared dependencies + /// 6. Calculate a total score by counting how often the solvable has a dependency that is in the max vector + /// 7. Sort by the total score and use timestamp of the record as a tie breaker + pub(crate) fn sort_by_highest_version( + &self, + solvables: &mut [SolvableId], + highest_version_spec: &HashMap>, + ) { + let first_unsorted = self.find_first_unsorted(solvables); + let first_unsorted = match first_unsorted { + Some(i) => i, + None => return, + }; + + // Split the solvables into two parts, the ordered and the ones that need ordering + let (_, needs_ordering) = solvables.split_at_mut(first_unsorted); + + // Get the dependencies for each solvable + let dependencies = needs_ordering + .iter() + .map(|id| { + self.solver + .get_or_cache_dependencies(*id) + .now_or_never() + .expect("get_or_cache_dependencies failed") + .map(|deps| (id, deps)) + }) + .collect::, _>>(); + + let dependencies = match dependencies { + Ok(dependencies) => dependencies, + // Solver cancelation, lets just return + Err(_) => return, + }; + + // Get the known dependencies for each solvable, filter out the unknown dependencies + let id_and_deps = dependencies + .into_iter() + // Only consider known dependencies + .filter_map(|(i, deps)| match deps { + Dependencies::Known(known_dependencies) => Some((i, known_dependencies)), + Dependencies::Unknown(_) => None, + }) + .map(|(i, known)| { + // Map all known dependencies to the package names + let dep_ids = known.requirements.iter().filter_map(|req| match req { + Requirement::Single(version_set_id) => Some(( + self.solver + .provider() + .pool + .resolve_version_set_package_name(*version_set_id), + *version_set_id, + )), + // Ignore union requirements + Requirement::Union(_) => None, + }); + (i, dep_ids.collect::>()) + }) + .collect_vec(); + + let unique_names: HashSet<_> = unique_name_ids( + id_and_deps + .iter() + .map(|(_, names)| names.iter().map(|(name, _)| *name).collect()), + ); + + // Only retain the dependencies that are shared by all solvables + let shared_dependencies = id_and_deps + .into_iter() + .map(|(i, names)| { + ( + i, + names + .into_iter() + .filter(|(name, _)| unique_names.contains(name)) + .collect::>(), + ) + }) + .collect_vec(); + + // Map the shared dependencies to the highest version of each dependency + + // Get the set of dependencies that each solvable has + } +} + +struct Sorter { + max_map: HashMap, + +} + +fn max_transforms() -> + + +// TODO: remove once we have make NameId Ord +// +#[repr(transparent)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +struct NameIdWrapper(pub NameId); + +impl Deref for NameIdWrapper { + type Target = NameId; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Ord for NameIdWrapper { + fn cmp(&self, other: &Self) -> Ordering { + self.0 .0.cmp(&other.0 .0) + } +} + +impl PartialOrd for NameIdWrapper { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +/// Get the unique package names from a list of vectors of package names. +fn unique_name_ids<'a>(vectors: impl IntoIterator>) -> HashSet { + let mut iter = vectors.into_iter(); + if let Some(first_set) = iter.next() { + iter.fold(first_set.clone(), |mut acc: HashSet, set| { + acc.retain(|item| set.contains(item)); + acc + }) + } else { + HashSet::new() // Return empty set if input is empty + } +} /// Returns the order of two candidates based on the order used by conda. #[allow(clippy::too_many_arguments)] pub(super) fn compare_candidates( @@ -58,8 +284,6 @@ pub(super) fn compare_candidates( Ordering::Equal => {} }; - // return Ordering::Equal; - // Otherwise, compare the dependencies of the variants. If there are similar // dependencies select the variant that selects the highest version of the // dependency. diff --git a/crates/rattler_solve/src/resolvo/mod.rs b/crates/rattler_solve/src/resolvo/mod.rs index 0159aca24..ba4335a7e 100644 --- a/crates/rattler_solve/src/resolvo/mod.rs +++ b/crates/rattler_solve/src/resolvo/mod.rs @@ -484,9 +484,12 @@ impl<'a> DependencyProvider for CondaDependencyProvider<'a> { } } }; - solvables.sort_by(|&p1, &p2| { - conda_util::compare_candidates(p1, p2, solver, &mut highest_version_spec, strategy) - }); + + let sorter = conda_util::SolvableSorter::new(solver, strategy); + // First initial sort + sorter.sort_by_name_version_build(solvables); + // Sort by highest version + sorter.sort_by_highest_version(solvables, &mut highest_version_spec); } async fn get_candidates(&self, name: NameId) -> Option { diff --git a/rust-toolchain b/rust-toolchain index aaceec04e..dbd41264a 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -1.80.0 +1.81.0