Skip to content

Commit

Permalink
o1vm/mips: use biguint instead of Fp in witness builder
Browse files Browse the repository at this point in the history
  • Loading branch information
dannywillems committed Nov 25, 2024
1 parent 20b29d3 commit a033449
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 57 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions o1vm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ libflate.workspace = true
log.workspace = true
mina-curves.workspace = true
mina-poseidon.workspace = true
num-bigint.workspace = true
o1-utils.workspace = true
os_pipe.workspace = true
poly-commitment.workspace = true
Expand Down
8 changes: 4 additions & 4 deletions o1vm/src/interpreters/mips/tests_helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ use crate::{
},
preimage_oracle::PreImageOracleT,
};
use ark_ff::Zero;
use num_bigint::BigUint;
use rand::{CryptoRng, Rng, RngCore};
use std::{fs, path::PathBuf};

// FIXME: we should parametrize the tests with different fields.
use ark_bn254::Fr as Fp;

use super::column::{SCRATCH_SIZE, SCRATCH_SIZE_INVERSE};

const PAGE_INDEX_EXECUTABLE_MEMORY: u32 = 1;

pub(crate) struct OnDiskPreImageOracle;
Expand Down Expand Up @@ -90,8 +90,8 @@ where
registers_write_index: Registers::default(),
scratch_state_idx: 0,
scratch_state_idx_inverse: 0,
scratch_state: [Fp::from(0); SCRATCH_SIZE],
scratch_state_inverse: [Fp::from(0); SCRATCH_SIZE_INVERSE],
scratch_state: std::array::from_fn(|_| BigUint::zero()),
scratch_state_inverse: std::array::from_fn(|_| BigUint::zero()),
selector: crate::interpreters::mips::column::N_MIPS_SEL_COLS,
halt: false,
// Keccak related
Expand Down
48 changes: 23 additions & 25 deletions o1vm/src/interpreters/mips/witness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ use crate::{
preimage_oracle::PreImageOracleT,
utils::memory_size,
};
use ark_ff::Field;
use ark_ff::{PrimeField, Zero};
use core::panic;
use kimchi::o1_utils::Two;
use log::{debug, info};
use num_bigint::BigUint;
use o1_utils::FieldHelpers;
use std::{
array,
fs::File,
Expand Down Expand Up @@ -69,7 +70,7 @@ impl SyscallEnv {
/// machine has access to its internal state and some external memory. In
/// addition to that, it has access to the environment of the Keccak interpreter
/// that is used to verify the preimage requested during the execution.
pub struct Env<Fp, PreImageOracle: PreImageOracleT> {
pub struct Env<Fp: PrimeField, PreImageOracle: PreImageOracleT> {
pub instruction_counter: u64,
pub memory: Vec<(u32, Vec<u8>)>,
pub last_memory_accesses: [usize; 3],
Expand All @@ -79,8 +80,8 @@ pub struct Env<Fp, PreImageOracle: PreImageOracleT> {
pub registers_write_index: Registers<u64>,
pub scratch_state_idx: usize,
pub scratch_state_idx_inverse: usize,
pub scratch_state: [Fp; SCRATCH_SIZE],
pub scratch_state_inverse: [Fp; SCRATCH_SIZE_INVERSE],
pub scratch_state: [BigUint; SCRATCH_SIZE],
pub scratch_state_inverse: [BigUint; SCRATCH_SIZE_INVERSE],
pub halt: bool,
pub syscall_env: SyscallEnv,
pub selector: usize,
Expand All @@ -92,11 +93,11 @@ pub struct Env<Fp, PreImageOracle: PreImageOracleT> {
pub hash_counter: u64,
}

fn fresh_scratch_state<Fp: Field, const N: usize>() -> [Fp; N] {
array::from_fn(|_| Fp::zero())
fn fresh_scratch_state<const N: usize>() -> [BigUint; N] {
array::from_fn(|_| BigUint::zero())
}

impl<Fp: Field, PreImageOracle: PreImageOracleT> InterpreterEnv for Env<Fp, PreImageOracle> {
impl<Fp: PrimeField, PreImageOracle: PreImageOracleT> InterpreterEnv for Env<Fp, PreImageOracle> {
type Position = Column;

fn alloc_scratch(&mut self) -> Self::Position {
Expand Down Expand Up @@ -325,31 +326,28 @@ impl<Fp: Field, PreImageOracle: PreImageOracleT> InterpreterEnv for Env<Fp, PreI
};
// write the non deterministic advice inv_or_zero
let pos = self.alloc_scratch_inverse();
if *x == 0 {
self.write_field_column(pos, Fp::zero());
} else {
self.write_field_column(pos, Fp::from(*x));
};
self.write_biguint_column(pos, BigUint::from(*x));
// return the result
res
}

fn equal(&mut self, x: &Self::Variable, y: &Self::Variable) -> Self::Variable {
// We replicate is_zero(x-y), but working on field elt,
// to avoid subtraction overflow in the witness interpreter for u32
let to_zero_test = Fp::from(*x) - Fp::from(*y);
let res = {
let pos = self.alloc_scratch();
let is_zero: u64 = if to_zero_test == Fp::zero() { 1 } else { 0 };
let is_zero: u64 = if x == y { 1 } else { 0 };
self.write_column(pos, is_zero);
is_zero
};
let pos = self.alloc_scratch_inverse();
if to_zero_test == Fp::zero() {
self.write_field_column(pos, Fp::zero());
if x > y {
let res = BigUint::from(x - y);
self.write_biguint_column(pos, res);
} else {
self.write_field_column(pos, to_zero_test);
};
let res = Fp::modulus_biguint().clone() - BigUint::from(y - x);
self.write_biguint_column(pos, res)
}
res
}

Expand Down Expand Up @@ -739,10 +737,10 @@ impl<Fp: Field, PreImageOracle: PreImageOracleT> InterpreterEnv for Env<Fp, PreI

// Store preimage key in the witness excluding the MSB as 248 bits
// so it can be used for the communication channel between Keccak
let bytes31 = (1..32).fold(Fp::zero(), |acc, i| {
acc * Fp::two_pow(8) + Fp::from(self.preimage_key.unwrap()[i])
let bytes31 = (1..32).fold(BigUint::zero(), |acc, i| {
acc * BigUint::from((1 << 8) as u32) + BigUint::from(self.preimage_key.unwrap()[i])
});
self.write_field_column(Self::Position::ScratchState(MIPS_PREIMAGE_KEY), bytes31);
self.write_biguint_column(Self::Position::ScratchState(MIPS_PREIMAGE_KEY), bytes31);

debug!("Preimage has been read entirely, triggering Keccak process");
self.keccak_env = Some(KeccakEnv::<Fp>::new(
Expand Down Expand Up @@ -819,7 +817,7 @@ impl<Fp: Field, PreImageOracle: PreImageOracleT> InterpreterEnv for Env<Fp, PreI
}
}

impl<Fp: Field, PreImageOracle: PreImageOracleT> Env<Fp, PreImageOracle> {
impl<Fp: PrimeField, PreImageOracle: PreImageOracleT> Env<Fp, PreImageOracle> {
pub fn create(page_size: usize, state: State, preimage_oracle: PreImageOracle) -> Self {
let initial_instruction_pointer = state.pc;
let next_instruction_pointer = state.next_pc;
Expand Down Expand Up @@ -906,10 +904,10 @@ impl<Fp: Field, PreImageOracle: PreImageOracleT> Env<Fp, PreImageOracle> {
}

pub fn write_column(&mut self, column: Column, value: u64) {
self.write_field_column(column, value.into())
self.write_biguint_column(column, value.into())
}

pub fn write_field_column(&mut self, column: Column, value: Fp) {
pub fn write_biguint_column(&mut self, column: Column, value: BigUint) {
match column {
Column::ScratchState(idx) => self.scratch_state[idx] = value,
Column::ScratchStateInverse(idx) => self.scratch_state_inverse[idx] = value,
Expand Down
2 changes: 1 addition & 1 deletion o1vm/src/legacy/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ pub fn main() -> ExitCode {
for i in 0..N_MIPS_REL_COLS {
match i.cmp(&SCRATCH_SIZE) {
Ordering::Less => mips_trace.trace.get_mut(&instr).unwrap().witness.cols[i]
.push(mips_wit_env.scratch_state[i]),
.push(mips_wit_env.scratch_state[i].clone().into()),
Ordering::Equal => mips_trace.trace.get_mut(&instr).unwrap().witness.cols[i]
.push(Fp::from(mips_wit_env.instruction_counter)),
Ordering::Greater => {
Expand Down
20 changes: 11 additions & 9 deletions o1vm/src/pickles/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use ark_ff::UniformRand;
use kimchi::circuits::domains::EvaluationDomains;
use kimchi_msm::expr::E;
use log::debug;
Expand All @@ -7,6 +6,7 @@ use mina_poseidon::{
constants::PlonkSpongeConstantsKimchi,
sponge::{DefaultFqSponge, DefaultFrSponge},
};
use num_bigint::BigUint;
use o1vm::{
cannon::{self, Meta, Start, State},
cannon_cli,
Expand Down Expand Up @@ -95,34 +95,36 @@ pub fn main() -> ExitCode {
constraints
};

let mut curr_proof_inputs: ProofInputs<Vesta> = ProofInputs::new(DOMAIN_SIZE);
let mut curr_proof_inputs: ProofInputs = ProofInputs::new(DOMAIN_SIZE);
while !mips_wit_env.halt {
let _instr: Instruction = mips_wit_env.step(&configuration, &meta, &start);
for (scratch, scratch_chunk) in mips_wit_env
.scratch_state
.iter()
.zip(curr_proof_inputs.evaluations.scratch.iter_mut())
{
scratch_chunk.push(*scratch);
scratch_chunk.push(scratch.clone());
}
for (scratch, scratch_chunk) in mips_wit_env
.scratch_state_inverse
.iter()
.zip(curr_proof_inputs.evaluations.scratch_inverse.iter_mut())
{
scratch_chunk.push(*scratch);
scratch_chunk.push(scratch.clone());
}
curr_proof_inputs
.evaluations
.instruction_counter
.push(Fp::from(mips_wit_env.instruction_counter));
.push(BigUint::from(mips_wit_env.instruction_counter));
// FIXME: Might be another value
curr_proof_inputs.evaluations.error.push(Fp::rand(&mut rng));

curr_proof_inputs
.evaluations
.selector
.push(Fp::from((mips_wit_env.selector - N_MIPS_REL_COLS) as u64));
.error
.push(BigUint::from(42_u32));

curr_proof_inputs.evaluations.selector.push(BigUint::from(
(mips_wit_env.selector - N_MIPS_REL_COLS) as u64,
));

if curr_proof_inputs.evaluations.instruction_counter.len() == DOMAIN_SIZE {
// FIXME
Expand Down
7 changes: 4 additions & 3 deletions o1vm/src/pickles/proof.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use kimchi::{curve::KimchiCurve, proof::PointEvaluations};
use num_bigint::BigUint;
use poly_commitment::{ipa::OpeningProof, PolyComm};

use crate::interpreters::mips::column::{N_MIPS_SEL_COLS, SCRATCH_SIZE, SCRATCH_SIZE_INVERSE};
Expand All @@ -11,11 +12,11 @@ pub struct WitnessColumns<G, S> {
pub selector: S,
}

pub struct ProofInputs<G: KimchiCurve> {
pub evaluations: WitnessColumns<Vec<G::ScalarField>, Vec<G::ScalarField>>,
pub struct ProofInputs {
pub evaluations: WitnessColumns<Vec<BigUint>, Vec<BigUint>>,
}

impl<G: KimchiCurve> ProofInputs<G> {
impl ProofInputs {
pub fn new(domain_size: usize) -> Self {
ProofInputs {
evaluations: WitnessColumns {
Expand Down
30 changes: 21 additions & 9 deletions o1vm/src/pickles/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ use kimchi::{
};
use log::debug;
use mina_poseidon::{sponge::ScalarChallenge, FqSponge};
use o1_utils::ExtendedDensePolynomial;
use num_bigint::BigUint;
use o1_utils::{ExtendedDensePolynomial, FieldHelpers};
use poly_commitment::{
commitment::{absorb_commitment, PolyComm},
ipa::{DensePolynomialOrEvaluations, OpeningProof, SRS},
Expand Down Expand Up @@ -59,7 +60,7 @@ pub fn prove<
>(
domain: EvaluationDomains<G::ScalarField>,
srs: &SRS<G>,
inputs: ProofInputs<G>,
inputs: ProofInputs,
constraints: &[E<G::ScalarField>],
rng: &mut RNG,
) -> Result<Proof<G>, ProverError>
Expand Down Expand Up @@ -93,29 +94,40 @@ where
let domain_size = domain.d1.size as usize;

// Build the selectors
let selector: [Vec<G::ScalarField>; N_MIPS_SEL_COLS] = array::from_fn(|i| {
let selector: [Vec<BigUint>; N_MIPS_SEL_COLS] = array::from_fn(|i| {
let mut s_i = Vec::with_capacity(domain_size);
for s in &selector {
s_i.push(if G::ScalarField::from(i as u64) == *s {
G::ScalarField::one()
s_i.push(if BigUint::from(i as u64) == *s {
BigUint::one()
} else {
G::ScalarField::zero()
BigUint::zero()
})
}
s_i
});

let eval_col = |evals: Vec<G::ScalarField>| {
let eval_col = |evals: Vec<BigUint>| {
let evals: Vec<G::ScalarField> = evals
.into_par_iter()
.map(|x| G::ScalarField::from_biguint(&x).unwrap())
.collect();
Evaluations::<G::ScalarField, D<G::ScalarField>>::from_vec_and_domain(evals, domain.d1)
.interpolate()
};
// Doing in parallel
let scratch = scratch.into_par_iter().map(eval_col).collect::<Vec<_>>();
let scratch_inverse = scratch_inverse
.into_par_iter()
.map(|mut evals| {
.map(|evals| {
let mut evals: Vec<G::ScalarField> = evals
.into_par_iter()
.map(|x| G::ScalarField::from_biguint(&x).unwrap())
.collect();
ark_ff::batch_inversion(&mut evals);
eval_col(evals)
Evaluations::<G::ScalarField, D<G::ScalarField>>::from_vec_and_domain(
evals, domain.d1,
)
.interpolate()
})
.collect::<Vec<_>>();
let selector = selector.into_par_iter().map(eval_col).collect::<Vec<_>>();
Expand Down
13 changes: 7 additions & 6 deletions o1vm/src/pickles/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use mina_poseidon::{
constants::PlonkSpongeConstantsKimchi,
sponge::{DefaultFqSponge, DefaultFrSponge},
};
use num_bigint::BigUint;
use o1_utils::tests::make_test_rng;
use poly_commitment::SRS;
use strum::IntoEnumIterator;
Expand Down Expand Up @@ -55,24 +56,24 @@ fn test_regression_constraints_with_selectors() {
assert_eq!(max_degree, MAXIMUM_DEGREE_CONSTRAINTS);
}

fn zero_to_n_minus_one(n: usize) -> Vec<Fq> {
(0..n).map(|i| Fq::from((i) as u64)).collect()
fn zero_to_n_minus_one(n: usize) -> Vec<BigUint> {
(0..n).map(|i| BigUint::from((i) as u64)).collect()
}

#[test]
fn test_small_circuit() {
let domain = EvaluationDomains::<Fq>::create(8).unwrap();
let srs = SRS::create(8);
let proof_input = ProofInputs::<Pallas> {
let proof_input = ProofInputs {
evaluations: WitnessColumns {
scratch: std::array::from_fn(|_| zero_to_n_minus_one(8)),
scratch_inverse: std::array::from_fn(|_| (0..8).map(|_| Fq::zero()).collect()),
scratch_inverse: std::array::from_fn(|_| (0..8).map(|_| BigUint::zero()).collect()),
instruction_counter: zero_to_n_minus_one(8)
.into_iter()
.map(|x| x + Fq::one())
.map(|x| x + BigUint::one())
.collect(),
error: (0..8)
.map(|i| -Fq::from((i * SCRATCH_SIZE + (i + 1)) as u64))
.map(|i| BigUint::from((i * SCRATCH_SIZE + (i + 1)) as u64))
.collect(),
selector: zero_to_n_minus_one(8),
},
Expand Down

0 comments on commit a033449

Please sign in to comment.