Build Partial Merkle Tree as recursive data structure.

This commit is contained in:
Katarzyna Streich 2016-10-17 18:04:32 +01:00
parent cd59dfd2af
commit 6af7573955

View File

@ -1,178 +1,118 @@
package com.r3corda.core.crypto
import com.r3corda.core.transactions.MerkleTree
import com.r3corda.core.transactions.hashConcat
import java.util.*
class MerkleTreeException(val reason: String) : Exception() {
class MerkleTreeException(val reason: String): Exception() {
override fun toString() = "Partial Merkle Tree exception. Reason: $reason"
}
//For convenient binary tree calculations.
fun <T: Number> log2(x: T): Double{
return Math.log(x.toDouble())/Math.log(2.0)
}
/**
* Building and verification of merkle branch. [branchHashes] - minimal set of hashes needed to check given subset of leaves.
* [includeBranch] - path telling us how tree was traversed and which hashes are included in branchHashes.
* [leavesSize] - number of all leaves in the original full Merkle tree.
* If we include l2 in a PMT. includeBranch will be equal to: [], branchHashes will be the hashes of: [] TODO examples
* Building and verification of Partial Merkle Tree.
* Partial Merkle Tree is a minimal tree needed to check that given set of leaves belongs to a full Merkle Tree.
* todo example of partial tree
*/
class PartialMerkleTree(
val branchHashes: List<SecureHash>,
val includeBranch: List<Boolean>,
val treeHeight: Int,
val leavesSize: Int
){
companion object{
private var hashIdx = 0 //Counters used in tree verification.
private var includeIdx = 0
val root: PartialTree
) {
/**
* The structure is a little different than that of Merkle Tree.
* Partial Tree by might not be a full binary tree. Leaves represent either original Merkle tree leaves
* or cut subtree node with stored hash. We differentiate between the leaves that are included in a filtered
* transaction and leaves that just keep hashes needed for calculation. Reason for this approach: during verification
* it's easier to extract hashes used as base for this tree.
*/
sealed class PartialTree() {
class IncludedLeaf(val hash: SecureHash): PartialTree()
class Leaf(val hash: SecureHash): PartialTree()
class Node(val left: PartialTree, val right: PartialTree): PartialTree()
}
companion object {
/**
* Builds new Partial Merkle Tree out of [allLeavesHashes]. [includeLeaves] is a list of Booleans that tells
* which leaves from [allLeavesHashes] to include in a partial tree.
* @param merkleRoot
* @param includeHashes
* @return Partial Merkle tree root.
*/
fun build(includeLeaves: List<Boolean>, allLeavesHashes: List<SecureHash>)
: PartialMerkleTree {
val branchHashes: MutableList<SecureHash> = ArrayList()
val includeBranch: MutableList<Boolean> = ArrayList()
val treeHeight = Math.ceil(log2(allLeavesHashes.size.toDouble())).toInt()
whichNodesInBranch(treeHeight, 0, includeLeaves, allLeavesHashes, includeBranch, branchHashes)
return PartialMerkleTree(branchHashes, includeBranch, treeHeight, allLeavesHashes.size)
fun build(merkleRoot: MerkleTree, includeHashes: List<SecureHash>): PartialMerkleTree {
val usedHashes = ArrayList<SecureHash>()
//Too much included hashes or different ones.
val tree = buildPartialTree(merkleRoot, includeHashes, usedHashes)
if(includeHashes.size != usedHashes.size)
throw MerkleTreeException("Some of the provided hashes are not in the tree.")
return PartialMerkleTree(tree.second)
}
/**
* Recursively build a tree, traversal order - preorder.
* [height] - height of the node in a tree (leaves are at 0 level).
* [position] - position of the node at a given height level (starting from 0).
* [includeBranch] - gives a path of traversal in a tree: false indicates that traversal stopped at given node
* and it's hash is stored.
* For true, algorithm continued to the subtree starting at that node (unless it reached leaves' level).
* Hashes of leaves included in that partial tree are stored - that set is checked later during verification stage.
* @param root Root of full Merkle tree which is a base for a partial one.
* @param includeHashes Hashes of leaves to be included in this partial tree.
* @param usedHashes Hashes actually used to build this partial tree.
* @return Pair, first element indicates if in a subtree there is a leaf that is included in that partial tree.
* Second element refers to that subtree.
*/
private fun whichNodesInBranch(
height: Int,
position: Int,
includeLeaves: List<Boolean>,
allLeavesHashes: List<SecureHash>,
includeBranch: MutableList<Boolean>,
resultHashes: MutableList<SecureHash>) {
val isParent = checkIsParent(includeLeaves, height, position, allLeavesHashes.size)
includeBranch.add(isParent)
if (height == 0 || !isParent) {
//Hash should be stored, don't traverse the subtree starting with that node.
//Or height == 0 and recursion reached leaf level of the tree, hash is stored.
resultHashes.add(treeHash(position, height, allLeavesHashes))
private fun buildPartialTree(
root: MerkleTree,
includeHashes: List<SecureHash>,
usedHashes: MutableList<SecureHash>
): Pair<Boolean, PartialTree> {
if (root is MerkleTree.Leaf) {
if (root.value in includeHashes) {
usedHashes.add(root.value)
return Pair(true, PartialTree.IncludedLeaf(root.value))
} else return Pair(false, PartialTree.Leaf(root.value))
} else if (root is MerkleTree.DuplicatedLeaf) {
//Duplicate leaves should be stored as normal leaves not included ones.
return Pair(false, PartialTree.Leaf(root.value))
} else if (root is MerkleTree.Node) {
val leftNode = buildPartialTree(root.left, includeHashes, usedHashes)
val rightNode = buildPartialTree(root.right, includeHashes, usedHashes)
if (leftNode.first or rightNode.first) {
//This node is on a path to some included leaves. Don't store hash.
val newTree = PartialTree.Node(leftNode.second, rightNode.second)
return Pair(true, newTree)
} else {
//This node has no included leaves below. Cut the tree here and store a hash as a Leaf.
val newTree = PartialTree.Leaf(root.value)
return Pair(false, newTree)
}
} else {
whichNodesInBranch(height - 1, position * 2, includeLeaves, allLeavesHashes, includeBranch, resultHashes)
//If the tree is not full, we don't add the rightmost hash.
if (position * 2 + 1 <= treeWidth(height-1, allLeavesHashes.size)-1) {
whichNodesInBranch(height - 1, position * 2 + 1, includeLeaves, allLeavesHashes, includeBranch, resultHashes)
}
throw MerkleTreeException("Invalid MerkleTree.")
}
}
/**
* Calculation of the node's hash using stack.
* Elements are pushed with an information about at what height they are in the tree.
*/
private fun treeHash(position: Int, height: Int, allLeavesHashes: List<SecureHash>): SecureHash {
var (startIdx, endIdx) = getNodeLeafRange(height, position, allLeavesHashes.size)
val stack = Stack<Pair<Int, SecureHash>>()
if (height == 0) { //Just return leaf's hash.
return allLeavesHashes[position]
}
//Otherwise calculate hash from lower elements.
while (true) {
val size = stack.size
//Two last elements on the stack are of the same height.
//The way we build the stack hashes assures that they are siblings in a tree.
if (size >= 2 && stack[size - 1].first == stack[size - 2].first) {
//Calculate hash of them and and push new node to the stack.
val el1 = stack.pop()
val el2 = stack.pop()
val h = el1.first
val combinedHash = el2.second.hashConcat(el1.second)
if (h + 1 == height) return combinedHash //We reached desired node.
else
stack.push(Pair(h + 1, combinedHash))
} else if (startIdx > endIdx) { //Odd numbers of elements at that level.
stack.push(stack.last()) //Need to duplicate the last element.
} else { //Add a leaf hash to the stack.
stack.push(Pair(0, allLeavesHashes[startIdx]))
startIdx++
}
}
}
//Calculates which leaves belong to the subtree starting from that node.
private fun getNodeLeafRange(height: Int, position: Int, leavesCount: Int): Pair<Int, Int> {
val offset = Math.pow(2.0, height.toDouble()).toInt()
val start = position * offset
val end = Math.min(start + offset - 1, leavesCount-1) //Not full binary tree.
return Pair(start, end)
}
//Checks if a node at given height and position is a parent of some of the leaves that are included in the transaction.
private fun checkIsParent(includeLeaves: List<Boolean>, height: Int, position: Int, leavesCount: Int): Boolean {
val (start, end) = getNodeLeafRange(height, position, leavesCount)
for (el in IntRange(start, end)) {
if (includeLeaves[el]) return true
}
return false
}
//Return tree width at given height.
private fun treeWidth(height: Int, leavesSize: Int): Double{
return Math.ceil(leavesSize/Math.pow(2.0, height.toDouble()))
}
}
/**
* Verification that leavesHashes belong to this tree. It is leaves' ordering insensitive.
* Checks if provided merkleRoot matches the one calculated from this Partial Merkle Tree.
* @param merkleRootHash
* @param hashesToCheck
*/
fun verify(leavesHashes: List<SecureHash>, merkleRoot: SecureHash): Boolean{
if(leavesSize==0) throw MerkleTreeException("PMT with zero leaves.")
includeIdx = 0
hashIdx = 0
val hashesUsed = ArrayList<SecureHash>()
val verifyRoot = verifyTree(treeHeight, 0, hashesUsed)
if(includeIdx < includeBranch.size-1 || hashIdx < branchHashes.size -1)
throw MerkleTreeException("Not all entries form PMT branch used.")
fun verify(merkleRootHash: SecureHash, hashesToCheck: List<SecureHash>): Boolean {
val usedHashes = ArrayList<SecureHash>()
val verifyRoot = verify(root, usedHashes)
//It means that we obtained more/less hashes than needed or different sets of hashes.
//Ordering insensitive.
if(leavesHashes.size != hashesUsed.size || leavesHashes.minus(hashesUsed).isNotEmpty())
if(hashesToCheck.size != usedHashes.size || hashesToCheck.minus(usedHashes).isNotEmpty())
return false
//Correctness of hashes is checked by folding the partial tree and comparing roots.
return (verifyRoot == merkleRoot)
return (verifyRoot == merkleRootHash)
}
//Traverses the tree in the same order as it was built consuming includeBranch and branchHashes.
private fun verifyTree(height: Int, position: Int, hashesUsed: MutableList<SecureHash>): SecureHash {
if(includeIdx >= includeBranch.size)
throw MerkleTreeException("Included nodes list index overflow.")
val isParent = includeBranch[includeIdx]
includeIdx++
if (height == 0 || !isParent) { //Hash included in a branch was reached.
if(hashIdx >branchHashes.size)
throw MerkleTreeException("Branch hashes index overflow.")
val hash = branchHashes[hashIdx]
hashIdx++
//It means that this leaf was included as part of original partial tree. It's hash is stored for later comparision.
if(height == 0 && isParent)
hashesUsed.add(hash)
return hash
} else { //Continue tree verification to left and right nodes and hash them together.
val left: SecureHash = verifyTree(height - 1, position * 2, hashesUsed)
val right: SecureHash = when{
position * 2 + 1 < treeWidth(height-1, leavesSize) -> verifyTree(height - 1, position * 2 + 1, hashesUsed)
else -> left
}
return left.hashConcat(right)
/**
* Recursive calculation of root of this partial tree.
* Modifies usedHashes to later check for inclusion with hashes provided.
*/
private fun verify(node: PartialTree, usedHashes: MutableList<SecureHash>): SecureHash{
if (node is PartialTree.IncludedLeaf) {
usedHashes.add(node.hash)
return node.hash
} else if (node is PartialTree.Leaf ) {
return node.hash
} else if (node is PartialTree.Node){
val leftHash = verify(node.left, usedHashes)
val rightHash = verify(node.right, usedHashes)
return leftHash.hashConcat(rightHash)
} else {
throw MerkleTreeException("Invalid node type.")
}
}
}