diff --git a/src/index.cpp b/src/index.cpp index 85b26076..61a5681a 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -335,7 +335,7 @@ void StrobemerIndex::print_diagnostics(const std::string& logfile_name, int k) c for (size_t it = 0; it < randstrobes.size(); it++) { seed_length = strobe2_offset(it) + k; - auto count = get_count(find(get_hash(it))); + auto count = get_count_full(find_full(get_hash(it))); if (seed_length < max_size){ log_count[seed_length] ++; diff --git a/src/index.hpp b/src/index.hpp index e7ce41e9..6e82a220 100644 --- a/src/index.hpp +++ b/src/index.hpp @@ -56,35 +56,26 @@ struct StrobemerIndex { void populate(float f, unsigned n_threads); void print_diagnostics(const std::string& logfile_name, int k) const; int pick_bits(size_t size) const; - size_t find(randstrobe_hash_t key) const { - constexpr int MAX_LINEAR_SEARCH = 4; - const unsigned int top_N = key >> (64 - bits); - bucket_index_t position_start = randstrobe_start_indices[top_N]; - bucket_index_t position_end = randstrobe_start_indices[top_N + 1]; - if (position_start == position_end) { - return end(); - } - if (position_end - position_start < MAX_LINEAR_SEARCH) { - for ( ; position_start < position_end; ++position_start) { - if (randstrobes[position_start].hash == key) return position_start; - if (randstrobes[position_start].hash > key) return end(); - } - return end(); - } - auto cmp = [](const RefRandstrobe lhs, const RefRandstrobe rhs) {return lhs.hash < rhs.hash; }; + // Find first entry that matches the given key + size_t find_full(randstrobe_hash_t key) const { + return find(key, 0); + } - auto pos = std::lower_bound(randstrobes.begin() + position_start, - randstrobes.begin() + position_end, - RefRandstrobe{key, 0, 0}, - cmp); - if (pos->hash == key) return pos - randstrobes.begin(); - return end(); + /* + * Find the first entry that matches the main hash (ignoring the aux_len + * least significant bits) + */ + size_t find_partial(randstrobe_hash_t key) const { + return find(key, parameters.randstrobe.aux_len); } - //Returns the first entry that matches the main hash - size_t partial_find(randstrobe_hash_t key) const { - const unsigned int aux_len = parameters.randstrobe.aux_len; + /* + * Find first entry whose hash matches the given key, but ignore the + * b least significant bits + */ + size_t find(randstrobe_hash_t key, uint8_t b) const { + const unsigned int aux_len = b; randstrobe_hash_t key_prefix = key >> aux_len; constexpr int MAX_LINEAR_SEARCH = 4; @@ -171,46 +162,27 @@ struct StrobemerIndex { return randstrobes.size(); } - unsigned int get_count(bucket_index_t position) const { + unsigned int get_count_full(bucket_index_t position) const { + return get_count(position, 0); + } + + unsigned int get_count_partial(bucket_index_t position) const { + return get_count(position, parameters.randstrobe.aux_len); + } + + unsigned int get_count(bucket_index_t position, uint8_t b) const { // For 95% of cases, the result will be small and a brute force search // is the best option. Once, we go over MAX_LINEAR_SEARCH, though, we // use a binary search to get the next position - // In the human genome, if we assume that the frequency - // a hash will be queried is proportional to the frequency it appears in the table, + // In the human genome, if we assume that the frequency + // a hash will be queried is proportional to the frequency it appears in the table, // with MAX_LINEAR_SEARCH=8, the actual value will be 96%. // Since the result depends on position, this function must be used on the smallest position which points to the // seed with the given hash to yield the number of seeds with this hash. constexpr unsigned int MAX_LINEAR_SEARCH = 8; - const auto key = randstrobes[position].hash; - const unsigned int top_N = key >> (64 - bits); - bucket_index_t position_end = randstrobe_start_indices[top_N + 1]; - uint64_t count = 1; - - if (position_end - position < MAX_LINEAR_SEARCH) { - for (bucket_index_t position_start = position + 1; position_start < position_end; ++position_start) { - if (randstrobes[position_start].hash == key){ - count += 1; - } - else{ - break; - } - } - return count; - } - auto cmp = [](const RefRandstrobe lhs, const RefRandstrobe rhs) {return lhs.hash < rhs.hash; }; - - auto pos = std::upper_bound(randstrobes.begin() + position, - randstrobes.begin() + position_end, - RefRandstrobe{key, 0, 0}, - cmp); - return (pos - randstrobes.begin() - 1) - position + 1; - } - - unsigned int get_partial_count(bucket_index_t position) const { - constexpr unsigned int MAX_LINEAR_SEARCH = 8; - const unsigned int aux_len = parameters.randstrobe.aux_len; + const unsigned int aux_len = b; const auto key = randstrobes[position].hash; randstrobe_hash_t key_prefix = key >> aux_len; diff --git a/src/nam.cpp b/src/nam.cpp index 78c1b8b1..3a21c16d 100644 --- a/src/nam.cpp +++ b/src/nam.cpp @@ -210,7 +210,7 @@ std::tuple> find_nams( int nr_good_hits = 0; int total_hits = 0; for (const auto &q : query_randstrobes) { - size_t position = index.find(q.hash); + size_t position = index.find_full(q.hash); if (position != index.end()){ total_hits++; if (index.is_filtered(position)) { @@ -225,7 +225,7 @@ std::tuple> find_nams( // already queried continue; } - size_t partial_pos = index.partial_find(q.hash); + size_t partial_pos = index.find_partial(q.hash); if (partial_pos != index.end()) { total_hits++; if (index.is_partial_filtered(partial_pos)) { @@ -278,9 +278,9 @@ std::pair> find_nams_rescue( hits_rc.reserve(5000); for (auto &qr : query_randstrobes) { - size_t position = index.find(qr.hash); + size_t position = index.find_full(qr.hash); if (position != index.end()) { - unsigned int count = index.get_count(position); + unsigned int count = index.get_count_full(position); RescueHit rh{position, count, qr.start, qr.end, false}; if (qr.is_reverse){ hits_rc.push_back(rh); @@ -294,9 +294,9 @@ std::pair> find_nams_rescue( // already queried continue; } - size_t partial_pos = index.partial_find(qr.hash); + size_t partial_pos = index.find_partial(qr.hash); if (partial_pos != index.end()) { - unsigned int partial_count = index.get_partial_count(partial_pos); + unsigned int partial_count = index.get_count_partial(partial_pos); RescueHit rh{partial_pos, partial_count, qr.partial_start, qr.partial_end, true}; if (qr.is_reverse){ hits_rc.push_back(rh); diff --git a/src/python/strobealign.cpp b/src/python/strobealign.cpp index 7d9929cb..2e42fcbb 100644 --- a/src/python/strobealign.cpp +++ b/src/python/strobealign.cpp @@ -119,7 +119,7 @@ NB_MODULE(strobealign_extension, m_) { .def(nb::init()) .def("find", [](const StrobemerIndex& index, uint64_t key) -> std::vector { std::vector v; - auto position = index.find(key); + auto position = index.find_full(key); while (position != index.end() && index.get_hash(position) == key) { v.push_back(index.get_randstrobe(position)); position++; diff --git a/src/randstrobes.hpp b/src/randstrobes.hpp index 06f5bd14..bc91edf0 100644 --- a/src/randstrobes.hpp +++ b/src/randstrobes.hpp @@ -65,6 +65,9 @@ struct QueryRandstrobe { randstrobe_hash_t hash; unsigned int start; unsigned int end; + /* Start and end of the main syncmer (relevant if the randstrobe couldn’t + * be found in the index and we fall back to a partial hit) + */ unsigned int partial_start; unsigned int partial_end; bool is_reverse;