Add logic for rejecting old messages based on event horizon

This commit is contained in:
Dimos Raptis 2020-09-24 15:16:20 +01:00
parent 673f02d635
commit dd8763494f
4 changed files with 51 additions and 5 deletions

View File

@ -5,6 +5,7 @@ import com.nhaarman.mockito_kotlin.doReturn
import com.nhaarman.mockito_kotlin.whenever import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.crypto.generateKeyPair import net.corda.core.crypto.generateKeyPair
import net.corda.core.internal.div import net.corda.core.internal.div
import net.corda.core.node.NetworkParameters
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.seconds import net.corda.core.utilities.seconds
import net.corda.node.services.config.FlowTimeoutConfiguration import net.corda.node.services.config.FlowTimeoutConfiguration
@ -26,6 +27,7 @@ import net.corda.coretesting.internal.rigorousMock
import net.corda.coretesting.internal.stubs.CertificateStoreStubs import net.corda.coretesting.internal.stubs.CertificateStoreStubs
import net.corda.node.services.statemachine.MessageType import net.corda.node.services.statemachine.MessageType
import net.corda.node.services.statemachine.SessionId import net.corda.node.services.statemachine.SessionId
import net.corda.testing.common.internal.eventually
import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties
import net.corda.testing.node.internal.MOCK_VERSION_INFO import net.corda.testing.node.internal.MOCK_VERSION_INFO
import org.apache.activemq.artemis.api.core.ActiveMQConnectionTimedOutException import org.apache.activemq.artemis.api.core.ActiveMQConnectionTimedOutException
@ -40,6 +42,8 @@ import rx.subjects.PublishSubject
import java.math.BigInteger import java.math.BigInteger
import java.net.ServerSocket import java.net.ServerSocket
import java.time.Clock import java.time.Clock
import java.time.Duration
import java.time.Instant
import java.util.concurrent.BlockingQueue import java.util.concurrent.BlockingQueue
import java.util.concurrent.LinkedBlockingQueue import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.TimeUnit.MILLISECONDS import java.util.concurrent.TimeUnit.MILLISECONDS
@ -52,6 +56,7 @@ class ArtemisMessagingTest {
companion object { companion object {
const val TOPIC = "platform.self" const val TOPIC = "platform.self"
private val MESSAGE_IDENTIFIER = MessageIdentifier(MessageType.DATA_MESSAGE, "XXXXXXXX", SessionId(BigInteger.valueOf(14)), 0, Clock.systemUTC().instant()) private val MESSAGE_IDENTIFIER = MessageIdentifier(MessageType.DATA_MESSAGE, "XXXXXXXX", SessionId(BigInteger.valueOf(14)), 0, Clock.systemUTC().instant())
private val EVENT_HORIZON = Duration.ofDays(5)
} }
@Rule @Rule
@ -191,6 +196,23 @@ class ArtemisMessagingTest {
assertNull(receivedMessages.poll(200, MILLISECONDS)) assertNull(receivedMessages.poll(200, MILLISECONDS))
} }
@Test(timeout=300_000)
fun `server should reject messages older than the event horizon`() {
val (messagingClient, receivedMessages) = createAndStartClientAndServer(clientMaxMessageSize = 100_000, serverMaxMessageSize = 50_000)
val regularMessage = messagingClient.createMessage(TOPIC, ByteArray(50_000), SenderDeduplicationInfo(MESSAGE_IDENTIFIER.copy(timestamp = Instant.now()), null), emptyMap())
val tooOldMessage = messagingClient.createMessage(TOPIC, ByteArray(50_000), SenderDeduplicationInfo(MESSAGE_IDENTIFIER.copy(timestamp = Instant.now().minus(EVENT_HORIZON)), null), emptyMap())
listOf(tooOldMessage, regularMessage).forEach { messagingClient.send(it, messagingClient.myAddress) }
val regularMsgReceived = receivedMessages.take()
assertThat(regularMsgReceived.uniqueMessageId).isEqualTo(regularMessage.uniqueMessageId)
eventually {
assertThat(messagingServer!!.totalMessagesAcknowledged()).isEqualTo(2)
}
assertThat(receivedMessages).isEmpty()
}
@Test(timeout=300_000) @Test(timeout=300_000)
fun `platform version is included in the message`() { fun `platform version is included in the message`() {
val (messagingClient, receivedMessages) = createAndStartClientAndServer(platformVersion = 3) val (messagingClient, receivedMessages) = createAndStartClientAndServer(platformVersion = 3)
@ -202,7 +224,8 @@ class ArtemisMessagingTest {
} }
private fun startNodeMessagingClient(maxMessageSize: Int = MAX_MESSAGE_SIZE) { private fun startNodeMessagingClient(maxMessageSize: Int = MAX_MESSAGE_SIZE) {
messagingClient!!.start(identity.public, null, maxMessageSize) val networkParams = NetworkParameters(3, emptyList(), maxMessageSize, 1_000, Instant.now(), 5, emptyMap(), EVENT_HORIZON)
messagingClient!!.start(identity.public, null, networkParams)
} }
private fun createAndStartClientAndServer(platformVersion: Int = 1, serverMaxMessageSize: Int = MAX_MESSAGE_SIZE, clientMaxMessageSize: Int = MAX_MESSAGE_SIZE): Pair<P2PMessagingClient, BlockingQueue<ReceivedMessage>> { private fun createAndStartClientAndServer(platformVersion: Int = 1, serverMaxMessageSize: Int = MAX_MESSAGE_SIZE, clientMaxMessageSize: Int = MAX_MESSAGE_SIZE): Pair<P2PMessagingClient, BlockingQueue<ReceivedMessage>> {

View File

@ -409,7 +409,7 @@ open class Node(configuration: NodeConfiguration,
myIdentity = nodeInfo.legalIdentities[0].owningKey, myIdentity = nodeInfo.legalIdentities[0].owningKey,
serviceIdentity = if (nodeInfo.legalIdentities.size == 1) null else nodeInfo.legalIdentities[1].owningKey, serviceIdentity = if (nodeInfo.legalIdentities.size == 1) null else nodeInfo.legalIdentities[1].owningKey,
advertisedAddress = nodeInfo.addresses[0], advertisedAddress = nodeInfo.addresses[0],
maxMessageSize = networkParameters.maxMessageSize networkParams = networkParameters
) )
} }

