diff --git a/src/main/kotlin/core/Transactions.kt b/src/main/kotlin/core/Transactions.kt index fc2f0c9488..3fe837252b 100644 --- a/src/main/kotlin/core/Transactions.kt +++ b/src/main/kotlin/core/Transactions.kt @@ -57,17 +57,30 @@ import java.util.* data class WireTransaction(val inputs: List, val outputs: List, val commands: List) { - fun toLedgerTransaction(identityService: IdentityService, originalHash: SecureHash): LedgerTransaction { + + // Cache the serialised form of the transaction and its hash to give us fast access to it. + @Volatile @Transient private var cachedBits: SerializedBytes? = null + val serialized: SerializedBytes get() = cachedBits ?: serialize().apply { cachedBits = this } + val id: SecureHash get() = serialized.hash + companion object { + fun deserialize(bits: SerializedBytes): WireTransaction { + val wtx = bits.deserialize() + wtx.cachedBits = bits + return wtx + } + } + + fun toLedgerTransaction(identityService: IdentityService): LedgerTransaction { val authenticatedArgs = commands.map { val institutions = it.pubkeys.mapNotNull { pk -> identityService.partyFromKey(pk) } AuthenticatedObject(it.pubkeys, institutions, it.data) } - return LedgerTransaction(inputs, outputs, authenticatedArgs, originalHash) + return LedgerTransaction(inputs, outputs, authenticatedArgs, id) } /** Serialises and returns this transaction as a [SignedWireTransaction] with no signatures attached. */ - fun toSignedTransaction(withSigs: List = emptyList()): SignedWireTransaction { - return SignedWireTransaction(serialize(), withSigs) + fun toSignedTransaction(withSigs: List): SignedWireTransaction { + return SignedWireTransaction(serialized, withSigs) } override fun toString(): String { @@ -85,7 +98,7 @@ data class SignedWireTransaction(val txBits: SerializedBytes, v init { check(sigs.isNotEmpty()) } /** Lazily calculated access to the deserialised/hashed transaction data. */ - val tx: WireTransaction by lazy { txBits.deserialize() } + val tx: WireTransaction by lazy { WireTransaction.deserialize(txBits) } /** A transaction ID is the hash of the [WireTransaction]. Thus adding or removing a signature does not change it. */ val id: SecureHash get() = txBits.hash @@ -124,7 +137,7 @@ data class SignedWireTransaction(val txBits: SerializedBytes, v */ fun verifyToLedgerTransaction(identityService: IdentityService): LedgerTransaction { verify() - return tx.toLedgerTransaction(identityService, id) + return tx.toLedgerTransaction(identityService) } /** Returns the same transaction but with an additional (unchecked) signature */ @@ -279,13 +292,6 @@ data class LedgerTransaction( @Suppress("UNCHECKED_CAST") fun outRef(index: Int) = StateAndRef(outputs[index] as T, StateRef(hash, index)) - fun outRef(state: T): StateAndRef { - val i = outputs.indexOf(state) - if (i == -1) - throw IllegalArgumentException("State not found in this transaction") - return outRef(i) - } - fun toWireTransaction(): WireTransaction { val wtx = WireTransaction(inputs, outputs, commands.map { Command(it.value, it.signers) }) check(wtx.serialize().hash == hash) diff --git a/src/test/kotlin/core/TransactionGroupTests.kt b/src/test/kotlin/core/TransactionGroupTests.kt index 940eb67cf4..5c8a4570ea 100644 --- a/src/test/kotlin/core/TransactionGroupTests.kt +++ b/src/test/kotlin/core/TransactionGroupTests.kt @@ -74,8 +74,8 @@ class TransactionGroupTests { val e = assertFailsWith(TransactionConflictException::class) { verify() } - assertEquals(StateRef(t.hash, 0), e.conflictRef) - assertEquals(setOf(conflict1, conflict2), setOf(e.tx1, e.tx2)) + assertEquals(StateRef(t.id, 0), e.conflictRef) + assertEquals(setOf(conflict1, conflict2), setOf(e.tx1.toWireTransaction(), e.tx2.toWireTransaction())) } } @@ -97,9 +97,11 @@ class TransactionGroupTests { // We have to do this manually without the DSL because transactionGroup { } won't let us create a tx that // points nowhere. val ref = StateRef(SecureHash.randomSHA256(), 0) - tg.txns.add(LedgerTransaction( - listOf(ref), listOf(A_THOUSAND_POUNDS), listOf(AuthenticatedObject(listOf(BOB), emptyList(), Cash.Commands.Move())), SecureHash.randomSHA256()) - ) + tg.txns += TransactionBuilder().apply { + addInputState(ref) + addOutputState(A_THOUSAND_POUNDS) + addCommand(Cash.Commands.Move(), BOB) + }.toWireTransaction() val e = assertFailsWith(TransactionResolutionException::class) { tg.verify() diff --git a/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt b/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt index 88f3b96220..312cecf61e 100644 --- a/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt +++ b/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt @@ -88,7 +88,7 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() { assertEquals(aliceResult.get(), bobResult.get()) - txns.add(aliceResult.get().second) + txns.add(aliceResult.get().first) verify() } } @@ -179,7 +179,7 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() { // Bob is now finished and has the same transaction as Alice. val tx = bobFuture.get() - txns.add(tx.second) + txns.add(tx.first) verify() assertTrue(smm.stateMachines.isEmpty()) diff --git a/src/test/kotlin/core/testutils/TestUtils.kt b/src/test/kotlin/core/testutils/TestUtils.kt index 17a67e8e3b..5632406f43 100644 --- a/src/test/kotlin/core/testutils/TestUtils.kt +++ b/src/test/kotlin/core/testutils/TestUtils.kt @@ -12,10 +12,7 @@ package core.testutils import contracts.* import core.* -import core.crypto.DummyPublicKey -import core.crypto.NullPublicKey -import core.crypto.SecureHash -import core.crypto.generateKeyPair +import core.crypto.* import core.serialization.serialize import core.visualiser.GraphVisualiser import java.security.PublicKey @@ -44,7 +41,7 @@ val BOB = BOB_KEY.public val MEGA_CORP = Party("MegaCorp", MEGA_CORP_PUBKEY) val MINI_CORP = Party("MiniCorp", MINI_CORP_PUBKEY) -val ALL_KEYS = listOf(MEGA_CORP_KEY, MINI_CORP_KEY, ALICE_KEY, BOB_KEY) +val ALL_TEST_KEYS = listOf(MEGA_CORP_KEY, MINI_CORP_KEY, ALICE_KEY, BOB_KEY) val TEST_KEYS_TO_CORP_MAP: Map = mapOf( MEGA_CORP_PUBKEY to MEGA_CORP, @@ -208,21 +205,14 @@ open class TransactionForTest : AbstractTransactionForTest() { fun transaction(body: TransactionForTest.() -> Unit) = TransactionForTest().apply { body() } class TransactionGroupDSL(private val stateType: Class) { - open inner class LedgerTransactionDSL : AbstractTransactionForTest() { + open inner class WireTransactionDSL : AbstractTransactionForTest() { private val inStates = ArrayList() fun input(label: String) { inStates.add(label.outputRef) } - /** - * Converts to a [LedgerTransaction] with the test institution map, and just assigns a random hash - * (i.e. pretend it was signed) - */ - fun toLedgerTransaction(): LedgerTransaction { - val wtx = WireTransaction(inStates, outStates.map { it.state }, commands) - return wtx.toLedgerTransaction(MockIdentityService, wtx.serialize().hash) - } + fun toWireTransaction() = WireTransaction(inStates, outStates.map { it.state }, commands) } val String.output: T get() = labelToOutputs[this] ?: throw IllegalArgumentException("State with label '$this' was not found") @@ -230,23 +220,23 @@ class TransactionGroupDSL(private val stateType: Class) { fun lookup(label: String) = StateAndRef(label.output as C, label.outputRef) - private inner class InternalLedgerTransactionDSL : LedgerTransactionDSL() { - fun finaliseAndInsertLabels(): LedgerTransaction { - val ltx = toLedgerTransaction() + private inner class InternalWireTransactionDSL : WireTransactionDSL() { + fun finaliseAndInsertLabels(): WireTransaction { + val wtx = toWireTransaction() for ((index, labelledState) in outStates.withIndex()) { if (labelledState.label != null) { - labelToRefs[labelledState.label] = StateRef(ltx.hash, index) + labelToRefs[labelledState.label] = StateRef(wtx.id, index) if (stateType.isInstance(labelledState.state)) { labelToOutputs[labelledState.label] = labelledState.state as T } outputsToLabels[labelledState.state] = labelledState.label } } - return ltx + return wtx } } - private val rootTxns = ArrayList() + private val rootTxns = ArrayList() private val labelToRefs = HashMap() private val labelToOutputs = HashMap() private val outputsToLabels = HashMap() @@ -257,42 +247,45 @@ class TransactionGroupDSL(private val stateType: Class) { fun transaction(vararg outputStates: LabeledOutput) { val outs = outputStates.map { it.state } val wtx = WireTransaction(emptyList(), outs, emptyList()) - val ltx = wtx.toLedgerTransaction(MockIdentityService, SecureHash.randomSHA256()) for ((index, state) in outputStates.withIndex()) { val label = state.label!! - labelToRefs[label] = StateRef(ltx.hash, index) + labelToRefs[label] = StateRef(wtx.id, index) outputsToLabels[state.state] = label labelToOutputs[label] = state.state as T } - rootTxns.add(ltx) + rootTxns.add(wtx) } @Deprecated("Does not nest ", level = DeprecationLevel.ERROR) fun roots(body: Roots.() -> Unit) {} @Deprecated("Use the vararg form of transaction inside roots", level = DeprecationLevel.ERROR) - fun transaction(body: LedgerTransactionDSL.() -> Unit) {} + fun transaction(body: WireTransactionDSL.() -> Unit) {} } fun roots(body: Roots.() -> Unit) = Roots().apply { body() } - val txns = ArrayList() - private val txnToLabelMap = HashMap() + val txns = ArrayList() + private val txnToLabelMap = HashMap() - fun transaction(label: String? = null, body: LedgerTransactionDSL.() -> Unit): LedgerTransaction { - val forTest = InternalLedgerTransactionDSL() + fun transaction(label: String? = null, body: WireTransactionDSL.() -> Unit): WireTransaction { + val forTest = InternalWireTransactionDSL() forTest.body() - val ltx = forTest.finaliseAndInsertLabels() - txns.add(ltx) + val wtx = forTest.finaliseAndInsertLabels() + txns.add(wtx) if (label != null) - txnToLabelMap[ltx] = label - return ltx + txnToLabelMap[wtx.id] = label + return wtx } - fun labelForTransaction(ltx: LedgerTransaction): String? = txnToLabelMap[ltx] + fun labelForTransaction(tx: WireTransaction): String? = txnToLabelMap[tx.id] + fun labelForTransaction(tx: LedgerTransaction): String? = txnToLabelMap[tx.hash] @Deprecated("Does not nest ", level = DeprecationLevel.ERROR) fun transactionGroup(body: TransactionGroupDSL.() -> Unit) {} - fun toTransactionGroup() = TransactionGroup(txns.map { it }.toSet(), rootTxns.toSet()) + fun toTransactionGroup() = TransactionGroup( + txns.map { it.toLedgerTransaction(MockIdentityService) }.toSet(), + rootTxns.map { it.toLedgerTransaction(MockIdentityService) }.toSet() + ) class Failed(val index: Int, cause: Throwable) : Exception("Transaction $index didn't verify", cause) @@ -302,8 +295,8 @@ class TransactionGroupDSL(private val stateType: Class) { group.verify(MockContractFactory) } catch (e: TransactionVerificationException) { // Let the developer know the index of the transaction that failed. - val ltx: LedgerTransaction = txns.find { it.hash == e.tx.origHash }!! - throw Failed(txns.indexOf(ltx) + 1, e) + val wtx: WireTransaction = txns.find { it.id == e.tx.origHash }!! + throw Failed(txns.indexOf(wtx) + 1, e) } } @@ -323,7 +316,17 @@ class TransactionGroupDSL(private val stateType: Class) { } fun signAll(): List { - return txns.map { it.toSignedTransaction(andSignWithKeys = ALL_KEYS, allowUnusedKeys = true) } + return txns.map { wtx -> + val allPubKeys = wtx.commands.flatMap { it.pubkeys }.toSet() + val bits = wtx.serialize() + require(bits == wtx.serialized) + val sigs = ArrayList() + for (key in ALL_TEST_KEYS) { + if (allPubKeys.contains(key.public)) + sigs += key.signWithECDSA(bits) + } + wtx.toSignedTransaction(sigs) + } } }