Skip to content

Commit

Permalink
Introduce Options classes to pass named optional arguments. (#14)
Browse files Browse the repository at this point in the history
This streamlines many user-visible constructors by aggregating all non-required
parameters into a separate class for more self-document argument passing. As a
result, we can now pass along the DuplicateAction from the handler constructors
to the inner MismatchTries. We also introduce a separate SearchStrand enum to
provide a consistent method to specify the strand to search in all handlers.

We remove all default arguments in constructors (same for overloads). Now that
we just have a single class to contain all optional parameters, the caller can
just default-construct an instance and pass it in, which is more explicit. If
we need to customize the options, we use IIFE to create Options instances in
the constructor - we now do this extensively in the tests.

Note that the MismatchTrie and ScanTemplate functions do not get Options, as
they are low-level (and simple) enough to have relatively few arguments.
  • Loading branch information
LTLA authored Aug 13, 2023
1 parent 4cbe987 commit fd20c70
Show file tree
Hide file tree
Showing 23 changed files with 1,303 additions and 511 deletions.
85 changes: 65 additions & 20 deletions include/kaori/BarcodeSearch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ void fill_library(
const std::vector<const char*>& options,
std::unordered_map<std::string, int>& exact,
Trie& trie,
bool reverse,
DuplicateAction duplicates
bool reverse
) {
size_t len = trie.get_length();

Expand All @@ -45,7 +44,7 @@ void fill_library(

// Note that this must be called, even if the sequence is duplicated;
// otherwise the trie's internal counter will not be properly incremented.
auto status = trie.add(current.c_str(), duplicates);
auto status = trie.add(current.c_str());

if (!status.has_ambiguous) {
if (!status.is_duplicate || status.duplicate_replaced) {
Expand Down Expand Up @@ -104,23 +103,43 @@ void matcher_in_the_rye(const std::string& x, const Cache& cache, const Trie& tr
* Instances of this class use caching to avoid redundant work when a mismatching sequence has been previously encountered.
*/
class SimpleBarcodeSearch {
public:
/**
* @brief Optional parameters for `SimpleBarcodeSearch`.
*/
struct Options {
/**
* Maximum number of mismatches for any search performed by `SimpleBarcodeSearch::search`.
*/
int max_mismatches = 0;

/**
* Whether to reverse-complement the barcode sequences before indexing them.
*/
bool reverse = false;

/**
* How duplicated barcode sequences should be handled.
*/
DuplicateAction duplicates = DuplicateAction::ERROR;
};

public:
/**
* Default constructor.
* This is only provided for composition purposes; methods of this class should only be called on properly constructed instance.
*/
SimpleBarcodeSearch() {}

/**
* @param barcode_pool Pool of barcode sequences.
* @param max_mismatches Maximum number of mismatches for any search performed by this class.
* @param reverse Whether to reverse-complement the barcode sequences.
* @param duplicates How duplicated `sequences` in `barcode_pool` should be handled.
* @param options Optional parameters for the search.
*/
SimpleBarcodeSearch(const BarcodePool& barcode_pool, int max_mismatches = 0, bool reverse = false, DuplicateAction duplicates = DuplicateAction::ERROR) :
trie(barcode_pool.length),
max_mm(max_mismatches)
SimpleBarcodeSearch(const BarcodePool& barcode_pool, const Options& options) :
trie(barcode_pool.length, options.duplicates),
max_mm(options.max_mismatches)
{
fill_library(barcode_pool.pool, exact, trie, reverse, duplicates);
fill_library(barcode_pool.pool, exact, trie, options.reverse);
return;
}

Expand Down Expand Up @@ -272,6 +291,37 @@ struct HasMore<total, total> {
*/
template<size_t num_segments>
class SegmentedBarcodeSearch {
public:
/**
* @brief Optional parameters for a `SegmentedBarcodeSearch`.
*/
struct Options {
/**
* @param max_mismatch_per_segment Maximum number of mismatches per segment.
* This is used to fill `max_mismatches`.
*/
Options(int max_mismatch_per_segment = 0) {
max_mismatches.fill(max_mismatch_per_segment);
}

/**
* Maximum number of mismatches in each segment for `SegmentedBarcodeSearch::search()`.
* All values should be non-negative.
* Defaults to an all-zero array in the `Options()` constructor.
*/
std::array<int, num_segments> max_mismatches;

/**
* Whether to reverse-complement the barcode sequences before indexing them.
*/
bool reverse = false;

/**
* How duplicated barcode sequences should be handled.
*/
DuplicateAction duplicates = DuplicateAction::ERROR;
};

public:
/**
* Default constructor.
Expand All @@ -282,25 +332,20 @@ class SegmentedBarcodeSearch {
* @param barcode_pool Pool of barcode sequences.
* @param segments Size of each segment.
* All values should be positive and their sum should be equal to the barcode length.
* @param max_mismatches Maximum number of mismatches in each segment.
* All values should be non-negative.
* @param reverse Whether to reverse-complement the barcode sequences.
* @param duplicates How duplicated `sequences` in `barcode_pool` should be handled.
* @param options Optional parameters.
*/
SegmentedBarcodeSearch(
const BarcodePool& barcode_pool,
std::array<int, num_segments> segments,
std::array<int, num_segments> max_mismatches,
bool reverse = false,
DuplicateAction duplicates = DuplicateAction::ERROR
const Options& options
) :
trie(segments),
max_mm(max_mismatches)
trie(segments, options.duplicates),
max_mm(options.max_mismatches)
{
if (barcode_pool.length != trie.get_length()) {
throw std::runtime_error("variable sequences should have the same length as the sum of segment lengths");
}
fill_library(barcode_pool.pool, exact, trie, reverse, duplicates);
fill_library(barcode_pool.pool, exact, trie, options.reverse);
return;
}

Expand Down
73 changes: 31 additions & 42 deletions include/kaori/MismatchTrie.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include <stdexcept>
#include <numeric>
#include "utils.hpp"
#include "BarcodePool.hpp"

/**
* @file MismatchTrie.hpp
Expand All @@ -31,19 +30,21 @@ class MismatchTrie {

public:
/**
* @param barcode_length Length of the barcodes in the pool.
* Default constructor.
* This is only provided to enable composition, the resulting object should not be used until it is copy-assigned to a properly constructed instance.
*/
MismatchTrie(size_t barcode_length = 0) : length(barcode_length), pointers(4, status_not_present), counter(0) {}
MismatchTrie() {}

/**
* @param barcode_pool Pool of known barcode sequences.
* @param duplicates How duplicated sequences in `barcode_pool` should be handled.
* @param barcode_length Length of the barcodes in the pool.
* @param duplicates How duplicate sequences across `add()` calls should be handled.
*/
MismatchTrie(const BarcodePool& barcode_pool, DuplicateAction duplicates = DuplicateAction::ERROR) : MismatchTrie(barcode_pool.length) {
for (auto s : barcode_pool.pool) {
add(s, duplicates);
}
}
MismatchTrie(size_t barcode_length, DuplicateAction duplicates) :
length(barcode_length),
pointers(4, status_not_present),
duplicates(duplicates),
counter(0)
{}

public:
/**
Expand Down Expand Up @@ -85,7 +86,7 @@ class MismatchTrie {
}
}

void end(int shift, int position, DuplicateAction duplicates, AddStatus& status) {
void end(int shift, int position, AddStatus& status) {
auto& current = pointers[position + shift];
if (current >= 0) {
status.is_duplicate = true;
Expand All @@ -112,7 +113,7 @@ class MismatchTrie {
}
}

void recursive_add(size_t i, int position, const char* barcode_seq, DuplicateAction duplicates, AddStatus& status) {
void recursive_add(size_t i, int position, const char* barcode_seq, AddStatus& status) {
// Processing a stretch of non-ambiguous codes, where possible.
// This reduces the recursion depth among the (hopefully fewer) ambiguous codes.
while (1) {
Expand All @@ -122,7 +123,7 @@ class MismatchTrie {
}

if ((++i) == length) {
end(shift, position, duplicates, status);
end(shift, position, status);
return;
} else {
next(shift, position);
Expand All @@ -135,11 +136,11 @@ class MismatchTrie {
auto process = [&](char base) -> void {
auto shift = base_shift(base);
if (i + 1 == length) {
end(shift, position, duplicates, status);
end(shift, position, status);
} else {
auto curpos = position;
next(shift, curpos);
recursive_add(i + 1, curpos, barcode_seq, duplicates, status);
recursive_add(i + 1, curpos, barcode_seq, status);
}
};

Expand Down Expand Up @@ -175,15 +176,14 @@ class MismatchTrie {
/**
* @param[in] barcode_seq Pointer to a character array containing a barcode sequence.
* The array should have length equal to `get_length()` and should only contain IUPAC nucleotides or their lower-case equivalents (excepting U or gap characters).
* @param duplicates How duplicate sequences across `add()` calls should be handled.
*
* @return The barcode sequence is added to the trie.
* The index of the newly added sequence is defined as the number of sequences that were previously added.
* The status of the addition is returned.
*/
AddStatus add(const char* barcode_seq, DuplicateAction duplicates) {
AddStatus add(const char* barcode_seq) {
AddStatus status;
recursive_add(0, 0, barcode_seq, duplicates, status);
recursive_add(0, 0, barcode_seq, status);
++counter;
return status;
}
Expand Down Expand Up @@ -239,6 +239,7 @@ class MismatchTrie {
*/

private:
DuplicateAction duplicates;
int counter;

public:
Expand Down Expand Up @@ -312,16 +313,16 @@ class MismatchTrie {
class AnyMismatches : public MismatchTrie {
public:
/**
* @param barcode_length Length of the barcode sequences.
* Default constructor.
* This is only provided to enable composition, the resulting object should not be used until it is copy-assigned to a properly constructed instance.
*/
AnyMismatches(size_t barcode_length = 0) : MismatchTrie(barcode_length) {}
AnyMismatches() {}

/**
* @param barcode_pool Pool of known barcode sequences.
* @param duplicates How duplicate sequences in `barcode_pool` should be handled.
* @param barcode_length Length of the barcode sequences.
* @param duplicates How duplicate sequences across `add()` calls should be handled.
*/
AnyMismatches(const BarcodePool& barcode_pool, DuplicateAction duplicates = DuplicateAction::ERROR) :
MismatchTrie(barcode_pool, duplicates) {}
AnyMismatches(size_t barcode_length, DuplicateAction duplicates) : MismatchTrie(barcode_length, duplicates) {}

public:
/**
Expand Down Expand Up @@ -424,36 +425,24 @@ class SegmentedMismatches : public MismatchTrie {
public:
/**
* Default constructor.
* This is only provided to enable composition, the resulting object should not be used until it is copy-assigned to a properly constructed instance.
*/
SegmentedMismatches() {}

/**
* @param segments Length of each segment of the sequence.
* Each entry should be positive and the sum should be equal to the total length of the barcode sequence.
* @param duplicates How duplicate sequences across `add()` calls should be handled.
*/
SegmentedMismatches(std::array<int, num_segments> segments) : MismatchTrie(std::accumulate(segments.begin(), segments.end(), 0)), boundaries(segments) {
SegmentedMismatches(std::array<int, num_segments> segments, DuplicateAction duplicates) :
MismatchTrie(std::accumulate(segments.begin(), segments.end(), 0), duplicates),
boundaries(segments)
{
for (size_t i = 1; i < num_segments; ++i) {
boundaries[i] += boundaries[i-1];
}
}

/**
* @param barcode_pool Possible set of known sequences for the variable region.
* @param segments Length of each segment of the sequence.
* Each entry should be positive and the sum should be equal to the total length of the barcode sequence.
* @param duplicates How duplicated sequences in `barcode_pool` should be handled.
*/
SegmentedMismatches(const BarcodePool& barcode_pool, std::array<int, num_segments> segments, DuplicateAction duplicates = DuplicateAction::ERROR) :
SegmentedMismatches(segments)
{
if (length != barcode_pool.length) {
throw std::runtime_error("length of barcode sequences should equal the sum of segment lengths");
}
for (auto s : barcode_pool.pool) {
add(s, duplicates);
}
}

public:
/**
* @brief Result of the segmented search.
Expand Down
7 changes: 3 additions & 4 deletions include/kaori/ScanTemplate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,10 @@ class ScanTemplate {
* Variable regions should be marked with `-`.
* @param template_length Length of the array pointed to by `template_seq`.
* This should be less than or equal to `max_size`.
* @param search_forward Should the search be performed on the forward strand of the read sequence?
* @param search_reverse Should the search be performed on the reverse strand of the read sequence?
* @param strand Strand(s) of the read sequence to search.
*/
ScanTemplate(const char* template_seq, size_t template_length, bool search_forward, bool search_reverse) :
length(template_length), forward(search_forward), reverse(search_reverse)
ScanTemplate(const char* template_seq, size_t template_length, SearchStrand strand) :
length(template_length), forward(search_forward(strand)), reverse(search_reverse(strand))
{
if (length > max_size) {
throw std::runtime_error("maximum template size should be " + std::to_string(max_size) + " bp");
Expand Down
Loading

0 comments on commit fd20c70

Please sign in to comment.