Skip to content

Commit

Permalink
Minor optimizations for mismatches due to ambiguity.
Browse files Browse the repository at this point in the history
In particular, there is no need to store a rolling queue of ambiguous
positions; we only need the most recent position to determine if there are
any ambiguous bases in the current window of the sequence.

Also expanded the test scenarios for scanning ambiguous bases.
  • Loading branch information
LTLA committed Jul 15, 2024
1 parent f0b00db commit 55325ba
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 41 deletions.
87 changes: 52 additions & 35 deletions include/kaori/ScanTemplate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
#define KAORI_SCAN_TEMPLATE_HPP

#include <bitset>
#include <deque>
#include <vector>
#include <stdexcept>
#include <string>

#include "utils.hpp"

/**
Expand Down Expand Up @@ -63,6 +66,7 @@ class ScanTemplate {
if (b != '-') {
add_base_to_hash(forward_ref, b);
add_mask_to_hash(forward_mask);
forward_mask_ambiguous.set(i);
} else {
shift_hash(forward_ref);
shift_hash(forward_mask);
Expand All @@ -85,6 +89,7 @@ class ScanTemplate {
if (b != '-') {
add_base_to_hash(reverse_ref, complement_base(b));
add_mask_to_hash(reverse_mask);
reverse_mask_ambiguous.set(i);
} else {
shift_hash(reverse_ref);
shift_hash(reverse_mask);
Expand Down Expand Up @@ -126,10 +131,13 @@ class ScanTemplate {
/**
* @cond
*/
std::bitset<N> state, ambiguous;
std::bitset<N> state;
const char * seq;
size_t len;
std::deque<size_t> bad;

