Skip to content

Commit

Permalink
use smt batch insert method to accelerate BuildBlockInfoTree process (e…
Browse files Browse the repository at this point in the history
…rigontech#744)

Co-authored-by: Valentin Staykov <[email protected]>
  • Loading branch information
louisliu2048 and V-Staykov authored Jul 10, 2024
1 parent 20a9440 commit 3321a6b
Show file tree
Hide file tree
Showing 3 changed files with 291 additions and 3 deletions.
190 changes: 187 additions & 3 deletions smt/pkg/blockinfo/block_info.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package blockinfo

import (
"context"
"fmt"
"math/big"

Expand Down Expand Up @@ -47,6 +48,8 @@ func BuildBlockInfoTree(
)
var err error
var logIndex int64 = 0
var keys []*utils.NodeKey
var vals []*utils.NodeValue8
for i, txInfo := range *transactionInfos {
receipt := txInfo.Receipt
t := txInfo.Tx
Expand All @@ -67,19 +70,28 @@ func BuildBlockInfoTree(

log.Trace("info-tree-tx", "block", blockNumber, "idx", i, "hash", l2TxHash.String())

_, err = infoTree.SetBlockTx(&l2TxHash, i, receipt, logIndex, receipt.CumulativeGasUsed, txInfo.EffectiveGasPrice)
genKeys, genVals, err := infoTree.GenerateBlockTxKeysVals(&l2TxHash, i, receipt, logIndex, receipt.CumulativeGasUsed, txInfo.EffectiveGasPrice)
if err != nil {
return nil, err
}
keys = append(keys, genKeys...)
vals = append(vals, genVals...)

logIndex += int64(len(receipt.Logs))
}

root, err := infoTree.SetBlockGasUsed(blockGasUsed)
key, val, err := generateBlockGasUsed(blockGasUsed)
if err != nil {
return nil, err
}
keys = append(keys, key)
vals = append(vals, val)

rootHash := common.BigToHash(root)
root, err := infoTree.smt.InsertBatch(context.Background(), "", keys, vals, nil, nil)
if err != nil {
return nil, err
}
rootHash := common.BigToHash(root.NewRootScalar.ToBigInt())

log.Trace("info-tree-root", "block", blockNumber, "root", rootHash.String())

Expand Down Expand Up @@ -381,3 +393,175 @@ func setL1BlockHash(smt *smt.SMT, blockHash *common.Hash) (*big.Int, error) {

return resp.NewRootScalar.ToBigInt(), nil
}

func bigInt2NodeVal8(val *big.Int) (*utils.NodeValue8, error) {
x := utils.ScalarToArrayBig(val)
v, err := utils.NodeValue8FromBigIntArray(x)
if err != nil {
return nil, err
}

return v, nil
}

func generateL2TxHash(txIndex *big.Int, l2TxHash *big.Int) (*utils.NodeKey, *utils.NodeValue8, error) {
key, err := KeyTxHash(txIndex)
if err != nil {
return nil, nil, err
}
val, err := bigInt2NodeVal8(l2TxHash)
if err != nil {
return nil, nil, err
}

return &key, val, nil
}

func generateTxStatus(txIndex *big.Int, status *big.Int) (*utils.NodeKey, *utils.NodeValue8, error) {
key, err := KeyTxStatus(txIndex)
if err != nil {
return nil, nil, err
}
val, err := bigInt2NodeVal8(status)
if err != nil {
return nil, nil, err
}

return &key, val, nil
}

func generateCumulativeGasUsed(txIndex, cumulativeGasUsed *big.Int) (*utils.NodeKey, *utils.NodeValue8, error) {
key, err := KeyCumulativeGasUsed(txIndex)
if err != nil {
return nil, nil, err
}
val, err := bigInt2NodeVal8(cumulativeGasUsed)
if err != nil {
return nil, nil, err
}
return &key, val, nil
}

func generateTxLog(txIndex *big.Int, logIndex *big.Int, log *big.Int) (*utils.NodeKey, *utils.NodeValue8, error) {
key, err := KeyTxLogs(txIndex, logIndex)
if err != nil {
return nil, nil, err
}
val, err := bigInt2NodeVal8(log)
if err != nil {
return nil, nil, err
}

return &key, val, nil
}

func generateTxEffectivePercentage(txIndex, effectivePercentage *big.Int) (*utils.NodeKey, *utils.NodeValue8, error) {
key, err := KeyEffectivePercentage(txIndex)
if err != nil {
return nil, nil, err
}
val, err := bigInt2NodeVal8(effectivePercentage)
if err != nil {
return nil, nil, err
}

return &key, val, nil
}

func generateBlockGasUsed(gasUsed uint64) (*utils.NodeKey, *utils.NodeValue8, error) {
key, err := KeyBlockHeaderParams(big.NewInt(IndexBlockHeaderParamGasUsed))
if err != nil {
return nil, nil, err
}
gasUsedBig := big.NewInt(0).SetUint64(gasUsed)
val, err := bigInt2NodeVal8(gasUsedBig)
if err != nil {
return nil, nil, err
}

return &key, val, nil
}

func (b *BlockInfoTree) GenerateBlockTxKeysVals(
l2TxHash *common.Hash,
txIndex int,
receipt *ethTypes.Receipt,
logIndex int64,
cumulativeGasUsed uint64,
effectivePercentage uint8,
) ([]*utils.NodeKey, []*utils.NodeValue8, error) {
var keys []*utils.NodeKey
var vals []*utils.NodeValue8
txIndexBig := big.NewInt(int64(txIndex))

key, val, err := generateL2TxHash(txIndexBig, l2TxHash.Big())
if err != nil {
return nil, nil, err
}
keys = append(keys, key)
vals = append(vals, val)

bigStatus := big.NewInt(0).SetUint64(receipt.Status)
key, val, err = generateTxStatus(txIndexBig, bigStatus)
if err != nil {
return nil, nil, err
}
keys = append(keys, key)
vals = append(vals, val)

bigCumulativeGasUsed := big.NewInt(0).SetUint64(cumulativeGasUsed)
key, val, err = generateCumulativeGasUsed(txIndexBig, bigCumulativeGasUsed)
if err != nil {
return nil, nil, err
}
keys = append(keys, key)
vals = append(vals, val)

log.Trace("info-tree-tx-inner",
"tx-index", txIndex,
"log-index", logIndex,
"cumulativeGasUsed", cumulativeGasUsed,
"effective-percentage", effectivePercentage,
"receipt-status", receipt.Status,
)

// now encode the logs
for _, rLog := range receipt.Logs {
reducedTopics := ""
for _, topic := range rLog.Topics {
reducedTopics += fmt.Sprintf("%x", topic)
}

logToEncode := fmt.Sprintf("0x%x%s", rLog.Data, reducedTopics)

hash, err := utils.HashContractBytecode(logToEncode)
if err != nil {
return nil, nil, err
}

logEncodedBig := utils.ConvertHexToBigInt(hash)
key, val, err = generateTxLog(txIndexBig, big.NewInt(logIndex), logEncodedBig)
keys = append(keys, key)
vals = append(vals, val)

log.Trace("info-tree-tx-receipt-log",
"topics", reducedTopics,
"to-encode", logToEncode,
"log-index", logIndex,
)

// increment log index
logIndex += 1
}

// setTxEffectivePercentage
bigEffectivePercentage := big.NewInt(0).SetUint64(uint64(effectivePercentage))
key, val, err = generateTxEffectivePercentage(txIndexBig, bigEffectivePercentage)
if err != nil {
return nil, nil, err
}
keys = append(keys, key)
vals = append(vals, val)

return keys, vals, nil
}
24 changes: 24 additions & 0 deletions smt/pkg/blockinfo/block_info_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package blockinfo

import (
"context"
"math/big"
"testing"

Expand Down Expand Up @@ -210,6 +211,29 @@ func TestSetBlockTx(t *testing.T) {
}
rootHex := common.BigToHash(root).Hex()

infoTree2 := NewBlockInfoTree()
keys, vals, err := infoTree2.GenerateBlockTxKeysVals(
&test.l2TxHash,
test.txIndex,
&test.receipt,
test.logIndex,
test.cumulativeGasUsed,
test.effectivePercentage,
)
if err != nil {
t.Fatal(err)
}

root2, err2 := infoTree2.smt.InsertBatch(context.Background(), "", keys, vals, nil, nil)
if err2 != nil {
t.Fatal(err2)
}

rootHex2 := common.BigToHash(root2.NewRootScalar.ToBigInt()).Hex()
if rootHex != rootHex2 {
t.Fatalf("generate different root, raw method root is %s, new method root %s", rootHex, rootHex2)
}

if rootHex != test.finalBlockInfoRoot {
t.Fatalf("expected root %s, got %s", test.finalBlockInfoRoot, rootHex)
}
Expand Down
80 changes: 80 additions & 0 deletions smt/pkg/smt/smt_batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,86 @@ func TestBatchSimpleInsert(t *testing.T) {
assertSmtDbStructure(t, smtBatch, false)
}

func incrementalInsert(tree *smt.SMT, key, val []*big.Int) {
for i := range key {
k := utils.ScalarToNodeKey(key[i])
tree.InsertKA(k, val[i])
}
}

func batchInsert(tree *smt.SMT, key, val []*big.Int) {
keyPointers := []*utils.NodeKey{}
valuePointers := []*utils.NodeValue8{}

for i := range key {
k := utils.ScalarToNodeKey(key[i])
vArray := utils.ScalarToArrayBig(val[i])
v, _ := utils.NodeValue8FromBigIntArray(vArray)

keyPointers = append(keyPointers, &k)
valuePointers = append(valuePointers, v)
}
tree.InsertBatch(context.Background(), "", keyPointers, valuePointers, nil, nil)
}

func BenchmarkIncrementalInsert(b *testing.B) {
keys := []*big.Int{}
vals := []*big.Int{}
for i := 0; i < 1000; i++ {
rand.Seed(time.Now().UnixNano())
keys = append(keys, big.NewInt(int64(rand.Intn(10000))))

rand.Seed(time.Now().UnixNano())
vals = append(vals, big.NewInt(int64(rand.Intn(10000))))
}

b.ResetTimer()
for i := 0; i < b.N; i++ {
smtIncremental := smt.NewSMT(nil)
incrementalInsert(smtIncremental, keys, vals)
}
}

func BenchmarkBatchInsert(b *testing.B) {
keys := []*big.Int{}
vals := []*big.Int{}
for i := 0; i < 1000; i++ {
rand.Seed(time.Now().UnixNano())
keys = append(keys, big.NewInt(int64(rand.Intn(10000))))

rand.Seed(time.Now().UnixNano())
vals = append(vals, big.NewInt(int64(rand.Intn(10000))))
}

b.ResetTimer()
for i := 0; i < b.N; i++ {
smtBatch := smt.NewSMT(nil)
batchInsert(smtBatch, keys, vals)
}
}

func TestBatchSimpleInsert2(t *testing.T) {
keys := []*big.Int{}
vals := []*big.Int{}
for i := 0; i < 1000; i++ {
rand.Seed(time.Now().UnixNano())
keys = append(keys, big.NewInt(int64(rand.Intn(10000))))

rand.Seed(time.Now().UnixNano())
vals = append(vals, big.NewInt(int64(rand.Intn(10000))))
}

smtIncremental := smt.NewSMT(nil)
incrementalInsert(smtIncremental, keys, vals)

smtBatch := smt.NewSMT(nil)
batchInsert(smtBatch, keys, vals)

smtIncrementalRootHash, _ := smtIncremental.Db.GetLastRoot()
smtBatchRootHash, _ := smtBatch.Db.GetLastRoot()
assert.Equal(t, utils.ConvertBigIntToHex(smtBatchRootHash), utils.ConvertBigIntToHex(smtIncrementalRootHash))
}

func TestBatchWitness(t *testing.T) {
keys := []utils.NodeKey{
utils.NodeKey{17822804428864912231, 4683868963463720294, 2947512351908939790, 2330225637707749973},
Expand Down

0 comments on commit 3321a6b

Please sign in to comment.