diff --git a/core/src/main/kotlin/com/r3corda/core/crypto/PartialMerkleTree.kt b/core/src/main/kotlin/com/r3corda/core/crypto/PartialMerkleTree.kt index 6d813f5418..1cdaa2cfb6 100644 --- a/core/src/main/kotlin/com/r3corda/core/crypto/PartialMerkleTree.kt +++ b/core/src/main/kotlin/com/r3corda/core/crypto/PartialMerkleTree.kt @@ -8,12 +8,17 @@ class MerkleTreeException(val reason: String) : Exception() { override fun toString() = "Partial Merkle Tree exception. Reason: $reason" } +//For convenient binary tree calculations. fun log2(x: T): Double{ return Math.log(x.toDouble())/Math.log(2.0) } /** - * TODO description + * Building and verification of merkle branch. [branchHashes] - minimal set of hashes needed to check given subset of leaves. + * [includeBranch] - path telling us how tree was traversed and which hashes are included in branchHashes. + * [leavesSize] - number of all leaves in the original full Merkle tree. + + * If we include l2 in a PMT. includeBranch will be equal to: [], branchHashes will be the hashes of: [] TODO examples */ class PartialMerkleTree( val branchHashes: List, @@ -22,7 +27,7 @@ class PartialMerkleTree( val leavesSize: Int ){ companion object{ - private var hashIdx = 0 + private var hashIdx = 0 //Counters used in tree verification. private var includeIdx = 0 /** @@ -45,7 +50,7 @@ class PartialMerkleTree( * [includeBranch] - gives a path of traversal in a tree: false indicates that traversal stopped at given node * and it's hash is stored. * For true, algorithm continued to the subtree starting at that node (unless it reached leaves' level). - * Hashes of leaves included in that partial tree are stored - that set is checked later durign verification stage. + * Hashes of leaves included in that partial tree are stored - that set is checked later during verification stage. */ private fun whichNodesInBranch( height: Int, @@ -73,7 +78,7 @@ class PartialMerkleTree( /** * Calculation of the node's hash using stack. * Elements are pushed with an information about at what height they are in the tree. - */ + */ private fun treeHash(position: Int, height: Int, allLeavesHashes: List): SecureHash { var (startIdx, endIdx) = getNodeLeafRange(height, position, allLeavesHashes.size) val stack = Stack>() @@ -142,24 +147,26 @@ class PartialMerkleTree( //Ordering insensitive. if(leavesHashes.size != hashesUsed.size || leavesHashes.minus(hashesUsed).isNotEmpty()) return false - return (verifyRoot == merkleRoot) //Correctness of hashes is checked by folding the partial tree. + //Correctness of hashes is checked by folding the partial tree and comparing roots. + return (verifyRoot == merkleRoot) } - //Traverses the tree in the same order as it was build consuming includeBranch and branchHashes. + //Traverses the tree in the same order as it was built consuming includeBranch and branchHashes. private fun verifyTree(height: Int, position: Int, hashesUsed: MutableList): SecureHash { if(includeIdx >= includeBranch.size) throw MerkleTreeException("Included nodes list index overflow.") val isParent = includeBranch[includeIdx] includeIdx++ - if (height == 0 || !isParent) { + if (height == 0 || !isParent) { //Hash included in a branch was reached. if(hashIdx >branchHashes.size) throw MerkleTreeException("Branch hashes index overflow.") val hash = branchHashes[hashIdx] hashIdx++ + //It means that this leaf was included as part of original partial tree. It's hash is stored for later comparision. if(height == 0 && isParent) hashesUsed.add(hash) return hash - } else { + } else { //Continue tree verification to left and right nodes and hash them together. val left: SecureHash = verifyTree(height - 1, position * 2, hashesUsed) val right: SecureHash = when{ position * 2 + 1 < treeWidth(height-1, leavesSize) -> verifyTree(height - 1, position * 2 + 1, hashesUsed) diff --git a/core/src/main/kotlin/com/r3corda/core/transactions/MerkleTransaction.kt b/core/src/main/kotlin/com/r3corda/core/transactions/MerkleTransaction.kt index 98b848fbca..d34a26f19d 100644 --- a/core/src/main/kotlin/com/r3corda/core/transactions/MerkleTransaction.kt +++ b/core/src/main/kotlin/com/r3corda/core/transactions/MerkleTransaction.kt @@ -24,10 +24,13 @@ fun WireTransaction.buildFilteredTransaction(filterFuns: FilterFuns): MerkleTran return MerkleTransaction.buildMerkleTransaction(this, filterFuns) } +/** + * All leaves hashes are needed for calculation of transaction id and partial merkle branches. + */ fun WireTransaction.calculateLeavesHashes(): List{ val resultHashes = ArrayList() val entries = listOf(inputs, outputs, attachments, commands) - entries.forEach { it.sortedBy { x-> x.hashCode() }.mapTo(resultHashes, { xy -> serializedHash(x) }) } + entries.forEach { it.sortedBy { x-> x.hashCode() }.mapTo(resultHashes, { x -> serializedHash(x) }) } return resultHashes } @@ -35,14 +38,14 @@ fun SecureHash.hashConcat(other: SecureHash) = (this.bits + other.bits).sha256() fun serializedHash(x: T) = x.serialize().hash /** - * Builds the tree from bottom to top. Takes as an argument list of leaves hashes. + * Builds the tree from bottom to top. Takes as an argument list of leaves hashes. Used later as WireTransaction id. */ tailrec fun getMerkleRoot( lastHashList: List): SecureHash{ if(lastHashList.size < 1) throw MerkleTreeException("Cannot calculate Merkle root on empty hash list.") if(lastHashList.size == 1) { - return lastHashList[0] + return lastHashList[0] //Root reached. } else{ val newLevelHashes: MutableList = ArrayList() @@ -76,7 +79,11 @@ class FilteredLeaves( } } -open class FilterFuns(val filterInputs: (StateRef) -> Boolean = { false }, +/** + * Holds filter functions on transactions fields. + * Functions are used to build a partial tree only out of some subset of original transaction fields. + */ +class FilterFuns(val filterInputs: (StateRef) -> Boolean = { false }, val filterOutputs: (TransactionState) -> Boolean = { false }, val filterAttachments: (SecureHash) -> Boolean = { false }, val filterCommands: (Command) -> Boolean = { false }){ diff --git a/core/src/main/kotlin/com/r3corda/core/transactions/WireTransaction.kt b/core/src/main/kotlin/com/r3corda/core/transactions/WireTransaction.kt index c40070ea8d..035a784fe7 100644 --- a/core/src/main/kotlin/com/r3corda/core/transactions/WireTransaction.kt +++ b/core/src/main/kotlin/com/r3corda/core/transactions/WireTransaction.kt @@ -41,9 +41,9 @@ class WireTransaction( // override val id: SecureHash get() = serialized.hash //todo remove - //We need cashed leaves hashed for Partial Merkle Tree calculation. + //We need cashed leaves hashes for id and Partial Merkle Tree calculation. @Volatile @Transient private var cachedLeavesHashes: List? = null - val allLeavesHashes: List get() = cachedLeavesHashes ?: calculateLeavesHashes().apply { cachedLeavesHashes } + val allLeavesHashes: List get() = cachedLeavesHashes ?: calculateLeavesHashes().apply { cachedLeavesHashes = this } //TODO There is a problem with that it's failing 4 tests. Also in few places in code, there was reference to tx.serialized.hash // instead of tx.id. diff --git a/core/src/test/kotlin/com/r3corda/core/crypto/PartialMerkleTreeTest.kt b/core/src/test/kotlin/com/r3corda/core/crypto/PartialMerkleTreeTest.kt index acc5c08b91..0eb2f18231 100644 --- a/core/src/test/kotlin/com/r3corda/core/crypto/PartialMerkleTreeTest.kt +++ b/core/src/test/kotlin/com/r3corda/core/crypto/PartialMerkleTreeTest.kt @@ -18,13 +18,8 @@ import com.r3corda.testing.* import org.junit.Test import java.util.* import kotlin.test.assertEquals -import kotlin.test.assertFailsWith import kotlin.test.assertFalse -//todo -//verification - different root -//tests failsWith - class PartialMerkleTreeTest{ val nodes = "abcdef" val hashed: MutableList = ArrayList() @@ -72,16 +67,15 @@ class PartialMerkleTreeTest{ assertEquals(node, mr) } -// @Test -// fun `building Merkle tree odd number of nodes`(){ -// val odd = hashed.subList(0, 3) -// val h1 = hashed[0].hashConcat(hashed[1]) -// val h2 = hashed[2].hashConcat(hashed[2]) -// val mtl = MerkleTransaction.Companion.buildMerkleTree(odd) -// assertEquals(6, mtl.size) -// assertEquals(h1, mtl[3]) -// assertEquals(h2, mtl[4]) -// } + @Test + fun `building Merkle tree odd number of nodes`(){ + val odd = hashed.subList(0, 3) + val h1 = hashed[0].hashConcat(hashed[1]) + val h2 = hashed[2].hashConcat(hashed[2]) + val expected = h1.hashConcat(h2) + val root = getMerkleRoot(odd) + assertEquals(root, expected) + } @Test fun `building Merkle tree for a transaction`(){ @@ -105,7 +99,11 @@ class PartialMerkleTreeTest{ tx2.addCommand(Cash.Commands.Issue(0), ALICE_PUBKEY) tx2.addCommand(Cash.Commands.Issue(1), ALICE_PUBKEY) val wtx2 = tx2.toWireTransaction() + val mt1 = wtx1.buildFilteredTransaction(filterFuns) + val mt2 = wtx2.buildFilteredTransaction(filterFuns) assertEquals(wtx1.id, wtx2.id) + assert(mt1.verify(wtx1.id)) + assert(mt2.verify(wtx2.id)) } //Partial Merkle Tree building tests @@ -141,4 +139,13 @@ class PartialMerkleTreeTest{ val pmt = PartialMerkleTree.build(includeLeaves, hashed) assertFalse(pmt.verify(inclHashes, root)) } + + @Test + fun `verify Partial Merkle Tree - wrong root`(){ + val includeLeaves = listOf(false, false, false, true, false, true) + val inclHashes = listOf(hashed[3], hashed[5]) + val pmt = PartialMerkleTree.build(includeLeaves, hashed) + val wrongRoot = hashed[3].hashConcat(hashed[5]) + assertFalse(pmt.verify(inclHashes, wrongRoot)) + } }