From bd7771b55afa138d5efc34dcd365b00aed6bebd7 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Wed, 31 Jul 2024 23:38:57 +0800 Subject: [PATCH] tower product sumcheck prove/verify --- ceno_zkvm/src/error.rs | 2 +- ceno_zkvm/src/instructions/riscv/add.rs | 3 +- ceno_zkvm/src/lib.rs | 1 + ceno_zkvm/src/scheme.rs | 12 +- ceno_zkvm/src/scheme/constants.rs | 2 +- ceno_zkvm/src/scheme/prover.rs | 410 ++++++++---------------- ceno_zkvm/src/scheme/utils.rs | 263 +++++++++++++++ ceno_zkvm/src/scheme/verifier.rs | 157 ++++++++- ceno_zkvm/src/utils.rs | 1 + multilinear_extensions/src/mle.rs | 98 ++++++ sumcheck/src/prover_v2.rs | 33 +- 11 files changed, 685 insertions(+), 297 deletions(-) create mode 100644 ceno_zkvm/src/scheme/utils.rs diff --git a/ceno_zkvm/src/error.rs b/ceno_zkvm/src/error.rs index 18c319036..ba8b412ea 100644 --- a/ceno_zkvm/src/error.rs +++ b/ceno_zkvm/src/error.rs @@ -4,7 +4,7 @@ use singer_utils::error::UtilError; pub enum ZKVMError { CircuitError, UtilError(UtilError), - VerifyError, + VerifyError(&'static str), } impl From for ZKVMError { diff --git a/ceno_zkvm/src/instructions/riscv/add.rs b/ceno_zkvm/src/instructions/riscv/add.rs index 38ea9fb58..aca0d33d1 100644 --- a/ceno_zkvm/src/instructions/riscv/add.rs +++ b/ceno_zkvm/src/instructions/riscv/add.rs @@ -121,7 +121,7 @@ mod test { use singer_utils::{structs_v2::CircuitBuilderV2, util_v2::InstructionV2}; use transcript::Transcript; - use crate::scheme::{prover::ZKVMProver, verifier::ZKVMVerifier}; + use crate::scheme::{constants::NUM_PRODUCT_FANIN, prover::ZKVMProver, verifier::ZKVMVerifier}; use super::AddInstruction; @@ -163,6 +163,7 @@ mod test { .verify( &mut proof, &mut v_transcript, + NUM_PRODUCT_FANIN, &PointAndEval::default(), &challenges, ) diff --git a/ceno_zkvm/src/lib.rs b/ceno_zkvm/src/lib.rs index f2256c12a..b095a86ba 100644 --- a/ceno_zkvm/src/lib.rs +++ b/ceno_zkvm/src/lib.rs @@ -3,4 +3,5 @@ pub mod instructions; pub mod scheme; // #[cfg(test)] pub use utils::u64vec; +mod structs; mod utils; diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index f0d62529d..7442f748c 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -1,8 +1,11 @@ use ff_ext::ExtensionField; use sumcheck::structs::IOPProverMessage; -mod constants; +use crate::structs::TowerProofs; + +pub mod constants; pub mod prover; +mod utils; pub mod verifier; #[derive(Clone)] @@ -10,9 +13,12 @@ pub struct ZKVMProof { // TODO support >1 opcodes pub num_instances: usize, + // product constraints + pub record_r_out_evals: Vec, + pub record_w_out_evals: Vec, + pub tower_proof: TowerProofs, + // main constraint and select sumcheck proof - pub out_record_r_eval: E, - pub out_record_w_eval: E, pub main_sel_sumcheck_proofs: Vec>, pub r_records_in_evals: Vec, pub w_records_in_evals: Vec, diff --git a/ceno_zkvm/src/scheme/constants.rs b/ceno_zkvm/src/scheme/constants.rs index 1ea901dea..fd7e6738f 100644 --- a/ceno_zkvm/src/scheme/constants.rs +++ b/ceno_zkvm/src/scheme/constants.rs @@ -1,3 +1,3 @@ pub(crate) const MIN_PAR_SIZE: usize = 64; pub(crate) const MAINCONSTRAIN_SUMCHECK_BATCH_SIZE: usize = 2; -pub(crate) const PRODUCT_ARGUMENT_SIZE: usize = 2; +pub const NUM_PRODUCT_FANIN: usize = 2; diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index bf1aeccac..b7d541434 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -1,31 +1,28 @@ -use std::{collections::BTreeMap, sync::Arc}; +use std::{cmp::max, collections::BTreeMap, mem, sync::Arc}; use ff_ext::ExtensionField; -use gkr::{entered_span, exit_span}; +use gkr::{entered_span, exit_span, structs::Point}; + use itertools::Itertools; use multilinear_extensions::{ - commutative_op_mle_pair, - mle::{DenseMultilinearExtension, FieldType, IntoMLE, MultilinearExtension}, - op_mle, + mle::{DenseMultilinearExtension, IntoMLE, MultilinearExtension}, util::ceil_log2, virtual_poly::build_eq_x_r_vec, virtual_poly_v2::{ArcMultilinearExtension, VirtualPolynomialV2}, }; -use rayon::{ - iter::{ - IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, - ParallelIterator, - }, - prelude::ParallelSliceMut, -}; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use simple_frontend::structs::WitnessId; -use singer_utils::{structs_v2::Circuit, util_v2::Expression}; -use sumcheck::structs::IOPProverStateV2; +use singer_utils::structs_v2::Circuit; +use sumcheck::structs::{IOPProverMessage, IOPProverStateV2}; use transcript::Transcript; use crate::{ error::ZKVMError, - scheme::constants::{MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, MIN_PAR_SIZE, PRODUCT_ARGUMENT_SIZE}, + scheme::{ + constants::{MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, NUM_PRODUCT_FANIN}, + utils::{infer_tower_product_witness, interleaving_mles_to_mles, wit_infer_by_expr}, + }, + structs::{TowerProofs, TowerProver, TowerProverSpec}, utils::get_challenge_pows, }; @@ -87,9 +84,9 @@ impl ZKVMProver { r_records_wit, log2_num_instances, log2_r_count, - PRODUCT_ARGUMENT_SIZE, + NUM_PRODUCT_FANIN, ); - assert_eq!(r_records_last_layer.len(), PRODUCT_ARGUMENT_SIZE); + assert_eq!(r_records_last_layer.len(), NUM_PRODUCT_FANIN); exit_span!(span); // infer all tower witness after last layer @@ -97,7 +94,7 @@ impl ZKVMProver { let r_wit_layers = infer_tower_product_witness( log2_num_instances + log2_r_count, r_records_last_layer, - PRODUCT_ARGUMENT_SIZE, + NUM_PRODUCT_FANIN, ); exit_span!(span); @@ -107,16 +104,16 @@ impl ZKVMProver { w_records_wit, log2_num_instances, log2_w_count, - PRODUCT_ARGUMENT_SIZE, + NUM_PRODUCT_FANIN, ); - assert_eq!(w_records_last_layer.len(), PRODUCT_ARGUMENT_SIZE); + assert_eq!(w_records_last_layer.len(), NUM_PRODUCT_FANIN); exit_span!(span); let span = entered_span!("wit_inference::tower_witness_w_layers"); let w_wit_layers = infer_tower_product_witness( log2_num_instances + log2_w_count, w_records_last_layer, - PRODUCT_ARGUMENT_SIZE, + NUM_PRODUCT_FANIN, ); exit_span!(span); @@ -125,15 +122,15 @@ impl ZKVMProver { assert_eq!(r_wit_layers.len(), (log2_num_instances + log2_r_count)); assert_eq!(w_wit_layers.len(), (log2_num_instances + log2_w_count)); assert!(r_wit_layers.iter().enumerate().all(|(i, r_wit_layer)| { - let expected_size = 1 << i; - r_wit_layer.len() == PRODUCT_ARGUMENT_SIZE + let expected_size = 1 << (ceil_log2(NUM_PRODUCT_FANIN) * i); + r_wit_layer.len() == NUM_PRODUCT_FANIN && r_wit_layer .iter() .all(|f| f.evaluations().len() == expected_size) })); assert!(w_wit_layers.iter().enumerate().all(|(i, w_wit_layer)| { - let expected_size = 1 << i; - w_wit_layer.len() == PRODUCT_ARGUMENT_SIZE + let expected_size = 1 << (ceil_log2(NUM_PRODUCT_FANIN) * i); + w_wit_layer.len() == NUM_PRODUCT_FANIN && w_wit_layer .iter() .all(|f| f.evaluations().len() == expected_size) @@ -142,22 +139,43 @@ impl ZKVMProver { // product constraint tower sumcheck let span = entered_span!("sumcheck::tower"); - // TODO + // final evals for verifier + let record_r_out_evals: Vec = r_wit_layers[0] + .iter() + .map(|w| w.get_ext_field_vec()[0]) + .collect(); + let record_w_out_evals: Vec = w_wit_layers[0] + .iter() + .map(|w| w.get_ext_field_vec()[0]) + .collect(); + assert!( + record_r_out_evals.len() == NUM_PRODUCT_FANIN + && record_w_out_evals.len() == NUM_PRODUCT_FANIN + ); + let (next_rt, tower_proof) = TowerProver::create_proof( + vec![ + TowerProverSpec { + witness: r_wit_layers, + }, + TowerProverSpec { + witness: w_wit_layers, + }, + ], + NUM_PRODUCT_FANIN, + transcript, + ); + assert_eq!( + next_rt.len(), + log2_num_instances + max(log2_r_count, log2_w_count) // TODO add lookup count + ); exit_span!(span); - // main constraints degree > 1 + selector sumcheck + // selector main constraints degree > 1 sumcheck let span = entered_span!("sumcheck::main_sel"); - // TODO fix rt_r/rt_w to use real let (rt_r, rt_w): (Vec, Vec) = ( - (0..(log2_num_instances + log2_r_count)) - .map(|i| E::from(i as u64)) - .collect(), - (0..(log2_num_instances + log2_w_count)) - .map(|i| E::from(i as u64)) - .collect(), + next_rt[0..(log2_num_instances + log2_r_count)].to_vec(), + next_rt[0..(log2_num_instances + log2_w_count)].to_vec(), ); - // TODO fix record_r_eval, record_w_eval - let (record_r_eval, record_w_eval) = (E::from(5u64), E::from(7u64)); let mut virtual_poly = VirtualPolynomialV2::::new(log2_num_instances); let alpha_pow = get_challenge_pows(MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, transcript); let (alpha_read, alpha_write) = (&alpha_pow[0], &alpha_pow[1]); @@ -230,8 +248,6 @@ impl ZKVMProver { ); let input_open_point = main_sel_sumcheck_proofs.point.clone(); assert!(input_open_point.len() == log2_num_instances); - println!("evals {:?}", main_sel_evals,); - println!("point {:?}", input_open_point); exit_span!(span); let span = entered_span!("witin::evals"); @@ -243,8 +259,9 @@ impl ZKVMProver { Ok(ZKVMProof { num_instances, - out_record_r_eval: record_r_eval, - out_record_w_eval: record_w_eval, + record_r_out_evals, + record_w_out_evals, + tower_proof, main_sel_sumcheck_proofs: main_sel_sumcheck_proofs.proofs, r_records_in_evals, w_records_in_evals, @@ -253,245 +270,92 @@ impl ZKVMProver { } } -/// interleaving multiple mles into mles for the product/logup arguments last layer witness -fn interleaving_mles_to_mles<'a, E: ExtensionField>( - mles: &[ArcMultilinearExtension], - log2_num_instances: usize, - log2_per_instance_size: usize, - product_argument_size: usize, -) -> Vec> { - assert!(product_argument_size.is_power_of_two()); - let mle_group_len = mles.len() / product_argument_size; - let log_product_argument_size = ceil_log2(product_argument_size); - mles.chunks(mle_group_len) - .map(|records_mle| { - // interleaving records witness into single vector - let mut evaluations = vec![ - E::ONE; - 1 << (log2_num_instances + log2_per_instance_size - - log_product_argument_size) - ]; - let per_instance_size = 1 << (log2_per_instance_size - log_product_argument_size); - records_mle - .iter() - .enumerate() - .for_each(|(record_i, record_mle)| match record_mle.evaluations() { - FieldType::Ext(record_mle) => record_mle - .par_iter() - .zip(evaluations.par_chunks_mut(per_instance_size)) - .with_min_len(MIN_PAR_SIZE) - .for_each(|(value, instance)| { - assert_eq!(instance.len(), per_instance_size); - instance[record_i] = *value; - }), - _ => { - unreachable!("must be extension field") - } - }); - evaluations.into_mle().into() - }) - .collect::>>() +/// TowerProofs +impl TowerProofs { + pub fn new(spec_size: usize) -> Self { + TowerProofs { + proofs: vec![], + specs_eval: vec![vec![]; spec_size], + } + } + pub fn push_sumcheck_proofs(&mut self, proofs: Vec>) { + self.proofs.push(proofs); + } + + pub fn push_evals(&mut self, spec_index: usize, evals: Vec) { + self.specs_eval[spec_index].push(evals); + } + + pub fn spec_size(&self) -> usize { + return self.specs_eval.len(); + } } -/// infer tower witness from last layer -fn infer_tower_product_witness<'a, E: ExtensionField>( - num_vars: usize, - last_layer: Vec>, - product_argument_size: usize, -) -> Vec>> { - assert!(last_layer.len() == product_argument_size); - let mut r_wit_layers = (0..num_vars - 1).fold(vec![last_layer], |mut acc, i| { - let next_layer = acc.last().unwrap(); - let cur_len = next_layer[0].evaluations().len() / product_argument_size; - let cur_layer: Vec> = (0..product_argument_size) - .map(|index| { - let mut evaluations = vec![E::ONE; cur_len]; - next_layer.iter().for_each(|f| match f.evaluations() { - FieldType::Ext(f) => { - let start: usize = index * cur_len; - f[start..][..cur_len] - .par_iter() - .zip(evaluations.par_iter_mut()) - .with_min_len(MIN_PAR_SIZE) - .map(|(v, evaluations)| *evaluations *= *v) - .collect() - } - _ => unreachable!("must be extension field"), - }); - println!("i {} evaluation {:?} ", i, evaluations); - evaluations.into_mle().into() - }) +/// Tower Prover +impl TowerProver { + pub fn create_proof<'a, E: ExtensionField>( + mut specs: Vec>, + num_product_fanin: usize, + transcript: &mut Transcript, + ) -> (Point, TowerProofs) { + let mut proofs = TowerProofs::new(specs.len()); + assert!(specs.len() > 0); + let log2_num_product_fanin = ceil_log2(num_product_fanin); + // -1 for sliding windows size 2: (cur_layer, next_layer) w.r.t total size + let max_round = specs.iter().map(|m| m.witness.len()).max().unwrap() - 1; + + // TODO soundness question: should we generate alpha for each layer? + let alpha_pows = get_challenge_pows(specs.len(), transcript); + let initial_rt: Point = (0..log2_num_product_fanin) + .map(|_| transcript.get_and_append_challenge(b"product_sum").elements) .collect_vec(); - acc.push(cur_layer); - acc - }); - r_wit_layers.reverse(); - r_wit_layers -} -fn wit_infer_by_expr<'a, E: ExtensionField>( - witnesses: &BTreeMap>, - challenges: &[E], - expr: &Expression, -) -> ArcMultilinearExtension<'a, E> { - expr.evaluate::>( - &|witness_id| { - let a: ArcMultilinearExtension = Arc::new( - witnesses - .get(&witness_id) - .expect("non exist witness") - .clone(), - ); - a - }, - &|scalar| { - let scalar: ArcMultilinearExtension = Arc::new( - DenseMultilinearExtension::from_evaluations_vec(0, vec![scalar]), - ); - scalar - }, - &|challenge_id, pow, scalar, offset| { - // TODO cache challenge power to be aquire once for each power - let challenge = challenges[challenge_id as usize]; - let challenge: ArcMultilinearExtension = - Arc::new(DenseMultilinearExtension::from_evaluations_ext_vec( - 0, - vec![challenge.pow(&[pow as u64]) * scalar + offset], - )); - challenge - }, - &|a, b| { - commutative_op_mle_pair!(|a, b| { - match (a.len(), b.len()) { - (1, 1) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - 0, - vec![a[0] + b[0]], - )), - (1, _) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - ceil_log2(b.len()), - b.par_iter() - .with_min_len(MIN_PAR_SIZE) - .map(|b| a[0] + *b) - .collect(), - )), - (_, 1) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - ceil_log2(a.len()), - a.par_iter() - .with_min_len(MIN_PAR_SIZE) - .map(|a| *a + b[0]) - .collect(), - )), - (_, _) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - ceil_log2(a.len()), - a.par_iter() - .zip(b.par_iter()) - .with_min_len(MIN_PAR_SIZE) - .map(|(a, b)| *a + b) - .collect(), - )), - } - }) - }, - &|a, b| { - commutative_op_mle_pair!(|a, b| { - match (a.len(), b.len()) { - (1, 1) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - 0, - vec![a[0] * b[0]], - )), - (1, _) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - ceil_log2(b.len()), - b.par_iter() - .with_min_len(MIN_PAR_SIZE) - .map(|b| a[0] * *b) - .collect(), - )), - (_, 1) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - ceil_log2(a.len()), - a.par_iter() - .with_min_len(MIN_PAR_SIZE) - .map(|a| *a * b[0]) - .collect(), - )), - (_, _) => { - unimplemented!("r,w only support degree 1 expression") - } + let next_rt = (0..(max_round - 1)).fold(initial_rt, |rt, round| { + let mut virtual_poly = VirtualPolynomialV2::::new(rt.len()); + + let eq: ArcMultilinearExtension = build_eq_x_r_vec(&rt).into_mle().into(); + + specs.iter_mut().enumerate().for_each(|(i, s)| { + if (round + 1) < s.witness.len() { + let layer_polys = mem::take(&mut s.witness[round + 1]); + + // sanity check + assert_eq!(layer_polys.len(), num_product_fanin); + layer_polys + .iter() + .all(|f| f.evaluations().len() == 1 << (log2_num_product_fanin * round)); + + // \sum_s eq(rt, s) * alpha^{i} * ([in_i0[s] * in_i1[s] * .... in_i{num_product_fanin}[s]]) + virtual_poly + .add_mle_list(vec![vec![eq.clone()], layer_polys].concat(), alpha_pows[i]); } - }) - }, - &|a, scalar| { - op_mle!(|a| { - Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - ceil_log2(a.len()), - a.par_iter() - .with_min_len(MIN_PAR_SIZE) - .map(|a| scalar * a) - .collect(), - )) - }) - }, - ) -} + }); + let (sumcheck_proofs, state) = + IOPProverStateV2::prove_parallel(virtual_poly, transcript); + proofs.push_sumcheck_proofs(sumcheck_proofs.proofs); -#[cfg(test)] -mod tests { - use ff::Field; - use goldilocks::{ExtensionField, GoldilocksExt2}; - use multilinear_extensions::{ - commutative_op_mle_pair, - mle::{FieldType, IntoMLE}, - op_mle, - util::ceil_log2, - virtual_poly_v2::ArcMultilinearExtension, - }; - - use crate::scheme::prover::{infer_tower_product_witness, interleaving_mles_to_mles}; - - #[test] - fn test_infer_tower_witness() { - type E = GoldilocksExt2; - let product_argument_size = 2; - let last_layer: Vec> = vec![ - vec![E::ONE, E::from(2u64)].into_mle().into(), - vec![E::from(3u64), E::from(4u64)].into_mle().into(), - ]; - let num_vars = ceil_log2(last_layer[0].evaluations().len()) + 1; - let res = infer_tower_product_witness(num_vars, last_layer.clone(), 2); - let (left, right) = (&res[0][0], &res[0][1]); - let final_product = commutative_op_mle_pair!( - |left, right| { - assert!(left.len() == 1 && right.len() == 1); - left[0] * right[0] - }, - |out| E::from_base(&out) - ); - let expected_final_product: E = last_layer - .iter() - .map(|f| match f.evaluations() { - FieldType::Ext(e) => e.iter().cloned().reduce(|a, b| a * b).unwrap(), - _ => unreachable!(""), - }) - .product(); - assert_eq!(res.len(), num_vars); - assert!( - res.iter() - .all(|layer_wit| layer_wit.len() == product_argument_size) - ); - assert_eq!(final_product, expected_final_product); - } + let mut rt_prime = (0..log2_num_product_fanin) + .map(|_| transcript.get_and_append_challenge(b"merge").elements) + .collect_vec(); + rt_prime.extend(sumcheck_proofs.point); + + let evals = state.get_mle_final_evaluations(); + let mut evals_iter = evals.iter(); + specs.iter().enumerate().for_each(|(i, s)| { + if (round + 1) < s.witness.len() { + evals_iter.next(); // skip first eq + // collect evals belong to current spec + proofs.push_evals( + i, + (0..s.witness.len()) + .map(|_| *evals_iter.next().expect("insufficient evals length")) + .collect::>(), + ); + } + }); + rt_prime + }); - #[test] - fn test_interleaving_mles_to_mles() { - type E = GoldilocksExt2; - let product_argument_size = 2; - // [[1, 2], [3, 4]] - let input_mles: Vec> = vec![ - vec![E::ONE, E::from(2u64)].into_mle().into(), - vec![E::from(3u64), E::from(4u64)].into_mle().into(), - ]; - let res = interleaving_mles_to_mles(&input_mles, 1, 2, product_argument_size); - // [[1, 1, 2, 1], [3, 1, 4, 1]] - assert!(res[0].get_ext_field_vec() == vec![E::ONE, E::ONE, E::from(2u64), E::ONE],); - assert!(res[1].get_ext_field_vec() == vec![E::from(3u64), E::ONE, E::from(4u64), E::ONE]); + (next_rt, proofs) } } diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs new file mode 100644 index 000000000..3410bb42c --- /dev/null +++ b/ceno_zkvm/src/scheme/utils.rs @@ -0,0 +1,263 @@ +use std::{collections::BTreeMap, sync::Arc}; + +use ff_ext::ExtensionField; +use itertools::Itertools; +use multilinear_extensions::{ + commutative_op_mle_pair, + mle::{DenseMultilinearExtension, FieldType, IntoMLE}, + op_mle, + util::ceil_log2, + virtual_poly_v2::ArcMultilinearExtension, +}; +use rayon::{ + iter::{ + IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, + ParallelIterator, + }, + prelude::ParallelSliceMut, +}; +use simple_frontend::structs::WitnessId; +use singer_utils::util_v2::Expression; + +use crate::scheme::constants::MIN_PAR_SIZE; + +/// interleaving multiple mles into mles for the product/logup arguments last layer witness +pub(crate) fn interleaving_mles_to_mles<'a, E: ExtensionField>( + mles: &[ArcMultilinearExtension], + log2_num_instances: usize, + log2_per_instance_size: usize, + num_product_fanin: usize, +) -> Vec> { + assert!(num_product_fanin.is_power_of_two()); + let mle_group_len = mles.len() / num_product_fanin; + let log_num_product_fanin = ceil_log2(num_product_fanin); + mles.chunks(mle_group_len) + .map(|records_mle| { + // interleaving records witness into single vector + let mut evaluations = vec![ + E::ONE; + 1 << (log2_num_instances + log2_per_instance_size + - log_num_product_fanin) + ]; + let per_instance_size = 1 << (log2_per_instance_size - log_num_product_fanin); + records_mle + .iter() + .enumerate() + .for_each(|(record_i, record_mle)| match record_mle.evaluations() { + FieldType::Ext(record_mle) => record_mle + .par_iter() + .zip(evaluations.par_chunks_mut(per_instance_size)) + .with_min_len(MIN_PAR_SIZE) + .for_each(|(value, instance)| { + assert_eq!(instance.len(), per_instance_size); + instance[record_i] = *value; + }), + _ => { + unreachable!("must be extension field") + } + }); + evaluations.into_mle().into() + }) + .collect::>>() +} + +/// infer tower witness from last layer +pub(crate) fn infer_tower_product_witness<'a, E: ExtensionField>( + num_vars: usize, + last_layer: Vec>, + num_product_fanin: usize, +) -> Vec>> { + assert!(last_layer.len() == num_product_fanin); + let mut r_wit_layers = (0..num_vars - 1).fold(vec![last_layer], |mut acc, _| { + let next_layer = acc.last().unwrap(); + let cur_len = next_layer[0].evaluations().len() / num_product_fanin; + let cur_layer: Vec> = (0..num_product_fanin) + .map(|index| { + let mut evaluations = vec![E::ONE; cur_len]; + next_layer.iter().for_each(|f| match f.evaluations() { + FieldType::Ext(f) => { + let start: usize = index * cur_len; + f[start..][..cur_len] + .par_iter() + .zip(evaluations.par_iter_mut()) + .with_min_len(MIN_PAR_SIZE) + .map(|(v, evaluations)| *evaluations *= *v) + .collect() + } + _ => unreachable!("must be extension field"), + }); + evaluations.into_mle().into() + }) + .collect_vec(); + acc.push(cur_layer); + acc + }); + r_wit_layers.reverse(); + r_wit_layers +} + +pub(crate) fn wit_infer_by_expr<'a, E: ExtensionField>( + witnesses: &BTreeMap>, + challenges: &[E], + expr: &Expression, +) -> ArcMultilinearExtension<'a, E> { + expr.evaluate::>( + &|witness_id| { + let a: ArcMultilinearExtension = Arc::new( + witnesses + .get(&witness_id) + .expect("non exist witness") + .clone(), + ); + a + }, + &|scalar| { + let scalar: ArcMultilinearExtension = Arc::new( + DenseMultilinearExtension::from_evaluations_vec(0, vec![scalar]), + ); + scalar + }, + &|challenge_id, pow, scalar, offset| { + // TODO cache challenge power to be aquire once for each power + let challenge = challenges[challenge_id as usize]; + let challenge: ArcMultilinearExtension = + Arc::new(DenseMultilinearExtension::from_evaluations_ext_vec( + 0, + vec![challenge.pow(&[pow as u64]) * scalar + offset], + )); + challenge + }, + &|a, b| { + commutative_op_mle_pair!(|a, b| { + match (a.len(), b.len()) { + (1, 1) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( + 0, + vec![a[0] + b[0]], + )), + (1, _) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( + ceil_log2(b.len()), + b.par_iter() + .with_min_len(MIN_PAR_SIZE) + .map(|b| a[0] + *b) + .collect(), + )), + (_, 1) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( + ceil_log2(a.len()), + a.par_iter() + .with_min_len(MIN_PAR_SIZE) + .map(|a| *a + b[0]) + .collect(), + )), + (_, _) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( + ceil_log2(a.len()), + a.par_iter() + .zip(b.par_iter()) + .with_min_len(MIN_PAR_SIZE) + .map(|(a, b)| *a + b) + .collect(), + )), + } + }) + }, + &|a, b| { + commutative_op_mle_pair!(|a, b| { + match (a.len(), b.len()) { + (1, 1) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( + 0, + vec![a[0] * b[0]], + )), + (1, _) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( + ceil_log2(b.len()), + b.par_iter() + .with_min_len(MIN_PAR_SIZE) + .map(|b| a[0] * *b) + .collect(), + )), + (_, 1) => Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( + ceil_log2(a.len()), + a.par_iter() + .with_min_len(MIN_PAR_SIZE) + .map(|a| *a * b[0]) + .collect(), + )), + (_, _) => { + unimplemented!("r,w only support degree 1 expression") + } + } + }) + }, + &|a, scalar| { + op_mle!(|a| { + Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( + ceil_log2(a.len()), + a.par_iter() + .with_min_len(MIN_PAR_SIZE) + .map(|a| scalar * a) + .collect(), + )) + }) + }, + ) +} + +#[cfg(test)] +mod tests { + use ff::Field; + use goldilocks::{ExtensionField, GoldilocksExt2}; + use multilinear_extensions::{ + commutative_op_mle_pair, + mle::{FieldType, IntoMLE}, + util::ceil_log2, + virtual_poly_v2::ArcMultilinearExtension, + }; + + use crate::scheme::utils::{infer_tower_product_witness, interleaving_mles_to_mles}; + + #[test] + fn test_infer_tower_witness() { + type E = GoldilocksExt2; + let num_product_fanin = 2; + let last_layer: Vec> = vec![ + vec![E::ONE, E::from(2u64)].into_mle().into(), + vec![E::from(3u64), E::from(4u64)].into_mle().into(), + ]; + let num_vars = ceil_log2(last_layer[0].evaluations().len()) + 1; + let res = infer_tower_product_witness(num_vars, last_layer.clone(), 2); + let (left, right) = (&res[0][0], &res[0][1]); + let final_product = commutative_op_mle_pair!( + |left, right| { + assert!(left.len() == 1 && right.len() == 1); + left[0] * right[0] + }, + |out| E::from_base(&out) + ); + let expected_final_product: E = last_layer + .iter() + .map(|f| match f.evaluations() { + FieldType::Ext(e) => e.iter().cloned().reduce(|a, b| a * b).unwrap(), + _ => unreachable!(""), + }) + .product(); + assert_eq!(res.len(), num_vars); + assert!( + res.iter() + .all(|layer_wit| layer_wit.len() == num_product_fanin) + ); + assert_eq!(final_product, expected_final_product); + } + + #[test] + fn test_interleaving_mles_to_mles() { + type E = GoldilocksExt2; + let num_product_fanin = 2; + // [[1, 2], [3, 4]] + let input_mles: Vec> = vec![ + vec![E::ONE, E::from(2u64)].into_mle().into(), + vec![E::from(3u64), E::from(4u64)].into_mle().into(), + ]; + let res = interleaving_mles_to_mles(&input_mles, 1, 2, num_product_fanin); + // [[1, 1, 2, 1], [3, 1, 4, 1]] + assert!(res[0].get_ext_field_vec() == vec![E::ONE, E::ONE, E::from(2u64), E::ONE],); + assert!(res[1].get_ext_field_vec() == vec![E::from(3u64), E::ONE, E::from(4u64), E::ONE]); + } +} diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 2c8e5bfd9..8ffcdf8b2 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -1,13 +1,23 @@ -use std::{marker::PhantomData, mem}; +use std::{cmp::max, marker::PhantomData}; use ff_ext::ExtensionField; -use gkr::{structs::PointAndEval, util::ceil_log2}; -use multilinear_extensions::virtual_poly::{build_eq_x_r_vec_sequential, VPAuxInfo}; +use gkr::{ + structs::{Point, PointAndEval}, + util::ceil_log2, +}; +use itertools::{izip, Itertools}; +use multilinear_extensions::{ + mle::{IntoMLE, MultilinearExtension}, + virtual_poly::{build_eq_x_r_vec_sequential, VPAuxInfo}, +}; use singer_utils::structs_v2::Circuit; use sumcheck::structs::{IOPProof, IOPVerifierState}; use transcript::Transcript; -use crate::{error::ZKVMError, utils::get_challenge_pows}; +use crate::{ + error::ZKVMError, scheme::constants::NUM_PRODUCT_FANIN, structs::TowerProofs, + utils::get_challenge_pows, +}; use super::{constants::MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, ZKVMProof}; @@ -21,18 +31,12 @@ impl ZKVMVerifier { } pub fn verify( &self, - proof: &mut ZKVMProof, + proof: &ZKVMProof, transcript: &mut Transcript, + num_product_fanin: usize, _out_evals: &PointAndEval, _challenges: &[E], // derive challenge from PCS ) -> Result<(), ZKVMError> { - // TODO remove rng - let num_instances = proof.num_instances; - let log2_num_instances = ceil_log2(num_instances); - // verify and reduce product tower sumcheck - - // verify zero statement (degree > 1) + sel sumcheck - // TODO fix rt_r/rt_w to use real let (r_counts_per_instance, w_counts_per_instance) = ( self.circuit.r_expressions.len(), self.circuit.w_expressions.len(), @@ -41,6 +45,33 @@ impl ZKVMVerifier { ceil_log2(r_counts_per_instance), ceil_log2(w_counts_per_instance), ); + + let num_instances = proof.num_instances; + let log2_num_instances = ceil_log2(num_instances); + + // verify and reduce product tower sumcheck + let tower_proofs = &proof.tower_proof; + + // check read/write set equality + if proof.record_r_out_evals.iter().product::() + != proof.record_w_out_evals.iter().product() + { + return Err(ZKVMError::VerifyError("rw set equality check failed")); + } + let expected_max_round = log2_num_instances + max(log2_r_count, log2_w_count); // TODO add lookup + let _rt = TowerVerify::verify( + vec![ + proof.record_r_out_evals.clone(), + proof.record_w_out_evals.clone(), + ], + tower_proofs, + expected_max_round, + num_product_fanin, + transcript, + )?; + + // verify zero statement (degree > 1) + sel sumcheck + // TODO fix rt_r/rt_w to use real let (rt_r, rt_w): (Vec, Vec) = ( (0..(log2_num_instances + log2_r_count)) .map(|i| E::from(i as u64)) @@ -52,8 +83,9 @@ impl ZKVMVerifier { let alpha_pow = get_challenge_pows(MAINCONSTRAIN_SUMCHECK_BATCH_SIZE, transcript); let (alpha_read, alpha_write) = (&alpha_pow[0], &alpha_pow[1]); - let claim_sum = *alpha_read * (proof.out_record_r_eval - E::ONE) - + *alpha_write * (proof.out_record_w_eval - E::ONE); + // let claim_sum = *alpha_read * (proof.record_r_sel_eval - E::ONE) + // + *alpha_write * (proof.record_w_sel_eval - E::ONE); + let claim_sum = E::ONE; // TODO FIXME println!( "verifier alpha_read {:?} alpha_write {:?}", alpha_read, alpha_write @@ -62,7 +94,7 @@ impl ZKVMVerifier { claim_sum, &IOPProof { point: vec![], // final claimed point will be derive from sumcheck protocol - proofs: mem::take(&mut proof.main_sel_sumcheck_proofs), + proofs: proof.main_sel_sumcheck_proofs.clone(), }, &VPAuxInfo { max_degree: 2, @@ -93,6 +125,7 @@ impl ZKVMVerifier { (0..w_counts_per_instance) .map(|i| sel_w * proof.w_records_in_evals[i] * eq_w[i] * alpha_write) .sum::(), + // write padding (w_counts_per_instance..w_counts_per_instance.next_power_of_two()) .map(|i| sel_w * (eq_w[i] * alpha_write - E::ONE)) .sum::(), @@ -100,7 +133,9 @@ impl ZKVMVerifier { .iter() .sum::(); if computed_evals != expected_evaluation { - return Err(ZKVMError::VerifyError); + return Err(ZKVMError::VerifyError( + "main + sel constraints verify failed", + )); } // verify records (degree = 1) statement, thus no sumcheck let _input_opening_point = main_sel_eval_point; @@ -109,3 +144,93 @@ impl ZKVMVerifier { Ok(()) } } + +pub struct TowerVerify; + +impl TowerVerify { + pub fn verify( + initial_evals: Vec>, + tower_proofs: &TowerProofs, + expected_max_round: usize, + num_product_fanin: usize, + transcript: &mut Transcript, + ) -> Result, ZKVMError> { + let log2_num_product_fanin = ceil_log2(num_product_fanin); + // sanity check + assert!(initial_evals.len() == tower_proofs.spec_size()); + assert!( + initial_evals + .iter() + .all(|evals| evals.len() == num_product_fanin) + ); + + let alpha_pows = get_challenge_pows(tower_proofs.spec_size(), transcript); + let initial_rt: Point = (0..log2_num_product_fanin) + .map(|_| transcript.get_and_append_challenge(b"product_sum").elements) + .collect_vec(); + // initial_claim = \sum_j alpha^j * record_{j}[rt] + let initial_claim = izip!(initial_evals, alpha_pows.iter()) + .map(|(evals, alpha)| evals.into_mle().evaluate(&initial_rt) * alpha) + .sum(); + + let next_rt = (0..(expected_max_round - 1)).fold( + PointAndEval { + point: initial_rt, + eval: initial_claim, + }, + |point_and_eval, round| { + let (_rt, out_claim) = (&point_and_eval.point, &point_and_eval.eval); + let sumcheck_claim = IOPVerifierState::verify( + *out_claim, + &IOPProof { + point: vec![], // final claimed point will be derive from sumcheck protocol + proofs: tower_proofs.proofs[round].clone(), + }, + &VPAuxInfo { + max_degree: NUM_PRODUCT_FANIN + 1, // + 1 for eq + num_variables: (round + 1) * log2_num_product_fanin, + phantom: PhantomData, + }, + transcript, + ); + + // TODO check expected_evaluation + let point: Point = sumcheck_claim.point.iter().map(|c| c.elements).collect(); + + // derive single eval + // rt' = r_merge || rt + // r_merge.len() == ceil_log2(num_product_fanin) + let mut rt_prime = (0..log2_num_product_fanin) + .map(|_| transcript.get_and_append_challenge(b"merge").elements) + .collect_vec(); + let coeffs = build_eq_x_r_vec_sequential(&rt_prime); + rt_prime.extend(point); + assert_eq!(coeffs.len(), num_product_fanin); + let spec_evals = (0..tower_proofs.spec_size()).map(|spec_index| { + if round < tower_proofs.specs_eval[spec_index].len() { + // merged evaluation + izip!( + tower_proofs.specs_eval[spec_index][round].iter(), + coeffs.iter() + ) + .map(|(a, b)| *a * b) + .sum::() + } else { + E::ZERO + } + }); + // sum evaluation from different specs + let next_eval = spec_evals + .zip(alpha_pows.iter()) + .map(|(eval, alpha)| eval * alpha) + .sum(); + PointAndEval { + point: rt_prime, + eval: next_eval, + } + }, + ); + + Ok(next_rt.point) + } +} diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index 0961ae869..8bdd208fb 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -1,5 +1,6 @@ use ff_ext::ExtensionField; use itertools::Itertools; +use multilinear_extensions::virtual_poly_v2::ArcMultilinearExtension; use transcript::Transcript; pub(crate) fn i64_to_base_field(x: i64) -> E::BaseField { diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index 11f3b21b1..03fa9eea3 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -997,6 +997,104 @@ macro_rules! op_mle { }; } +#[macro_export] +macro_rules! op_mle_3 { + (|$f1:ident, $f2:ident, $f3:ident| $op:expr, |$bb_out:ident| $op_bb_out:expr) => { + match (&$f1.evaluations(), &$f2.evaluations(), &$f3.evaluations()) { + ( + $crate::mle::FieldType::Base(f1), + $crate::mle::FieldType::Base(f2), + $crate::mle::FieldType::Base(f3), + ) => { + let $f1 = if let Some((start, offset)) = $f1.evaluations_range() { + &f1[start..][..offset] + } else { + &f1[..] + }; + let $f2 = if let Some((start, offset)) = $f2.evaluations_range() { + &f2[start..][..offset] + } else { + &f2[..] + }; + let $f3 = if let Some((start, offset)) = $f3.evaluations_range() { + &f3[start..][..offset] + } else { + &f3[..] + }; + let $bb_out = $op; + $op_bb_out + } + ( + $crate::mle::FieldType::Ext(f1), + $crate::mle::FieldType::Base(f2), + $crate::mle::FieldType::Base(f3), + ) => { + let $f1 = if let Some((start, offset)) = $f1.evaluations_range() { + &f1[start..][..offset] + } else { + &f1[..] + }; + let $f2 = if let Some((start, offset)) = $f2.evaluations_range() { + &f2[start..][..offset] + } else { + &f2[..] + }; + let $f3 = if let Some((start, offset)) = $f3.evaluations_range() { + &f3[start..][..offset] + } else { + &f3[..] + }; + $op + } + ( + $crate::mle::FieldType::Ext(f1), + $crate::mle::FieldType::Ext(f2), + $crate::mle::FieldType::Ext(f3), + ) => { + let $f1 = if let Some((start, offset)) = $f1.evaluations_range() { + &f1[start..][..offset] + } else { + &f1[..] + }; + let $f2 = if let Some((start, offset)) = $f2.evaluations_range() { + &f2[start..][..offset] + } else { + &f2[..] + }; + let $f3 = if let Some((start, offset)) = $f3.evaluations_range() { + &f3[start..][..offset] + } else { + &f3[..] + }; + $op + } + ( + $crate::mle::FieldType::Ext(f1), + $crate::mle::FieldType::Ext(f2), + $crate::mle::FieldType::Base(f3), + ) => { + let $f1 = if let Some((start, offset)) = $f1.evaluations_range() { + &f1[start..][..offset] + } else { + &f1[..] + }; + let $f2 = if let Some((start, offset)) = $f2.evaluations_range() { + &f2[start..][..offset] + } else { + &f2[..] + }; + let $f3 = if let Some((start, offset)) = $f3.evaluations_range() { + &f3[start..][..offset] + } else { + &f3[..] + }; + $op + } + _ => unreachable!(), + } + }; +} + /// macro support op(a, b) and tackles type matching internally. /// Please noted that op must satisfy commutative rule w.r.t op(b, a) operand swap. #[macro_export] diff --git a/sumcheck/src/prover_v2.rs b/sumcheck/src/prover_v2.rs index f786ace38..b7a8e8922 100644 --- a/sumcheck/src/prover_v2.rs +++ b/sumcheck/src/prover_v2.rs @@ -6,7 +6,7 @@ use ff_ext::ExtensionField; use multilinear_extensions::{ commutative_op_mle_pair, mle::{DenseMultilinearExtension, MultilinearExtension}, - op_mle, + op_mle, op_mle_3, virtual_poly_v2::VirtualPolynomialV2, }; use rayon::{ @@ -728,7 +728,36 @@ impl<'a, E: ExtensionField> IOPProverStateV2<'a, E> { ) .to_vec() } - _ => unimplemented!("do not support degree > 2"), + 3 => { + let (f1, f2, f3) = ( + &self.poly.flattened_ml_extensions[products[0]], + &self.poly.flattened_ml_extensions[products[1]], + &self.poly.flattened_ml_extensions[products[2]], + ); + op_mle_3!( + |f1, f2, f3| (0..f1.len()) + .into_iter() + .step_by(2) + .map(|b| { + // f = c x + d + let c1 = f1[b + 1] - f1[b]; + let c2 = f2[b + 1] - f2[b]; + let c3 = f3[b + 1] - f3[b]; + AdditiveArray([ + f1[b] * (f2[b] * f3[b]), + f1[b + 1] * (f2[b + 1] * f3[b + 1]), + (c1 + f1[b + 1]) + * ((c2 + f2[b + 1]) * (c3 + f3[b + 1])), + (c1 + c1 + f1[b + 1]) + * ((c2 + c2 + f2[b + 1]) * (c3 + c3 + f3[b + 1])), + ]) + }) + .sum::>(), + |sum| AdditiveArray(sum.0.map(E::from)) + ) + .to_vec() + } + _ => unimplemented!("do not support degree > 3"), }; exit_span!(span); sum.iter_mut().for_each(|sum| *sum *= coefficient);