Skip to content

Commit

Permalink
refactor: trait StateTrie (#542)
Browse files Browse the repository at this point in the history
* mark: 0xaatif/typed-smt2

* refactor: remove dead code

* refactor: StateTrie::reporting_remove

* refactor: StateTrie::trim_to

* refactor: typed_mpt::Error -> anyhow::Error

* refactor: StateTrie -> StateMpt

* wip

* refactor: StateMpt::iter -> (H256, AccountRlp)

* wip

* chore: remove investigate

* chore: sort and prune deps
  • Loading branch information
0xaatif authored Aug 28, 2024
1 parent 85e72f2 commit 00c759a
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 201 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 5 additions & 6 deletions trace_decoder/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,17 @@ strum = { version = "0.26.3", features = ["derive"] }
thiserror = { workspace = true }
u4 = { workspace = true }
winnow = { workspace = true }
zk_evm_common = {workspace = true}
zk_evm_common = { workspace = true }

[dev-dependencies]
alloy = { workspace = true }
criterion = { workspace = true }
plonky2_maybe_rayon = { workspace = true }
pretty_env_logger = { workspace = true }
serde_json = { workspace = true }
prover = { workspace = true }
serde_path_to_error = { workspace = true }
plonky2_maybe_rayon = { workspace = true }
alloy = { workspace = true }
rstest = "0.21.0"

serde_json = { workspace = true }
serde_path_to_error = { workspace = true }

[[bench]]
name = "block_processing"
Expand Down
128 changes: 51 additions & 77 deletions trace_decoder/src/decoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use mpt_trie::{
nibbles::Nibbles,
partial_trie::{HashedPartialTrie, PartialTrie as _},
special_query::path_for_query,
trie_ops::TrieOpError,
utils::{IntoTrieKey as _, TriePath},
};

Expand All @@ -25,14 +26,14 @@ use crate::{
NodesUsedByTxn, ProcessedBlockTrace, ProcessedTxnInfo, StateWrite, TxnMetaState,
},
typed_mpt::{ReceiptTrie, StateTrie, StorageTrie, TransactionTrie, TrieKey},
OtherBlockData, PartialTriePreImages,
OtherBlockData, PartialTriePreImages, TryIntoExt as TryIntoBounds,
};

