Skip to content

Commit

Permalink
Fix StrobemerIndex.find() in Python bindings
Browse files Browse the repository at this point in the history
It previously returned an empty list (regression introduced in
59b4d04)
  • Loading branch information
marcelm committed Feb 29, 2024
1 parent d420610 commit 4141d21
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 6 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,6 @@ requires = [
"ninja; platform_system!='Windows'",
"nanobind>=0.2.0",
]

[tool.pytest.ini_options]
testpaths = ["tests"]
11 changes: 5 additions & 6 deletions src/python/strobealign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,9 @@ NB_MODULE(strobealign_extension, m_) {
.def_ro("syncmer", &IndexParameters::syncmer)
.def_ro("randstrobe", &IndexParameters::randstrobe)
;
nb::class_<RefRandstrobe>(m, "RefRandstrobeWithHash", "Randstrobe on a reference")
nb::class_<RefRandstrobe>(m, "RefRandstrobe", "Randstrobe on a reference")
.def_ro("position", &RefRandstrobe::position)
.def_ro("hash", &RefRandstrobe::hash)
.def_prop_ro("reference_index", &RefRandstrobe::reference_index)
.def_prop_ro("strobe2_offset", &RefRandstrobe::strobe2_offset)
;
Expand All @@ -119,11 +120,9 @@ NB_MODULE(strobealign_extension, m_) {
.def("find", [](const StrobemerIndex& index, uint64_t key) -> std::vector<RefRandstrobe> {
std::vector<RefRandstrobe> v;
auto position = index.find(key);
if (position != index.end()) {
/*while (index.randstrobes[position].hash == key) {
v.push_back(index.randstrobes[position]);
position++;
}*/
while (position != index.end() && index.get_hash(position) == key) {
v.push_back(index.get_randstrobe(position));
position++;
}
return v;
})
Expand Down
21 changes: 21 additions & 0 deletions tests/test_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,24 @@ def test_indexing_and_nams_finding():
ref_aligned = ref[nam.ref_start:nam.ref_end]
query_aligned = query[nam.query_start:nam.query_end]
score = nam.score


def test_index_find():
refs = strobealign.References.from_fasta("tests/phix.fasta")
index_parameters = strobealign.IndexParameters.from_read_length(100)
index = strobealign.StrobemerIndex(refs, index_parameters)
index.populate()

query = "TGCGTTTATGGTACGCTGGACTTTGTGGGATACCCTCGCTTTCCTGCTCCTGTTGAGTTTATTGCTGCCG"
query_randstrobes = strobealign.randstrobes_query(query, index_parameters)
assert query_randstrobes
# First randstrobe must be found
assert index.find(query_randstrobes[0].hash)

n = 0
for qr in query_randstrobes:
for rs in index.find(qr.hash):
n += 1
assert rs.hash == qr.hash
# Ensure the for loop did test something
assert n > 1

0 comments on commit 4141d21

Please sign in to comment.