Add leaves filtering and building trees with different leaves types.

This commit is contained in:
Katarzyna Streich 2016-10-13 18:12:45 +01:00
parent 2ce21842f8
commit ea826988af

View File

@ -1,6 +1,9 @@
package com.r3corda.core.transactions
import com.r3corda.core.contracts.Command
import com.r3corda.core.contracts.ContractState
import com.r3corda.core.contracts.StateRef
import com.r3corda.core.contracts.TransactionState
import com.r3corda.core.crypto.MerkleTreeException
import com.r3corda.core.crypto.PartialMerkleTree
import com.r3corda.core.crypto.SecureHash
@ -16,100 +19,128 @@ import java.util.*
* If a row in a tree has an odd number of elements - the final hash is hashed with itself.
*/
fun SecureHash.hashConcat(other: SecureHash) = (this.bits + other.bits).sha256()
//Todo It's just mess, move it to wtx
fun WireTransaction.buildFilteredTransaction(filterFuns: FilterFuns): MerkleTransaction{
return MerkleTransaction.buildMerkleTransaction(this, filterFuns)
}
fun WireTransaction.calculateLeavesHashes(): List<SecureHash>{
val resultHashes = ArrayList<SecureHash>()
val entries = listOf(inputs, outputs, attachments, commands)
entries.forEach { it.sortedBy { x-> x.hashCode() }.mapTo(resultHashes, { xy -> serializedHash(x) }) }
return resultHashes
}
fun SecureHash.hashConcat(other: SecureHash) = (this.bits + other.bits).sha256()
fun <T: Any> serializedHash(x: T) = x.serialize().hash
/**
* Builds the tree from bottom to top. Takes as an argument list of leaves hashes.
*/
tailrec fun getMerkleRoot(
lastHashList: List<SecureHash>): SecureHash{
if(lastHashList.size < 1)
throw MerkleTreeException("Cannot calculate Merkle root on empty hash list.")
if(lastHashList.size == 1) {
return lastHashList[0]
}
else{
val newLevelHashes: MutableList<SecureHash> = ArrayList()
var i = 0
while(i < lastHashList.size){
val left = lastHashList[i]
//If there is an odd number of elements, the last element is hashed with itself.
val right = lastHashList[Math.min(i+1, lastHashList.size - 1)]
val combined = left.hashConcat(right)
newLevelHashes.add(combined)
i+=2
}
return getMerkleRoot(newLevelHashes)
}
}
/**
* Class that holds filtered leaves for a partial Merkle transaction. We assume mixed leaves types.
*/
class FilteredLeaves(
val inputs: List<StateRef>,
val outputs: List<TransactionState<ContractState>>,
val attachments: List<SecureHash>,
val commands: List<Command>
){
fun getFilteredHashes(): List<SecureHash>{
val resultHashes = ArrayList<SecureHash>()
val entries = listOf(inputs, outputs, attachments, commands)
entries.forEach { it.mapTo(resultHashes, { x -> serializedHash(x) }) }
return resultHashes
}
}
open class FilterFuns(val filterInputs: (StateRef) -> Boolean = { false },
val filterOutputs: (TransactionState<ContractState>) -> Boolean = { false },
val filterAttachments: (SecureHash) -> Boolean = { false },
val filterCommands: (Command) -> Boolean = { false }){
fun <T: Any> genericFilter(elem: T): Boolean{
return when (elem) {
is StateRef -> filterInputs(elem)
is TransactionState<*> -> filterOutputs(elem)
is SecureHash -> filterAttachments(elem)
is Command -> filterCommands(elem)
else -> throw IllegalArgumentException("Wrong argument type: ${elem.javaClass}")
}
}
}
/**
* Class representing merkleized filtered transaction.
* filteredLeaves - are the leaves included in a filtered transaction.
* partialMerkleTree - Merkle branch needed to verify that filtered transaction.
*/
class MerkleTransaction(
val merkleRoot: SecureHash, //todo that should be in a wire tx? not with PMT and filtered commands
val filteredCommands : List<Command>, //todo + <Command> do we want to also filter something else than commands?
val filteredLeaves: FilteredLeaves,
val partialMerkleTree: PartialMerkleTree
){
companion object {
fun buildMerkleTransaction(wtx: WireTransaction, filterFunction: (Command) -> Boolean): MerkleTransaction {
val merkleTree: List<SecureHash> = buildMerkleTree(wtx) //todo change
val merkleRoot = merkleTree.last()
val allLeavesHashes: MutableList<SecureHash> = ArrayList()
//todo naive version with inputs, outputs, attachemets each as one block
getTransactionBlocks(wtx).mapTo(allLeavesHashes, { it.sha256() })
val filteredCommands: MutableList<Command> = ArrayList()
/**
* Construction of filtered transaction with Partial Merkle Tree, takes WireTransaction and filtering functions
* for inputs, outputs, attachments, commands.
*/
fun buildMerkleTransaction(wtx: WireTransaction,
filterFuns: FilterFuns
): MerkleTransaction {
val includeLeaves: MutableList<Boolean> = ArrayList()
filterLeaves(filterFunction, wtx, includeLeaves, filteredCommands)
val filteredInputs: MutableList<StateRef> = ArrayList()
val filteredOutputs: MutableList<TransactionState<ContractState>> = ArrayList()
val filteredAttachments: MutableList<SecureHash> = ArrayList()
val filteredCommands: MutableList<Command> = ArrayList()
val pmt = PartialMerkleTree.build(includeLeaves, allLeavesHashes)
return MerkleTransaction(merkleRoot, filteredCommands, pmt)
}
//todo type -> only on in, out, att, cmd
private fun filterLeaves(filterFunction: (Command) -> Boolean,
wtx: WireTransaction,
includeLeaves: MutableList<Boolean>,
filteredCommands: MutableList<Command> ){
//todo glued together for now
val tmpArr = arrayListOf(false, false, false)
includeLeaves.addAll(tmpArr)
val orderedCmds = wtx.commands.sortedBy { it.toString() }
orderedCmds.forEach { //todo should go on all in/outputs etc.
val include = filterFunction(it)
if(include) filteredCommands.add(it)
//It's a little evil, I needed a way of building at once few lists.
fun <T: Any> filterLeaves(el: T, destination: MutableList<T>){
val include = filterFuns.genericFilter(el)
if (include) destination.add(el)
includeLeaves.add(include)
}
}
/**
* Function that splits the transaction into serialized blocks.
* Blocks: inputs, outputs, attachments, commands.
*/
private fun getTransactionBlocks(wtx: WireTransaction) : MutableList<ByteArray> {
val blocks: MutableList<ByteArray> = ArrayList()
val toBlockList = listOf(wtx.inputs, wtx.outputs, wtx.attachments)
val orderedCmds = wtx.commands.sortedBy { it.toString() }
toBlockList.mapTo(blocks, { it.serialize().bits } )
blocks.addAll(orderedCmds.map { it.serialize().bits })
return blocks
}
/**
* Start building a Merkle tree from the transaction.
* Calls helper tailrecursive function with an accumulator and initial hashedBlocks.
*/
fun buildMerkleTree(wtx: WireTransaction): MutableList<SecureHash>{
val blocks = getTransactionBlocks(wtx)
val hashedBlocks: MutableList<SecureHash> = ArrayList()
blocks.mapTo(hashedBlocks, { it.sha256() })
val merkleTreeList = ArrayList<SecureHash>()
merkleTreeList.addAll(hashedBlocks)
buildMerkleTree(merkleTreeList, hashedBlocks)
return merkleTreeList
}
//TODO Ordering by hashCode
wtx.inputs.sortedBy { it.hashCode() }.forEach { filterLeaves(it, filteredInputs) }
wtx.outputs.sortedBy { it.hashCode() }.forEach { filterLeaves(it, filteredOutputs) }
wtx.attachments.sortedBy { it.hashCode() }.forEach { filterLeaves(it, filteredAttachments) }
wtx.commands.sortedBy { it.hashCode() }.forEach { filterLeaves(it, filteredCommands) }
tailrec fun buildMerkleTree(
resultHashes: MutableList<SecureHash>,
lastHashList: List<SecureHash>){
if(lastHashList.size <= 1) {
return
}
else{
val newLevelHashes: MutableList<SecureHash> = ArrayList()
var i = 0
while(i < lastHashList.size){
val left = lastHashList[i]
//If there is an odd number of elements, the last element is hashed with itself.
val right = lastHashList[Math.min(i+1, lastHashList.size - 1)]
val combined = left.hashConcat(right)
resultHashes.add(combined)
newLevelHashes.add(combined)
i+=2
}
buildMerkleTree(resultHashes, newLevelHashes)
}
val filteredLeaves = FilteredLeaves(filteredInputs, filteredOutputs, filteredAttachments, filteredCommands)
val pmt = PartialMerkleTree.build(includeLeaves, wtx.allLeavesHashes)
return MerkleTransaction(filteredLeaves, pmt)
}
}
//todo exception
fun verify():Boolean{
val hashes: List<SecureHash> = filteredCommands.map { it.serialize().sha256() }
/**
* Runs verification of Partial Merkle Branch with provided merkleRoot.
*/
fun verify(merkleRoot: SecureHash):Boolean{
val hashes: List<SecureHash> = filteredLeaves.getFilteredHashes()
if(hashes.size == 0)
throw MerkleTreeException("Transaction without included leaves.")
return partialMerkleTree.verify(hashes, merkleRoot)