Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the part of getting variable value from assignment table. #290

Merged
merged 1 commit into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,11 @@ namespace nil {
constexpr static const std::size_t value_bits = ValueBits;
typedef typename boost::uint_t<value_bits>::least value_type;
BOOST_STATIC_ASSERT(word_bits % value_bits == 0);

constexpr static const std::size_t block_values = block_bits / value_bits;
typedef std::array<value_type, block_values> cache_type;

protected:
BOOST_STATIC_ASSERT(block_bits % value_bits == 0);

inline void process_block(std::size_t block_seen = block_bits) {
using namespace nil::crypto3::detail;
// Convert the input into words
Expand Down
8 changes: 4 additions & 4 deletions libs/zk/include/nil/crypto3/zk/math/expression_evaluator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ namespace nil {
*/
expression_evaluator(
const math::expression<VariableType>& expr,
std::function<ValueType(const VariableType&)> get_var_value)
std::function<const ValueType&(const VariableType&)> get_var_value)
: expr(expr)
, get_var_value(get_var_value) {
}
Expand Down Expand Up @@ -140,7 +140,7 @@ namespace nil {
const math::expression<VariableType>& expr;

// A function used to retrieve the value of a variable.
std::function<ValueType(const VariableType &var)> get_var_value;
std::function<const ValueType&(const VariableType &var)> get_var_value;

};

Expand Down Expand Up @@ -207,7 +207,7 @@ namespace nil {
*/
cached_expression_evaluator(
const math::expression<VariableType>& expr,
std::function<ValueType(const VariableType&)> get_var_value)
std::function<const ValueType&(const VariableType&)> get_var_value)
: _expr(expr)
, _get_var_value(get_var_value) {
}
Expand Down Expand Up @@ -304,7 +304,7 @@ namespace nil {
const math::expression<VariableType>& _expr;

// A function used to retrieve the value of a variable.
std::function<ValueType(const VariableType &var)> _get_var_value;
std::function<const ValueType&(const VariableType &var)> _get_var_value;

// Shows how many times each subexpression appears. We count have the expression
// itself as a key, but apparently it's waay too slow. Just map the hash->count, assume
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@
#include <nil/crypto3/zk/snark/arithmetization/plonk/padding.hpp>
#include <nil/crypto3/random/algebraic_engine.hpp>
#include <nil/crypto3/math/polynomial/polynomial_dfs.hpp>
#include <nil/crypto3/zk/snark/arithmetization/plonk/variable.hpp>

namespace nil {
namespace blueprint {
template<typename ArithmetizationType>
class assignment;
} // namespace blueprint

namespace crypto3 {
namespace zk {
namespace snark {
Expand All @@ -55,6 +57,7 @@ namespace nil {
class plonk_private_table {
public:
using witnesses_container_type = std::vector<ColumnType>;
using VariableType = plonk_variable<ColumnType>;

protected:

Expand Down Expand Up @@ -82,6 +85,31 @@ namespace nil {
return _witnesses[index].size();
}

const ColumnType& get_variable_value_without_rotation(const VariableType& var) const {
switch (var.type) {
case VariableType::column_type::witness:
return witness(var.index);
case VariableType::column_type::public_input:
return public_input(var.index);
case VariableType::column_type::constant:
return constant(var.index);
case VariableType::column_type::selector:
return selector(var.index);
default:
std::cerr << "Invalid column type" << std::endl;
abort();
}
}

ColumnType get_variable_value(const VariableType& var, std::shared_ptr<math::evaluation_domain<FieldType>> domain) const {
if (var.rotation == 0) {
return get_variable_value_without_rotation(var);
}
return math::polynomial_shift(
this->get_variable_value_without_rotation(var),
var.rotation, domain->m);
}

const ColumnType& witness(std::uint32_t index) const {
assert(index < _witnesses.size());
return _witnesses[index];
Expand Down Expand Up @@ -126,6 +154,7 @@ namespace nil {
using public_input_container_type = std::vector<ColumnType>;
using constant_container_type = std::vector<ColumnType>;
using selector_container_type = std::vector<ColumnType>;
using VariableType = plonk_variable<ColumnType>;

protected:

Expand Down Expand Up @@ -286,6 +315,7 @@ namespace nil {
using public_input_container_type = typename public_table_type::public_input_container_type;
using constant_container_type = typename public_table_type::constant_container_type;
using selector_container_type = typename public_table_type::selector_container_type;
using VariableType = plonk_variable<ColumnType>;

protected:
// These are normally created by the assigner, or read from a file.
Expand All @@ -309,6 +339,31 @@ namespace nil {
, _public_table(public_inputs_amount, constants_amount, selectors_amount) {
}

const ColumnType& get_variable_value_without_rotation(const VariableType& var) const {
switch (var.type) {
case VariableType::column_type::witness:
return witness(var.index);
case VariableType::column_type::public_input:
return public_input(var.index);
case VariableType::column_type::constant:
return constant(var.index);
case VariableType::column_type::selector:
return selector(var.index);
default:
std::cerr << "Invalid column type" << std::endl;
abort();
}
}

ColumnType get_variable_value(const VariableType& var, std::shared_ptr<math::evaluation_domain<FieldType>> domain) const {
if (var.rotation == 0) {
return get_variable_value_without_rotation(var);
}
return math::polynomial_shift(
this->get_variable_value_without_rotation(var),
var.rotation, domain->m);
}

const ColumnType& witness(std::uint32_t index) const {
return _private_table.witness(index);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ namespace nil {
const plonk_assignment_table<FieldType> &assignments) const {
math::expression_evaluator<VariableType> evaluator(
*this,
[&assignments, row_index](const VariableType &var) {
[&assignments, row_index](const VariableType &var) -> const typename VariableType::assignment_type& {
std::size_t rows_amount = assignments.rows_amount();
switch (var.type) {
case VariableType::column_type::witness:
Expand All @@ -100,8 +100,8 @@ namespace nil {
case VariableType::column_type::selector:
return assignments.selector(var.index)[(rows_amount + row_index + var.rotation) % rows_amount];
default:
BOOST_ASSERT_MSG(false, "Invalid column type");
return VariableType::assignment_type::zero();
std::cerr << "Invalid column type" << std::endl;
abort();
}
});

Expand All @@ -110,38 +110,33 @@ namespace nil {

math::polynomial<typename VariableType::assignment_type>
evaluate(const plonk_polynomial_table<FieldType> &assignments,
std::shared_ptr<math::evaluation_domain<FieldType>>
domain) const {
using polynomial_type = math::polynomial<typename VariableType::assignment_type>;
using polynomial_variable_type = plonk_variable<polynomial_type>;
math::expression_variable_type_converter<VariableType, polynomial_variable_type> converter;

math::expression_evaluator<polynomial_variable_type> evaluator(
converter.convert(*this),
[&domain, &assignments](const VariableType &var) {
polynomial_type assignment;
switch (var.type) {
case VariableType::column_type::witness:
assignment = assignments.witness(var.index);
break;
case VariableType::column_type::public_input:
assignment = assignments.public_input(var.index);
break;
case VariableType::column_type::constant:
assignment = assignments.constant(var.index);
break;
case VariableType::column_type::selector:
assignment = assignments.selector(var.index);
break;
default:
BOOST_ASSERT_MSG(false, "Invalid column type");
std::shared_ptr<math::evaluation_domain<FieldType>> domain) const {
using polynomial_type = math::polynomial<typename VariableType::assignment_type>;
using polynomial_variable_type = plonk_variable<polynomial_type>;

// Convert scalar values to polynomials inside the expression.
math::expression_variable_type_converter<VariableType, polynomial_variable_type> converter;
auto converted_expression = converter.convert(*this);

// For each variable with a rotation pre-compute its value.
std::unordered_map<polynomial_variable_type, polynomial_type> rotated_variable_values;

math::expression_for_each_variable_visitor<polynomial_variable_type> visitor(
[&rotated_variable_values, &assignments, &domain](const polynomial_variable_type& var) {
if (var.rotation == 0)
return;
rotated_variable_values[var] = assignments.get_variable_value(var, domain);
});
visitor.visit(converted_expression);

math::expression_evaluator<polynomial_variable_type> evaluator(
converted_expression,
[&domain, &assignments, &rotated_variable_values]
(const VariableType &var) -> const polynomial_type& {
if (var.rotation == 0) {
return assignments.get_variable_value_without_rotation(var, domain);
}

if (var.rotation != 0) {
assignment =
math::polynomial_shift(assignment, domain->get_domain_element(var.rotation));
}
return assignment;
return rotated_variable_values[var];
});
return evaluator.evaluate();
}
Expand All @@ -152,46 +147,43 @@ namespace nil {
using polynomial_dfs_type = math::polynomial_dfs<typename VariableType::assignment_type>;
using polynomial_dfs_variable_type = plonk_variable<polynomial_dfs_type>;

// Convert scalar values to polynomials inside the expression.
math::expression_variable_type_converter<variable_type, polynomial_dfs_variable_type> converter(
[&assignments](const typename VariableType::assignment_type& coeff) {
polynomial_dfs_type(0, assignments.rows_amount(), coeff);
});
math::expression_evaluator<polynomial_dfs_variable_type> evaluator(
converter.convert(*this),
[&domain, &assignments](const polynomial_dfs_variable_type &var) {
polynomial_dfs_type assignment;
switch (var.type) {
case VariableType::column_type::witness:
assignment = assignments.witness(var.index);
break;
case VariableType::column_type::public_input:
assignment = assignments.public_input(var.index);
break;
case VariableType::column_type::constant:
assignment = assignments.constant(var.index);
break;
case VariableType::column_type::selector:
assignment = assignments.selector(var.index);
break;
default:
BOOST_ASSERT_MSG(false, "Invalid column type");
}

if (var.rotation != 0) {
assignment = math::polynomial_shift(assignment, var.rotation, domain->m);
auto converted_expression = converter.convert(*this);

// For each variable with a rotation pre-compute its value.
std::unordered_map<polynomial_dfs_variable_type, polynomial_dfs_type> rotated_variable_values;

math::expression_for_each_variable_visitor<polynomial_dfs_variable_type> visitor(
[&rotated_variable_values, &assignments, &domain](const polynomial_dfs_variable_type& var) {
if (var.rotation == 0)
return ;
rotated_variable_values[var] = assignments.get_variable_value(var, domain);
});
visitor.visit(converted_expression);

math::expression_evaluator<polynomial_dfs_variable_type> evaluator(
converted_expression,
[&domain, &assignments, &rotated_variable_values]
(const polynomial_dfs_variable_type &var) -> const polynomial_dfs_type& {
if (var.rotation == 0) {
return assignments.get_variable_value_without_rotation(var, domain);
}
return assignment;
return rotated_variable_values[var];
}
);

return evaluator.evaluate();
}

typename VariableType::assignment_type
evaluate(detail::plonk_evaluation_map<VariableType> &assignments) const {
typename VariableType::assignment_type evaluate(detail::plonk_evaluation_map<VariableType> &assignments) const {
math::expression_evaluator<VariableType> evaluator(
*this,
[&assignments](const VariableType &var) {
[&assignments](const VariableType &var) -> const typename VariableType::assignment_type& {
std::tuple<std::size_t, int, typename VariableType::column_type> key =
std::make_tuple(var.index, var.rotation, var.type);

Expand Down
Loading
Loading