Skip to content

Commit

Permalink
Allow num_inputs to be different per witness_sec
Browse files Browse the repository at this point in the history
  • Loading branch information
Kunming Jiang committed Dec 30, 2024
1 parent 2f2ecd7 commit fc385b6
Show file tree
Hide file tree
Showing 8 changed files with 329 additions and 454 deletions.
297 changes: 136 additions & 161 deletions spartan_parallel/src/custom_dense_mlpoly.rs

Large diffs are not rendered by default.

123 changes: 53 additions & 70 deletions spartan_parallel/src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub struct Instance<S: SpartanExtensionField> {
pub digest: Vec<u8>,
}

impl<S: SpartanExtensionField + Send + Sync> Instance<S> {
impl<S: SpartanExtensionField> Instance<S> {
/// Constructs a new `Instance` and an associated satisfying assignment
pub fn new(
num_instances: usize,
Expand All @@ -38,6 +38,8 @@ impl<S: SpartanExtensionField + Send + Sync> Instance<S> {
B: &Vec<Vec<(usize, usize, [u8; 32])>>,
C: &Vec<Vec<(usize, usize, [u8; 32])>>,
) -> Result<Instance<S>, R1CSError> {
let ZERO = S::field_zero();

let (max_num_vars_padded, num_vars_padded, max_num_cons_padded, num_cons_padded) = {
let max_num_vars_padded = {
let mut max_num_vars_padded = max_num_vars;
Expand Down Expand Up @@ -82,12 +84,7 @@ impl<S: SpartanExtensionField + Send + Sync> Instance<S> {
}
}

(
max_num_vars_padded,
num_vars_padded,
max_num_cons_padded,
num_cons_padded,
)
(max_num_vars_padded, num_vars_padded, max_num_cons_padded, num_cons_padded)
};

let bytes_to_scalar =
Expand Down Expand Up @@ -124,7 +121,7 @@ impl<S: SpartanExtensionField + Send + Sync> Instance<S> {
// we do not need to pad otherwise because the dummy constraints are implicit in the sum-check protocol
if num_cons[b] == 0 || num_cons[b] == 1 {
for i in tups.len()..num_cons_padded[b] {
mat.push((i, num_vars[b], S::field_zero()));
mat.push((i, num_vars[b], ZERO));
}
}

Expand Down Expand Up @@ -245,10 +242,10 @@ impl<S: SpartanExtensionField + Send + Sync> Instance<S> {
/// Verify the correctness of each block execution, as well as extracting all memory operations
///
/// Input composition: (if every segment exists)
/// INPUT + VAR Challenges BLOCK_W2 BLOCK_W3 BLOCK_W3_SHIFTED
/// 0 1 2 IOW +1 +2 +3 +4 +5 | 0 1 2 3 | 0 1 2 3 4 NIU 1 2 3 2NP +1 +2 +3 +4 | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7
/// v i0 ... PA0 PD0 ... VA0 VD0 ... | tau r r^2 ... | _ _ ZO r*i1 ... MR MC MR ... MR1 MR2 MR3 MC MR1 ... | v x pi D pi D pi D | v x pi D pi D pi D
/// INPUT PHY VIR INPUT PHY VIR INPUT PHY VIR
/// INPUT + VAR BLOCK_W2 Challenges BLOCK_W3 BLOCK_W3_SHIFTED
/// 0 1 2 IOW +1 +2 +3 +4 +5 | 0 1 2 3 4 NIU 1 2 3 2NP +1 +2 +3 +4 | 0 1 2 3 | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7
/// v i0 ... PA0 PD0 ... VA0 VD0 ... | _ _ ZO r*i1 ... MR MC MR ... MR1 MR2 MR3 MC MR1 ... | tau r r^2 ... | v x pi D pi D pi D | v x pi D pi D pi D
/// INPUT PHY VIR INPUT PHY VIR INPUT PHY VIR
///
/// VAR:
/// We assume that the witnesses are of the following format:
Expand All @@ -271,7 +268,7 @@ impl<S: SpartanExtensionField + Send + Sync> Instance<S> {
/// - VMR3 = r^3 * VT
/// - VMC = (1 or VMC[i-1]) * (tau - VA - VMR1 - VMR2 - VMR3)
/// The final product is stored in X = MC[NV - 1]
///
///
/// If in COMMIT_MODE, commit instance by num_vars_per_block, rounded to the nearest power of four
pub fn gen_block_inst<const PRINT_SIZE: bool, const COMMIT_MODE: bool>(
num_instances: usize,
Expand Down Expand Up @@ -306,20 +303,12 @@ impl<S: SpartanExtensionField + Send + Sync> Instance<S> {
max_size_per_group.insert(next_group_size(*num_vars), num_vars.next_power_of_two());
}
}
num_vars_per_block
.iter()
.map(|i| {
max_size_per_group
.get(&next_group_size(*i))
.unwrap()
.clone()
})
.collect()
num_vars_per_block.iter().map(|i| max_size_per_group.get(&next_group_size(*i)).unwrap().clone()).collect()
} else {
vec![num_vars; num_instances]
};

if PRINT_SIZE {
if PRINT_SIZE && !COMMIT_MODE {
println!("\n\n--\nBLOCK INSTS");
println!(
"{:10} {:>4} {:>4} {:>4} {:>4}",
Expand Down Expand Up @@ -348,37 +337,30 @@ impl<S: SpartanExtensionField + Send + Sync> Instance<S> {
let V_VD = |b: usize, i: usize| io_width + 2 * num_phy_ops[b] + 4 * i + 1;
let V_VL = |b: usize, i: usize| io_width + 2 * num_phy_ops[b] + 4 * i + 2;
let V_VT = |b: usize, i: usize| io_width + 2 * num_phy_ops[b] + 4 * i + 3;
// in CHALLENGES, not used if !has_mem_op
let V_tau = |b: usize| num_vars_padded_per_block[b];
let V_r = |b: usize, i: usize| num_vars_padded_per_block[b] + i;
// in BLOCK_W2 / INPUT_W2
let V_input_dot_prod = |b: usize, i: usize| {
if i == 0 {
V_input(0)
} else {
2 * num_vars_padded_per_block[b] + 2 + i
num_vars_padded_per_block[b] + 2 + i
}
};
let V_output_dot_prod =
|b: usize, i: usize| 2 * num_vars_padded_per_block[b] + 2 + (num_inputs_unpadded - 1) + i;
let V_output_dot_prod = |b: usize, i: usize| num_vars_padded_per_block[b] + 2 + (num_inputs_unpadded - 1) + i;
// in BLOCK_W2 / PHY_W2
let V_PMR =
|b: usize, i: usize| 2 * num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * i;
let V_PMC =
|b: usize, i: usize| 2 * num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * i + 1;
let V_PMR = |b: usize, i: usize| num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * i;
let V_PMC = |b: usize, i: usize| num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * i + 1;
// in BLOCK_W2 / VIR_W2
let V_VMR1 = |b: usize, i: usize| {
2 * num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * num_phy_ops[b] + 4 * i
};
let V_VMR2 = |b: usize, i: usize| {
2 * num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * num_phy_ops[b] + 4 * i + 1
};
let V_VMR3 = |b: usize, i: usize| {
2 * num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * num_phy_ops[b] + 4 * i + 2
};
let V_VMC = |b: usize, i: usize| {
2 * num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * num_phy_ops[b] + 4 * i + 3
};
let V_VMR1 =
|b: usize, i: usize| num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * num_phy_ops[b] + 4 * i;
let V_VMR2 =
|b: usize, i: usize| num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * num_phy_ops[b] + 4 * i + 1;
let V_VMR3 =
|b: usize, i: usize| num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * num_phy_ops[b] + 4 * i + 2;
let V_VMC =
|b: usize, i: usize| num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * num_phy_ops[b] + 4 * i + 3;
// in CHALLENGES, not used if !has_mem_op
let V_tau = |b: usize| 2 * num_vars_padded_per_block[b];
let V_r = |b: usize, i: usize| 2 * num_vars_padded_per_block[b] + i;
// in BLOCK_W3
let V_v = |b: usize| 3 * num_vars_padded_per_block[b];
let V_x = |b: usize| 3 * num_vars_padded_per_block[b] + 1;
Expand Down Expand Up @@ -703,7 +685,7 @@ impl<S: SpartanExtensionField + Send + Sync> Instance<S> {
B_list.push(B);
C_list.push(C);

if PRINT_SIZE {
if PRINT_SIZE && !COMMIT_MODE {
let max_nnz = max(tmp_nnz_A, max(tmp_nnz_B, tmp_nnz_C));
let total_var = num_vars_per_block[b]
+ 2 * num_inputs_unpadded.next_power_of_two()
Expand All @@ -724,7 +706,7 @@ impl<S: SpartanExtensionField + Send + Sync> Instance<S> {
}
}

if PRINT_SIZE {
if PRINT_SIZE && !COMMIT_MODE {
println!("Total Num of Blocks: {}", num_instances);
println!("Total Inst Commit Size: {}", total_inst_commit_size);
println!("Total Var Commit Size: {}", total_var_commit_size);
Expand All @@ -744,10 +726,7 @@ impl<S: SpartanExtensionField + Send + Sync> Instance<S> {
max_cons_per_group.insert(num_vars_padded_per_block[i], block_num_cons[i]);
}
}
num_vars_padded_per_block
.iter()
.map(|i| max_cons_per_group.get(i).unwrap().clone())
.collect()
num_vars_padded_per_block.iter().map(|i| max_cons_per_group.get(i).unwrap().clone()).collect()
} else {
block_num_cons
}
Expand All @@ -759,10 +738,7 @@ impl<S: SpartanExtensionField + Send + Sync> Instance<S> {
block_max_num_cons,
num_cons_padded_per_block,
block_num_vars,
num_vars_padded_per_block
.into_iter()
.map(|i| 8 * i)
.collect(),
num_vars_padded_per_block.into_iter().map(|i| 8 * i).collect(),
&A_list,
&B_list,
&C_list,
Expand Down Expand Up @@ -816,9 +792,14 @@ impl<S: SpartanExtensionField + Send + Sync> Instance<S> {
/// D2 = D1 * (ls[i+1] - STORE)
/// Where STORE = 0
/// Input composition:
/// Op[k] Op[k + 1] D2 & bits of ts[k + 1] - ts[k]
/// 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 | 0 1 2 3 4
/// v D1 a d ls ts _ _ | v D1 a d ls ts _ _ | D2 EQ B0 B1 ...
/// bits of ts[k + 1] - ts[k] Op[k] Op[k + 1]
/// 0 1 2 3 4 | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7
/// D2 EQ B0 B1 ... | v D1 a d ls ts _ _ | v D1 a d ls ts _ _
///
/// If ADDR_NONCONSEC, address comparison of VIR uses <= instead of +1, with the following expression
/// ts | addr
/// 0 1 2 3 4 | 0 1 2 3 4 5
/// D2 EQ B0 B1 ... | D4 INV EQ B0 B1 ...
pub fn gen_pairwise_check_inst<const PRINT_SIZE: bool>(
max_ts_width: usize,
mem_addr_ts_bits_size: usize,
Expand All @@ -834,14 +815,14 @@ impl<S: SpartanExtensionField + Send + Sync> Instance<S> {
"", "con", "var", "nnz", "exec"
);
}

// Variable used by printing
let mut total_inst_commit_size = 0;
let mut total_var_commit_size = 0;
let mut total_cons_exec_size = 0;

let pairwise_check_num_vars = max(8, mem_addr_ts_bits_size);
let pairwise_check_max_num_cons = 8 + max_ts_width;
let pairwise_check_num_cons = vec![2, 4, 8 + max_ts_width];
let pairwise_check_num_non_zero_entries: usize = max(13 + max_ts_width, 5 + 2 * max_ts_width);

let pairwise_check_inst = {
Expand Down Expand Up @@ -972,23 +953,24 @@ impl<S: SpartanExtensionField + Send + Sync> Instance<S> {
let (A, B, C) = {
let width = pairwise_check_num_vars;

let V_valid = 0;
// TS_BITS
let V_D2 = 0;
let V_EQ = 1;
let V_B = |i| 2 + i;
// OP[K], OP[K + 1]
let V_valid = width;
let V_cnst = V_valid;
let V_D1 = 1;
let V_addr = 2;
let V_data = 3;
let V_ls = 4;
let V_ts = 5;
let V_D2 = 2 * width;
let V_EQ = 2 * width + 1;
let V_B = |i| 2 * width + 2 + i;
let V_D1 = width + 1;
let V_addr = width + 2;
let V_data = width + 3;
let V_ls = width + 4;
let V_ts = width + 5;

let mut A: Vec<(usize, usize, [u8; 32])> = Vec::new();
let mut B: Vec<(usize, usize, [u8; 32])> = Vec::new();
let mut C: Vec<(usize, usize, [u8; 32])> = Vec::new();

let mut num_cons = 0;
// Sortedness
// (v[k] - 1) * v[k + 1] = 0
(A, B, C) = Instance::<S>::gen_constr(
A,
Expand All @@ -1000,6 +982,7 @@ impl<S: SpartanExtensionField + Send + Sync> Instance<S> {
vec![],
);
num_cons += 1;
// Sortedness
// D1[k] = v[k + 1] * (1 - addr[k + 1] + addr[k])
(A, B, C) = Instance::<S>::gen_constr(
A,
Expand Down Expand Up @@ -1403,4 +1386,4 @@ impl<S: SpartanExtensionField + Send + Sync> Instance<S> {
perm_root_inst,
)
}
}
}
Loading

0 comments on commit fc385b6

Please sign in to comment.