/// The current state of all tries as we process txn deltas. These are mutated
/// after every txn we process in the trace.
#[derive(Clone, Debug, Default)]
struct PartialTrieState {
state: StateTrie,
struct PartialTrieState<StateTrieT> {
state: StateTrieT,
storage: HashMap<H256, StorageTrie>,
txn: TransactionTrie,
receipt: ReceiptTrie,
Expand Down Expand Up @@ -113,7 +114,7 @@ pub fn into_txn_proof_gen_ir(
/// need to update the storage of the beacon block root contract.
// See <https://eips.ethereum.org/EIPS/eip-4788>.
fn update_beacon_block_root_contract_storage(
trie_state: &mut PartialTrieState,
trie_state: &mut PartialTrieState<impl StateTrie>,
delta_out: &mut TrieDeltaApplicationOutput,
nodes_used: &mut NodesUsedByTxn,
block_data: &BlockMetadata,
Expand Down Expand Up @@ -207,7 +208,7 @@ fn update_beacon_block_root_contract_storage(
}

fn update_txn_and_receipt_tries(
trie_state: &mut PartialTrieState,
trie_state: &mut PartialTrieState<impl StateTrie>,
meta: &TxnMetaState,
txn_idx: usize,
) -> anyhow::Result<()> {
Expand Down Expand Up @@ -246,20 +247,19 @@ fn init_any_needed_empty_storage_tries<'a>(
}

fn create_minimal_partial_tries_needed_by_txn(
curr_block_tries: &PartialTrieState,
curr_block_tries: &PartialTrieState<impl StateTrie + Clone + TryIntoBounds<HashedPartialTrie>>,
nodes_used_by_txn: &NodesUsedByTxn,
txn_range: Range<usize>,
delta_application_out: TrieDeltaApplicationOutput,
) -> anyhow::Result<TrieInputs> {
let state_trie = create_minimal_state_partial_trie(
&curr_block_tries.state,
nodes_used_by_txn.state_accesses.iter().map(hash),
delta_application_out
.additional_state_trie_paths_to_not_hash
.into_iter(),
)?
.as_hashed_partial_trie()
.clone();
let mut state_trie = curr_block_tries.state.clone();
state_trie.trim_to(
nodes_used_by_txn
.state_accesses
.iter()
.map(|it| TrieKey::from_address(*it))
.chain(delta_application_out.additional_state_trie_paths_to_not_hash),
)?;

let txn_keys = txn_range.map(TrieKey::from_txn_ix);

Expand All @@ -282,15 +282,15 @@ fn create_minimal_partial_tries_needed_by_txn(
)?;

Ok(TrieInputs {
state_trie,
state_trie: state_trie.try_into()?,
transactions_trie,
receipts_trie,
storage_tries,
})
}

fn apply_deltas_to_trie_state(
trie_state: &mut PartialTrieState,
trie_state: &mut PartialTrieState<impl StateTrie>,
deltas: &NodesUsedByTxn,
meta: &[TxnMetaState],
) -> anyhow::Result<TrieDeltaApplicationOutput> {
Expand Down Expand Up @@ -360,12 +360,7 @@ fn apply_deltas_to_trie_state(

if !receipt.status {
// The transaction failed, hence any created account should be removed.
if let Some(remaining_account_key) =
delete_node_and_report_remaining_key_if_branch_collapsed(
trie_state.state.as_mut_hashed_partial_trie_unchecked(),
&TrieKey::from_hash(hash(addr)),
)?
{
if let Some(remaining_account_key) = trie_state.state.reporting_remove(*addr)? {
out.additional_state_trie_paths_to_not_hash
.push(remaining_account_key);
trie_state.storage.remove(&hash(addr));
Expand All @@ -379,12 +374,7 @@ fn apply_deltas_to_trie_state(
for addr in deltas.self_destructed_accounts.iter() {
trie_state.storage.remove(&hash(addr));

if let Some(remaining_account_key) =
delete_node_and_report_remaining_key_if_branch_collapsed(
trie_state.state.as_mut_hashed_partial_trie_unchecked(),
&TrieKey::from_hash(hash(addr)),
)?
{
if let Some(remaining_account_key) = trie_state.state.reporting_remove(*addr)? {
out.additional_state_trie_paths_to_not_hash
.push(remaining_account_key);
}
Expand All @@ -400,13 +390,14 @@ fn get_trie_trace(trie: &HashedPartialTrie, k: &Nibbles) -> TriePath {
/// If a branch collapse occurred after a delete, then we must ensure that
/// the other single child that remains also is not hashed when passed into
/// plonky2. Returns the key to the remaining child if a collapse occurred.
fn delete_node_and_report_remaining_key_if_branch_collapsed(
pub fn delete_node_and_report_remaining_key_if_branch_collapsed(
trie: &mut HashedPartialTrie,
delete_k: &TrieKey,
) -> anyhow::Result<Option<TrieKey>> {
let old_trace = get_trie_trace(trie, &delete_k.into_nibbles());
trie.delete(delete_k.into_nibbles())?;
let new_trace = get_trie_trace(trie, &delete_k.into_nibbles());
key: &TrieKey,
) -> Result<Option<TrieKey>, TrieOpError> {
let key = key.into_nibbles();
let old_trace = get_trie_trace(trie, &key);
trie.delete(key)?;
let new_trace = get_trie_trace(trie, &key);
Ok(
node_deletion_resulted_in_a_branch_collapse(&old_trace, &new_trace)
.map(TrieKey::from_nibbles),
Expand Down Expand Up @@ -441,7 +432,9 @@ fn node_deletion_resulted_in_a_branch_collapse(
/// The withdrawals are always in the final ir payload.
fn add_withdrawals_to_txns(
txn_ir: &mut [GenerationInputs],
final_trie_state: &mut PartialTrieState,
final_trie_state: &mut PartialTrieState<
impl StateTrie + Clone + TryIntoBounds<HashedPartialTrie>,
>,
mut withdrawals: Vec<(Address, U256)>,
) -> anyhow::Result<()> {
// Scale withdrawals amounts.
Expand All @@ -460,25 +453,22 @@ fn add_withdrawals_to_txns(
.expect("We cannot have an empty list of payloads.");

if last_inputs.signed_txns.is_empty() {
// This is a dummy payload, hence it does not contain yet
// state accesses to the withdrawal addresses.
let withdrawal_addrs = withdrawals_with_hashed_addrs_iter().map(|(_, h_addr, _)| h_addr);

let additional_paths = if last_inputs.txn_number_before == 0.into() {
// We need to include the beacon roots contract as this payload is at the
// start of the block execution.
vec![TrieKey::from_hash(BEACON_ROOTS_CONTRACT_ADDRESS_HASHED)]
} else {
vec![]
};

last_inputs.tries.state_trie = create_minimal_state_partial_trie(
&final_trie_state.state,
withdrawal_addrs,
additional_paths,
)?
.as_hashed_partial_trie()
.clone();
let mut state_trie = final_trie_state.state.clone();
state_trie.trim_to(
// This is a dummy payload, hence it does not contain yet
// state accesses to the withdrawal addresses.
withdrawals
.iter()
.map(|(addr, _)| *addr)
.chain(match last_inputs.txn_number_before == 0.into() {
// We need to include the beacon roots contract as this payload is at the
// start of the block execution.
true => Some(BEACON_ROOTS_CONTRACT_ADDRESS),
false => None,
})
.map(TrieKey::from_address),
)?;
last_inputs.tries.state_trie = state_trie.try_into()?;
}

update_trie_state_from_withdrawals(
Expand All @@ -487,7 +477,7 @@ fn add_withdrawals_to_txns(
)?;

last_inputs.withdrawals = withdrawals;
last_inputs.trie_roots_after.state_root = final_trie_state.state.root();
last_inputs.trie_roots_after.state_root = final_trie_state.state.clone().try_into()?.hash();

Ok(())
}
Expand All @@ -496,7 +486,7 @@ fn add_withdrawals_to_txns(
/// our local trie state.
fn update_trie_state_from_withdrawals<'a>(
withdrawals: impl IntoIterator<Item = (Address, H256, U256)> + 'a,
state: &mut StateTrie,
state: &mut impl StateTrie,
) -> anyhow::Result<()> {
for (addr, h_addr, amt) in withdrawals {
let mut acc_data = state.get_by_address(addr).context(format!(
Expand All @@ -520,7 +510,9 @@ fn process_txn_info(
txn_range: Range<usize>,
is_initial_payload: bool,
txn_info: ProcessedTxnInfo,
curr_block_tries: &mut PartialTrieState,
curr_block_tries: &mut PartialTrieState<
impl StateTrie + Clone + TryIntoBounds<HashedPartialTrie>,
>,
extra_data: &mut ExtraBlockData,
other_data: &OtherBlockData,
) -> anyhow::Result<GenerationInputs> {
Expand Down Expand Up @@ -595,7 +587,7 @@ fn process_txn_info(
* for more info). */
tries,
trie_roots_after: TrieRoots {
state_root: curr_block_tries.state.root(),
state_root: curr_block_tries.state.clone().try_into()?.hash(),
transactions_root: curr_block_tries.txn.root(),
receipts_root: curr_block_tries.receipt.root(),
},
Expand Down Expand Up @@ -645,22 +637,6 @@ impl StateWrite {
}
}

fn create_minimal_state_partial_trie(
state_trie: &StateTrie,
state_accesses: impl IntoIterator<Item = H256>,
additional_state_trie_paths_to_not_hash: impl IntoIterator<Item = TrieKey>,
) -> anyhow::Result<StateTrie> {
create_trie_subset_wrapped(
state_trie.as_hashed_partial_trie(),
state_accesses
.into_iter()
.map(TrieKey::from_hash)
.chain(additional_state_trie_paths_to_not_hash),
TrieType::State,
)
.map(StateTrie::from_hashed_partial_trie_unchecked)
}

// TODO!!!: We really need to be appending the empty storage tries to the base
// trie somewhere else! This is a big hack!
fn create_minimal_storage_partial_tries<'a>(
Expand Down Expand Up @@ -714,11 +690,9 @@ fn eth_to_gwei(eth: U256) -> U256 {
const ZERO_STORAGE_SLOT_VAL_RLPED: [u8; 1] = [128];

/// Aid for error context.
/// Covers all Ethereum trie types (see <https://ethereum.github.io/yellowpaper/paper.pdf> for details).
#[derive(Debug, strum::Display)]
#[allow(missing_docs)]
enum TrieType {
State,
Storage,
Receipt,
Txn,
Expand Down
35 changes: 22 additions & 13 deletions trace_decoder/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ use keccak_hash::H256;
use mpt_trie::partial_trie::{HashedPartialTrie, OnOrphanedHashNode};
use processed_block_trace::ProcessedTxnInfo;
use serde::{Deserialize, Serialize};
use typed_mpt::{StateTrie, StorageTrie, TrieKey};
use typed_mpt::{StateMpt, StateTrie as _, StorageTrie, TrieKey};

/// Core payload needed to generate proof for a block.
/// Additional data retrievable from the blockchain node (using standard ETH RPC
Expand Down Expand Up @@ -311,7 +311,7 @@ pub fn entrypoint(
}) => ProcessedBlockTracePreImages {
tries: PartialTriePreImages {
state: state.items().try_fold(
StateTrie::new(OnOrphanedHashNode::Reject),
StateMpt::new(OnOrphanedHashNode::Reject),
|mut acc, (nibbles, hash_or_val)| {
let path = TrieKey::from_nibbles(nibbles);
match hash_or_val {
Expand Down Expand Up @@ -367,10 +367,7 @@ pub fn entrypoint(
ProcessedBlockTracePreImages {
tries: PartialTriePreImages {
state,
storage: storage
.into_iter()
.map(|(path, trie)| (path.into_hash_left_padded(), trie))
.collect(),
storage: storage.into_iter().collect(),
},
extra_code_hash_mappings: match code.is_empty() {
true => None,
Expand All @@ -384,12 +381,7 @@ pub fn entrypoint(
}
};

let all_accounts_in_pre_images = pre_images
.tries
.state
.iter()
.map(|(addr, data)| (addr.into_hash_left_padded(), data))
.collect::<Vec<_>>();
let all_accounts_in_pre_images = pre_images.tries.state.iter().collect::<Vec<_>>();

// Note we discard any user-provided hashes.
let mut hash2code = code_db
Expand Down Expand Up @@ -449,7 +441,7 @@ pub fn entrypoint(

#[derive(Debug, Default)]
struct PartialTriePreImages {
pub state: StateTrie,
pub state: StateMpt,
pub storage: HashMap<H256, StorageTrie>,
}

Expand Down Expand Up @@ -479,6 +471,23 @@ mod hex {
}
}

trait TryIntoExt<T> {
type Error: std::error::Error + Send + Sync + 'static;
fn try_into(self) -> Result<T, Self::Error>;
}

impl<ThisT, T, E> TryIntoExt<T> for ThisT
where
ThisT: TryInto<T, Error = E>,
E: std::error::Error + Send + Sync + 'static,
{
type Error = ThisT::Error;

fn try_into(self) -> Result<T, Self::Error> {
TryInto::try_into(self)
}
}

#[cfg(test)]
#[derive(serde::Deserialize)]
struct Case {
Expand Down
2 changes: 1 addition & 1 deletion trace_decoder/src/processed_block_trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use evm_arithmetization::generation::mpt::{AccountRlp, LegacyReceiptRlp};
use itertools::Itertools;
use zk_evm_common::EMPTY_TRIE_HASH;

use crate::typed_mpt::TrieKey;
use crate::typed_mpt::{StateTrie as _, TrieKey};
use crate::PartialTriePreImages;
use crate::{hash, TxnTrace};
use crate::{ContractCodeUsage, TxnInfo};
Expand Down
Loading

0 comments on commit 00c759a

Please sign in to comment.