Introduce MessageIdentifier and related tests

This commit is contained in:
Dimos Raptis 2020-08-14 11:44:07 +01:00
parent 8ed3dc1150
commit 64e7fdd83a
43 changed files with 1058 additions and 355 deletions

View File

@ -24,6 +24,8 @@ import net.corda.testing.internal.TestingNamedCacheFactory
import net.corda.testing.internal.configureDatabase import net.corda.testing.internal.configureDatabase
import net.corda.coretesting.internal.rigorousMock import net.corda.coretesting.internal.rigorousMock
import net.corda.coretesting.internal.stubs.CertificateStoreStubs import net.corda.coretesting.internal.stubs.CertificateStoreStubs
import net.corda.node.services.statemachine.MessageType
import net.corda.node.services.statemachine.SessionId
import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties
import net.corda.testing.node.internal.MOCK_VERSION_INFO import net.corda.testing.node.internal.MOCK_VERSION_INFO
import org.apache.activemq.artemis.api.core.ActiveMQConnectionTimedOutException import org.apache.activemq.artemis.api.core.ActiveMQConnectionTimedOutException
@ -35,7 +37,9 @@ import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.junit.rules.TemporaryFolder import org.junit.rules.TemporaryFolder
import rx.subjects.PublishSubject import rx.subjects.PublishSubject
import java.math.BigInteger
import java.net.ServerSocket import java.net.ServerSocket
import java.time.Clock
import java.util.concurrent.BlockingQueue import java.util.concurrent.BlockingQueue
import java.util.concurrent.LinkedBlockingQueue import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.TimeUnit.MILLISECONDS import java.util.concurrent.TimeUnit.MILLISECONDS
@ -47,6 +51,7 @@ import kotlin.test.assertTrue
class ArtemisMessagingTest { class ArtemisMessagingTest {
companion object { companion object {
const val TOPIC = "platform.self" const val TOPIC = "platform.self"
private val MESSAGE_IDENTIFIER = MessageIdentifier(MessageType.DATA_MESSAGE, "XXXXXXXX", SessionId(BigInteger.valueOf(14)), 0, Clock.systemUTC().instant())
} }
@Rule @Rule
@ -142,7 +147,7 @@ class ArtemisMessagingTest {
@Test(timeout=300_000) @Test(timeout=300_000)
fun `client should be able to send message to itself`() { fun `client should be able to send message to itself`() {
val (messagingClient, receivedMessages) = createAndStartClientAndServer() val (messagingClient, receivedMessages) = createAndStartClientAndServer()
val message = messagingClient.createMessage(TOPIC, data = "first msg".toByteArray()) val message = messagingClient.createMessage(TOPIC, "first msg".toByteArray(), SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
messagingClient.send(message, messagingClient.myAddress) messagingClient.send(message, messagingClient.myAddress)
val actual: Message = receivedMessages.take() val actual: Message = receivedMessages.take()
@ -153,14 +158,14 @@ class ArtemisMessagingTest {
@Test(timeout=300_000) @Test(timeout=300_000)
fun `client should fail if message exceed maxMessageSize limit`() { fun `client should fail if message exceed maxMessageSize limit`() {
val (messagingClient, receivedMessages) = createAndStartClientAndServer() val (messagingClient, receivedMessages) = createAndStartClientAndServer()
val message = messagingClient.createMessage(TOPIC, data = ByteArray(MAX_MESSAGE_SIZE)) val message = messagingClient.createMessage(TOPIC, ByteArray(MAX_MESSAGE_SIZE), SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
messagingClient.send(message, messagingClient.myAddress) messagingClient.send(message, messagingClient.myAddress)
val actual: Message = receivedMessages.take() val actual: Message = receivedMessages.take()
assertTrue(ByteArray(MAX_MESSAGE_SIZE).contentEquals(actual.data.bytes)) assertTrue(ByteArray(MAX_MESSAGE_SIZE).contentEquals(actual.data.bytes))
assertNull(receivedMessages.poll(200, MILLISECONDS)) assertNull(receivedMessages.poll(200, MILLISECONDS))
val tooLagerMessage = messagingClient.createMessage(TOPIC, data = ByteArray(MAX_MESSAGE_SIZE + 1)) val tooLagerMessage = messagingClient.createMessage(TOPIC, ByteArray(MAX_MESSAGE_SIZE + 1), SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
assertThatThrownBy { assertThatThrownBy {
messagingClient.send(tooLagerMessage, messagingClient.myAddress) messagingClient.send(tooLagerMessage, messagingClient.myAddress)
}.isInstanceOf(IllegalArgumentException::class.java) }.isInstanceOf(IllegalArgumentException::class.java)
@ -172,14 +177,14 @@ class ArtemisMessagingTest {
@Test(timeout=300_000) @Test(timeout=300_000)
fun `server should not process if incoming message exceed maxMessageSize limit`() { fun `server should not process if incoming message exceed maxMessageSize limit`() {
val (messagingClient, receivedMessages) = createAndStartClientAndServer(clientMaxMessageSize = 100_000, serverMaxMessageSize = 50_000) val (messagingClient, receivedMessages) = createAndStartClientAndServer(clientMaxMessageSize = 100_000, serverMaxMessageSize = 50_000)
val message = messagingClient.createMessage(TOPIC, data = ByteArray(50_000)) val message = messagingClient.createMessage(TOPIC, ByteArray(50_000), SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
messagingClient.send(message, messagingClient.myAddress) messagingClient.send(message, messagingClient.myAddress)
val actual: Message = receivedMessages.take() val actual: Message = receivedMessages.take()
assertTrue(ByteArray(50_000).contentEquals(actual.data.bytes)) assertTrue(ByteArray(50_000).contentEquals(actual.data.bytes))
assertNull(receivedMessages.poll(200, MILLISECONDS)) assertNull(receivedMessages.poll(200, MILLISECONDS))
val tooLagerMessage = messagingClient.createMessage(TOPIC, data = ByteArray(100_000)) val tooLagerMessage = messagingClient.createMessage(TOPIC, ByteArray(100_000), SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
assertThatThrownBy { assertThatThrownBy {
messagingClient.send(tooLagerMessage, messagingClient.myAddress) messagingClient.send(tooLagerMessage, messagingClient.myAddress)
}.isInstanceOf(ActiveMQConnectionTimedOutException::class.java) }.isInstanceOf(ActiveMQConnectionTimedOutException::class.java)
@ -189,7 +194,7 @@ class ArtemisMessagingTest {
@Test(timeout=300_000) @Test(timeout=300_000)
fun `platform version is included in the message`() { fun `platform version is included in the message`() {
val (messagingClient, receivedMessages) = createAndStartClientAndServer(platformVersion = 3) val (messagingClient, receivedMessages) = createAndStartClientAndServer(platformVersion = 3)
val message = messagingClient.createMessage(TOPIC, data = "first msg".toByteArray()) val message = messagingClient.createMessage(TOPIC, "first msg".toByteArray(), SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
messagingClient.send(message, messagingClient.myAddress) messagingClient.send(message, messagingClient.myAddress)
val received = receivedMessages.take() val received = receivedMessages.take()

View File

@ -10,9 +10,13 @@ import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.seconds import net.corda.core.utilities.seconds
import net.corda.node.services.messaging.MessageIdentifier
import net.corda.node.services.messaging.MessagingService import net.corda.node.services.messaging.MessagingService
import net.corda.node.services.messaging.ReceivedMessage import net.corda.node.services.messaging.ReceivedMessage
import net.corda.node.services.messaging.SenderDeduplicationInfo
import net.corda.node.services.messaging.send import net.corda.node.services.messaging.send
import net.corda.node.services.statemachine.MessageType
import net.corda.node.services.statemachine.SessionId
import net.corda.testing.driver.DriverDSL import net.corda.testing.driver.DriverDSL
import net.corda.testing.driver.DriverParameters import net.corda.testing.driver.DriverParameters
import net.corda.testing.driver.InProcess import net.corda.testing.driver.InProcess
@ -23,12 +27,15 @@ import net.corda.testing.node.NotarySpec
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.junit.Ignore import org.junit.Ignore
import org.junit.Test import org.junit.Test
import java.math.BigInteger
import java.time.Clock
import java.util.* import java.util.*
import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicBoolean
class P2PMessagingTest { class P2PMessagingTest {
private companion object { private companion object {
val DISTRIBUTED_SERVICE_NAME = CordaX500Name("DistributedService", "London", "GB") val DISTRIBUTED_SERVICE_NAME = CordaX500Name("DistributedService", "London", "GB")
private val MESSAGE_IDENTIFIER = MessageIdentifier(MessageType.DATA_MESSAGE, "XXXXXXXX", SessionId(BigInteger.valueOf(14)), 0, Clock.systemUTC().instant())
} }
@Test(timeout=300_000) @Test(timeout=300_000)
@ -72,7 +79,7 @@ class P2PMessagingTest {
private fun InProcess.respondWith(message: Any) { private fun InProcess.respondWith(message: Any) {
internalServices.networkService.addMessageHandler("test.request") { netMessage, _, handler -> internalServices.networkService.addMessageHandler("test.request") { netMessage, _, handler ->
val request = netMessage.data.deserialize<TestRequest>() val request = netMessage.data.deserialize<TestRequest>()
val response = internalServices.networkService.createMessage("test.response", message.serialize().bytes) val response = internalServices.networkService.createMessage("test.response", message.serialize().bytes, SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
internalServices.networkService.send(response, request.replyTo) internalServices.networkService.send(response, request.replyTo)
handler.afterDatabaseTransaction() handler.afterDatabaseTransaction()
} }
@ -83,7 +90,7 @@ class P2PMessagingTest {
internalServices.networkService.runOnNextMessage("test.response") { netMessage -> internalServices.networkService.runOnNextMessage("test.response") { netMessage ->
response.set(netMessage.data.deserialize()) response.set(netMessage.data.deserialize())
} }
internalServices.networkService.send("test.request", TestRequest(replyTo = internalServices.networkService.myAddress), target) internalServices.networkService.send("test.request", TestRequest(replyTo = internalServices.networkService.myAddress), target, SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
return response return response
} }

View File

@ -0,0 +1,93 @@
package net.corda.node.services.messaging
import net.corda.core.crypto.SecureHash
import net.corda.node.services.statemachine.MessageType
import net.corda.node.services.statemachine.SessionId
import java.lang.IllegalStateException
import java.math.BigInteger
import java.time.Instant
/**
* This represents the unique identifier for every message.
* It's composed of multiple segments.
*
* @property messageType the type of the message.
* @property shardIdentifier an identifier that can be used to partition messages into groups for sharding purposes.
* This is supposed to have the same value for messages that correspond to the same business-level flow. It is
* @property sessionIdentifier the identifier of the session this message belongs to. This corresponds to the identifier of the session on the receiving side.
* @property sessionSequenceNumber the sequence number of the message inside the session. This can be used to handle out-of-order delivery.
* @property timestamp the time when the message was requested to be sent.
* This is expected to remain the same across replays of the same message and represent the moment in time when this message was initially scheduled to be sent.
*/
data class MessageIdentifier(
val messageType: MessageType,
val shardIdentifier: String,
val sessionIdentifier: SessionId,
val sessionSequenceNumber: Int,
val timestamp: Instant
) {
init {
require(shardIdentifier.length == 8) { "Shard identifier needs to be 8 characters long, but it was $shardIdentifier" }
}
companion object {
const val LONG_SIZE_IN_HEX = 16 // 64 / 4
const val SESSION_ID_SIZE_IN_HEX = SessionId.MAX_BIT_SIZE / 4
fun parse(id: String): MessageIdentifier {
val prefix = id.substring(0, 2)
val messageType = prefixToMessageType(prefix)
val timestamp = java.lang.Long.parseUnsignedLong(id.substring(3, 19), 16)
val shardIdentifier = id.substring(20, 28)
val sessionId = BigInteger(id.substring(29, 61), 16)
val sessionSequenceNumber = Integer.parseInt(id.substring(62), 16)
return MessageIdentifier(messageType, shardIdentifier, SessionId(sessionId), sessionSequenceNumber, Instant.ofEpochMilli(timestamp))
}
private fun messageTypeToPrefix(messageType: MessageType): String {
return when(messageType) {
MessageType.SESSION_INIT -> "XI"
MessageType.SESSION_CONFIRM -> "XC"
MessageType.SESSION_REJECT -> "XR"
MessageType.DATA_MESSAGE -> "XD"
MessageType.SESSION_END -> "XE"
MessageType.SESSION_ERROR -> "XX"
}
}
private fun prefixToMessageType(prefix: String): MessageType {
return when(prefix) {
"XI" -> MessageType.SESSION_INIT
"XC" -> MessageType.SESSION_CONFIRM
"XR" -> MessageType.SESSION_REJECT
"XD" -> MessageType.DATA_MESSAGE
"XE" -> MessageType.SESSION_END
"XX" -> MessageType.SESSION_ERROR
else -> throw IllegalStateException("Invalid prefix: $prefix")
}
}
}
override fun toString(): String {
val prefix = messageTypeToPrefix(messageType)
val encodedSessionIdentifier = String.format("%1$0${SESSION_ID_SIZE_IN_HEX}X", sessionIdentifier.value)
val encodedSequenceNumber = Integer.toHexString(sessionSequenceNumber).toUpperCase()
val encodedTimestamp = String.format("%1$0${LONG_SIZE_IN_HEX}X", timestamp.toEpochMilli())
return "$prefix-$encodedTimestamp-$shardIdentifier-$encodedSessionIdentifier-$encodedSequenceNumber"
}
}
fun generateShardId(flowIdentifier: String): String {
return SecureHash.sha256(flowIdentifier).prefixChars(8)
}
/**
* A unique identifier for a sender that might be different across restarts.
* It is used to help identify when messages are being sent continuously without errors or message are sent after the sender recovered from an error.
*/
typealias SenderUUID = String
/**
* A global sequence number for all the messages sent by a sender.
*/
typealias SenderSequenceNumber = Long

View File

@ -1,7 +1,6 @@
package net.corda.node.services.messaging package net.corda.node.services.messaging
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import net.corda.core.crypto.newSecureRandom
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.messaging.MessageRecipients import net.corda.core.messaging.MessageRecipients
import net.corda.core.messaging.SingleMessageRecipient import net.corda.core.messaging.SingleMessageRecipient
@ -9,9 +8,8 @@ import net.corda.core.node.services.PartyInfo
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import net.corda.core.utilities.ByteSequence import net.corda.core.utilities.ByteSequence
import net.corda.node.services.statemachine.DeduplicationId
import net.corda.node.services.statemachine.ExternalEvent import net.corda.node.services.statemachine.ExternalEvent
import net.corda.node.services.statemachine.SenderDeduplicationId import net.corda.node.services.statemachine.SessionId
import net.corda.nodeapi.internal.lifecycle.ServiceLifecycleSupport import net.corda.nodeapi.internal.lifecycle.ServiceLifecycleSupport
import java.time.Instant import java.time.Instant
import javax.annotation.concurrent.ThreadSafe import javax.annotation.concurrent.ThreadSafe
@ -32,7 +30,7 @@ interface MessagingService : ServiceLifecycleSupport {
* A unique identifier for this sender that changes whenever a node restarts. This is used in conjunction with a sequence * A unique identifier for this sender that changes whenever a node restarts. This is used in conjunction with a sequence
* number for message de-duplication at the recipient. * number for message de-duplication at the recipient.
*/ */
val ourSenderUUID: String val ourSenderUUID: SenderUUID
/** /**
* The provided function will be invoked for each received message whose topic and session matches. The callback * The provided function will be invoked for each received message whose topic and session matches. The callback
@ -92,15 +90,25 @@ interface MessagingService : ServiceLifecycleSupport {
@Suspendable @Suspendable
fun sendAll(addressedMessages: List<AddressedMessage>) fun sendAll(addressedMessages: List<AddressedMessage>)
/**
* Signal that a session has ended to the messaging layer, so that any necessary cleanup is performed.
*
* @param sessionId the identifier of the session that ended.
* @param senderUUID the sender UUID of the last message seen in the session or null if there was no sender UUID in that message.
* @param senderSequenceNumber the sender sequence number of the last message seen in the session or null if there was no sender sequence number in that message.
*/
@Suspendable
fun sessionEnded(sessionId: SessionId, senderUUID: SenderUUID?, senderSequenceNumber: SenderSequenceNumber?)
/** /**
* Returns an initialised [Message] with the current time, etc, already filled in. * Returns an initialised [Message] with the current time, etc, already filled in.
* *
* @param topic identifier for the topic the message is sent to. * @param topic identifier for the topic the message is sent to.
* @param data the payload for the message. * @param data the payload for the message.
* @param deduplicationId optional message deduplication ID including sender identifier. * @param deduplicationInfo optional message deduplication information.
* @param additionalHeaders optional additional message headers. * @param additionalHeaders optional additional message headers.
*/ */
fun createMessage(topic: String, data: ByteArray, deduplicationId: SenderDeduplicationId = SenderDeduplicationId(DeduplicationId.createRandom(newSecureRandom()), ourSenderUUID), additionalHeaders: Map<String, String> = emptyMap()): Message fun createMessage(topic: String, data: ByteArray, deduplicationInfo: SenderDeduplicationInfo, additionalHeaders: Map<String, String>): Message
/** Given information about either a specific node or a service returns its corresponding address */ /** Given information about either a specific node or a service returns its corresponding address */
fun getAddressOfParty(partyInfo: PartyInfo): MessageRecipients fun getAddressOfParty(partyInfo: PartyInfo): MessageRecipients
@ -109,7 +117,7 @@ interface MessagingService : ServiceLifecycleSupport {
val myAddress: SingleMessageRecipient val myAddress: SingleMessageRecipient
} }
fun MessagingService.send(topicSession: String, payload: Any, to: MessageRecipients, deduplicationId: SenderDeduplicationId = SenderDeduplicationId(DeduplicationId.createRandom(newSecureRandom()), ourSenderUUID), additionalHeaders: Map<String, String> = emptyMap()) = send(createMessage(topicSession, payload.serialize().bytes, deduplicationId, additionalHeaders), to) fun MessagingService.send(topicSession: String, payload: Any, to: MessageRecipients, deduplicationInfo: SenderDeduplicationInfo, additionalHeaders: Map<String, String>) = send(createMessage(topicSession, payload.serialize().bytes, deduplicationInfo, additionalHeaders), to)
interface MessageHandlerRegistration interface MessageHandlerRegistration
@ -128,7 +136,7 @@ interface Message {
val topic: String val topic: String
val data: ByteSequence val data: ByteSequence
val debugTimestamp: Instant val debugTimestamp: Instant
val uniqueMessageId: DeduplicationId val uniqueMessageId: MessageIdentifier
val senderUUID: String? val senderUUID: String?
val additionalHeaders: Map<String, String> val additionalHeaders: Map<String, String>
} }

View File

@ -75,7 +75,7 @@ class MessagingExecutor(
putStringProperty(P2PMessagingHeaders.topicProperty, SimpleString(message.topic)) putStringProperty(P2PMessagingHeaders.topicProperty, SimpleString(message.topic))
writeBodyBufferBytes(message.data.bytes) writeBodyBufferBytes(message.data.bytes)
// Use the magic deduplication property built into Artemis as our message identity too // Use the magic deduplication property built into Artemis as our message identity too
putStringProperty(org.apache.activemq.artemis.api.core.Message.HDR_DUPLICATE_DETECTION_ID, SimpleString(message.uniqueMessageId.toString)) putStringProperty(org.apache.activemq.artemis.api.core.Message.HDR_DUPLICATE_DETECTION_ID, SimpleString(message.uniqueMessageId.toString()))
// If we are the sender (ie. we are not going through recovery of some sort), use sequence number short cut. // If we are the sender (ie. we are not going through recovery of some sort), use sequence number short cut.
if (ourSenderUUID == message.senderUUID) { if (ourSenderUUID == message.senderUUID) {
putStringProperty(P2PMessagingHeaders.senderUUID, SimpleString(ourSenderUUID)) putStringProperty(P2PMessagingHeaders.senderUUID, SimpleString(ourSenderUUID))

View File

@ -1,56 +1,79 @@
package net.corda.node.services.messaging package net.corda.node.services.messaging
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.NamedCacheFactory import net.corda.core.internal.NamedCacheFactory
import net.corda.node.services.statemachine.DeduplicationId import net.corda.core.utilities.contextLogger
import net.corda.node.services.statemachine.SessionId
import net.corda.node.utilities.AppendOnlyPersistentMap import net.corda.node.utilities.AppendOnlyPersistentMap
import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
import net.corda.nodeapi.internal.persistence.currentDBSession
import java.math.BigInteger
import java.time.Instant import java.time.Instant
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap
import javax.persistence.Column import javax.persistence.Column
import javax.persistence.Entity import javax.persistence.Entity
import javax.persistence.Id import javax.persistence.Id
import javax.persistence.Table
/** /**
* Encapsulate the de-duplication logic. * This component is responsible for determining whether session-init messages are duplicates and it also keeps track of information related to
* sessions that can be used for this purpose.
*/ */
class P2PMessageDeduplicator(cacheFactory: NamedCacheFactory, private val database: CordaPersistence) { class P2PMessageDeduplicator(cacheFactory: NamedCacheFactory, private val database: CordaPersistence) {
companion object {
private val logger = contextLogger()
}
// A temporary in-memory set of deduplication IDs and associated high water mark details. // A temporary in-memory set of deduplication IDs and associated high water mark details.
// When we receive a message we don't persist the ID immediately, // When we receive a message we don't persist the ID immediately,
// so we store the ID here in the meantime (until the persisting db tx has committed). This is because Artemis may // so we store the ID here in the meantime (until the persisting db tx has committed). This is because Artemis may
// redeliver messages to the same consumer if they weren't ACKed. // redeliver messages to the same consumer if they weren't ACKed.
private val beingProcessedMessages = ConcurrentHashMap<DeduplicationId, MessageMeta>() private val beingProcessedMessages = ConcurrentHashMap<MessageIdentifier, MessageMeta>()
private val processedMessages = createProcessedMessages(cacheFactory)
private fun createProcessedMessages(cacheFactory: NamedCacheFactory): AppendOnlyPersistentMap<DeduplicationId, MessageMeta, ProcessedMessage, String> { /**
* This table holds data *only* for sessions that have been initiated from a counterparty (e.g. ones we have received session-init messages from).
* This is because any other messages apart from session-init messages are deduplicated by the state machine.
*/
private val sessionData = createSessionDataMap(cacheFactory)
private fun createSessionDataMap(cacheFactory: NamedCacheFactory): AppendOnlyPersistentMap<SessionId, MessageMeta, SessionData, BigInteger> {
return AppendOnlyPersistentMap( return AppendOnlyPersistentMap(
cacheFactory = cacheFactory, cacheFactory = cacheFactory,
name = "P2PMessageDeduplicator_processedMessages", name = "P2PMessageDeduplicator_sessionData",
toPersistentEntityKey = { it.toString }, toPersistentEntityKey = { it.value },
fromPersistentEntity = { Pair(DeduplicationId(it.id), MessageMeta(it.insertionTime, it.hash, it.seqNo)) }, fromPersistentEntity = { Pair(SessionId(it.sessionId), MessageMeta(it.generationTime, it.senderHash, it.firstSenderSeqNo, it.lastSenderSeqNo)) },
toPersistentEntity = { key: DeduplicationId, value: MessageMeta -> toPersistentEntity = { key: SessionId, value: MessageMeta ->
ProcessedMessage().apply { SessionData().apply {
id = key.toString sessionId = key.value
insertionTime = value.insertionTime generationTime = value.generationTime
hash = value.senderHash senderHash = value.senderHash
seqNo = value.senderSeqNo firstSenderSeqNo = value.firstSenderSeqNo
lastSenderSeqNo = value.lastSenderSeqNo
} }
}, },
persistentEntityClass = ProcessedMessage::class.java persistentEntityClass = SessionData::class.java
) )
} }
private fun isDuplicateInDatabase(msg: ReceivedMessage): Boolean = database.transaction { msg.uniqueMessageId in processedMessages } private fun isDuplicateInDatabase(msg: ReceivedMessage): Boolean = database.transaction { msg.uniqueMessageId.sessionIdentifier in sessionData }
// We need to incorporate the sending party, and the sessionInit flag as per the in-memory cache. // We need to incorporate the sending party, and the sessionInit flag as per the in-memory cache.
private fun senderHash(senderKey: SenderKey) = SecureHash.sha256(senderKey.peer.toString() + senderKey.isSessionInit.toString() + senderKey.senderUUID).toString() private fun senderHash(senderKey: SenderKey) = SecureHash.sha256(senderKey.peer.toString() + senderKey.isSessionInit.toString() + senderKey.senderUUID).toString()
/** /**
* Determines whether a session-init message is a duplicate.
* This is achieved by checking whether this message is currently being processed or if the associated session has already been created in the past.
* This method should be invoked only with session-init messages, otherwise it will fail with an [IllegalArgumentException].
*
* @return true if we have seen this message before. * @return true if we have seen this message before.
*/ */
fun isDuplicate(msg: ReceivedMessage): Boolean { fun isDuplicateSessionInit(msg: ReceivedMessage): Boolean {
require(msg.isSessionInit) { "Message ${msg.uniqueMessageId} was not a session-init message." }
if (beingProcessedMessages.containsKey(msg.uniqueMessageId)) { if (beingProcessedMessages.containsKey(msg.uniqueMessageId)) {
return true return true
} }
@ -65,44 +88,86 @@ class P2PMessageDeduplicator(cacheFactory: NamedCacheFactory, private val databa
val receivedSenderSeqNo = msg.senderSeqNo val receivedSenderSeqNo = msg.senderSeqNo
// We don't want a mix of nulls and values so we ensure that here. // We don't want a mix of nulls and values so we ensure that here.
val senderHash: String? = if (receivedSenderUUID != null && receivedSenderSeqNo != null) senderHash(SenderKey(receivedSenderUUID, msg.peer, msg.isSessionInit)) else null val senderHash: String? = if (receivedSenderUUID != null && receivedSenderSeqNo != null) senderHash(SenderKey(receivedSenderUUID, msg.peer, msg.isSessionInit)) else null
val senderSeqNo: Long? = if (senderHash != null) msg.senderSeqNo else null val firstSenderSeqNo: Long? = if (senderHash != null) msg.senderSeqNo else null
beingProcessedMessages[msg.uniqueMessageId] = MessageMeta(Instant.now(), senderHash, senderSeqNo) beingProcessedMessages[msg.uniqueMessageId] = MessageMeta(msg.uniqueMessageId.timestamp, senderHash, firstSenderSeqNo, null)
} }
/** /**
* Called inside a DB transaction to persist [deduplicationId]. * Called inside a DB transaction to persist [deduplicationId].
*/ */
fun persistDeduplicationId(deduplicationId: DeduplicationId) { fun persistDeduplicationId(deduplicationId: MessageIdentifier) {
processedMessages[deduplicationId] = beingProcessedMessages[deduplicationId]!! sessionData[deduplicationId.sessionIdentifier] = beingProcessedMessages[deduplicationId]!!
} }
/** /**
* Called after the DB transaction persisting [deduplicationId] committed. * Called after the DB transaction persisting [deduplicationId] committed.
* Any subsequent redelivery will be deduplicated using the DB. * Any subsequent redelivery will be deduplicated using the DB.
*/ */
fun signalMessageProcessFinish(deduplicationId: DeduplicationId) { fun signalMessageProcessFinish(deduplicationId: MessageIdentifier) {
beingProcessedMessages.remove(deduplicationId) beingProcessedMessages.remove(deduplicationId)
} }
/**
* Called inside a DB transaction to update entry for corresponding session.
* The parameters [senderUUID] and [senderSequenceNumber] correspond to the last message seen from this session before it ended.
* If [senderUUID] is not null, then [senderSequenceNumber] is also expected to not be null.
*/
@Suspendable
fun signalSessionEnd(sessionId: SessionId, senderUUID: String?, senderSequenceNumber: Long?) {
if (senderSequenceNumber != null && senderUUID != null) {
val existingEntry = sessionData[sessionId]
if (existingEntry != null) {
val newEntry = existingEntry.copy(lastSenderSeqNo = senderSequenceNumber)
sessionData.addOrUpdate(sessionId, newEntry) { k, v ->
update(k, v)
}
}
}
}
private fun update(key: SessionId, value: MessageMeta): Boolean {
val session = currentDBSession()
val criteriaBuilder = session.criteriaBuilder
val criteriaUpdate = criteriaBuilder.createCriteriaUpdate(SessionData::class.java)
val queryRoot = criteriaUpdate.from(SessionData::class.java)
criteriaUpdate.set(SessionData::lastSenderSeqNo.name, value.lastSenderSeqNo)
criteriaUpdate.where(criteriaBuilder.equal(queryRoot.get<BigInteger>(SessionData::sessionId.name), key.value))
val update = session.createQuery(criteriaUpdate)
val rowsUpdated = update.executeUpdate()
return rowsUpdated != 0
}
@Entity @Entity
@Suppress("MagicNumber") // database column width @Suppress("MagicNumber") // database column width
@javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}message_ids") @Table(name = "${NODE_DATABASE_PREFIX}session_data")
class ProcessedMessage( class SessionData (
@Id @Id
@Column(name = "message_id", length = 64, nullable = false) @Column(name = "session_id", nullable = false)
var id: String = "", var sessionId: BigInteger = BigInteger.ZERO,
@Column(name = "insertion_time", nullable = false) /**
var insertionTime: Instant = Instant.now(), * The time the corresponding session-init message was originally generated on the sender side.
*/
@Column(name = "init_generation_time", nullable = false)
var generationTime: Instant = Instant.now(),
@Column(name = "sender", length = 64, nullable = true) @Column(name = "sender_hash", length = 64, nullable = true)
var hash: String? = "", var senderHash: String? = "",
@Column(name = "sequence_number", nullable = true) /**
var seqNo: Long? = null * The sender sequence number of the first message seen in a session.
*/
@Column(name = "init_sequence_number", nullable = true)
var firstSenderSeqNo: Long? = null,
/**
* The sender sequence number of the last message seen in a session before it was closed/terminated.
*/
@Column(name = "last_sequence_number", nullable = true)
var lastSenderSeqNo: Long? = null
) )
private data class MessageMeta(val insertionTime: Instant, val senderHash: String?, val senderSeqNo: Long?) private data class MessageMeta(val generationTime: Instant, val senderHash: String?, val firstSenderSeqNo: SenderSequenceNumber?, val lastSenderSeqNo: SenderSequenceNumber?)
private data class SenderKey(val senderUUID: String, val peer: CordaX500Name, val isSessionInit: Boolean) private data class SenderKey(val senderUUID: String, val peer: CordaX500Name, val isSessionInit: Boolean)
} }

View File

@ -25,9 +25,9 @@ import net.corda.node.internal.artemis.ReactiveArtemisConsumer
import net.corda.node.internal.artemis.ReactiveArtemisConsumer.Companion.multiplex import net.corda.node.internal.artemis.ReactiveArtemisConsumer.Companion.multiplex
import net.corda.node.services.api.NetworkMapCacheInternal import net.corda.node.services.api.NetworkMapCacheInternal
import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.statemachine.DeduplicationId
import net.corda.node.services.statemachine.ExternalEvent import net.corda.node.services.statemachine.ExternalEvent
import net.corda.node.services.statemachine.SenderDeduplicationId import net.corda.node.services.statemachine.MessageType
import net.corda.node.services.statemachine.SessionId
import net.corda.node.utilities.AffinityExecutor import net.corda.node.utilities.AffinityExecutor
import net.corda.node.utilities.errorAndTerminate import net.corda.node.utilities.errorAndTerminate
import net.corda.nodeapi.internal.ArtemisMessagingComponent import net.corda.nodeapi.internal.ArtemisMessagingComponent
@ -101,8 +101,8 @@ class P2PMessagingClient(val config: NodeConfiguration,
private class NodeClientMessage(override val topic: String, private class NodeClientMessage(override val topic: String,
override val data: ByteSequence, override val data: ByteSequence,
override val uniqueMessageId: DeduplicationId, override val uniqueMessageId: MessageIdentifier,
override val senderUUID: String?, override val senderUUID: SenderUUID?,
override val additionalHeaders: Map<String, String>) : Message { override val additionalHeaders: Map<String, String>) : Message {
override val debugTimestamp: Instant = Instant.now() override val debugTimestamp: Instant = Instant.now()
override fun toString() = "$topic#${String(data.bytes)}" override fun toString() = "$topic#${String(data.bytes)}"
@ -371,7 +371,7 @@ class P2PMessagingClient(val config: NodeConfiguration,
val user = requireNotNull(message.getStringProperty(HDR_VALIDATED_USER)) { "Message is not authenticated" } val user = requireNotNull(message.getStringProperty(HDR_VALIDATED_USER)) { "Message is not authenticated" }
val platformVersion = message.required(P2PMessagingHeaders.platformVersionProperty) { getIntProperty(it) } val platformVersion = message.required(P2PMessagingHeaders.platformVersionProperty) { getIntProperty(it) }
// Use the magic deduplication property built into Artemis as our message identity too // Use the magic deduplication property built into Artemis as our message identity too
val uniqueMessageId = message.required(HDR_DUPLICATE_DETECTION_ID) { DeduplicationId(message.getStringProperty(it)) } val uniqueMessageId = message.required(HDR_DUPLICATE_DETECTION_ID) { MessageIdentifier.parse(message.getStringProperty(it)) }
val receivedSenderUUID = message.getStringProperty(P2PMessagingHeaders.senderUUID) val receivedSenderUUID = message.getStringProperty(P2PMessagingHeaders.senderUUID)
val receivedSenderSeqNo = if (message.containsProperty(P2PMessagingHeaders.senderSeqNo)) message.getLongProperty(P2PMessagingHeaders.senderSeqNo) else null val receivedSenderSeqNo = if (message.containsProperty(P2PMessagingHeaders.senderSeqNo)) message.getLongProperty(P2PMessagingHeaders.senderSeqNo) else null
val isSessionInit = message.getStringProperty(P2PMessagingHeaders.Type.KEY) == P2PMessagingHeaders.Type.SESSION_INIT_VALUE val isSessionInit = message.getStringProperty(P2PMessagingHeaders.Type.KEY) == P2PMessagingHeaders.Type.SESSION_INIT_VALUE
@ -392,8 +392,8 @@ class P2PMessagingClient(val config: NodeConfiguration,
private class ArtemisReceivedMessage(override val topic: String, private class ArtemisReceivedMessage(override val topic: String,
override val peer: CordaX500Name, override val peer: CordaX500Name,
override val platformVersion: Int, override val platformVersion: Int,
override val uniqueMessageId: DeduplicationId, override val uniqueMessageId: MessageIdentifier,
override val senderUUID: String?, override val senderUUID: SenderUUID?,
override val senderSeqNo: Long?, override val senderSeqNo: Long?,
override val isSessionInit: Boolean, override val isSessionInit: Boolean,
private val message: ClientMessage) : ReceivedMessage { private val message: ClientMessage) : ReceivedMessage {
@ -405,12 +405,17 @@ class P2PMessagingClient(val config: NodeConfiguration,
internal fun deliver(artemisMessage: ClientMessage) { internal fun deliver(artemisMessage: ClientMessage) {
artemisToCordaMessage(artemisMessage)?.let { cordaMessage -> artemisToCordaMessage(artemisMessage)?.let { cordaMessage ->
if (!deduplicator.isDuplicate(cordaMessage)) { if (cordaMessage.uniqueMessageId.messageType == MessageType.SESSION_INIT) {
deduplicator.signalMessageProcessStart(cordaMessage) if (!deduplicator.isDuplicateSessionInit(cordaMessage)) {
deliver(cordaMessage, artemisMessage) deduplicator.signalMessageProcessStart(cordaMessage)
deliver(cordaMessage, artemisMessage)
} else {
log.debug { "Discarding duplicate session-init message with identifier: ${cordaMessage.uniqueMessageId}, senderUUID: ${cordaMessage.senderUUID}, senderSeqNo: ${cordaMessage.senderSeqNo}" }
messagingExecutor!!.acknowledge(artemisMessage)
}
} else { } else {
log.trace { "Discard duplicate message ${cordaMessage.uniqueMessageId} for ${cordaMessage.topic}" } // non session-init messages are directly handed to the state machine, which is responsible for performing deduplication.
messagingExecutor!!.acknowledge(artemisMessage) deliver(cordaMessage, artemisMessage)
} }
} }
} }
@ -420,7 +425,11 @@ class P2PMessagingClient(val config: NodeConfiguration,
val deliverTo = handlers[msg.topic] val deliverTo = handlers[msg.topic]
if (deliverTo != null) { if (deliverTo != null) {
try { try {
deliverTo(msg, HandlerRegistration(msg.topic, deliverTo), MessageDeduplicationHandler(artemisMessage, msg)) if (msg.uniqueMessageId.messageType == MessageType.SESSION_INIT) {
deliverTo(msg, HandlerRegistration(msg.topic, deliverTo), MessageDeduplicationHandlerForSessionInitMessages(artemisMessage, msg))
} else {
deliverTo(msg, HandlerRegistration(msg.topic, deliverTo), MessageDeduplicationHandlerForRegularMessages(artemisMessage, msg))
}
} catch (e: Exception) { } catch (e: Exception) {
log.error("Caught exception whilst executing message handler for ${msg.topic}", e) log.error("Caught exception whilst executing message handler for ${msg.topic}", e)
} }
@ -429,11 +438,11 @@ class P2PMessagingClient(val config: NodeConfiguration,
} }
} }
private inner class MessageDeduplicationHandler(val artemisMessage: ClientMessage, override val receivedMessage: ReceivedMessage) : DeduplicationHandler, ExternalEvent.ExternalMessageEvent { private inner class MessageDeduplicationHandlerForSessionInitMessages(val artemisMessage: ClientMessage, override val receivedMessage: ReceivedMessage) : DeduplicationHandler, ExternalEvent.ExternalMessageEvent {
override val externalCause: ExternalEvent override val externalCause: ExternalEvent
get() = this get() = this
override val flowId: StateMachineRunId by lazy { StateMachineRunId.createRandom() } override val flowId: StateMachineRunId by lazy { StateMachineRunId.createRandom() }
override val deduplicationHandler: MessageDeduplicationHandler override val deduplicationHandler: MessageDeduplicationHandlerForSessionInitMessages
get() = this get() = this
override fun insideDatabaseTransaction() { override fun insideDatabaseTransaction() {
@ -450,6 +459,27 @@ class P2PMessagingClient(val config: NodeConfiguration,
} }
} }
private inner class MessageDeduplicationHandlerForRegularMessages(val artemisMessage: ClientMessage, override val receivedMessage: ReceivedMessage) : DeduplicationHandler, ExternalEvent.ExternalMessageEvent {
override val externalCause: ExternalEvent
get() = this
override val flowId: StateMachineRunId by lazy { StateMachineRunId.createRandom() }
override val deduplicationHandler: MessageDeduplicationHandlerForRegularMessages
get() = this
/**
* Nothing to do, since deduplication information is kept in the state machine.
*/
override fun insideDatabaseTransaction() {}
override fun afterDatabaseTransaction() {
messagingExecutor!!.acknowledge(artemisMessage)
}
override fun toString(): String {
return "${javaClass.simpleName}(${receivedMessage.uniqueMessageId})"
}
}
/** /**
* Initiates shutdown: if called from a thread that isn't controlled by the executor passed to the constructor * Initiates shutdown: if called from a thread that isn't controlled by the executor passed to the constructor
* then this will block until all in-flight messages have finished being handled and acknowledged. If called * then this will block until all in-flight messages have finished being handled and acknowledged. If called
@ -520,6 +550,11 @@ class P2PMessagingClient(val config: NodeConfiguration,
} }
} }
@Suspendable
override fun sessionEnded(sessionId: SessionId, senderUUID: SenderUUID?, senderSequenceNumber: SenderSequenceNumber?) {
deduplicator.signalSessionEnd(sessionId, senderUUID, senderSequenceNumber)
}
override fun resolveTargetToArtemisQueue(address: MessageRecipients): String { override fun resolveTargetToArtemisQueue(address: MessageRecipients): String {
return if (address == myAddress) { return if (address == myAddress) {
// If we are sending to ourselves then route the message directly to our P2P queue. // If we are sending to ourselves then route the message directly to our P2P queue.
@ -586,8 +621,8 @@ class P2PMessagingClient(val config: NodeConfiguration,
handlers.remove(registration.topic) handlers.remove(registration.topic)
} }
override fun createMessage(topic: String, data: ByteArray, deduplicationId: SenderDeduplicationId, additionalHeaders: Map<String, String>): Message { override fun createMessage(topic: String, data: ByteArray, deduplicationInfo: SenderDeduplicationInfo, additionalHeaders: Map<String, String>): Message {
return NodeClientMessage(topic, OpaqueBytes(data), deduplicationId.deduplicationId, deduplicationId.senderUUID, additionalHeaders) return NodeClientMessage(topic, OpaqueBytes(data), deduplicationInfo.messageIdentifier, deduplicationInfo.senderUUID, additionalHeaders)
} }
override fun getAddressOfParty(partyInfo: PartyInfo): MessageRecipients { override fun getAddressOfParty(partyInfo: PartyInfo): MessageRecipients {

View File

@ -0,0 +1,11 @@
package net.corda.node.services.messaging
/**
* This is a combination of a message's unique identifier along with a unique identifier for the sender of the message.
* The former can be used independently for deduplication purposes when receiving a message, but enriching it with the latter helps us
* optimise some paths and perform smarter deduplication logic per sender.
*
* The [senderUUID] property might be null if the flow is trying to replay messages and doesn't want an optimisation to ignore the message identifier
* because it could lead to false negatives (messages that are deemed duplicates, but are not).
*/
data class SenderDeduplicationInfo(val messageIdentifier: MessageIdentifier, val senderUUID: String?)

View File

@ -513,7 +513,7 @@ class CheckpointDumperImpl(private val checkpointStorage: CheckpointStorage, pri
private fun SessionState.toActiveSession(sessionId: SessionId): ActiveSession? { private fun SessionState.toActiveSession(sessionId: SessionId): ActiveSession? {
return if (this is SessionState.Initiated) { return if (this is SessionState.Initiated) {
ActiveSession(peerParty, sessionId, receivedMessages, peerFlowInfo, peerSinkSessionId) ActiveSession(peerParty, sessionId, receivedMessages.values.toList(), peerFlowInfo, peerSinkSessionId)
} else { } else {
null null
} }

View File

@ -43,7 +43,7 @@ class NodeSchemaService(private val extraSchemas: Set<MappedSchema> = emptySet()
BasicHSMKeyManagementService.PersistentKey::class.java, BasicHSMKeyManagementService.PersistentKey::class.java,
NodeSchedulerService.PersistentScheduledState::class.java, NodeSchedulerService.PersistentScheduledState::class.java,
NodeAttachmentService.DBAttachment::class.java, NodeAttachmentService.DBAttachment::class.java,
P2PMessageDeduplicator.ProcessedMessage::class.java, P2PMessageDeduplicator.SessionData::class.java,
PersistentIdentityService.PersistentPublicKeyHashToCertificate::class.java, PersistentIdentityService.PersistentPublicKeyHashToCertificate::class.java,
PersistentIdentityService.PersistentPublicKeyHashToParty::class.java, PersistentIdentityService.PersistentPublicKeyHashToParty::class.java,
PersistentIdentityService.PersistentHashToPublicKey::class.java, PersistentIdentityService.PersistentHashToPublicKey::class.java,

View File

@ -6,6 +6,10 @@ import net.corda.core.flows.StateMachineRunId
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.internal.FlowAsyncOperation import net.corda.core.internal.FlowAsyncOperation
import net.corda.node.services.messaging.DeduplicationHandler import net.corda.node.services.messaging.DeduplicationHandler
import net.corda.node.services.messaging.MessageIdentifier
import net.corda.node.services.messaging.SenderDeduplicationInfo
import net.corda.node.services.messaging.SenderSequenceNumber
import net.corda.node.services.messaging.SenderUUID
import java.time.Instant import java.time.Instant
import java.util.* import java.util.*
@ -25,7 +29,7 @@ sealed class Action {
data class SendInitial( data class SendInitial(
val destination: Destination, val destination: Destination,
val initialise: InitialSessionMessage, val initialise: InitialSessionMessage,
val deduplicationId: SenderDeduplicationId val deduplicationInfo: SenderDeduplicationInfo
) : Action() ) : Action()
/** /**
@ -34,7 +38,7 @@ sealed class Action {
data class SendExisting( data class SendExisting(
val peerParty: Party, val peerParty: Party,
val message: ExistingSessionMessage, val message: ExistingSessionMessage,
val deduplicationId: SenderDeduplicationId val deduplicationInfo: SenderDeduplicationInfo
) : Action() ) : Action()
/** /**
@ -95,12 +99,11 @@ sealed class Action {
data class AcknowledgeMessages(val deduplicationHandlers: List<DeduplicationHandler>) : Action() data class AcknowledgeMessages(val deduplicationHandlers: List<DeduplicationHandler>) : Action()
/** /**
* Propagate [errorMessages] to [sessions]. * Propagate the specified error messages to the specified sessions.
* @param sessions a map from source session IDs to initiated sessions. * @param errorsPerSession a map containing the error messages to be sent per session along with their identifiers.
*/ */
data class PropagateErrors( data class PropagateErrors(
val errorMessages: List<ErrorSessionMessage>, val errorsPerSession: Map<SessionState.Initiated, List<Pair<MessageIdentifier, ErrorSessionMessage>>>,
val sessions: List<SessionState.Initiated>,
val senderUUID: String? val senderUUID: String?
) : Action() ) : Action()
@ -114,6 +117,11 @@ sealed class Action {
*/ */
data class RemoveSessionBindings(val sessionIds: Set<SessionId>) : Action() data class RemoveSessionBindings(val sessionIds: Set<SessionId>) : Action()
/**
* Signal sessions ended at the messaging layer.
*/
data class SignalSessionsHasEnded(val terminatedSessions: Map<SessionId, Pair<SenderUUID?, SenderSequenceNumber?>>): Action()
/** /**
* Signal that the flow corresponding to [flowId] is considered started. * Signal that the flow corresponding to [flowId] is considered started.
*/ */

View File

@ -10,6 +10,7 @@ import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.trace import net.corda.core.utilities.trace
import net.corda.node.services.api.CheckpointStorage import net.corda.node.services.api.CheckpointStorage
import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.messaging.SenderDeduplicationInfo
import net.corda.nodeapi.internal.persistence.contextDatabase import net.corda.nodeapi.internal.persistence.contextDatabase
import net.corda.nodeapi.internal.persistence.contextTransaction import net.corda.nodeapi.internal.persistence.contextTransaction
import net.corda.nodeapi.internal.persistence.contextTransactionOrNull import net.corda.nodeapi.internal.persistence.contextTransactionOrNull
@ -58,6 +59,7 @@ internal class ActionExecutorImpl(
is Action.AddSessionBinding -> executeAddSessionBinding(action) is Action.AddSessionBinding -> executeAddSessionBinding(action)
is Action.RemoveSessionBindings -> executeRemoveSessionBindings(action) is Action.RemoveSessionBindings -> executeRemoveSessionBindings(action)
is Action.SignalFlowHasStarted -> executeSignalFlowHasStarted(action) is Action.SignalFlowHasStarted -> executeSignalFlowHasStarted(action)
is Action.SignalSessionsHasEnded -> executeSignalSessionsHasEnded(action)
is Action.RemoveFlow -> executeRemoveFlow(action) is Action.RemoveFlow -> executeRemoveFlow(action)
is Action.CreateTransaction -> executeCreateTransaction() is Action.CreateTransaction -> executeCreateTransaction()
is Action.RollbackTransaction -> executeRollbackTransaction() is Action.RollbackTransaction -> executeRollbackTransaction()
@ -132,16 +134,17 @@ internal class ActionExecutorImpl(
@Suspendable @Suspendable
private fun executePropagateErrors(action: Action.PropagateErrors) { private fun executePropagateErrors(action: Action.PropagateErrors) {
action.errorMessages.forEach { (exception) -> val errors = action.errorsPerSession.values.flatMap { it.map { it.second } }.distinct()
errors.forEach { errorSessionMessage ->
val exception = errorSessionMessage.flowException
log.warn("Propagating error", exception) log.warn("Propagating error", exception)
} }
for (sessionState in action.sessions) { action.errorsPerSession.forEach { (sessionState, sessionErrors) ->
// Don't propagate errors to the originating session // Don't propagate errors to the originating session
for (errorMessage in action.errorMessages) { for ((id, msg) in sessionErrors) {
val sinkSessionId = sessionState.peerSinkSessionId val sinkSessionId = sessionState.peerSinkSessionId
val existingMessage = ExistingSessionMessage(sinkSessionId, errorMessage) val errorMsg = ExistingSessionMessage(sinkSessionId, msg)
val deduplicationId = DeduplicationId.createForError(errorMessage.errorId, sinkSessionId) flowMessaging.sendSessionMessage(sessionState.peerParty, errorMsg, SenderDeduplicationInfo(id, action.senderUUID))
flowMessaging.sendSessionMessage(sessionState.peerParty, existingMessage, SenderDeduplicationId(deduplicationId, action.senderUUID))
} }
} }
} }
@ -162,18 +165,18 @@ internal class ActionExecutorImpl(
@Suspendable @Suspendable
private fun executeSendInitial(action: Action.SendInitial) { private fun executeSendInitial(action: Action.SendInitial) {
flowMessaging.sendSessionMessage(action.destination, action.initialise, action.deduplicationId) flowMessaging.sendSessionMessage(action.destination, action.initialise, action.deduplicationInfo)
} }
@Suspendable @Suspendable
private fun executeSendExisting(action: Action.SendExisting) { private fun executeSendExisting(action: Action.SendExisting) {
flowMessaging.sendSessionMessage(action.peerParty, action.message, action.deduplicationId) flowMessaging.sendSessionMessage(action.peerParty, action.message, action.deduplicationInfo)
} }
@Suspendable @Suspendable
private fun executeSendMultiple(action: Action.SendMultiple) { private fun executeSendMultiple(action: Action.SendMultiple) {
val messages = action.sendInitial.map { Message(it.destination, it.initialise, it.deduplicationId) } + val messages = action.sendInitial.map { Message(it.destination, it.initialise, it.deduplicationInfo) } +
action.sendExisting.map { Message(it.peerParty, it.message, it.deduplicationId) } action.sendExisting.map { Message(it.peerParty, it.message, it.deduplicationInfo) }
flowMessaging.sendSessionMessages(messages) flowMessaging.sendSessionMessages(messages)
} }
@ -192,6 +195,13 @@ internal class ActionExecutorImpl(
stateMachineManager.signalFlowHasStarted(action.flowId) stateMachineManager.signalFlowHasStarted(action.flowId)
} }
@Suspendable
private fun executeSignalSessionsHasEnded(action: Action.SignalSessionsHasEnded) {
action.terminatedSessions.forEach { (sessionId, senderData) ->
flowMessaging.sessionEnded(sessionId, senderData.first, senderData.second)
}
}
@Suspendable @Suspendable
private fun executeRemoveFlow(action: Action.RemoveFlow) { private fun executeRemoveFlow(action: Action.RemoveFlow) {
stateMachineManager.removeFlow(action.flowId, action.removalReason, action.lastState) stateMachineManager.removeFlow(action.flowId, action.removalReason, action.lastState)

View File

@ -1,53 +0,0 @@
package net.corda.node.services.statemachine
import java.security.SecureRandom
/**
* A deduplication ID of a flow message.
*/
data class DeduplicationId(val toString: String) {
companion object {
/**
* Create a random deduplication ID. Note that this isn't deterministic, which means we will never dedupe it,
* unless we persist the ID somehow.
*/
fun createRandom(random: SecureRandom) = DeduplicationId("R-${random.nextLong()}")
/**
* Create a deduplication ID for a normal clean state message. This is used to have a deterministic way of
* creating IDs in case the message-generating flow logic is replayed on hard failure.
*
* A normal deduplication ID consists of:
* 1. A deduplication seed set per session. This is the initiator's session ID, with a prefix for initiator
* or initiated.
* 2. The number of *clean* suspends since the start of the flow.
* 3. An optional additional index, for cases where several messages are sent as part of the state transition.
* Note that care must be taken with this index, it must be a deterministic counter. For example a naive
* iteration over a HashMap will produce a different list of indeces than a previous run, causing the
* message-id map to change, which means deduplication will not happen correctly.
*/
fun createForNormal(checkpoint: Checkpoint, index: Int, session: SessionState): DeduplicationId {
return DeduplicationId("N-${session.deduplicationSeed}-${checkpoint.checkpointState.numberOfSuspends}-$index")
}
/**
* Create a deduplication ID for an error message. Note that these IDs live in a different namespace than normal
* IDs, as we don't want error conditions to affect the determinism of clean deduplication IDs. This allows the
* dirtiness state to be thrown away for resumption.
*
* An error deduplication ID consists of:
* 1. The error's ID. This is a unique value per "source" of error and is propagated.
* See [net.corda.core.flows.IdentifiableException].
* 2. The recipient's session ID.
*/
fun createForError(errorId: Long, recipientSessionId: SessionId): DeduplicationId {
return DeduplicationId("E-$errorId-${recipientSessionId.toLong}")
}
}
}
/**
* Represents the deduplication ID of a flow message, and the sender identifier for the flow doing the sending. The identifier might be
* null if the flow is trying to replay messages and doesn't want an optimisation to ignore the deduplication ID.
*/
data class SenderDeduplicationId(val deduplicationId: DeduplicationId, val senderUUID: String?)

View File

@ -8,6 +8,9 @@ import net.corda.core.serialization.SerializedBytes
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.ProgressTracker
import net.corda.node.services.messaging.DeduplicationHandler import net.corda.node.services.messaging.DeduplicationHandler
import net.corda.node.services.messaging.MessageIdentifier
import net.corda.node.services.messaging.SenderSequenceNumber
import net.corda.node.services.messaging.SenderUUID
import java.util.UUID import java.util.UUID
/** /**
@ -34,7 +37,10 @@ sealed class Event {
data class DeliverSessionMessage( data class DeliverSessionMessage(
val sessionMessage: ExistingSessionMessage, val sessionMessage: ExistingSessionMessage,
override val deduplicationHandler: DeduplicationHandler, override val deduplicationHandler: DeduplicationHandler,
val sender: Party val sender: Party,
val messageIdentifier: MessageIdentifier,
val senderUUID: SenderUUID?,
val senderSequenceNumber: SenderSequenceNumber?
) : Event(), GeneratedByExternalEvent ) : Event(), GeneratedByExternalEvent
/** /**

View File

@ -155,7 +155,8 @@ class FlowCreator(
frozenFlowLogic, frozenFlowLogic,
ourIdentity, ourIdentity,
flowCorDappVersion, flowCorDappVersion,
flowLogic.isEnabledTimedFlow() flowLogic.isEnabledTimedFlow(),
serviceHub.clock.instant()
).getOrThrow() ).getOrThrow()
val state = createStateMachineState( val state = createStateMachineState(
@ -253,6 +254,7 @@ class FlowCreator(
return StateMachineState( return StateMachineState(
checkpoint = checkpoint, checkpoint = checkpoint,
pendingDeduplicationHandlers = deduplicationHandler?.let { listOf(it) } ?: emptyList(), pendingDeduplicationHandlers = deduplicationHandler?.let { listOf(it) } ?: emptyList(),
closedSessionsPendingToBeSignalled = emptyMap(),
isFlowResumed = false, isFlowResumed = false,
future = null, future = null,
isWaitingForFuture = false, isWaitingForFuture = false,

View File

@ -15,6 +15,9 @@ import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.messaging.DeduplicationHandler import net.corda.node.services.messaging.DeduplicationHandler
import net.corda.node.services.messaging.MessagingService import net.corda.node.services.messaging.MessagingService
import net.corda.node.services.messaging.ReceivedMessage import net.corda.node.services.messaging.ReceivedMessage
import net.corda.node.services.messaging.SenderDeduplicationInfo
import net.corda.node.services.messaging.SenderSequenceNumber
import net.corda.node.services.messaging.SenderUUID
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2PMessagingHeaders import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2PMessagingHeaders
import java.io.NotSerializableException import java.io.NotSerializableException
@ -23,21 +26,24 @@ import java.io.NotSerializableException
*/ */
interface FlowMessaging { interface FlowMessaging {
/** /**
* Send [message] to [destination] using [deduplicationId]. * Send [message] to [destination] using [deduplicationInfo].
*/ */
@Suspendable @Suspendable
fun sendSessionMessage(destination: Destination, message: SessionMessage, deduplicationId: SenderDeduplicationId) fun sendSessionMessage(destination: Destination, message: SessionMessage, deduplicationInfo: SenderDeduplicationInfo)
@Suspendable @Suspendable
fun sendSessionMessages(messageData: List<Message>) fun sendSessionMessages(messageData: List<Message>)
@Suspendable
fun sessionEnded(sessionId: SessionId, senderUUID: SenderUUID?, senderSequenceNumber: SenderSequenceNumber?)
/** /**
* Start the messaging using the [onMessage] message handler. * Start the messaging using the [onMessage] message handler.
*/ */
fun start(onMessage: (ReceivedMessage, deduplicationHandler: DeduplicationHandler) -> Unit) fun start(onMessage: (ReceivedMessage, deduplicationHandler: DeduplicationHandler) -> Unit)
} }
data class Message(val destination: Destination, val sessionMessage: SessionMessage, val dedupId: SenderDeduplicationId) data class Message(val destination: Destination, val sessionMessage: SessionMessage, val dedupInfo: SenderDeduplicationInfo)
/** /**
* Implementation of [FlowMessaging] using a [ServiceHubInternal] to do the messaging and routing. * Implementation of [FlowMessaging] using a [ServiceHubInternal] to do the messaging and routing.
@ -56,18 +62,23 @@ class FlowMessagingImpl(val serviceHub: ServiceHubInternal): FlowMessaging {
} }
@Suspendable @Suspendable
override fun sendSessionMessage(destination: Destination, message: SessionMessage, deduplicationId: SenderDeduplicationId) { override fun sendSessionMessage(destination: Destination, message: SessionMessage, deduplicationInfo: SenderDeduplicationInfo) {
val addressedMessage = createMessage(destination, message, deduplicationId) val addressedMessage = createMessage(destination, message, deduplicationInfo)
serviceHub.networkService.send(addressedMessage.message, addressedMessage.target, addressedMessage.sequenceKey) serviceHub.networkService.send(addressedMessage.message, addressedMessage.target, addressedMessage.sequenceKey)
} }
@Suspendable @Suspendable
override fun sendSessionMessages(messageData: List<Message>) { override fun sendSessionMessages(messageData: List<Message>) {
val addressedMessages = messageData.map { createMessage(it.destination, it.sessionMessage, it.dedupId) } val addressedMessages = messageData.map { createMessage(it.destination, it.sessionMessage, it.dedupInfo) }
serviceHub.networkService.sendAll(addressedMessages) serviceHub.networkService.sendAll(addressedMessages)
} }
private fun createMessage(destination: Destination, message: SessionMessage, deduplicationId: SenderDeduplicationId): MessagingService.AddressedMessage { @Suspendable
override fun sessionEnded(sessionId: SessionId, senderUUID: SenderUUID?, senderSequenceNumber: SenderSequenceNumber?) {
serviceHub.networkService.sessionEnded(sessionId, senderUUID, senderSequenceNumber)
}
private fun createMessage(destination: Destination, message: SessionMessage, deduplicationId: SenderDeduplicationInfo): MessagingService.AddressedMessage {
// We assume that the destination type has already been checked by initiateFlow. // We assume that the destination type has already been checked by initiateFlow.
// Destination may point to a stale well-known identity due to key rotation, so always resolve actual identity via IdentityService. // Destination may point to a stale well-known identity due to key rotation, so always resolve actual identity via IdentityService.
val party = requireNotNull(serviceHub.identityService.wellKnownPartyFromAnonymous(destination as AbstractParty)) { val party = requireNotNull(serviceHub.identityService.wellKnownPartyFromAnonymous(destination as AbstractParty)) {
@ -80,6 +91,7 @@ class FlowMessagingImpl(val serviceHub: ServiceHubInternal): FlowMessaging {
} }
val networkMessage = serviceHub.networkService.createMessage(sessionTopic, serializeSessionMessage(message).bytes, deduplicationId, message.additionalHeaders(party)) val networkMessage = serviceHub.networkService.createMessage(sessionTopic, serializeSessionMessage(message).bytes, deduplicationId, message.additionalHeaders(party))
val partyInfo = requireNotNull(serviceHub.networkMapCache.getPartyInfo(party)) { "Don't know about ${party.description()}" } val partyInfo = requireNotNull(serviceHub.networkMapCache.getPartyInfo(party)) { "Don't know about ${party.description()}" }
val address = serviceHub.networkService.getAddressOfParty(partyInfo) val address = serviceHub.networkService.getAddressOfParty(partyInfo)
val sequenceKey = when (message) { val sequenceKey = when (message) {
is InitialSessionMessage -> message.initiatorSessionId is InitialSessionMessage -> message.initiatorSessionId

View File

@ -179,7 +179,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
val stateMachine = transientValues.stateMachine val stateMachine = transientValues.stateMachine
val oldState = transientState val oldState = transientState
val actionExecutor = transientValues.actionExecutor val actionExecutor = transientValues.actionExecutor
val transition = stateMachine.transition(event, oldState) val transition = stateMachine.transition(event, oldState, serviceHub.clock.instant())
val (continuation, newState) = transitionExecutor.executeTransition( val (continuation, newState) = transitionExecutor.executeTransition(
this, this,
oldState, oldState,

View File

@ -4,6 +4,7 @@ import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowInfo import net.corda.core.flows.FlowInfo
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.SerializedBytes
import java.math.BigInteger
import java.security.SecureRandom import java.security.SecureRandom
/** /**
@ -21,9 +22,28 @@ import java.security.SecureRandom
sealed class SessionMessage sealed class SessionMessage
@CordaSerializable @CordaSerializable
data class SessionId(val toLong: Long) { data class SessionId(val value: BigInteger) {
init {
require(value.signum() >= 0) { "Session identifier cannot be a negative number, but it was $value" }
require(value.bitLength() <= MAX_BIT_SIZE) { "The size of a session identifier cannot exceed $MAX_BIT_SIZE bits, but it was $value" }
}
/**
* This calculates the initiated session ID assuming this is the initiating session ID.
* This is the next larger number in the range [0, 2^[MAX_BIT_SIZE]] with wrap around the largest number in the interval.
*/
fun calculateInitiatedSessionId(): SessionId {
return if (this.value == LARGEST_SESSION_ID)
SessionId(BigInteger.ZERO)
else
SessionId(this.value.plus(BigInteger.ONE))
}
companion object { companion object {
fun createRandom(secureRandom: SecureRandom) = SessionId(secureRandom.nextLong()) const val MAX_BIT_SIZE = 128
val LARGEST_SESSION_ID = BigInteger.valueOf(2).pow(MAX_BIT_SIZE).minus(BigInteger.ONE)
fun createRandom(secureRandom: SecureRandom) = SessionId(BigInteger(MAX_BIT_SIZE, secureRandom))
} }
} }
@ -118,3 +138,29 @@ data class RejectSessionMessage(val message: String, val errorId: Long) : Existi
* protocols don't match up, e.g. one is waiting for the other, but the other side has already finished. * protocols don't match up, e.g. one is waiting for the other, but the other side has already finished.
*/ */
object EndSessionMessage : ExistingSessionMessagePayload() object EndSessionMessage : ExistingSessionMessagePayload()
enum class MessageType {
SESSION_INIT,
SESSION_CONFIRM,
SESSION_REJECT,
DATA_MESSAGE,
SESSION_END,
SESSION_ERROR;
companion object {
fun inferFromMessage(message: SessionMessage): MessageType {
return when (message) {
is InitialSessionMessage -> SESSION_INIT
is ExistingSessionMessage -> {
when(message.payload) {
is ConfirmSessionMessage -> SESSION_CONFIRM
is RejectSessionMessage -> SESSION_REJECT
is DataSessionMessage -> DATA_MESSAGE
is EndSessionMessage -> SESSION_END
is ErrorSessionMessage -> SESSION_ERROR
}
}
}
}
}
}

View File

@ -37,6 +37,10 @@ import net.corda.node.internal.InitiatedFlowFactory
import net.corda.node.services.api.CheckpointStorage import net.corda.node.services.api.CheckpointStorage
import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.messaging.DeduplicationHandler import net.corda.node.services.messaging.DeduplicationHandler
import net.corda.node.services.messaging.MessageIdentifier
import net.corda.node.services.messaging.SenderDeduplicationInfo
import net.corda.node.services.messaging.SenderSequenceNumber
import net.corda.node.services.messaging.SenderUUID
import net.corda.node.services.statemachine.FlowStateMachineImpl.Companion.currentStateMachine import net.corda.node.services.statemachine.FlowStateMachineImpl.Companion.currentStateMachine
import net.corda.node.services.statemachine.interceptors.DumpHistoryOnErrorInterceptor import net.corda.node.services.statemachine.interceptors.DumpHistoryOnErrorInterceptor
import net.corda.node.services.statemachine.interceptors.HospitalisingInterceptor import net.corda.node.services.statemachine.interceptors.HospitalisingInterceptor
@ -703,7 +707,7 @@ internal class SingleThreadedStateMachineManager(
val sender = serviceHub.networkMapCache.getPeerByLegalName(peer) val sender = serviceHub.networkMapCache.getPeerByLegalName(peer)
if (sender != null) { if (sender != null) {
when (sessionMessage) { when (sessionMessage) {
is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, sender, event) is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, sender, event, event.receivedMessage.uniqueMessageId, event.receivedMessage.senderUUID, event.receivedMessage.senderSeqNo)
is InitialSessionMessage -> onSessionInit(sessionMessage, sender, event) is InitialSessionMessage -> onSessionInit(sessionMessage, sender, event)
} }
} else { } else {
@ -716,7 +720,10 @@ internal class SingleThreadedStateMachineManager(
private fun onExistingSessionMessage( private fun onExistingSessionMessage(
sessionMessage: ExistingSessionMessage, sessionMessage: ExistingSessionMessage,
sender: Party, sender: Party,
externalEvent: ExternalEvent.ExternalMessageEvent externalEvent: ExternalEvent.ExternalMessageEvent,
messageIdentifier: MessageIdentifier,
senderUUID: SenderUUID?,
senderSequenceNumber: SenderSequenceNumber?
) { ) {
try { try {
val deduplicationHandler = externalEvent.deduplicationHandler val deduplicationHandler = externalEvent.deduplicationHandler
@ -734,7 +741,7 @@ internal class SingleThreadedStateMachineManager(
logger.info("Cannot find flow corresponding to session ID - $recipientId.") logger.info("Cannot find flow corresponding to session ID - $recipientId.")
} }
} else { } else {
val event = Event.DeliverSessionMessage(sessionMessage, deduplicationHandler, sender) val event = Event.DeliverSessionMessage(sessionMessage, deduplicationHandler, sender, messageIdentifier, senderUUID, senderSequenceNumber)
innerState.withLock { innerState.withLock {
flows[flowId]?.run { fiber.scheduleEvent(event) } flows[flowId]?.run { fiber.scheduleEvent(event) }
// If flow is not running add it to the list of external events to be processed if/when the flow resumes. // If flow is not running add it to the list of external events to be processed if/when the flow resumes.
@ -751,7 +758,7 @@ internal class SingleThreadedStateMachineManager(
private fun onSessionInit(sessionMessage: InitialSessionMessage, sender: Party, event: ExternalEvent.ExternalMessageEvent) { private fun onSessionInit(sessionMessage: InitialSessionMessage, sender: Party, event: ExternalEvent.ExternalMessageEvent) {
try { try {
val initiatedFlowFactory = getInitiatedFlowFactory(sessionMessage) val initiatedFlowFactory = getInitiatedFlowFactory(sessionMessage)
val initiatedSessionId = SessionId.createRandom(secureRandom) val initiatedSessionId = event.receivedMessage.uniqueMessageId.sessionIdentifier
val senderSession = FlowSessionImpl(sender, sender, initiatedSessionId) val senderSession = FlowSessionImpl(sender, sender, initiatedSessionId)
val flowLogic = initiatedFlowFactory.createFlow(senderSession) val flowLogic = initiatedFlowFactory.createFlow(senderSession)
val initiatedFlowInfo = when (initiatedFlowFactory) { val initiatedFlowInfo = when (initiatedFlowFactory) {
@ -763,9 +770,8 @@ internal class SingleThreadedStateMachineManager(
is InitiatedFlowFactory.CorDapp -> null is InitiatedFlowFactory.CorDapp -> null
} }
startInitiatedFlow( startInitiatedFlow(
event.flowId, event,
flowLogic, flowLogic,
event.deduplicationHandler,
senderSession, senderSession,
initiatedSessionId, initiatedSessionId,
sessionMessage, sessionMessage,
@ -800,24 +806,24 @@ internal class SingleThreadedStateMachineManager(
@Suppress("LongParameterList") @Suppress("LongParameterList")
private fun <A> startInitiatedFlow( private fun <A> startInitiatedFlow(
flowId: StateMachineRunId, event: ExternalEvent.ExternalMessageEvent,
flowLogic: FlowLogic<A>, flowLogic: FlowLogic<A>,
initiatingMessageDeduplicationHandler: DeduplicationHandler,
peerSession: FlowSessionImpl, peerSession: FlowSessionImpl,
initiatedSessionId: SessionId, initiatedSessionId: SessionId,
initiatingMessage: InitialSessionMessage, initiatingMessage: InitialSessionMessage,
senderCoreFlowVersion: Int?, senderCoreFlowVersion: Int?,
initiatedFlowInfo: FlowInfo initiatedFlowInfo: FlowInfo
) { ) {
val flowStart = FlowStart.Initiated(peerSession, initiatedSessionId, initiatingMessage, senderCoreFlowVersion, initiatedFlowInfo) val flowStart = FlowStart.Initiated(peerSession, initiatedSessionId, initiatingMessage, senderCoreFlowVersion, initiatedFlowInfo,
event.receivedMessage.uniqueMessageId.shardIdentifier, event.receivedMessage.senderUUID, event.receivedMessage.senderSeqNo)
val ourIdentity = ourFirstIdentity val ourIdentity = ourFirstIdentity
startFlowInternal( startFlowInternal(
flowId, event.flowId,
InvocationContext.peer(peerSession.counterparty.name), InvocationContext.peer(peerSession.counterparty.name),
flowLogic, flowLogic,
flowStart, flowStart,
ourIdentity, ourIdentity,
initiatingMessageDeduplicationHandler event.deduplicationHandler
) )
} }

View File

@ -21,6 +21,8 @@ import net.corda.core.utilities.debug
import net.corda.core.utilities.minutes import net.corda.core.utilities.minutes
import net.corda.core.utilities.seconds import net.corda.core.utilities.seconds
import net.corda.node.services.FinalityHandler import net.corda.node.services.FinalityHandler
import net.corda.node.services.messaging.MessageIdentifier
import net.corda.node.services.messaging.SenderDeduplicationInfo
import org.hibernate.exception.ConstraintViolationException import org.hibernate.exception.ConstraintViolationException
import rx.subjects.PublishSubject import rx.subjects.PublishSubject
import java.io.Closeable import java.io.Closeable
@ -169,7 +171,9 @@ class StaffedFlowHospital(private val flowMessaging: FlowMessaging,
log.info("Sending session initiation error back to $sender", error) log.info("Sending session initiation error back to $sender", error)
flowMessaging.sendSessionMessage(sender, replyError, SenderDeduplicationId(DeduplicationId.createRandom(secureRandom), ourSenderUUID)) val messageType = MessageType.inferFromMessage(replyError)
val messageIdentifier = MessageIdentifier(messageType, event.receivedMessage.uniqueMessageId.shardIdentifier, sessionMessage.initiatorSessionId, 0, event.receivedMessage.uniqueMessageId.timestamp)
flowMessaging.sendSessionMessage(sender, replyError, SenderDeduplicationInfo(messageIdentifier, ourSenderUUID))
event.deduplicationHandler.afterDatabaseTransaction() event.deduplicationHandler.afterDatabaseTransaction()
} }

View File

@ -22,11 +22,16 @@ import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.checkpointDeserialize import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.utilities.Try import net.corda.core.utilities.Try
import net.corda.node.services.messaging.DeduplicationHandler import net.corda.node.services.messaging.DeduplicationHandler
import net.corda.node.services.messaging.MessageIdentifier
import net.corda.node.services.messaging.SenderSequenceNumber
import net.corda.node.services.messaging.SenderUUID
import java.lang.IllegalArgumentException
import java.lang.IllegalStateException import java.lang.IllegalStateException
import java.security.Principal import java.security.Principal
import java.time.Instant import java.time.Instant
import java.util.concurrent.Future import java.util.concurrent.Future
import java.util.concurrent.Semaphore import java.util.concurrent.Semaphore
import kotlin.math.max
/** /**
* The state of the state machine, capturing the state of a flow. It consists of two parts, an *immutable* part that is * The state of the state machine, capturing the state of a flow. It consists of two parts, an *immutable* part that is
@ -35,6 +40,7 @@ import java.util.concurrent.Semaphore
* @param checkpoint the persisted part of the state. * @param checkpoint the persisted part of the state.
* @param flowLogic the [FlowLogic] associated with the flow. Note that this is mutable by the user. * @param flowLogic the [FlowLogic] associated with the flow. Note that this is mutable by the user.
* @param pendingDeduplicationHandlers the list of incomplete deduplication handlers. * @param pendingDeduplicationHandlers the list of incomplete deduplication handlers.
* @param closedSessionsPendingToBeSignalled the sessions that have been closed and need to be signalled to the messaging layer on the next checkpoint (along with some metadata).
* @param isFlowResumed true if the control is returned (or being returned) to "user-space" flow code. This is used * @param isFlowResumed true if the control is returned (or being returned) to "user-space" flow code. This is used
* to make [Event.DoRemainingWork] idempotent. * to make [Event.DoRemainingWork] idempotent.
* @param isWaitingForFuture true if the flow is waiting for the completion of a future triggered by one of the statemachine's actions * @param isWaitingForFuture true if the flow is waiting for the completion of a future triggered by one of the statemachine's actions
@ -61,6 +67,7 @@ data class StateMachineState(
val checkpoint: Checkpoint, val checkpoint: Checkpoint,
val flowLogic: FlowLogic<*>, val flowLogic: FlowLogic<*>,
val pendingDeduplicationHandlers: List<DeduplicationHandler>, val pendingDeduplicationHandlers: List<DeduplicationHandler>,
val closedSessionsPendingToBeSignalled: Map<SessionId, Pair<SenderUUID?, SenderSequenceNumber?>>,
val isFlowResumed: Boolean, val isFlowResumed: Boolean,
val isWaitingForFuture: Boolean, val isWaitingForFuture: Boolean,
var future: Future<*>?, var future: Future<*>?,
@ -123,7 +130,8 @@ data class Checkpoint(
frozenFlowLogic: SerializedBytes<FlowLogic<*>>, frozenFlowLogic: SerializedBytes<FlowLogic<*>>,
ourIdentity: Party, ourIdentity: Party,
subFlowVersion: SubFlowVersion, subFlowVersion: SubFlowVersion,
isEnabledTimedFlow: Boolean isEnabledTimedFlow: Boolean,
timestamp: Instant
): Try<Checkpoint> { ): Try<Checkpoint> {
return SubFlow.create(flowLogicClass, subFlowVersion, isEnabledTimedFlow).map { topLevelSubFlow -> return SubFlow.create(flowLogicClass, subFlowVersion, isEnabledTimedFlow).map { topLevelSubFlow ->
Checkpoint( Checkpoint(
@ -135,7 +143,8 @@ data class Checkpoint(
listOf(topLevelSubFlow), listOf(topLevelSubFlow),
numberOfSuspends = 0, numberOfSuspends = 0,
// We set this to 1 here to avoid an extra copy and increment in UnstartedFlowTransition.createInitialCheckpoint // We set this to 1 here to avoid an extra copy and increment in UnstartedFlowTransition.createInitialCheckpoint
numberOfCommits = 1 numberOfCommits = 1,
suspensionTime = timestamp
), ),
flowState = FlowState.Unstarted(flowStart, frozenFlowLogic), flowState = FlowState.Unstarted(flowStart, frozenFlowLogic),
errorState = ErrorState.Clean errorState = ErrorState.Clean
@ -235,6 +244,7 @@ data class Checkpoint(
} }
/** /**
<<<<<<< HEAD
* @param invocationContext The initiator of the flow. * @param invocationContext The initiator of the flow.
* @param ourIdentity The identity the flow is run as. * @param ourIdentity The identity the flow is run as.
* @param sessions Map of source session ID to session state. * @param sessions Map of source session ID to session state.
@ -242,6 +252,7 @@ data class Checkpoint(
* @param subFlowStack The stack of currently executing subflows. * @param subFlowStack The stack of currently executing subflows.
* @param numberOfSuspends The number of flow suspends due to IO API calls. * @param numberOfSuspends The number of flow suspends due to IO API calls.
* @param numberOfCommits The number of times this checkpoint has been persisted. * @param numberOfCommits The number of times this checkpoint has been persisted.
* @param suspensionTime the time of the last suspension. This is supposed to be used as a stable timestamp in case of replays.
*/ */
@CordaSerializable @CordaSerializable
data class CheckpointState( data class CheckpointState(
@ -251,7 +262,8 @@ data class CheckpointState(
val sessionsToBeClosed: Set<SessionId>, val sessionsToBeClosed: Set<SessionId>,
val subFlowStack: List<SubFlow>, val subFlowStack: List<SubFlow>,
val numberOfSuspends: Int, val numberOfSuspends: Int,
val numberOfCommits: Int val numberOfCommits: Int,
val suspensionTime: Instant
) )
/** /**
@ -262,44 +274,162 @@ sealed class SessionState {
abstract val deduplicationSeed: String abstract val deduplicationSeed: String
/** /**
* We haven't yet sent the initialisation message * the sender UUID last seen in this session, if there was one.
*/
abstract val lastSenderUUID: SenderUUID?
/**
* the sender sequence number last seen in this session, if there was one.
*/
abstract val lastSenderSeqNo: SenderSequenceNumber?
/**
* the messages that have been received and are pending processing indexed by their sequence number.
* this could be any [ExistingSessionMessagePayload] type in theory, but it in practice it can only be one of the following types now:
* * [DataSessionMessage]
* * [ErrorSessionMessage]
* * [EndSessionMessage]
*/
abstract val receivedMessages: Map<Int, ExistingSessionMessagePayload>
/**
* Returns a new session state with the specified messages added to the list of received messages.
*/
fun addReceivedMessages(message: ExistingSessionMessagePayload, messageIdentifier: MessageIdentifier, senderUUID: String?, senderSequenceNumber: Long?): SessionState {
val newReceivedMessages = receivedMessages.plus(messageIdentifier.sessionSequenceNumber to message)
val (newLastSenderUUID, newLastSenderSeqNo) = calculateSenderInfo(lastSenderUUID, lastSenderSeqNo, senderUUID, senderSequenceNumber)
return when(this) {
is Uninitiated -> { copy(receivedMessages = newReceivedMessages, lastSenderUUID = newLastSenderUUID, lastSenderSeqNo = newLastSenderSeqNo) }
is Initiating -> { copy(receivedMessages = newReceivedMessages, lastSenderUUID = newLastSenderUUID, lastSenderSeqNo = newLastSenderSeqNo) }
is Initiated -> { copy(receivedMessages = newReceivedMessages, lastSenderUUID = newLastSenderUUID, lastSenderSeqNo = newLastSenderSeqNo) }
}
}
private fun calculateSenderInfo(currentSender: String?, currentSenderSeqNo: Long?, msgSender: String?, msgSenderSeqNo: Long?): Pair<String?, Long?> {
return if (msgSender != null && msgSenderSeqNo != null) {
if (currentSenderSeqNo != null)
Pair(msgSender, max(msgSenderSeqNo, currentSenderSeqNo))
else
Pair(msgSender, msgSenderSeqNo)
} else {
Pair(currentSender, currentSenderSeqNo)
}
}
/**
* We haven't yet sent the initialisation message.
* This really means that the flow is in a state before sending the initialisation message,
* but in reality it could have sent it before and fail before reaching the next checkpoint, thus ending up replaying from the last checkpoint.
*
* @param hasBeenAcknowledged whether a positive response to a session initiation has already been received and the associated confirmation message, if so.
* @param hasBeenRejected whether a negative response to a session initiation has already been received and the associated rejection message, if so.
*/ */
data class Uninitiated( data class Uninitiated(
val destination: Destination, val destination: Destination,
val initiatingSubFlow: SubFlow.Initiating, val initiatingSubFlow: SubFlow.Initiating,
val sourceSessionId: SessionId, val sourceSessionId: SessionId,
val additionalEntropy: Long val additionalEntropy: Long,
val hasBeenAcknowledged: Pair<Party, ConfirmSessionMessage>?,
val hasBeenRejected: RejectSessionMessage?,
override val receivedMessages: Map<Int, ExistingSessionMessagePayload>,
override val lastSenderUUID: String?,
override val lastSenderSeqNo: Long?
) : SessionState() { ) : SessionState() {
override val deduplicationSeed: String get() = "R-${sourceSessionId.toLong}-$additionalEntropy" override val deduplicationSeed: String get() = "R-${sourceSessionId.value}-$additionalEntropy"
} }
/** /**
* We have sent the initialisation message but have not yet received a confirmation. * We have sent the initialisation message but have not yet received a confirmation.
* @property bufferedMessages the messages that have been buffered to be sent after the session is confirmed from the other side.
* @property rejectionError if non-null the initiation failed. * @property rejectionError if non-null the initiation failed.
* @property nextSendingSeqNumber the sequence number of the next message to be sent.
* @property shardId the shard ID of the associated flow to be embedded on all the messages sent from this session.
*/ */
data class Initiating( data class Initiating(
val bufferedMessages: List<Pair<DeduplicationId, ExistingSessionMessagePayload>>, val bufferedMessages: List<Pair<MessageIdentifier, ExistingSessionMessagePayload>>,
val rejectionError: FlowError?, val rejectionError: FlowError?,
override val deduplicationSeed: String override val deduplicationSeed: String,
) : SessionState() val nextSendingSeqNumber: Int,
val shardId: String,
override val receivedMessages: Map<Int, ExistingSessionMessagePayload>,
override val lastSenderUUID: String?,
override val lastSenderSeqNo: Long?
) : SessionState() {
/**
* Buffers an outgoing message to be sent when ready.
* Returns the new form of the state
*/
fun bufferMessage(messageIdentifier: MessageIdentifier, messagePayload: ExistingSessionMessagePayload): SessionState {
return this.copy(bufferedMessages = bufferedMessages + Pair(messageIdentifier, messagePayload), nextSendingSeqNumber = nextSendingSeqNumber + 1)
}
/**
* A batched form of [bufferMessage].
*/
fun bufferMessages(messages: List<Pair<MessageIdentifier, ExistingSessionMessagePayload>>): SessionState {
return this.copy(bufferedMessages = bufferedMessages + messages, nextSendingSeqNumber = nextSendingSeqNumber + messages.size)
}
}
/** /**
* We have received a confirmation, the peer party and session id is resolved. * We have received a confirmation, the peer party and session id is resolved.
* @property receivedMessages the messages that have been received and are pending processing.
* this could be any [ExistingSessionMessagePayload] type in theory, but it in practice it can only be one of the following types now:
* * [DataSessionMessage]
* * [ErrorSessionMessage]
* * [EndSessionMessage]
* @property otherSideErrored whether the session has received an error from the other side. * @property otherSideErrored whether the session has received an error from the other side.
* @property nextSendingSeqNumber the sequence number that corresponds to the next message to be sent.
* @property lastProcessedSeqNumber the sequence number of the last message that has been processed.
* @property shardId the shard ID of the associated flow to be embedded on all the messages sent from this session.
*/ */
data class Initiated( data class Initiated(
val peerParty: Party, val peerParty: Party,
val peerFlowInfo: FlowInfo, val peerFlowInfo: FlowInfo,
val receivedMessages: List<ExistingSessionMessagePayload>,
val otherSideErrored: Boolean, val otherSideErrored: Boolean,
val peerSinkSessionId: SessionId, val peerSinkSessionId: SessionId,
override val deduplicationSeed: String override val deduplicationSeed: String,
) : SessionState() val nextSendingSeqNumber: Int,
val lastProcessedSeqNumber: Int,
val shardId: String,
override val receivedMessages: Map<Int, ExistingSessionMessagePayload>,
override val lastSenderUUID: String?,
override val lastSenderSeqNo: Long?
) : SessionState() {
/**
* Indicates whether this message has already been processed.
*/
fun isDuplicate(messageIdentifier: MessageIdentifier): Boolean {
return messageIdentifier.sessionSequenceNumber <= lastProcessedSeqNumber
}
/**
* Indicates whether the session has an error message pending from the other side.
*/
fun hasErrored(): Boolean {
return hasNextMessageArrived() && receivedMessages[lastProcessedSeqNumber + 1] is ErrorSessionMessage
}
/**
* Indicates whether the next expected message has arrived.
*/
fun hasNextMessageArrived(): Boolean {
return receivedMessages.containsKey(lastProcessedSeqNumber + 1)
}
/**
* Returns the next message to be processed and the new session state.
* If you want to check first whether the next message has arrived, call [hasNextMessageArrived]
*
* @throws [IllegalArgumentException] if the next hasn't arrived.
*/
fun extractMessage(): Pair<ExistingSessionMessagePayload, Initiated> {
if (!hasNextMessageArrived()) {
throw IllegalArgumentException("Tried to extract a message that hasn't arrived yet.")
}
val message = receivedMessages[lastProcessedSeqNumber + 1]!!
val newState = this.copy(receivedMessages = receivedMessages.minus(lastProcessedSeqNumber + 1), lastProcessedSeqNumber = lastProcessedSeqNumber + 1)
return Pair(message, newState)
}
}
} }
typealias SessionMap = Map<SessionId, SessionState> typealias SessionMap = Map<SessionId, SessionState>
@ -321,7 +451,10 @@ sealed class FlowStart {
val initiatedSessionId: SessionId, val initiatedSessionId: SessionId,
val initiatingMessage: InitialSessionMessage, val initiatingMessage: InitialSessionMessage,
val senderCoreFlowVersion: Int?, val senderCoreFlowVersion: Int?,
val initiatedFlowInfo: FlowInfo val initiatedFlowInfo: FlowInfo,
val shardIdentifier: String,
val senderUUID: String?,
val senderSequenceNumber: Long?
) : FlowStart() { override fun toString() = "Initiated" } ) : FlowStart() { override fun toString() = "Initiated" }
} }

View File

@ -3,6 +3,7 @@ package net.corda.node.services.statemachine.transitions
import net.corda.core.flows.UnexpectedFlowEndException import net.corda.core.flows.UnexpectedFlowEndException
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.debug import net.corda.core.utilities.debug
import net.corda.core.utilities.trace
import net.corda.node.services.statemachine.Action import net.corda.node.services.statemachine.Action
import net.corda.node.services.statemachine.ConfirmSessionMessage import net.corda.node.services.statemachine.ConfirmSessionMessage
import net.corda.node.services.statemachine.DataSessionMessage import net.corda.node.services.statemachine.DataSessionMessage
@ -13,7 +14,7 @@ import net.corda.node.services.statemachine.ExistingSessionMessage
import net.corda.node.services.statemachine.FlowError import net.corda.node.services.statemachine.FlowError
import net.corda.node.services.statemachine.FlowState import net.corda.node.services.statemachine.FlowState
import net.corda.node.services.statemachine.RejectSessionMessage import net.corda.node.services.statemachine.RejectSessionMessage
import net.corda.node.services.statemachine.SenderDeduplicationId import net.corda.node.services.messaging.SenderDeduplicationInfo
import net.corda.node.services.statemachine.SessionState import net.corda.node.services.statemachine.SessionState
import net.corda.node.services.statemachine.StateMachineState import net.corda.node.services.statemachine.StateMachineState
@ -87,56 +88,78 @@ class DeliverSessionMessageTransition(
val initiatedSession = SessionState.Initiated( val initiatedSession = SessionState.Initiated(
peerParty = event.sender, peerParty = event.sender,
peerFlowInfo = message.initiatedFlowInfo, peerFlowInfo = message.initiatedFlowInfo,
receivedMessages = emptyList(), receivedMessages = emptyMap(),
peerSinkSessionId = message.initiatedSessionId, peerSinkSessionId = message.initiatedSessionId,
deduplicationSeed = sessionState.deduplicationSeed, deduplicationSeed = sessionState.deduplicationSeed,
otherSideErrored = false otherSideErrored = false,
nextSendingSeqNumber = sessionState.nextSendingSeqNumber,
lastProcessedSeqNumber = 0,
shardId = sessionState.shardId,
lastSenderUUID = event.senderUUID,
lastSenderSeqNo = event.senderSequenceNumber
) )
val newCheckpoint = currentState.checkpoint.addSession( val newCheckpoint = currentState.checkpoint.addSession(
event.sessionMessage.recipientSessionId to initiatedSession event.sessionMessage.recipientSessionId to initiatedSession
) )
// Send messages that were buffered pending confirmation of session. // Send messages that were buffered pending confirmation of session.
val sendActions = sessionState.bufferedMessages.map { (deduplicationId, bufferedMessage) -> val sendActions = sessionState.bufferedMessages.map { (messageId, bufferedMessage) ->
val existingMessage = ExistingSessionMessage(message.initiatedSessionId, bufferedMessage) val existingMessage = ExistingSessionMessage(message.initiatedSessionId, bufferedMessage)
Action.SendExisting(initiatedSession.peerParty, existingMessage, SenderDeduplicationId(deduplicationId, startingState.senderUUID)) Action.SendExisting(initiatedSession.peerParty, existingMessage, SenderDeduplicationInfo(messageId, startingState.senderUUID))
} }
actions.addAll(sendActions) actions.addAll(sendActions)
currentState = currentState.copy(checkpoint = newCheckpoint) currentState = currentState.copy(checkpoint = newCheckpoint)
} }
else -> freshErrorTransition(UnexpectedEventInState()) is SessionState.Initiated -> {
log.trace { "Discarding duplicate confirmation for session ${event.sessionMessage.recipientSessionId} with ${sessionState.peerParty}" }
}
is SessionState.Uninitiated -> {
val newSessionState = sessionState.copy(hasBeenAcknowledged = Pair(event.sender, message))
val newCheckpoint = currentState.checkpoint.addSession(event.sessionMessage.recipientSessionId to newSessionState)
currentState = currentState.copy(checkpoint = newCheckpoint)
}
} }
} }
private fun TransitionBuilder.dataMessageTransition(sessionState: SessionState, message: DataSessionMessage) { private fun TransitionBuilder.dataMessageTransition(sessionState: SessionState, message: DataSessionMessage) {
// We received a data message. The corresponding session must be Initiated.
return when (sessionState) { return when (sessionState) {
is SessionState.Initiated -> { is SessionState.Initiated -> {
// Buffer the message in the session's receivedMessages buffer. if (!sessionState.isDuplicate(event.messageIdentifier)) {
val newSessionState = sessionState.copy( val newSessionState = sessionState.addReceivedMessages(message, event.messageIdentifier, event.senderUUID, event.senderSequenceNumber)
receivedMessages = sessionState.receivedMessages + message currentState = currentState.copy(
) checkpoint = currentState.checkpoint.addSession(event.sessionMessage.recipientSessionId to newSessionState)
)
} else {
log.trace { "Discarding duplicate data message for session ${event.sessionMessage.recipientSessionId} with ${sessionState.peerParty}" }
}
}
is SessionState.Initiating, is SessionState.Uninitiated -> {
val newSessionState = sessionState.addReceivedMessages(message, event.messageIdentifier, event.senderUUID, event.senderSequenceNumber)
currentState = currentState.copy( currentState = currentState.copy(
checkpoint = currentState.checkpoint.addSession( checkpoint = currentState.checkpoint.addSession(event.sessionMessage.recipientSessionId to newSessionState)
event.sessionMessage.recipientSessionId to newSessionState
)
) )
} }
else -> freshErrorTransition(UnexpectedEventInState())
} }
} }
private fun TransitionBuilder.errorMessageTransition(sessionState: SessionState, payload: ErrorSessionMessage) { private fun TransitionBuilder.errorMessageTransition(sessionState: SessionState, payload: ErrorSessionMessage) {
val sequenceNumber = event.messageIdentifier.sessionSequenceNumber
return when (sessionState) { return when (sessionState) {
is SessionState.Initiated -> { is SessionState.Initiated -> {
val checkpoint = currentState.checkpoint if (sequenceNumber > sessionState.lastProcessedSeqNumber) {
val sessionId = event.sessionMessage.recipientSessionId val newSessionState = sessionState.addReceivedMessages(payload, event.messageIdentifier, event.senderUUID, event.senderSequenceNumber)
val newSessionState = sessionState.copy(receivedMessages = sessionState.receivedMessages + payload) currentState = currentState.copy(
checkpoint = currentState.checkpoint.addSession(event.sessionMessage.recipientSessionId to newSessionState)
)
} else {
log.trace { "Discarding duplicate error message for session ${event.sessionMessage.recipientSessionId} with ${sessionState.peerParty}" }
}
}
is SessionState.Initiating, is SessionState.Uninitiated -> {
val newSessionState = sessionState.addReceivedMessages(payload, event.messageIdentifier, event.senderUUID, event.senderSequenceNumber)
currentState = currentState.copy( currentState = currentState.copy(
checkpoint = checkpoint.addSession(sessionId to newSessionState) checkpoint = currentState.checkpoint.addSession(event.sessionMessage.recipientSessionId to newSessionState)
) )
} }
else -> freshErrorTransition(UnexpectedEventInState())
} }
} }
@ -145,42 +168,42 @@ class DeliverSessionMessageTransition(
return when (sessionState) { return when (sessionState) {
is SessionState.Initiating -> { is SessionState.Initiating -> {
if (sessionState.rejectionError != null) { if (sessionState.rejectionError != null) {
// Double reject log.trace { "Discarding duplicate session rejection message for session ${event.sessionMessage.recipientSessionId}" }
freshErrorTransition(UnexpectedEventInState())
} else { } else {
val checkpoint = currentState.checkpoint
val sessionId = event.sessionMessage.recipientSessionId
val flowError = FlowError(payload.errorId, exception) val flowError = FlowError(payload.errorId, exception)
currentState = currentState.copy( currentState = currentState.copy(
checkpoint = checkpoint.addSession(sessionId to sessionState.copy(rejectionError = flowError)) checkpoint = currentState.checkpoint.addSession(event.sessionMessage.recipientSessionId to sessionState.copy(rejectionError = flowError))
) )
} }
} }
else -> freshErrorTransition(UnexpectedEventInState()) is SessionState.Uninitiated -> {
currentState = currentState.copy(
checkpoint = currentState.checkpoint.addSession(event.sessionMessage.recipientSessionId to sessionState.copy(hasBeenRejected = payload))
)
}
is SessionState.Initiated -> {
freshErrorTransition(UnexpectedEventInState("A session rejection message was received for an already established session ${event.messageIdentifier.sessionIdentifier}."))
}
} }
} }
private fun TransitionBuilder.endMessageTransition(payload: EndSessionMessage) { private fun TransitionBuilder.endMessageTransition(payload: EndSessionMessage) {
val flowState = currentState.checkpoint.flowState
// flow must have already been started when session end messages are being delivered.
if (flowState !is FlowState.Started)
return freshErrorTransition(UnexpectedEventInState())
val sessionId = event.sessionMessage.recipientSessionId val sessionId = event.sessionMessage.recipientSessionId
val sessions = currentState.checkpoint.checkpointState.sessions val sessions = currentState.checkpoint.checkpointState.sessions
// a check has already been performed to confirm the session exists for this message before this method is invoked. // a check has already been performed to confirm the session exists for this message before this method is invoked.
val sessionState = sessions[sessionId]!! val sessionState = sessions[sessionId]!!
when (sessionState) { when (sessionState) {
is SessionState.Initiated -> { is SessionState.Initiated, is SessionState.Initiating, is SessionState.Uninitiated -> {
val flowState = currentState.checkpoint.flowState val newSessionState = sessionState.addReceivedMessages(payload, event.messageIdentifier, event.senderUUID, event.senderSequenceNumber)
// flow must have already been started when session end messages are being delivered.
if (flowState !is FlowState.Started)
return freshErrorTransition(UnexpectedEventInState())
val newSessionState = sessionState.copy(receivedMessages = sessionState.receivedMessages + payload)
val newCheckpoint = currentState.checkpoint.addSession(event.sessionMessage.recipientSessionId to newSessionState) val newCheckpoint = currentState.checkpoint.addSession(event.sessionMessage.recipientSessionId to newSessionState)
.addSessionsToBeClosed(setOf(event.sessionMessage.recipientSessionId)) .addSessionsToBeClosed(setOf(event.sessionMessage.recipientSessionId))
currentState = currentState.copy(checkpoint = newCheckpoint) currentState = currentState.copy(checkpoint = newCheckpoint)
} }
else -> {
freshErrorTransition(PrematureSessionEndException(event.sessionMessage.recipientSessionId))
}
} }
} }

View File

@ -1,6 +1,7 @@
package net.corda.node.services.statemachine.transitions package net.corda.node.services.statemachine.transitions
import net.corda.core.flows.FlowException import net.corda.core.flows.FlowException
import net.corda.node.services.messaging.MessageIdentifier
import net.corda.node.services.statemachine.* import net.corda.node.services.statemachine.*
/** /**
@ -40,16 +41,28 @@ class ErrorFlowTransition(
return builder { return builder {
// If we're errored and propagating do the actual propagation and update the index. // If we're errored and propagating do the actual propagation and update the index.
if (remainingErrorsToPropagate.isNotEmpty() && errorState.propagating) { if (remainingErrorsToPropagate.isNotEmpty() && errorState.propagating) {
val (initiatedSessions, newSessions) = bufferErrorMessagesInInitiatingSessions( val (initiatedSessions, newSessionStates) = bufferErrorMessagesInInitiatingSessions(
startingState.checkpoint.checkpointState.sessions, startingState.checkpoint.checkpointState.sessions,
errorMessages errorMessages
) )
val sessionsWithAdvancedSeqNumbers = mutableMapOf<SessionId, SessionState>()
val errorsPerSession = initiatedSessions.map { (sessionId, sessionState) ->
var currentSeqNumber = sessionState.nextSendingSeqNumber
val errorsWithId = errorMessages.map { errorMsg ->
val messageIdentifier = MessageIdentifier(MessageType.SESSION_ERROR, sessionState.shardId, sessionState.peerSinkSessionId, currentSeqNumber, startingState.checkpoint.checkpointState.suspensionTime)
currentSeqNumber++
Pair(messageIdentifier, errorMsg)
}.toList()
sessionsWithAdvancedSeqNumbers[sessionId] = sessionState.copy(nextSendingSeqNumber = currentSeqNumber)
Pair(sessionState, errorsWithId)
}.toMap()
val newCheckpoint = startingState.checkpoint.copy( val newCheckpoint = startingState.checkpoint.copy(
errorState = errorState.copy(propagatedIndex = allErrors.size), errorState = errorState.copy(propagatedIndex = allErrors.size),
checkpointState = startingState.checkpoint.checkpointState.copy(sessions = newSessions) checkpointState = startingState.checkpoint.checkpointState.copy(sessions = newSessionStates + sessionsWithAdvancedSeqNumbers)
) )
currentState = currentState.copy(checkpoint = newCheckpoint) currentState = currentState.copy(checkpoint = newCheckpoint)
actions += Action.PropagateErrors(errorMessages, initiatedSessions, startingState.senderUUID) actions += Action.PropagateErrors(errorsPerSession, startingState.senderUUID)
} }
// If we're errored but not propagating keep processing events. // If we're errored but not propagating keep processing events.
@ -81,16 +94,27 @@ class ErrorFlowTransition(
isCheckpointUpdate = currentState.isAnyCheckpointPersisted isCheckpointUpdate = currentState.isAnyCheckpointPersisted
) )
} }
val signalSessionsEndMap = currentState.checkpoint.checkpointState.sessions.map { (sessionId, sessionState) ->
sessionId to Pair(sessionState.lastSenderUUID, sessionState.lastSenderSeqNo)
}.toMap()
actions += Action.CreateTransaction actions += Action.CreateTransaction
actions += removeOrPersistCheckpoint actions += removeOrPersistCheckpoint
actions += Action.PersistDeduplicationFacts(startingState.pendingDeduplicationHandlers) actions += Action.PersistDeduplicationFacts(startingState.pendingDeduplicationHandlers)
actions += Action.SignalSessionsHasEnded(signalSessionsEndMap)
actions += Action.ReleaseSoftLocks(context.id.uuid) actions += Action.ReleaseSoftLocks(context.id.uuid)
actions += Action.CommitTransaction(currentState) actions += Action.CommitTransaction(currentState)
actions += Action.AcknowledgeMessages(startingState.pendingDeduplicationHandlers) actions += Action.AcknowledgeMessages(startingState.pendingDeduplicationHandlers)
actions += Action.RemoveSessionBindings(startingState.checkpoint.checkpointState.sessions.keys) actions += Action.RemoveSessionBindings(startingState.checkpoint.checkpointState.sessions.keys)
actions += Action.RemoveFlow(context.id, FlowRemovalReason.ErrorFinish(allErrors), currentState) actions += Action.RemoveFlow(context.id, FlowRemovalReason.ErrorFinish(allErrors), currentState)
currentState = currentState.copy(
checkpoint = newCheckpoint,
pendingDeduplicationHandlers = emptyList(),
closedSessionsPendingToBeSignalled = emptyMap(),
isRemoved = true
)
FlowContinuation.Abort FlowContinuation.Abort
} else { } else {
// Otherwise keep processing events. This branch happens when there are some outstanding initiating // Otherwise keep processing events. This branch happens when there are some outstanding initiating
@ -112,31 +136,37 @@ class ErrorFlowTransition(
} }
} }
// Buffer error messages in Initiating sessions, return the initialised ones. /**
* Buffers errors message for initiating states and filters the initiated states.
* Returns a pair that consists of:
* - a map containing the initiated states as filtered from the ones provided as input.
* - a map containing the new state of all the sessions.
*/
private fun bufferErrorMessagesInInitiatingSessions( private fun bufferErrorMessagesInInitiatingSessions(
sessions: Map<SessionId, SessionState>, sessions: Map<SessionId, SessionState>,
errorMessages: List<ErrorSessionMessage> errorMessages: List<ErrorSessionMessage>
): Pair<List<SessionState.Initiated>, Map<SessionId, SessionState>> { ): Pair<Map<SessionId, SessionState.Initiated>, Map<SessionId, SessionState>> {
val newSessions = sessions.mapValues { (sourceSessionId, sessionState) -> val newSessionStates = sessions.mapValues { (sourceSessionId, sessionState) ->
if (sessionState is SessionState.Initiating && sessionState.rejectionError == null) { if (sessionState is SessionState.Initiating && sessionState.rejectionError == null) {
// *prepend* the error messages in order to error the other sessions ASAP. The other messages will var currentSequenceNumber = sessionState.nextSendingSeqNumber
// be delivered all the same, they just won't trigger flow resumption because of dirtiness. val errorMessagesWithDeduplication = errorMessages.map { errorMessage ->
val errorMessagesWithDeduplication = errorMessages.map { val messageIdentifier = MessageIdentifier(MessageType.SESSION_ERROR, sessionState.shardId, sourceSessionId.calculateInitiatedSessionId(), currentSequenceNumber, startingState.checkpoint.checkpointState.suspensionTime)
DeduplicationId.createForError(it.errorId, sourceSessionId) to it currentSequenceNumber++
messageIdentifier to errorMessage
} }
sessionState.copy(bufferedMessages = errorMessagesWithDeduplication + sessionState.bufferedMessages) sessionState.bufferMessages(errorMessagesWithDeduplication)
} else { } else {
sessionState sessionState
} }
} }
// if we have already received error message from the other side, we don't include that session in the list to avoid propagating errors. // if we have already received error message from the other side, we don't include that session in the list to avoid propagating errors.
val initiatedSessions = sessions.values.mapNotNull { session -> val initiatedSessions = sessions.mapNotNull { (sessionId, sessionState) ->
if (session is SessionState.Initiated && !session.otherSideErrored) { if (sessionState is SessionState.Initiated && !sessionState.otherSideErrored) {
session sessionId to sessionState
} else { } else {
null null
} }
} }.toMap()
return Pair(initiatedSessions, newSessions) return Pair(initiatedSessions, newSessionStates)
} }
} }

View File

@ -2,14 +2,15 @@ package net.corda.node.services.statemachine.transitions
import net.corda.core.flows.FlowException import net.corda.core.flows.FlowException
import net.corda.core.flows.KilledFlowException import net.corda.core.flows.KilledFlowException
import net.corda.node.services.messaging.MessageIdentifier
import net.corda.node.services.statemachine.Action import net.corda.node.services.statemachine.Action
import net.corda.node.services.statemachine.Checkpoint import net.corda.node.services.statemachine.Checkpoint
import net.corda.node.services.statemachine.DeduplicationId
import net.corda.node.services.statemachine.ErrorSessionMessage import net.corda.node.services.statemachine.ErrorSessionMessage
import net.corda.node.services.statemachine.Event import net.corda.node.services.statemachine.Event
import net.corda.node.services.statemachine.FlowError import net.corda.node.services.statemachine.FlowError
import net.corda.node.services.statemachine.FlowRemovalReason import net.corda.node.services.statemachine.FlowRemovalReason
import net.corda.node.services.statemachine.FlowState import net.corda.node.services.statemachine.FlowState
import net.corda.node.services.statemachine.MessageType
import net.corda.node.services.statemachine.SessionId import net.corda.node.services.statemachine.SessionId
import net.corda.node.services.statemachine.SessionState import net.corda.node.services.statemachine.SessionState
import net.corda.node.services.statemachine.StateMachineState import net.corda.node.services.statemachine.StateMachineState
@ -27,24 +28,37 @@ class KilledFlowTransition(
val killedFlowErrorMessage = createErrorMessageFromError(killedFlowError) val killedFlowErrorMessage = createErrorMessageFromError(killedFlowError)
val errorMessages = listOf(killedFlowErrorMessage) val errorMessages = listOf(killedFlowErrorMessage)
val (initiatedSessions, newSessions) = bufferErrorMessagesInInitiatingSessions( val (initiatedSessions, newSessionStates) = bufferErrorMessagesInInitiatingSessions(
startingState.checkpoint.checkpointState.sessions, startingState.checkpoint.checkpointState.sessions,
errorMessages errorMessages
) )
val sessionsWithAdvancedSeqNumbers = mutableMapOf<SessionId, SessionState>()
val errorsPerSession = initiatedSessions.map { (sessionId, sessionState) ->
var currentSeqNumber = sessionState.nextSendingSeqNumber
val errorsWithId = errorMessages.map { errorMsg ->
val messageIdentifier = MessageIdentifier(MessageType.SESSION_ERROR, sessionState.shardId, sessionState.peerSinkSessionId, currentSeqNumber, startingState.checkpoint.checkpointState.suspensionTime)
currentSeqNumber++
Pair(messageIdentifier, errorMsg)
}.toList()
sessionsWithAdvancedSeqNumbers[sessionId] = sessionState.copy(nextSendingSeqNumber = currentSeqNumber)
Pair(sessionState, errorsWithId)
}.toMap()
val newCheckpoint = startingState.checkpoint.copy( val newCheckpoint = startingState.checkpoint.copy(
status = Checkpoint.FlowStatus.KILLED, status = Checkpoint.FlowStatus.KILLED,
flowState = FlowState.Finished, flowState = FlowState.Finished,
checkpointState = startingState.checkpoint.checkpointState.copy(sessions = newSessions) checkpointState = startingState.checkpoint.checkpointState.copy(sessions = newSessionStates + sessionsWithAdvancedSeqNumbers)
) )
currentState = currentState.copy( currentState = currentState.copy(
checkpoint = newCheckpoint, checkpoint = newCheckpoint,
pendingDeduplicationHandlers = emptyList(), pendingDeduplicationHandlers = emptyList(),
isRemoved = true closedSessionsPendingToBeSignalled = emptyMap(),
isRemoved = true
) )
actions += Action.PropagateErrors(errorMessages, initiatedSessions, startingState.senderUUID) actions += Action.PropagateErrors(errorsPerSession, startingState.senderUUID)
if (!startingState.isFlowResumed) { if (!startingState.isFlowResumed) {
actions += Action.CreateTransaction actions += Action.CreateTransaction
@ -59,7 +73,12 @@ class KilledFlowTransition(
actions += Action.AddFlowException(context.id, killedFlowError.exception) actions += Action.AddFlowException(context.id, killedFlowError.exception)
} }
val signalSessionsEndMap = currentState.checkpoint.checkpointState.sessions.map { (sessionId, sessionState) ->
sessionId to Pair(sessionState.lastSenderUUID, sessionState.lastSenderSeqNo)
}.toMap()
actions += Action.PersistDeduplicationFacts(startingState.pendingDeduplicationHandlers) actions += Action.PersistDeduplicationFacts(startingState.pendingDeduplicationHandlers)
actions += Action.SignalSessionsHasEnded(signalSessionsEndMap)
actions += Action.ReleaseSoftLocks(context.id.uuid) actions += Action.ReleaseSoftLocks(context.id.uuid)
actions += Action.CommitTransaction(currentState) actions += Action.CommitTransaction(currentState)
actions += Action.AcknowledgeMessages(startingState.pendingDeduplicationHandlers) actions += Action.AcknowledgeMessages(startingState.pendingDeduplicationHandlers)
@ -91,32 +110,37 @@ class KilledFlowTransition(
} }
} }
// Purposely left the same as [bufferErrorMessagesInInitiatingSessions] in [ErrorFlowTransition] so that it can be refactored /**
// Buffer error messages in Initiating sessions, return the initialised ones. * Buffers errors message for initiating states and filters the initiated states.
* Returns a pair that consists of:
* - a map containing the initiated states as filtered from the ones provided as input.
* - a map containing the new state of all the sessions.
*/
private fun bufferErrorMessagesInInitiatingSessions( private fun bufferErrorMessagesInInitiatingSessions(
sessions: Map<SessionId, SessionState>, sessions: Map<SessionId, SessionState>,
errorMessages: List<ErrorSessionMessage> errorMessages: List<ErrorSessionMessage>
): Pair<List<SessionState.Initiated>, Map<SessionId, SessionState>> { ): Pair<Map<SessionId, SessionState.Initiated>, Map<SessionId, SessionState>> {
val newSessions = sessions.mapValues { (sourceSessionId, sessionState) -> val newSessions = sessions.mapValues { (sourceSessionId, sessionState) ->
if (sessionState is SessionState.Initiating && sessionState.rejectionError == null) { if (sessionState is SessionState.Initiating && sessionState.rejectionError == null) {
// *prepend* the error messages in order to error the other sessions ASAP. The other messages will var currentSequenceNumber = sessionState.nextSendingSeqNumber
// be delivered all the same, they just won't trigger flow resumption because of dirtiness. val errorMessagesWithDeduplication = errorMessages.map { errorMessage ->
val errorMessagesWithDeduplication = errorMessages.map { val messageIdentifier = MessageIdentifier(MessageType.SESSION_ERROR, sessionState.shardId, sourceSessionId.calculateInitiatedSessionId(), currentSequenceNumber, startingState.checkpoint.checkpointState.suspensionTime)
DeduplicationId.createForError(it.errorId, sourceSessionId) to it currentSequenceNumber++
messageIdentifier to errorMessage
} }
sessionState.copy(bufferedMessages = errorMessagesWithDeduplication + sessionState.bufferedMessages) sessionState.bufferMessages(errorMessagesWithDeduplication)
} else { } else {
sessionState sessionState
} }
} }
// if we have already received error message from the other side, we don't include that session in the list to avoid propagating errors. // if we have already received error message from the other side, we don't include that session in the list to avoid propagating errors.
val initiatedSessions = sessions.values.mapNotNull { session -> val initiatedSessions = sessions.mapNotNull { (sessionId, sessionState) ->
if (session is SessionState.Initiated && !session.otherSideErrored) { if (sessionState is SessionState.Initiated && !sessionState.otherSideErrored) {
session sessionId to sessionState
} else { } else {
null null
} }
} }.toMap()
return Pair(initiatedSessions, newSessions) return Pair(initiatedSessions, newSessions)
} }

View File

@ -10,6 +10,9 @@ import net.corda.core.internal.FlowIORequest
import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.SerializedBytes
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.toNonEmptySet import net.corda.core.utilities.toNonEmptySet
import net.corda.node.services.messaging.MessageIdentifier
import net.corda.node.services.messaging.SenderDeduplicationInfo
import net.corda.node.services.messaging.generateShardId
import net.corda.node.services.statemachine.* import net.corda.node.services.statemachine.*
import org.slf4j.Logger import org.slf4j.Logger
import kotlin.collections.LinkedHashMap import kotlin.collections.LinkedHashMap
@ -177,14 +180,19 @@ class StartedFlowTransition(
} }
if (existingSessionsToRemove.isNotEmpty()) { if (existingSessionsToRemove.isNotEmpty()) {
val sendEndMessageActions = existingSessionsToRemove.values.mapIndexed { index, state -> val sendEndMessageActions = existingSessionsToRemove.map { (_, sessionState) ->
val sinkSessionId = (state as SessionState.Initiated).peerSinkSessionId val sinkSessionId = (sessionState as SessionState.Initiated).peerSinkSessionId
val message = ExistingSessionMessage(sinkSessionId, EndSessionMessage) val message = ExistingSessionMessage(sinkSessionId, EndSessionMessage)
val deduplicationId = DeduplicationId.createForNormal(currentState.checkpoint, index, state) val messageType = MessageType.inferFromMessage(message)
Action.SendExisting(state.peerParty, message, SenderDeduplicationId(deduplicationId, currentState.senderUUID)) val messageIdentifier = MessageIdentifier(messageType, generateShardId(context.id.toString()), sinkSessionId, sessionState.nextSendingSeqNumber, currentState.checkpoint.checkpointState.suspensionTime)
Action.SendExisting(sessionState.peerParty, message, SenderDeduplicationInfo(messageIdentifier, currentState.senderUUID))
} }
val signalSessionsEndMap = existingSessionsToRemove.map { (sessionId, _) ->
val sessionState = currentState.checkpoint.checkpointState.sessions[sessionId]!!
sessionId to Pair(sessionState.lastSenderUUID, sessionState.lastSenderSeqNo)
}.toMap()
currentState = currentState.copy(checkpoint = currentState.checkpoint.removeSessions(existingSessionsToRemove.keys)) currentState = currentState.copy(checkpoint = currentState.checkpoint.removeSessions(existingSessionsToRemove.keys), closedSessionsPendingToBeSignalled = currentState.closedSessionsPendingToBeSignalled + signalSessionsEndMap)
actions.add(Action.RemoveSessionBindings(sessionIdsToRemove)) actions.add(Action.RemoveSessionBindings(sessionIdsToRemove))
actions.add(Action.SendMultiple(emptyList(), sendEndMessageActions)) actions.add(Action.SendMultiple(emptyList(), sendEndMessageActions))
} }
@ -239,23 +247,23 @@ class StartedFlowTransition(
@Suppress("ComplexMethod", "NestedBlockDepth") @Suppress("ComplexMethod", "NestedBlockDepth")
private fun pollSessionMessages(sessions: SessionMap, sessionIds: Set<SessionId>): PollResult? { private fun pollSessionMessages(sessions: SessionMap, sessionIds: Set<SessionId>): PollResult? {
val newSessionMessages = LinkedHashMap(sessions) val newSessionStates = LinkedHashMap(sessions)
val resultMessages = LinkedHashMap<SessionId, SerializedBytes<Any>>() val resultMessages = LinkedHashMap<SessionId, SerializedBytes<Any>>()
var someNotFound = false var someNotFound = false
for (sessionId in sessionIds) { for (sessionId in sessionIds) {
val sessionState = sessions[sessionId] val sessionState = sessions[sessionId]
when (sessionState) { when (sessionState) {
is SessionState.Initiated -> { is SessionState.Initiated -> {
val messages = sessionState.receivedMessages if (!sessionState.hasNextMessageArrived()) {
if (messages.isEmpty()) {
someNotFound = true someNotFound = true
} else { } else {
newSessionMessages[sessionId] = sessionState.copy(receivedMessages = messages.subList(1, messages.size).toList()) val (message, newState) = sessionState.extractMessage()
newSessionStates[sessionId] = newState
// at this point, we've already checked for errors and session ends, so it's guaranteed that the first message will be a data message. // at this point, we've already checked for errors and session ends, so it's guaranteed that the first message will be a data message.
resultMessages[sessionId] = if (messages[0] is EndSessionMessage) { resultMessages[sessionId] = if (message is EndSessionMessage) {
throw UnexpectedFlowEndException("Received session end message instead of a data session message. Mismatched send and receive?") throw UnexpectedFlowEndException("Received session end message instead of a data session message. Mismatched send and receive?")
} else { } else {
(messages[0] as DataSessionMessage).payload (message as DataSessionMessage).payload
} }
} }
} }
@ -267,14 +275,13 @@ class StartedFlowTransition(
return if (someNotFound) { return if (someNotFound) {
return null return null
} else { } else {
PollResult(resultMessages, newSessionMessages) PollResult(resultMessages, newSessionStates)
} }
} }
private fun TransitionBuilder.sendInitialSessionMessagesIfNeeded(sourceSessions: Set<SessionId>) { private fun TransitionBuilder.sendInitialSessionMessagesIfNeeded(sourceSessions: Set<SessionId>) {
val checkpoint = startingState.checkpoint val checkpoint = startingState.checkpoint
val newSessions = LinkedHashMap<SessionId, SessionState>(checkpoint.checkpointState.sessions) val newSessions = LinkedHashMap<SessionId, SessionState>(checkpoint.checkpointState.sessions)
var index = 0
for (sourceSessionId in sourceSessions) { for (sourceSessionId in sourceSessions) {
val sessionState = checkpoint.checkpointState.sessions[sourceSessionId] val sessionState = checkpoint.checkpointState.sessions[sourceSessionId]
if (sessionState == null) { if (sessionState == null) {
@ -283,14 +290,22 @@ class StartedFlowTransition(
if (sessionState !is SessionState.Uninitiated) { if (sessionState !is SessionState.Uninitiated) {
continue continue
} }
val shardId = generateShardId(context.id.toString())
val counterpartySessionId = sourceSessionId.calculateInitiatedSessionId()
val initialMessage = createInitialSessionMessage(sessionState.initiatingSubFlow, sourceSessionId, sessionState.additionalEntropy, null) val initialMessage = createInitialSessionMessage(sessionState.initiatingSubFlow, sourceSessionId, sessionState.additionalEntropy, null)
val newSessionState = SessionState.Initiating( val newSessionState = SessionState.Initiating(
bufferedMessages = emptyList(), bufferedMessages = emptyList(),
rejectionError = null, rejectionError = null,
deduplicationSeed = sessionState.deduplicationSeed deduplicationSeed = sessionState.deduplicationSeed,
nextSendingSeqNumber = 1,
shardId = shardId,
receivedMessages = emptyMap(),
lastSenderUUID = null,
lastSenderSeqNo = null
) )
val deduplicationId = DeduplicationId.createForNormal(checkpoint, index++, newSessionState) val messageType = MessageType.inferFromMessage(initialMessage)
actions.add(Action.SendInitial(sessionState.destination, initialMessage, SenderDeduplicationId(deduplicationId, startingState.senderUUID))) val messageIdentifier = MessageIdentifier(messageType, shardId, counterpartySessionId, 0, checkpoint.checkpointState.suspensionTime)
actions.add(Action.SendInitial(sessionState.destination, initialMessage, SenderDeduplicationInfo(messageIdentifier, startingState.senderUUID)))
newSessions[sourceSessionId] = newSessionState newSessions[sourceSessionId] = newSessionState
} }
currentState = currentState.copy(checkpoint = checkpoint.setSessions(sessions = newSessions)) currentState = currentState.copy(checkpoint = checkpoint.setSessions(sessions = newSessions))
@ -313,37 +328,60 @@ class StartedFlowTransition(
private fun TransitionBuilder.sendToSessionsTransition(sourceSessionIdToMessage: Map<SessionId, SerializedBytes<Any>>) { private fun TransitionBuilder.sendToSessionsTransition(sourceSessionIdToMessage: Map<SessionId, SerializedBytes<Any>>) {
val checkpoint = startingState.checkpoint val checkpoint = startingState.checkpoint
val newSessions = LinkedHashMap(checkpoint.checkpointState.sessions) val newSessions = LinkedHashMap(checkpoint.checkpointState.sessions)
var index = 0
val messagesByType = sourceSessionIdToMessage.toList() val messagesByType = sourceSessionIdToMessage.toList()
.map { (sourceSessionId, message) -> Triple(sourceSessionId, checkpoint.checkpointState.sessions[sourceSessionId]!!, message) } .map { (sourceSessionId, message) -> Triple(sourceSessionId, checkpoint.checkpointState.sessions[sourceSessionId]!!, message) }
.groupBy { it.second::class } .groupBy { it.second::class }
val sendInitialActions = messagesByType[SessionState.Uninitiated::class]?.map { (sourceSessionId, sessionState, message) -> val sendInitialActions = messagesByType[SessionState.Uninitiated::class]?.mapNotNull { (sourceSessionId, sessionState, message) ->
val uninitiatedSessionState = sessionState as SessionState.Uninitiated val uninitiatedSessionState = sessionState as SessionState.Uninitiated
val deduplicationId = DeduplicationId.createForNormal(checkpoint, index++, sessionState) val shardId = generateShardId(context.id.toString())
val initialMessage = createInitialSessionMessage(uninitiatedSessionState.initiatingSubFlow, sourceSessionId, uninitiatedSessionState.additionalEntropy, message) if (sessionState.hasBeenAcknowledged != null) {
newSessions[sourceSessionId] = SessionState.Initiating( newSessions[sourceSessionId] = SessionState.Initiated(
bufferedMessages = emptyList(), peerParty = sessionState.hasBeenAcknowledged.first,
rejectionError = null, peerFlowInfo = sessionState.hasBeenAcknowledged.second.initiatedFlowInfo,
deduplicationSeed = uninitiatedSessionState.deduplicationSeed receivedMessages = emptyMap(),
) otherSideErrored = false,
Action.SendInitial(uninitiatedSessionState.destination, initialMessage, SenderDeduplicationId(deduplicationId, startingState.senderUUID)) peerSinkSessionId = sessionState.hasBeenAcknowledged.second.initiatedSessionId,
deduplicationSeed = sessionState.deduplicationSeed,
nextSendingSeqNumber = 1,
lastProcessedSeqNumber = 0,
shardId = shardId,
lastSenderUUID = null,
lastSenderSeqNo = null
)
null
} else {
newSessions[sourceSessionId] = SessionState.Initiating(
bufferedMessages = emptyList(),
rejectionError = null,
deduplicationSeed = sessionState.deduplicationSeed,
nextSendingSeqNumber = 1,
shardId = shardId,
receivedMessages = emptyMap(),
lastSenderUUID = null,
lastSenderSeqNo = null
)
val initialMessage = createInitialSessionMessage(uninitiatedSessionState.initiatingSubFlow, sourceSessionId, uninitiatedSessionState.additionalEntropy, message)
val messageType = MessageType.inferFromMessage(initialMessage)
val messageIdentifier = MessageIdentifier(messageType, shardId, sourceSessionId.calculateInitiatedSessionId(), 0, checkpoint.checkpointState.suspensionTime)
Action.SendInitial(uninitiatedSessionState.destination, initialMessage, SenderDeduplicationInfo(messageIdentifier, startingState.senderUUID))
}
} ?: emptyList() } ?: emptyList()
messagesByType[SessionState.Initiating::class]?.forEach { (sourceSessionId, sessionState, message) -> messagesByType[SessionState.Initiating::class]?.forEach { (sourceSessionId, sessionState, message) ->
val initiatingSessionState = sessionState as SessionState.Initiating val initiatingSessionState = sessionState as SessionState.Initiating
val sessionMessage = DataSessionMessage(message) val sessionMessage = DataSessionMessage(message)
val deduplicationId = DeduplicationId.createForNormal(checkpoint, index++, initiatingSessionState) val messageIdentifier = MessageIdentifier(MessageType.DATA_MESSAGE, sessionState.shardId, sourceSessionId.calculateInitiatedSessionId(), sessionState.nextSendingSeqNumber, checkpoint.checkpointState.suspensionTime)
val newBufferedMessages = initiatingSessionState.bufferedMessages + Pair(deduplicationId, sessionMessage) newSessions[sourceSessionId] = initiatingSessionState.bufferMessage(messageIdentifier, sessionMessage)
newSessions[sourceSessionId] = initiatingSessionState.copy(bufferedMessages = newBufferedMessages)
} }
val sendExistingActions = messagesByType[SessionState.Initiated::class]?.map {(_, sessionState, message) -> val sendExistingActions = messagesByType[SessionState.Initiated::class]?.map {(sourceSessionId, sessionState, message) ->
val initiatedSessionState = sessionState as SessionState.Initiated val initiatedSessionState = sessionState as SessionState.Initiated
val sessionMessage = DataSessionMessage(message) val sessionMessage = DataSessionMessage(message)
val deduplicationId = DeduplicationId.createForNormal(checkpoint, index++, initiatedSessionState)
val sinkSessionId = initiatedSessionState.peerSinkSessionId val sinkSessionId = initiatedSessionState.peerSinkSessionId
val existingMessage = ExistingSessionMessage(sinkSessionId, sessionMessage) val existingMessage = ExistingSessionMessage(sinkSessionId, sessionMessage)
Action.SendExisting(initiatedSessionState.peerParty, existingMessage, SenderDeduplicationId(deduplicationId, startingState.senderUUID)) val messageType = MessageType.inferFromMessage(existingMessage)
val messageIdentifier = MessageIdentifier(messageType, sessionState.shardId, sessionState.peerSinkSessionId, sessionState.nextSendingSeqNumber, checkpoint.checkpointState.suspensionTime)
newSessions[sourceSessionId] = initiatedSessionState.copy(nextSendingSeqNumber = initiatedSessionState.nextSendingSeqNumber + 1)
Action.SendExisting(initiatedSessionState.peerParty, existingMessage, SenderDeduplicationInfo(messageIdentifier, startingState.senderUUID))
} ?: emptyList() } ?: emptyList()
if (sendInitialActions.isNotEmpty() || sendExistingActions.isNotEmpty()) { if (sendInitialActions.isNotEmpty() || sendExistingActions.isNotEmpty()) {
@ -372,11 +410,10 @@ class StartedFlowTransition(
} }
} }
is SessionState.Initiated -> { is SessionState.Initiated -> {
if (sessionState.receivedMessages.isNotEmpty() && sessionState.receivedMessages.first() is ErrorSessionMessage) { if (sessionState.hasErrored()) {
val errorMessage = sessionState.receivedMessages.first() as ErrorSessionMessage val (message, newSessionState) = sessionState.extractMessage()
val exception = convertErrorMessageToException(errorMessage, sessionState.peerParty) val exception = convertErrorMessageToException(message as ErrorSessionMessage, sessionState.peerParty)
val newSessionState = sessionState.copy(receivedMessages = sessionState.receivedMessages.subList(1, sessionState.receivedMessages.size), otherSideErrored = true) val newCheckpoint = startingState.checkpoint.addSession(sessionId to newSessionState.copy(otherSideErrored = true))
val newCheckpoint = startingState.checkpoint.addSession(sessionId to newSessionState)
newState = startingState.copy(checkpoint = newCheckpoint) newState = startingState.copy(checkpoint = newCheckpoint)
listOf(exception) listOf(exception)
} else { } else {
@ -541,8 +578,8 @@ class StartedFlowTransition(
private fun findSessionsToBeTerminated(startingState: StateMachineState): SessionMap { private fun findSessionsToBeTerminated(startingState: StateMachineState): SessionMap {
return startingState.checkpoint.checkpointState.sessionsToBeClosed.mapNotNull { sessionId -> return startingState.checkpoint.checkpointState.sessionsToBeClosed.mapNotNull { sessionId ->
val sessionState = startingState.checkpoint.checkpointState.sessions[sessionId]!! as SessionState.Initiated val sessionState = startingState.checkpoint.checkpointState.sessions[sessionId]!!
if (sessionState.receivedMessages.isNotEmpty() && sessionState.receivedMessages.first() is EndSessionMessage) { if (sessionState is SessionState.Initiated && sessionState.receivedMessages.containsKey(sessionState.lastProcessedSeqNumber + 1) && sessionState.receivedMessages[sessionState.lastProcessedSeqNumber + 1] is EndSessionMessage) {
sessionId to sessionState sessionId to sessionState
} else { } else {
null null

View File

@ -4,12 +4,13 @@ import net.corda.core.flows.StateMachineRunId
import net.corda.node.services.statemachine.Event import net.corda.node.services.statemachine.Event
import net.corda.node.services.statemachine.StateMachineState import net.corda.node.services.statemachine.StateMachineState
import java.security.SecureRandom import java.security.SecureRandom
import java.time.Instant
class StateMachine( class StateMachine(
val id: StateMachineRunId, val id: StateMachineRunId,
val secureRandom: SecureRandom val secureRandom: SecureRandom
) { ) {
fun transition(event: Event, state: StateMachineState): TransitionResult { fun transition(event: Event, state: StateMachineState, time: Instant): TransitionResult {
return TopLevelTransition(TransitionContext(id, secureRandom), state, event).transition() return TopLevelTransition(TransitionContext(id, secureRandom, time), state, event).transition()
} }
} }

View File

@ -7,9 +7,9 @@ import net.corda.core.serialization.deserialize
import net.corda.core.utilities.Try import net.corda.core.utilities.Try
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.node.services.messaging.DeduplicationHandler import net.corda.node.services.messaging.DeduplicationHandler
import net.corda.node.services.messaging.MessageIdentifier
import net.corda.node.services.statemachine.Action import net.corda.node.services.statemachine.Action
import net.corda.node.services.statemachine.Checkpoint import net.corda.node.services.statemachine.Checkpoint
import net.corda.node.services.statemachine.DeduplicationId
import net.corda.node.services.statemachine.EndSessionMessage import net.corda.node.services.statemachine.EndSessionMessage
import net.corda.node.services.statemachine.ErrorState import net.corda.node.services.statemachine.ErrorState
import net.corda.node.services.statemachine.Event import net.corda.node.services.statemachine.Event
@ -19,7 +19,8 @@ import net.corda.node.services.statemachine.FlowRemovalReason
import net.corda.node.services.statemachine.FlowSessionImpl import net.corda.node.services.statemachine.FlowSessionImpl
import net.corda.node.services.statemachine.FlowState import net.corda.node.services.statemachine.FlowState
import net.corda.node.services.statemachine.InitialSessionMessage import net.corda.node.services.statemachine.InitialSessionMessage
import net.corda.node.services.statemachine.SenderDeduplicationId import net.corda.node.services.messaging.SenderDeduplicationInfo
import net.corda.node.services.statemachine.MessageType
import net.corda.node.services.statemachine.SessionId import net.corda.node.services.statemachine.SessionId
import net.corda.node.services.statemachine.SessionMessage import net.corda.node.services.statemachine.SessionMessage
import net.corda.node.services.statemachine.SessionState import net.corda.node.services.statemachine.SessionState
@ -197,7 +198,8 @@ class TopLevelTransition(
checkpointState.invocationContext checkpointState.invocationContext
}, },
numberOfSuspends = checkpointState.numberOfSuspends + 1, numberOfSuspends = checkpointState.numberOfSuspends + 1,
numberOfCommits = checkpointState.numberOfCommits + 1 numberOfCommits = checkpointState.numberOfCommits + 1,
suspensionTime = context.time
) )
copy( copy(
flowState = FlowState.Started(event.ioRequest, event.fiber), flowState = FlowState.Started(event.ioRequest, event.fiber),
@ -217,11 +219,13 @@ class TopLevelTransition(
currentState = startingState.copy( currentState = startingState.copy(
checkpoint = newCheckpoint, checkpoint = newCheckpoint,
pendingDeduplicationHandlers = emptyList(), pendingDeduplicationHandlers = emptyList(),
closedSessionsPendingToBeSignalled = emptyMap(),
isFlowResumed = false, isFlowResumed = false,
isAnyCheckpointPersisted = true isAnyCheckpointPersisted = true
) )
actions += Action.PersistCheckpoint(context.id, newCheckpoint, isCheckpointUpdate = startingState.isAnyCheckpointPersisted) actions += Action.PersistCheckpoint(context.id, newCheckpoint, isCheckpointUpdate = startingState.isAnyCheckpointPersisted)
actions += Action.PersistDeduplicationFacts(startingState.pendingDeduplicationHandlers) actions += Action.PersistDeduplicationFacts(startingState.pendingDeduplicationHandlers)
actions += Action.SignalSessionsHasEnded(startingState.closedSessionsPendingToBeSignalled)
actions += Action.CommitTransaction(currentState) actions += Action.CommitTransaction(currentState)
actions += Action.AcknowledgeMessages(startingState.pendingDeduplicationHandlers) actions += Action.AcknowledgeMessages(startingState.pendingDeduplicationHandlers)
actions += Action.ScheduleEvent(Event.DoRemainingWork) actions += Action.ScheduleEvent(Event.DoRemainingWork)
@ -240,12 +244,14 @@ class TopLevelTransition(
checkpoint = checkpoint.copy( checkpoint = checkpoint.copy(
checkpointState = checkpoint.checkpointState.copy( checkpointState = checkpoint.checkpointState.copy(
numberOfSuspends = checkpoint.checkpointState.numberOfSuspends + 1, numberOfSuspends = checkpoint.checkpointState.numberOfSuspends + 1,
numberOfCommits = checkpoint.checkpointState.numberOfCommits + 1 numberOfCommits = checkpoint.checkpointState.numberOfCommits + 1,
suspensionTime = context.time
), ),
flowState = FlowState.Finished, flowState = FlowState.Finished,
result = event.returnValue, result = event.returnValue,
status = Checkpoint.FlowStatus.COMPLETED status = Checkpoint.FlowStatus.COMPLETED
), ).removeSessions(checkpoint.checkpointState.sessions.keys),
closedSessionsPendingToBeSignalled = emptyMap(),
pendingDeduplicationHandlers = emptyList(), pendingDeduplicationHandlers = emptyList(),
isFlowResumed = false, isFlowResumed = false,
isRemoved = true isRemoved = true
@ -263,7 +269,12 @@ class TopLevelTransition(
) )
} }
val signalSessionsEndMap = startingState.checkpoint.checkpointState.sessions.map { (sessionId, sessionState) ->
sessionId to Pair(sessionState.lastSenderUUID, sessionState.lastSenderSeqNo)
}.toMap()
actions += Action.PersistDeduplicationFacts(startingState.pendingDeduplicationHandlers) actions += Action.PersistDeduplicationFacts(startingState.pendingDeduplicationHandlers)
actions += Action.SignalSessionsHasEnded(signalSessionsEndMap)
actions += Action.ReleaseSoftLocks(event.softLocksId) actions += Action.ReleaseSoftLocks(event.softLocksId)
actions += Action.CommitTransaction(currentState) actions += Action.CommitTransaction(currentState)
actions += Action.AcknowledgeMessages(startingState.pendingDeduplicationHandlers) actions += Action.AcknowledgeMessages(startingState.pendingDeduplicationHandlers)
@ -284,11 +295,12 @@ class TopLevelTransition(
} }
private fun TransitionBuilder.sendEndMessages() { private fun TransitionBuilder.sendEndMessages() {
val sendEndMessageActions = currentState.checkpoint.checkpointState.sessions.values.mapIndexed { index, state -> val sendEndMessageActions = startingState.checkpoint.checkpointState.sessions.map { (sessionId, state) ->
if (state is SessionState.Initiated) { if (state is SessionState.Initiated) {
val message = ExistingSessionMessage(state.peerSinkSessionId, EndSessionMessage) val message = ExistingSessionMessage(state.peerSinkSessionId, EndSessionMessage)
val deduplicationId = DeduplicationId.createForNormal(currentState.checkpoint, index, state) val messageType = MessageType.inferFromMessage(message)
Action.SendExisting(state.peerParty, message, SenderDeduplicationId(deduplicationId, currentState.senderUUID)) val messageIdentifier = MessageIdentifier(messageType, state.shardId, state.peerSinkSessionId, state.nextSendingSeqNumber, startingState.checkpoint.checkpointState.suspensionTime)
Action.SendExisting(state.peerParty, message, SenderDeduplicationInfo(messageIdentifier, startingState.senderUUID))
} else { } else {
null null
} }
@ -306,7 +318,7 @@ class TopLevelTransition(
} }
val sourceSessionId = SessionId.createRandom(context.secureRandom) val sourceSessionId = SessionId.createRandom(context.secureRandom)
val sessionImpl = FlowSessionImpl(event.destination, event.wellKnownParty, sourceSessionId) val sessionImpl = FlowSessionImpl(event.destination, event.wellKnownParty, sourceSessionId)
val newSessions = checkpoint.checkpointState.sessions + (sourceSessionId to SessionState.Uninitiated(event.destination, initiatingSubFlow, sourceSessionId, context.secureRandom.nextLong())) val newSessions = checkpoint.checkpointState.sessions + (sourceSessionId to SessionState.Uninitiated(event.destination, initiatingSubFlow, sourceSessionId, context.secureRandom.nextLong(), null, null, emptyMap(), null, null))
currentState = currentState.copy(checkpoint = checkpoint.setSessions(newSessions)) currentState = currentState.copy(checkpoint = checkpoint.setSessions(newSessions))
actions.add(Action.AddSessionBinding(context.id, sourceSessionId)) actions.add(Action.AddSessionBinding(context.id, sourceSessionId))
FlowContinuation.Resume(sessionImpl) FlowContinuation.Resume(sessionImpl)
@ -361,10 +373,12 @@ class TopLevelTransition(
numberOfCommits = startingState.checkpoint.checkpointState.numberOfCommits + 1 numberOfCommits = startingState.checkpoint.checkpointState.numberOfCommits + 1
) )
), ),
pendingDeduplicationHandlers = startingState.pendingDeduplicationHandlers - flowStartEvents pendingDeduplicationHandlers = startingState.pendingDeduplicationHandlers - flowStartEvents,
closedSessionsPendingToBeSignalled = emptyMap()
) )
actions += Action.CreateTransaction actions += Action.CreateTransaction
actions += Action.PersistDeduplicationFacts(flowStartEvents) actions += Action.PersistDeduplicationFacts(flowStartEvents)
actions += Action.SignalSessionsHasEnded(startingState.closedSessionsPendingToBeSignalled)
actions += Action.PersistCheckpoint(context.id, newCheckpoint, isCheckpointUpdate = startingState.isAnyCheckpointPersisted) actions += Action.PersistCheckpoint(context.id, newCheckpoint, isCheckpointUpdate = startingState.isAnyCheckpointPersisted)
actions += Action.CommitTransaction(currentState) actions += Action.CommitTransaction(currentState)
actions += Action.AcknowledgeMessages(flowStartEvents) actions += Action.AcknowledgeMessages(flowStartEvents)
@ -401,4 +415,5 @@ class TopLevelTransition(
FlowContinuation.Abort FlowContinuation.Abort
} }
} }
} }

View File

@ -3,6 +3,7 @@ package net.corda.node.services.statemachine.transitions
import net.corda.core.flows.StateMachineRunId import net.corda.core.flows.StateMachineRunId
import net.corda.node.services.statemachine.StateMachineState import net.corda.node.services.statemachine.StateMachineState
import java.security.SecureRandom import java.security.SecureRandom
import java.time.Instant
/** /**
* An interface used to separate out different parts of the state machine transition function. * An interface used to separate out different parts of the state machine transition function.
@ -28,5 +29,6 @@ interface Transition {
class TransitionContext( class TransitionContext(
val id: StateMachineRunId, val id: StateMachineRunId,
val secureRandom: SecureRandom val secureRandom: SecureRandom,
val time: Instant
) )

View File

@ -80,6 +80,6 @@ class TransitionBuilder(val context: TransitionContext, initialState: StateMachi
} }
class CannotFindSessionException(sessionId: SessionId) : IllegalStateException("Couldn't find session with id $sessionId") class CannotFindSessionException(sessionId: SessionId) : IllegalStateException("Couldn't find session with id $sessionId")
class UnexpectedEventInState : IllegalStateException("Unexpected event") class UnexpectedEventInState(message: String = "") : IllegalStateException("An unexpected event happened. $message")
class PrematureSessionCloseException(sessionId: SessionId): IllegalStateException("The following session was closed before it was initialised: $sessionId") class PrematureSessionCloseException(sessionId: SessionId): IllegalStateException("The following session was closed before it was initialised: $sessionId")
class PrematureSessionEndException(sessionId: SessionId): IllegalStateException("A premature session end message was received before the session was initialised: $sessionId") class PrematureSessionEndException(sessionId: SessionId): IllegalStateException("A premature session end message was received before the session was initialised: $sessionId")

View File

@ -1,14 +1,15 @@
package net.corda.node.services.statemachine.transitions package net.corda.node.services.statemachine.transitions
import net.corda.core.flows.FlowInfo import net.corda.core.flows.FlowInfo
import net.corda.node.services.messaging.MessageIdentifier
import net.corda.node.services.statemachine.Action import net.corda.node.services.statemachine.Action
import net.corda.node.services.statemachine.ConfirmSessionMessage import net.corda.node.services.statemachine.ConfirmSessionMessage
import net.corda.node.services.statemachine.DataSessionMessage import net.corda.node.services.statemachine.DataSessionMessage
import net.corda.node.services.statemachine.DeduplicationId
import net.corda.node.services.statemachine.ExistingSessionMessage import net.corda.node.services.statemachine.ExistingSessionMessage
import net.corda.node.services.statemachine.FlowStart import net.corda.node.services.statemachine.FlowStart
import net.corda.node.services.statemachine.FlowState import net.corda.node.services.statemachine.FlowState
import net.corda.node.services.statemachine.SenderDeduplicationId import net.corda.node.services.messaging.SenderDeduplicationInfo
import net.corda.node.services.statemachine.MessageType
import net.corda.node.services.statemachine.SessionState import net.corda.node.services.statemachine.SessionState
import net.corda.node.services.statemachine.StateMachineState import net.corda.node.services.statemachine.StateMachineState
@ -50,25 +51,28 @@ class UnstartedFlowTransition(
appName = initiatingMessage.appName appName = initiatingMessage.appName
), ),
receivedMessages = if (initiatingMessage.firstPayload == null) { receivedMessages = if (initiatingMessage.firstPayload == null) {
emptyList() emptyMap()
} else { } else {
listOf(DataSessionMessage(initiatingMessage.firstPayload)) mapOf(0 to DataSessionMessage(initiatingMessage.firstPayload))
}, },
deduplicationSeed = "D-${initiatingMessage.initiatorSessionId.toLong}-${initiatingMessage.initiationEntropy}", deduplicationSeed = "D-${initiatingMessage.initiatorSessionId.value}-${initiatingMessage.initiationEntropy}",
otherSideErrored = false otherSideErrored = false,
nextSendingSeqNumber = 1,
lastProcessedSeqNumber = if (initiatingMessage.firstPayload == null) {
0
} else {
-1
},
shardId = flowStart.shardIdentifier,
lastSenderUUID = flowStart.senderUUID,
lastSenderSeqNo = flowStart.senderSequenceNumber
) )
val confirmationMessage = ConfirmSessionMessage(flowStart.initiatedSessionId, flowStart.initiatedFlowInfo) val confirmationMessage = ConfirmSessionMessage(flowStart.initiatedSessionId, flowStart.initiatedFlowInfo)
val sessionMessage = ExistingSessionMessage(initiatingMessage.initiatorSessionId, confirmationMessage) val sessionMessage = ExistingSessionMessage(initiatingMessage.initiatorSessionId, confirmationMessage)
currentState = currentState.copy( val messageType = MessageType.inferFromMessage(sessionMessage)
checkpoint = currentState.checkpoint.setSessions(mapOf(flowStart.initiatedSessionId to initiatedState)) val messageIdentifier = MessageIdentifier(messageType, flowStart.shardIdentifier, initiatingMessage.initiatorSessionId, 0, currentState.checkpoint.checkpointState.suspensionTime)
) currentState = currentState.copy(checkpoint = currentState.checkpoint.setSessions(mapOf(flowStart.initiatedSessionId to initiatedState)))
actions.add( actions.add(Action.SendExisting(flowStart.peerSession.counterparty, sessionMessage, SenderDeduplicationInfo(messageIdentifier, currentState.senderUUID)))
Action.SendExisting(
flowStart.peerSession.counterparty,
sessionMessage,
SenderDeduplicationId(DeduplicationId.createForNormal(currentState.checkpoint, 0, initiatedState), currentState.senderUUID)
)
)
} }
// Create initial checkpoint and acknowledge triggering messages. // Create initial checkpoint and acknowledge triggering messages.

View File

@ -55,7 +55,7 @@ open class DefaultNamedCacheFactory protected constructor(private val metricRegi
name == "FlowDrainingMode_nodeProperties" -> caffeine.maximumSize(defaultCacheSize) name == "FlowDrainingMode_nodeProperties" -> caffeine.maximumSize(defaultCacheSize)
name == "ContractUpgradeService_upgrades" -> caffeine.maximumSize(defaultCacheSize) name == "ContractUpgradeService_upgrades" -> caffeine.maximumSize(defaultCacheSize)
name == "PersistentUniquenessProvider_transactions" -> caffeine.maximumSize(defaultCacheSize) name == "PersistentUniquenessProvider_transactions" -> caffeine.maximumSize(defaultCacheSize)
name == "P2PMessageDeduplicator_processedMessages" -> caffeine.maximumSize(defaultCacheSize) name == "P2PMessageDeduplicator_sessionData" -> caffeine.maximumSize(defaultCacheSize)
name == "DeduplicationChecker_watermark" -> caffeine name == "DeduplicationChecker_watermark" -> caffeine
name == "BFTNonValidatingNotaryService_transactions" -> caffeine.maximumSize(defaultCacheSize) name == "BFTNonValidatingNotaryService_transactions" -> caffeine.maximumSize(defaultCacheSize)
name == "RaftUniquenessProvider_transactions" -> caffeine.maximumSize(defaultCacheSize) name == "RaftUniquenessProvider_transactions" -> caffeine.maximumSize(defaultCacheSize)

View File

@ -34,5 +34,7 @@
<include file="migration/node-core.changelog-v19.xml"/> <include file="migration/node-core.changelog-v19.xml"/>
<include file="migration/node-core.changelog-v19-postgres.xml"/> <include file="migration/node-core.changelog-v19-postgres.xml"/>
<include file="migration/node-core.changelog-v19-keys.xml"/> <include file="migration/node-core.changelog-v19-keys.xml"/>
<include file="migration/node-core.changelog-v20.xml"/>
<include file="migration/node-core.changelog-v21.xml"/>
</databaseChangeLog> </databaseChangeLog>

View File

@ -0,0 +1,28 @@
<?xml version="1.1" encoding="UTF-8" standalone="no"?>
<databaseChangeLog xmlns="http://www.liquibase.org/xml/ns/dbchangelog"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://www.liquibase.org/xml/ns/dbchangelog http://www.liquibase.org/xml/ns/dbchangelog/dbchangelog-3.5.xsd"
logicalFilePath="migration/node-services.changelog-init.xml">
<changeSet author="R3.Corda" id="add_session_data_table">
<createTable tableName="node_session_data">
<column name="session_id" type="NUMBER(128)">
<constraints nullable="false"/>
</column>
<column name="init_generation_time" type="timestamp">
<constraints nullable="false"/>
</column>
<column name="sender_hash" type="NVARCHAR(64)">
<constraints nullable="true"/>
</column>
<column name="init_sequence_number" type="BIGINT">
<constraints nullable="true"/>
</column>
<column name="last_sequence_number" type="BIGINT">
<constraints nullable="true"/>
</column>
</createTable>
<addPrimaryKey columnNames="session_id" constraintName="node_session_data_pk" tableName="node_session_data"/>
</changeSet>
</databaseChangeLog>

View File

@ -0,0 +1,47 @@
package net.corda.node.services.messaging
import net.corda.node.services.statemachine.MessageType
import net.corda.node.services.statemachine.SessionId
import org.assertj.core.api.Assertions
import org.assertj.core.api.Assertions.assertThat
import org.junit.Test
import java.lang.IllegalArgumentException
import java.math.BigInteger
import java.time.Instant
class MessageIdentifierTest {
private val shardIdentifier = "XXXXXXXX"
private val sessionIdentifier = SessionId(BigInteger.valueOf(14))
private val sessionSequenceNumber = 1
private val timestamp = Instant.ofEpochMilli(100)
private val messageIdString = "XI-0000000000000064-XXXXXXXX-0000000000000000000000000000000E-1"
@Test(timeout=300_000)
fun `can parse message identifier from string value`() {
val messageIdentifier = MessageIdentifier(MessageType.SESSION_INIT, shardIdentifier, sessionIdentifier, sessionSequenceNumber, timestamp)
assertThat(messageIdentifier.toString()).isEqualTo(messageIdString)
}
@Test(timeout=300_000)
fun `can convert message identifier object to string value`() {
val messageIdentifierString = messageIdString
val messageIdentifier = MessageIdentifier.parse(messageIdentifierString)
assertThat(messageIdentifier.messageType).isInstanceOf(MessageType.SESSION_INIT::class.java)
assertThat(messageIdentifier.shardIdentifier).isEqualTo(shardIdentifier)
assertThat(messageIdentifier.sessionIdentifier).isEqualTo(sessionIdentifier)
assertThat(messageIdentifier.sessionSequenceNumber).isEqualTo(1)
assertThat(messageIdentifier.timestamp).isEqualTo(timestamp)
}
@Test(timeout=300_000)
fun `shard identifier needs to be 8 characters long`() {
Assertions.assertThatThrownBy { MessageIdentifier(MessageType.SESSION_INIT, "XX", sessionIdentifier, 1, timestamp) }
.isInstanceOf(IllegalArgumentException::class.java)
.hasMessage("Shard identifier needs to be 8 characters long, but it was XX")
}
}

View File

@ -0,0 +1,28 @@
package net.corda.node.services.messaging
import net.corda.node.services.statemachine.SessionId
import org.assertj.core.api.Assertions.assertThatThrownBy
import org.junit.Test
import java.lang.IllegalArgumentException
import java.math.BigInteger
class SessionIdTest {
@Test(timeout=300_000)
fun `session identifier cannot be negative`() {
assertThatThrownBy { SessionId(BigInteger.valueOf(-1)) }
.isInstanceOf(IllegalArgumentException::class.java)
.hasMessage("Session identifier cannot be a negative number, but it was -1")
}
@Test(timeout=300_000)
fun `session identifier needs to be a number that can be represented in maximum 128 bits`() {
val largestSessionIdentifierValue = BigInteger.valueOf(2).pow(128).minus(BigInteger.ONE)
val largestValidSessionId = SessionId(largestSessionIdentifierValue)
assertThatThrownBy { SessionId(largestSessionIdentifierValue.plus(BigInteger.ONE)) }
.isInstanceOf(IllegalArgumentException::class.java)
.hasMessage("The size of a session identifier cannot exceed 128 bits, but it was 340282366920938463463374607431768211456")
}
}

View File

@ -873,7 +873,8 @@ class DBCheckpointStorageTests {
frozenLogic, frozenLogic,
ALICE, ALICE,
SubFlowVersion.CoreFlow(version), SubFlowVersion.CoreFlow(version),
false false,
Clock.systemUTC().instant()
) )
.getOrThrow() .getOrThrow()
return id to checkpoint return id to checkpoint

View File

@ -194,7 +194,7 @@ class CheckpointDumperImplTest {
override fun call() {} override fun call() {}
} }
val frozenLogic = logic.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT) val frozenLogic = logic.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT)
val checkpoint = Checkpoint.create(InvocationContext.shell(), FlowStart.Explicit, logic.javaClass, frozenLogic, myself.identity.party, SubFlowVersion.CoreFlow(version), false) val checkpoint = Checkpoint.create(InvocationContext.shell(), FlowStart.Explicit, logic.javaClass, frozenLogic, myself.identity.party, SubFlowVersion.CoreFlow(version), false, Clock.systemUTC().instant())
.getOrThrow() .getOrThrow()
return id to checkpoint return id to checkpoint
} }

View File

@ -8,7 +8,6 @@ import net.corda.client.rpc.notUsed
import net.corda.core.concurrent.CordaFuture import net.corda.core.concurrent.CordaFuture
import net.corda.core.contracts.ContractState import net.corda.core.contracts.ContractState
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.random63BitValue
import net.corda.core.flows.Destination import net.corda.core.flows.Destination
import net.corda.core.flows.FinalityFlow import net.corda.core.flows.FinalityFlow
import net.corda.core.flows.FlowException import net.corda.core.flows.FlowException
@ -42,6 +41,8 @@ import net.corda.core.utilities.ProgressTracker.Change
import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.seconds import net.corda.core.utilities.seconds
import net.corda.core.utilities.unwrap import net.corda.core.utilities.unwrap
import net.corda.node.services.messaging.MessageIdentifier
import net.corda.node.services.messaging.SenderDeduplicationInfo
import net.corda.node.services.persistence.CheckpointPerformanceRecorder import net.corda.node.services.persistence.CheckpointPerformanceRecorder
import net.corda.node.services.persistence.DBCheckpointStorage import net.corda.node.services.persistence.DBCheckpointStorage
import net.corda.node.services.persistence.checkpoints import net.corda.node.services.persistence.checkpoints
@ -80,6 +81,8 @@ import org.junit.Before
import org.junit.Test import org.junit.Test
import rx.Notification import rx.Notification
import rx.Observable import rx.Observable
import java.math.BigInteger
import java.security.SecureRandom
import java.sql.SQLTransientConnectionException import java.sql.SQLTransientConnectionException
import java.time.Clock import java.time.Clock
import java.time.Duration import java.time.Duration
@ -571,7 +574,7 @@ class FlowFrameworkTests {
@Test(timeout=300_000) @Test(timeout=300_000)
fun `session init with unknown class is sent to the flow hospital, from where we then drop it`() { fun `session init with unknown class is sent to the flow hospital, from where we then drop it`() {
aliceNode.sendSessionMessage(InitialSessionMessage(SessionId(random63BitValue()), 0, "not.a.real.Class", 1, "", null), bob) aliceNode.sendSessionMessage(InitialSessionMessage(SessionId.createRandom(SecureRandom()), 0, "not.a.real.Class", 1, "", null), bob)
mockNet.runNetwork() mockNet.runNetwork()
assertThat(receivedSessionMessages).hasSize(1) // Only the session-init is expected as the session-reject is blocked by the flow hospital assertThat(receivedSessionMessages).hasSize(1) // Only the session-init is expected as the session-reject is blocked by the flow hospital
val medicalRecords = bobNode.smm.flowHospital.track().apply { updates.notUsed() }.snapshot val medicalRecords = bobNode.smm.flowHospital.track().apply { updates.notUsed() }.snapshot
@ -587,7 +590,7 @@ class FlowFrameworkTests {
@Test(timeout=300_000) @Test(timeout=300_000)
fun `non-flow class in session init`() { fun `non-flow class in session init`() {
aliceNode.sendSessionMessage(InitialSessionMessage(SessionId(random63BitValue()), 0, String::class.java.name, 1, "", null), bob) aliceNode.sendSessionMessage(InitialSessionMessage(SessionId.createRandom(SecureRandom()), 0, String::class.java.name, 1, "", null), bob)
mockNet.runNetwork() mockNet.runNetwork()
assertThat(receivedSessionMessages).hasSize(2) // Only the session-init and session-reject are expected assertThat(receivedSessionMessages).hasSize(2) // Only the session-init and session-reject are expected
val lastMessage = receivedSessionMessages.last().message as ExistingSessionMessage val lastMessage = receivedSessionMessages.last().message as ExistingSessionMessage
@ -897,7 +900,7 @@ class FlowFrameworkTests {
} }
//region Helpers //region Helpers
private val normalEnd = ExistingSessionMessage(SessionId(0), EndSessionMessage) // NormalSessionEnd(0) private val normalEnd = ExistingSessionMessage(SessionId(BigInteger.valueOf(0)), EndSessionMessage) // NormalSessionEnd(0)
private fun assertSessionTransfers(vararg expected: SessionTransfer) { private fun assertSessionTransfers(vararg expected: SessionTransfer) {
assertThat(receivedSessionMessages).containsExactly(*expected) assertThat(receivedSessionMessages).containsExactly(*expected)
@ -1039,21 +1042,21 @@ class FlowFrameworkTests {
} }
internal fun sessionConfirm(flowVersion: Int = 1) = internal fun sessionConfirm(flowVersion: Int = 1) =
ExistingSessionMessage(SessionId(0), ConfirmSessionMessage(SessionId(0), FlowInfo(flowVersion, ""))) ExistingSessionMessage(SessionId(BigInteger.valueOf(0)), ConfirmSessionMessage(SessionId(BigInteger.valueOf(0)), FlowInfo(flowVersion, "")))
internal inline fun <reified P : FlowLogic<*>> TestStartedNode.getSingleFlow(): Pair<P, CordaFuture<*>> { internal inline fun <reified P : FlowLogic<*>> TestStartedNode.getSingleFlow(): Pair<P, CordaFuture<*>> {
return smm.findStateMachines(P::class.java).single() return smm.findStateMachines(P::class.java).single()
} }
private fun sanitise(message: SessionMessage) = when (message) { private fun sanitise(message: SessionMessage) = when (message) {
is InitialSessionMessage -> message.copy(initiatorSessionId = SessionId(0), initiationEntropy = 0, appName = "") is InitialSessionMessage -> message.copy(initiatorSessionId = SessionId(BigInteger.valueOf(0)), initiationEntropy = 0, appName = "")
is ExistingSessionMessage -> { is ExistingSessionMessage -> {
val payload = message.payload val payload = message.payload
message.copy( message.copy(
recipientSessionId = SessionId(0), recipientSessionId = SessionId(BigInteger.valueOf(0)),
payload = when (payload) { payload = when (payload) {
is ConfirmSessionMessage -> payload.copy( is ConfirmSessionMessage -> payload.copy(
initiatedSessionId = SessionId(0), initiatedSessionId = SessionId(BigInteger.valueOf(0)),
initiatedFlowInfo = payload.initiatedFlowInfo.copy(appName = "") initiatedFlowInfo = payload.initiatedFlowInfo.copy(appName = "")
) )
is ErrorSessionMessage -> payload.copy( is ErrorSessionMessage -> payload.copy(
@ -1076,7 +1079,8 @@ internal fun Observable<MessageTransfer>.toSessionTransfers(): Observable<Sessio
internal fun TestStartedNode.sendSessionMessage(message: SessionMessage, destination: Party) { internal fun TestStartedNode.sendSessionMessage(message: SessionMessage, destination: Party) {
services.networkService.apply { services.networkService.apply {
val address = getAddressOfParty(PartyInfo.SingleNode(destination, emptyList())) val address = getAddressOfParty(PartyInfo.SingleNode(destination, emptyList()))
send(createMessage(FlowMessagingImpl.sessionTopic, message.serialize().bytes), address) val messageIdentifier = MessageIdentifier(MessageType.DATA_MESSAGE, "XXXXXXXX", SessionId(BigInteger.valueOf(14)), 0, Clock.systemUTC().instant())
send(createMessage(FlowMessagingImpl.sessionTopic, message.serialize().bytes, SenderDeduplicationInfo(messageIdentifier, null), emptyMap()), address)
} }
} }
@ -1087,7 +1091,7 @@ inline fun <reified T> DatabaseTransaction.findRecordsFromDatabase(): List<T> {
} }
internal fun errorMessage(errorResponse: FlowException? = null) = internal fun errorMessage(errorResponse: FlowException? = null) =
ExistingSessionMessage(SessionId(0), ErrorSessionMessage(errorResponse, 0)) ExistingSessionMessage(SessionId(BigInteger.valueOf(0)), ErrorSessionMessage(errorResponse, 0))
internal infix fun TestStartedNode.sent(message: SessionMessage): Pair<Int, SessionMessage> = Pair(internals.id, message) internal infix fun TestStartedNode.sent(message: SessionMessage): Pair<Int, SessionMessage> = Pair(internals.id, message)
internal infix fun Pair<Int, SessionMessage>.to(node: TestStartedNode): SessionTransfer = internal infix fun Pair<Int, SessionMessage>.to(node: TestStartedNode): SessionTransfer =
@ -1103,10 +1107,10 @@ internal data class SessionTransfer(val from: Int, val message: SessionMessage,
} }
internal fun sessionInit(clientFlowClass: KClass<out FlowLogic<*>>, flowVersion: Int = 1, payload: Any? = null): InitialSessionMessage { internal fun sessionInit(clientFlowClass: KClass<out FlowLogic<*>>, flowVersion: Int = 1, payload: Any? = null): InitialSessionMessage {
return InitialSessionMessage(SessionId(0), 0, clientFlowClass.java.name, flowVersion, "", payload?.serialize()) return InitialSessionMessage(SessionId(BigInteger.valueOf(0)), 0, clientFlowClass.java.name, flowVersion, "", payload?.serialize())
} }
internal fun sessionData(payload: Any) = ExistingSessionMessage(SessionId(0), DataSessionMessage(payload.serialize())) internal fun sessionData(payload: Any) = ExistingSessionMessage(SessionId(BigInteger.valueOf(0)), DataSessionMessage(payload.serialize()))
@InitiatingFlow @InitiatingFlow
internal open class SendFlow(private val payload: Any, private vararg val otherParties: Party) : FlowLogic<FlowInfo>() { internal open class SendFlow(private val payload: Any, private vararg val otherParties: Party) : FlowLogic<FlowInfo>() {

View File

@ -18,6 +18,7 @@ import org.junit.After
import org.junit.Before import org.junit.Before
import org.junit.Test import org.junit.Test
import rx.Observable import rx.Observable
import java.math.BigInteger
import java.util.* import java.util.*
class FlowFrameworkTripartyTests { class FlowFrameworkTripartyTests {
@ -168,7 +169,7 @@ class FlowFrameworkTripartyTests {
) )
} }
private val normalEnd = ExistingSessionMessage(SessionId(0), EndSessionMessage) // NormalSessionEnd(0) private val normalEnd = ExistingSessionMessage(SessionId(BigInteger.valueOf(0)), EndSessionMessage) // NormalSessionEnd(0)
private fun assertSessionTransfers(node: TestStartedNode, vararg expected: SessionTransfer): List<SessionTransfer> { private fun assertSessionTransfers(node: TestStartedNode, vararg expected: SessionTransfer): List<SessionTransfer> {
val actualForNode = receivedSessionMessages.filter { it.from == node.internals.id || it.to == node.network.myAddress } val actualForNode = receivedSessionMessages.filter { it.from == node.internals.id || it.to == node.network.myAddress }

View File

@ -0,0 +1,42 @@
package net.corda.node.services.statemachine
import org.assertj.core.api.Assertions.*
import org.junit.Test
import java.lang.IllegalArgumentException
import java.math.BigInteger
class SessionIdTest {
companion object {
private val LARGEST_SESSION_ID_VALUE = BigInteger.valueOf(2).pow(128).minus(BigInteger.ONE)
}
@Test(timeout=300_000)
fun `session id must be positive and representable in 128 bits`() {
assertThatThrownBy { SessionId(BigInteger.ZERO.minus(BigInteger.ONE)) }
.isInstanceOf(IllegalArgumentException::class.java)
.hasMessage("Session identifier cannot be a negative number, but it was -1")
assertThatThrownBy { SessionId(LARGEST_SESSION_ID_VALUE.plus(BigInteger.ONE)) }
.isInstanceOf(IllegalArgumentException::class.java)
.hasMessageContaining("The size of a session identifier cannot exceed 128 bits, but it was")
val correctSessionId = SessionId(LARGEST_SESSION_ID_VALUE)
}
@Test(timeout=300_000)
fun `initiated session id is calculated properly`() {
val sessionId = SessionId(BigInteger.ONE)
val initiatedSessionId = sessionId.calculateInitiatedSessionId()
assertThat(initiatedSessionId.value.toLong()).isEqualTo(2)
}
@Test(timeout=300_000)
fun `calculation of initiated session id wraps around`() {
val sessionId = SessionId(LARGEST_SESSION_ID_VALUE)
val initiatedSessionId = sessionId.calculateInitiatedSessionId()
assertThat(initiatedSessionId.value.toLong()).isEqualTo(0)
}
}

View File

@ -2,7 +2,7 @@ package net.corda.testing.node.internal
import net.corda.core.utilities.ByteSequence import net.corda.core.utilities.ByteSequence
import net.corda.node.services.messaging.Message import net.corda.node.services.messaging.Message
import net.corda.node.services.statemachine.DeduplicationId import net.corda.node.services.messaging.MessageIdentifier
import java.time.Instant import java.time.Instant
/** /**
@ -10,7 +10,7 @@ import java.time.Instant
*/ */
data class InMemoryMessage(override val topic: String, data class InMemoryMessage(override val topic: String,
override val data: ByteSequence, override val data: ByteSequence,
override val uniqueMessageId: DeduplicationId, override val uniqueMessageId: MessageIdentifier,
override val debugTimestamp: Instant = Instant.now(), override val debugTimestamp: Instant = Instant.now(),
override val senderUUID: String? = null) : Message { override val senderUUID: String? = null) : Message {

View File

@ -1,5 +1,6 @@
package net.corda.testing.node.internal package net.corda.testing.node.internal
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.flows.StateMachineRunId import net.corda.core.flows.StateMachineRunId
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.identity.PartyAndCertificate import net.corda.core.identity.PartyAndCertificate
@ -13,9 +14,9 @@ import net.corda.core.utilities.OpaqueBytes
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.messaging.* import net.corda.node.services.messaging.*
import net.corda.node.services.statemachine.DeduplicationId
import net.corda.node.services.statemachine.ExternalEvent import net.corda.node.services.statemachine.ExternalEvent
import net.corda.node.services.statemachine.SenderDeduplicationId import net.corda.node.services.messaging.SenderDeduplicationInfo
import net.corda.node.services.statemachine.SessionId
import net.corda.node.utilities.AffinityExecutor import net.corda.node.utilities.AffinityExecutor
import net.corda.nodeapi.internal.lifecycle.ServiceStateHelper import net.corda.nodeapi.internal.lifecycle.ServiceStateHelper
import net.corda.nodeapi.internal.lifecycle.ServiceStateSupport import net.corda.nodeapi.internal.lifecycle.ServiceStateSupport
@ -46,9 +47,9 @@ class MockNodeMessagingService(private val configuration: NodeConfiguration,
} }
private val state = ThreadBox(InnerState()) private val state = ThreadBox(InnerState())
private val processedMessages: MutableSet<DeduplicationId> = Collections.synchronizedSet(HashSet<DeduplicationId>()) private val processedMessages: MutableSet<MessageIdentifier> = Collections.synchronizedSet(HashSet())
override val ourSenderUUID: String = UUID.randomUUID().toString() override val ourSenderUUID: SenderUUID = UUID.randomUUID().toString()
private var _myAddress: InMemoryMessagingNetwork.PeerHandle? = null private var _myAddress: InMemoryMessagingNetwork.PeerHandle? = null
override val myAddress: InMemoryMessagingNetwork.PeerHandle get() = checkNotNull(_myAddress) { "Not started" } override val myAddress: InMemoryMessagingNetwork.PeerHandle get() = checkNotNull(_myAddress) { "Not started" }
@ -167,6 +168,11 @@ class MockNodeMessagingService(private val configuration: NodeConfiguration,
} }
} }
@Suspendable
override fun sessionEnded(sessionId: SessionId, senderUUID: SenderUUID?, senderSequenceNumber: SenderSequenceNumber?) {
// nothing to do here.
}
override fun close() { override fun close() {
backgroundThread?.let { backgroundThread?.let {
it.interrupt() it.interrupt()
@ -178,8 +184,8 @@ class MockNodeMessagingService(private val configuration: NodeConfiguration,
} }
/** Returns the given (topic & session, data) pair as a newly created message object. */ /** Returns the given (topic & session, data) pair as a newly created message object. */
override fun createMessage(topic: String, data: ByteArray, deduplicationId: SenderDeduplicationId, additionalHeaders: Map<String, String>): Message { override fun createMessage(topic: String, data: ByteArray, deduplicationInfo: SenderDeduplicationInfo, additionalHeaders: Map<String, String>): Message {
return InMemoryMessage(topic, OpaqueBytes(data), deduplicationId.deduplicationId, senderUUID = deduplicationId.senderUUID) return InMemoryMessage(topic, OpaqueBytes(data), deduplicationInfo.messageIdentifier, senderUUID = deduplicationInfo.senderUUID)
} }
/** /**
@ -269,7 +275,7 @@ class MockNodeMessagingService(private val configuration: NodeConfiguration,
private data class InMemoryReceivedMessage(override val topic: String, private data class InMemoryReceivedMessage(override val topic: String,
override val data: ByteSequence, override val data: ByteSequence,
override val platformVersion: Int, override val platformVersion: Int,
override val uniqueMessageId: DeduplicationId, override val uniqueMessageId: MessageIdentifier,
override val debugTimestamp: Instant, override val debugTimestamp: Instant,
override val peer: CordaX500Name, override val peer: CordaX500Name,
override val senderUUID: String? = null, override val senderUUID: String? = null,

View File

@ -4,14 +4,24 @@ import net.corda.core.messaging.AllPossibleRecipients
import net.corda.core.serialization.internal.effectiveSerializationEnv import net.corda.core.serialization.internal.effectiveSerializationEnv
import net.corda.node.services.messaging.Message import net.corda.node.services.messaging.Message
import net.corda.coretesting.internal.rigorousMock import net.corda.coretesting.internal.rigorousMock
import net.corda.node.services.messaging.MessageIdentifier
import net.corda.node.services.messaging.SenderDeduplicationInfo
import net.corda.node.services.statemachine.MessageType
import net.corda.node.services.statemachine.SessionId
import org.assertj.core.api.Assertions.assertThatThrownBy import org.assertj.core.api.Assertions.assertThatThrownBy
import org.junit.After import org.junit.After
import org.junit.Test import org.junit.Test
import java.math.BigInteger
import java.time.Clock
import java.util.* import java.util.*
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertTrue import kotlin.test.assertTrue
class InternalMockNetworkTests { class InternalMockNetworkTests {
companion object {
private val MESSAGE_IDENTIFIER = MessageIdentifier(MessageType.DATA_MESSAGE, "XXXXXXXX", SessionId(BigInteger.valueOf(14)), 0, Clock.systemUTC().instant())
}
lateinit var mockNet: InternalMockNetwork lateinit var mockNet: InternalMockNetwork
@After @After
@ -39,7 +49,7 @@ class InternalMockNetworkTests {
} }
// Node 1 sends a message and it should end up in finalDelivery, after we run the network // Node 1 sends a message and it should end up in finalDelivery, after we run the network
node1.network.send(node1.network.createMessage("test.topic", data = bits), node2.network.myAddress) node1.network.send(node1.network.createMessage("test.topic", bits, SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap()), node2.network.myAddress)
mockNet.runNetwork(rounds = 1) mockNet.runNetwork(rounds = 1)
@ -58,7 +68,7 @@ class InternalMockNetworkTests {
var counter = 0 var counter = 0
listOf(node1, node2, node3).forEach { it.network.addMessageHandler("test.topic") { _, _, _ -> counter++ } } listOf(node1, node2, node3).forEach { it.network.addMessageHandler("test.topic") { _, _, _ -> counter++ } }
node1.network.send(node2.network.createMessage("test.topic", data = bits), rigorousMock<AllPossibleRecipients>()) node1.network.send(node2.network.createMessage("test.topic", bits, SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap()), rigorousMock<AllPossibleRecipients>())
mockNet.runNetwork(rounds = 1) mockNet.runNetwork(rounds = 1)
assertEquals(3, counter) assertEquals(3, counter)
} }
@ -79,8 +89,8 @@ class InternalMockNetworkTests {
received++ received++
} }
val invalidMessage = node2.network.createMessage("invalid_message", data = ByteArray(1)) val invalidMessage = node2.network.createMessage("invalid_message", ByteArray(1), SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
val validMessage = node2.network.createMessage("valid_message", data = ByteArray(1)) val validMessage = node2.network.createMessage("valid_message", ByteArray(1), SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
node2.network.send(invalidMessage, node1.network.myAddress) node2.network.send(invalidMessage, node1.network.myAddress)
mockNet.runNetwork() mockNet.runNetwork()
assertEquals(0, received) assertEquals(0, received)
@ -91,8 +101,8 @@ class InternalMockNetworkTests {
// Here's the core of the test; previously the unhandled message would cause runNetwork() to abort early, so // Here's the core of the test; previously the unhandled message would cause runNetwork() to abort early, so
// this would fail. Make fresh messages to stop duplicate uniqueMessageId causing drops // this would fail. Make fresh messages to stop duplicate uniqueMessageId causing drops
val invalidMessage2 = node2.network.createMessage("invalid_message", data = ByteArray(1)) val invalidMessage2 = node2.network.createMessage("invalid_message", ByteArray(1), SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
val validMessage2 = node2.network.createMessage("valid_message", data = ByteArray(1)) val validMessage2 = node2.network.createMessage("valid_message", ByteArray(1), SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
node2.network.send(invalidMessage2, node1.network.myAddress) node2.network.send(invalidMessage2, node1.network.myAddress)
node2.network.send(validMessage2, node1.network.myAddress) node2.network.send(validMessage2, node1.network.myAddress)
mockNet.runNetwork() mockNet.runNetwork()