diff --git a/benchmarks/benchmarks.py b/benchmarks/benchmarks.py index 3183e0c9ea..cb2ae91ddf 100644 --- a/benchmarks/benchmarks.py +++ b/benchmarks/benchmarks.py @@ -133,6 +133,12 @@ def time_set_abundances(self): for i in range(500): mh.set_abundances(mins) + def time_set_abundances_noclear(self): + mh = self.mh + mins = self.populated_mh.get_mins(with_abundance=True) + for i in range(500): + mh.set_abundances(mins, clear=False) + class PeakmemMinAbundanceSuite(PeakmemMinHashSuite): def setup(self): PeakmemMinHashSuite.setup(self) diff --git a/include/sourmash.h b/include/sourmash.h index bc6c597cc0..e187005126 100644 --- a/include/sourmash.h +++ b/include/sourmash.h @@ -119,6 +119,8 @@ void kmerminhash_add_from(SourmashKmerMinHash *ptr, const SourmashKmerMinHash *o void kmerminhash_add_hash(SourmashKmerMinHash *ptr, uint64_t h); +void kmerminhash_add_hash_with_abundance(SourmashKmerMinHash *ptr, uint64_t h, uint64_t abundance); + void kmerminhash_add_many(SourmashKmerMinHash *ptr, const uint64_t *hashes_ptr, uintptr_t insize); void kmerminhash_add_protein(SourmashKmerMinHash *ptr, const char *sequence); @@ -130,6 +132,8 @@ void kmerminhash_add_word(SourmashKmerMinHash *ptr, const char *word); double kmerminhash_angular_similarity(const SourmashKmerMinHash *ptr, const SourmashKmerMinHash *other); +void kmerminhash_clear(SourmashKmerMinHash *ptr); + uint64_t kmerminhash_count_common(const SourmashKmerMinHash *ptr, const SourmashKmerMinHash *other, bool downsample); @@ -192,7 +196,8 @@ uint64_t kmerminhash_seed(const SourmashKmerMinHash *ptr); void kmerminhash_set_abundances(SourmashKmerMinHash *ptr, const uint64_t *hashes_ptr, const uint64_t *abunds_ptr, - uintptr_t insize); + uintptr_t insize, + bool clear); double kmerminhash_similarity(const SourmashKmerMinHash *ptr, const SourmashKmerMinHash *other, diff --git a/sourmash/_minhash.py b/sourmash/_minhash.py index 3151e7149c..eee206db88 100644 --- a/sourmash/_minhash.py +++ b/sourmash/_minhash.py @@ -363,6 +363,20 @@ def add_hash(self, h): "Add a single hash value." return self._methodcall(lib.kmerminhash_add_hash, h) + def add_hash_with_abundance(self, h, a): + "Add a single hash value with an abundance." + if self.track_abundance: + return self._methodcall(lib.kmerminhash_add_hash_with_abundance, h, a) + else: + raise RuntimeError( + "Use track_abundance=True when constructing " + "the MinHash to use add_hash_with_abundance." + ) + + def clear(self): + "Clears all hashes and abundances." + return self._methodcall(lib.kmerminhash_clear) + def translate_codon(self, codon): "Translate a codon into an amino acid." try: @@ -544,18 +558,19 @@ def __iadd__(self, other): merge = __iadd__ - def set_abundances(self, values): + def set_abundances(self, values, clear=True): """Set abundances for hashes from ``values``, where ``values[hash] = abund`` """ if self.track_abundance: hashes = [] abunds = [] + for h, v in values.items(): hashes.append(h) abunds.append(v) - self._methodcall(lib.kmerminhash_set_abundances, hashes, abunds, len(hashes)) + self._methodcall(lib.kmerminhash_set_abundances, hashes, abunds, len(hashes), clear) else: raise RuntimeError( "Use track_abundance=True when constructing " diff --git a/src/core/src/ffi/minhash.rs b/src/core/src/ffi/minhash.rs index efeef3b2ee..55f74cf9ca 100644 --- a/src/core/src/ffi/minhash.rs +++ b/src/core/src/ffi/minhash.rs @@ -90,6 +90,13 @@ unsafe fn kmerminhash_add_protein(ptr: *mut SourmashKmerMinHash, sequence: *cons } } +#[no_mangle] +pub unsafe extern "C" fn kmerminhash_clear(ptr: *mut SourmashKmerMinHash) { + let mh = SourmashKmerMinHash::as_rust_mut(ptr); + + mh.clear(); +} + #[no_mangle] pub unsafe extern "C" fn kmerminhash_add_hash(ptr: *mut SourmashKmerMinHash, h: u64) { let mh = SourmashKmerMinHash::as_rust_mut(ptr); @@ -97,6 +104,13 @@ pub unsafe extern "C" fn kmerminhash_add_hash(ptr: *mut SourmashKmerMinHash, h: mh.add_hash(h); } +#[no_mangle] +pub unsafe extern "C" fn kmerminhash_add_hash_with_abundance(ptr: *mut SourmashKmerMinHash, h: u64, abundance: u64) { + let mh = SourmashKmerMinHash::as_rust_mut(ptr); + + mh.add_hash_with_abundance(h, abundance); +} + #[no_mangle] pub unsafe extern "C" fn kmerminhash_add_word(ptr: *mut SourmashKmerMinHash, word: *const c_char) { let mh = SourmashKmerMinHash::as_rust_mut(ptr); @@ -228,6 +242,7 @@ unsafe fn kmerminhash_set_abundances( hashes_ptr: *const u64, abunds_ptr: *const u64, insize: usize, + clear: bool, ) -> Result<()> { let mh = SourmashKmerMinHash::as_rust_mut(ptr); @@ -247,7 +262,9 @@ unsafe fn kmerminhash_set_abundances( pairs.sort(); // Reset the minhash - mh.clear(); + if clear { + mh.clear(); + } mh.add_many_with_abund(&pairs)?; diff --git a/src/core/src/sketch/minhash.rs b/src/core/src/sketch/minhash.rs index 52c6b725b9..73afd7a4c7 100644 --- a/src/core/src/sketch/minhash.rs +++ b/src/core/src/sketch/minhash.rs @@ -424,6 +424,22 @@ impl KmerMinHash { } } + pub fn set_hash_with_abundance(&mut self, hash: u64, abundance: u64) { + let mut found = false; + if let Ok(pos) = self.mins.binary_search(&hash) { + if self.mins[pos] == hash { + found = true; + if let Some(ref mut abunds) = self.abunds { + abunds[pos] = abundance; + } + } + } + + if !found { + self.add_hash_with_abundance(hash, abundance); + } + } + pub fn add_word(&mut self, word: &[u8]) { let hash = _hash_murmur(word, self.seed); self.add_hash(hash); diff --git a/tests/test__minhash.py b/tests/test__minhash.py index 3395f6d6db..866ceb1769 100644 --- a/tests/test__minhash.py +++ b/tests/test__minhash.py @@ -1041,6 +1041,48 @@ def test_abundance_simple(): assert a.get_mins(with_abundance=True) == {2110480117637990133: 2} +def test_add_hash_with_abundance(): + a = MinHash(20, 5, False, track_abundance=True) + + a.add_hash_with_abundance(10, 1) + assert a.get_mins(with_abundance=True) == {10: 1} + + a.add_hash_with_abundance(20, 2) + assert a.get_mins(with_abundance=True) == {10: 1, 20: 2} + + a.add_hash_with_abundance(10, 2) + assert a.get_mins(with_abundance=True) == {10: 3, 20: 2} + + +def test_add_hash_with_abundance_2(): + a = MinHash(20, 5, False, track_abundance=False) + + with pytest.raises(RuntimeError) as e: + a.add_hash_with_abundance(10, 1) + + assert "track_abundance=True when constructing" in e.value.args[0] + + +def test_clear(): + a = MinHash(20, 5, False, track_abundance=True) + + a.add_hash(10) + assert a.get_mins(with_abundance=True) == {10: 1} + + a.clear() + assert a.get_mins(with_abundance=True) == {} + + +def test_clear_2(): + a = MinHash(20, 5, False, track_abundance=False) + + a.add_hash(10) + assert a.get_mins() == [10] + + a.clear() + assert a.get_mins() == [] + + def test_abundance_simple_2(): a = MinHash(20, 5, False, track_abundance=True) b = MinHash(20, 5, False, track_abundance=True) @@ -1127,6 +1169,50 @@ def test_set_abundance_2(): assert new_mh.get_mins(with_abundance=True) == mins +def test_set_abundance_clear(): + # on empty minhash, clear should have no effect + a = MinHash(20, 5, False, track_abundance=True) + b = MinHash(20, 5, False, track_abundance=True) + + a.set_abundances({1: 3, 2: 4}, clear=True) + b.set_abundances({1: 3, 2: 4}, clear=False) + + assert a.get_mins() == b.get_mins() + + +def test_set_abundance_clear_2(): + # default should be clear=True + a = MinHash(20, 5, False, track_abundance=True) + + a.add_hash(10) + assert a.get_mins(with_abundance=True) == {10: 1} + + a.set_abundances({20: 2}) + assert a.get_mins(with_abundance=True) == {20: 2} + + +def test_set_abundance_clear_3(): + a = MinHash(20, 5, False, track_abundance=True) + + a.add_hash(10) + assert a.get_mins(with_abundance=True) == {10: 1} + + a.set_abundances({20: 1, 30: 4}, clear=False) + assert a.get_mins(with_abundance=True) == {10: 1, 20: 1, 30: 4} + + +def test_set_abundance_clear_4(): + # setting the abundance of an already set hash should add + # the abundances together + a = MinHash(20, 5, False, track_abundance=True) + + a.set_abundances({20: 2, 10: 1}, clear=False) # should also sort the hashes + assert a.get_mins(with_abundance=True) == {10: 1, 20: 2} + + a.set_abundances({20: 1, 10: 2}, clear=False) + assert a.get_mins(with_abundance=True) == {10: 3, 20: 3} + + def test_reset_abundance_initialized(): a = MinHash(1, 4, track_abundance=True) a.add_sequence('ATGC')