std::bitset<N/4> ambiguous; // we only need a yes/no for the ambiguous state, so we can use a smaller bitset.
size_t last_ambiguous; // contains the position of the most recent ambiguous base; should only be read if any_ambiguous = true.
bool any_ambiguous = false; // indicates whether ambiguous.count() > 0.
/**
* @endcond
*/
Expand All @@ -156,13 +164,20 @@ class ScanTemplate {

if (is_standard_base(base)) {
add_base_to_hash(out.state, base);
if (!out.bad.empty()) {
shift_hash(out.ambiguous);

if (out.any_ambiguous) {
out.ambiguous <<= 1;
}
} else {
add_other_to_hash(out.state);
add_other_to_hash(out.ambiguous);
out.bad.push_back(i);

if (out.any_ambiguous) {
out.ambiguous <<= 1;
} else {
out.any_ambiguous = true;
}
out.ambiguous.set(0);
out.last_ambiguous = i;
}
}
} else {
Expand All @@ -181,28 +196,32 @@ class ScanTemplate {
* On return, `state` is updated with the details of the current match at a particular position on the read sequence.
*/
void next(State& state) const {
if (!state.bad.empty() && state.bad.front() == state.position) {
state.bad.pop_front();
if (state.bad.empty()) {
// This should effectively clear the ambiguous bitset, allowing
// us to skip its shifting if there are no more ambiguous
// bases. We do it here because we won't get an opportunity to
// do it later; as 'bad' is empty, the shift below is skipped.
shift_hash(state.ambiguous);
}
}

size_t right = state.position + length;
char base = state.seq[right];

if (is_standard_base(base)) {
add_base_to_hash(state.state, base); // no need to trim off the end, the mask will handle that.
if (!state.bad.empty()) {
shift_hash(state.ambiguous);
if (state.any_ambiguous) {
state.ambiguous <<= 1;

// If the last ambiguous position is equal to 'position', the
// ensuing increment to the latter will shift it out of the
// hash... at which point, we've got no ambiguity left.
if (state.last_ambiguous == state.position) {
state.any_ambiguous = false;
}
}

} else {
add_other_to_hash(state.state);
add_other_to_hash(state.ambiguous);
state.bad.push_back(right);

if (state.any_ambiguous) {
state.ambiguous <<= 1;
} else {
state.any_ambiguous = true;
}
state.ambiguous.set(0);
state.last_ambiguous = right;
}

++state.position;
Expand All @@ -221,6 +240,9 @@ class ScanTemplate {
int mismatches;
bool forward, reverse;

std::bitset<N/4> forward_mask_ambiguous; // we only need a yes/no for whether a position is an ambiguous base, so we can use a smaller bitset.
std::bitset<N/4> reverse_mask_ambiguous;

static void add_mask_to_hash(std::bitset<N>& current) {
shift_hash(current);
current.set(0);
Expand All @@ -230,33 +252,28 @@ class ScanTemplate {
return;
}

static int strand_match(const State& match, const std::bitset<N>& ref, const std::bitset<N>& mask) {
static int strand_match(const State& match, const std::bitset<N>& ref, const std::bitset<N>& mask, const std::bitset<N/4>& mask_ambiguous) {
// pop count here is equal to the number of non-ambiguous mismatches *
// 2 + number of ambiguous mismatches * 3. This is because
// non-ambiguous bases are encoded by 1 set bit per 4 bases (so 2 are
// left after a XOR'd mismatch), while ambiguous mismatches are encoded
// by all set bits per 4 bases (which means that 3 are left after XOR).
int pcount = ((match.state & mask) ^ ref).count();

// Counting the number of ambiguous bases after masking. Each ambiguous
// base is represented by 4 set bits, so we divide by 4 to get the number
// of bases; then we multiply by three to remove their contribution. The
// difference is then divided by two to get the number of non-ambig mm's.
if (!match.bad.empty()) {
int acount = (match.ambiguous & mask).count();
acount /= 4;
return acount + (pcount - acount * 3) / 2;
size_t pcount = ((match.state & mask) ^ ref).count();

if (match.any_ambiguous) {
size_t acount = (match.ambiguous & mask_ambiguous).count();
return (pcount - acount) / 2; // i.e., acount + (pcount - acount * 3) / 2;
} else {
return pcount / 2;
}
}

void full_match(State& match) const {
if (forward) {
match.forward_mismatches = strand_match(match, forward_ref, forward_mask);
match.forward_mismatches = strand_match(match, forward_ref, forward_mask, forward_mask_ambiguous);
}
if (reverse) {
match.reverse_mismatches = strand_match(match, reverse_ref, reverse_mask);
match.reverse_mismatches = strand_match(match, reverse_ref, reverse_mask, reverse_mask_ambiguous);
}
}

Expand Down
105 changes: 99 additions & 6 deletions tests/src/ScanTemplate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,22 +191,42 @@ TEST(ScanTemplate, BadBases) {
EXPECT_FALSE(out.finished);
}

std::bitset<16> amask;
for (size_t i = 0; i < thing.size(); ++i) {
amask.set(i);
}

// Runs into an N later.
{
std::string seq = "aaaaaaaaaaACGNAAAATTTTa";
std::string seq = "aaaaaaaaaaACGNAAAATTTTacgatcgatcagctag";
auto out = stuff.initialize(seq.c_str(), seq.size());

for (int i = 0; i < 10; ++i) {
stuff.next(out);
EXPECT_TRUE(out.forward_mismatches > 0);
EXPECT_GT(out.forward_mismatches, 1);
EXPECT_FALSE(out.finished);
EXPECT_EQ(out.any_ambiguous, i > 1); // at the third base (i.e., i = 2), the N comes into the window.
}

stuff.next(out);
EXPECT_EQ(out.forward_mismatches, 1);
EXPECT_FALSE(out.finished);
EXPECT_EQ(out.ambiguous.count(), 4);
EXPECT_EQ(out.bad.size(), 1);
EXPECT_EQ((out.ambiguous & amask).count(), 1);
EXPECT_TRUE(out.any_ambiguous);

for (int i = 0; i < 3; ++i) { // next three shifts still overlap the N.
stuff.next(out);
EXPECT_GT(out.forward_mismatches, 1);
EXPECT_FALSE(out.finished);
EXPECT_EQ((out.ambiguous & amask).count(), 1);
EXPECT_TRUE(out.any_ambiguous);
}

stuff.next(out); // past the N now, so ambiguity is now dropped.
EXPECT_GT(out.forward_mismatches, 1);
EXPECT_FALSE(out.finished);
EXPECT_EQ((out.ambiguous & amask).count(), 0);
EXPECT_FALSE(out.any_ambiguous);
}

// Clears existing Ns.
Expand All @@ -216,13 +236,86 @@ TEST(ScanTemplate, BadBases) {

for (int i = 0; i < 4; ++i) {
stuff.next(out);
EXPECT_TRUE(out.forward_mismatches > 0);
EXPECT_GT(out.forward_mismatches, 0);
EXPECT_FALSE(out.finished);
EXPECT_EQ((out.ambiguous & amask).count(), 4 - i);
EXPECT_TRUE(out.any_ambiguous);
}

stuff.next(out);
EXPECT_EQ(out.forward_mismatches, 0);
EXPECT_FALSE(out.finished);
EXPECT_TRUE(out.bad.empty());
EXPECT_EQ((out.ambiguous & amask).count(), 0);
EXPECT_FALSE(out.any_ambiguous);
}

// Works with separated N's.
{
std::string seq = "aaaaaaaANGTAAAATTTNaaaaaaaaaaaaaa";
auto out = stuff.initialize(seq.c_str(), seq.size());

for (int i = 0; i <= 6; ++i) {
stuff.next(out);
EXPECT_GT(out.forward_mismatches, 2);
EXPECT_FALSE(out.finished);
EXPECT_EQ((out.ambiguous & amask).count(), 1);
EXPECT_TRUE(out.any_ambiguous);
}

stuff.next(out);
EXPECT_EQ(out.forward_mismatches, 2);
EXPECT_FALSE(out.finished);
EXPECT_EQ((out.ambiguous & amask).count(), 2);
EXPECT_TRUE(out.any_ambiguous);

stuff.next(out);
EXPECT_GT(out.forward_mismatches, 2);
EXPECT_FALSE(out.finished);
EXPECT_EQ((out.ambiguous & amask).count(), 2);
EXPECT_TRUE(out.any_ambiguous);

for (int i = 0; i <= 9; ++i) {
stuff.next(out);
EXPECT_GT(out.forward_mismatches, 2);
EXPECT_FALSE(out.finished);
EXPECT_EQ((out.ambiguous & amask).count(), 1);
EXPECT_TRUE(out.any_ambiguous);
}

stuff.next(out);
EXPECT_GT(out.forward_mismatches, 2);
EXPECT_FALSE(out.finished);
EXPECT_EQ((out.ambiguous & amask).count(), 0);
EXPECT_FALSE(out.any_ambiguous);
}

// Works in reverse.
{
std::string seq = "aaNaaAAAATTTTACGTaaNaaaaa";
kaori::ScanTemplate<16> stuff(thing.c_str(), thing.size(), kaori::SearchStrand::REVERSE);
auto out = stuff.initialize(seq.c_str(), seq.size());

for (int i = 0; i <= 4; ++i) {
stuff.next(out);
EXPECT_GT(out.reverse_mismatches, 1);
EXPECT_FALSE(out.finished);
EXPECT_EQ((out.ambiguous & amask).count(), (i <= 2)); // i.e., before we pass the N.
EXPECT_EQ(out.any_ambiguous, (i <= 2));
}

stuff.next(out);
EXPECT_EQ(out.reverse_mismatches, 0);
EXPECT_FALSE(out.finished);
EXPECT_EQ((out.ambiguous & amask).count(), 0);
EXPECT_FALSE(out.any_ambiguous);

size_t remaining = seq.size() - thing.size() - out.position;
for (size_t i = 0; i < remaining; ++i) {
stuff.next(out);
EXPECT_GT(out.reverse_mismatches, 1);
EXPECT_EQ(out.finished, i + 1 == remaining);
EXPECT_EQ((out.ambiguous & amask).count(), (i > 1)); // i.e., after we hit the next N.
EXPECT_EQ(out.any_ambiguous, (i > 1));
}
}
}

0 comments on commit 55325ba

Please sign in to comment.