diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessageDeduplicator.kt b/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessageDeduplicator.kt index becc12e116..e96b832933 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessageDeduplicator.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessageDeduplicator.kt @@ -62,8 +62,7 @@ class P2PMessageDeduplicator(cacheFactory: NamedCacheFactory, private val databa private fun isDuplicateInDatabase(msg: ReceivedMessage): Boolean = database.transaction { msg.uniqueMessageId.sessionIdentifier in sessionData } - // We need to incorporate the sending party, and the sessionInit flag as per the in-memory cache. - private fun senderHash(senderKey: SenderKey) = SecureHash.sha256(senderKey.peer.toString() + senderKey.isSessionInit.toString() + senderKey.senderUUID).toString() + private fun senderHash(senderKey: SenderKey) = SecureHash.sha256(senderKey.peer.toString() + senderKey.senderUUID).toString() /** * Determines whether a session-init message is a duplicate. @@ -85,10 +84,12 @@ class P2PMessageDeduplicator(cacheFactory: NamedCacheFactory, private val databa * Called the first time we encounter [deduplicationId]. */ fun signalMessageProcessStart(msg: ReceivedMessage) { + require(msg.uniqueMessageId.messageType == MessageType.SESSION_INIT) { "Message ${msg.uniqueMessageId} was not a session-init message." } + val receivedSenderUUID = msg.senderUUID val receivedSenderSeqNo = msg.senderSeqNo // We don't want a mix of nulls and values so we ensure that here. - val senderHash: String? = if (receivedSenderUUID != null && receivedSenderSeqNo != null) senderHash(SenderKey(receivedSenderUUID, msg.peer, msg.isSessionInit)) else null + val senderHash: String? = if (receivedSenderUUID != null && receivedSenderSeqNo != null) senderHash(SenderKey(receivedSenderUUID, msg.peer)) else null val firstSenderSeqNo: Long? = if (senderHash != null) msg.senderSeqNo else null beingProcessedMessages[msg.uniqueMessageId] = MessageMeta(msg.uniqueMessageId.timestamp, senderHash, firstSenderSeqNo, null) } @@ -97,6 +98,8 @@ class P2PMessageDeduplicator(cacheFactory: NamedCacheFactory, private val databa * Called inside a DB transaction to persist [deduplicationId]. */ fun persistDeduplicationId(deduplicationId: MessageIdentifier) { + require(deduplicationId.messageType == MessageType.SESSION_INIT) { "Message $deduplicationId was not a session-init message." } + sessionData[deduplicationId.sessionIdentifier] = beingProcessedMessages[deduplicationId]!! } @@ -105,6 +108,8 @@ class P2PMessageDeduplicator(cacheFactory: NamedCacheFactory, private val databa * Any subsequent redelivery will be deduplicated using the DB. */ fun signalMessageProcessFinish(deduplicationId: MessageIdentifier) { + require(deduplicationId.messageType == MessageType.SESSION_INIT) { "Message $deduplicationId was not a session-init message." } + beingProcessedMessages.remove(deduplicationId) } @@ -173,5 +178,5 @@ class P2PMessageDeduplicator(cacheFactory: NamedCacheFactory, private val databa private data class MessageMeta(val generationTime: Instant, val senderHash: String?, val firstSenderSeqNo: SenderSequenceNumber?, val lastSenderSeqNo: SenderSequenceNumber?) - private data class SenderKey(val senderUUID: String, val peer: CordaX500Name, val isSessionInit: Boolean) + private data class SenderKey(val senderUUID: String, val peer: CordaX500Name) }