mirror of
https://github.com/corda/corda.git
synced 2025-03-11 15:04:14 +00:00
Introduce MessageIdentifier and related tests
This commit is contained in:
parent
8ed3dc1150
commit
64e7fdd83a
@ -24,6 +24,8 @@ import net.corda.testing.internal.TestingNamedCacheFactory
|
||||
import net.corda.testing.internal.configureDatabase
|
||||
import net.corda.coretesting.internal.rigorousMock
|
||||
import net.corda.coretesting.internal.stubs.CertificateStoreStubs
|
||||
import net.corda.node.services.statemachine.MessageType
|
||||
import net.corda.node.services.statemachine.SessionId
|
||||
import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties
|
||||
import net.corda.testing.node.internal.MOCK_VERSION_INFO
|
||||
import org.apache.activemq.artemis.api.core.ActiveMQConnectionTimedOutException
|
||||
@ -35,7 +37,9 @@ import org.junit.Rule
|
||||
import org.junit.Test
|
||||
import org.junit.rules.TemporaryFolder
|
||||
import rx.subjects.PublishSubject
|
||||
import java.math.BigInteger
|
||||
import java.net.ServerSocket
|
||||
import java.time.Clock
|
||||
import java.util.concurrent.BlockingQueue
|
||||
import java.util.concurrent.LinkedBlockingQueue
|
||||
import java.util.concurrent.TimeUnit.MILLISECONDS
|
||||
@ -47,6 +51,7 @@ import kotlin.test.assertTrue
|
||||
class ArtemisMessagingTest {
|
||||
companion object {
|
||||
const val TOPIC = "platform.self"
|
||||
private val MESSAGE_IDENTIFIER = MessageIdentifier(MessageType.DATA_MESSAGE, "XXXXXXXX", SessionId(BigInteger.valueOf(14)), 0, Clock.systemUTC().instant())
|
||||
}
|
||||
|
||||
@Rule
|
||||
@ -142,7 +147,7 @@ class ArtemisMessagingTest {
|
||||
@Test(timeout=300_000)
|
||||
fun `client should be able to send message to itself`() {
|
||||
val (messagingClient, receivedMessages) = createAndStartClientAndServer()
|
||||
val message = messagingClient.createMessage(TOPIC, data = "first msg".toByteArray())
|
||||
val message = messagingClient.createMessage(TOPIC, "first msg".toByteArray(), SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
|
||||
messagingClient.send(message, messagingClient.myAddress)
|
||||
|
||||
val actual: Message = receivedMessages.take()
|
||||
@ -153,14 +158,14 @@ class ArtemisMessagingTest {
|
||||
@Test(timeout=300_000)
|
||||
fun `client should fail if message exceed maxMessageSize limit`() {
|
||||
val (messagingClient, receivedMessages) = createAndStartClientAndServer()
|
||||
val message = messagingClient.createMessage(TOPIC, data = ByteArray(MAX_MESSAGE_SIZE))
|
||||
val message = messagingClient.createMessage(TOPIC, ByteArray(MAX_MESSAGE_SIZE), SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
|
||||
messagingClient.send(message, messagingClient.myAddress)
|
||||
|
||||
val actual: Message = receivedMessages.take()
|
||||
assertTrue(ByteArray(MAX_MESSAGE_SIZE).contentEquals(actual.data.bytes))
|
||||
assertNull(receivedMessages.poll(200, MILLISECONDS))
|
||||
|
||||
val tooLagerMessage = messagingClient.createMessage(TOPIC, data = ByteArray(MAX_MESSAGE_SIZE + 1))
|
||||
val tooLagerMessage = messagingClient.createMessage(TOPIC, ByteArray(MAX_MESSAGE_SIZE + 1), SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
|
||||
assertThatThrownBy {
|
||||
messagingClient.send(tooLagerMessage, messagingClient.myAddress)
|
||||
}.isInstanceOf(IllegalArgumentException::class.java)
|
||||
@ -172,14 +177,14 @@ class ArtemisMessagingTest {
|
||||
@Test(timeout=300_000)
|
||||
fun `server should not process if incoming message exceed maxMessageSize limit`() {
|
||||
val (messagingClient, receivedMessages) = createAndStartClientAndServer(clientMaxMessageSize = 100_000, serverMaxMessageSize = 50_000)
|
||||
val message = messagingClient.createMessage(TOPIC, data = ByteArray(50_000))
|
||||
val message = messagingClient.createMessage(TOPIC, ByteArray(50_000), SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
|
||||
messagingClient.send(message, messagingClient.myAddress)
|
||||
|
||||
val actual: Message = receivedMessages.take()
|
||||
assertTrue(ByteArray(50_000).contentEquals(actual.data.bytes))
|
||||
assertNull(receivedMessages.poll(200, MILLISECONDS))
|
||||
|
||||
val tooLagerMessage = messagingClient.createMessage(TOPIC, data = ByteArray(100_000))
|
||||
val tooLagerMessage = messagingClient.createMessage(TOPIC, ByteArray(100_000), SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
|
||||
assertThatThrownBy {
|
||||
messagingClient.send(tooLagerMessage, messagingClient.myAddress)
|
||||
}.isInstanceOf(ActiveMQConnectionTimedOutException::class.java)
|
||||
@ -189,7 +194,7 @@ class ArtemisMessagingTest {
|
||||
@Test(timeout=300_000)
|
||||
fun `platform version is included in the message`() {
|
||||
val (messagingClient, receivedMessages) = createAndStartClientAndServer(platformVersion = 3)
|
||||
val message = messagingClient.createMessage(TOPIC, data = "first msg".toByteArray())
|
||||
val message = messagingClient.createMessage(TOPIC, "first msg".toByteArray(), SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
|
||||
messagingClient.send(message, messagingClient.myAddress)
|
||||
|
||||
val received = receivedMessages.take()
|
||||
|
@ -10,9 +10,13 @@ import net.corda.core.serialization.deserialize
|
||||
import net.corda.core.serialization.serialize
|
||||
import net.corda.core.utilities.getOrThrow
|
||||
import net.corda.core.utilities.seconds
|
||||
import net.corda.node.services.messaging.MessageIdentifier
|
||||
import net.corda.node.services.messaging.MessagingService
|
||||
import net.corda.node.services.messaging.ReceivedMessage
|
||||
import net.corda.node.services.messaging.SenderDeduplicationInfo
|
||||
import net.corda.node.services.messaging.send
|
||||
import net.corda.node.services.statemachine.MessageType
|
||||
import net.corda.node.services.statemachine.SessionId
|
||||
import net.corda.testing.driver.DriverDSL
|
||||
import net.corda.testing.driver.DriverParameters
|
||||
import net.corda.testing.driver.InProcess
|
||||
@ -23,12 +27,15 @@ import net.corda.testing.node.NotarySpec
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.junit.Ignore
|
||||
import org.junit.Test
|
||||
import java.math.BigInteger
|
||||
import java.time.Clock
|
||||
import java.util.*
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
|
||||
class P2PMessagingTest {
|
||||
private companion object {
|
||||
val DISTRIBUTED_SERVICE_NAME = CordaX500Name("DistributedService", "London", "GB")
|
||||
private val MESSAGE_IDENTIFIER = MessageIdentifier(MessageType.DATA_MESSAGE, "XXXXXXXX", SessionId(BigInteger.valueOf(14)), 0, Clock.systemUTC().instant())
|
||||
}
|
||||
|
||||
@Test(timeout=300_000)
|
||||
@ -72,7 +79,7 @@ class P2PMessagingTest {
|
||||
private fun InProcess.respondWith(message: Any) {
|
||||
internalServices.networkService.addMessageHandler("test.request") { netMessage, _, handler ->
|
||||
val request = netMessage.data.deserialize<TestRequest>()
|
||||
val response = internalServices.networkService.createMessage("test.response", message.serialize().bytes)
|
||||
val response = internalServices.networkService.createMessage("test.response", message.serialize().bytes, SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
|
||||
internalServices.networkService.send(response, request.replyTo)
|
||||
handler.afterDatabaseTransaction()
|
||||
}
|
||||
@ -83,7 +90,7 @@ class P2PMessagingTest {
|
||||
internalServices.networkService.runOnNextMessage("test.response") { netMessage ->
|
||||
response.set(netMessage.data.deserialize())
|
||||
}
|
||||
internalServices.networkService.send("test.request", TestRequest(replyTo = internalServices.networkService.myAddress), target)
|
||||
internalServices.networkService.send("test.request", TestRequest(replyTo = internalServices.networkService.myAddress), target, SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
|
||||
return response
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,93 @@
|
||||
package net.corda.node.services.messaging
|
||||
|
||||
import net.corda.core.crypto.SecureHash
|
||||
import net.corda.node.services.statemachine.MessageType
|
||||
import net.corda.node.services.statemachine.SessionId
|
||||
import java.lang.IllegalStateException
|
||||
import java.math.BigInteger
|
||||
import java.time.Instant
|
||||
|
||||
/**
|
||||
* This represents the unique identifier for every message.
|
||||
* It's composed of multiple segments.
|
||||
*
|
||||
* @property messageType the type of the message.
|
||||
* @property shardIdentifier an identifier that can be used to partition messages into groups for sharding purposes.
|
||||
* This is supposed to have the same value for messages that correspond to the same business-level flow. It is
|
||||
* @property sessionIdentifier the identifier of the session this message belongs to. This corresponds to the identifier of the session on the receiving side.
|
||||
* @property sessionSequenceNumber the sequence number of the message inside the session. This can be used to handle out-of-order delivery.
|
||||
* @property timestamp the time when the message was requested to be sent.
|
||||
* This is expected to remain the same across replays of the same message and represent the moment in time when this message was initially scheduled to be sent.
|
||||
*/
|
||||
data class MessageIdentifier(
|
||||
val messageType: MessageType,
|
||||
val shardIdentifier: String,
|
||||
val sessionIdentifier: SessionId,
|
||||
val sessionSequenceNumber: Int,
|
||||
val timestamp: Instant
|
||||
) {
|
||||
init {
|
||||
require(shardIdentifier.length == 8) { "Shard identifier needs to be 8 characters long, but it was $shardIdentifier" }
|
||||
}
|
||||
|
||||
companion object {
|
||||
const val LONG_SIZE_IN_HEX = 16 // 64 / 4
|
||||
const val SESSION_ID_SIZE_IN_HEX = SessionId.MAX_BIT_SIZE / 4
|
||||
|
||||
fun parse(id: String): MessageIdentifier {
|
||||
val prefix = id.substring(0, 2)
|
||||
val messageType = prefixToMessageType(prefix)
|
||||
val timestamp = java.lang.Long.parseUnsignedLong(id.substring(3, 19), 16)
|
||||
val shardIdentifier = id.substring(20, 28)
|
||||
val sessionId = BigInteger(id.substring(29, 61), 16)
|
||||
val sessionSequenceNumber = Integer.parseInt(id.substring(62), 16)
|
||||
return MessageIdentifier(messageType, shardIdentifier, SessionId(sessionId), sessionSequenceNumber, Instant.ofEpochMilli(timestamp))
|
||||
}
|
||||
|
||||
private fun messageTypeToPrefix(messageType: MessageType): String {
|
||||
return when(messageType) {
|
||||
MessageType.SESSION_INIT -> "XI"
|
||||
MessageType.SESSION_CONFIRM -> "XC"
|
||||
MessageType.SESSION_REJECT -> "XR"
|
||||
MessageType.DATA_MESSAGE -> "XD"
|
||||
MessageType.SESSION_END -> "XE"
|
||||
MessageType.SESSION_ERROR -> "XX"
|
||||
}
|
||||
}
|
||||
|
||||
private fun prefixToMessageType(prefix: String): MessageType {
|
||||
return when(prefix) {
|
||||
"XI" -> MessageType.SESSION_INIT
|
||||
"XC" -> MessageType.SESSION_CONFIRM
|
||||
"XR" -> MessageType.SESSION_REJECT
|
||||
"XD" -> MessageType.DATA_MESSAGE
|
||||
"XE" -> MessageType.SESSION_END
|
||||
"XX" -> MessageType.SESSION_ERROR
|
||||
else -> throw IllegalStateException("Invalid prefix: $prefix")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun toString(): String {
|
||||
val prefix = messageTypeToPrefix(messageType)
|
||||
val encodedSessionIdentifier = String.format("%1$0${SESSION_ID_SIZE_IN_HEX}X", sessionIdentifier.value)
|
||||
val encodedSequenceNumber = Integer.toHexString(sessionSequenceNumber).toUpperCase()
|
||||
val encodedTimestamp = String.format("%1$0${LONG_SIZE_IN_HEX}X", timestamp.toEpochMilli())
|
||||
return "$prefix-$encodedTimestamp-$shardIdentifier-$encodedSessionIdentifier-$encodedSequenceNumber"
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
fun generateShardId(flowIdentifier: String): String {
|
||||
return SecureHash.sha256(flowIdentifier).prefixChars(8)
|
||||
}
|
||||
|
||||
/**
|
||||
* A unique identifier for a sender that might be different across restarts.
|
||||
* It is used to help identify when messages are being sent continuously without errors or message are sent after the sender recovered from an error.
|
||||
*/
|
||||
typealias SenderUUID = String
|
||||
/**
|
||||
* A global sequence number for all the messages sent by a sender.
|
||||
*/
|
||||
typealias SenderSequenceNumber = Long
|
@ -1,7 +1,6 @@
|
||||
package net.corda.node.services.messaging
|
||||
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import net.corda.core.crypto.newSecureRandom
|
||||
import net.corda.core.identity.CordaX500Name
|
||||
import net.corda.core.messaging.MessageRecipients
|
||||
import net.corda.core.messaging.SingleMessageRecipient
|
||||
@ -9,9 +8,8 @@ import net.corda.core.node.services.PartyInfo
|
||||
import net.corda.core.serialization.CordaSerializable
|
||||
import net.corda.core.serialization.serialize
|
||||
import net.corda.core.utilities.ByteSequence
|
||||
import net.corda.node.services.statemachine.DeduplicationId
|
||||
import net.corda.node.services.statemachine.ExternalEvent
|
||||
import net.corda.node.services.statemachine.SenderDeduplicationId
|
||||
import net.corda.node.services.statemachine.SessionId
|
||||
import net.corda.nodeapi.internal.lifecycle.ServiceLifecycleSupport
|
||||
import java.time.Instant
|
||||
import javax.annotation.concurrent.ThreadSafe
|
||||
@ -32,7 +30,7 @@ interface MessagingService : ServiceLifecycleSupport {
|
||||
* A unique identifier for this sender that changes whenever a node restarts. This is used in conjunction with a sequence
|
||||
* number for message de-duplication at the recipient.
|
||||
*/
|
||||
val ourSenderUUID: String
|
||||
val ourSenderUUID: SenderUUID
|
||||
|
||||
/**
|
||||
* The provided function will be invoked for each received message whose topic and session matches. The callback
|
||||
@ -92,15 +90,25 @@ interface MessagingService : ServiceLifecycleSupport {
|
||||
@Suspendable
|
||||
fun sendAll(addressedMessages: List<AddressedMessage>)
|
||||
|
||||
/**
|
||||
* Signal that a session has ended to the messaging layer, so that any necessary cleanup is performed.
|
||||
*
|
||||
* @param sessionId the identifier of the session that ended.
|
||||
* @param senderUUID the sender UUID of the last message seen in the session or null if there was no sender UUID in that message.
|
||||
* @param senderSequenceNumber the sender sequence number of the last message seen in the session or null if there was no sender sequence number in that message.
|
||||
*/
|
||||
@Suspendable
|
||||
fun sessionEnded(sessionId: SessionId, senderUUID: SenderUUID?, senderSequenceNumber: SenderSequenceNumber?)
|
||||
|
||||
/**
|
||||
* Returns an initialised [Message] with the current time, etc, already filled in.
|
||||
*
|
||||
* @param topic identifier for the topic the message is sent to.
|
||||
* @param data the payload for the message.
|
||||
* @param deduplicationId optional message deduplication ID including sender identifier.
|
||||
* @param deduplicationInfo optional message deduplication information.
|
||||
* @param additionalHeaders optional additional message headers.
|
||||
*/
|
||||
fun createMessage(topic: String, data: ByteArray, deduplicationId: SenderDeduplicationId = SenderDeduplicationId(DeduplicationId.createRandom(newSecureRandom()), ourSenderUUID), additionalHeaders: Map<String, String> = emptyMap()): Message
|
||||
fun createMessage(topic: String, data: ByteArray, deduplicationInfo: SenderDeduplicationInfo, additionalHeaders: Map<String, String>): Message
|
||||
|
||||
/** Given information about either a specific node or a service returns its corresponding address */
|
||||
fun getAddressOfParty(partyInfo: PartyInfo): MessageRecipients
|
||||
@ -109,7 +117,7 @@ interface MessagingService : ServiceLifecycleSupport {
|
||||
val myAddress: SingleMessageRecipient
|
||||
}
|
||||
|
||||
fun MessagingService.send(topicSession: String, payload: Any, to: MessageRecipients, deduplicationId: SenderDeduplicationId = SenderDeduplicationId(DeduplicationId.createRandom(newSecureRandom()), ourSenderUUID), additionalHeaders: Map<String, String> = emptyMap()) = send(createMessage(topicSession, payload.serialize().bytes, deduplicationId, additionalHeaders), to)
|
||||
fun MessagingService.send(topicSession: String, payload: Any, to: MessageRecipients, deduplicationInfo: SenderDeduplicationInfo, additionalHeaders: Map<String, String>) = send(createMessage(topicSession, payload.serialize().bytes, deduplicationInfo, additionalHeaders), to)
|
||||
|
||||
interface MessageHandlerRegistration
|
||||
|
||||
@ -128,7 +136,7 @@ interface Message {
|
||||
val topic: String
|
||||
val data: ByteSequence
|
||||
val debugTimestamp: Instant
|
||||
val uniqueMessageId: DeduplicationId
|
||||
val uniqueMessageId: MessageIdentifier
|
||||
val senderUUID: String?
|
||||
val additionalHeaders: Map<String, String>
|
||||
}
|
||||
|
@ -75,7 +75,7 @@ class MessagingExecutor(
|
||||
putStringProperty(P2PMessagingHeaders.topicProperty, SimpleString(message.topic))
|
||||
writeBodyBufferBytes(message.data.bytes)
|
||||
// Use the magic deduplication property built into Artemis as our message identity too
|
||||
putStringProperty(org.apache.activemq.artemis.api.core.Message.HDR_DUPLICATE_DETECTION_ID, SimpleString(message.uniqueMessageId.toString))
|
||||
putStringProperty(org.apache.activemq.artemis.api.core.Message.HDR_DUPLICATE_DETECTION_ID, SimpleString(message.uniqueMessageId.toString()))
|
||||
// If we are the sender (ie. we are not going through recovery of some sort), use sequence number short cut.
|
||||
if (ourSenderUUID == message.senderUUID) {
|
||||
putStringProperty(P2PMessagingHeaders.senderUUID, SimpleString(ourSenderUUID))
|
||||
|
@ -1,56 +1,79 @@
|
||||
package net.corda.node.services.messaging
|
||||
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import net.corda.core.crypto.SecureHash
|
||||
import net.corda.core.identity.CordaX500Name
|
||||
import net.corda.core.internal.NamedCacheFactory
|
||||
import net.corda.node.services.statemachine.DeduplicationId
|
||||
import net.corda.core.utilities.contextLogger
|
||||
import net.corda.node.services.statemachine.SessionId
|
||||
import net.corda.node.utilities.AppendOnlyPersistentMap
|
||||
import net.corda.nodeapi.internal.persistence.CordaPersistence
|
||||
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
|
||||
import net.corda.nodeapi.internal.persistence.currentDBSession
|
||||
import java.math.BigInteger
|
||||
import java.time.Instant
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import javax.persistence.Column
|
||||
import javax.persistence.Entity
|
||||
import javax.persistence.Id
|
||||
import javax.persistence.Table
|
||||
|
||||
/**
|
||||
* Encapsulate the de-duplication logic.
|
||||
* This component is responsible for determining whether session-init messages are duplicates and it also keeps track of information related to
|
||||
* sessions that can be used for this purpose.
|
||||
*/
|
||||
class P2PMessageDeduplicator(cacheFactory: NamedCacheFactory, private val database: CordaPersistence) {
|
||||
|
||||
companion object {
|
||||
private val logger = contextLogger()
|
||||
}
|
||||
|
||||
// A temporary in-memory set of deduplication IDs and associated high water mark details.
|
||||
// When we receive a message we don't persist the ID immediately,
|
||||
// so we store the ID here in the meantime (until the persisting db tx has committed). This is because Artemis may
|
||||
// redeliver messages to the same consumer if they weren't ACKed.
|
||||
private val beingProcessedMessages = ConcurrentHashMap<DeduplicationId, MessageMeta>()
|
||||
private val processedMessages = createProcessedMessages(cacheFactory)
|
||||
private val beingProcessedMessages = ConcurrentHashMap<MessageIdentifier, MessageMeta>()
|
||||
|
||||
private fun createProcessedMessages(cacheFactory: NamedCacheFactory): AppendOnlyPersistentMap<DeduplicationId, MessageMeta, ProcessedMessage, String> {
|
||||
/**
|
||||
* This table holds data *only* for sessions that have been initiated from a counterparty (e.g. ones we have received session-init messages from).
|
||||
* This is because any other messages apart from session-init messages are deduplicated by the state machine.
|
||||
*/
|
||||
private val sessionData = createSessionDataMap(cacheFactory)
|
||||
|
||||
private fun createSessionDataMap(cacheFactory: NamedCacheFactory): AppendOnlyPersistentMap<SessionId, MessageMeta, SessionData, BigInteger> {
|
||||
return AppendOnlyPersistentMap(
|
||||
cacheFactory = cacheFactory,
|
||||
name = "P2PMessageDeduplicator_processedMessages",
|
||||
toPersistentEntityKey = { it.toString },
|
||||
fromPersistentEntity = { Pair(DeduplicationId(it.id), MessageMeta(it.insertionTime, it.hash, it.seqNo)) },
|
||||
toPersistentEntity = { key: DeduplicationId, value: MessageMeta ->
|
||||
ProcessedMessage().apply {
|
||||
id = key.toString
|
||||
insertionTime = value.insertionTime
|
||||
hash = value.senderHash
|
||||
seqNo = value.senderSeqNo
|
||||
name = "P2PMessageDeduplicator_sessionData",
|
||||
toPersistentEntityKey = { it.value },
|
||||
fromPersistentEntity = { Pair(SessionId(it.sessionId), MessageMeta(it.generationTime, it.senderHash, it.firstSenderSeqNo, it.lastSenderSeqNo)) },
|
||||
toPersistentEntity = { key: SessionId, value: MessageMeta ->
|
||||
SessionData().apply {
|
||||
sessionId = key.value
|
||||
generationTime = value.generationTime
|
||||
senderHash = value.senderHash
|
||||
firstSenderSeqNo = value.firstSenderSeqNo
|
||||
lastSenderSeqNo = value.lastSenderSeqNo
|
||||
}
|
||||
},
|
||||
persistentEntityClass = ProcessedMessage::class.java
|
||||
persistentEntityClass = SessionData::class.java
|
||||
)
|
||||
}
|
||||
|
||||
private fun isDuplicateInDatabase(msg: ReceivedMessage): Boolean = database.transaction { msg.uniqueMessageId in processedMessages }
|
||||
private fun isDuplicateInDatabase(msg: ReceivedMessage): Boolean = database.transaction { msg.uniqueMessageId.sessionIdentifier in sessionData }
|
||||
|
||||
// We need to incorporate the sending party, and the sessionInit flag as per the in-memory cache.
|
||||
private fun senderHash(senderKey: SenderKey) = SecureHash.sha256(senderKey.peer.toString() + senderKey.isSessionInit.toString() + senderKey.senderUUID).toString()
|
||||
|
||||
/**
|
||||
* Determines whether a session-init message is a duplicate.
|
||||
* This is achieved by checking whether this message is currently being processed or if the associated session has already been created in the past.
|
||||
* This method should be invoked only with session-init messages, otherwise it will fail with an [IllegalArgumentException].
|
||||
*
|
||||
* @return true if we have seen this message before.
|
||||
*/
|
||||
fun isDuplicate(msg: ReceivedMessage): Boolean {
|
||||
fun isDuplicateSessionInit(msg: ReceivedMessage): Boolean {
|
||||
require(msg.isSessionInit) { "Message ${msg.uniqueMessageId} was not a session-init message." }
|
||||
|
||||
if (beingProcessedMessages.containsKey(msg.uniqueMessageId)) {
|
||||
return true
|
||||
}
|
||||
@ -65,44 +88,86 @@ class P2PMessageDeduplicator(cacheFactory: NamedCacheFactory, private val databa
|
||||
val receivedSenderSeqNo = msg.senderSeqNo
|
||||
// We don't want a mix of nulls and values so we ensure that here.
|
||||
val senderHash: String? = if (receivedSenderUUID != null && receivedSenderSeqNo != null) senderHash(SenderKey(receivedSenderUUID, msg.peer, msg.isSessionInit)) else null
|
||||
val senderSeqNo: Long? = if (senderHash != null) msg.senderSeqNo else null
|
||||
beingProcessedMessages[msg.uniqueMessageId] = MessageMeta(Instant.now(), senderHash, senderSeqNo)
|
||||
val firstSenderSeqNo: Long? = if (senderHash != null) msg.senderSeqNo else null
|
||||
beingProcessedMessages[msg.uniqueMessageId] = MessageMeta(msg.uniqueMessageId.timestamp, senderHash, firstSenderSeqNo, null)
|
||||
}
|
||||
|
||||
/**
|
||||
* Called inside a DB transaction to persist [deduplicationId].
|
||||
*/
|
||||
fun persistDeduplicationId(deduplicationId: DeduplicationId) {
|
||||
processedMessages[deduplicationId] = beingProcessedMessages[deduplicationId]!!
|
||||
fun persistDeduplicationId(deduplicationId: MessageIdentifier) {
|
||||
sessionData[deduplicationId.sessionIdentifier] = beingProcessedMessages[deduplicationId]!!
|
||||
}
|
||||
|
||||
/**
|
||||
* Called after the DB transaction persisting [deduplicationId] committed.
|
||||
* Any subsequent redelivery will be deduplicated using the DB.
|
||||
*/
|
||||
fun signalMessageProcessFinish(deduplicationId: DeduplicationId) {
|
||||
fun signalMessageProcessFinish(deduplicationId: MessageIdentifier) {
|
||||
beingProcessedMessages.remove(deduplicationId)
|
||||
}
|
||||
|
||||
/**
|
||||
* Called inside a DB transaction to update entry for corresponding session.
|
||||
* The parameters [senderUUID] and [senderSequenceNumber] correspond to the last message seen from this session before it ended.
|
||||
* If [senderUUID] is not null, then [senderSequenceNumber] is also expected to not be null.
|
||||
*/
|
||||
@Suspendable
|
||||
fun signalSessionEnd(sessionId: SessionId, senderUUID: String?, senderSequenceNumber: Long?) {
|
||||
if (senderSequenceNumber != null && senderUUID != null) {
|
||||
val existingEntry = sessionData[sessionId]
|
||||
if (existingEntry != null) {
|
||||
val newEntry = existingEntry.copy(lastSenderSeqNo = senderSequenceNumber)
|
||||
sessionData.addOrUpdate(sessionId, newEntry) { k, v ->
|
||||
update(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun update(key: SessionId, value: MessageMeta): Boolean {
|
||||
val session = currentDBSession()
|
||||
val criteriaBuilder = session.criteriaBuilder
|
||||
val criteriaUpdate = criteriaBuilder.createCriteriaUpdate(SessionData::class.java)
|
||||
val queryRoot = criteriaUpdate.from(SessionData::class.java)
|
||||
criteriaUpdate.set(SessionData::lastSenderSeqNo.name, value.lastSenderSeqNo)
|
||||
criteriaUpdate.where(criteriaBuilder.equal(queryRoot.get<BigInteger>(SessionData::sessionId.name), key.value))
|
||||
val update = session.createQuery(criteriaUpdate)
|
||||
val rowsUpdated = update.executeUpdate()
|
||||
return rowsUpdated != 0
|
||||
}
|
||||
|
||||
@Entity
|
||||
@Suppress("MagicNumber") // database column width
|
||||
@javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}message_ids")
|
||||
class ProcessedMessage(
|
||||
@Table(name = "${NODE_DATABASE_PREFIX}session_data")
|
||||
class SessionData (
|
||||
@Id
|
||||
@Column(name = "message_id", length = 64, nullable = false)
|
||||
var id: String = "",
|
||||
@Column(name = "session_id", nullable = false)
|
||||
var sessionId: BigInteger = BigInteger.ZERO,
|
||||
|
||||
@Column(name = "insertion_time", nullable = false)
|
||||
var insertionTime: Instant = Instant.now(),
|
||||
/**
|
||||
* The time the corresponding session-init message was originally generated on the sender side.
|
||||
*/
|
||||
@Column(name = "init_generation_time", nullable = false)
|
||||
var generationTime: Instant = Instant.now(),
|
||||
|
||||
@Column(name = "sender", length = 64, nullable = true)
|
||||
var hash: String? = "",
|
||||
@Column(name = "sender_hash", length = 64, nullable = true)
|
||||
var senderHash: String? = "",
|
||||
|
||||
@Column(name = "sequence_number", nullable = true)
|
||||
var seqNo: Long? = null
|
||||
/**
|
||||
* The sender sequence number of the first message seen in a session.
|
||||
*/
|
||||
@Column(name = "init_sequence_number", nullable = true)
|
||||
var firstSenderSeqNo: Long? = null,
|
||||
|
||||
/**
|
||||
* The sender sequence number of the last message seen in a session before it was closed/terminated.
|
||||
*/
|
||||
@Column(name = "last_sequence_number", nullable = true)
|
||||
var lastSenderSeqNo: Long? = null
|
||||
)
|
||||
|
||||
private data class MessageMeta(val insertionTime: Instant, val senderHash: String?, val senderSeqNo: Long?)
|
||||
private data class MessageMeta(val generationTime: Instant, val senderHash: String?, val firstSenderSeqNo: SenderSequenceNumber?, val lastSenderSeqNo: SenderSequenceNumber?)
|
||||
|
||||
private data class SenderKey(val senderUUID: String, val peer: CordaX500Name, val isSessionInit: Boolean)
|
||||
}
|
||||
|
@ -25,9 +25,9 @@ import net.corda.node.internal.artemis.ReactiveArtemisConsumer
|
||||
import net.corda.node.internal.artemis.ReactiveArtemisConsumer.Companion.multiplex
|
||||
import net.corda.node.services.api.NetworkMapCacheInternal
|
||||
import net.corda.node.services.config.NodeConfiguration
|
||||
import net.corda.node.services.statemachine.DeduplicationId
|
||||
import net.corda.node.services.statemachine.ExternalEvent
|
||||
import net.corda.node.services.statemachine.SenderDeduplicationId
|
||||
import net.corda.node.services.statemachine.MessageType
|
||||
import net.corda.node.services.statemachine.SessionId
|
||||
import net.corda.node.utilities.AffinityExecutor
|
||||
import net.corda.node.utilities.errorAndTerminate
|
||||
import net.corda.nodeapi.internal.ArtemisMessagingComponent
|
||||
@ -101,8 +101,8 @@ class P2PMessagingClient(val config: NodeConfiguration,
|
||||
|
||||
private class NodeClientMessage(override val topic: String,
|
||||
override val data: ByteSequence,
|
||||
override val uniqueMessageId: DeduplicationId,
|
||||
override val senderUUID: String?,
|
||||
override val uniqueMessageId: MessageIdentifier,
|
||||
override val senderUUID: SenderUUID?,
|
||||
override val additionalHeaders: Map<String, String>) : Message {
|
||||
override val debugTimestamp: Instant = Instant.now()
|
||||
override fun toString() = "$topic#${String(data.bytes)}"
|
||||
@ -371,7 +371,7 @@ class P2PMessagingClient(val config: NodeConfiguration,
|
||||
val user = requireNotNull(message.getStringProperty(HDR_VALIDATED_USER)) { "Message is not authenticated" }
|
||||
val platformVersion = message.required(P2PMessagingHeaders.platformVersionProperty) { getIntProperty(it) }
|
||||
// Use the magic deduplication property built into Artemis as our message identity too
|
||||
val uniqueMessageId = message.required(HDR_DUPLICATE_DETECTION_ID) { DeduplicationId(message.getStringProperty(it)) }
|
||||
val uniqueMessageId = message.required(HDR_DUPLICATE_DETECTION_ID) { MessageIdentifier.parse(message.getStringProperty(it)) }
|
||||
val receivedSenderUUID = message.getStringProperty(P2PMessagingHeaders.senderUUID)
|
||||
val receivedSenderSeqNo = if (message.containsProperty(P2PMessagingHeaders.senderSeqNo)) message.getLongProperty(P2PMessagingHeaders.senderSeqNo) else null
|
||||
val isSessionInit = message.getStringProperty(P2PMessagingHeaders.Type.KEY) == P2PMessagingHeaders.Type.SESSION_INIT_VALUE
|
||||
@ -392,8 +392,8 @@ class P2PMessagingClient(val config: NodeConfiguration,
|
||||
private class ArtemisReceivedMessage(override val topic: String,
|
||||
override val peer: CordaX500Name,
|
||||
override val platformVersion: Int,
|
||||
override val uniqueMessageId: DeduplicationId,
|
||||
override val senderUUID: String?,
|
||||
override val uniqueMessageId: MessageIdentifier,
|
||||
override val senderUUID: SenderUUID?,
|
||||
override val senderSeqNo: Long?,
|
||||
override val isSessionInit: Boolean,
|
||||
private val message: ClientMessage) : ReceivedMessage {
|
||||
@ -405,13 +405,18 @@ class P2PMessagingClient(val config: NodeConfiguration,
|
||||
|
||||
internal fun deliver(artemisMessage: ClientMessage) {
|
||||
artemisToCordaMessage(artemisMessage)?.let { cordaMessage ->
|
||||
if (!deduplicator.isDuplicate(cordaMessage)) {
|
||||
if (cordaMessage.uniqueMessageId.messageType == MessageType.SESSION_INIT) {
|
||||
if (!deduplicator.isDuplicateSessionInit(cordaMessage)) {
|
||||
deduplicator.signalMessageProcessStart(cordaMessage)
|
||||
deliver(cordaMessage, artemisMessage)
|
||||
} else {
|
||||
log.trace { "Discard duplicate message ${cordaMessage.uniqueMessageId} for ${cordaMessage.topic}" }
|
||||
log.debug { "Discarding duplicate session-init message with identifier: ${cordaMessage.uniqueMessageId}, senderUUID: ${cordaMessage.senderUUID}, senderSeqNo: ${cordaMessage.senderSeqNo}" }
|
||||
messagingExecutor!!.acknowledge(artemisMessage)
|
||||
}
|
||||
} else {
|
||||
// non session-init messages are directly handed to the state machine, which is responsible for performing deduplication.
|
||||
deliver(cordaMessage, artemisMessage)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -420,7 +425,11 @@ class P2PMessagingClient(val config: NodeConfiguration,
|
||||
val deliverTo = handlers[msg.topic]
|
||||
if (deliverTo != null) {
|
||||
try {
|
||||
deliverTo(msg, HandlerRegistration(msg.topic, deliverTo), MessageDeduplicationHandler(artemisMessage, msg))
|
||||
if (msg.uniqueMessageId.messageType == MessageType.SESSION_INIT) {
|
||||
deliverTo(msg, HandlerRegistration(msg.topic, deliverTo), MessageDeduplicationHandlerForSessionInitMessages(artemisMessage, msg))
|
||||
} else {
|
||||
deliverTo(msg, HandlerRegistration(msg.topic, deliverTo), MessageDeduplicationHandlerForRegularMessages(artemisMessage, msg))
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
log.error("Caught exception whilst executing message handler for ${msg.topic}", e)
|
||||
}
|
||||
@ -429,11 +438,11 @@ class P2PMessagingClient(val config: NodeConfiguration,
|
||||
}
|
||||
}
|
||||
|
||||
private inner class MessageDeduplicationHandler(val artemisMessage: ClientMessage, override val receivedMessage: ReceivedMessage) : DeduplicationHandler, ExternalEvent.ExternalMessageEvent {
|
||||
private inner class MessageDeduplicationHandlerForSessionInitMessages(val artemisMessage: ClientMessage, override val receivedMessage: ReceivedMessage) : DeduplicationHandler, ExternalEvent.ExternalMessageEvent {
|
||||
override val externalCause: ExternalEvent
|
||||
get() = this
|
||||
override val flowId: StateMachineRunId by lazy { StateMachineRunId.createRandom() }
|
||||
override val deduplicationHandler: MessageDeduplicationHandler
|
||||
override val deduplicationHandler: MessageDeduplicationHandlerForSessionInitMessages
|
||||
get() = this
|
||||
|
||||
override fun insideDatabaseTransaction() {
|
||||
@ -450,6 +459,27 @@ class P2PMessagingClient(val config: NodeConfiguration,
|
||||
}
|
||||
}
|
||||
|
||||
private inner class MessageDeduplicationHandlerForRegularMessages(val artemisMessage: ClientMessage, override val receivedMessage: ReceivedMessage) : DeduplicationHandler, ExternalEvent.ExternalMessageEvent {
|
||||
override val externalCause: ExternalEvent
|
||||
get() = this
|
||||
override val flowId: StateMachineRunId by lazy { StateMachineRunId.createRandom() }
|
||||
override val deduplicationHandler: MessageDeduplicationHandlerForRegularMessages
|
||||
get() = this
|
||||
|
||||
/**
|
||||
* Nothing to do, since deduplication information is kept in the state machine.
|
||||
*/
|
||||
override fun insideDatabaseTransaction() {}
|
||||
|
||||
override fun afterDatabaseTransaction() {
|
||||
messagingExecutor!!.acknowledge(artemisMessage)
|
||||
}
|
||||
|
||||
override fun toString(): String {
|
||||
return "${javaClass.simpleName}(${receivedMessage.uniqueMessageId})"
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Initiates shutdown: if called from a thread that isn't controlled by the executor passed to the constructor
|
||||
* then this will block until all in-flight messages have finished being handled and acknowledged. If called
|
||||
@ -520,6 +550,11 @@ class P2PMessagingClient(val config: NodeConfiguration,
|
||||
}
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
override fun sessionEnded(sessionId: SessionId, senderUUID: SenderUUID?, senderSequenceNumber: SenderSequenceNumber?) {
|
||||
deduplicator.signalSessionEnd(sessionId, senderUUID, senderSequenceNumber)
|
||||
}
|
||||
|
||||
override fun resolveTargetToArtemisQueue(address: MessageRecipients): String {
|
||||
return if (address == myAddress) {
|
||||
// If we are sending to ourselves then route the message directly to our P2P queue.
|
||||
@ -586,8 +621,8 @@ class P2PMessagingClient(val config: NodeConfiguration,
|
||||
handlers.remove(registration.topic)
|
||||
}
|
||||
|
||||
override fun createMessage(topic: String, data: ByteArray, deduplicationId: SenderDeduplicationId, additionalHeaders: Map<String, String>): Message {
|
||||
return NodeClientMessage(topic, OpaqueBytes(data), deduplicationId.deduplicationId, deduplicationId.senderUUID, additionalHeaders)
|
||||
override fun createMessage(topic: String, data: ByteArray, deduplicationInfo: SenderDeduplicationInfo, additionalHeaders: Map<String, String>): Message {
|
||||
return NodeClientMessage(topic, OpaqueBytes(data), deduplicationInfo.messageIdentifier, deduplicationInfo.senderUUID, additionalHeaders)
|
||||
}
|
||||
|
||||
override fun getAddressOfParty(partyInfo: PartyInfo): MessageRecipients {
|
||||
|
@ -0,0 +1,11 @@
|
||||
package net.corda.node.services.messaging
|
||||
|
||||
/**
|
||||
* This is a combination of a message's unique identifier along with a unique identifier for the sender of the message.
|
||||
* The former can be used independently for deduplication purposes when receiving a message, but enriching it with the latter helps us
|
||||
* optimise some paths and perform smarter deduplication logic per sender.
|
||||
*
|
||||
* The [senderUUID] property might be null if the flow is trying to replay messages and doesn't want an optimisation to ignore the message identifier
|
||||
* because it could lead to false negatives (messages that are deemed duplicates, but are not).
|
||||
*/
|
||||
data class SenderDeduplicationInfo(val messageIdentifier: MessageIdentifier, val senderUUID: String?)
|
@ -513,7 +513,7 @@ class CheckpointDumperImpl(private val checkpointStorage: CheckpointStorage, pri
|
||||
|
||||
private fun SessionState.toActiveSession(sessionId: SessionId): ActiveSession? {
|
||||
return if (this is SessionState.Initiated) {
|
||||
ActiveSession(peerParty, sessionId, receivedMessages, peerFlowInfo, peerSinkSessionId)
|
||||
ActiveSession(peerParty, sessionId, receivedMessages.values.toList(), peerFlowInfo, peerSinkSessionId)
|
||||
} else {
|
||||
null
|
||||
}
|
||||
|
@ -43,7 +43,7 @@ class NodeSchemaService(private val extraSchemas: Set<MappedSchema> = emptySet()
|
||||
BasicHSMKeyManagementService.PersistentKey::class.java,
|
||||
NodeSchedulerService.PersistentScheduledState::class.java,
|
||||
NodeAttachmentService.DBAttachment::class.java,
|
||||
P2PMessageDeduplicator.ProcessedMessage::class.java,
|
||||
P2PMessageDeduplicator.SessionData::class.java,
|
||||
PersistentIdentityService.PersistentPublicKeyHashToCertificate::class.java,
|
||||
PersistentIdentityService.PersistentPublicKeyHashToParty::class.java,
|
||||
PersistentIdentityService.PersistentHashToPublicKey::class.java,
|
||||
|
@ -6,6 +6,10 @@ import net.corda.core.flows.StateMachineRunId
|
||||
import net.corda.core.identity.Party
|
||||
import net.corda.core.internal.FlowAsyncOperation
|
||||
import net.corda.node.services.messaging.DeduplicationHandler
|
||||
import net.corda.node.services.messaging.MessageIdentifier
|
||||
import net.corda.node.services.messaging.SenderDeduplicationInfo
|
||||
import net.corda.node.services.messaging.SenderSequenceNumber
|
||||
import net.corda.node.services.messaging.SenderUUID
|
||||
import java.time.Instant
|
||||
import java.util.*
|
||||
|
||||
@ -25,7 +29,7 @@ sealed class Action {
|
||||
data class SendInitial(
|
||||
val destination: Destination,
|
||||
val initialise: InitialSessionMessage,
|
||||
val deduplicationId: SenderDeduplicationId
|
||||
val deduplicationInfo: SenderDeduplicationInfo
|
||||
) : Action()
|
||||
|
||||
/**
|
||||
@ -34,7 +38,7 @@ sealed class Action {
|
||||
data class SendExisting(
|
||||
val peerParty: Party,
|
||||
val message: ExistingSessionMessage,
|
||||
val deduplicationId: SenderDeduplicationId
|
||||
val deduplicationInfo: SenderDeduplicationInfo
|
||||
) : Action()
|
||||
|
||||
/**
|
||||
@ -95,12 +99,11 @@ sealed class Action {
|
||||
data class AcknowledgeMessages(val deduplicationHandlers: List<DeduplicationHandler>) : Action()
|
||||
|
||||
/**
|
||||
* Propagate [errorMessages] to [sessions].
|
||||
* @param sessions a map from source session IDs to initiated sessions.
|
||||
* Propagate the specified error messages to the specified sessions.
|
||||
* @param errorsPerSession a map containing the error messages to be sent per session along with their identifiers.
|
||||
*/
|
||||
data class PropagateErrors(
|
||||
val errorMessages: List<ErrorSessionMessage>,
|
||||
val sessions: List<SessionState.Initiated>,
|
||||
val errorsPerSession: Map<SessionState.Initiated, List<Pair<MessageIdentifier, ErrorSessionMessage>>>,
|
||||
val senderUUID: String?
|
||||
) : Action()
|
||||
|
||||
@ -114,6 +117,11 @@ sealed class Action {
|
||||
*/
|
||||
data class RemoveSessionBindings(val sessionIds: Set<SessionId>) : Action()
|
||||
|
||||
/**
|
||||
* Signal sessions ended at the messaging layer.
|
||||
*/
|
||||
data class SignalSessionsHasEnded(val terminatedSessions: Map<SessionId, Pair<SenderUUID?, SenderSequenceNumber?>>): Action()
|
||||
|
||||
/**
|
||||
* Signal that the flow corresponding to [flowId] is considered started.
|
||||
*/
|
||||
|
@ -10,6 +10,7 @@ import net.corda.core.utilities.contextLogger
|
||||
import net.corda.core.utilities.trace
|
||||
import net.corda.node.services.api.CheckpointStorage
|
||||
import net.corda.node.services.api.ServiceHubInternal
|
||||
import net.corda.node.services.messaging.SenderDeduplicationInfo
|
||||
import net.corda.nodeapi.internal.persistence.contextDatabase
|
||||
import net.corda.nodeapi.internal.persistence.contextTransaction
|
||||
import net.corda.nodeapi.internal.persistence.contextTransactionOrNull
|
||||
@ -58,6 +59,7 @@ internal class ActionExecutorImpl(
|
||||
is Action.AddSessionBinding -> executeAddSessionBinding(action)
|
||||
is Action.RemoveSessionBindings -> executeRemoveSessionBindings(action)
|
||||
is Action.SignalFlowHasStarted -> executeSignalFlowHasStarted(action)
|
||||
is Action.SignalSessionsHasEnded -> executeSignalSessionsHasEnded(action)
|
||||
is Action.RemoveFlow -> executeRemoveFlow(action)
|
||||
is Action.CreateTransaction -> executeCreateTransaction()
|
||||
is Action.RollbackTransaction -> executeRollbackTransaction()
|
||||
@ -132,16 +134,17 @@ internal class ActionExecutorImpl(
|
||||
|
||||
@Suspendable
|
||||
private fun executePropagateErrors(action: Action.PropagateErrors) {
|
||||
action.errorMessages.forEach { (exception) ->
|
||||
val errors = action.errorsPerSession.values.flatMap { it.map { it.second } }.distinct()
|
||||
errors.forEach { errorSessionMessage ->
|
||||
val exception = errorSessionMessage.flowException
|
||||
log.warn("Propagating error", exception)
|
||||
}
|
||||
for (sessionState in action.sessions) {
|
||||
action.errorsPerSession.forEach { (sessionState, sessionErrors) ->
|
||||
// Don't propagate errors to the originating session
|
||||
for (errorMessage in action.errorMessages) {
|
||||
for ((id, msg) in sessionErrors) {
|
||||
val sinkSessionId = sessionState.peerSinkSessionId
|
||||
val existingMessage = ExistingSessionMessage(sinkSessionId, errorMessage)
|
||||
val deduplicationId = DeduplicationId.createForError(errorMessage.errorId, sinkSessionId)
|
||||
flowMessaging.sendSessionMessage(sessionState.peerParty, existingMessage, SenderDeduplicationId(deduplicationId, action.senderUUID))
|
||||
val errorMsg = ExistingSessionMessage(sinkSessionId, msg)
|
||||
flowMessaging.sendSessionMessage(sessionState.peerParty, errorMsg, SenderDeduplicationInfo(id, action.senderUUID))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -162,18 +165,18 @@ internal class ActionExecutorImpl(
|
||||
|
||||
@Suspendable
|
||||
private fun executeSendInitial(action: Action.SendInitial) {
|
||||
flowMessaging.sendSessionMessage(action.destination, action.initialise, action.deduplicationId)
|
||||
flowMessaging.sendSessionMessage(action.destination, action.initialise, action.deduplicationInfo)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun executeSendExisting(action: Action.SendExisting) {
|
||||
flowMessaging.sendSessionMessage(action.peerParty, action.message, action.deduplicationId)
|
||||
flowMessaging.sendSessionMessage(action.peerParty, action.message, action.deduplicationInfo)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun executeSendMultiple(action: Action.SendMultiple) {
|
||||
val messages = action.sendInitial.map { Message(it.destination, it.initialise, it.deduplicationId) } +
|
||||
action.sendExisting.map { Message(it.peerParty, it.message, it.deduplicationId) }
|
||||
val messages = action.sendInitial.map { Message(it.destination, it.initialise, it.deduplicationInfo) } +
|
||||
action.sendExisting.map { Message(it.peerParty, it.message, it.deduplicationInfo) }
|
||||
flowMessaging.sendSessionMessages(messages)
|
||||
}
|
||||
|
||||
@ -192,6 +195,13 @@ internal class ActionExecutorImpl(
|
||||
stateMachineManager.signalFlowHasStarted(action.flowId)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun executeSignalSessionsHasEnded(action: Action.SignalSessionsHasEnded) {
|
||||
action.terminatedSessions.forEach { (sessionId, senderData) ->
|
||||
flowMessaging.sessionEnded(sessionId, senderData.first, senderData.second)
|
||||
}
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun executeRemoveFlow(action: Action.RemoveFlow) {
|
||||
stateMachineManager.removeFlow(action.flowId, action.removalReason, action.lastState)
|
||||
|
@ -1,53 +0,0 @@
|
||||
package net.corda.node.services.statemachine
|
||||
|
||||
import java.security.SecureRandom
|
||||
|
||||
/**
|
||||
* A deduplication ID of a flow message.
|
||||
*/
|
||||
data class DeduplicationId(val toString: String) {
|
||||
companion object {
|
||||
/**
|
||||
* Create a random deduplication ID. Note that this isn't deterministic, which means we will never dedupe it,
|
||||
* unless we persist the ID somehow.
|
||||
*/
|
||||
fun createRandom(random: SecureRandom) = DeduplicationId("R-${random.nextLong()}")
|
||||
|
||||
/**
|
||||
* Create a deduplication ID for a normal clean state message. This is used to have a deterministic way of
|
||||
* creating IDs in case the message-generating flow logic is replayed on hard failure.
|
||||
*
|
||||
* A normal deduplication ID consists of:
|
||||
* 1. A deduplication seed set per session. This is the initiator's session ID, with a prefix for initiator
|
||||
* or initiated.
|
||||
* 2. The number of *clean* suspends since the start of the flow.
|
||||
* 3. An optional additional index, for cases where several messages are sent as part of the state transition.
|
||||
* Note that care must be taken with this index, it must be a deterministic counter. For example a naive
|
||||
* iteration over a HashMap will produce a different list of indeces than a previous run, causing the
|
||||
* message-id map to change, which means deduplication will not happen correctly.
|
||||
*/
|
||||
fun createForNormal(checkpoint: Checkpoint, index: Int, session: SessionState): DeduplicationId {
|
||||
return DeduplicationId("N-${session.deduplicationSeed}-${checkpoint.checkpointState.numberOfSuspends}-$index")
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a deduplication ID for an error message. Note that these IDs live in a different namespace than normal
|
||||
* IDs, as we don't want error conditions to affect the determinism of clean deduplication IDs. This allows the
|
||||
* dirtiness state to be thrown away for resumption.
|
||||
*
|
||||
* An error deduplication ID consists of:
|
||||
* 1. The error's ID. This is a unique value per "source" of error and is propagated.
|
||||
* See [net.corda.core.flows.IdentifiableException].
|
||||
* 2. The recipient's session ID.
|
||||
*/
|
||||
fun createForError(errorId: Long, recipientSessionId: SessionId): DeduplicationId {
|
||||
return DeduplicationId("E-$errorId-${recipientSessionId.toLong}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents the deduplication ID of a flow message, and the sender identifier for the flow doing the sending. The identifier might be
|
||||
* null if the flow is trying to replay messages and doesn't want an optimisation to ignore the deduplication ID.
|
||||
*/
|
||||
data class SenderDeduplicationId(val deduplicationId: DeduplicationId, val senderUUID: String?)
|
@ -8,6 +8,9 @@ import net.corda.core.serialization.SerializedBytes
|
||||
import net.corda.core.transactions.SignedTransaction
|
||||
import net.corda.core.utilities.ProgressTracker
|
||||
import net.corda.node.services.messaging.DeduplicationHandler
|
||||
import net.corda.node.services.messaging.MessageIdentifier
|
||||
import net.corda.node.services.messaging.SenderSequenceNumber
|
||||
import net.corda.node.services.messaging.SenderUUID
|
||||
import java.util.UUID
|
||||
|
||||
/**
|
||||
@ -34,7 +37,10 @@ sealed class Event {
|
||||
data class DeliverSessionMessage(
|
||||
val sessionMessage: ExistingSessionMessage,
|
||||
override val deduplicationHandler: DeduplicationHandler,
|
||||
val sender: Party
|
||||
val sender: Party,
|
||||
val messageIdentifier: MessageIdentifier,
|
||||
val senderUUID: SenderUUID?,
|
||||
val senderSequenceNumber: SenderSequenceNumber?
|
||||
) : Event(), GeneratedByExternalEvent
|
||||
|
||||
/**
|
||||
|
@ -155,7 +155,8 @@ class FlowCreator(
|
||||
frozenFlowLogic,
|
||||
ourIdentity,
|
||||
flowCorDappVersion,
|
||||
flowLogic.isEnabledTimedFlow()
|
||||
flowLogic.isEnabledTimedFlow(),
|
||||
serviceHub.clock.instant()
|
||||
).getOrThrow()
|
||||
|
||||
val state = createStateMachineState(
|
||||
@ -253,6 +254,7 @@ class FlowCreator(
|
||||
return StateMachineState(
|
||||
checkpoint = checkpoint,
|
||||
pendingDeduplicationHandlers = deduplicationHandler?.let { listOf(it) } ?: emptyList(),
|
||||
closedSessionsPendingToBeSignalled = emptyMap(),
|
||||
isFlowResumed = false,
|
||||
future = null,
|
||||
isWaitingForFuture = false,
|
||||
|
@ -15,6 +15,9 @@ import net.corda.node.services.api.ServiceHubInternal
|
||||
import net.corda.node.services.messaging.DeduplicationHandler
|
||||
import net.corda.node.services.messaging.MessagingService
|
||||
import net.corda.node.services.messaging.ReceivedMessage
|
||||
import net.corda.node.services.messaging.SenderDeduplicationInfo
|
||||
import net.corda.node.services.messaging.SenderSequenceNumber
|
||||
import net.corda.node.services.messaging.SenderUUID
|
||||
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2PMessagingHeaders
|
||||
import java.io.NotSerializableException
|
||||
|
||||
@ -23,21 +26,24 @@ import java.io.NotSerializableException
|
||||
*/
|
||||
interface FlowMessaging {
|
||||
/**
|
||||
* Send [message] to [destination] using [deduplicationId].
|
||||
* Send [message] to [destination] using [deduplicationInfo].
|
||||
*/
|
||||
@Suspendable
|
||||
fun sendSessionMessage(destination: Destination, message: SessionMessage, deduplicationId: SenderDeduplicationId)
|
||||
fun sendSessionMessage(destination: Destination, message: SessionMessage, deduplicationInfo: SenderDeduplicationInfo)
|
||||
|
||||
@Suspendable
|
||||
fun sendSessionMessages(messageData: List<Message>)
|
||||
|
||||
@Suspendable
|
||||
fun sessionEnded(sessionId: SessionId, senderUUID: SenderUUID?, senderSequenceNumber: SenderSequenceNumber?)
|
||||
|
||||
/**
|
||||
* Start the messaging using the [onMessage] message handler.
|
||||
*/
|
||||
fun start(onMessage: (ReceivedMessage, deduplicationHandler: DeduplicationHandler) -> Unit)
|
||||
}
|
||||
|
||||
data class Message(val destination: Destination, val sessionMessage: SessionMessage, val dedupId: SenderDeduplicationId)
|
||||
data class Message(val destination: Destination, val sessionMessage: SessionMessage, val dedupInfo: SenderDeduplicationInfo)
|
||||
|
||||
/**
|
||||
* Implementation of [FlowMessaging] using a [ServiceHubInternal] to do the messaging and routing.
|
||||
@ -56,18 +62,23 @@ class FlowMessagingImpl(val serviceHub: ServiceHubInternal): FlowMessaging {
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
override fun sendSessionMessage(destination: Destination, message: SessionMessage, deduplicationId: SenderDeduplicationId) {
|
||||
val addressedMessage = createMessage(destination, message, deduplicationId)
|
||||
override fun sendSessionMessage(destination: Destination, message: SessionMessage, deduplicationInfo: SenderDeduplicationInfo) {
|
||||
val addressedMessage = createMessage(destination, message, deduplicationInfo)
|
||||
serviceHub.networkService.send(addressedMessage.message, addressedMessage.target, addressedMessage.sequenceKey)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
override fun sendSessionMessages(messageData: List<Message>) {
|
||||
val addressedMessages = messageData.map { createMessage(it.destination, it.sessionMessage, it.dedupId) }
|
||||
val addressedMessages = messageData.map { createMessage(it.destination, it.sessionMessage, it.dedupInfo) }
|
||||
serviceHub.networkService.sendAll(addressedMessages)
|
||||
}
|
||||
|
||||
private fun createMessage(destination: Destination, message: SessionMessage, deduplicationId: SenderDeduplicationId): MessagingService.AddressedMessage {
|
||||
@Suspendable
|
||||
override fun sessionEnded(sessionId: SessionId, senderUUID: SenderUUID?, senderSequenceNumber: SenderSequenceNumber?) {
|
||||
serviceHub.networkService.sessionEnded(sessionId, senderUUID, senderSequenceNumber)
|
||||
}
|
||||
|
||||
private fun createMessage(destination: Destination, message: SessionMessage, deduplicationId: SenderDeduplicationInfo): MessagingService.AddressedMessage {
|
||||
// We assume that the destination type has already been checked by initiateFlow.
|
||||
// Destination may point to a stale well-known identity due to key rotation, so always resolve actual identity via IdentityService.
|
||||
val party = requireNotNull(serviceHub.identityService.wellKnownPartyFromAnonymous(destination as AbstractParty)) {
|
||||
@ -80,6 +91,7 @@ class FlowMessagingImpl(val serviceHub: ServiceHubInternal): FlowMessaging {
|
||||
}
|
||||
val networkMessage = serviceHub.networkService.createMessage(sessionTopic, serializeSessionMessage(message).bytes, deduplicationId, message.additionalHeaders(party))
|
||||
val partyInfo = requireNotNull(serviceHub.networkMapCache.getPartyInfo(party)) { "Don't know about ${party.description()}" }
|
||||
|
||||
val address = serviceHub.networkService.getAddressOfParty(partyInfo)
|
||||
val sequenceKey = when (message) {
|
||||
is InitialSessionMessage -> message.initiatorSessionId
|
||||
|
@ -179,7 +179,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
val stateMachine = transientValues.stateMachine
|
||||
val oldState = transientState
|
||||
val actionExecutor = transientValues.actionExecutor
|
||||
val transition = stateMachine.transition(event, oldState)
|
||||
val transition = stateMachine.transition(event, oldState, serviceHub.clock.instant())
|
||||
val (continuation, newState) = transitionExecutor.executeTransition(
|
||||
this,
|
||||
oldState,
|
||||
|
@ -4,6 +4,7 @@ import net.corda.core.flows.FlowException
|
||||
import net.corda.core.flows.FlowInfo
|
||||
import net.corda.core.serialization.CordaSerializable
|
||||
import net.corda.core.serialization.SerializedBytes
|
||||
import java.math.BigInteger
|
||||
import java.security.SecureRandom
|
||||
|
||||
/**
|
||||
@ -21,9 +22,28 @@ import java.security.SecureRandom
|
||||
sealed class SessionMessage
|
||||
|
||||
@CordaSerializable
|
||||
data class SessionId(val toLong: Long) {
|
||||
data class SessionId(val value: BigInteger) {
|
||||
init {
|
||||
require(value.signum() >= 0) { "Session identifier cannot be a negative number, but it was $value" }
|
||||
require(value.bitLength() <= MAX_BIT_SIZE) { "The size of a session identifier cannot exceed $MAX_BIT_SIZE bits, but it was $value" }
|
||||
}
|
||||
|
||||
/**
|
||||
* This calculates the initiated session ID assuming this is the initiating session ID.
|
||||
* This is the next larger number in the range [0, 2^[MAX_BIT_SIZE]] with wrap around the largest number in the interval.
|
||||
*/
|
||||
fun calculateInitiatedSessionId(): SessionId {
|
||||
return if (this.value == LARGEST_SESSION_ID)
|
||||
SessionId(BigInteger.ZERO)
|
||||
else
|
||||
SessionId(this.value.plus(BigInteger.ONE))
|
||||
}
|
||||
|
||||
companion object {
|
||||
fun createRandom(secureRandom: SecureRandom) = SessionId(secureRandom.nextLong())
|
||||
const val MAX_BIT_SIZE = 128
|
||||
val LARGEST_SESSION_ID = BigInteger.valueOf(2).pow(MAX_BIT_SIZE).minus(BigInteger.ONE)
|
||||
|
||||
fun createRandom(secureRandom: SecureRandom) = SessionId(BigInteger(MAX_BIT_SIZE, secureRandom))
|
||||
}
|
||||
}
|
||||
|
||||
@ -118,3 +138,29 @@ data class RejectSessionMessage(val message: String, val errorId: Long) : Existi
|
||||
* protocols don't match up, e.g. one is waiting for the other, but the other side has already finished.
|
||||
*/
|
||||
object EndSessionMessage : ExistingSessionMessagePayload()
|
||||
|
||||
enum class MessageType {
|
||||
SESSION_INIT,
|
||||
SESSION_CONFIRM,
|
||||
SESSION_REJECT,
|
||||
DATA_MESSAGE,
|
||||
SESSION_END,
|
||||
SESSION_ERROR;
|
||||
|
||||
companion object {
|
||||
fun inferFromMessage(message: SessionMessage): MessageType {
|
||||
return when (message) {
|
||||
is InitialSessionMessage -> SESSION_INIT
|
||||
is ExistingSessionMessage -> {
|
||||
when(message.payload) {
|
||||
is ConfirmSessionMessage -> SESSION_CONFIRM
|
||||
is RejectSessionMessage -> SESSION_REJECT
|
||||
is DataSessionMessage -> DATA_MESSAGE
|
||||
is EndSessionMessage -> SESSION_END
|
||||
is ErrorSessionMessage -> SESSION_ERROR
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -37,6 +37,10 @@ import net.corda.node.internal.InitiatedFlowFactory
|
||||
import net.corda.node.services.api.CheckpointStorage
|
||||
import net.corda.node.services.api.ServiceHubInternal
|
||||
import net.corda.node.services.messaging.DeduplicationHandler
|
||||
import net.corda.node.services.messaging.MessageIdentifier
|
||||
import net.corda.node.services.messaging.SenderDeduplicationInfo
|
||||
import net.corda.node.services.messaging.SenderSequenceNumber
|
||||
import net.corda.node.services.messaging.SenderUUID
|
||||
import net.corda.node.services.statemachine.FlowStateMachineImpl.Companion.currentStateMachine
|
||||
import net.corda.node.services.statemachine.interceptors.DumpHistoryOnErrorInterceptor
|
||||
import net.corda.node.services.statemachine.interceptors.HospitalisingInterceptor
|
||||
@ -703,7 +707,7 @@ internal class SingleThreadedStateMachineManager(
|
||||
val sender = serviceHub.networkMapCache.getPeerByLegalName(peer)
|
||||
if (sender != null) {
|
||||
when (sessionMessage) {
|
||||
is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, sender, event)
|
||||
is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, sender, event, event.receivedMessage.uniqueMessageId, event.receivedMessage.senderUUID, event.receivedMessage.senderSeqNo)
|
||||
is InitialSessionMessage -> onSessionInit(sessionMessage, sender, event)
|
||||
}
|
||||
} else {
|
||||
@ -716,7 +720,10 @@ internal class SingleThreadedStateMachineManager(
|
||||
private fun onExistingSessionMessage(
|
||||
sessionMessage: ExistingSessionMessage,
|
||||
sender: Party,
|
||||
externalEvent: ExternalEvent.ExternalMessageEvent
|
||||
externalEvent: ExternalEvent.ExternalMessageEvent,
|
||||
messageIdentifier: MessageIdentifier,
|
||||
senderUUID: SenderUUID?,
|
||||
senderSequenceNumber: SenderSequenceNumber?
|
||||
) {
|
||||
try {
|
||||
val deduplicationHandler = externalEvent.deduplicationHandler
|
||||
@ -734,7 +741,7 @@ internal class SingleThreadedStateMachineManager(
|
||||
logger.info("Cannot find flow corresponding to session ID - $recipientId.")
|
||||
}
|
||||
} else {
|
||||
val event = Event.DeliverSessionMessage(sessionMessage, deduplicationHandler, sender)
|
||||
val event = Event.DeliverSessionMessage(sessionMessage, deduplicationHandler, sender, messageIdentifier, senderUUID, senderSequenceNumber)
|
||||
innerState.withLock {
|
||||
flows[flowId]?.run { fiber.scheduleEvent(event) }
|
||||
// If flow is not running add it to the list of external events to be processed if/when the flow resumes.
|
||||
@ -751,7 +758,7 @@ internal class SingleThreadedStateMachineManager(
|
||||
private fun onSessionInit(sessionMessage: InitialSessionMessage, sender: Party, event: ExternalEvent.ExternalMessageEvent) {
|
||||
try {
|
||||
val initiatedFlowFactory = getInitiatedFlowFactory(sessionMessage)
|
||||
val initiatedSessionId = SessionId.createRandom(secureRandom)
|
||||
val initiatedSessionId = event.receivedMessage.uniqueMessageId.sessionIdentifier
|
||||
val senderSession = FlowSessionImpl(sender, sender, initiatedSessionId)
|
||||
val flowLogic = initiatedFlowFactory.createFlow(senderSession)
|
||||
val initiatedFlowInfo = when (initiatedFlowFactory) {
|
||||
@ -763,9 +770,8 @@ internal class SingleThreadedStateMachineManager(
|
||||
is InitiatedFlowFactory.CorDapp -> null
|
||||
}
|
||||
startInitiatedFlow(
|
||||
event.flowId,
|
||||
event,
|
||||
flowLogic,
|
||||
event.deduplicationHandler,
|
||||
senderSession,
|
||||
initiatedSessionId,
|
||||
sessionMessage,
|
||||
@ -800,24 +806,24 @@ internal class SingleThreadedStateMachineManager(
|
||||
|
||||
@Suppress("LongParameterList")
|
||||
private fun <A> startInitiatedFlow(
|
||||
flowId: StateMachineRunId,
|
||||
event: ExternalEvent.ExternalMessageEvent,
|
||||
flowLogic: FlowLogic<A>,
|
||||
initiatingMessageDeduplicationHandler: DeduplicationHandler,
|
||||
peerSession: FlowSessionImpl,
|
||||
initiatedSessionId: SessionId,
|
||||
initiatingMessage: InitialSessionMessage,
|
||||
senderCoreFlowVersion: Int?,
|
||||
initiatedFlowInfo: FlowInfo
|
||||
) {
|
||||
val flowStart = FlowStart.Initiated(peerSession, initiatedSessionId, initiatingMessage, senderCoreFlowVersion, initiatedFlowInfo)
|
||||
val flowStart = FlowStart.Initiated(peerSession, initiatedSessionId, initiatingMessage, senderCoreFlowVersion, initiatedFlowInfo,
|
||||
event.receivedMessage.uniqueMessageId.shardIdentifier, event.receivedMessage.senderUUID, event.receivedMessage.senderSeqNo)
|
||||
val ourIdentity = ourFirstIdentity
|
||||
startFlowInternal(
|
||||
flowId,
|
||||
event.flowId,
|
||||
InvocationContext.peer(peerSession.counterparty.name),
|
||||
flowLogic,
|
||||
flowStart,
|
||||
ourIdentity,
|
||||
initiatingMessageDeduplicationHandler
|
||||
event.deduplicationHandler
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -21,6 +21,8 @@ import net.corda.core.utilities.debug
|
||||
import net.corda.core.utilities.minutes
|
||||
import net.corda.core.utilities.seconds
|
||||
import net.corda.node.services.FinalityHandler
|
||||
import net.corda.node.services.messaging.MessageIdentifier
|
||||
import net.corda.node.services.messaging.SenderDeduplicationInfo
|
||||
import org.hibernate.exception.ConstraintViolationException
|
||||
import rx.subjects.PublishSubject
|
||||
import java.io.Closeable
|
||||
@ -169,7 +171,9 @@ class StaffedFlowHospital(private val flowMessaging: FlowMessaging,
|
||||
|
||||
log.info("Sending session initiation error back to $sender", error)
|
||||
|
||||
flowMessaging.sendSessionMessage(sender, replyError, SenderDeduplicationId(DeduplicationId.createRandom(secureRandom), ourSenderUUID))
|
||||
val messageType = MessageType.inferFromMessage(replyError)
|
||||
val messageIdentifier = MessageIdentifier(messageType, event.receivedMessage.uniqueMessageId.shardIdentifier, sessionMessage.initiatorSessionId, 0, event.receivedMessage.uniqueMessageId.timestamp)
|
||||
flowMessaging.sendSessionMessage(sender, replyError, SenderDeduplicationInfo(messageIdentifier, ourSenderUUID))
|
||||
event.deduplicationHandler.afterDatabaseTransaction()
|
||||
}
|
||||
|
||||
|
@ -22,11 +22,16 @@ import net.corda.core.serialization.internal.CheckpointSerializationContext
|
||||
import net.corda.core.serialization.internal.checkpointDeserialize
|
||||
import net.corda.core.utilities.Try
|
||||
import net.corda.node.services.messaging.DeduplicationHandler
|
||||
import net.corda.node.services.messaging.MessageIdentifier
|
||||
import net.corda.node.services.messaging.SenderSequenceNumber
|
||||
import net.corda.node.services.messaging.SenderUUID
|
||||
import java.lang.IllegalArgumentException
|
||||
import java.lang.IllegalStateException
|
||||
import java.security.Principal
|
||||
import java.time.Instant
|
||||
import java.util.concurrent.Future
|
||||
import java.util.concurrent.Semaphore
|
||||
import kotlin.math.max
|
||||
|
||||
/**
|
||||
* The state of the state machine, capturing the state of a flow. It consists of two parts, an *immutable* part that is
|
||||
@ -35,6 +40,7 @@ import java.util.concurrent.Semaphore
|
||||
* @param checkpoint the persisted part of the state.
|
||||
* @param flowLogic the [FlowLogic] associated with the flow. Note that this is mutable by the user.
|
||||
* @param pendingDeduplicationHandlers the list of incomplete deduplication handlers.
|
||||
* @param closedSessionsPendingToBeSignalled the sessions that have been closed and need to be signalled to the messaging layer on the next checkpoint (along with some metadata).
|
||||
* @param isFlowResumed true if the control is returned (or being returned) to "user-space" flow code. This is used
|
||||
* to make [Event.DoRemainingWork] idempotent.
|
||||
* @param isWaitingForFuture true if the flow is waiting for the completion of a future triggered by one of the statemachine's actions
|
||||
@ -61,6 +67,7 @@ data class StateMachineState(
|
||||
val checkpoint: Checkpoint,
|
||||
val flowLogic: FlowLogic<*>,
|
||||
val pendingDeduplicationHandlers: List<DeduplicationHandler>,
|
||||
val closedSessionsPendingToBeSignalled: Map<SessionId, Pair<SenderUUID?, SenderSequenceNumber?>>,
|
||||
val isFlowResumed: Boolean,
|
||||
val isWaitingForFuture: Boolean,
|
||||
var future: Future<*>?,
|
||||
@ -123,7 +130,8 @@ data class Checkpoint(
|
||||
frozenFlowLogic: SerializedBytes<FlowLogic<*>>,
|
||||
ourIdentity: Party,
|
||||
subFlowVersion: SubFlowVersion,
|
||||
isEnabledTimedFlow: Boolean
|
||||
isEnabledTimedFlow: Boolean,
|
||||
timestamp: Instant
|
||||
): Try<Checkpoint> {
|
||||
return SubFlow.create(flowLogicClass, subFlowVersion, isEnabledTimedFlow).map { topLevelSubFlow ->
|
||||
Checkpoint(
|
||||
@ -135,7 +143,8 @@ data class Checkpoint(
|
||||
listOf(topLevelSubFlow),
|
||||
numberOfSuspends = 0,
|
||||
// We set this to 1 here to avoid an extra copy and increment in UnstartedFlowTransition.createInitialCheckpoint
|
||||
numberOfCommits = 1
|
||||
numberOfCommits = 1,
|
||||
suspensionTime = timestamp
|
||||
),
|
||||
flowState = FlowState.Unstarted(flowStart, frozenFlowLogic),
|
||||
errorState = ErrorState.Clean
|
||||
@ -235,6 +244,7 @@ data class Checkpoint(
|
||||
}
|
||||
|
||||
/**
|
||||
<<<<<<< HEAD
|
||||
* @param invocationContext The initiator of the flow.
|
||||
* @param ourIdentity The identity the flow is run as.
|
||||
* @param sessions Map of source session ID to session state.
|
||||
@ -242,6 +252,7 @@ data class Checkpoint(
|
||||
* @param subFlowStack The stack of currently executing subflows.
|
||||
* @param numberOfSuspends The number of flow suspends due to IO API calls.
|
||||
* @param numberOfCommits The number of times this checkpoint has been persisted.
|
||||
* @param suspensionTime the time of the last suspension. This is supposed to be used as a stable timestamp in case of replays.
|
||||
*/
|
||||
@CordaSerializable
|
||||
data class CheckpointState(
|
||||
@ -251,7 +262,8 @@ data class CheckpointState(
|
||||
val sessionsToBeClosed: Set<SessionId>,
|
||||
val subFlowStack: List<SubFlow>,
|
||||
val numberOfSuspends: Int,
|
||||
val numberOfCommits: Int
|
||||
val numberOfCommits: Int,
|
||||
val suspensionTime: Instant
|
||||
)
|
||||
|
||||
/**
|
||||
@ -262,44 +274,162 @@ sealed class SessionState {
|
||||
abstract val deduplicationSeed: String
|
||||
|
||||
/**
|
||||
* We haven't yet sent the initialisation message
|
||||
* the sender UUID last seen in this session, if there was one.
|
||||
*/
|
||||
abstract val lastSenderUUID: SenderUUID?
|
||||
|
||||
/**
|
||||
* the sender sequence number last seen in this session, if there was one.
|
||||
*/
|
||||
abstract val lastSenderSeqNo: SenderSequenceNumber?
|
||||
|
||||
/**
|
||||
* the messages that have been received and are pending processing indexed by their sequence number.
|
||||
* this could be any [ExistingSessionMessagePayload] type in theory, but it in practice it can only be one of the following types now:
|
||||
* * [DataSessionMessage]
|
||||
* * [ErrorSessionMessage]
|
||||
* * [EndSessionMessage]
|
||||
*/
|
||||
abstract val receivedMessages: Map<Int, ExistingSessionMessagePayload>
|
||||
|
||||
/**
|
||||
* Returns a new session state with the specified messages added to the list of received messages.
|
||||
*/
|
||||
fun addReceivedMessages(message: ExistingSessionMessagePayload, messageIdentifier: MessageIdentifier, senderUUID: String?, senderSequenceNumber: Long?): SessionState {
|
||||
val newReceivedMessages = receivedMessages.plus(messageIdentifier.sessionSequenceNumber to message)
|
||||
val (newLastSenderUUID, newLastSenderSeqNo) = calculateSenderInfo(lastSenderUUID, lastSenderSeqNo, senderUUID, senderSequenceNumber)
|
||||
return when(this) {
|
||||
is Uninitiated -> { copy(receivedMessages = newReceivedMessages, lastSenderUUID = newLastSenderUUID, lastSenderSeqNo = newLastSenderSeqNo) }
|
||||
is Initiating -> { copy(receivedMessages = newReceivedMessages, lastSenderUUID = newLastSenderUUID, lastSenderSeqNo = newLastSenderSeqNo) }
|
||||
is Initiated -> { copy(receivedMessages = newReceivedMessages, lastSenderUUID = newLastSenderUUID, lastSenderSeqNo = newLastSenderSeqNo) }
|
||||
}
|
||||
}
|
||||
|
||||
private fun calculateSenderInfo(currentSender: String?, currentSenderSeqNo: Long?, msgSender: String?, msgSenderSeqNo: Long?): Pair<String?, Long?> {
|
||||
return if (msgSender != null && msgSenderSeqNo != null) {
|
||||
if (currentSenderSeqNo != null)
|
||||
Pair(msgSender, max(msgSenderSeqNo, currentSenderSeqNo))
|
||||
else
|
||||
Pair(msgSender, msgSenderSeqNo)
|
||||
} else {
|
||||
Pair(currentSender, currentSenderSeqNo)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* We haven't yet sent the initialisation message.
|
||||
* This really means that the flow is in a state before sending the initialisation message,
|
||||
* but in reality it could have sent it before and fail before reaching the next checkpoint, thus ending up replaying from the last checkpoint.
|
||||
*
|
||||
* @param hasBeenAcknowledged whether a positive response to a session initiation has already been received and the associated confirmation message, if so.
|
||||
* @param hasBeenRejected whether a negative response to a session initiation has already been received and the associated rejection message, if so.
|
||||
*/
|
||||
data class Uninitiated(
|
||||
val destination: Destination,
|
||||
val initiatingSubFlow: SubFlow.Initiating,
|
||||
val sourceSessionId: SessionId,
|
||||
val additionalEntropy: Long
|
||||
val additionalEntropy: Long,
|
||||
val hasBeenAcknowledged: Pair<Party, ConfirmSessionMessage>?,
|
||||
val hasBeenRejected: RejectSessionMessage?,
|
||||
override val receivedMessages: Map<Int, ExistingSessionMessagePayload>,
|
||||
override val lastSenderUUID: String?,
|
||||
override val lastSenderSeqNo: Long?
|
||||
) : SessionState() {
|
||||
override val deduplicationSeed: String get() = "R-${sourceSessionId.toLong}-$additionalEntropy"
|
||||
override val deduplicationSeed: String get() = "R-${sourceSessionId.value}-$additionalEntropy"
|
||||
}
|
||||
|
||||
/**
|
||||
* We have sent the initialisation message but have not yet received a confirmation.
|
||||
* @property bufferedMessages the messages that have been buffered to be sent after the session is confirmed from the other side.
|
||||
* @property rejectionError if non-null the initiation failed.
|
||||
* @property nextSendingSeqNumber the sequence number of the next message to be sent.
|
||||
* @property shardId the shard ID of the associated flow to be embedded on all the messages sent from this session.
|
||||
*/
|
||||
data class Initiating(
|
||||
val bufferedMessages: List<Pair<DeduplicationId, ExistingSessionMessagePayload>>,
|
||||
val bufferedMessages: List<Pair<MessageIdentifier, ExistingSessionMessagePayload>>,
|
||||
val rejectionError: FlowError?,
|
||||
override val deduplicationSeed: String
|
||||
) : SessionState()
|
||||
override val deduplicationSeed: String,
|
||||
val nextSendingSeqNumber: Int,
|
||||
val shardId: String,
|
||||
override val receivedMessages: Map<Int, ExistingSessionMessagePayload>,
|
||||
override val lastSenderUUID: String?,
|
||||
override val lastSenderSeqNo: Long?
|
||||
) : SessionState() {
|
||||
|
||||
/**
|
||||
* Buffers an outgoing message to be sent when ready.
|
||||
* Returns the new form of the state
|
||||
*/
|
||||
fun bufferMessage(messageIdentifier: MessageIdentifier, messagePayload: ExistingSessionMessagePayload): SessionState {
|
||||
return this.copy(bufferedMessages = bufferedMessages + Pair(messageIdentifier, messagePayload), nextSendingSeqNumber = nextSendingSeqNumber + 1)
|
||||
}
|
||||
|
||||
/**
|
||||
* A batched form of [bufferMessage].
|
||||
*/
|
||||
fun bufferMessages(messages: List<Pair<MessageIdentifier, ExistingSessionMessagePayload>>): SessionState {
|
||||
return this.copy(bufferedMessages = bufferedMessages + messages, nextSendingSeqNumber = nextSendingSeqNumber + messages.size)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* We have received a confirmation, the peer party and session id is resolved.
|
||||
* @property receivedMessages the messages that have been received and are pending processing.
|
||||
* this could be any [ExistingSessionMessagePayload] type in theory, but it in practice it can only be one of the following types now:
|
||||
* * [DataSessionMessage]
|
||||
* * [ErrorSessionMessage]
|
||||
* * [EndSessionMessage]
|
||||
* @property otherSideErrored whether the session has received an error from the other side.
|
||||
* @property nextSendingSeqNumber the sequence number that corresponds to the next message to be sent.
|
||||
* @property lastProcessedSeqNumber the sequence number of the last message that has been processed.
|
||||
* @property shardId the shard ID of the associated flow to be embedded on all the messages sent from this session.
|
||||
*/
|
||||
data class Initiated(
|
||||
val peerParty: Party,
|
||||
val peerFlowInfo: FlowInfo,
|
||||
val receivedMessages: List<ExistingSessionMessagePayload>,
|
||||
val otherSideErrored: Boolean,
|
||||
val peerSinkSessionId: SessionId,
|
||||
override val deduplicationSeed: String
|
||||
) : SessionState()
|
||||
override val deduplicationSeed: String,
|
||||
val nextSendingSeqNumber: Int,
|
||||
val lastProcessedSeqNumber: Int,
|
||||
val shardId: String,
|
||||
override val receivedMessages: Map<Int, ExistingSessionMessagePayload>,
|
||||
override val lastSenderUUID: String?,
|
||||
override val lastSenderSeqNo: Long?
|
||||
) : SessionState() {
|
||||
|
||||
/**
|
||||
* Indicates whether this message has already been processed.
|
||||
*/
|
||||
fun isDuplicate(messageIdentifier: MessageIdentifier): Boolean {
|
||||
return messageIdentifier.sessionSequenceNumber <= lastProcessedSeqNumber
|
||||
}
|
||||
|
||||
/**
|
||||
* Indicates whether the session has an error message pending from the other side.
|
||||
*/
|
||||
fun hasErrored(): Boolean {
|
||||
return hasNextMessageArrived() && receivedMessages[lastProcessedSeqNumber + 1] is ErrorSessionMessage
|
||||
}
|
||||
|
||||
/**
|
||||
* Indicates whether the next expected message has arrived.
|
||||
*/
|
||||
fun hasNextMessageArrived(): Boolean {
|
||||
return receivedMessages.containsKey(lastProcessedSeqNumber + 1)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the next message to be processed and the new session state.
|
||||
* If you want to check first whether the next message has arrived, call [hasNextMessageArrived]
|
||||
*
|
||||
* @throws [IllegalArgumentException] if the next hasn't arrived.
|
||||
*/
|
||||
fun extractMessage(): Pair<ExistingSessionMessagePayload, Initiated> {
|
||||
if (!hasNextMessageArrived()) {
|
||||
throw IllegalArgumentException("Tried to extract a message that hasn't arrived yet.")
|
||||
}
|
||||
|
||||
val message = receivedMessages[lastProcessedSeqNumber + 1]!!
|
||||
val newState = this.copy(receivedMessages = receivedMessages.minus(lastProcessedSeqNumber + 1), lastProcessedSeqNumber = lastProcessedSeqNumber + 1)
|
||||
return Pair(message, newState)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
typealias SessionMap = Map<SessionId, SessionState>
|
||||
@ -321,7 +451,10 @@ sealed class FlowStart {
|
||||
val initiatedSessionId: SessionId,
|
||||
val initiatingMessage: InitialSessionMessage,
|
||||
val senderCoreFlowVersion: Int?,
|
||||
val initiatedFlowInfo: FlowInfo
|
||||
val initiatedFlowInfo: FlowInfo,
|
||||
val shardIdentifier: String,
|
||||
val senderUUID: String?,
|
||||
val senderSequenceNumber: Long?
|
||||
) : FlowStart() { override fun toString() = "Initiated" }
|
||||
}
|
||||
|
||||
|
@ -3,6 +3,7 @@ package net.corda.node.services.statemachine.transitions
|
||||
import net.corda.core.flows.UnexpectedFlowEndException
|
||||
import net.corda.core.utilities.contextLogger
|
||||
import net.corda.core.utilities.debug
|
||||
import net.corda.core.utilities.trace
|
||||
import net.corda.node.services.statemachine.Action
|
||||
import net.corda.node.services.statemachine.ConfirmSessionMessage
|
||||
import net.corda.node.services.statemachine.DataSessionMessage
|
||||
@ -13,7 +14,7 @@ import net.corda.node.services.statemachine.ExistingSessionMessage
|
||||
import net.corda.node.services.statemachine.FlowError
|
||||
import net.corda.node.services.statemachine.FlowState
|
||||
import net.corda.node.services.statemachine.RejectSessionMessage
|
||||
import net.corda.node.services.statemachine.SenderDeduplicationId
|
||||
import net.corda.node.services.messaging.SenderDeduplicationInfo
|
||||
import net.corda.node.services.statemachine.SessionState
|
||||
import net.corda.node.services.statemachine.StateMachineState
|
||||
|
||||
@ -87,56 +88,78 @@ class DeliverSessionMessageTransition(
|
||||
val initiatedSession = SessionState.Initiated(
|
||||
peerParty = event.sender,
|
||||
peerFlowInfo = message.initiatedFlowInfo,
|
||||
receivedMessages = emptyList(),
|
||||
receivedMessages = emptyMap(),
|
||||
peerSinkSessionId = message.initiatedSessionId,
|
||||
deduplicationSeed = sessionState.deduplicationSeed,
|
||||
otherSideErrored = false
|
||||
otherSideErrored = false,
|
||||
nextSendingSeqNumber = sessionState.nextSendingSeqNumber,
|
||||
lastProcessedSeqNumber = 0,
|
||||
shardId = sessionState.shardId,
|
||||
lastSenderUUID = event.senderUUID,
|
||||
lastSenderSeqNo = event.senderSequenceNumber
|
||||
)
|
||||
val newCheckpoint = currentState.checkpoint.addSession(
|
||||
event.sessionMessage.recipientSessionId to initiatedSession
|
||||
)
|
||||
// Send messages that were buffered pending confirmation of session.
|
||||
val sendActions = sessionState.bufferedMessages.map { (deduplicationId, bufferedMessage) ->
|
||||
val sendActions = sessionState.bufferedMessages.map { (messageId, bufferedMessage) ->
|
||||
val existingMessage = ExistingSessionMessage(message.initiatedSessionId, bufferedMessage)
|
||||
Action.SendExisting(initiatedSession.peerParty, existingMessage, SenderDeduplicationId(deduplicationId, startingState.senderUUID))
|
||||
Action.SendExisting(initiatedSession.peerParty, existingMessage, SenderDeduplicationInfo(messageId, startingState.senderUUID))
|
||||
}
|
||||
actions.addAll(sendActions)
|
||||
currentState = currentState.copy(checkpoint = newCheckpoint)
|
||||
}
|
||||
else -> freshErrorTransition(UnexpectedEventInState())
|
||||
is SessionState.Initiated -> {
|
||||
log.trace { "Discarding duplicate confirmation for session ${event.sessionMessage.recipientSessionId} with ${sessionState.peerParty}" }
|
||||
}
|
||||
is SessionState.Uninitiated -> {
|
||||
val newSessionState = sessionState.copy(hasBeenAcknowledged = Pair(event.sender, message))
|
||||
val newCheckpoint = currentState.checkpoint.addSession(event.sessionMessage.recipientSessionId to newSessionState)
|
||||
currentState = currentState.copy(checkpoint = newCheckpoint)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun TransitionBuilder.dataMessageTransition(sessionState: SessionState, message: DataSessionMessage) {
|
||||
// We received a data message. The corresponding session must be Initiated.
|
||||
return when (sessionState) {
|
||||
is SessionState.Initiated -> {
|
||||
// Buffer the message in the session's receivedMessages buffer.
|
||||
val newSessionState = sessionState.copy(
|
||||
receivedMessages = sessionState.receivedMessages + message
|
||||
)
|
||||
|
||||
if (!sessionState.isDuplicate(event.messageIdentifier)) {
|
||||
val newSessionState = sessionState.addReceivedMessages(message, event.messageIdentifier, event.senderUUID, event.senderSequenceNumber)
|
||||
currentState = currentState.copy(
|
||||
checkpoint = currentState.checkpoint.addSession(
|
||||
event.sessionMessage.recipientSessionId to newSessionState
|
||||
checkpoint = currentState.checkpoint.addSession(event.sessionMessage.recipientSessionId to newSessionState)
|
||||
)
|
||||
} else {
|
||||
log.trace { "Discarding duplicate data message for session ${event.sessionMessage.recipientSessionId} with ${sessionState.peerParty}" }
|
||||
}
|
||||
}
|
||||
is SessionState.Initiating, is SessionState.Uninitiated -> {
|
||||
val newSessionState = sessionState.addReceivedMessages(message, event.messageIdentifier, event.senderUUID, event.senderSequenceNumber)
|
||||
currentState = currentState.copy(
|
||||
checkpoint = currentState.checkpoint.addSession(event.sessionMessage.recipientSessionId to newSessionState)
|
||||
)
|
||||
}
|
||||
else -> freshErrorTransition(UnexpectedEventInState())
|
||||
}
|
||||
}
|
||||
|
||||
private fun TransitionBuilder.errorMessageTransition(sessionState: SessionState, payload: ErrorSessionMessage) {
|
||||
val sequenceNumber = event.messageIdentifier.sessionSequenceNumber
|
||||
return when (sessionState) {
|
||||
is SessionState.Initiated -> {
|
||||
val checkpoint = currentState.checkpoint
|
||||
val sessionId = event.sessionMessage.recipientSessionId
|
||||
val newSessionState = sessionState.copy(receivedMessages = sessionState.receivedMessages + payload)
|
||||
if (sequenceNumber > sessionState.lastProcessedSeqNumber) {
|
||||
val newSessionState = sessionState.addReceivedMessages(payload, event.messageIdentifier, event.senderUUID, event.senderSequenceNumber)
|
||||
currentState = currentState.copy(
|
||||
checkpoint = checkpoint.addSession(sessionId to newSessionState)
|
||||
checkpoint = currentState.checkpoint.addSession(event.sessionMessage.recipientSessionId to newSessionState)
|
||||
)
|
||||
} else {
|
||||
log.trace { "Discarding duplicate error message for session ${event.sessionMessage.recipientSessionId} with ${sessionState.peerParty}" }
|
||||
}
|
||||
}
|
||||
is SessionState.Initiating, is SessionState.Uninitiated -> {
|
||||
val newSessionState = sessionState.addReceivedMessages(payload, event.messageIdentifier, event.senderUUID, event.senderSequenceNumber)
|
||||
currentState = currentState.copy(
|
||||
checkpoint = currentState.checkpoint.addSession(event.sessionMessage.recipientSessionId to newSessionState)
|
||||
)
|
||||
}
|
||||
else -> freshErrorTransition(UnexpectedEventInState())
|
||||
}
|
||||
}
|
||||
|
||||
@ -145,42 +168,42 @@ class DeliverSessionMessageTransition(
|
||||
return when (sessionState) {
|
||||
is SessionState.Initiating -> {
|
||||
if (sessionState.rejectionError != null) {
|
||||
// Double reject
|
||||
freshErrorTransition(UnexpectedEventInState())
|
||||
log.trace { "Discarding duplicate session rejection message for session ${event.sessionMessage.recipientSessionId}" }
|
||||
} else {
|
||||
val checkpoint = currentState.checkpoint
|
||||
val sessionId = event.sessionMessage.recipientSessionId
|
||||
val flowError = FlowError(payload.errorId, exception)
|
||||
currentState = currentState.copy(
|
||||
checkpoint = checkpoint.addSession(sessionId to sessionState.copy(rejectionError = flowError))
|
||||
checkpoint = currentState.checkpoint.addSession(event.sessionMessage.recipientSessionId to sessionState.copy(rejectionError = flowError))
|
||||
)
|
||||
}
|
||||
}
|
||||
else -> freshErrorTransition(UnexpectedEventInState())
|
||||
is SessionState.Uninitiated -> {
|
||||
currentState = currentState.copy(
|
||||
checkpoint = currentState.checkpoint.addSession(event.sessionMessage.recipientSessionId to sessionState.copy(hasBeenRejected = payload))
|
||||
)
|
||||
}
|
||||
is SessionState.Initiated -> {
|
||||
freshErrorTransition(UnexpectedEventInState("A session rejection message was received for an already established session ${event.messageIdentifier.sessionIdentifier}."))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun TransitionBuilder.endMessageTransition(payload: EndSessionMessage) {
|
||||
val flowState = currentState.checkpoint.flowState
|
||||
// flow must have already been started when session end messages are being delivered.
|
||||
if (flowState !is FlowState.Started)
|
||||
return freshErrorTransition(UnexpectedEventInState())
|
||||
|
||||
val sessionId = event.sessionMessage.recipientSessionId
|
||||
val sessions = currentState.checkpoint.checkpointState.sessions
|
||||
// a check has already been performed to confirm the session exists for this message before this method is invoked.
|
||||
val sessionState = sessions[sessionId]!!
|
||||
when (sessionState) {
|
||||
is SessionState.Initiated -> {
|
||||
val flowState = currentState.checkpoint.flowState
|
||||
// flow must have already been started when session end messages are being delivered.
|
||||
if (flowState !is FlowState.Started)
|
||||
return freshErrorTransition(UnexpectedEventInState())
|
||||
|
||||
val newSessionState = sessionState.copy(receivedMessages = sessionState.receivedMessages + payload)
|
||||
is SessionState.Initiated, is SessionState.Initiating, is SessionState.Uninitiated -> {
|
||||
val newSessionState = sessionState.addReceivedMessages(payload, event.messageIdentifier, event.senderUUID, event.senderSequenceNumber)
|
||||
val newCheckpoint = currentState.checkpoint.addSession(event.sessionMessage.recipientSessionId to newSessionState)
|
||||
.addSessionsToBeClosed(setOf(event.sessionMessage.recipientSessionId))
|
||||
currentState = currentState.copy(checkpoint = newCheckpoint)
|
||||
}
|
||||
else -> {
|
||||
freshErrorTransition(PrematureSessionEndException(event.sessionMessage.recipientSessionId))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
package net.corda.node.services.statemachine.transitions
|
||||
|
||||
import net.corda.core.flows.FlowException
|
||||
import net.corda.node.services.messaging.MessageIdentifier
|
||||
import net.corda.node.services.statemachine.*
|
||||
|
||||
/**
|
||||
@ -40,16 +41,28 @@ class ErrorFlowTransition(
|
||||
return builder {
|
||||
// If we're errored and propagating do the actual propagation and update the index.
|
||||
if (remainingErrorsToPropagate.isNotEmpty() && errorState.propagating) {
|
||||
val (initiatedSessions, newSessions) = bufferErrorMessagesInInitiatingSessions(
|
||||
val (initiatedSessions, newSessionStates) = bufferErrorMessagesInInitiatingSessions(
|
||||
startingState.checkpoint.checkpointState.sessions,
|
||||
errorMessages
|
||||
)
|
||||
val sessionsWithAdvancedSeqNumbers = mutableMapOf<SessionId, SessionState>()
|
||||
val errorsPerSession = initiatedSessions.map { (sessionId, sessionState) ->
|
||||
var currentSeqNumber = sessionState.nextSendingSeqNumber
|
||||
val errorsWithId = errorMessages.map { errorMsg ->
|
||||
val messageIdentifier = MessageIdentifier(MessageType.SESSION_ERROR, sessionState.shardId, sessionState.peerSinkSessionId, currentSeqNumber, startingState.checkpoint.checkpointState.suspensionTime)
|
||||
currentSeqNumber++
|
||||
Pair(messageIdentifier, errorMsg)
|
||||
}.toList()
|
||||
sessionsWithAdvancedSeqNumbers[sessionId] = sessionState.copy(nextSendingSeqNumber = currentSeqNumber)
|
||||
Pair(sessionState, errorsWithId)
|
||||
}.toMap()
|
||||
|
||||
val newCheckpoint = startingState.checkpoint.copy(
|
||||
errorState = errorState.copy(propagatedIndex = allErrors.size),
|
||||
checkpointState = startingState.checkpoint.checkpointState.copy(sessions = newSessions)
|
||||
checkpointState = startingState.checkpoint.checkpointState.copy(sessions = newSessionStates + sessionsWithAdvancedSeqNumbers)
|
||||
)
|
||||
currentState = currentState.copy(checkpoint = newCheckpoint)
|
||||
actions += Action.PropagateErrors(errorMessages, initiatedSessions, startingState.senderUUID)
|
||||
actions += Action.PropagateErrors(errorsPerSession, startingState.senderUUID)
|
||||
}
|
||||
|
||||
// If we're errored but not propagating keep processing events.
|
||||
@ -81,16 +94,27 @@ class ErrorFlowTransition(
|
||||
isCheckpointUpdate = currentState.isAnyCheckpointPersisted
|
||||
)
|
||||
}
|
||||
val signalSessionsEndMap = currentState.checkpoint.checkpointState.sessions.map { (sessionId, sessionState) ->
|
||||
sessionId to Pair(sessionState.lastSenderUUID, sessionState.lastSenderSeqNo)
|
||||
}.toMap()
|
||||
|
||||
actions += Action.CreateTransaction
|
||||
actions += removeOrPersistCheckpoint
|
||||
actions += Action.PersistDeduplicationFacts(startingState.pendingDeduplicationHandlers)
|
||||
actions += Action.SignalSessionsHasEnded(signalSessionsEndMap)
|
||||
actions += Action.ReleaseSoftLocks(context.id.uuid)
|
||||
actions += Action.CommitTransaction(currentState)
|
||||
actions += Action.AcknowledgeMessages(startingState.pendingDeduplicationHandlers)
|
||||
actions += Action.RemoveSessionBindings(startingState.checkpoint.checkpointState.sessions.keys)
|
||||
actions += Action.RemoveFlow(context.id, FlowRemovalReason.ErrorFinish(allErrors), currentState)
|
||||
|
||||
currentState = currentState.copy(
|
||||
checkpoint = newCheckpoint,
|
||||
pendingDeduplicationHandlers = emptyList(),
|
||||
closedSessionsPendingToBeSignalled = emptyMap(),
|
||||
isRemoved = true
|
||||
)
|
||||
|
||||
FlowContinuation.Abort
|
||||
} else {
|
||||
// Otherwise keep processing events. This branch happens when there are some outstanding initiating
|
||||
@ -112,31 +136,37 @@ class ErrorFlowTransition(
|
||||
}
|
||||
}
|
||||
|
||||
// Buffer error messages in Initiating sessions, return the initialised ones.
|
||||
/**
|
||||
* Buffers errors message for initiating states and filters the initiated states.
|
||||
* Returns a pair that consists of:
|
||||
* - a map containing the initiated states as filtered from the ones provided as input.
|
||||
* - a map containing the new state of all the sessions.
|
||||
*/
|
||||
private fun bufferErrorMessagesInInitiatingSessions(
|
||||
sessions: Map<SessionId, SessionState>,
|
||||
errorMessages: List<ErrorSessionMessage>
|
||||
): Pair<List<SessionState.Initiated>, Map<SessionId, SessionState>> {
|
||||
val newSessions = sessions.mapValues { (sourceSessionId, sessionState) ->
|
||||
): Pair<Map<SessionId, SessionState.Initiated>, Map<SessionId, SessionState>> {
|
||||
val newSessionStates = sessions.mapValues { (sourceSessionId, sessionState) ->
|
||||
if (sessionState is SessionState.Initiating && sessionState.rejectionError == null) {
|
||||
// *prepend* the error messages in order to error the other sessions ASAP. The other messages will
|
||||
// be delivered all the same, they just won't trigger flow resumption because of dirtiness.
|
||||
val errorMessagesWithDeduplication = errorMessages.map {
|
||||
DeduplicationId.createForError(it.errorId, sourceSessionId) to it
|
||||
var currentSequenceNumber = sessionState.nextSendingSeqNumber
|
||||
val errorMessagesWithDeduplication = errorMessages.map { errorMessage ->
|
||||
val messageIdentifier = MessageIdentifier(MessageType.SESSION_ERROR, sessionState.shardId, sourceSessionId.calculateInitiatedSessionId(), currentSequenceNumber, startingState.checkpoint.checkpointState.suspensionTime)
|
||||
currentSequenceNumber++
|
||||
messageIdentifier to errorMessage
|
||||
}
|
||||
sessionState.copy(bufferedMessages = errorMessagesWithDeduplication + sessionState.bufferedMessages)
|
||||
sessionState.bufferMessages(errorMessagesWithDeduplication)
|
||||
} else {
|
||||
sessionState
|
||||
}
|
||||
}
|
||||
// if we have already received error message from the other side, we don't include that session in the list to avoid propagating errors.
|
||||
val initiatedSessions = sessions.values.mapNotNull { session ->
|
||||
if (session is SessionState.Initiated && !session.otherSideErrored) {
|
||||
session
|
||||
val initiatedSessions = sessions.mapNotNull { (sessionId, sessionState) ->
|
||||
if (sessionState is SessionState.Initiated && !sessionState.otherSideErrored) {
|
||||
sessionId to sessionState
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
return Pair(initiatedSessions, newSessions)
|
||||
}.toMap()
|
||||
return Pair(initiatedSessions, newSessionStates)
|
||||
}
|
||||
}
|
||||
|
@ -2,14 +2,15 @@ package net.corda.node.services.statemachine.transitions
|
||||
|
||||
import net.corda.core.flows.FlowException
|
||||
import net.corda.core.flows.KilledFlowException
|
||||
import net.corda.node.services.messaging.MessageIdentifier
|
||||
import net.corda.node.services.statemachine.Action
|
||||
import net.corda.node.services.statemachine.Checkpoint
|
||||
import net.corda.node.services.statemachine.DeduplicationId
|
||||
import net.corda.node.services.statemachine.ErrorSessionMessage
|
||||
import net.corda.node.services.statemachine.Event
|
||||
import net.corda.node.services.statemachine.FlowError
|
||||
import net.corda.node.services.statemachine.FlowRemovalReason
|
||||
import net.corda.node.services.statemachine.FlowState
|
||||
import net.corda.node.services.statemachine.MessageType
|
||||
import net.corda.node.services.statemachine.SessionId
|
||||
import net.corda.node.services.statemachine.SessionState
|
||||
import net.corda.node.services.statemachine.StateMachineState
|
||||
@ -27,24 +28,37 @@ class KilledFlowTransition(
|
||||
val killedFlowErrorMessage = createErrorMessageFromError(killedFlowError)
|
||||
val errorMessages = listOf(killedFlowErrorMessage)
|
||||
|
||||
val (initiatedSessions, newSessions) = bufferErrorMessagesInInitiatingSessions(
|
||||
val (initiatedSessions, newSessionStates) = bufferErrorMessagesInInitiatingSessions(
|
||||
startingState.checkpoint.checkpointState.sessions,
|
||||
errorMessages
|
||||
)
|
||||
|
||||
val sessionsWithAdvancedSeqNumbers = mutableMapOf<SessionId, SessionState>()
|
||||
val errorsPerSession = initiatedSessions.map { (sessionId, sessionState) ->
|
||||
var currentSeqNumber = sessionState.nextSendingSeqNumber
|
||||
val errorsWithId = errorMessages.map { errorMsg ->
|
||||
val messageIdentifier = MessageIdentifier(MessageType.SESSION_ERROR, sessionState.shardId, sessionState.peerSinkSessionId, currentSeqNumber, startingState.checkpoint.checkpointState.suspensionTime)
|
||||
currentSeqNumber++
|
||||
Pair(messageIdentifier, errorMsg)
|
||||
}.toList()
|
||||
sessionsWithAdvancedSeqNumbers[sessionId] = sessionState.copy(nextSendingSeqNumber = currentSeqNumber)
|
||||
Pair(sessionState, errorsWithId)
|
||||
}.toMap()
|
||||
|
||||
val newCheckpoint = startingState.checkpoint.copy(
|
||||
status = Checkpoint.FlowStatus.KILLED,
|
||||
flowState = FlowState.Finished,
|
||||
checkpointState = startingState.checkpoint.checkpointState.copy(sessions = newSessions)
|
||||
checkpointState = startingState.checkpoint.checkpointState.copy(sessions = newSessionStates + sessionsWithAdvancedSeqNumbers)
|
||||
)
|
||||
|
||||
currentState = currentState.copy(
|
||||
checkpoint = newCheckpoint,
|
||||
pendingDeduplicationHandlers = emptyList(),
|
||||
closedSessionsPendingToBeSignalled = emptyMap(),
|
||||
isRemoved = true
|
||||
)
|
||||
|
||||
actions += Action.PropagateErrors(errorMessages, initiatedSessions, startingState.senderUUID)
|
||||
actions += Action.PropagateErrors(errorsPerSession, startingState.senderUUID)
|
||||
|
||||
if (!startingState.isFlowResumed) {
|
||||
actions += Action.CreateTransaction
|
||||
@ -59,7 +73,12 @@ class KilledFlowTransition(
|
||||
actions += Action.AddFlowException(context.id, killedFlowError.exception)
|
||||
}
|
||||
|
||||
val signalSessionsEndMap = currentState.checkpoint.checkpointState.sessions.map { (sessionId, sessionState) ->
|
||||
sessionId to Pair(sessionState.lastSenderUUID, sessionState.lastSenderSeqNo)
|
||||
}.toMap()
|
||||
|
||||
actions += Action.PersistDeduplicationFacts(startingState.pendingDeduplicationHandlers)
|
||||
actions += Action.SignalSessionsHasEnded(signalSessionsEndMap)
|
||||
actions += Action.ReleaseSoftLocks(context.id.uuid)
|
||||
actions += Action.CommitTransaction(currentState)
|
||||
actions += Action.AcknowledgeMessages(startingState.pendingDeduplicationHandlers)
|
||||
@ -91,32 +110,37 @@ class KilledFlowTransition(
|
||||
}
|
||||
}
|
||||
|
||||
// Purposely left the same as [bufferErrorMessagesInInitiatingSessions] in [ErrorFlowTransition] so that it can be refactored
|
||||
// Buffer error messages in Initiating sessions, return the initialised ones.
|
||||
/**
|
||||
* Buffers errors message for initiating states and filters the initiated states.
|
||||
* Returns a pair that consists of:
|
||||
* - a map containing the initiated states as filtered from the ones provided as input.
|
||||
* - a map containing the new state of all the sessions.
|
||||
*/
|
||||
private fun bufferErrorMessagesInInitiatingSessions(
|
||||
sessions: Map<SessionId, SessionState>,
|
||||
errorMessages: List<ErrorSessionMessage>
|
||||
): Pair<List<SessionState.Initiated>, Map<SessionId, SessionState>> {
|
||||
): Pair<Map<SessionId, SessionState.Initiated>, Map<SessionId, SessionState>> {
|
||||
val newSessions = sessions.mapValues { (sourceSessionId, sessionState) ->
|
||||
if (sessionState is SessionState.Initiating && sessionState.rejectionError == null) {
|
||||
// *prepend* the error messages in order to error the other sessions ASAP. The other messages will
|
||||
// be delivered all the same, they just won't trigger flow resumption because of dirtiness.
|
||||
val errorMessagesWithDeduplication = errorMessages.map {
|
||||
DeduplicationId.createForError(it.errorId, sourceSessionId) to it
|
||||
var currentSequenceNumber = sessionState.nextSendingSeqNumber
|
||||
val errorMessagesWithDeduplication = errorMessages.map { errorMessage ->
|
||||
val messageIdentifier = MessageIdentifier(MessageType.SESSION_ERROR, sessionState.shardId, sourceSessionId.calculateInitiatedSessionId(), currentSequenceNumber, startingState.checkpoint.checkpointState.suspensionTime)
|
||||
currentSequenceNumber++
|
||||
messageIdentifier to errorMessage
|
||||
}
|
||||
sessionState.copy(bufferedMessages = errorMessagesWithDeduplication + sessionState.bufferedMessages)
|
||||
sessionState.bufferMessages(errorMessagesWithDeduplication)
|
||||
} else {
|
||||
sessionState
|
||||
}
|
||||
}
|
||||
// if we have already received error message from the other side, we don't include that session in the list to avoid propagating errors.
|
||||
val initiatedSessions = sessions.values.mapNotNull { session ->
|
||||
if (session is SessionState.Initiated && !session.otherSideErrored) {
|
||||
session
|
||||
val initiatedSessions = sessions.mapNotNull { (sessionId, sessionState) ->
|
||||
if (sessionState is SessionState.Initiated && !sessionState.otherSideErrored) {
|
||||
sessionId to sessionState
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
}.toMap()
|
||||
return Pair(initiatedSessions, newSessions)
|
||||
}
|
||||
|
||||
|
@ -10,6 +10,9 @@ import net.corda.core.internal.FlowIORequest
|
||||
import net.corda.core.serialization.SerializedBytes
|
||||
import net.corda.core.utilities.contextLogger
|
||||
import net.corda.core.utilities.toNonEmptySet
|
||||
import net.corda.node.services.messaging.MessageIdentifier
|
||||
import net.corda.node.services.messaging.SenderDeduplicationInfo
|
||||
import net.corda.node.services.messaging.generateShardId
|
||||
import net.corda.node.services.statemachine.*
|
||||
import org.slf4j.Logger
|
||||
import kotlin.collections.LinkedHashMap
|
||||
@ -177,14 +180,19 @@ class StartedFlowTransition(
|
||||
}
|
||||
|
||||
if (existingSessionsToRemove.isNotEmpty()) {
|
||||
val sendEndMessageActions = existingSessionsToRemove.values.mapIndexed { index, state ->
|
||||
val sinkSessionId = (state as SessionState.Initiated).peerSinkSessionId
|
||||
val sendEndMessageActions = existingSessionsToRemove.map { (_, sessionState) ->
|
||||
val sinkSessionId = (sessionState as SessionState.Initiated).peerSinkSessionId
|
||||
val message = ExistingSessionMessage(sinkSessionId, EndSessionMessage)
|
||||
val deduplicationId = DeduplicationId.createForNormal(currentState.checkpoint, index, state)
|
||||
Action.SendExisting(state.peerParty, message, SenderDeduplicationId(deduplicationId, currentState.senderUUID))
|
||||
val messageType = MessageType.inferFromMessage(message)
|
||||
val messageIdentifier = MessageIdentifier(messageType, generateShardId(context.id.toString()), sinkSessionId, sessionState.nextSendingSeqNumber, currentState.checkpoint.checkpointState.suspensionTime)
|
||||
Action.SendExisting(sessionState.peerParty, message, SenderDeduplicationInfo(messageIdentifier, currentState.senderUUID))
|
||||
}
|
||||
val signalSessionsEndMap = existingSessionsToRemove.map { (sessionId, _) ->
|
||||
val sessionState = currentState.checkpoint.checkpointState.sessions[sessionId]!!
|
||||
sessionId to Pair(sessionState.lastSenderUUID, sessionState.lastSenderSeqNo)
|
||||
}.toMap()
|
||||
|
||||
currentState = currentState.copy(checkpoint = currentState.checkpoint.removeSessions(existingSessionsToRemove.keys))
|
||||
currentState = currentState.copy(checkpoint = currentState.checkpoint.removeSessions(existingSessionsToRemove.keys), closedSessionsPendingToBeSignalled = currentState.closedSessionsPendingToBeSignalled + signalSessionsEndMap)
|
||||
actions.add(Action.RemoveSessionBindings(sessionIdsToRemove))
|
||||
actions.add(Action.SendMultiple(emptyList(), sendEndMessageActions))
|
||||
}
|
||||
@ -239,23 +247,23 @@ class StartedFlowTransition(
|
||||
|
||||
@Suppress("ComplexMethod", "NestedBlockDepth")
|
||||
private fun pollSessionMessages(sessions: SessionMap, sessionIds: Set<SessionId>): PollResult? {
|
||||
val newSessionMessages = LinkedHashMap(sessions)
|
||||
val newSessionStates = LinkedHashMap(sessions)
|
||||
val resultMessages = LinkedHashMap<SessionId, SerializedBytes<Any>>()
|
||||
var someNotFound = false
|
||||
for (sessionId in sessionIds) {
|
||||
val sessionState = sessions[sessionId]
|
||||
when (sessionState) {
|
||||
is SessionState.Initiated -> {
|
||||
val messages = sessionState.receivedMessages
|
||||
if (messages.isEmpty()) {
|
||||
if (!sessionState.hasNextMessageArrived()) {
|
||||
someNotFound = true
|
||||
} else {
|
||||
newSessionMessages[sessionId] = sessionState.copy(receivedMessages = messages.subList(1, messages.size).toList())
|
||||
val (message, newState) = sessionState.extractMessage()
|
||||
newSessionStates[sessionId] = newState
|
||||
// at this point, we've already checked for errors and session ends, so it's guaranteed that the first message will be a data message.
|
||||
resultMessages[sessionId] = if (messages[0] is EndSessionMessage) {
|
||||
resultMessages[sessionId] = if (message is EndSessionMessage) {
|
||||
throw UnexpectedFlowEndException("Received session end message instead of a data session message. Mismatched send and receive?")
|
||||
} else {
|
||||
(messages[0] as DataSessionMessage).payload
|
||||
(message as DataSessionMessage).payload
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -267,14 +275,13 @@ class StartedFlowTransition(
|
||||
return if (someNotFound) {
|
||||
return null
|
||||
} else {
|
||||
PollResult(resultMessages, newSessionMessages)
|
||||
PollResult(resultMessages, newSessionStates)
|
||||
}
|
||||
}
|
||||
|
||||
private fun TransitionBuilder.sendInitialSessionMessagesIfNeeded(sourceSessions: Set<SessionId>) {
|
||||
val checkpoint = startingState.checkpoint
|
||||
val newSessions = LinkedHashMap<SessionId, SessionState>(checkpoint.checkpointState.sessions)
|
||||
var index = 0
|
||||
for (sourceSessionId in sourceSessions) {
|
||||
val sessionState = checkpoint.checkpointState.sessions[sourceSessionId]
|
||||
if (sessionState == null) {
|
||||
@ -283,14 +290,22 @@ class StartedFlowTransition(
|
||||
if (sessionState !is SessionState.Uninitiated) {
|
||||
continue
|
||||
}
|
||||
val shardId = generateShardId(context.id.toString())
|
||||
val counterpartySessionId = sourceSessionId.calculateInitiatedSessionId()
|
||||
val initialMessage = createInitialSessionMessage(sessionState.initiatingSubFlow, sourceSessionId, sessionState.additionalEntropy, null)
|
||||
val newSessionState = SessionState.Initiating(
|
||||
bufferedMessages = emptyList(),
|
||||
rejectionError = null,
|
||||
deduplicationSeed = sessionState.deduplicationSeed
|
||||
deduplicationSeed = sessionState.deduplicationSeed,
|
||||
nextSendingSeqNumber = 1,
|
||||
shardId = shardId,
|
||||
receivedMessages = emptyMap(),
|
||||
lastSenderUUID = null,
|
||||
lastSenderSeqNo = null
|
||||
)
|
||||
val deduplicationId = DeduplicationId.createForNormal(checkpoint, index++, newSessionState)
|
||||
actions.add(Action.SendInitial(sessionState.destination, initialMessage, SenderDeduplicationId(deduplicationId, startingState.senderUUID)))
|
||||
val messageType = MessageType.inferFromMessage(initialMessage)
|
||||
val messageIdentifier = MessageIdentifier(messageType, shardId, counterpartySessionId, 0, checkpoint.checkpointState.suspensionTime)
|
||||
actions.add(Action.SendInitial(sessionState.destination, initialMessage, SenderDeduplicationInfo(messageIdentifier, startingState.senderUUID)))
|
||||
newSessions[sourceSessionId] = newSessionState
|
||||
}
|
||||
currentState = currentState.copy(checkpoint = checkpoint.setSessions(sessions = newSessions))
|
||||
@ -313,37 +328,60 @@ class StartedFlowTransition(
|
||||
private fun TransitionBuilder.sendToSessionsTransition(sourceSessionIdToMessage: Map<SessionId, SerializedBytes<Any>>) {
|
||||
val checkpoint = startingState.checkpoint
|
||||
val newSessions = LinkedHashMap(checkpoint.checkpointState.sessions)
|
||||
var index = 0
|
||||
|
||||
val messagesByType = sourceSessionIdToMessage.toList()
|
||||
.map { (sourceSessionId, message) -> Triple(sourceSessionId, checkpoint.checkpointState.sessions[sourceSessionId]!!, message) }
|
||||
.groupBy { it.second::class }
|
||||
|
||||
val sendInitialActions = messagesByType[SessionState.Uninitiated::class]?.map { (sourceSessionId, sessionState, message) ->
|
||||
val sendInitialActions = messagesByType[SessionState.Uninitiated::class]?.mapNotNull { (sourceSessionId, sessionState, message) ->
|
||||
val uninitiatedSessionState = sessionState as SessionState.Uninitiated
|
||||
val deduplicationId = DeduplicationId.createForNormal(checkpoint, index++, sessionState)
|
||||
val initialMessage = createInitialSessionMessage(uninitiatedSessionState.initiatingSubFlow, sourceSessionId, uninitiatedSessionState.additionalEntropy, message)
|
||||
val shardId = generateShardId(context.id.toString())
|
||||
if (sessionState.hasBeenAcknowledged != null) {
|
||||
newSessions[sourceSessionId] = SessionState.Initiated(
|
||||
peerParty = sessionState.hasBeenAcknowledged.first,
|
||||
peerFlowInfo = sessionState.hasBeenAcknowledged.second.initiatedFlowInfo,
|
||||
receivedMessages = emptyMap(),
|
||||
otherSideErrored = false,
|
||||
peerSinkSessionId = sessionState.hasBeenAcknowledged.second.initiatedSessionId,
|
||||
deduplicationSeed = sessionState.deduplicationSeed,
|
||||
nextSendingSeqNumber = 1,
|
||||
lastProcessedSeqNumber = 0,
|
||||
shardId = shardId,
|
||||
lastSenderUUID = null,
|
||||
lastSenderSeqNo = null
|
||||
)
|
||||
null
|
||||
} else {
|
||||
newSessions[sourceSessionId] = SessionState.Initiating(
|
||||
bufferedMessages = emptyList(),
|
||||
rejectionError = null,
|
||||
deduplicationSeed = uninitiatedSessionState.deduplicationSeed
|
||||
deduplicationSeed = sessionState.deduplicationSeed,
|
||||
nextSendingSeqNumber = 1,
|
||||
shardId = shardId,
|
||||
receivedMessages = emptyMap(),
|
||||
lastSenderUUID = null,
|
||||
lastSenderSeqNo = null
|
||||
)
|
||||
Action.SendInitial(uninitiatedSessionState.destination, initialMessage, SenderDeduplicationId(deduplicationId, startingState.senderUUID))
|
||||
val initialMessage = createInitialSessionMessage(uninitiatedSessionState.initiatingSubFlow, sourceSessionId, uninitiatedSessionState.additionalEntropy, message)
|
||||
val messageType = MessageType.inferFromMessage(initialMessage)
|
||||
val messageIdentifier = MessageIdentifier(messageType, shardId, sourceSessionId.calculateInitiatedSessionId(), 0, checkpoint.checkpointState.suspensionTime)
|
||||
Action.SendInitial(uninitiatedSessionState.destination, initialMessage, SenderDeduplicationInfo(messageIdentifier, startingState.senderUUID))
|
||||
}
|
||||
} ?: emptyList()
|
||||
messagesByType[SessionState.Initiating::class]?.forEach { (sourceSessionId, sessionState, message) ->
|
||||
val initiatingSessionState = sessionState as SessionState.Initiating
|
||||
val sessionMessage = DataSessionMessage(message)
|
||||
val deduplicationId = DeduplicationId.createForNormal(checkpoint, index++, initiatingSessionState)
|
||||
val newBufferedMessages = initiatingSessionState.bufferedMessages + Pair(deduplicationId, sessionMessage)
|
||||
newSessions[sourceSessionId] = initiatingSessionState.copy(bufferedMessages = newBufferedMessages)
|
||||
val messageIdentifier = MessageIdentifier(MessageType.DATA_MESSAGE, sessionState.shardId, sourceSessionId.calculateInitiatedSessionId(), sessionState.nextSendingSeqNumber, checkpoint.checkpointState.suspensionTime)
|
||||
newSessions[sourceSessionId] = initiatingSessionState.bufferMessage(messageIdentifier, sessionMessage)
|
||||
}
|
||||
val sendExistingActions = messagesByType[SessionState.Initiated::class]?.map {(_, sessionState, message) ->
|
||||
val sendExistingActions = messagesByType[SessionState.Initiated::class]?.map {(sourceSessionId, sessionState, message) ->
|
||||
val initiatedSessionState = sessionState as SessionState.Initiated
|
||||
val sessionMessage = DataSessionMessage(message)
|
||||
val deduplicationId = DeduplicationId.createForNormal(checkpoint, index++, initiatedSessionState)
|
||||
val sinkSessionId = initiatedSessionState.peerSinkSessionId
|
||||
val existingMessage = ExistingSessionMessage(sinkSessionId, sessionMessage)
|
||||
Action.SendExisting(initiatedSessionState.peerParty, existingMessage, SenderDeduplicationId(deduplicationId, startingState.senderUUID))
|
||||
val messageType = MessageType.inferFromMessage(existingMessage)
|
||||
val messageIdentifier = MessageIdentifier(messageType, sessionState.shardId, sessionState.peerSinkSessionId, sessionState.nextSendingSeqNumber, checkpoint.checkpointState.suspensionTime)
|
||||
newSessions[sourceSessionId] = initiatedSessionState.copy(nextSendingSeqNumber = initiatedSessionState.nextSendingSeqNumber + 1)
|
||||
Action.SendExisting(initiatedSessionState.peerParty, existingMessage, SenderDeduplicationInfo(messageIdentifier, startingState.senderUUID))
|
||||
} ?: emptyList()
|
||||
|
||||
if (sendInitialActions.isNotEmpty() || sendExistingActions.isNotEmpty()) {
|
||||
@ -372,11 +410,10 @@ class StartedFlowTransition(
|
||||
}
|
||||
}
|
||||
is SessionState.Initiated -> {
|
||||
if (sessionState.receivedMessages.isNotEmpty() && sessionState.receivedMessages.first() is ErrorSessionMessage) {
|
||||
val errorMessage = sessionState.receivedMessages.first() as ErrorSessionMessage
|
||||
val exception = convertErrorMessageToException(errorMessage, sessionState.peerParty)
|
||||
val newSessionState = sessionState.copy(receivedMessages = sessionState.receivedMessages.subList(1, sessionState.receivedMessages.size), otherSideErrored = true)
|
||||
val newCheckpoint = startingState.checkpoint.addSession(sessionId to newSessionState)
|
||||
if (sessionState.hasErrored()) {
|
||||
val (message, newSessionState) = sessionState.extractMessage()
|
||||
val exception = convertErrorMessageToException(message as ErrorSessionMessage, sessionState.peerParty)
|
||||
val newCheckpoint = startingState.checkpoint.addSession(sessionId to newSessionState.copy(otherSideErrored = true))
|
||||
newState = startingState.copy(checkpoint = newCheckpoint)
|
||||
listOf(exception)
|
||||
} else {
|
||||
@ -541,8 +578,8 @@ class StartedFlowTransition(
|
||||
|
||||
private fun findSessionsToBeTerminated(startingState: StateMachineState): SessionMap {
|
||||
return startingState.checkpoint.checkpointState.sessionsToBeClosed.mapNotNull { sessionId ->
|
||||
val sessionState = startingState.checkpoint.checkpointState.sessions[sessionId]!! as SessionState.Initiated
|
||||
if (sessionState.receivedMessages.isNotEmpty() && sessionState.receivedMessages.first() is EndSessionMessage) {
|
||||
val sessionState = startingState.checkpoint.checkpointState.sessions[sessionId]!!
|
||||
if (sessionState is SessionState.Initiated && sessionState.receivedMessages.containsKey(sessionState.lastProcessedSeqNumber + 1) && sessionState.receivedMessages[sessionState.lastProcessedSeqNumber + 1] is EndSessionMessage) {
|
||||
sessionId to sessionState
|
||||
} else {
|
||||
null
|
||||
|
@ -4,12 +4,13 @@ import net.corda.core.flows.StateMachineRunId
|
||||
import net.corda.node.services.statemachine.Event
|
||||
import net.corda.node.services.statemachine.StateMachineState
|
||||
import java.security.SecureRandom
|
||||
import java.time.Instant
|
||||
|
||||
class StateMachine(
|
||||
val id: StateMachineRunId,
|
||||
val secureRandom: SecureRandom
|
||||
) {
|
||||
fun transition(event: Event, state: StateMachineState): TransitionResult {
|
||||
return TopLevelTransition(TransitionContext(id, secureRandom), state, event).transition()
|
||||
fun transition(event: Event, state: StateMachineState, time: Instant): TransitionResult {
|
||||
return TopLevelTransition(TransitionContext(id, secureRandom, time), state, event).transition()
|
||||
}
|
||||
}
|
||||
|
@ -7,9 +7,9 @@ import net.corda.core.serialization.deserialize
|
||||
import net.corda.core.utilities.Try
|
||||
import net.corda.core.utilities.contextLogger
|
||||
import net.corda.node.services.messaging.DeduplicationHandler
|
||||
import net.corda.node.services.messaging.MessageIdentifier
|
||||
import net.corda.node.services.statemachine.Action
|
||||
import net.corda.node.services.statemachine.Checkpoint
|
||||
import net.corda.node.services.statemachine.DeduplicationId
|
||||
import net.corda.node.services.statemachine.EndSessionMessage
|
||||
import net.corda.node.services.statemachine.ErrorState
|
||||
import net.corda.node.services.statemachine.Event
|
||||
@ -19,7 +19,8 @@ import net.corda.node.services.statemachine.FlowRemovalReason
|
||||
import net.corda.node.services.statemachine.FlowSessionImpl
|
||||
import net.corda.node.services.statemachine.FlowState
|
||||
import net.corda.node.services.statemachine.InitialSessionMessage
|
||||
import net.corda.node.services.statemachine.SenderDeduplicationId
|
||||
import net.corda.node.services.messaging.SenderDeduplicationInfo
|
||||
import net.corda.node.services.statemachine.MessageType
|
||||
import net.corda.node.services.statemachine.SessionId
|
||||
import net.corda.node.services.statemachine.SessionMessage
|
||||
import net.corda.node.services.statemachine.SessionState
|
||||
@ -197,7 +198,8 @@ class TopLevelTransition(
|
||||
checkpointState.invocationContext
|
||||
},
|
||||
numberOfSuspends = checkpointState.numberOfSuspends + 1,
|
||||
numberOfCommits = checkpointState.numberOfCommits + 1
|
||||
numberOfCommits = checkpointState.numberOfCommits + 1,
|
||||
suspensionTime = context.time
|
||||
)
|
||||
copy(
|
||||
flowState = FlowState.Started(event.ioRequest, event.fiber),
|
||||
@ -217,11 +219,13 @@ class TopLevelTransition(
|
||||
currentState = startingState.copy(
|
||||
checkpoint = newCheckpoint,
|
||||
pendingDeduplicationHandlers = emptyList(),
|
||||
closedSessionsPendingToBeSignalled = emptyMap(),
|
||||
isFlowResumed = false,
|
||||
isAnyCheckpointPersisted = true
|
||||
)
|
||||
actions += Action.PersistCheckpoint(context.id, newCheckpoint, isCheckpointUpdate = startingState.isAnyCheckpointPersisted)
|
||||
actions += Action.PersistDeduplicationFacts(startingState.pendingDeduplicationHandlers)
|
||||
actions += Action.SignalSessionsHasEnded(startingState.closedSessionsPendingToBeSignalled)
|
||||
actions += Action.CommitTransaction(currentState)
|
||||
actions += Action.AcknowledgeMessages(startingState.pendingDeduplicationHandlers)
|
||||
actions += Action.ScheduleEvent(Event.DoRemainingWork)
|
||||
@ -240,12 +244,14 @@ class TopLevelTransition(
|
||||
checkpoint = checkpoint.copy(
|
||||
checkpointState = checkpoint.checkpointState.copy(
|
||||
numberOfSuspends = checkpoint.checkpointState.numberOfSuspends + 1,
|
||||
numberOfCommits = checkpoint.checkpointState.numberOfCommits + 1
|
||||
numberOfCommits = checkpoint.checkpointState.numberOfCommits + 1,
|
||||
suspensionTime = context.time
|
||||
),
|
||||
flowState = FlowState.Finished,
|
||||
result = event.returnValue,
|
||||
status = Checkpoint.FlowStatus.COMPLETED
|
||||
),
|
||||
).removeSessions(checkpoint.checkpointState.sessions.keys),
|
||||
closedSessionsPendingToBeSignalled = emptyMap(),
|
||||
pendingDeduplicationHandlers = emptyList(),
|
||||
isFlowResumed = false,
|
||||
isRemoved = true
|
||||
@ -263,7 +269,12 @@ class TopLevelTransition(
|
||||
)
|
||||
}
|
||||
|
||||
val signalSessionsEndMap = startingState.checkpoint.checkpointState.sessions.map { (sessionId, sessionState) ->
|
||||
sessionId to Pair(sessionState.lastSenderUUID, sessionState.lastSenderSeqNo)
|
||||
}.toMap()
|
||||
|
||||
actions += Action.PersistDeduplicationFacts(startingState.pendingDeduplicationHandlers)
|
||||
actions += Action.SignalSessionsHasEnded(signalSessionsEndMap)
|
||||
actions += Action.ReleaseSoftLocks(event.softLocksId)
|
||||
actions += Action.CommitTransaction(currentState)
|
||||
actions += Action.AcknowledgeMessages(startingState.pendingDeduplicationHandlers)
|
||||
@ -284,11 +295,12 @@ class TopLevelTransition(
|
||||
}
|
||||
|
||||
private fun TransitionBuilder.sendEndMessages() {
|
||||
val sendEndMessageActions = currentState.checkpoint.checkpointState.sessions.values.mapIndexed { index, state ->
|
||||
val sendEndMessageActions = startingState.checkpoint.checkpointState.sessions.map { (sessionId, state) ->
|
||||
if (state is SessionState.Initiated) {
|
||||
val message = ExistingSessionMessage(state.peerSinkSessionId, EndSessionMessage)
|
||||
val deduplicationId = DeduplicationId.createForNormal(currentState.checkpoint, index, state)
|
||||
Action.SendExisting(state.peerParty, message, SenderDeduplicationId(deduplicationId, currentState.senderUUID))
|
||||
val messageType = MessageType.inferFromMessage(message)
|
||||
val messageIdentifier = MessageIdentifier(messageType, state.shardId, state.peerSinkSessionId, state.nextSendingSeqNumber, startingState.checkpoint.checkpointState.suspensionTime)
|
||||
Action.SendExisting(state.peerParty, message, SenderDeduplicationInfo(messageIdentifier, startingState.senderUUID))
|
||||
} else {
|
||||
null
|
||||
}
|
||||
@ -306,7 +318,7 @@ class TopLevelTransition(
|
||||
}
|
||||
val sourceSessionId = SessionId.createRandom(context.secureRandom)
|
||||
val sessionImpl = FlowSessionImpl(event.destination, event.wellKnownParty, sourceSessionId)
|
||||
val newSessions = checkpoint.checkpointState.sessions + (sourceSessionId to SessionState.Uninitiated(event.destination, initiatingSubFlow, sourceSessionId, context.secureRandom.nextLong()))
|
||||
val newSessions = checkpoint.checkpointState.sessions + (sourceSessionId to SessionState.Uninitiated(event.destination, initiatingSubFlow, sourceSessionId, context.secureRandom.nextLong(), null, null, emptyMap(), null, null))
|
||||
currentState = currentState.copy(checkpoint = checkpoint.setSessions(newSessions))
|
||||
actions.add(Action.AddSessionBinding(context.id, sourceSessionId))
|
||||
FlowContinuation.Resume(sessionImpl)
|
||||
@ -361,10 +373,12 @@ class TopLevelTransition(
|
||||
numberOfCommits = startingState.checkpoint.checkpointState.numberOfCommits + 1
|
||||
)
|
||||
),
|
||||
pendingDeduplicationHandlers = startingState.pendingDeduplicationHandlers - flowStartEvents
|
||||
pendingDeduplicationHandlers = startingState.pendingDeduplicationHandlers - flowStartEvents,
|
||||
closedSessionsPendingToBeSignalled = emptyMap()
|
||||
)
|
||||
actions += Action.CreateTransaction
|
||||
actions += Action.PersistDeduplicationFacts(flowStartEvents)
|
||||
actions += Action.SignalSessionsHasEnded(startingState.closedSessionsPendingToBeSignalled)
|
||||
actions += Action.PersistCheckpoint(context.id, newCheckpoint, isCheckpointUpdate = startingState.isAnyCheckpointPersisted)
|
||||
actions += Action.CommitTransaction(currentState)
|
||||
actions += Action.AcknowledgeMessages(flowStartEvents)
|
||||
@ -401,4 +415,5 @@ class TopLevelTransition(
|
||||
FlowContinuation.Abort
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package net.corda.node.services.statemachine.transitions
|
||||
import net.corda.core.flows.StateMachineRunId
|
||||
import net.corda.node.services.statemachine.StateMachineState
|
||||
import java.security.SecureRandom
|
||||
import java.time.Instant
|
||||
|
||||
/**
|
||||
* An interface used to separate out different parts of the state machine transition function.
|
||||
@ -28,5 +29,6 @@ interface Transition {
|
||||
|
||||
class TransitionContext(
|
||||
val id: StateMachineRunId,
|
||||
val secureRandom: SecureRandom
|
||||
val secureRandom: SecureRandom,
|
||||
val time: Instant
|
||||
)
|
||||
|
@ -80,6 +80,6 @@ class TransitionBuilder(val context: TransitionContext, initialState: StateMachi
|
||||
}
|
||||
|
||||
class CannotFindSessionException(sessionId: SessionId) : IllegalStateException("Couldn't find session with id $sessionId")
|
||||
class UnexpectedEventInState : IllegalStateException("Unexpected event")
|
||||
class UnexpectedEventInState(message: String = "") : IllegalStateException("An unexpected event happened. $message")
|
||||
class PrematureSessionCloseException(sessionId: SessionId): IllegalStateException("The following session was closed before it was initialised: $sessionId")
|
||||
class PrematureSessionEndException(sessionId: SessionId): IllegalStateException("A premature session end message was received before the session was initialised: $sessionId")
|
@ -1,14 +1,15 @@
|
||||
package net.corda.node.services.statemachine.transitions
|
||||
|
||||
import net.corda.core.flows.FlowInfo
|
||||
import net.corda.node.services.messaging.MessageIdentifier
|
||||
import net.corda.node.services.statemachine.Action
|
||||
import net.corda.node.services.statemachine.ConfirmSessionMessage
|
||||
import net.corda.node.services.statemachine.DataSessionMessage
|
||||
import net.corda.node.services.statemachine.DeduplicationId
|
||||
import net.corda.node.services.statemachine.ExistingSessionMessage
|
||||
import net.corda.node.services.statemachine.FlowStart
|
||||
import net.corda.node.services.statemachine.FlowState
|
||||
import net.corda.node.services.statemachine.SenderDeduplicationId
|
||||
import net.corda.node.services.messaging.SenderDeduplicationInfo
|
||||
import net.corda.node.services.statemachine.MessageType
|
||||
import net.corda.node.services.statemachine.SessionState
|
||||
import net.corda.node.services.statemachine.StateMachineState
|
||||
|
||||
@ -50,25 +51,28 @@ class UnstartedFlowTransition(
|
||||
appName = initiatingMessage.appName
|
||||
),
|
||||
receivedMessages = if (initiatingMessage.firstPayload == null) {
|
||||
emptyList()
|
||||
emptyMap()
|
||||
} else {
|
||||
listOf(DataSessionMessage(initiatingMessage.firstPayload))
|
||||
mapOf(0 to DataSessionMessage(initiatingMessage.firstPayload))
|
||||
},
|
||||
deduplicationSeed = "D-${initiatingMessage.initiatorSessionId.toLong}-${initiatingMessage.initiationEntropy}",
|
||||
otherSideErrored = false
|
||||
deduplicationSeed = "D-${initiatingMessage.initiatorSessionId.value}-${initiatingMessage.initiationEntropy}",
|
||||
otherSideErrored = false,
|
||||
nextSendingSeqNumber = 1,
|
||||
lastProcessedSeqNumber = if (initiatingMessage.firstPayload == null) {
|
||||
0
|
||||
} else {
|
||||
-1
|
||||
},
|
||||
shardId = flowStart.shardIdentifier,
|
||||
lastSenderUUID = flowStart.senderUUID,
|
||||
lastSenderSeqNo = flowStart.senderSequenceNumber
|
||||
)
|
||||
val confirmationMessage = ConfirmSessionMessage(flowStart.initiatedSessionId, flowStart.initiatedFlowInfo)
|
||||
val sessionMessage = ExistingSessionMessage(initiatingMessage.initiatorSessionId, confirmationMessage)
|
||||
currentState = currentState.copy(
|
||||
checkpoint = currentState.checkpoint.setSessions(mapOf(flowStart.initiatedSessionId to initiatedState))
|
||||
)
|
||||
actions.add(
|
||||
Action.SendExisting(
|
||||
flowStart.peerSession.counterparty,
|
||||
sessionMessage,
|
||||
SenderDeduplicationId(DeduplicationId.createForNormal(currentState.checkpoint, 0, initiatedState), currentState.senderUUID)
|
||||
)
|
||||
)
|
||||
val messageType = MessageType.inferFromMessage(sessionMessage)
|
||||
val messageIdentifier = MessageIdentifier(messageType, flowStart.shardIdentifier, initiatingMessage.initiatorSessionId, 0, currentState.checkpoint.checkpointState.suspensionTime)
|
||||
currentState = currentState.copy(checkpoint = currentState.checkpoint.setSessions(mapOf(flowStart.initiatedSessionId to initiatedState)))
|
||||
actions.add(Action.SendExisting(flowStart.peerSession.counterparty, sessionMessage, SenderDeduplicationInfo(messageIdentifier, currentState.senderUUID)))
|
||||
}
|
||||
|
||||
// Create initial checkpoint and acknowledge triggering messages.
|
||||
|
@ -55,7 +55,7 @@ open class DefaultNamedCacheFactory protected constructor(private val metricRegi
|
||||
name == "FlowDrainingMode_nodeProperties" -> caffeine.maximumSize(defaultCacheSize)
|
||||
name == "ContractUpgradeService_upgrades" -> caffeine.maximumSize(defaultCacheSize)
|
||||
name == "PersistentUniquenessProvider_transactions" -> caffeine.maximumSize(defaultCacheSize)
|
||||
name == "P2PMessageDeduplicator_processedMessages" -> caffeine.maximumSize(defaultCacheSize)
|
||||
name == "P2PMessageDeduplicator_sessionData" -> caffeine.maximumSize(defaultCacheSize)
|
||||
name == "DeduplicationChecker_watermark" -> caffeine
|
||||
name == "BFTNonValidatingNotaryService_transactions" -> caffeine.maximumSize(defaultCacheSize)
|
||||
name == "RaftUniquenessProvider_transactions" -> caffeine.maximumSize(defaultCacheSize)
|
||||
|
@ -34,5 +34,7 @@
|
||||
<include file="migration/node-core.changelog-v19.xml"/>
|
||||
<include file="migration/node-core.changelog-v19-postgres.xml"/>
|
||||
<include file="migration/node-core.changelog-v19-keys.xml"/>
|
||||
<include file="migration/node-core.changelog-v20.xml"/>
|
||||
<include file="migration/node-core.changelog-v21.xml"/>
|
||||
|
||||
</databaseChangeLog>
|
||||
|
@ -0,0 +1,28 @@
|
||||
<?xml version="1.1" encoding="UTF-8" standalone="no"?>
|
||||
<databaseChangeLog xmlns="http://www.liquibase.org/xml/ns/dbchangelog"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://www.liquibase.org/xml/ns/dbchangelog http://www.liquibase.org/xml/ns/dbchangelog/dbchangelog-3.5.xsd"
|
||||
logicalFilePath="migration/node-services.changelog-init.xml">
|
||||
|
||||
<changeSet author="R3.Corda" id="add_session_data_table">
|
||||
<createTable tableName="node_session_data">
|
||||
<column name="session_id" type="NUMBER(128)">
|
||||
<constraints nullable="false"/>
|
||||
</column>
|
||||
<column name="init_generation_time" type="timestamp">
|
||||
<constraints nullable="false"/>
|
||||
</column>
|
||||
<column name="sender_hash" type="NVARCHAR(64)">
|
||||
<constraints nullable="true"/>
|
||||
</column>
|
||||
<column name="init_sequence_number" type="BIGINT">
|
||||
<constraints nullable="true"/>
|
||||
</column>
|
||||
<column name="last_sequence_number" type="BIGINT">
|
||||
<constraints nullable="true"/>
|
||||
</column>
|
||||
</createTable>
|
||||
<addPrimaryKey columnNames="session_id" constraintName="node_session_data_pk" tableName="node_session_data"/>
|
||||
</changeSet>
|
||||
|
||||
</databaseChangeLog>
|
@ -0,0 +1,47 @@
|
||||
package net.corda.node.services.messaging
|
||||
|
||||
import net.corda.node.services.statemachine.MessageType
|
||||
import net.corda.node.services.statemachine.SessionId
|
||||
import org.assertj.core.api.Assertions
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.junit.Test
|
||||
import java.lang.IllegalArgumentException
|
||||
import java.math.BigInteger
|
||||
import java.time.Instant
|
||||
|
||||
class MessageIdentifierTest {
|
||||
|
||||
private val shardIdentifier = "XXXXXXXX"
|
||||
private val sessionIdentifier = SessionId(BigInteger.valueOf(14))
|
||||
private val sessionSequenceNumber = 1
|
||||
private val timestamp = Instant.ofEpochMilli(100)
|
||||
|
||||
private val messageIdString = "XI-0000000000000064-XXXXXXXX-0000000000000000000000000000000E-1"
|
||||
|
||||
@Test(timeout=300_000)
|
||||
fun `can parse message identifier from string value`() {
|
||||
val messageIdentifier = MessageIdentifier(MessageType.SESSION_INIT, shardIdentifier, sessionIdentifier, sessionSequenceNumber, timestamp)
|
||||
|
||||
assertThat(messageIdentifier.toString()).isEqualTo(messageIdString)
|
||||
}
|
||||
|
||||
@Test(timeout=300_000)
|
||||
fun `can convert message identifier object to string value`() {
|
||||
val messageIdentifierString = messageIdString
|
||||
|
||||
val messageIdentifier = MessageIdentifier.parse(messageIdentifierString)
|
||||
assertThat(messageIdentifier.messageType).isInstanceOf(MessageType.SESSION_INIT::class.java)
|
||||
assertThat(messageIdentifier.shardIdentifier).isEqualTo(shardIdentifier)
|
||||
assertThat(messageIdentifier.sessionIdentifier).isEqualTo(sessionIdentifier)
|
||||
assertThat(messageIdentifier.sessionSequenceNumber).isEqualTo(1)
|
||||
assertThat(messageIdentifier.timestamp).isEqualTo(timestamp)
|
||||
}
|
||||
|
||||
@Test(timeout=300_000)
|
||||
fun `shard identifier needs to be 8 characters long`() {
|
||||
Assertions.assertThatThrownBy { MessageIdentifier(MessageType.SESSION_INIT, "XX", sessionIdentifier, 1, timestamp) }
|
||||
.isInstanceOf(IllegalArgumentException::class.java)
|
||||
.hasMessage("Shard identifier needs to be 8 characters long, but it was XX")
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,28 @@
|
||||
package net.corda.node.services.messaging
|
||||
|
||||
import net.corda.node.services.statemachine.SessionId
|
||||
import org.assertj.core.api.Assertions.assertThatThrownBy
|
||||
import org.junit.Test
|
||||
import java.lang.IllegalArgumentException
|
||||
import java.math.BigInteger
|
||||
|
||||
class SessionIdTest {
|
||||
|
||||
@Test(timeout=300_000)
|
||||
fun `session identifier cannot be negative`() {
|
||||
assertThatThrownBy { SessionId(BigInteger.valueOf(-1)) }
|
||||
.isInstanceOf(IllegalArgumentException::class.java)
|
||||
.hasMessage("Session identifier cannot be a negative number, but it was -1")
|
||||
}
|
||||
|
||||
@Test(timeout=300_000)
|
||||
fun `session identifier needs to be a number that can be represented in maximum 128 bits`() {
|
||||
val largestSessionIdentifierValue = BigInteger.valueOf(2).pow(128).minus(BigInteger.ONE)
|
||||
val largestValidSessionId = SessionId(largestSessionIdentifierValue)
|
||||
|
||||
assertThatThrownBy { SessionId(largestSessionIdentifierValue.plus(BigInteger.ONE)) }
|
||||
.isInstanceOf(IllegalArgumentException::class.java)
|
||||
.hasMessage("The size of a session identifier cannot exceed 128 bits, but it was 340282366920938463463374607431768211456")
|
||||
}
|
||||
|
||||
}
|
@ -873,7 +873,8 @@ class DBCheckpointStorageTests {
|
||||
frozenLogic,
|
||||
ALICE,
|
||||
SubFlowVersion.CoreFlow(version),
|
||||
false
|
||||
false,
|
||||
Clock.systemUTC().instant()
|
||||
)
|
||||
.getOrThrow()
|
||||
return id to checkpoint
|
||||
|
@ -194,7 +194,7 @@ class CheckpointDumperImplTest {
|
||||
override fun call() {}
|
||||
}
|
||||
val frozenLogic = logic.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT)
|
||||
val checkpoint = Checkpoint.create(InvocationContext.shell(), FlowStart.Explicit, logic.javaClass, frozenLogic, myself.identity.party, SubFlowVersion.CoreFlow(version), false)
|
||||
val checkpoint = Checkpoint.create(InvocationContext.shell(), FlowStart.Explicit, logic.javaClass, frozenLogic, myself.identity.party, SubFlowVersion.CoreFlow(version), false, Clock.systemUTC().instant())
|
||||
.getOrThrow()
|
||||
return id to checkpoint
|
||||
}
|
||||
|
@ -8,7 +8,6 @@ import net.corda.client.rpc.notUsed
|
||||
import net.corda.core.concurrent.CordaFuture
|
||||
import net.corda.core.contracts.ContractState
|
||||
import net.corda.core.crypto.SecureHash
|
||||
import net.corda.core.crypto.random63BitValue
|
||||
import net.corda.core.flows.Destination
|
||||
import net.corda.core.flows.FinalityFlow
|
||||
import net.corda.core.flows.FlowException
|
||||
@ -42,6 +41,8 @@ import net.corda.core.utilities.ProgressTracker.Change
|
||||
import net.corda.core.utilities.getOrThrow
|
||||
import net.corda.core.utilities.seconds
|
||||
import net.corda.core.utilities.unwrap
|
||||
import net.corda.node.services.messaging.MessageIdentifier
|
||||
import net.corda.node.services.messaging.SenderDeduplicationInfo
|
||||
import net.corda.node.services.persistence.CheckpointPerformanceRecorder
|
||||
import net.corda.node.services.persistence.DBCheckpointStorage
|
||||
import net.corda.node.services.persistence.checkpoints
|
||||
@ -80,6 +81,8 @@ import org.junit.Before
|
||||
import org.junit.Test
|
||||
import rx.Notification
|
||||
import rx.Observable
|
||||
import java.math.BigInteger
|
||||
import java.security.SecureRandom
|
||||
import java.sql.SQLTransientConnectionException
|
||||
import java.time.Clock
|
||||
import java.time.Duration
|
||||
@ -571,7 +574,7 @@ class FlowFrameworkTests {
|
||||
|
||||
@Test(timeout=300_000)
|
||||
fun `session init with unknown class is sent to the flow hospital, from where we then drop it`() {
|
||||
aliceNode.sendSessionMessage(InitialSessionMessage(SessionId(random63BitValue()), 0, "not.a.real.Class", 1, "", null), bob)
|
||||
aliceNode.sendSessionMessage(InitialSessionMessage(SessionId.createRandom(SecureRandom()), 0, "not.a.real.Class", 1, "", null), bob)
|
||||
mockNet.runNetwork()
|
||||
assertThat(receivedSessionMessages).hasSize(1) // Only the session-init is expected as the session-reject is blocked by the flow hospital
|
||||
val medicalRecords = bobNode.smm.flowHospital.track().apply { updates.notUsed() }.snapshot
|
||||
@ -587,7 +590,7 @@ class FlowFrameworkTests {
|
||||
|
||||
@Test(timeout=300_000)
|
||||
fun `non-flow class in session init`() {
|
||||
aliceNode.sendSessionMessage(InitialSessionMessage(SessionId(random63BitValue()), 0, String::class.java.name, 1, "", null), bob)
|
||||
aliceNode.sendSessionMessage(InitialSessionMessage(SessionId.createRandom(SecureRandom()), 0, String::class.java.name, 1, "", null), bob)
|
||||
mockNet.runNetwork()
|
||||
assertThat(receivedSessionMessages).hasSize(2) // Only the session-init and session-reject are expected
|
||||
val lastMessage = receivedSessionMessages.last().message as ExistingSessionMessage
|
||||
@ -897,7 +900,7 @@ class FlowFrameworkTests {
|
||||
}
|
||||
//region Helpers
|
||||
|
||||
private val normalEnd = ExistingSessionMessage(SessionId(0), EndSessionMessage) // NormalSessionEnd(0)
|
||||
private val normalEnd = ExistingSessionMessage(SessionId(BigInteger.valueOf(0)), EndSessionMessage) // NormalSessionEnd(0)
|
||||
|
||||
private fun assertSessionTransfers(vararg expected: SessionTransfer) {
|
||||
assertThat(receivedSessionMessages).containsExactly(*expected)
|
||||
@ -1039,21 +1042,21 @@ class FlowFrameworkTests {
|
||||
}
|
||||
|
||||
internal fun sessionConfirm(flowVersion: Int = 1) =
|
||||
ExistingSessionMessage(SessionId(0), ConfirmSessionMessage(SessionId(0), FlowInfo(flowVersion, "")))
|
||||
ExistingSessionMessage(SessionId(BigInteger.valueOf(0)), ConfirmSessionMessage(SessionId(BigInteger.valueOf(0)), FlowInfo(flowVersion, "")))
|
||||
|
||||
internal inline fun <reified P : FlowLogic<*>> TestStartedNode.getSingleFlow(): Pair<P, CordaFuture<*>> {
|
||||
return smm.findStateMachines(P::class.java).single()
|
||||
}
|
||||
|
||||
private fun sanitise(message: SessionMessage) = when (message) {
|
||||
is InitialSessionMessage -> message.copy(initiatorSessionId = SessionId(0), initiationEntropy = 0, appName = "")
|
||||
is InitialSessionMessage -> message.copy(initiatorSessionId = SessionId(BigInteger.valueOf(0)), initiationEntropy = 0, appName = "")
|
||||
is ExistingSessionMessage -> {
|
||||
val payload = message.payload
|
||||
message.copy(
|
||||
recipientSessionId = SessionId(0),
|
||||
recipientSessionId = SessionId(BigInteger.valueOf(0)),
|
||||
payload = when (payload) {
|
||||
is ConfirmSessionMessage -> payload.copy(
|
||||
initiatedSessionId = SessionId(0),
|
||||
initiatedSessionId = SessionId(BigInteger.valueOf(0)),
|
||||
initiatedFlowInfo = payload.initiatedFlowInfo.copy(appName = "")
|
||||
)
|
||||
is ErrorSessionMessage -> payload.copy(
|
||||
@ -1076,7 +1079,8 @@ internal fun Observable<MessageTransfer>.toSessionTransfers(): Observable<Sessio
|
||||
internal fun TestStartedNode.sendSessionMessage(message: SessionMessage, destination: Party) {
|
||||
services.networkService.apply {
|
||||
val address = getAddressOfParty(PartyInfo.SingleNode(destination, emptyList()))
|
||||
send(createMessage(FlowMessagingImpl.sessionTopic, message.serialize().bytes), address)
|
||||
val messageIdentifier = MessageIdentifier(MessageType.DATA_MESSAGE, "XXXXXXXX", SessionId(BigInteger.valueOf(14)), 0, Clock.systemUTC().instant())
|
||||
send(createMessage(FlowMessagingImpl.sessionTopic, message.serialize().bytes, SenderDeduplicationInfo(messageIdentifier, null), emptyMap()), address)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1087,7 +1091,7 @@ inline fun <reified T> DatabaseTransaction.findRecordsFromDatabase(): List<T> {
|
||||
}
|
||||
|
||||
internal fun errorMessage(errorResponse: FlowException? = null) =
|
||||
ExistingSessionMessage(SessionId(0), ErrorSessionMessage(errorResponse, 0))
|
||||
ExistingSessionMessage(SessionId(BigInteger.valueOf(0)), ErrorSessionMessage(errorResponse, 0))
|
||||
|
||||
internal infix fun TestStartedNode.sent(message: SessionMessage): Pair<Int, SessionMessage> = Pair(internals.id, message)
|
||||
internal infix fun Pair<Int, SessionMessage>.to(node: TestStartedNode): SessionTransfer =
|
||||
@ -1103,10 +1107,10 @@ internal data class SessionTransfer(val from: Int, val message: SessionMessage,
|
||||
}
|
||||
|
||||
internal fun sessionInit(clientFlowClass: KClass<out FlowLogic<*>>, flowVersion: Int = 1, payload: Any? = null): InitialSessionMessage {
|
||||
return InitialSessionMessage(SessionId(0), 0, clientFlowClass.java.name, flowVersion, "", payload?.serialize())
|
||||
return InitialSessionMessage(SessionId(BigInteger.valueOf(0)), 0, clientFlowClass.java.name, flowVersion, "", payload?.serialize())
|
||||
}
|
||||
|
||||
internal fun sessionData(payload: Any) = ExistingSessionMessage(SessionId(0), DataSessionMessage(payload.serialize()))
|
||||
internal fun sessionData(payload: Any) = ExistingSessionMessage(SessionId(BigInteger.valueOf(0)), DataSessionMessage(payload.serialize()))
|
||||
|
||||
@InitiatingFlow
|
||||
internal open class SendFlow(private val payload: Any, private vararg val otherParties: Party) : FlowLogic<FlowInfo>() {
|
||||
|
@ -18,6 +18,7 @@ import org.junit.After
|
||||
import org.junit.Before
|
||||
import org.junit.Test
|
||||
import rx.Observable
|
||||
import java.math.BigInteger
|
||||
import java.util.*
|
||||
|
||||
class FlowFrameworkTripartyTests {
|
||||
@ -168,7 +169,7 @@ class FlowFrameworkTripartyTests {
|
||||
)
|
||||
}
|
||||
|
||||
private val normalEnd = ExistingSessionMessage(SessionId(0), EndSessionMessage) // NormalSessionEnd(0)
|
||||
private val normalEnd = ExistingSessionMessage(SessionId(BigInteger.valueOf(0)), EndSessionMessage) // NormalSessionEnd(0)
|
||||
|
||||
private fun assertSessionTransfers(node: TestStartedNode, vararg expected: SessionTransfer): List<SessionTransfer> {
|
||||
val actualForNode = receivedSessionMessages.filter { it.from == node.internals.id || it.to == node.network.myAddress }
|
||||
|
@ -0,0 +1,42 @@
|
||||
package net.corda.node.services.statemachine
|
||||
|
||||
import org.assertj.core.api.Assertions.*
|
||||
import org.junit.Test
|
||||
import java.lang.IllegalArgumentException
|
||||
import java.math.BigInteger
|
||||
|
||||
class SessionIdTest {
|
||||
|
||||
companion object {
|
||||
private val LARGEST_SESSION_ID_VALUE = BigInteger.valueOf(2).pow(128).minus(BigInteger.ONE)
|
||||
}
|
||||
|
||||
@Test(timeout=300_000)
|
||||
fun `session id must be positive and representable in 128 bits`() {
|
||||
assertThatThrownBy { SessionId(BigInteger.ZERO.minus(BigInteger.ONE)) }
|
||||
.isInstanceOf(IllegalArgumentException::class.java)
|
||||
.hasMessage("Session identifier cannot be a negative number, but it was -1")
|
||||
|
||||
assertThatThrownBy { SessionId(LARGEST_SESSION_ID_VALUE.plus(BigInteger.ONE)) }
|
||||
.isInstanceOf(IllegalArgumentException::class.java)
|
||||
.hasMessageContaining("The size of a session identifier cannot exceed 128 bits, but it was")
|
||||
|
||||
val correctSessionId = SessionId(LARGEST_SESSION_ID_VALUE)
|
||||
}
|
||||
|
||||
@Test(timeout=300_000)
|
||||
fun `initiated session id is calculated properly`() {
|
||||
val sessionId = SessionId(BigInteger.ONE)
|
||||
val initiatedSessionId = sessionId.calculateInitiatedSessionId()
|
||||
assertThat(initiatedSessionId.value.toLong()).isEqualTo(2)
|
||||
|
||||
}
|
||||
|
||||
@Test(timeout=300_000)
|
||||
fun `calculation of initiated session id wraps around`() {
|
||||
val sessionId = SessionId(LARGEST_SESSION_ID_VALUE)
|
||||
val initiatedSessionId = sessionId.calculateInitiatedSessionId()
|
||||
assertThat(initiatedSessionId.value.toLong()).isEqualTo(0)
|
||||
}
|
||||
|
||||
}
|
@ -2,7 +2,7 @@ package net.corda.testing.node.internal
|
||||
|
||||
import net.corda.core.utilities.ByteSequence
|
||||
import net.corda.node.services.messaging.Message
|
||||
import net.corda.node.services.statemachine.DeduplicationId
|
||||
import net.corda.node.services.messaging.MessageIdentifier
|
||||
import java.time.Instant
|
||||
|
||||
/**
|
||||
@ -10,7 +10,7 @@ import java.time.Instant
|
||||
*/
|
||||
data class InMemoryMessage(override val topic: String,
|
||||
override val data: ByteSequence,
|
||||
override val uniqueMessageId: DeduplicationId,
|
||||
override val uniqueMessageId: MessageIdentifier,
|
||||
override val debugTimestamp: Instant = Instant.now(),
|
||||
override val senderUUID: String? = null) : Message {
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
package net.corda.testing.node.internal
|
||||
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import net.corda.core.flows.StateMachineRunId
|
||||
import net.corda.core.identity.CordaX500Name
|
||||
import net.corda.core.identity.PartyAndCertificate
|
||||
@ -13,9 +14,9 @@ import net.corda.core.utilities.OpaqueBytes
|
||||
import net.corda.core.utilities.contextLogger
|
||||
import net.corda.node.services.config.NodeConfiguration
|
||||
import net.corda.node.services.messaging.*
|
||||
import net.corda.node.services.statemachine.DeduplicationId
|
||||
import net.corda.node.services.statemachine.ExternalEvent
|
||||
import net.corda.node.services.statemachine.SenderDeduplicationId
|
||||
import net.corda.node.services.messaging.SenderDeduplicationInfo
|
||||
import net.corda.node.services.statemachine.SessionId
|
||||
import net.corda.node.utilities.AffinityExecutor
|
||||
import net.corda.nodeapi.internal.lifecycle.ServiceStateHelper
|
||||
import net.corda.nodeapi.internal.lifecycle.ServiceStateSupport
|
||||
@ -46,9 +47,9 @@ class MockNodeMessagingService(private val configuration: NodeConfiguration,
|
||||
}
|
||||
|
||||
private val state = ThreadBox(InnerState())
|
||||
private val processedMessages: MutableSet<DeduplicationId> = Collections.synchronizedSet(HashSet<DeduplicationId>())
|
||||
private val processedMessages: MutableSet<MessageIdentifier> = Collections.synchronizedSet(HashSet())
|
||||
|
||||
override val ourSenderUUID: String = UUID.randomUUID().toString()
|
||||
override val ourSenderUUID: SenderUUID = UUID.randomUUID().toString()
|
||||
|
||||
private var _myAddress: InMemoryMessagingNetwork.PeerHandle? = null
|
||||
override val myAddress: InMemoryMessagingNetwork.PeerHandle get() = checkNotNull(_myAddress) { "Not started" }
|
||||
@ -167,6 +168,11 @@ class MockNodeMessagingService(private val configuration: NodeConfiguration,
|
||||
}
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
override fun sessionEnded(sessionId: SessionId, senderUUID: SenderUUID?, senderSequenceNumber: SenderSequenceNumber?) {
|
||||
// nothing to do here.
|
||||
}
|
||||
|
||||
override fun close() {
|
||||
backgroundThread?.let {
|
||||
it.interrupt()
|
||||
@ -178,8 +184,8 @@ class MockNodeMessagingService(private val configuration: NodeConfiguration,
|
||||
}
|
||||
|
||||
/** Returns the given (topic & session, data) pair as a newly created message object. */
|
||||
override fun createMessage(topic: String, data: ByteArray, deduplicationId: SenderDeduplicationId, additionalHeaders: Map<String, String>): Message {
|
||||
return InMemoryMessage(topic, OpaqueBytes(data), deduplicationId.deduplicationId, senderUUID = deduplicationId.senderUUID)
|
||||
override fun createMessage(topic: String, data: ByteArray, deduplicationInfo: SenderDeduplicationInfo, additionalHeaders: Map<String, String>): Message {
|
||||
return InMemoryMessage(topic, OpaqueBytes(data), deduplicationInfo.messageIdentifier, senderUUID = deduplicationInfo.senderUUID)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -269,7 +275,7 @@ class MockNodeMessagingService(private val configuration: NodeConfiguration,
|
||||
private data class InMemoryReceivedMessage(override val topic: String,
|
||||
override val data: ByteSequence,
|
||||
override val platformVersion: Int,
|
||||
override val uniqueMessageId: DeduplicationId,
|
||||
override val uniqueMessageId: MessageIdentifier,
|
||||
override val debugTimestamp: Instant,
|
||||
override val peer: CordaX500Name,
|
||||
override val senderUUID: String? = null,
|
||||
|
@ -4,14 +4,24 @@ import net.corda.core.messaging.AllPossibleRecipients
|
||||
import net.corda.core.serialization.internal.effectiveSerializationEnv
|
||||
import net.corda.node.services.messaging.Message
|
||||
import net.corda.coretesting.internal.rigorousMock
|
||||
import net.corda.node.services.messaging.MessageIdentifier
|
||||
import net.corda.node.services.messaging.SenderDeduplicationInfo
|
||||
import net.corda.node.services.statemachine.MessageType
|
||||
import net.corda.node.services.statemachine.SessionId
|
||||
import org.assertj.core.api.Assertions.assertThatThrownBy
|
||||
import org.junit.After
|
||||
import org.junit.Test
|
||||
import java.math.BigInteger
|
||||
import java.time.Clock
|
||||
import java.util.*
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
class InternalMockNetworkTests {
|
||||
companion object {
|
||||
private val MESSAGE_IDENTIFIER = MessageIdentifier(MessageType.DATA_MESSAGE, "XXXXXXXX", SessionId(BigInteger.valueOf(14)), 0, Clock.systemUTC().instant())
|
||||
}
|
||||
|
||||
lateinit var mockNet: InternalMockNetwork
|
||||
|
||||
@After
|
||||
@ -39,7 +49,7 @@ class InternalMockNetworkTests {
|
||||
}
|
||||
|
||||
// Node 1 sends a message and it should end up in finalDelivery, after we run the network
|
||||
node1.network.send(node1.network.createMessage("test.topic", data = bits), node2.network.myAddress)
|
||||
node1.network.send(node1.network.createMessage("test.topic", bits, SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap()), node2.network.myAddress)
|
||||
|
||||
mockNet.runNetwork(rounds = 1)
|
||||
|
||||
@ -58,7 +68,7 @@ class InternalMockNetworkTests {
|
||||
|
||||
var counter = 0
|
||||
listOf(node1, node2, node3).forEach { it.network.addMessageHandler("test.topic") { _, _, _ -> counter++ } }
|
||||
node1.network.send(node2.network.createMessage("test.topic", data = bits), rigorousMock<AllPossibleRecipients>())
|
||||
node1.network.send(node2.network.createMessage("test.topic", bits, SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap()), rigorousMock<AllPossibleRecipients>())
|
||||
mockNet.runNetwork(rounds = 1)
|
||||
assertEquals(3, counter)
|
||||
}
|
||||
@ -79,8 +89,8 @@ class InternalMockNetworkTests {
|
||||
received++
|
||||
}
|
||||
|
||||
val invalidMessage = node2.network.createMessage("invalid_message", data = ByteArray(1))
|
||||
val validMessage = node2.network.createMessage("valid_message", data = ByteArray(1))
|
||||
val invalidMessage = node2.network.createMessage("invalid_message", ByteArray(1), SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
|
||||
val validMessage = node2.network.createMessage("valid_message", ByteArray(1), SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
|
||||
node2.network.send(invalidMessage, node1.network.myAddress)
|
||||
mockNet.runNetwork()
|
||||
assertEquals(0, received)
|
||||
@ -91,8 +101,8 @@ class InternalMockNetworkTests {
|
||||
|
||||
// Here's the core of the test; previously the unhandled message would cause runNetwork() to abort early, so
|
||||
// this would fail. Make fresh messages to stop duplicate uniqueMessageId causing drops
|
||||
val invalidMessage2 = node2.network.createMessage("invalid_message", data = ByteArray(1))
|
||||
val validMessage2 = node2.network.createMessage("valid_message", data = ByteArray(1))
|
||||
val invalidMessage2 = node2.network.createMessage("invalid_message", ByteArray(1), SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
|
||||
val validMessage2 = node2.network.createMessage("valid_message", ByteArray(1), SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
|
||||
node2.network.send(invalidMessage2, node1.network.myAddress)
|
||||
node2.network.send(validMessage2, node1.network.myAddress)
|
||||
mockNet.runNetwork()
|
||||
|
Loading…
x
Reference in New Issue
Block a user