Skip to content

Commit

Permalink
Refactor generation mode class
Browse files Browse the repository at this point in the history
This is done to avoid ugly constructions when checking generation mode.
  • Loading branch information
aleasims committed Feb 15, 2024
1 parent f60eee7 commit 0800c9d
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 68 deletions.
84 changes: 41 additions & 43 deletions include/nil/blueprint/assigner.hpp

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion include/nil/blueprint/extract_constructor_parameters.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ namespace nil {
&assignment, generation_mode gen_mode
) {
std::vector<var> res = {};
if (std::uint8_t(gen_mode & generation_mode::ASSIGNMENTS)) {
if (gen_mode.has_assignments()) {
ptr_type input_ptr = static_cast<ptr_type>(
typename BlueprintFieldType::integral_type(var_value(assignment, variables[input_value]).data));
for (std::size_t i = 0; i < input_length; i++) {
Expand Down
134 changes: 113 additions & 21 deletions include/nil/blueprint/handle_component.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,24 +73,116 @@

namespace nil {
namespace blueprint {
/**
* @brief Assigner generation mode, defining which output types will be produced.
*
* A number of flags may be set:
*
* - CIRCUIT - generate circuit;
* - ASSIGNMENTS - generate assignment table;
* - FALSE_ASSIGNMENTS;
* - SIZE_ESTIMATION - print circuit stats (generate nothing).
*
* Binary AND and OR can be applied to modes:
* `mode_a | mode_b`, `mode_a & mode_b`.
**/
class generation_mode {
private:
enum modes : uint8_t {
NONE = 0,
CIRCUIT = 1 << 0,
ASSIGNMENTS = 1 << 1,
FALSE_ASSIGNMENTS = 1 << 2,
SIZE_ESTIMATION = 1 << 3
};

enum class generation_mode : uint8_t {
NONE = 0,
CIRCUIT = 1 << 0,
ASSIGNMENTS = 1 << 1,
FALSE_ASSIGNMENTS = 1 << 2,
SIZE_ESTIMATION = 1 << 3
};
public:
constexpr generation_mode() : mode_(NONE) {
}

constexpr enum generation_mode operator |( const enum generation_mode self, const enum generation_mode val )
{
return (enum generation_mode)(uint8_t(self) | uint8_t(val));
}
constexpr generation_mode(uint8_t mode) : mode_(mode) {
}

constexpr enum generation_mode operator &( const enum generation_mode self, const enum generation_mode val )
{
return (enum generation_mode)(uint8_t(self) & uint8_t(val));
}
constexpr generation_mode(const generation_mode& other) : mode_(other.mode_) {
}

/// @brief "Do nothing" mode.
constexpr static generation_mode none() {
return generation_mode(NONE);
}

/// @brief Generate circuit.
constexpr static generation_mode circuit() {
return generation_mode(CIRCUIT);
}

/// @brief Generate assignment table.
constexpr static generation_mode assignments() {
return generation_mode(ASSIGNMENTS);
}

constexpr static generation_mode false_assignments() {
return generation_mode(FALSE_ASSIGNMENTS);
}

/// @brief Print circuit statistics (generate nothing).
constexpr static generation_mode size_estimation() {
return generation_mode(SIZE_ESTIMATION);
}

constexpr bool operator==(generation_mode other) const {
return mode_ == other.mode_;
}

constexpr bool operator!=(generation_mode other) const {
return mode_ != other.mode_;
}

constexpr generation_mode operator|(const generation_mode other) const {
return generation_mode(mode_ | other.mode_);
}

constexpr generation_mode operator&(const generation_mode other) const {
return generation_mode(mode_ & other.mode_);
}

generation_mode& operator=(const generation_mode& other) {
mode_ = other.mode_;
return *this;
}

generation_mode& operator|=(const generation_mode& other) {
mode_ |= other.mode_;
return *this;
}

generation_mode& operator&=(const generation_mode& other) {
mode_ &= other.mode_;
return *this;
}

/// @brief Whether generate circuit or not in this mode.
constexpr bool has_circuit() const {
return mode_ & CIRCUIT;
}

/// @brief Whether generate assignment table or not in this mode.
constexpr bool has_assignments() const {
return mode_ & ASSIGNMENTS;
}

constexpr bool has_false_assignments() const {
return mode_ & FALSE_ASSIGNMENTS;
}

/// @brief Whether print circuit statistics or not in this mode.
constexpr bool has_size_estimation() const {
return mode_ & SIZE_ESTIMATION;
}

private:
uint8_t mode_;
};

struct common_component_parameters {
std::uint32_t start_row;
Expand All @@ -113,7 +205,7 @@ namespace nil {
bool found = (used_rows.find(v.get().rotation) != used_rows.end());
if (!found && (v.get().type == var::column_type::witness || v.get().type == var::column_type::constant)) {
var new_v;
if (std::uint8_t(gen_mode & generation_mode::ASSIGNMENTS)) {
if (gen_mode.has_assignments()) {
new_v = save_shared_var(assignment, v);
} else {
const auto& shared_idx = assignment.shared_column_size(0);
Expand Down Expand Up @@ -212,7 +304,7 @@ namespace nil {

BOOST_LOG_TRIVIAL(debug) << "Using component \"" << component_instance.component_name << "\"";

if (std::uint8_t(param.gen_mode & generation_mode::SIZE_ESTIMATION)) {
if (param.gen_mode.has_size_estimation()) {
statistics.add_record(
component_instance.component_name,
component_instance.rows_amount,
Expand All @@ -230,11 +322,11 @@ namespace nil {
// generate circuit in any case for fill selectors
generate_circuit(component_instance, bp, assignment, instance_input, param.start_row);

if (std::uint8_t(param.gen_mode & generation_mode::ASSIGNMENTS)) {
if (param.gen_mode.has_assignments()) {
return generate_assignments(component_instance, assignment, instance_input, param.start_row,
param.target_prover_idx);
} else {
if (std::uint8_t(param.gen_mode & generation_mode::FALSE_ASSIGNMENTS)) {
if (param.gen_mode.has_false_assignments()) {
const auto rows_amount = ComponentType::get_rows_amount(p.witness.size(), 0, args...);
// disable selector
for (std::uint32_t i = 0; i < rows_amount; i++) {
Expand Down Expand Up @@ -295,7 +387,7 @@ namespace nil {
std::vector<var> output = component_result.all_vars();

//touch result variables
if (std::uint8_t(gen_mode & generation_mode::ASSIGNMENTS) == 0) {
if (!gen_mode.has_assignments()) {
const auto result_vars = component_result.all_vars();
for (const auto &v : result_vars) {
if (v.type == var::column_type::witness) {
Expand Down Expand Up @@ -324,7 +416,7 @@ namespace nil {
using var = crypto3::zk::snark::plonk_variable<typename BlueprintFieldType::value_type>;

//touch result variables
if (std::uint8_t(gen_mode & generation_mode::ASSIGNMENTS) == 0) {
if (!gen_mode.has_assignments()) {
for (const auto &v : result) {
if (v.type == var::column_type::witness) {
assignment.witness(v.index, v.rotation) = BlueprintFieldType::value_type::zero();
Expand Down
2 changes: 1 addition & 1 deletion include/nil/blueprint/integers/bit_de_composition.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ namespace nil {
auto result = get_component_result<BlueprintFieldType, ArithmetizationParams, component_type>
(bp, assignment, statistics, param, instance_input, BitsAmount, Mode).output;

if (std::uint8_t(param.gen_mode & generation_mode::ASSIGNMENTS)) {
if (param.gen_mode.has_assignments()) {
ptr_type result_ptr = static_cast<ptr_type>(
typename BlueprintFieldType::integral_type(var_value(assignment, variables[result_value]).data));
for (var v : result) {
Expand Down
2 changes: 1 addition & 1 deletion include/nil/blueprint/recursive_prover/fri_array_swap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ namespace nil {
std::vector<var> res = get_component_result<BlueprintFieldType, ArithmetizationParams, component_type>
(bp, assignment, statistics, param, instance_input, array_size / 2).output;

if (std::uint8_t(param.gen_mode & generation_mode::ASSIGNMENTS)) {
if (param.gen_mode.has_assignments()) {
ptr_type result_ptr = static_cast<ptr_type>(typename BlueprintFieldType::integral_type(
var_value(assignment, frame.scalars[result_value]).data));
for (std::size_t i = 0; i < array_size; i++) {
Expand Down
2 changes: 1 addition & 1 deletion include/nil/blueprint/recursive_prover/fri_cosets.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ namespace nil {
const auto& result = get_component_result<BlueprintFieldType, ArithmetizationParams, component_type>
(bp, assignment, statistics, param, instance_input, res_length, omega).output;

if (std::uint8_t(param.gen_mode & generation_mode::ASSIGNMENTS)) {
if (param.gen_mode.has_assignments()) {
ptr_type result_ptr = static_cast<ptr_type>(
typename BlueprintFieldType::integral_type(var_value(assignment, variables[result_value]).data));
for (std::size_t i = 0; i < result.size(); i++) {
Expand Down

0 comments on commit 0800c9d

Please sign in to comment.