Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: revamp zero prove function #793

Merged
merged 8 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/prove_stdio.sh
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ cargo build --release --jobs "$num_procs"
start_time=$(date +%s%N)

cmd=("${REPO_ROOT}/target/release/leader" --runtime in-memory \
--load-strategy on-demand -n 1 \
--load-strategy on-demand \
--block-batch-size "$BLOCK_BATCH_SIZE")

if [[ "$USE_TEST_CONFIG" == "use_test_config" ]]; then
Expand Down
6 changes: 3 additions & 3 deletions zero/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::{debug_utils::save_inputs_to_disk, prover_state::p_state};

registry!();

#[derive(Deserialize, Serialize, RemoteExecute)]
#[derive(Deserialize, Serialize, RemoteExecute, Clone)]
pub struct SegmentProof {
pub save_inputs_on_error: bool,
}
Expand Down Expand Up @@ -207,7 +207,7 @@ impl Drop for SegmentProofSpan {
}
}

#[derive(Deserialize, Serialize, RemoteExecute)]
#[derive(Deserialize, Serialize, RemoteExecute, Clone)]
pub struct SegmentAggProof {
pub save_inputs_on_error: bool,
}
Expand Down Expand Up @@ -289,7 +289,7 @@ impl Monoid for SegmentAggProof {
}
}

#[derive(Deserialize, Serialize, RemoteExecute)]
#[derive(Deserialize, Serialize, RemoteExecute, Clone)]
pub struct BatchAggProof {
pub save_inputs_on_error: bool,
}
Expand Down
198 changes: 175 additions & 23 deletions zero/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ use anyhow::{Context, Result};
use evm_arithmetization::Field;
use evm_arithmetization::SegmentDataIterator;
use futures::{
future, future::BoxFuture, stream::FuturesUnordered, FutureExt, TryFutureExt, TryStreamExt,
future::BoxFuture,
future::{self, try_join, try_join_all},
stream::FuturesUnordered,
FutureExt as _, StreamExt as _, TryFutureExt as _, TryStreamExt as _,
};
use hashbrown::HashMap;
use num_traits::ToPrimitive as _;
Expand All @@ -23,10 +26,10 @@ use plonky2::plonk::circuit_data::CircuitConfig;
use serde::{Deserialize, Serialize};
use tokio::io::AsyncWriteExt;
use tokio::sync::mpsc::Receiver;
use tokio::sync::{oneshot, Semaphore};
use tokio::sync::{mpsc, oneshot, Semaphore};
use trace_decoder::observer::DummyObserver;
use trace_decoder::{BlockTrace, OtherBlockData, WireDisposition};
use tracing::{error, info};
use tracing::{debug, error, info};

use crate::fs::generate_block_proof_file_name;
use crate::ops;
Expand Down Expand Up @@ -116,6 +119,8 @@ impl BlockProverInput {
WIRE_DISPOSITION,
)?;

let batch_count = block_generation_inputs.len();

