Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the number of segments #170

Merged
merged 2 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion emulator/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ pub struct InstrumentedState {
/// writer for stderr
stderr_writer: Box<dyn Write>,

pre_segment_id: u32,
pub pre_segment_id: u32,
pre_pc: u32,
pre_image_id: [u8; 32],
pre_hash_root: [u8; 32],
Expand Down
3 changes: 2 additions & 1 deletion emulator/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub fn split_prog_into_segs(
seg_path: &str,
block_path: &str,
seg_size: usize,
) -> (usize, Box<State>) {
) -> (usize, usize, Box<State>) {
let mut instrumented_state = InstrumentedState::new(state, block_path.to_string());
std::fs::create_dir_all(seg_path).unwrap();
let new_writer = |_: &str| -> Option<std::fs::File> { None };
Expand All @@ -50,6 +50,7 @@ pub fn split_prog_into_segs(
instrumented_state.dump_memory();
(
instrumented_state.state.total_step as usize,
instrumented_state.pre_segment_id as usize,
instrumented_state.state,
)
}
79 changes: 26 additions & 53 deletions prover/examples/zkmips.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,9 @@ fn split_segments() {
let _ = split_prog_into_segs(state, &seg_path, &block_path, seg_size);
}

fn prove_single_seg_common(
seg_file: &str,
basedir: &str,
block: &str,
file: &str,
seg_size: usize,
) {
fn prove_single_seg_common(seg_file: &str, basedir: &str, block: &str, file: &str) {
let seg_reader = BufReader::new(File::open(seg_file).unwrap());
let kernel = segment_kernel(basedir, block, file, seg_reader, seg_size);
let kernel = segment_kernel(basedir, block, file, seg_reader);

const D: usize = 2;
type C = PoseidonGoldilocksConfig;
Expand Down Expand Up @@ -82,7 +76,6 @@ fn prove_multi_seg_common(
basedir: &str,
block: &str,
file: &str,
seg_size: usize,
seg_file_number: usize,
seg_start_id: usize,
) -> anyhow::Result<()> {
Expand All @@ -107,7 +100,7 @@ fn prove_multi_seg_common(
let seg_file = format!("{}/{}", seg_dir, seg_start_id);
log::info!("Process segment {}", seg_file);
let seg_reader = BufReader::new(File::open(seg_file)?);
let input_first = segment_kernel(basedir, block, file, seg_reader, seg_size);
let input_first = segment_kernel(basedir, block, file, seg_reader);
let mut timing = TimingTree::new("prove root first", log::Level::Info);
let (mut agg_proof, mut updated_agg_public_values) =
all_circuits.prove_root(&all_stark, &input_first, &config, &mut timing)?;
Expand All @@ -123,7 +116,7 @@ fn prove_multi_seg_common(
let seg_file = format!("{}/{}", seg_dir, seg_start_id + 1);
log::info!("Process segment {}", seg_file);
let seg_reader = BufReader::new(File::open(seg_file)?);
let input = segment_kernel(basedir, block, file, seg_reader, seg_size);
let input = segment_kernel(basedir, block, file, seg_reader);
timing = TimingTree::new("prove root second", log::Level::Info);
let (root_proof, public_values) =
all_circuits.prove_root(&all_stark, &input, &config, &mut timing)?;
Expand Down Expand Up @@ -158,7 +151,7 @@ fn prove_multi_seg_common(
let seg_file = format!("{}/{}", seg_dir, base_seg + (i << 1));
log::info!("Process segment {}", seg_file);
let seg_reader = BufReader::new(File::open(&seg_file)?);
let input_first = segment_kernel(basedir, block, file, seg_reader, seg_size);
let input_first = segment_kernel(basedir, block, file, seg_reader);
let mut timing = TimingTree::new("prove root first", log::Level::Info);
let (root_proof_first, first_public_values) =
all_circuits.prove_root(&all_stark, &input_first, &config, &mut timing)?;
Expand All @@ -169,7 +162,7 @@ fn prove_multi_seg_common(
let seg_file = format!("{}/{}", seg_dir, base_seg + (i << 1) + 1);
log::info!("Process segment {}", seg_file);
let seg_reader = BufReader::new(File::open(&seg_file)?);
let input = segment_kernel(basedir, block, file, seg_reader, seg_size);
let input = segment_kernel(basedir, block, file, seg_reader);
let mut timing = TimingTree::new("prove root second", log::Level::Info);
let (root_proof, public_values) =
all_circuits.prove_root(&all_stark, &input, &config, &mut timing)?;
Expand Down Expand Up @@ -270,17 +263,18 @@ fn prove_sha2_rust() {
log::info!("private input value: {:X?}", private_input);
state.add_input_stream(&private_input);

let (total_steps, mut state) = split_prog_into_segs(state, &seg_path, "", seg_size);
let (_total_steps, seg_num, mut state) = split_prog_into_segs(state, &seg_path, "", seg_size);

let value = state.read_public_values::<[u8; 32]>();
log::info!("public value: {:X?}", value);
log::info!("public value: {} in hex", hex::encode(value));

let mut seg_num = 1usize;
if seg_size != 0 {
seg_num = (total_steps + seg_size - 1).div_ceil(seg_size);
if seg_num == 1 {
let seg_file = format!("{seg_path}/{}", 0);
prove_single_seg_common(&seg_file, "", "", "")
} else {
prove_multi_seg_common(&seg_path, "", "", "", seg_num, 0).unwrap()
}
prove_multi_seg_common(&seg_path, "", "", "", seg_size, seg_num, 0).unwrap()
}

fn prove_sha2_go() {
Expand Down Expand Up @@ -312,17 +306,17 @@ fn prove_sha2_go() {
);
log::info!("public input: {:X?}", data);

let (total_steps, mut state) = split_prog_into_segs(state, &seg_path, "", seg_size);
let (_total_steps, seg_num, mut state) = split_prog_into_segs(state, &seg_path, "", seg_size);

let value = state.read_public_values::<Data>();
log::info!("public value: {:X?}", value);

let mut seg_num = 1usize;
if seg_size != 0 {
seg_num = (total_steps + seg_size - 1).div_ceil(seg_size);
if seg_num == 1 {
let seg_file = format!("{seg_path}/{}", 0);
prove_single_seg_common(&seg_file, "", "", "")
} else {
prove_multi_seg_common(&seg_path, "", "", "", seg_num, 0).unwrap()
}

prove_multi_seg_common(&seg_path, "", "", "", seg_size, seg_num, 0).unwrap()
}

fn prove_revm() {
Expand All @@ -340,18 +334,13 @@ fn prove_revm() {
// load input
state.add_input_stream(&data);

let (total_steps, mut _state) = split_prog_into_segs(state, &seg_path, "", seg_size);

let mut seg_num = 1usize;
if seg_size != 0 {
seg_num = (total_steps + seg_size - 1).div_ceil(seg_size);
}
let (_total_steps, seg_num, mut _state) = split_prog_into_segs(state, &seg_path, "", seg_size);

if seg_num == 1 {
let seg_file = format!("{seg_path}/{}", 0);
prove_single_seg_common(&seg_file, "", "", "", total_steps)
prove_single_seg_common(&seg_file, "", "", "")
} else {
prove_multi_seg_common(&seg_path, "", "", "", seg_size, seg_num, 0).unwrap()
prove_multi_seg_common(&seg_path, "", "", "", seg_num, 0).unwrap()
}
}

Expand Down Expand Up @@ -423,21 +412,16 @@ fn prove_add_example() {
);
log::info!("public input: {:X?}", data);

let (total_steps, mut state) = split_prog_into_segs(state, &seg_path, "", seg_size);
let (_total_steps, seg_num, mut state) = split_prog_into_segs(state, &seg_path, "", seg_size);

let value = state.read_public_values::<Data>();
log::info!("public value: {:X?}", value);

let mut seg_num = 1usize;
if seg_size != 0 {
seg_num = (total_steps + seg_size - 1).div_ceil(seg_size);
}

if seg_num == 1 {
let seg_file = format!("{seg_path}/{}", 0);
prove_single_seg_common(&seg_file, "", "", "", total_steps)
prove_single_seg_common(&seg_file, "", "", "")
} else {
prove_multi_seg_common(&seg_path, "", "", "", seg_size, seg_num, 0).unwrap()
prove_multi_seg_common(&seg_path, "", "", "", seg_num, 0).unwrap()
}
}

Expand All @@ -461,23 +445,12 @@ fn prove_segments() {
let seg_num = seg_num.parse::<_>().unwrap_or(1usize);
let seg_start_id = env::var("SEG_START_ID").unwrap_or("0".to_string());
let seg_start_id = seg_start_id.parse::<_>().unwrap_or(0usize);
let seg_size = env::var("SEG_SIZE").unwrap_or(format!("{SEGMENT_STEPS}"));
let seg_size = seg_size.parse::<_>().unwrap_or(SEGMENT_STEPS);

if seg_num == 1 {
let seg_file = format!("{seg_dir}/{}", seg_start_id);
prove_single_seg_common(&seg_file, &basedir, &block, &file, seg_size)
prove_single_seg_common(&seg_file, &basedir, &block, &file)
} else {
prove_multi_seg_common(
&seg_dir,
&basedir,
&block,
&file,
seg_size,
seg_num,
seg_start_id,
)
.unwrap()
prove_multi_seg_common(&seg_dir, &basedir, &block, &file, seg_num, seg_start_id).unwrap()
}
}

Expand Down
16 changes: 1 addition & 15 deletions prover/src/cpu/kernel/assembler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,19 @@ pub struct Kernel {
// should be preprocessed after loading code
pub(crate) global_labels: HashMap<String, usize>,
pub blockpath: String,
pub steps: usize,
}

pub const MAX_MEM: u32 = 0x80000000;

pub fn segment_kernel<T: Read>(
basedir: &str,
block: &str,
file: &str,
seg_reader: T,
steps: usize,
) -> Kernel {
pub fn segment_kernel<T: Read>(basedir: &str, block: &str, file: &str, seg_reader: T) -> Kernel {
let p: Program = Program::load_segment(seg_reader).unwrap();
let blockpath = get_block_path(basedir, block, file);

let mut final_step = steps;
if p.step != 0 {
assert!(p.step <= steps);
final_step = p.step;
}

Kernel {
program: p,
ordered_labels: vec![],
global_labels: HashMap::new(),
blockpath,
steps: final_step,
}
}

Expand Down
2 changes: 1 addition & 1 deletion prover/src/generation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub fn generate_traces<F: RichField + Extendable<D>, const D: usize>(
// Decode the trace record
// 1. Decode instruction and fill in cpu columns
// 2. Decode memory and fill in memory columns
let mut state = GenerationState::<F>::new(kernel.steps, kernel).unwrap();
let mut state = GenerationState::<F>::new(kernel.program.step, kernel).unwrap();
generate_bootstrap_kernel::<F>(&mut state, kernel);

timed!(timing, "simulate CPU", simulate_cpu(&mut state, kernel)?);
Expand Down
Loading