diff --git a/README.md b/README.md index 0397aa7..6b0bfb2 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,22 @@ -# Merkle Multiproof Inputs Generation For Forge +> [!WARNING] +> Note that this library has not had a security review yet, is not gas-efficient, and should only be used for testing. -This is a simple library to generate Merkle Multiproof inputs for OpenZeppelin's `MerkleProof` library. It is written in Solidity and can be used in Forge framework. The library has been tested to work with at most 10k of arbitrary leaves and arbitrary indices of any size. +# Merkle MultiProof and SingleProof Inputs Generation For Forge -Note that this library is not gas-efficient and should be used for testing purposes only. +This simple library generates Merkle MultiProof and SingleProof inputs for OpenZeppelin's [`MerkleProof (v5.0.0)`](https://github.com/OpenZeppelin/openzeppelin-contracts/blob/master/contracts/utils/cryptography/MerkleProof.sol) library. It is written in Solidity and can be used in the Forge framework. The library has been tested to work with at most 10k arbitrary leaves and arbitrary indices of any size. ## Usage -See `MerkleGen.t.sol` for a sample. First, deploy `MerkleGen` contract. Then, prepare the leaves and indices. Finally, call `gen()` to generate the necessary inputs. +See [`MerkleGen.t.sol`](./test/MerkleGen.t.sol) for example: +- Import the `MerkleGen` library. +- Prepare the leaves and indices. +- Call `generateMultiproof` to generate the MultiProof or `generateSingleProof` to generate the SingleProof inputs. -A mock prover contract is provided in `Prover.sol` which forwards all inputs to OpenZeppelin's `MerkleProof.multiProofVerifyCalldata()`. +A wrapper Prover library is provided in `Prover.sol`, which forwards all the inputs to OpenZeppelin's `MerkleProof` library. ## How to run tests -The following command will run a normal test and a (pretty long) fuzzing test of 100k runs. To modify the number of runs, change the parameter in `foundry.toml`. +The following command will run a standard test and a (pretty long) fuzzing test of 100k runs. To modify the number of runs, change the `runs` parameter in `foundry.toml`. ``` forge test @@ -20,4 +24,4 @@ forge test ## License -MIT +MIT \ No newline at end of file diff --git a/src/MerkleGen.sol b/src/MerkleGen.sol index c716e18..a5af683 100644 --- a/src/MerkleGen.sol +++ b/src/MerkleGen.sol @@ -18,18 +18,18 @@ library MerkleGen { * @notice Generates a Merkle MultiProof for the selected leaves. * @dev Constructs the necessary proof components and verifies the Merkle root. * @dev The computed root must match the actual root of the Merkle tree. - * @param hashed_leaves The array of hashed leaves in the Merkle tree. - * @param selected_indexes The indices of the leaves to include in the proof. + * @param hashedLeaves The array of hashed leaves in the Merkle tree. + * @param selectedIndexes The indices of the leaves to include in the proof. * @return Sibling hashes required for the proof. * @return Flags indicating the source of each proof hash. * @return Merkle root of the tree. */ - function generateMultiproof(bytes32[] memory hashed_leaves, uint256[] memory selected_indexes) + function generateMultiproof(bytes32[] memory hashedLeaves, uint256[] memory selectedIndexes) public pure returns (bytes32[] memory, bool[] memory, bytes32) { - bytes32[] memory layer = hashed_leaves.copy(); + bytes32[] memory layer = hashedLeaves.copy(); // Append with the same leaf if odd number of leaves if (layer.length % 2 == 1) { layer = layer.append(layer[layer.length - 1]); @@ -37,52 +37,52 @@ library MerkleGen { // Create a two dimensional array bytes32[][] memory layers = new bytes32[][](1); layers[0] = layer; - bytes32[] memory parent_layer; + bytes32[] memory parentLayer; while (layer.length > 1) { - parent_layer = _computeParentLayer(layer); - layers = layers.append(parent_layer); - layer = parent_layer; + parentLayer = _computeParentLayer(layer); + layers = layers.append(parentLayer); + layer = parentLayer; } - bytes32[] memory proof_hashes; - bool[] memory proof_source_flags; - uint256[] memory indices = selected_indexes.copy(); + bytes32[] memory proofHashes; + bool[] memory proofSourceFlags; + uint256[] memory indices = selectedIndexes.copy(); - bytes32[] memory subproof; - bool[] memory source_flags; + bytes32[] memory subProof; + bool[] memory sourceFlags; for (uint256 i = 0; i < layers.length - 1; i++) { // Exclude the last layer because it's the root layer = layers[i]; - (indices, subproof, source_flags) = _proveSingleLayer(layer, indices); - proof_hashes = proof_hashes.extend(subproof); - proof_source_flags = proof_source_flags.extend(source_flags); + (indices, subProof, sourceFlags) = _proveSingleLayer(layer, indices); + proofHashes = proofHashes.extend(subProof); + proofSourceFlags = proofSourceFlags.extend(sourceFlags); } // Get leaves in hashed_leaves that are in selected_indexes - bytes32[] memory indexed_leaves = new bytes32[](selected_indexes.length); - for (uint256 i = 0; i < selected_indexes.length; i++) { - indexed_leaves[i] = hashed_leaves[selected_indexes[i]]; + bytes32[] memory indexed_leaves = new bytes32[](selectedIndexes.length); + for (uint256 i = 0; i < selectedIndexes.length; i++) { + indexed_leaves[i] = hashedLeaves[selectedIndexes[i]]; } - bytes32 root = _verifyComputeRoot(indexed_leaves, proof_hashes, proof_source_flags); + bytes32 root = _verifyComputeRoot(indexed_leaves, proofHashes, proofSourceFlags); // Check if computed root is the same as the root of the tree require(root == layers[layers.length - 1][0], "Invalid root"); - // Convert proof_source_flags to bits and uint256 - uint256 proof_flag_bits = 0; - bool[] memory proof_flag_bits_bool = new bool[](proof_source_flags.length); - for (uint256 i = 0; i < proof_source_flags.length; i++) { - if (proof_source_flags[i] == SOURCE_FROM_HASHES) { - proof_flag_bits_bool[i] = true; - proof_flag_bits = proof_flag_bits | (1 << i); + // Convert proofSourceFlags to bits and uint256 + uint256 proofFlagBits = 0; + bool[] memory proofFlagBitsBool = new bool[](proofSourceFlags.length); + for (uint256 i = 0; i < proofSourceFlags.length; i++) { + if (proofSourceFlags[i] == SOURCE_FROM_HASHES) { + proofFlagBitsBool[i] = true; + proofFlagBits = proofFlagBits | (1 << i); } else { - proof_flag_bits_bool[i] = false; - proof_flag_bits = proof_flag_bits | (0 << i); + proofFlagBitsBool[i] = false; + proofFlagBits = proofFlagBits | (0 << i); } } - return (proof_hashes, proof_flag_bits_bool, root); + return (proofHashes, proofFlagBitsBool, root); } /** @@ -141,13 +141,13 @@ library MerkleGen { layer = layer.append(layer[layer.length - 1]); } - bytes32[] memory parent_layer; + bytes32[] memory parentLayer; for (uint256 i = 0; i < layer.length; i += 2) { - parent_layer = parent_layer.append(_hashLeafPairs(layer[i], layer[i + 1])); + parentLayer = parentLayer.append(_hashLeafPairs(layer[i], layer[i + 1])); } - return parent_layer; + return parentLayer; } /** @@ -182,36 +182,36 @@ library MerkleGen { pure returns (uint256[] memory, bytes32[] memory, bool[] memory) { - uint256[] memory auth_indices; - uint256[] memory next_indices; - bool[] memory source_flags; + uint256[] memory authIndices; + uint256[] memory nextIndices; + bool[] memory sourceFlags; uint256 j = 0; while (j < indices.length) { uint256 x = indices[j]; - next_indices = next_indices.append(_getParentIndex(x)); + nextIndices = nextIndices.append(_getParentIndex(x)); if (((j + 1) < indices.length) && (indices[j + 1] == _getSiblingIndex(x))) { j += 1; - source_flags = source_flags.append(SOURCE_FROM_HASHES); + sourceFlags = sourceFlags.append(SOURCE_FROM_HASHES); } else { - auth_indices = auth_indices.append(_getSiblingIndex(x)); - source_flags = source_flags.append(SOURCE_FROM_PROOF); + authIndices = authIndices.append(_getSiblingIndex(x)); + sourceFlags = sourceFlags.append(SOURCE_FROM_PROOF); } j += 1; } - bytes32[] memory subProof = new bytes32[](auth_indices.length); - for (uint256 i = 0; i < auth_indices.length; i++) { + bytes32[] memory subProof = new bytes32[](authIndices.length); + for (uint256 i = 0; i < authIndices.length; i++) { // Here, if the index is out of bounds, we use the last element of the layer - if (layer.length - 1 < auth_indices[i]) { - subProof[i] = layer[auth_indices[i] - 1]; + if (layer.length - 1 < authIndices[i]) { + subProof[i] = layer[authIndices[i] - 1]; } else { - subProof[i] = layer[auth_indices[i]]; + subProof[i] = layer[authIndices[i]]; } } - return (next_indices, subProof, source_flags); + return (nextIndices, subProof, sourceFlags); } /** @@ -236,65 +236,65 @@ library MerkleGen { * @dev The total number of hashes must equal the number of source flags plus one. * @dev The number of proof hashes must match the number of `SOURCE_FROM_PROOF` flags. * @param leaves Selected leaves to be included in the proof. - * @param proof_hashes Sibling hashes extracted from the proof. - * @param proof_source_flags Flags indicating the source of each proof hash. + * @param proofHashes Sibling hashes extracted from the proof. + * @param proofSourceFlags Flags indicating the source of each proof hash. * @return Computed Merkle root. */ - function _verifyComputeRoot( - bytes32[] memory leaves, - bytes32[] memory proof_hashes, - bool[] memory proof_source_flags - ) internal pure returns (bytes32) { - uint256 total_hashes = leaves.length + proof_hashes.length - 1; - require(total_hashes == proof_source_flags.length, "MerkleGen: Invalid total hashes."); + function _verifyComputeRoot(bytes32[] memory leaves, bytes32[] memory proofHashes, bool[] memory proofSourceFlags) + internal + pure + returns (bytes32) + { + uint256 totalHashes = leaves.length + proofHashes.length - 1; + require(totalHashes == proofSourceFlags.length, "MerkleGen: Invalid total hashes."); require( - _helperCount(proof_source_flags, SOURCE_FROM_PROOF) == proof_hashes.length, + _helperCount(proofSourceFlags, SOURCE_FROM_PROOF) == proofHashes.length, "MerkleGen: Invalid number of proof hashes." ); - bytes32[] memory hashes = new bytes32[](total_hashes); + bytes32[] memory hashes = new bytes32[](totalHashes); // Fill hashes with leaves[0] - for (uint256 i = 0; i < total_hashes; i++) { + for (uint256 i = 0; i < totalHashes; i++) { hashes[i] = leaves[0]; } // Variables - uint256 leaf_pos = 0; - uint256 hash_pos = 0; - uint256 proof_pos = 0; + uint256 leafPos = 0; + uint256 hashPos = 0; + uint256 proofPos = 0; - for (uint256 i = 0; i < total_hashes; i++) { + for (uint256 i = 0; i < totalHashes; i++) { bytes32 a; bytes32 b; // Select a - if (proof_source_flags[i] == SOURCE_FROM_HASHES) { - if (leaf_pos < leaves.length) { - a = leaves[leaf_pos]; - leaf_pos += 1; + if (proofSourceFlags[i] == SOURCE_FROM_HASHES) { + if (leafPos < leaves.length) { + a = leaves[leafPos]; + leafPos += 1; } else { - a = hashes[hash_pos]; - hash_pos += 1; + a = hashes[hashPos]; + hashPos += 1; } - } else if (proof_source_flags[i] == SOURCE_FROM_PROOF) { - a = proof_hashes[proof_pos]; - proof_pos += 1; + } else if (proofSourceFlags[i] == SOURCE_FROM_PROOF) { + a = proofHashes[proofPos]; + proofPos += 1; } // Select b - if (leaf_pos < leaves.length) { - b = leaves[leaf_pos]; - leaf_pos += 1; + if (leafPos < leaves.length) { + b = leaves[leafPos]; + leafPos += 1; } else { - b = hashes[hash_pos]; - hash_pos += 1; + b = hashes[hashPos]; + hashPos += 1; } // Compute hash hashes[i] = _hashLeafPairs(a, b); } - if (total_hashes > 0) { - return hashes[total_hashes - 1]; + if (totalHashes > 0) { + return hashes[totalHashes - 1]; } else { return leaves[0]; } @@ -347,8 +347,6 @@ library MerkleGen { * @return The root hash of the Merkle tree. */ function _getRoot(bytes32[] memory leaves) internal pure returns (bytes32) { - require(leaves.length > 1, "MerkleGen: Data should be greater than 1."); - bytes32[] memory tree = _buildTree(leaves); return tree[0]; @@ -362,8 +360,6 @@ library MerkleGen { * @return An array of sibling hashes forming the Merkle proof for the leaf at the specified index. */ function _getProof(bytes32[] memory leaves, uint256 index) internal pure returns (bytes32[] memory) { - require(leaves.length > 1, "MerkleGen: Leaves should be greater than 1."); - bytes32[] memory tree = _buildTree(leaves); uint256 proofLength = _log2CeilBitMagic(leaves.length);