From 55325ba662ee44f57e3e3db3d9a0ba927b2b4dd5 Mon Sep 17 00:00:00 2001 From: LTLA Date: Sun, 14 Jul 2024 23:27:38 -0700 Subject: [PATCH] Minor optimizations for mismatches due to ambiguity. In particular, there is no need to store a rolling queue of ambiguous positions; we only need the most recent position to determine if there are any ambiguous bases in the current window of the sequence. Also expanded the test scenarios for scanning ambiguous bases. --- include/kaori/ScanTemplate.hpp | 87 ++++++++++++++++----------- tests/src/ScanTemplate.cpp | 105 +++++++++++++++++++++++++++++++-- 2 files changed, 151 insertions(+), 41 deletions(-) diff --git a/include/kaori/ScanTemplate.hpp b/include/kaori/ScanTemplate.hpp index 577c074..f08daea 100644 --- a/include/kaori/ScanTemplate.hpp +++ b/include/kaori/ScanTemplate.hpp @@ -2,7 +2,10 @@ #define KAORI_SCAN_TEMPLATE_HPP #include -#include +#include +#include +#include + #include "utils.hpp" /** @@ -63,6 +66,7 @@ class ScanTemplate { if (b != '-') { add_base_to_hash(forward_ref, b); add_mask_to_hash(forward_mask); + forward_mask_ambiguous.set(i); } else { shift_hash(forward_ref); shift_hash(forward_mask); @@ -85,6 +89,7 @@ class ScanTemplate { if (b != '-') { add_base_to_hash(reverse_ref, complement_base(b)); add_mask_to_hash(reverse_mask); + reverse_mask_ambiguous.set(i); } else { shift_hash(reverse_ref); shift_hash(reverse_mask); @@ -126,10 +131,13 @@ class ScanTemplate { /** * @cond */ - std::bitset state, ambiguous; + std::bitset state; const char * seq; size_t len; - std::deque bad; + + std::bitset ambiguous; // we only need a yes/no for the ambiguous state, so we can use a smaller bitset. + size_t last_ambiguous; // contains the position of the most recent ambiguous base; should only be read if any_ambiguous = true. + bool any_ambiguous = false; // indicates whether ambiguous.count() > 0. /** * @endcond */ @@ -156,13 +164,20 @@ class ScanTemplate { if (is_standard_base(base)) { add_base_to_hash(out.state, base); - if (!out.bad.empty()) { - shift_hash(out.ambiguous); + + if (out.any_ambiguous) { + out.ambiguous <<= 1; } } else { add_other_to_hash(out.state); - add_other_to_hash(out.ambiguous); - out.bad.push_back(i); + + if (out.any_ambiguous) { + out.ambiguous <<= 1; + } else { + out.any_ambiguous = true; + } + out.ambiguous.set(0); + out.last_ambiguous = i; } } } else { @@ -181,28 +196,32 @@ class ScanTemplate { * On return, `state` is updated with the details of the current match at a particular position on the read sequence. */ void next(State& state) const { - if (!state.bad.empty() && state.bad.front() == state.position) { - state.bad.pop_front(); - if (state.bad.empty()) { - // This should effectively clear the ambiguous bitset, allowing - // us to skip its shifting if there are no more ambiguous - // bases. We do it here because we won't get an opportunity to - // do it later; as 'bad' is empty, the shift below is skipped. - shift_hash(state.ambiguous); - } - } - size_t right = state.position + length; char base = state.seq[right]; + if (is_standard_base(base)) { add_base_to_hash(state.state, base); // no need to trim off the end, the mask will handle that. - if (!state.bad.empty()) { - shift_hash(state.ambiguous); + if (state.any_ambiguous) { + state.ambiguous <<= 1; + + // If the last ambiguous position is equal to 'position', the + // ensuing increment to the latter will shift it out of the + // hash... at which point, we've got no ambiguity left. + if (state.last_ambiguous == state.position) { + state.any_ambiguous = false; + } } + } else { add_other_to_hash(state.state); - add_other_to_hash(state.ambiguous); - state.bad.push_back(right); + + if (state.any_ambiguous) { + state.ambiguous <<= 1; + } else { + state.any_ambiguous = true; + } + state.ambiguous.set(0); + state.last_ambiguous = right; } ++state.position; @@ -221,6 +240,9 @@ class ScanTemplate { int mismatches; bool forward, reverse; + std::bitset forward_mask_ambiguous; // we only need a yes/no for whether a position is an ambiguous base, so we can use a smaller bitset. + std::bitset reverse_mask_ambiguous; + static void add_mask_to_hash(std::bitset& current) { shift_hash(current); current.set(0); @@ -230,22 +252,17 @@ class ScanTemplate { return; } - static int strand_match(const State& match, const std::bitset& ref, const std::bitset& mask) { + static int strand_match(const State& match, const std::bitset& ref, const std::bitset& mask, const std::bitset& mask_ambiguous) { // pop count here is equal to the number of non-ambiguous mismatches * // 2 + number of ambiguous mismatches * 3. This is because // non-ambiguous bases are encoded by 1 set bit per 4 bases (so 2 are // left after a XOR'd mismatch), while ambiguous mismatches are encoded // by all set bits per 4 bases (which means that 3 are left after XOR). - int pcount = ((match.state & mask) ^ ref).count(); - - // Counting the number of ambiguous bases after masking. Each ambiguous - // base is represented by 4 set bits, so we divide by 4 to get the number - // of bases; then we multiply by three to remove their contribution. The - // difference is then divided by two to get the number of non-ambig mm's. - if (!match.bad.empty()) { - int acount = (match.ambiguous & mask).count(); - acount /= 4; - return acount + (pcount - acount * 3) / 2; + size_t pcount = ((match.state & mask) ^ ref).count(); + + if (match.any_ambiguous) { + size_t acount = (match.ambiguous & mask_ambiguous).count(); + return (pcount - acount) / 2; // i.e., acount + (pcount - acount * 3) / 2; } else { return pcount / 2; } @@ -253,10 +270,10 @@ class ScanTemplate { void full_match(State& match) const { if (forward) { - match.forward_mismatches = strand_match(match, forward_ref, forward_mask); + match.forward_mismatches = strand_match(match, forward_ref, forward_mask, forward_mask_ambiguous); } if (reverse) { - match.reverse_mismatches = strand_match(match, reverse_ref, reverse_mask); + match.reverse_mismatches = strand_match(match, reverse_ref, reverse_mask, reverse_mask_ambiguous); } } diff --git a/tests/src/ScanTemplate.cpp b/tests/src/ScanTemplate.cpp index 043244b..5c149f4 100644 --- a/tests/src/ScanTemplate.cpp +++ b/tests/src/ScanTemplate.cpp @@ -191,22 +191,42 @@ TEST(ScanTemplate, BadBases) { EXPECT_FALSE(out.finished); } + std::bitset<16> amask; + for (size_t i = 0; i < thing.size(); ++i) { + amask.set(i); + } + // Runs into an N later. { - std::string seq = "aaaaaaaaaaACGNAAAATTTTa"; + std::string seq = "aaaaaaaaaaACGNAAAATTTTacgatcgatcagctag"; auto out = stuff.initialize(seq.c_str(), seq.size()); for (int i = 0; i < 10; ++i) { stuff.next(out); - EXPECT_TRUE(out.forward_mismatches > 0); + EXPECT_GT(out.forward_mismatches, 1); EXPECT_FALSE(out.finished); + EXPECT_EQ(out.any_ambiguous, i > 1); // at the third base (i.e., i = 2), the N comes into the window. } stuff.next(out); EXPECT_EQ(out.forward_mismatches, 1); EXPECT_FALSE(out.finished); - EXPECT_EQ(out.ambiguous.count(), 4); - EXPECT_EQ(out.bad.size(), 1); + EXPECT_EQ((out.ambiguous & amask).count(), 1); + EXPECT_TRUE(out.any_ambiguous); + + for (int i = 0; i < 3; ++i) { // next three shifts still overlap the N. + stuff.next(out); + EXPECT_GT(out.forward_mismatches, 1); + EXPECT_FALSE(out.finished); + EXPECT_EQ((out.ambiguous & amask).count(), 1); + EXPECT_TRUE(out.any_ambiguous); + } + + stuff.next(out); // past the N now, so ambiguity is now dropped. + EXPECT_GT(out.forward_mismatches, 1); + EXPECT_FALSE(out.finished); + EXPECT_EQ((out.ambiguous & amask).count(), 0); + EXPECT_FALSE(out.any_ambiguous); } // Clears existing Ns. @@ -216,13 +236,86 @@ TEST(ScanTemplate, BadBases) { for (int i = 0; i < 4; ++i) { stuff.next(out); - EXPECT_TRUE(out.forward_mismatches > 0); + EXPECT_GT(out.forward_mismatches, 0); EXPECT_FALSE(out.finished); + EXPECT_EQ((out.ambiguous & amask).count(), 4 - i); + EXPECT_TRUE(out.any_ambiguous); } stuff.next(out); EXPECT_EQ(out.forward_mismatches, 0); EXPECT_FALSE(out.finished); - EXPECT_TRUE(out.bad.empty()); + EXPECT_EQ((out.ambiguous & amask).count(), 0); + EXPECT_FALSE(out.any_ambiguous); + } + + // Works with separated N's. + { + std::string seq = "aaaaaaaANGTAAAATTTNaaaaaaaaaaaaaa"; + auto out = stuff.initialize(seq.c_str(), seq.size()); + + for (int i = 0; i <= 6; ++i) { + stuff.next(out); + EXPECT_GT(out.forward_mismatches, 2); + EXPECT_FALSE(out.finished); + EXPECT_EQ((out.ambiguous & amask).count(), 1); + EXPECT_TRUE(out.any_ambiguous); + } + + stuff.next(out); + EXPECT_EQ(out.forward_mismatches, 2); + EXPECT_FALSE(out.finished); + EXPECT_EQ((out.ambiguous & amask).count(), 2); + EXPECT_TRUE(out.any_ambiguous); + + stuff.next(out); + EXPECT_GT(out.forward_mismatches, 2); + EXPECT_FALSE(out.finished); + EXPECT_EQ((out.ambiguous & amask).count(), 2); + EXPECT_TRUE(out.any_ambiguous); + + for (int i = 0; i <= 9; ++i) { + stuff.next(out); + EXPECT_GT(out.forward_mismatches, 2); + EXPECT_FALSE(out.finished); + EXPECT_EQ((out.ambiguous & amask).count(), 1); + EXPECT_TRUE(out.any_ambiguous); + } + + stuff.next(out); + EXPECT_GT(out.forward_mismatches, 2); + EXPECT_FALSE(out.finished); + EXPECT_EQ((out.ambiguous & amask).count(), 0); + EXPECT_FALSE(out.any_ambiguous); + } + + // Works in reverse. + { + std::string seq = "aaNaaAAAATTTTACGTaaNaaaaa"; + kaori::ScanTemplate<16> stuff(thing.c_str(), thing.size(), kaori::SearchStrand::REVERSE); + auto out = stuff.initialize(seq.c_str(), seq.size()); + + for (int i = 0; i <= 4; ++i) { + stuff.next(out); + EXPECT_GT(out.reverse_mismatches, 1); + EXPECT_FALSE(out.finished); + EXPECT_EQ((out.ambiguous & amask).count(), (i <= 2)); // i.e., before we pass the N. + EXPECT_EQ(out.any_ambiguous, (i <= 2)); + } + + stuff.next(out); + EXPECT_EQ(out.reverse_mismatches, 0); + EXPECT_FALSE(out.finished); + EXPECT_EQ((out.ambiguous & amask).count(), 0); + EXPECT_FALSE(out.any_ambiguous); + + size_t remaining = seq.size() - thing.size() - out.position; + for (size_t i = 0; i < remaining; ++i) { + stuff.next(out); + EXPECT_GT(out.reverse_mismatches, 1); + EXPECT_EQ(out.finished, i + 1 == remaining); + EXPECT_EQ((out.ambiguous & amask).count(), (i > 1)); // i.e., after we hit the next N. + EXPECT_EQ(out.any_ambiguous, (i > 1)); + } } }