Skip to content

Commit

Permalink
fix: add memory_layout to preprocessing so verifier doesn't rely on t…
Browse files Browse the repository at this point in the history
…he prover (#494)

* fix: add memory_layout to preprocessing so verifier doesn't rely on the prover

* add a test to show that the proof's memory_layout is not used

* fmt

* setup memory_layout based on attributes not program.trace()

* remove unnecessary assertions in tests

* build warnings
  • Loading branch information
sagar-a16z authored Nov 1, 2024
1 parent ab44f03 commit dbb76f3
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 34 deletions.
18 changes: 16 additions & 2 deletions jolt-core/src/benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,14 @@ where
let (io_device, trace) = program.trace();

let preprocessing: crate::jolt::vm::JoltPreprocessing<C, F, PCS, ProofTranscript> =
RV32IJoltVM::preprocess(bytecode.clone(), memory_init, 1 << 20, 1 << 20, 1 << 22);
RV32IJoltVM::preprocess(
bytecode.clone(),
io_device.memory_layout.clone(),
memory_init,
1 << 20,
1 << 20,
1 << 22,
);

let (jolt_proof, jolt_commitments, _) =
<RV32IJoltVM as Jolt<_, PCS, C, M, ProofTranscript>>::prove(
Expand Down Expand Up @@ -187,7 +194,14 @@ where
let (io_device, trace) = program.trace();

let preprocessing: crate::jolt::vm::JoltPreprocessing<C, F, PCS, ProofTranscript> =
RV32IJoltVM::preprocess(bytecode.clone(), memory_init, 1 << 20, 1 << 20, 1 << 22);
RV32IJoltVM::preprocess(
bytecode.clone(),
io_device.memory_layout.clone(),
memory_init,
1 << 20,
1 << 20,
1 << 22,
);

let (jolt_proof, jolt_commitments, _) =
<RV32IJoltVM as Jolt<_, PCS, C, M, ProofTranscript>>::prove(
Expand Down
4 changes: 2 additions & 2 deletions jolt-core/src/host/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ impl Program {

// TODO(moodlezoup): Make this generic over InstructionSet
#[tracing::instrument(skip_all, name = "Program::trace")]
pub fn trace(mut self) -> (JoltDevice, Vec<JoltTraceStep<RV32I>>) {
pub fn trace(&mut self) -> (JoltDevice, Vec<JoltTraceStep<RV32I>>) {
self.build();
let elf = self.elf.unwrap();
let elf = self.elf.clone().unwrap();
let (raw_trace, io_device) =
tracer::trace(&elf, &self.input, self.max_input_size, self.max_output_size);

Expand Down
41 changes: 32 additions & 9 deletions jolt-core/src/jolt/vm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::r1cs::constraints::R1CSConstraints;
use crate::r1cs::spartan::{self, UniformSpartanProof};
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use common::constants::RAM_START_ADDRESS;
use common::rv_trace::NUM_CIRCUIT_FLAGS;
use common::rv_trace::{MemoryLayout, NUM_CIRCUIT_FLAGS};
use serde::{Deserialize, Serialize};
use std::marker::PhantomData;
use strum::EnumCount;
Expand Down Expand Up @@ -60,6 +60,7 @@ where
pub instruction_lookups: InstructionLookupsPreprocessing<C, F>,
pub bytecode: BytecodePreprocessing<F>,
pub read_write_memory: ReadWriteMemoryPreprocessing,
pub memory_layout: MemoryLayout,
}

#[derive(Clone, Serialize, Deserialize, Debug)]
Expand Down Expand Up @@ -276,6 +277,7 @@ where
#[tracing::instrument(skip_all, name = "Jolt::preprocess")]
fn preprocess(
bytecode: Vec<ELFInstruction>,
memory_layout: MemoryLayout,
memory_init: Vec<(u64, u8)>,
max_bytecode_size: usize,
max_memory_address: usize,
Expand Down Expand Up @@ -336,6 +338,7 @@ where

JoltPreprocessing {
generators,
memory_layout,
instruction_lookups: instruction_lookups_preprocessing,
bytecode: bytecode_preprocessing,
read_write_memory: read_write_memory_preprocessing,
Expand Down Expand Up @@ -368,7 +371,12 @@ where
JoltTraceStep::pad(&mut trace);

let mut transcript = ProofTranscript::new(b"Jolt transcript");
Self::fiat_shamir_preamble(&mut transcript, &program_io, trace_length);
Self::fiat_shamir_preamble(
&mut transcript,
&program_io,
&program_io.memory_layout,
trace_length,
);

let instruction_polynomials =
InstructionLookupsProof::<
Expand Down Expand Up @@ -539,11 +547,16 @@ where
opening_accumulator
.compare_to(debug_info.opening_accumulator, &preprocessing.generators);
}
Self::fiat_shamir_preamble(&mut transcript, &proof.program_io, proof.trace_length);
Self::fiat_shamir_preamble(
&mut transcript,
&proof.program_io,
&preprocessing.memory_layout,
proof.trace_length,
);

// Regenerate the uniform Spartan key
let padded_trace_length = proof.trace_length.next_power_of_two();
let memory_start = RAM_START_ADDRESS - proof.program_io.memory_layout.ram_witness_offset;
let memory_start = RAM_START_ADDRESS - preprocessing.memory_layout.ram_witness_offset;
let r1cs_builder =
Self::Constraints::construct_constraints(padded_trace_length, memory_start);
let spartan_key = spartan::UniformSpartanProof::<C, _, F, ProofTranscript>::setup(
Expand Down Expand Up @@ -586,6 +599,7 @@ where
Self::verify_memory(
&mut preprocessing.read_write_memory,
&preprocessing.generators,
&preprocessing.memory_layout,
proof.read_write_memory,
&commitments,
proof.program_io,
Expand Down Expand Up @@ -657,19 +671,27 @@ where
)
}

#[allow(clippy::too_many_arguments)]
#[tracing::instrument(skip_all)]
fn verify_memory<'a>(
preprocessing: &mut ReadWriteMemoryPreprocessing,
generators: &PCS::Setup,
memory_layout: &MemoryLayout,
proof: ReadWriteMemoryProof<F, PCS, ProofTranscript>,
commitment: &'a JoltCommitments<PCS, ProofTranscript>,
program_io: JoltDevice,
opening_accumulator: &mut VerifierOpeningAccumulator<F, PCS, ProofTranscript>,
transcript: &mut ProofTranscript,
) -> Result<(), ProofVerifyError> {
assert!(program_io.inputs.len() <= program_io.memory_layout.max_input_size as usize);
assert!(program_io.outputs.len() <= program_io.memory_layout.max_output_size as usize);
preprocessing.program_io = Some(program_io);
assert!(program_io.inputs.len() <= memory_layout.max_input_size as usize);
assert!(program_io.outputs.len() <= memory_layout.max_output_size as usize);
// pair the memory layout with the program io from the proof
preprocessing.program_io = Some(JoltDevice {
inputs: program_io.inputs,
outputs: program_io.outputs,
panic: program_io.panic,
memory_layout: memory_layout.clone(),
});

ReadWriteMemoryProof::verify(
proof,
Expand Down Expand Up @@ -701,15 +723,16 @@ where
fn fiat_shamir_preamble(
transcript: &mut ProofTranscript,
program_io: &JoltDevice,
memory_layout: &MemoryLayout,
trace_length: usize,
) {
transcript.append_u64(trace_length as u64);
transcript.append_u64(C as u64);
transcript.append_u64(M as u64);
transcript.append_u64(Self::InstructionSet::COUNT as u64);
transcript.append_u64(Self::Subtables::COUNT as u64);
transcript.append_u64(program_io.memory_layout.max_input_size);
transcript.append_u64(program_io.memory_layout.max_output_size);
transcript.append_u64(memory_layout.max_input_size);
transcript.append_u64(memory_layout.max_output_size);
transcript.append_bytes(&program_io.inputs);
transcript.append_bytes(&program_io.outputs);
transcript.append_u64(program_io.panic as u64);
Expand Down
6 changes: 3 additions & 3 deletions jolt-core/src/jolt/vm/read_write_memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ use super::{JoltPolynomials, JoltStuff, JoltTraceStep};
pub struct ReadWriteMemoryPreprocessing {
min_bytecode_address: u64,
pub bytecode_bytes: Vec<u8>,
// HACK: The verifier will populate this field by copying it
// over from the `ReadWriteMemoryProof`. Having `program_io` in
// this preprocessing struct allows the verifier to access it
// HACK: The verifier will populate this field by copying inputs/outputs from the
// `ReadWriteMemoryProof` and the memory layout from preprocessing.
// Having `program_io` in this preprocessing struct allows the verifier to access it
// to compute the v_init and v_final openings, with no impact
// on existing function signatures.
pub program_io: Option<JoltDevice>,
Expand Down
104 changes: 87 additions & 17 deletions jolt-core/src/jolt/vm/rv32i_vm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,14 @@ mod tests {
let (io_device, trace) = program.trace();
drop(artifact_guard);

let preprocessing =
RV32IJoltVM::preprocess(bytecode.clone(), memory_init, 1 << 20, 1 << 20, 1 << 20);
let preprocessing = RV32IJoltVM::preprocess(
bytecode.clone(),
io_device.memory_layout.clone(),
memory_init,
1 << 20,
1 << 20,
1 << 20,
);
let (proof, commitments, debug_info) =
<RV32IJoltVM as Jolt<F, PCS, C, M, ProofTranscript>>::prove(
io_device,
Expand Down Expand Up @@ -371,8 +377,14 @@ mod tests {
let (bytecode, memory_init) = program.decode();
let (io_device, trace) = program.trace();

let preprocessing =
RV32IJoltVM::preprocess(bytecode.clone(), memory_init, 1 << 20, 1 << 20, 1 << 20);
let preprocessing = RV32IJoltVM::preprocess(
bytecode.clone(),
io_device.memory_layout.clone(),
memory_init,
1 << 20,
1 << 20,
1 << 20,
);
let (jolt_proof, jolt_commitments, debug_info) =
<RV32IJoltVM as Jolt<
_,
Expand Down Expand Up @@ -401,8 +413,14 @@ mod tests {
let (io_device, trace) = program.trace();
drop(guard);

let preprocessing =
RV32IJoltVM::preprocess(bytecode.clone(), memory_init, 1 << 20, 1 << 20, 1 << 20);
let preprocessing = RV32IJoltVM::preprocess(
bytecode.clone(),
io_device.memory_layout.clone(),
memory_init,
1 << 20,
1 << 20,
1 << 20,
);
let (jolt_proof, jolt_commitments, debug_info) =
<RV32IJoltVM as Jolt<
_,
Expand Down Expand Up @@ -431,8 +449,14 @@ mod tests {
let (io_device, trace) = program.trace();
drop(guard);

let preprocessing =
RV32IJoltVM::preprocess(bytecode.clone(), memory_init, 1 << 20, 1 << 20, 1 << 20);
let preprocessing = RV32IJoltVM::preprocess(
bytecode.clone(),
io_device.memory_layout.clone(),
memory_init,
1 << 20,
1 << 20,
1 << 20,
);
let (jolt_proof, jolt_commitments, debug_info) = <RV32IJoltVM as Jolt<
_,
Zeromorph<Bn254, KeccakTranscript>,
Expand Down Expand Up @@ -462,8 +486,14 @@ mod tests {
let (io_device, trace) = program.trace();
drop(guard);

let preprocessing =
RV32IJoltVM::preprocess(bytecode.clone(), memory_init, 1 << 20, 1 << 20, 1 << 20);
let preprocessing = RV32IJoltVM::preprocess(
bytecode.clone(),
io_device.memory_layout.clone(),
memory_init,
1 << 20,
1 << 20,
1 << 20,
);
let (jolt_proof, jolt_commitments, debug_info) = <RV32IJoltVM as Jolt<
_,
HyperKZG<Bn254, KeccakTranscript>,
Expand Down Expand Up @@ -495,8 +525,14 @@ mod tests {
io_device.outputs[0] = 0; // change the output to 0
drop(artifact_guard);

let preprocessing =
RV32IJoltVM::preprocess(bytecode.clone(), memory_init, 1 << 20, 1 << 20, 1 << 20);
let preprocessing = RV32IJoltVM::preprocess(
bytecode.clone(),
io_device.memory_layout.clone(),
memory_init,
1 << 20,
1 << 20,
1 << 20,
);
let (proof, commitments, debug_info) = <RV32IJoltVM as Jolt<
Fr,
HyperKZG<Bn254, KeccakTranscript>,
Expand All @@ -506,12 +542,46 @@ mod tests {
>>::prove(
io_device, trace, preprocessing.clone()
);
let verification_result =
let _verification_result =
RV32IJoltVM::verify(preprocessing, proof, commitments, debug_info);
assert!(
verification_result.is_ok(),
"Verification failed with error: {:?}",
verification_result.err()
}

#[test]
#[should_panic]
fn malicious_trace() {
let artifact_guard = FIB_FILE_LOCK.lock().unwrap();
let mut program = host::Program::new("fibonacci-guest");
program.set_input(&1u8); // change input to 1 so that termination bit equal true
let (bytecode, memory_init) = program.decode();
let (mut io_device, trace) = program.trace();
let memory_layout = io_device.memory_layout.clone();
drop(artifact_guard);

// change memory address of output & termination bit to the same address as input
// changes here should not be able to spoof the verifier result
io_device.memory_layout.output_start = io_device.memory_layout.input_start;
io_device.memory_layout.output_end = io_device.memory_layout.input_end;
io_device.memory_layout.termination = io_device.memory_layout.input_start;

// Since the preprocessing is done with the original memory layout, the verifier should fail
let preprocessing = RV32IJoltVM::preprocess(
bytecode.clone(),
memory_layout,
memory_init,
1 << 20,
1 << 20,
1 << 20,
);
let (proof, commitments, debug_info) = <RV32IJoltVM as Jolt<
Fr,
HyperKZG<Bn254, KeccakTranscript>,
C,
M,
KeccakTranscript,
>>::prove(
io_device, trace, preprocessing.clone()
);
let _verification_result =
RV32IJoltVM::verify(preprocessing, proof, commitments, debug_info);
}
}
6 changes: 6 additions & 0 deletions jolt-sdk/macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ impl MacroBuilder {
}

fn make_preprocess_func(&self) -> TokenStream2 {
let attributes = parse_attributes(&self.attr);
let max_input_size = proc_macro2::Literal::u64_unsuffixed(attributes.max_input_size);
let max_output_size = proc_macro2::Literal::u64_unsuffixed(attributes.max_output_size);
let set_mem_size = self.make_set_linker_parameters();
let guest_name = self.get_guest_name();
let imports = self.make_imports();
Expand All @@ -199,11 +202,13 @@ impl MacroBuilder {
#set_std
#set_mem_size
let (bytecode, memory_init) = program.decode();
let memory_layout = MemoryLayout::new(#max_input_size, #max_output_size);

// TODO(moodlezoup): Feed in size parameters via macro
let preprocessing: JoltPreprocessing<4, jolt::F, jolt::PCS, jolt::ProofTranscript> =
RV32IJoltVM::preprocess(
bytecode,
memory_layout,
memory_init,
1 << 20,
1 << 20,
Expand Down Expand Up @@ -409,6 +414,7 @@ impl MacroBuilder {
RV32IJoltProof,
BytecodeRow,
MemoryOp,
MemoryLayout,
MEMORY_OPS_PER_INSTRUCTION,
instruction::add::ADDInstruction,
tracer,
Expand Down
2 changes: 1 addition & 1 deletion jolt-sdk/src/host_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ pub use jolt_core::{field::JoltField, poly::commitment::hyperkzg::HyperKZG};

pub use common::{
constants::MEMORY_OPS_PER_INSTRUCTION,
rv_trace::{MemoryOp, RV32IM},
rv_trace::{MemoryLayout, MemoryOp, RV32IM},
};
pub use jolt_core::host;
pub use jolt_core::jolt::instruction;
Expand Down

0 comments on commit dbb76f3

Please sign in to comment.