Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Catch usize conversion and Err more consistently #1499

Open
wants to merge 11 commits into
base: testnet3
Choose a base branch
from
5 changes: 5 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ snarkVM is a big project, so (non-)adherence to best practices related to perfor
- if possible, reuse collections; an example would be a loop that needs a clean vector on each iteration: instead of creating and allocating it over and over, create it _before_ the loop and use `.clear()` on every iteration instead
- try to keep the sizes of `enum` variants uniform; use `Box<T>` on ones that are large

### Cross-platform consistency
- First and foremost, types which contain consensus- or cryptographic logic should have a consistent size across platforms. Their serialized output should not contain `usize`. For defense in depth, we serialize `usize` as `u64`.
- For clarity, use `u32` and `u64` as much or long as possible, especially in type definitions.
- Given that we only target 32- and 64-bit systems, casting `usize` as `u64` and casting `u32` as `usize` will always be safe and doesn't need a `try_from::`. In serialization code, for defense in depth it is still encouraged to use `try_from::`.

### Misc. performance

- avoid the `format!()` macro; if it is used only to convert a single value to a `String`, use `.to_string()` instead, which is also available to all the implementors of `Display`
Expand Down
4 changes: 2 additions & 2 deletions algorithms/examples/msm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ pub fn main() -> Result<()> {

// Parse the variant.
match args[1].as_str() {
"batched" => batched::msm(bases.as_slice(), scalars.as_slice()),
"standard" => standard::msm(bases.as_slice(), scalars.as_slice()),
"batched" => batched::msm(bases.as_slice(), scalars.as_slice())?,
"standard" => standard::msm(bases.as_slice(), scalars.as_slice())?,
_ => panic!("Invalid variant: use 'batched' or 'standard'"),
};

Expand Down
17 changes: 9 additions & 8 deletions algorithms/src/crypto_hash/poseidon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ impl<F: PrimeField, const RATE: usize> PoseidonSponge<F, RATE, 1> {
let capacity = F::size_in_bits() - 1;
let mut dest_limbs = Vec::<F>::new();

let params = get_params(TargetField::size_in_bits(), F::size_in_bits(), ty);
let params = get_params(TargetField::size_in_bits_u32(), F::size_in_bits_u32(), ty);

let adjustment_factor_lookup_table = {
let mut table = Vec::<F>::new();
Expand All @@ -342,9 +342,9 @@ impl<F: PrimeField, const RATE: usize> PoseidonSponge<F, RATE, 1> {
let first = &src_limbs[i];
let second = if i + 1 < src_len { Some(&src_limbs[i + 1]) } else { None };

let first_max_bits_per_limb = params.bits_per_limb + crate::overhead!(first.1 + F::one());
let first_max_bits_per_limb = params.bits_per_limb as usize + crate::overhead!(first.1 + F::one());
let second_max_bits_per_limb = if let Some(second) = second {
params.bits_per_limb + crate::overhead!(second.1 + F::one())
params.bits_per_limb as usize + crate::overhead!(second.1 + F::one())
} else {
0
};
Expand Down Expand Up @@ -382,18 +382,19 @@ impl<F: PrimeField, const RATE: usize> PoseidonSponge<F, RATE, 1> {
elem: &<TargetField as PrimeField>::BigInteger,
optimization_type: OptimizationType,
) -> SmallVec<[F; 10]> {
let params = get_params(TargetField::size_in_bits(), F::size_in_bits(), optimization_type);
let params = get_params(TargetField::size_in_bits_u32(), F::size_in_bits_u32(), optimization_type);

// Push the lower limbs first
let mut limbs: SmallVec<[F; 10]> = SmallVec::new();
let mut cur = *elem;
for _ in 0..params.num_limbs {
let cur_bits = cur.to_bits_be(); // `to_bits` is big endian
let cur_mod_r =
<F as PrimeField>::BigInteger::from_bits_be(&cur_bits[cur_bits.len() - params.bits_per_limb..])
.unwrap(); // therefore, the lowest `bits_per_non_top_limb` bits is what we want.
let cur_mod_r = <F as PrimeField>::BigInteger::from_bits_be(
&cur_bits[cur_bits.len() - params.bits_per_limb as usize..],
)
.unwrap(); // therefore, the lowest `bits_per_non_top_limb` bits is what we want.
limbs.push(F::from_bigint(cur_mod_r).unwrap());
cur.divn(params.bits_per_limb as u32);
cur.divn(params.bits_per_limb);
}

// then we reserve, so that the limbs are ``big limb first''
Expand Down
24 changes: 15 additions & 9 deletions algorithms/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,36 @@ pub enum SNARKError {
#[error("{}", _0)]
AnyhowError(#[from] anyhow::Error),

#[error("Batch size was different between public input and proof")]
BatchSizeMismatch,

#[error("Circuit not found")]
CircuitNotFound,

#[error("{}", _0)]
ConstraintFieldError(#[from] ConstraintFieldError),

#[error("{}: {}", _0, _1)]
Crate(&'static str, String),

#[error("Batch size was zero; must be at least 1")]
EmptyBatch,

#[error("Expected a circuit-specific SRS in SNARK")]
ExpectedCircuitSpecificSRS,

#[error(transparent)]
IntError(#[from] std::num::TryFromIntError),

#[error("{}", _0)]
Message(String),

#[error(transparent)]
ParseIntError(#[from] std::num::ParseIntError),

#[error("{}", _0)]
SynthesisError(SynthesisError),

#[error("Batch size was zero; must be at least 1")]
EmptyBatch,

#[error("Batch size was different between public input and proof")]
BatchSizeMismatch,

#[error("Circuit not found")]
CircuitNotFound,

#[error("terminated")]
Terminated,
}
Expand Down
49 changes: 25 additions & 24 deletions algorithms/src/fft/domain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ const MIN_PARALLEL_CHUNK_SIZE: usize = 1 << 7;
#[derive(Copy, Clone, Hash, Eq, PartialEq, CanonicalSerialize, CanonicalDeserialize)]
pub struct EvaluationDomain<F: FftField> {
/// The size of the domain.
pub size: u64,
pub size: usize,
/// `log_2(self.size)`.
pub log_size_of_group: u32,
/// Size of the domain as a field element.
Expand Down Expand Up @@ -114,7 +114,7 @@ impl<F: FftField> EvaluationDomain<F> {
/// having `num_coeffs` coefficients.
pub fn new(num_coeffs: usize) -> Option<Self> {
// Compute the size of our evaluation domain
let size = num_coeffs.checked_next_power_of_two()? as u64;
let size = num_coeffs.checked_next_power_of_two()?;
let log_size_of_group = size.trailing_zeros();

// libfqfft uses > https://github.com/scipr-lab/libfqfft/blob/e0183b2cef7d4c5deb21a6eaf3fe3b586d738fe0/libfqfft/evaluation_domain/domains/basic_radix2_domain.tcc#L33
Expand All @@ -124,12 +124,13 @@ impl<F: FftField> EvaluationDomain<F> {

// Compute the generator for the multiplicative subgroup.
// It should be the 2^(log_size_of_group) root of unity.
let group_gen = F::get_root_of_unity(size as usize)?;
let group_gen = F::get_root_of_unity(size)?;
let size_u64 = size as u64;

// Check that it is indeed the 2^(log_size_of_group) root of unity.
debug_assert_eq!(group_gen.pow([size]), F::one());
debug_assert_eq!(group_gen.pow([size_u64]), F::one());

let size_as_field_element = F::from(size);
let size_as_field_element = F::from(size_u64);
let size_inv = size_as_field_element.inverse()?;

Some(EvaluationDomain {
Expand All @@ -152,7 +153,7 @@ impl<F: FftField> EvaluationDomain<F> {

/// Return the size of `self`.
pub fn size(&self) -> usize {
self.size as usize
self.size
}

/// Compute an FFT.
Expand Down Expand Up @@ -254,8 +255,8 @@ impl<F: FftField> EvaluationDomain<F> {
/// `tau`.
pub fn evaluate_all_lagrange_coefficients(&self, tau: F) -> Vec<F> {
// Evaluate all Lagrange polynomials
let size = self.size as usize;
let t_size = tau.pow([self.size]);
let size = self.size();
let t_size = tau.pow([self.size as u64]);
let one = F::one();
if t_size.is_one() {
let mut u = vec![F::zero(); size];
Expand Down Expand Up @@ -297,7 +298,7 @@ impl<F: FftField> EvaluationDomain<F> {
/// This evaluates the vanishing polynomial for this domain at tau.
/// For multiplicative subgroups, this polynomial is `z(X) = X^self.size - 1`.
pub fn evaluate_vanishing_polynomial(&self, tau: F) -> F {
tau.pow([self.size]) - F::one()
tau.pow([self.size as u64]) - F::one()
}

/// Return an iterator over the elements of the domain.
Expand Down Expand Up @@ -373,7 +374,7 @@ impl<F: FftField> EvaluationDomain<F> {
// SNP TODO: how to set threshold and check that the type is Fr
if self.size >= 32 && std::mem::size_of::<T>() == 32 {
let result = snarkvm_algorithms_cuda::NTT(
self.size as usize,
self.size(),
x_s,
snarkvm_algorithms_cuda::NTTInputOutputOrder::NN,
snarkvm_algorithms_cuda::NTTDirection::Forward,
Expand Down Expand Up @@ -402,7 +403,7 @@ impl<F: FftField> EvaluationDomain<F> {
// SNP TODO: how to set threshold
if self.size >= 32 && std::mem::size_of::<T>() == 32 {
let result = snarkvm_algorithms_cuda::NTT(
self.size as usize,
self.size(),
x_s,
snarkvm_algorithms_cuda::NTTInputOutputOrder::NN,
snarkvm_algorithms_cuda::NTTDirection::Inverse,
Expand All @@ -423,7 +424,7 @@ impl<F: FftField> EvaluationDomain<F> {
// SNP TODO: how to set threshold
if self.size >= 32 && std::mem::size_of::<T>() == 32 {
let result = snarkvm_algorithms_cuda::NTT(
self.size as usize,
self.size(),
x_s,
snarkvm_algorithms_cuda::NTTInputOutputOrder::NN,
snarkvm_algorithms_cuda::NTTDirection::Inverse,
Expand All @@ -450,7 +451,7 @@ impl<F: FftField> EvaluationDomain<F> {
// SNP TODO: how to set threshold
if self.size >= 32 && std::mem::size_of::<T>() == 32 {
let result = snarkvm_algorithms_cuda::NTT(
self.size as usize,
self.size(),
x_s,
snarkvm_algorithms_cuda::NTTInputOutputOrder::NN,
snarkvm_algorithms_cuda::NTTDirection::Forward,
Expand Down Expand Up @@ -481,7 +482,7 @@ impl<F: FftField> EvaluationDomain<F> {
// SNP TODO: how to set threshold
if self.size >= 32 && std::mem::size_of::<T>() == 32 {
let result = snarkvm_algorithms_cuda::NTT(
self.size as usize,
self.size(),
x_s,
snarkvm_algorithms_cuda::NTTInputOutputOrder::NN,
snarkvm_algorithms_cuda::NTTDirection::Inverse,
Expand Down Expand Up @@ -515,7 +516,7 @@ impl<F: FftField> EvaluationDomain<F> {
// SNP TODO: how to set threshold
if self.size >= 32 && std::mem::size_of::<T>() == 32 {
let result = snarkvm_algorithms_cuda::NTT(
self.size as usize,
self.size(),
x_s,
snarkvm_algorithms_cuda::NTTInputOutputOrder::NN,
snarkvm_algorithms_cuda::NTTDirection::Inverse,
Expand Down Expand Up @@ -583,17 +584,17 @@ impl<F: FftField> EvaluationDomain<F> {
// [1, g, g^2, ..., g^{(n/2) - 1}]
#[cfg(feature = "serial")]
pub fn roots_of_unity(&self, root: F) -> Vec<F> {
compute_powers_serial((self.size as usize) / 2, root)
compute_powers_serial((self.size()) / 2, root)
}

/// Computes the first `self.size / 2` roots of unity.
#[cfg(not(feature = "serial"))]
pub fn roots_of_unity(&self, root: F) -> Vec<F> {
// TODO: check if this method can replace parallel compute powers.
let log_size = log2(self.size as usize);
let log_size = log2(self.size());
// early exit for short inputs
if log_size <= LOG_ROOTS_OF_UNITY_PARALLEL_SIZE {
compute_powers_serial((self.size as usize) / 2, root)
compute_powers_serial((self.size()) / 2, root)
} else {
let mut temp = root;
// w, w^2, w^4, w^8, ..., w^(2^(log_size - 1))
Expand Down Expand Up @@ -783,19 +784,19 @@ const MIN_GAP_SIZE_FOR_PARALLELISATION: usize = 1 << 10;
const LOG_ROOTS_OF_UNITY_PARALLEL_SIZE: u32 = 7;

#[inline]
pub(super) fn bitrev(a: u64, log_len: u32) -> u64 {
a.reverse_bits() >> (64 - log_len)
pub(super) fn bitrev(a: usize, log_len: usize) -> usize {
a.reverse_bits() >> (std::mem::size_of::<usize>() * 8 - log_len)
}

pub(crate) fn derange<T>(xi: &mut [T]) {
derange_helper(xi, log2(xi.len()))
}

fn derange_helper<T>(xi: &mut [T], log_len: u32) {
for idx in 1..(xi.len() as u64 - 1) {
let ridx = bitrev(idx, log_len);
for idx in 1..(xi.len() - 1) {
let ridx = bitrev(idx, log_len as usize);
if idx < ridx {
xi.swap(idx as usize, ridx as usize);
xi.swap(idx, ridx);
}
}
}
Expand Down Expand Up @@ -864,7 +865,7 @@ impl<F: FftField> Iterator for Elements<F> {
type Item = F;

fn next(&mut self) -> Option<F> {
if self.cur_pow == self.domain.size {
if self.cur_pow == self.domain.size as u64 {
None
} else {
let cur_elem = self.cur_elem;
Expand Down
3 changes: 1 addition & 2 deletions algorithms/src/fft/polynomial/sparse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@

use crate::fft::{EvaluationDomain, Evaluations, Polynomial};
use snarkvm_fields::{Field, PrimeField};
use snarkvm_utilities::serialize::*;

use std::{collections::BTreeMap, fmt};

/// Stores a sparse polynomial in coefficient form.
#[derive(Clone, PartialEq, Eq, Hash, Default, CanonicalSerialize, CanonicalDeserialize)]
#[derive(Clone, PartialEq, Eq, Hash, Default)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

side note: it might be a good idea to check if any other objects can be "trimmed" in a similar fashion for faster compilation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a thumbsup! rust-lang/rust#50927

#[must_use]
pub struct SparsePolynomial<F: Field> {
/// The coefficient a_i of `x^i` is stored as (i, a_i) in `self.coeffs`.
Expand Down
1 change: 1 addition & 0 deletions algorithms/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#![allow(clippy::module_inception)]
#![allow(clippy::type_complexity)]
#![cfg_attr(test, allow(clippy::assertions_on_result_states))]
#![warn(clippy::cast_possible_truncation)]
howardwu marked this conversation as resolved.
Show resolved Hide resolved

#[cfg(feature = "wasm")]
#[macro_use]
Expand Down
2 changes: 1 addition & 1 deletion algorithms/src/msm/fixed_base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use rayon::prelude::*;
pub struct FixedBase;

impl FixedBase {
pub fn get_mul_window_size(num_scalars: usize) -> usize {
pub fn get_mul_window_size(num_scalars: usize) -> u32 {
match num_scalars < 32 {
true => 3,
false => super::ln_without_floats(num_scalars),
Expand Down
4 changes: 2 additions & 2 deletions algorithms/src/msm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub use variable_base::*;
/// [`Explanation of usage`]
///
/// [`Explanation of usage`]: https://github.com/scipr-lab/zexe/issues/79#issue-556220473
fn ln_without_floats(a: usize) -> usize {
fn ln_without_floats(a: usize) -> u32 {
// log2(a) * ln(2)
(crate::fft::domain::log2(a) * 69 / 100) as usize
crate::fft::domain::log2(a) * 69 / 100
}
4 changes: 2 additions & 2 deletions algorithms/src/msm/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ fn variable_base_test_with_bls12() {
let g = (0..SAMPLES).map(|_| G1Projective::rand(&mut rng).to_affine()).collect::<Vec<_>>();

let naive = naive_variable_base_msm(g.as_slice(), v.as_slice());
let fast = VariableBase::msm(g.as_slice(), v.as_slice());
let fast = VariableBase::msm(g.as_slice(), v.as_slice()).unwrap();

assert_eq!(naive.to_affine(), fast.to_affine());
}
Expand All @@ -60,7 +60,7 @@ fn variable_base_test_with_bls12_unequal_numbers() {
let g = (0..SAMPLES).map(|_| G1Projective::rand(&mut rng).to_affine()).collect::<Vec<_>>();

let naive = naive_variable_base_msm(g.as_slice(), v.as_slice());
let fast = VariableBase::msm(g.as_slice(), v.as_slice());
let fast = VariableBase::msm(g.as_slice(), v.as_slice()).unwrap();

assert_eq!(naive.to_affine(), fast.to_affine());
}
Loading