Skip to content

Commit

Permalink
feat: retrieve prover input per block (#499)
Browse files Browse the repository at this point in the history
* feat: retrieve prover input per block

* fix: cleanup

* fix: into implementation

* fix: nitpick

* fix: review

* fix: review and cleanup
  • Loading branch information
atanmarko authored Aug 22, 2024
1 parent 67dbf7a commit 6bcf06b
Show file tree
Hide file tree
Showing 10 changed files with 231 additions and 191 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

61 changes: 42 additions & 19 deletions zero_bin/leader/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::io::Write;
use std::path::PathBuf;
use std::sync::Arc;

use alloy::rpc::types::{BlockId, BlockNumberOrTag, BlockTransactionsKind};
use alloy::transports::http::reqwest::Url;
use anyhow::Result;
use paladin::runtime::Runtime;
Expand Down Expand Up @@ -34,31 +36,52 @@ pub(crate) async fn client_main(
block_interval: BlockInterval,
mut params: ProofParams,
) -> Result<()> {
let cached_provider = rpc::provider::CachedProvider::new(build_http_retry_provider(
rpc_params.rpc_url.clone(),
rpc_params.backoff,
rpc_params.max_retries,
use futures::{FutureExt, StreamExt};

let cached_provider = Arc::new(rpc::provider::CachedProvider::new(
build_http_retry_provider(
rpc_params.rpc_url.clone(),
rpc_params.backoff,
rpc_params.max_retries,
),
));

let prover_input = rpc::prover_input(
&cached_provider,
block_interval,
params.checkpoint_block_number.into(),
rpc_params.rpc_type,
)
.await?;
// Grab interval checkpoint block state trie
let checkpoint_state_trie_root = cached_provider
.get_block(
params.checkpoint_block_number.into(),
BlockTransactionsKind::Hashes,
)
.await?
.header
.state_root;

let mut block_prover_inputs = Vec::new();
let mut block_interval = block_interval.into_bounded_stream()?;
while let Some(block_num) = block_interval.next().await {
let block_id = BlockId::Number(BlockNumberOrTag::Number(block_num));
// Get future of prover input for particular block.
let block_prover_input = rpc::block_prover_input(
cached_provider.clone(),
block_id,
checkpoint_state_trie_root,
rpc_params.rpc_type,
)
.boxed();
block_prover_inputs.push(block_prover_input);
}

// If `keep_intermediate_proofs` is not set we only keep the last block
// proof from the interval. It contains all the necessary information to
// verify the whole sequence.
let proved_blocks = prover_input
.prove(
&runtime,
params.previous_proof.take(),
params.save_inputs_on_error,
params.proof_output_dir.clone(),
)
.await;
let proved_blocks = prover::prove(
block_prover_inputs,
&runtime,
params.previous_proof.take(),
params.save_inputs_on_error,
params.proof_output_dir.clone(),
)
.await;
runtime.close().await?;
let proved_blocks = proved_blocks?;

Expand Down
20 changes: 13 additions & 7 deletions zero_bin/leader/src/stdio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::io::{Read, Write};
use anyhow::Result;
use paladin::runtime::Runtime;
use proof_gen::proof_types::GeneratedBlockProof;
use prover::ProverInput;
use prover::{BlockProverInput, BlockProverInputFuture};
use tracing::info;

/// The main function for the stdio mode.
Expand All @@ -16,13 +16,19 @@ pub(crate) async fn stdio_main(
std::io::stdin().read_to_string(&mut buffer)?;

let des = &mut serde_json::Deserializer::from_str(&buffer);
let prover_input = ProverInput {
blocks: serde_path_to_error::deserialize(des)?,
};
let block_prover_inputs = serde_path_to_error::deserialize::<_, Vec<BlockProverInput>>(des)?
.into_iter()
.map(Into::into)
.collect::<Vec<BlockProverInputFuture>>();

let proved_blocks = prover_input
.prove(&runtime, previous, save_inputs_on_error, None)
.await;
let proved_blocks = prover::prove(
block_prover_inputs,
&runtime,
previous,
save_inputs_on_error,
None,
)
.await;
runtime.close().await?;
let proved_blocks = proved_blocks?;

Expand Down
165 changes: 82 additions & 83 deletions zero_bin/prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,20 @@ use trace_decoder::{BlockTrace, OtherBlockData};
use tracing::info;
use zero_bin_common::fs::generate_block_proof_file_name;

#[derive(Debug, Deserialize, Serialize)]
pub type BlockProverInputFuture = std::pin::Pin<
Box<dyn Future<Output = std::result::Result<BlockProverInput, anyhow::Error>> + Send>,
>;

impl From<BlockProverInput> for BlockProverInputFuture {
fn from(item: BlockProverInput) -> Self {
async fn _from(item: BlockProverInput) -> Result<BlockProverInput, anyhow::Error> {
Ok(item)
}
Box::pin(_from(item))
}
}

#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct BlockProverInput {
pub block_trace: BlockTrace,
pub other_data: OtherBlockData,
Expand Down Expand Up @@ -113,91 +126,77 @@ impl BlockProverInput {
}
}

#[derive(Debug, Deserialize, Serialize)]
pub struct ProverInput {
pub blocks: Vec<BlockProverInput>,
/// Prove all the blocks in the input.
/// Return the list of block numbers that are proved and if the proof data
/// is not saved to disk, return the generated block proofs as well.
pub async fn prove(
block_prover_inputs: Vec<BlockProverInputFuture>,
runtime: &Runtime,
previous_proof: Option<GeneratedBlockProof>,
save_inputs_on_error: bool,
proof_output_dir: Option<PathBuf>,
) -> Result<Vec<(BlockNumber, Option<GeneratedBlockProof>)>> {
let mut prev: Option<BoxFuture<Result<GeneratedBlockProof>>> =
previous_proof.map(|proof| Box::pin(futures::future::ok(proof)) as BoxFuture<_>);

let mut results = FuturesOrdered::new();
for block_prover_input in block_prover_inputs {
let (tx, rx) = oneshot::channel::<GeneratedBlockProof>();
let proof_output_dir = proof_output_dir.clone();
let previos_block_proof = prev.take();
let fut = async move {
// Get the prover input data from the external source (e.g. Erigon node).
let block = block_prover_input.await?;
let block_number = block.get_block_number();
info!("Proving block {block_number}");

// Prove the block
let block_proof = block
.prove(runtime, previos_block_proof, save_inputs_on_error)
.then(move |proof| async move {
let proof = proof?;
let block_number = proof.b_height;

// Write latest generated proof to disk if proof_output_dir is provided
// or alternatively return proof as function result.
let return_proof: Option<GeneratedBlockProof> =
if let Some(output_dir) = proof_output_dir {
write_proof_to_dir(output_dir, &proof).await?;
None
} else {
Some(proof.clone())
};

if tx.send(proof).is_err() {
anyhow::bail!("Failed to send proof");
}

Ok((block_number, return_proof))
})
.await?;

Ok(block_proof)
}
.boxed();
prev = Some(Box::pin(rx.map_err(anyhow::Error::new)));
results.push_back(fut);
}

results.try_collect().await
}

impl ProverInput {
/// Prove all the blocks in the input.
/// Return the list of block numbers that are proved and if the proof data
/// is not saved to disk, return the generated block proofs as well.
pub async fn prove(
self,
runtime: &Runtime,
previous_proof: Option<GeneratedBlockProof>,
save_inputs_on_error: bool,
proof_output_dir: Option<PathBuf>,
) -> Result<Vec<(BlockNumber, Option<GeneratedBlockProof>)>> {
let mut prev: Option<BoxFuture<Result<GeneratedBlockProof>>> =
previous_proof.map(|proof| Box::pin(futures::future::ok(proof)) as BoxFuture<_>);

let results: FuturesOrdered<_> = self
.blocks
.into_iter()
.map(|block| {
let block_number = block.get_block_number();
info!("Proving block {block_number}");

let (tx, rx) = oneshot::channel::<GeneratedBlockProof>();

// Prove the block
let proof_output_dir = proof_output_dir.clone();
let fut = block
.prove(runtime, prev.take(), save_inputs_on_error)
.then(move |proof| async move {
let proof = proof?;
let block_number = proof.b_height;

// Write latest generated proof to disk if proof_output_dir is provided
let return_proof: Option<GeneratedBlockProof> =
if proof_output_dir.is_some() {
ProverInput::write_proof(proof_output_dir, &proof).await?;
None
} else {
Some(proof.clone())
};

if tx.send(proof).is_err() {
anyhow::bail!("Failed to send proof");
}

Ok((block_number, return_proof))
})
.boxed();

prev = Some(Box::pin(rx.map_err(anyhow::Error::new)));

fut
})
.collect();
/// Write the proof to the `output_dir` directory.
async fn write_proof_to_dir(output_dir: PathBuf, proof: &GeneratedBlockProof) -> Result<()> {
let proof_serialized = serde_json::to_vec(proof)?;
let block_proof_file_path =
generate_block_proof_file_name(&output_dir.to_str(), proof.b_height);

results.try_collect().await
if let Some(parent) = block_proof_file_path.parent() {
tokio::fs::create_dir_all(parent).await?;
}

/// Write the proof to the disk (if `output_dir` is provided) or stdout.
pub(crate) async fn write_proof(
output_dir: Option<PathBuf>,
proof: &GeneratedBlockProof,
) -> Result<()> {
let proof_serialized = serde_json::to_vec(proof)?;
let block_proof_file_path =
output_dir.map(|path| generate_block_proof_file_name(&path.to_str(), proof.b_height));
match block_proof_file_path {
Some(p) => {
if let Some(parent) = p.parent() {
tokio::fs::create_dir_all(parent).await?;
}

let mut f = tokio::fs::File::create(p).await?;
f.write_all(&proof_serialized)
.await
.context("Failed to write proof to disk")
}
None => tokio::io::stdout()
.write_all(&proof_serialized)
.await
.context("Failed to write proof to stdout"),
}
}
let mut f = tokio::fs::File::create(block_proof_file_path).await?;
f.write_all(&proof_serialized)
.await
.context("Failed to write proof to disk")
}
1 change: 1 addition & 0 deletions zero_bin/rpc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ tower = { workspace = true, features = ["retry"] }
trace_decoder = { workspace = true }
tracing-subscriber = { workspace = true }
url = { workspace = true }
itertools = {workspace = true}

# Local dependencies
compat = { workspace = true }
Expand Down
2 changes: 1 addition & 1 deletion zero_bin/rpc/src/jerigon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub struct ZeroTxResult {
}

pub async fn block_prover_input<ProviderT, TransportT>(
cached_provider: &CachedProvider<ProviderT, TransportT>,
cached_provider: std::sync::Arc<CachedProvider<ProviderT, TransportT>>,
target_block_id: BlockId,
checkpoint_state_trie_root: B256,
) -> anyhow::Result<BlockProverInput>
Expand Down
Loading

0 comments on commit 6bcf06b

Please sign in to comment.