Skip to content

Commit

Permalink
Remove one of the StrobemerIndex::find*() functions
Browse files Browse the repository at this point in the history
Instead, have a common function that has a parameter which tells it how many
bit of the hash to ignore (0 for finding a full hash and aux_len for a
partial one).
  • Loading branch information
marcelm committed Oct 7, 2024
1 parent 2491b3d commit be5135f
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 64 deletions.
2 changes: 1 addition & 1 deletion src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ void StrobemerIndex::print_diagnostics(const std::string& logfile_name, int k) c

for (size_t it = 0; it < randstrobes.size(); it++) {
seed_length = strobe2_offset(it) + k;
auto count = get_count(find(get_hash(it)));
auto count = get_count_full(find_full(get_hash(it)));

if (seed_length < max_size){
log_count[seed_length] ++;
Expand Down
84 changes: 28 additions & 56 deletions src/index.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,35 +56,26 @@ struct StrobemerIndex {
void populate(float f, unsigned n_threads);
void print_diagnostics(const std::string& logfile_name, int k) const;
int pick_bits(size_t size) const;
size_t find(randstrobe_hash_t key) const {
constexpr int MAX_LINEAR_SEARCH = 4;
const unsigned int top_N = key >> (64 - bits);
bucket_index_t position_start = randstrobe_start_indices[top_N];
bucket_index_t position_end = randstrobe_start_indices[top_N + 1];
if (position_start == position_end) {
return end();
}

if (position_end - position_start < MAX_LINEAR_SEARCH) {
for ( ; position_start < position_end; ++position_start) {
if (randstrobes[position_start].hash == key) return position_start;
if (randstrobes[position_start].hash > key) return end();
}
return end();
}
auto cmp = [](const RefRandstrobe lhs, const RefRandstrobe rhs) {return lhs.hash < rhs.hash; };
// Find first entry that matches the given key
size_t find_full(randstrobe_hash_t key) const {
return find(key, 0);
}

auto pos = std::lower_bound(randstrobes.begin() + position_start,
randstrobes.begin() + position_end,
RefRandstrobe{key, 0, 0},
cmp);
if (pos->hash == key) return pos - randstrobes.begin();
return end();
/*
* Find the first entry that matches the main hash (ignoring the aux_len
* least significant bits)
*/
size_t find_partial(randstrobe_hash_t key) const {
return find(key, parameters.randstrobe.aux_len);
}

//Returns the first entry that matches the main hash
size_t partial_find(randstrobe_hash_t key) const {
const unsigned int aux_len = parameters.randstrobe.aux_len;
/*
* Find first entry whose hash matches the given key, but ignore the
* b least significant bits
*/
size_t find(randstrobe_hash_t key, uint8_t b) const {
const unsigned int aux_len = b;
randstrobe_hash_t key_prefix = key >> aux_len;

constexpr int MAX_LINEAR_SEARCH = 4;
Expand Down Expand Up @@ -171,46 +162,27 @@ struct StrobemerIndex {
return randstrobes.size();
}

unsigned int get_count(bucket_index_t position) const {
unsigned int get_count_full(bucket_index_t position) const {
return get_count(position, 0);
}

unsigned int get_count_partial(bucket_index_t position) const {
return get_count(position, parameters.randstrobe.aux_len);
}

unsigned int get_count(bucket_index_t position, uint8_t b) const {
// For 95% of cases, the result will be small and a brute force search
// is the best option. Once, we go over MAX_LINEAR_SEARCH, though, we
// use a binary search to get the next position
// In the human genome, if we assume that the frequency
// a hash will be queried is proportional to the frequency it appears in the table,
// In the human genome, if we assume that the frequency
// a hash will be queried is proportional to the frequency it appears in the table,
// with MAX_LINEAR_SEARCH=8, the actual value will be 96%.

// Since the result depends on position, this function must be used on the smallest position which points to the
// seed with the given hash to yield the number of seeds with this hash.

constexpr unsigned int MAX_LINEAR_SEARCH = 8;
const auto key = randstrobes[position].hash;
const unsigned int top_N = key >> (64 - bits);
bucket_index_t position_end = randstrobe_start_indices[top_N + 1];
uint64_t count = 1;

if (position_end - position < MAX_LINEAR_SEARCH) {
for (bucket_index_t position_start = position + 1; position_start < position_end; ++position_start) {
if (randstrobes[position_start].hash == key){
count += 1;
}
else{
break;
}
}
return count;
}
auto cmp = [](const RefRandstrobe lhs, const RefRandstrobe rhs) {return lhs.hash < rhs.hash; };

auto pos = std::upper_bound(randstrobes.begin() + position,
randstrobes.begin() + position_end,
RefRandstrobe{key, 0, 0},
cmp);
return (pos - randstrobes.begin() - 1) - position + 1;
}

unsigned int get_partial_count(bucket_index_t position) const {
constexpr unsigned int MAX_LINEAR_SEARCH = 8;
const unsigned int aux_len = parameters.randstrobe.aux_len;
const unsigned int aux_len = b;

const auto key = randstrobes[position].hash;
randstrobe_hash_t key_prefix = key >> aux_len;
Expand Down
12 changes: 6 additions & 6 deletions src/nam.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ std::tuple<float, int, std::vector<Nam>> find_nams(
int nr_good_hits = 0;
int total_hits = 0;
for (const auto &q : query_randstrobes) {
size_t position = index.find(q.hash);
size_t position = index.find_full(q.hash);
if (position != index.end()){
total_hits++;
if (index.is_filtered(position)) {
Expand All @@ -225,7 +225,7 @@ std::tuple<float, int, std::vector<Nam>> find_nams(
// already queried
continue;
}
size_t partial_pos = index.partial_find(q.hash);
size_t partial_pos = index.find_partial(q.hash);
if (partial_pos != index.end()) {
total_hits++;
if (index.is_partial_filtered(partial_pos)) {
Expand Down Expand Up @@ -278,9 +278,9 @@ std::pair<int, std::vector<Nam>> find_nams_rescue(
hits_rc.reserve(5000);

for (auto &qr : query_randstrobes) {
size_t position = index.find(qr.hash);
size_t position = index.find_full(qr.hash);
if (position != index.end()) {
unsigned int count = index.get_count(position);
unsigned int count = index.get_count_full(position);
RescueHit rh{position, count, qr.start, qr.end, false};
if (qr.is_reverse){
hits_rc.push_back(rh);
Expand All @@ -294,9 +294,9 @@ std::pair<int, std::vector<Nam>> find_nams_rescue(
// already queried
continue;
}
size_t partial_pos = index.partial_find(qr.hash);
size_t partial_pos = index.find_partial(qr.hash);
if (partial_pos != index.end()) {
unsigned int partial_count = index.get_partial_count(partial_pos);
unsigned int partial_count = index.get_count_partial(partial_pos);
RescueHit rh{partial_pos, partial_count, qr.partial_start, qr.partial_end, true};
if (qr.is_reverse){
hits_rc.push_back(rh);
Expand Down
2 changes: 1 addition & 1 deletion src/python/strobealign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ NB_MODULE(strobealign_extension, m_) {
.def(nb::init<References&, IndexParameters&>())
.def("find", [](const StrobemerIndex& index, uint64_t key) -> std::vector<RefRandstrobe> {
std::vector<RefRandstrobe> v;
auto position = index.find(key);
auto position = index.find_full(key);
while (position != index.end() && index.get_hash(position) == key) {
v.push_back(index.get_randstrobe(position));
position++;
Expand Down
3 changes: 3 additions & 0 deletions src/randstrobes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ struct QueryRandstrobe {
randstrobe_hash_t hash;
unsigned int start;
unsigned int end;
/* Start and end of the main syncmer (relevant if the randstrobe couldn’t
* be found in the index and we fall back to a partial hit)
*/
unsigned int partial_start;
unsigned int partial_end;
bool is_reverse;
Expand Down

0 comments on commit be5135f

Please sign in to comment.