Skip to content

Commit

Permalink
Refactor the part of getting variable value from assignment table.
Browse files Browse the repository at this point in the history
  • Loading branch information
martun committed Jul 16, 2024
1 parent 93edae7 commit 25ef30f
Show file tree
Hide file tree
Showing 6 changed files with 250 additions and 224 deletions.
8 changes: 4 additions & 4 deletions libs/zk/include/nil/crypto3/zk/math/expression_evaluator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ namespace nil {
*/
expression_evaluator(
const math::expression<VariableType>& expr,
std::function<ValueType(const VariableType&)> get_var_value)
std::function<const ValueType&(const VariableType&)> get_var_value)
: expr(expr)
, get_var_value(get_var_value) {
}
Expand Down Expand Up @@ -140,7 +140,7 @@ namespace nil {
const math::expression<VariableType>& expr;

// A function used to retrieve the value of a variable.
std::function<ValueType(const VariableType &var)> get_var_value;
std::function<const ValueType&(const VariableType &var)> get_var_value;

};

Expand Down Expand Up @@ -207,7 +207,7 @@ namespace nil {
*/
cached_expression_evaluator(
const math::expression<VariableType>& expr,
std::function<ValueType(const VariableType&)> get_var_value)
std::function<const ValueType&(const VariableType&)> get_var_value)
: _expr(expr)
, _get_var_value(get_var_value) {
}
Expand Down Expand Up @@ -304,7 +304,7 @@ namespace nil {
const math::expression<VariableType>& _expr;

// A function used to retrieve the value of a variable.
std::function<ValueType(const VariableType &var)> _get_var_value;
std::function<const ValueType&(const VariableType &var)> _get_var_value;

// Shows how many times each subexpression appears. We count have the expression
// itself as a key, but apparently it's waay too slow. Just map the hash->count, assume
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@
#include <nil/crypto3/zk/snark/arithmetization/plonk/padding.hpp>
#include <nil/crypto3/random/algebraic_engine.hpp>
#include <nil/crypto3/math/polynomial/polynomial_dfs.hpp>
#include <nil/crypto3/zk/snark/arithmetization/plonk/variable.hpp>

namespace nil {
namespace blueprint {
template<typename ArithmetizationType>
class assignment;
} // namespace blueprint

namespace crypto3 {
namespace zk {
namespace snark {
Expand All @@ -55,6 +57,7 @@ namespace nil {
class plonk_private_table {
public:
using witnesses_container_type = std::vector<ColumnType>;
using VariableType = plonk_variable<ColumnType>;

protected:

Expand Down Expand Up @@ -82,6 +85,31 @@ namespace nil {
return _witnesses[index].size();
}

const ColumnType& get_variable_value_without_rotation(const VariableType& var) const {
switch (var.type) {
case VariableType::column_type::witness:
return witness(var.index);
case VariableType::column_type::public_input:
return public_input(var.index);
case VariableType::column_type::constant:
return constant(var.index);
case VariableType::column_type::selector:
return selector(var.index);
default:
std::cerr << "Invalid column type" << std::endl;
abort();
}
}

ColumnType get_variable_value(const VariableType& var, std::shared_ptr<math::evaluation_domain<FieldType>> domain) const {
if (var.rotation == 0) {
return get_variable_value_without_rotation(var);
}
return math::polynomial_shift(
this->get_variable_value_without_rotation(var),
var.rotation, domain->m);
}

const ColumnType& witness(std::uint32_t index) const {
assert(index < _witnesses.size());
return _witnesses[index];
Expand Down Expand Up @@ -126,6 +154,7 @@ namespace nil {
using public_input_container_type = std::vector<ColumnType>;
using constant_container_type = std::vector<ColumnType>;
using selector_container_type = std::vector<ColumnType>;
using VariableType = plonk_variable<ColumnType>;

protected:

Expand Down Expand Up @@ -286,6 +315,7 @@ namespace nil {
using public_input_container_type = typename public_table_type::public_input_container_type;
using constant_container_type = typename public_table_type::constant_container_type;
using selector_container_type = typename public_table_type::selector_container_type;
using VariableType = plonk_variable<ColumnType>;

protected:
// These are normally created by the assigner, or read from a file.
Expand All @@ -309,6 +339,31 @@ namespace nil {
, _public_table(public_inputs_amount, constants_amount, selectors_amount) {
}

const ColumnType& get_variable_value_without_rotation(const VariableType& var) const {
switch (var.type) {
case VariableType::column_type::witness:
return witness(var.index);
case VariableType::column_type::public_input:
return public_input(var.index);
case VariableType::column_type::constant:
return constant(var.index);
case VariableType::column_type::selector:
return selector(var.index);
default:
std::cerr << "Invalid column type" << std::endl;
abort();
}
}

ColumnType get_variable_value(const VariableType& var, std::shared_ptr<math::evaluation_domain<FieldType>> domain) const {
if (var.rotation == 0) {
return get_variable_value_without_rotation(var);
}
return math::polynomial_shift(
this->get_variable_value_without_rotation(var),
var.rotation, domain->m);
}

const ColumnType& witness(std::uint32_t index) const {
return _private_table.witness(index);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ namespace nil {
const plonk_assignment_table<FieldType> &assignments) const {
math::expression_evaluator<VariableType> evaluator(
*this,
[&assignments, row_index](const VariableType &var) {
[&assignments, row_index](const VariableType &var) -> const typename VariableType::assignment_type& {
std::size_t rows_amount = assignments.rows_amount();
switch (var.type) {
case VariableType::column_type::witness:
Expand All @@ -100,8 +100,8 @@ namespace nil {
case VariableType::column_type::selector:
return assignments.selector(var.index)[(rows_amount + row_index + var.rotation) % rows_amount];
default:
BOOST_ASSERT_MSG(false, "Invalid column type");
return VariableType::assignment_type::zero();
std::cerr << "Invalid column type" << std::endl;
abort();
}
});

Expand All @@ -110,38 +110,33 @@ namespace nil {

math::polynomial<typename VariableType::assignment_type>
evaluate(const plonk_polynomial_table<FieldType> &assignments,
std::shared_ptr<math::evaluation_domain<FieldType>>
domain) const {
using polynomial_type = math::polynomial<typename VariableType::assignment_type>;
using polynomial_variable_type = plonk_variable<polynomial_type>;
math::expression_variable_type_converter<VariableType, polynomial_variable_type> converter;

math::expression_evaluator<polynomial_variable_type> evaluator(
converter.convert(*this),
[&domain, &assignments](const VariableType &var) {
polynomial_type assignment;
switch (var.type) {
case VariableType::column_type::witness:
assignment = assignments.witness(var.index);
break;
case VariableType::column_type::public_input:
assignment = assignments.public_input(var.index);
break;
case VariableType::column_type::constant:
assignment = assignments.constant(var.index);
break;
case VariableType::column_type::selector:
assignment = assignments.selector(var.index);
break;
default:
BOOST_ASSERT_MSG(false, "Invalid column type");
}

if (var.rotation != 0) {
assignment =
math::polynomial_shift(assignment, domain->get_domain_element(var.rotation));
std::shared_ptr<math::evaluation_domain<FieldType>> domain) const {
using polynomial_type = math::polynomial<typename VariableType::assignment_type>;
using polynomial_variable_type = plonk_variable<polynomial_type>;

// Convert scalar values to polynomials inside the expression.
math::expression_variable_type_converter<VariableType, polynomial_variable_type> converter;
auto converted_expression = converter.convert(*this);

// For each variable with a rotation pre-compute its value.
std::unordered_map<polynomial_variable_type, polynomial_type> rotated_variable_values;

math::expression_for_each_variable_visitor<polynomial_variable_type> visitor(
[&rotated_variable_values, &assignments, &domain](const polynomial_variable_type& var) {
if (var.rotation == 0)
return;
rotated_variable_values[var] = assignments.get_variable_value(var, domain);
});
visitor.visit(converted_expression);

math::expression_evaluator<polynomial_variable_type> evaluator(
converted_expression,
[&domain, &assignments, &rotated_variable_values]
(const VariableType &var) -> const polynomial_type& {
if (var.rotation == 0) {
return assignments.get_variable_value_without_rotation(var, domain);
}
return assignment;
return rotated_variable_values[var];
});
return evaluator.evaluate();
}
Expand All @@ -152,43 +147,41 @@ namespace nil {
using polynomial_dfs_type = math::polynomial_dfs<typename VariableType::assignment_type>;
using polynomial_dfs_variable_type = plonk_variable<polynomial_dfs_type>;

// Convert scalar values to polynomials inside the expression.
math::expression_variable_type_converter<variable_type, polynomial_dfs_variable_type> converter(
[&assignments](const typename VariableType::assignment_type& coeff) {
polynomial_dfs_type(0, assignments.rows_amount(), coeff);
});
math::expression_evaluator<polynomial_dfs_variable_type> evaluator(
converter.convert(*this),
[&domain, &assignments](const polynomial_dfs_variable_type &var) {
polynomial_dfs_type assignment;
switch (var.type) {
case VariableType::column_type::witness:
assignment = assignments.witness(var.index);
break;
case VariableType::column_type::public_input:
assignment = assignments.public_input(var.index);
break;
case VariableType::column_type::constant:
assignment = assignments.constant(var.index);
break;
case VariableType::column_type::selector:
assignment = assignments.selector(var.index);
break;
default:
BOOST_ASSERT_MSG(false, "Invalid column type");
}

if (var.rotation != 0) {
assignment = math::polynomial_shift(assignment, var.rotation, domain->m);
auto converted_expression = converter.convert(*this);

// For each variable with a rotation pre-compute its value.
std::unordered_map<polynomial_dfs_variable_type, polynomial_dfs_type> rotated_variable_values;

math::expression_for_each_variable_visitor<polynomial_dfs_variable_type> visitor(
[&rotated_variable_values, &assignments, &domain](const polynomial_dfs_variable_type& var) {
if (var.rotation == 0)
return ;
rotated_variable_values[var] = assignments.get_variable_value(var, domain);
});
visitor.visit(converted_expression);

math::expression_evaluator<polynomial_dfs_variable_type> evaluator(
converted_expression,
[&domain, &assignments, &rotated_variable_values]
(const polynomial_dfs_variable_type &var) -> const polynomial_dfs_type& {
if (var.rotation == 0) {
return assignments.get_variable_value_without_rotation(var, domain);
}
return assignment;
return rotated_variable_values[var];
}
);

return evaluator.evaluate();
}

typename VariableType::assignment_type
evaluate(detail::plonk_evaluation_map<VariableType> &assignments) const {
evaluate(detail::plonk_evaluation_map<VariableType> &assignments) const -> const typename VariableType::assignment_type& {
math::expression_evaluator<VariableType> evaluator(
*this,
[&assignments](const VariableType &var) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,29 +93,7 @@ namespace nil {
// We may have variable values in required sizes in some cases.
if (variable_values_out.find(var) != variable_values_out.end())
continue;
polynomial_dfs_type assignment;
switch (var.type) {
case polynomial_dfs_variable_type::column_type::witness:
assignment = assignments.witness(var.index);
break;
case polynomial_dfs_variable_type::column_type::public_input:
assignment = assignments.public_input(var.index);
break;
case polynomial_dfs_variable_type::column_type::constant:
assignment = assignments.constant(var.index);
break;
case polynomial_dfs_variable_type::column_type::selector:
assignment = assignments.selector(var.index);
break;
default:
std::cerr << "Invalid column type";
std::abort();
break;
}

if (var.rotation != 0) {
assignment = math::polynomial_shift(assignment, var.rotation, domain->m);
}
polynomial_dfs_type assignment = assignments.get_variable_value(var, domain);
if (count > 1) {
assignment.resize(extended_domain_size, domain, extended_domain);
}
Expand Down Expand Up @@ -206,9 +184,11 @@ namespace nil {
extended_domain_sizes[i], variable_values);

math::cached_expression_evaluator<polynomial_dfs_variable_type> evaluator(
expressions[i], [&assignments=variable_values, domain_size=extended_domain_sizes[i]](const polynomial_dfs_variable_type &var) {
return assignments[var];
});
expressions[i], [&assignments=variable_values, domain_size=extended_domain_sizes[i]]
(const polynomial_dfs_variable_type &var) -> const polynomial_dfs_type& {
return assignments[var];
}
);

F[0] += evaluator.evaluate();
}
Expand Down
Loading

0 comments on commit 25ef30f

Please sign in to comment.