Skip to content

Commit

Permalink
Merge pull request #326 from ksahlin/refactor
Browse files Browse the repository at this point in the history
Refactoring
  • Loading branch information
marcelm authored Aug 21, 2023
2 parents 3223dc5 + c0e362b commit 40748fe
Show file tree
Hide file tree
Showing 15 changed files with 651 additions and 648 deletions.
699 changes: 223 additions & 476 deletions src/aln.cpp

Large diffs are not rendered by default.

16 changes: 11 additions & 5 deletions src/aln.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,23 @@ struct AlignmentStatistics {
}
};

struct mapping_params {
struct MappingParameters {
int r { 150 };
int max_secondary { 0 };
float dropoff_threshold { 0.5 };
int R { 2 };
int maxTries { 20 };
int rescue_level { 2 };
int max_tries { 20 };
int rescue_cutoff;
bool is_sam_out { true };
CigarOps cigar_ops{CigarOps::M};
bool output_unmapped { true };
bool details{false};

void verify() const {
if (max_tries < 1) {
throw BadParameter("max_tries must be greater than zero");
}
}
};

class i_dist_est {
Expand All @@ -88,7 +94,7 @@ void align_PE_read(
AlignmentStatistics& statistics,
i_dist_est& isize_est,
const Aligner& aligner,
const mapping_params& map_param,
const MappingParameters& map_param,
const IndexParameters& index_parameters,
const References& references,
const StrobemerIndex& index
Expand All @@ -100,7 +106,7 @@ void align_SE_read(
std::string& outstring,
AlignmentStatistics& statistics,
const Aligner& aligner,
const mapping_params& map_param,
const MappingParameters& map_param,
const IndexParameters& index_parameters,
const References& references,
const StrobemerIndex& index
Expand Down
4 changes: 2 additions & 2 deletions src/cmdline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ CommandLineOptions parse_command_line_arguments(int argc, char **argv) {
// Search parameters
if (f) { opt.f = args::get(f); }
if (S) { opt.dropoff_threshold = args::get(S); }
if (M) { opt.maxTries = args::get(M); }
if (R) { opt.R = args::get(R); }
if (M) { opt.max_tries = args::get(M); }
if (R) { opt.rescue_level = args::get(R); }

// Reference and read files
opt.ref_filename = args::get(ref_filename);
Expand Down
4 changes: 2 additions & 2 deletions src/cmdline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ struct CommandLineOptions {
// Search parameters
float f { 0.0002 };
float dropoff_threshold { 0.5 };
int maxTries { 20 };
int R { 2 };
int max_tries { 20 };
int rescue_level { 2 };

// Reference and read files
std::string ref_filename; // This is either a fasta file or an index file - if fasta, indexing will be run
Expand Down
15 changes: 8 additions & 7 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ void warn_if_no_optimizations() {
}
}

void log_parameters(const IndexParameters& index_parameters, const mapping_params& map_param, const alignment_params& aln_params) {
void log_parameters(const IndexParameters& index_parameters, const MappingParameters& map_param, const alignment_params& aln_params) {
logger.debug() << "Using" << std::endl
<< "k: " << index_parameters.syncmer.k << std::endl
<< "s: " << index_parameters.syncmer.s << std::endl
<< "w_min: " << index_parameters.randstrobe.w_min << std::endl
<< "w_max: " << index_parameters.randstrobe.w_max << std::endl
<< "Read length (r): " << map_param.r << std::endl
<< "Maximum seed length: " << index_parameters.randstrobe.max_dist + index_parameters.syncmer.k << std::endl
<< "R: " << map_param.R << std::endl
<< "R: " << map_param.rescue_level << std::endl
<< "Expected [w_min, w_max] in #syncmers: [" << index_parameters.randstrobe.w_min << ", " << index_parameters.randstrobe.w_max << "]" << std::endl
<< "Expected [w_min, w_max] in #nucleotides: [" << (index_parameters.syncmer.k - index_parameters.syncmer.s + 1) * index_parameters.randstrobe.w_min << ", " << (index_parameters.syncmer.k - index_parameters.syncmer.s + 1) * index_parameters.randstrobe.w_max << "]" << std::endl
<< "A: " << aln_params.match << std::endl
Expand Down Expand Up @@ -168,16 +168,17 @@ int run_strobealign(int argc, char **argv) {
aln_params.gap_extend = opt.E;
aln_params.end_bonus = opt.end_bonus;

mapping_params map_param;
MappingParameters map_param;
map_param.r = opt.r;
map_param.max_secondary = opt.max_secondary;
map_param.dropoff_threshold = opt.dropoff_threshold;
map_param.R = opt.R;
map_param.maxTries = opt.maxTries;
map_param.rescue_level = opt.rescue_level;
map_param.max_tries = opt.max_tries;
map_param.is_sam_out = opt.is_sam_out;
map_param.cigar_ops = opt.cigar_eqx ? CigarOps::EQX : CigarOps::M;
map_param.output_unmapped = opt.output_unmapped;
map_param.details = opt.details;
map_param.verify();

log_parameters(index_parameters, map_param, aln_params);
logger.debug() << "Threads: " << opt.n_threads << std::endl;
Expand Down Expand Up @@ -257,7 +258,7 @@ int run_strobealign(int argc, char **argv) {
// Map/align reads

Timer map_align_timer;
map_param.rescue_cutoff = map_param.R < 100 ? map_param.R * index.filter_cutoff : 1000;
map_param.rescue_cutoff = map_param.rescue_level < 100 ? map_param.rescue_level * index.filter_cutoff : 1000;
logger.debug() << "Using rescue cutoff: " << map_param.rescue_cutoff << std::endl;

std::streambuf* buf;
Expand Down Expand Up @@ -331,7 +332,7 @@ int main(int argc, char **argv) {
try {
return run_strobealign(argc, argv);
} catch (BadParameter& e) {
logger.error() << "A mapping or seeding parameter is invalid: " << e.what() << std::endl;
logger.error() << "A parameter is invalid: " << e.what() << std::endl;
} catch (const std::runtime_error& e) {
logger.error() << "strobealign: " << e.what() << std::endl;
}
Expand Down
74 changes: 37 additions & 37 deletions src/nam.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,28 @@
namespace {

struct Hit {
int query_s;
int query_e;
int ref_s;
int ref_e;
int query_start;
int query_end;
int ref_start;
int ref_end;
bool is_rc = false;
};

void add_to_hits_per_ref(
robin_hood::unordered_map<unsigned int, std::vector<Hit>>& hits_per_ref,
int query_s,
int query_e,
int query_start,
int query_end,
bool is_rc,
const StrobemerIndex& index,
size_t position,
int min_diff
) {
for (const auto hash = index.get_hash(position); index.get_hash(position) == hash; ++position) {
int ref_s = index.get_strobe1_position(position);
int ref_e = ref_s + index.strobe2_offset(position) + index.k();
int diff = std::abs((query_e - query_s) - (ref_e - ref_s));
int ref_start = index.get_strobe1_position(position);
int ref_end = ref_start + index.strobe2_offset(position) + index.k();
int diff = std::abs((query_end - query_start) - (ref_end - ref_start));
if (diff <= min_diff) {
hits_per_ref[index.reference_index(position)].push_back(Hit{query_s, query_e, ref_s, ref_e, is_rc});
hits_per_ref[index.reference_index(position)].push_back(Hit{query_start, query_end, ref_start, ref_end, is_rc});
min_diff = diff;
}
}
Expand All @@ -41,7 +41,7 @@ std::vector<Nam> merge_hits_into_nams(
if (sort) {
std::sort(hits.begin(), hits.end(), [](const Hit& a, const Hit& b) -> bool {
// first sort on query starts, then on reference starts
return (a.query_s < b.query_s) || ( (a.query_s == b.query_s) && (a.ref_s < b.ref_s) );
return (a.query_start < b.query_start) || ( (a.query_start == b.query_start) && (a.ref_start < b.ref_start) );
}
);
}
Expand All @@ -53,24 +53,24 @@ std::vector<Nam> merge_hits_into_nams(
for (auto & o : open_nams) {

// Extend NAM
if (( o.is_rc == h.is_rc) && (o.query_prev_hit_startpos < h.query_s) && (h.query_s <= o.query_e ) && (o.ref_prev_hit_startpos < h.ref_s) && (h.ref_s <= o.ref_e) ){
if ( (h.query_e > o.query_e) && (h.ref_e > o.ref_e) ) {
o.query_e = h.query_e;
o.ref_e = h.ref_e;
if (( o.is_rc == h.is_rc) && (o.query_prev_hit_startpos < h.query_start) && (h.query_start <= o.query_end ) && (o.ref_prev_hit_startpos < h.ref_start) && (h.ref_start <= o.ref_end) ){
if ( (h.query_end > o.query_end) && (h.ref_end > o.ref_end) ) {
o.query_end = h.query_end;
o.ref_end = h.ref_end;
// o.previous_query_start = h.query_s;
// o.previous_ref_start = h.ref_s; // keeping track so that we don't . Can be caused by interleaved repeats.
o.query_prev_hit_startpos = h.query_s; // log the last strobemer hit in case of outputting paf
o.ref_prev_hit_startpos = h.ref_s; // log the last strobemer hit in case of outputting paf
o.query_prev_hit_startpos = h.query_start; // log the last strobemer hit in case of outputting paf
o.ref_prev_hit_startpos = h.ref_start; // log the last strobemer hit in case of outputting paf
o.n_hits ++;
// o.score += (float)1/ (float)h.count;
is_added = true;
break;
}
else if ((h.query_e <= o.query_e) && (h.ref_e <= o.ref_e)) {
else if ((h.query_end <= o.query_end) && (h.ref_end <= o.ref_end)) {
// o.previous_query_start = h.query_s;
// o.previous_ref_start = h.ref_s; // keeping track so that we don't . Can be caused by interleaved repeats.
o.query_prev_hit_startpos = h.query_s; // log the last strobemer hit in case of outputting paf
o.ref_prev_hit_startpos = h.ref_s; // log the last strobemer hit in case of outputting paf
o.query_prev_hit_startpos = h.query_start; // log the last strobemer hit in case of outputting paf
o.ref_prev_hit_startpos = h.ref_start; // log the last strobemer hit in case of outputting paf
o.n_hits ++;
// o.score += (float)1/ (float)h.count;
is_added = true;
Expand All @@ -84,27 +84,27 @@ std::vector<Nam> merge_hits_into_nams(
Nam n;
n.nam_id = nam_id_cnt;
nam_id_cnt ++;
n.query_s = h.query_s;
n.query_e = h.query_e;
n.ref_s = h.ref_s;
n.ref_e = h.ref_e;
n.query_start = h.query_start;
n.query_end = h.query_end;
n.ref_start = h.ref_start;
n.ref_end = h.ref_end;
n.ref_id = ref_id;
// n.previous_query_start = h.query_s;
// n.previous_ref_start = h.ref_s;
n.query_prev_hit_startpos = h.query_s;
n.ref_prev_hit_startpos = h.ref_s;
n.query_prev_hit_startpos = h.query_start;
n.ref_prev_hit_startpos = h.ref_start;
n.n_hits = 1;
n.is_rc = h.is_rc;
// n.score += (float)1 / (float)h.count;
open_nams.push_back(n);
}

// Only filter if we have advanced at least k nucleotides
if (h.query_s > prev_q_start + k) {
if (h.query_start > prev_q_start + k) {

// Output all NAMs from open_matches to final_nams that the current hit have passed
for (auto &n : open_nams) {
if (n.query_e < h.query_s) {
if (n.query_end < h.query_start) {
int n_max_span = std::max(n.query_span(), n.ref_span());
int n_min_span = std::min(n.query_span(), n.ref_span());
float n_score;
Expand All @@ -116,10 +116,10 @@ std::vector<Nam> merge_hits_into_nams(
}

// Remove all NAMs from open_matches that the current hit have passed
auto c = h.query_s;
auto predicate = [c](decltype(open_nams)::value_type const &nam) { return nam.query_e < c; };
auto c = h.query_start;
auto predicate = [c](decltype(open_nams)::value_type const &nam) { return nam.query_end < c; };
open_nams.erase(std::remove_if(open_nams.begin(), open_nams.end(), predicate), open_nams.end());
prev_q_start = h.query_s;
prev_q_start = h.query_start;
}
}

Expand Down Expand Up @@ -180,13 +180,13 @@ std::vector<Nam> find_nams_rescue(
struct RescueHit {
unsigned int count;
size_t position;
unsigned int query_s;
unsigned int query_e;
unsigned int query_start;
unsigned int query_end;
bool is_rc;

bool operator< (const RescueHit& rhs) const {
return std::tie(count, query_s, query_e, is_rc)
< std::tie(rhs.count, rhs.query_s, rhs.query_e, rhs.is_rc);
return std::tie(count, query_start, query_end, is_rc)
< std::tie(rhs.count, rhs.query_start, rhs.query_end, rhs.is_rc);
}
};

Expand Down Expand Up @@ -218,7 +218,7 @@ std::vector<Nam> find_nams_rescue(
if ((rh.count > filter_cutoff && cnt >= 5) || rh.count > 1000) {
break;
}
add_to_hits_per_ref(hits_per_ref, rh.query_s, rh.query_e, rh.is_rc, index, rh.position, 1000);
add_to_hits_per_ref(hits_per_ref, rh.query_start, rh.query_end, rh.is_rc, index, rh.position, 1000);
cnt++;
}
}
Expand All @@ -227,6 +227,6 @@ std::vector<Nam> find_nams_rescue(
}

std::ostream& operator<<(std::ostream& os, const Nam& n) {
os << "Nam(ref_id=" << n.ref_id << ", query: " << n.query_s << ".." << n.query_e << ", ref: " << n.ref_s << ".." << n.ref_e << ", score=" << n.score << ")";
os << "Nam(ref_id=" << n.ref_id << ", query: " << n.query_start << ".." << n.query_end << ", ref: " << n.ref_start << ".." << n.ref_end << ", score=" << n.score << ")";
return os;
}
12 changes: 6 additions & 6 deletions src/nam.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
// Non-overlapping approximate match
struct Nam {
int nam_id;
int query_s;
int query_e;
int query_start;
int query_end;
int query_prev_hit_startpos;
int ref_s;
int ref_e;
int ref_start;
int ref_end;
int ref_prev_hit_startpos;
int n_hits = 0;
int ref_id;
Expand All @@ -22,11 +22,11 @@ struct Nam {
bool is_rc = false;

int ref_span() const {
return ref_e - ref_s;
return ref_end - ref_start;
}

int query_span() const {
return query_e - query_s;
return query_end - query_start;
}
};

Expand Down
8 changes: 4 additions & 4 deletions src/paf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
* 12 mapping quality (0-255; 255 for missing)
*/
void output_hits_paf_PE(std::string &paf_output, const Nam &n, const std::string &query_name, const References& references, int k, int read_len) {
if (n.ref_s < 0 ) {
if (n.ref_start < 0 ) {
return;
}
paf_output.append(query_name);
paf_output.append("\t");
paf_output.append(std::to_string(read_len));
paf_output.append("\t");
paf_output.append(std::to_string(n.query_s));
paf_output.append(std::to_string(n.query_start));
paf_output.append("\t");
paf_output.append(std::to_string(n.query_prev_hit_startpos + k));
paf_output.append("\t");
Expand All @@ -32,13 +32,13 @@ void output_hits_paf_PE(std::string &paf_output, const Nam &n, const std::string
paf_output.append("\t");
paf_output.append(std::to_string(references.lengths[n.ref_id]));
paf_output.append("\t");
paf_output.append(std::to_string(n.ref_s));
paf_output.append(std::to_string(n.ref_start));
paf_output.append("\t");
paf_output.append(std::to_string(n.ref_prev_hit_startpos + k));
paf_output.append("\t");
paf_output.append(std::to_string(n.n_hits));
paf_output.append("\t");
paf_output.append(std::to_string(n.ref_prev_hit_startpos + k - n.ref_s));
paf_output.append(std::to_string(n.ref_prev_hit_startpos + k - n.ref_start));
paf_output.append("\t255\n");
}

Expand Down
2 changes: 1 addition & 1 deletion src/pc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ void perform_task(
AlignmentStatistics& statistics,
int& done,
const alignment_params &aln_params,
const mapping_params &map_param,
const MappingParameters &map_param,
const IndexParameters& index_parameters,
const References& references,
const StrobemerIndex& index,
Expand Down
2 changes: 1 addition & 1 deletion src/pc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class OutputBuffer {

void perform_task(InputBuffer &input_buffer, OutputBuffer &output_buffer,
AlignmentStatistics& statistics, int& done, const alignment_params &aln_params,
const mapping_params &map_param, const IndexParameters& index_parameters, const References& references, const StrobemerIndex& index, const std::string& read_group_id);
const MappingParameters &map_param, const IndexParameters& index_parameters, const References& references, const StrobemerIndex& index, const std::string& read_group_id);

bool same_name(const std::string& n1, const std::string& n2);

Expand Down
8 changes: 4 additions & 4 deletions src/python/strobealign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,10 @@ NB_MODULE(strobealign_extension, m_) {
nb::bind_vector<QueryRandstrobeVector>(m, "QueryRandstrobeVector");

nb::class_<Nam>(m, "Nam")
.def_ro("query_start", &Nam::query_s)
.def_ro("query_end", &Nam::query_e)
.def_ro("ref_start", &Nam::ref_s)
.def_ro("ref_end", &Nam::ref_e)
.def_ro("query_start", &Nam::query_start)
.def_ro("query_end", &Nam::query_end)
.def_ro("ref_start", &Nam::ref_start)
.def_ro("ref_end", &Nam::ref_end)
.def_ro("score", &Nam::score)
.def_ro("n_hits", &Nam::n_hits)
.def_ro("reference_index", &Nam::ref_id)
Expand Down
Loading

0 comments on commit 40748fe

Please sign in to comment.