Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ToExpr trait instead of ad-hoc expr functions #736

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions ceno_zkvm/src/chip_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ pub mod utils;
pub mod test;

pub trait GlobalStateRegisterMachineChipOperations<E: ExtensionField> {
fn state_in(&mut self, pc: Expression<E>, ts: Expression<E>) -> Result<(), ZKVMError>;
fn state_in(
&mut self,
pc: impl ToExpr<E, Output = Expression<E>>,
ts: impl ToExpr<E, Output = Expression<E>>,
) -> Result<(), ZKVMError>;

fn state_out(&mut self, pc: Expression<E>, ts: Expression<E>) -> Result<(), ZKVMError>;
}
Expand All @@ -30,8 +34,8 @@ pub trait RegisterChipOperations<E: ExtensionField, NR: Into<String>, N: FnOnce(
fn register_read(
&mut self,
name_fn: N,
register_id: impl ToExpr<E, Output = Expression<E>>,
prev_ts: Expression<E>,
register_id: impl ToExpr<E, Output = Expression<E>> + std::marker::Copy,
prev_ts: impl ToExpr<E, Output = Expression<E>> + std::marker::Copy,
ts: Expression<E>,
value: RegisterExpr<E>,
) -> Result<(Expression<E>, AssertLTConfig), ZKVMError>;
Expand All @@ -40,8 +44,8 @@ pub trait RegisterChipOperations<E: ExtensionField, NR: Into<String>, N: FnOnce(
fn register_write(
&mut self,
name_fn: N,
register_id: impl ToExpr<E, Output = Expression<E>>,
prev_ts: Expression<E>,
register_id: impl ToExpr<E, Output = Expression<E>> + std::marker::Copy,
prev_ts: impl ToExpr<E, Output = Expression<E>> + std::marker::Copy,
ts: Expression<E>,
prev_values: RegisterExpr<E>,
value: RegisterExpr<E>,
Expand Down
81 changes: 46 additions & 35 deletions ceno_zkvm/src/chip_handler/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
N: FnOnce() -> NR + Clone,
{
let byte = self.cs.create_witin(name_fn.clone());
self.assert_ux::<_, _, 8>(name_fn, byte.expr())?;
self.assert_ux::<_, _, 8>(name_fn, byte)?;

Ok(byte)
}
Expand All @@ -175,7 +175,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
N: FnOnce() -> NR + Clone,
{
let limb = self.cs.create_witin(name_fn.clone());
self.assert_ux::<_, _, 16>(name_fn, limb.expr())?;
self.assert_ux::<_, _, 16>(name_fn, limb)?;

Ok(limb)
}
Expand All @@ -191,15 +191,15 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
N: FnOnce() -> NR + Clone,
{
let wit = self.cs.create_witin(name_fn.clone());
self.require_equal(name_fn, wit.expr(), expr)?;
self.require_equal(name_fn, wit, expr)?;

Ok(wit)
}

pub fn require_zero<NR, N>(
&mut self,
name_fn: N,
assert_zero_expr: Expression<E>,
assert_zero_expr: impl ToExpr<E, Output = Expression<E>>,
) -> Result<(), ZKVMError>
where
NR: Into<String>,
Expand All @@ -214,8 +214,8 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
pub fn require_equal<NR, N>(
&mut self,
name_fn: N,
a: Expression<E>,
b: Expression<E>,
a: impl ToExpr<E, Output = Expression<E>>,
b: impl ToExpr<E, Output = Expression<E>>,
) -> Result<(), ZKVMError>
where
NR: Into<String>,
Expand All @@ -224,8 +224,10 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
self.namespace(
|| "require_equal",
|cb| {
cb.cs
.require_zero(name_fn, a.to_monomial_form() - b.to_monomial_form())
cb.cs.require_zero(
name_fn,
a.expr().to_monomial_form() - b.expr().to_monomial_form(),
)
},
)
}
Expand All @@ -241,44 +243,50 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
pub fn condition_require_equal<NR, N>(
&mut self,
name_fn: N,
cond: Expression<E>,
target: Expression<E>,
true_expr: Expression<E>,
false_expr: Expression<E>,
cond: impl ToExpr<E, Output = Expression<E>>,
target: impl ToExpr<E, Output = Expression<E>>,
true_expr: impl ToExpr<E, Output = Expression<E>>,
false_expr: impl ToExpr<E, Output = Expression<E>>,
) -> Result<(), ZKVMError>
where
NR: Into<String>,
N: FnOnce() -> NR,
{
let cond = cond.expr();
let target = target.expr();
let true_expr = true_expr.expr();
let false_expr = false_expr.expr();
// cond * (true_expr) + (1 - cond) * false_expr
// => false_expr + cond * true_expr - cond * false_expr
self.namespace(
|| "cond_require_equal",
|cb| {
let cond_target = false_expr.clone() + cond.clone() * true_expr - cond * false_expr;
let cond_target = &false_expr + &cond * true_expr - cond * false_expr;
cb.cs.require_zero(name_fn, target - cond_target)
},
)
}

pub fn select(
&mut self,
cond: &Expression<E>,
when_true: &Expression<E>,
when_false: &Expression<E>,
cond: impl ToExpr<E, Output = Expression<E>>,
when_true: impl ToExpr<E, Output = Expression<E>>,
when_false: impl ToExpr<E, Output = Expression<E>>,
) -> Expression<E> {
cond * when_true + (1 - cond) * when_false
let cond = cond.expr();
&cond * when_true.expr() + (1 - &cond) * when_false.expr()
}

pub(crate) fn assert_ux<NR, N, const C: usize>(
&mut self,
name_fn: N,
expr: Expression<E>,
expr: impl ToExpr<E, Output = Expression<E>>,
) -> Result<(), ZKVMError>
where
NR: Into<String>,
N: FnOnce() -> NR,
{
let expr = expr.expr();
match C {
16 => self.assert_u16(name_fn, expr),
14 => self.assert_u14(name_fn, expr),
Expand Down Expand Up @@ -333,25 +341,26 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
pub(crate) fn assert_byte<NR, N>(
&mut self,
name_fn: N,
expr: Expression<E>,
expr: impl ToExpr<E, Output = Expression<E>>,
) -> Result<(), ZKVMError>
where
NR: Into<String>,
N: FnOnce() -> NR,
{
self.lk_record(name_fn, ROMType::U8, vec![expr])?;
self.lk_record(name_fn, ROMType::U8, vec![expr.expr()])?;
Ok(())
}

pub(crate) fn assert_bit<NR, N>(
&mut self,
name_fn: N,
expr: Expression<E>,
expr: impl ToExpr<E, Output = Expression<E>>,
) -> Result<(), ZKVMError>
where
NR: Into<String>,
N: FnOnce() -> NR,
{
let expr = expr.expr();
self.namespace(
|| "assert_bit",
|cb| cb.cs.require_zero(name_fn, &expr * (1 - &expr)),
Expand All @@ -362,10 +371,13 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
pub fn logic_u8(
&mut self,
rom_type: ROMType,
a: Expression<E>,
b: Expression<E>,
c: Expression<E>,
a: impl ToExpr<E, Output = Expression<E>>,
b: impl ToExpr<E, Output = Expression<E>>,
c: impl ToExpr<E, Output = Expression<E>>,
) -> Result<(), ZKVMError> {
let a = a.expr();
let b = b.expr();
let c = c.expr();
self.lk_record(|| format!("lookup_{:?}", rom_type), rom_type, vec![a, b, c])
}

Expand Down Expand Up @@ -402,31 +414,30 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
/// Assert that `(a < b) == c as bool`, that `a, b` are unsigned bytes, and that `c` is 0 or 1.
pub fn lookup_ltu_byte(
&mut self,
a: Expression<E>,
b: Expression<E>,
c: Expression<E>,
a: impl ToExpr<E, Output = Expression<E>>,
b: impl ToExpr<E, Output = Expression<E>>,
c: impl ToExpr<E, Output = Expression<E>>,
) -> Result<(), ZKVMError> {
self.logic_u8(ROMType::Ltu, a, b, c)
}

// Assert that `2^b = c` and that `b` is a 5-bit unsigned integer.
pub fn lookup_pow2(&mut self, b: Expression<E>, c: Expression<E>) -> Result<(), ZKVMError> {
self.logic_u8(ROMType::Pow, 2.into(), b, c)
self.logic_u8(ROMType::Pow, 2, b, c)
}

pub(crate) fn is_equal(
&mut self,
lhs: Expression<E>,
rhs: Expression<E>,
lhs: impl ToExpr<E, Output = Expression<E>>,
rhs: impl ToExpr<E, Output = Expression<E>>,
) -> Result<(WitIn, WitIn), ZKVMError> {
let lhs = lhs.expr();
let rhs = rhs.expr();
let is_eq = self.create_witin(|| "is_eq");
let diff_inverse = self.create_witin(|| "diff_inverse");

self.require_zero(|| "is equal", is_eq.expr() * &lhs - is_eq.expr() * &rhs)?;
self.require_zero(
|| "is equal",
1 - is_eq.expr() - diff_inverse.expr() * lhs + diff_inverse.expr() * rhs,
)?;
self.require_zero(|| "is equal", is_eq * &lhs - is_eq * &rhs)?;
self.require_zero(|| "is equal", 1 + diff_inverse * (rhs + lhs) - is_eq)?;

Ok((is_eq, diff_inverse))
}
Expand Down
15 changes: 11 additions & 4 deletions ceno_zkvm/src/chip_handler/global_state.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
use ff_ext::ExtensionField;

use crate::{
circuit_builder::CircuitBuilder, error::ZKVMError, expression::Expression, structs::RAMType,
circuit_builder::CircuitBuilder,
error::ZKVMError,
expression::{Expression, ToExpr},
structs::RAMType,
};

use super::GlobalStateRegisterMachineChipOperations;

impl<E: ExtensionField> GlobalStateRegisterMachineChipOperations<E> for CircuitBuilder<'_, E> {
fn state_in(&mut self, pc: Expression<E>, ts: Expression<E>) -> Result<(), ZKVMError> {
fn state_in(
&mut self,
pc: impl ToExpr<E, Output = Expression<E>>,
ts: impl ToExpr<E, Output = Expression<E>>,
) -> Result<(), ZKVMError> {
let record: Vec<Expression<E>> = vec![
Expression::Constant(E::BaseField::from(RAMType::GlobalState as u64)),
pc,
ts,
pc.expr(),
ts.expr(),
];

self.read_record(|| "state_in", RAMType::GlobalState, record)
Expand Down
12 changes: 6 additions & 6 deletions ceno_zkvm/src/chip_handler/register.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ impl<E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> RegisterChipOperati
fn register_read(
&mut self,
name_fn: N,
register_id: impl ToExpr<E, Output = Expression<E>>,
prev_ts: Expression<E>,
register_id: impl ToExpr<E, Output = Expression<E>> + std::marker::Copy,
prev_ts: impl ToExpr<E, Output = Expression<E>> + std::marker::Copy,
ts: Expression<E>,
value: RegisterExpr<E>,
) -> Result<(Expression<E>, AssertLTConfig), ZKVMError> {
Expand All @@ -28,7 +28,7 @@ impl<E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> RegisterChipOperati
vec![RAMType::Register.into()],
vec![register_id.expr()],
value.to_vec(),
vec![prev_ts.clone()],
vec![prev_ts.expr()],
]
.concat();
// Write (a, v, t)
Expand Down Expand Up @@ -60,8 +60,8 @@ impl<E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> RegisterChipOperati
fn register_write(
&mut self,
name_fn: N,
register_id: impl ToExpr<E, Output = Expression<E>>,
prev_ts: Expression<E>,
register_id: impl ToExpr<E, Output = Expression<E>> + std::marker::Copy,
prev_ts: impl ToExpr<E, Output = Expression<E>> + std::marker::Copy,
ts: Expression<E>,
prev_values: RegisterExpr<E>,
value: RegisterExpr<E>,
Expand All @@ -73,7 +73,7 @@ impl<E: ExtensionField, NR: Into<String>, N: FnOnce() -> NR> RegisterChipOperati
vec![RAMType::Register.into()],
vec![register_id.expr()],
prev_values.to_vec(),
vec![prev_ts.clone()],
vec![prev_ts.expr()],
]
.concat();
// Write (a, v, t)
Expand Down
5 changes: 3 additions & 2 deletions ceno_zkvm/src/circuit_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{
ROMType,
chip_handler::utils::rlc_chip_record,
error::ZKVMError,
expression::{Expression, Fixed, Instance, WitIn},
expression::{Expression, Fixed, Instance, ToExpr, WitIn},
structs::{ProgramParams, ProvingKey, RAMType, VerifyingKey, WitnessId},
witness::RowMajorMatrix,
};
Expand Down Expand Up @@ -440,8 +440,9 @@ impl<E: ExtensionField> ConstraintSystem<E> {
pub fn require_zero<NR: Into<String>, N: FnOnce() -> NR>(
&mut self,
name_fn: N,
assert_zero_expr: Expression<E>,
assert_zero_expr: impl ToExpr<E, Output = Expression<E>>,
) -> Result<(), ZKVMError> {
let assert_zero_expr = assert_zero_expr.expr();
assert!(
assert_zero_expr.degree() > 0,
"constant expression assert to zero ?"
Expand Down
Loading