diff --git a/core/src/main/kotlin/protocols/TwoPartyDealProtocol.kt b/core/src/main/kotlin/protocols/TwoPartyDealProtocol.kt
index df518ce36b..e2f6386bba 100644
--- a/core/src/main/kotlin/protocols/TwoPartyDealProtocol.kt
+++ b/core/src/main/kotlin/protocols/TwoPartyDealProtocol.kt
@@ -43,7 +43,7 @@ object TwoPartyDealProtocol {
             val sessionID: Long
     )
 
-    class SignaturesFromPrimary(val sellerSig: DigitalSignature.WithKey, val notarySig: DigitalSignature.WithKey)
+    class SignaturesFromPrimary(val sellerSig: DigitalSignature.WithKey, val notarySig: DigitalSignature.LegallyIdentifiable)
 
     /**
      * Abstracted bilateral deal protocol participant that initiates communication/handshake.
@@ -211,6 +211,8 @@ object TwoPartyDealProtocol {
             val signatures = swapSignaturesWithPrimary(stx, handshake.sessionID)
 
             logger.trace { "Got signatures from other party, verifying ... " }
+
+            verifyCorrectNotary(stx.tx, signatures.notarySig)
             val fullySigned = stx + signatures.sellerSig + signatures.notarySig
             fullySigned.verify()
 
@@ -255,6 +257,11 @@ object TwoPartyDealProtocol {
             return ptx.toSignedTransaction(checkSufficientSignatures = false)
         }
 
+        private fun verifyCorrectNotary(wtx: WireTransaction, sig: DigitalSignature.LegallyIdentifiable) {
+            val notary = serviceHub.loadState(wtx.inputs.first()).notary
+            check(sig.signer == notary) { "Transaction not signed by the required Notary" }
+        }
+
         @Suspendable protected abstract fun validateHandshake(handshake: Handshake<U>): Handshake<U>
         @Suspendable protected abstract fun assembleSharedTX(handshake: Handshake<U>): Pair<TransactionBuilder, List<PublicKey>>
     }
diff --git a/node/src/main/kotlin/protocols/TwoPartyTradeProtocol.kt b/node/src/main/kotlin/protocols/TwoPartyTradeProtocol.kt
index 99b47c6cfd..50632a71e3 100644
--- a/node/src/main/kotlin/protocols/TwoPartyTradeProtocol.kt
+++ b/node/src/main/kotlin/protocols/TwoPartyTradeProtocol.kt
@@ -73,7 +73,7 @@ object TwoPartyTradeProtocol {
     )
 
     class SignaturesFromSeller(val sellerSig: DigitalSignature.WithKey,
-                               val notarySig: DigitalSignature.WithKey)
+                               val notarySig: DigitalSignature.LegallyIdentifiable)
 
     open class Seller(val otherSide: SingleMessageRecipient,
                       val notaryNode: NodeInfo,
@@ -215,6 +215,8 @@ object TwoPartyTradeProtocol {
             val signatures = swapSignaturesWithSeller(stx, tradeRequest.sessionID)
 
             logger.trace { "Got signatures from seller, verifying ... " }
+
+            verifyCorrectNotary(stx.tx, signatures.notarySig)
             val fullySigned = stx + signatures.sellerSig + signatures.notarySig
             fullySigned.verify()
 
@@ -270,6 +272,11 @@ object TwoPartyTradeProtocol {
             return ptx.toSignedTransaction(checkSufficientSignatures = false)
         }
 
+        private fun verifyCorrectNotary(wtx: WireTransaction, sig: DigitalSignature.LegallyIdentifiable) {
+            val notary = serviceHub.loadState(wtx.inputs.first()).notary
+            check(sig.signer == notary) { "Transaction not signed by the required Notary" }
+        }
+
         private fun assembleSharedTX(tradeRequest: SellerTradeInfo): Pair<TransactionBuilder, List<PublicKey>> {
             val ptx = TransactionBuilder()
             // Add input and output states for the movement of cash, by using the Cash contract to generate the states.
diff --git a/node/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt b/node/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt
index 115d1b17d3..c93db34411 100644
--- a/node/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt
+++ b/node/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt
@@ -163,6 +163,9 @@ class TwoPartyTradeProtocolTests {
             // OK, now Bob has sent the partial transaction back to Alice and is waiting for Alice's signature.
             assertThat(bobNode.storage.checkpointStorage.checkpoints).hasSize(1)
 
+            // TODO: remove once validated transactions are persisted to disk
+            val recordedTransactions = bobNode.storage.validatedTransactions
+
             // .. and let's imagine that Bob's computer has a power cut. He now has nothing now beyond what was on disk.
             bobNode.stop()
 
@@ -179,8 +182,11 @@ class TwoPartyTradeProtocolTests {
                 }
             }, BOB.name, BOB_KEY)
 
+            // TODO: remove once validated transactions are persisted to disk
+            bobNode.storage.validatedTransactions.putAll(recordedTransactions)
+
             // Find the future representing the result of this state machine again.
-            var bobFuture = bobNode.smm.findStateMachines(TwoPartyTradeProtocol.Buyer::class.java).single().second
+            val bobFuture = bobNode.smm.findStateMachines(TwoPartyTradeProtocol.Buyer::class.java).single().second
 
             // And off we go again.
             net.runNetwork()
@@ -268,7 +274,9 @@ class TwoPartyTradeProtocolTests {
                         RecordingMap.Get(bobsFakeCash[1].id),
                         RecordingMap.Get(bobsFakeCash[2].id),
                         // Alice notices that Bob's cash txns depend on a third tx she also doesn't know. She asks, Bob answers.
-                        RecordingMap.Get(bobsFakeCash[0].id)
+                        RecordingMap.Get(bobsFakeCash[0].id),
+                        // Bob wants to verify that the tx has been signed by the correct Notary, which requires looking up an input state
+                        RecordingMap.Get(bobsFakeCash[1].id)
                 )
                 assertEquals(expected, records)