Skip to content

Commit

Permalink
MRG: fix performance regression in manysearch by removing unnecessa…
Browse files Browse the repository at this point in the history
…ry downsampling (#464)

* add support for ignoring abundance

* cargo fmt

* avoid downsampling until we know there is overlap

* change downsample to true; add panic assertion

* move downsampling side guard

* eliminate redundant overlap check

* move calc_abund_stats

* extract abundance code into own function; avoid downsampling if poss

* cleanup

* fmt
  • Loading branch information
ctb authored Oct 8, 2024
1 parent 6bce2f4 commit 88e406f
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 72 deletions.
8 changes: 6 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
/// Python interface Rust code for sourmash_plugin_branchwater.
use pyo3::prelude::*;
use singlesketch::singlesketch;

#[macro_use]
extern crate simple_error;
Expand All @@ -24,7 +23,7 @@ mod singlesketch;
use camino::Utf8PathBuf as PathBuf;

#[pyfunction]
#[pyo3(signature = (querylist_path, siglist_path, threshold, ksize, scaled, moltype, output_path=None))]
#[pyo3(signature = (querylist_path, siglist_path, threshold, ksize, scaled, moltype, output_path=None, ignore_abundance=false))]
fn do_manysearch(
querylist_path: String,
siglist_path: String,
Expand All @@ -33,14 +32,18 @@ fn do_manysearch(
scaled: usize,
moltype: String,
output_path: Option<String>,
ignore_abundance: Option<bool>,
) -> anyhow::Result<u8> {
let againstfile_path: PathBuf = siglist_path.clone().into();
let selection = build_selection(ksize, scaled, &moltype);
eprintln!("selection scaled: {:?}", selection.scaled());
let allow_failed_sigpaths = true;

let ignore_abundance = ignore_abundance.unwrap_or(false);

// if siglist_path is revindex, run mastiff_manysearch; otherwise run manysearch
if is_revindex_database(&againstfile_path) {
// note: mastiff_manysearch ignores abundance automatically.
match mastiff_manysearch::mastiff_manysearch(
querylist_path,
againstfile_path,
Expand All @@ -63,6 +66,7 @@ fn do_manysearch(
threshold,
output_path,
allow_failed_sigpaths,
ignore_abundance,
) {
Ok(_) => Ok(0),
Err(e) => {
Expand Down
181 changes: 112 additions & 69 deletions src/manysearch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ use std::sync::atomic::AtomicUsize;

use crate::utils::{csvwriter_thread, load_collection, load_sketches, ReportType, SearchResult};
use sourmash::ani_utils::ani_from_containment;
use sourmash::errors::SourmashError;
use sourmash::selection::Selection;
use sourmash::signature::SigsTrait;
use sourmash::sketch::minhash::KmerMinHash;

pub fn manysearch(
query_filepath: String,
Expand All @@ -21,6 +23,7 @@ pub fn manysearch(
threshold: f64,
output: Option<String>,
allow_failed_sigpaths: bool,
ignore_abundance: bool,
) -> Result<()> {
// Load query collection
let query_collection = load_collection(
Expand Down Expand Up @@ -71,76 +74,71 @@ pub fn manysearch(
Ok(against_sig) => {
if let Some(against_mh) = against_sig.minhash() {
for query in query_sketchlist.iter() {
// to do - let user choose?
let calc_abund_stats = against_mh.track_abundance();

let against_mh_ds = against_mh.downsample_scaled(query.minhash.scaled()).unwrap();
let overlap =
query.minhash.count_common(&against_mh_ds, false).unwrap() as f64;

// avoid calculating details unless there is overlap
let overlap = query
.minhash
.count_common(against_mh, true)
.expect("incompatible sketches")
as f64;

let query_size = query.minhash.size() as f64;
let containment_query_in_target = overlap / query_size;
// only calculate results if we have shared hashes
if overlap > 0.0 {
let query_size = query.minhash.size() as f64;
let containment_query_in_target = overlap / query_size;
if containment_query_in_target > threshold {
let target_size = against_mh.size() as f64;
let containment_target_in_query = overlap / target_size;

let max_containment =
containment_query_in_target.max(containment_target_in_query);
let jaccard = overlap / (target_size + query_size - overlap);

let qani = ani_from_containment(
containment_query_in_target,
against_mh.ksize() as f64,
);
let mani = ani_from_containment(
containment_target_in_query,
against_mh.ksize() as f64,
);
let query_containment_ani = Some(qani);
let match_containment_ani = Some(mani);
let average_containment_ani = Some((qani + mani) / 2.);
let max_containment_ani = Some(f64::max(qani, mani));

let (total_weighted_hashes, n_weighted_found, average_abund, median_abund, std_abund) = if calc_abund_stats {
match query.minhash.inflated_abundances(&against_mh_ds) {
Ok((abunds, sum_weighted_overlap)) => {
let sum_all_abunds = against_mh_ds.sum_abunds() as usize;
let average_abund = sum_weighted_overlap as f64 / abunds.len() as f64;
let median_abund = median(abunds.iter().cloned()).unwrap();
let std_abund = stddev(abunds.iter().cloned());
(Some(sum_all_abunds), Some(sum_weighted_overlap as usize), Some(average_abund), Some(median_abund), Some(std_abund))
}
Err(e) => {
eprintln!("Error calculating abundances for query: {}, against: {}; Error: {}", query.name, against_sig.name(), e);
continue;
}
}
} else {
(None, None, None, None, None)
};

results.push(SearchResult {
query_name: query.name.clone(),
query_md5: query.md5sum.clone(),
match_name: against_sig.name(),
containment: containment_query_in_target,
intersect_hashes: overlap as usize,
match_md5: Some(against_sig.md5sum()),
jaccard: Some(jaccard),
max_containment: Some(max_containment),
average_abund,
median_abund,
std_abund,
query_containment_ani,
match_containment_ani,
average_containment_ani,
max_containment_ani,
n_weighted_found,
total_weighted_hashes,
});
}
if containment_query_in_target > threshold {
let target_size = against_mh.size() as f64;
let containment_target_in_query = overlap / target_size;

let max_containment =
containment_query_in_target.max(containment_target_in_query);
let jaccard = overlap / (target_size + query_size - overlap);

let qani = ani_from_containment(
containment_query_in_target,
against_mh.ksize() as f64,
);
let mani = ani_from_containment(
containment_target_in_query,
against_mh.ksize() as f64,
);
let query_containment_ani = Some(qani);
let match_containment_ani = Some(mani);
let average_containment_ani = Some((qani + mani) / 2.);
let max_containment_ani = Some(f64::max(qani, mani));

let calc_abund_stats =
against_mh.track_abundance() && !ignore_abundance;
let (
total_weighted_hashes,
n_weighted_found,
average_abund,
median_abund,
std_abund,
) = if calc_abund_stats {
downsample_and_inflate_abundances(&query.minhash, against_mh)
.ok()?
} else {
(None, None, None, None, None)
};

results.push(SearchResult {
query_name: query.name.clone(),
query_md5: query.md5sum.clone(),
match_name: against_sig.name(),
containment: containment_query_in_target,
intersect_hashes: overlap as usize,
match_md5: Some(against_sig.md5sum()),
jaccard: Some(jaccard),
max_containment: Some(max_containment),
average_abund,
median_abund,
std_abund,
query_containment_ani,
match_containment_ani,
average_containment_ani,
max_containment_ani,
n_weighted_found,
total_weighted_hashes,
});
}
}
} else {
Expand Down Expand Up @@ -197,3 +195,48 @@ pub fn manysearch(

Ok(())
}

fn downsample_and_inflate_abundances(
query: &KmerMinHash,
against: &KmerMinHash,
) -> Result<
(
Option<usize>,
Option<usize>,
Option<f64>,
Option<f64>,
Option<f64>,
),
SourmashError,
> {
let query_scaled = query.scaled();
let against_scaled = against.scaled();

let abunds: Vec<u64>;
let sum_weighted: u64;
let sum_all_abunds: usize;

// avoid downsampling if we can
if against_scaled != query_scaled {
let against_ds = against
.downsample_scaled(query.scaled())
.expect("cannot downsample sketch");
(abunds, sum_weighted) = query.inflated_abundances(&against_ds)?;
sum_all_abunds = against_ds.sum_abunds() as usize;
} else {
(abunds, sum_weighted) = query.inflated_abundances(against)?;
sum_all_abunds = against.sum_abunds() as usize;
}

let average_abund = sum_weighted as f64 / abunds.len() as f64;
let median_abund = median(abunds.iter().cloned()).expect("error");
let std_abund = stddev(abunds.iter().cloned());

Ok((
Some(sum_all_abunds),
Some(sum_weighted as usize),
Some(average_abund),
Some(median_abund),
Some(std_abund),
))
}
5 changes: 4 additions & 1 deletion src/python/sourmash_plugin_branchwater/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def __init__(self, p):
p.add_argument('-N', '--no-pretty-print', action='store_false',
dest='pretty_print',
help="do not display results (e.g. for large output)")
p.add_argument('--ignore-abundance', action='store_true',
help="do not do expensive abundance calculations")

def main(self, args):
print_version()
Expand All @@ -80,7 +82,8 @@ def main(self, args):
args.ksize,
args.scaled,
args.moltype,
args.output)
args.output,
args.ignore_abundance)
if status == 0:
notify(f"...manysearch is done! results in '{args.output}'")

Expand Down

0 comments on commit 88e406f

Please sign in to comment.