View File

@ -1,6 +1,7 @@
package net.corda.node.services.messaging package net.corda.node.services.messaging
import net.corda.core.internal.ThreadBox import net.corda.core.internal.ThreadBox
import net.corda.core.internal.VisibleForTesting
import net.corda.core.internal.div import net.corda.core.internal.div
import net.corda.core.internal.errors.AddressBindingException import net.corda.core.internal.errors.AddressBindingException
import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.SingletonSerializeAsToken
@ -89,6 +90,11 @@ class ArtemisMessagingServer(private val config: NodeConfiguration,
override val started: Boolean override val started: Boolean
get() = activeMQServer.isStarted get() = activeMQServer.isStarted
@VisibleForTesting
fun totalMessagesAcknowledged(): Long {
return activeMQServer.totalMessagesAcknowledged
}
// TODO: Maybe wrap [IOException] on a key store load error so that it's clearly splitting key store loading from // TODO: Maybe wrap [IOException] on a key store load error so that it's clearly splitting key store loading from
// Artemis IO errors // Artemis IO errors
@Throws(IOException::class, AddressBindingException::class, KeyStoreException::class) @Throws(IOException::class, AddressBindingException::class, KeyStoreException::class)

View File

@ -10,6 +10,7 @@ import net.corda.core.internal.ThreadBox
import net.corda.core.messaging.CordaRPCOps import net.corda.core.messaging.CordaRPCOps
import net.corda.core.messaging.MessageRecipients import net.corda.core.messaging.MessageRecipients
import net.corda.core.messaging.SingleMessageRecipient import net.corda.core.messaging.SingleMessageRecipient
import net.corda.core.node.NetworkParameters
import net.corda.core.node.NodeInfo import net.corda.core.node.NodeInfo
import net.corda.core.node.services.NetworkMapCache import net.corda.core.node.services.NetworkMapCache
import net.corda.core.node.services.PartyInfo import net.corda.core.node.services.PartyInfo
@ -55,6 +56,7 @@ import rx.Observable
import rx.Subscription import rx.Subscription
import rx.subjects.PublishSubject import rx.subjects.PublishSubject
import java.security.PublicKey import java.security.PublicKey
import java.time.Duration
import java.time.Instant import java.time.Instant
import java.util.* import java.util.*
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap
@ -133,6 +135,7 @@ class P2PMessagingClient(val config: NodeConfiguration,
private var serviceIdentity: PublicKey? = null private var serviceIdentity: PublicKey? = null
private lateinit var advertisedAddress: NetworkHostAndPort private lateinit var advertisedAddress: NetworkHostAndPort
private var maxMessageSize: Int = -1 private var maxMessageSize: Int = -1
private lateinit var eventHorizon: Duration
override val myAddress: SingleMessageRecipient get() = NodeAddress(myIdentity) override val myAddress: SingleMessageRecipient get() = NodeAddress(myIdentity)
override val ourSenderUUID = UUID.randomUUID().toString() override val ourSenderUUID = UUID.randomUUID().toString()
@ -153,13 +156,14 @@ class P2PMessagingClient(val config: NodeConfiguration,
* @param serviceIdentity An optional second identity if the node is also part of a group address, for example a notary. * @param serviceIdentity An optional second identity if the node is also part of a group address, for example a notary.
* @param advertisedAddress The externally advertised version of the Artemis broker address used to construct myAddress and included * @param advertisedAddress The externally advertised version of the Artemis broker address used to construct myAddress and included
* in the network map data. * in the network map data.
* @param maxMessageSize A bound applied to the message size. * @param networkParams the network parameters when the service is started.
*/ */
fun start(myIdentity: PublicKey, serviceIdentity: PublicKey?, maxMessageSize: Int, advertisedAddress: NetworkHostAndPort = serverAddress) { fun start(myIdentity: PublicKey, serviceIdentity: PublicKey?, networkParams: NetworkParameters, advertisedAddress: NetworkHostAndPort = serverAddress) {
this.myIdentity = myIdentity this.myIdentity = myIdentity
this.serviceIdentity = serviceIdentity this.serviceIdentity = serviceIdentity
this.advertisedAddress = advertisedAddress this.advertisedAddress = advertisedAddress
this.maxMessageSize = maxMessageSize this.maxMessageSize = networkParams.maxMessageSize
this.eventHorizon = networkParams.eventHorizon
state.locked { state.locked {
started = true started = true
log.info("Connecting to message broker: $serverAddress") log.info("Connecting to message broker: $serverAddress")
@ -405,6 +409,15 @@ class P2PMessagingClient(val config: NodeConfiguration,
internal fun deliver(artemisMessage: ClientMessage) { internal fun deliver(artemisMessage: ClientMessage) {
artemisToCordaMessage(artemisMessage)?.let { cordaMessage -> artemisToCordaMessage(artemisMessage)?.let { cordaMessage ->
if (isTooOld(cordaMessage)) {
log.info("Discarding old message message with identifier: ${cordaMessage.uniqueMessageId}, " +
"senderUUID: ${cordaMessage.senderUUID}, " +
"senderSeqNo: ${cordaMessage.senderSeqNo}, " +
"timestamp: ${cordaMessage.uniqueMessageId.timestamp}")
messagingExecutor!!.acknowledge(artemisMessage)
return
}
if (cordaMessage.uniqueMessageId.messageType == MessageType.SESSION_INIT) { if (cordaMessage.uniqueMessageId.messageType == MessageType.SESSION_INIT) {
if (!deduplicator.isDuplicateSessionInit(cordaMessage)) { if (!deduplicator.isDuplicateSessionInit(cordaMessage)) {
deduplicator.signalMessageProcessStart(cordaMessage) deduplicator.signalMessageProcessStart(cordaMessage)
@ -420,6 +433,10 @@ class P2PMessagingClient(val config: NodeConfiguration,
} }
} }
private fun isTooOld(msg: ReceivedMessage): Boolean {
return msg.uniqueMessageId.timestamp.isBefore(Instant.now().minus(eventHorizon))
}
private fun deliver(msg: ReceivedMessage, artemisMessage: ClientMessage) { private fun deliver(msg: ReceivedMessage, artemisMessage: ClientMessage) {
state.checkNotLocked() state.checkNotLocked()
val deliverTo = handlers[msg.topic] val deliverTo = handlers[msg.topic]