Skip to content

Commit

Permalink
chore: Update README and camelCase all variables (#10)
Browse files Browse the repository at this point in the history
* Update README and camelCase all variables

* chore: minor nits
  • Loading branch information
kamuikatsurgi authored Oct 4, 2024
1 parent 1cf989b commit 7b03986
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 89 deletions.
18 changes: 11 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
# Merkle Multiproof Inputs Generation For Forge
> [!WARNING]
> Note that this library has not had a security review yet, is not gas-efficient, and should only be used for testing.
This is a simple library to generate Merkle Multiproof inputs for OpenZeppelin's `MerkleProof` library. It is written in Solidity and can be used in Forge framework. The library has been tested to work with at most 10k of arbitrary leaves and arbitrary indices of any size.
# Merkle MultiProof and SingleProof Inputs Generation For Forge

Note that this library is not gas-efficient and should be used for testing purposes only.
This simple library generates Merkle MultiProof and SingleProof inputs for OpenZeppelin's [`MerkleProof (v5.0.0)`](https://github.com/OpenZeppelin/openzeppelin-contracts/blob/master/contracts/utils/cryptography/MerkleProof.sol) library. It is written in Solidity and can be used in the Forge framework. The library has been tested to work with at most 10k arbitrary leaves and arbitrary indices of any size.

## Usage

See `MerkleGen.t.sol` for a sample. First, deploy `MerkleGen` contract. Then, prepare the leaves and indices. Finally, call `gen()` to generate the necessary inputs.
See [`MerkleGen.t.sol`](./test/MerkleGen.t.sol) for example:
- Import the `MerkleGen` library.
- Prepare the leaves and indices.
- Call `generateMultiproof` to generate the MultiProof or `generateSingleProof` to generate the SingleProof inputs.

A mock prover contract is provided in `Prover.sol` which forwards all inputs to OpenZeppelin's `MerkleProof.multiProofVerifyCalldata()`.
A wrapper Prover library is provided in `Prover.sol`, which forwards all the inputs to OpenZeppelin's `MerkleProof` library.

## How to run tests

The following command will run a normal test and a (pretty long) fuzzing test of 100k runs. To modify the number of runs, change the parameter in `foundry.toml`.
The following command will run a standard test and a (pretty long) fuzzing test of 100k runs. To modify the number of runs, change the `runs` parameter in `foundry.toml`.

```
forge test
```

## License

MIT
MIT
160 changes: 78 additions & 82 deletions src/MerkleGen.sol
Original file line number Diff line number Diff line change
Expand Up @@ -18,71 +18,71 @@ library MerkleGen {
* @notice Generates a Merkle MultiProof for the selected leaves.
* @dev Constructs the necessary proof components and verifies the Merkle root.
* @dev The computed root must match the actual root of the Merkle tree.
* @param hashed_leaves The array of hashed leaves in the Merkle tree.
* @param selected_indexes The indices of the leaves to include in the proof.
* @param hashedLeaves The array of hashed leaves in the Merkle tree.
* @param selectedIndexes The indices of the leaves to include in the proof.
* @return Sibling hashes required for the proof.
* @return Flags indicating the source of each proof hash.
* @return Merkle root of the tree.
*/
function generateMultiproof(bytes32[] memory hashed_leaves, uint256[] memory selected_indexes)
function generateMultiproof(bytes32[] memory hashedLeaves, uint256[] memory selectedIndexes)
public
pure
returns (bytes32[] memory, bool[] memory, bytes32)
{
bytes32[] memory layer = hashed_leaves.copy();
bytes32[] memory layer = hashedLeaves.copy();
// Append with the same leaf if odd number of leaves
if (layer.length % 2 == 1) {
layer = layer.append(layer[layer.length - 1]);
}
// Create a two dimensional array
bytes32[][] memory layers = new bytes32[][](1);
layers[0] = layer;
bytes32[] memory parent_layer;
bytes32[] memory parentLayer;
while (layer.length > 1) {
parent_layer = _computeParentLayer(layer);
layers = layers.append(parent_layer);
layer = parent_layer;
parentLayer = _computeParentLayer(layer);
layers = layers.append(parentLayer);
layer = parentLayer;
}

bytes32[] memory proof_hashes;
bool[] memory proof_source_flags;
uint256[] memory indices = selected_indexes.copy();
bytes32[] memory proofHashes;
bool[] memory proofSourceFlags;
uint256[] memory indices = selectedIndexes.copy();

bytes32[] memory subproof;
bool[] memory source_flags;
bytes32[] memory subProof;
bool[] memory sourceFlags;
for (uint256 i = 0; i < layers.length - 1; i++) {
// Exclude the last layer because it's the root
layer = layers[i];
(indices, subproof, source_flags) = _proveSingleLayer(layer, indices);
proof_hashes = proof_hashes.extend(subproof);
proof_source_flags = proof_source_flags.extend(source_flags);
(indices, subProof, sourceFlags) = _proveSingleLayer(layer, indices);
proofHashes = proofHashes.extend(subProof);
proofSourceFlags = proofSourceFlags.extend(sourceFlags);
}

// Get leaves in hashed_leaves that are in selected_indexes
bytes32[] memory indexed_leaves = new bytes32[](selected_indexes.length);
for (uint256 i = 0; i < selected_indexes.length; i++) {
indexed_leaves[i] = hashed_leaves[selected_indexes[i]];
bytes32[] memory indexed_leaves = new bytes32[](selectedIndexes.length);
for (uint256 i = 0; i < selectedIndexes.length; i++) {
indexed_leaves[i] = hashedLeaves[selectedIndexes[i]];
}

bytes32 root = _verifyComputeRoot(indexed_leaves, proof_hashes, proof_source_flags);
bytes32 root = _verifyComputeRoot(indexed_leaves, proofHashes, proofSourceFlags);

// Check if computed root is the same as the root of the tree
require(root == layers[layers.length - 1][0], "Invalid root");

// Convert proof_source_flags to bits and uint256
uint256 proof_flag_bits = 0;
bool[] memory proof_flag_bits_bool = new bool[](proof_source_flags.length);
for (uint256 i = 0; i < proof_source_flags.length; i++) {
if (proof_source_flags[i] == SOURCE_FROM_HASHES) {
proof_flag_bits_bool[i] = true;
proof_flag_bits = proof_flag_bits | (1 << i);
// Convert proofSourceFlags to bits and uint256
uint256 proofFlagBits = 0;
bool[] memory proofFlagBitsBool = new bool[](proofSourceFlags.length);
for (uint256 i = 0; i < proofSourceFlags.length; i++) {
if (proofSourceFlags[i] == SOURCE_FROM_HASHES) {
proofFlagBitsBool[i] = true;
proofFlagBits = proofFlagBits | (1 << i);
} else {
proof_flag_bits_bool[i] = false;
proof_flag_bits = proof_flag_bits | (0 << i);
proofFlagBitsBool[i] = false;
proofFlagBits = proofFlagBits | (0 << i);
}
}

return (proof_hashes, proof_flag_bits_bool, root);
return (proofHashes, proofFlagBitsBool, root);
}

/**
Expand Down Expand Up @@ -141,13 +141,13 @@ library MerkleGen {
layer = layer.append(layer[layer.length - 1]);
}

bytes32[] memory parent_layer;
bytes32[] memory parentLayer;

for (uint256 i = 0; i < layer.length; i += 2) {
parent_layer = parent_layer.append(_hashLeafPairs(layer[i], layer[i + 1]));
parentLayer = parentLayer.append(_hashLeafPairs(layer[i], layer[i + 1]));
}

return parent_layer;
return parentLayer;
}

/**
Expand Down Expand Up @@ -182,36 +182,36 @@ library MerkleGen {
pure
returns (uint256[] memory, bytes32[] memory, bool[] memory)
{
uint256[] memory auth_indices;
uint256[] memory next_indices;
bool[] memory source_flags;
uint256[] memory authIndices;
uint256[] memory nextIndices;
bool[] memory sourceFlags;
uint256 j = 0;

while (j < indices.length) {
uint256 x = indices[j];
next_indices = next_indices.append(_getParentIndex(x));
nextIndices = nextIndices.append(_getParentIndex(x));

if (((j + 1) < indices.length) && (indices[j + 1] == _getSiblingIndex(x))) {
j += 1;
source_flags = source_flags.append(SOURCE_FROM_HASHES);
sourceFlags = sourceFlags.append(SOURCE_FROM_HASHES);
} else {
auth_indices = auth_indices.append(_getSiblingIndex(x));
source_flags = source_flags.append(SOURCE_FROM_PROOF);
authIndices = authIndices.append(_getSiblingIndex(x));
sourceFlags = sourceFlags.append(SOURCE_FROM_PROOF);
}
j += 1;
}

bytes32[] memory subProof = new bytes32[](auth_indices.length);
for (uint256 i = 0; i < auth_indices.length; i++) {
bytes32[] memory subProof = new bytes32[](authIndices.length);
for (uint256 i = 0; i < authIndices.length; i++) {
// Here, if the index is out of bounds, we use the last element of the layer
if (layer.length - 1 < auth_indices[i]) {
subProof[i] = layer[auth_indices[i] - 1];
if (layer.length - 1 < authIndices[i]) {
subProof[i] = layer[authIndices[i] - 1];
} else {
subProof[i] = layer[auth_indices[i]];
subProof[i] = layer[authIndices[i]];
}
}

return (next_indices, subProof, source_flags);
return (nextIndices, subProof, sourceFlags);
}

/**
Expand All @@ -236,65 +236,65 @@ library MerkleGen {
* @dev The total number of hashes must equal the number of source flags plus one.
* @dev The number of proof hashes must match the number of `SOURCE_FROM_PROOF` flags.
* @param leaves Selected leaves to be included in the proof.
* @param proof_hashes Sibling hashes extracted from the proof.
* @param proof_source_flags Flags indicating the source of each proof hash.
* @param proofHashes Sibling hashes extracted from the proof.
* @param proofSourceFlags Flags indicating the source of each proof hash.
* @return Computed Merkle root.
*/
function _verifyComputeRoot(
bytes32[] memory leaves,
bytes32[] memory proof_hashes,
bool[] memory proof_source_flags
) internal pure returns (bytes32) {
uint256 total_hashes = leaves.length + proof_hashes.length - 1;
require(total_hashes == proof_source_flags.length, "MerkleGen: Invalid total hashes.");
function _verifyComputeRoot(bytes32[] memory leaves, bytes32[] memory proofHashes, bool[] memory proofSourceFlags)
internal
pure
returns (bytes32)
{
uint256 totalHashes = leaves.length + proofHashes.length - 1;
require(totalHashes == proofSourceFlags.length, "MerkleGen: Invalid total hashes.");
require(
_helperCount(proof_source_flags, SOURCE_FROM_PROOF) == proof_hashes.length,
_helperCount(proofSourceFlags, SOURCE_FROM_PROOF) == proofHashes.length,
"MerkleGen: Invalid number of proof hashes."
);

bytes32[] memory hashes = new bytes32[](total_hashes);
bytes32[] memory hashes = new bytes32[](totalHashes);
// Fill hashes with leaves[0]
for (uint256 i = 0; i < total_hashes; i++) {
for (uint256 i = 0; i < totalHashes; i++) {
hashes[i] = leaves[0];
}
// Variables
uint256 leaf_pos = 0;
uint256 hash_pos = 0;
uint256 proof_pos = 0;
uint256 leafPos = 0;
uint256 hashPos = 0;
uint256 proofPos = 0;

for (uint256 i = 0; i < total_hashes; i++) {
for (uint256 i = 0; i < totalHashes; i++) {
bytes32 a;
bytes32 b;

// Select a
if (proof_source_flags[i] == SOURCE_FROM_HASHES) {
if (leaf_pos < leaves.length) {
a = leaves[leaf_pos];
leaf_pos += 1;
if (proofSourceFlags[i] == SOURCE_FROM_HASHES) {
if (leafPos < leaves.length) {
a = leaves[leafPos];
leafPos += 1;
} else {
a = hashes[hash_pos];
hash_pos += 1;
a = hashes[hashPos];
hashPos += 1;
}
} else if (proof_source_flags[i] == SOURCE_FROM_PROOF) {
a = proof_hashes[proof_pos];
proof_pos += 1;
} else if (proofSourceFlags[i] == SOURCE_FROM_PROOF) {
a = proofHashes[proofPos];
proofPos += 1;
}

// Select b
if (leaf_pos < leaves.length) {
b = leaves[leaf_pos];
leaf_pos += 1;
if (leafPos < leaves.length) {
b = leaves[leafPos];
leafPos += 1;
} else {
b = hashes[hash_pos];
hash_pos += 1;
b = hashes[hashPos];
hashPos += 1;
}

// Compute hash
hashes[i] = _hashLeafPairs(a, b);
}

if (total_hashes > 0) {
return hashes[total_hashes - 1];
if (totalHashes > 0) {
return hashes[totalHashes - 1];
} else {
return leaves[0];
}
Expand Down Expand Up @@ -347,8 +347,6 @@ library MerkleGen {
* @return The root hash of the Merkle tree.
*/
function _getRoot(bytes32[] memory leaves) internal pure returns (bytes32) {
require(leaves.length > 1, "MerkleGen: Data should be greater than 1.");

bytes32[] memory tree = _buildTree(leaves);

return tree[0];
Expand All @@ -362,8 +360,6 @@ library MerkleGen {
* @return An array of sibling hashes forming the Merkle proof for the leaf at the specified index.
*/
function _getProof(bytes32[] memory leaves, uint256 index) internal pure returns (bytes32[] memory) {
require(leaves.length > 1, "MerkleGen: Leaves should be greater than 1.");

bytes32[] memory tree = _buildTree(leaves);

uint256 proofLength = _log2CeilBitMagic(leaves.length);
Expand Down

0 comments on commit 7b03986

Please sign in to comment.