Skip to content

Commit

Permalink
Add stride implementation for the node class. (#144)
Browse files Browse the repository at this point in the history
* added stride implementation for the node class.
* added additional checks for ill-defined parameters.
* update test.grc.expected due to the changes in node class.
  • Loading branch information
drslebedev authored Aug 5, 2023
1 parent 691415a commit 07a83d8
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 87 deletions.
5 changes: 5 additions & 0 deletions include/annotated.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ struct BlockingIO {};
*/
struct PerformDecimationInterpolation {};

/**
* @brief Annotates node, indicating to perform stride
*/
struct PerformStride {};

/**
* @brief Annotates templated node, indicating which port data types are supported.
*/
Expand Down
104 changes: 75 additions & 29 deletions include/node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,14 +249,15 @@ class node : protected std::tuple<Arguments...> {
using A = Annotated<T, description, Args...>;

public:
using base_t = node<Derived, Arguments...>;
using derived_t = Derived;
using node_template_parameters = meta::typelist<Arguments...>;
using Description = typename node_template_parameters::template find_or_default<is_doc, EmptyDoc>;
constexpr static tag_propagation_policy_t tag_policy = tag_propagation_policy_t::TPP_ALL_TO_ALL;
A<uint64_t, "numerator", Doc<"decimation/interpolation settings">> numerator = 1_UZ;
A<uint64_t, "denominator", Doc<"decimation/interpolation settings">> denominator = 1_UZ;
A<uint64_t, "stride", Doc<"stride doc">> stride = 1_UZ;
using base_t = node<Derived, Arguments...>;
using derived_t = Derived;
using node_template_parameters = meta::typelist<Arguments...>;
using Description = typename node_template_parameters::template find_or_default<is_doc, EmptyDoc>;
constexpr static tag_propagation_policy_t tag_policy = tag_propagation_policy_t::TPP_ALL_TO_ALL;
A<std::size_t, "numerator", Doc<"The top number of a fraction = numerator/denominator: decimation (fraction < 1), interpolation (fraction > 1), no effect (fraction = 1)">> numerator = 1_UZ;
A<std::size_t, "denominator", Doc<"The bottom number of a fraction = numerator/denominator: decimation (fraction < 1), interpolation (fraction > 1), no effect (fraction = 1)">> denominator = 1_UZ;
A<std::size_t, "stride", Doc<"Number of samples between two data processing: overlap (stride < N), skip (stride > N), undefined-default (stride = 0)">> stride = 0_UZ;
std::size_t stride_counter = 0_UZ;
const std::size_t unique_id = _unique_id_counter++;
const std::string unique_name = fmt::format("{}#{}", fair::meta::type_name<Derived>(), unique_id);
A<std::string, "user-defined name", Doc<"N.B. may not be unique -> ::unique_name">> name{ std::string(fair::meta::type_name<Derived>()) };
Expand Down Expand Up @@ -533,6 +534,7 @@ class node : protected std::tuple<Arguments...> {
constexpr bool is_source_node = input_types::size == 0;
constexpr bool is_sink_node = output_types::size == 0;

// TODO: these checks can be moved to setting changed
if constexpr (node_template_parameters::template contains<PerformDecimationInterpolation>) {
static_assert(!is_sink_node, "Decimation/interpolation is not available for sink blocks. Remove 'PerformDecimationInterpolation' from the block definition.");
static_assert(!is_source_node, "Decimation/interpolation is not available for source blocks. Remove 'PerformDecimationInterpolation' from the block definition.");
Expand All @@ -545,6 +547,14 @@ class node : protected std::tuple<Arguments...> {
}
}

if constexpr (node_template_parameters::template contains<PerformStride>) {
static_assert(!is_source_node, "Stride is not available for source blocks. Remove 'PerformStride' from the block definition.");
} else {
if (stride != 0) {
throw std::runtime_error(fmt::format("Block is not defined as `PerformStride`, but stride = {}, it must equal to 0.", stride));
}
}

update_ports_status();

if constexpr (is_source_node) {
Expand Down Expand Up @@ -612,29 +622,33 @@ class node : protected std::tuple<Arguments...> {
return { requested_work, 0_UZ, ports_status.in_at_least_one_port_has_data ? work_return_status_t::INSUFFICIENT_INPUT_ITEMS : work_return_status_t::DONE };
}

if (numerator != 1. || denominator != 1.) {
bool is_ill_defined = (denominator > ports_status.in_max_samples);
assert(!is_ill_defined);
if (is_ill_defined) {
return { requested_work, 0_UZ, work_return_status_t::ERROR };
}

ports_status.in_samples = static_cast<std::size_t>(ports_status.in_samples / denominator) * denominator; // remove reminder
if constexpr (node_template_parameters::template contains<PerformDecimationInterpolation>) {
if (numerator != 1. || denominator != 1.) {
// TODO: this ill-defined checks can be done only once after parameters were changed
const double ratio = static_cast<double>(numerator) / static_cast<double>(denominator);
bool is_ill_defined = (denominator > ports_status.in_max_samples) || (ports_status.in_min_samples * ratio > ports_status.out_max_samples)
|| (ports_status.in_max_samples * ratio < ports_status.out_min_samples);
assert(!is_ill_defined);
if (is_ill_defined) {
return { requested_work, 0_UZ, work_return_status_t::ERROR };
}

const std::size_t out_min_limit = ports_status.out_min_samples;
const std::size_t out_max_limit = std::min(ports_status.out_available, ports_status.out_max_samples);
const double ratio = static_cast<double>(numerator) / static_cast<double>(denominator);
ports_status.in_samples = static_cast<std::size_t>(ports_status.in_samples / denominator) * denominator; // remove reminder

std::size_t in_min_samples = static_cast<std::size_t>(static_cast<double>(out_min_limit) / ratio);
if (in_min_samples % denominator != 0) in_min_samples += denominator;
std::size_t in_min_wo_reminder = static_cast<std::size_t>(in_min_samples / denominator) * denominator;
const std::size_t out_min_limit = ports_status.out_min_samples;
const std::size_t out_max_limit = std::min(ports_status.out_available, ports_status.out_max_samples);

const std::size_t in_max_samples = static_cast<std::size_t>(static_cast<double>(out_max_limit) / ratio);
std::size_t in_max_wo_reminder = static_cast<std::size_t>(in_max_samples / denominator) * denominator;
std::size_t in_min_samples = static_cast<std::size_t>(static_cast<double>(out_min_limit) / ratio);
if (in_min_samples % denominator != 0) in_min_samples += denominator;
std::size_t in_min_wo_reminder = static_cast<std::size_t>(in_min_samples / denominator) * denominator;

if (ports_status.in_samples < in_min_wo_reminder) return { requested_work, 0_UZ, work_return_status_t::INSUFFICIENT_INPUT_ITEMS };
ports_status.in_samples = std::clamp(ports_status.in_samples, in_min_wo_reminder, in_max_wo_reminder);
ports_status.out_samples = numerator * (ports_status.in_samples / denominator);
const std::size_t in_max_samples = static_cast<std::size_t>(static_cast<double>(out_max_limit) / ratio);
std::size_t in_max_wo_reminder = static_cast<std::size_t>(in_max_samples / denominator) * denominator;

if (ports_status.in_samples < in_min_wo_reminder) return { requested_work, 0_UZ, work_return_status_t::INSUFFICIENT_INPUT_ITEMS };
ports_status.in_samples = std::clamp(ports_status.in_samples, in_min_wo_reminder, in_max_wo_reminder);
ports_status.out_samples = numerator * (ports_status.in_samples / denominator);
}
}

// TODO: special case for ports_status.in_samples == 0 ?
Expand Down Expand Up @@ -708,6 +722,38 @@ class node : protected std::tuple<Arguments...> {
// case sources: HW triggered vs. generating data per invocation (generators via Port::MIN)
// case sinks: HW triggered vs. fixed-size consumer (may block/never finish for insufficient input data and fixed Port::MIN>0)

std::size_t n_samples_to_consume = ports_status.in_samples; // default stride == 0
if constexpr (node_template_parameters::template contains<PerformStride>) {
if (stride != 0) {
const bool first_time_stride = stride_counter == 0;
if (first_time_stride) {
// sample processing are done as usual, ports_status.in_samples samples will be processed
if (stride.value > stride_counter + ports_status.in_available) { // stride can not be consumed at once -> start stride_counter
stride_counter += ports_status.in_available;
n_samples_to_consume = ports_status.in_available;
} else { // if the stride can be consumed at once -> no stride_counter is needed
stride_counter = 0;
n_samples_to_consume = stride.value;
}
} else {
// |====================|...|====================|==============----| -> ====== is the stride
// ^first ^we are here (1) or ^here (2)
// if it is not the "first time" stride -> just consume (1) all samples or (2) missing rest of the samples
// forward tags but no additional sample processing are done ->return
if (stride.value > stride_counter + ports_status.in_available) {
stride_counter += ports_status.in_available;
n_samples_to_consume = ports_status.in_available;
} else { // stride is at the end -> reset stride_counter
n_samples_to_consume = stride.value - stride_counter;
stride_counter = 0;
}
const bool success = consume_readers(self(), n_samples_to_consume);
forward_tags();
return { requested_work, n_samples_to_consume, success ? work_return_status_t::OK : work_return_status_t::ERROR };
}
}
}

const auto input_spans = meta::tuple_transform([in_samples = ports_status.in_samples](auto &input_port) noexcept { return input_port.streamReader().get(in_samples); }, input_ports(&self()));
auto writers_tuple = meta::tuple_transform([out_samples = ports_status.out_samples](auto &output_port) noexcept { return output_port.streamWriter().reserve_output_range(out_samples); },
output_ports(&self()));
Expand All @@ -719,7 +765,7 @@ class node : protected std::tuple<Arguments...> {
}(std::make_index_sequence<traits::node::input_ports<Derived>::size>(), std::make_index_sequence<traits::node::output_ports<Derived>::size>());

write_to_outputs(ports_status.out_samples, writers_tuple);
const bool success = consume_readers(self(), ports_status.in_samples);
const bool success = consume_readers(self(), n_samples_to_consume);
forward_tags();
return { requested_work, ports_status.in_samples, success ? ret : work_return_status_t::ERROR };
} else if constexpr (HasProcessOneFunction<Derived>) {
Expand Down Expand Up @@ -759,7 +805,7 @@ class node : protected std::tuple<Arguments...> {

write_to_outputs(ports_status.out_samples, writers_tuple);

const bool success = consume_readers(self(), ports_status.in_samples);
const bool success = consume_readers(self(), n_samples_to_consume);

#ifdef _DEBUG
if (!success) {
Expand Down
32 changes: 16 additions & 16 deletions test/grc/test.grc.expected
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ blocks:
event_count: 100
name: main_source
numerator: 1
stride: 1
stride: 0
unique_name: good::fixed_source<double>#0
denominator::description: denominator
denominator::documentation: decimation/interpolation settings
denominator::documentation: "The bottom number of a fraction = numerator/denominator: decimation (fraction < 1), interpolation (fraction > 1), no effect (fraction = 1)"
denominator::unit: ""
denominator::visible: 0
description: ""
Expand All @@ -22,11 +22,11 @@ blocks:
name::unit: ""
name::visible: 0
numerator::description: numerator
numerator::documentation: decimation/interpolation settings
numerator::documentation: "The top number of a fraction = numerator/denominator: decimation (fraction < 1), interpolation (fraction > 1), no effect (fraction = 1)"
numerator::unit: ""
numerator::visible: 0
stride::description: stride
stride::documentation: stride doc
stride::documentation: "Number of samples between two data processing: overlap (stride < N), skip (stride > N), undefined-default (stride = 0)"
stride::unit: ""
stride::visible: 0
unknown_property: 42
Expand All @@ -36,10 +36,10 @@ blocks:
denominator: 1
name: multiplier
numerator: 1
stride: 1
stride: 0
unique_name: good::multiply<double>#0
denominator::description: denominator
denominator::documentation: decimation/interpolation settings
denominator::documentation: "The bottom number of a fraction = numerator/denominator: decimation (fraction < 1), interpolation (fraction > 1), no effect (fraction = 1)"
denominator::unit: ""
denominator::visible: 0
description: ""
Expand All @@ -52,11 +52,11 @@ blocks:
name::unit: ""
name::visible: 0
numerator::description: numerator
numerator::documentation: decimation/interpolation settings
numerator::documentation: "The top number of a fraction = numerator/denominator: decimation (fraction < 1), interpolation (fraction > 1), no effect (fraction = 1)"
numerator::unit: ""
numerator::visible: 0
stride::description: stride
stride::documentation: stride doc
stride::documentation: "Number of samples between two data processing: overlap (stride < N), skip (stride > N), undefined-default (stride = 0)"
stride::unit: ""
stride::visible: 0
- name: counter
Expand All @@ -65,10 +65,10 @@ blocks:
denominator: 1
name: counter
numerator: 1
stride: 1
stride: 0
unique_name: builtin_counter<double>#0
denominator::description: denominator
denominator::documentation: decimation/interpolation settings
denominator::documentation: "The bottom number of a fraction = numerator/denominator: decimation (fraction < 1), interpolation (fraction > 1), no effect (fraction = 1)"
denominator::unit: ""
denominator::visible: 0
description: ""
Expand All @@ -81,11 +81,11 @@ blocks:
name::unit: ""
name::visible: 0
numerator::description: numerator
numerator::documentation: decimation/interpolation settings
numerator::documentation: "The top number of a fraction = numerator/denominator: decimation (fraction < 1), interpolation (fraction > 1), no effect (fraction = 1)"
numerator::unit: ""
numerator::visible: 0
stride::description: stride
stride::documentation: stride doc
stride::documentation: "Number of samples between two data processing: overlap (stride < N), skip (stride > N), undefined-default (stride = 0)"
stride::unit: ""
stride::visible: 0
- name: sink
Expand All @@ -94,11 +94,11 @@ blocks:
denominator: 1
name: sink
numerator: 1
stride: 1
stride: 0
total_count: 100
unique_name: good::cout_sink<double>#0
denominator::description: denominator
denominator::documentation: decimation/interpolation settings
denominator::documentation: "The bottom number of a fraction = numerator/denominator: decimation (fraction < 1), interpolation (fraction > 1), no effect (fraction = 1)"
denominator::unit: ""
denominator::visible: 0
description: ""
Expand All @@ -111,11 +111,11 @@ blocks:
name::unit: ""
name::visible: 0
numerator::description: numerator
numerator::documentation: decimation/interpolation settings
numerator::documentation: "The top number of a fraction = numerator/denominator: decimation (fraction < 1), interpolation (fraction > 1), no effect (fraction = 1)"
numerator::unit: ""
numerator::visible: 0
stride::description: stride
stride::documentation: stride doc
stride::documentation: "Number of samples between two data processing: overlap (stride < N), skip (stride > N), undefined-default (stride = 0)"
stride::unit: ""
stride::visible: 0
unknown_property: 42
Expand Down
Loading

0 comments on commit 07a83d8

Please sign in to comment.