Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-context seeds refactoring and docs #450

Merged
merged 3 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
# Strobealign Changelog

## Development version

* #388 and #426: Increase accuracy and mapping rate for reads shorter than
about 200 bp by introducing multi-context seeds.
Previously, seeds always consisted of two k-mers and would only be found if
both occur in query and reference.
With this change, strobealign falls back to looking up just one of the k-mers
when appropriate.
This feature is currently *experimental* and only enabled when using the
`--mcs` command-line option.
Contributed by Ivan Tolstoganov (@Itolstoganov).

## v0.14.0 (2024-10-03)

* #401: The default number of threads is now 1 instead of 3.
Expand Down
2 changes: 1 addition & 1 deletion src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ void StrobemerIndex::print_diagnostics(const std::string& logfile_name, int k) c

for (size_t it = 0; it < randstrobes.size(); it++) {
seed_length = strobe2_offset(it) + k;
auto count = get_count(find(get_hash(it)));
auto count = get_count_full(find_full(get_hash(it)));

if (seed_length < max_size){
log_count[seed_length] ++;
Expand Down
84 changes: 28 additions & 56 deletions src/index.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,35 +56,26 @@ struct StrobemerIndex {
void populate(float f, unsigned n_threads);
void print_diagnostics(const std::string& logfile_name, int k) const;
int pick_bits(size_t size) const;
size_t find(randstrobe_hash_t key) const {
constexpr int MAX_LINEAR_SEARCH = 4;
const unsigned int top_N = key >> (64 - bits);
bucket_index_t position_start = randstrobe_start_indices[top_N];
bucket_index_t position_end = randstrobe_start_indices[top_N + 1];
if (position_start == position_end) {
return end();
}

if (position_end - position_start < MAX_LINEAR_SEARCH) {
for ( ; position_start < position_end; ++position_start) {
if (randstrobes[position_start].hash == key) return position_start;
if (randstrobes[position_start].hash > key) return end();
}
return end();
}
auto cmp = [](const RefRandstrobe lhs, const RefRandstrobe rhs) {return lhs.hash < rhs.hash; };
// Find first entry that matches the given key
size_t find_full(randstrobe_hash_t key) const {
return find(key, 0);
}

auto pos = std::lower_bound(randstrobes.begin() + position_start,
randstrobes.begin() + position_end,
RefRandstrobe{key, 0, 0},
cmp);
if (pos->hash == key) return pos - randstrobes.begin();
return end();
/*
* Find the first entry that matches the main hash (ignoring the aux_len
* least significant bits)
*/
size_t find_partial(randstrobe_hash_t key) const {
return find(key, parameters.randstrobe.aux_len);
}

//Returns the first entry that matches the main hash
size_t partial_find(randstrobe_hash_t key) const {
const unsigned int aux_len = parameters.randstrobe.aux_len;
/*
* Find first entry whose hash matches the given key, but ignore the
* b least significant bits
*/
size_t find(randstrobe_hash_t key, uint8_t b) const {
const unsigned int aux_len = b;
randstrobe_hash_t key_prefix = key >> aux_len;

constexpr int MAX_LINEAR_SEARCH = 4;
Expand Down Expand Up @@ -171,46 +162,27 @@ struct StrobemerIndex {
return randstrobes.size();
}

unsigned int get_count(bucket_index_t position) const {
unsigned int get_count_full(bucket_index_t position) const {
return get_count(position, 0);
}

unsigned int get_count_partial(bucket_index_t position) const {
return get_count(position, parameters.randstrobe.aux_len);
}

unsigned int get_count(bucket_index_t position, uint8_t b) const {
// For 95% of cases, the result will be small and a brute force search
// is the best option. Once, we go over MAX_LINEAR_SEARCH, though, we
// use a binary search to get the next position
// In the human genome, if we assume that the frequency
// a hash will be queried is proportional to the frequency it appears in the table,
// In the human genome, if we assume that the frequency
// a hash will be queried is proportional to the frequency it appears in the table,
// with MAX_LINEAR_SEARCH=8, the actual value will be 96%.

// Since the result depends on position, this function must be used on the smallest position which points to the
// seed with the given hash to yield the number of seeds with this hash.

constexpr unsigned int MAX_LINEAR_SEARCH = 8;
const auto key = randstrobes[position].hash;
const unsigned int top_N = key >> (64 - bits);
bucket_index_t position_end = randstrobe_start_indices[top_N + 1];
uint64_t count = 1;

if (position_end - position < MAX_LINEAR_SEARCH) {
for (bucket_index_t position_start = position + 1; position_start < position_end; ++position_start) {
if (randstrobes[position_start].hash == key){
count += 1;
}
else{
break;
}
}
return count;
}
auto cmp = [](const RefRandstrobe lhs, const RefRandstrobe rhs) {return lhs.hash < rhs.hash; };

auto pos = std::upper_bound(randstrobes.begin() + position,
randstrobes.begin() + position_end,
RefRandstrobe{key, 0, 0},
cmp);
return (pos - randstrobes.begin() - 1) - position + 1;
}

unsigned int get_partial_count(bucket_index_t position) const {
constexpr unsigned int MAX_LINEAR_SEARCH = 8;
const unsigned int aux_len = parameters.randstrobe.aux_len;
const unsigned int aux_len = b;

const auto key = randstrobes[position].hash;
randstrobe_hash_t key_prefix = key >> aux_len;
Expand Down
40 changes: 22 additions & 18 deletions src/nam.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,19 @@ struct Match {
int ref_end;
};

struct PartialSeed {
size_t hash;
unsigned int start;
bool operator==(const Match& lhs, const Match& rhs) {
return (lhs.query_start == rhs.query_start) && (lhs.query_end == rhs.query_end) && (lhs.ref_start == rhs.ref_start) && (lhs.ref_end == rhs.ref_end);
}

/*
* A partial hit is a hit where not the full randstrobe hash could be found in
* the index but only the "main" hash (only the first aux_len bits).
*/
struct PartialHit {
randstrobe_hash_t hash;
unsigned int start; // position in strobemer index
bool is_reverse;
bool operator==(const PartialSeed& rhs) const {
bool operator==(const PartialHit& rhs) const {
return (hash == rhs.hash) && (start == rhs.start) && (is_reverse == rhs.is_reverse);
}
};
Expand Down Expand Up @@ -57,10 +65,6 @@ inline void add_to_matches_map_partial(
}
}

bool operator==(const Match& lhs, const Match& rhs) {
return (lhs.query_start == rhs.query_start) && (lhs.query_end == rhs.query_end) && (lhs.ref_start == rhs.ref_start) && (lhs.ref_end == rhs.ref_end);
}

void merge_matches_into_nams(
robin_hood::unordered_map<unsigned int, std::vector<Match>>& matches_map,
int k,
Expand Down Expand Up @@ -200,7 +204,7 @@ std::tuple<float, int, std::vector<Nam>> find_nams(
const StrobemerIndex& index,
bool use_mcs
) {
std::vector<PartialSeed> partial_queried; // TODO: is a small set more efficient than linear search in a small vector?
std::vector<PartialHit> partial_queried; // TODO: is a small set more efficient than linear search in a small vector?
if (use_mcs) {
partial_queried.reserve(10);
}
Expand All @@ -210,7 +214,7 @@ std::tuple<float, int, std::vector<Nam>> find_nams(
int nr_good_hits = 0;
int total_hits = 0;
for (const auto &q : query_randstrobes) {
size_t position = index.find(q.hash);
size_t position = index.find_full(q.hash);
if (position != index.end()){
total_hits++;
if (index.is_filtered(position)) {
Expand All @@ -220,12 +224,12 @@ std::tuple<float, int, std::vector<Nam>> find_nams(
add_to_matches_map_full(matches_map[q.is_reverse], q.start, q.end, index, position);
}
else if (use_mcs) {
PartialSeed ph{q.hash >> index.get_aux_len(), q.partial_start, q.is_reverse};
PartialHit ph{q.hash >> index.get_aux_len(), q.partial_start, q.is_reverse};
if (std::find(partial_queried.begin(), partial_queried.end(), ph) != partial_queried.end()) {
// already queried
continue;
}
size_t partial_pos = index.partial_find(q.hash);
size_t partial_pos = index.find_partial(q.hash);
if (partial_pos != index.end()) {
total_hits++;
if (index.is_partial_filtered(partial_pos)) {
Expand Down Expand Up @@ -267,7 +271,7 @@ std::pair<int, std::vector<Nam>> find_nams_rescue(
< std::tie(rhs.count, rhs.query_start, rhs.query_end);
}
};
std::vector<PartialSeed> partial_queried; // TODO: is a small set more efficient than linear search in a small vector?
std::vector<PartialHit> partial_queried; // TODO: is a small set more efficient than linear search in a small vector?
partial_queried.reserve(10);
std::array<robin_hood::unordered_map<unsigned int, std::vector<Match>>, 2> matches_map;
std::vector<RescueHit> hits_fw;
Expand All @@ -278,9 +282,9 @@ std::pair<int, std::vector<Nam>> find_nams_rescue(
hits_rc.reserve(5000);

for (auto &qr : query_randstrobes) {
size_t position = index.find(qr.hash);
size_t position = index.find_full(qr.hash);
if (position != index.end()) {
unsigned int count = index.get_count(position);
unsigned int count = index.get_count_full(position);
RescueHit rh{position, count, qr.start, qr.end, false};
if (qr.is_reverse){
hits_rc.push_back(rh);
Expand All @@ -289,14 +293,14 @@ std::pair<int, std::vector<Nam>> find_nams_rescue(
}
}
else if (use_mcs) {
PartialSeed ph = {qr.hash >> index.get_aux_len(), qr.partial_start, qr.is_reverse};
PartialHit ph = {qr.hash >> index.get_aux_len(), qr.partial_start, qr.is_reverse};
if (std::find(partial_queried.begin(), partial_queried.end(), ph) != partial_queried.end()) {
// already queried
continue;
}
size_t partial_pos = index.partial_find(qr.hash);
size_t partial_pos = index.find_partial(qr.hash);
if (partial_pos != index.end()) {
unsigned int partial_count = index.get_partial_count(partial_pos);
unsigned int partial_count = index.get_count_partial(partial_pos);
RescueHit rh{partial_pos, partial_count, qr.partial_start, qr.partial_end, true};
if (qr.is_reverse){
hits_rc.push_back(rh);
Expand Down
2 changes: 1 addition & 1 deletion src/python/strobealign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ NB_MODULE(strobealign_extension, m_) {
.def(nb::init<References&, IndexParameters&>())
.def("find", [](const StrobemerIndex& index, uint64_t key) -> std::vector<RefRandstrobe> {
std::vector<RefRandstrobe> v;
auto position = index.find(key);
auto position = index.find_full(key);
while (position != index.end() && index.get_hash(position) == key) {
v.push_back(index.get_randstrobe(position));
position++;
Expand Down
11 changes: 11 additions & 0 deletions src/randstrobes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,17 @@ static inline syncmer_hash_t syncmer_smer_hash(uint64_t packed) {
return xxh64(packed);
}

/*
* This function combines two individual syncmer hashes into a single hash
* for the randstrobe.
*
* The syncmer with the smaller hash is designated as the "main", the other is
* the "auxiliary".
* The combined hash is obtained by setting the top aux_len bits to the bits of
* the main hash and the bottom 64 - aux_len bits to the bits of the auxiliary
* hash. Since entries in the index are sorted by randstrobe hash, this allows
* us to search for the main syncmer only by masking out the lower aux_len bits.
*/
static inline randstrobe_hash_t randstrobe_hash(syncmer_hash_t hash1, syncmer_hash_t hash2, size_t aux_len) {
// Make the function symmetric
if (hash1 > hash2) {
Expand Down
3 changes: 3 additions & 0 deletions src/randstrobes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ struct QueryRandstrobe {
randstrobe_hash_t hash;
unsigned int start;
unsigned int end;
/* Start and end of the main syncmer (relevant if the randstrobe couldn’t
* be found in the index and we fall back to a partial hit)
*/
unsigned int partial_start;
unsigned int partial_end;
bool is_reverse;
Expand Down