From fa1e2f914d338d450f97d8c513136eaeb18d1324 Mon Sep 17 00:00:00 2001 From: 0xaatif <169152398+0xaatif@users.noreply.github.com> Date: Wed, 16 Oct 2024 16:34:36 +0100 Subject: [PATCH] feat: SMT support in `trace_decoder` ignores storage (#693) --- .github/workflows/lint.yml | 12 +- Cargo.lock | 11 + Cargo.toml | 3 - trace_decoder/Cargo.toml | 3 +- trace_decoder/benches/block_processing.rs | 2 + trace_decoder/src/core.rs | 224 +++++++++--- trace_decoder/src/lib.rs | 8 +- trace_decoder/src/observer.rs | 2 +- trace_decoder/src/{typed_mpt.rs => tries.rs} | 345 ++++++++++++++---- trace_decoder/src/type1.rs | 15 +- trace_decoder/src/type2.rs | 294 ++++++++------- trace_decoder/src/wire.rs | 8 +- trace_decoder/tests/consistent-with-header.rs | 2 + trace_decoder/tests/simulate-execution.rs | 16 +- zero/Cargo.toml | 1 + zero/src/bin/rpc.rs | 2 + zero/src/bin/trie_diff.rs | 2 + zero/src/prover.rs | 16 +- 18 files changed, 681 insertions(+), 285 deletions(-) rename trace_decoder/src/{typed_mpt.rs => tries.rs} (58%) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 446d6a314..13089b3d0 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -13,8 +13,6 @@ concurrency: env: CARGO_TERM_COLOR: always - BINSTALL_NO_CONFIRM: true - RUSTDOCFLAGS: "-D warnings" jobs: clippy: @@ -42,7 +40,15 @@ jobs: steps: - uses: actions/checkout@v3 - uses: ./.github/actions/rust - - run: cargo doc --all --no-deps + - run: RUSTDOCFLAGS='-D warnings -A rustdoc::private_intra_doc_links' cargo doc --all --no-deps + # TODO(zero): https://github.com/0xPolygonZero/zk_evm/issues/718 + - run: > + RUSTDOCFLAGS='-D warnings -A rustdoc::private_intra_doc_links' cargo doc --no-deps --document-private-items + --package trace_decoder + --package compat + --package smt_trie + --package zk_evm_proc_macro + --package zk_evm_common cargo-fmt: runs-on: ubuntu-latest timeout-minutes: 5 diff --git a/Cargo.lock b/Cargo.lock index d642c2107..ea4fb8060 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1151,6 +1151,15 @@ dependencies = [ "zeroize", ] +[[package]] +name = "build-array" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67ef4e2687af237b2646687e19a0643bc369878216122e46c3f1a01c56baa9d5" +dependencies = [ + "arrayvec", +] + [[package]] name = "bumpalo" version = "3.16.0" @@ -5126,6 +5135,7 @@ dependencies = [ "assert2", "bitflags 2.6.0", "bitvec", + "build-array", "bytes", "camino", "ciborium", @@ -5800,6 +5810,7 @@ dependencies = [ "anyhow", "async-stream", "axum", + "cfg-if", "clap", "compat", "directories", diff --git a/Cargo.toml b/Cargo.toml index 306530217..e47c71386 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,7 +41,6 @@ axum = "0.7.5" bitflags = "2.5.0" bitvec = "1.0.1" bytes = "1.6.0" -cargo_metadata = "0.18.1" ciborium = "0.2.2" ciborium-io = "0.2.2" clap = { version = "4.5.7", features = ["derive", "env"] } @@ -72,7 +71,6 @@ nunny = "0.2.1" once_cell = "1.19.0" paladin-core = "0.4.3" parking_lot = "0.12.3" -paste = "1.0.15" pest = "2.7.10" pest_derive = "2.7.10" pretty_env_logger = "0.5.0" @@ -94,7 +92,6 @@ syn = "2.0" thiserror = "1.0.61" tiny-keccak = "2.0.2" tokio = { version = "1.38.0", features = ["full"] } -toml = "0.8.14" tower = "0.4" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/trace_decoder/Cargo.toml b/trace_decoder/Cargo.toml index de9b389a2..5bdd24f12 100644 --- a/trace_decoder/Cargo.toml +++ b/trace_decoder/Cargo.toml @@ -15,6 +15,7 @@ alloy-compat = "0.1.0" anyhow.workspace = true bitflags.workspace = true bitvec.workspace = true +build-array = "0.1.2" bytes.workspace = true ciborium.workspace = true ciborium-io.workspace = true @@ -33,6 +34,7 @@ nunny = { workspace = true, features = ["serde"] } plonky2.workspace = true rlp.workspace = true serde.workspace = true +smt_trie.workspace = true stackstack = "0.3.0" strum = { version = "0.26.3", features = ["derive"] } thiserror.workspace = true @@ -52,7 +54,6 @@ libtest-mimic = "0.7.3" plonky2_maybe_rayon.workspace = true serde_json.workspace = true serde_path_to_error.workspace = true -smt_trie.workspace = true zero.workspace = true [features] diff --git a/trace_decoder/benches/block_processing.rs b/trace_decoder/benches/block_processing.rs index adefdae3f..4e3582e98 100644 --- a/trace_decoder/benches/block_processing.rs +++ b/trace_decoder/benches/block_processing.rs @@ -8,6 +8,7 @@ use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; use trace_decoder::observer::DummyObserver; use trace_decoder::{BlockTrace, OtherBlockData}; +use zero::prover::WIRE_DISPOSITION; #[derive(Clone, Debug, serde::Deserialize)] pub struct ProverInput { @@ -39,6 +40,7 @@ fn criterion_benchmark(c: &mut Criterion) { other_data, batch_size, &mut DummyObserver::new(), + WIRE_DISPOSITION, ) .unwrap() }, diff --git a/trace_decoder/src/core.rs b/trace_decoder/src/core.rs index c5aa890da..46495030f 100644 --- a/trace_decoder/src/core.rs +++ b/trace_decoder/src/core.rs @@ -6,6 +6,7 @@ use std::{ use alloy_compat::Compat as _; use anyhow::{anyhow, bail, ensure, Context as _}; +use either::Either; use ethereum_types::{Address, U256}; use evm_arithmetization::{ generation::{mpt::AccountRlp, TrieInputs}, @@ -18,20 +19,35 @@ use mpt_trie::partial_trie::PartialTrie as _; use nunny::NonEmpty; use zk_evm_common::gwei_to_wei; -use crate::observer::Observer; use crate::{ - typed_mpt::{ReceiptTrie, StateMpt, StateTrie, StorageTrie, TransactionTrie, TrieKey}, + observer::{DummyObserver, Observer}, + tries::StateSmt, +}; +use crate::{ + tries::{MptKey, ReceiptTrie, StateMpt, StateTrie, StorageTrie, TransactionTrie}, BlockLevelData, BlockTrace, BlockTraceTriePreImages, CombinedPreImages, ContractCodeUsage, OtherBlockData, SeparateStorageTriesPreImage, SeparateTriePreImage, SeparateTriePreImages, TxnInfo, TxnMeta, TxnTrace, }; +/// Expected trie type when parsing from binary in a [`BlockTrace`]. +/// +/// See [`crate::wire`] and [`CombinedPreImages`] for more. +#[derive(Debug)] +pub enum WireDisposition { + /// MPT + Type1, + /// SMT + Type2, +} + /// TODO(0xaatif): document this after pub fn entrypoint( trace: BlockTrace, other: OtherBlockData, batch_size_hint: usize, observer: &mut impl Observer, + wire_disposition: WireDisposition, ) -> anyhow::Result> { ensure!(batch_size_hint != 0); @@ -45,8 +61,8 @@ pub fn entrypoint( BlockTraceTriePreImages::Separate(_) => FatalMissingCode(true), BlockTraceTriePreImages::Combined(_) => FatalMissingCode(false), }; + let (state, storage, mut code) = start(trie_pre_images, wire_disposition)?; - let (state, storage, mut code) = start(trie_pre_images)?; code.extend(code_db); let OtherBlockData { @@ -66,17 +82,40 @@ pub fn entrypoint( *amt = gwei_to_wei(*amt) } - let batches = middle( - state, - storage, - batch(txn_info, batch_size_hint), - &mut code, - &b_meta, - ger_data, - withdrawals, - fatal_missing_code, - observer, - )?; + let batches = match state { + Either::Left(mpt) => Either::Left( + middle( + mpt, + storage, + batch(txn_info, batch_size_hint), + &mut code, + &b_meta, + ger_data, + withdrawals, + fatal_missing_code, + observer, + )? + .into_iter() + .map(|it| it.map(Either::Left)), + ), + Either::Right(smt) => { + Either::Right( + middle( + smt, + storage, + batch(txn_info, batch_size_hint), + &mut code, + &b_meta, + ger_data, + withdrawals, + fatal_missing_code, + &mut DummyObserver::new(), // TODO(0xaatif) + )? + .into_iter() + .map(|it| it.map(Either::Right)), + ) + } + }; let mut running_gas_used = 0; Ok(batches @@ -107,7 +146,10 @@ pub fn entrypoint( withdrawals, ger_data, tries: TrieInputs { - state_trie: state.into(), + state_trie: match state { + Either::Left(mpt) => mpt.into(), + Either::Right(_) => todo!("evm_arithmetization accepts an SMT"), + }, transactions_trie: transaction.into(), receipts_trie: receipt.into(), storage_tries: storage.into_iter().map(|(k, v)| (k, v.into())).collect(), @@ -131,11 +173,16 @@ pub fn entrypoint( /// [`HashedPartialTrie`](mpt_trie::partial_trie::HashedPartialTrie), /// or a [`wire`](crate::wire)-encoded representation of one. /// -/// Turn either of those into our [`typed_mpt`](crate::typed_mpt) -/// representations. +/// Turn either of those into our [internal representations](crate::tries). +#[allow(clippy::type_complexity)] fn start( pre_images: BlockTraceTriePreImages, -) -> anyhow::Result<(StateMpt, BTreeMap, Hash2Code)> { + wire_disposition: WireDisposition, +) -> anyhow::Result<( + Either, + BTreeMap, + Hash2Code, +)> { Ok(match pre_images { // TODO(0xaatif): https://github.com/0xPolygonZero/zk_evm/issues/401 // refactor our convoluted input types @@ -146,7 +193,7 @@ fn start( let state = state.items().try_fold( StateMpt::default(), |mut acc, (nibbles, hash_or_val)| { - let path = TrieKey::from_nibbles(nibbles); + let path = MptKey::from_nibbles(nibbles); match hash_or_val { mpt_trie::trie_ops::ValOrHash::Val(bytes) => { #[expect(deprecated)] // this is MPT specific @@ -169,7 +216,7 @@ fn start( .map(|(k, SeparateTriePreImage::Direct(v))| { v.items() .try_fold(StorageTrie::default(), |mut acc, (nibbles, hash_or_val)| { - let path = TrieKey::from_nibbles(nibbles); + let path = MptKey::from_nibbles(nibbles); match hash_or_val { mpt_trie::trie_ops::ValOrHash::Val(value) => { acc.insert(path, value)?; @@ -183,17 +230,35 @@ fn start( .map(|v| (k, v)) }) .collect::>()?; - (state, storage, Hash2Code::new()) + (Either::Left(state), storage, Hash2Code::new()) } BlockTraceTriePreImages::Combined(CombinedPreImages { compact }) => { let instructions = crate::wire::parse(&compact) .context("couldn't parse instructions from binary format")?; - let crate::type1::Frontend { - state, - storage, - code, - } = crate::type1::frontend(instructions)?; - (state, storage, code.into_iter().map(Into::into).collect()) + let (state, storage, code) = match wire_disposition { + WireDisposition::Type1 => { + let crate::type1::Frontend { + state, + storage, + code, + } = crate::type1::frontend(instructions)?; + ( + Either::Left(state), + storage, + Hash2Code::from_iter(code.into_iter().map(NonEmpty::into_vec)), + ) + } + WireDisposition::Type2 => { + let crate::type2::Frontend { trie, code } = + crate::type2::frontend(instructions)?; + ( + Either::Right(trie), + BTreeMap::new(), + Hash2Code::from_iter(code.into_iter().map(NonEmpty::into_vec)), + ) + } + }; + (state, storage, code) } }) } @@ -267,6 +332,29 @@ struct Batch { pub withdrawals: Vec<(Address, U256)>, } +impl Batch { + fn map(self, f: impl FnMut(T) -> U) -> Batch { + let Self { + first_txn_ix, + gas_used, + contract_code, + byte_code, + before, + after, + withdrawals, + } = self; + Batch { + first_txn_ix, + gas_used, + contract_code, + byte_code, + before: before.map(f), + after, + withdrawals, + } + } +} + /// [`evm_arithmetization::generation::TrieInputs`], /// generic over state trie representation. #[derive(Debug)] @@ -277,6 +365,22 @@ pub struct IntraBlockTries { pub receipt: ReceiptTrie, } +impl IntraBlockTries { + fn map(self, mut f: impl FnMut(T) -> U) -> IntraBlockTries { + let Self { + state, + storage, + transaction, + receipt, + } = self; + IntraBlockTries { + state: f(state), + storage, + transaction, + receipt, + } + } +} /// Hacky handling of possibly missing contract bytecode in `Hash2Code` inner /// map. /// Allows incomplete payloads fetched with the zero tracer to skip these @@ -303,12 +407,15 @@ fn middle( fatal_missing_code: FatalMissingCode, // called with the untrimmed tries after each batch observer: &mut impl Observer, -) -> anyhow::Result>> { +) -> anyhow::Result>> +where + StateTrieT::Key: Ord + From
, +{ // Initialise the storage tries. for (haddr, acct) in state_trie.iter() { let storage = storage_tries.entry(haddr).or_insert({ let mut it = StorageTrie::default(); - it.insert_hash(TrieKey::default(), acct.storage_root) + it.insert_hash(MptKey::default(), acct.storage_root) .expect("empty trie insert cannot fail"); it }); @@ -343,8 +450,8 @@ fn middle( // We want to perform mask the TrieInputs above, // but won't know the bounds until after the loop below, // so store that information here. - let mut storage_masks = BTreeMap::<_, BTreeSet>::new(); - let mut state_mask = BTreeSet::new(); + let mut storage_masks = BTreeMap::<_, BTreeSet>::new(); + let mut state_mask = BTreeSet::::new(); if txn_ix == 0 { do_pre_execution( @@ -440,7 +547,7 @@ fn middle( storage_written .keys() .chain(&storage_read) - .map(|it| TrieKey::from_hash(keccak_hash::keccak(it))), + .map(|it| MptKey::from_hash(keccak_hash::keccak(it))), ); if do_writes { @@ -487,7 +594,7 @@ fn middle( }; for (k, v) in storage_written { - let slot = TrieKey::from_hash(keccak_hash::keccak(k)); + let slot = MptKey::from_hash(keccak_hash::keccak(k)); match v.is_zero() { // this is actually a delete true => storage_mask.extend(storage.reporting_remove(slot)?), @@ -500,10 +607,10 @@ fn middle( } state_trie.insert_by_address(addr, acct)?; - state_mask.insert(TrieKey::from_address(addr)); + state_mask.insert(::from(addr)); } else { // Simple state access - state_mask.insert(TrieKey::from_address(addr)); + state_mask.insert(::from(addr)); } if self_destructed { @@ -526,7 +633,7 @@ fn middle( withdrawals: match loop_ix == loop_len { true => { for (addr, amt) in &withdrawals { - state_mask.insert(TrieKey::from_address(*addr)); + state_mask.insert(::from(*addr)); let mut acct = state_trie .get_by_address(*addr) .context(format!("missing address {addr:x} for withdrawal"))?; @@ -584,10 +691,13 @@ fn do_pre_execution( block: &BlockMetadata, ger_data: Option<(H256, H256)>, storage: &mut BTreeMap, - trim_storage: &mut BTreeMap>, - trim_state: &mut BTreeSet, + trim_storage: &mut BTreeMap>, + trim_state: &mut BTreeSet, state_trie: &mut StateTrieT, -) -> anyhow::Result<()> { +) -> anyhow::Result<()> +where + StateTrieT::Key: From
+ Ord, +{ // Ethereum mainnet: EIP-4788 if cfg!(feature = "eth_mainnet") { return do_beacon_hook( @@ -623,10 +733,13 @@ fn do_scalable_hook( block: &BlockMetadata, ger_data: Option<(H256, H256)>, storage: &mut BTreeMap, - trim_storage: &mut BTreeMap>, - trim_state: &mut BTreeSet, + trim_storage: &mut BTreeMap>, + trim_state: &mut BTreeSet, state_trie: &mut StateTrieT, -) -> anyhow::Result<()> { +) -> anyhow::Result<()> +where + StateTrieT::Key: From
+ Ord, +{ use evm_arithmetization::testing_utils::{ ADDRESS_SCALABLE_L2, ADDRESS_SCALABLE_L2_ADDRESS_HASHED, GLOBAL_EXIT_ROOT_ADDRESS, GLOBAL_EXIT_ROOT_ADDRESS_HASHED, GLOBAL_EXIT_ROOT_STORAGE_POS, LAST_BLOCK_STORAGE_POS, @@ -641,7 +754,7 @@ fn do_scalable_hook( .context("missing scalable contract storage trie")?; let scalable_trim = trim_storage.entry(ADDRESS_SCALABLE_L2).or_default(); - let timestamp_slot_key = TrieKey::from_slot_position(U256::from(TIMESTAMP_STORAGE_POS.1)); + let timestamp_slot_key = MptKey::from_slot_position(U256::from(TIMESTAMP_STORAGE_POS.1)); let timestamp = scalable_storage .get(×tamp_slot_key) @@ -655,7 +768,7 @@ fn do_scalable_hook( (U256::from(LAST_BLOCK_STORAGE_POS.1), block.block_number), (U256::from(TIMESTAMP_STORAGE_POS.1), timestamp), ] { - let slot = TrieKey::from_slot_position(ix); + let slot = MptKey::from_slot_position(ix); // These values are never 0. scalable_storage.insert(slot, alloy::rlp::encode(u.compat()))?; @@ -668,12 +781,12 @@ fn do_scalable_hook( let mut arr = [0; 64]; (block.block_number - 1).to_big_endian(&mut arr[0..32]); U256::from(STATE_ROOT_STORAGE_POS.1).to_big_endian(&mut arr[32..64]); - let slot = TrieKey::from_hash(keccak_hash::keccak(arr)); + let slot = MptKey::from_hash(keccak_hash::keccak(arr)); scalable_storage.insert(slot, alloy::rlp::encode(prev_block_root_hash.compat()))?; scalable_trim.insert(slot); - trim_state.insert(TrieKey::from_address(ADDRESS_SCALABLE_L2)); + trim_state.insert(::from(ADDRESS_SCALABLE_L2)); let mut scalable_acct = state_trie .get_by_address(ADDRESS_SCALABLE_L2) .context("missing scalable contract address")?; @@ -694,12 +807,12 @@ fn do_scalable_hook( let mut arr = [0; 64]; arr[0..32].copy_from_slice(&root.0); U256::from(GLOBAL_EXIT_ROOT_STORAGE_POS.1).to_big_endian(&mut arr[32..64]); - let slot = TrieKey::from_hash(keccak_hash::keccak(arr)); + let slot = MptKey::from_hash(keccak_hash::keccak(arr)); ger_storage.insert(slot, alloy::rlp::encode(l1blockhash.compat()))?; ger_trim.insert(slot); - trim_state.insert(TrieKey::from_address(GLOBAL_EXIT_ROOT_ADDRESS)); + trim_state.insert(::from(GLOBAL_EXIT_ROOT_ADDRESS)); let mut ger_acct = state_trie .get_by_address(GLOBAL_EXIT_ROOT_ADDRESS) .context("missing GER contract address")?; @@ -722,11 +835,14 @@ fn do_scalable_hook( fn do_beacon_hook( block_timestamp: U256, storage: &mut BTreeMap, - trim_storage: &mut BTreeMap>, + trim_storage: &mut BTreeMap>, parent_beacon_block_root: H256, - trim_state: &mut BTreeSet, + trim_state: &mut BTreeSet, state_trie: &mut StateTrieT, -) -> anyhow::Result<()> { +) -> anyhow::Result<()> +where + StateTrieT::Key: From
+ Ord, +{ use evm_arithmetization::testing_utils::{ BEACON_ROOTS_CONTRACT_ADDRESS, BEACON_ROOTS_CONTRACT_ADDRESS_HASHED, HISTORY_BUFFER_LENGTH, }; @@ -747,7 +863,7 @@ fn do_beacon_hook( U256::from_big_endian(parent_beacon_block_root.as_bytes()), ), ] { - let slot = TrieKey::from_slot_position(ix); + let slot = MptKey::from_slot_position(ix); beacon_trim.insert(slot); match u.is_zero() { @@ -758,7 +874,7 @@ fn do_beacon_hook( } } } - trim_state.insert(TrieKey::from_address(BEACON_ROOTS_CONTRACT_ADDRESS)); + trim_state.insert(::from(BEACON_ROOTS_CONTRACT_ADDRESS)); let mut beacon_acct = state_trie .get_by_address(BEACON_ROOTS_CONTRACT_ADDRESS) .context("missing beacon contract address")?; @@ -785,7 +901,7 @@ fn map_receipt_bytes(bytes: Vec) -> anyhow::Result> { /// If there are any txns that create contracts, then they will also /// get added here as we process the deltas. struct Hash2Code { - /// Key must always be [`hash`] of value. + /// Key must always be [`hash`](keccak_hash) of value. inner: HashMap>, } diff --git a/trace_decoder/src/lib.rs b/trace_decoder/src/lib.rs index 049472c40..057d11e89 100644 --- a/trace_decoder/src/lib.rs +++ b/trace_decoder/src/lib.rs @@ -56,16 +56,12 @@ mod interface; pub use interface::*; +mod tries; mod type1; -// TODO(0xaatif): https://github.com/0xPolygonZero/zk_evm/issues/275 -// add backend/prod support for type 2 -#[cfg(test)] -#[allow(dead_code)] mod type2; -mod typed_mpt; mod wire; -pub use core::entrypoint; +pub use core::{entrypoint, WireDisposition}; mod core; diff --git a/trace_decoder/src/observer.rs b/trace_decoder/src/observer.rs index 320019e55..f9811e87c 100644 --- a/trace_decoder/src/observer.rs +++ b/trace_decoder/src/observer.rs @@ -4,7 +4,7 @@ use std::marker::PhantomData; use ethereum_types::{H256, U256}; use crate::core::IntraBlockTries; -use crate::typed_mpt::{ReceiptTrie, StorageTrie, TransactionTrie}; +use crate::tries::{ReceiptTrie, StorageTrie, TransactionTrie}; /// Observer API for the trace decoder. /// Observer is used to collect various debugging and metadata info diff --git a/trace_decoder/src/typed_mpt.rs b/trace_decoder/src/tries.rs similarity index 58% rename from trace_decoder/src/typed_mpt.rs rename to trace_decoder/src/tries.rs index 8baf3cf29..91add4d98 100644 --- a/trace_decoder/src/typed_mpt.rs +++ b/trace_decoder/src/tries.rs @@ -1,10 +1,12 @@ -//! Principled MPT types used in this library. +//! Principled trie types and abstractions used in this library. use core::fmt; -use std::{collections::BTreeMap, marker::PhantomData}; +use std::{cmp, collections::BTreeMap, marker::PhantomData}; +use anyhow::ensure; +use bitvec::{array::BitArray, slice::BitSlice}; use copyvec::CopyVec; -use ethereum_types::{Address, H256, U256}; +use ethereum_types::{Address, BigEndianHash as _, H256, U256}; use evm_arithmetization::generation::mpt::AccountRlp; use mpt_trie::partial_trie::{HashedPartialTrie, Node, OnOrphanedHashNode, PartialTrie as _}; use u4::{AsNibbles, U4}; @@ -30,27 +32,26 @@ impl TypedMpt { /// Insert a node which represents an out-of-band sub-trie. /// /// See [module documentation](super) for more. - fn insert_hash(&mut self, key: TrieKey, hash: H256) -> anyhow::Result<()> { + fn insert_hash(&mut self, key: MptKey, hash: H256) -> anyhow::Result<()> { self.inner.insert(key.into_nibbles(), hash)?; Ok(()) } - /// Returns an [`Error`] if the `key` crosses into a part of the trie that - /// isn't hydrated. - fn insert(&mut self, key: TrieKey, value: T) -> anyhow::Result> + /// Returns [`Err`] if the `key` crosses into a part of the trie that + /// is hashed out. + fn insert(&mut self, key: MptKey, value: T) -> anyhow::Result<()> where T: rlp::Encodable + rlp::Decodable, { - let prev = self.get(key); self.inner .insert(key.into_nibbles(), rlp::encode(&value).to_vec())?; - Ok(prev) + Ok(()) } /// Note that this returns [`None`] if `key` crosses into a part of the - /// trie that isn't hydrated. + /// trie that is hashed out. /// /// # Panics /// - If [`rlp::decode`]-ing for `T` doesn't round-trip. - fn get(&self, key: TrieKey) -> Option + fn get(&self, key: MptKey) -> Option where T: rlp::Decodable, { @@ -67,12 +68,12 @@ impl TypedMpt { self.inner.hash() } /// Note that this returns owned paths and items. - fn iter(&self) -> impl Iterator + '_ + fn iter(&self) -> impl Iterator + '_ where T: rlp::Decodable, { self.inner.keys().filter_map(|nib| { - let path = TrieKey::from_nibbles(nib); + let path = MptKey::from_nibbles(nib); Some((path, self.get(path)?)) }) } @@ -88,7 +89,7 @@ impl<'a, T> IntoIterator for &'a TypedMpt where T: rlp::Decodable, { - type Item = (TrieKey, T); + type Item = (MptKey, T); type IntoIter = Box + 'a>; fn into_iter(self) -> Self::IntoIter { Box::new(self.iter()) @@ -100,9 +101,9 @@ where /// /// Semantically equivalent to [`mpt_trie::nibbles::Nibbles`]. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] -pub struct TrieKey(CopyVec); +pub struct MptKey(CopyVec); -impl fmt::Display for TrieKey { +impl fmt::Display for MptKey { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { for u in self.0 { f.write_fmt(format_args!("{:x}", u))? @@ -111,9 +112,9 @@ impl fmt::Display for TrieKey { } } -impl TrieKey { +impl MptKey { pub fn new(components: impl IntoIterator) -> anyhow::Result { - Ok(TrieKey(CopyVec::try_from_iter(components)?)) + Ok(MptKey(CopyVec::try_from_iter(components)?)) } pub fn into_hash_left_padded(mut self) -> H256 { for _ in 0..self.0.spare_capacity_mut().len() { @@ -136,7 +137,7 @@ impl TrieKey { } pub fn from_txn_ix(txn_ix: usize) -> Self { - TrieKey::new(AsNibbles(rlp::encode(&txn_ix))).expect( + MptKey::new(AsNibbles(rlp::encode(&txn_ix))).expect( "\ rlp of an usize goes through a u64, which is 8 bytes, which will be 9 bytes RLP'ed. @@ -170,17 +171,111 @@ impl TrieKey { } } +impl From
for MptKey { + fn from(value: Address) -> Self { + Self::from_hash(keccak_hash::keccak(value)) + } +} + #[test] -fn key_into_hash() { - assert_eq!(TrieKey::new([]).unwrap().into_hash(), None); +fn mpt_key_into_hash() { + assert_eq!(MptKey::new([]).unwrap().into_hash(), None); assert_eq!( - TrieKey::new(itertools::repeat_n(u4::u4!(0), 64)) + MptKey::new(itertools::repeat_n(u4::u4!(0), 64)) .unwrap() .into_hash(), Some(H256::zero()) ) } +/// Bounded sequence of bits, +/// used as a key for [`StateSmt`]. +/// +/// Semantically equivalent to [`smt_trie::bits::Bits`]. +#[derive(Clone, Copy)] +pub struct SmtKey { + bits: bitvec::array::BitArray<[u8; 32]>, + len: usize, +} + +impl SmtKey { + fn as_bitslice(&self) -> &BitSlice { + self.bits.as_bitslice().get(..self.len).unwrap() + } +} + +impl fmt::Debug for SmtKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list() + .entries(self.as_bitslice().iter().map(|it| match *it { + true => 1, + false => 0, + })) + .finish() + } +} + +impl fmt::Display for SmtKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + for bit in self.as_bitslice() { + f.write_str(match *bit { + true => "1", + false => "0", + })? + } + Ok(()) + } +} + +impl SmtKey { + pub fn new(components: impl IntoIterator) -> anyhow::Result { + let mut bits = bitvec::array::BitArray::default(); + let mut len = 0; + for (ix, bit) in components.into_iter().enumerate() { + ensure!( + bits.get(ix).is_some(), + "expected at most {} components", + bits.len() + ); + bits.set(ix, bit); + len += 1 + } + Ok(Self { bits, len }) + } + + pub fn into_smt_bits(self) -> smt_trie::bits::Bits { + let mut bits = smt_trie::bits::Bits::default(); + for bit in self.as_bitslice() { + bits.push_bit(*bit) + } + bits + } +} + +impl From
for SmtKey { + fn from(addr: Address) -> Self { + let H256(bytes) = keccak_hash::keccak(addr); + Self::new(BitArray::<_>::new(bytes)).expect("SmtKey has room for 256 bits") + } +} + +impl Ord for SmtKey { + fn cmp(&self, other: &Self) -> cmp::Ordering { + self.as_bitslice().cmp(other.as_bitslice()) + } +} +impl PartialOrd for SmtKey { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl Eq for SmtKey {} +impl PartialEq for SmtKey { + fn eq(&self, other: &Self) -> bool { + self.as_bitslice().eq(other.as_bitslice()) + } +} + /// Per-block, `txn_ix -> [u8]`. /// /// See @@ -196,10 +291,10 @@ impl TransactionTrie { pub fn insert(&mut self, txn_ix: usize, val: Vec) -> anyhow::Result>> { let prev = self .untyped - .get(TrieKey::from_txn_ix(txn_ix).into_nibbles()) + .get(MptKey::from_txn_ix(txn_ix).into_nibbles()) .map(Vec::from); self.untyped - .insert(TrieKey::from_txn_ix(txn_ix).into_nibbles(), val)?; + .insert(MptKey::from_txn_ix(txn_ix).into_nibbles(), val)?; Ok(prev) } pub fn root(&self) -> H256 { @@ -214,7 +309,7 @@ impl TransactionTrie { &self.untyped, txn_ixs .into_iter() - .map(|it| TrieKey::from_txn_ix(it).into_nibbles()), + .map(|it| MptKey::from_txn_ix(it).into_nibbles()), )?; Ok(()) } @@ -241,10 +336,10 @@ impl ReceiptTrie { pub fn insert(&mut self, txn_ix: usize, val: Vec) -> anyhow::Result>> { let prev = self .untyped - .get(TrieKey::from_txn_ix(txn_ix).into_nibbles()) + .get(MptKey::from_txn_ix(txn_ix).into_nibbles()) .map(Vec::from); self.untyped - .insert(TrieKey::from_txn_ix(txn_ix).into_nibbles(), val)?; + .insert(MptKey::from_txn_ix(txn_ix).into_nibbles(), val)?; Ok(prev) } pub fn root(&self) -> H256 { @@ -259,7 +354,7 @@ impl ReceiptTrie { &self.untyped, txn_ixs .into_iter() - .map(|it| TrieKey::from_txn_ix(it).into_nibbles()), + .map(|it| MptKey::from_txn_ix(it).into_nibbles()), )?; Ok(()) } @@ -271,18 +366,14 @@ impl From for HashedPartialTrie { } } -/// TODO(0xaatif): document this after refactoring is done https://github.com/0xPolygonZero/zk_evm/issues/275 +/// TODO(0xaatif): document this after refactoring is done pub trait StateTrie { - fn insert_by_address( - &mut self, - address: Address, - account: AccountRlp, - ) -> anyhow::Result>; - fn insert_hash_by_key(&mut self, key: TrieKey, hash: H256) -> anyhow::Result<()>; + type Key; + fn insert_by_address(&mut self, address: Address, account: AccountRlp) -> anyhow::Result<()>; fn get_by_address(&self, address: Address) -> Option; - fn reporting_remove(&mut self, address: Address) -> anyhow::Result>; - /// _Hash out_ parts of the trie that aren't in `txn_ixs`. - fn mask(&mut self, address: impl IntoIterator) -> anyhow::Result<()>; + fn reporting_remove(&mut self, address: Address) -> anyhow::Result>; + /// _Hash out_ parts of the trie that aren't in `addresses`. + fn mask(&mut self, address: impl IntoIterator) -> anyhow::Result<()>; fn iter(&self) -> impl Iterator + '_; fn root(&self) -> H256; } @@ -304,13 +395,17 @@ impl StateMpt { }, } } + /// Insert a _hashed out_ part of the trie + pub fn insert_hash_by_key(&mut self, key: MptKey, hash: H256) -> anyhow::Result<()> { + self.typed.insert_hash(key, hash) + } #[deprecated = "prefer operations on `Address` where possible, as SMT support requires this"] pub fn insert_by_hashed_address( &mut self, key: H256, account: AccountRlp, - ) -> anyhow::Result> { - self.typed.insert(TrieKey::from_hash(key), account) + ) -> anyhow::Result<()> { + self.typed.insert(MptKey::from_hash(key), account) } pub fn iter(&self) -> impl Iterator + '_ { self.typed @@ -323,34 +418,27 @@ impl StateMpt { } impl StateTrie for StateMpt { - fn insert_by_address( - &mut self, - address: Address, - account: AccountRlp, - ) -> anyhow::Result> { + type Key = MptKey; + fn insert_by_address(&mut self, address: Address, account: AccountRlp) -> anyhow::Result<()> { #[expect(deprecated)] self.insert_by_hashed_address(keccak_hash::keccak(address), account) } - /// Insert an _hashed out_ part of the trie - fn insert_hash_by_key(&mut self, key: TrieKey, hash: H256) -> anyhow::Result<()> { - self.typed.insert_hash(key, hash) - } fn get_by_address(&self, address: Address) -> Option { self.typed - .get(TrieKey::from_hash(keccak_hash::keccak(address))) + .get(MptKey::from_hash(keccak_hash::keccak(address))) } /// Delete the account at `address`, returning any remaining branch on /// collapse - fn reporting_remove(&mut self, address: Address) -> anyhow::Result> { + fn reporting_remove(&mut self, address: Address) -> anyhow::Result> { delete_node_and_report_remaining_key_if_branch_collapsed( self.typed.as_mut_hashed_partial_trie_unchecked(), - TrieKey::from_address(address), + MptKey::from_address(address), ) } - fn mask(&mut self, addresses: impl IntoIterator) -> anyhow::Result<()> { + fn mask(&mut self, addresses: impl IntoIterator) -> anyhow::Result<()> { let inner = mpt_trie::trie_subsets::create_trie_subset( self.typed.as_hashed_partial_trie(), - addresses.into_iter().map(TrieKey::into_nibbles), + addresses.into_iter().map(MptKey::into_nibbles), )?; self.typed = TypedMpt { inner, @@ -377,31 +465,30 @@ impl From for HashedPartialTrie { } } +// TODO(0xaatif): https://github.com/0xPolygonZero/zk_evm/issues/706 +// We're covering for [`smt_trie`] in a couple of ways: +// - insertion operations aren't fallible, they just panic. +// - it documents a requirement that `set_hash` is called before `set`. +#[derive(Clone, Debug)] pub struct StateSmt { address2state: BTreeMap, - hashed_out: BTreeMap, + hashed_out: BTreeMap, } impl StateTrie for StateSmt { - fn insert_by_address( - &mut self, - address: Address, - account: AccountRlp, - ) -> anyhow::Result> { - Ok(self.address2state.insert(address, account)) - } - fn insert_hash_by_key(&mut self, key: TrieKey, hash: H256) -> anyhow::Result<()> { - self.hashed_out.insert(key, hash); + type Key = SmtKey; + fn insert_by_address(&mut self, address: Address, account: AccountRlp) -> anyhow::Result<()> { + self.address2state.insert(address, account); Ok(()) } fn get_by_address(&self, address: Address) -> Option { self.address2state.get(&address).copied() } - fn reporting_remove(&mut self, address: Address) -> anyhow::Result> { + fn reporting_remove(&mut self, address: Address) -> anyhow::Result> { self.address2state.remove(&address); Ok(None) } - fn mask(&mut self, address: impl IntoIterator) -> anyhow::Result<()> { + fn mask(&mut self, address: impl IntoIterator) -> anyhow::Result<()> { let _ = address; Ok(()) } @@ -411,7 +498,111 @@ impl StateTrie for StateSmt { .map(|(addr, acct)| (keccak_hash::keccak(addr), *acct)) } fn root(&self) -> H256 { - todo!() + conv_hash::smt2eth(self.as_smt().root) + } +} + +impl StateSmt { + pub(crate) fn new_unchecked( + address2state: BTreeMap, + hashed_out: BTreeMap, + ) -> Self { + Self { + address2state, + hashed_out, + } + } + + fn as_smt(&self) -> smt_trie::smt::Smt { + let Self { + address2state, + hashed_out, + } = self; + let mut smt = smt_trie::smt::Smt::::default(); + for (k, v) in hashed_out { + smt.set_hash(k.into_smt_bits(), conv_hash::eth2smt(*v)); + } + for ( + addr, + AccountRlp { + nonce, + balance, + storage_root, + code_hash, + }, + ) in address2state + { + smt.set(smt_trie::keys::key_nonce(*addr), *nonce); + smt.set(smt_trie::keys::key_balance(*addr), *balance); + smt.set(smt_trie::keys::key_code(*addr), code_hash.into_uint()); + smt.set( + // TODO(0xaatif): https://github.com/0xPolygonZero/zk_evm/issues/707 + // combined abstraction for state and storage + smt_trie::keys::key_storage(*addr, U256::zero()), + storage_root.into_uint(), + ); + } + smt + } +} + +mod conv_hash { + //! We [`u64::to_le_bytes`] because: + //! - Reference go code just puns the bytes: + //! - It's better to fix the endianness for correctness. + //! - Most (consumer) CPUs are little-endian. + + use std::array; + + use ethereum_types::H256; + use itertools::Itertools as _; + use plonky2::{ + field::{ + goldilocks_field::GoldilocksField, + types::{Field as _, PrimeField64}, + }, + hash::hash_types::HashOut, + }; + + /// # Panics + /// - On certain inputs if `debug_assertions` are enabled. See + /// [`GoldilocksField::from_canonical_u64`] for more. + pub fn eth2smt(H256(bytes): H256) -> smt_trie::smt::HashOut { + let mut bytes = bytes.into_iter(); + // (no unsafe, no unstable) + let ret = HashOut { + elements: array::from_fn(|_ix| { + let (a, b, c, d, e, f, g, h) = bytes.next_tuple().unwrap(); + GoldilocksField::from_canonical_u64(u64::from_le_bytes([a, b, c, d, e, f, g, h])) + }), + }; + assert_eq!(bytes.len(), 0); + ret + } + pub fn smt2eth(HashOut { elements }: smt_trie::smt::HashOut) -> H256 { + H256( + build_array::ArrayBuilder::from_iter( + elements + .iter() + .map(GoldilocksField::to_canonical_u64) + .flat_map(u64::to_le_bytes), + ) + .build_exact() + .unwrap(), + ) + } + + #[test] + fn test() { + use plonky2::field::types::Field64 as _; + let mut max = std::iter::repeat(GoldilocksField::ORDER - 1).flat_map(u64::to_le_bytes); + for h in [ + H256::zero(), + H256(array::from_fn(|ix| ix as u8)), + H256(array::from_fn(|_| max.next().unwrap())), + ] { + assert_eq!(smt2eth(eth2smt(h)), h); + } } } @@ -428,15 +619,15 @@ impl StorageTrie { untyped: HashedPartialTrie::new_with_strategy(Node::Empty, strategy), } } - pub fn get(&mut self, key: &TrieKey) -> Option<&[u8]> { + pub fn get(&mut self, key: &MptKey) -> Option<&[u8]> { self.untyped.get(key.into_nibbles()) } - pub fn insert(&mut self, key: TrieKey, value: Vec) -> anyhow::Result>> { + pub fn insert(&mut self, key: MptKey, value: Vec) -> anyhow::Result>> { let prev = self.get(&key).map(Vec::from); self.untyped.insert(key.into_nibbles(), value)?; Ok(prev) } - pub fn insert_hash(&mut self, key: TrieKey, hash: H256) -> anyhow::Result<()> { + pub fn insert_hash(&mut self, key: MptKey, hash: H256) -> anyhow::Result<()> { self.untyped.insert(key.into_nibbles(), hash)?; Ok(()) } @@ -446,17 +637,17 @@ impl StorageTrie { pub const fn as_hashed_partial_trie(&self) -> &HashedPartialTrie { &self.untyped } - pub fn reporting_remove(&mut self, key: TrieKey) -> anyhow::Result> { + pub fn reporting_remove(&mut self, key: MptKey) -> anyhow::Result> { delete_node_and_report_remaining_key_if_branch_collapsed(&mut self.untyped, key) } pub fn as_mut_hashed_partial_trie_unchecked(&mut self) -> &mut HashedPartialTrie { &mut self.untyped } /// _Hash out_ the parts of the trie that aren't in `paths`. - pub fn mask(&mut self, paths: impl IntoIterator) -> anyhow::Result<()> { + pub fn mask(&mut self, paths: impl IntoIterator) -> anyhow::Result<()> { self.untyped = mpt_trie::trie_subsets::create_trie_subset( &self.untyped, - paths.into_iter().map(TrieKey::into_nibbles), + paths.into_iter().map(MptKey::into_nibbles), )?; Ok(()) } @@ -473,18 +664,18 @@ impl From for HashedPartialTrie { /// plonky2. Returns the key to the remaining child if a collapse occurred. fn delete_node_and_report_remaining_key_if_branch_collapsed( trie: &mut HashedPartialTrie, - key: TrieKey, -) -> anyhow::Result> { + key: MptKey, +) -> anyhow::Result> { let old_trace = get_trie_trace(trie, key); trie.delete(key.into_nibbles())?; let new_trace = get_trie_trace(trie, key); Ok( node_deletion_resulted_in_a_branch_collapse(&old_trace, &new_trace) - .map(TrieKey::from_nibbles), + .map(MptKey::from_nibbles), ) } -fn get_trie_trace(trie: &HashedPartialTrie, k: TrieKey) -> mpt_trie::utils::TriePath { +fn get_trie_trace(trie: &HashedPartialTrie, k: MptKey) -> mpt_trie::utils::TriePath { mpt_trie::special_query::path_for_query(trie, k.into_nibbles(), true).collect() } diff --git a/trace_decoder/src/type1.rs b/trace_decoder/src/type1.rs index aeea0dbb6..c44beaec7 100644 --- a/trace_decoder/src/type1.rs +++ b/trace_decoder/src/type1.rs @@ -12,7 +12,7 @@ use mpt_trie::partial_trie::OnOrphanedHashNode; use nunny::NonEmpty; use u4::U4; -use crate::typed_mpt::{StateMpt, StateTrie as _, StorageTrie, TrieKey}; +use crate::tries::{MptKey, StateMpt, StorageTrie}; use crate::wire::{Instruction, SmtLeaf}; #[derive(Debug, Clone)] @@ -66,10 +66,10 @@ fn visit( Node::Hash(Hash { raw_hash }) => { frontend .state - .insert_hash_by_key(TrieKey::new(path.iter().copied())?, raw_hash.into())?; + .insert_hash_by_key(MptKey::new(path.iter().copied())?, raw_hash.into())?; } Node::Leaf(Leaf { key, value }) => { - let path = TrieKey::new(path.iter().copied().chain(key))? + let path = MptKey::new(path.iter().copied().chain(key))? .into_hash() .context("invalid depth for leaf of state trie")?; match value { @@ -106,8 +106,7 @@ fn visit( }, }; #[expect(deprecated)] // this is MPT-specific code - let clobbered = frontend.state.insert_by_hashed_address(path, account)?; - ensure!(clobbered.is_none(), "duplicate account"); + frontend.state.insert_by_hashed_address(path, account)?; } } } @@ -141,12 +140,12 @@ fn node2storagetrie(node: Node) -> anyhow::Result { ) -> anyhow::Result<()> { match node { Node::Hash(Hash { raw_hash }) => { - mpt.insert_hash(TrieKey::new(path.iter().copied())?, raw_hash.into())?; + mpt.insert_hash(MptKey::new(path.iter().copied())?, raw_hash.into())?; } Node::Leaf(Leaf { key, value }) => { match value { Either::Left(Value { raw_value }) => mpt.insert( - TrieKey::new(path.iter().copied().chain(key))?, + MptKey::new(path.iter().copied().chain(key))?, rlp::encode(&raw_value.as_slice()).to_vec(), )?, Either::Right(_) => bail!("unexpected account node in storage trie"), @@ -380,6 +379,8 @@ fn finish_stack(v: &mut Vec) -> anyhow::Result { #[test] fn test_tries() { + use crate::tries::StateTrie as _; + for (ix, case) in serde_json::from_str::>(include_str!("cases/zero_jerigon.json")) .unwrap() diff --git a/trace_decoder/src/type2.rs b/trace_decoder/src/type2.rs index dd3e45c4b..a71761533 100644 --- a/trace_decoder/src/type2.rs +++ b/trace_decoder/src/type2.rs @@ -1,35 +1,34 @@ //! Frontend for the witness format emitted by e.g [`0xPolygonHermez/cdk-erigon`](https://github.com/0xPolygonHermez/cdk-erigon/) //! Ethereum node. -use std::{ - collections::{HashMap, HashSet}, - iter, -}; +use std::collections::{BTreeMap, HashSet}; use anyhow::{bail, ensure, Context as _}; -use bitvec::vec::BitVec; -use either::Either; -use ethereum_types::BigEndianHash as _; -use itertools::{EitherOrBoth, Itertools as _}; +use ethereum_types::{Address, U256}; +use evm_arithmetization::generation::mpt::AccountRlp; +use itertools::EitherOrBoth; +use keccak_hash::H256; use nunny::NonEmpty; -use plonky2::field::types::Field; - -use crate::wire::{Instruction, SmtLeaf, SmtLeafType}; +use stackstack::Stack; -type SmtTrie = smt_trie::smt::Smt; +use crate::{ + tries::{SmtKey, StateSmt}, + wire::{Instruction, SmtLeaf, SmtLeafType}, +}; -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] +/// Combination of all the [`SmtLeaf::node_type`]s +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] pub struct CollatedLeaf { pub balance: Option, pub nonce: Option, - pub code_hash: Option, - pub storage_root: Option, + pub code: Option, + pub code_length: Option, + pub storage: BTreeMap, } pub struct Frontend { - pub trie: SmtTrie, + pub trie: StateSmt, pub code: HashSet>>, - pub collation: HashMap, } /// # Panics @@ -37,18 +36,13 @@ pub struct Frontend { /// NOT call this function on untrusted inputs. pub fn frontend(instructions: impl IntoIterator) -> anyhow::Result { let (node, code) = fold(instructions).context("couldn't fold smt from instructions")?; - let (trie, collation) = - node2trie(node).context("couldn't construct trie and collation from folded node")?; - Ok(Frontend { - trie, - code, - collation, - }) + let trie = node2trie(node).context("couldn't construct trie and collation from folded node")?; + Ok(Frontend { trie, code }) } /// Node in a binary (SMT) tree. /// -/// This is an intermediary type on the way to [`SmtTrie`]. +/// This is an intermediary type on the way to [`StateSmt`]. enum Node { Branch(EitherOrBoth>), Hash([u8; 32]), @@ -105,9 +99,9 @@ fn fold1(instructions: impl IntoIterator) -> anyhow::Result< Ok(Some(match mask { // note that the single-child bits are reversed... - 0b0001 => Node::Branch(EitherOrBoth::Left(get_child()?)), - 0b0010 => Node::Branch(EitherOrBoth::Right(get_child()?)), - 0b0011 => Node::Branch(EitherOrBoth::Both(get_child()?, get_child()?)), + 0b_01 => Node::Branch(EitherOrBoth::Left(get_child()?)), + 0b_10 => Node::Branch(EitherOrBoth::Right(get_child()?)), + 0b_11 => Node::Branch(EitherOrBoth::Both(get_child()?, get_child()?)), other => bail!("unexpected bit pattern in Branch mask: {:#b}", other), })) } @@ -119,113 +113,162 @@ fn fold1(instructions: impl IntoIterator) -> anyhow::Result< } } -/// Pack a [`Node`] tree into an [`SmtTrie`]. -/// Also summarizes the [`Node::Leaf`]s out-of-band. -/// -/// # Panics -/// - if the tree is too deep. -/// - if [`SmtLeaf::address`] or [`SmtLeaf::value`] are the wrong length. -/// - if [`SmtLeafType::Storage`] is the wrong length. -/// - [`SmtTrie`] panics internally. -fn node2trie( - node: Node, -) -> anyhow::Result<(SmtTrie, HashMap)> { - let mut trie = SmtTrie::default(); - - let (hashes, leaves) = - iter_leaves(node).partition_map::, Vec<_>, _, _, _>(|(path, leaf)| match leaf { - Either::Left(it) => Either::Left((path, it)), - Either::Right(it) => Either::Right(it), - }); - - for (path, hash) in hashes { - // needs to be called before `set`, below, "to avoid any issues" according - // to the smt docs. - trie.set_hash( - bits2bits(path), - smt_trie::smt::HashOut { - elements: { - let ethereum_types::U256(arr) = ethereum_types::H256(hash).into_uint(); - arr.map(smt_trie::smt::F::from_canonical_u64) +fn node2trie(node: Node) -> anyhow::Result { + let mut hashes = BTreeMap::new(); + let mut leaves = BTreeMap::new(); + visit(&mut hashes, &mut leaves, Stack::new(), node)?; + Ok(StateSmt::new_unchecked( + leaves + .into_iter() + .map( + |( + addr, + CollatedLeaf { + balance, + nonce, + // TODO(0xaatif): https://github.com/0xPolygonZero/zk_evm/issues/707 + // we shouldn't ignore these fields + code: _, + code_length: _, + storage: _, + }, + )| { + ( + addr, + AccountRlp { + nonce: nonce.unwrap_or_default(), + balance: balance.unwrap_or_default(), + storage_root: H256::zero(), + code_hash: H256::zero(), + }, + ) }, - }, - ) - } + ) + .collect(), + hashes, + )) +} - let mut collated = HashMap::::new(); - for SmtLeaf { - node_type, - address, - value, - } in leaves - { - let address = ethereum_types::Address::from_slice(&address); - let collated = collated.entry(address).or_default(); - let value = ethereum_types::U256::from_big_endian(&value); - let key = match node_type { - SmtLeafType::Balance => { - ensure!(collated.balance.is_none(), "double write of field"); - collated.balance = Some(value); - smt_trie::keys::key_balance(address) - } - SmtLeafType::Nonce => { - ensure!(collated.nonce.is_none(), "double write of field"); - collated.nonce = Some(value); - smt_trie::keys::key_nonce(address) +fn visit( + hashes: &mut BTreeMap, + leaves: &mut BTreeMap, + path: Stack, + node: Node, +) -> anyhow::Result<()> { + match node { + Node::Branch(children) => { + let (left, right) = children.left_and_right(); + if let Some(left) = left { + visit(hashes, leaves, path.pushed(false), *left)?; } - SmtLeafType::Code => { - ensure!(collated.code_hash.is_none(), "double write of field"); - collated.code_hash = Some({ - let mut it = ethereum_types::H256::zero(); - value.to_big_endian(it.as_bytes_mut()); - it - }); - smt_trie::keys::key_code(address) + if let Some(right) = right { + visit(hashes, leaves, path.pushed(true), *right)?; } - SmtLeafType::Storage(it) => { - ensure!(collated.storage_root.is_none(), "double write of field"); - // TODO(0xaatif): https://github.com/0xPolygonZero/zk_evm/issues/275 - // do we not do anything with the storage here? - smt_trie::keys::key_storage(address, ethereum_types::U256::from_big_endian(&it)) + } + Node::Hash(hash) => { + hashes.insert(SmtKey::new(path.iter().copied())?, H256(hash)); + } + Node::Leaf(SmtLeaf { + node_type, + address, + value, + }) => { + let address = Address::from_slice(&address); + let collated = leaves.entry(address).or_default(); + let value = U256::from_big_endian(&value); + macro_rules! ensure { + ($expr:expr) => { + ::anyhow::ensure!($expr, "double write of field for address {}", address) + }; } - SmtLeafType::CodeLength => smt_trie::keys::key_code_length(address), - }; - trie.set(key, value) + match node_type { + SmtLeafType::Balance => { + ensure!(collated.balance.is_none()); + collated.balance = Some(value) + } + SmtLeafType::Nonce => { + ensure!(collated.nonce.is_none()); + collated.nonce = Some(value) + } + SmtLeafType::Code => { + ensure!(collated.code.is_none()); + collated.code = Some(value) + } + SmtLeafType::Storage(slot) => { + let clobbered = collated.storage.insert(U256::from_big_endian(&slot), value); + ensure!(clobbered.is_none()) + } + SmtLeafType::CodeLength => { + ensure!(collated.code_length.is_none()); + collated.code_length = Some(value) + } + }; + } } - Ok((trie, collated)) + Ok(()) } -/// # Panics -/// - on overcapacity -fn bits2bits(ours: BitVec) -> smt_trie::bits::Bits { - let mut theirs = smt_trie::bits::Bits::empty(); - for it in ours { - theirs.push_bit(it) - } - theirs -} +#[test] +fn test_tries() { + type Smt = smt_trie::smt::Smt; + use ethereum_types::BigEndianHash as _; + use plonky2::field::types::{Field, Field64 as _}; -/// Simple, inefficient visitor of all leaves of the [`Node`] tree. -#[allow(clippy::type_complexity)] -fn iter_leaves(node: Node) -> Box)>> { - match node { - Node::Hash(it) => Box::new(iter::once((BitVec::new(), Either::Left(it)))), - Node::Branch(it) => { - let (left, right) = it.left_and_right(); - let left = left - .into_iter() - .flat_map(|it| iter_leaves(*it).update(|(path, _)| path.insert(0, false))); - let right = right - .into_iter() - .flat_map(|it| iter_leaves(*it).update(|(path, _)| path.insert(0, true))); - Box::new(left.chain(right)) + // TODO(0xaatif): https://github.com/0xPolygonZero/zk_evm/issues/707 + // this logic should live in StateSmt, but we need to + // - abstract over state and storage tries + // - parameterize the account types + // we preserve this code as a tested record of how it _should_ + // be done. + fn node2trie(node: Node) -> anyhow::Result { + let mut trie = Smt::default(); + let mut hashes = BTreeMap::new(); + let mut leaves = BTreeMap::new(); + visit(&mut hashes, &mut leaves, Stack::new(), node)?; + for (key, hash) in hashes { + trie.set_hash( + key.into_smt_bits(), + smt_trie::smt::HashOut { + elements: { + let ethereum_types::U256(arr) = hash.into_uint(); + for u in arr { + ensure!(u < smt_trie::smt::F::ORDER); + } + arr.map(smt_trie::smt::F::from_canonical_u64) + }, + }, + ); } - Node::Leaf(it) => Box::new(iter::once((BitVec::new(), Either::Right(it)))), + for ( + addr, + CollatedLeaf { + balance, + nonce, + code, + code_length, + storage, + }, + ) in leaves + { + use smt_trie::keys::{key_balance, key_code, key_code_length, key_nonce, key_storage}; + + for (value, key_fn) in [ + (balance, key_balance as fn(_) -> _), + (nonce, key_nonce), + (code, key_code), + (code_length, key_code_length), + ] { + if let Some(value) = value { + trie.set(key_fn(addr), value); + } + } + for (slot, value) in storage { + trie.set(key_storage(addr, slot), value); + } + } + Ok(trie) } -} -#[test] -fn test_tries() { for (ix, case) in serde_json::from_str::>(include_str!("cases/hermez_cdk_erigon.json")) .unwrap() @@ -234,10 +277,11 @@ fn test_tries() { { println!("case {}", ix); let instructions = crate::wire::parse(&case.bytes).unwrap(); - let frontend = frontend(instructions).unwrap(); + let (node, _code) = fold(instructions).unwrap(); + let trie = node2trie(node).unwrap(); assert_eq!(case.expected_state_root, { let mut it = [0; 32]; - smt_trie::utils::hashout2u(frontend.trie.root).to_big_endian(&mut it); + smt_trie::utils::hashout2u(trie.root).to_big_endian(&mut it); ethereum_types::H256(it) }); } diff --git a/trace_decoder/src/wire.rs b/trace_decoder/src/wire.rs index 6f56f1e44..63dee6040 100644 --- a/trace_decoder/src/wire.rs +++ b/trace_decoder/src/wire.rs @@ -1,6 +1,6 @@ //! We support two wire formats: -//! - Type 1, based on [this specification](https://gist.github.com/mandrigin/ff7eccf30d0ef9c572bafcb0ab665cff#the-bytes-layout). -//! - Type 2, loosely based on [this specification](https://github.com/0xPolygonHermez/cdk-erigon/blob/d1d6b3c7a4c81c46fd995c1baa5c1f8069ff0348/turbo/trie/WITNESS.md) +//! - Type 1 (AKA MPT), based on [this specification](https://gist.github.com/mandrigin/ff7eccf30d0ef9c572bafcb0ab665cff#the-bytes-layout). +//! - Type 2 (AKA SMT), loosely based on [this specification](https://github.com/0xPolygonHermez/cdk-erigon/blob/d1d6b3c7a4c81c46fd995c1baa5c1f8069ff0348/turbo/trie/WITNESS.md) //! //! Fortunately, their opcodes don't conflict, so we can have a single //! [`Instruction`] type, with shared parsing logic in this module, and bail on @@ -80,6 +80,8 @@ pub enum Instruction { } #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +// TODO(0xaatif): https://github.com/0xPolygonZero/zk_evm/issues/705 +// `address` and `value` should be fixed length fields pub struct SmtLeaf { pub node_type: SmtLeafType, pub address: NonEmpty>, @@ -87,6 +89,8 @@ pub struct SmtLeaf { } #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +// TODO(0xaatif): https://github.com/0xPolygonZero/zk_evm/issues/705 +// `Storage` should contain a fixed length field pub enum SmtLeafType { Balance, Nonce, diff --git a/trace_decoder/tests/consistent-with-header.rs b/trace_decoder/tests/consistent-with-header.rs index 609fd57bb..63df41b4f 100644 --- a/trace_decoder/tests/consistent-with-header.rs +++ b/trace_decoder/tests/consistent-with-header.rs @@ -11,6 +11,7 @@ use itertools::Itertools; use libtest_mimic::{Arguments, Trial}; use mpt_trie::partial_trie::PartialTrie as _; use trace_decoder::observer::DummyObserver; +use zero::prover::WIRE_DISPOSITION; fn main() -> anyhow::Result<()> { let mut trials = vec![]; @@ -29,6 +30,7 @@ fn main() -> anyhow::Result<()> { other.clone(), batch_size, &mut DummyObserver::new(), + WIRE_DISPOSITION, ) .map_err(|e| format!("{e:?}"))?; // get the full cause chain check!(gen_inputs.len() >= 2); diff --git a/trace_decoder/tests/simulate-execution.rs b/trace_decoder/tests/simulate-execution.rs index d0476c2b7..fc7136c34 100644 --- a/trace_decoder/tests/simulate-execution.rs +++ b/trace_decoder/tests/simulate-execution.rs @@ -9,6 +9,7 @@ use common::{cases, Case}; use libtest_mimic::{Arguments, Trial}; use plonky2::field::goldilocks_field::GoldilocksField; use trace_decoder::observer::DummyObserver; +use zero::prover::WIRE_DISPOSITION; fn main() -> anyhow::Result<()> { let mut trials = vec![]; @@ -20,11 +21,16 @@ fn main() -> anyhow::Result<()> { other, } in cases()? { - let gen_inputs = - trace_decoder::entrypoint(trace, other, batch_size, &mut DummyObserver::new()) - .context(format!( - "error in `trace_decoder` for {name} at batch size {batch_size}" - ))?; + let gen_inputs = trace_decoder::entrypoint( + trace, + other, + batch_size, + &mut DummyObserver::new(), + WIRE_DISPOSITION, + ) + .context(format!( + "error in `trace_decoder` for {name} at batch size {batch_size}" + ))?; for (ix, gi) in gen_inputs.into_iter().enumerate() { trials.push(Trial::test( format!("{name}@{batch_size}/{ix}"), diff --git a/zero/Cargo.toml b/zero/Cargo.toml index 22c2a8bfb..7cbf2f351 100644 --- a/zero/Cargo.toml +++ b/zero/Cargo.toml @@ -15,6 +15,7 @@ alloy-compat = "0.1.0" anyhow.workspace = true async-stream.workspace = true axum.workspace = true +cfg-if = "1.0.0" clap = { workspace = true, features = ["derive", "string"] } compat.workspace = true directories = "5.0.1" diff --git a/zero/src/bin/rpc.rs b/zero/src/bin/rpc.rs index d49cdde5c..164751df2 100644 --- a/zero/src/bin/rpc.rs +++ b/zero/src/bin/rpc.rs @@ -14,6 +14,7 @@ use url::Url; use zero::block_interval::BlockInterval; use zero::block_interval::BlockIntervalStream; use zero::prover::BlockProverInput; +use zero::prover::WIRE_DISPOSITION; use zero::provider::CachedProvider; use zero::rpc; @@ -172,6 +173,7 @@ impl Cli { block_prover_input.other_data, batch_size, &mut DummyObserver::new(), + WIRE_DISPOSITION, )?; if let Some(index) = tx_info.transaction_index { diff --git a/zero/src/bin/trie_diff.rs b/zero/src/bin/trie_diff.rs index 4c00d2ca3..c211cc528 100644 --- a/zero/src/bin/trie_diff.rs +++ b/zero/src/bin/trie_diff.rs @@ -26,6 +26,7 @@ use regex::Regex; use trace_decoder::observer::TriesObserver; use tracing::{error, info}; use zero::ops::register; +use zero::prover::WIRE_DISPOSITION; use zero::prover::{cli::CliProverConfig, BlockProverInput, ProverConfig}; /// This binary is a debugging tool used to compare @@ -97,6 +98,7 @@ async fn main() -> Result<()> { block_prover_input.other_data.clone(), prover_config.batch_size, &mut observer, + WIRE_DISPOSITION, )?; info!( "Number of collected batch tries for block {}: {}", diff --git a/zero/src/prover.rs b/zero/src/prover.rs index 4e221709c..7cc840f02 100644 --- a/zero/src/prover.rs +++ b/zero/src/prover.rs @@ -25,7 +25,7 @@ use tokio::io::AsyncWriteExt; use tokio::sync::mpsc::Receiver; use tokio::sync::{oneshot, Semaphore}; use trace_decoder::observer::DummyObserver; -use trace_decoder::{BlockTrace, OtherBlockData}; +use trace_decoder::{BlockTrace, OtherBlockData, WireDisposition}; use tracing::{error, info}; use crate::fs::generate_block_proof_file_name; @@ -55,6 +55,18 @@ pub struct ProofRuntime { // batches as soon as they are generated. static PARALLEL_BLOCK_PROVING_PERMIT_POOL: Semaphore = Semaphore::const_new(0); +pub const WIRE_DISPOSITION: WireDisposition = { + cfg_if::cfg_if! { + if #[cfg(feature = "eth_mainnet")] { + WireDisposition::Type1 + } else if #[cfg(feature = "cdk_erigon")] { + WireDisposition::Type2 + } else { + compile_error!("must select a feature"); + } + } +}; + #[derive(Debug, Clone)] pub struct ProverConfig { pub batch_size: usize, @@ -101,6 +113,7 @@ impl BlockProverInput { self.other_data, batch_size, &mut DummyObserver::new(), + WIRE_DISPOSITION, )?; // Create segment proof. @@ -193,6 +206,7 @@ impl BlockProverInput { self.other_data, batch_size, &mut DummyObserver::new(), + WIRE_DISPOSITION, )?; let seg_ops = ops::SegmentProofTestOnly {