Skip to content

Commit

Permalink
opt: parallelize the convesion from trace data to table (#200)
Browse files Browse the repository at this point in the history
* opt: parallelize the convesion from trace data to table

* chore: modify based on the comments

* chore: lint the code style

* chore: lint the code style

* chore: lint the code style

* opt: use into_par_iter in rayon

* chore: revert optimization
  • Loading branch information
felicityin authored Dec 29, 2024
1 parent 3e32a00 commit 3167ad8
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 107 deletions.
2 changes: 1 addition & 1 deletion emulator/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1049,7 +1049,7 @@ impl InstrumentedState {
return rt << shamt; // sll
} else if fun == 0x02 {
if (insn >> 21) & 0x1F == 1 {
return rt >> shamt | rt << (32 - shamt); // ror
return (rt >> shamt) | (rt << (32 - shamt)); // ror
} else if (insn >> 21) & 0x1F == 0 {
return rt >> shamt; // srl
}
Expand Down
6 changes: 3 additions & 3 deletions prover/src/arithmetic/arithmetic_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ impl<F: RichField, const D: usize> ArithmeticStark<F, D> {
}
}

pub(crate) fn generate_trace(&self, operations: Vec<Operation>) -> Vec<PolynomialValues<F>> {
pub(crate) fn generate_trace(&self, operations: &Vec<Operation>) -> Vec<PolynomialValues<F>> {
// The number of rows reserved is the smallest value that's
// guaranteed to avoid a reallocation: The only ops that use
// two rows are the modular operations and DIV, so the only
Expand Down Expand Up @@ -345,7 +345,7 @@ mod tests {

let ops: Vec<Operation> = vec![add, mul, div0, div1, divu, mult0, mult1, multu];

let pols = stark.generate_trace(ops);
let pols = stark.generate_trace(&ops);

// Trace should always have NUM_ARITH_COLUMNS columns and
// min(RANGE_MAX, operations.len()) rows. In this case there
Expand Down Expand Up @@ -398,7 +398,7 @@ mod tests {
.map(|_| Operation::binary(BinaryOperator::MULT, rng.gen::<u32>(), rng.gen::<u32>()))
.collect::<Vec<_>>();

let pols = stark.generate_trace(ops);
let pols = stark.generate_trace(&ops);

// Trace should always have NUM_ARITH_COLUMNS columns and
// min(RANGE_MAX, operations.len()) rows. In this case there
Expand Down
2 changes: 1 addition & 1 deletion prover/src/arithmetic/slt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub(crate) fn generate<F: PrimeField64>(
let (diff, cy) = left_in.overflowing_sub(right_in);
let mut cy_val = cy as u32;
if (left_in & 0x80000000u32) != (right_in & 0x80000000u32) {
cy_val = 1u32 << 16 | (!cy as u32);
cy_val = (1u32 << 16) | (!cy as u32);
}

u32_to_array(&mut lv[AUX_INPUT_REGISTER_0], diff);
Expand Down
16 changes: 2 additions & 14 deletions prover/src/logic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ use plonky2::field::polynomial::PolynomialValues;
use plonky2::field::types::Field;
use plonky2::hash::hash_types::RichField;
use plonky2::iop::ext_target::ExtensionTarget;
use plonky2::timed;
use plonky2::util::timing::TimingTree;
use plonky2_util::ceil_div_usize;

use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
Expand Down Expand Up @@ -149,19 +147,9 @@ impl<F: RichField, const D: usize> LogicStark<F, D> {
&self,
operations: Vec<Operation>,
min_rows: usize,
timing: &mut TimingTree,
) -> Vec<PolynomialValues<F>> {
let trace_rows = timed!(
timing,
"generate trace rows",
self.generate_trace_rows(operations, min_rows)
);
let trace_polys = timed!(
timing,
"convert to PolynomialValues",
trace_rows_to_poly_values(trace_rows)
);
trace_polys
let trace_rows = self.generate_trace_rows(operations, min_rows);
trace_rows_to_poly_values(trace_rows)
}

fn generate_trace_rows(
Expand Down
17 changes: 5 additions & 12 deletions prover/src/memory/memory_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ use plonky2::field::polynomial::PolynomialValues;
use plonky2::field::types::Field;
use plonky2::hash::hash_types::RichField;
use plonky2::iop::ext_target::ExtensionTarget;
use plonky2::timed;
use plonky2::util::timing::TimingTree;
use plonky2::util::transpose;
use plonky2_maybe_rayon::*;

Expand Down Expand Up @@ -134,12 +132,12 @@ pub fn generate_first_change_flags_and_rc<F: RichField>(trace_rows: &mut [[F; NU
impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {
/// Generate most of the trace rows. Excludes a few columns like `COUNTER`, which are generated
/// later, after transposing to column-major form.
fn generate_trace_row_major(&self, mut memory_ops: Vec<MemoryOp>) -> Vec<[F; NUM_COLUMNS]> {
fn generate_trace_row_major(&self, memory_ops: &mut Vec<MemoryOp>) -> Vec<[F; NUM_COLUMNS]> {
// fill_gaps expects an ordered list of operations.
memory_ops.sort_by_key(MemoryOp::sorting_key);
Self::fill_gaps(&mut memory_ops);
Self::fill_gaps(memory_ops);

Self::pad_memory_ops(&mut memory_ops);
Self::pad_memory_ops(memory_ops);

// fill_gaps may have added operations at the end which break the order, so sort again.
memory_ops.sort_by_key(MemoryOp::sorting_key);
Expand Down Expand Up @@ -227,15 +225,10 @@ impl<F: RichField + Extendable<D>, const D: usize> MemoryStark<F, D> {

pub(crate) fn generate_trace(
&self,
memory_ops: Vec<MemoryOp>,
timing: &mut TimingTree,
memory_ops: &mut Vec<MemoryOp>,
) -> Vec<PolynomialValues<F>> {
// Generate most of the trace in row-major form.
let trace_rows = timed!(
timing,
"generate trace rows",
self.generate_trace_row_major(memory_ops)
);
let trace_rows = self.generate_trace_row_major(memory_ops);
let trace_row_vecs: Vec<_> = trace_rows.into_iter().map(|row| row.to_vec()).collect();

// Transpose to column-major form.
Expand Down
28 changes: 6 additions & 22 deletions prover/src/poseidon/poseidon_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ use plonky2::field::types::PrimeField64;
use plonky2::hash::hash_types::RichField;
use plonky2::iop::ext_target::ExtensionTarget;
use plonky2::plonk::circuit_builder::CircuitBuilder;
use plonky2::timed;
use plonky2::util::timing::TimingTree;

use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::cross_table_lookup::{Column, Filter};
Expand Down Expand Up @@ -106,7 +104,7 @@ impl<F: RichField + Extendable<D>, const D: usize> PoseidonStark<F, D> {
/// in our lookup arguments, as those are computed after transposing to column-wise form.
fn generate_trace_rows(
&self,
inputs_and_timestamps: Vec<([F; SPONGE_WIDTH], usize)>,
inputs_and_timestamps: &[([F; SPONGE_WIDTH], usize)],
min_rows: usize,
) -> Vec<[F; NUM_COLUMNS]> {
let num_rows = inputs_and_timestamps
Expand Down Expand Up @@ -148,22 +146,12 @@ impl<F: RichField + Extendable<D>, const D: usize> PoseidonStark<F, D> {

pub fn generate_trace(
&self,
inputs: Vec<([F; SPONGE_WIDTH], usize)>,
inputs: &[([F; SPONGE_WIDTH], usize)],
min_rows: usize,
timing: &mut TimingTree,
) -> Vec<PolynomialValues<F>> {
// Generate the witness, except for permuted columns in the lookup argument.
let trace_rows = timed!(
timing,
"generate trace rows",
self.generate_trace_rows(inputs, min_rows)
);
let trace_polys = timed!(
timing,
"convert to PolynomialValues",
trace_rows_to_poly_values(trace_rows)
);
trace_polys
let trace_rows = self.generate_trace_rows(inputs, min_rows);
trace_rows_to_poly_values(trace_rows)
}
}

Expand Down Expand Up @@ -745,7 +733,7 @@ mod tests {
init_logger();

let input: ([F; SPONGE_WIDTH], usize) = (F::rand_array(), 0);
let rows = stark.generate_trace_rows(vec![input], 4);
let rows = stark.generate_trace_rows(&[input], 4);

let mut constraint_consumer = ConstraintConsumer::new(
vec![GoldilocksField(2), GoldilocksField(3), GoldilocksField(5)],
Expand Down Expand Up @@ -775,11 +763,7 @@ mod tests {
(0..NUM_PERMS).map(|_| (F::rand_array(), 0)).collect();

let mut timing = TimingTree::new("prove", log::Level::Debug);
let trace_poly_values = timed!(
timing,
"generate trace",
stark.generate_trace(input, 8, &mut timing)
);
let trace_poly_values = stark.generate_trace(&input, 8);

// TODO: Cloning this isn't great; consider having `from_values` accept a reference,
// or having `compute_permutation_z_polys` read trace values from the `PolynomialBatch`.
Expand Down
30 changes: 8 additions & 22 deletions prover/src/poseidon_sponge/poseidon_sponge_stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ use plonky2::field::polynomial::PolynomialValues;
use plonky2::field::types::{Field, PrimeField64};
use plonky2::hash::hash_types::RichField;
use plonky2::iop::ext_target::ExtensionTarget;
use plonky2::timed;
use plonky2::util::timing::TimingTree;

use crate::constraint_consumer::{ConstraintConsumer, RecursiveConstraintConsumer};
use crate::cross_table_lookup::{Column, Filter};
Expand Down Expand Up @@ -188,29 +186,17 @@ pub struct PoseidonSpongeStark<F, const D: usize> {
impl<F: RichField + Extendable<D>, const D: usize> PoseidonSpongeStark<F, D> {
pub(crate) fn generate_trace(
&self,
operations: Vec<PoseidonSpongeOp>,
operations: &Vec<PoseidonSpongeOp>,
min_rows: usize,
timing: &mut TimingTree,
) -> Vec<PolynomialValues<F>> {
// Generate the witness row-wise.
let trace_rows = timed!(
timing,
"generate trace rows",
self.generate_trace_rows(operations, min_rows)
);

let trace_polys = timed!(
timing,
"convert to PolynomialValues",
trace_rows_to_poly_values(trace_rows)
);

trace_polys
let trace_rows = self.generate_trace_rows(operations, min_rows);
trace_rows_to_poly_values(trace_rows)
}

fn generate_trace_rows(
&self,
operations: Vec<PoseidonSpongeOp>,
operations: &Vec<PoseidonSpongeOp>,
min_rows: usize,
) -> Vec<[F; NUM_POSEIDON_SPONGE_COLUMNS]> {
let base_len: usize = operations
Expand All @@ -228,7 +214,7 @@ impl<F: RichField + Extendable<D>, const D: usize> PoseidonSpongeStark<F, D> {
rows
}

fn generate_rows_for_op(&self, op: PoseidonSpongeOp) -> Vec<[F; NUM_POSEIDON_SPONGE_COLUMNS]> {
fn generate_rows_for_op(&self, op: &PoseidonSpongeOp) -> Vec<[F; NUM_POSEIDON_SPONGE_COLUMNS]> {
let mut rows = Vec::with_capacity(op.input.len() / POSEIDON_RATE_BYTES + 1);

let mut sponge_state = [F::ZEROS; SPONGE_WIDTH];
Expand All @@ -237,7 +223,7 @@ impl<F: RichField + Extendable<D>, const D: usize> PoseidonSpongeStark<F, D> {
let mut already_absorbed_bytes = 0;
for block in input_blocks.by_ref() {
let row = self.generate_full_input_row(
&op,
op,
already_absorbed_bytes,
sponge_state,
block.try_into().unwrap(),
Expand All @@ -252,7 +238,7 @@ impl<F: RichField + Extendable<D>, const D: usize> PoseidonSpongeStark<F, D> {

rows.push(
self.generate_final_row(
&op,
op,
already_absorbed_bytes,
sponge_state,
input_blocks.remainder(),
Expand Down Expand Up @@ -690,7 +676,7 @@ mod tests {
input,
};
let stark = S::default();
let rows = stark.generate_rows_for_op(op);
let rows = stark.generate_rows_for_op(&op);
assert_eq!(rows.len(), 1);
let last_row: &PoseidonSpongeColumnsView<F> = rows.last().unwrap().borrow();
let output = last_row
Expand Down
72 changes: 40 additions & 32 deletions prover/src/witness/traces.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use plonky2::field::polynomial::PolynomialValues;
use plonky2::hash::hash_types::RichField;
use plonky2::timed;
use plonky2::util::timing::TimingTree;
use plonky2_maybe_rayon::rayon;
use std::cmp::max;

use crate::all_stark::{AllStark, MIN_TRACE_LEN, NUM_TABLES};
Expand Down Expand Up @@ -146,43 +147,50 @@ impl<T: Copy> Traces<T> {
arithmetic_ops,
cpu,
logic_ops,
memory_ops,
mut memory_ops,
poseidon_inputs,
poseidon_sponge_ops,
} = self;

let arithmetic_trace = timed!(
timing,
"generate arithmetic trace",
all_stark.arithmetic_stark.generate_trace(arithmetic_ops)
);
let cpu_rows: Vec<_> = cpu.into_iter().map(|x| x.into()).collect();
let cpu_trace = trace_rows_to_poly_values(cpu_rows);
let poseidon_trace = timed!(
timing,
"generate Poseidon trace",
all_stark
.poseidon_stark
.generate_trace(poseidon_inputs, min_rows, timing)
);
let poseidon_sponge_trace = timed!(
timing,
"generate Poseidon sponge trace",
all_stark
.poseidon_sponge_stark
.generate_trace(poseidon_sponge_ops, min_rows, timing)
);
let logic_trace = timed!(
timing,
"generate logic trace",
all_stark
.logic_stark
.generate_trace(logic_ops, min_rows, timing)
);
let memory_trace = timed!(
let mut memory_trace = vec![];
let mut arithmetic_trace = vec![];
let mut cpu_trace = vec![];
let mut poseidon_trace = vec![];
let mut poseidon_sponge_trace = vec![];
let mut logic_trace = vec![];

timed!(
timing,
"generate memory trace",
all_stark.memory_stark.generate_trace(memory_ops, timing)
"convert trace to table parallelly",
rayon::join(
|| rayon::join(
|| memory_trace = all_stark.memory_stark.generate_trace(&mut memory_ops,),
|| arithmetic_trace =
all_stark.arithmetic_stark.generate_trace(&arithmetic_ops),
),
|| {
rayon::join(
|| {
cpu_trace = trace_rows_to_poly_values(
cpu.into_iter().map(|x| x.into()).collect(),
)
},
|| {
poseidon_trace = all_stark
.poseidon_stark
.generate_trace(&poseidon_inputs, min_rows)
},
);
rayon::join(
|| {
poseidon_sponge_trace = all_stark
.poseidon_sponge_stark
.generate_trace(&poseidon_sponge_ops, min_rows)
},
|| logic_trace = all_stark.logic_stark.generate_trace(logic_ops, min_rows),
);
},
)
);

[
Expand Down

0 comments on commit 3167ad8

Please sign in to comment.