Java accessible API for merkle trees

Respond to PR comment
This commit is contained in:
Matthew Nesbit 2017-06-19 17:30:11 +01:00
parent b874b3e62a
commit d2aaba2820
6 changed files with 26 additions and 19 deletions

View File

@ -11,6 +11,7 @@ import net.corda.core.serialization.p2PKryo
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import net.corda.core.serialization.withoutReferences import net.corda.core.serialization.withoutReferences
import java.security.PublicKey import java.security.PublicKey
import java.util.function.Predicate
fun <T : Any> serializedHash(x: T): SecureHash { fun <T : Any> serializedHash(x: T): SecureHash {
return p2PKryo().run { kryo -> kryo.withoutReferences { x.serialize(kryo).hash } } return p2PKryo().run { kryo -> kryo.withoutReferences { x.serialize(kryo).hash } }
@ -116,8 +117,9 @@ class FilteredTransaction private constructor(
* @param wtx WireTransaction to be filtered. * @param wtx WireTransaction to be filtered.
* @param filtering filtering over the whole WireTransaction * @param filtering filtering over the whole WireTransaction
*/ */
@JvmStatic
fun buildMerkleTransaction(wtx: WireTransaction, fun buildMerkleTransaction(wtx: WireTransaction,
filtering: (Any) -> Boolean filtering: Predicate<Any>
): FilteredTransaction { ): FilteredTransaction {
val filteredLeaves = wtx.filterWithFun(filtering) val filteredLeaves = wtx.filterWithFun(filtering)
val merkleTree = wtx.merkleTree val merkleTree = wtx.merkleTree

View File

@ -13,6 +13,7 @@ import net.corda.core.serialization.p2PKryo
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import net.corda.core.utilities.Emoji import net.corda.core.utilities.Emoji
import java.security.PublicKey import java.security.PublicKey
import java.util.function.Predicate
/** /**
* A transaction ready for serialisation, without any signatures attached. A WireTransaction is usually wrapped * A transaction ready for serialisation, without any signatures attached. A WireTransaction is usually wrapped
@ -106,7 +107,7 @@ class WireTransaction(
/** /**
* Build filtered transaction using provided filtering functions. * Build filtered transaction using provided filtering functions.
*/ */
fun buildFilteredTransaction(filtering: (Any) -> Boolean): FilteredTransaction { fun buildFilteredTransaction(filtering: Predicate<Any>): FilteredTransaction {
return FilteredTransaction.buildMerkleTransaction(this, filtering) return FilteredTransaction.buildMerkleTransaction(this, filtering)
} }
@ -120,15 +121,15 @@ class WireTransaction(
* @param filtering filtering over the whole WireTransaction * @param filtering filtering over the whole WireTransaction
* @returns FilteredLeaves used in PartialMerkleTree calculation and verification. * @returns FilteredLeaves used in PartialMerkleTree calculation and verification.
*/ */
fun filterWithFun(filtering: (Any) -> Boolean): FilteredLeaves { fun filterWithFun(filtering: Predicate<Any>): FilteredLeaves {
fun notNullFalse(elem: Any?): Any? = if (elem == null || !filtering(elem)) null else elem fun notNullFalse(elem: Any?): Any? = if (elem == null || !filtering.test(elem)) null else elem
return FilteredLeaves( return FilteredLeaves(
inputs.filter { filtering(it) }, inputs.filter { filtering.test(it) },
attachments.filter { filtering(it) }, attachments.filter { filtering.test(it) },
outputs.filter { filtering(it) }, outputs.filter { filtering.test(it) },
commands.filter { filtering(it) }, commands.filter { filtering.test(it) },
notNullFalse(notary) as Party?, notNullFalse(notary) as Party?,
mustSign.filter { filtering(it) }, mustSign.filter { filtering.test(it) },
notNullFalse(type) as TransactionType?, notNullFalse(type) as TransactionType?,
notNullFalse(timeWindow) as TimeWindow? notNullFalse(timeWindow) as TimeWindow?
) )

View File

@ -19,6 +19,7 @@ import net.corda.core.serialization.serialize
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.unwrap import net.corda.core.utilities.unwrap
import java.util.function.Predicate
object NotaryFlow { object NotaryFlow {
/** /**
@ -63,7 +64,7 @@ object NotaryFlow {
val payload: Any = if (serviceHub.networkMapCache.isValidatingNotary(notaryParty)) { val payload: Any = if (serviceHub.networkMapCache.isValidatingNotary(notaryParty)) {
stx stx
} else { } else {
wtx.buildFilteredTransaction { it is StateRef || it is TimeWindow } wtx.buildFilteredTransaction(Predicate { it is StateRef || it is TimeWindow })
} }
val response = try { val response = try {

View File

@ -15,6 +15,7 @@ import net.corda.core.utilities.TEST_TX_TIME
import net.corda.testing.* import net.corda.testing.*
import org.junit.Test import org.junit.Test
import java.security.PublicKey import java.security.PublicKey
import java.util.function.Predicate
import kotlin.test.* import kotlin.test.*
class PartialMerkleTreeTest { class PartialMerkleTreeTest {
@ -104,7 +105,7 @@ class PartialMerkleTreeTest {
} }
} }
val mt = testTx.buildFilteredTransaction(::filtering) val mt = testTx.buildFilteredTransaction(Predicate(::filtering))
val leaves = mt.filteredLeaves val leaves = mt.filteredLeaves
val d = WireTransaction.deserialize(testTx.serialized) val d = WireTransaction.deserialize(testTx.serialized)
assertEquals(testTx.id, d.id) assertEquals(testTx.id, d.id)
@ -128,7 +129,7 @@ class PartialMerkleTreeTest {
@Test @Test
fun `nothing filtered`() { fun `nothing filtered`() {
val mt = testTx.buildFilteredTransaction({ false }) val mt = testTx.buildFilteredTransaction(Predicate { false })
assertTrue(mt.filteredLeaves.attachments.isEmpty()) assertTrue(mt.filteredLeaves.attachments.isEmpty())
assertTrue(mt.filteredLeaves.commands.isEmpty()) assertTrue(mt.filteredLeaves.commands.isEmpty())
assertTrue(mt.filteredLeaves.inputs.isEmpty()) assertTrue(mt.filteredLeaves.inputs.isEmpty())

View File

@ -17,6 +17,7 @@ import net.corda.irs.utilities.suggestInterestRateAnnouncementTimeWindow
import java.math.BigDecimal import java.math.BigDecimal
import java.time.Instant import java.time.Instant
import java.util.* import java.util.*
import java.util.function.Predicate
// This code is unit tested in NodeInterestRates.kt // This code is unit tested in NodeInterestRates.kt
@ -62,7 +63,7 @@ open class RatesFixFlow(protected val tx: TransactionBuilder,
tx.addCommand(fix, oracle.owningKey) tx.addCommand(fix, oracle.owningKey)
beforeSigning(fix) beforeSigning(fix)
progressTracker.currentStep = SIGNING progressTracker.currentStep = SIGNING
val mtx = tx.toWireTransaction().buildFilteredTransaction({ filtering(it) }) val mtx = tx.toWireTransaction().buildFilteredTransaction(Predicate { filtering(it) })
val signature = subFlow(FixSignFlow(tx, oracle, mtx)) val signature = subFlow(FixSignFlow(tx, oracle, mtx))
tx.addSignatureUnchecked(signature) tx.addSignatureUnchecked(signature)
} }

View File

@ -36,6 +36,7 @@ import org.junit.Before
import org.junit.Test import org.junit.Test
import java.io.Closeable import java.io.Closeable
import java.math.BigDecimal import java.math.BigDecimal
import java.util.function.Predicate
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertFailsWith import kotlin.test.assertFailsWith
import kotlin.test.assertFalse import kotlin.test.assertFalse
@ -142,11 +143,11 @@ class NodeInterestRatesTest {
} }
} }
val ftx1 = wtx1.buildFilteredTransaction(::filterAllOutputs) val ftx1 = wtx1.buildFilteredTransaction(Predicate(::filterAllOutputs))
assertFailsWith<IllegalArgumentException> { oracle.sign(ftx1) } assertFailsWith<IllegalArgumentException> { oracle.sign(ftx1) }
tx.addCommand(Cash.Commands.Move(), ALICE_PUBKEY) tx.addCommand(Cash.Commands.Move(), ALICE_PUBKEY)
val wtx2 = tx.toWireTransaction() val wtx2 = tx.toWireTransaction()
val ftx2 = wtx2.buildFilteredTransaction { x -> filterCmds(x) } val ftx2 = wtx2.buildFilteredTransaction(Predicate { x -> filterCmds(x) })
assertFalse(wtx1.id == wtx2.id) assertFalse(wtx1.id == wtx2.id)
assertFailsWith<IllegalArgumentException> { oracle.sign(ftx2) } assertFailsWith<IllegalArgumentException> { oracle.sign(ftx2) }
} }
@ -160,7 +161,7 @@ class NodeInterestRatesTest {
tx.addCommand(fix, oracle.identity.owningKey) tx.addCommand(fix, oracle.identity.owningKey)
// Sign successfully. // Sign successfully.
val wtx = tx.toWireTransaction() val wtx = tx.toWireTransaction()
val ftx = wtx.buildFilteredTransaction { x -> fixCmdFilter(x) } val ftx = wtx.buildFilteredTransaction(Predicate { x -> fixCmdFilter(x) })
val signature = oracle.sign(ftx) val signature = oracle.sign(ftx)
tx.checkAndAddSignature(signature) tx.checkAndAddSignature(signature)
} }
@ -174,7 +175,7 @@ class NodeInterestRatesTest {
val badFix = Fix(fixOf, "0.6789".bd) val badFix = Fix(fixOf, "0.6789".bd)
tx.addCommand(badFix, oracle.identity.owningKey) tx.addCommand(badFix, oracle.identity.owningKey)
val wtx = tx.toWireTransaction() val wtx = tx.toWireTransaction()
val ftx = wtx.buildFilteredTransaction { x -> fixCmdFilter(x) } val ftx = wtx.buildFilteredTransaction(Predicate { x -> fixCmdFilter(x) })
val e1 = assertFailsWith<NodeInterestRates.UnknownFix> { oracle.sign(ftx) } val e1 = assertFailsWith<NodeInterestRates.UnknownFix> { oracle.sign(ftx) }
assertEquals(fixOf, e1.fix) assertEquals(fixOf, e1.fix)
} }
@ -194,7 +195,7 @@ class NodeInterestRatesTest {
} }
tx.addCommand(fix, oracle.identity.owningKey) tx.addCommand(fix, oracle.identity.owningKey)
val wtx = tx.toWireTransaction() val wtx = tx.toWireTransaction()
val ftx = wtx.buildFilteredTransaction(::filtering) val ftx = wtx.buildFilteredTransaction(Predicate(::filtering))
assertFailsWith<IllegalArgumentException> { oracle.sign(ftx) } assertFailsWith<IllegalArgumentException> { oracle.sign(ftx) }
} }
} }
@ -203,7 +204,7 @@ class NodeInterestRatesTest {
fun `empty partial transaction to sign`() { fun `empty partial transaction to sign`() {
val tx = makeTX() val tx = makeTX()
val wtx = tx.toWireTransaction() val wtx = tx.toWireTransaction()
val ftx = wtx.buildFilteredTransaction({ false }) val ftx = wtx.buildFilteredTransaction(Predicate { false })
assertFailsWith<MerkleTreeException> { oracle.sign(ftx) } assertFailsWith<MerkleTreeException> { oracle.sign(ftx) }
} }