From 128731563ccc85e9e828e2f17495f4dbaf72bda9 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Mon, 6 Jan 2025 15:19:28 +0800 Subject: [PATCH 01/12] migrate field from in-housed goldilock to plonky3 --- Cargo.lock | 167 +++++++++++++++++++++++++++- Cargo.toml | 3 + ff_ext/Cargo.toml | 4 +- ff_ext/src/lib.rs | 115 +++++++++++++------ multilinear_extensions/Cargo.toml | 3 + multilinear_extensions/src/mle.rs | 13 ++- poseidon/Cargo.toml | 4 + poseidon/src/poseidon.rs | 2 +- poseidon/src/poseidon_goldilocks.rs | 6 +- 9 files changed, 275 insertions(+), 42 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1d9cc9427..b742a9176 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -742,7 +742,9 @@ version = "0.1.0" dependencies = [ "ff", "goldilocks", - "poseidon", + "p3-field", + "p3-goldilocks", + "p3-poseidon", "serde", ] @@ -779,6 +781,12 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" +[[package]] +name = "gcd" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d758ba1b47b00caf47f24925c0074ecb20d6dfcffe7f6d53395c0465674841a" + [[package]] name = "generic-array" version = "0.14.7" @@ -1151,6 +1159,9 @@ dependencies = [ "goldilocks", "itertools 0.13.0", "log", + "p3-field", + "p3-goldilocks", + "p3-poseidon", "rayon", "serde", "tracing", @@ -1314,6 +1325,18 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "nums" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf3c74f925fb8cfc49a8022f2afce48a0683b70f9e439885594e84c5edbf5b01" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", + "rand", +] + [[package]] name = "object" version = "0.36.5" @@ -1341,6 +1364,128 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "p3-dft" +version = "0.1.0" +source = "git+https://github.com/plonky3/plonky3#f7eb441c94d098c510af9afa47b6743eb6b5261c" +dependencies = [ + "itertools 0.13.0", + "p3-field", + "p3-matrix", + "p3-maybe-rayon", + "p3-util", + "tracing", +] + +[[package]] +name = "p3-field" +version = "0.1.0" +source = "git+https://github.com/plonky3/plonky3#f7eb441c94d098c510af9afa47b6743eb6b5261c" +dependencies = [ + "itertools 0.13.0", + "num-bigint", + "num-integer", + "num-traits", + "nums", + "p3-maybe-rayon", + "p3-util", + "rand", + "serde", + "tracing", +] + +[[package]] +name = "p3-goldilocks" +version = "0.1.0" +source = "git+https://github.com/plonky3/plonky3#f7eb441c94d098c510af9afa47b6743eb6b5261c" +dependencies = [ + "num-bigint", + "p3-dft", + "p3-field", + "p3-mds", + "p3-poseidon", + "p3-poseidon2", + "p3-symmetric", + "p3-util", + "rand", + "serde", +] + +[[package]] +name = "p3-matrix" +version = "0.1.0" +source = "git+https://github.com/plonky3/plonky3#f7eb441c94d098c510af9afa47b6743eb6b5261c" +dependencies = [ + "itertools 0.13.0", + "p3-field", + "p3-maybe-rayon", + "p3-util", + "rand", + "serde", + "tracing", + "transpose", +] + +[[package]] +name = "p3-maybe-rayon" +version = "0.1.0" +source = "git+https://github.com/plonky3/plonky3#f7eb441c94d098c510af9afa47b6743eb6b5261c" + +[[package]] +name = "p3-mds" +version = "0.1.0" +source = "git+https://github.com/plonky3/plonky3#f7eb441c94d098c510af9afa47b6743eb6b5261c" +dependencies = [ + "itertools 0.13.0", + "p3-dft", + "p3-field", + "p3-matrix", + "p3-symmetric", + "p3-util", + "rand", +] + +[[package]] +name = "p3-poseidon" +version = "0.1.0" +source = "git+https://github.com/plonky3/plonky3#f7eb441c94d098c510af9afa47b6743eb6b5261c" +dependencies = [ + "p3-field", + "p3-mds", + "p3-symmetric", + "rand", +] + +[[package]] +name = "p3-poseidon2" +version = "0.1.0" +source = "git+https://github.com/plonky3/plonky3#f7eb441c94d098c510af9afa47b6743eb6b5261c" +dependencies = [ + "gcd", + "p3-field", + "p3-mds", + "p3-symmetric", + "rand", +] + +[[package]] +name = "p3-symmetric" +version = "0.1.0" +source = "git+https://github.com/plonky3/plonky3#f7eb441c94d098c510af9afa47b6743eb6b5261c" +dependencies = [ + "itertools 0.13.0", + "p3-field", + "serde", +] + +[[package]] +name = "p3-util" +version = "0.1.0" +source = "git+https://github.com/plonky3/plonky3#f7eb441c94d098c510af9afa47b6743eb6b5261c" +dependencies = [ + "serde", +] + [[package]] name = "parking_lot" version = "0.12.3" @@ -1482,7 +1627,11 @@ dependencies = [ "ark-std", "criterion", "ff", + "ff_ext", "goldilocks", + "p3-field", + "p3-goldilocks", + "p3-poseidon", "plonky2", "rand", "serde", @@ -1966,6 +2115,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9091b6114800a5f2141aee1d1b9d6ca3592ac062dc5decb3764ec5895a47b4eb" +[[package]] +name = "strength_reduce" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" + [[package]] name = "strsim" version = "0.11.1" @@ -2276,6 +2431,16 @@ dependencies = [ "serde", ] +[[package]] +name = "transpose" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ad61aed86bc3faea4300c7aee358b4c6d0c8d6ccc36524c96e4c92ccf26e77e" +dependencies = [ + "num-integer", + "strength_reduce", +] + [[package]] name = "typenum" version = "1.17.0" diff --git a/Cargo.toml b/Cargo.toml index 0a5361732..bd6785d88 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,9 @@ num-derive = "0.4" num-traits = "0.2" paste = "1" plonky2 = "0.2" +p3-field = { git = "https://github.com/plonky3/plonky3" } +p3-goldilocks = { git = "https://github.com/plonky3/plonky3" } +p3-poseidon = { git = "https://github.com/plonky3/plonky3" } poseidon = { path = "./poseidon" } pprof2 = { version = "0.13", features = ["flamegraph"] } prettytable-rs = "^0.10" diff --git a/ff_ext/Cargo.toml b/ff_ext/Cargo.toml index 3b55f3581..e324e2f5b 100644 --- a/ff_ext/Cargo.toml +++ b/ff_ext/Cargo.toml @@ -12,5 +12,7 @@ version.workspace = true [dependencies] ff.workspace = true goldilocks.workspace = true -poseidon.workspace = true serde.workspace = true +p3-field.workspace = true +p3-goldilocks.workspace = true +p3-poseidon.workspace = true diff --git a/ff_ext/src/lib.rs b/ff_ext/src/lib.rs index 32d77a565..48d4598fd 100644 --- a/ff_ext/src/lib.rs +++ b/ff_ext/src/lib.rs @@ -1,33 +1,48 @@ #![deny(clippy::cargo)] + pub use ff; -use ff::FromUniformBytes; -use goldilocks::SmallField; -use poseidon::poseidon::Poseidon; +use p3_field::{ExtensionField as P3ExtensionField, Field as P3Field}; use serde::Serialize; -use std::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}; - -pub trait ExtensionField: - Serialize - + FromUniformBytes<64> - + From - + Add - + Sub - + Mul - + for<'a> Add<&'a Self::BaseField, Output = Self> - + for<'a> Sub<&'a Self::BaseField, Output = Self> - + for<'a> Mul<&'a Self::BaseField, Output = Self> - + AddAssign - + SubAssign - + MulAssign - + for<'a> AddAssign<&'a Self::BaseField> - + for<'a> SubAssign<&'a Self::BaseField> - + for<'a> MulAssign<&'a Self::BaseField> - + Ord - + std::hash::Hash + +// TODO remove SmallField +pub trait SmallField: Serialize + P3Field { + /// MODULUS as u64 + const MODULUS_U64: u64; + + /// Identifier string + const NAME: &'static str; + + /// Convert a byte string into a list of field elements + fn bytes_to_field_elements(bytes: &[u8]) -> Vec; + + /// Convert a field elements to a u64. + fn to_canonical_u64(&self) -> u64; + + /// Convert a field elements to a u64. Do not normalize it. + fn to_noncanonical_u64(&self) -> u64; +} + +pub trait ExtensionField: P3ExtensionField +// + FromUniformBytes<64> +// + From +// + Add +// + Sub +// + Mul +// // + for<'a> Add<&'a Self::BaseField, Output = Self> +// + for<'a> Sub<&'a Self::BaseField, Output = Self> +// + for<'a> Mul<&'a Self::BaseField, Output = Self> +// + AddAssign +// + SubAssign +// + MulAssign +// + for<'a> AddAssign<&'a Self::BaseField> +// + for<'a> SubAssign<&'a Self::BaseField> +// + for<'a> MulAssign<&'a Self::BaseField> +// + Ord +// + std::hash::Hash { const DEGREE: usize; - type BaseField: SmallField + FromUniformBytes<64> + Poseidon + Ord; + type BaseField: SmallField + Ord + P3Field; fn from_bases(bases: &[Self::BaseField]) -> Self; @@ -41,30 +56,68 @@ pub trait ExtensionField: } mod impl_goldilocks { - use crate::ExtensionField; - use goldilocks::{ExtensionField as GoldilocksEF, Goldilocks, GoldilocksExt2}; + use crate::{ExtensionField, SmallField}; + use p3_field::{ + FieldAlgebra, FieldExtensionAlgebra, PrimeField64, extension::BinomialExtensionField, + }; + use p3_goldilocks::Goldilocks; + + impl SmallField for Goldilocks { + /// Identifier string + const NAME: &'static str = "Goldilocks"; + const MODULUS_U64: u64 = Self::ORDER_U64; + + /// Convert a byte string into a list of field elements + fn bytes_to_field_elements(bytes: &[u8]) -> Vec { + bytes + .chunks(8) + .map(|chunk| { + let mut array = [0u8; 8]; + array[..chunk.len()].copy_from_slice(chunk); + unsafe { std::ptr::read_unaligned(array.as_ptr() as *const u64) } + }) + .map(Self::from_canonical_u64) + .collect::>() + } + + /// Convert a field elements to a u64. + fn to_canonical_u64(&self) -> u64 { + self.as_canonical_u64() + } + + /// Convert a field elements to a u64. Do not normalize it. + fn to_noncanonical_u64(&self) -> u64 { + self.as_canonical_u64() + } + } - impl ExtensionField for GoldilocksExt2 { + impl ExtensionField for BinomialExtensionField { const DEGREE: usize = 2; type BaseField = Goldilocks; fn from_bases(bases: &[Goldilocks]) -> Self { debug_assert_eq!(bases.len(), 2); - Self([bases[0], bases[1]]) + Self::from_base_slice(bases) + // Self([bases[0], bases[1]]) } fn as_bases(&self) -> &[Goldilocks] { - self.0.as_slice() + self.as_base_slice() } /// Convert limbs into self fn from_limbs(limbs: &[Self::BaseField]) -> Self { - Self([limbs[0], limbs[1]]) + // Self([limbs[0], limbs[1]]) + Self::from_base_slice(&limbs[0..2]) } fn to_canonical_u64_vec(&self) -> Vec { - ::to_canonical_u64_vec(self) + self.as_base_slice() + .iter() + .map(|v: &Self::BaseField| v.as_canonical_u64()) + .collect() + // ::to_canonical_u64_vec(self) } } } diff --git a/multilinear_extensions/Cargo.toml b/multilinear_extensions/Cargo.toml index 1a8777641..a4abdb6ea 100644 --- a/multilinear_extensions/Cargo.toml +++ b/multilinear_extensions/Cargo.toml @@ -18,6 +18,9 @@ itertools.workspace = true rayon.workspace = true serde.workspace = true tracing.workspace = true +p3-field.workspace = true +p3-goldilocks.workspace = true +p3-poseidon.workspace = true [dev-dependencies] env_logger = "0.11" diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index b4e8df983..d43412faa 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -5,10 +5,11 @@ use ark_std::{end_timer, rand::RngCore, start_timer}; use core::hash::Hash; use ff::Field; use ff_ext::ExtensionField; +use p3_field::FieldAlgebra; use rayon::iter::{ IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, }; -use serde::{Deserialize, Serialize}; +use serde::Serialize; use std::fmt::Debug; pub trait MultilinearExtension: Send + Sync { @@ -122,7 +123,7 @@ impl> IntoMLEs { @@ -159,7 +160,7 @@ impl FieldType { } /// Stores a multilinear polynomial in dense evaluation form. -#[derive(Clone, PartialEq, Eq, Default, Debug, Serialize, Deserialize)] +#[derive(Clone, PartialEq, Eq, Default, Debug, Serialize)] pub struct DenseMultilinearExtension { /// The evaluation over {0,1}^`num_vars` pub evaluations: FieldType, @@ -489,7 +490,7 @@ impl MultilinearExtension for DenseMultilinearExtension FieldType::Ext(evaluations) => { (0..evaluations.len()).step_by(2).for_each(|b| { evaluations[b >> 1] = - evaluations[b] + (evaluations[b + 1] - evaluations[b]) * point + evaluations[b] + (evaluations[b + 1] - evaluations[b]) * *point }); } FieldType::Unreachable => unreachable!(), @@ -568,7 +569,7 @@ impl MultilinearExtension for DenseMultilinearExtension lo.par_iter_mut() .zip(hi) .with_min_len(64) - .for_each(|(lo, hi)| *lo += (*hi - *lo) * point); + .for_each(|(lo, hi)| *lo += (*hi - *lo) * *point); current_eval_size = half_size; } FieldType::Unreachable => unreachable!(), @@ -672,7 +673,7 @@ impl MultilinearExtension for DenseMultilinearExtension .par_iter_mut() .chunks(2) .with_min_len(64) - .for_each(|mut buf| *buf[0] = *buf[0] + (*buf[1] - *buf[0]) * point); + .for_each(|mut buf| *buf[0] = *buf[0] + (*buf[1] - *buf[0]) * *point); // sequentially update buf[b1, b2,..bt] = buf[b1, b2,..bt, 0] for index in 0..1 << (max_log2_size - 1) { diff --git a/poseidon/Cargo.toml b/poseidon/Cargo.toml index eff0f50b7..69709a28f 100644 --- a/poseidon/Cargo.toml +++ b/poseidon/Cargo.toml @@ -15,6 +15,10 @@ ff.workspace = true goldilocks.workspace = true serde.workspace = true unroll = "0.1" +ff_ext = { path = "../ff_ext" } +p3-field.workspace = true +p3-goldilocks.workspace = true +p3-poseidon.workspace = true [dev-dependencies] ark-std.workspace = true diff --git a/poseidon/src/poseidon.rs b/poseidon/src/poseidon.rs index ed5d76d14..ee896d25a 100644 --- a/poseidon/src/poseidon.rs +++ b/poseidon/src/poseidon.rs @@ -1,7 +1,7 @@ use crate::constants::{ ALL_ROUND_CONSTANTS, HALF_N_FULL_ROUNDS, N_PARTIAL_ROUNDS, N_ROUNDS, SPONGE_WIDTH, }; -use goldilocks::SmallField; +use ff_ext::ExtensionField; use unroll::unroll_for_loops; pub trait Poseidon: AdaptedField { diff --git a/poseidon/src/poseidon_goldilocks.rs b/poseidon/src/poseidon_goldilocks.rs index eaab6fcd0..5844fed0d 100644 --- a/poseidon/src/poseidon_goldilocks.rs +++ b/poseidon/src/poseidon_goldilocks.rs @@ -2,7 +2,9 @@ use crate::{ constants::N_PARTIAL_ROUNDS, poseidon::{AdaptedField, Poseidon}, }; -use goldilocks::{EPSILON, Goldilocks, SmallField}; +use goldilocks::EPSILON; +use p3_field::PrimeField64; +use p3_goldilocks::Goldilocks; #[cfg(target_arch = "x86_64")] use std::hint::unreachable_unchecked; @@ -214,7 +216,7 @@ impl Poseidon for Goldilocks { } impl AdaptedField for Goldilocks { - const ORDER: u64 = Goldilocks::MODULUS_U64; + const ORDER: u64 = Goldilocks::ORDER_U64; fn from_noncanonical_u96(n_lo: u64, n_hi: u32) -> Self { reduce96((n_lo, n_hi)) From c372730230a0101500d13579bcb32e16f66b4b5f Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Mon, 6 Jan 2025 16:42:58 +0800 Subject: [PATCH 02/12] mle works --- Cargo.lock | 1 + ff_ext/Cargo.toml | 1 + ff_ext/src/lib.rs | 91 +++++++- multilinear_extensions/src/mle.rs | 11 +- multilinear_extensions/src/test.rs | 43 ++-- multilinear_extensions/src/virtual_poly.rs | 5 +- poseidon/benches/hashing.rs | 256 ++++++++++----------- poseidon/src/digest.rs | 6 +- poseidon/src/poseidon.rs | 9 +- poseidon/src/poseidon_goldilocks.rs | 11 +- poseidon/src/poseidon_hash.rs | 180 +++++++-------- transcript/src/basic.rs | 1 - 12 files changed, 340 insertions(+), 275 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b742a9176..f46fc40cb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -745,6 +745,7 @@ dependencies = [ "p3-field", "p3-goldilocks", "p3-poseidon", + "rand_core", "serde", ] diff --git a/ff_ext/Cargo.toml b/ff_ext/Cargo.toml index e324e2f5b..1a2bd840a 100644 --- a/ff_ext/Cargo.toml +++ b/ff_ext/Cargo.toml @@ -13,6 +13,7 @@ version.workspace = true ff.workspace = true goldilocks.workspace = true serde.workspace = true +rand_core.workspace = true p3-field.workspace = true p3-goldilocks.workspace = true p3-poseidon.workspace = true diff --git a/ff_ext/src/lib.rs b/ff_ext/src/lib.rs index 48d4598fd..cc67be770 100644 --- a/ff_ext/src/lib.rs +++ b/ff_ext/src/lib.rs @@ -1,9 +1,76 @@ #![deny(clippy::cargo)] +use std::{array::from_fn, iter::repeat_with}; + pub use ff; -use p3_field::{ExtensionField as P3ExtensionField, Field as P3Field}; +use p3_field::{ + ExtensionField as P3ExtensionField, Field as P3Field, PackedValue, + extension::BinomialExtensionField, +}; +use p3_goldilocks::Goldilocks; +use rand_core::RngCore; use serde::Serialize; +pub type GoldilocksExt2 = BinomialExtensionField; + +fn array_try_from_uniform_bytes< + F: Copy + Default + FromUniformBytes, + const W: usize, + const N: usize, +>( + bytes: &[u8], +) -> Option<[F; N]> { + let mut array = [F::default(); N]; + for i in 0..N { + array[i] = F::try_from_uniform_bytes(from_fn(|j| bytes[i * W + j]))?; + } + Some(array) +} + +pub trait FromUniformBytes: Sized { + type Bytes: Copy + Default + AsRef<[u8]> + AsMut<[u8]>; + + fn from_uniform_bytes(mut fill: impl FnMut(&mut [u8])) -> Self { + let mut bytes = Self::Bytes::default(); + loop { + fill(bytes.as_mut()); + if let Some(value) = Self::try_from_uniform_bytes(bytes) { + return value; + } + } + } + + fn try_from_uniform_bytes(bytes: Self::Bytes) -> Option; + + fn random(mut rng: impl RngCore) -> Self { + Self::from_uniform_bytes(|bytes| rng.fill_bytes(bytes.as_mut())) + } + + fn random_vec(n: usize, mut rng: impl RngCore) -> Vec { + repeat_with(|| Self::random(&mut rng)).take(n).collect() + } +} + +macro_rules! impl_from_uniform_bytes_for_binomial_extension { + ($base:ty, $degree:literal) => { + impl FromUniformBytes for p3_field::extension::BinomialExtensionField<$base, $degree> { + type Bytes = [u8; <$base as FromUniformBytes>::Bytes::WIDTH * $degree]; + + fn try_from_uniform_bytes(bytes: Self::Bytes) -> Option { + Some(p3_field::FieldExtensionAlgebra::from_base_slice( + &array_try_from_uniform_bytes::< + $base, + { <$base as FromUniformBytes>::Bytes::WIDTH }, + $degree, + >(&bytes)?, + )) + } + } + }; +} + +impl_from_uniform_bytes_for_binomial_extension!(p3_goldilocks::Goldilocks, 2); + // TODO remove SmallField pub trait SmallField: Serialize + P3Field { /// MODULUS as u64 @@ -22,7 +89,7 @@ pub trait SmallField: Serialize + P3Field { fn to_noncanonical_u64(&self) -> u64; } -pub trait ExtensionField: P3ExtensionField +pub trait ExtensionField: P3ExtensionField + FromUniformBytes // + FromUniformBytes<64> // + From // + Add @@ -42,7 +109,7 @@ pub trait ExtensionField: P3ExtensionField { const DEGREE: usize; - type BaseField: SmallField + Ord + P3Field; + type BaseField: SmallField + Ord + P3Field + FromUniformBytes; fn from_bases(bases: &[Self::BaseField]) -> Self; @@ -56,12 +123,20 @@ pub trait ExtensionField: P3ExtensionField } mod impl_goldilocks { - use crate::{ExtensionField, SmallField}; - use p3_field::{ - FieldAlgebra, FieldExtensionAlgebra, PrimeField64, extension::BinomialExtensionField, - }; + use crate::{ExtensionField, FromUniformBytes, GoldilocksExt2, SmallField}; + use p3_field::{FieldAlgebra, FieldExtensionAlgebra, PrimeField64}; use p3_goldilocks::Goldilocks; + impl FromUniformBytes for Goldilocks { + type Bytes = [u8; 8]; + + fn try_from_uniform_bytes(bytes: [u8; 8]) -> Option { + let value = u64::from_le_bytes(bytes); + let is_canonical = value < Self::ORDER_U64; + is_canonical.then(|| Self::from_canonical_u64(value)) + } + } + impl SmallField for Goldilocks { /// Identifier string const NAME: &'static str = "Goldilocks"; @@ -91,7 +166,7 @@ mod impl_goldilocks { } } - impl ExtensionField for BinomialExtensionField { + impl ExtensionField for GoldilocksExt2 { const DEGREE: usize = 2; type BaseField = Goldilocks; diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index d43412faa..39de4fc3f 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -4,7 +4,7 @@ use crate::{op_mle, util::ceil_log2}; use ark_std::{end_timer, rand::RngCore, start_timer}; use core::hash::Hash; use ff::Field; -use ff_ext::ExtensionField; +use ff_ext::{ExtensionField, FromUniformBytes}; use p3_field::FieldAlgebra; use rayon::iter::{ IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, @@ -330,15 +330,6 @@ impl DenseMultilinearExtension { end_timer!(start); list } - - pub fn to_ext_field(&self) -> Self { - op_mle!(self, |evaluations| { - DenseMultilinearExtension::from_evaluations_ext_vec( - self.num_vars(), - evaluations.iter().cloned().map(E::from).collect(), - ) - }) - } } #[allow(clippy::wrong_self_convention)] diff --git a/multilinear_extensions/src/test.rs b/multilinear_extensions/src/test.rs index 91b176e71..7e6a7f562 100644 --- a/multilinear_extensions/src/test.rs +++ b/multilinear_extensions/src/test.rs @@ -1,9 +1,10 @@ use ark_std::test_rng; -use ff::Field; -use ff_ext::ExtensionField; -use goldilocks::{Goldilocks, GoldilocksExt2}; +use ff_ext::{ExtensionField, FromUniformBytes}; +use p3_field::{FieldAlgebra, extension::BinomialExtensionField}; +use p3_goldilocks::Goldilocks; -type E = GoldilocksExt2; +type F = Goldilocks; +type E = BinomialExtensionField; use crate::{ mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension, MultilinearExtension}, @@ -46,29 +47,31 @@ fn test_eq_xr() { fn test_fix_high_variables() { let poly: DenseMultilinearExtension = DenseMultilinearExtension::from_evaluations_vec(3, vec![ - Goldilocks::from(13), - Goldilocks::from(97), - Goldilocks::from(11), - Goldilocks::from(101), - Goldilocks::from(7), - Goldilocks::from(103), - Goldilocks::from(5), - Goldilocks::from(107), + F::from_canonical_u64(13), + F::from_canonical_u64(97), + F::from_canonical_u64(11), + F::from_canonical_u64(101), + F::from_canonical_u64(7), + F::from_canonical_u64(103), + F::from_canonical_u64(5), + F::from_canonical_u64(107), ]); - let partial_point = vec![E::from(3), E::from(5)]; + let partial_point = vec![E::from_canonical_u64(3), E::from_canonical_u64(5)]; let expected1 = DenseMultilinearExtension::from_evaluations_ext_vec(2, vec![ - -E::from(17), - E::from(127), - -E::from(19), - E::from(131), + -E::from_canonical_u64(17), + E::from_canonical_u64(127), + -E::from_canonical_u64(19), + E::from_canonical_u64(131), ]); let result1 = poly.fix_high_variables(&partial_point[1..]); assert_eq!(result1, expected1); - let expected2 = - DenseMultilinearExtension::from_evaluations_ext_vec(1, vec![-E::from(23), E::from(139)]); + let expected2 = DenseMultilinearExtension::from_evaluations_ext_vec(1, vec![ + -E::from_canonical_u64(23), + E::from_canonical_u64(139), + ]); let result2 = poly.fix_high_variables(&partial_point); assert_eq!(result2, expected2); } @@ -92,7 +95,7 @@ fn build_eq_x_r_for_test(r: &[E]) -> ArcDenseMultilinearExten // we will need 2^num_var evaluations // First, we build array for {1 - r_i} - let one_minus_r: Vec = r.iter().map(|ri| E::ONE - ri).collect(); + let one_minus_r: Vec = r.iter().map(|ri| E::ONE - *ri).collect(); let num_var = r.len(); let mut eval = vec![]; diff --git a/multilinear_extensions/src/virtual_poly.rs b/multilinear_extensions/src/virtual_poly.rs index bd50d659a..36773750a 100644 --- a/multilinear_extensions/src/virtual_poly.rs +++ b/multilinear_extensions/src/virtual_poly.rs @@ -220,7 +220,7 @@ impl<'a, E: ExtensionField> VirtualPolynomial<'a, E> { } for i in 0..1 << self.aux_info.max_num_variables { let point = bit_decompose(i, self.aux_info.max_num_variables); - let point_fr: Vec = point.iter().map(|&x| E::from(x as u64)).collect(); + let point_fr: Vec = point.iter().map(|&x| E::from_bool(x)).collect(); println!("{} {:?}", i, self.evaluate(point_fr.as_ref())) } println!() @@ -371,8 +371,7 @@ pub fn build_eq_x_r_vec(r: &[E]) -> Vec { mod tests { use crate::virtual_poly::{build_eq_x_r_vec, build_eq_x_r_vec_sequential}; use ark_std::rand::thread_rng; - use ff::Field; - use goldilocks::GoldilocksExt2; + use ff_ext::{FromUniformBytes, GoldilocksExt2}; #[test] fn test_build_eq() { diff --git a/poseidon/benches/hashing.rs b/poseidon/benches/hashing.rs index 43a299ddc..d352f7291 100644 --- a/poseidon/benches/hashing.rs +++ b/poseidon/benches/hashing.rs @@ -1,128 +1,128 @@ -use ark_std::test_rng; -use criterion::{BatchSize, Criterion, black_box, criterion_group, criterion_main}; -use ff::Field; -use goldilocks::Goldilocks; -use plonky2::{ - field::{goldilocks_field::GoldilocksField, types::Sample}, - hash::{ - hash_types::HashOut, - hashing::PlonkyPermutation, - poseidon::{PoseidonHash as PlonkyPoseidonHash, PoseidonPermutation}, - }, - plonk::config::Hasher, -}; -use poseidon::{digest::Digest, poseidon_hash::PoseidonHash}; - -fn random_plonky_2_goldy() -> GoldilocksField { - GoldilocksField::rand() -} - -fn random_ceno_goldy() -> Goldilocks { - Goldilocks::random(&mut test_rng()) -} - -fn random_ceno_hash() -> Digest { - Digest( - vec![Goldilocks::random(&mut test_rng()); 4] - .try_into() - .unwrap(), - ) -} - -fn plonky_hash_single(a: GoldilocksField) { - let _result = black_box(PlonkyPoseidonHash::hash_or_noop(&[a])); -} - -fn ceno_hash_single(a: Goldilocks) { - let _result = black_box(PoseidonHash::hash_or_noop(&[a])); -} - -fn plonky_hash_2_to_1(left: HashOut, right: HashOut) { - let _result = black_box(PlonkyPoseidonHash::two_to_one(left, right)); -} - -fn ceno_hash_2_to_1(left: &Digest, right: &Digest) { - let _result = black_box(PoseidonHash::two_to_one(left, right)); -} - -fn plonky_hash_many_to_1(values: &[GoldilocksField]) { - let _result = black_box(PlonkyPoseidonHash::hash_or_noop(values)); -} - -fn ceno_hash_many_to_1(values: &[Goldilocks]) { - let _result = black_box(PoseidonHash::hash_or_noop(values)); -} - -pub fn hashing_benchmark(c: &mut Criterion) { - c.bench_function("plonky hash single", |bencher| { - bencher.iter_batched( - random_plonky_2_goldy, - plonky_hash_single, - BatchSize::SmallInput, - ) - }); - - c.bench_function("plonky hash 2 to 1", |bencher| { - bencher.iter_batched( - || { - ( - HashOut::::rand(), - HashOut::::rand(), - ) - }, - |(left, right)| plonky_hash_2_to_1(left, right), - BatchSize::SmallInput, - ) - }); - - c.bench_function("plonky hash 60 to 1", |bencher| { - bencher.iter_batched( - || GoldilocksField::rand_vec(60), - |sixty_elems| plonky_hash_many_to_1(sixty_elems.as_slice()), - BatchSize::SmallInput, - ) - }); - - c.bench_function("ceno hash single", |bencher| { - bencher.iter_batched(random_ceno_goldy, ceno_hash_single, BatchSize::SmallInput) - }); - - c.bench_function("ceno hash 2 to 1", |bencher| { - bencher.iter_batched( - || (random_ceno_hash(), random_ceno_hash()), - |(left, right)| ceno_hash_2_to_1(&left, &right), - BatchSize::SmallInput, - ) - }); - - c.bench_function("ceno hash 60 to 1", |bencher| { - bencher.iter_batched( - || { - (0..60) - .map(|_| Goldilocks::random(&mut test_rng())) - .collect::>() - }, - |values| ceno_hash_many_to_1(values.as_slice()), - BatchSize::SmallInput, - ) - }); -} - -// bench permutation -pub fn permutation_benchmark(c: &mut Criterion) { - let mut plonky_permutation = PoseidonPermutation::new(core::iter::repeat(GoldilocksField(0))); - let mut ceno_permutation = poseidon::poseidon_permutation::PoseidonPermutation::new( - core::iter::repeat(Goldilocks::ZERO), - ); - - c.bench_function("plonky permute", |bencher| { - bencher.iter(|| plonky_permutation.permute()) - }); - - c.bench_function("ceno permute", |bencher| { - bencher.iter(|| ceno_permutation.permute()) - }); -} - -criterion_group!(benches, permutation_benchmark, hashing_benchmark); -criterion_main!(benches); +// use ark_std::test_rng; +// use criterion::{BatchSize, Criterion, black_box, criterion_group, criterion_main}; +// use ff::Field; +// use goldilocks::Goldilocks; +// use plonky2::{ +// field::{goldilocks_field::GoldilocksField, types::Sample}, +// hash::{ +// hash_types::HashOut, +// hashing::PlonkyPermutation, +// poseidon::{PoseidonHash as PlonkyPoseidonHash, PoseidonPermutation}, +// }, +// plonk::config::Hasher, +// }; +// use poseidon::{digest::Digest, poseidon_hash::PoseidonHash}; + +// fn random_plonky_2_goldy() -> GoldilocksField { +// GoldilocksField::rand() +// } + +// fn random_ceno_goldy() -> Goldilocks { +// Goldilocks::random(&mut test_rng()) +// } + +// fn random_ceno_hash() -> Digest { +// Digest( +// vec![Goldilocks::random(&mut test_rng()); 4] +// .try_into() +// .unwrap(), +// ) +// } + +// fn plonky_hash_single(a: GoldilocksField) { +// let _result = black_box(PlonkyPoseidonHash::hash_or_noop(&[a])); +// } + +// fn ceno_hash_single(a: Goldilocks) { +// let _result = black_box(PoseidonHash::hash_or_noop(&[a])); +// } + +// fn plonky_hash_2_to_1(left: HashOut, right: HashOut) { +// let _result = black_box(PlonkyPoseidonHash::two_to_one(left, right)); +// } + +// fn ceno_hash_2_to_1(left: &Digest, right: &Digest) { +// let _result = black_box(PoseidonHash::two_to_one(left, right)); +// } + +// fn plonky_hash_many_to_1(values: &[GoldilocksField]) { +// let _result = black_box(PlonkyPoseidonHash::hash_or_noop(values)); +// } + +// fn ceno_hash_many_to_1(values: &[Goldilocks]) { +// let _result = black_box(PoseidonHash::hash_or_noop(values)); +// } + +// pub fn hashing_benchmark(c: &mut Criterion) { +// c.bench_function("plonky hash single", |bencher| { +// bencher.iter_batched( +// random_plonky_2_goldy, +// plonky_hash_single, +// BatchSize::SmallInput, +// ) +// }); + +// c.bench_function("plonky hash 2 to 1", |bencher| { +// bencher.iter_batched( +// || { +// ( +// HashOut::::rand(), +// HashOut::::rand(), +// ) +// }, +// |(left, right)| plonky_hash_2_to_1(left, right), +// BatchSize::SmallInput, +// ) +// }); + +// c.bench_function("plonky hash 60 to 1", |bencher| { +// bencher.iter_batched( +// || GoldilocksField::rand_vec(60), +// |sixty_elems| plonky_hash_many_to_1(sixty_elems.as_slice()), +// BatchSize::SmallInput, +// ) +// }); + +// c.bench_function("ceno hash single", |bencher| { +// bencher.iter_batched(random_ceno_goldy, ceno_hash_single, BatchSize::SmallInput) +// }); + +// c.bench_function("ceno hash 2 to 1", |bencher| { +// bencher.iter_batched( +// || (random_ceno_hash(), random_ceno_hash()), +// |(left, right)| ceno_hash_2_to_1(&left, &right), +// BatchSize::SmallInput, +// ) +// }); + +// c.bench_function("ceno hash 60 to 1", |bencher| { +// bencher.iter_batched( +// || { +// (0..60) +// .map(|_| Goldilocks::random(&mut test_rng())) +// .collect::>() +// }, +// |values| ceno_hash_many_to_1(values.as_slice()), +// BatchSize::SmallInput, +// ) +// }); +// } + +// // bench permutation +// pub fn permutation_benchmark(c: &mut Criterion) { +// let mut plonky_permutation = PoseidonPermutation::new(core::iter::repeat(GoldilocksField(0))); +// let mut ceno_permutation = poseidon::poseidon_permutation::PoseidonPermutation::new( +// core::iter::repeat(Goldilocks::ZERO), +// ); + +// c.bench_function("plonky permute", |bencher| { +// bencher.iter(|| plonky_permutation.permute()) +// }); + +// c.bench_function("ceno permute", |bencher| { +// bencher.iter(|| ceno_permutation.permute()) +// }); +// } + +// criterion_group!(benches, permutation_benchmark, hashing_benchmark); +// criterion_main!(benches); diff --git a/poseidon/src/digest.rs b/poseidon/src/digest.rs index a487d676b..a4175cccf 100644 --- a/poseidon/src/digest.rs +++ b/poseidon/src/digest.rs @@ -1,8 +1,8 @@ use crate::constants::DIGEST_WIDTH; -use goldilocks::SmallField; -use serde::{Deserialize, Serialize}; +use ff_ext::SmallField; +use serde::Serialize; -#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)] +#[derive(Clone, Debug, Default, Serialize, PartialEq, Eq)] pub struct Digest(pub [F; DIGEST_WIDTH]); impl TryFrom> for Digest { diff --git a/poseidon/src/poseidon.rs b/poseidon/src/poseidon.rs index ee896d25a..f3736a4af 100644 --- a/poseidon/src/poseidon.rs +++ b/poseidon/src/poseidon.rs @@ -1,7 +1,7 @@ use crate::constants::{ ALL_ROUND_CONSTANTS, HALF_N_FULL_ROUNDS, N_PARTIAL_ROUNDS, N_ROUNDS, SPONGE_WIDTH, }; -use ff_ext::ExtensionField; +use ff_ext::SmallField; use unroll::unroll_for_loops; pub trait Poseidon: AdaptedField { @@ -247,13 +247,6 @@ pub trait AdaptedField: SmallField { fn multiply_accumulate(&self, x: Self, y: Self) -> Self; - /// Returns `n`. Assumes that `n` is already in canonical form, i.e. `n < Self::order()`. - // TODO: Should probably be unsafe. - fn from_canonical_u64(n: u64) -> Self { - debug_assert!(n < Self::ORDER); - Self::from(n) - } - /// # Safety /// Equivalent to *self + Self::from_canonical_u64(rhs), but may be cheaper. The caller must /// ensure that 0 <= rhs < Self::ORDER. The function may return incorrect results if this diff --git a/poseidon/src/poseidon_goldilocks.rs b/poseidon/src/poseidon_goldilocks.rs index 5844fed0d..7a3de236b 100644 --- a/poseidon/src/poseidon_goldilocks.rs +++ b/poseidon/src/poseidon_goldilocks.rs @@ -3,7 +3,7 @@ use crate::{ poseidon::{AdaptedField, Poseidon}, }; use goldilocks::EPSILON; -use p3_field::PrimeField64; +use p3_field::{FieldAlgebra, PrimeField64}; use p3_goldilocks::Goldilocks; #[cfg(target_arch = "x86_64")] use std::hint::unreachable_unchecked; @@ -228,7 +228,10 @@ impl AdaptedField for Goldilocks { fn multiply_accumulate(&self, x: Self, y: Self) -> Self { // u64 + u64 * u64 cannot overflow. - reduce128((self.0 as u128) + (x.0 as u128) * (y.0 as u128)) + reduce128( + (self.as_canonical_u64() as u128) + + (x.as_canonical_u64() as u128) * (y.as_canonical_u64() as u128), + ) } } @@ -278,7 +281,7 @@ const unsafe fn add_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 { fn reduce96((x_lo, x_hi): (u64, u32)) -> Goldilocks { let t1 = x_hi as u64 * EPSILON; let t2 = unsafe { add_no_canonicalize_trashing_input(x_lo, t1) }; - Goldilocks(t2) + Goldilocks::from_canonical_u64(t2) } /// Reduces to a 64-bit value. The result might not be in canonical form; it could be in between the @@ -296,7 +299,7 @@ fn reduce128(x: u128) -> Goldilocks { } let t1 = x_hi_lo * EPSILON; let t2 = unsafe { add_no_canonicalize_trashing_input(t0, t1) }; - Goldilocks(t2) + Goldilocks::from_canonical_u64(t2) } #[inline] diff --git a/poseidon/src/poseidon_hash.rs b/poseidon/src/poseidon_hash.rs index c4559248e..186f5d0e5 100644 --- a/poseidon/src/poseidon_hash.rs +++ b/poseidon/src/poseidon_hash.rs @@ -120,93 +120,93 @@ pub fn compress(x: &Digest, y: &Digest) -> Digest { Digest(perm.squeeze()[..DIGEST_WIDTH].try_into().unwrap()) } -#[cfg(test)] -mod tests { - use crate::{digest::Digest, poseidon_hash::PoseidonHash}; - use goldilocks::Goldilocks; - use plonky2::{ - field::{ - goldilocks_field::GoldilocksField, - types::{PrimeField64, Sample}, - }, - hash::{hash_types::HashOut, poseidon::PoseidonHash as PlonkyPoseidonHash}, - plonk::config::{GenericHashOut, Hasher}, - }; - use rand::{Rng, thread_rng}; - - type PlonkyFieldElements = Vec; - type CenoFieldElements = Vec; - - const N_ITERATIONS: usize = 100; - - fn ceno_goldy_from_plonky_goldy(values: &[GoldilocksField]) -> Vec { - values - .iter() - .map(|value| Goldilocks(value.to_canonical_u64())) - .collect() - } - - fn test_vector_pair(n: usize) -> (PlonkyFieldElements, CenoFieldElements) { - let plonky_elems = GoldilocksField::rand_vec(n); - let ceno_elems = ceno_goldy_from_plonky_goldy(plonky_elems.as_slice()); - (plonky_elems, ceno_elems) - } - - fn random_hash_pair() -> (HashOut, Digest) { - let plonky_random_hash = HashOut::::rand(); - let ceno_equivalent_hash = Digest( - ceno_goldy_from_plonky_goldy(plonky_random_hash.elements.as_slice()) - .try_into() - .unwrap(), - ); - (plonky_random_hash, ceno_equivalent_hash) - } - - fn compare_hash_output( - plonky_hash: HashOut, - ceno_hash: Digest, - ) -> bool { - let plonky_elems = plonky_hash.to_vec(); - let plonky_in_ceno_field = ceno_goldy_from_plonky_goldy(plonky_elems.as_slice()); - plonky_in_ceno_field == ceno_hash.elements() - } - - #[test] - fn compare_hash() { - let mut rng = thread_rng(); - for _ in 0..N_ITERATIONS { - let n = rng.gen_range(5..=100); - let (plonky_elems, ceno_elems) = test_vector_pair(n); - let plonky_out = PlonkyPoseidonHash::hash_or_noop(plonky_elems.as_slice()); - let ceno_out = PoseidonHash::hash_or_noop(ceno_elems.as_slice()); - let ceno_iter = PoseidonHash::hash_or_noop_iter(ceno_elems.iter()); - assert!(compare_hash_output(plonky_out, ceno_out)); - assert!(compare_hash_output(plonky_out, ceno_iter)); - } - } - - #[test] - fn compare_noop() { - let mut rng = thread_rng(); - for _ in 0..N_ITERATIONS { - let n = rng.gen_range(0..=4); - let (plonky_elems, ceno_elems) = test_vector_pair(n); - let plonky_out = PlonkyPoseidonHash::hash_or_noop(plonky_elems.as_slice()); - let ceno_out = PoseidonHash::hash_or_noop(ceno_elems.as_slice()); - let ceno_iter = PoseidonHash::hash_or_noop_iter(ceno_elems.iter()); - assert!(compare_hash_output(plonky_out, ceno_out)); - assert!(compare_hash_output(plonky_out, ceno_iter)); - } - } - - #[test] - fn compare_two_to_one() { - for _ in 0..N_ITERATIONS { - let (plonky_hash_a, ceno_hash_a) = random_hash_pair(); - let (plonky_hash_b, ceno_hash_b) = random_hash_pair(); - let plonky_combined = PlonkyPoseidonHash::two_to_one(plonky_hash_a, plonky_hash_b); - let ceno_combined = PoseidonHash::two_to_one(&ceno_hash_a, &ceno_hash_b); - assert!(compare_hash_output(plonky_combined, ceno_combined)); - } - } -} +// #[cfg(test)] +// mod tests { +// use crate::{digest::Digest, poseidon_hash::PoseidonHash}; +// use p3_goldilocks::Goldilocks; +// use plonky2::{ +// field::{ +// goldilocks_field::GoldilocksField, +// types::{PrimeField64, Sample}, +// }, +// hash::{hash_types::HashOut, poseidon::PoseidonHash as PlonkyPoseidonHash}, +// plonk::config::{GenericHashOut, Hasher}, +// }; +// use rand::{Rng, thread_rng}; + +// type PlonkyFieldElements = Vec; +// type CenoFieldElements = Vec; + +// const N_ITERATIONS: usize = 100; + +// fn ceno_goldy_from_plonky_goldy(values: &[GoldilocksField]) -> Vec { +// values +// .iter() +// .map(|value| Goldilocks(value.to_canonical_u64())) +// .collect() +// } + +// fn test_vector_pair(n: usize) -> (PlonkyFieldElements, CenoFieldElements) { +// let plonky_elems = GoldilocksField::rand_vec(n); +// let ceno_elems = ceno_goldy_from_plonky_goldy(plonky_elems.as_slice()); +// (plonky_elems, ceno_elems) +// } + +// fn random_hash_pair() -> (HashOut, Digest) { +// let plonky_random_hash = HashOut::::rand(); +// let ceno_equivalent_hash = Digest( +// ceno_goldy_from_plonky_goldy(plonky_random_hash.elements.as_slice()) +// .try_into() +// .unwrap(), +// ); +// (plonky_random_hash, ceno_equivalent_hash) +// } + +// fn compare_hash_output( +// plonky_hash: HashOut, +// ceno_hash: Digest, +// ) -> bool { +// let plonky_elems = plonky_hash.to_vec(); +// let plonky_in_ceno_field = ceno_goldy_from_plonky_goldy(plonky_elems.as_slice()); +// plonky_in_ceno_field == ceno_hash.elements() +// } + +// #[test] +// fn compare_hash() { +// let mut rng = thread_rng(); +// for _ in 0..N_ITERATIONS { +// let n = rng.gen_range(5..=100); +// let (plonky_elems, ceno_elems) = test_vector_pair(n); +// let plonky_out = PlonkyPoseidonHash::hash_or_noop(plonky_elems.as_slice()); +// let ceno_out = PoseidonHash::hash_or_noop(ceno_elems.as_slice()); +// let ceno_iter = PoseidonHash::hash_or_noop_iter(ceno_elems.iter()); +// assert!(compare_hash_output(plonky_out, ceno_out)); +// assert!(compare_hash_output(plonky_out, ceno_iter)); +// } +// } + +// #[test] +// fn compare_noop() { +// let mut rng = thread_rng(); +// for _ in 0..N_ITERATIONS { +// let n = rng.gen_range(0..=4); +// let (plonky_elems, ceno_elems) = test_vector_pair(n); +// let plonky_out = PlonkyPoseidonHash::hash_or_noop(plonky_elems.as_slice()); +// let ceno_out = PoseidonHash::hash_or_noop(ceno_elems.as_slice()); +// let ceno_iter = PoseidonHash::hash_or_noop_iter(ceno_elems.iter()); +// assert!(compare_hash_output(plonky_out, ceno_out)); +// assert!(compare_hash_output(plonky_out, ceno_iter)); +// } +// } + +// #[test] +// fn compare_two_to_one() { +// for _ in 0..N_ITERATIONS { +// let (plonky_hash_a, ceno_hash_a) = random_hash_pair(); +// let (plonky_hash_b, ceno_hash_b) = random_hash_pair(); +// let plonky_combined = PlonkyPoseidonHash::two_to_one(plonky_hash_a, plonky_hash_b); +// let ceno_combined = PoseidonHash::two_to_one(&ceno_hash_a, &ceno_hash_b); +// assert!(compare_hash_output(plonky_combined, ceno_combined)); +// } +// } +// } diff --git a/transcript/src/basic.rs b/transcript/src/basic.rs index 24902d105..1c893e5c5 100644 --- a/transcript/src/basic.rs +++ b/transcript/src/basic.rs @@ -1,6 +1,5 @@ use ff::Field; use ff_ext::ExtensionField; -use goldilocks::SmallField; use poseidon::poseidon_permutation::PoseidonPermutation; use crate::{Challenge, ForkableTranscript, Transcript}; From 2f4cb0a3de5820e58890293afbba78c51f64e637 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Mon, 6 Jan 2025 21:53:46 +0800 Subject: [PATCH 03/12] transcript compile pass --- Cargo.lock | 39 +++++++++++----- Cargo.toml | 4 ++ ff_ext/src/lib.rs | 4 +- sumcheck/src/prover.rs | 4 +- transcript/Cargo.toml | 8 ++++ transcript/src/basic.rs | 86 ++++++++++++++++++++++++++++-------- transcript/src/lib.rs | 5 ++- transcript/src/statistics.rs | 27 ++++++++--- 8 files changed, 137 insertions(+), 40 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f46fc40cb..cf3ab9229 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1365,10 +1365,22 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "p3-challenger" +version = "0.1.0" +source = "git+https://github.com/plonky3/plonky3#b0591e9b82d58d10f86359875b5d5fa96433b4cf" +dependencies = [ + "p3-field", + "p3-maybe-rayon", + "p3-symmetric", + "p3-util", + "tracing", +] + [[package]] name = "p3-dft" version = "0.1.0" -source = "git+https://github.com/plonky3/plonky3#f7eb441c94d098c510af9afa47b6743eb6b5261c" +source = "git+https://github.com/plonky3/plonky3#b0591e9b82d58d10f86359875b5d5fa96433b4cf" dependencies = [ "itertools 0.13.0", "p3-field", @@ -1381,7 +1393,7 @@ dependencies = [ [[package]] name = "p3-field" version = "0.1.0" -source = "git+https://github.com/plonky3/plonky3#f7eb441c94d098c510af9afa47b6743eb6b5261c" +source = "git+https://github.com/plonky3/plonky3#b0591e9b82d58d10f86359875b5d5fa96433b4cf" dependencies = [ "itertools 0.13.0", "num-bigint", @@ -1398,7 +1410,7 @@ dependencies = [ [[package]] name = "p3-goldilocks" version = "0.1.0" -source = "git+https://github.com/plonky3/plonky3#f7eb441c94d098c510af9afa47b6743eb6b5261c" +source = "git+https://github.com/plonky3/plonky3#b0591e9b82d58d10f86359875b5d5fa96433b4cf" dependencies = [ "num-bigint", "p3-dft", @@ -1415,7 +1427,7 @@ dependencies = [ [[package]] name = "p3-matrix" version = "0.1.0" -source = "git+https://github.com/plonky3/plonky3#f7eb441c94d098c510af9afa47b6743eb6b5261c" +source = "git+https://github.com/plonky3/plonky3#b0591e9b82d58d10f86359875b5d5fa96433b4cf" dependencies = [ "itertools 0.13.0", "p3-field", @@ -1430,12 +1442,12 @@ dependencies = [ [[package]] name = "p3-maybe-rayon" version = "0.1.0" -source = "git+https://github.com/plonky3/plonky3#f7eb441c94d098c510af9afa47b6743eb6b5261c" +source = "git+https://github.com/plonky3/plonky3#b0591e9b82d58d10f86359875b5d5fa96433b4cf" [[package]] name = "p3-mds" version = "0.1.0" -source = "git+https://github.com/plonky3/plonky3#f7eb441c94d098c510af9afa47b6743eb6b5261c" +source = "git+https://github.com/plonky3/plonky3#b0591e9b82d58d10f86359875b5d5fa96433b4cf" dependencies = [ "itertools 0.13.0", "p3-dft", @@ -1449,7 +1461,7 @@ dependencies = [ [[package]] name = "p3-poseidon" version = "0.1.0" -source = "git+https://github.com/plonky3/plonky3#f7eb441c94d098c510af9afa47b6743eb6b5261c" +source = "git+https://github.com/plonky3/plonky3#b0591e9b82d58d10f86359875b5d5fa96433b4cf" dependencies = [ "p3-field", "p3-mds", @@ -1460,7 +1472,7 @@ dependencies = [ [[package]] name = "p3-poseidon2" version = "0.1.0" -source = "git+https://github.com/plonky3/plonky3#f7eb441c94d098c510af9afa47b6743eb6b5261c" +source = "git+https://github.com/plonky3/plonky3#b0591e9b82d58d10f86359875b5d5fa96433b4cf" dependencies = [ "gcd", "p3-field", @@ -1472,7 +1484,7 @@ dependencies = [ [[package]] name = "p3-symmetric" version = "0.1.0" -source = "git+https://github.com/plonky3/plonky3#f7eb441c94d098c510af9afa47b6743eb6b5261c" +source = "git+https://github.com/plonky3/plonky3#b0591e9b82d58d10f86359875b5d5fa96433b4cf" dependencies = [ "itertools 0.13.0", "p3-field", @@ -1482,7 +1494,7 @@ dependencies = [ [[package]] name = "p3-util" version = "0.1.0" -source = "git+https://github.com/plonky3/plonky3#f7eb441c94d098c510af9afa47b6743eb6b5261c" +source = "git+https://github.com/plonky3/plonky3#b0591e9b82d58d10f86359875b5d5fa96433b4cf" dependencies = [ "serde", ] @@ -2428,6 +2440,13 @@ dependencies = [ "ff", "ff_ext", "goldilocks", + "p3-challenger", + "p3-field", + "p3-goldilocks", + "p3-mds", + "p3-poseidon", + "p3-poseidon2", + "p3-symmetric", "poseidon", "serde", ] diff --git a/Cargo.toml b/Cargo.toml index bd6785d88..5fef73980 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,10 @@ plonky2 = "0.2" p3-field = { git = "https://github.com/plonky3/plonky3" } p3-goldilocks = { git = "https://github.com/plonky3/plonky3" } p3-poseidon = { git = "https://github.com/plonky3/plonky3" } +p3-poseidon2 = { git = "https://github.com/Plonky3/Plonky3.git" } +p3-challenger = { git = "https://github.com/plonky3/plonky3" } +p3-mds = { git = "https://github.com/Plonky3/Plonky3.git" } +p3-symmetric = { git = "https://github.com/Plonky3/Plonky3.git" } poseidon = { path = "./poseidon" } pprof2 = { version = "0.13", features = ["flamegraph"] } prettytable-rs = "^0.10" diff --git a/ff_ext/src/lib.rs b/ff_ext/src/lib.rs index cc67be770..e9ca39e0a 100644 --- a/ff_ext/src/lib.rs +++ b/ff_ext/src/lib.rs @@ -4,7 +4,7 @@ use std::{array::from_fn, iter::repeat_with}; pub use ff; use p3_field::{ - ExtensionField as P3ExtensionField, Field as P3Field, PackedValue, + ExtensionField as P3ExtensionField, Field as P3Field, PackedValue, PrimeField, extension::BinomialExtensionField, }; use p3_goldilocks::Goldilocks; @@ -109,7 +109,7 @@ pub trait ExtensionField: P3ExtensionField + FromUniformBytes { const DEGREE: usize; - type BaseField: SmallField + Ord + P3Field + FromUniformBytes; + type BaseField: SmallField + Ord + PrimeField + FromUniformBytes; fn from_bases(bases: &[Self::BaseField]) -> Self; diff --git a/sumcheck/src/prover.rs b/sumcheck/src/prover.rs index af4169d83..6302a9e6f 100644 --- a/sumcheck/src/prover.rs +++ b/sumcheck/src/prover.rs @@ -71,7 +71,9 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { // extrapolation_aux only need to init once let extrapolation_aux = (1..max_degree) .map(|degree| { - let points = (0..1 + degree as u64).map(E::from).collect::>(); + let points = (0..1 + degree as u64) + .map(E::from_canonical_u64) + .collect::>(); let weights = barycentric_weights(&points); (points, weights) }) diff --git a/transcript/Cargo.toml b/transcript/Cargo.toml index f784b689f..36b6bd8a5 100644 --- a/transcript/Cargo.toml +++ b/transcript/Cargo.toml @@ -16,3 +16,11 @@ ff_ext = { path = "../ff_ext" } goldilocks.workspace = true poseidon.workspace = true serde.workspace = true +p3-field.workspace = true +p3-goldilocks.workspace = true +p3-poseidon.workspace = true +p3-poseidon2.workspace = true +p3-challenger.workspace = true +p3-mds.workspace = true +p3-symmetric.workspace = true + diff --git a/transcript/src/basic.rs b/transcript/src/basic.rs index 1c893e5c5..0fb07ed21 100644 --- a/transcript/src/basic.rs +++ b/transcript/src/basic.rs @@ -1,33 +1,75 @@ -use ff::Field; -use ff_ext::ExtensionField; -use poseidon::poseidon_permutation::PoseidonPermutation; +use std::array; use crate::{Challenge, ForkableTranscript, Transcript}; +use ff_ext::{ExtensionField, SmallField}; +use p3_field::{FieldAlgebra, PrimeField}; +use p3_mds::MdsPermutation; +use p3_poseidon::Poseidon; +use p3_symmetric::Permutation; -#[derive(Copy, Clone)] -pub struct BasicTranscript { - permutation: PoseidonPermutation, +#[derive(Clone)] +pub struct BasicTranscript { + // TODO generalized to accept general permutation + poseidon: Poseidon, + state: [E::BaseField; WIDTH], } -impl BasicTranscript { +impl + BasicTranscript +where + Mds: MdsPermutation + Default, +{ /// Create a new IOP transcript. pub fn new(label: &'static [u8]) -> Self { - let mut permutation = PoseidonPermutation::new(core::iter::repeat(E::BaseField::ZERO)); + let mds = Mds::default(); + + // TODO: Should be calculated for the particular field, width and ALPHA. + let half_num_full_rounds = 4; + let num_partial_rounds = 22; + + let num_rounds = 2 * half_num_full_rounds + num_partial_rounds; + let num_constants = WIDTH * num_rounds; + let constants = vec![E::BaseField::ZERO; num_constants]; + + let poseidon = Poseidon::::new( + half_num_full_rounds, + num_partial_rounds, + constants, + mds, + ); + let input: [E::BaseField; WIDTH] = array::from_fn(|_| E::BaseField::ZERO); let label_f = E::BaseField::bytes_to_field_elements(label); - permutation.set_from_slice(label_f.as_slice(), 0); - permutation.permute(); - Self { permutation } + let mut new = BasicTranscript:: { + poseidon, + state: input, + }; + new.set_from_slice(label_f.as_slice(), 0); + new.poseidon.permute_mut(&mut new.state); + new } -} -impl Transcript for BasicTranscript { - fn append_field_elements(&mut self, elements: &[E::BaseField]) { - self.permutation.set_from_slice(elements, 0); - self.permutation.permute(); + /// Set state element `i` to be `elts[i] for i = + /// start_idx..start_idx + n` where `n = min(elts.len(), + /// WIDTH-start_idx)`. Panics if `start_idx > SPONGE_WIDTH`. + fn set_from_slice(&mut self, elts: &[E::BaseField], start_idx: usize) { + let begin = start_idx; + let end = start_idx + elts.len(); + self.state[begin..end].copy_from_slice(elts) } +} +impl Transcript + for BasicTranscript +where + Mds: MdsPermutation + Default, +{ fn append_field_element_ext(&mut self, element: &E) { - self.append_field_elements(element.as_bases()) + self.append_field_elements(element.as_bases()); + } + + fn append_field_elements(&mut self, elements: &[E::BaseField]) { + self.set_from_slice(elements, 0); + self.poseidon.permute_mut(&mut self.state); } fn read_challenge(&mut self) -> Challenge { @@ -37,7 +79,7 @@ impl Transcript for BasicTranscript { // We select `from_base` here to make it more clear that // we only use the first 2 fields here to construct the // challenge as an extension field element. - let elements = E::from_bases(&self.permutation.squeeze()[..2]); + let elements = E::from_bases(&self.state[..8][..2]); Challenge { elements } } @@ -59,4 +101,10 @@ impl Transcript for BasicTranscript { } } -impl ForkableTranscript for BasicTranscript {} +impl ForkableTranscript + for BasicTranscript +where + E::BaseField: FieldAlgebra + PrimeField, + Mds: MdsPermutation + Default, +{ +} diff --git a/transcript/src/lib.rs b/transcript/src/lib.rs index 8d01dd366..47a2767b9 100644 --- a/transcript/src/lib.rs +++ b/transcript/src/lib.rs @@ -7,6 +7,8 @@ pub mod basic; mod statistics; pub mod syncronized; pub use basic::BasicTranscript; +use ff_ext::SmallField; +use p3_field::FieldAlgebra; pub use statistics::{BasicTranscriptWithStat, StatisticRecorder}; pub use syncronized::TranscriptSyncronized; @@ -16,7 +18,6 @@ pub struct Challenge { } use ff_ext::ExtensionField; -use goldilocks::SmallField; /// The Transcript trait pub trait Transcript { /// Append a slice of base field elemets to the transcript. @@ -98,7 +99,7 @@ pub trait ForkableTranscript: Transcript + Sized + Clone { (0..n) .map(|i| { let mut fork = self.clone(); - fork.append_field_element(&(i as u64).into()); + fork.append_field_element(&E::BaseField::from_canonical_u64(i as u64)); fork }) .collect() diff --git a/transcript/src/statistics.rs b/transcript/src/statistics.rs index 113f4aa1c..f6a7205fd 100644 --- a/transcript/src/statistics.rs +++ b/transcript/src/statistics.rs @@ -1,5 +1,6 @@ use crate::{BasicTranscript, Challenge, ForkableTranscript, Transcript}; use ff_ext::ExtensionField; +use p3_mds::MdsPermutation; use std::cell::RefCell; #[derive(Debug, Default)] @@ -10,21 +11,30 @@ pub struct Statistic { pub type StatisticRecorder = RefCell; #[derive(Clone)] -pub struct BasicTranscriptWithStat<'a, E: ExtensionField> { - inner: BasicTranscript, +pub struct BasicTranscriptWithStat<'a, E: ExtensionField, Mds, const WIDTH: usize, const ALPHA: u64> +{ + inner: BasicTranscript, stat: &'a StatisticRecorder, } -impl<'a, E: ExtensionField> BasicTranscriptWithStat<'a, E> { +impl<'a, E: ExtensionField, Mds, const WIDTH: usize, const ALPHA: u64> + BasicTranscriptWithStat<'a, E, Mds, WIDTH, ALPHA> +where + Mds: MdsPermutation + Default, +{ pub fn new(stat: &'a StatisticRecorder, label: &'static [u8]) -> Self { Self { - inner: BasicTranscript::<_>::new(label), + inner: BasicTranscript::<_, _, _, _>::new(label), stat, } } } -impl Transcript for BasicTranscriptWithStat<'_, E> { +impl Transcript + for BasicTranscriptWithStat<'_, E, Mds, WIDTH, ALPHA> +where + Mds: MdsPermutation + Default, +{ fn append_field_elements(&mut self, elements: &[E::BaseField]) { self.stat.borrow_mut().field_appended_num += 1; self.inner.append_field_elements(elements) @@ -56,4 +66,9 @@ impl Transcript for BasicTranscriptWithStat<'_, E> { } } -impl ForkableTranscript for BasicTranscriptWithStat<'_, E> {} +impl ForkableTranscript + for BasicTranscriptWithStat<'_, E, Mds, WIDTH, ALPHA> +where + Mds: MdsPermutation + Default, +{ +} From 0673fd2e57481cacd205c5caeecdb4fcfba3682b Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Mon, 6 Jan 2025 23:03:10 +0800 Subject: [PATCH 04/12] sumcheck util to p3 field --- Cargo.lock | 2 +- sumcheck/Cargo.toml | 1 + sumcheck/src/prover.rs | 10 +++++--- sumcheck/src/structs.rs | 6 ++--- sumcheck/src/util.rs | 57 +++++++++++++++++++---------------------- transcript/Cargo.toml | 1 - 6 files changed, 37 insertions(+), 40 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cf3ab9229..7d9bea908 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2177,6 +2177,7 @@ dependencies = [ "goldilocks", "itertools 0.13.0", "multilinear_extensions", + "p3-field", "rayon", "serde", "tracing", @@ -2447,7 +2448,6 @@ dependencies = [ "p3-poseidon", "p3-poseidon2", "p3-symmetric", - "poseidon", "serde", ] diff --git a/sumcheck/Cargo.toml b/sumcheck/Cargo.toml index 5d747be06..71c223136 100644 --- a/sumcheck/Cargo.toml +++ b/sumcheck/Cargo.toml @@ -18,6 +18,7 @@ itertools.workspace = true rayon.workspace = true serde.workspace = true tracing.workspace = true +p3-field.workspace = true crossbeam-channel.workspace = true multilinear_extensions = { path = "../multilinear_extensions", features = ["parallel"] } diff --git a/sumcheck/src/prover.rs b/sumcheck/src/prover.rs index 6302a9e6f..edfd96ccd 100644 --- a/sumcheck/src/prover.rs +++ b/sumcheck/src/prover.rs @@ -542,7 +542,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { let extrapolation = (0..self.poly.aux_info.max_degree - products.len()) .map(|i| { let (points, weights) = &self.extrapolation_aux[products.len() - 1]; - let at = E::from((products.len() + 1 + i) as u64); + let at = E::from_canonical_u64((products.len() + 1 + i) as u64); serial_extrapolate(points, weights, &sum, &at) }) .collect::>(); @@ -684,7 +684,9 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { poly: polynomial, extrapolation_aux: (1..max_degree) .map(|degree| { - let points = (0..1 + degree as u64).map(E::from).collect::>(); + let points = (0..1 + degree as u64) + .map(E::from_canonical_u64) + .collect::>(); let weights = barycentric_weights(&points); (points, weights) }) @@ -891,14 +893,14 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { _ => unimplemented!("do not support degree > 3"), }; exit_span!(span); - sum.iter_mut().for_each(|sum| *sum *= coefficient); + sum.iter_mut().for_each(|sum| *sum *= *coefficient); let span = entered_span!("extrapolation"); let extrapolation = (0..self.poly.aux_info.max_degree - products.len()) .into_par_iter() .map(|i| { let (points, weights) = &self.extrapolation_aux[products.len() - 1]; - let at = E::from((products.len() + 1 + i) as u64); + let at = E::from_canonical_u64((products.len() + 1 + i) as u64); extrapolate(points, weights, &sum, &at) }) .collect::>(); diff --git a/sumcheck/src/structs.rs b/sumcheck/src/structs.rs index 2316d79aa..a2da024a2 100644 --- a/sumcheck/src/structs.rs +++ b/sumcheck/src/structs.rs @@ -1,12 +1,12 @@ use ff_ext::ExtensionField; use multilinear_extensions::virtual_poly::VirtualPolynomial; -use serde::{Deserialize, Serialize}; +use serde::Serialize; use transcript::Challenge; /// An IOP proof is a collections of /// - messages from prover to verifier at each round through the interactive protocol. /// - a point that is generated by the transcript for evaluation -#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize)] pub struct IOPProof { pub point: Vec, pub proofs: Vec>, @@ -19,7 +19,7 @@ impl IOPProof { /// A message from the prover to the verifier at a given round /// is a list of evaluations. -#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize)] pub struct IOPProverMessage { pub(crate) evaluations: Vec, } diff --git a/sumcheck/src/util.rs b/sumcheck/src/util.rs index 4fa84f1e2..056bc9068 100644 --- a/sumcheck/src/util.rs +++ b/sumcheck/src/util.rs @@ -7,16 +7,16 @@ use std::{ }; use ark_std::{end_timer, start_timer}; -use ff::PrimeField; use ff_ext::ExtensionField; use multilinear_extensions::{ mle::DenseMultilinearExtension, op_mle, virtual_poly::VirtualPolynomial, }; +use p3_field::Field; use rayon::{prelude::ParallelIterator, slice::ParallelSliceMut}; use crate::structs::IOPProverState; -pub fn barycentric_weights(points: &[F]) -> Vec { +pub fn barycentric_weights(points: &[F]) -> Vec { let mut weights = points .iter() .enumerate() @@ -25,7 +25,7 @@ pub fn barycentric_weights(points: &[F]) -> Vec { .iter() .enumerate() .filter(|&(i, _)| (i != j)) - .map(|(_, point_i)| *point_j - point_i) + .map(|(_, point_i)| *point_j - *point_i) .reduce(|acc, value| acc * value) .unwrap_or(F::ONE) }) @@ -35,17 +35,17 @@ pub fn barycentric_weights(points: &[F]) -> Vec { } // Computes the inverse of each field element in a vector {v_i} using a parallelized batch inversion. -pub fn batch_inversion(v: &mut [F]) { +pub fn batch_inversion(v: &mut [F]) { batch_inversion_and_mul(v, &F::ONE); } // Computes the inverse of each field element in a vector {v_i} sequentially (serial version). -pub fn serial_batch_inversion(v: &mut [F]) { +pub fn serial_batch_inversion(v: &mut [F]) { serial_batch_inversion_and_mul(v, &F::ONE) } // Given a vector of field elements {v_i}, compute the vector {coeff * v_i^(-1)} -pub fn batch_inversion_and_mul(v: &mut [F], coeff: &F) { +pub fn batch_inversion_and_mul(v: &mut [F], coeff: &F) { // Divide the vector v evenly between all available cores let min_elements_per_thread = 1; let num_cpus_available = rayon::current_num_threads(); @@ -60,7 +60,7 @@ pub fn batch_inversion_and_mul(v: &mut [F], coeff: &F) { /// Given a vector of field elements {v_i}, compute the vector {coeff * v_i^(-1)}. /// This method is explicitly single-threaded. -fn serial_batch_inversion_and_mul(v: &mut [F], coeff: &F) { +fn serial_batch_inversion_and_mul(v: &mut [F], coeff: &F) { // Montgomery’s Trick and Fast Implementation of Masked AES // Genelle, Prouff and Quisquater // Section 3.2 @@ -70,16 +70,16 @@ fn serial_batch_inversion_and_mul(v: &mut [F], coeff: &F) { // First pass: compute [a, ab, abc, ...] let mut prod = Vec::with_capacity(v.len()); let mut tmp = F::ONE; - for f in v.iter().filter(|f| !f.is_zero_vartime()) { - tmp.mul_assign(f); + for f in v.iter().filter(|f| !f.is_zero()) { + tmp.mul_assign(*f); prod.push(tmp); } // Invert `tmp`. - tmp = tmp.invert().unwrap(); // Guaranteed to be nonzero. + tmp = tmp.try_inverse().unwrap(); // Guaranteed to be nonzero. // Multiply product by coeff, so all inverses will be scaled by coeff - tmp *= coeff; + tmp *= *coeff; // Second pass: iterate backwards to compute inverses for (f, s) in v @@ -87,7 +87,7 @@ fn serial_batch_inversion_and_mul(v: &mut [F], coeff: &F) { // Backwards .rev() // Ignore normalized elements - .filter(|f| !f.is_zero_vartime()) + .filter(|f| !f.is_zero()) // Backwards, skip last element, fill in one for last term. .zip(prod.into_iter().rev().skip(1).chain(Some(F::ONE))) { @@ -98,27 +98,22 @@ fn serial_batch_inversion_and_mul(v: &mut [F], coeff: &F) { } } -pub(crate) fn extrapolate(points: &[F], weights: &[F], evals: &[F], at: &F) -> F { +pub(crate) fn extrapolate(points: &[F], weights: &[F], evals: &[F], at: &F) -> F { inner_extrapolate::(points, weights, evals, at) } -pub(crate) fn serial_extrapolate( - points: &[F], - weights: &[F], - evals: &[F], - at: &F, -) -> F { +pub(crate) fn serial_extrapolate(points: &[F], weights: &[F], evals: &[F], at: &F) -> F { inner_extrapolate::(points, weights, evals, at) } -fn inner_extrapolate( +fn inner_extrapolate( points: &[F], weights: &[F], evals: &[F], at: &F, ) -> F { let (coeffs, sum_inv) = { - let mut coeffs = points.iter().map(|point| *at - point).collect::>(); + let mut coeffs = points.iter().map(|point| *at - *point).collect::>(); if IS_PARALLEL { batch_inversion(&mut coeffs); } else { @@ -126,16 +121,16 @@ fn inner_extrapolate( } let mut sum = F::ZERO; coeffs.iter_mut().zip(weights).for_each(|(coeff, weight)| { - *coeff *= weight; + *coeff *= *weight; sum += *coeff }); - let sum_inv = sum.invert().unwrap_or(F::ZERO); + let sum_inv = sum.try_inverse().unwrap_or(F::ZERO); (coeffs, sum_inv) }; coeffs .iter() .zip(evals) - .map(|(coeff, eval)| *coeff * eval) + .map(|(coeff, eval)| *coeff * *eval) .sum::() * sum_inv } @@ -150,7 +145,7 @@ fn inner_extrapolate( /// negligible compared to field operations. /// TODO: The quadratic term can be removed by precomputing the lagrange /// coefficients. -pub(crate) fn interpolate_uni_poly(p_i: &[F], eval_at: F) -> F { +pub(crate) fn interpolate_uni_poly(p_i: &[F], eval_at: F) -> F { let start = start_timer!(|| "sum check interpolate uni poly opt"); let len = p_i.len(); @@ -160,7 +155,7 @@ pub(crate) fn interpolate_uni_poly(p_i: &[F], eval_at: F) -> F { // `prod = \prod_{j} (eval_at - j)` for e in 1..len { - let tmp = eval_at - F::from(e as u64); + let tmp = eval_at - F::from_canonical_u64(e as u64); evals.push(tmp); prod *= tmp; } @@ -187,12 +182,12 @@ pub(crate) fn interpolate_uni_poly(p_i: &[F], eval_at: F) -> F { let mut denom_down = F::ONE; for i in (0..len).rev() { - res += p_i[i] * prod * denom_down * (denom_up * evals[i]).invert().unwrap(); + res += p_i[i] * prod * denom_down * (denom_up * evals[i]).inverse(); // compute denom for the next step is current_denom * (len-i)/i if i != 0 { - denom_up *= -F::from((len - i) as u64); - denom_down *= F::from(i as u64); + denom_up *= -F::from_canonical_u64((len - i) as u64); + denom_down *= F::from_canonical_u64(i as u64); } } end_timer!(start); @@ -201,10 +196,10 @@ pub(crate) fn interpolate_uni_poly(p_i: &[F], eval_at: F) -> F { /// compute the factorial(a) = 1 * 2 * ... * a #[inline] -fn field_factorial(a: usize) -> F { +fn field_factorial(a: usize) -> F { let mut res = F::ONE; for i in 2..=a { - res *= F::from(i as u64); + res *= F::from_canonical_u64(i as u64); } res } diff --git a/transcript/Cargo.toml b/transcript/Cargo.toml index 36b6bd8a5..acbdabe26 100644 --- a/transcript/Cargo.toml +++ b/transcript/Cargo.toml @@ -14,7 +14,6 @@ crossbeam-channel.workspace = true ff.workspace = true ff_ext = { path = "../ff_ext" } goldilocks.workspace = true -poseidon.workspace = true serde.workspace = true p3-field.workspace = true p3-goldilocks.workspace = true From 09c20004e038a510a2a6cbdc3b570ed86897fc06 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Tue, 7 Jan 2025 10:59:54 +0800 Subject: [PATCH 05/12] sumcheck crate almost compile pass --- multilinear_extensions/src/mle.rs | 9 +++-- sumcheck/src/prover.rs | 58 ++++++++++++++++--------------- 2 files changed, 36 insertions(+), 31 deletions(-) diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index 39de4fc3f..3f9c80dcc 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -1105,6 +1105,9 @@ macro_rules! op_mle_product_3 { _ => op_mle_product_3!(@internal |$f1, $f2, $f3| $op, |$bb_out| $op_bb_out), } }; + (|$f1:ident, $f2:ident, $f3:ident| $op:expr) => { + op_mle_product_3!(|$f1, $f2, $f3| $op, |out| out), + }; (@internal |$f1:ident, $f2:ident, $f3:ident| $op:expr, |$bb_out:ident| $op_bb_out:expr) => { match (&$f1.evaluations(), &$f2.evaluations(), &$f3.evaluations()) { ( @@ -1119,21 +1122,21 @@ macro_rules! op_mle_product_3 { $crate::mle::FieldType::Base(f2_vec), $crate::mle::FieldType::Base(f3_vec), ) => { - op_mle3_range!($f1, $f2, $f3, f1_vec, f2_vec, f3_vec, $op, |$bb_out| $op_bb_out) + op_mle3_range!($f1, $f2, $f3, f1_vec, f2_vec, f3_vec, $op, |out| out) } ( $crate::mle::FieldType::Ext(f1_vec), $crate::mle::FieldType::Ext(f2_vec), $crate::mle::FieldType::Ext(f3_vec), ) => { - op_mle3_range!($f1, $f2, $f3, f1_vec, f2_vec, f3_vec, $op, |$bb_out| $op_bb_out) + op_mle3_range!($f1, $f2, $f3, f1_vec, f2_vec, f3_vec, $op, |out| out) } ( $crate::mle::FieldType::Ext(f1_vec), $crate::mle::FieldType::Ext(f2_vec), $crate::mle::FieldType::Base(f3_vec), ) => { - op_mle3_range!($f1, $f2, $f3, f1_vec, f2_vec, f3_vec, $op, |$bb_out| $op_bb_out) + op_mle3_range!($f1, $f2, $f3, f1_vec, f2_vec, f3_vec, $op, |out| out) } // ... add more canonial case if missing (a, b, c) => unreachable!( diff --git a/sumcheck/src/prover.rs b/sumcheck/src/prover.rs index edfd96ccd..38bf2bdfc 100644 --- a/sumcheck/src/prover.rs +++ b/sumcheck/src/prover.rs @@ -1,4 +1,4 @@ -use std::{array, mem, sync::Arc}; +use std::{mem, sync::Arc}; use ark_std::{end_timer, start_timer}; use crossbeam_channel::bounded; @@ -26,6 +26,7 @@ use crate::{ merge_sumcheck_polys, serial_extrapolate, }, }; +use p3_field::FieldAlgebra; impl<'a, E: ExtensionField> IOPProverState<'a, E> { /// Given a virtual polynomial, generate an IOP proof. @@ -438,11 +439,12 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { let res = (0..largest_even_below(f.len())) .step_by(2) .rev() - .fold(AdditiveArray::<_, 2>(array::from_fn(|_| 0.into())), |mut acc, b| { - acc.0[0] += f[b]; - acc.0[1] += f[b+1]; - acc - }); + .map(|b| { + AdditiveArray([ + f[b], + f[b+1], + ]) + }).sum::>(); let res = if f.len() == 1 { AdditiveArray::<_, 2>([f[0]; 2]) } else { @@ -450,12 +452,12 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { }; let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f.len()).max(1) + self.round - 1); if num_vars_multiplicity > 0 { - AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity))) + AdditiveArray(res.0.map(|e| e * E::BaseField::from_canonical_u64(1 << num_vars_multiplicity))) } else { res } }, - |sum| AdditiveArray(sum.0.map(E::from)) + |sum| AdditiveArray(sum.0.map(E::from_base)) } .to_vec() } @@ -466,15 +468,15 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { ); commutative_op_mle_pair!( |f, g| { - let res = (0..largest_even_below(f.len())).step_by(2).rev().fold( - AdditiveArray::<_, 3>(array::from_fn(|_| 0.into())), - |mut acc, b| { - acc.0[0] += f[b] * g[b]; - acc.0[1] += f[b + 1] * g[b + 1]; - acc.0[2] += - (f[b + 1] + f[b + 1] - f[b]) * (g[b + 1] + g[b + 1] - g[b]); - acc - }); + let res = (0..largest_even_below(f.len())).step_by(2).rev().map(|b| { + AdditiveArray([ + f[b] * g[b], + f[b + 1] * g[b + 1], + (f[b + 1] + f[b + 1] - f[b]) + * (g[b + 1] + g[b + 1] - g[b]), + ]) + }) + .sum::>(); let res = if f.len() == 1 { AdditiveArray::<_, 3>([f[0] * g[0]; 3]) } else { @@ -482,12 +484,12 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { }; let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f.len()).max(1) + self.round - 1); if num_vars_multiplicity > 0 { - AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity))) + AdditiveArray(res.0.map(|e| e * E::BaseField::from_canonical_u64(1 << num_vars_multiplicity))) } else { res } }, - |sum| AdditiveArray(sum.0.map(E::from)) + |sum| AdditiveArray(sum.0.map(E::from_base)) ) .to_vec() } @@ -524,19 +526,19 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { }; let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f1.len()).max(1) + self.round - 1); if num_vars_multiplicity > 0 { - AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity))) + AdditiveArray(res.0.map(|e| e * E::BaseField::from_canonical_u64(1 << num_vars_multiplicity))) } else { res } }, - |sum| AdditiveArray(sum.0.map(E::from)) + |sum| AdditiveArray(sum.0.map(E::from_base)) ) .to_vec() } _ => unimplemented!("do not support degree > 3"), }; exit_span!(span); - sum.iter_mut().for_each(|sum| *sum *= coefficient); + sum.iter_mut().for_each(|sum| *sum *= *coefficient); let span = entered_span!("extrapolation"); let extrapolation = (0..self.poly.aux_info.max_degree - products.len()) @@ -804,12 +806,12 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { }; let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f.len()).max(1) + self.round - 1); if num_vars_multiplicity > 0 { - AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity))) + AdditiveArray(res.0.map(|e| e * E::BaseField::from_canonical_u64(1 << num_vars_multiplicity))) } else { res } }, - |sum| AdditiveArray(sum.0.map(E::from)) + |sum| AdditiveArray(sum.0.map(E::from_base)) } .to_vec() } @@ -840,12 +842,12 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { }; let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f.len()).max(1) + self.round - 1); if num_vars_multiplicity > 0 { - AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity))) + AdditiveArray(res.0.map(|e| e * E::BaseField::from_canonical_u64(1 << num_vars_multiplicity))) } else { res } }, - |sum| AdditiveArray(sum.0.map(E::from)) + |sum| AdditiveArray(sum.0.map(E::from_base)) ) .to_vec() } @@ -881,12 +883,12 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { }; let num_vars_multiplicity = self.poly.aux_info.max_num_variables - (ceil_log2(f1.len()).max(1) + self.round - 1); if num_vars_multiplicity > 0 { - AdditiveArray(res.0.map(|e| e * E::BaseField::from(1 << num_vars_multiplicity))) + AdditiveArray(res.0.map(|e| e * E::BaseField::from_canonical_u64(1 << num_vars_multiplicity))) } else { res } }, - |sum| AdditiveArray(sum.0.map(E::from)) + |sum| AdditiveArray(sum.0.map(E::from_base)) ) .to_vec() } From f01f06af592002a25536f6b50eb69b9510942adb Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Tue, 7 Jan 2025 11:28:17 +0800 Subject: [PATCH 06/12] clean up ff and in-house goldilock from mle/sumcheck crate --- Cargo.lock | 6 +-- multilinear_extensions/Cargo.toml | 2 - multilinear_extensions/src/mle.rs | 3 +- multilinear_extensions/src/virtual_poly.rs | 4 +- sumcheck/Cargo.toml | 4 +- sumcheck/benches/devirgo_sumcheck.rs | 11 ++-- sumcheck/src/test.rs | 63 +++++++++++++--------- transcript/src/basic.rs | 29 +++++----- transcript/src/statistics.rs | 23 ++++---- 9 files changed, 75 insertions(+), 70 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7d9bea908..45dd98e20 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1155,9 +1155,7 @@ version = "0.1.0" dependencies = [ "ark-std", "env_logger", - "ff", "ff_ext", - "goldilocks", "itertools 0.13.0", "log", "p3-field", @@ -2172,12 +2170,12 @@ dependencies = [ "ark-std", "criterion", "crossbeam-channel", - "ff", "ff_ext", - "goldilocks", "itertools 0.13.0", "multilinear_extensions", "p3-field", + "p3-goldilocks", + "p3-mds", "rayon", "serde", "tracing", diff --git a/multilinear_extensions/Cargo.toml b/multilinear_extensions/Cargo.toml index a4abdb6ea..f6269c85b 100644 --- a/multilinear_extensions/Cargo.toml +++ b/multilinear_extensions/Cargo.toml @@ -11,9 +11,7 @@ version.workspace = true [dependencies] ark-std.workspace = true -ff.workspace = true ff_ext = { path = "../ff_ext" } -goldilocks.workspace = true itertools.workspace = true rayon.workspace = true serde.workspace = true diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index 3f9c80dcc..8a78ea7fa 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -3,9 +3,8 @@ use std::{any::TypeId, borrow::Cow, mem, sync::Arc}; use crate::{op_mle, util::ceil_log2}; use ark_std::{end_timer, rand::RngCore, start_timer}; use core::hash::Hash; -use ff::Field; use ff_ext::{ExtensionField, FromUniformBytes}; -use p3_field::FieldAlgebra; +use p3_field::{Field, FieldAlgebra}; use rayon::iter::{ IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, }; diff --git a/multilinear_extensions/src/virtual_poly.rs b/multilinear_extensions/src/virtual_poly.rs index 36773750a..957e39354 100644 --- a/multilinear_extensions/src/virtual_poly.rs +++ b/multilinear_extensions/src/virtual_poly.rs @@ -5,9 +5,9 @@ use crate::{ util::{bit_decompose, create_uninit_vec, max_usable_threads}, }; use ark_std::{end_timer, rand::Rng, start_timer}; -use ff::PrimeField; use ff_ext::ExtensionField; use itertools::Itertools; +use p3_field::Field; use rayon::{ iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}, slice::ParallelSliceMut, @@ -228,7 +228,7 @@ impl<'a, E: ExtensionField> VirtualPolynomial<'a, E> { } /// Evaluate eq polynomial. -pub fn eq_eval(x: &[F], y: &[F]) -> F { +pub fn eq_eval(x: &[F], y: &[F]) -> F { assert_eq!(x.len(), y.len(), "x and y have different length"); let start = start_timer!(|| "eq_eval"); diff --git a/sumcheck/Cargo.toml b/sumcheck/Cargo.toml index 71c223136..47ceb0c16 100644 --- a/sumcheck/Cargo.toml +++ b/sumcheck/Cargo.toml @@ -11,9 +11,7 @@ version.workspace = true [dependencies] ark-std.workspace = true -ff.workspace = true ff_ext = { path = "../ff_ext" } -goldilocks.workspace = true itertools.workspace = true rayon.workspace = true serde.workspace = true @@ -26,6 +24,8 @@ transcript = { path = "../transcript" } [dev-dependencies] criterion.workspace = true +p3-goldilocks.workspace = true +p3-mds.workspace = true [[bench]] harness = false diff --git a/sumcheck/benches/devirgo_sumcheck.rs b/sumcheck/benches/devirgo_sumcheck.rs index 7fb919cb7..5b7acda25 100644 --- a/sumcheck/benches/devirgo_sumcheck.rs +++ b/sumcheck/benches/devirgo_sumcheck.rs @@ -5,11 +5,11 @@ use std::{array, time::Duration}; use ark_std::test_rng; use criterion::*; -use ff_ext::ExtensionField; +use ff_ext::{ExtensionField, GoldilocksExt2}; use itertools::Itertools; +use p3_goldilocks::MdsMatrixGoldilocks; use sumcheck::{structs::IOPProverState, util::ceil_log2}; -use goldilocks::GoldilocksExt2; use multilinear_extensions::{ mle::DenseMultilinearExtension, op_mle, @@ -95,6 +95,7 @@ fn prepare_input<'a, E: ExtensionField>( }) }) .iter() + .cloned() .sum::(); (asserted_sum, virtual_poly_v1, virtual_poly_v2) @@ -102,6 +103,7 @@ fn prepare_input<'a, E: ExtensionField>( fn sumcheck_fn(c: &mut Criterion) { type E = GoldilocksExt2; + type Mds = MdsMatrixGoldilocks; for nv in NV { // expand more input size once runtime is acceptable @@ -115,7 +117,7 @@ fn sumcheck_fn(c: &mut Criterion) { b.iter_custom(|iters| { let mut time = Duration::new(0, 0); for _ in 0..iters { - let mut prover_transcript = Transcript::::new(b"test"); + let mut prover_transcript = Transcript::::new(b"test"); let (_, virtual_poly, _) = { prepare_input(nv) }; let instant = std::time::Instant::now(); @@ -138,6 +140,7 @@ fn sumcheck_fn(c: &mut Criterion) { fn devirgo_sumcheck_fn(c: &mut Criterion) { type E = GoldilocksExt2; + type Mds = MdsMatrixGoldilocks; let threads = max_usable_threads(); for nv in NV { @@ -152,7 +155,7 @@ fn devirgo_sumcheck_fn(c: &mut Criterion) { b.iter_custom(|iters| { let mut time = Duration::new(0, 0); for _ in 0..iters { - let mut prover_transcript = Transcript::::new(b"test"); + let mut prover_transcript = Transcript::::new(b"test"); let (_, _, virtual_poly_splitted) = { prepare_input(nv) }; let instant = std::time::Instant::now(); diff --git a/sumcheck/src/test.rs b/sumcheck/src/test.rs index 6b2d4a2c3..64184fcb4 100644 --- a/sumcheck/src/test.rs +++ b/sumcheck/src/test.rs @@ -1,10 +1,11 @@ use std::sync::Arc; use ark_std::{rand::RngCore, test_rng}; -use ff::Field; -use ff_ext::ExtensionField; -use goldilocks::GoldilocksExt2; +use ff_ext::{ExtensionField, GoldilocksExt2}; use multilinear_extensions::virtual_poly::VirtualPolynomial; +use p3_field::FieldAlgebra; +use p3_goldilocks::MdsMatrixGoldilocks; +use p3_mds::MdsPermutation; use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator}; use transcript::{BasicTranscript, Transcript}; @@ -12,16 +13,19 @@ use crate::{ structs::{IOPProverState, IOPVerifierState}, util::interpolate_uni_poly, }; +use ff_ext::FromUniformBytes; // TODO add more tests related to various num_vars combination after PR #162 -fn test_sumcheck( +fn test_sumcheck( nv: usize, num_multiplicands_range: (usize, usize), num_products: usize, -) { +) where + Mds: MdsPermutation + Default, +{ let mut rng = test_rng(); - let mut transcript = BasicTranscript::new(b"test"); + let mut transcript = BasicTranscript::::new(b"test"); let (poly, asserted_sum) = VirtualPolynomial::::random(nv, num_multiplicands_range, num_products, &mut rng); @@ -29,7 +33,7 @@ fn test_sumcheck( #[allow(deprecated)] let (proof, _) = IOPProverState::::prove_parallel(poly.clone(), &mut transcript); - let mut transcript = BasicTranscript::new(b"test"); + let mut transcript = BasicTranscript::::new(b"test"); let subclaim = IOPVerifierState::::verify(asserted_sum, &proof, &poly_info, &mut transcript); assert!( poly.evaluate( @@ -44,11 +48,13 @@ fn test_sumcheck( ); } -fn test_sumcheck_internal( +fn test_sumcheck_internal( nv: usize, num_multiplicands_range: (usize, usize), num_products: usize, -) { +) where + Mds: MdsPermutation + Default, +{ let mut rng = test_rng(); let (poly, asserted_sum) = VirtualPolynomial::::random(nv, num_multiplicands_range, num_products, &mut rng); @@ -58,7 +64,7 @@ fn test_sumcheck_internal( let mut verifier_state = IOPVerifierState::verifier_init(&poly_info); let mut challenge = None; - let mut transcript = BasicTranscript::new(b"test"); + let mut transcript = BasicTranscript::::new(b"test"); transcript.append_message(b"initializing transcript for testing"); @@ -103,31 +109,37 @@ fn test_sumcheck_internal( #[test] #[ignore = "temporarily not supporting degree > 2"] fn test_trivial_polynomial() { - test_trivial_polynomial_helper::(); + test_trivial_polynomial_helper::(); } -fn test_trivial_polynomial_helper() { +fn test_trivial_polynomial_helper() +where + Mds: MdsPermutation + Default, +{ let nv = 1; let num_multiplicands_range = (4, 13); let num_products = 5; - test_sumcheck::(nv, num_multiplicands_range, num_products); - test_sumcheck_internal::(nv, num_multiplicands_range, num_products); + test_sumcheck::(nv, num_multiplicands_range, num_products); + test_sumcheck_internal::(nv, num_multiplicands_range, num_products); } #[test] #[ignore = "temporarily not supporting degree > 2"] fn test_normal_polynomial() { - test_normal_polynomial_helper::(); + test_normal_polynomial_helper::(); } -fn test_normal_polynomial_helper() { +fn test_normal_polynomial_helper() +where + Mds: MdsPermutation + Default, +{ let nv = 12; let num_multiplicands_range = (4, 9); let num_products = 5; - test_sumcheck::(nv, num_multiplicands_range, num_products); - test_sumcheck_internal::(nv, num_multiplicands_range, num_products); + test_sumcheck::(nv, num_multiplicands_range, num_products); + test_sumcheck_internal::(nv, num_multiplicands_range, num_products); } // #[test] @@ -142,12 +154,15 @@ fn test_normal_polynomial_helper() { #[test] fn test_extract_sum() { - test_extract_sum_helper::(); + test_extract_sum_helper::(); } -fn test_extract_sum_helper() { +fn test_extract_sum_helper() +where + Mds: MdsPermutation + Default, +{ let mut rng = test_rng(); - let mut transcript = BasicTranscript::::new(b"test"); + let mut transcript = BasicTranscript::::new(b"test"); let (poly, asserted_sum) = VirtualPolynomial::::random(8, (2, 3), 3, &mut rng); #[allow(deprecated)] let (proof, _) = IOPProverState::::prove_parallel(poly, &mut transcript); @@ -183,7 +198,7 @@ fn test_interpolation() { // test a polynomial with 20 known points, i.e., with degree 19 let poly = DensePolynomial::rand(20 - 1, &mut prng); let evals = (0..20) - .map(|i| poly.evaluate(&GoldilocksExt2::from(i))) + .map(|i| poly.evaluate(&GoldilocksExt2::from_canonical_u64(i as u64))) .collect::>(); let query = GoldilocksExt2::random(&mut prng); @@ -192,7 +207,7 @@ fn test_interpolation() { // test a polynomial with 33 known points, i.e., with degree 32 let poly = DensePolynomial::rand(33 - 1, &mut prng); let evals = (0..33) - .map(|i| poly.evaluate(&GoldilocksExt2::from(i))) + .map(|i| poly.evaluate(&GoldilocksExt2::from_canonical_u64(i as u64))) .collect::>(); let query = GoldilocksExt2::random(&mut prng); @@ -201,7 +216,7 @@ fn test_interpolation() { // test a polynomial with 64 known points, i.e., with degree 63 let poly = DensePolynomial::rand(64 - 1, &mut prng); let evals = (0..64) - .map(|i| poly.evaluate(&GoldilocksExt2::from(i))) + .map(|i| poly.evaluate(&GoldilocksExt2::from_canonical_u64(i as u64))) .collect::>(); let query = GoldilocksExt2::random(&mut prng); diff --git a/transcript/src/basic.rs b/transcript/src/basic.rs index 0fb07ed21..022d8051b 100644 --- a/transcript/src/basic.rs +++ b/transcript/src/basic.rs @@ -8,16 +8,15 @@ use p3_poseidon::Poseidon; use p3_symmetric::Permutation; #[derive(Clone)] -pub struct BasicTranscript { +pub struct BasicTranscript { // TODO generalized to accept general permutation - poseidon: Poseidon, - state: [E::BaseField; WIDTH], + poseidon: Poseidon, + state: [E::BaseField; 8], } -impl - BasicTranscript +impl BasicTranscript where - Mds: MdsPermutation + Default, + Mds: MdsPermutation + Default, { /// Create a new IOP transcript. pub fn new(label: &'static [u8]) -> Self { @@ -28,18 +27,18 @@ where let num_partial_rounds = 22; let num_rounds = 2 * half_num_full_rounds + num_partial_rounds; - let num_constants = WIDTH * num_rounds; + let num_constants = 8 * num_rounds; let constants = vec![E::BaseField::ZERO; num_constants]; - let poseidon = Poseidon::::new( + let poseidon = Poseidon::::new( half_num_full_rounds, num_partial_rounds, constants, mds, ); - let input: [E::BaseField; WIDTH] = array::from_fn(|_| E::BaseField::ZERO); + let input: [E::BaseField; 8] = array::from_fn(|_| E::BaseField::ZERO); let label_f = E::BaseField::bytes_to_field_elements(label); - let mut new = BasicTranscript:: { + let mut new = BasicTranscript:: { poseidon, state: input, }; @@ -58,10 +57,9 @@ where } } -impl Transcript - for BasicTranscript +impl Transcript for BasicTranscript where - Mds: MdsPermutation + Default, + Mds: MdsPermutation + Default, { fn append_field_element_ext(&mut self, element: &E) { self.append_field_elements(element.as_bases()); @@ -101,10 +99,9 @@ where } } -impl ForkableTranscript - for BasicTranscript +impl ForkableTranscript for BasicTranscript where E::BaseField: FieldAlgebra + PrimeField, - Mds: MdsPermutation + Default, + Mds: MdsPermutation + Default, { } diff --git a/transcript/src/statistics.rs b/transcript/src/statistics.rs index f6a7205fd..2e28c313a 100644 --- a/transcript/src/statistics.rs +++ b/transcript/src/statistics.rs @@ -11,29 +11,26 @@ pub struct Statistic { pub type StatisticRecorder = RefCell; #[derive(Clone)] -pub struct BasicTranscriptWithStat<'a, E: ExtensionField, Mds, const WIDTH: usize, const ALPHA: u64> -{ - inner: BasicTranscript, +pub struct BasicTranscriptWithStat<'a, E: ExtensionField, Mds> { + inner: BasicTranscript, stat: &'a StatisticRecorder, } -impl<'a, E: ExtensionField, Mds, const WIDTH: usize, const ALPHA: u64> - BasicTranscriptWithStat<'a, E, Mds, WIDTH, ALPHA> +impl<'a, E: ExtensionField, Mds> BasicTranscriptWithStat<'a, E, Mds> where - Mds: MdsPermutation + Default, + Mds: MdsPermutation + Default, { pub fn new(stat: &'a StatisticRecorder, label: &'static [u8]) -> Self { Self { - inner: BasicTranscript::<_, _, _, _>::new(label), + inner: BasicTranscript::<_, _>::new(label), stat, } } } -impl Transcript - for BasicTranscriptWithStat<'_, E, Mds, WIDTH, ALPHA> +impl Transcript for BasicTranscriptWithStat<'_, E, Mds> where - Mds: MdsPermutation + Default, + Mds: MdsPermutation + Default, { fn append_field_elements(&mut self, elements: &[E::BaseField]) { self.stat.borrow_mut().field_appended_num += 1; @@ -66,9 +63,7 @@ where } } -impl ForkableTranscript - for BasicTranscriptWithStat<'_, E, Mds, WIDTH, ALPHA> -where - Mds: MdsPermutation + Default, +impl ForkableTranscript for BasicTranscriptWithStat<'_, E, Mds> where + Mds: MdsPermutation + Default { } From 942266b75e3f57cd98e36291aca6251a62745705 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Tue, 7 Jan 2025 11:57:40 +0800 Subject: [PATCH 07/12] cleanup --- Cargo.lock | 18 -- Cargo.toml | 2 - ff_ext/src/lib.rs | 22 +-- poseidon/Cargo.toml | 4 - poseidon/benches/hashing.rs | 256 ++++++++++++++-------------- poseidon/src/digest.rs | 6 +- poseidon/src/poseidon.rs | 9 +- poseidon/src/poseidon_goldilocks.rs | 15 +- poseidon/src/poseidon_hash.rs | 180 +++++++++---------- transcript/Cargo.toml | 2 - 10 files changed, 235 insertions(+), 279 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a3f5e2a4f..e1a9dbc8d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1364,18 +1364,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" -[[package]] -name = "p3-challenger" -version = "0.1.0" -source = "git+https://github.com/plonky3/plonky3#b0591e9b82d58d10f86359875b5d5fa96433b4cf" -dependencies = [ - "p3-field", - "p3-maybe-rayon", - "p3-symmetric", - "p3-util", - "tracing", -] - [[package]] name = "p3-dft" version = "0.1.0" @@ -1645,11 +1633,7 @@ dependencies = [ "ark-std", "criterion", "ff", - "ff_ext", "goldilocks", - "p3-field", - "p3-goldilocks", - "p3-poseidon", "plonky2", "rand", "serde", @@ -2446,12 +2430,10 @@ dependencies = [ "ff", "ff_ext", "goldilocks", - "p3-challenger", "p3-field", "p3-goldilocks", "p3-mds", "p3-poseidon", - "p3-poseidon2", "p3-symmetric", "serde", ] diff --git a/Cargo.toml b/Cargo.toml index 5fef73980..9e686e612 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,8 +39,6 @@ plonky2 = "0.2" p3-field = { git = "https://github.com/plonky3/plonky3" } p3-goldilocks = { git = "https://github.com/plonky3/plonky3" } p3-poseidon = { git = "https://github.com/plonky3/plonky3" } -p3-poseidon2 = { git = "https://github.com/Plonky3/Plonky3.git" } -p3-challenger = { git = "https://github.com/plonky3/plonky3" } p3-mds = { git = "https://github.com/Plonky3/Plonky3.git" } p3-symmetric = { git = "https://github.com/Plonky3/Plonky3.git" } poseidon = { path = "./poseidon" } diff --git a/ff_ext/src/lib.rs b/ff_ext/src/lib.rs index e9ca39e0a..d8bba3c86 100644 --- a/ff_ext/src/lib.rs +++ b/ff_ext/src/lib.rs @@ -89,24 +89,7 @@ pub trait SmallField: Serialize + P3Field { fn to_noncanonical_u64(&self) -> u64; } -pub trait ExtensionField: P3ExtensionField + FromUniformBytes -// + FromUniformBytes<64> -// + From -// + Add -// + Sub -// + Mul -// // + for<'a> Add<&'a Self::BaseField, Output = Self> -// + for<'a> Sub<&'a Self::BaseField, Output = Self> -// + for<'a> Mul<&'a Self::BaseField, Output = Self> -// + AddAssign -// + SubAssign -// + MulAssign -// + for<'a> AddAssign<&'a Self::BaseField> -// + for<'a> SubAssign<&'a Self::BaseField> -// + for<'a> MulAssign<&'a Self::BaseField> -// + Ord -// + std::hash::Hash -{ +pub trait ExtensionField: P3ExtensionField + FromUniformBytes { const DEGREE: usize; type BaseField: SmallField + Ord + PrimeField + FromUniformBytes; @@ -174,7 +157,6 @@ mod impl_goldilocks { fn from_bases(bases: &[Goldilocks]) -> Self { debug_assert_eq!(bases.len(), 2); Self::from_base_slice(bases) - // Self([bases[0], bases[1]]) } fn as_bases(&self) -> &[Goldilocks] { @@ -183,7 +165,6 @@ mod impl_goldilocks { /// Convert limbs into self fn from_limbs(limbs: &[Self::BaseField]) -> Self { - // Self([limbs[0], limbs[1]]) Self::from_base_slice(&limbs[0..2]) } @@ -192,7 +173,6 @@ mod impl_goldilocks { .iter() .map(|v: &Self::BaseField| v.as_canonical_u64()) .collect() - // ::to_canonical_u64_vec(self) } } } diff --git a/poseidon/Cargo.toml b/poseidon/Cargo.toml index 69709a28f..eff0f50b7 100644 --- a/poseidon/Cargo.toml +++ b/poseidon/Cargo.toml @@ -15,10 +15,6 @@ ff.workspace = true goldilocks.workspace = true serde.workspace = true unroll = "0.1" -ff_ext = { path = "../ff_ext" } -p3-field.workspace = true -p3-goldilocks.workspace = true -p3-poseidon.workspace = true [dev-dependencies] ark-std.workspace = true diff --git a/poseidon/benches/hashing.rs b/poseidon/benches/hashing.rs index d352f7291..43a299ddc 100644 --- a/poseidon/benches/hashing.rs +++ b/poseidon/benches/hashing.rs @@ -1,128 +1,128 @@ -// use ark_std::test_rng; -// use criterion::{BatchSize, Criterion, black_box, criterion_group, criterion_main}; -// use ff::Field; -// use goldilocks::Goldilocks; -// use plonky2::{ -// field::{goldilocks_field::GoldilocksField, types::Sample}, -// hash::{ -// hash_types::HashOut, -// hashing::PlonkyPermutation, -// poseidon::{PoseidonHash as PlonkyPoseidonHash, PoseidonPermutation}, -// }, -// plonk::config::Hasher, -// }; -// use poseidon::{digest::Digest, poseidon_hash::PoseidonHash}; - -// fn random_plonky_2_goldy() -> GoldilocksField { -// GoldilocksField::rand() -// } - -// fn random_ceno_goldy() -> Goldilocks { -// Goldilocks::random(&mut test_rng()) -// } - -// fn random_ceno_hash() -> Digest { -// Digest( -// vec![Goldilocks::random(&mut test_rng()); 4] -// .try_into() -// .unwrap(), -// ) -// } - -// fn plonky_hash_single(a: GoldilocksField) { -// let _result = black_box(PlonkyPoseidonHash::hash_or_noop(&[a])); -// } - -// fn ceno_hash_single(a: Goldilocks) { -// let _result = black_box(PoseidonHash::hash_or_noop(&[a])); -// } - -// fn plonky_hash_2_to_1(left: HashOut, right: HashOut) { -// let _result = black_box(PlonkyPoseidonHash::two_to_one(left, right)); -// } - -// fn ceno_hash_2_to_1(left: &Digest, right: &Digest) { -// let _result = black_box(PoseidonHash::two_to_one(left, right)); -// } - -// fn plonky_hash_many_to_1(values: &[GoldilocksField]) { -// let _result = black_box(PlonkyPoseidonHash::hash_or_noop(values)); -// } - -// fn ceno_hash_many_to_1(values: &[Goldilocks]) { -// let _result = black_box(PoseidonHash::hash_or_noop(values)); -// } - -// pub fn hashing_benchmark(c: &mut Criterion) { -// c.bench_function("plonky hash single", |bencher| { -// bencher.iter_batched( -// random_plonky_2_goldy, -// plonky_hash_single, -// BatchSize::SmallInput, -// ) -// }); - -// c.bench_function("plonky hash 2 to 1", |bencher| { -// bencher.iter_batched( -// || { -// ( -// HashOut::::rand(), -// HashOut::::rand(), -// ) -// }, -// |(left, right)| plonky_hash_2_to_1(left, right), -// BatchSize::SmallInput, -// ) -// }); - -// c.bench_function("plonky hash 60 to 1", |bencher| { -// bencher.iter_batched( -// || GoldilocksField::rand_vec(60), -// |sixty_elems| plonky_hash_many_to_1(sixty_elems.as_slice()), -// BatchSize::SmallInput, -// ) -// }); - -// c.bench_function("ceno hash single", |bencher| { -// bencher.iter_batched(random_ceno_goldy, ceno_hash_single, BatchSize::SmallInput) -// }); - -// c.bench_function("ceno hash 2 to 1", |bencher| { -// bencher.iter_batched( -// || (random_ceno_hash(), random_ceno_hash()), -// |(left, right)| ceno_hash_2_to_1(&left, &right), -// BatchSize::SmallInput, -// ) -// }); - -// c.bench_function("ceno hash 60 to 1", |bencher| { -// bencher.iter_batched( -// || { -// (0..60) -// .map(|_| Goldilocks::random(&mut test_rng())) -// .collect::>() -// }, -// |values| ceno_hash_many_to_1(values.as_slice()), -// BatchSize::SmallInput, -// ) -// }); -// } - -// // bench permutation -// pub fn permutation_benchmark(c: &mut Criterion) { -// let mut plonky_permutation = PoseidonPermutation::new(core::iter::repeat(GoldilocksField(0))); -// let mut ceno_permutation = poseidon::poseidon_permutation::PoseidonPermutation::new( -// core::iter::repeat(Goldilocks::ZERO), -// ); - -// c.bench_function("plonky permute", |bencher| { -// bencher.iter(|| plonky_permutation.permute()) -// }); - -// c.bench_function("ceno permute", |bencher| { -// bencher.iter(|| ceno_permutation.permute()) -// }); -// } - -// criterion_group!(benches, permutation_benchmark, hashing_benchmark); -// criterion_main!(benches); +use ark_std::test_rng; +use criterion::{BatchSize, Criterion, black_box, criterion_group, criterion_main}; +use ff::Field; +use goldilocks::Goldilocks; +use plonky2::{ + field::{goldilocks_field::GoldilocksField, types::Sample}, + hash::{ + hash_types::HashOut, + hashing::PlonkyPermutation, + poseidon::{PoseidonHash as PlonkyPoseidonHash, PoseidonPermutation}, + }, + plonk::config::Hasher, +}; +use poseidon::{digest::Digest, poseidon_hash::PoseidonHash}; + +fn random_plonky_2_goldy() -> GoldilocksField { + GoldilocksField::rand() +} + +fn random_ceno_goldy() -> Goldilocks { + Goldilocks::random(&mut test_rng()) +} + +fn random_ceno_hash() -> Digest { + Digest( + vec![Goldilocks::random(&mut test_rng()); 4] + .try_into() + .unwrap(), + ) +} + +fn plonky_hash_single(a: GoldilocksField) { + let _result = black_box(PlonkyPoseidonHash::hash_or_noop(&[a])); +} + +fn ceno_hash_single(a: Goldilocks) { + let _result = black_box(PoseidonHash::hash_or_noop(&[a])); +} + +fn plonky_hash_2_to_1(left: HashOut, right: HashOut) { + let _result = black_box(PlonkyPoseidonHash::two_to_one(left, right)); +} + +fn ceno_hash_2_to_1(left: &Digest, right: &Digest) { + let _result = black_box(PoseidonHash::two_to_one(left, right)); +} + +fn plonky_hash_many_to_1(values: &[GoldilocksField]) { + let _result = black_box(PlonkyPoseidonHash::hash_or_noop(values)); +} + +fn ceno_hash_many_to_1(values: &[Goldilocks]) { + let _result = black_box(PoseidonHash::hash_or_noop(values)); +} + +pub fn hashing_benchmark(c: &mut Criterion) { + c.bench_function("plonky hash single", |bencher| { + bencher.iter_batched( + random_plonky_2_goldy, + plonky_hash_single, + BatchSize::SmallInput, + ) + }); + + c.bench_function("plonky hash 2 to 1", |bencher| { + bencher.iter_batched( + || { + ( + HashOut::::rand(), + HashOut::::rand(), + ) + }, + |(left, right)| plonky_hash_2_to_1(left, right), + BatchSize::SmallInput, + ) + }); + + c.bench_function("plonky hash 60 to 1", |bencher| { + bencher.iter_batched( + || GoldilocksField::rand_vec(60), + |sixty_elems| plonky_hash_many_to_1(sixty_elems.as_slice()), + BatchSize::SmallInput, + ) + }); + + c.bench_function("ceno hash single", |bencher| { + bencher.iter_batched(random_ceno_goldy, ceno_hash_single, BatchSize::SmallInput) + }); + + c.bench_function("ceno hash 2 to 1", |bencher| { + bencher.iter_batched( + || (random_ceno_hash(), random_ceno_hash()), + |(left, right)| ceno_hash_2_to_1(&left, &right), + BatchSize::SmallInput, + ) + }); + + c.bench_function("ceno hash 60 to 1", |bencher| { + bencher.iter_batched( + || { + (0..60) + .map(|_| Goldilocks::random(&mut test_rng())) + .collect::>() + }, + |values| ceno_hash_many_to_1(values.as_slice()), + BatchSize::SmallInput, + ) + }); +} + +// bench permutation +pub fn permutation_benchmark(c: &mut Criterion) { + let mut plonky_permutation = PoseidonPermutation::new(core::iter::repeat(GoldilocksField(0))); + let mut ceno_permutation = poseidon::poseidon_permutation::PoseidonPermutation::new( + core::iter::repeat(Goldilocks::ZERO), + ); + + c.bench_function("plonky permute", |bencher| { + bencher.iter(|| plonky_permutation.permute()) + }); + + c.bench_function("ceno permute", |bencher| { + bencher.iter(|| ceno_permutation.permute()) + }); +} + +criterion_group!(benches, permutation_benchmark, hashing_benchmark); +criterion_main!(benches); diff --git a/poseidon/src/digest.rs b/poseidon/src/digest.rs index a4175cccf..a487d676b 100644 --- a/poseidon/src/digest.rs +++ b/poseidon/src/digest.rs @@ -1,8 +1,8 @@ use crate::constants::DIGEST_WIDTH; -use ff_ext::SmallField; -use serde::Serialize; +use goldilocks::SmallField; +use serde::{Deserialize, Serialize}; -#[derive(Clone, Debug, Default, Serialize, PartialEq, Eq)] +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)] pub struct Digest(pub [F; DIGEST_WIDTH]); impl TryFrom> for Digest { diff --git a/poseidon/src/poseidon.rs b/poseidon/src/poseidon.rs index f3736a4af..ed5d76d14 100644 --- a/poseidon/src/poseidon.rs +++ b/poseidon/src/poseidon.rs @@ -1,7 +1,7 @@ use crate::constants::{ ALL_ROUND_CONSTANTS, HALF_N_FULL_ROUNDS, N_PARTIAL_ROUNDS, N_ROUNDS, SPONGE_WIDTH, }; -use ff_ext::SmallField; +use goldilocks::SmallField; use unroll::unroll_for_loops; pub trait Poseidon: AdaptedField { @@ -247,6 +247,13 @@ pub trait AdaptedField: SmallField { fn multiply_accumulate(&self, x: Self, y: Self) -> Self; + /// Returns `n`. Assumes that `n` is already in canonical form, i.e. `n < Self::order()`. + // TODO: Should probably be unsafe. + fn from_canonical_u64(n: u64) -> Self { + debug_assert!(n < Self::ORDER); + Self::from(n) + } + /// # Safety /// Equivalent to *self + Self::from_canonical_u64(rhs), but may be cheaper. The caller must /// ensure that 0 <= rhs < Self::ORDER. The function may return incorrect results if this diff --git a/poseidon/src/poseidon_goldilocks.rs b/poseidon/src/poseidon_goldilocks.rs index 7a3de236b..eaab6fcd0 100644 --- a/poseidon/src/poseidon_goldilocks.rs +++ b/poseidon/src/poseidon_goldilocks.rs @@ -2,9 +2,7 @@ use crate::{ constants::N_PARTIAL_ROUNDS, poseidon::{AdaptedField, Poseidon}, }; -use goldilocks::EPSILON; -use p3_field::{FieldAlgebra, PrimeField64}; -use p3_goldilocks::Goldilocks; +use goldilocks::{EPSILON, Goldilocks, SmallField}; #[cfg(target_arch = "x86_64")] use std::hint::unreachable_unchecked; @@ -216,7 +214,7 @@ impl Poseidon for Goldilocks { } impl AdaptedField for Goldilocks { - const ORDER: u64 = Goldilocks::ORDER_U64; + const ORDER: u64 = Goldilocks::MODULUS_U64; fn from_noncanonical_u96(n_lo: u64, n_hi: u32) -> Self { reduce96((n_lo, n_hi)) @@ -228,10 +226,7 @@ impl AdaptedField for Goldilocks { fn multiply_accumulate(&self, x: Self, y: Self) -> Self { // u64 + u64 * u64 cannot overflow. - reduce128( - (self.as_canonical_u64() as u128) - + (x.as_canonical_u64() as u128) * (y.as_canonical_u64() as u128), - ) + reduce128((self.0 as u128) + (x.0 as u128) * (y.0 as u128)) } } @@ -281,7 +276,7 @@ const unsafe fn add_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 { fn reduce96((x_lo, x_hi): (u64, u32)) -> Goldilocks { let t1 = x_hi as u64 * EPSILON; let t2 = unsafe { add_no_canonicalize_trashing_input(x_lo, t1) }; - Goldilocks::from_canonical_u64(t2) + Goldilocks(t2) } /// Reduces to a 64-bit value. The result might not be in canonical form; it could be in between the @@ -299,7 +294,7 @@ fn reduce128(x: u128) -> Goldilocks { } let t1 = x_hi_lo * EPSILON; let t2 = unsafe { add_no_canonicalize_trashing_input(t0, t1) }; - Goldilocks::from_canonical_u64(t2) + Goldilocks(t2) } #[inline] diff --git a/poseidon/src/poseidon_hash.rs b/poseidon/src/poseidon_hash.rs index 186f5d0e5..c4559248e 100644 --- a/poseidon/src/poseidon_hash.rs +++ b/poseidon/src/poseidon_hash.rs @@ -120,93 +120,93 @@ pub fn compress(x: &Digest, y: &Digest) -> Digest { Digest(perm.squeeze()[..DIGEST_WIDTH].try_into().unwrap()) } -// #[cfg(test)] -// mod tests { -// use crate::{digest::Digest, poseidon_hash::PoseidonHash}; -// use p3_goldilocks::Goldilocks; -// use plonky2::{ -// field::{ -// goldilocks_field::GoldilocksField, -// types::{PrimeField64, Sample}, -// }, -// hash::{hash_types::HashOut, poseidon::PoseidonHash as PlonkyPoseidonHash}, -// plonk::config::{GenericHashOut, Hasher}, -// }; -// use rand::{Rng, thread_rng}; - -// type PlonkyFieldElements = Vec; -// type CenoFieldElements = Vec; - -// const N_ITERATIONS: usize = 100; - -// fn ceno_goldy_from_plonky_goldy(values: &[GoldilocksField]) -> Vec { -// values -// .iter() -// .map(|value| Goldilocks(value.to_canonical_u64())) -// .collect() -// } - -// fn test_vector_pair(n: usize) -> (PlonkyFieldElements, CenoFieldElements) { -// let plonky_elems = GoldilocksField::rand_vec(n); -// let ceno_elems = ceno_goldy_from_plonky_goldy(plonky_elems.as_slice()); -// (plonky_elems, ceno_elems) -// } - -// fn random_hash_pair() -> (HashOut, Digest) { -// let plonky_random_hash = HashOut::::rand(); -// let ceno_equivalent_hash = Digest( -// ceno_goldy_from_plonky_goldy(plonky_random_hash.elements.as_slice()) -// .try_into() -// .unwrap(), -// ); -// (plonky_random_hash, ceno_equivalent_hash) -// } - -// fn compare_hash_output( -// plonky_hash: HashOut, -// ceno_hash: Digest, -// ) -> bool { -// let plonky_elems = plonky_hash.to_vec(); -// let plonky_in_ceno_field = ceno_goldy_from_plonky_goldy(plonky_elems.as_slice()); -// plonky_in_ceno_field == ceno_hash.elements() -// } - -// #[test] -// fn compare_hash() { -// let mut rng = thread_rng(); -// for _ in 0..N_ITERATIONS { -// let n = rng.gen_range(5..=100); -// let (plonky_elems, ceno_elems) = test_vector_pair(n); -// let plonky_out = PlonkyPoseidonHash::hash_or_noop(plonky_elems.as_slice()); -// let ceno_out = PoseidonHash::hash_or_noop(ceno_elems.as_slice()); -// let ceno_iter = PoseidonHash::hash_or_noop_iter(ceno_elems.iter()); -// assert!(compare_hash_output(plonky_out, ceno_out)); -// assert!(compare_hash_output(plonky_out, ceno_iter)); -// } -// } - -// #[test] -// fn compare_noop() { -// let mut rng = thread_rng(); -// for _ in 0..N_ITERATIONS { -// let n = rng.gen_range(0..=4); -// let (plonky_elems, ceno_elems) = test_vector_pair(n); -// let plonky_out = PlonkyPoseidonHash::hash_or_noop(plonky_elems.as_slice()); -// let ceno_out = PoseidonHash::hash_or_noop(ceno_elems.as_slice()); -// let ceno_iter = PoseidonHash::hash_or_noop_iter(ceno_elems.iter()); -// assert!(compare_hash_output(plonky_out, ceno_out)); -// assert!(compare_hash_output(plonky_out, ceno_iter)); -// } -// } - -// #[test] -// fn compare_two_to_one() { -// for _ in 0..N_ITERATIONS { -// let (plonky_hash_a, ceno_hash_a) = random_hash_pair(); -// let (plonky_hash_b, ceno_hash_b) = random_hash_pair(); -// let plonky_combined = PlonkyPoseidonHash::two_to_one(plonky_hash_a, plonky_hash_b); -// let ceno_combined = PoseidonHash::two_to_one(&ceno_hash_a, &ceno_hash_b); -// assert!(compare_hash_output(plonky_combined, ceno_combined)); -// } -// } -// } +#[cfg(test)] +mod tests { + use crate::{digest::Digest, poseidon_hash::PoseidonHash}; + use goldilocks::Goldilocks; + use plonky2::{ + field::{ + goldilocks_field::GoldilocksField, + types::{PrimeField64, Sample}, + }, + hash::{hash_types::HashOut, poseidon::PoseidonHash as PlonkyPoseidonHash}, + plonk::config::{GenericHashOut, Hasher}, + }; + use rand::{Rng, thread_rng}; + + type PlonkyFieldElements = Vec; + type CenoFieldElements = Vec; + + const N_ITERATIONS: usize = 100; + + fn ceno_goldy_from_plonky_goldy(values: &[GoldilocksField]) -> Vec { + values + .iter() + .map(|value| Goldilocks(value.to_canonical_u64())) + .collect() + } + + fn test_vector_pair(n: usize) -> (PlonkyFieldElements, CenoFieldElements) { + let plonky_elems = GoldilocksField::rand_vec(n); + let ceno_elems = ceno_goldy_from_plonky_goldy(plonky_elems.as_slice()); + (plonky_elems, ceno_elems) + } + + fn random_hash_pair() -> (HashOut, Digest) { + let plonky_random_hash = HashOut::::rand(); + let ceno_equivalent_hash = Digest( + ceno_goldy_from_plonky_goldy(plonky_random_hash.elements.as_slice()) + .try_into() + .unwrap(), + ); + (plonky_random_hash, ceno_equivalent_hash) + } + + fn compare_hash_output( + plonky_hash: HashOut, + ceno_hash: Digest, + ) -> bool { + let plonky_elems = plonky_hash.to_vec(); + let plonky_in_ceno_field = ceno_goldy_from_plonky_goldy(plonky_elems.as_slice()); + plonky_in_ceno_field == ceno_hash.elements() + } + + #[test] + fn compare_hash() { + let mut rng = thread_rng(); + for _ in 0..N_ITERATIONS { + let n = rng.gen_range(5..=100); + let (plonky_elems, ceno_elems) = test_vector_pair(n); + let plonky_out = PlonkyPoseidonHash::hash_or_noop(plonky_elems.as_slice()); + let ceno_out = PoseidonHash::hash_or_noop(ceno_elems.as_slice()); + let ceno_iter = PoseidonHash::hash_or_noop_iter(ceno_elems.iter()); + assert!(compare_hash_output(plonky_out, ceno_out)); + assert!(compare_hash_output(plonky_out, ceno_iter)); + } + } + + #[test] + fn compare_noop() { + let mut rng = thread_rng(); + for _ in 0..N_ITERATIONS { + let n = rng.gen_range(0..=4); + let (plonky_elems, ceno_elems) = test_vector_pair(n); + let plonky_out = PlonkyPoseidonHash::hash_or_noop(plonky_elems.as_slice()); + let ceno_out = PoseidonHash::hash_or_noop(ceno_elems.as_slice()); + let ceno_iter = PoseidonHash::hash_or_noop_iter(ceno_elems.iter()); + assert!(compare_hash_output(plonky_out, ceno_out)); + assert!(compare_hash_output(plonky_out, ceno_iter)); + } + } + + #[test] + fn compare_two_to_one() { + for _ in 0..N_ITERATIONS { + let (plonky_hash_a, ceno_hash_a) = random_hash_pair(); + let (plonky_hash_b, ceno_hash_b) = random_hash_pair(); + let plonky_combined = PlonkyPoseidonHash::two_to_one(plonky_hash_a, plonky_hash_b); + let ceno_combined = PoseidonHash::two_to_one(&ceno_hash_a, &ceno_hash_b); + assert!(compare_hash_output(plonky_combined, ceno_combined)); + } + } +} diff --git a/transcript/Cargo.toml b/transcript/Cargo.toml index acbdabe26..84ab7ed30 100644 --- a/transcript/Cargo.toml +++ b/transcript/Cargo.toml @@ -18,8 +18,6 @@ serde.workspace = true p3-field.workspace = true p3-goldilocks.workspace = true p3-poseidon.workspace = true -p3-poseidon2.workspace = true -p3-challenger.workspace = true p3-mds.workspace = true p3-symmetric.workspace = true From 1c3a5b6953b78a6797c1ada15a944cb915074289 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Tue, 7 Jan 2025 12:02:48 +0800 Subject: [PATCH 08/12] cleanup hardcode --- transcript/src/basic.rs | 18 +++++++++++------- transcript/src/statistics.rs | 8 ++++---- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/transcript/src/basic.rs b/transcript/src/basic.rs index 022d8051b..2cac83bcb 100644 --- a/transcript/src/basic.rs +++ b/transcript/src/basic.rs @@ -7,16 +7,20 @@ use p3_mds::MdsPermutation; use p3_poseidon::Poseidon; use p3_symmetric::Permutation; +// follow https://github.com/Plonky3/Plonky3/blob/main/poseidon/benches/poseidon.rs#L22 +pub(crate) const WIDTH: usize = 8; +pub(crate) const ALPHA: u64 = 7; + #[derive(Clone)] pub struct BasicTranscript { // TODO generalized to accept general permutation - poseidon: Poseidon, - state: [E::BaseField; 8], + poseidon: Poseidon, + state: [E::BaseField; WIDTH], } impl BasicTranscript where - Mds: MdsPermutation + Default, + Mds: MdsPermutation + Default, { /// Create a new IOP transcript. pub fn new(label: &'static [u8]) -> Self { @@ -27,7 +31,7 @@ where let num_partial_rounds = 22; let num_rounds = 2 * half_num_full_rounds + num_partial_rounds; - let num_constants = 8 * num_rounds; + let num_constants = WIDTH * num_rounds; let constants = vec![E::BaseField::ZERO; num_constants]; let poseidon = Poseidon::::new( @@ -36,7 +40,7 @@ where constants, mds, ); - let input: [E::BaseField; 8] = array::from_fn(|_| E::BaseField::ZERO); + let input = array::from_fn(|_| E::BaseField::ZERO); let label_f = E::BaseField::bytes_to_field_elements(label); let mut new = BasicTranscript:: { poseidon, @@ -59,7 +63,7 @@ where impl Transcript for BasicTranscript where - Mds: MdsPermutation + Default, + Mds: MdsPermutation + Default, { fn append_field_element_ext(&mut self, element: &E) { self.append_field_elements(element.as_bases()); @@ -102,6 +106,6 @@ where impl ForkableTranscript for BasicTranscript where E::BaseField: FieldAlgebra + PrimeField, - Mds: MdsPermutation + Default, + Mds: MdsPermutation + Default, { } diff --git a/transcript/src/statistics.rs b/transcript/src/statistics.rs index 2e28c313a..6a8439956 100644 --- a/transcript/src/statistics.rs +++ b/transcript/src/statistics.rs @@ -1,4 +1,4 @@ -use crate::{BasicTranscript, Challenge, ForkableTranscript, Transcript}; +use crate::{BasicTranscript, Challenge, ForkableTranscript, Transcript, basic::WIDTH}; use ff_ext::ExtensionField; use p3_mds::MdsPermutation; use std::cell::RefCell; @@ -18,7 +18,7 @@ pub struct BasicTranscriptWithStat<'a, E: ExtensionField, Mds> { impl<'a, E: ExtensionField, Mds> BasicTranscriptWithStat<'a, E, Mds> where - Mds: MdsPermutation + Default, + Mds: MdsPermutation + Default, { pub fn new(stat: &'a StatisticRecorder, label: &'static [u8]) -> Self { Self { @@ -30,7 +30,7 @@ where impl Transcript for BasicTranscriptWithStat<'_, E, Mds> where - Mds: MdsPermutation + Default, + Mds: MdsPermutation + Default, { fn append_field_elements(&mut self, elements: &[E::BaseField]) { self.stat.borrow_mut().field_appended_num += 1; @@ -64,6 +64,6 @@ where } impl ForkableTranscript for BasicTranscriptWithStat<'_, E, Mds> where - Mds: MdsPermutation + Default + Mds: MdsPermutation + Default { } From f0079fbaf4c58fb3db6a29d17f9517a5e67597e9 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Tue, 7 Jan 2025 18:02:46 +0800 Subject: [PATCH 09/12] use poseidon in right way --- Cargo.lock | 8 + mpcs/src/util/hash.rs | 4 +- mpcs/src/util/matrix.rs | 123 +++++++++++++ poseidon/Cargo.toml | 6 + poseidon/src/constants.rs | 2 +- poseidon/src/digest.rs | 6 +- poseidon/src/lib.rs | 1 + poseidon/src/poseidon.rs | 254 +-------------------------- poseidon/src/poseidon_goldilocks.rs | 138 +-------------- poseidon/src/poseidon_hash.rs | 86 ++++++--- poseidon/src/poseidon_permutation.rs | 32 +++- sumcheck/Cargo.toml | 1 + sumcheck/src/test.rs | 16 +- transcript/Cargo.toml | 2 +- transcript/src/basic.rs | 83 +++------ transcript/src/lib.rs | 1 - transcript/src/statistics.rs | 20 ++- 17 files changed, 291 insertions(+), 492 deletions(-) create mode 100644 mpcs/src/util/matrix.rs diff --git a/Cargo.lock b/Cargo.lock index e1a9dbc8d..6be119d21 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1633,7 +1633,13 @@ dependencies = [ "ark-std", "criterion", "ff", + "ff_ext", "goldilocks", + "p3-field", + "p3-goldilocks", + "p3-mds", + "p3-poseidon", + "p3-symmetric", "plonky2", "rand", "serde", @@ -2167,6 +2173,7 @@ dependencies = [ "p3-field", "p3-goldilocks", "p3-mds", + "poseidon", "rayon", "serde", "tracing", @@ -2435,6 +2442,7 @@ dependencies = [ "p3-mds", "p3-poseidon", "p3-symmetric", + "poseidon", "serde", ] diff --git a/mpcs/src/util/hash.rs b/mpcs/src/util/hash.rs index 499a053b5..09b1da635 100644 --- a/mpcs/src/util/hash.rs +++ b/mpcs/src/util/hash.rs @@ -5,7 +5,7 @@ use poseidon::poseidon_hash::PoseidonHash; use transcript::Transcript; pub use poseidon::digest::Digest; -use poseidon::poseidon::Poseidon; +use poseidon::poseidon::PoseidonField; pub fn write_digest_to_transcript( digest: &Digest, @@ -44,6 +44,6 @@ pub fn hash_two_leaves_batch_base( hash_two_digests(&a_m_to_1_hash, &b_m_to_1_hash) } -pub fn hash_two_digests(a: &Digest, b: &Digest) -> Digest { +pub fn hash_two_digests(a: &Digest, b: &Digest) -> Digest { PoseidonHash::two_to_one(a, b) } diff --git a/mpcs/src/util/matrix.rs b/mpcs/src/util/matrix.rs new file mode 100644 index 000000000..7860c3a6c --- /dev/null +++ b/mpcs/src/util/matrix.rs @@ -0,0 +1,123 @@ +use std::{ + marker::PhantomData, + ops::Index, + slice::{Chunks, ChunksMut}, + sync::Arc, +}; + +use ff::Field; +use multilinear_extensions::mle::{DenseMultilinearExtension, IntoMLE}; +use rayon::{ + iter::{IntoParallelIterator, ParallelIterator}, + slice::ParallelSliceMut, +}; + +use super::next_pow2_instance_padding; + +#[derive(Clone)] +pub enum InstancePaddingStrategy { + // Pads with default values of underlying type + // Usually zero, but check carefully + Default, + // Pads by repeating last row + RepeatLast, + // Custom strategy consists of a closure + // `pad(i, j) = padding value for cell at row i, column j` + // pad should be able to cross thread boundaries + Custom(Arc u64 + Send + Sync>), +} + +/// TODO replace with plonky3 RowMajorMatrix https://github.com/Plonky3/Plonky3/blob/784b7dd1fa87c1202e63350cc8182d7c5327a7af/matrix/src/dense.rs#L26 +#[derive(Clone)] +pub struct RowMajorMatrix> { + // represent 2D in 1D linear memory and avoid double indirection by Vec> to improve performance + values: V, + num_col: usize, + padding_strategy: InstancePaddingStrategy, + _phantom: PhantomData, +} + +impl> RowMajorMatrix { + pub fn new(num_rows: usize, num_col: usize, padding_strategy: InstancePaddingStrategy) -> Self { + RowMajorMatrix { + values: (0..num_rows * num_col) + .into_par_iter() + .map(|_| T::default()) + .collect(), + num_col, + padding_strategy, + _phantom: PhantomData, + } + } + + pub fn num_cols(&self) -> usize { + self.num_col + } + + pub fn num_padding_instances(&self) -> usize { + next_pow2_instance_padding(self.num_instances()) - self.num_instances() + } + + pub fn num_instances(&self) -> usize { + self.values.len() / self.num_col + } + + pub fn iter_rows(&self) -> Chunks { + self.values.chunks(self.num_col) + } + + pub fn iter_mut(&mut self) -> ChunksMut { + self.values.chunks_mut(self.num_col) + } + + pub fn par_iter_mut(&mut self) -> rayon::slice::ChunksMut { + self.values.par_chunks_mut(self.num_col) + } + + pub fn par_batch_iter_mut(&mut self, num_rows: usize) -> rayon::slice::ChunksMut { + self.values.par_chunks_mut(num_rows * self.num_col) + } + + // Returns column number `column`, padded appropriately according to the stored strategy + pub fn column_padded(&self, column: usize) -> Vec { + let num_instances = self.num_instances(); + let num_padding_instances = self.num_padding_instances(); + + let padding_iter = (num_instances..num_instances + num_padding_instances).map(|i| { + match &self.padding_strategy { + InstancePaddingStrategy::Custom(fun) => T::from(fun(i as u64, column as u64)), + InstancePaddingStrategy::RepeatLast if num_instances > 0 => { + self[num_instances - 1][column] + } + _ => T::default(), + } + }); + + self.values + .iter() + .skip(column) + .step_by(self.num_col) + .copied() + .chain(padding_iter) + .collect::>() + } +} + +impl> RowMajorMatrix { + pub fn into_mles>( + self, + ) -> Vec> { + (0..self.num_col) + .into_par_iter() + .map(|i| self.column_padded(i).into_mle()) + .collect() + } +} + +impl Index for RowMajorMatrix { + type Output = [F]; + + fn index(&self, idx: usize) -> &Self::Output { + &self.values[self.num_col * idx..][..self.num_col] + } +} diff --git a/poseidon/Cargo.toml b/poseidon/Cargo.toml index eff0f50b7..e4d33dc96 100644 --- a/poseidon/Cargo.toml +++ b/poseidon/Cargo.toml @@ -14,7 +14,13 @@ criterion.workspace = true ff.workspace = true goldilocks.workspace = true serde.workspace = true +ff_ext = { path = "../ff_ext" } unroll = "0.1" +p3-field.workspace = true +p3-goldilocks.workspace = true +p3-poseidon.workspace = true +p3-mds.workspace = true +p3-symmetric.workspace = true [dev-dependencies] ark-std.workspace = true diff --git a/poseidon/src/constants.rs b/poseidon/src/constants.rs index db170d100..e983d3d0a 100644 --- a/poseidon/src/constants.rs +++ b/poseidon/src/constants.rs @@ -2,7 +2,7 @@ pub(crate) const DIGEST_WIDTH: usize = 4; pub(crate) const SPONGE_RATE: usize = 8; pub(crate) const SPONGE_CAPACITY: usize = 4; -pub(crate) const SPONGE_WIDTH: usize = SPONGE_RATE + SPONGE_CAPACITY; +pub const SPONGE_WIDTH: usize = SPONGE_RATE + SPONGE_CAPACITY; // The number of full rounds and partial rounds is given by the // calc_round_numbers.py script. They happen to be the same for both diff --git a/poseidon/src/digest.rs b/poseidon/src/digest.rs index a487d676b..a4175cccf 100644 --- a/poseidon/src/digest.rs +++ b/poseidon/src/digest.rs @@ -1,8 +1,8 @@ use crate::constants::DIGEST_WIDTH; -use goldilocks::SmallField; -use serde::{Deserialize, Serialize}; +use ff_ext::SmallField; +use serde::Serialize; -#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)] +#[derive(Clone, Debug, Default, Serialize, PartialEq, Eq)] pub struct Digest(pub [F; DIGEST_WIDTH]); impl TryFrom> for Digest { diff --git a/poseidon/src/lib.rs b/poseidon/src/lib.rs index 17db28f72..22c6be7ad 100644 --- a/poseidon/src/lib.rs +++ b/poseidon/src/lib.rs @@ -2,6 +2,7 @@ extern crate core; pub(crate) mod constants; +pub use constants::SPONGE_WIDTH; pub mod digest; pub mod poseidon; mod poseidon_goldilocks; diff --git a/poseidon/src/poseidon.rs b/poseidon/src/poseidon.rs index ed5d76d14..87cf4ce8e 100644 --- a/poseidon/src/poseidon.rs +++ b/poseidon/src/poseidon.rs @@ -1,12 +1,13 @@ -use crate::constants::{ - ALL_ROUND_CONSTANTS, HALF_N_FULL_ROUNDS, N_PARTIAL_ROUNDS, N_ROUNDS, SPONGE_WIDTH, -}; -use goldilocks::SmallField; -use unroll::unroll_for_loops; +use ff_ext::SmallField; +use p3_field::PrimeField; -pub trait Poseidon: AdaptedField { +use crate::constants::{N_PARTIAL_ROUNDS, N_ROUNDS, SPONGE_WIDTH}; + +pub trait PoseidonField: SmallField + PrimeField { // Total number of round constants required: width of the input // times number of rounds. + + const SPONGE_WIDTH: usize = SPONGE_WIDTH; const N_ROUND_CONSTANTS: usize = SPONGE_WIDTH * N_ROUNDS; // The MDS matrix we use is C + D, where C is the circulant matrix whose first @@ -22,245 +23,4 @@ pub trait Poseidon: AdaptedField { const FAST_PARTIAL_ROUND_VS: [[u64; SPONGE_WIDTH - 1]; N_PARTIAL_ROUNDS]; const FAST_PARTIAL_ROUND_W_HATS: [[u64; SPONGE_WIDTH - 1]; N_PARTIAL_ROUNDS]; const FAST_PARTIAL_ROUND_INITIAL_MATRIX: [[u64; SPONGE_WIDTH - 1]; SPONGE_WIDTH - 1]; - - #[inline] - fn poseidon(input: [Self; SPONGE_WIDTH]) -> [Self; SPONGE_WIDTH] { - let mut state = input; - let mut round_ctr = 0; - - Self::full_rounds(&mut state, &mut round_ctr); - Self::partial_rounds(&mut state, &mut round_ctr); - Self::full_rounds(&mut state, &mut round_ctr); - debug_assert_eq!(round_ctr, N_ROUNDS); - - state - } - - #[inline] - fn full_rounds(state: &mut [Self; SPONGE_WIDTH], round_ctr: &mut usize) { - for _ in 0..HALF_N_FULL_ROUNDS { - Self::constant_layer(state, *round_ctr); - Self::sbox_layer(state); - *state = Self::mds_layer(state); - *round_ctr += 1; - } - } - - #[inline] - fn partial_rounds(state: &mut [Self; SPONGE_WIDTH], round_ctr: &mut usize) { - Self::partial_first_constant_layer(state); - *state = Self::mds_partial_layer_init(state); - - for i in 0..N_PARTIAL_ROUNDS { - state[0] = Self::sbox_monomial(state[0]); - unsafe { - state[0] = state[0].add_canonical_u64(Self::FAST_PARTIAL_ROUND_CONSTANTS[i]); - } - *state = Self::mds_partial_layer_fast(state, i); - } - *round_ctr += N_PARTIAL_ROUNDS; - } - - #[inline(always)] - #[unroll_for_loops] - fn constant_layer(state: &mut [Self; SPONGE_WIDTH], round_ctr: usize) { - for i in 0..12 { - if i < SPONGE_WIDTH { - let round_constant = ALL_ROUND_CONSTANTS[i + SPONGE_WIDTH * round_ctr]; - unsafe { - state[i] = state[i].add_canonical_u64(round_constant); - } - } - } - } - - #[inline(always)] - #[unroll_for_loops] - fn sbox_layer(state: &mut [Self; SPONGE_WIDTH]) { - for i in 0..12 { - if i < SPONGE_WIDTH { - state[i] = Self::sbox_monomial(state[i]); - } - } - } - - #[inline(always)] - #[unroll_for_loops] - fn mds_layer(state_: &[Self; SPONGE_WIDTH]) -> [Self; SPONGE_WIDTH] { - let mut result = [Self::ZERO; SPONGE_WIDTH]; - - let mut state = [0u64; SPONGE_WIDTH]; - for r in 0..SPONGE_WIDTH { - state[r] = state_[r].to_noncanonical_u64(); - } - - // This is a hacky way of fully unrolling the loop. - for r in 0..12 { - if r < SPONGE_WIDTH { - let sum = Self::mds_row_shf(r, &state); - let sum_lo = sum as u64; - let sum_hi = (sum >> 64) as u32; - result[r] = Self::from_noncanonical_u96(sum_lo, sum_hi); - } - } - - result - } - - #[inline(always)] - #[unroll_for_loops] - fn partial_first_constant_layer(state: &mut [Self; SPONGE_WIDTH]) { - for i in 0..12 { - if i < SPONGE_WIDTH { - state[i] += Self::from_canonical_u64(Self::FAST_PARTIAL_FIRST_ROUND_CONSTANT[i]); - } - } - } - - #[inline(always)] - #[unroll_for_loops] - fn mds_partial_layer_init(state: &[Self; SPONGE_WIDTH]) -> [Self; SPONGE_WIDTH] { - let mut result = [Self::ZERO; SPONGE_WIDTH]; - - // Initial matrix has first row/column = [1, 0, ..., 0]; - - // c = 0 - result[0] = state[0]; - - for r in 1..12 { - if r < SPONGE_WIDTH { - for c in 1..12 { - if c < SPONGE_WIDTH { - // NB: FAST_PARTIAL_ROUND_INITIAL_MATRIX is stored in - // row-major order so that this dot product is cache - // friendly. - let t = Self::from_canonical_u64( - Self::FAST_PARTIAL_ROUND_INITIAL_MATRIX[r - 1][c - 1], - ); - result[c] += state[r] * t; - } - } - } - } - result - } - - #[inline(always)] - fn sbox_monomial(x: Self) -> Self { - // Observed a performance improvement by using x*x rather than x.square(). - // In Plonky2, where this function originates, operations might be over an algebraic extension field. - // Specialized square functions could leverage the field's structure for potential savings. - // Adding this note in case future generalizations or optimizations are considered. - - // x |--> x^7 - let x2 = x * x; - let x4 = x2 * x2; - let x3 = x * x2; - x3 * x4 - } - - /// Computes s*A where s is the state row vector and A is the matrix - /// - /// [ M_00 | v ] - /// [ ------+--- ] - /// [ w_hat | Id ] - /// - /// M_00 is a scalar, v is 1x(t-1), w_hat is (t-1)x1 and Id is the - /// (t-1)x(t-1) identity matrix. - #[inline(always)] - #[unroll_for_loops] - fn mds_partial_layer_fast(state: &[Self; SPONGE_WIDTH], r: usize) -> [Self; SPONGE_WIDTH] { - // Set d = [M_00 | w^] dot [state] - - let mut d_sum = (0u128, 0u32); // u160 accumulator - for i in 1..12 { - if i < SPONGE_WIDTH { - let t = Self::FAST_PARTIAL_ROUND_W_HATS[r][i - 1] as u128; - let si = state[i].to_noncanonical_u64() as u128; - d_sum = add_u160_u128(d_sum, si * t); - } - } - let s0 = state[0].to_noncanonical_u64() as u128; - let mds0to0 = (Self::MDS_MATRIX_CIRC[0] + Self::MDS_MATRIX_DIAG[0]) as u128; - d_sum = add_u160_u128(d_sum, s0 * mds0to0); - let d = reduce_u160::(d_sum); - - // result = [d] concat [state[0] * v + state[shift up by 1]] - let mut result = [Self::ZERO; SPONGE_WIDTH]; - result[0] = d; - for i in 1..12 { - if i < SPONGE_WIDTH { - let t = Self::from_canonical_u64(Self::FAST_PARTIAL_ROUND_VS[r][i - 1]); - result[i] = state[i].multiply_accumulate(state[0], t); - } - } - result - } - - #[inline(always)] - #[unroll_for_loops] - fn mds_row_shf(r: usize, v: &[u64; SPONGE_WIDTH]) -> u128 { - debug_assert!(r < SPONGE_WIDTH); - // The values of `MDS_MATRIX_CIRC` and `MDS_MATRIX_DIAG` are - // known to be small, so we can accumulate all the products for - // each row and reduce just once at the end (done by the - // caller). - - // NB: Unrolling this, calculating each term independently, and - // summing at the end, didn't improve performance for me. - let mut res = 0u128; - - // This is a hacky way of fully unrolling the loop. - for i in 0..12 { - if i < SPONGE_WIDTH { - res += (v[(i + r) % SPONGE_WIDTH] as u128) * (Self::MDS_MATRIX_CIRC[i] as u128); - } - } - res += (v[r] as u128) * (Self::MDS_MATRIX_DIAG[r] as u128); - - res - } -} - -#[inline(always)] -const fn add_u160_u128((x_lo, x_hi): (u128, u32), y: u128) -> (u128, u32) { - let (res_lo, over) = x_lo.overflowing_add(y); - let res_hi = x_hi + (over as u32); - (res_lo, res_hi) -} - -#[inline(always)] -fn reduce_u160((n_lo, n_hi): (u128, u32)) -> F { - let n_lo_hi = (n_lo >> 64) as u64; - let n_lo_lo = n_lo as u64; - let reduced_hi: u64 = F::from_noncanonical_u96(n_lo_hi, n_hi).to_noncanonical_u64(); - let reduced128: u128 = ((reduced_hi as u128) << 64) + (n_lo_lo as u128); - F::from_noncanonical_u128(reduced128) -} - -pub trait AdaptedField: SmallField { - const ORDER: u64; - - fn from_noncanonical_u96(n_lo: u64, n_hi: u32) -> Self; - - fn from_noncanonical_u128(n: u128) -> Self; - - fn multiply_accumulate(&self, x: Self, y: Self) -> Self; - - /// Returns `n`. Assumes that `n` is already in canonical form, i.e. `n < Self::order()`. - // TODO: Should probably be unsafe. - fn from_canonical_u64(n: u64) -> Self { - debug_assert!(n < Self::ORDER); - Self::from(n) - } - - /// # Safety - /// Equivalent to *self + Self::from_canonical_u64(rhs), but may be cheaper. The caller must - /// ensure that 0 <= rhs < Self::ORDER. The function may return incorrect results if this - /// precondition is not met. It is marked unsafe for this reason. - #[inline] - unsafe fn add_canonical_u64(&self, rhs: u64) -> Self { - // Default implementation. - *self + Self::from_canonical_u64(rhs) - } } diff --git a/poseidon/src/poseidon_goldilocks.rs b/poseidon/src/poseidon_goldilocks.rs index eaab6fcd0..5ef553123 100644 --- a/poseidon/src/poseidon_goldilocks.rs +++ b/poseidon/src/poseidon_goldilocks.rs @@ -1,13 +1,8 @@ -use crate::{ - constants::N_PARTIAL_ROUNDS, - poseidon::{AdaptedField, Poseidon}, -}; -use goldilocks::{EPSILON, Goldilocks, SmallField}; -#[cfg(target_arch = "x86_64")] -use std::hint::unreachable_unchecked; +use crate::{constants::N_PARTIAL_ROUNDS, poseidon::PoseidonField}; +use p3_goldilocks::Goldilocks; #[rustfmt::skip] -impl Poseidon for Goldilocks { +impl PoseidonField for Goldilocks { // The MDS matrix we use is C + D, where C is the circulant matrix whose first row is given by // `MDS_MATRIX_CIRC`, and D is the diagonal matrix whose diagonal is given by `MDS_MATRIX_DIAG`. // @@ -210,131 +205,4 @@ impl Poseidon for Goldilocks { 0xdcedab70f40718ba, 0xe796d293a47a64cb, 0x80772dc2645b280b, ], ]; - -} - -impl AdaptedField for Goldilocks { - const ORDER: u64 = Goldilocks::MODULUS_U64; - - fn from_noncanonical_u96(n_lo: u64, n_hi: u32) -> Self { - reduce96((n_lo, n_hi)) - } - - fn from_noncanonical_u128(n: u128) -> Self { - reduce128(n) - } - - fn multiply_accumulate(&self, x: Self, y: Self) -> Self { - // u64 + u64 * u64 cannot overflow. - reduce128((self.0 as u128) + (x.0 as u128) * (y.0 as u128)) - } -} - -/// Fast addition modulo ORDER for x86-64. -/// This function is marked unsafe for the following reasons: -/// - It is only correct if x + y < 2**64 + ORDER = 0x1ffffffff00000001. -/// - It is only faster in some circumstances. In particular, on x86 it overwrites both inputs in -/// the registers, so its use is not recommended when either input will be used again. -#[inline(always)] -#[cfg(target_arch = "x86_64")] -unsafe fn add_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 { - let res_wrapped: u64; - let adjustment: u64; - core::arch::asm!( - "add {0}, {1}", - // Trick. The carry flag is set iff the addition overflowed. - // sbb x, y does x := x - y - CF. In our case, x and y are both {1:e}, so it simply does - // {1:e} := 0xffffffff on overflow and {1:e} := 0 otherwise. {1:e} is the low 32 bits of - // {1}; the high 32-bits are zeroed on write. In the end, we end up with 0xffffffff in {1} - // on overflow; this happens be EPSILON. - // Note that the CPU does not realize that the result of sbb x, x does not actually depend - // on x. We must write the result to a register that we know to be ready. We have a - // dependency on {1} anyway, so let's use it. - "sbb {1:e}, {1:e}", - inlateout(reg) x => res_wrapped, - inlateout(reg) y => adjustment, - options(pure, nomem, nostack), - ); - assume(x != 0 || (res_wrapped == y && adjustment == 0)); - assume(y != 0 || (res_wrapped == x && adjustment == 0)); - // Add EPSILON == subtract ORDER. - // Cannot overflow unless the assumption if x + y < 2**64 + ORDER is incorrect. - res_wrapped + adjustment -} - -#[inline(always)] -#[cfg(not(target_arch = "x86_64"))] -const unsafe fn add_no_canonicalize_trashing_input(x: u64, y: u64) -> u64 { - let (res_wrapped, carry) = x.overflowing_add(y); - // Below cannot overflow unless the assumption if x + y < 2**64 + ORDER is incorrect. - res_wrapped + EPSILON * (carry as u64) -} - -/// Reduces to a 64-bit value. The result might not be in canonical form; it could be in between the -/// field order and `2^64`. -#[inline] -fn reduce96((x_lo, x_hi): (u64, u32)) -> Goldilocks { - let t1 = x_hi as u64 * EPSILON; - let t2 = unsafe { add_no_canonicalize_trashing_input(x_lo, t1) }; - Goldilocks(t2) -} - -/// Reduces to a 64-bit value. The result might not be in canonical form; it could be in between the -/// field order and `2^64`. -#[inline] -fn reduce128(x: u128) -> Goldilocks { - let (x_lo, x_hi) = split(x); // This is a no-op - let x_hi_hi = x_hi >> 32; - let x_hi_lo = x_hi & EPSILON; - - let (mut t0, borrow) = x_lo.overflowing_sub(x_hi_hi); - if borrow { - branch_hint(); // A borrow is exceedingly rare. It is faster to branch. - t0 -= EPSILON; // Cannot underflow. - } - let t1 = x_hi_lo * EPSILON; - let t2 = unsafe { add_no_canonicalize_trashing_input(t0, t1) }; - Goldilocks(t2) -} - -#[inline] -const fn split(x: u128) -> (u64, u64) { - (x as u64, (x >> 64) as u64) -} - -#[inline(always)] -#[cfg(target_arch = "x86_64")] -pub fn assume(p: bool) { - debug_assert!(p); - if !p { - unsafe { - unreachable_unchecked(); - } - } -} - -/// Try to force Rust to emit a branch. Example: -/// if x > 2 { -/// y = foo(); -/// branch_hint(); -/// } else { -/// y = bar(); -/// } -/// This function has no semantics. It is a hint only. -#[inline(always)] -pub fn branch_hint() { - // NOTE: These are the currently supported assembly architectures. See the - // [nightly reference](https://doc.rust-lang.org/nightly/reference/inline-assembly.html) for - // the most up-to-date list. - #[cfg(any( - target_arch = "aarch64", - target_arch = "arm", - target_arch = "riscv32", - target_arch = "riscv64", - target_arch = "x86", - target_arch = "x86_64", - ))] - unsafe { - core::arch::asm!("", options(nomem, nostack, preserves_flags)); - } } diff --git a/poseidon/src/poseidon_hash.rs b/poseidon/src/poseidon_hash.rs index c4559248e..ceac268ff 100644 --- a/poseidon/src/poseidon_hash.rs +++ b/poseidon/src/poseidon_hash.rs @@ -1,28 +1,38 @@ +use std::marker::PhantomData; + +use p3_mds::MdsPermutation; + use crate::{ - constants::{DIGEST_WIDTH, SPONGE_RATE}, + constants::{DIGEST_WIDTH, SPONGE_RATE, SPONGE_WIDTH}, digest::Digest, - poseidon::Poseidon, + poseidon::PoseidonField, poseidon_permutation::PoseidonPermutation, }; -pub struct PoseidonHash; +pub struct PoseidonHash { + _phantom: PhantomData<(F, Mds)>, +} -impl PoseidonHash { - pub fn two_to_one(left: &Digest, right: &Digest) -> Digest { - compress(left, right) +impl PoseidonHash +where + Mds: MdsPermutation + Default, +{ + pub fn two_to_one(left: &Digest, right: &Digest) -> Digest + where + Mds: MdsPermutation + Default, + { + compress::(left, right) } - pub fn hash_or_noop(inputs: &[F]) -> Digest { + pub fn hash_or_noop(inputs: &[F]) -> Digest { if inputs.len() <= DIGEST_WIDTH { Digest::from_partial(inputs) } else { - hash_n_to_hash_no_pad(inputs) + hash_n_to_hash_no_pad::(inputs) } } - pub fn hash_or_noop_iter<'a, F: Poseidon, I: Iterator>( - mut input_iter: I, - ) -> Digest { + pub fn hash_or_noop_iter<'a, I: Iterator>(mut input_iter: I) -> Digest { let mut initial_elements = Vec::with_capacity(DIGEST_WIDTH); for _ in 0..DIGEST_WIDTH + 1 { @@ -42,15 +52,18 @@ impl PoseidonHash { ) } else { let iter = initial_elements.into_iter().chain(input_iter); - hash_n_to_m_no_pad_iter(iter, DIGEST_WIDTH) + hash_n_to_m_no_pad_iter::<'_, F, _, Mds>(iter, DIGEST_WIDTH) .try_into() .unwrap() } } } -pub fn hash_n_to_m_no_pad(inputs: &[F], num_outputs: usize) -> Vec { - let mut perm = PoseidonPermutation::new(core::iter::repeat(F::ZERO)); +pub fn hash_n_to_m_no_pad(inputs: &[F], num_outputs: usize) -> Vec +where + Mds: MdsPermutation + Default, +{ + let mut perm = PoseidonPermutation::::new(core::iter::repeat(F::ZERO)); // Absorb all input chunks. for input_chunk in inputs.chunks(SPONGE_RATE) { @@ -74,11 +87,14 @@ pub fn hash_n_to_m_no_pad(inputs: &[F], num_outputs: usize) -> Vec< } } -pub fn hash_n_to_m_no_pad_iter<'a, F: Poseidon, I: Iterator>( +pub fn hash_n_to_m_no_pad_iter<'a, F: PoseidonField, I: Iterator, Mds>( mut input_iter: I, num_outputs: usize, -) -> Vec { - let mut perm = PoseidonPermutation::new(core::iter::repeat(F::ZERO)); +) -> Vec +where + Mds: MdsPermutation + Default, +{ + let mut perm = PoseidonPermutation::::new(core::iter::repeat(F::ZERO)); // Absorb all input chunks. loop { @@ -106,12 +122,20 @@ pub fn hash_n_to_m_no_pad_iter<'a, F: Poseidon, I: Iterator>( } } -pub fn hash_n_to_hash_no_pad(inputs: &[F]) -> Digest { - hash_n_to_m_no_pad(inputs, DIGEST_WIDTH).try_into().unwrap() +pub fn hash_n_to_hash_no_pad(inputs: &[F]) -> Digest +where + Mds: MdsPermutation + Default, +{ + hash_n_to_m_no_pad::(inputs, DIGEST_WIDTH) + .try_into() + .unwrap() } -pub fn compress(x: &Digest, y: &Digest) -> Digest { - let mut perm = PoseidonPermutation::new(core::iter::repeat(F::ZERO)); +pub fn compress(x: &Digest, y: &Digest) -> Digest +where + Mds: MdsPermutation + Default, +{ + let mut perm = PoseidonPermutation::::new(core::iter::repeat(F::ZERO)); perm.set_from_slice(x.elements(), 0); perm.set_from_slice(y.elements(), DIGEST_WIDTH); @@ -123,7 +147,8 @@ pub fn compress(x: &Digest, y: &Digest) -> Digest { #[cfg(test)] mod tests { use crate::{digest::Digest, poseidon_hash::PoseidonHash}; - use goldilocks::Goldilocks; + use p3_field::FieldAlgebra; + use p3_goldilocks::{Goldilocks, MdsMatrixGoldilocks}; use plonky2::{ field::{ goldilocks_field::GoldilocksField, @@ -142,7 +167,7 @@ mod tests { fn ceno_goldy_from_plonky_goldy(values: &[GoldilocksField]) -> Vec { values .iter() - .map(|value| Goldilocks(value.to_canonical_u64())) + .map(|value| Goldilocks::from_canonical_u64(value.to_canonical_u64())) .collect() } @@ -178,8 +203,10 @@ mod tests { let n = rng.gen_range(5..=100); let (plonky_elems, ceno_elems) = test_vector_pair(n); let plonky_out = PlonkyPoseidonHash::hash_or_noop(plonky_elems.as_slice()); - let ceno_out = PoseidonHash::hash_or_noop(ceno_elems.as_slice()); - let ceno_iter = PoseidonHash::hash_or_noop_iter(ceno_elems.iter()); + let ceno_out = + PoseidonHash::<_, MdsMatrixGoldilocks>::hash_or_noop(ceno_elems.as_slice()); + let ceno_iter = + PoseidonHash::<_, MdsMatrixGoldilocks>::hash_or_noop_iter(ceno_elems.iter()); assert!(compare_hash_output(plonky_out, ceno_out)); assert!(compare_hash_output(plonky_out, ceno_iter)); } @@ -192,8 +219,10 @@ mod tests { let n = rng.gen_range(0..=4); let (plonky_elems, ceno_elems) = test_vector_pair(n); let plonky_out = PlonkyPoseidonHash::hash_or_noop(plonky_elems.as_slice()); - let ceno_out = PoseidonHash::hash_or_noop(ceno_elems.as_slice()); - let ceno_iter = PoseidonHash::hash_or_noop_iter(ceno_elems.iter()); + let ceno_out = + PoseidonHash::<_, MdsMatrixGoldilocks>::hash_or_noop(ceno_elems.as_slice()); + let ceno_iter = + PoseidonHash::<_, MdsMatrixGoldilocks>::hash_or_noop_iter(ceno_elems.iter()); assert!(compare_hash_output(plonky_out, ceno_out)); assert!(compare_hash_output(plonky_out, ceno_iter)); } @@ -205,7 +234,8 @@ mod tests { let (plonky_hash_a, ceno_hash_a) = random_hash_pair(); let (plonky_hash_b, ceno_hash_b) = random_hash_pair(); let plonky_combined = PlonkyPoseidonHash::two_to_one(plonky_hash_a, plonky_hash_b); - let ceno_combined = PoseidonHash::two_to_one(&ceno_hash_a, &ceno_hash_b); + let ceno_combined = + PoseidonHash::<_, MdsMatrixGoldilocks>::two_to_one(&ceno_hash_a, &ceno_hash_b); assert!(compare_hash_output(plonky_combined, ceno_combined)); } } diff --git a/poseidon/src/poseidon_permutation.rs b/poseidon/src/poseidon_permutation.rs index a5f8db021..bf2143c1e 100644 --- a/poseidon/src/poseidon_permutation.rs +++ b/poseidon/src/poseidon_permutation.rs @@ -1,14 +1,28 @@ +use p3_field::PrimeField; +use p3_mds::MdsPermutation; +use p3_poseidon::Poseidon; + use crate::{ - constants::{SPONGE_RATE, SPONGE_WIDTH}, - poseidon::Poseidon, + constants::{ + ALL_ROUND_CONSTANTS, HALF_N_FULL_ROUNDS, N_PARTIAL_ROUNDS, SPONGE_RATE, SPONGE_WIDTH, + }, + poseidon::PoseidonField, }; +use p3_symmetric::Permutation; + +// follow https://github.com/Plonky3/Plonky3/blob/main/poseidon/benches/poseidon.rs#L22 +pub(crate) const ALPHA: u64 = 7; -#[derive(Copy, Clone)] -pub struct PoseidonPermutation { +#[derive(Clone)] +pub struct PoseidonPermutation { + poseidon: Poseidon, state: [T; SPONGE_WIDTH], } -impl PoseidonPermutation { +impl PoseidonPermutation +where + Mds: MdsPermutation + Default, +{ /// Initialises internal state with values from `iter` until /// `iter` is exhausted or `SPONGE_WIDTH` values have been /// received; remaining state (if any) initialised with @@ -18,6 +32,12 @@ impl PoseidonPermutation { /// or similar. pub fn new>(elts: I) -> Self { let mut perm = Self { + poseidon: Poseidon::::new( + HALF_N_FULL_ROUNDS, + N_PARTIAL_ROUNDS, + ALL_ROUND_CONSTANTS.map(T::from_canonical_u64).to_vec(), + Mds::default(), + ), state: [T::default(); SPONGE_WIDTH], }; perm.set_from_iter(elts, 0); @@ -43,7 +63,7 @@ impl PoseidonPermutation { /// Apply permutation to internal state pub fn permute(&mut self) { - self.state = T::poseidon(self.state); + self.poseidon.permute_mut(&mut self.state); } /// Return a slice of `RATE` elements diff --git a/sumcheck/Cargo.toml b/sumcheck/Cargo.toml index 47ceb0c16..58496a3ad 100644 --- a/sumcheck/Cargo.toml +++ b/sumcheck/Cargo.toml @@ -26,6 +26,7 @@ transcript = { path = "../transcript" } criterion.workspace = true p3-goldilocks.workspace = true p3-mds.workspace = true +poseidon.workspace = true [[bench]] harness = false diff --git a/sumcheck/src/test.rs b/sumcheck/src/test.rs index 64184fcb4..1e456f6fd 100644 --- a/sumcheck/src/test.rs +++ b/sumcheck/src/test.rs @@ -6,6 +6,7 @@ use multilinear_extensions::virtual_poly::VirtualPolynomial; use p3_field::FieldAlgebra; use p3_goldilocks::MdsMatrixGoldilocks; use p3_mds::MdsPermutation; +use poseidon::{SPONGE_WIDTH, poseidon::PoseidonField}; use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator}; use transcript::{BasicTranscript, Transcript}; @@ -22,7 +23,8 @@ fn test_sumcheck( num_multiplicands_range: (usize, usize), num_products: usize, ) where - Mds: MdsPermutation + Default, + E::BaseField: PoseidonField, + Mds: MdsPermutation + Default, { let mut rng = test_rng(); let mut transcript = BasicTranscript::::new(b"test"); @@ -53,7 +55,8 @@ fn test_sumcheck_internal( num_multiplicands_range: (usize, usize), num_products: usize, ) where - Mds: MdsPermutation + Default, + E::BaseField: PoseidonField, + Mds: MdsPermutation + Default, { let mut rng = test_rng(); let (poly, asserted_sum) = @@ -114,7 +117,8 @@ fn test_trivial_polynomial() { fn test_trivial_polynomial_helper() where - Mds: MdsPermutation + Default, + E::BaseField: PoseidonField, + Mds: MdsPermutation + Default, { let nv = 1; let num_multiplicands_range = (4, 13); @@ -132,7 +136,8 @@ fn test_normal_polynomial() { fn test_normal_polynomial_helper() where - Mds: MdsPermutation + Default, + E::BaseField: PoseidonField, + Mds: MdsPermutation + Default, { let nv = 12; let num_multiplicands_range = (4, 9); @@ -159,7 +164,8 @@ fn test_extract_sum() { fn test_extract_sum_helper() where - Mds: MdsPermutation + Default, + E::BaseField: PoseidonField, + Mds: MdsPermutation + Default, { let mut rng = test_rng(); let mut transcript = BasicTranscript::::new(b"test"); diff --git a/transcript/Cargo.toml b/transcript/Cargo.toml index 84ab7ed30..06a19346e 100644 --- a/transcript/Cargo.toml +++ b/transcript/Cargo.toml @@ -14,10 +14,10 @@ crossbeam-channel.workspace = true ff.workspace = true ff_ext = { path = "../ff_ext" } goldilocks.workspace = true +poseidon.workspace = true serde.workspace = true p3-field.workspace = true p3-goldilocks.workspace = true p3-poseidon.workspace = true p3-mds.workspace = true p3-symmetric.workspace = true - diff --git a/transcript/src/basic.rs b/transcript/src/basic.rs index 2cac83bcb..61b69c253 100644 --- a/transcript/src/basic.rs +++ b/transcript/src/basic.rs @@ -1,77 +1,46 @@ -use std::array; - -use crate::{Challenge, ForkableTranscript, Transcript}; -use ff_ext::{ExtensionField, SmallField}; -use p3_field::{FieldAlgebra, PrimeField}; +use ff_ext::ExtensionField; use p3_mds::MdsPermutation; -use p3_poseidon::Poseidon; -use p3_symmetric::Permutation; +use poseidon::{SPONGE_WIDTH, poseidon::PoseidonField, poseidon_permutation::PoseidonPermutation}; -// follow https://github.com/Plonky3/Plonky3/blob/main/poseidon/benches/poseidon.rs#L22 -pub(crate) const WIDTH: usize = 8; -pub(crate) const ALPHA: u64 = 7; +use crate::{Challenge, ForkableTranscript, Transcript}; +use ff_ext::SmallField; +use p3_field::FieldAlgebra; #[derive(Clone)] -pub struct BasicTranscript { - // TODO generalized to accept general permutation - poseidon: Poseidon, - state: [E::BaseField; WIDTH], +pub struct BasicTranscript +where + E::BaseField: PoseidonField, +{ + permutation: PoseidonPermutation, } impl BasicTranscript where - Mds: MdsPermutation + Default, + E::BaseField: PoseidonField, + Mds: MdsPermutation + Default, { /// Create a new IOP transcript. pub fn new(label: &'static [u8]) -> Self { - let mds = Mds::default(); - - // TODO: Should be calculated for the particular field, width and ALPHA. - let half_num_full_rounds = 4; - let num_partial_rounds = 22; - - let num_rounds = 2 * half_num_full_rounds + num_partial_rounds; - let num_constants = WIDTH * num_rounds; - let constants = vec![E::BaseField::ZERO; num_constants]; - - let poseidon = Poseidon::::new( - half_num_full_rounds, - num_partial_rounds, - constants, - mds, - ); - let input = array::from_fn(|_| E::BaseField::ZERO); + let mut permutation = PoseidonPermutation::new(core::iter::repeat(E::BaseField::ZERO)); let label_f = E::BaseField::bytes_to_field_elements(label); - let mut new = BasicTranscript:: { - poseidon, - state: input, - }; - new.set_from_slice(label_f.as_slice(), 0); - new.poseidon.permute_mut(&mut new.state); - new - } - - /// Set state element `i` to be `elts[i] for i = - /// start_idx..start_idx + n` where `n = min(elts.len(), - /// WIDTH-start_idx)`. Panics if `start_idx > SPONGE_WIDTH`. - fn set_from_slice(&mut self, elts: &[E::BaseField], start_idx: usize) { - let begin = start_idx; - let end = start_idx + elts.len(); - self.state[begin..end].copy_from_slice(elts) + permutation.set_from_slice(label_f.as_slice(), 0); + permutation.permute(); + Self { permutation } } } impl Transcript for BasicTranscript where - Mds: MdsPermutation + Default, + E::BaseField: PoseidonField, + Mds: MdsPermutation + Default, { - fn append_field_element_ext(&mut self, element: &E) { - self.append_field_elements(element.as_bases()); + fn append_field_elements(&mut self, elements: &[E::BaseField]) { + self.permutation.set_from_slice(elements, 0); + self.permutation.permute(); } - fn append_field_elements(&mut self, elements: &[E::BaseField]) { - self.set_from_slice(elements, 0); - self.poseidon.permute_mut(&mut self.state); + fn append_field_element_ext(&mut self, element: &E) { + self.append_field_elements(element.as_bases()) } fn read_challenge(&mut self) -> Challenge { @@ -81,7 +50,7 @@ where // We select `from_base` here to make it more clear that // we only use the first 2 fields here to construct the // challenge as an extension field element. - let elements = E::from_bases(&self.state[..8][..2]); + let elements = E::from_bases(&self.permutation.squeeze()[..2]); Challenge { elements } } @@ -105,7 +74,7 @@ where impl ForkableTranscript for BasicTranscript where - E::BaseField: FieldAlgebra + PrimeField, - Mds: MdsPermutation + Default, + E::BaseField: PoseidonField, + Mds: MdsPermutation + Default, { } diff --git a/transcript/src/lib.rs b/transcript/src/lib.rs index 47a2767b9..d3f6ff57e 100644 --- a/transcript/src/lib.rs +++ b/transcript/src/lib.rs @@ -11,7 +11,6 @@ use ff_ext::SmallField; use p3_field::FieldAlgebra; pub use statistics::{BasicTranscriptWithStat, StatisticRecorder}; pub use syncronized::TranscriptSyncronized; - #[derive(Default, Copy, Clone, Eq, PartialEq, Debug)] pub struct Challenge { pub elements: F, diff --git a/transcript/src/statistics.rs b/transcript/src/statistics.rs index 6a8439956..6811ffefa 100644 --- a/transcript/src/statistics.rs +++ b/transcript/src/statistics.rs @@ -1,6 +1,7 @@ -use crate::{BasicTranscript, Challenge, ForkableTranscript, Transcript, basic::WIDTH}; +use crate::{BasicTranscript, Challenge, ForkableTranscript, Transcript}; use ff_ext::ExtensionField; use p3_mds::MdsPermutation; +use poseidon::{SPONGE_WIDTH, poseidon::PoseidonField}; use std::cell::RefCell; #[derive(Debug, Default)] @@ -11,14 +12,18 @@ pub struct Statistic { pub type StatisticRecorder = RefCell; #[derive(Clone)] -pub struct BasicTranscriptWithStat<'a, E: ExtensionField, Mds> { +pub struct BasicTranscriptWithStat<'a, E: ExtensionField, Mds> +where + E::BaseField: PoseidonField, +{ inner: BasicTranscript, stat: &'a StatisticRecorder, } impl<'a, E: ExtensionField, Mds> BasicTranscriptWithStat<'a, E, Mds> where - Mds: MdsPermutation + Default, + E::BaseField: PoseidonField, + Mds: MdsPermutation + Default, { pub fn new(stat: &'a StatisticRecorder, label: &'static [u8]) -> Self { Self { @@ -30,7 +35,8 @@ where impl Transcript for BasicTranscriptWithStat<'_, E, Mds> where - Mds: MdsPermutation + Default, + E::BaseField: PoseidonField, + Mds: MdsPermutation + Default, { fn append_field_elements(&mut self, elements: &[E::BaseField]) { self.stat.borrow_mut().field_appended_num += 1; @@ -63,7 +69,9 @@ where } } -impl ForkableTranscript for BasicTranscriptWithStat<'_, E, Mds> where - Mds: MdsPermutation + Default +impl ForkableTranscript for BasicTranscriptWithStat<'_, E, Mds> +where + E::BaseField: PoseidonField, + Mds: MdsPermutation + Default, { } From 00e4bb04f8445744c1f23e0b9ddc8cf933971c87 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Tue, 7 Jan 2025 21:43:26 +0800 Subject: [PATCH 10/12] rollback deserialize change with serde bound --- mpcs/src/util/matrix.rs | 123 ------------------------------ multilinear_extensions/src/mle.rs | 8 +- poseidon/src/digest.rs | 5 +- sumcheck/src/structs.rs | 8 +- 4 files changed, 13 insertions(+), 131 deletions(-) delete mode 100644 mpcs/src/util/matrix.rs diff --git a/mpcs/src/util/matrix.rs b/mpcs/src/util/matrix.rs deleted file mode 100644 index 7860c3a6c..000000000 --- a/mpcs/src/util/matrix.rs +++ /dev/null @@ -1,123 +0,0 @@ -use std::{ - marker::PhantomData, - ops::Index, - slice::{Chunks, ChunksMut}, - sync::Arc, -}; - -use ff::Field; -use multilinear_extensions::mle::{DenseMultilinearExtension, IntoMLE}; -use rayon::{ - iter::{IntoParallelIterator, ParallelIterator}, - slice::ParallelSliceMut, -}; - -use super::next_pow2_instance_padding; - -#[derive(Clone)] -pub enum InstancePaddingStrategy { - // Pads with default values of underlying type - // Usually zero, but check carefully - Default, - // Pads by repeating last row - RepeatLast, - // Custom strategy consists of a closure - // `pad(i, j) = padding value for cell at row i, column j` - // pad should be able to cross thread boundaries - Custom(Arc u64 + Send + Sync>), -} - -/// TODO replace with plonky3 RowMajorMatrix https://github.com/Plonky3/Plonky3/blob/784b7dd1fa87c1202e63350cc8182d7c5327a7af/matrix/src/dense.rs#L26 -#[derive(Clone)] -pub struct RowMajorMatrix> { - // represent 2D in 1D linear memory and avoid double indirection by Vec> to improve performance - values: V, - num_col: usize, - padding_strategy: InstancePaddingStrategy, - _phantom: PhantomData, -} - -impl> RowMajorMatrix { - pub fn new(num_rows: usize, num_col: usize, padding_strategy: InstancePaddingStrategy) -> Self { - RowMajorMatrix { - values: (0..num_rows * num_col) - .into_par_iter() - .map(|_| T::default()) - .collect(), - num_col, - padding_strategy, - _phantom: PhantomData, - } - } - - pub fn num_cols(&self) -> usize { - self.num_col - } - - pub fn num_padding_instances(&self) -> usize { - next_pow2_instance_padding(self.num_instances()) - self.num_instances() - } - - pub fn num_instances(&self) -> usize { - self.values.len() / self.num_col - } - - pub fn iter_rows(&self) -> Chunks { - self.values.chunks(self.num_col) - } - - pub fn iter_mut(&mut self) -> ChunksMut { - self.values.chunks_mut(self.num_col) - } - - pub fn par_iter_mut(&mut self) -> rayon::slice::ChunksMut { - self.values.par_chunks_mut(self.num_col) - } - - pub fn par_batch_iter_mut(&mut self, num_rows: usize) -> rayon::slice::ChunksMut { - self.values.par_chunks_mut(num_rows * self.num_col) - } - - // Returns column number `column`, padded appropriately according to the stored strategy - pub fn column_padded(&self, column: usize) -> Vec { - let num_instances = self.num_instances(); - let num_padding_instances = self.num_padding_instances(); - - let padding_iter = (num_instances..num_instances + num_padding_instances).map(|i| { - match &self.padding_strategy { - InstancePaddingStrategy::Custom(fun) => T::from(fun(i as u64, column as u64)), - InstancePaddingStrategy::RepeatLast if num_instances > 0 => { - self[num_instances - 1][column] - } - _ => T::default(), - } - }); - - self.values - .iter() - .skip(column) - .step_by(self.num_col) - .copied() - .chain(padding_iter) - .collect::>() - } -} - -impl> RowMajorMatrix { - pub fn into_mles>( - self, - ) -> Vec> { - (0..self.num_col) - .into_par_iter() - .map(|i| self.column_padded(i).into_mle()) - .collect() - } -} - -impl Index for RowMajorMatrix { - type Output = [F]; - - fn index(&self, idx: usize) -> &Self::Output { - &self.values[self.num_col * idx..][..self.num_col] - } -} diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index 8a78ea7fa..706d9a46a 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -8,7 +8,7 @@ use p3_field::{Field, FieldAlgebra}; use rayon::iter::{ IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, }; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use std::fmt::Debug; pub trait MultilinearExtension: Send + Sync { @@ -122,8 +122,9 @@ impl> IntoMLEs { Base(#[serde(skip)] Vec), @@ -159,7 +160,8 @@ impl FieldType { } /// Stores a multilinear polynomial in dense evaluation form. -#[derive(Clone, PartialEq, Eq, Default, Debug, Serialize)] +#[derive(Clone, PartialEq, Eq, Default, Debug, Serialize, Deserialize)] +#[serde(bound = "")] pub struct DenseMultilinearExtension { /// The evaluation over {0,1}^`num_vars` pub evaluations: FieldType, diff --git a/poseidon/src/digest.rs b/poseidon/src/digest.rs index a4175cccf..8680aa0f1 100644 --- a/poseidon/src/digest.rs +++ b/poseidon/src/digest.rs @@ -1,8 +1,9 @@ use crate::constants::DIGEST_WIDTH; use ff_ext::SmallField; -use serde::Serialize; +use serde::{Deserialize, Serialize}; -#[derive(Clone, Debug, Default, Serialize, PartialEq, Eq)] +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)] +#[serde(bound = "")] pub struct Digest(pub [F; DIGEST_WIDTH]); impl TryFrom> for Digest { diff --git a/sumcheck/src/structs.rs b/sumcheck/src/structs.rs index a2da024a2..02ac86646 100644 --- a/sumcheck/src/structs.rs +++ b/sumcheck/src/structs.rs @@ -1,12 +1,13 @@ use ff_ext::ExtensionField; use multilinear_extensions::virtual_poly::VirtualPolynomial; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use transcript::Challenge; /// An IOP proof is a collections of /// - messages from prover to verifier at each round through the interactive protocol. /// - a point that is generated by the transcript for evaluation -#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize)] +#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)] +#[serde(bound = "")] pub struct IOPProof { pub point: Vec, pub proofs: Vec>, @@ -19,7 +20,8 @@ impl IOPProof { /// A message from the prover to the verifier at a given round /// is a list of evaluations. -#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize)] +#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)] +#[serde(bound = "")] pub struct IOPProverMessage { pub(crate) evaluations: Vec, } From 1126d58846c40e2a01d93518fb0c41215cb1fb78 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Wed, 8 Jan 2025 15:34:07 +0800 Subject: [PATCH 11/12] implement P2MDS to match poseidon with plonky2 result --- mpcs/src/util/hash.rs | 4 +- poseidon/benches/hashing.rs | 20 ++- poseidon/src/digest.rs | 8 +- poseidon/src/lib.rs | 5 +- poseidon/src/plonky2_goldilock_mds.rs | 57 +++++++ poseidon/src/poseidon.rs | 26 ---- poseidon/src/poseidon_goldilocks.rs | 208 -------------------------- poseidon/src/poseidon_hash.rs | 26 ++-- poseidon/src/poseidon_permutation.rs | 11 +- sumcheck/src/test.rs | 7 +- transcript/src/basic.rs | 15 +- transcript/src/statistics.rs | 15 +- 12 files changed, 104 insertions(+), 298 deletions(-) create mode 100644 poseidon/src/plonky2_goldilock_mds.rs delete mode 100644 poseidon/src/poseidon.rs delete mode 100644 poseidon/src/poseidon_goldilocks.rs diff --git a/mpcs/src/util/hash.rs b/mpcs/src/util/hash.rs index 09b1da635..2dc4c8bf4 100644 --- a/mpcs/src/util/hash.rs +++ b/mpcs/src/util/hash.rs @@ -5,7 +5,7 @@ use poseidon::poseidon_hash::PoseidonHash; use transcript::Transcript; pub use poseidon::digest::Digest; -use poseidon::poseidon::PoseidonField; +use poseidon::poseidon::PrimeField; pub fn write_digest_to_transcript( digest: &Digest, @@ -44,6 +44,6 @@ pub fn hash_two_leaves_batch_base( hash_two_digests(&a_m_to_1_hash, &b_m_to_1_hash) } -pub fn hash_two_digests(a: &Digest, b: &Digest) -> Digest { +pub fn hash_two_digests(a: &Digest, b: &Digest) -> Digest { PoseidonHash::two_to_one(a, b) } diff --git a/poseidon/benches/hashing.rs b/poseidon/benches/hashing.rs index 43a299ddc..ab58560a4 100644 --- a/poseidon/benches/hashing.rs +++ b/poseidon/benches/hashing.rs @@ -1,7 +1,8 @@ use ark_std::test_rng; use criterion::{BatchSize, Criterion, black_box, criterion_group, criterion_main}; -use ff::Field; -use goldilocks::Goldilocks; +use ff_ext::FromUniformBytes; +use p3_field::FieldAlgebra; +use p3_goldilocks::{Goldilocks, MdsMatrixGoldilocks}; use plonky2::{ field::{goldilocks_field::GoldilocksField, types::Sample}, hash::{ @@ -34,7 +35,7 @@ fn plonky_hash_single(a: GoldilocksField) { } fn ceno_hash_single(a: Goldilocks) { - let _result = black_box(PoseidonHash::hash_or_noop(&[a])); + let _result = black_box(PoseidonHash::::hash_or_noop(&[a])); } fn plonky_hash_2_to_1(left: HashOut, right: HashOut) { @@ -42,7 +43,9 @@ fn plonky_hash_2_to_1(left: HashOut, right: HashOut, right: &Digest) { - let _result = black_box(PoseidonHash::two_to_one(left, right)); + let _result = black_box(PoseidonHash::::two_to_one( + left, right, + )); } fn plonky_hash_many_to_1(values: &[GoldilocksField]) { @@ -50,7 +53,7 @@ fn plonky_hash_many_to_1(values: &[GoldilocksField]) { } fn ceno_hash_many_to_1(values: &[Goldilocks]) { - let _result = black_box(PoseidonHash::hash_or_noop(values)); + let _result = black_box(PoseidonHash::::hash_or_noop(values)); } pub fn hashing_benchmark(c: &mut Criterion) { @@ -111,9 +114,10 @@ pub fn hashing_benchmark(c: &mut Criterion) { // bench permutation pub fn permutation_benchmark(c: &mut Criterion) { let mut plonky_permutation = PoseidonPermutation::new(core::iter::repeat(GoldilocksField(0))); - let mut ceno_permutation = poseidon::poseidon_permutation::PoseidonPermutation::new( - core::iter::repeat(Goldilocks::ZERO), - ); + let mut ceno_permutation = poseidon::poseidon_permutation::PoseidonPermutation::< + Goldilocks, + MdsMatrixGoldilocks, + >::new(core::iter::repeat(Goldilocks::ZERO)); c.bench_function("plonky permute", |bencher| { bencher.iter(|| plonky_permutation.permute()) diff --git a/poseidon/src/digest.rs b/poseidon/src/digest.rs index 8680aa0f1..846757e31 100644 --- a/poseidon/src/digest.rs +++ b/poseidon/src/digest.rs @@ -1,12 +1,12 @@ use crate::constants::DIGEST_WIDTH; -use ff_ext::SmallField; +use p3_field::PrimeField; use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)] #[serde(bound = "")] -pub struct Digest(pub [F; DIGEST_WIDTH]); +pub struct Digest(pub [F; DIGEST_WIDTH]); -impl TryFrom> for Digest { +impl TryFrom> for Digest { type Error = String; fn try_from(values: Vec) -> Result { @@ -20,7 +20,7 @@ impl TryFrom> for Digest { } } -impl Digest { +impl Digest { pub(crate) fn from_partial(inputs: &[F]) -> Self { let mut elements = [F::ZERO; DIGEST_WIDTH]; elements[0..inputs.len()].copy_from_slice(inputs); diff --git a/poseidon/src/lib.rs b/poseidon/src/lib.rs index 22c6be7ad..49d3cedf2 100644 --- a/poseidon/src/lib.rs +++ b/poseidon/src/lib.rs @@ -2,9 +2,10 @@ extern crate core; pub(crate) mod constants; +pub(crate) mod plonky2_goldilock_mds; pub use constants::SPONGE_WIDTH; +#[cfg(test)] +pub use plonky2_goldilock_mds::P2MdsMatrixGoldilocks; pub mod digest; -pub mod poseidon; -mod poseidon_goldilocks; pub mod poseidon_hash; pub mod poseidon_permutation; diff --git a/poseidon/src/plonky2_goldilock_mds.rs b/poseidon/src/plonky2_goldilock_mds.rs new file mode 100644 index 000000000..017ab593b --- /dev/null +++ b/poseidon/src/plonky2_goldilock_mds.rs @@ -0,0 +1,57 @@ +//! this is just for compatible with plonky2 poseidon result, refer from plonky3 commit +//! https://github.com/Plonky3/Plonky3/commit/13ad333f3c74e5986df161dd7189eac3fe73e520 +//! once upgrade to poseidon2 we can remove this functionality +use p3_field::FieldAlgebra; +use p3_goldilocks::Goldilocks; +use p3_mds::MdsPermutation; +use p3_symmetric::Permutation; +use unroll::unroll_for_loops; + +use crate::SPONGE_WIDTH; +#[derive(Clone)] + +pub struct P2MdsMatrixGoldilocks { + pub matrix: [[Goldilocks; SPONGE_WIDTH]; SPONGE_WIDTH], +} + +impl Permutation<[Goldilocks; SPONGE_WIDTH]> for P2MdsMatrixGoldilocks { + #[unroll_for_loops] + #[allow(clippy::needless_range_loop)] + fn permute(&self, input: [Goldilocks; SPONGE_WIDTH]) -> [Goldilocks; SPONGE_WIDTH] { + let mut output = [Goldilocks::ZERO; SPONGE_WIDTH]; + for i in 0..SPONGE_WIDTH { + for j in 0..SPONGE_WIDTH { + output[i] += self.matrix[i][j] * input[j]; + } + } + output + } + fn permute_mut(&self, input: &mut [Goldilocks; 12]) { + *input = self.permute(*input); + } +} + +impl MdsPermutation for P2MdsMatrixGoldilocks {} + +impl P2MdsMatrixGoldilocks { + const CIRC: [u64; SPONGE_WIDTH] = [17, 15, 41, 16, 2, 28, 13, 13, 39, 18, 34, 20]; + const DIAG: [u64; SPONGE_WIDTH] = [8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; +} + +impl Default for P2MdsMatrixGoldilocks { + #[allow(clippy::needless_range_loop)] + fn default() -> Self { + let mut matrix = [[Goldilocks::ZERO; SPONGE_WIDTH]; SPONGE_WIDTH]; + for i in 0..SPONGE_WIDTH { + for j in 0..SPONGE_WIDTH { + matrix[i][j] = Goldilocks::from_canonical_u64( + Self::CIRC[(SPONGE_WIDTH + j - i) % SPONGE_WIDTH], + ); + if i == j { + matrix[i][j] += Goldilocks::from_canonical_u64(Self::DIAG[i]); + } + } + } + Self { matrix } + } +} diff --git a/poseidon/src/poseidon.rs b/poseidon/src/poseidon.rs deleted file mode 100644 index 87cf4ce8e..000000000 --- a/poseidon/src/poseidon.rs +++ /dev/null @@ -1,26 +0,0 @@ -use ff_ext::SmallField; -use p3_field::PrimeField; - -use crate::constants::{N_PARTIAL_ROUNDS, N_ROUNDS, SPONGE_WIDTH}; - -pub trait PoseidonField: SmallField + PrimeField { - // Total number of round constants required: width of the input - // times number of rounds. - - const SPONGE_WIDTH: usize = SPONGE_WIDTH; - const N_ROUND_CONSTANTS: usize = SPONGE_WIDTH * N_ROUNDS; - - // The MDS matrix we use is C + D, where C is the circulant matrix whose first - // row is given by `MDS_MATRIX_CIRC`, and D is the diagonal matrix whose - // diagonal is given by `MDS_MATRIX_DIAG`. - const MDS_MATRIX_CIRC: [u64; SPONGE_WIDTH]; - const MDS_MATRIX_DIAG: [u64; SPONGE_WIDTH]; - - // Precomputed constants for the fast Poseidon calculation. See - // the paper. - const FAST_PARTIAL_FIRST_ROUND_CONSTANT: [u64; SPONGE_WIDTH]; - const FAST_PARTIAL_ROUND_CONSTANTS: [u64; N_PARTIAL_ROUNDS]; - const FAST_PARTIAL_ROUND_VS: [[u64; SPONGE_WIDTH - 1]; N_PARTIAL_ROUNDS]; - const FAST_PARTIAL_ROUND_W_HATS: [[u64; SPONGE_WIDTH - 1]; N_PARTIAL_ROUNDS]; - const FAST_PARTIAL_ROUND_INITIAL_MATRIX: [[u64; SPONGE_WIDTH - 1]; SPONGE_WIDTH - 1]; -} diff --git a/poseidon/src/poseidon_goldilocks.rs b/poseidon/src/poseidon_goldilocks.rs deleted file mode 100644 index 5ef553123..000000000 --- a/poseidon/src/poseidon_goldilocks.rs +++ /dev/null @@ -1,208 +0,0 @@ -use crate::{constants::N_PARTIAL_ROUNDS, poseidon::PoseidonField}; -use p3_goldilocks::Goldilocks; - -#[rustfmt::skip] -impl PoseidonField for Goldilocks { - // The MDS matrix we use is C + D, where C is the circulant matrix whose first row is given by - // `MDS_MATRIX_CIRC`, and D is the diagonal matrix whose diagonal is given by `MDS_MATRIX_DIAG`. - // - // WARNING: If the MDS matrix is changed, then the following - // constants need to be updated accordingly: - // - FAST_PARTIAL_ROUND_CONSTANTS - // - FAST_PARTIAL_ROUND_VS - // - FAST_PARTIAL_ROUND_W_HATS - // - FAST_PARTIAL_ROUND_INITIAL_MATRIX - const MDS_MATRIX_CIRC: [u64; 12] = [17, 15, 41, 16, 2, 28, 13, 13, 39, 18, 34, 20]; - const MDS_MATRIX_DIAG: [u64; 12] = [8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; - - const FAST_PARTIAL_FIRST_ROUND_CONSTANT: [u64; 12] = [ - 0x3cc3f892184df408, 0xe993fd841e7e97f1, 0xf2831d3575f0f3af, 0xd2500e0a350994ca, - 0xc5571f35d7288633, 0x91d89c5184109a02, 0xf37f925d04e5667b, 0x2d6e448371955a69, - 0x740ef19ce01398a1, 0x694d24c0752fdf45, 0x60936af96ee2f148, 0xc33448feadc78f0c, - ]; - - const FAST_PARTIAL_ROUND_CONSTANTS: [u64; N_PARTIAL_ROUNDS] = [ - 0x74cb2e819ae421ab, 0xd2559d2370e7f663, 0x62bf78acf843d17c, 0xd5ab7b67e14d1fb4, - 0xb9fe2ae6e0969bdc, 0xe33fdf79f92a10e8, 0x0ea2bb4c2b25989b, 0xca9121fbf9d38f06, - 0xbdd9b0aa81f58fa4, 0x83079fa4ecf20d7e, 0x650b838edfcc4ad3, 0x77180c88583c76ac, - 0xaf8c20753143a180, 0xb8ccfe9989a39175, 0x954a1729f60cc9c5, 0xdeb5b550c4dca53b, - 0xf01bb0b00f77011e, 0xa1ebb404b676afd9, 0x860b6e1597a0173e, 0x308bb65a036acbce, - 0x1aca78f31c97c876, 0x0, - ]; - - const FAST_PARTIAL_ROUND_VS: [[u64; 12 - 1]; N_PARTIAL_ROUNDS] =[ - [0x94877900674181c3, 0xc6c67cc37a2a2bbd, 0xd667c2055387940f, 0x0ba63a63e94b5ff0, - 0x99460cc41b8f079f, 0x7ff02375ed524bb3, 0xea0870b47a8caf0e, 0xabcad82633b7bc9d, - 0x3b8d135261052241, 0xfb4515f5e5b0d539, 0x3ee8011c2b37f77c, ], - [0x0adef3740e71c726, 0xa37bf67c6f986559, 0xc6b16f7ed4fa1b00, 0x6a065da88d8bfc3c, - 0x4cabc0916844b46f, 0x407faac0f02e78d1, 0x07a786d9cf0852cf, 0x42433fb6949a629a, - 0x891682a147ce43b0, 0x26cfd58e7b003b55, 0x2bbf0ed7b657acb3, ], - [0x481ac7746b159c67, 0xe367de32f108e278, 0x73f260087ad28bec, 0x5cfc82216bc1bdca, - 0xcaccc870a2663a0e, 0xdb69cd7b4298c45d, 0x7bc9e0c57243e62d, 0x3cc51c5d368693ae, - 0x366b4e8cc068895b, 0x2bd18715cdabbca4, 0xa752061c4f33b8cf, ], - [0xb22d2432b72d5098, 0x9e18a487f44d2fe4, 0x4b39e14ce22abd3c, 0x9e77fde2eb315e0d, - 0xca5e0385fe67014d, 0x0c2cb99bf1b6bddb, 0x99ec1cd2a4460bfe, 0x8577a815a2ff843f, - 0x7d80a6b4fd6518a5, 0xeb6c67123eab62cb, 0x8f7851650eca21a5, ], - [0x11ba9a1b81718c2a, 0x9f7d798a3323410c, 0xa821855c8c1cf5e5, 0x535e8d6fac0031b2, - 0x404e7c751b634320, 0xa729353f6e55d354, 0x4db97d92e58bb831, 0xb53926c27897bf7d, - 0x965040d52fe115c5, 0x9565fa41ebd31fd7, 0xaae4438c877ea8f4, ], - [0x37f4e36af6073c6e, 0x4edc0918210800e9, 0xc44998e99eae4188, 0x9f4310d05d068338, - 0x9ec7fe4350680f29, 0xc5b2c1fdc0b50874, 0xa01920c5ef8b2ebe, 0x59fa6f8bd91d58ba, - 0x8bfc9eb89b515a82, 0xbe86a7a2555ae775, 0xcbb8bbaa3810babf, ], - [0x577f9a9e7ee3f9c2, 0x88c522b949ace7b1, 0x82f07007c8b72106, 0x8283d37c6675b50e, - 0x98b074d9bbac1123, 0x75c56fb7758317c1, 0xfed24e206052bc72, 0x26d7c3d1bc07dae5, - 0xf88c5e441e28dbb4, 0x4fe27f9f96615270, 0x514d4ba49c2b14fe, ], - [0xf02a3ac068ee110b, 0x0a3630dafb8ae2d7, 0xce0dc874eaf9b55c, 0x9a95f6cff5b55c7e, - 0x626d76abfed00c7b, 0xa0c1cf1251c204ad, 0xdaebd3006321052c, 0x3d4bd48b625a8065, - 0x7f1e584e071f6ed2, 0x720574f0501caed3, 0xe3260ba93d23540a, ], - [0xab1cbd41d8c1e335, 0x9322ed4c0bc2df01, 0x51c3c0983d4284e5, 0x94178e291145c231, - 0xfd0f1a973d6b2085, 0xd427ad96e2b39719, 0x8a52437fecaac06b, 0xdc20ee4b8c4c9a80, - 0xa2c98e9549da2100, 0x1603fe12613db5b6, 0x0e174929433c5505, ], - [0x3d4eab2b8ef5f796, 0xcfff421583896e22, 0x4143cb32d39ac3d9, 0x22365051b78a5b65, - 0x6f7fd010d027c9b6, 0xd9dd36fba77522ab, 0xa44cf1cb33e37165, 0x3fc83d3038c86417, - 0xc4588d418e88d270, 0xce1320f10ab80fe2, 0xdb5eadbbec18de5d, ], - [0x1183dfce7c454afd, 0x21cea4aa3d3ed949, 0x0fce6f70303f2304, 0x19557d34b55551be, - 0x4c56f689afc5bbc9, 0xa1e920844334f944, 0xbad66d423d2ec861, 0xf318c785dc9e0479, - 0x99e2032e765ddd81, 0x400ccc9906d66f45, 0xe1197454db2e0dd9, ], - [0x84d1ecc4d53d2ff1, 0xd8af8b9ceb4e11b6, 0x335856bb527b52f4, 0xc756f17fb59be595, - 0xc0654e4ea5553a78, 0x9e9a46b61f2ea942, 0x14fc8b5b3b809127, 0xd7009f0f103be413, - 0x3e0ee7b7a9fb4601, 0xa74e888922085ed7, 0xe80a7cde3d4ac526, ], - [0x238aa6daa612186d, 0x9137a5c630bad4b4, 0xc7db3817870c5eda, 0x217e4f04e5718dc9, - 0xcae814e2817bd99d, 0xe3292e7ab770a8ba, 0x7bb36ef70b6b9482, 0x3c7835fb85bca2d3, - 0xfe2cdf8ee3c25e86, 0x61b3915ad7274b20, 0xeab75ca7c918e4ef, ], - [0xd6e15ffc055e154e, 0xec67881f381a32bf, 0xfbb1196092bf409c, 0xdc9d2e07830ba226, - 0x0698ef3245ff7988, 0x194fae2974f8b576, 0x7a5d9bea6ca4910e, 0x7aebfea95ccdd1c9, - 0xf9bd38a67d5f0e86, 0xfa65539de65492d8, 0xf0dfcbe7653ff787, ], - [0x0bd87ad390420258, 0x0ad8617bca9e33c8, 0x0c00ad377a1e2666, 0x0ac6fc58b3f0518f, - 0x0c0cc8a892cc4173, 0x0c210accb117bc21, 0x0b73630dbb46ca18, 0x0c8be4920cbd4a54, - 0x0bfe877a21be1690, 0x0ae790559b0ded81, 0x0bf50db2f8d6ce31, ], - [0x000cf29427ff7c58, 0x000bd9b3cf49eec8, 0x000d1dc8aa81fb26, 0x000bc792d5c394ef, - 0x000d2ae0b2266453, 0x000d413f12c496c1, 0x000c84128cfed618, 0x000db5ebd48fc0d4, - 0x000d1b77326dcb90, 0x000beb0ccc145421, 0x000d10e5b22b11d1, ], - [0x00000e24c99adad8, 0x00000cf389ed4bc8, 0x00000e580cbf6966, 0x00000cde5fd7e04f, - 0x00000e63628041b3, 0x00000e7e81a87361, 0x00000dabe78f6d98, 0x00000efb14cac554, - 0x00000e5574743b10, 0x00000d05709f42c1, 0x00000e4690c96af1, ], - [0x0000000f7157bc98, 0x0000000e3006d948, 0x0000000fa65811e6, 0x0000000e0d127e2f, - 0x0000000fc18bfe53, 0x0000000fd002d901, 0x0000000eed6461d8, 0x0000001068562754, - 0x0000000fa0236f50, 0x0000000e3af13ee1, 0x0000000fa460f6d1, ], - [0x0000000011131738, 0x000000000f56d588, 0x0000000011050f86, 0x000000000f848f4f, - 0x00000000111527d3, 0x00000000114369a1, 0x00000000106f2f38, 0x0000000011e2ca94, - 0x00000000110a29f0, 0x000000000fa9f5c1, 0x0000000010f625d1, ], - [0x000000000011f718, 0x000000000010b6c8, 0x0000000000134a96, 0x000000000010cf7f, - 0x0000000000124d03, 0x000000000013f8a1, 0x0000000000117c58, 0x0000000000132c94, - 0x0000000000134fc0, 0x000000000010a091, 0x0000000000128961, ], - [0x0000000000001300, 0x0000000000001750, 0x000000000000114e, 0x000000000000131f, - 0x000000000000167b, 0x0000000000001371, 0x0000000000001230, 0x000000000000182c, - 0x0000000000001368, 0x0000000000000f31, 0x00000000000015c9, ], - [0x0000000000000014, 0x0000000000000022, 0x0000000000000012, 0x0000000000000027, - 0x000000000000000d, 0x000000000000000d, 0x000000000000001c, 0x0000000000000002, - 0x0000000000000010, 0x0000000000000029, 0x000000000000000f, ], - ]; - - const FAST_PARTIAL_ROUND_W_HATS: [[u64; 12 - 1]; N_PARTIAL_ROUNDS] = [ - [0x3d999c961b7c63b0, 0x814e82efcd172529, 0x2421e5d236704588, 0x887af7d4dd482328, - 0xa5e9c291f6119b27, 0xbdc52b2676a4b4aa, 0x64832009d29bcf57, 0x09c4155174a552cc, - 0x463f9ee03d290810, 0xc810936e64982542, 0x043b1c289f7bc3ac, ], - [0x673655aae8be5a8b, 0xd510fe714f39fa10, 0x2c68a099b51c9e73, 0xa667bfa9aa96999d, - 0x4d67e72f063e2108, 0xf84dde3e6acda179, 0x40f9cc8c08f80981, 0x5ead032050097142, - 0x6591b02092d671bb, 0x00e18c71963dd1b7, 0x8a21bcd24a14218a, ], - [0x202800f4addbdc87, 0xe4b5bdb1cc3504ff, 0xbe32b32a825596e7, 0x8e0f68c5dc223b9a, - 0x58022d9e1c256ce3, 0x584d29227aa073ac, 0x8b9352ad04bef9e7, 0xaead42a3f445ecbf, - 0x3c667a1d833a3cca, 0xda6f61838efa1ffe, 0xe8f749470bd7c446, ], - [0xc5b85bab9e5b3869, 0x45245258aec51cf7, 0x16e6b8e68b931830, 0xe2ae0f051418112c, - 0x0470e26a0093a65b, 0x6bef71973a8146ed, 0x119265be51812daf, 0xb0be7356254bea2e, - 0x8584defff7589bd7, 0x3c5fe4aeb1fb52ba, 0x9e7cd88acf543a5e, ], - [0x179be4bba87f0a8c, 0xacf63d95d8887355, 0x6696670196b0074f, 0xd99ddf1fe75085f9, - 0xc2597881fef0283b, 0xcf48395ee6c54f14, 0x15226a8e4cd8d3b6, 0xc053297389af5d3b, - 0x2c08893f0d1580e2, 0x0ed3cbcff6fcc5ba, 0xc82f510ecf81f6d0, ], - [0x94b06183acb715cc, 0x500392ed0d431137, 0x861cc95ad5c86323, 0x05830a443f86c4ac, - 0x3b68225874a20a7c, 0x10b3309838e236fb, 0x9b77fc8bcd559e2c, 0xbdecf5e0cb9cb213, - 0x30276f1221ace5fa, 0x7935dd342764a144, 0xeac6db520bb03708, ], - [0x7186a80551025f8f, 0x622247557e9b5371, 0xc4cbe326d1ad9742, 0x55f1523ac6a23ea2, - 0xa13dfe77a3d52f53, 0xe30750b6301c0452, 0x08bd488070a3a32b, 0xcd800caef5b72ae3, - 0x83329c90f04233ce, 0xb5b99e6664a0a3ee, 0x6b0731849e200a7f, ], - [0xec3fabc192b01799, 0x382b38cee8ee5375, 0x3bfb6c3f0e616572, 0x514abd0cf6c7bc86, - 0x47521b1361dcc546, 0x178093843f863d14, 0xad1003c5d28918e7, 0x738450e42495bc81, - 0xaf947c59af5e4047, 0x4653fb0685084ef2, 0x057fde2062ae35bf, ], - [0xe376678d843ce55e, 0x66f3860d7514e7fc, 0x7817f3dfff8b4ffa, 0x3929624a9def725b, - 0x0126ca37f215a80a, 0xfce2f5d02762a303, 0x1bc927375febbad7, 0x85b481e5243f60bf, - 0x2d3c5f42a39c91a0, 0x0811719919351ae8, 0xf669de0add993131, ], - [0x7de38bae084da92d, 0x5b848442237e8a9b, 0xf6c705da84d57310, 0x31e6a4bdb6a49017, - 0x889489706e5c5c0f, 0x0e4a205459692a1b, 0xbac3fa75ee26f299, 0x5f5894f4057d755e, - 0xb0dc3ecd724bb076, 0x5e34d8554a6452ba, 0x04f78fd8c1fdcc5f, ], - [0x4dd19c38779512ea, 0xdb79ba02704620e9, 0x92a29a3675a5d2be, 0xd5177029fe495166, - 0xd32b3298a13330c1, 0x251c4a3eb2c5f8fd, 0xe1c48b26e0d98825, 0x3301d3362a4ffccb, - 0x09bb6c88de8cd178, 0xdc05b676564f538a, 0x60192d883e473fee, ], - [0x16b9774801ac44a0, 0x3cb8411e786d3c8e, 0xa86e9cf505072491, 0x0178928152e109ae, - 0x5317b905a6e1ab7b, 0xda20b3be7f53d59f, 0xcb97dedecebee9ad, 0x4bd545218c59f58d, - 0x77dc8d856c05a44a, 0x87948589e4f243fd, 0x7e5217af969952c2, ], - [0xbc58987d06a84e4d, 0x0b5d420244c9cae3, 0xa3c4711b938c02c0, 0x3aace640a3e03990, - 0x865a0f3249aacd8a, 0x8d00b2a7dbed06c7, 0x6eacb905beb7e2f8, 0x045322b216ec3ec7, - 0xeb9de00d594828e6, 0x088c5f20df9e5c26, 0xf555f4112b19781f, ], - [0xa8cedbff1813d3a7, 0x50dcaee0fd27d164, 0xf1cb02417e23bd82, 0xfaf322786e2abe8b, - 0x937a4315beb5d9b6, 0x1b18992921a11d85, 0x7d66c4368b3c497b, 0x0e7946317a6b4e99, - 0xbe4430134182978b, 0x3771e82493ab262d, 0xa671690d8095ce82, ], - [0xb035585f6e929d9d, 0xba1579c7e219b954, 0xcb201cf846db4ba3, 0x287bf9177372cf45, - 0xa350e4f61147d0a6, 0xd5d0ecfb50bcff99, 0x2e166aa6c776ed21, 0xe1e66c991990e282, - 0x662b329b01e7bb38, 0x8aa674b36144d9a9, 0xcbabf78f97f95e65, ], - [0xeec24b15a06b53fe, 0xc8a7aa07c5633533, 0xefe9c6fa4311ad51, 0xb9173f13977109a1, - 0x69ce43c9cc94aedc, 0xecf623c9cd118815, 0x28625def198c33c7, 0xccfc5f7de5c3636a, - 0xf5e6c40f1621c299, 0xcec0e58c34cb64b1, 0xa868ea113387939f, ], - [0xd8dddbdc5ce4ef45, 0xacfc51de8131458c, 0x146bb3c0fe499ac0, 0x9e65309f15943903, - 0x80d0ad980773aa70, 0xf97817d4ddbf0607, 0xe4626620a75ba276, 0x0dfdc7fd6fc74f66, - 0xf464864ad6f2bb93, 0x02d55e52a5d44414, 0xdd8de62487c40925, ], - [0xc15acf44759545a3, 0xcbfdcf39869719d4, 0x33f62042e2f80225, 0x2599c5ead81d8fa3, - 0x0b306cb6c1d7c8d0, 0x658c80d3df3729b1, 0xe8d1b2b21b41429c, 0xa1b67f09d4b3ccb8, - 0x0e1adf8b84437180, 0x0d593a5e584af47b, 0xa023d94c56e151c7, ], - [0x49026cc3a4afc5a6, 0xe06dff00ab25b91b, 0x0ab38c561e8850ff, 0x92c3c8275e105eeb, - 0xb65256e546889bd0, 0x3c0468236ea142f6, 0xee61766b889e18f2, 0xa206f41b12c30415, - 0x02fe9d756c9f12d1, 0xe9633210630cbf12, 0x1ffea9fe85a0b0b1, ], - [0x81d1ae8cc50240f3, 0xf4c77a079a4607d7, 0xed446b2315e3efc1, 0x0b0a6b70915178c3, - 0xb11ff3e089f15d9a, 0x1d4dba0b7ae9cc18, 0x65d74e2f43b48d05, 0xa2df8c6b8ae0804a, - 0xa4e6f0a8c33348a6, 0xc0a26efc7be5669b, 0xa6b6582c547d0d60, ], - [0x84afc741f1c13213, 0x2f8f43734fc906f3, 0xde682d72da0a02d9, 0x0bb005236adb9ef2, - 0x5bdf35c10a8b5624, 0x0739a8a343950010, 0x52f515f44785cfbc, 0xcbaf4e5d82856c60, - 0xac9ea09074e3e150, 0x8f0fa011a2035fb0, 0x1a37905d8450904a, ], - [0x3abeb80def61cc85, 0x9d19c9dd4eac4133, 0x075a652d9641a985, 0x9daf69ae1b67e667, - 0x364f71da77920a18, 0x50bd769f745c95b1, 0xf223d1180dbbf3fc, 0x2f885e584e04aa99, - 0xb69a0fa70aea684a, 0x09584acaa6e062a0, 0x0bc051640145b19b, ], - ]; - - // NB: This is in ROW-major order to support cache-friendly pre-multiplication. - const FAST_PARTIAL_ROUND_INITIAL_MATRIX: [[u64; 12 - 1]; 12 - 1] = [ - [0x80772dc2645b280b, 0xdc927721da922cf8, 0xc1978156516879ad, 0x90e80c591f48b603, - 0x3a2432625475e3ae, 0x00a2d4321cca94fe, 0x77736f524010c932, 0x904d3f2804a36c54, - 0xbf9b39e28a16f354, 0x3a1ded54a6cd058b, 0x42392870da5737cf, ], - [0xe796d293a47a64cb, 0xb124c33152a2421a, 0x0ee5dc0ce131268a, 0xa9032a52f930fae6, - 0x7e33ca8c814280de, 0xad11180f69a8c29e, 0xc75ac6d5b5a10ff3, 0xf0674a8dc5a387ec, - 0xb36d43120eaa5e2b, 0x6f232aab4b533a25, 0x3a1ded54a6cd058b, ], - [0xdcedab70f40718ba, 0x14a4a64da0b2668f, 0x4715b8e5ab34653b, 0x1e8916a99c93a88e, - 0xbba4b5d86b9a3b2c, 0xe76649f9bd5d5c2e, 0xaf8e2518a1ece54d, 0xdcda1344cdca873f, - 0xcd080204256088e5, 0xb36d43120eaa5e2b, 0xbf9b39e28a16f354, ], - [0xf4a437f2888ae909, 0xc537d44dc2875403, 0x7f68007619fd8ba9, 0xa4911db6a32612da, - 0x2f7e9aade3fdaec1, 0xe7ffd578da4ea43d, 0x43a608e7afa6b5c2, 0xca46546aa99e1575, - 0xdcda1344cdca873f, 0xf0674a8dc5a387ec, 0x904d3f2804a36c54, ], - [0xf97abba0dffb6c50, 0x5e40f0c9bb82aab5, 0x5996a80497e24a6b, 0x07084430a7307c9a, - 0xad2f570a5b8545aa, 0xab7f81fef4274770, 0xcb81f535cf98c9e9, 0x43a608e7afa6b5c2, - 0xaf8e2518a1ece54d, 0xc75ac6d5b5a10ff3, 0x77736f524010c932, ], - [0x7f8e41e0b0a6cdff, 0x4b1ba8d40afca97d, 0x623708f28fca70e8, 0xbf150dc4914d380f, - 0xc26a083554767106, 0x753b8b1126665c22, 0xab7f81fef4274770, 0xe7ffd578da4ea43d, - 0xe76649f9bd5d5c2e, 0xad11180f69a8c29e, 0x00a2d4321cca94fe, ], - [0x726af914971c1374, 0x1d7f8a2cce1a9d00, 0x18737784700c75cd, 0x7fb45d605dd82838, - 0x862361aeab0f9b6e, 0xc26a083554767106, 0xad2f570a5b8545aa, 0x2f7e9aade3fdaec1, - 0xbba4b5d86b9a3b2c, 0x7e33ca8c814280de, 0x3a2432625475e3ae, ], - [0x64dd936da878404d, 0x4db9a2ead2bd7262, 0xbe2e19f6d07f1a83, 0x02290fe23c20351a, - 0x7fb45d605dd82838, 0xbf150dc4914d380f, 0x07084430a7307c9a, 0xa4911db6a32612da, - 0x1e8916a99c93a88e, 0xa9032a52f930fae6, 0x90e80c591f48b603, ], - [0x85418a9fef8a9890, 0xd8a2eb7ef5e707ad, 0xbfe85ababed2d882, 0xbe2e19f6d07f1a83, - 0x18737784700c75cd, 0x623708f28fca70e8, 0x5996a80497e24a6b, 0x7f68007619fd8ba9, - 0x4715b8e5ab34653b, 0x0ee5dc0ce131268a, 0xc1978156516879ad, ], - [0x156048ee7a738154, 0x91f7562377e81df5, 0xd8a2eb7ef5e707ad, 0x4db9a2ead2bd7262, - 0x1d7f8a2cce1a9d00, 0x4b1ba8d40afca97d, 0x5e40f0c9bb82aab5, 0xc537d44dc2875403, - 0x14a4a64da0b2668f, 0xb124c33152a2421a, 0xdc927721da922cf8, ], - [0xd841e8ef9dde8ba0, 0x156048ee7a738154, 0x85418a9fef8a9890, 0x64dd936da878404d, - 0x726af914971c1374, 0x7f8e41e0b0a6cdff, 0xf97abba0dffb6c50, 0xf4a437f2888ae909, - 0xdcedab70f40718ba, 0xe796d293a47a64cb, 0x80772dc2645b280b, ], - ]; - -} diff --git a/poseidon/src/poseidon_hash.rs b/poseidon/src/poseidon_hash.rs index ceac268ff..0beaa9732 100644 --- a/poseidon/src/poseidon_hash.rs +++ b/poseidon/src/poseidon_hash.rs @@ -1,11 +1,11 @@ use std::marker::PhantomData; +use p3_field::PrimeField; use p3_mds::MdsPermutation; use crate::{ constants::{DIGEST_WIDTH, SPONGE_RATE, SPONGE_WIDTH}, digest::Digest, - poseidon::PoseidonField, poseidon_permutation::PoseidonPermutation, }; @@ -13,7 +13,7 @@ pub struct PoseidonHash { _phantom: PhantomData<(F, Mds)>, } -impl PoseidonHash +impl PoseidonHash where Mds: MdsPermutation + Default, { @@ -59,7 +59,7 @@ where } } -pub fn hash_n_to_m_no_pad(inputs: &[F], num_outputs: usize) -> Vec +pub fn hash_n_to_m_no_pad(inputs: &[F], num_outputs: usize) -> Vec where Mds: MdsPermutation + Default, { @@ -87,7 +87,7 @@ where } } -pub fn hash_n_to_m_no_pad_iter<'a, F: PoseidonField, I: Iterator, Mds>( +pub fn hash_n_to_m_no_pad_iter<'a, F: PrimeField, I: Iterator, Mds>( mut input_iter: I, num_outputs: usize, ) -> Vec @@ -122,7 +122,7 @@ where } } -pub fn hash_n_to_hash_no_pad(inputs: &[F]) -> Digest +pub fn hash_n_to_hash_no_pad(inputs: &[F]) -> Digest where Mds: MdsPermutation + Default, { @@ -131,7 +131,7 @@ where .unwrap() } -pub fn compress(x: &Digest, y: &Digest) -> Digest +pub fn compress(x: &Digest, y: &Digest) -> Digest where Mds: MdsPermutation + Default, { @@ -146,9 +146,9 @@ where #[cfg(test)] mod tests { - use crate::{digest::Digest, poseidon_hash::PoseidonHash}; + use crate::{P2MdsMatrixGoldilocks, digest::Digest, poseidon_hash::PoseidonHash}; use p3_field::FieldAlgebra; - use p3_goldilocks::{Goldilocks, MdsMatrixGoldilocks}; + use p3_goldilocks::Goldilocks; use plonky2::{ field::{ goldilocks_field::GoldilocksField, @@ -204,9 +204,9 @@ mod tests { let (plonky_elems, ceno_elems) = test_vector_pair(n); let plonky_out = PlonkyPoseidonHash::hash_or_noop(plonky_elems.as_slice()); let ceno_out = - PoseidonHash::<_, MdsMatrixGoldilocks>::hash_or_noop(ceno_elems.as_slice()); + PoseidonHash::<_, P2MdsMatrixGoldilocks>::hash_or_noop(ceno_elems.as_slice()); let ceno_iter = - PoseidonHash::<_, MdsMatrixGoldilocks>::hash_or_noop_iter(ceno_elems.iter()); + PoseidonHash::<_, P2MdsMatrixGoldilocks>::hash_or_noop_iter(ceno_elems.iter()); assert!(compare_hash_output(plonky_out, ceno_out)); assert!(compare_hash_output(plonky_out, ceno_iter)); } @@ -220,9 +220,9 @@ mod tests { let (plonky_elems, ceno_elems) = test_vector_pair(n); let plonky_out = PlonkyPoseidonHash::hash_or_noop(plonky_elems.as_slice()); let ceno_out = - PoseidonHash::<_, MdsMatrixGoldilocks>::hash_or_noop(ceno_elems.as_slice()); + PoseidonHash::<_, P2MdsMatrixGoldilocks>::hash_or_noop(ceno_elems.as_slice()); let ceno_iter = - PoseidonHash::<_, MdsMatrixGoldilocks>::hash_or_noop_iter(ceno_elems.iter()); + PoseidonHash::<_, P2MdsMatrixGoldilocks>::hash_or_noop_iter(ceno_elems.iter()); assert!(compare_hash_output(plonky_out, ceno_out)); assert!(compare_hash_output(plonky_out, ceno_iter)); } @@ -235,7 +235,7 @@ mod tests { let (plonky_hash_b, ceno_hash_b) = random_hash_pair(); let plonky_combined = PlonkyPoseidonHash::two_to_one(plonky_hash_a, plonky_hash_b); let ceno_combined = - PoseidonHash::<_, MdsMatrixGoldilocks>::two_to_one(&ceno_hash_a, &ceno_hash_b); + PoseidonHash::<_, P2MdsMatrixGoldilocks>::two_to_one(&ceno_hash_a, &ceno_hash_b); assert!(compare_hash_output(plonky_combined, ceno_combined)); } } diff --git a/poseidon/src/poseidon_permutation.rs b/poseidon/src/poseidon_permutation.rs index bf2143c1e..319a49821 100644 --- a/poseidon/src/poseidon_permutation.rs +++ b/poseidon/src/poseidon_permutation.rs @@ -2,11 +2,8 @@ use p3_field::PrimeField; use p3_mds::MdsPermutation; use p3_poseidon::Poseidon; -use crate::{ - constants::{ - ALL_ROUND_CONSTANTS, HALF_N_FULL_ROUNDS, N_PARTIAL_ROUNDS, SPONGE_RATE, SPONGE_WIDTH, - }, - poseidon::PoseidonField, +use crate::constants::{ + ALL_ROUND_CONSTANTS, HALF_N_FULL_ROUNDS, N_PARTIAL_ROUNDS, SPONGE_RATE, SPONGE_WIDTH, }; use p3_symmetric::Permutation; @@ -14,12 +11,12 @@ use p3_symmetric::Permutation; pub(crate) const ALPHA: u64 = 7; #[derive(Clone)] -pub struct PoseidonPermutation { +pub struct PoseidonPermutation { poseidon: Poseidon, state: [T; SPONGE_WIDTH], } -impl PoseidonPermutation +impl PoseidonPermutation where Mds: MdsPermutation + Default, { diff --git a/sumcheck/src/test.rs b/sumcheck/src/test.rs index 1e456f6fd..ad944820c 100644 --- a/sumcheck/src/test.rs +++ b/sumcheck/src/test.rs @@ -6,7 +6,7 @@ use multilinear_extensions::virtual_poly::VirtualPolynomial; use p3_field::FieldAlgebra; use p3_goldilocks::MdsMatrixGoldilocks; use p3_mds::MdsPermutation; -use poseidon::{SPONGE_WIDTH, poseidon::PoseidonField}; +use poseidon::SPONGE_WIDTH; use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator}; use transcript::{BasicTranscript, Transcript}; @@ -23,7 +23,6 @@ fn test_sumcheck( num_multiplicands_range: (usize, usize), num_products: usize, ) where - E::BaseField: PoseidonField, Mds: MdsPermutation + Default, { let mut rng = test_rng(); @@ -55,7 +54,6 @@ fn test_sumcheck_internal( num_multiplicands_range: (usize, usize), num_products: usize, ) where - E::BaseField: PoseidonField, Mds: MdsPermutation + Default, { let mut rng = test_rng(); @@ -117,7 +115,6 @@ fn test_trivial_polynomial() { fn test_trivial_polynomial_helper() where - E::BaseField: PoseidonField, Mds: MdsPermutation + Default, { let nv = 1; @@ -136,7 +133,6 @@ fn test_normal_polynomial() { fn test_normal_polynomial_helper() where - E::BaseField: PoseidonField, Mds: MdsPermutation + Default, { let nv = 12; @@ -164,7 +160,6 @@ fn test_extract_sum() { fn test_extract_sum_helper() where - E::BaseField: PoseidonField, Mds: MdsPermutation + Default, { let mut rng = test_rng(); diff --git a/transcript/src/basic.rs b/transcript/src/basic.rs index 61b69c253..c0bd2126e 100644 --- a/transcript/src/basic.rs +++ b/transcript/src/basic.rs @@ -1,22 +1,18 @@ use ff_ext::ExtensionField; use p3_mds::MdsPermutation; -use poseidon::{SPONGE_WIDTH, poseidon::PoseidonField, poseidon_permutation::PoseidonPermutation}; +use poseidon::{SPONGE_WIDTH, poseidon_permutation::PoseidonPermutation}; use crate::{Challenge, ForkableTranscript, Transcript}; use ff_ext::SmallField; use p3_field::FieldAlgebra; #[derive(Clone)] -pub struct BasicTranscript -where - E::BaseField: PoseidonField, -{ +pub struct BasicTranscript { permutation: PoseidonPermutation, } impl BasicTranscript where - E::BaseField: PoseidonField, Mds: MdsPermutation + Default, { /// Create a new IOP transcript. @@ -31,7 +27,6 @@ where impl Transcript for BasicTranscript where - E::BaseField: PoseidonField, Mds: MdsPermutation + Default, { fn append_field_elements(&mut self, elements: &[E::BaseField]) { @@ -72,9 +67,7 @@ where } } -impl ForkableTranscript for BasicTranscript -where - E::BaseField: PoseidonField, - Mds: MdsPermutation + Default, +impl ForkableTranscript for BasicTranscript where + Mds: MdsPermutation + Default { } diff --git a/transcript/src/statistics.rs b/transcript/src/statistics.rs index 6811ffefa..4552bea7e 100644 --- a/transcript/src/statistics.rs +++ b/transcript/src/statistics.rs @@ -1,7 +1,7 @@ use crate::{BasicTranscript, Challenge, ForkableTranscript, Transcript}; use ff_ext::ExtensionField; use p3_mds::MdsPermutation; -use poseidon::{SPONGE_WIDTH, poseidon::PoseidonField}; +use poseidon::SPONGE_WIDTH; use std::cell::RefCell; #[derive(Debug, Default)] @@ -12,17 +12,13 @@ pub struct Statistic { pub type StatisticRecorder = RefCell; #[derive(Clone)] -pub struct BasicTranscriptWithStat<'a, E: ExtensionField, Mds> -where - E::BaseField: PoseidonField, -{ +pub struct BasicTranscriptWithStat<'a, E: ExtensionField, Mds> { inner: BasicTranscript, stat: &'a StatisticRecorder, } impl<'a, E: ExtensionField, Mds> BasicTranscriptWithStat<'a, E, Mds> where - E::BaseField: PoseidonField, Mds: MdsPermutation + Default, { pub fn new(stat: &'a StatisticRecorder, label: &'static [u8]) -> Self { @@ -35,7 +31,6 @@ where impl Transcript for BasicTranscriptWithStat<'_, E, Mds> where - E::BaseField: PoseidonField, Mds: MdsPermutation + Default, { fn append_field_elements(&mut self, elements: &[E::BaseField]) { @@ -69,9 +64,7 @@ where } } -impl ForkableTranscript for BasicTranscriptWithStat<'_, E, Mds> -where - E::BaseField: PoseidonField, - Mds: MdsPermutation + Default, +impl ForkableTranscript for BasicTranscriptWithStat<'_, E, Mds> where + Mds: MdsPermutation + Default { } From e189724082a64f958094e077ef8492d485604566 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Tue, 7 Jan 2025 13:15:09 +0800 Subject: [PATCH 12/12] mpcs to plonky3 field --- Cargo.lock | 4 + ff_ext/src/lib.rs | 15 +- mpcs/Cargo.toml | 4 + mpcs/benches/basecode.rs | 7 +- mpcs/benches/basefold.rs | 36 ++-- mpcs/benches/fft.rs | 7 +- mpcs/benches/hashing.rs | 9 +- mpcs/benches/interpolate.rs | 3 +- mpcs/benches/rscode.rs | 6 +- mpcs/benches/utils.rs | 3 +- mpcs/src/basefold.rs | 151 +++++++++------- mpcs/src/basefold/commit_phase.rs | 41 +++-- mpcs/src/basefold/encoding.rs | 4 +- mpcs/src/basefold/encoding/basecode.rs | 14 +- mpcs/src/basefold/encoding/rs.rs | 125 ++++++-------- mpcs/src/basefold/query_phase.rs | 230 ++++++++++++++++--------- mpcs/src/basefold/structure.rs | 69 +++++--- mpcs/src/basefold/sumcheck.rs | 17 +- mpcs/src/lib.rs | 38 ++-- mpcs/src/sum_check.rs | 17 +- mpcs/src/sum_check/classic.rs | 63 ++++--- mpcs/src/sum_check/classic/coeff.rs | 27 +-- mpcs/src/util.rs | 33 ++-- mpcs/src/util/arithmetic.rs | 94 ++-------- mpcs/src/util/arithmetic/hypercube.rs | 2 +- mpcs/src/util/expression.rs | 3 +- mpcs/src/util/hash.rs | 57 +++--- mpcs/src/util/merkle_tree.rs | 115 ++++++++----- sumcheck/src/test.rs | 2 +- 29 files changed, 653 insertions(+), 543 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6be119d21..c80f87249 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1141,6 +1141,10 @@ dependencies = [ "multilinear_extensions", "num-bigint", "num-integer", + "p3-field", + "p3-goldilocks", + "p3-mds", + "p3-symmetric", "plonky2", "poseidon", "rand", diff --git a/ff_ext/src/lib.rs b/ff_ext/src/lib.rs index d8bba3c86..9d1018253 100644 --- a/ff_ext/src/lib.rs +++ b/ff_ext/src/lib.rs @@ -4,13 +4,12 @@ use std::{array::from_fn, iter::repeat_with}; pub use ff; use p3_field::{ - ExtensionField as P3ExtensionField, Field as P3Field, PackedValue, PrimeField, + ExtensionField as P3ExtensionField, Field as P3Field, PackedValue, PrimeField, TwoAdicField, extension::BinomialExtensionField, }; use p3_goldilocks::Goldilocks; use rand_core::RngCore; use serde::Serialize; - pub type GoldilocksExt2 = BinomialExtensionField; fn array_try_from_uniform_bytes< @@ -92,7 +91,7 @@ pub trait SmallField: Serialize + P3Field { pub trait ExtensionField: P3ExtensionField + FromUniformBytes { const DEGREE: usize; - type BaseField: SmallField + Ord + PrimeField + FromUniformBytes; + type BaseField: SmallField + Ord + PrimeField + FromUniformBytes + TwoAdicField; fn from_bases(bases: &[Self::BaseField]) -> Self; @@ -176,3 +175,13 @@ mod impl_goldilocks { } } } + +#[cfg(test)] +mod test { + use p3_field::TwoAdicField; + use p3_goldilocks::Goldilocks; + #[test] + fn test() { + println!("{:?}", Goldilocks::two_adic_generator(21)); + } +} diff --git a/mpcs/Cargo.toml b/mpcs/Cargo.toml index f977328cc..dbdedfdfd 100644 --- a/mpcs/Cargo.toml +++ b/mpcs/Cargo.toml @@ -25,6 +25,10 @@ num-bigint = "0.4" num-integer = "0.1" plonky2.workspace = true poseidon.workspace = true +p3-field.workspace = true +p3-goldilocks.workspace = true +p3-mds.workspace = true +p3-symmetric.workspace = true rand.workspace = true rand_chacha.workspace = true rayon = { workspace = true, optional = true } diff --git a/mpcs/benches/basecode.rs b/mpcs/benches/basecode.rs index 9ef1896f6..ab2d245b1 100644 --- a/mpcs/benches/basecode.rs +++ b/mpcs/benches/basecode.rs @@ -1,9 +1,8 @@ use std::time::Duration; use criterion::*; -use ff::Field; -use goldilocks::GoldilocksExt2; +use ff_ext::GoldilocksExt2; use itertools::Itertools; use mpcs::{ Basefold, BasefoldBasecodeParams, BasefoldSpec, EncodingScheme, PolynomialCommitmentScheme, @@ -13,12 +12,14 @@ use mpcs::{ }, }; +use ff_ext::FromUniformBytes; use multilinear_extensions::mle::{DenseMultilinearExtension, FieldType}; +use p3_goldilocks::MdsMatrixGoldilocks; use rand::{SeedableRng, rngs::OsRng}; use rand_chacha::ChaCha8Rng; use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; -type Pcs = Basefold; +type Pcs = Basefold; type E = GoldilocksExt2; const NUM_SAMPLES: usize = 10; diff --git a/mpcs/benches/basefold.rs b/mpcs/benches/basefold.rs index 64b9a5460..1e287fab4 100644 --- a/mpcs/benches/basefold.rs +++ b/mpcs/benches/basefold.rs @@ -1,8 +1,7 @@ use std::time::Duration; use criterion::*; -use ff_ext::ExtensionField; -use goldilocks::GoldilocksExt2; +use ff_ext::{ExtensionField, GoldilocksExt2}; use itertools::{Itertools, chain}; use mpcs::{ @@ -18,11 +17,12 @@ use multilinear_extensions::{ mle::{DenseMultilinearExtension, MultilinearExtension}, virtual_poly::ArcMultilinearExtension, }; +use p3_goldilocks::MdsMatrixGoldilocks; use transcript::{BasicTranscript, Transcript}; -type PcsGoldilocksRSCode = Basefold; -type PcsGoldilocksBasecode = Basefold; -type T = BasicTranscript; +type PcsGoldilocksRSCode = Basefold; +type PcsGoldilocksBasecode = Basefold; +type T = BasicTranscript; type E = GoldilocksExt2; const NUM_SAMPLES: usize = 10; @@ -73,12 +73,12 @@ fn bench_commit_open_verify_goldilocks>( let point = get_point_from_challenge(num_vars, &mut transcript); let eval = poly.evaluate(point.as_slice()); transcript.append_field_element_ext(&eval); - let transcript_for_bench = transcript; + let transcript_for_bench = transcript.clone(); let proof = Pcs::open(&pp, &poly, &comm, &point, &eval, &mut transcript).unwrap(); group.bench_function(BenchmarkId::new("open", format!("{}", num_vars)), |b| { b.iter_batched( - || transcript_for_bench, + || transcript_for_bench.clone(), |mut transcript| { Pcs::open(&pp, &poly, &comm, &point, &eval, &mut transcript).unwrap(); }, @@ -91,11 +91,11 @@ fn bench_commit_open_verify_goldilocks>( Pcs::write_commitment(&comm, &mut transcript).unwrap(); let point = get_point_from_challenge(num_vars, &mut transcript); transcript.append_field_element_ext(&eval); - let transcript_for_bench = transcript; + let transcript_for_bench = transcript.clone(); Pcs::verify(&vp, &comm, &point, &eval, &proof, &mut transcript).unwrap(); group.bench_function(BenchmarkId::new("verify", format!("{}", num_vars)), |b| { b.iter_batched( - || transcript_for_bench, + || transcript_for_bench.clone(), |mut transcript| { Pcs::verify(&vp, &comm, &point, &eval, &proof, &mut transcript).unwrap(); }, @@ -163,7 +163,7 @@ fn bench_batch_commit_open_verify_goldilocks> .collect_vec(); let values: Vec = evals.iter().map(Evaluation::value).copied().collect(); transcript.append_field_element_exts(values.as_slice()); - let transcript_for_bench = transcript; + let transcript_for_bench = transcript.clone(); let proof = Pcs::batch_open(&pp, &polys, &comms, &points, &evals, &mut transcript).unwrap(); @@ -171,7 +171,7 @@ fn bench_batch_commit_open_verify_goldilocks> BenchmarkId::new("batch_open", format!("{}-{}", num_vars, batch_size)), |b| { b.iter_batched( - || transcript_for_bench, + || transcript_for_bench.clone(), |mut transcript| { Pcs::batch_open(&pp, &polys, &comms, &points, &evals, &mut transcript) .unwrap(); @@ -203,7 +203,7 @@ fn bench_batch_commit_open_verify_goldilocks> .collect::>(); transcript.append_field_element_exts(values.as_slice()); - let backup_transcript = transcript; + let backup_transcript = transcript.clone(); Pcs::batch_verify(&vp, &comms, &points, &evals, &proof, &mut transcript).unwrap(); @@ -211,7 +211,7 @@ fn bench_batch_commit_open_verify_goldilocks> BenchmarkId::new("batch_verify", format!("{}-{}", num_vars, batch_size)), |b| { b.iter_batched( - || backup_transcript, + || backup_transcript.clone(), |mut transcript| { Pcs::batch_verify( &vp, @@ -261,7 +261,7 @@ fn bench_simple_batch_commit_open_verify_goldilocks::new(b"BaseFold"); Pcs::write_commitment(&comm, &mut transcript).unwrap(); let point = get_point_from_challenge(num_vars, &mut transcript); transcript.append_field_element_exts(&evals); - let backup_transcript = transcript; + let backup_transcript = transcript.clone(); Pcs::simple_batch_verify(&vp, &comm, &point, &evals, &proof, &mut transcript).unwrap(); @@ -305,7 +305,7 @@ fn bench_simple_batch_commit_open_verify_goldilocks( &mut poly.evaluations, - Goldilocks::MULTIPLICATIVE_GENERATOR, + Goldilocks::GENERATOR, 0, &root_table, ); diff --git a/mpcs/benches/hashing.rs b/mpcs/benches/hashing.rs index 818cd7b6c..7d1059144 100644 --- a/mpcs/benches/hashing.rs +++ b/mpcs/benches/hashing.rs @@ -1,8 +1,9 @@ use ark_std::test_rng; use criterion::{Criterion, criterion_group, criterion_main}; -use ff::Field; -use goldilocks::Goldilocks; + +use ff_ext::FromUniformBytes; use mpcs::util::hash::{Digest, hash_two_digests}; +use p3_goldilocks::{Goldilocks, MdsMatrixGoldilocks}; use poseidon::poseidon_hash::PoseidonHash; fn random_ceno_goldy() -> Goldilocks { @@ -12,12 +13,12 @@ pub fn criterion_benchmark(c: &mut Criterion) { let left = Digest(vec![random_ceno_goldy(); 4].try_into().unwrap()); let right = Digest(vec![random_ceno_goldy(); 4].try_into().unwrap()); c.bench_function("ceno hash 2 to 1", |bencher| { - bencher.iter(|| hash_two_digests(&left, &right)) + bencher.iter(|| hash_two_digests::(&left, &right)) }); let values = (0..60).map(|_| random_ceno_goldy()).collect::>(); c.bench_function("ceno hash 60 to 1", |bencher| { - bencher.iter(|| PoseidonHash::hash_or_noop(&values)) + bencher.iter(|| PoseidonHash::::hash_or_noop(&values)) }); } diff --git a/mpcs/benches/interpolate.rs b/mpcs/benches/interpolate.rs index 79b5ab805..60e69672d 100644 --- a/mpcs/benches/interpolate.rs +++ b/mpcs/benches/interpolate.rs @@ -1,9 +1,8 @@ use std::time::Duration; use criterion::*; -use ff::Field; -use goldilocks::GoldilocksExt2; +use ff_ext::{FromUniformBytes, GoldilocksExt2}; use itertools::Itertools; use mpcs::util::arithmetic::interpolate_field_type_over_boolean_hypercube; diff --git a/mpcs/benches/rscode.rs b/mpcs/benches/rscode.rs index 2d284d177..a7ed08ce6 100644 --- a/mpcs/benches/rscode.rs +++ b/mpcs/benches/rscode.rs @@ -1,9 +1,8 @@ use std::time::Duration; use criterion::*; -use ff::Field; -use goldilocks::GoldilocksExt2; +use ff_ext::{FromUniformBytes, GoldilocksExt2}; use itertools::Itertools; use mpcs::{ Basefold, BasefoldRSParams, BasefoldSpec, EncodingScheme, PolynomialCommitmentScheme, @@ -14,11 +13,12 @@ use mpcs::{ }; use multilinear_extensions::mle::{DenseMultilinearExtension, FieldType}; +use p3_goldilocks::MdsMatrixGoldilocks; use rand::{SeedableRng, rngs::OsRng}; use rand_chacha::ChaCha8Rng; use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; -type Pcs = Basefold; +type Pcs = Basefold; type E = GoldilocksExt2; const NUM_SAMPLES: usize = 10; diff --git a/mpcs/benches/utils.rs b/mpcs/benches/utils.rs index 93ac567a3..68a86f65e 100644 --- a/mpcs/benches/utils.rs +++ b/mpcs/benches/utils.rs @@ -1,9 +1,8 @@ use std::time::Duration; use criterion::*; -use ff::Field; -use goldilocks::GoldilocksExt2; +use ff_ext::{FromUniformBytes, GoldilocksExt2}; use mpcs::{one_level_eval_hc, one_level_interp_hc}; use rand::rngs::OsRng; diff --git a/mpcs/src/basefold.rs b/mpcs/src/basefold.rs index 6204ed038..6f6e91127 100644 --- a/mpcs/src/basefold.rs +++ b/mpcs/src/basefold.rs @@ -28,13 +28,15 @@ pub use encoding::{ }; use ff_ext::ExtensionField; use multilinear_extensions::mle::MultilinearExtension; +use p3_mds::MdsPermutation; +use poseidon::SPONGE_WIDTH; use query_phase::{ BatchedQueriesResultWithMerklePath, QueriesResultWithMerklePath, SimpleBatchQueriesResultWithMerklePath, batch_prover_query_phase, batch_verifier_query_phase, prover_query_phase, simple_batch_prover_query_phase, simple_batch_verifier_query_phase, verifier_query_phase, }; -use std::{borrow::BorrowMut, ops::Deref}; +use std::{borrow::BorrowMut, fmt::Debug, ops::Deref}; pub use structure::BasefoldSpec; use structure::{BasefoldProof, ProofQueriesResultWithMerklePath}; use transcript::Transcript; @@ -79,7 +81,7 @@ enum PolyEvalsCodeword { TooBig(usize), } -impl> Basefold +impl, Mds> Basefold where E: Serialize + DeserializeOwned, E::BaseField: Serialize + DeserializeOwned, @@ -265,18 +267,20 @@ where /// positions are (i >> k) and (i >> k) XOR 1. /// (c) The verifier checks that the folding has been correctly computed /// at these positions. -impl> PolynomialCommitmentScheme for Basefold +impl, Mds> PolynomialCommitmentScheme + for Basefold where E: Serialize + DeserializeOwned, E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default + Debug, { type Param = BasefoldParams; type ProverParam = BasefoldProverParams; type VerifierParam = BasefoldVerifierParams; - type CommitmentWithWitness = BasefoldCommitmentWithWitness; + type CommitmentWithWitness = BasefoldCommitmentWithWitness; type Commitment = BasefoldCommitment; type CommitmentChunk = Digest; - type Proof = BasefoldProof; + type Proof = BasefoldProof; fn setup(poly_size: usize) -> Result { let pp = >::setup(log2_strict(poly_size)); @@ -323,7 +327,7 @@ where // (2) The encoding of the coefficient vector (need an interpolation) let ret = match Self::get_poly_bh_evals_and_codeword(pp, poly) { PolyEvalsCodeword::Normal((bh_evals, codeword)) => { - let codeword_tree = MerkleTree::::from_leaves(codeword); + let codeword_tree = MerkleTree::::from_leaves(codeword); // All these values are stored in the `CommitmentWithWitness` because // they are useful in opening, and we don't want to recompute them. @@ -336,7 +340,7 @@ where }) } PolyEvalsCodeword::TooSmall(evals) => { - let codeword_tree = MerkleTree::::from_leaves(evals.clone()); + let codeword_tree = MerkleTree::::from_leaves(evals.clone()); // All these values are stored in the `CommitmentWithWitness` because // they are useful in opening, and we don't want to recompute them. @@ -412,7 +416,7 @@ where } }) .collect::<(Vec<_>, Vec<_>)>(); - let codeword_tree = MerkleTree::::from_batch_leaves(codewords); + let codeword_tree = MerkleTree::::from_batch_leaves(codewords); Self::CommitmentWithWitness { codeword_tree, polynomials_bh_evals: bh_evals, @@ -432,7 +436,7 @@ where } }) .collect::>(); - let codeword_tree = MerkleTree::::from_batch_leaves(bh_evals.clone()); + let codeword_tree = MerkleTree::::from_batch_leaves(bh_evals.clone()); Self::CommitmentWithWitness { codeword_tree, polynomials_bh_evals: bh_evals, @@ -494,7 +498,7 @@ where // part, the prover needs to prepare the answers to the // queries, so the prover needs the oracles and the Merkle // trees built over them. - let (trees, commit_phase_proof) = commit_phase::( + let (trees, commit_phase_proof) = commit_phase::( &pp.encoding_params, point, comm, @@ -594,7 +598,7 @@ where evals.iter().map(Evaluation::value), &evals .iter() - .map(|eval| E::from(1 << (num_vars - points[eval.point()].len()))) + .map(|eval| E::from_canonical_u64(1 << (num_vars - points[eval.point()].len()))) .collect_vec(), &poly_iter_ext(&eq_xt).take(evals.len()).collect_vec(), ); @@ -645,8 +649,8 @@ where inner_product( &poly_iter_ext(poly).collect_vec(), build_eq_x_r_vec(point).iter(), - ) * scalar - * E::from(1 << (num_vars - poly.num_vars)) + ) * *scalar + * E::from_canonical_u64(1 << (num_vars - poly.num_vars)) // When this polynomial is smaller, it will be repeatedly summed over the cosets of the hypercube }) .sum::(); @@ -719,7 +723,7 @@ where let point = challenges; - let (trees, commit_phase_proof) = batch_commit_phase::( + let (trees, commit_phase_proof) = batch_commit_phase::( &pp.encoding_params, &point, comms, @@ -815,7 +819,7 @@ where // The remaining tasks for the prover is to prove that // sum_i coeffs[i] poly_evals[i] is equal to // the new target sum, where coeffs is computed as follows - let (trees, commit_phase_proof) = simple_batch_commit_phase::( + let (trees, commit_phase_proof) = simple_batch_commit_phase::( &pp.encoding_params, point, &eq_xt, @@ -864,7 +868,7 @@ where if proof.is_trivial() { let trivial_proof = &proof.trivial_proof; - let merkle_tree = MerkleTree::from_batch_leaves(trivial_proof.clone()); + let merkle_tree = MerkleTree::::from_batch_leaves(trivial_proof.clone()); if comm.root() == merkle_tree.root() { return Ok(()); } else { @@ -919,7 +923,7 @@ where let mut eq = build_eq_x_r_vec(&point[..point.len() - fold_challenges.len()]); eq.par_iter_mut().for_each(|e| *e *= coeff); - verifier_query_phase::( + verifier_query_phase::( queries.as_slice(), &vp.encoding_params, query_result_with_merkle_path, @@ -977,7 +981,7 @@ where evals.iter().map(Evaluation::value), &evals .iter() - .map(|eval| E::from(1 << (num_vars - points[eval.point()].len()))) + .map(|eval| E::from_canonical_u64(1 << (num_vars - points[eval.point()].len()))) .collect_vec(), &poly_iter_ext(&eq_xt).take(evals.len()).collect_vec(), ); @@ -1044,7 +1048,7 @@ where ); eq.par_iter_mut().for_each(|e| *e *= coeff); - batch_verifier_query_phase::( + batch_verifier_query_phase::( queries.as_slice(), &vp.encoding_params, query_result_with_merkle_path, @@ -1079,7 +1083,7 @@ where if proof.is_trivial() { let trivial_proof = &proof.trivial_proof; - let merkle_tree = MerkleTree::from_batch_leaves(trivial_proof.clone()); + let merkle_tree = MerkleTree::::from_batch_leaves(trivial_proof.clone()); if comm.root() == merkle_tree.root() { return Ok(()); } else { @@ -1144,7 +1148,7 @@ where let mut eq = build_eq_x_r_vec(&point[..point.len() - fold_challenges.len()]); eq.par_iter_mut().for_each(|e| *e *= coeff); - simple_batch_verifier_query_phase::( + simple_batch_verifier_query_phase::( queries.as_slice(), &vp.encoding_params, query_result_with_merkle_path, @@ -1165,15 +1169,20 @@ where } } -impl> NoninteractivePCS for Basefold +impl, Mds> NoninteractivePCS + for Basefold where E: Serialize + DeserializeOwned, E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default + Debug, { } #[cfg(test)] mod test { + use ff_ext::GoldilocksExt2; + use p3_goldilocks::MdsMatrixGoldilocks; + use crate::{ basefold::Basefold, test_util::{ @@ -1181,83 +1190,93 @@ mod test { run_commit_open_verify, run_simple_batch_commit_open_verify, }, }; - use goldilocks::GoldilocksExt2; use super::{BasefoldRSParams, structure::BasefoldBasecodeParams}; - type PcsGoldilocksRSCode = Basefold; - type PcsGoldilocksBaseCode = Basefold; + type PcsGoldilocksRSCode = Basefold; + type PcsGoldilocksBaseCode = + Basefold; #[test] fn commit_open_verify_goldilocks() { for gen_rand_poly in [gen_rand_poly_base, gen_rand_poly_ext] { // Challenge is over extension field, poly over the base field - run_commit_open_verify::(gen_rand_poly, 10, 11); - // Test trivial proof with small num vars - run_commit_open_verify::(gen_rand_poly, 4, 6); - // Challenge is over extension field, poly over the base field - run_commit_open_verify::(gen_rand_poly, 10, 11); - // Test trivial proof with small num vars - run_commit_open_verify::(gen_rand_poly, 4, 6); - } - } - - #[test] - fn simple_batch_commit_open_verify_goldilocks() { - for gen_rand_poly in [gen_rand_poly_base, gen_rand_poly_ext] { - // Both challenge and poly are over base field - run_simple_batch_commit_open_verify::( - gen_rand_poly, - 10, - 11, - 1, - ); - run_simple_batch_commit_open_verify::( + run_commit_open_verify::( gen_rand_poly, 10, 11, - 4, ); // Test trivial proof with small num vars - run_simple_batch_commit_open_verify::( + run_commit_open_verify::( gen_rand_poly, 4, 6, - 4, ); - // Both challenge and poly are over base field - run_simple_batch_commit_open_verify::( - gen_rand_poly, - 10, - 11, - 1, - ); - run_simple_batch_commit_open_verify::( + // Challenge is over extension field, poly over the base field + run_commit_open_verify::( gen_rand_poly, 10, 11, - 4, ); // Test trivial proof with small num vars - run_simple_batch_commit_open_verify::( + run_commit_open_verify::( gen_rand_poly, 4, 6, - 4, ); } } + #[test] + fn simple_batch_commit_open_verify_goldilocks() { + for gen_rand_poly in [gen_rand_poly_base, gen_rand_poly_ext] { + // Both challenge and poly are over base field + run_simple_batch_commit_open_verify::< + GoldilocksExt2, + PcsGoldilocksBaseCode, + MdsMatrixGoldilocks, + >(gen_rand_poly, 10, 11, 1); + run_simple_batch_commit_open_verify::< + GoldilocksExt2, + PcsGoldilocksBaseCode, + MdsMatrixGoldilocks, + >(gen_rand_poly, 10, 11, 4); + // Test trivial proof with small num vars + run_simple_batch_commit_open_verify::< + GoldilocksExt2, + PcsGoldilocksBaseCode, + MdsMatrixGoldilocks, + >(gen_rand_poly, 4, 6, 4); + // Both challenge and poly are over base field + run_simple_batch_commit_open_verify::< + GoldilocksExt2, + PcsGoldilocksRSCode, + MdsMatrixGoldilocks, + >(gen_rand_poly, 10, 11, 1); + run_simple_batch_commit_open_verify::< + GoldilocksExt2, + PcsGoldilocksRSCode, + MdsMatrixGoldilocks, + >(gen_rand_poly, 10, 11, 4); + // Test trivial proof with small num vars + run_simple_batch_commit_open_verify::< + GoldilocksExt2, + PcsGoldilocksRSCode, + MdsMatrixGoldilocks, + >(gen_rand_poly, 4, 6, 4); + } + } + #[test] fn batch_commit_open_verify() { for gen_rand_poly in [gen_rand_poly_base, gen_rand_poly_ext] { // Both challenge and poly are over base field - run_batch_commit_open_verify::( - gen_rand_poly, - 10, - 11, - ); - run_batch_commit_open_verify::( + run_batch_commit_open_verify::< + GoldilocksExt2, + PcsGoldilocksBaseCode, + MdsMatrixGoldilocks, + >(gen_rand_poly, 10, 11); + run_batch_commit_open_verify::( gen_rand_poly, 10, 11, diff --git a/mpcs/src/basefold/commit_phase.rs b/mpcs/src/basefold/commit_phase.rs index 55a6acea5..f61424fda 100644 --- a/mpcs/src/basefold/commit_phase.rs +++ b/mpcs/src/basefold/commit_phase.rs @@ -16,6 +16,8 @@ use crate::util::{ use ark_std::{end_timer, start_timer}; use ff_ext::ExtensionField; use itertools::Itertools; +use p3_mds::MdsPermutation; +use poseidon::SPONGE_WIDTH; use serde::{Serialize, de::DeserializeOwned}; use transcript::Transcript; @@ -30,16 +32,17 @@ use rayon::prelude::{ use super::structure::BasefoldCommitmentWithWitness; // outputs (trees, sumcheck_oracles, oracles, bh_evals, eq, eval) -pub fn commit_phase>( +pub fn commit_phase, Mds>( pp: &>::ProverParameters, point: &[E], - comm: &BasefoldCommitmentWithWitness, + comm: &BasefoldCommitmentWithWitness, transcript: &mut impl Transcript, num_vars: usize, num_rounds: usize, -) -> (Vec>, BasefoldCommitPhaseProof) +) -> (Vec>, BasefoldCommitPhaseProof) where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { let timer = start_timer!(|| "Commit phase"); #[cfg(feature = "sanity-check")] @@ -98,7 +101,7 @@ where ); if i > 0 { - let running_tree = MerkleTree::::from_inner_leaves( + let running_tree = MerkleTree::::from_inner_leaves( running_tree_inner, FieldType::Ext(running_oracle), ); @@ -116,8 +119,8 @@ where // Then the oracle will be used to fold to the next oracle in the next // round. After that, this oracle is free to be moved to build the // complete Merkle tree. - running_tree_inner = MerkleTree::::compute_inner_ext(&new_running_oracle); - let running_root = MerkleTree::::root_from_inner(&running_tree_inner); + running_tree_inner = MerkleTree::::compute_inner_ext(&new_running_oracle); + let running_root = MerkleTree::::root_from_inner(&running_tree_inner); write_digest_to_transcript(&running_root, transcript); roots.push(running_root.clone()); @@ -176,17 +179,18 @@ where // outputs (trees, sumcheck_oracles, oracles, bh_evals, eq, eval) #[allow(clippy::too_many_arguments)] -pub fn batch_commit_phase>( +pub fn batch_commit_phase, Mds>( pp: &>::ProverParameters, point: &[E], - comms: &[BasefoldCommitmentWithWitness], + comms: &[BasefoldCommitmentWithWitness], transcript: &mut impl Transcript, num_vars: usize, num_rounds: usize, coeffs: &[E], -) -> (Vec>, BasefoldCommitPhaseProof) +) -> (Vec>, BasefoldCommitPhaseProof) where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { let timer = start_timer!(|| "Batch Commit phase"); assert_eq!(point.len(), num_vars); @@ -266,7 +270,7 @@ where ); if i > 0 { - let running_tree = MerkleTree::::from_inner_leaves( + let running_tree = MerkleTree::::from_inner_leaves( running_tree_inner, FieldType::Ext(running_oracle), ); @@ -277,8 +281,8 @@ where last_sumcheck_message = sum_check_challenge_round(&mut eq, &mut sum_of_all_evals_for_sumcheck, challenge); sumcheck_messages.push(last_sumcheck_message.clone()); - running_tree_inner = MerkleTree::::compute_inner_ext(&new_running_oracle); - let running_root = MerkleTree::::root_from_inner(&running_tree_inner); + running_tree_inner = MerkleTree::::compute_inner_ext(&new_running_oracle); + let running_root = MerkleTree::::root_from_inner(&running_tree_inner); write_digest_to_transcript(&running_root, transcript); roots.push(running_root); @@ -346,17 +350,18 @@ where // outputs (trees, sumcheck_oracles, oracles, bh_evals, eq, eval) #[allow(clippy::too_many_arguments)] -pub fn simple_batch_commit_phase>( +pub fn simple_batch_commit_phase, Mds>( pp: &>::ProverParameters, point: &[E], batch_coeffs: &[E], - comm: &BasefoldCommitmentWithWitness, + comm: &BasefoldCommitmentWithWitness, transcript: &mut impl Transcript, num_vars: usize, num_rounds: usize, -) -> (Vec>, BasefoldCommitPhaseProof) +) -> (Vec>, BasefoldCommitPhaseProof) where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { let timer = start_timer!(|| "Simple batch commit phase"); assert_eq!(point.len(), num_vars); @@ -416,7 +421,7 @@ where ); if i > 0 { - let running_tree = MerkleTree::::from_inner_leaves( + let running_tree = MerkleTree::::from_inner_leaves( running_tree_inner, FieldType::Ext(running_oracle), ); @@ -426,8 +431,8 @@ where if i < num_rounds - 1 { last_sumcheck_message = sum_check_challenge_round(&mut eq, &mut running_evals, challenge); - running_tree_inner = MerkleTree::::compute_inner_ext(&new_running_oracle); - let running_root = MerkleTree::::root_from_inner(&running_tree_inner); + running_tree_inner = MerkleTree::::compute_inner_ext(&new_running_oracle); + let running_root = MerkleTree::::root_from_inner(&running_tree_inner); write_digest_to_transcript(&running_root, transcript); roots.push(running_root); running_oracle = new_running_oracle; diff --git a/mpcs/src/basefold/encoding.rs b/mpcs/src/basefold/encoding.rs index 410d35970..6c3a03d2f 100644 --- a/mpcs/src/basefold/encoding.rs +++ b/mpcs/src/basefold/encoding.rs @@ -173,7 +173,9 @@ pub(crate) mod test_util { pub fn test_codeword_folding>() { let num_vars = 12; - let poly: Vec = (0..(1 << num_vars)).map(|i| E::from(i)).collect(); + let poly: Vec = (0..(1 << num_vars)) + .map(|i| E::from_canonical_u64(i)) + .collect(); let mut poly = FieldType::Ext(poly); let pp: Code::PublicParameters = Code::setup(num_vars); diff --git a/mpcs/src/basefold/encoding/basecode.rs b/mpcs/src/basefold/encoding/basecode.rs index 9fbee84f1..e753879b8 100644 --- a/mpcs/src/basefold/encoding/basecode.rs +++ b/mpcs/src/basefold/encoding/basecode.rs @@ -10,10 +10,10 @@ use crate::{ }; use aes::cipher::{KeyIvInit, StreamCipher, StreamCipherSeek}; use ark_std::{end_timer, start_timer}; -use ff::{BatchInvert, Field, PrimeField}; use ff_ext::ExtensionField; use generic_array::GenericArray; use multilinear_extensions::mle::FieldType; +use p3_field::{Field, FieldAlgebra, batch_multiplicative_inverse}; use rand::SeedableRng; use rayon::prelude::{ParallelIterator, ParallelSlice, ParallelSliceMut}; @@ -216,7 +216,7 @@ where let x0: E::BaseField = query_root_table_from_rng_aes::(level, index, &mut cipher); let x1 = -x0; - let w = (x1 - x0).invert().unwrap(); + let w = (x1 - x0).try_inverse().unwrap(); (E::from(x0), E::from(x1), E::from(w)) } @@ -351,13 +351,13 @@ pub fn get_table_aes( assert_eq!(flat_table.len(), 1 << lg_n); // Multiply -2 to every element to get the weights. Now weights = { -2x } - let mut weights: Vec = flat_table + let weights: Vec = flat_table .par_iter() .map(|el| E::BaseField::ZERO - *el - *el) .collect(); // Then invert all the elements. Now weights = { -1/2x } - BatchInvert::batch_invert(&mut weights); + let weights = batch_multiplicative_inverse(&weights); // Zip x and -1/2x together. The result is the list { (x, -1/2x) } // What is this -1/2x? It is used in linear interpolation over the domain (x, -x), which @@ -399,13 +399,13 @@ pub fn query_root_table_from_rng_aes( } let pos = ((level_offset + (reverse_bits(index, level) as u128)) - * ((E::BaseField::NUM_BITS as usize).next_power_of_two() as u128)) + * (E::BaseField::bits().next_power_of_two() as u128)) .checked_div(8) .unwrap(); cipher.seek(pos); - let bytes = (E::BaseField::NUM_BITS as usize).next_power_of_two() / 8; + let bytes = (E::BaseField::bits()).next_power_of_two() / 8; let mut dest: Vec = vec![0u8; bytes]; cipher.apply_keystream(&mut dest); @@ -417,7 +417,7 @@ mod tests { use crate::basefold::encoding::test_util::test_codeword_folding; use super::*; - use goldilocks::GoldilocksExt2; + use ff_ext::GoldilocksExt2; use multilinear_extensions::mle::DenseMultilinearExtension; #[test] diff --git a/mpcs/src/basefold/encoding/rs.rs b/mpcs/src/basefold/encoding/rs.rs index 2bcac0826..fef22bc3f 100644 --- a/mpcs/src/basefold/encoding/rs.rs +++ b/mpcs/src/basefold/encoding/rs.rs @@ -7,9 +7,9 @@ use crate::{ vec_mut, }; use ark_std::{end_timer, start_timer}; -use ff::{Field, PrimeField}; use ff_ext::ExtensionField; use multilinear_extensions::mle::FieldType; +use p3_field::{Field, FieldAlgebra, PrimeField, TwoAdicField}; use serde::{Deserialize, Serialize, de::DeserializeOwned}; @@ -28,11 +28,11 @@ pub trait RSCodeSpec: std::fmt::Debug + Clone { /// The FFT codes in this file are borrowed and adapted from Plonky2. type FftRootTable = Vec>; -pub fn fft_root_table(lg_n: usize) -> FftRootTable { +pub fn fft_root_table(lg_n: usize) -> FftRootTable { // bases[i] = g^2^i, for i = 0, ..., lg_n - 1 // Note that the end of bases is g^{n/2} = -1 let mut bases = Vec::with_capacity(lg_n); - let mut base = F::ROOT_OF_UNITY.pow([(1 << (F::S - lg_n as u32)) as u64]); + let mut base = F::two_adic_generator(lg_n); bases.push(base); for _ in 1..lg_n { base = base.square(); // base = g^2^_ @@ -71,9 +71,8 @@ fn ifft( let n = poly.len(); let lg_n = log2_strict(n); let n_inv = (E::BaseField::ONE + E::BaseField::ONE) - .invert() - .unwrap() - .pow([lg_n as u64]); + .inverse() + .exp_u64(lg_n as u64); fft(poly, zero_factor, root_table); @@ -275,7 +274,9 @@ where fn setup(max_message_size_log: usize) -> Self::PublicParameters { RSCodeParameters { - fft_root_table: fft_root_table(max_message_size_log + Spec::get_rate_log()), + fft_root_table: fft_root_table::( + max_message_size_log + Spec::get_rate_log(), + ), } } @@ -310,13 +311,13 @@ where } let mut gamma_powers = Vec::with_capacity(max_message_size_log); let mut gamma_powers_inv = Vec::with_capacity(max_message_size_log); - gamma_powers.push(E::BaseField::MULTIPLICATIVE_GENERATOR); - gamma_powers_inv.push(E::BaseField::MULTIPLICATIVE_GENERATOR.invert().unwrap()); + gamma_powers.push(E::BaseField::GENERATOR); + gamma_powers_inv.push(E::BaseField::GENERATOR.inverse()); for i in 1..max_message_size_log + Spec::get_rate_log() { gamma_powers.push(gamma_powers[i - 1].square()); gamma_powers_inv.push(gamma_powers_inv[i - 1].square()); } - let inv_of_two = E::BaseField::from(2).invert().unwrap(); + let inv_of_two = E::BaseField::from_canonical_u64(2).inverse(); gamma_powers_inv.iter_mut().for_each(|x| *x *= inv_of_two); pp.fft_root_table .truncate(max_message_size_log + Spec::get_rate_log()); @@ -427,7 +428,7 @@ where } else { // In this case, the level-th row of fft root table of the verifier // only stores the first 2^(level+1)-th roots of unity. - vp.fft_root_table[level][0].pow([index as u64]) + vp.fft_root_table[level][0].exp_u64(index as u64) } * vp.gamma_powers[vp.full_message_size_log + Spec::get_rate_log() - level - 1]; let x1 = -x0; // The weight is 1/(x1-x0) = -1/(2x0) @@ -447,7 +448,7 @@ where } else { // In this case, this level of fft root table of the verifier // only stores the first 2^(level+1)-th root of unity. - vp.fft_root_table[level][0].pow([(1 << (level + 1)) - index as u64]) + vp.fft_root_table[level][0].exp_u64((1 << (level + 1)) - index as u64) }; (E::from(x0), E::from(x1), E::from(w)) } @@ -490,10 +491,9 @@ impl RSCode { // of size n * rate. // When the input message size is not n, but n/2^k, then the domain is // gamma^2^k H. - let k = 1 << (full_message_size_log - lg_m); coset_fft( &mut ret, - E::BaseField::MULTIPLICATIVE_GENERATOR.pow([k]), + E::BaseField::GENERATOR.exp_power_of_2(full_message_size_log - lg_m), Spec::get_rate_log(), fft_root_table, ); @@ -511,13 +511,11 @@ impl RSCode { let index = reverse_bits(index, level); // x0 is the index-th 2^(level+1)-th root of unity, multiplied by // the shift factor at level+1, which is gamma^2^(full_codeword_log_n - level - 1). - let x0 = E::BaseField::ROOT_OF_UNITY - .pow([1 << (E::BaseField::S - (level as u32 + 1))]) - .pow([index as u64]) - * E::BaseField::MULTIPLICATIVE_GENERATOR - .pow([1 << (full_message_size_log + Spec::get_rate_log() - level - 1)]); + let x0 = E::BaseField::two_adic_generator(level + 1).exp_u64(index as u64) + * E::BaseField::GENERATOR + .exp_power_of_2(full_message_size_log + Spec::get_rate_log() - level - 1); let x1 = -x0; - let w = (x1 - x0).invert().unwrap(); + let w = (x1 - x0).inverse(); (E::from(x0), E::from(x1), E::from(w)) } } @@ -527,7 +525,7 @@ fn naive_fft(poly: &[E], rate: usize, shift: E::BaseField) -> let timer = start_timer!(|| "Encode RSCode"); let message_size = poly.len(); let domain_size_bit = log2_strict(message_size * rate); - let root = E::BaseField::ROOT_OF_UNITY.pow([1 << (E::BaseField::S - domain_size_bit as u32)]); + let root = E::BaseField::two_adic_generator(domain_size_bit); // The domain is shift * H where H is the multiplicative subgroup of size // message_size * rate. let mut domain = Vec::::with_capacity(message_size * rate); @@ -546,19 +544,24 @@ fn naive_fft(poly: &[E], rate: usize, shift: E::BaseField) -> #[cfg(test)] mod tests { + use ff_ext::GoldilocksExt2; + use p3_goldilocks::Goldilocks; + use crate::{ basefold::encoding::test_util::test_codeword_folding, util::{field_type_index_ext, plonky2_util::reverse_index_bits_in_place_field_type}, }; + use ff_ext::FromUniformBytes; use super::*; - use goldilocks::{Goldilocks, GoldilocksExt2}; #[test] fn test_naive_fft() { let num_vars = 5; - let poly: Vec = (0..(1 << num_vars)).map(GoldilocksExt2::from).collect(); + let poly: Vec = (0..(1 << num_vars)) + .map(GoldilocksExt2::from_canonical_u64) + .collect(); let mut poly2 = FieldType::Ext(poly.clone()); let naive = naive_fft::(&poly, 1, Goldilocks::ONE); @@ -583,15 +586,10 @@ mod tests { .collect(); let mut poly2 = FieldType::Ext(poly.clone()); - let naive = naive_fft::(&poly, 1, Goldilocks::MULTIPLICATIVE_GENERATOR); + let naive = naive_fft::(&poly, 1, Goldilocks::GENERATOR); let root_table = fft_root_table(num_vars); - coset_fft::( - &mut poly2, - Goldilocks::MULTIPLICATIVE_GENERATOR, - 0, - &root_table, - ); + coset_fft::(&mut poly2, Goldilocks::GENERATOR, 0, &root_table); let poly2 = match poly2 { FieldType::Ext(coeffs) => coeffs, @@ -613,19 +611,10 @@ mod tests { poly2.as_mut_slice()[..poly.len()].copy_from_slice(poly.as_slice()); let mut poly2 = FieldType::Ext(poly2.clone()); - let naive = naive_fft::( - &poly, - 1 << rate_bits, - Goldilocks::MULTIPLICATIVE_GENERATOR, - ); + let naive = naive_fft::(&poly, 1 << rate_bits, Goldilocks::GENERATOR); let root_table = fft_root_table(num_vars + rate_bits); - coset_fft::( - &mut poly2, - Goldilocks::MULTIPLICATIVE_GENERATOR, - rate_bits, - &root_table, - ); + coset_fft::(&mut poly2, Goldilocks::GENERATOR, rate_bits, &root_table); let poly2 = match poly2 { FieldType::Ext(coeffs) => coeffs, @@ -638,7 +627,9 @@ mod tests { fn test_ifft() { let num_vars = 5; - let poly: Vec = (0..(1 << num_vars)).map(GoldilocksExt2::from).collect(); + let poly: Vec = (0..(1 << num_vars)) + .map(GoldilocksExt2::from_canonical_u64) + .collect(); let mut poly = FieldType::Ext(poly); let original = poly.clone(); @@ -686,14 +677,14 @@ mod tests { pub fn test_colinearity() { let num_vars = 10; - let poly: Vec = (0..(1 << num_vars)).map(E::from).collect(); + let poly: Vec = (0..(1 << num_vars)).map(E::from_canonical_u64).collect(); let poly = FieldType::Ext(poly); let pp = >::setup(num_vars); let (pp, _) = Code::trim(pp, num_vars).unwrap(); let mut codeword = Code::encode(&pp, &poly); reverse_index_bits_in_place_field_type(&mut codeword); - let challenge = E::from(2); + let challenge = E::from_canonical_u64(2); let folded_codeword = Code::fold_bitreversed_codeword(&pp, &codeword, challenge); let codeword = match codeword { FieldType::Ext(coeffs) => coeffs, @@ -712,8 +703,8 @@ mod tests { // which is equivalent to // (x0-challenge)*(b[1]-a) = (x1-challenge)*(b[0]-a) assert_eq!( - (x0 - challenge) * (b[1] - a), - (x1 - challenge) * (b[0] - a), + (x0 - challenge) * (b[1] - *a), + (x1 - challenge) * (b[0] - *a), "failed for i = {}", i ); @@ -724,7 +715,7 @@ mod tests { pub fn test_low_degree() { let num_vars = 10; - let poly: Vec = (0..(1 << num_vars)).map(E::from).collect(); + let poly: Vec = (0..(1 << num_vars)).map(E::from_canonical_u64).collect(); let poly = FieldType::Ext(poly); let pp = >::setup(num_vars); @@ -777,10 +768,11 @@ mod tests { assert_eq!(left_right_diff[0], c0 - c_mid); reverse_index_bits_in_place(&mut left_right_diff); assert_eq!(left_right_diff[1], c1 - c_mid1); - let root_of_unity_inv = F::ROOT_OF_UNITY_INV - .pow([1 << (F::S as usize - log2_strict(left_right_diff.len()) - 1)]); + let root_of_unity_inv = F::two_adic_generator(F::TWO_ADICITY) + .inverse() + .exp_power_of_2(F::TWO_ADICITY - log2_strict(left_right_diff.len()) - 1); for (i, coeff) in left_right_diff.iter_mut().enumerate() { - *coeff *= root_of_unity_inv.pow([i as u64]); + *coeff *= root_of_unity_inv.exp_u64(i as u64); } assert_eq!(left_right_diff[0], c0 - c_mid); assert_eq!(left_right_diff[1], (c1 - c_mid1) * root_of_unity_inv); @@ -789,7 +781,7 @@ mod tests { "check low degree of (left-right)*omega^(-i)", ); - let challenge = E::from(2); + let challenge = E::from_canonical_u64(2); let folded_codeword = Code::fold_bitreversed_codeword(&pp, &codeword, challenge); let c_fold = folded_codeword[0]; let c_fold1 = folded_codeword[folded_codeword.len() >> 1]; @@ -800,7 +792,7 @@ mod tests { // The top level folding coefficient should have shift factor gamma let folding_coeffs = Code::prover_folding_coeffs(&pp, log2_strict(codeword.len()) - 1, 0); - assert_eq!(folding_coeffs.0, E::from(F::MULTIPLICATIVE_GENERATOR)); + assert_eq!(folding_coeffs.0, E::from(F::GENERATOR)); assert_eq!(folding_coeffs.0 + folding_coeffs.1, E::ZERO); assert_eq!( (folding_coeffs.1 - folding_coeffs.0) * folding_coeffs.2, @@ -815,28 +807,25 @@ mod tests { // So the folded value should be equal to // (gamma^{-1} * alpha * (c0 - c_mid) + (c0 + c_mid)) / 2 assert_eq!( - c_fold * F::MULTIPLICATIVE_GENERATOR * F::from(2), - challenge * (c0 - c_mid) + (c0 + c_mid) * F::MULTIPLICATIVE_GENERATOR + c_fold * F::GENERATOR * F::from_canonical_u64(2), + challenge * (c0 - c_mid) + (c0 + c_mid) * F::GENERATOR ); assert_eq!( - c_fold * F::MULTIPLICATIVE_GENERATOR * F::from(2), - challenge * left_right_diff[0] + left_right_sum[0] * F::MULTIPLICATIVE_GENERATOR + c_fold * F::GENERATOR * F::from_canonical_u64(2), + challenge * left_right_diff[0] + left_right_sum[0] * F::GENERATOR ); assert_eq!( - c_fold * F::from(2), - challenge * left_right_diff[0] * F::MULTIPLICATIVE_GENERATOR.invert().unwrap() - + left_right_sum[0] + c_fold * F::from_canonical_u64(2), + challenge * left_right_diff[0] * F::GENERATOR.inverse() + left_right_sum[0] ); let folding_coeffs = Code::prover_folding_coeffs(&pp, log2_strict(codeword.len()) - 1, 1); - let root_of_unity = - F::ROOT_OF_UNITY.pow([1 << (F::S as usize - log2_strict(codeword.len()))]); - assert_eq!(root_of_unity.pow([codeword.len() as u64]), F::ONE); - assert_eq!(root_of_unity.pow([(codeword.len() >> 1) as u64]), -F::ONE); + let root_of_unity = F::two_adic_generator(log2_strict(codeword.len())); + assert_eq!(root_of_unity.exp_u64(codeword.len() as u64), F::ONE); + assert_eq!(root_of_unity.exp_u64((codeword.len() >> 1) as u64), -F::ONE); assert_eq!( folding_coeffs.0, - E::from(F::MULTIPLICATIVE_GENERATOR) - * E::from(root_of_unity).pow([(codeword.len() >> 2) as u64]) + E::from(F::GENERATOR) * E::from(root_of_unity).exp_u64((codeword.len() >> 2) as u64) ); assert_eq!(folding_coeffs.0 + folding_coeffs.1, E::ZERO); assert_eq!( @@ -849,14 +838,14 @@ mod tests { // The coefficients are respectively 1/2 and gamma^{-1}/2 * alpha. // In another word, the folded codeword multipled by 2 is the linear // combination by coeffs: 1 and gamma^{-1} * alpha - let gamma_inv = F::MULTIPLICATIVE_GENERATOR.invert().unwrap(); + let gamma_inv = F::GENERATOR.inverse(); let b = challenge * gamma_inv; let folded_codeword_vec = match &folded_codeword { FieldType::Ext(coeffs) => coeffs.clone(), _ => panic!("Wrong field type"), }; assert_eq!( - c_fold * F::from(2), + c_fold * F::from_canonical_u64(2), left_right_diff[0] * b + left_right_sum[0] ); for (i, (c, (diff, sum))) in folded_codeword_vec @@ -864,7 +853,7 @@ mod tests { .zip(left_right_diff.iter().zip(left_right_sum.iter())) .enumerate() { - assert_eq!(*c + c, *sum + b * diff, "failed for i = {}", i); + assert_eq!(*c + *c, *sum + b * *diff, "failed for i = {}", i); } check_low_degree(&folded_codeword, "low degree check for folded"); diff --git a/mpcs/src/basefold/query_phase.rs b/mpcs/src/basefold/query_phase.rs index 9ec15d36b..a72fefb83 100644 --- a/mpcs/src/basefold/query_phase.rs +++ b/mpcs/src/basefold/query_phase.rs @@ -12,6 +12,8 @@ use ark_std::{end_timer, start_timer}; use core::fmt::Debug; use ff_ext::ExtensionField; use itertools::Itertools; +use p3_mds::MdsPermutation; +use poseidon::SPONGE_WIDTH; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use transcript::Transcript; @@ -28,14 +30,15 @@ use super::{ structure::{BasefoldCommitment, BasefoldCommitmentWithWitness, BasefoldSpec}, }; -pub fn prover_query_phase( +pub fn prover_query_phase( transcript: &mut impl Transcript, - comm: &BasefoldCommitmentWithWitness, - trees: &[MerkleTree], + comm: &BasefoldCommitmentWithWitness, + trees: &[MerkleTree], num_verifier_queries: usize, ) -> QueriesResult where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { let queries: Vec<_> = (0..num_verifier_queries) .map(|_| { @@ -57,22 +60,23 @@ where .map(|x_index| { ( *x_index, - basefold_get_query::(&comm.get_codewords()[0], trees, *x_index), + basefold_get_query::(&comm.get_codewords()[0], trees, *x_index), ) }) .collect(), } } -pub fn batch_prover_query_phase( +pub fn batch_prover_query_phase( transcript: &mut impl Transcript, codeword_size: usize, - comms: &[BasefoldCommitmentWithWitness], - trees: &[MerkleTree], + comms: &[BasefoldCommitmentWithWitness], + trees: &[MerkleTree], num_verifier_queries: usize, ) -> BatchedQueriesResult where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { let queries: Vec<_> = (0..num_verifier_queries) .map(|_| { @@ -94,21 +98,22 @@ where .map(|x_index| { ( *x_index, - batch_basefold_get_query::(comms, trees, codeword_size, *x_index), + batch_basefold_get_query::(comms, trees, codeword_size, *x_index), ) }) .collect(), } } -pub fn simple_batch_prover_query_phase( +pub fn simple_batch_prover_query_phase( transcript: &mut impl Transcript, - comm: &BasefoldCommitmentWithWitness, - trees: &[MerkleTree], + comm: &BasefoldCommitmentWithWitness, + trees: &[MerkleTree], num_verifier_queries: usize, ) -> SimpleBatchQueriesResult where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { let queries: Vec<_> = (0..num_verifier_queries) .map(|_| { @@ -130,7 +135,11 @@ where .map(|x_index| { ( *x_index, - simple_batch_basefold_get_query::(comm.get_codewords(), trees, *x_index), + simple_batch_basefold_get_query::( + comm.get_codewords(), + trees, + *x_index, + ), ) }) .collect(), @@ -138,10 +147,10 @@ where } #[allow(clippy::too_many_arguments)] -pub fn verifier_query_phase>( +pub fn verifier_query_phase, Mds>( indices: &[usize], vp: &>::VerifierParameters, - queries: &QueriesResultWithMerklePath, + queries: &QueriesResultWithMerklePath, sum_check_messages: &[Vec], fold_challenges: &[E], num_rounds: usize, @@ -153,6 +162,7 @@ pub fn verifier_query_phase>( eval: &E, ) where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { let timer = start_timer!(|| "Verifier query phase"); @@ -210,10 +220,10 @@ pub fn verifier_query_phase>( } #[allow(clippy::too_many_arguments)] -pub fn batch_verifier_query_phase>( +pub fn batch_verifier_query_phase, Mds>( indices: &[usize], vp: &>::VerifierParameters, - queries: &BatchedQueriesResultWithMerklePath, + queries: &BatchedQueriesResultWithMerklePath, sum_check_messages: &[Vec], fold_challenges: &[E], num_rounds: usize, @@ -226,6 +236,7 @@ pub fn batch_verifier_query_phase>( eval: &E, ) where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { let timer = start_timer!(|| "Verifier batch query phase"); let encode_timer = start_timer!(|| "Encode final codeword"); @@ -286,10 +297,10 @@ pub fn batch_verifier_query_phase>( } #[allow(clippy::too_many_arguments)] -pub fn simple_batch_verifier_query_phase>( +pub fn simple_batch_verifier_query_phase, Mds>( indices: &[usize], vp: &>::VerifierParameters, - queries: &SimpleBatchQueriesResultWithMerklePath, + queries: &SimpleBatchQueriesResultWithMerklePath, sum_check_messages: &[Vec], fold_challenges: &[E], batch_coeffs: &[E], @@ -302,6 +313,7 @@ pub fn simple_batch_verifier_query_phase + Default, { let timer = start_timer!(|| "Verifier query phase"); @@ -364,13 +376,14 @@ pub fn simple_batch_verifier_query_phase( +fn basefold_get_query( poly_codeword: &FieldType, - trees: &[MerkleTree], + trees: &[MerkleTree], x_index: usize, ) -> SingleQueryResult where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { let mut index = x_index; let p1 = index | 1; @@ -410,14 +423,15 @@ where } } -fn batch_basefold_get_query( - comms: &[BasefoldCommitmentWithWitness], - trees: &[MerkleTree], +fn batch_basefold_get_query( + comms: &[BasefoldCommitmentWithWitness], + trees: &[MerkleTree], codeword_size: usize, x_index: usize, ) -> BatchedSingleQueryResult where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { let mut oracle_list_queries = Vec::with_capacity(trees.len()); @@ -465,13 +479,14 @@ where } } -fn simple_batch_basefold_get_query( +fn simple_batch_basefold_get_query( poly_codewords: &[FieldType], - trees: &[MerkleTree], + trees: &[MerkleTree], x_index: usize, ) -> SimpleBatchSingleQueryResult where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { let mut index = x_index; let p1 = index | 1; @@ -529,6 +544,7 @@ where #[derive(Debug, Copy, Clone, Serialize, Deserialize)] enum CodewordPointPair { + #[serde(bound = "")] Ext(E, E), Base(E::BaseField, E::BaseField), } @@ -543,6 +559,7 @@ impl CodewordPointPair { } #[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound = "")] enum SimpleBatchLeavesPair where E::BaseField: Serialize + DeserializeOwned, @@ -588,6 +605,7 @@ where } #[derive(Debug, Copy, Clone, Serialize, Deserialize)] +#[serde(bound = "")] struct CodewordSingleQueryResult where E::BaseField: Serialize + DeserializeOwned, @@ -630,17 +648,20 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] -struct CodewordSingleQueryResultWithMerklePath +#[serde(bound = "")] +struct CodewordSingleQueryResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { query: CodewordSingleQueryResult, - merkle_path: MerklePathWithoutLeafOrRoot, + merkle_path: MerklePathWithoutLeafOrRoot, } -impl CodewordSingleQueryResultWithMerklePath +impl CodewordSingleQueryResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { pub fn check_merkle_path(&self, root: &Digest) { // let timer = start_timer!(|| "CodewordSingleQuery::Check Merkle Path"); @@ -659,6 +680,7 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound = "")] struct OracleListQueryResult where E::BaseField: Serialize + DeserializeOwned, @@ -667,6 +689,7 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound = "")] struct CommitmentsQueryResult where E::BaseField: Serialize + DeserializeOwned, @@ -675,24 +698,29 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] -struct OracleListQueryResultWithMerklePath +#[serde(bound = "")] +struct OracleListQueryResultWithMerklePath where - E::BaseField: Serialize + DeserializeOwned, + E::BaseField: Serialize, + Mds: MdsPermutation + Default, { - inner: Vec>, + inner: Vec>, } #[derive(Debug, Clone, Serialize, Deserialize)] -struct CommitmentsQueryResultWithMerklePath +#[serde(bound = "")] +struct CommitmentsQueryResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { - inner: Vec>, + inner: Vec>, } -impl ListQueryResult for OracleListQueryResult +impl ListQueryResult for OracleListQueryResult where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { fn get_inner(&self) -> &Vec> { &self.inner @@ -703,9 +731,10 @@ where } } -impl ListQueryResult for CommitmentsQueryResult +impl ListQueryResult for CommitmentsQueryResult where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { fn get_inner(&self) -> &Vec> { &self.inner @@ -716,35 +745,40 @@ where } } -impl ListQueryResultWithMerklePath for OracleListQueryResultWithMerklePath +impl ListQueryResultWithMerklePath + for OracleListQueryResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { - fn get_inner(&self) -> &Vec> { + fn get_inner(&self) -> &Vec> { &self.inner } - fn new(inner: Vec>) -> Self { + fn new(inner: Vec>) -> Self { Self { inner } } } -impl ListQueryResultWithMerklePath for CommitmentsQueryResultWithMerklePath +impl ListQueryResultWithMerklePath + for CommitmentsQueryResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { - fn get_inner(&self) -> &Vec> { + fn get_inner(&self) -> &Vec> { &self.inner } - fn new(inner: Vec>) -> Self { + fn new(inner: Vec>) -> Self { Self { inner } } } -trait ListQueryResult +trait ListQueryResult where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { fn get_inner(&self) -> &Vec>; @@ -752,8 +786,8 @@ where fn merkle_path( &self, - path: impl Fn(usize, usize) -> MerklePathWithoutLeafOrRoot, - ) -> Vec> { + path: impl Fn(usize, usize) -> MerklePathWithoutLeafOrRoot, + ) -> Vec> { let ret = self .get_inner() .iter() @@ -764,17 +798,18 @@ where } } -trait ListQueryResultWithMerklePath: Sized +trait ListQueryResultWithMerklePath: Sized where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { - fn new(inner: Vec>) -> Self; + fn new(inner: Vec>) -> Self; - fn get_inner(&self) -> &Vec>; + fn get_inner(&self) -> &Vec>; - fn from_query_and_trees>( + fn from_query_and_trees>( query_result: LQR, - path: impl Fn(usize, usize) -> MerklePathWithoutLeafOrRoot, + path: impl Fn(usize, usize) -> MerklePathWithoutLeafOrRoot, ) -> Self { Self::new( query_result @@ -804,6 +839,7 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound = "")] struct SingleQueryResult where E::BaseField: Serialize + DeserializeOwned, @@ -813,22 +849,25 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] -struct SingleQueryResultWithMerklePath +#[serde(bound = "")] +struct SingleQueryResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { - oracle_query: OracleListQueryResultWithMerklePath, - commitment_query: CodewordSingleQueryResultWithMerklePath, + oracle_query: OracleListQueryResultWithMerklePath, + commitment_query: CodewordSingleQueryResultWithMerklePath, } -impl SingleQueryResultWithMerklePath +impl SingleQueryResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { pub fn from_single_query_result( single_query_result: SingleQueryResult, - oracle_trees: &[MerkleTree], - commitment: &BasefoldCommitmentWithWitness, + oracle_trees: &[MerkleTree], + commitment: &BasefoldCommitmentWithWitness, ) -> Self { assert!(commitment.codeword_tree.height() > 0); Self { @@ -909,16 +948,19 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct QueriesResultWithMerklePath +#[serde(bound = "")] +pub struct QueriesResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { - inner: Vec<(usize, SingleQueryResultWithMerklePath)>, + inner: Vec<(usize, SingleQueryResultWithMerklePath)>, } -impl QueriesResultWithMerklePath +impl QueriesResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { pub fn empty() -> Self { Self { inner: vec![] } @@ -926,8 +968,8 @@ where pub fn from_query_result( query_result: QueriesResult, - oracle_trees: &[MerkleTree], - commitment: &BasefoldCommitmentWithWitness, + oracle_trees: &[MerkleTree], + commitment: &BasefoldCommitmentWithWitness, ) -> Self { Self { inner: query_result @@ -978,6 +1020,7 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound = "")] struct BatchedSingleQueryResult where E::BaseField: Serialize + DeserializeOwned, @@ -987,22 +1030,25 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] -struct BatchedSingleQueryResultWithMerklePath +#[serde(bound = "")] +struct BatchedSingleQueryResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { - oracle_query: OracleListQueryResultWithMerklePath, - commitments_query: CommitmentsQueryResultWithMerklePath, + oracle_query: OracleListQueryResultWithMerklePath, + commitments_query: CommitmentsQueryResultWithMerklePath, } -impl BatchedSingleQueryResultWithMerklePath +impl BatchedSingleQueryResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { pub fn from_batched_single_query_result( batched_single_query_result: BatchedSingleQueryResult, - oracle_trees: &[MerkleTree], - commitments: &[BasefoldCommitmentWithWitness], + oracle_trees: &[MerkleTree], + commitments: &[BasefoldCommitmentWithWitness], ) -> Self { Self { oracle_query: OracleListQueryResultWithMerklePath::from_query_and_trees( @@ -1133,21 +1179,24 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct BatchedQueriesResultWithMerklePath +#[serde(bound = "")] +pub struct BatchedQueriesResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { - inner: Vec<(usize, BatchedSingleQueryResultWithMerklePath)>, + inner: Vec<(usize, BatchedSingleQueryResultWithMerklePath)>, } -impl BatchedQueriesResultWithMerklePath +impl BatchedQueriesResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { pub fn from_batched_query_result( batched_query_result: BatchedQueriesResult, - oracle_trees: &[MerkleTree], - commitments: &[BasefoldCommitmentWithWitness], + oracle_trees: &[MerkleTree], + commitments: &[BasefoldCommitmentWithWitness], ) -> Self { Self { inner: batched_query_result @@ -1202,6 +1251,7 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound = "")] struct SimpleBatchCommitmentSingleQueryResult where E::BaseField: Serialize + DeserializeOwned, @@ -1246,17 +1296,20 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] -struct SimpleBatchCommitmentSingleQueryResultWithMerklePath +#[serde(bound = "")] +struct SimpleBatchCommitmentSingleQueryResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { query: SimpleBatchCommitmentSingleQueryResult, - merkle_path: MerklePathWithoutLeafOrRoot, + merkle_path: MerklePathWithoutLeafOrRoot, } -impl SimpleBatchCommitmentSingleQueryResultWithMerklePath +impl SimpleBatchCommitmentSingleQueryResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { pub fn check_merkle_path(&self, root: &Digest) { // let timer = start_timer!(|| "CodewordSingleQuery::Check Merkle Path"); @@ -1283,6 +1336,7 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound = "")] struct SimpleBatchSingleQueryResult where E::BaseField: Serialize + DeserializeOwned, @@ -1292,22 +1346,25 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] -struct SimpleBatchSingleQueryResultWithMerklePath +#[serde(bound = "")] +struct SimpleBatchSingleQueryResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { - oracle_query: OracleListQueryResultWithMerklePath, - commitment_query: SimpleBatchCommitmentSingleQueryResultWithMerklePath, + oracle_query: OracleListQueryResultWithMerklePath, + commitment_query: SimpleBatchCommitmentSingleQueryResultWithMerklePath, } -impl SimpleBatchSingleQueryResultWithMerklePath +impl SimpleBatchSingleQueryResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { pub fn from_single_query_result( single_query_result: SimpleBatchSingleQueryResult, - oracle_trees: &[MerkleTree], - commitment: &BasefoldCommitmentWithWitness, + oracle_trees: &[MerkleTree], + commitment: &BasefoldCommitmentWithWitness, ) -> Self { Self { oracle_query: OracleListQueryResultWithMerklePath::from_query_and_trees( @@ -1389,21 +1446,24 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SimpleBatchQueriesResultWithMerklePath +#[serde(bound = "")] +pub struct SimpleBatchQueriesResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { - inner: Vec<(usize, SimpleBatchSingleQueryResultWithMerklePath)>, + inner: Vec<(usize, SimpleBatchSingleQueryResultWithMerklePath)>, } -impl SimpleBatchQueriesResultWithMerklePath +impl SimpleBatchQueriesResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { pub fn from_query_result( query_result: SimpleBatchQueriesResult, - oracle_trees: &[MerkleTree], - commitment: &BasefoldCommitmentWithWitness, + oracle_trees: &[MerkleTree], + commitment: &BasefoldCommitmentWithWitness, ) -> Self { Self { inner: query_result diff --git a/mpcs/src/basefold/structure.rs b/mpcs/src/basefold/structure.rs index 547dd7c51..b89ac7c7e 100644 --- a/mpcs/src/basefold/structure.rs +++ b/mpcs/src/basefold/structure.rs @@ -5,6 +5,8 @@ use crate::{ use core::fmt::Debug; use ff_ext::ExtensionField; +use p3_mds::MdsPermutation; +use poseidon::SPONGE_WIDTH; use serde::{Deserialize, Serialize, Serializer, de::DeserializeOwned}; use multilinear_extensions::mle::FieldType; @@ -59,20 +61,22 @@ pub struct BasefoldVerifierParams> { /// A polynomial commitment together with all the data (e.g., the codeword, and Merkle tree) /// used to generate this commitment and for assistant in opening #[derive(Clone, Debug, Default)] -pub struct BasefoldCommitmentWithWitness +pub struct BasefoldCommitmentWithWitness where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { - pub(crate) codeword_tree: MerkleTree, + pub(crate) codeword_tree: MerkleTree, pub(crate) polynomials_bh_evals: Vec>, pub(crate) num_vars: usize, pub(crate) is_base: bool, pub(crate) num_polys: usize, } -impl BasefoldCommitmentWithWitness +impl BasefoldCommitmentWithWitness where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { pub fn to_commitment(&self) -> BasefoldCommitment { BasefoldCommitment::new( @@ -132,20 +136,22 @@ where } } -impl From> for Digest +impl From> for Digest where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { - fn from(val: BasefoldCommitmentWithWitness) -> Self { + fn from(val: BasefoldCommitmentWithWitness) -> Self { val.get_root_as() } } -impl From<&BasefoldCommitmentWithWitness> for BasefoldCommitment +impl From<&BasefoldCommitmentWithWitness> for BasefoldCommitment where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { - fn from(val: &BasefoldCommitmentWithWitness) -> Self { + fn from(val: &BasefoldCommitmentWithWitness) -> Self { val.to_commitment() } } @@ -193,9 +199,10 @@ where } } -impl PartialEq for BasefoldCommitmentWithWitness +impl PartialEq for BasefoldCommitmentWithWitness where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { fn eq(&self, other: &Self) -> bool { self.get_codewords().eq(other.get_codewords()) @@ -203,8 +210,10 @@ where } } -impl Eq for BasefoldCommitmentWithWitness where - E::BaseField: Serialize + DeserializeOwned +impl Eq for BasefoldCommitmentWithWitness +where + E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { } @@ -245,9 +254,9 @@ where } #[derive(Debug)] -pub struct Basefold>(PhantomData<(E, Spec)>); +pub struct Basefold, Mds>(PhantomData<(E, Spec, Mds)>); -impl> Serialize for Basefold { +impl, Mds> Serialize for Basefold { fn serialize(&self, serializer: S) -> Result where S: Serializer, @@ -256,9 +265,9 @@ impl> Serialize for Basefold { } } -pub type BasefoldDefault = Basefold; +pub type BasefoldDefault = Basefold; -impl> Clone for Basefold { +impl, Mds> Clone for Basefold { fn clone(&self) -> Self { Self(PhantomData) } @@ -274,9 +283,10 @@ where } } -impl AsRef<[Digest]> for BasefoldCommitmentWithWitness +impl AsRef<[Digest]> for BasefoldCommitmentWithWitness where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { fn as_ref(&self) -> &[Digest] { let root = self.get_root_ref(); @@ -285,34 +295,37 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] -pub enum ProofQueriesResultWithMerklePath +#[serde(bound = "")] +pub enum ProofQueriesResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { - Single(QueriesResultWithMerklePath), - Batched(BatchedQueriesResultWithMerklePath), - SimpleBatched(SimpleBatchQueriesResultWithMerklePath), + Single(QueriesResultWithMerklePath), + Batched(BatchedQueriesResultWithMerklePath), + SimpleBatched(SimpleBatchQueriesResultWithMerklePath), } -impl ProofQueriesResultWithMerklePath +impl ProofQueriesResultWithMerklePath where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { - pub fn as_single(&self) -> &QueriesResultWithMerklePath { + pub fn as_single(&self) -> &QueriesResultWithMerklePath { match self { Self::Single(x) => x, _ => panic!("Not a single query result"), } } - pub fn as_batched(&self) -> &BatchedQueriesResultWithMerklePath { + pub fn as_batched(&self) -> &BatchedQueriesResultWithMerklePath { match self { Self::Batched(x) => x, _ => panic!("Not a batched query result"), } } - pub fn as_simple_batched(&self) -> &SimpleBatchQueriesResultWithMerklePath { + pub fn as_simple_batched(&self) -> &SimpleBatchQueriesResultWithMerklePath { match self { Self::SimpleBatched(x) => x, _ => panic!("Not a simple batched query result"), @@ -321,21 +334,24 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct BasefoldProof +#[serde(bound = "")] +pub struct BasefoldProof where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { pub(crate) sumcheck_messages: Vec>, pub(crate) roots: Vec>, pub(crate) final_message: Vec, - pub(crate) query_result_with_merkle_path: ProofQueriesResultWithMerklePath, + pub(crate) query_result_with_merkle_path: ProofQueriesResultWithMerklePath, pub(crate) sumcheck_proof: Option>>, pub(crate) trivial_proof: Vec>, } -impl BasefoldProof +impl BasefoldProof where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { pub fn trivial(evals: Vec>) -> Self { Self { @@ -356,6 +372,7 @@ where } #[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound = "")] pub struct BasefoldCommitPhaseProof where E::BaseField: Serialize + DeserializeOwned, diff --git a/mpcs/src/basefold/sumcheck.rs b/mpcs/src/basefold/sumcheck.rs index ede813e21..6cefc31b1 100644 --- a/mpcs/src/basefold/sumcheck.rs +++ b/mpcs/src/basefold/sumcheck.rs @@ -1,6 +1,6 @@ -use ff::Field; use ff_ext::ExtensionField; use multilinear_extensions::mle::FieldType; +use p3_field::Field; use rayon::prelude::{ IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, ParallelSliceMut, @@ -101,9 +101,9 @@ fn parallel_pi(evals: &[F], eq: &[F]) -> Vec { } }); - coeffs[0] = firsts.par_iter().sum(); - coeffs[1] = seconds.par_iter().sum(); - coeffs[2] = thirds.par_iter().sum(); + coeffs[0] = firsts.par_iter().copied().sum(); + coeffs[1] = seconds.par_iter().copied().sum(); + coeffs[2] = thirds.par_iter().copied().sum(); coeffs } @@ -136,9 +136,9 @@ fn parallel_pi_base(evals: &[E::BaseField], eq: &[E]) -> Vec< } }); - coeffs[0] = firsts.par_iter().sum(); - coeffs[1] = seconds.par_iter().sum(); - coeffs[2] = thirds.par_iter().sum(); + coeffs[0] = firsts.par_iter().copied().sum(); + coeffs[1] = seconds.par_iter().copied().sum(); + coeffs[2] = thirds.par_iter().copied().sum(); coeffs } @@ -169,8 +169,7 @@ pub fn sum_check_last_round(eq: &mut Vec, bh_values: &mut Vec, c #[cfg(test)] mod tests { - use ff::Field; - use goldilocks::Goldilocks; + use p3_goldilocks::Goldilocks; use rand::SeedableRng; use rand_chacha::ChaCha8Rng; diff --git a/mpcs/src/lib.rs b/mpcs/src/lib.rs index fcfd1ba69..7c1938b88 100644 --- a/mpcs/src/lib.rs +++ b/mpcs/src/lib.rs @@ -2,6 +2,8 @@ use ff_ext::ExtensionField; use itertools::Itertools; use multilinear_extensions::mle::DenseMultilinearExtension; +use p3_mds::MdsPermutation; +use poseidon::SPONGE_WIDTH; use serde::{Serialize, de::DeserializeOwned}; use std::fmt::Debug; use transcript::{BasicTranscript, Transcript}; @@ -221,10 +223,11 @@ pub trait PolynomialCommitmentScheme: Clone + Debug { ) -> Result<(), Error>; } -pub trait NoninteractivePCS: +pub trait NoninteractivePCS: PolynomialCommitmentScheme> where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { fn ni_open( pp: &Self::ProverParam, @@ -233,7 +236,7 @@ where point: &[E], eval: &E, ) -> Result { - let mut transcript = BasicTranscript::::new(b"BaseFold"); + let mut transcript = BasicTranscript::::new(b"BaseFold"); Self::open(pp, poly, comm, point, eval, &mut transcript) } @@ -244,7 +247,7 @@ where points: &[Vec], evals: &[Evaluation], ) -> Result { - let mut transcript = BasicTranscript::::new(b"BaseFold"); + let mut transcript = BasicTranscript::::new(b"BaseFold"); Self::batch_open(pp, polys, comms, points, evals, &mut transcript) } @@ -255,7 +258,7 @@ where eval: &E, proof: &Self::Proof, ) -> Result<(), Error> { - let mut transcript = BasicTranscript::::new(b"BaseFold"); + let mut transcript = BasicTranscript::::new(b"BaseFold"); Self::verify(vp, comm, point, eval, proof, &mut transcript) } @@ -269,7 +272,7 @@ where where Self::Commitment: 'a, { - let mut transcript = BasicTranscript::::new(b"BaseFold"); + let mut transcript = BasicTranscript::::new(b"BaseFold"); Self::batch_verify(vp, comms, points, evals, proof, &mut transcript) } } @@ -379,6 +382,10 @@ pub mod test_util { use multilinear_extensions::{ mle::MultilinearExtension, virtual_poly::ArcMultilinearExtension, }; + #[cfg(test)] + use p3_mds::MdsPermutation; + #[cfg(test)] + use poseidon::SPONGE_WIDTH; use rand::rngs::OsRng; #[cfg(test)] use transcript::BasicTranscript; @@ -445,19 +452,20 @@ pub mod test_util { } #[cfg(test)] - pub fn run_commit_open_verify( + pub fn run_commit_open_verify( gen_rand_poly: fn(usize) -> DenseMultilinearExtension, num_vars_start: usize, num_vars_end: usize, ) where Pcs: PolynomialCommitmentScheme, + Mds: MdsPermutation + Default, { for num_vars in num_vars_start..num_vars_end { let (pp, vp) = setup_pcs::(num_vars); // Commit and open let (comm, eval, proof, challenge) = { - let mut transcript = BasicTranscript::new(b"BaseFold"); + let mut transcript = BasicTranscript::::new(b"BaseFold"); let poly = gen_rand_poly(num_vars); let comm = Pcs::commit_and_write(&pp, &poly, &mut transcript).unwrap(); let point = get_point_from_challenge(num_vars, &mut transcript); @@ -473,7 +481,7 @@ pub mod test_util { }; // Verify { - let mut transcript = BasicTranscript::new(b"BaseFold"); + let mut transcript = BasicTranscript::::new(b"BaseFold"); Pcs::write_commitment(&comm, &mut transcript).unwrap(); let point = get_point_from_challenge(num_vars, &mut transcript); transcript.append_field_element_ext(&eval); @@ -486,12 +494,13 @@ pub mod test_util { } #[cfg(test)] - pub fn run_batch_commit_open_verify( + pub fn run_batch_commit_open_verify( gen_rand_poly: fn(usize) -> DenseMultilinearExtension, num_vars_start: usize, num_vars_end: usize, ) where E: ExtensionField, + Mds: MdsPermutation + Default, Pcs: PolynomialCommitmentScheme, { for num_vars in num_vars_start..num_vars_end { @@ -508,7 +517,7 @@ pub mod test_util { .collect_vec(); let (comms, evals, proof, challenge) = { - let mut transcript = BasicTranscript::new(b"BaseFold"); + let mut transcript = BasicTranscript::::new(b"BaseFold"); let polys = gen_rand_polys(|i| num_vars - (i >> 1), batch_size, gen_rand_poly); let comms = @@ -539,7 +548,7 @@ pub mod test_util { }; // Batch verify { - let mut transcript = BasicTranscript::new(b"BaseFold"); + let mut transcript = BasicTranscript::::new(b"BaseFold"); let comms = comms .iter() .map(|comm| { @@ -567,20 +576,21 @@ pub mod test_util { } #[cfg(test)] - pub(super) fn run_simple_batch_commit_open_verify( + pub(super) fn run_simple_batch_commit_open_verify( gen_rand_poly: fn(usize) -> DenseMultilinearExtension, num_vars_start: usize, num_vars_end: usize, batch_size: usize, ) where E: ExtensionField, + Mds: MdsPermutation + Default, Pcs: PolynomialCommitmentScheme, { for num_vars in num_vars_start..num_vars_end { let (pp, vp) = setup_pcs::(num_vars); let (comm, evals, proof, challenge) = { - let mut transcript = BasicTranscript::new(b"BaseFold"); + let mut transcript = BasicTranscript::::new(b"BaseFold"); let polys = gen_rand_polys(|_| num_vars, batch_size, gen_rand_poly); let comm = Pcs::batch_commit_and_write(&pp, polys.as_slice(), &mut transcript).unwrap(); @@ -604,7 +614,7 @@ pub mod test_util { }; // Batch verify { - let mut transcript = BasicTranscript::new(b"BaseFold"); + let mut transcript = BasicTranscript::::new(b"BaseFold"); Pcs::write_commitment(&comm, &mut transcript).unwrap(); let point = get_point_from_challenge(num_vars, &mut transcript); diff --git a/mpcs/src/sum_check.rs b/mpcs/src/sum_check.rs index f2fcf0e47..7406025ca 100644 --- a/mpcs/src/sum_check.rs +++ b/mpcs/src/sum_check.rs @@ -9,10 +9,10 @@ use crate::{ use std::{collections::HashMap, fmt::Debug}; use classic::{ClassicSumCheckRoundMessage, SumcheckProof}; -use ff::PrimeField; use ff_ext::ExtensionField; use itertools::Itertools; use multilinear_extensions::mle::DenseMultilinearExtension; +use p3_field::Field; use serde::{Serialize, de::DeserializeOwned}; use transcript::Transcript; @@ -113,27 +113,30 @@ pub fn evaluate( ) } -pub fn lagrange_eval(x: &[F], b: usize) -> F { +pub fn lagrange_eval(x: &[F], b: usize) -> F { assert!(!x.is_empty()); product(x.iter().enumerate().map( |(idx, x_i)| { - if b.nth_bit(idx) { *x_i } else { F::ONE - x_i } + if b.nth_bit(idx) { *x_i } else { F::ONE - *x_i } }, )) } -pub fn eq_xy_eval(x: &[F], y: &[F]) -> F { +pub fn eq_xy_eval(x: &[F], y: &[F]) -> F { assert!(!x.is_empty()); assert_eq!(x.len(), y.len()); product( x.iter() .zip(y) - .map(|(x_i, y_i)| (*x_i * y_i).double() + F::ONE - x_i - y_i), + .map(|(x_i, y_i)| (*x_i * *y_i).double() + F::ONE - *x_i - *y_i), ) } -fn identity_eval(x: &[F]) -> F { - inner_product(x, &powers(F::from(2)).take(x.len()).collect_vec()) +fn identity_eval(x: &[F]) -> F { + inner_product( + x, + &powers(F::from_canonical_u64(2)).take(x.len()).collect_vec(), + ) } diff --git a/mpcs/src/sum_check/classic.rs b/mpcs/src/sum_check/classic.rs index f99d832df..ea7bfc92f 100644 --- a/mpcs/src/sum_check/classic.rs +++ b/mpcs/src/sum_check/classic.rs @@ -9,7 +9,6 @@ use crate::{ }, }; use ark_std::{end_timer, start_timer}; -use ff::Field; use ff_ext::ExtensionField; use itertools::Itertools; use num_integer::Integer; @@ -24,6 +23,7 @@ use multilinear_extensions::{ pub(crate) use coeff::Coefficients; pub use coeff::CoefficientsProver; +use p3_field::FieldAlgebra; #[derive(Debug)] pub struct ProverState<'a, E: ExtensionField> { @@ -99,12 +99,12 @@ impl<'a, E: ExtensionField> ProverState<'a, E> { fn next_round(&mut self, sum: E, challenge: &E) { self.sum = sum; - self.identity += E::from(1 << self.round) * challenge; + self.identity += E::from_canonical_u64(1 << self.round) * *challenge; self.lagranges.values_mut().for_each(|(b, value)| { if b.is_even() { - *value *= &(E::ONE - challenge); + *value *= E::ONE - *challenge; } else { - *value *= challenge; + *value *= *challenge; } *b >>= 1; }); @@ -324,51 +324,58 @@ mod tests { use transcript::BasicTranscript; use super::*; - use goldilocks::{Goldilocks as Fr, GoldilocksExt2 as E}; + use ff_ext::GoldilocksExt2 as E; + use p3_goldilocks::{Goldilocks as Fr, MdsMatrixGoldilocks}; #[test] fn test_sum_check_protocol() { let polys = [ DenseMultilinearExtension::::from_evaluations_vec(2, vec![ - Fr::from(1), - Fr::from(2), - Fr::from(3), - Fr::from(4), + Fr::from_canonical_u64(1), + Fr::from_canonical_u64(2), + Fr::from_canonical_u64(3), + Fr::from_canonical_u64(4), ]), DenseMultilinearExtension::from_evaluations_vec(2, vec![ - Fr::from(0), - Fr::from(1), - Fr::from(1), - Fr::from(0), + Fr::from_canonical_u64(0), + Fr::from_canonical_u64(1), + Fr::from_canonical_u64(1), + Fr::from_canonical_u64(0), ]), - DenseMultilinearExtension::from_evaluations_vec(1, vec![Fr::from(0), Fr::from(1)]), + DenseMultilinearExtension::from_evaluations_vec(1, vec![ + Fr::from_canonical_u64(0), + Fr::from_canonical_u64(1), + ]), + ]; + let points = vec![ + vec![E::from_canonical_u64(1), E::from_canonical_u64(2)], + vec![E::from_canonical_u64(1)], ]; - let points = vec![vec![E::from(1), E::from(2)], vec![E::from(1)]]; let expression = Expression::::eq_xy(0) * Expression::Polynomial(Query::new(0, Rotation::cur())) - * E::from(Fr::from(2)) + * E::from(Fr::from_canonical_u64(2)) + Expression::::eq_xy(0) * Expression::Polynomial(Query::new(1, Rotation::cur())) - * E::from(Fr::from(3)) + * E::from(Fr::from_canonical_u64(3)) + Expression::::eq_xy(1) * Expression::Polynomial(Query::new(2, Rotation::cur())) - * E::from(Fr::from(4)); + * E::from(Fr::from_canonical_u64(4)); let virtual_poly = VirtualPolynomial::::new(&expression, polys.iter(), &[], points.as_slice()); let sum = inner_product( &poly_iter_ext(&polys[0]).collect_vec(), &build_eq_x_r_vec(&points[0]), - ) * Fr::from(2) + ) * Fr::from_canonical_u64(2) + inner_product( &poly_iter_ext(&polys[1]).collect_vec(), &build_eq_x_r_vec(&points[0]), - ) * Fr::from(3) + ) * Fr::from_canonical_u64(3) + inner_product( &poly_iter_ext(&polys[2]).collect_vec(), &build_eq_x_r_vec(&points[1]), - ) * Fr::from(4) - * Fr::from(2); // The third polynomial is summed twice because the hypercube is larger - let mut transcript = BasicTranscript::::new(b"sumcheck"); + ) * Fr::from_canonical_u64(4) + * Fr::from_canonical_u64(2); // The third polynomial is summed twice because the hypercube is larger + let mut transcript = BasicTranscript::::new(b"sumcheck"); let (challenges, evals, proof) = > as SumCheck>::prove( &(), @@ -383,7 +390,7 @@ mod tests { assert_eq!(polys[1].evaluate(&challenges), evals[1]); assert_eq!(polys[2].evaluate(&challenges[..1]), evals[2]); - let mut transcript = BasicTranscript::::new(b"sumcheck"); + let mut transcript = BasicTranscript::::new(b"sumcheck"); let (new_sum, verifier_challenges) = > as SumCheck< E, @@ -395,12 +402,12 @@ mod tests { assert_eq!(verifier_challenges, challenges); assert_eq!( new_sum, - evals[0] * eq_xy_eval(&points[0], &challenges[..2]) * Fr::from(2) - + evals[1] * eq_xy_eval(&points[0], &challenges[..2]) * Fr::from(3) - + evals[2] * eq_xy_eval(&points[1], &challenges[..1]) * Fr::from(4) + evals[0] * eq_xy_eval(&points[0], &challenges[..2]) * Fr::from_canonical_u64(2) + + evals[1] * eq_xy_eval(&points[0], &challenges[..2]) * Fr::from_canonical_u64(3) + + evals[2] * eq_xy_eval(&points[1], &challenges[..1]) * Fr::from_canonical_u64(4) ); - let mut transcript = BasicTranscript::::new(b"sumcheck"); + let mut transcript = BasicTranscript::::new(b"sumcheck"); > as SumCheck>::verify( &(), diff --git a/mpcs/src/sum_check/classic/coeff.rs b/mpcs/src/sum_check/classic/coeff.rs index 10d5c1c20..a571deedd 100644 --- a/mpcs/src/sum_check/classic/coeff.rs +++ b/mpcs/src/sum_check/classic/coeff.rs @@ -32,6 +32,7 @@ macro_rules! zip_self { } #[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound = "")] pub struct Coefficients(FieldType); impl ClassicSumCheckRoundMessage for Coefficients { @@ -49,7 +50,7 @@ impl ClassicSumCheckRoundMessage for Coefficients { } fn sum(&self) -> E { - self[0] + self[..].iter().sum::() + self[0] + self[..].iter().copied().sum::() } fn evaluate(&self, _: &Self::Auxiliary, challenge: &E) -> E { @@ -60,7 +61,7 @@ impl ClassicSumCheckRoundMessage for Coefficients { impl<'rhs, E: ExtensionField> AddAssign<&'rhs E> for Coefficients { fn add_assign(&mut self, rhs: &'rhs E) { match &mut self.0 { - FieldType::Ext(coeffs) => coeffs[0] += rhs, + FieldType::Ext(coeffs) => coeffs[0] += *rhs, FieldType::Base(_) => panic!("Cannot add extension element to base coefficients"), FieldType::Unreachable => unreachable!(), } @@ -74,11 +75,11 @@ impl<'rhs, E: ExtensionField> AddAssign<(&'rhs E, &'rhs Coefficients)> for Co if scalar == &E::ONE { lhs.iter_mut() .zip(rhs.iter()) - .for_each(|(lhs, rhs)| *lhs += rhs) + .for_each(|(lhs, rhs)| *lhs += *rhs) } else if scalar != &E::ZERO { lhs.iter_mut() .zip(rhs.iter()) - .for_each(|(lhs, rhs)| *lhs += &(*scalar * rhs)) + .for_each(|(lhs, rhs)| *lhs += *scalar * *rhs) } } _ => panic!("Cannot add base coefficients to extension coefficients"), @@ -116,7 +117,7 @@ impl CoefficientsProver { result.iter_mut().enumerate().for_each(|(i, v)| { *v += poly_index_ext(lhs, i % lhs.evaluations.len()) * poly_index_ext(rhs, i % rhs.evaluations.len()) - * scalar; + * *scalar; }) } _ => unimplemented!(), @@ -162,7 +163,7 @@ impl ClassicSumCheckProver for CoefficientsProver { outputs.extend( products .iter() - .map(|(scalar, polys)| (constant * scalar, polys.clone())), + .map(|(scalar, polys)| (constant * *scalar, polys.clone())), ) } } @@ -170,7 +171,7 @@ impl ClassicSumCheckProver for CoefficientsProver { lhs_products.iter().cartesian_product(rhs_products.iter()) { outputs.push(( - *lhs_scalar * rhs_scalar, + *lhs_scalar * *rhs_scalar, iter::empty() .chain(lhs_polys) .chain(rhs_polys) @@ -182,7 +183,7 @@ impl ClassicSumCheckProver for CoefficientsProver { }, &|(constant, mut products), rhs| { products.iter_mut().for_each(|(lhs, _)| { - *lhs *= &rhs; + *lhs *= rhs; }); (constant * rhs, products) }, @@ -194,7 +195,7 @@ impl ClassicSumCheckProver for CoefficientsProver { // Initialize h(X) to zero let mut coeffs = Coefficients(FieldType::Ext(vec![E::ZERO; state.expression.degree() + 1])); // First, sum the constant over the hypercube and add to h(X) - coeffs += &(E::from(state.size() as u64) * self.0); + coeffs += &(E::from_canonical_u64(state.size() as u64) * self.0); // Next, for every product of polynomials, where each product is assumed to be exactly 2 // put this into h(X). if self.1.iter().all(|(_, products)| products.len() == 2) { @@ -219,7 +220,7 @@ impl ClassicSumCheckProver for CoefficientsProver { } fn sum(&self, state: &ProverState) -> E { - self.evals(state).iter().sum() + self.evals(state).iter().copied().sum() } } @@ -267,10 +268,10 @@ impl CoefficientsProver { .for_each(|((lhs_0, lhs_1), (rhs_0, rhs_1))| { let coeff_0 = lhs_0 * rhs_0; let coeff_2 = (lhs_1 - lhs_0) * (rhs_1 - rhs_0); - coeffs[0] += &coeff_0; - coeffs[2] += &coeff_2; + coeffs[0] += coeff_0; + coeffs[2] += coeff_2; if !LAZY { - coeffs[1] += &(lhs_1 * rhs_1 - coeff_0 - coeff_2); + coeffs[1] += lhs_1 * rhs_1 - coeff_0 - coeff_2; } }); }; diff --git a/mpcs/src/util.rs b/mpcs/src/util.rs index 7688b53ec..9a1b0f3d8 100644 --- a/mpcs/src/util.rs +++ b/mpcs/src/util.rs @@ -3,14 +3,13 @@ pub mod expression; pub mod hash; pub mod parallel; pub mod plonky2_util; -use ff::{Field, PrimeField}; -use ff_ext::ExtensionField; -use goldilocks::SmallField; +use ff_ext::{ExtensionField, SmallField}; use itertools::{Either, Itertools, izip}; use multilinear_extensions::mle::{DenseMultilinearExtension, FieldType}; use serde::{Deserialize, Serialize, de::DeserializeOwned}; pub mod merkle_tree; use crate::{Error, util::parallel::parallelize}; +use p3_field::{FieldAlgebra, PrimeField}; pub use plonky2_util::log2_strict; pub fn ext_to_usize(x: &E) -> usize { @@ -23,7 +22,7 @@ pub fn base_to_usize(x: &E::BaseField) -> usize { } pub fn u32_to_field(x: u32) -> E::BaseField { - E::BaseField::from(x as u64) + E::BaseField::from_canonical_u32(x) } pub trait BitIndex { @@ -38,7 +37,7 @@ impl BitIndex for usize { /// How many bytes are required to store n field elements? pub fn num_of_bytes(n: usize) -> usize { - (F::NUM_BITS as usize).next_power_of_two() * n / 8 + F::bits().next_power_of_two() * n / 8 } macro_rules! impl_index { @@ -118,8 +117,8 @@ pub fn field_type_index_mul_base( scalar: &E::BaseField, ) { match poly { - FieldType::Ext(coeffs) => coeffs[index] *= scalar, - FieldType::Base(coeffs) => coeffs[index] *= scalar, + FieldType::Ext(coeffs) => coeffs[index] *= *scalar, + FieldType::Base(coeffs) => coeffs[index] *= *scalar, _ => unreachable!(), } } @@ -194,13 +193,13 @@ pub fn multiply_poly(poly: &mut DenseMultilinearExtension, match &mut poly.evaluations { FieldType::Ext(coeffs) => { for coeff in coeffs.iter_mut() { - *coeff *= scalar; + *coeff *= *scalar; } } FieldType::Base(coeffs) => { *poly = DenseMultilinearExtension::::from_evaluations_ext_vec( poly.num_vars, - coeffs.iter().map(|x| E::from(*x) * scalar).collect(), + coeffs.iter().map(|x| E::from(*x) * *scalar).collect(), ); } _ => unreachable!(), @@ -320,11 +319,12 @@ pub fn ext_try_into_base(x: &E) -> Result(mut rng: impl RngCore) -> [F; N] { + pub fn rand_array(mut rng: impl RngCore) -> [F; N] { array::from_fn(|_| F::random(&mut rng)) } - pub fn rand_vec(n: usize, mut rng: impl RngCore) -> Vec { + pub fn rand_vec(n: usize, mut rng: impl RngCore) -> Vec { iter::repeat_with(|| F::random(&mut rng)).take(n).collect() } #[test] pub fn test_field_transform() { - assert_eq!(F::from(2) * F::from(3), F::from(6)); + assert_eq!( + F::from_canonical_u64(2) * F::from_canonical_u64(3), + F::from_canonical_u64(6) + ); assert_eq!(base_to_usize::(&u32_to_field::(1u32)), 1); assert_eq!(base_to_usize::(&u32_to_field::(10u32)), 10); } diff --git a/mpcs/src/util/arithmetic.rs b/mpcs/src/util/arithmetic.rs index 609f65455..34837b60e 100644 --- a/mpcs/src/util/arithmetic.rs +++ b/mpcs/src/util/arithmetic.rs @@ -1,18 +1,16 @@ -use ff::{BatchInvert, Field, PrimeField}; - use ff_ext::ExtensionField; use multilinear_extensions::mle::FieldType; use num_integer::Integer; +use p3_field::Field; use std::{borrow::Borrow, iter}; mod bh; mod hypercube; pub use bh::BooleanHypercube; -pub use bitvec::field::BitField; pub use hypercube::{ interpolate_field_type_over_boolean_hypercube, interpolate_over_boolean_hypercube, }; -use num_bigint::BigUint; +use p3_field::FieldAlgebra; use itertools::Itertools; @@ -29,7 +27,7 @@ pub fn horner(coeffs: &[F], x: &F) -> F { let coeff_vec: Vec<&F> = coeffs.iter().rev().collect(); let mut acc = F::ZERO; for c in coeff_vec { - acc = acc * x + c; + acc = acc * *x + *c; } acc // 2 @@ -40,7 +38,7 @@ pub fn horner(coeffs: &[F], x: &F) -> F { pub fn horner_base(coeffs: &[E::BaseField], x: &E) -> E { let mut acc = E::ZERO; for c in coeffs.iter().rev() { - acc = acc * x + E::from(*c); + acc = acc * *x + E::from(*c); } acc // 2 @@ -52,11 +50,11 @@ pub fn steps(start: F) -> impl Iterator { } pub fn steps_by(start: F, step: F) -> impl Iterator { - iter::successors(Some(start), move |state| Some(step + state)) + iter::successors(Some(start), move |state| Some(step + *state)) } pub fn powers(scalar: F) -> impl Iterator { - iter::successors(Some(F::ONE), move |power| Some(scalar * power)) + iter::successors(Some(F::ONE), move |power| Some(scalar * *power)) } pub fn squares(scalar: F) -> impl Iterator { @@ -66,13 +64,13 @@ pub fn squares(scalar: F) -> impl Iterator { pub fn product(values: impl IntoIterator>) -> F { values .into_iter() - .fold(F::ONE, |acc, value| acc * value.borrow()) + .fold(F::ONE, |acc, value| acc * *value.borrow()) } pub fn sum(values: impl IntoIterator>) -> F { values .into_iter() - .fold(F::ZERO, |acc, value| acc + value.borrow()) + .fold(F::ZERO, |acc, value| acc + *value.borrow()) } pub fn inner_product<'a, 'b, F: Field>( @@ -81,7 +79,7 @@ pub fn inner_product<'a, 'b, F: Field>( ) -> F { lhs.into_iter() .zip_eq(rhs) - .map(|(lhs, rhs)| *lhs * rhs) + .map(|(lhs, rhs)| *lhs * *rhs) .reduce(|acc, product| acc + product) .unwrap_or_default() } @@ -94,83 +92,15 @@ pub fn inner_product_three<'a, 'b, 'c, F: Field>( a.into_iter() .zip_eq(b) .zip_eq(c) - .map(|((a, b), c)| *a * b * c) + .map(|((a, b), c)| *a * *b * *c) .reduce(|acc, product| acc + product) .unwrap_or_default() } -pub fn barycentric_weights(points: &[F]) -> Vec { - let mut weights = points - .iter() - .enumerate() - .map(|(j, point_j)| { - points - .iter() - .enumerate() - .filter(|&(i, _point_i)| (i != j)) - .map(|(_i, point_i)| *point_j - point_i) - .reduce(|acc, value| acc * value) - .unwrap_or(F::ONE) - }) - .collect_vec(); - weights.iter_mut().batch_invert(); - weights -} - -pub fn barycentric_interpolate(weights: &[F], points: &[F], evals: &[F], x: &F) -> F { - let (coeffs, sum_inv) = { - let mut coeffs = points.iter().map(|point| *x - point).collect_vec(); - coeffs.iter_mut().batch_invert(); - coeffs.iter_mut().zip(weights).for_each(|(coeff, weight)| { - *coeff *= weight; - }); - let sum_inv = coeffs.iter().fold(F::ZERO, |sum, coeff| sum + coeff); - (coeffs, sum_inv.invert().unwrap()) - }; - inner_product(&coeffs, evals) * sum_inv -} - -pub fn modulus() -> BigUint { - BigUint::from_bytes_le((-F::ONE).to_repr().as_ref()) + 1u64 -} - pub fn fe_from_bool(value: bool) -> F { if value { F::ONE } else { F::ZERO } } -pub fn fe_mod_from_le_bytes(bytes: impl AsRef<[u8]>) -> F { - fe_from_le_bytes((BigUint::from_bytes_le(bytes.as_ref()) % modulus::()).to_bytes_le()) -} - -pub fn fe_truncated_from_le_bytes(bytes: impl AsRef<[u8]>, num_bits: usize) -> F { - let mut big = BigUint::from_bytes_le(bytes.as_ref()); - (num_bits as u64..big.bits()).for_each(|idx| big.set_bit(idx, false)); - fe_from_le_bytes(big.to_bytes_le()) -} - -pub fn fe_from_le_bytes(bytes: impl AsRef<[u8]>) -> F { - let bytes = bytes.as_ref(); - let mut repr = F::Repr::default(); - assert!(bytes.len() <= repr.as_ref().len()); - repr.as_mut()[..bytes.len()].copy_from_slice(bytes); - F::from_repr(repr).unwrap() -} - -pub fn fe_to_fe(fe: F1) -> F2 { - debug_assert!(BigUint::from_bytes_le(fe.to_repr().as_ref()) < modulus::()); - let mut repr = F2::Repr::default(); - repr.as_mut().copy_from_slice(fe.to_repr().as_ref()); - F2::from_repr(repr).unwrap() -} - -pub fn fe_truncated(fe: F, num_bits: usize) -> F { - let (num_bytes, num_bits_last_byte) = div_rem(num_bits, 8); - let mut repr = fe.to_repr(); - repr.as_mut()[num_bytes + 1..].fill(0); - repr.as_mut()[num_bytes] &= (1 << num_bits_last_byte) - 1; - F::from_repr(repr).unwrap() -} - pub fn usize_from_bits_le(bits: &[bool]) -> usize { bits.iter() .rev() @@ -215,7 +145,7 @@ pub fn interpolate2(points: [(F, F); 2], x: F) -> F { let (a0, a1) = points[0]; let (b0, b1) = points[1]; assert_ne!(a0, b0); - a1 + (x - a0) * (b1 - a1) * (b0 - a0).invert().unwrap() + a1 + (x - a0) * (b1 - a1) * (b0 - a0).inverse() } pub fn degree_2_zero_plus_one(poly: &[F]) -> F { @@ -229,7 +159,7 @@ pub fn degree_2_eval(poly: &[F], point: F) -> F { pub fn base_from_raw_bytes(bytes: &[u8]) -> E::BaseField { let mut res = E::BaseField::ZERO; bytes.iter().for_each(|b| { - res += E::BaseField::from(u64::from(*b)); + res += E::BaseField::from_canonical_u8(*b); }); res } diff --git a/mpcs/src/util/arithmetic/hypercube.rs b/mpcs/src/util/arithmetic/hypercube.rs index fc4edc248..19e42372e 100644 --- a/mpcs/src/util/arithmetic/hypercube.rs +++ b/mpcs/src/util/arithmetic/hypercube.rs @@ -1,6 +1,6 @@ -use ff::Field; use ff_ext::ExtensionField; use multilinear_extensions::mle::FieldType; +use p3_field::Field; use rayon::prelude::{ParallelIterator, ParallelSliceMut}; use crate::util::log2_strict; diff --git a/mpcs/src/util/expression.rs b/mpcs/src/util/expression.rs index 97298371c..bc77edcb0 100644 --- a/mpcs/src/util/expression.rs +++ b/mpcs/src/util/expression.rs @@ -1,5 +1,4 @@ use crate::util::{Deserialize, Itertools, Serialize, izip}; -use ff::Field; use std::{ collections::BTreeSet, fmt::Debug, @@ -8,6 +7,8 @@ use std::{ ops::{Add, Mul, Neg, Sub}, }; +use p3_field::Field; + #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] pub struct Rotation(pub i32); diff --git a/mpcs/src/util/hash.rs b/mpcs/src/util/hash.rs index 2dc4c8bf4..7140129d2 100644 --- a/mpcs/src/util/hash.rs +++ b/mpcs/src/util/hash.rs @@ -1,11 +1,11 @@ -use ff_ext::ExtensionField; -use goldilocks::SmallField; -use poseidon::poseidon_hash::PoseidonHash; +use ff_ext::{ExtensionField, SmallField}; +use p3_field::PrimeField; +use p3_mds::MdsPermutation; +use poseidon::{SPONGE_WIDTH, poseidon_hash::PoseidonHash}; use transcript::Transcript; pub use poseidon::digest::Digest; -use poseidon::poseidon::PrimeField; pub fn write_digest_to_transcript( digest: &Digest, @@ -17,33 +17,50 @@ pub fn write_digest_to_transcript( .for_each(|x| transcript.append_field_element(x)); } -pub fn hash_two_leaves_ext(a: &E, b: &E) -> Digest { +pub fn hash_two_leaves_ext(a: &E, b: &E) -> Digest +where + Mds: MdsPermutation + Default, +{ let input = [a.as_bases(), b.as_bases()].concat(); - PoseidonHash::hash_or_noop(&input) + PoseidonHash::::hash_or_noop(&input) } -pub fn hash_two_leaves_base( +pub fn hash_two_leaves_base( a: &E::BaseField, b: &E::BaseField, -) -> Digest { - PoseidonHash::hash_or_noop(&[*a, *b]) +) -> Digest +where + Mds: MdsPermutation + Default, +{ + PoseidonHash::::hash_or_noop(&[*a, *b]) } -pub fn hash_two_leaves_batch_ext(a: &[E], b: &[E]) -> Digest { - let a_m_to_1_hash = PoseidonHash::hash_or_noop_iter(a.iter().flat_map(|v| v.as_bases())); - let b_m_to_1_hash = PoseidonHash::hash_or_noop_iter(b.iter().flat_map(|v| v.as_bases())); - hash_two_digests(&a_m_to_1_hash, &b_m_to_1_hash) +pub fn hash_two_leaves_batch_ext(a: &[E], b: &[E]) -> Digest +where + Mds: MdsPermutation + Default, +{ + let a_m_to_1_hash = + PoseidonHash::::hash_or_noop_iter(a.iter().flat_map(|v| v.as_bases())); + let b_m_to_1_hash = + PoseidonHash::::hash_or_noop_iter(b.iter().flat_map(|v| v.as_bases())); + hash_two_digests::(&a_m_to_1_hash, &b_m_to_1_hash) } -pub fn hash_two_leaves_batch_base( +pub fn hash_two_leaves_batch_base( a: &[E::BaseField], b: &[E::BaseField], -) -> Digest { - let a_m_to_1_hash = PoseidonHash::hash_or_noop_iter(a.iter()); - let b_m_to_1_hash = PoseidonHash::hash_or_noop_iter(b.iter()); - hash_two_digests(&a_m_to_1_hash, &b_m_to_1_hash) +) -> Digest +where + Mds: MdsPermutation + Default, +{ + let a_m_to_1_hash = PoseidonHash::::hash_or_noop_iter(a.iter()); + let b_m_to_1_hash = PoseidonHash::::hash_or_noop_iter(b.iter()); + hash_two_digests::(&a_m_to_1_hash, &b_m_to_1_hash) } -pub fn hash_two_digests(a: &Digest, b: &Digest) -> Digest { - PoseidonHash::two_to_one(a, b) +pub fn hash_two_digests(a: &Digest, b: &Digest) -> Digest +where + Mds: MdsPermutation + Default, +{ + PoseidonHash::::two_to_one(a, b) } diff --git a/mpcs/src/util/merkle_tree.rs b/mpcs/src/util/merkle_tree.rs index d24840496..e26954c84 100644 --- a/mpcs/src/util/merkle_tree.rs +++ b/mpcs/src/util/merkle_tree.rs @@ -1,6 +1,10 @@ +use std::marker::PhantomData; + use ff_ext::ExtensionField; use itertools::Itertools; use multilinear_extensions::mle::FieldType; +use p3_mds::MdsPermutation; +use poseidon::SPONGE_WIDTH; use rayon::{ iter::{ IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelIterator, @@ -24,28 +28,30 @@ use super::hash::write_digest_to_transcript; #[derive(Clone, Debug, Default, Serialize, Deserialize)] #[serde(bound(serialize = "E: Serialize", deserialize = "E: DeserializeOwned"))] -pub struct MerkleTree +pub struct MerkleTree where E::BaseField: Serialize + DeserializeOwned, { inner: Vec>>, leaves: Vec>, + _phantom: PhantomData, } -impl MerkleTree +impl MerkleTree where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { pub fn compute_inner(leaves: &FieldType) -> Vec>> { - merkelize::(&[leaves]) + merkelize::(&[leaves]) } pub fn compute_inner_base(leaves: &[E::BaseField]) -> Vec>> { - merkelize_base::(&[leaves]) + merkelize_base::(&[leaves]) } pub fn compute_inner_ext(leaves: &[E]) -> Vec>> { - merkelize_ext::(&[leaves]) + merkelize_ext::(&[leaves]) } pub fn root_from_inner(inner: &[Vec>]) -> Digest { @@ -56,6 +62,7 @@ where Self { inner, leaves: vec![leaves], + _phantom: PhantomData, } } @@ -63,13 +70,15 @@ where Self { inner: Self::compute_inner(&leaves), leaves: vec![leaves], + _phantom: PhantomData, } } pub fn from_batch_leaves(leaves: Vec>) -> Self { Self { - inner: merkelize::(&leaves.iter().collect_vec()), + inner: merkelize::(&leaves.iter().collect_vec()), leaves, + _phantom: PhantomData, } } @@ -139,9 +148,9 @@ where pub fn merkle_path_without_leaf_sibling_or_root( &self, leaf_index: usize, - ) -> MerklePathWithoutLeafOrRoot { + ) -> MerklePathWithoutLeafOrRoot { assert!(leaf_index < self.size().1); - MerklePathWithoutLeafOrRoot::::new( + MerklePathWithoutLeafOrRoot::::new( self.inner .iter() .take(self.height() - 1) @@ -155,19 +164,24 @@ where } #[derive(Clone, Debug, Default, Serialize, Deserialize)] -pub struct MerklePathWithoutLeafOrRoot +pub struct MerklePathWithoutLeafOrRoot where E::BaseField: Serialize + DeserializeOwned, { inner: Vec>, + _phantom: PhantomData, } -impl MerklePathWithoutLeafOrRoot +impl MerklePathWithoutLeafOrRoot where E::BaseField: Serialize + DeserializeOwned, + Mds: MdsPermutation + Default, { pub fn new(inner: Vec>) -> Self { - Self { inner } + Self { + inner, + _phantom: PhantomData, + } } pub fn is_empty(&self) -> bool { @@ -195,7 +209,7 @@ where index: usize, root: &Digest, ) { - authenticate_merkle_path_root::( + authenticate_merkle_path_root::( &self.inner, FieldType::Ext(vec![left, right]), index, @@ -210,7 +224,7 @@ where index: usize, root: &Digest, ) { - authenticate_merkle_path_root::( + authenticate_merkle_path_root::( &self.inner, FieldType::Base(vec![left, right]), index, @@ -225,7 +239,7 @@ where index: usize, root: &Digest, ) { - authenticate_merkle_path_root_batch::( + authenticate_merkle_path_root_batch::( &self.inner, FieldType::Ext(left), FieldType::Ext(right), @@ -241,7 +255,7 @@ where index: usize, root: &Digest, ) { - authenticate_merkle_path_root_batch::( + authenticate_merkle_path_root_batch::( &self.inner, FieldType::Base(left), FieldType::Base(right), @@ -253,7 +267,10 @@ where /// Merkle tree construction /// TODO: Support merkelizing mixed-type values -fn merkelize(values: &[&FieldType]) -> Vec>> { +fn merkelize(values: &[&FieldType]) -> Vec>> +where + Mds: MdsPermutation + Default, +{ #[cfg(feature = "sanity-check")] for i in 0..(values.len() - 1) { assert_eq!(values[i].len(), values[i + 1].len()); @@ -267,10 +284,10 @@ fn merkelize(values: &[&FieldType]) -> Vec { - hash_two_leaves_base::(&values[i << 1], &values[(i << 1) + 1]) + hash_two_leaves_base::(&values[i << 1], &values[(i << 1) + 1]) } FieldType::Ext(values) => { - hash_two_leaves_ext::(&values[i << 1], &values[(i << 1) + 1]) + hash_two_leaves_ext::(&values[i << 1], &values[(i << 1) + 1]) } FieldType::Unreachable => unreachable!(), }; @@ -278,7 +295,7 @@ fn merkelize(values: &[&FieldType]) -> Vec hash_two_leaves_batch_base::( + FieldType::Base(_) => hash_two_leaves_batch_base::( values .iter() .map(|values| field_type_index_base(values, i << 1)) @@ -290,7 +307,7 @@ fn merkelize(values: &[&FieldType]) -> Vec hash_two_leaves_batch_ext::( + FieldType::Ext(_) => hash_two_leaves_batch_ext::( values .iter() .map(|values| field_type_index_ext(values, i << 1)) @@ -312,7 +329,7 @@ fn merkelize(values: &[&FieldType]) -> Vec(&ys[0], &ys[1])) .collect::>(); tree.push(oracle); @@ -321,7 +338,12 @@ fn merkelize(values: &[&FieldType]) -> Vec(values: &[&[E::BaseField]]) -> Vec>> { +fn merkelize_base( + values: &[&[E::BaseField]], +) -> Vec>> +where + Mds: MdsPermutation + Default, +{ #[cfg(feature = "sanity-check")] for i in 0..(values.len() - 1) { assert_eq!(values[i].len(), values[i + 1].len()); @@ -333,11 +355,11 @@ fn merkelize_base(values: &[&[E::BaseField]]) -> Vec> 1]; if values.len() == 1 { hashes.par_iter_mut().enumerate().for_each(|(i, hash)| { - *hash = hash_two_leaves_base::(&values[0][i << 1], &values[0][(i << 1) + 1]); + *hash = hash_two_leaves_base::(&values[0][i << 1], &values[0][(i << 1) + 1]); }); } else { hashes.par_iter_mut().enumerate().for_each(|(i, hash)| { - *hash = hash_two_leaves_batch_base::( + *hash = hash_two_leaves_batch_base::( values .iter() .map(|values| values[i << 1]) @@ -357,7 +379,7 @@ fn merkelize_base(values: &[&[E::BaseField]]) -> Vec(&ys[0], &ys[1])) .collect::>(); tree.push(oracle); @@ -366,7 +388,10 @@ fn merkelize_base(values: &[&[E::BaseField]]) -> Vec(values: &[&[E]]) -> Vec>> { +fn merkelize_ext(values: &[&[E]]) -> Vec>> +where + Mds: MdsPermutation + Default, +{ #[cfg(feature = "sanity-check")] for i in 0..(values.len() - 1) { assert_eq!(values[i].len(), values[i + 1].len()); @@ -378,11 +403,11 @@ fn merkelize_ext(values: &[&[E]]) -> Vec> 1]; if values.len() == 1 { hashes.par_iter_mut().enumerate().for_each(|(i, hash)| { - *hash = hash_two_leaves_ext::(&values[0][i << 1], &values[0][(i << 1) + 1]); + *hash = hash_two_leaves_ext::(&values[0][i << 1], &values[0][(i << 1) + 1]); }); } else { hashes.par_iter_mut().enumerate().for_each(|(i, hash)| { - *hash = hash_two_leaves_batch_ext::( + *hash = hash_two_leaves_batch_ext::( values .iter() .map(|values| values[i << 1]) @@ -402,7 +427,7 @@ fn merkelize_ext(values: &[&[E]]) -> Vec(&ys[0], &ys[1])) .collect::>(); tree.push(oracle); @@ -411,17 +436,19 @@ fn merkelize_ext(values: &[&[E]]) -> Vec( +fn authenticate_merkle_path_root( path: &[Digest], leaves: FieldType, x_index: usize, root: &Digest, -) { +) where + Mds: MdsPermutation + Default, +{ let mut x_index = x_index; assert_eq!(leaves.len(), 2); let mut hash = match leaves { - FieldType::Base(leaves) => hash_two_leaves_base::(&leaves[0], &leaves[1]), - FieldType::Ext(leaves) => hash_two_leaves_ext(&leaves[0], &leaves[1]), + FieldType::Base(leaves) => hash_two_leaves_base::(&leaves[0], &leaves[1]), + FieldType::Ext(leaves) => hash_two_leaves_ext::(&leaves[0], &leaves[1]), FieldType::Unreachable => unreachable!(), }; @@ -429,40 +456,42 @@ fn authenticate_merkle_path_root( x_index >>= 1; for path_i in path.iter() { hash = if x_index & 1 == 0 { - hash_two_digests(&hash, path_i) + hash_two_digests::(&hash, path_i) } else { - hash_two_digests(path_i, &hash) + hash_two_digests::(path_i, &hash) }; x_index >>= 1; } assert_eq!(&hash, root); } -fn authenticate_merkle_path_root_batch( +fn authenticate_merkle_path_root_batch( path: &[Digest], left: FieldType, right: FieldType, x_index: usize, root: &Digest, -) { +) where + Mds: MdsPermutation + Default, +{ let mut x_index = x_index; let mut hash = if left.len() > 1 { match (left, right) { (FieldType::Base(left), FieldType::Base(right)) => { - hash_two_leaves_batch_base::(&left, &right) + hash_two_leaves_batch_base::(&left, &right) } (FieldType::Ext(left), FieldType::Ext(right)) => { - hash_two_leaves_batch_ext::(&left, &right) + hash_two_leaves_batch_ext::(&left, &right) } _ => unreachable!(), } } else { match (left, right) { (FieldType::Base(left), FieldType::Base(right)) => { - hash_two_leaves_base::(&left[0], &right[0]) + hash_two_leaves_base::(&left[0], &right[0]) } (FieldType::Ext(left), FieldType::Ext(right)) => { - hash_two_leaves_ext::(&left[0], &right[0]) + hash_two_leaves_ext::(&left[0], &right[0]) } _ => unreachable!(), } @@ -472,9 +501,9 @@ fn authenticate_merkle_path_root_batch( x_index >>= 1; for path_i in path.iter() { hash = if x_index & 1 == 0 { - hash_two_digests(&hash, path_i) + hash_two_digests::(&hash, path_i) } else { - hash_two_digests(path_i, &hash) + hash_two_digests::(path_i, &hash) }; x_index >>= 1; } diff --git a/sumcheck/src/test.rs b/sumcheck/src/test.rs index ad944820c..c6ae022b9 100644 --- a/sumcheck/src/test.rs +++ b/sumcheck/src/test.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use ark_std::{rand::RngCore, test_rng}; use ff_ext::{ExtensionField, GoldilocksExt2}; use multilinear_extensions::virtual_poly::VirtualPolynomial; -use p3_field::FieldAlgebra; +use p3_field::{Field, FieldAlgebra}; use p3_goldilocks::MdsMatrixGoldilocks; use p3_mds::MdsPermutation; use poseidon::SPONGE_WIDTH;