Skip to content

Commit

Permalink
Custom commits (#105)
Browse files Browse the repository at this point in the history
* Custom commits working
  • Loading branch information
RogerTaule authored Nov 8, 2024
1 parent cb18246 commit 3fd2d24
Show file tree
Hide file tree
Showing 41 changed files with 1,061 additions and 278 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ jobs:
with:
repository: 0xPolygonHermez/pil2-proofman-js
token: ${{ secrets.ZISK_CI_TOKEN }}
ref: 0.0.6
ref: 0.0.7
path: pil2-proofman-js

- name: Install pil2-proofman-js dependencies
Expand All @@ -147,7 +147,7 @@ jobs:
with:
repository: 0xPolygonHermez/pil2-compiler
token: ${{ secrets.GITHUB_TOKEN }}
ref: develop
ref: feature/custom-cols
path: pil2-compiler

- name: Install pil2-compiler dependencies
Expand Down
8 changes: 7 additions & 1 deletion cli/assets/templates/pil_helpers_trace.rs.tt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,10 @@ pub use proofman_macros::trace;
trace!({ air.name }Row, { air.name }Trace<F> \{
{{ for column in air.columns }} { column.name }: { column.type },{{ endfor }}
});
{{ endfor }}{{ endfor }}
{{ endfor }}{{ endfor }}

{{ for air_group in air_groups }}{{ for air in air_group.airs }}{{ for custom_commit in air.custom_columns }}
trace!({ air.name }{custom_commit.name}Row, { air.name }{custom_commit.name}Trace<F> \{
{{ for column in custom_commit.custom_columns }} { column.name }: { column.type },{{ endfor }}
});
{{ endfor }}{{ endfor }}{{ endfor }}
29 changes: 26 additions & 3 deletions cli/src/commands/pil_helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,14 @@ struct AirCtx {
name: String,
num_rows: u32,
columns: Vec<ColumnCtx>,
custom_columns: Vec<CustomCommitsCtx>,
}

#[derive(Debug, Serialize)]
struct CustomCommitsCtx {
name: String,
custom_columns: Vec<ColumnCtx>,
}
#[derive(Debug, Serialize)]
struct ColumnCtx {
name: String,
Expand Down Expand Up @@ -111,6 +117,7 @@ impl PilHelpersCmd {
name: air.name.as_ref().unwrap().clone(),
num_rows: air.num_rows.unwrap(),
columns: Vec::new(),
custom_columns: Vec::new(),
})
.collect(),
});
Expand Down Expand Up @@ -139,6 +146,16 @@ impl PilHelpersCmd {
// Build columns data for traces
for (airgroup_id, airgroup) in pilout.air_groups.iter().enumerate() {
for (air_id, _) in airgroup.airs.iter().enumerate() {
let air = wcctxs[airgroup_id].airs.get_mut(air_id).unwrap();
air.custom_columns = pilout.air_groups[airgroup_id].airs[air_id]
.custom_commits
.iter()
.map(|commit| CustomCommitsCtx {
name: commit.name.clone().unwrap().to_case(Case::Pascal),
custom_columns: Vec::new(),
})
.collect();

// Search symbols where airgroup_id == airgroup_id && air_id == air_id && type == WitnessCol
pilout
.symbols
Expand All @@ -149,8 +166,8 @@ impl PilHelpersCmd {
&& symbol.air_id.is_some()
&& symbol.air_id.unwrap() == air_id as u32
&& symbol.stage.is_some()
&& symbol.stage.unwrap() == 1
&& symbol.r#type == SymbolType::WitnessCol as i32
&& ((symbol.r#type == SymbolType::WitnessCol as i32 && symbol.stage.unwrap() == 1)
|| (symbol.r#type == SymbolType::CustomCol as i32 && symbol.stage.unwrap() == 0))
})
.for_each(|symbol| {
let air = wcctxs[airgroup_id].airs.get_mut(air_id).unwrap();
Expand All @@ -165,7 +182,13 @@ impl PilHelpersCmd {
.rev()
.fold("F".to_string(), |acc, &length| format!("[{}; {}]", acc, length))
};
air.columns.push(ColumnCtx { name: name.to_owned(), r#type });
if symbol.r#type == SymbolType::WitnessCol as i32 {
air.columns.push(ColumnCtx { name: name.to_owned(), r#type });
} else {
air.custom_columns[symbol.commit_id.unwrap() as usize]
.custom_columns
.push(ColumnCtx { name: name.to_owned(), r#type });
}
});
}
}
Expand Down
31 changes: 31 additions & 0 deletions common/src/air_instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::{collections::HashMap, os::raw::c_void, sync::Arc};
use p3_field::Field;
use proofman_starks_lib_c::{
get_airval_id_by_name_c, get_n_airgroupvals_c, get_n_airvals_c, get_n_evals_c, get_airgroupval_id_by_name_c,
get_n_custom_commits_c,
};

use crate::SetupCtx;
Expand All @@ -18,6 +19,7 @@ pub struct StepsParams {
pub xdivxsub: *mut c_void,
pub p_const_pols: *mut c_void,
pub p_const_tree: *mut c_void,
pub custom_commits: [*mut c_void; 10],
}

impl From<&StepsParams> for *mut c_void {
Expand All @@ -37,12 +39,14 @@ pub struct AirInstance<F> {
pub idx: Option<usize>,
pub global_idx: Option<usize>,
pub buffer: Vec<F>,
pub custom_commits: Vec<Vec<F>>,
pub airgroup_values: Vec<F>,
pub airvalues: Vec<F>,
pub evals: Vec<F>,
pub commits_calculated: HashMap<usize, bool>,
pub airgroupvalue_calculated: HashMap<usize, bool>,
pub airvalue_calculated: HashMap<usize, bool>,
pub custom_commits_calculated: Vec<HashMap<usize, bool>>,
}

impl<F: Field> AirInstance<F> {
Expand All @@ -55,6 +59,15 @@ impl<F: Field> AirInstance<F> {
) -> Self {
let ps = setup_ctx.get_setup(airgroup_id, air_id);

let custom_commits_calculated = vec![HashMap::new(); get_n_custom_commits_c(ps.p_setup.p_stark_info) as usize];

let mut custom_commits = Vec::new();

let n_custom_commits = get_n_custom_commits_c(ps.p_setup.p_stark_info);
for _ in 0..n_custom_commits {
custom_commits.push(Vec::new());
}

AirInstance {
airgroup_id,
air_id,
Expand All @@ -63,19 +76,33 @@ impl<F: Field> AirInstance<F> {
idx: None,
global_idx: None,
buffer,
custom_commits,
airgroup_values: vec![F::zero(); get_n_airgroupvals_c(ps.p_setup.p_stark_info) as usize * 3],
airvalues: vec![F::zero(); get_n_airvals_c(ps.p_setup.p_stark_info) as usize * 3],
evals: vec![F::zero(); get_n_evals_c(ps.p_setup.p_stark_info) as usize * 3],
commits_calculated: HashMap::new(),
airgroupvalue_calculated: HashMap::new(),
airvalue_calculated: HashMap::new(),
custom_commits_calculated,
}
}

pub fn get_buffer_ptr(&self) -> *mut u8 {
self.buffer.as_ptr() as *mut u8
}

pub fn get_custom_commits_ptr(&self) -> [*mut c_void; 10] {
let mut ptrs = [std::ptr::null_mut(); 10];
for (i, custom_commit) in self.custom_commits.iter().enumerate() {
ptrs[i] = custom_commit.as_ptr() as *mut c_void;
}
ptrs
}

pub fn set_custom_commit_id_buffer(&mut self, buffer: Vec<F>, commit_id: u64) {
self.custom_commits[commit_id as usize] = buffer;
}

pub fn set_airvalue(&mut self, setup_ctx: &SetupCtx<F>, name: &str, value: F) {
let ps = setup_ctx.get_setup(self.airgroup_id, self.air_id);

Expand Down Expand Up @@ -142,6 +169,10 @@ impl<F: Field> AirInstance<F> {
self.commits_calculated.insert(id, true);
}

pub fn set_custom_commit_calculated(&mut self, commit_id: usize, id: usize) {
self.custom_commits_calculated[commit_id].insert(id, true);
}

pub fn set_air_instance_id(&mut self, air_instance_id: usize, idx: usize) {
self.air_instance_id = Some(air_instance_id);
self.idx = Some(idx);
Expand Down
8 changes: 8 additions & 0 deletions common/src/buffer_allocator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,12 @@ pub trait BufferAllocator<F>: Send + Sync {
airgroup_id: usize,
air_id: usize,
) -> Result<(u64, Vec<u64>), Box<dyn Error>>;

fn get_buffer_info_custom_commit(
&self,
sctx: &SetupCtx<F>,
airgroup_id: usize,
air_id: usize,
custom_commit_name: &str,
) -> Result<(u64, Vec<u64>, u64), Box<dyn Error>>;
}
11 changes: 11 additions & 0 deletions common/src/global_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ pub struct ProofValueMap {
#[serde(default)]
pub id: u64,
}
#[derive(Clone, Deserialize)]
pub struct PublicMap {
pub name: String,
#[serde(default)]
pub stage: u64,
#[serde(default)]
pub lengths: Vec<u64>,
}

#[derive(Clone, Deserialize)]
pub struct GlobalInfo {
Expand All @@ -35,6 +43,9 @@ pub struct GlobalInfo {

#[serde(rename = "proofValuesMap")]
pub proof_values_map: Option<Vec<ProofValueMap>>,

#[serde(rename = "publicsMap")]
pub publics_map: Option<Vec<PublicMap>>,
}

#[derive(Clone, Deserialize)]
Expand Down
51 changes: 48 additions & 3 deletions common/src/proof_ctx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,21 @@ use crate::{AirInstancesRepository, GlobalInfo, VerboseMode, WitnessPilout};

pub struct PublicInputs {
pub inputs: RwLock<Vec<u8>>,
pub inputs_set: RwLock<Vec<bool>>,
}

impl Default for PublicInputs {
fn default() -> Self {
Self { inputs: RwLock::new(Vec::new()) }
Self { inputs: RwLock::new(Vec::new()), inputs_set: RwLock::new(Vec::new()) }
}
}

impl PublicInputs {
pub fn new(n_publics: usize) -> Self {
Self {
inputs: RwLock::new(vec![0; n_publics * std::mem::size_of::<u64>()]),
inputs_set: RwLock::new(vec![false; n_publics]),
}
}
}

Expand Down Expand Up @@ -83,11 +93,11 @@ impl<F: Field> ProofCtx<F> {
values: RwLock::new(vec![F::zero(); global_info.n_proof_values * 3]),
values_set: RwLock::new(HashMap::new()),
};

let n_publics = global_info.n_publics;
Self {
pilout,
global_info,
public_inputs: PublicInputs::default(),
public_inputs: PublicInputs::new(n_publics),
proof_values,
challenges: Challenges::default(),
buff_helper: BuffHelper::default(),
Expand Down Expand Up @@ -146,4 +156,39 @@ impl<F: Field> ProofCtx<F> {
pub fn set_proof_value_calculated(&self, id: usize) {
self.proof_values.values_set.write().unwrap().insert(id, true);
}

pub fn set_public_value(&self, value: u64, public_id: u64) {
self.public_inputs.inputs.write().unwrap()[(public_id as usize) * 8..(public_id as usize + 1) * 8]
.copy_from_slice(&value.to_le_bytes());

self.public_inputs.inputs_set.write().unwrap()[public_id as usize] = true;
}

pub fn set_public_value_by_name(&self, value: u64, public_name: &str) {
let n_publics: usize = self.global_info.publics_map.as_ref().expect("REASON").len();
let public_id = (0..n_publics)
.find(|&i| {
let public = self.global_info.publics_map.as_ref().expect("REASON").get(i).unwrap();
public.name == public_name
})
.unwrap_or_else(|| panic!("Name {} not found in publics_map", public_name));

self.set_public_value(value, public_id as u64);
}

pub fn get_public_value(&self, public_name: &str) -> u64 {
let n_publics: usize = self.global_info.publics_map.as_ref().expect("REASON").len();
let public_id = (0..n_publics)
.find(|&i| {
let public = self.global_info.publics_map.as_ref().expect("REASON").get(i).unwrap();
public.name == public_name
})
.unwrap_or_else(|| panic!("Name {} not found in publics_map", public_name));

u64::from_le_bytes(
self.public_inputs.inputs.read().unwrap()[public_id * 8..(public_id + 1) * 8]
.try_into()
.expect("Expected 8 bytes for u64"),
)
}
}
1 change: 1 addition & 0 deletions common/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ pub trait Prover<F: Field> {
fn num_stages(&self) -> u32;
fn get_challenges(&self, stage_id: u32, proof_ctx: Arc<ProofCtx<F>>, transcript: &FFITranscript);
fn calculate_stage(&mut self, stage_id: u32, setup_ctx: Arc<SetupCtx<F>>, proof_ctx: Arc<ProofCtx<F>>);
fn check_stage(&self, stage_id: u32, proof_ctx: Arc<ProofCtx<F>>);
fn commit_stage(&mut self, stage_id: u32, proof_ctx: Arc<ProofCtx<F>>) -> ProverStatus;
fn calculate_xdivxsub(&mut self, proof_ctx: Arc<ProofCtx<F>>);
fn calculate_lev(&mut self, proof_ctx: Arc<ProofCtx<F>>);
Expand Down
22 changes: 11 additions & 11 deletions common/src/setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ impl From<&SetupC> for *mut c_void {
}

#[derive(Debug)]
pub struct ConstPols<F> {
pub const_pols: RwLock<Vec<MaybeUninit<F>>>,
pub struct Pols<F> {
pub values: RwLock<Vec<MaybeUninit<F>>>,
}

impl<F> Default for ConstPols<F> {
impl<F> Default for Pols<F> {
fn default() -> Self {
Self { const_pols: RwLock::new(Vec::new()) }
Self { values: RwLock::new(Vec::new()) }
}
}

Expand All @@ -47,8 +47,8 @@ pub struct Setup<F> {
pub airgroup_id: usize,
pub air_id: usize,
pub p_setup: SetupC,
pub const_pols: ConstPols<F>,
pub const_tree: ConstPols<F>,
pub const_pols: Pols<F>,
pub const_tree: Pols<F>,
}

impl<F> Setup<F> {
Expand Down Expand Up @@ -80,8 +80,8 @@ impl<F> Setup<F> {
air_id,
airgroup_id,
p_setup: SetupC { p_stark_info, p_expressions_bin, p_prover_helpers },
const_pols: ConstPols::default(),
const_tree: ConstPols::default(),
const_pols: Pols::default(),
const_tree: Pols::default(),
}
}

Expand Down Expand Up @@ -109,7 +109,7 @@ impl<F> Setup<F> {

let p_const_pols_address = const_pols.as_ptr() as *mut c_void;
load_const_pols_c(p_const_pols_address, const_pols_path.as_str(), const_size as u64);
*self.const_pols.const_pols.write().unwrap() = const_pols;
*self.const_pols.values.write().unwrap() = const_pols;
}

pub fn load_const_pols_tree(&self, global_info: &GlobalInfo, setup_type: &ProofType, save_file: bool) {
Expand All @@ -132,11 +132,11 @@ impl<F> Setup<F> {
if PathBuf::from(&const_pols_tree_path).exists() {
load_const_tree_c(p_const_tree_address, const_pols_tree_path.as_str(), const_tree_size as u64);
} else {
let const_pols = self.const_pols.const_pols.read().unwrap();
let const_pols = self.const_pols.values.read().unwrap();
let p_const_pols_address = (*const_pols).as_ptr() as *mut c_void;
let tree_filename = if save_file { const_pols_tree_path.as_str() } else { "" };
calculate_const_tree_c(p_stark_info, p_const_pols_address, p_const_tree_address, tree_filename);
};
*self.const_tree.const_pols.write().unwrap() = const_tree;
*self.const_tree.values.write().unwrap() = const_tree;
}
}
Loading

0 comments on commit 3fd2d24

Please sign in to comment.