// Create segment proof.
let seg_prove_ops = ops::SegmentProof {
save_inputs_on_error,
Expand All @@ -131,29 +136,176 @@ impl BlockProverInput {
save_inputs_on_error,
};

// Segment the batches, prove segments and aggregate them to resulting batch
// proofs.
let batch_proof_futs: FuturesUnordered<_> = block_generation_inputs
.iter()
.enumerate()
.map(|(idx, txn_batch)| {
let segment_data_iterator =
SegmentDataIterator::<Field>::new(txn_batch, Some(max_cpu_len_log));

Directive::map(IndexedStream::from(segment_data_iterator), &seg_prove_ops)
.fold(&seg_agg_ops)
.run(&proof_runtime.heavy_proof)
.map(move |e| {
e.map(|p| (idx, crate::proof_types::BatchAggregatableProof::from(p)))
})
// Generate channels to communicate segments of each batch to a batch proving
// task. We generate segments and send them to the proving task, where they
// are proven in parallel.
let (segment_senders, segment_receivers): (Vec<_>, Vec<_>) = (0..batch_count)
.map(|_idx| {
let (segment_tx, segment_rx) =
mpsc::channel::<Option<evm_arithmetization::AllData>>(1);
(segment_tx, segment_rx)
})
.collect();
.unzip();

// The size of this channel does not matter much, as it is only used to collect
// batch proofs.
let (batch_proof_tx, mut batch_proof_rx) =
mpsc::channel::<(usize, crate::proof_types::BatchAggregatableProof)>(32);

// Spin up a task for each batch to generate segments for that batch
// and send them to the proving task.
let segment_generation_task = tokio::spawn(async move {
let mut batch_segment_futures: FuturesUnordered<_> = FuturesUnordered::new();

for (batch_idx, (txn_batch, segment_tx)) in block_generation_inputs
.into_iter()
.zip(segment_senders)
.enumerate()
{
batch_segment_futures.push(async move {
let segment_data_iterator =
SegmentDataIterator::<Field>::new(&txn_batch, Some(max_cpu_len_log));
for (segment_idx, segment_data) in segment_data_iterator.enumerate() {
segment_tx
.send(Some(segment_data))
.await
.context(format!("failed to send segment data for batch {batch_idx} segment {segment_idx}"))?;
0xaatif marked this conversation as resolved.
Show resolved Hide resolved
}
// Mark the end of the batch segments by sending `None`
segment_tx
.send(None)
.await
.context(format!("failed to send end segment data indicator for batch {batch_idx}"))?;
anyhow::Ok(())
});
}
while let Some(it) = batch_segment_futures.next().await {
// In case of an error, propagate the error to the main task
it?;
}
atanmarko marked this conversation as resolved.
Show resolved Hide resolved
let () = batch_segment_futures.try_collect().await?;
anyhow::Ok(())
});

let proof_runtime_ = proof_runtime.clone();
atanmarko marked this conversation as resolved.
Show resolved Hide resolved
let batches_proving_task = tokio::spawn(async move {
let mut batch_proving_futures = FuturesUnordered::new();
// Span a proving subtask for each batch where we generate segment proofs
// and aggregate them to batch proof.
for (batch_idx, mut segment_rx) in segment_receivers.into_iter().enumerate() {
let batch_proof_tx = batch_proof_tx.clone();
let seg_prove_ops = seg_prove_ops.clone();
let seg_agg_ops = seg_agg_ops.clone();
let proof_runtime = proof_runtime_.clone();
// Tasks to dispatch proving jobs and aggregate segment proofs of one batch
batch_proving_futures.push(async move {
let mut batch_segment_aggregatable_proofs = Vec::new();

// This channel collects segment proofs from the one batch
// proven in parallel. The size of this channel does not matter much,
// as it is only used to collect segment aggregatable proofs.
let (segment_proof_tx, mut segment_proof_rx) =
mpsc::channel::<(usize, crate::proof_types::SegmentAggregatableProof)>(32);

// Wait for segments and dispatch them to the segment proof worker task.
// The segment proof worker task will prove the segment and send it back.
let mut segment_counter = 0;
let mut segment_proving_tasks = Vec::new();
while let Some(Some(segment_data)) = segment_rx.recv().await {
let seg_prove_ops = seg_prove_ops.clone();
let proof_runtime = proof_runtime.clone();
let segment_proof_tx = segment_proof_tx.clone();
// Prove one segment in a dedicated async task.
let segment_proving_task = tokio::spawn(async move {
debug!(%batch_idx, %segment_counter, "proving batch segment");
let seg_aggregatable_proof= Directive::map(
IndexedStream::from([segment_data]),
&seg_prove_ops,
)
.run(&proof_runtime.heavy_proof)
.await?
.into_values_sorted()
.await?
.into_iter()
.next()
.context(format!(
"failed to get segment proof, batch: {batch_idx}, segment: {segment_counter}"
))?;

segment_proof_tx
.send((segment_counter, seg_aggregatable_proof))
.await
.context(format!(
"unable to send segment proof, batch: {batch_idx}, segment: {segment_counter}"
))?;
anyhow::Ok(())
});

segment_proving_tasks.push(segment_proving_task);
segment_counter += 1;
}
drop(segment_proof_tx);
// Wait for all the segment proving tasks of one batch to finish.
while let Some((segment_idx, segment_aggregatable_proof)) = segment_proof_rx.recv().await {
batch_segment_aggregatable_proofs.push((segment_idx, segment_aggregatable_proof));
}
atanmarko marked this conversation as resolved.
Show resolved Hide resolved
try_join_all(segment_proving_tasks).await?;
batch_segment_aggregatable_proofs.sort_by(|(a, _), (b, _)| a.cmp(b));
debug!(%block_number, batch=%batch_idx, "finished proving all segments");
// We have proved all the segments in a batch,
// now we need to aggregate them to the batch proof.
// Fold the segment aggregated proof stream into a single batch proof.
let batch_proof = if batch_segment_aggregatable_proofs.len() == 1 {
atanmarko marked this conversation as resolved.
Show resolved Hide resolved
// If there is only one segment aggregated proof, just transform it to batch proof.
(batch_idx, crate::proof_types::BatchAggregatableProof::from(
batch_segment_aggregatable_proofs.pop().map(|(_, it)| it).unwrap(),
))
} else {
Directive::fold(IndexedStream::from(batch_segment_aggregatable_proofs.into_iter().map(|(_, it)| it)), &seg_agg_ops)
.run(&proof_runtime.light_proof)
.map(move |e| {
e.map(|p| {
(
batch_idx,
crate::proof_types::BatchAggregatableProof::from(p),
)
})
})
.await?
};
debug!(%block_number, batch=%batch_idx, "generated batch proof for block");
batch_proof_tx.send(batch_proof).await.context(format!(
"unable to send batch proof, block: {block_number}, batch: {batch_idx}"
))?;
anyhow::Ok(())
});
}
// Wait for all the batch proving tasks to finish. Exit early on error.
while let Some(it) = batch_proving_futures.next().await {
it?;
}
anyhow::Ok(())
});

// Collect all the batch proofs.
let mut batch_proofs: Vec<(usize, crate::proof_types::BatchAggregatableProof)> = Vec::new();
while let Some((batch_idx, batch_proof)) = batch_proof_rx.recv().await {
batch_proofs.push((batch_idx, batch_proof));
}
atanmarko marked this conversation as resolved.
Show resolved Hide resolved
debug!(%block_number, "collected all batch proofs");

// Wait for the segment generation and proving tasks to finish.
let _ = try_join(segment_generation_task, batches_proving_task).await?;

batch_proofs.sort_by(|(a, _), (b, _)| a.cmp(b));

// Fold the batch aggregated proof stream into a single proof.
let final_batch_proof =
Directive::fold(IndexedStream::new(batch_proof_futs), &batch_agg_ops)
.run(&proof_runtime.light_proof)
.await?;
let final_batch_proof = Directive::fold(
IndexedStream::from(batch_proofs.into_iter().map(|(_, it)| it)),
&batch_agg_ops,
)
.run(&proof_runtime.light_proof)
.await?;

if let crate::proof_types::BatchAggregatableProof::BatchAgg(proof) = final_batch_proof {
let block_number = block_number
Expand Down
Loading