Skip to content

Commit

Permalink
Merge pull request #450 from ksahlin/mcs-refactor-docs
Browse files Browse the repository at this point in the history
Multi-context seeds refactoring and docs
  • Loading branch information
marcelm authored Oct 8, 2024
2 parents 3209621 + b64e5b5 commit 33b6526
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 76 deletions.
12 changes: 12 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
# Strobealign Changelog

## Development version

* #388 and #426: Increase accuracy and mapping rate for reads shorter than
about 200 bp by introducing multi-context seeds.
Previously, seeds always consisted of two k-mers and would only be found if
both occur in query and reference.
With this change, strobealign falls back to looking up just one of the k-mers
when appropriate.
This feature is currently *experimental* and only enabled when using the
`--mcs` command-line option.
Contributed by Ivan Tolstoganov (@Itolstoganov).

## v0.14.0 (2024-10-03)

* #401: The default number of threads is now 1 instead of 3.
Expand Down
2 changes: 1 addition & 1 deletion src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ void StrobemerIndex::print_diagnostics(const std::string& logfile_name, int k) c

for (size_t it = 0; it < randstrobes.size(); it++) {
seed_length = strobe2_offset(it) + k;
auto count = get_count(find(get_hash(it)));
auto count = get_count_full(find_full(get_hash(it)));

if (seed_length < max_size){
log_count[seed_length] ++;
Expand Down
84 changes: 28 additions & 56 deletions src/index.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,35 +56,26 @@ struct StrobemerIndex {
void populate(float f, unsigned n_threads);
void print_diagnostics(const std::string& logfile_name, int k) const;
int pick_bits(size_t size) const;
size_t find(randstrobe_hash_t key) const {
constexpr int MAX_LINEAR_SEARCH = 4;
const unsigned int top_N = key >> (64 - bits);
bucket_index_t position_start = randstrobe_start_indices[top_N];
bucket_index_t position_end = randstrobe_start_indices[top_N + 1];
if (position_start == position_end) {
return end();
}

if (position_end - position_start < MAX_LINEAR_SEARCH) {
for ( ; position_start < position_end; ++position_start) {
if (randstrobes[position_start].hash == key) return position_start;
if (randstrobes[position_start].hash > key) return end();
}
return end();
}
auto cmp = [](const RefRandstrobe lhs, const RefRandstrobe rhs) {return lhs.hash < rhs.hash; };
// Find first entry that matches the given key
size_t find_full(randstrobe_hash_t key) const {
return find(key, 0);
}

auto pos = std::lower_bound(randstrobes.begin() + position_start,
randstrobes.begin() + position_end,
RefRandstrobe{key, 0, 0},
cmp);
if (pos->hash == key) return pos - randstrobes.begin();
return end();
/*
* Find the first entry that matches the main hash (ignoring the aux_len
* least significant bits)
*/
size_t find_partial(randstrobe_hash_t key) const {
return find(key, parameters.randstrobe.aux_len);
}

//Returns the first entry that matches the main hash
size_t partial_find(randstrobe_hash_t key) const {
const unsigned int aux_len = parameters.randstrobe.aux_len;
/*
* Find first entry whose hash matches the given key, but ignore the
* b least significant bits
*/
size_t find(randstrobe_hash_t key, uint8_t b) const {
const unsigned int aux_len = b;
randstrobe_hash_t key_prefix = key >> aux_len;

constexpr int MAX_LINEAR_SEARCH = 4;
Expand Down Expand Up @@ -171,46 +162,27 @@ struct StrobemerIndex {
return randstrobes.size();
}

unsigned int get_count(bucket_index_t position) const {
unsigned int get_count_full(bucket_index_t position) const {
return get_count(position, 0);
}

unsigned int get_count_partial(bucket_index_t position) const {
return get_count(position, parameters.randstrobe.aux_len);
}

unsigned int get_count(bucket_index_t position, uint8_t b) const {
// For 95% of cases, the result will be small and a brute force search
// is the best option. Once, we go over MAX_LINEAR_SEARCH, though, we
// use a binary search to get the next position
// In the human genome, if we assume that the frequency
// a hash will be queried is proportional to the frequency it appears in the table,
// In the human genome, if we assume that the frequency
// a hash will be queried is proportional to the frequency it appears in the table,
// with MAX_LINEAR_SEARCH=8, the actual value will be 96%.

// Since the result depends on position, this function must be used on the smallest position which points to the
// seed with the given hash to yield the number of seeds with this hash.

constexpr unsigned int MAX_LINEAR_SEARCH = 8;
const auto key = randstrobes[position].hash;
const unsigned int top_N = key >> (64 - bits);
bucket_index_t position_end = randstrobe_start_indices[top_N + 1];
uint64_t count = 1;

if (position_end - position < MAX_LINEAR_SEARCH) {
for (bucket_index_t position_start = position + 1; position_start < position_end; ++position_start) {
if (randstrobes[position_start].hash == key){
count += 1;
}
else{
break;
}
}
return count;
}
auto cmp = [](const RefRandstrobe lhs, const RefRandstrobe rhs) {return lhs.hash < rhs.hash; };

auto pos = std::upper_bound(randstrobes.begin() + position,
randstrobes.begin() + position_end,
RefRandstrobe{key, 0, 0},
cmp);
return (pos - randstrobes.begin() - 1) - position + 1;
}

unsigned int get_partial_count(bucket_index_t position) const {
constexpr unsigned int MAX_LINEAR_SEARCH = 8;
const unsigned int aux_len = parameters.randstrobe.aux_len;
const unsigned int aux_len = b;

const auto key = randstrobes[position].hash;
randstrobe_hash_t key_prefix = key >> aux_len;
Expand Down
40 changes: 22 additions & 18 deletions src/nam.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,19 @@ struct Match {
int ref_end;
};

struct PartialSeed {
size_t hash;
unsigned int start;
bool operator==(const Match& lhs, const Match& rhs) {
return (lhs.query_start == rhs.query_start) && (lhs.query_end == rhs.query_end) && (lhs.ref_start == rhs.ref_start) && (lhs.ref_end == rhs.ref_end);
}

/*
* A partial hit is a hit where not the full randstrobe hash could be found in
* the index but only the "main" hash (only the first aux_len bits).
*/
struct PartialHit {
randstrobe_hash_t hash;
unsigned int start; // position in strobemer index
bool is_reverse;
bool operator==(const PartialSeed& rhs) const {
bool operator==(const PartialHit& rhs) const {
return (hash == rhs.hash) && (start == rhs.start) && (is_reverse == rhs.is_reverse);
}
};
Expand Down Expand Up @@ -57,10 +65,6 @@ inline void add_to_matches_map_partial(
}
}

bool operator==(const Match& lhs, const Match& rhs) {
return (lhs.query_start == rhs.query_start) && (lhs.query_end == rhs.query_end) && (lhs.ref_start == rhs.ref_start) && (lhs.ref_end == rhs.ref_end);
}

void merge_matches_into_nams(
robin_hood::unordered_map<unsigned int, std::vector<Match>>& matches_map,
int k,
Expand Down Expand Up @@ -200,7 +204,7 @@ std::tuple<float, int, std::vector<Nam>> find_nams(
const StrobemerIndex& index,
bool use_mcs
) {
std::vector<PartialSeed> partial_queried; // TODO: is a small set more efficient than linear search in a small vector?
std::vector<PartialHit> partial_queried; // TODO: is a small set more efficient than linear search in a small vector?
if (use_mcs) {
partial_queried.reserve(10);
}
Expand All @@ -210,7 +214,7 @@ std::tuple<float, int, std::vector<Nam>> find_nams(
int nr_good_hits = 0;
int total_hits = 0;
for (const auto &q : query_randstrobes) {
size_t position = index.find(q.hash);
size_t position = index.find_full(q.hash);
if (position != index.end()){
total_hits++;
if (index.is_filtered(position)) {
Expand All @@ -220,12 +224,12 @@ std::tuple<float, int, std::vector<Nam>> find_nams(
add_to_matches_map_full(matches_map[q.is_reverse], q.start, q.end, index, position);
}
else if (use_mcs) {
PartialSeed ph{q.hash >> index.get_aux_len(), q.partial_start, q.is_reverse};
PartialHit ph{q.hash >> index.get_aux_len(), q.partial_start, q.is_reverse};
if (std::find(partial_queried.begin(), partial_queried.end(), ph) != partial_queried.end()) {
// already queried
continue;
}
size_t partial_pos = index.partial_find(q.hash);
size_t partial_pos = index.find_partial(q.hash);
if (partial_pos != index.end()) {
total_hits++;
if (index.is_partial_filtered(partial_pos)) {
Expand Down Expand Up @@ -267,7 +271,7 @@ std::pair<int, std::vector<Nam>> find_nams_rescue(
< std::tie(rhs.count, rhs.query_start, rhs.query_end);
}
};
std::vector<PartialSeed> partial_queried; // TODO: is a small set more efficient than linear search in a small vector?
std::vector<PartialHit> partial_queried; // TODO: is a small set more efficient than linear search in a small vector?
partial_queried.reserve(10);
std::array<robin_hood::unordered_map<unsigned int, std::vector<Match>>, 2> matches_map;
std::vector<RescueHit> hits_fw;
Expand All @@ -278,9 +282,9 @@ std::pair<int, std::vector<Nam>> find_nams_rescue(
hits_rc.reserve(5000);

for (auto &qr : query_randstrobes) {
size_t position = index.find(qr.hash);
size_t position = index.find_full(qr.hash);
if (position != index.end()) {
unsigned int count = index.get_count(position);
unsigned int count = index.get_count_full(position);
RescueHit rh{position, count, qr.start, qr.end, false};
if (qr.is_reverse){
hits_rc.push_back(rh);
Expand All @@ -289,14 +293,14 @@ std::pair<int, std::vector<Nam>> find_nams_rescue(
}
}
else if (use_mcs) {
PartialSeed ph = {qr.hash >> index.get_aux_len(), qr.partial_start, qr.is_reverse};
PartialHit ph = {qr.hash >> index.get_aux_len(), qr.partial_start, qr.is_reverse};
if (std::find(partial_queried.begin(), partial_queried.end(), ph) != partial_queried.end()) {
// already queried
continue;
}
size_t partial_pos = index.partial_find(qr.hash);
size_t partial_pos = index.find_partial(qr.hash);
if (partial_pos != index.end()) {
unsigned int partial_count = index.get_partial_count(partial_pos);
unsigned int partial_count = index.get_count_partial(partial_pos);
RescueHit rh{partial_pos, partial_count, qr.partial_start, qr.partial_end, true};
if (qr.is_reverse){
hits_rc.push_back(rh);
Expand Down
2 changes: 1 addition & 1 deletion src/python/strobealign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ NB_MODULE(strobealign_extension, m_) {
.def(nb::init<References&, IndexParameters&>())
.def("find", [](const StrobemerIndex& index, uint64_t key) -> std::vector<RefRandstrobe> {
std::vector<RefRandstrobe> v;
auto position = index.find(key);
auto position = index.find_full(key);
while (position != index.end() && index.get_hash(position) == key) {
v.push_back(index.get_randstrobe(position));
position++;
Expand Down
11 changes: 11 additions & 0 deletions src/randstrobes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,17 @@ static inline syncmer_hash_t syncmer_smer_hash(uint64_t packed) {
return xxh64(packed);
}

/*
* This function combines two individual syncmer hashes into a single hash
* for the randstrobe.
*
* The syncmer with the smaller hash is designated as the "main", the other is
* the "auxiliary".
* The combined hash is obtained by setting the top aux_len bits to the bits of
* the main hash and the bottom 64 - aux_len bits to the bits of the auxiliary
* hash. Since entries in the index are sorted by randstrobe hash, this allows
* us to search for the main syncmer only by masking out the lower aux_len bits.
*/
static inline randstrobe_hash_t randstrobe_hash(syncmer_hash_t hash1, syncmer_hash_t hash2, size_t aux_len) {
// Make the function symmetric
if (hash1 > hash2) {
Expand Down
3 changes: 3 additions & 0 deletions src/randstrobes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ struct QueryRandstrobe {
randstrobe_hash_t hash;
unsigned int start;
unsigned int end;
/* Start and end of the main syncmer (relevant if the randstrobe couldn’t
* be found in the index and we fall back to a partial hit)
*/
unsigned int partial_start;
unsigned int partial_end;
bool is_reverse;
Expand Down

0 comments on commit 33b6526

Please sign in to comment.