diff --git a/core/src/main/kotlin/protocols/TwoPartyDealProtocol.kt b/core/src/main/kotlin/protocols/TwoPartyDealProtocol.kt index df518ce36b..e2f6386bba 100644 --- a/core/src/main/kotlin/protocols/TwoPartyDealProtocol.kt +++ b/core/src/main/kotlin/protocols/TwoPartyDealProtocol.kt @@ -43,7 +43,7 @@ object TwoPartyDealProtocol { val sessionID: Long ) - class SignaturesFromPrimary(val sellerSig: DigitalSignature.WithKey, val notarySig: DigitalSignature.WithKey) + class SignaturesFromPrimary(val sellerSig: DigitalSignature.WithKey, val notarySig: DigitalSignature.LegallyIdentifiable) /** * Abstracted bilateral deal protocol participant that initiates communication/handshake. @@ -211,6 +211,8 @@ object TwoPartyDealProtocol { val signatures = swapSignaturesWithPrimary(stx, handshake.sessionID) logger.trace { "Got signatures from other party, verifying ... " } + + verifyCorrectNotary(stx.tx, signatures.notarySig) val fullySigned = stx + signatures.sellerSig + signatures.notarySig fullySigned.verify() @@ -255,6 +257,11 @@ object TwoPartyDealProtocol { return ptx.toSignedTransaction(checkSufficientSignatures = false) } + private fun verifyCorrectNotary(wtx: WireTransaction, sig: DigitalSignature.LegallyIdentifiable) { + val notary = serviceHub.loadState(wtx.inputs.first()).notary + check(sig.signer == notary) { "Transaction not signed by the required Notary" } + } + @Suspendable protected abstract fun validateHandshake(handshake: Handshake<U>): Handshake<U> @Suspendable protected abstract fun assembleSharedTX(handshake: Handshake<U>): Pair<TransactionBuilder, List<PublicKey>> } diff --git a/node/src/main/kotlin/protocols/TwoPartyTradeProtocol.kt b/node/src/main/kotlin/protocols/TwoPartyTradeProtocol.kt index 99b47c6cfd..50632a71e3 100644 --- a/node/src/main/kotlin/protocols/TwoPartyTradeProtocol.kt +++ b/node/src/main/kotlin/protocols/TwoPartyTradeProtocol.kt @@ -73,7 +73,7 @@ object TwoPartyTradeProtocol { ) class SignaturesFromSeller(val sellerSig: DigitalSignature.WithKey, - val notarySig: DigitalSignature.WithKey) + val notarySig: DigitalSignature.LegallyIdentifiable) open class Seller(val otherSide: SingleMessageRecipient, val notaryNode: NodeInfo, @@ -215,6 +215,8 @@ object TwoPartyTradeProtocol { val signatures = swapSignaturesWithSeller(stx, tradeRequest.sessionID) logger.trace { "Got signatures from seller, verifying ... " } + + verifyCorrectNotary(stx.tx, signatures.notarySig) val fullySigned = stx + signatures.sellerSig + signatures.notarySig fullySigned.verify() @@ -270,6 +272,11 @@ object TwoPartyTradeProtocol { return ptx.toSignedTransaction(checkSufficientSignatures = false) } + private fun verifyCorrectNotary(wtx: WireTransaction, sig: DigitalSignature.LegallyIdentifiable) { + val notary = serviceHub.loadState(wtx.inputs.first()).notary + check(sig.signer == notary) { "Transaction not signed by the required Notary" } + } + private fun assembleSharedTX(tradeRequest: SellerTradeInfo): Pair<TransactionBuilder, List<PublicKey>> { val ptx = TransactionBuilder() // Add input and output states for the movement of cash, by using the Cash contract to generate the states. diff --git a/node/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt b/node/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt index 115d1b17d3..c93db34411 100644 --- a/node/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt +++ b/node/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt @@ -163,6 +163,9 @@ class TwoPartyTradeProtocolTests { // OK, now Bob has sent the partial transaction back to Alice and is waiting for Alice's signature. assertThat(bobNode.storage.checkpointStorage.checkpoints).hasSize(1) + // TODO: remove once validated transactions are persisted to disk + val recordedTransactions = bobNode.storage.validatedTransactions + // .. and let's imagine that Bob's computer has a power cut. He now has nothing now beyond what was on disk. bobNode.stop() @@ -179,8 +182,11 @@ class TwoPartyTradeProtocolTests { } }, BOB.name, BOB_KEY) + // TODO: remove once validated transactions are persisted to disk + bobNode.storage.validatedTransactions.putAll(recordedTransactions) + // Find the future representing the result of this state machine again. - var bobFuture = bobNode.smm.findStateMachines(TwoPartyTradeProtocol.Buyer::class.java).single().second + val bobFuture = bobNode.smm.findStateMachines(TwoPartyTradeProtocol.Buyer::class.java).single().second // And off we go again. net.runNetwork() @@ -268,7 +274,9 @@ class TwoPartyTradeProtocolTests { RecordingMap.Get(bobsFakeCash[1].id), RecordingMap.Get(bobsFakeCash[2].id), // Alice notices that Bob's cash txns depend on a third tx she also doesn't know. She asks, Bob answers. - RecordingMap.Get(bobsFakeCash[0].id) + RecordingMap.Get(bobsFakeCash[0].id), + // Bob wants to verify that the tx has been signed by the correct Notary, which requires looking up an input state + RecordingMap.Get(bobsFakeCash[1].id) ) assertEquals(expected, records)