diff --git a/core/src/main/kotlin/com/r3corda/core/crypto/PartialMerkleTree.kt b/core/src/main/kotlin/com/r3corda/core/crypto/PartialMerkleTree.kt index 5a0a13cf69..6d813f5418 100644 --- a/core/src/main/kotlin/com/r3corda/core/crypto/PartialMerkleTree.kt +++ b/core/src/main/kotlin/com/r3corda/core/crypto/PartialMerkleTree.kt @@ -12,13 +12,9 @@ fun <T: Number> log2(x: T): Double{ return Math.log(x.toDouble())/Math.log(2.0) } -/* -Include branch: - * false - hash stored, no hashes below stored - * true - not stored, some hashes below stored -At leaves level, hashes of not included transaction's blocks are stored. -Tree traversal: preorder. -*/ +/** + * TODO description + */ class PartialMerkleTree( val branchHashes: List<SecureHash>, val includeBranch: List<Boolean>, @@ -26,9 +22,13 @@ class PartialMerkleTree( val leavesSize: Int ){ companion object{ - protected var hashIdx = 0 - protected var includeIdx = 0 + private var hashIdx = 0 + private var includeIdx = 0 + /** + * Builds new Partial Merkle Tree out of [allLeavesHashes]. [includeLeaves] is a list of Booleans that tells + * which leaves from [allLeavesHashes] to include in a partial tree. + */ fun build(includeLeaves: List<Boolean>, allLeavesHashes: List<SecureHash>) : PartialMerkleTree { val branchHashes: MutableList<SecureHash> = ArrayList() @@ -38,9 +38,16 @@ class PartialMerkleTree( return PartialMerkleTree(branchHashes, includeBranch, treeHeight, allLeavesHashes.size) } - //height - height of the node in the tree (leaves are 0) - //position - position of the node at a given height level (starting from 0) - fun whichNodesInBranch( + /** + * Recursively build a tree, traversal order - preorder. + * [height] - height of the node in a tree (leaves are at 0 level). + * [position] - position of the node at a given height level (starting from 0). + * [includeBranch] - gives a path of traversal in a tree: false indicates that traversal stopped at given node + * and it's hash is stored. + * For true, algorithm continued to the subtree starting at that node (unless it reached leaves' level). + * Hashes of leaves included in that partial tree are stored - that set is checked later durign verification stage. + */ + private fun whichNodesInBranch( height: Int, position: Int, includeLeaves: List<Boolean>, @@ -53,7 +60,7 @@ class PartialMerkleTree( if (height == 0 || !isParent) { //Hash should be stored, don't traverse the subtree starting with that node. //Or height == 0 and recursion reached leaf level of the tree, hash is stored. - resultHashes.add(treeHash(position, height, allLeavesHashes)) //resultHashes[height].add(treeHash) + resultHashes.add(treeHash(position, height, allLeavesHashes)) } else { whichNodesInBranch(height - 1, position * 2, includeLeaves, allLeavesHashes, includeBranch, resultHashes) //If the tree is not full, we don't add the rightmost hash. @@ -63,19 +70,21 @@ class PartialMerkleTree( } } - /* Calculation of the node's hash using stack. - Pushes to the stack elements with an information about on what height they are in the tree. + /** + * Calculation of the node's hash using stack. + * Elements are pushed with an information about at what height they are in the tree. */ - fun treeHash(position: Int, height: Int, allLeavesHashes: List<SecureHash>): SecureHash { + private fun treeHash(position: Int, height: Int, allLeavesHashes: List<SecureHash>): SecureHash { var (startIdx, endIdx) = getNodeLeafRange(height, position, allLeavesHashes.size) val stack = Stack<Pair<Int, SecureHash>>() - if (height <= 0) { //Just return leaf's hash. todo if height < 0 + if (height == 0) { //Just return leaf's hash. return allLeavesHashes[position] } - //otherwise calculate + //Otherwise calculate hash from lower elements. while (true) { val size = stack.size - //Two last elements on the stack are of the same height + //Two last elements on the stack are of the same height. + //The way we build the stack hashes assures that they are siblings in a tree. if (size >= 2 && stack[size - 1].first == stack[size - 2].first) { //Calculate hash of them and and push new node to the stack. val el1 = stack.pop() @@ -85,9 +94,9 @@ class PartialMerkleTree( if (h + 1 == height) return combinedHash //We reached desired node. else stack.push(Pair(h + 1, combinedHash)) - } else if (startIdx > endIdx) { //Odd numbers of elements at that level - stack.push(stack.last()) //Need to duplicate the last element. todo check - } else { //Add a leaf hash to the stack + } else if (startIdx > endIdx) { //Odd numbers of elements at that level. + stack.push(stack.last()) //Need to duplicate the last element. + } else { //Add a leaf hash to the stack. stack.push(Pair(0, allLeavesHashes[startIdx])) startIdx++ } @@ -95,17 +104,15 @@ class PartialMerkleTree( } //Calculates which leaves belong to the subtree starting from that node. - //todo - out of tree width - //OK - protected fun getNodeLeafRange(height: Int, position: Int, leavesCount: Int): Pair<Int, Int> { + private fun getNodeLeafRange(height: Int, position: Int, leavesCount: Int): Pair<Int, Int> { val offset = Math.pow(2.0, height.toDouble()).toInt() val start = position * offset - val end = Math.min(start + offset - 1, leavesCount-1) //Not full binary trees + val end = Math.min(start + offset - 1, leavesCount-1) //Not full binary tree. return Pair(start, end) } //Checks if a node at given height and position is a parent of some of the leaves that are included in the transaction. - protected fun checkIsParent(includeLeaves: List<Boolean>, height: Int, position: Int, leavesCount: Int): Boolean { + private fun checkIsParent(includeLeaves: List<Boolean>, height: Int, position: Int, leavesCount: Int): Boolean { val (start, end) = getNodeLeafRange(height, position, leavesCount) for (el in IntRange(start, end)) { if (includeLeaves[el]) return true @@ -113,25 +120,32 @@ class PartialMerkleTree( return false } - //OK - protected fun treeWidth(height: Int, leavesSize: Int): Double{ //return tree width at given height + //Return tree width at given height. + private fun treeWidth(height: Int, leavesSize: Int): Double{ return Math.ceil(leavesSize/Math.pow(2.0, height.toDouble())) } - } + /** + * Verification that leavesHashes belong to this tree. It is leaves' ordering insensitive. + * Checks if provided merkleRoot matches the one calculated from this Partial Merkle Tree. + */ fun verify(leavesHashes: List<SecureHash>, merkleRoot: SecureHash): Boolean{ - includeIdx = 0 //todo check that + if(leavesSize==0) throw MerkleTreeException("PMT with zero leaves.") + includeIdx = 0 hashIdx = 0 val hashesUsed = ArrayList<SecureHash>() val verifyRoot = verifyTree(treeHeight, 0, hashesUsed) - //It means that we obtained more/less hashes than needed. Or different sets of hashes. + if(includeIdx < includeBranch.size-1 || hashIdx < branchHashes.size -1) + throw MerkleTreeException("Not all entries form PMT branch used.") + //It means that we obtained more/less hashes than needed or different sets of hashes. //Ordering insensitive. if(leavesHashes.size != hashesUsed.size || leavesHashes.minus(hashesUsed).isNotEmpty()) return false - return (verifyRoot == merkleRoot) //Correctness of hashes is checked by folding the tree. + return (verifyRoot == merkleRoot) //Correctness of hashes is checked by folding the partial tree. } + //Traverses the tree in the same order as it was build consuming includeBranch and branchHashes. private fun verifyTree(height: Int, position: Int, hashesUsed: MutableList<SecureHash>): SecureHash { if(includeIdx >= includeBranch.size) throw MerkleTreeException("Included nodes list index overflow.") @@ -143,12 +157,12 @@ class PartialMerkleTree( val hash = branchHashes[hashIdx] hashIdx++ if(height == 0 && isParent) - hashesUsed.add(hash) //todo or hash into a tree + hashesUsed.add(hash) return hash } else { val left: SecureHash = verifyTree(height - 1, position * 2, hashesUsed) val right: SecureHash = when{ - position * 2 + 1 < treeWidth(height, leavesSize)-1 -> verifyTree(height - 1, position * 2 + 1, hashesUsed) + position * 2 + 1 < treeWidth(height-1, leavesSize) -> verifyTree(height - 1, position * 2 + 1, hashesUsed) else -> left } return left.hashConcat(right) diff --git a/core/src/main/kotlin/com/r3corda/core/transactions/MerkleTransaction.kt b/core/src/main/kotlin/com/r3corda/core/transactions/MerkleTransaction.kt index b089b4085a..ae8c062a5c 100644 --- a/core/src/main/kotlin/com/r3corda/core/transactions/MerkleTransaction.kt +++ b/core/src/main/kotlin/com/r3corda/core/transactions/MerkleTransaction.kt @@ -1,19 +1,20 @@ package com.r3corda.core.transactions import com.r3corda.core.contracts.Command +import com.r3corda.core.crypto.MerkleTreeException import com.r3corda.core.crypto.PartialMerkleTree import com.r3corda.core.crypto.SecureHash import com.r3corda.core.crypto.sha256 import com.r3corda.core.serialization.serialize import java.util.* -/* Creation and verification of a Merkle Tree for a Wire Transaction -* Tree should be the same no matter the ordering of outputs, inputs, attachments and commands. */ - -/* Transaction is split into following blocks: -inputs, outputs, commands, attachments' refs -If a row in a tree has odd number of elements - the final hash is hashed with itself. -*/ +/** + * Creation and verification of a Merkle Tree for a Wire Transaction. + * + * Tree should be the same no matter the ordering of outputs, inputs, attachments and commands. + * Transaction is split into following blocks: inputs, outputs, commands, attachments' refs. + * If a row in a tree has an odd number of elements - the final hash is hashed with itself. + */ fun SecureHash.hashConcat(other: SecureHash) = (this.bits + other.bits).sha256() @@ -50,8 +51,10 @@ class MerkleTransaction( return blocks } - /* Start building a Merkle tree from the transaction. - Calls helper tailrecursive function with an accumulator and initial hashedBlocks */ + /** + * Start building a Merkle tree from the transaction. + * Calls helper tailrecursive function with an accumulator and initial hashedBlocks. + */ fun buildMerkleTree(wtx: WireTransaction): MutableList<SecureHash>{ val blocks = getTransactionBlocks(wtx) val hashedBlocks: MutableList<SecureHash> = ArrayList() @@ -73,7 +76,7 @@ class MerkleTransaction( var i = 0 while(i < lastHashList.size){ val left = lastHashList[i] - //If there is an odd number of elements, the last element is hashed with itself + //If there is an odd number of elements, the last element is hashed with itself. val right = lastHashList[Math.min(i+1, lastHashList.size - 1)] val combined = left.hashConcat(right) resultHashes.add(combined)