diff --git a/ceno_zkvm/src/chip_handler.rs b/ceno_zkvm/src/chip_handler.rs index 991fad3d7..76a0569ce 100644 --- a/ceno_zkvm/src/chip_handler.rs +++ b/ceno_zkvm/src/chip_handler.rs @@ -17,7 +17,11 @@ pub mod utils; pub mod test; pub trait GlobalStateRegisterMachineChipOperations { - fn state_in(&mut self, pc: Expression, ts: Expression) -> Result<(), ZKVMError>; + fn state_in( + &mut self, + pc: impl ToExpr>, + ts: impl ToExpr>, + ) -> Result<(), ZKVMError>; fn state_out(&mut self, pc: Expression, ts: Expression) -> Result<(), ZKVMError>; } @@ -30,8 +34,8 @@ pub trait RegisterChipOperations, N: FnOnce( fn register_read( &mut self, name_fn: N, - register_id: impl ToExpr>, - prev_ts: Expression, + register_id: impl ToExpr> + std::marker::Copy, + prev_ts: impl ToExpr> + std::marker::Copy, ts: Expression, value: RegisterExpr, ) -> Result<(Expression, AssertLtConfig), ZKVMError>; @@ -40,8 +44,8 @@ pub trait RegisterChipOperations, N: FnOnce( fn register_write( &mut self, name_fn: N, - register_id: impl ToExpr>, - prev_ts: Expression, + register_id: impl ToExpr> + std::marker::Copy, + prev_ts: impl ToExpr> + std::marker::Copy, ts: Expression, prev_values: RegisterExpr, value: RegisterExpr, diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index 780914e78..46715d85e 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -164,7 +164,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { N: FnOnce() -> NR + Clone, { let byte = self.cs.create_witin(name_fn.clone()); - self.assert_ux::<_, _, 8>(name_fn, byte.expr())?; + self.assert_ux::<_, _, 8>(name_fn, byte)?; Ok(byte) } @@ -175,7 +175,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { N: FnOnce() -> NR + Clone, { let limb = self.cs.create_witin(name_fn.clone()); - self.assert_ux::<_, _, 16>(name_fn, limb.expr())?; + self.assert_ux::<_, _, 16>(name_fn, limb)?; Ok(limb) } @@ -191,7 +191,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { N: FnOnce() -> NR + Clone, { let wit = self.cs.create_witin(name_fn.clone()); - self.require_equal(name_fn, wit.expr(), expr)?; + self.require_equal(name_fn, wit, expr)?; Ok(wit) } @@ -199,7 +199,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { pub fn require_zero( &mut self, name_fn: N, - assert_zero_expr: Expression, + assert_zero_expr: impl ToExpr>, ) -> Result<(), ZKVMError> where NR: Into, @@ -214,8 +214,8 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { pub fn require_equal( &mut self, name_fn: N, - a: Expression, - b: Expression, + a: impl ToExpr>, + b: impl ToExpr>, ) -> Result<(), ZKVMError> where NR: Into, @@ -224,8 +224,10 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { self.namespace( || "require_equal", |cb| { - cb.cs - .require_zero(name_fn, a.to_monomial_form() - b.to_monomial_form()) + cb.cs.require_zero( + name_fn, + a.expr().to_monomial_form() - b.expr().to_monomial_form(), + ) }, ) } @@ -241,21 +243,25 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { pub fn condition_require_equal( &mut self, name_fn: N, - cond: Expression, - target: Expression, - true_expr: Expression, - false_expr: Expression, + cond: impl ToExpr>, + target: impl ToExpr>, + true_expr: impl ToExpr>, + false_expr: impl ToExpr>, ) -> Result<(), ZKVMError> where NR: Into, N: FnOnce() -> NR, { + let cond = cond.expr(); + let target = target.expr(); + let true_expr = true_expr.expr(); + let false_expr = false_expr.expr(); // cond * (true_expr) + (1 - cond) * false_expr // => false_expr + cond * true_expr - cond * false_expr self.namespace( || "cond_require_equal", |cb| { - let cond_target = false_expr.clone() + cond.clone() * true_expr - cond * false_expr; + let cond_target = &false_expr + &cond * true_expr - cond * false_expr; cb.cs.require_zero(name_fn, target - cond_target) }, ) @@ -263,22 +269,24 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { pub fn select( &mut self, - cond: &Expression, - when_true: &Expression, - when_false: &Expression, + cond: impl ToExpr>, + when_true: impl ToExpr>, + when_false: impl ToExpr>, ) -> Expression { - cond * when_true + (1 - cond) * when_false + let cond = cond.expr(); + &cond * when_true.expr() + (1 - &cond) * when_false.expr() } pub(crate) fn assert_ux( &mut self, name_fn: N, - expr: Expression, + expr: impl ToExpr>, ) -> Result<(), ZKVMError> where NR: Into, N: FnOnce() -> NR, { + let expr = expr.expr(); match C { 16 => self.assert_u16(name_fn, expr), 14 => self.assert_u14(name_fn, expr), @@ -333,25 +341,26 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { pub(crate) fn assert_byte( &mut self, name_fn: N, - expr: Expression, + expr: impl ToExpr>, ) -> Result<(), ZKVMError> where NR: Into, N: FnOnce() -> NR, { - self.lk_record(name_fn, ROMType::U8, vec![expr])?; + self.lk_record(name_fn, ROMType::U8, vec![expr.expr()])?; Ok(()) } pub(crate) fn assert_bit( &mut self, name_fn: N, - expr: Expression, + expr: impl ToExpr>, ) -> Result<(), ZKVMError> where NR: Into, N: FnOnce() -> NR, { + let expr = expr.expr(); self.namespace( || "assert_bit", |cb| cb.cs.require_zero(name_fn, &expr * (1 - &expr)), @@ -362,10 +371,13 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { pub fn logic_u8( &mut self, rom_type: ROMType, - a: Expression, - b: Expression, - c: Expression, + a: impl ToExpr>, + b: impl ToExpr>, + c: impl ToExpr>, ) -> Result<(), ZKVMError> { + let a = a.expr(); + let b = b.expr(); + let c = c.expr(); self.lk_record(|| format!("lookup_{:?}", rom_type), rom_type, vec![a, b, c]) } @@ -402,31 +414,30 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { /// Assert that `(a < b) == c as bool`, that `a, b` are unsigned bytes, and that `c` is 0 or 1. pub fn lookup_ltu_byte( &mut self, - a: Expression, - b: Expression, - c: Expression, + a: impl ToExpr>, + b: impl ToExpr>, + c: impl ToExpr>, ) -> Result<(), ZKVMError> { self.logic_u8(ROMType::Ltu, a, b, c) } // Assert that `2^b = c` and that `b` is a 5-bit unsigned integer. pub fn lookup_pow2(&mut self, b: Expression, c: Expression) -> Result<(), ZKVMError> { - self.logic_u8(ROMType::Pow, 2.into(), b, c) + self.logic_u8(ROMType::Pow, 2, b, c) } pub(crate) fn is_equal( &mut self, - lhs: Expression, - rhs: Expression, + lhs: impl ToExpr>, + rhs: impl ToExpr>, ) -> Result<(WitIn, WitIn), ZKVMError> { + let lhs = lhs.expr(); + let rhs = rhs.expr(); let is_eq = self.create_witin(|| "is_eq"); let diff_inverse = self.create_witin(|| "diff_inverse"); - self.require_zero(|| "is equal", is_eq.expr() * &lhs - is_eq.expr() * &rhs)?; - self.require_zero( - || "is equal", - 1 - is_eq.expr() - diff_inverse.expr() * lhs + diff_inverse.expr() * rhs, - )?; + self.require_zero(|| "is equal", is_eq * &lhs - is_eq * &rhs)?; + self.require_zero(|| "is equal", 1 + diff_inverse * (rhs + lhs) - is_eq)?; Ok((is_eq, diff_inverse)) } diff --git a/ceno_zkvm/src/chip_handler/global_state.rs b/ceno_zkvm/src/chip_handler/global_state.rs index 27c28e166..ffeff9de7 100644 --- a/ceno_zkvm/src/chip_handler/global_state.rs +++ b/ceno_zkvm/src/chip_handler/global_state.rs @@ -1,17 +1,24 @@ use ff_ext::ExtensionField; use crate::{ - circuit_builder::CircuitBuilder, error::ZKVMError, expression::Expression, structs::RAMType, + circuit_builder::CircuitBuilder, + error::ZKVMError, + expression::{Expression, ToExpr}, + structs::RAMType, }; use super::GlobalStateRegisterMachineChipOperations; impl GlobalStateRegisterMachineChipOperations for CircuitBuilder<'_, E> { - fn state_in(&mut self, pc: Expression, ts: Expression) -> Result<(), ZKVMError> { + fn state_in( + &mut self, + pc: impl ToExpr>, + ts: impl ToExpr>, + ) -> Result<(), ZKVMError> { let record: Vec> = vec![ Expression::Constant(E::BaseField::from(RAMType::GlobalState as u64)), - pc, - ts, + pc.expr(), + ts.expr(), ]; self.read_record(|| "state_in", RAMType::GlobalState, record) diff --git a/ceno_zkvm/src/chip_handler/register.rs b/ceno_zkvm/src/chip_handler/register.rs index c85060cc6..bf8a618b3 100644 --- a/ceno_zkvm/src/chip_handler/register.rs +++ b/ceno_zkvm/src/chip_handler/register.rs @@ -17,8 +17,8 @@ impl, N: FnOnce() -> NR> RegisterChipOperati fn register_read( &mut self, name_fn: N, - register_id: impl ToExpr>, - prev_ts: Expression, + register_id: impl ToExpr> + std::marker::Copy, + prev_ts: impl ToExpr> + std::marker::Copy, ts: Expression, value: RegisterExpr, ) -> Result<(Expression, AssertLtConfig), ZKVMError> { @@ -28,7 +28,7 @@ impl, N: FnOnce() -> NR> RegisterChipOperati vec![RAMType::Register.into()], vec![register_id.expr()], value.to_vec(), - vec![prev_ts.clone()], + vec![prev_ts.expr()], ] .concat(); // Write (a, v, t) @@ -60,8 +60,8 @@ impl, N: FnOnce() -> NR> RegisterChipOperati fn register_write( &mut self, name_fn: N, - register_id: impl ToExpr>, - prev_ts: Expression, + register_id: impl ToExpr> + std::marker::Copy, + prev_ts: impl ToExpr> + std::marker::Copy, ts: Expression, prev_values: RegisterExpr, value: RegisterExpr, @@ -73,7 +73,7 @@ impl, N: FnOnce() -> NR> RegisterChipOperati vec![RAMType::Register.into()], vec![register_id.expr()], prev_values.to_vec(), - vec![prev_ts.clone()], + vec![prev_ts.expr()], ] .concat(); // Write (a, v, t) diff --git a/ceno_zkvm/src/circuit_builder.rs b/ceno_zkvm/src/circuit_builder.rs index 709f46839..206b18c46 100644 --- a/ceno_zkvm/src/circuit_builder.rs +++ b/ceno_zkvm/src/circuit_builder.rs @@ -9,7 +9,7 @@ use crate::{ ROMType, chip_handler::utils::rlc_chip_record, error::ZKVMError, - expression::{Expression, Fixed, Instance, WitIn}, + expression::{Expression, Fixed, Instance, ToExpr, WitIn}, structs::{ProgramParams, ProvingKey, RAMType, VerifyingKey, WitnessId}, witness::RowMajorMatrix, }; @@ -440,8 +440,9 @@ impl ConstraintSystem { pub fn require_zero, N: FnOnce() -> NR>( &mut self, name_fn: N, - assert_zero_expr: Expression, + assert_zero_expr: impl ToExpr>, ) -> Result<(), ZKVMError> { + let assert_zero_expr = assert_zero_expr.expr(); assert!( assert_zero_expr.degree() > 0, "constant expression assert to zero ?" diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index a39f8a6a1..78ca0234b 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -5,7 +5,10 @@ use std::{ fmt::Display, iter::{Product, Sum}, mem::MaybeUninit, - ops::{Add, AddAssign, Deref, Mul, MulAssign, Neg, Shl, ShlAssign, Sub, SubAssign}, + ops::{ + Add, AddAssign, Deref, Div, Mul, MulAssign, Neg, Shl, ShlAssign, Shr, ShrAssign, Sub, + SubAssign, + }, }; use ceno_emul::InsnKind; @@ -361,6 +364,35 @@ impl ShlAssign for Expression { } } +// + +impl Shr for Expression { + type Output = Expression; + fn shr(self, rhs: usize) -> Expression { + self / (1_usize << rhs) + } +} + +impl Shr for &Expression { + type Output = Expression; + fn shr(self, rhs: usize) -> Expression { + self.clone() >> rhs + } +} + +impl Shr for &mut Expression { + type Output = Expression; + fn shr(self, rhs: usize) -> Expression { + self.clone() >> rhs + } +} + +impl ShrAssign for Expression { + fn shr_assign(&mut self, rhs: usize) { + *self = self.clone() >> rhs; + } +} + impl Sum for Expression { fn sum>>(iter: I) -> Expression { iter.fold(Expression::ZERO, |acc, x| acc + x) @@ -588,20 +620,29 @@ macro_rules! mixed_binop_instances { }; } -mixed_binop_instances!( - Add, - add, - (u8, u16, u32, u64, usize, i8, i16, i32, i64, isize) -); -mixed_binop_instances!( - Sub, - sub, - (u8, u16, u32, u64, usize, i8, i16, i32, i64, isize) -); -mixed_binop_instances!( - Mul, - mul, - (u8, u16, u32, u64, usize, i8, i16, i32, i64, isize) +macro_rules! mixed_binop_instances_all { + ($($t:ty),*) => { + mixed_binop_instances!( + Add, + add, + ($($t),*) + ); + mixed_binop_instances!( + Sub, + sub, + ($($t),*) + ); + mixed_binop_instances!( + Mul, + mul, + ($($t),*) + ); + }; +} + +mixed_binop_instances_all!( + u8, u16, u32, u64, usize, i8, i16, i32, i64, isize, WitIn, Fixed, Instance, &WitIn, &Fixed, + &Instance ); impl Mul for Expression { @@ -721,6 +762,32 @@ impl Mul for Expression { } } +macro_rules! div_instances { + (($($t:ty),*)) => { + $( + + impl Div<$t> for Expression { + type Output = Expression; + #[allow(clippy::suspicious_arithmetic_impl)] + fn div(self, rhs: $t) -> Expression { + let reduced = (rhs as i128).rem_euclid(E::BaseField::MODULUS_U64 as i128) as u64; + self * E::BaseField::from(reduced).invert().unwrap().to_canonical_u64() + } + } + + impl Div<$t> for &Expression { + type Output = Expression; + #[allow(clippy::suspicious_arithmetic_impl)] + fn div(self, rhs: $t) -> Expression { + let reduced = (rhs as i128).rem_euclid(E::BaseField::MODULUS_U64 as i128) as u64; + self * E::BaseField::from(reduced).invert().unwrap().to_canonical_u64() + } + } + )* + }; +} +div_instances!((u8, u16, u32, u64, usize, i8, i16, i32, i64, isize)); + #[derive(Clone, Debug, Copy)] pub struct WitIn { pub id: WitnessId, @@ -767,48 +834,55 @@ impl WitIn { pub trait ToExpr { type Output; - fn expr(&self) -> Self::Output; + fn expr(self) -> Self::Output; +} + +impl ToExpr for Expression { + type Output = Expression; + fn expr(self) -> Self::Output { + self + } +} + +impl ToExpr for &Expression { + type Output = Expression; + fn expr(self) -> Self::Output { + self.clone() + } } impl ToExpr for WitIn { type Output = Expression; - fn expr(&self) -> Expression { + fn expr(self) -> Self::Output { Expression::WitIn(self.id) } } impl ToExpr for &WitIn { type Output = Expression; - fn expr(&self) -> Expression { + fn expr(self) -> Self::Output { Expression::WitIn(self.id) } } impl ToExpr for Fixed { type Output = Expression; - fn expr(&self) -> Expression { - Expression::Fixed(*self) + fn expr(self) -> Self::Output { + Expression::Fixed(self) } } impl ToExpr for &Fixed { type Output = Expression; - fn expr(&self) -> Expression { - Expression::Fixed(**self) + fn expr(self) -> Self::Output { + Expression::Fixed(*self) } } impl ToExpr for Instance { type Output = Expression; - fn expr(&self) -> Expression { - Expression::Instance(*self) - } -} - -impl> ToExpr for F { - type Output = Expression; - fn expr(&self) -> Expression { - Expression::Constant(*self) + fn expr(self) -> Self::Output { + Expression::Instance(self) } } @@ -823,37 +897,41 @@ macro_rules! impl_from_via_ToExpr { )* }; } -impl_from_via_ToExpr!(WitIn, Fixed, Instance); +impl_from_via_ToExpr!( + WitIn, Fixed, Instance, u8, u16, u32, u64, usize, RAMType, InsnKind, i8, i16, i32, i64, isize +); impl_from_via_ToExpr!(&WitIn, &Fixed, &Instance); -// Implement From trait for unsigned types of at most 64 bits -macro_rules! impl_from_unsigned { +// Implement ToExpr trait for unsigned types of at most 64 bits +macro_rules! impl_ToExpr_unsigned { ($($t:ty),*) => { $( - impl> From<$t> for Expression { - fn from(value: $t) -> Self { - Expression::Constant(F::from(value as u64)) + impl> ToExpr for $t { + type Output = Expression; + fn expr(self) -> Self::Output { + Expression::Constant(F::from(self as u64)) } } )* }; } -impl_from_unsigned!(u8, u16, u32, u64, usize, RAMType, InsnKind); +impl_ToExpr_unsigned!(u8, u16, u32, u64, usize, RAMType, InsnKind); -// Implement From trait for signed types -macro_rules! impl_from_signed { +// Implement ToExpr trait for signed types +macro_rules! impl_ToExpr_signed { ($($t:ty),*) => { $( - impl> From<$t> for Expression { - fn from(value: $t) -> Self { - let reduced = (value as i128).rem_euclid(F::MODULUS_U64 as i128) as u64; + impl> ToExpr for $t { + type Output = Expression; + fn expr(self) -> Self::Output { + let reduced = (self as i128).rem_euclid(F::MODULUS_U64 as i128) as u64; Expression::Constant(F::from(reduced)) } } )* }; } -impl_from_signed!(i8, i16, i32, i64, isize); +impl_ToExpr_signed!(i8, i16, i32, i64, isize); impl Display for Expression { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { diff --git a/ceno_zkvm/src/gadgets/is_lt.rs b/ceno_zkvm/src/gadgets/is_lt.rs index ff0fe3989..0cb05d5ef 100644 --- a/ceno_zkvm/src/gadgets/is_lt.rs +++ b/ceno_zkvm/src/gadgets/is_lt.rs @@ -29,22 +29,16 @@ impl AssertLtConfig { >( cb: &mut CircuitBuilder, name_fn: N, - lhs: Expression, - rhs: Expression, + lhs: impl ToExpr>, + rhs: impl ToExpr>, max_num_u16_limbs: usize, ) -> Result { cb.namespace( || "assert_lt", |cb| { let name = name_fn(); - let config = InnerLtConfig::construct_circuit( - cb, - name, - lhs, - rhs, - Expression::ONE, - max_num_u16_limbs, - )?; + let config = + InnerLtConfig::construct_circuit(cb, name, lhs, rhs, 1, max_num_u16_limbs)?; Ok(Self(config)) }, ) @@ -68,11 +62,23 @@ pub struct IsLtConfig { config: InnerLtConfig, } -impl IsLtConfig { - pub fn expr(&self) -> Expression { +impl ToExpr for IsLtConfig { + type Output = Expression; + + fn expr(self) -> Self::Output { self.is_lt.expr() } +} + +impl ToExpr for &IsLtConfig { + type Output = Expression; + fn expr(self) -> Self::Output { + (&self.is_lt).expr() + } +} + +impl IsLtConfig { pub fn construct_circuit< E: ExtensionField, NR: Into + Display + Clone, @@ -80,8 +86,8 @@ impl IsLtConfig { >( cb: &mut CircuitBuilder, name_fn: N, - lhs: Expression, - rhs: Expression, + lhs: impl ToExpr>, + rhs: impl ToExpr>, max_num_u16_limbs: usize, ) -> Result { cb.namespace( @@ -89,16 +95,10 @@ impl IsLtConfig { |cb| { let name = name_fn(); let is_lt = cb.create_witin(|| format!("{name} is_lt witin")); - cb.assert_bit(|| "is_lt_bit", is_lt.expr())?; - - let config = InnerLtConfig::construct_circuit( - cb, - name, - lhs, - rhs, - is_lt.expr(), - max_num_u16_limbs, - )?; + cb.assert_bit(|| "is_lt_bit", is_lt)?; + + let config = + InnerLtConfig::construct_circuit(cb, name, lhs, rhs, is_lt, max_num_u16_limbs)?; Ok(Self { is_lt, config }) }, ) @@ -144,11 +144,14 @@ impl InnerLtConfig { pub fn construct_circuit + Display + Clone>( cb: &mut CircuitBuilder, name: NR, - lhs: Expression, - rhs: Expression, - is_lt_expr: Expression, + lhs: impl ToExpr>, + rhs: impl ToExpr>, + is_lt_expr: impl ToExpr>, max_num_u16_limbs: usize, ) -> Result { + let lhs = lhs.expr(); + let rhs = rhs.expr(); + let is_lt_expr = is_lt_expr.expr(); assert!(max_num_u16_limbs >= 1); let mut witin_u16 = |var_name: String| -> Result { @@ -156,7 +159,7 @@ impl InnerLtConfig { || format!("var {var_name}"), |cb| { let witin = cb.create_witin(|| var_name.to_string()); - cb.assert_ux::<_, _, 16>(|| name.clone(), witin.expr())?; + cb.assert_ux::<_, _, 16>(|| name.clone(), witin)?; Ok(witin) }, ) @@ -169,7 +172,7 @@ impl InnerLtConfig { let pows = power_sequence((1 << u16::BITS).into()); let diff_expr = izip!(&diff, pows) - .map(|(record, beta)| beta * record.expr()) + .map(|(record, beta)| beta * record) .sum::>(); let range = Self::range(max_num_u16_limbs); @@ -247,8 +250,7 @@ impl AssertSignedLtConfig { || "assert_signed_lt", |cb| { let name = name_fn(); - let config = - InnerSignedLtConfig::construct_circuit(cb, name, lhs, rhs, Expression::ONE)?; + let config = InnerSignedLtConfig::construct_circuit(cb, name, lhs, rhs, 1)?; Ok(Self { config }) }, ) @@ -272,11 +274,23 @@ pub struct SignedLtConfig { config: InnerSignedLtConfig, } -impl SignedLtConfig { - pub fn expr(&self) -> Expression { +impl ToExpr for SignedLtConfig { + type Output = Expression; + + fn expr(self) -> Self::Output { self.is_lt.expr() } +} + +impl ToExpr for &SignedLtConfig { + type Output = Expression; + + fn expr(self) -> Self::Output { + (&self.is_lt).expr() + } +} +impl SignedLtConfig { pub fn construct_circuit + Display + Clone, N: FnOnce() -> NR>( cb: &mut CircuitBuilder, name_fn: N, @@ -288,9 +302,8 @@ impl SignedLtConfig { |cb| { let name = name_fn(); let is_lt = cb.create_witin(|| format!("{name} is_signed_lt witin")); - cb.assert_bit(|| "is_lt_bit", is_lt.expr())?; - let config = - InnerSignedLtConfig::construct_circuit(cb, name, lhs, rhs, is_lt.expr())?; + cb.assert_bit(|| "is_lt_bit", is_lt)?; + let config = InnerSignedLtConfig::construct_circuit(cb, name, lhs, rhs, is_lt)?; Ok(SignedLtConfig { is_lt, config }) }, @@ -324,15 +337,15 @@ impl InnerSignedLtConfig { name: NR, lhs: &UInt, rhs: &UInt, - is_lt_expr: Expression, + is_lt_expr: impl ToExpr>, ) -> Result { // Extract the sign bit. let is_lhs_neg = lhs.is_negative(cb)?; let is_rhs_neg = rhs.is_negative(cb)?; // Convert to field arithmetic. - let lhs_value = lhs.to_field_expr(is_lhs_neg.expr()); - let rhs_value = rhs.to_field_expr(is_rhs_neg.expr()); + let lhs_value = lhs.to_field_expr(is_lhs_neg); + let rhs_value = rhs.to_field_expr(is_rhs_neg); let config = InnerLtConfig::construct_circuit( cb, format!("{name} (lhs < rhs)"), diff --git a/ceno_zkvm/src/gadgets/is_zero.rs b/ceno_zkvm/src/gadgets/is_zero.rs index f7d749354..676693efc 100644 --- a/ceno_zkvm/src/gadgets/is_zero.rs +++ b/ceno_zkvm/src/gadgets/is_zero.rs @@ -10,16 +10,29 @@ use crate::{ set_val, }; +#[derive(Clone, Copy, Debug)] pub struct IsZeroConfig { is_zero: Option, inverse: WitIn, } -impl IsZeroConfig { - pub fn expr(&self) -> Expression { - self.is_zero.map(|wit| wit.expr()).unwrap_or(0.into()) +impl ToExpr for IsZeroConfig { + type Output = Expression; + + fn expr(self) -> Self::Output { + self.is_zero.map(ToExpr::expr).unwrap_or(0.into()) } +} +impl ToExpr for &IsZeroConfig { + type Output = Expression; + + fn expr(self) -> Self::Output { + self.is_zero.map(ToExpr::expr).unwrap_or(0.into()) + } +} + +impl IsZeroConfig { pub fn construct_circuit, N: FnOnce() -> NR>( cb: &mut CircuitBuilder, name_fn: N, @@ -49,14 +62,14 @@ impl IsZeroConfig { let is_zero = cb.create_witin(|| "is_zero"); // x!=0 => is_zero=0 - cb.require_zero(|| "is_zero_0", is_zero.expr() * x.clone())?; + cb.require_zero(|| "is_zero_0", is_zero * &x)?; (Some(is_zero), is_zero.expr()) }; let inverse = cb.create_witin(|| "inv"); // x==0 => is_zero=1 - cb.require_one(|| "is_zero_1", is_zero_expr + x.clone() * inverse.expr())?; + cb.require_one(|| "is_zero_1", is_zero_expr + x * inverse)?; Ok(IsZeroConfig { is_zero, inverse }) }) @@ -82,13 +95,26 @@ impl IsZeroConfig { } } +#[derive(Clone, Copy, Debug)] pub struct IsEqualConfig(IsZeroConfig); -impl IsEqualConfig { - pub fn expr(&self) -> Expression { +impl ToExpr for IsEqualConfig { + type Output = Expression; + + fn expr(self) -> Self::Output { self.0.expr() } +} + +impl ToExpr for &IsEqualConfig { + type Output = Expression; + fn expr(self) -> Self::Output { + (&self.0).expr() + } +} + +impl IsEqualConfig { pub fn construct_circuit, N: FnOnce() -> NR>( cb: &mut CircuitBuilder, name_fn: N, @@ -105,9 +131,11 @@ impl IsEqualConfig { pub fn construct_non_equal, N: FnOnce() -> NR>( cb: &mut CircuitBuilder, name_fn: N, - a: Expression, - b: Expression, + a: impl ToExpr>, + b: impl ToExpr>, ) -> Result { + let a = a.expr(); + let b = b.expr(); Ok(IsEqualConfig(IsZeroConfig::construct_non_zero( cb, name_fn, diff --git a/ceno_zkvm/src/gadgets/signed_ext.rs b/ceno_zkvm/src/gadgets/signed_ext.rs index d1dc8ed62..8d642955b 100644 --- a/ceno_zkvm/src/gadgets/signed_ext.rs +++ b/ceno_zkvm/src/gadgets/signed_ext.rs @@ -9,7 +9,7 @@ use crate::{ use ff_ext::ExtensionField; use std::{marker::PhantomData, mem::MaybeUninit}; -#[derive(Debug)] +#[derive(Copy, Clone, Debug)] pub struct SignedExtendConfig { /// most significant bit msb: WitIn, @@ -19,6 +19,22 @@ pub struct SignedExtendConfig { _marker: PhantomData, } +impl ToExpr for SignedExtendConfig { + type Output = Expression; + + fn expr(self) -> Self::Output { + self.msb.expr() + } +} + +impl ToExpr for &SignedExtendConfig { + type Output = Expression; + + fn expr(self) -> Self::Output { + self.msb.expr() + } +} + impl SignedExtendConfig { pub fn construct_limb( cb: &mut CircuitBuilder, @@ -34,10 +50,6 @@ impl SignedExtendConfig { Self::construct_circuit(cb, 8, val) } - pub fn expr(&self) -> Expression { - self.msb.expr() - } - fn construct_circuit( cb: &mut CircuitBuilder, n_bits: usize, @@ -47,7 +59,7 @@ impl SignedExtendConfig { let msb = cb.create_witin(|| "msb"); // require msb is boolean - cb.assert_bit(|| "msb is boolean", msb.expr())?; + cb.assert_bit(|| "msb is boolean", msb)?; // assert 2*val - msb*2^N_BITS is within range [0, 2^N_BITS) // - if val < 2^(N_BITS-1), then 2*val < 2^N_BITS, msb can only be zero. diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm.rs b/ceno_zkvm/src/instructions/riscv/arith_imm.rs index 2508000f2..10654abe4 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm.rs @@ -41,7 +41,7 @@ impl Instruction for AddiInstruction { let i_insn = IInstructionConfig::::construct_circuit( circuit_builder, Self::INST_KIND, - &imm.value(), + imm.value(), rs1_read.register_expr(), rd_written.register_expr(), false, diff --git a/ceno_zkvm/src/instructions/riscv/b_insn.rs b/ceno_zkvm/src/instructions/riscv/b_insn.rs index c638d314f..1c44098bf 100644 --- a/ceno_zkvm/src/instructions/riscv/b_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/b_insn.rs @@ -69,14 +69,9 @@ impl BInstructionConfig { ))?; // Branch program counter - let pc_offset = - branch_taken_bit.clone() * imm.expr() - branch_taken_bit * PC_STEP_SIZE + PC_STEP_SIZE; + let pc_offset = &branch_taken_bit * imm - branch_taken_bit * PC_STEP_SIZE + PC_STEP_SIZE; let next_pc = vm_state.next_pc.unwrap(); - circuit_builder.require_equal( - || "pc_branch", - next_pc.expr(), - vm_state.pc.expr() + pc_offset, - )?; + circuit_builder.require_equal(|| "pc_branch", next_pc, vm_state.pc + pc_offset)?; Ok(BInstructionConfig { vm_state, diff --git a/ceno_zkvm/src/instructions/riscv/branch/beq_circuit.rs b/ceno_zkvm/src/instructions/riscv/branch/beq_circuit.rs index 4826c94bf..bcf9975e1 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/beq_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/beq_circuit.rs @@ -7,7 +7,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, error::ZKVMError, - expression::Expression, + expression::ToExpr, gadgets::IsEqualConfig, instructions::{ Instruction, @@ -50,7 +50,7 @@ impl Instruction for BeqCircuit { let branch_taken_bit = match I::INST_KIND { InsnKind::BEQ => equal.expr(), - InsnKind::BNE => Expression::ONE - equal.expr(), + InsnKind::BNE => 1 - equal.expr(), _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), }; diff --git a/ceno_zkvm/src/instructions/riscv/branch/blt.rs b/ceno_zkvm/src/instructions/riscv/branch/blt.rs index c5e0798f2..f72ad04bd 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/blt.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/blt.rs @@ -6,7 +6,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, error::ZKVMError, - expression::Expression, + expression::ToExpr, gadgets::SignedLtConfig, instructions::{ Instruction, @@ -42,8 +42,8 @@ impl Instruction for BltCircuit { SignedLtConfig::construct_circuit(circuit_builder, || "rs1 is_lt.expr(), - InsnKind::BGE => Expression::ONE - is_lt.expr(), + InsnKind::BLT => (&is_lt).expr(), + InsnKind::BGE => 1 - (&is_lt).expr(), _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), }; diff --git a/ceno_zkvm/src/instructions/riscv/branch/bltu.rs b/ceno_zkvm/src/instructions/riscv/branch/bltu.rs index 896bf19da..a6ec83add 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/bltu.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/bltu.rs @@ -6,7 +6,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, error::ZKVMError, - expression::Expression, + expression::ToExpr, gadgets::IsLtConfig, instructions::{ Instruction, @@ -51,8 +51,8 @@ impl Instruction for BltuCircuit )?; let branch_taken_bit = match I::INST_KIND { - InsnKind::BLTU => is_lt.expr(), - InsnKind::BGEU => Expression::ONE - is_lt.expr(), + InsnKind::BLTU => (&is_lt).expr(), + InsnKind::BGEU => 1 - (&is_lt).expr(), _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), }; diff --git a/ceno_zkvm/src/instructions/riscv/div.rs b/ceno_zkvm/src/instructions/riscv/div.rs index e9cb2e4ca..404250459 100644 --- a/ceno_zkvm/src/instructions/riscv/div.rs +++ b/ceno_zkvm/src/instructions/riscv/div.rs @@ -10,7 +10,7 @@ use super::{ use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, - expression::Expression, + expression::ToExpr, gadgets::{IsLtConfig, IsZeroConfig}, instructions::Instruction, uint::Value, @@ -79,10 +79,10 @@ impl Instruction for ArithInstruction::TOTAL_BITS) - 1).into(), - outcome_value, + is_zero, + &outcome_value, + (1u64 << UInt::::TOTAL_BITS) - 1, + &outcome_value, )?; // remainder should be less than divisor if divisor != 0. @@ -98,8 +98,8 @@ impl Instruction for ArithInstruction::construct_circuit( diff --git a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs index 70338e3d6..425edb10d 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs @@ -47,11 +47,12 @@ impl Instruction for HaltInstruction { Some(EXIT_PC.into()), )?; + // let reg_: usize = ceno_emul::Platform::reg_arg0(); // read exit_code from arg0 (X10 register) let (_, lt_x10_cfg) = cb.register_read( || "read x10", - E::BaseField::from(ceno_emul::Platform::reg_arg0() as u64), - prev_x10_ts.expr(), + ceno_emul::Platform::reg_arg0(), + prev_x10_ts, ecall_cfg.ts.expr() + Tracer::SUBCYCLE_RS2, exit_code, )?; diff --git a/ceno_zkvm/src/instructions/riscv/ecall_insn.rs b/ceno_zkvm/src/instructions/riscv/ecall_insn.rs index 0457979c7..0a9854c3f 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall_insn.rs @@ -31,10 +31,10 @@ impl EcallInstructionConfig { let pc = cb.create_witin(|| "pc"); let ts = cb.create_witin(|| "cur_ts"); - cb.state_in(pc.expr(), ts.expr())?; + cb.state_in(pc, ts)?; cb.state_out( next_pc.map_or(pc.expr() + PC_STEP_SIZE, |next_pc| next_pc), - ts.expr() + (Tracer::SUBCYCLES_PER_INSN as usize), + ts.expr() + Tracer::SUBCYCLES_PER_INSN, )?; cb.lk_fetch(&InsnRecord::new( @@ -51,8 +51,8 @@ impl EcallInstructionConfig { // read syscall_id from x5 and write return value to x5 let (_, lt_x5_cfg) = cb.register_write( || "write x5", - E::BaseField::from(Platform::reg_ecall() as u64), - prev_x5_ts.expr(), + Platform::reg_ecall(), + prev_x5_ts, ts.expr() + Tracer::SUBCYCLE_RS1, syscall_id.clone(), syscall_ret_value.map_or(syscall_id, |v| v), diff --git a/ceno_zkvm/src/instructions/riscv/i_insn.rs b/ceno_zkvm/src/instructions/riscv/i_insn.rs index 65beb8c5f..32f748da1 100644 --- a/ceno_zkvm/src/instructions/riscv/i_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/i_insn.rs @@ -28,7 +28,7 @@ impl IInstructionConfig { pub fn construct_circuit( circuit_builder: &mut CircuitBuilder, insn_kind: InsnKind, - imm: &Expression, + imm: impl ToExpr>, rs1_read: RegisterExpr, rd_written: RegisterExpr, branching: bool, @@ -49,7 +49,7 @@ impl IInstructionConfig { Some(rd.id.expr()), rs1.id.expr(), 0.into(), - imm.clone(), + imm.expr(), ))?; Ok(IInstructionConfig { vm_state, rs1, rd }) diff --git a/ceno_zkvm/src/instructions/riscv/im_insn.rs b/ceno_zkvm/src/instructions/riscv/im_insn.rs index 6727f6628..e51ffae99 100644 --- a/ceno_zkvm/src/instructions/riscv/im_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/im_insn.rs @@ -26,7 +26,7 @@ impl IMInstructionConfig { pub fn construct_circuit( circuit_builder: &mut CircuitBuilder, insn_kind: InsnKind, - imm: &Expression, + imm: impl ToExpr>, rs1_read: RegisterExpr, memory_read: MemoryExpr, memory_addr: AddressExpr, @@ -49,7 +49,7 @@ impl IMInstructionConfig { Some(rd.id.expr()), rs1.id.expr(), 0.into(), - imm.clone(), + imm.expr(), ))?; Ok(IMInstructionConfig { diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index 3f71b6b20..4ba3d7a9e 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -1,5 +1,4 @@ use ceno_emul::{StepRecord, Word}; -use ff::Field; use ff_ext::ExtensionField; use itertools::Itertools; @@ -92,7 +91,7 @@ impl ReadRS1 { let (_, lt_cfg) = circuit_builder.register_read( || "read_rs1", id, - prev_ts.expr(), + prev_ts, cur_ts.expr() + Tracer::SUBCYCLE_RS1, rs1_read, )?; @@ -146,7 +145,7 @@ impl ReadRS2 { let (_, lt_cfg) = circuit_builder.register_read( || "read_rs2", id, - prev_ts.expr(), + prev_ts, cur_ts.expr() + Tracer::SUBCYCLE_RS2, rs2_read, )?; @@ -201,7 +200,7 @@ impl WriteRD { let (_, lt_cfg) = circuit_builder.register_write( || "write_rd", id, - prev_ts.expr(), + prev_ts, cur_ts.expr() + Tracer::SUBCYCLE_RD, prev_value.register_expr(), rd_written, @@ -420,11 +419,7 @@ impl MemAddr { .sum(); // Range check the middle bits, that is the low limb excluding the low bits. - let shift_right = E::BaseField::from(1 << Self::N_LOW_BITS) - .invert() - .unwrap() - .expr(); - let mid_u14 = (&limbs[0] - low_sum) * shift_right; + let mid_u14 = (&limbs[0] - low_sum) >> Self::N_LOW_BITS; cb.assert_ux::<_, _, 14>(|| "mid_u14", mid_u14)?; // Range check the high limb. @@ -477,6 +472,7 @@ mod test { ROMType, circuit_builder::{CircuitBuilder, ConstraintSystem}, error::ZKVMError, + expression::Expression, scheme::mock_prover::MockProver, witness::{LkMultiplicity, RowMajorMatrix}, }; @@ -535,9 +531,9 @@ mod test { assert_eq!(lkm[ROMType::U16 as usize].len(), 1); if is_ok { - cb.require_equal(|| "", mem_addr.expr_unaligned(), addr.into())?; - cb.require_equal(|| "", mem_addr.expr_align2(), (addr & !1).into())?; - cb.require_equal(|| "", mem_addr.expr_align4(), (addr & !3).into())?; + cb.require_equal(|| "", mem_addr.expr_unaligned(), Expression::from(addr))?; + cb.require_equal(|| "", mem_addr.expr_align2(), Expression::from(addr & !1))?; + cb.require_equal(|| "", mem_addr.expr_align4(), Expression::from(addr & !3))?; } MockProver::assert_with_expected_errors( &cb, diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs index 0339a6b0a..b8778a7f7 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs @@ -51,7 +51,7 @@ impl Instruction for JalrInstruction { let i_insn = IInstructionConfig::construct_circuit( circuit_builder, InsnKind::JALR, - &imm.expr(), + imm, rs1_read.register_expr(), rd_written.register_expr(), true, diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs index b6a8bb690..bcf05fd0f 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs @@ -87,7 +87,7 @@ impl LogicConfig { let i_insn = IInstructionConfig::::construct_circuit( cb, insn_kind, - &imm.value(), + imm.value(), rs1_read.register_expr(), rd_written.register_expr(), false, diff --git a/ceno_zkvm/src/instructions/riscv/memory/gadget.rs b/ceno_zkvm/src/instructions/riscv/memory/gadget.rs index 6539c1325..79efcd81c 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/gadget.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/gadget.rs @@ -8,7 +8,6 @@ use crate::{ witness::LkMultiplicity, }; use ceno_emul::StepRecord; -use ff::Field; use ff_ext::ExtensionField; use itertools::izip; use std::mem::MaybeUninit; @@ -55,7 +54,7 @@ impl MemWordChange { .iter() .enumerate() .map(|(idx, byte)| byte.expr() << (idx * 8)) - .sum(), + .sum::>(), )?; Ok(bytes) @@ -77,30 +76,29 @@ impl MemWordChange { // extract the least significant byte from u16 limb let rs2_limb_bytes = alloc_bytes(cb, "rs2_limb[0]", 1)?; - let u8_base_inv = E::BaseField::from(1 << 8).invert().unwrap(); cb.assert_ux::<_, _, 8>( || "rs2_limb[0].le_bytes[1]", - u8_base_inv.expr() * (&rs2_limbs[0] - rs2_limb_bytes[0].expr()), + (&rs2_limbs[0] - rs2_limb_bytes[0].expr()) >> 8, )?; // alloc a new witIn to cache degree 2 expression let expected_limb_change = cb.create_witin(|| "expected_limb_change"); cb.condition_require_equal( || "expected_limb_change = select(low_bits[0], rs2 - prev)", - low_bits[0].clone(), - expected_limb_change.expr(), - (rs2_limb_bytes[0].expr() - prev_limb_bytes[1].expr()) << 8, - rs2_limb_bytes[0].expr() - prev_limb_bytes[0].expr(), + &low_bits[0], + expected_limb_change, + (rs2_limb_bytes[0].expr() - prev_limb_bytes[1]) << 8, + rs2_limb_bytes[0].expr() - prev_limb_bytes[0], )?; // alloc a new witIn to cache degree 2 expression let expected_change = cb.create_witin(|| "expected_change"); cb.condition_require_equal( || "expected_change = select(low_bits[1], limb_change*2^16, limb_change)", - low_bits[1].clone(), - expected_change.expr(), + &low_bits[1], + expected_change, expected_limb_change.expr() << 16, - expected_limb_change.expr(), + expected_limb_change, )?; Ok(MemWordChange { diff --git a/ceno_zkvm/src/instructions/riscv/memory/load.rs b/ceno_zkvm/src/instructions/riscv/memory/load.rs index a06c17687..ccd68c9a5 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load.rs @@ -97,7 +97,7 @@ impl Instruction for LoadInstruction Instruction for LoadInstruction Instruction for LoadInstruction::construct_circuit( circuit_builder, I::INST_KIND, - &imm.expr(), + imm, rs1_read.register_expr(), memory_read.memory_expr(), memory_addr.expr_align4(), diff --git a/ceno_zkvm/src/instructions/riscv/memory/store.rs b/ceno_zkvm/src/instructions/riscv/memory/store.rs index d1d941b97..2b2a39fb1 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/store.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/store.rs @@ -2,7 +2,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, error::ZKVMError, - expression::{ToExpr, WitIn}, + expression::WitIn, instructions::{ Instruction, riscv::{ @@ -87,7 +87,7 @@ impl Instruction circuit_builder.require_equal( || "memory_addr = rs1_read + imm", memory_addr.expr_unaligned(), - rs1_read.value() + imm.expr(), + rs1_read.value() + imm, )?; let (new_memory_value, word_change) = match I::INST_KIND { @@ -107,7 +107,7 @@ impl Instruction let s_insn = SInstructionConfig::::construct_circuit( circuit_builder, I::INST_KIND, - &imm.expr(), + imm, rs1_read.register_expr(), rs2_read.register_expr(), memory_addr.expr_align4(), diff --git a/ceno_zkvm/src/instructions/riscv/mul.rs b/ceno_zkvm/src/instructions/riscv/mul.rs index 16a08fe41..d3ae5e8b2 100644 --- a/ceno_zkvm/src/instructions/riscv/mul.rs +++ b/ceno_zkvm/src/instructions/riscv/mul.rs @@ -87,7 +87,7 @@ use goldilocks::SmallField; use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, - expression::Expression, + expression::{Expression, ToExpr}, gadgets::{IsEqualConfig, SignedExtendConfig}, instructions::{ Instruction, @@ -198,9 +198,9 @@ impl Instruction for MulhInstructionBas let prod_low = UInt::new(|| "prod_low", circuit_builder)?; ( - rs1_signed.expr(), - rs2_signed.expr(), - rd_signed.expr(), + (&rs1_signed).expr(), + (&rs2_signed).expr(), + (&rd_signed).expr(), MulhSignDependencies::SS { rs1_signed, rs2_signed, @@ -213,12 +213,11 @@ impl Instruction for MulhInstructionBas InsnKind::MULHU => { let prod_low = UInt::new(|| "prod_low", circuit_builder)?; // constrain that rd does not represent 2^32 - 1 - let rd_avoid = Expression::::from(u32::MAX); let constrain_rd = IsEqualConfig::construct_non_equal( circuit_builder, || "constrain_rd", rd_written.value(), - rd_avoid, + u32::MAX, )?; ( @@ -231,13 +230,12 @@ impl Instruction for MulhInstructionBas } InsnKind::MUL => { // constrain that prod_hi does not represent 2^32 - 1 - let prod_hi_avoid = Expression::::from(u32::MAX); let prod_hi = UInt::new(|| "prod_hi", circuit_builder)?; let constrain_rd = IsEqualConfig::construct_non_equal( circuit_builder, || "constrain_prod_hi", prod_hi.value(), - prod_hi_avoid, + u32::MAX, )?; ( @@ -255,18 +253,17 @@ impl Instruction for MulhInstructionBas let prod_low = UInt::new(|| "prod_low", circuit_builder)?; // constrain that (signed) rd does not represent 2^31 - 1 - let rd_avoid = Expression::::from(i32::MAX); let constrain_rd = IsEqualConfig::construct_non_equal( circuit_builder, || "constrain_rd", - rd_signed.expr(), - rd_avoid, + &rd_signed, + i32::MAX, )?; ( - rs1_signed.expr(), + (&rs1_signed).expr(), rs2_read.value(), - rd_signed.expr(), + (&rd_signed).expr(), MulhSignDependencies::SU { rs1_signed, rd_signed, @@ -407,6 +404,20 @@ struct Signed { val: Expression, } +impl ToExpr for &Signed { + type Output = Expression; + fn expr(self) -> Self::Output { + self.val.clone() + } +} + +impl ToExpr for Signed { + type Output = Expression; + fn expr(self) -> Self::Output { + self.val + } +} + impl Signed { pub fn construct_circuit + Display + Clone, N: FnOnce() -> NR>( cb: &mut CircuitBuilder, @@ -415,7 +426,7 @@ impl Signed { ) -> Result { cb.namespace(name_fn, |cb| { let is_negative = unsigned_val.is_negative(cb)?; - let val = unsigned_val.value() - (1u64 << BIT_WIDTH) * is_negative.expr(); + let val = unsigned_val.value() - (is_negative.expr() << BIT_WIDTH); Ok(Self { is_negative, val }) }) @@ -434,10 +445,6 @@ impl Signed { )?; Ok(i32::from(val)) } - - pub fn expr(&self) -> Expression { - self.val.clone() - } } #[cfg(test)] diff --git a/ceno_zkvm/src/instructions/riscv/s_insn.rs b/ceno_zkvm/src/instructions/riscv/s_insn.rs index dc133e894..4158bcd36 100644 --- a/ceno_zkvm/src/instructions/riscv/s_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/s_insn.rs @@ -27,7 +27,7 @@ impl SInstructionConfig { pub fn construct_circuit( circuit_builder: &mut CircuitBuilder, insn_kind: InsnKind, - imm: &Expression, + imm: impl ToExpr>, rs1_read: RegisterExpr, rs2_read: RegisterExpr, memory_addr: AddressExpr, @@ -48,7 +48,7 @@ impl SInstructionConfig { None, rs1.id.expr(), rs2.id.expr(), - imm.clone(), + imm.expr(), ))?; // Memory diff --git a/ceno_zkvm/src/instructions/riscv/shift.rs b/ceno_zkvm/src/instructions/riscv/shift.rs index 365afe786..7e0cee7ce 100644 --- a/ceno_zkvm/src/instructions/riscv/shift.rs +++ b/ceno_zkvm/src/instructions/riscv/shift.rs @@ -113,8 +113,8 @@ impl Instruction for ShiftLogicalInstru let (inflow, signed_extend_config) = match I::INST_KIND { InsnKind::SRA => { let signed_extend_config = rs1_read.is_negative(circuit_builder)?; - let msb_expr = signed_extend_config.expr(); - let ones = pow2_rs2_low5.expr() - Expression::ONE; + let msb_expr = (&signed_extend_config).expr(); + let ones = (&pow2_rs2_low5).expr() - 1; (msb_expr * ones, Some(signed_extend_config)) } InsnKind::SRL => (Expression::ZERO, None), diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm.rs b/ceno_zkvm/src/instructions/riscv/shift_imm.rs index 080a8a6ae..3f02b5603 100644 --- a/ceno_zkvm/src/instructions/riscv/shift_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/shift_imm.rs @@ -106,7 +106,7 @@ impl Instruction for ShiftImmInstructio let (inflow, is_lt_config) = match I::INST_KIND { InsnKind::SRAI => { let is_rs1_neg = rs1_read.is_negative(circuit_builder)?; - let ones = imm.expr() - 1; + let ones: Expression = imm.expr() - 1; (is_rs1_neg.expr() * ones, Some(is_rs1_neg)) } InsnKind::SRLI => (Expression::ZERO, None), @@ -125,7 +125,7 @@ impl Instruction for ShiftImmInstructio let i_insn = IInstructionConfig::::construct_circuit( circuit_builder, I::INST_KIND, - &imm.expr(), + imm, rs1_read.register_expr(), rd_written.register_expr(), false, diff --git a/ceno_zkvm/src/instructions/riscv/slt.rs b/ceno_zkvm/src/instructions/riscv/slt.rs index 80fc69874..bc171bdb8 100644 --- a/ceno_zkvm/src/instructions/riscv/slt.rs +++ b/ceno_zkvm/src/instructions/riscv/slt.rs @@ -65,7 +65,7 @@ impl Instruction for SetLessThanInstruc InsnKind::SLT => { let signed_lt = SignedLtConfig::construct_circuit(cb, || "rs1 < rs2", &rs1_read, &rs2_read)?; - let rd_written = UInt::from_exprs_unchecked(vec![signed_lt.expr()]); + let rd_written = UInt::from_exprs_unchecked(vec![&signed_lt]); (SetLessThanDependencies::Slt { signed_lt }, rd_written) } InsnKind::SLTU => { @@ -76,7 +76,7 @@ impl Instruction for SetLessThanInstruc rs2_read.value(), UINT_LIMBS, )?; - let rd_written = UInt::from_exprs_unchecked(vec![is_lt.expr()]); + let rd_written = UInt::from_exprs_unchecked(vec![&is_lt]); (SetLessThanDependencies::Sltu { is_lt }, rd_written) } _ => unreachable!(), diff --git a/ceno_zkvm/src/instructions/riscv/slti.rs b/ceno_zkvm/src/instructions/riscv/slti.rs index 76894f7a0..be2315705 100644 --- a/ceno_zkvm/src/instructions/riscv/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/slti.rs @@ -66,19 +66,21 @@ impl Instruction for SetLessThanImmInst InsnKind::SLTIU => (rs1_read.value(), None), InsnKind::SLTI => { let is_rs1_neg = rs1_read.is_negative(cb)?; - (rs1_read.to_field_expr(is_rs1_neg.expr()), Some(is_rs1_neg)) + ( + rs1_read.to_field_expr((&is_rs1_neg).expr()), + Some(is_rs1_neg), + ) } _ => unreachable!("Unsupported instruction kind {:?}", I::INST_KIND), }; - let lt = - IsLtConfig::construct_circuit(cb, || "rs1 < imm", value_expr, imm.expr(), UINT_LIMBS)?; - let rd_written = UInt::from_exprs_unchecked(vec![lt.expr()]); + let lt = IsLtConfig::construct_circuit(cb, || "rs1 < imm", value_expr, imm, UINT_LIMBS)?; + let rd_written = UInt::from_exprs_unchecked(vec![(<).expr()]); let i_insn = IInstructionConfig::::construct_circuit( cb, I::INST_KIND, - &imm.expr(), + imm, rs1_read.register_expr(), rd_written.register_expr(), false, diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index b84e94148..ebaf0af28 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -1369,7 +1369,7 @@ mod tests { fn construct_circuit(cb: &mut CircuitBuilder) -> Result { let a = cb.create_witin(|| "a"); let b = cb.create_witin(|| "b"); - let lt_wtns = AssertLtConfig::construct_circuit(cb, || "lt", a.expr(), b.expr(), 1)?; + let lt_wtns = AssertLtConfig::construct_circuit(cb, || "lt", a, b, 1)?; Ok(Self { a, b, lt_wtns }) } diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 13ef29a66..42f8630e3 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -65,7 +65,7 @@ impl Instruction for Test Result::<(), ZKVMError>::Ok(()) })?; (0..L).try_for_each(|_| { - cb.assert_ux::<_, _, 16>(|| "regid_in_range", reg_id.expr())?; + cb.assert_ux::<_, _, 16>(|| "regid_in_range", reg_id)?; Result::<(), ZKVMError>::Ok(()) })?; assert_eq!(cb.cs.lk_expressions.len(), L); diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index c8ec6453a..1e8e31276 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -19,7 +19,9 @@ use rayon::{ }; use crate::{ - expression::Expression, scheme::constants::MIN_PAR_SIZE, utils::next_pow2_instance_padding, + expression::{Expression, ToExpr}, + scheme::constants::MIN_PAR_SIZE, + utils::next_pow2_instance_padding, }; /// interleaving multiple mles into mles, and num_limbs indicate number of final limbs vector @@ -350,7 +352,7 @@ pub(crate) fn wit_infer_by_expr<'a, E: ExtensionField, const N: usize>( pub(crate) fn eval_by_expr( witnesses: &[E], challenges: &[E], - expr: &Expression, + expr: impl ToExpr>, ) -> E { eval_by_expr_with_fixed(&[], witnesses, challenges, expr) } @@ -359,9 +361,9 @@ pub(crate) fn eval_by_expr_with_fixed( fixed: &[E], witnesses: &[E], challenges: &[E], - expr: &Expression, + expr: impl ToExpr>, ) -> E { - expr.evaluate::( + expr.expr().evaluate::( &|f| fixed[f.0], &|witness_id| witnesses[witness_id as usize], &|scalar| scalar.into(), diff --git a/ceno_zkvm/src/uint.rs b/ceno_zkvm/src/uint.rs index b555682ac..73a6858e5 100644 --- a/ceno_zkvm/src/uint.rs +++ b/ceno_zkvm/src/uint.rs @@ -113,7 +113,7 @@ impl UIntLimbs { .map(|i| { let w = cb.create_witin(|| format!("limb_{i}")); if is_check { - cb.assert_ux::<_, _, C>(|| format!("limb_{i}_in_{C}"), w.expr())?; + cb.assert_ux::<_, _, C>(|| format!("limb_{i}_in_{C}"), w)?; } // skip range check Ok(w) @@ -185,10 +185,10 @@ impl UIntLimbs { .map(|i| { let w = circuit_builder.create_witin(|| "wit for limb"); circuit_builder - .assert_ux::<_, _, C>(|| "range check", w.expr()) + .assert_ux::<_, _, C>(|| "range check", w) .unwrap(); circuit_builder - .require_zero(|| "create_witin_from_expr", w.expr() - &expr_limbs[i]) + .require_zero(|| "create_witin_from_expr", w - &expr_limbs[i]) .unwrap(); w }) @@ -315,9 +315,8 @@ impl UIntLimbs { chunk .iter() .zip(shift_pows.iter()) - .map(|(limb, shift)| shift * limb.expr()) - .reduce(|a, b| a + b) - .unwrap() + .map(|(&limb, shift)| shift * limb) + .sum::>() }) .collect_vec(); Ok(UIntLimbs::::from_exprs_unchecked(combined_limbs)) @@ -343,8 +342,8 @@ impl UIntLimbs { let limbs = (0..k) .map(|_| { let w = circuit_builder.create_witin(|| ""); - circuit_builder.assert_byte(|| "", w.expr()).unwrap(); - w.expr() + circuit_builder.assert_byte(|| "", w).unwrap(); + w }) .collect_vec(); let combined_limb = limbs @@ -355,19 +354,21 @@ impl UIntLimbs { .unwrap(); circuit_builder - .require_zero(|| "zero check", large_limb.expr() - combined_limb) + .require_zero(|| "zero check", large_limb - combined_limb) .unwrap(); limbs }) + .map(ToExpr::expr) .collect_vec(); UIntLimbs::::create_witin_from_exprs(circuit_builder, split_limbs) } - pub fn from_exprs_unchecked(expr_limbs: Vec>) -> Self { + pub fn from_exprs_unchecked(expr_limbs: Vec>>) -> Self { Self { limbs: UintLimb::Expression( expr_limbs .into_iter() + .map(ToExpr::expr) .chain(std::iter::repeat(Expression::ZERO)) .take(Self::NUM_LIMBS) .collect_vec(), @@ -479,10 +480,10 @@ impl UIntLimbs { )) } - pub fn to_field_expr(&self, is_neg: Expression) -> Expression { + pub fn to_field_expr(&self, is_neg: impl ToExpr>) -> Expression { // Convert two's complement representation into field arithmetic. // Example: 0xFFFF_FFFF = 2^32 - 1 --> shift --> -1 - self.value() - is_neg * (1_u64 << 32) + self.value() - (is_neg.expr() << 32) } } @@ -531,9 +532,9 @@ impl TryFrom<&[WitIn]> for UI } } -impl ToExpr for UIntLimbs { +impl ToExpr for &UIntLimbs { type Output = Vec>; - fn expr(&self) -> Vec> { + fn expr(self) -> Self::Output { match &self.limbs { UintLimb::WitIn(limbs) => limbs .iter() diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index 5e48b0319..cc6e09152 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -34,31 +34,30 @@ impl UIntLimbs { return Err(ZKVMError::CircuitError); }; carries.iter().enumerate().try_for_each(|(i, carry)| { - circuit_builder.assert_bit(|| format!("carry_{i}_in_as_bit"), carry.expr()) + circuit_builder.assert_bit(|| format!("carry_{i}_in_as_bit"), carry) })?; // perform add operation // c[i] = a[i] + b[i] + carry[i-1] - carry[i] * 2 ^ C c.limbs = UintLimb::Expression( - (self.expr()) + self.expr() .iter() .zip((*addend).iter()) .enumerate() .map(|(i, (a, b))| { let carries = c.carries.as_ref().unwrap(); - let carry = if i > 0 { carries.get(i - 1) } else { None }; + let carry = carries.get(i - 1); let next_carry = carries.get(i); - let mut limb_expr = a.clone() + b.clone(); - if carry.is_some() { - limb_expr = limb_expr.clone() + carry.unwrap().expr(); + let mut limb_expr = a + b; + if let Some(carry) = carry { + limb_expr += carry; } - if next_carry.is_some() { - limb_expr = limb_expr.clone() - next_carry.unwrap().expr() * Self::POW_OF_C; + if let Some(next_carry) = next_carry { + limb_expr -= next_carry.expr() * Self::POW_OF_C; } - circuit_builder - .assert_ux::<_, _, C>(|| format!("limb_{i}_in_{C}"), limb_expr.clone())?; + .assert_ux::<_, _, C>(|| format!("limb_{i}_in_{C}"), &limb_expr)?; Ok(limb_expr) }) .collect::>, ZKVMError>>()?, @@ -120,7 +119,7 @@ impl UIntLimbs { // with high limb, overall cell will be double let c_limbs: Vec = (0..num_limbs).try_fold(vec![], |mut c_limbs, i| { let limb = circuit_builder.create_witin(|| format!("limb_{i}")); - circuit_builder.assert_ux::<_, _, C>(|| format!("limb_{i}_in_{C}"), limb.expr())?; + circuit_builder.assert_ux::<_, _, C>(|| format!("limb_{i}_in_{C}"), limb)?; c_limbs.push(limb); Result::, ZKVMError>::Ok(c_limbs) })?; @@ -140,8 +139,8 @@ impl UIntLimbs { AssertLtConfig::construct_circuit( circuit_builder, || format!("carry_{i}_in_less_than"), - carry.expr(), - (Self::MAX_DEGREE_2_MUL_CARRY_VALUE as usize).into(), + carry, + Self::MAX_DEGREE_2_MUL_CARRY_VALUE, Self::MAX_DEGREE_2_MUL_CARRY_U16_LIMB, ) }) @@ -190,16 +189,16 @@ impl UIntLimbs { // constrain each limb with carry c_limbs.iter().enumerate().try_for_each(|(i, c_limb)| { - let carry = if i > 0 { c_carries.get(i - 1) } else { None }; + let carry = c_carries.get(i - 1); let next_carry = c_carries.get(i); - result_c[i] = result_c[i].clone() - c_limb.expr(); - if carry.is_some() { - result_c[i] = result_c[i].clone() + carry.unwrap().expr(); + result_c[i] -= c_limb; + if let Some(carry) = carry { + result_c[i] += carry; } - if next_carry.is_some() { - result_c[i] = result_c[i].clone() - next_carry.unwrap().expr() * Self::POW_OF_C; + if let Some(next_carry) = next_carry { + result_c[i] -= next_carry.expr() * Self::POW_OF_C; } - circuit_builder.require_zero(|| format!("mul_zero_{i}"), result_c[i].clone())?; + circuit_builder.require_zero(|| format!("mul_zero_{i}"), &result_c[i])?; Ok::<(), ZKVMError>(()) })?; @@ -269,18 +268,15 @@ impl UIntLimbs { let n_limbs = Self::NUM_LIMBS; let (is_equal_per_limb, diff_inv_per_limb): (Vec, Vec) = izip!(&self.limbs, &rhs.limbs) - .map(|(a, b)| circuit_builder.is_equal(a.expr(), b.expr())) + .map(|(a, b)| circuit_builder.is_equal(a, b)) .collect::, ZKVMError>>()? .into_iter() .unzip(); - let sum_expr = is_equal_per_limb - .iter() - .fold(Expression::ZERO, |acc, flag| acc.clone() + flag.expr()); + let sum_expr = is_equal_per_limb.iter().map(ToExpr::expr).sum(); let sum_flag = WitIn::from_expr(|| "sum_flag", circuit_builder, sum_expr, false)?; - let (is_equal, diff_inv) = - circuit_builder.is_equal(sum_flag.expr(), Expression::from(n_limbs))?; + let (is_equal, diff_inv) = circuit_builder.is_equal(sum_flag, n_limbs)?; Ok(IsEqualConfig { is_equal_per_limb, diff_inv_per_limb, @@ -667,8 +663,8 @@ mod tests { // overflow if overflow { - let overflow = uint_c.carries.unwrap().last().unwrap().expr(); - assert_eq!(eval_by_expr(&wit, &challenges, &overflow), E::ONE); + let &overflow = uint_c.carries.unwrap().last().unwrap(); + assert_eq!(eval_by_expr(&wit, &challenges, overflow), E::ONE); } else { // non-overflow case, the len of carries should be (NUM_CELLS - 1) assert_eq!(uint_c.carries.unwrap().len(), single_wit_size - 1) diff --git a/ceno_zkvm/src/uint/logic.rs b/ceno_zkvm/src/uint/logic.rs index 024d09d73..3867b48a1 100644 --- a/ceno_zkvm/src/uint/logic.rs +++ b/ceno_zkvm/src/uint/logic.rs @@ -3,8 +3,8 @@ use itertools::izip; use super::UIntLimbs; use crate::{ - ROMType, circuit_builder::CircuitBuilder, error::ZKVMError, expression::ToExpr, - tables::OpsTable, witness::LkMultiplicity, + ROMType, circuit_builder::CircuitBuilder, error::ZKVMError, tables::OpsTable, + witness::LkMultiplicity, }; // Only implemented for u8 limbs. @@ -19,7 +19,7 @@ impl UIntLimbs { c: &Self, ) -> Result<(), ZKVMError> { for (a_byte, b_byte, c_byte) in izip!(&a.limbs, &b.limbs, &c.limbs) { - cb.logic_u8(rom_type, a_byte.expr(), b_byte.expr(), c_byte.expr())?; + cb.logic_u8(rom_type, a_byte, b_byte, c_byte)?; } Ok(()) } diff --git a/mpcs/src/sum_check/classic/coeff.rs b/mpcs/src/sum_check/classic/coeff.rs index 12f46880f..10d5c1c20 100644 --- a/mpcs/src/sum_check/classic/coeff.rs +++ b/mpcs/src/sum_check/classic/coeff.rs @@ -49,9 +49,7 @@ impl ClassicSumCheckRoundMessage for Coefficients { } fn sum(&self) -> E { - self[1..] - .iter() - .fold(self[0].double(), |acc, coeff| acc + coeff) + self[0] + self[..].iter().sum::() } fn evaluate(&self, _: &Self::Auxiliary, challenge: &E) -> E {