diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index 3143f2287..62af3ab96 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -122,7 +122,7 @@ pub struct ZKVMProof> { pub raw_pi: Vec>, // the evaluation of raw_pi. pub pi_evals: Vec, - opcode_proofs: BTreeMap)>, + opcode_proofs: BTreeMap>)>, table_proofs: BTreeMap)>, } diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 2c8cae8bc..ff6adbc6e 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -1,3 +1,4 @@ +use core::assert_eq; use ff_ext::ExtensionField; use std::{ collections::{BTreeMap, BTreeSet, HashMap}, @@ -8,7 +9,7 @@ use ff::Field; use itertools::{Itertools, enumerate, izip}; use mpcs::PolynomialCommitmentScheme; use multilinear_extensions::{ - mle::{IntoMLE, MultilinearExtension}, + mle::{DenseMultilinearExtension, IntoMLE, MultilinearExtension}, util::ceil_log2, virtual_poly::build_eq_x_r_vec, virtual_poly_v2::ArcMultilinearExtension, @@ -36,6 +37,7 @@ use crate::{ }, utils::{get_challenge_pows, next_pow2_instance_padding, optimal_sumcheck_threads}, virtual_polys::VirtualPolynomials, + witness::RowMajorMatrix, }; use super::{PublicValues, ZKVMOpcodeProof, ZKVMProof, ZKVMTableProof}; @@ -90,33 +92,59 @@ impl> ZKVMProver { } exit_span!(span); + // TODO: is it better to set different size of different opcode? + let shard_size = 1048576; + // commit to main traces - let mut commitments = BTreeMap::new(); - let mut wits = BTreeMap::new(); + // TODO: (1) is it ok to store mle? (2) replace tuple with struct? + #[allow(clippy::type_complexity)] + let mut wits_and_commitments: BTreeMap< + String, + Vec<( + RowMajorMatrix<_>, + Vec>, + PCS::CommitmentWithWitness, + )>, + > = BTreeMap::new(); let commit_to_traces_span = entered_span!("commit_to_traces", profiling_1 = true); // commit to opcode circuits first and then commit to table circuits, sorted by name for (circuit_name, witness) in witnesses.into_iter_sorted() { let num_instances = witness.num_instances(); + tracing::debug!( + "committing {} witnesses of size {}..", + circuit_name, + num_instances + ); + if num_instances == 0 { + wits_and_commitments.insert(circuit_name.clone(), Vec::new()); + continue; + } let span = entered_span!( "commit to iteration", circuit_name = circuit_name, profiling_2 = true ); - let witness = match num_instances { - 0 => vec![], - _ => { - let witness = witness.into_mles(); - commitments.insert( - circuit_name.clone(), - PCS::batch_commit_and_write(&self.pk.pp, &witness, &mut transcript) - .map_err(ZKVMError::PCSError)?, - ); - witness - } - }; + + let witness_shards = witness.shard_by_rows(shard_size); + if witness_shards.len() > 1 { + tracing::info!( + "split {circuit_name} witness into {} shards", + witness_shards.len() + ); + } + let witness_and_commitment: Vec<_> = witness_shards + .into_iter() + .map(|witness| -> Result<_, ZKVMError> { + let witness_mles = witness.clone().into_mles(); + let commitment = + PCS::batch_commit_and_write(&self.pk.pp, &witness_mles, &mut transcript) + .map_err(ZKVMError::PCSError)?; + Ok((witness, witness_mles, commitment)) + }) + .collect::, _>>()?; + wits_and_commitments.insert(circuit_name, witness_and_commitment); exit_span!(span); - wits.insert(circuit_name, (witness, num_instances)); } exit_span!(commit_to_traces_span); @@ -135,13 +163,14 @@ impl> ZKVMProver { .iter() // Sorted by key. .zip_eq(transcripts.iter_mut().enumerate()) { - let (witness, num_instances) = wits + let mut witness_and_wit: Vec<_> = wits_and_commitments .remove(circuit_name) .ok_or(ZKVMError::WitnessNotFound(circuit_name.clone()))?; - if witness.is_empty() { + + if witness_and_wit.is_empty() { continue; } - let wits_commit = commitments.remove(circuit_name).unwrap(); + // TODO: add an enum for circuit type either in constraint_system or vk let cs = pk.get_cs(); let is_opcode_circuit = cs.lk_table_expressions.is_empty() @@ -157,31 +186,36 @@ impl> ZKVMProver { cs.w_expressions.len(), cs.lk_expressions.len(), ); - let opcode_proof = self.create_opcode_proof( - circuit_name, - &self.pk.pp, - pk, - witness.into_iter().map(|w| w.into()).collect_vec(), - wits_commit, - &pi, - num_instances, - transcript, - &challenges, - )?; - tracing::info!( - "generated proof for opcode {} with num_instances={}", - circuit_name, - num_instances - ); + let opcode_proof: Vec<_> = witness_and_wit.into_iter().enumerate().map(|(idx, (witness, mles, wits_commit))| -> Result<_, ZKVMError> { + let num_instances = witness.num_instances(); + let proof = self.create_opcode_proof( + circuit_name, + &self.pk.pp, + pk, + mles.into_iter().map(|v| v.into()).collect_vec(), + wits_commit, + &pi, + num_instances, + transcript, + &challenges, + )?; + tracing::info!( + "generated proof for opcode {circuit_name} with num_instances={num_instances}, shard idx {idx}" + ); + Ok(proof) + }).collect::, _>>()?; vm_proof .opcode_proofs .insert(circuit_name.clone(), (i, opcode_proof)); } else { + assert_eq!(witness_and_wit.len(), 1); + let (witness, mles, wits_commit) = witness_and_wit.remove(0); + let num_instances = witness.num_instances(); let (table_proof, pi_in_evals) = self.create_table_proof( circuit_name, &self.pk.pp, pk, - witness.into_iter().map(|v| v.into()).collect_vec(), + mles.into_iter().map(|v| v.into()).collect_vec(), wits_commit, &pi, transcript, diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index e3cb38b2d..3759d8513 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -61,10 +61,11 @@ impl> ZKVMVerifier does_halt: bool, ) -> Result { // require ecall/halt proof to exist, depending whether we expect a halt. + // TODO: make it less adhoc let num_instances = vm_proof .opcode_proofs .get(&HaltInstruction::::name()) - .map(|(_, p)| p.num_instances) + .map(|(_, p)| p[0].num_instances) .unwrap_or(0); if num_instances != (does_halt as usize) { return Err(ZKVMError::VerifyError(format!( @@ -119,8 +120,10 @@ impl> ZKVMVerifier for (name, (_, proof)) in vm_proof.opcode_proofs.iter() { tracing::debug!("read {}'s commit", name); - PCS::write_commitment(&proof.wits_commit, &mut transcript) - .map_err(ZKVMError::PCSError)?; + for p in proof { + PCS::write_commitment(&p.wits_commit, &mut transcript) + .map_err(ZKVMError::PCSError)?; + } } for (name, (_, proof)) in vm_proof.table_proofs.iter() { tracing::debug!("read {}'s commit", name); @@ -140,7 +143,7 @@ impl> ZKVMVerifier let point_eval = PointAndEval::default(); let mut transcripts = transcript.fork(self.vk.circuit_vks.len()); - for (name, (i, opcode_proof)) in vm_proof.opcode_proofs { + for (name, (i, opcode_proofs)) in vm_proof.opcode_proofs { let transcript = &mut transcripts[i]; let circuit_vk = self @@ -148,35 +151,39 @@ impl> ZKVMVerifier .circuit_vks .get(&name) .ok_or(ZKVMError::VKNotFound(name.clone()))?; - let _rand_point = self.verify_opcode_proof( - &name, - &self.vk.vp, - circuit_vk, - &opcode_proof, - pi_evals, - transcript, - NUM_FANIN, - &point_eval, - &challenges, - )?; - tracing::info!("verified proof for opcode {}", name); + for opcode_proof in &opcode_proofs { + let _rand_point = self.verify_opcode_proof( + &name, + &self.vk.vp, + circuit_vk, + opcode_proof, + pi_evals, + transcript, + NUM_FANIN, + &point_eval, + &challenges, + )?; + } - // getting the number of dummy padding item that we used in this opcode circuit - let num_lks = circuit_vk.get_cs().lk_expressions.len(); - let num_padded_lks_per_instance = next_pow2_instance_padding(num_lks) - num_lks; - let num_padded_instance = - next_pow2_instance_padding(opcode_proof.num_instances) - opcode_proof.num_instances; - dummy_table_item_multiplicity += num_padded_lks_per_instance - * opcode_proof.num_instances - + num_lks.next_power_of_two() * num_padded_instance; - - prod_r *= opcode_proof.record_r_out_evals.iter().product::(); - prod_w *= opcode_proof.record_w_out_evals.iter().product::(); - - logup_sum += - opcode_proof.lk_p1_out_eval * opcode_proof.lk_q1_out_eval.invert().unwrap(); - logup_sum += - opcode_proof.lk_p2_out_eval * opcode_proof.lk_q2_out_eval.invert().unwrap(); + tracing::info!("verified proof for opcode {}", name); + for opcode_proof in &opcode_proofs { + // getting the number of dummy padding item that we used in this opcode circuit + let num_lks = circuit_vk.get_cs().lk_expressions.len(); + let num_padded_lks_per_instance = next_pow2_instance_padding(num_lks) - num_lks; + let num_padded_instance = next_pow2_instance_padding(opcode_proof.num_instances) + - opcode_proof.num_instances; + dummy_table_item_multiplicity += num_padded_lks_per_instance + * opcode_proof.num_instances + + num_lks.next_power_of_two() * num_padded_instance; + + prod_r *= opcode_proof.record_r_out_evals.iter().product::(); + prod_w *= opcode_proof.record_w_out_evals.iter().product::(); + + logup_sum += + opcode_proof.lk_p1_out_eval * opcode_proof.lk_q1_out_eval.invert().unwrap(); + logup_sum += + opcode_proof.lk_p2_out_eval * opcode_proof.lk_q2_out_eval.invert().unwrap(); + } } for (name, (i, table_proof)) in vm_proof.table_proofs { @@ -471,6 +478,17 @@ impl> ZKVMVerifier } // verify zero expression (degree = 1) statement, thus no sumcheck + for (expr, name) in cs + .assert_zero_expressions + .iter() + .zip_eq(cs.assert_zero_expressions_namespace_map.iter()) + { + if eval_by_expr_with_instance(&[], &proof.wits_in_evals, pi, challenges, expr) + != E::ZERO + { + tracing::error!("checking zero expression {name} failed."); + } + } if cs.assert_zero_expressions.iter().any(|expr| { eval_by_expr_with_instance(&[], &proof.wits_in_evals, pi, challenges, expr) != E::ZERO }) { diff --git a/ceno_zkvm/src/witness.rs b/ceno_zkvm/src/witness.rs index 97150e086..d05a4629f 100644 --- a/ceno_zkvm/src/witness.rs +++ b/ceno_zkvm/src/witness.rs @@ -104,6 +104,33 @@ impl> RowMajorMatrix .chain(padding_iter) .collect::>() } + + pub fn shard_by_rows(&self, shard_rows: usize) -> Vec { + let padded_row_num = self.num_instances() + self.num_padding_instances(); + if padded_row_num <= shard_rows { + return vec![self.clone()]; + } + // padded_row_num and chunk_rows should both be pow of 2. + assert_eq!(padded_row_num % shard_rows, 0); + let shard_num = self.num_instances().div_ceil(shard_rows); + let mut shards = Vec::new(); + for i in 0..shard_num { + let start = i * shard_rows * self.num_col; + let end = ((i + 1) * shard_rows * self.num_col).min(self.values.len()); + let values: Vec<_> = self.values[start..end].to_vec(); + + shards.push(Self { + num_col: self.num_col, + values, + padding_strategy: self.padding_strategy.clone(), + }); + } + assert_eq!( + self.num_instances(), + shards.iter().map(|c| { c.num_instances() }).sum::() + ); + shards + } } impl> RowMajorMatrix {