Introduce MessageIdentifier and related tests

This commit is contained in:
Dimos Raptis 2020-08-14 11:44:07 +01:00
parent 8ed3dc1150
commit 64e7fdd83a
43 changed files with 1058 additions and 355 deletions

View File

@ -24,6 +24,8 @@ import net.corda.testing.internal.TestingNamedCacheFactory
import net.corda.testing.internal.configureDatabase
import net.corda.coretesting.internal.rigorousMock
import net.corda.coretesting.internal.stubs.CertificateStoreStubs
import net.corda.node.services.statemachine.MessageType
import net.corda.node.services.statemachine.SessionId
import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties
import net.corda.testing.node.internal.MOCK_VERSION_INFO
import org.apache.activemq.artemis.api.core.ActiveMQConnectionTimedOutException
@ -35,7 +37,9 @@ import org.junit.Rule
import org.junit.Test
import org.junit.rules.TemporaryFolder
import rx.subjects.PublishSubject
import java.math.BigInteger
import java.net.ServerSocket
import java.time.Clock
import java.util.concurrent.BlockingQueue
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.TimeUnit.MILLISECONDS
@ -47,6 +51,7 @@ import kotlin.test.assertTrue
class ArtemisMessagingTest {
companion object {
const val TOPIC = "platform.self"
private val MESSAGE_IDENTIFIER = MessageIdentifier(MessageType.DATA_MESSAGE, "XXXXXXXX", SessionId(BigInteger.valueOf(14)), 0, Clock.systemUTC().instant())
}
@Rule
@ -142,7 +147,7 @@ class ArtemisMessagingTest {
@Test(timeout=300_000)
fun `client should be able to send message to itself`() {
val (messagingClient, receivedMessages) = createAndStartClientAndServer()
val message = messagingClient.createMessage(TOPIC, data = "first msg".toByteArray())
val message = messagingClient.createMessage(TOPIC, "first msg".toByteArray(), SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
messagingClient.send(message, messagingClient.myAddress)
val actual: Message = receivedMessages.take()
@ -153,14 +158,14 @@ class ArtemisMessagingTest {
@Test(timeout=300_000)
fun `client should fail if message exceed maxMessageSize limit`() {
val (messagingClient, receivedMessages) = createAndStartClientAndServer()
val message = messagingClient.createMessage(TOPIC, data = ByteArray(MAX_MESSAGE_SIZE))
val message = messagingClient.createMessage(TOPIC, ByteArray(MAX_MESSAGE_SIZE), SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
messagingClient.send(message, messagingClient.myAddress)
val actual: Message = receivedMessages.take()
assertTrue(ByteArray(MAX_MESSAGE_SIZE).contentEquals(actual.data.bytes))
assertNull(receivedMessages.poll(200, MILLISECONDS))
val tooLagerMessage = messagingClient.createMessage(TOPIC, data = ByteArray(MAX_MESSAGE_SIZE + 1))
val tooLagerMessage = messagingClient.createMessage(TOPIC, ByteArray(MAX_MESSAGE_SIZE + 1), SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
assertThatThrownBy {
messagingClient.send(tooLagerMessage, messagingClient.myAddress)
}.isInstanceOf(IllegalArgumentException::class.java)
@ -172,14 +177,14 @@ class ArtemisMessagingTest {
@Test(timeout=300_000)
fun `server should not process if incoming message exceed maxMessageSize limit`() {
val (messagingClient, receivedMessages) = createAndStartClientAndServer(clientMaxMessageSize = 100_000, serverMaxMessageSize = 50_000)
val message = messagingClient.createMessage(TOPIC, data = ByteArray(50_000))
val message = messagingClient.createMessage(TOPIC, ByteArray(50_000), SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
messagingClient.send(message, messagingClient.myAddress)
val actual: Message = receivedMessages.take()
assertTrue(ByteArray(50_000).contentEquals(actual.data.bytes))
assertNull(receivedMessages.poll(200, MILLISECONDS))
val tooLagerMessage = messagingClient.createMessage(TOPIC, data = ByteArray(100_000))
val tooLagerMessage = messagingClient.createMessage(TOPIC, ByteArray(100_000), SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
assertThatThrownBy {
messagingClient.send(tooLagerMessage, messagingClient.myAddress)
}.isInstanceOf(ActiveMQConnectionTimedOutException::class.java)
@ -189,7 +194,7 @@ class ArtemisMessagingTest {
@Test(timeout=300_000)
fun `platform version is included in the message`() {
val (messagingClient, receivedMessages) = createAndStartClientAndServer(platformVersion = 3)
val message = messagingClient.createMessage(TOPIC, data = "first msg".toByteArray())
val message = messagingClient.createMessage(TOPIC, "first msg".toByteArray(), SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
messagingClient.send(message, messagingClient.myAddress)
val received = receivedMessages.take()

View File

@ -10,9 +10,13 @@ import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.seconds
import net.corda.node.services.messaging.MessageIdentifier
import net.corda.node.services.messaging.MessagingService
import net.corda.node.services.messaging.ReceivedMessage
import net.corda.node.services.messaging.SenderDeduplicationInfo
import net.corda.node.services.messaging.send
import net.corda.node.services.statemachine.MessageType
import net.corda.node.services.statemachine.SessionId
import net.corda.testing.driver.DriverDSL
import net.corda.testing.driver.DriverParameters
import net.corda.testing.driver.InProcess
@ -23,12 +27,15 @@ import net.corda.testing.node.NotarySpec
import org.assertj.core.api.Assertions.assertThat
import org.junit.Ignore
import org.junit.Test
import java.math.BigInteger
import java.time.Clock
import java.util.*
import java.util.concurrent.atomic.AtomicBoolean
class P2PMessagingTest {
private companion object {
val DISTRIBUTED_SERVICE_NAME = CordaX500Name("DistributedService", "London", "GB")
private val MESSAGE_IDENTIFIER = MessageIdentifier(MessageType.DATA_MESSAGE, "XXXXXXXX", SessionId(BigInteger.valueOf(14)), 0, Clock.systemUTC().instant())
}
@Test(timeout=300_000)
@ -72,7 +79,7 @@ class P2PMessagingTest {
private fun InProcess.respondWith(message: Any) {
internalServices.networkService.addMessageHandler("test.request") { netMessage, _, handler ->
val request = netMessage.data.deserialize<TestRequest>()
val response = internalServices.networkService.createMessage("test.response", message.serialize().bytes)
val response = internalServices.networkService.createMessage("test.response", message.serialize().bytes, SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
internalServices.networkService.send(response, request.replyTo)
handler.afterDatabaseTransaction()
}
@ -83,7 +90,7 @@ class P2PMessagingTest {
internalServices.networkService.runOnNextMessage("test.response") { netMessage ->
response.set(netMessage.data.deserialize())
}
internalServices.networkService.send("test.request", TestRequest(replyTo = internalServices.networkService.myAddress), target)
internalServices.networkService.send("test.request", TestRequest(replyTo = internalServices.networkService.myAddress), target, SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
return response
}

View File

@ -0,0 +1,93 @@
package net.corda.node.services.messaging
import net.corda.core.crypto.SecureHash
import net.corda.node.services.statemachine.MessageType
import net.corda.node.services.statemachine.SessionId
import java.lang.IllegalStateException
import java.math.BigInteger
import java.time.Instant
/**
* This represents the unique identifier for every message.
* It's composed of multiple segments.
*
* @property messageType the type of the message.
* @property shardIdentifier an identifier that can be used to partition messages into groups for sharding purposes.
* This is supposed to have the same value for messages that correspond to the same business-level flow. It is
* @property sessionIdentifier the identifier of the session this message belongs to. This corresponds to the identifier of the session on the receiving side.
* @property sessionSequenceNumber the sequence number of the message inside the session. This can be used to handle out-of-order delivery.
* @property timestamp the time when the message was requested to be sent.
* This is expected to remain the same across replays of the same message and represent the moment in time when this message was initially scheduled to be sent.
*/
data class MessageIdentifier(
val messageType: MessageType,
val shardIdentifier: String,
val sessionIdentifier: SessionId,
val sessionSequenceNumber: Int,
val timestamp: Instant
) {
init {
require(shardIdentifier.length == 8) { "Shard identifier needs to be 8 characters long, but it was $shardIdentifier" }
}
companion object {
const val LONG_SIZE_IN_HEX = 16 // 64 / 4
const val SESSION_ID_SIZE_IN_HEX = SessionId.MAX_BIT_SIZE / 4
fun parse(id: String): MessageIdentifier {
val prefix = id.substring(0, 2)
val messageType = prefixToMessageType(prefix)
val timestamp = java.lang.Long.parseUnsignedLong(id.substring(3, 19), 16)
val shardIdentifier = id.substring(20, 28)
val sessionId = BigInteger(id.substring(29, 61), 16)
val sessionSequenceNumber = Integer.parseInt(id.substring(62), 16)
return MessageIdentifier(messageType, shardIdentifier, SessionId(sessionId), sessionSequenceNumber, Instant.ofEpochMilli(timestamp))
}
private fun messageTypeToPrefix(messageType: MessageType): String {
return when(messageType) {
MessageType.SESSION_INIT -> "XI"
MessageType.SESSION_CONFIRM -> "XC"
MessageType.SESSION_REJECT -> "XR"
MessageType.DATA_MESSAGE -> "XD"
MessageType.SESSION_END -> "XE"
MessageType.SESSION_ERROR -> "XX"
}
}
private fun prefixToMessageType(prefix: String): MessageType {
return when(prefix) {
"XI" -> MessageType.SESSION_INIT
"XC" -> MessageType.SESSION_CONFIRM
"XR" -> MessageType.SESSION_REJECT
"XD" -> MessageType.DATA_MESSAGE
"XE" -> MessageType.SESSION_END
"XX" -> MessageType.SESSION_ERROR
else -> throw IllegalStateException("Invalid prefix: $prefix")
}
}
}
override fun toString(): String {
val prefix = messageTypeToPrefix(messageType)
val encodedSessionIdentifier = String.format("%1$0${SESSION_ID_SIZE_IN_HEX}X", sessionIdentifier.value)
val encodedSequenceNumber = Integer.toHexString(sessionSequenceNumber).toUpperCase()
val encodedTimestamp = String.format("%1$0${LONG_SIZE_IN_HEX}X", timestamp.toEpochMilli())
return "$prefix-$encodedTimestamp-$shardIdentifier-$encodedSessionIdentifier-$encodedSequenceNumber"
}
}
fun generateShardId(flowIdentifier: String): String {
return SecureHash.sha256(flowIdentifier).prefixChars(8)
}
/**
* A unique identifier for a sender that might be different across restarts.
* It is used to help identify when messages are being sent continuously without errors or message are sent after the sender recovered from an error.
*/
typealias SenderUUID = String
/**
* A global sequence number for all the messages sent by a sender.
*/
typealias SenderSequenceNumber = Long

View File

@ -1,7 +1,6 @@
package net.corda.node.services.messaging
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.crypto.newSecureRandom
import net.corda.core.identity.CordaX500Name
import net.corda.core.messaging.MessageRecipients
import net.corda.core.messaging.SingleMessageRecipient
@ -9,9 +8,8 @@ import net.corda.core.node.services.PartyInfo
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.serialize
import net.corda.core.utilities.ByteSequence
import net.corda.node.services.statemachine.DeduplicationId
import net.corda.node.services.statemachine.ExternalEvent
import net.corda.node.services.statemachine.SenderDeduplicationId
import net.corda.node.services.statemachine.SessionId
import net.corda.nodeapi.internal.lifecycle.ServiceLifecycleSupport
import java.time.Instant
import javax.annotation.concurrent.ThreadSafe
@ -32,7 +30,7 @@ interface MessagingService : ServiceLifecycleSupport {
* A unique identifier for this sender that changes whenever a node restarts. This is used in conjunction with a sequence
* number for message de-duplication at the recipient.
*/
val ourSenderUUID: String
val ourSenderUUID: SenderUUID
/**
* The provided function will be invoked for each received message whose topic and session matches. The callback
@ -92,15 +90,25 @@ interface MessagingService : ServiceLifecycleSupport {
@Suspendable
fun sendAll(addressedMessages: List<AddressedMessage>)
/**
* Signal that a session has ended to the messaging layer, so that any necessary cleanup is performed.
*
* @param sessionId the identifier of the session that ended.
* @param senderUUID the sender UUID of the last message seen in the session or null if there was no sender UUID in that message.
* @param senderSequenceNumber the sender sequence number of the last message seen in the session or null if there was no sender sequence number in that message.
*/
@Suspendable
fun sessionEnded(sessionId: SessionId, senderUUID: SenderUUID?, senderSequenceNumber: SenderSequenceNumber?)
/**
* Returns an initialised [Message] with the current time, etc, already filled in.
*
* @param topic identifier for the topic the message is sent to.
* @param data the payload for the message.
* @param deduplicationId optional message deduplication ID including sender identifier.
* @param deduplicationInfo optional message deduplication information.
* @param additionalHeaders optional additional message headers.
*/
fun createMessage(topic: String, data: ByteArray, deduplicationId: SenderDeduplicationId = SenderDeduplicationId(DeduplicationId.createRandom(newSecureRandom()), ourSenderUUID), additionalHeaders: Map<String, String> = emptyMap()): Message
fun createMessage(topic: String, data: ByteArray, deduplicationInfo: SenderDeduplicationInfo, additionalHeaders: Map<String, String>): Message
/** Given information about either a specific node or a service returns its corresponding address */
fun getAddressOfParty(partyInfo: PartyInfo): MessageRecipients
@ -109,7 +117,7 @@ interface MessagingService : ServiceLifecycleSupport {
val myAddress: SingleMessageRecipient
}
fun MessagingService.send(topicSession: String, payload: Any, to: MessageRecipients, deduplicationId: SenderDeduplicationId = SenderDeduplicationId(DeduplicationId.createRandom(newSecureRandom()), ourSenderUUID), additionalHeaders: Map<String, String> = emptyMap()) = send(createMessage(topicSession, payload.serialize().bytes, deduplicationId, additionalHeaders), to)
fun MessagingService.send(topicSession: String, payload: Any, to: MessageRecipients, deduplicationInfo: SenderDeduplicationInfo, additionalHeaders: Map<String, String>) = send(createMessage(topicSession, payload.serialize().bytes, deduplicationInfo, additionalHeaders), to)
interface MessageHandlerRegistration
@ -128,7 +136,7 @@ interface Message {
val topic: String
val data: ByteSequence
val debugTimestamp: Instant
val uniqueMessageId: DeduplicationId
val uniqueMessageId: MessageIdentifier
val senderUUID: String?
val additionalHeaders: Map<String, String>
}

View File

@ -75,7 +75,7 @@ class MessagingExecutor(
putStringProperty(P2PMessagingHeaders.topicProperty, SimpleString(message.topic))
writeBodyBufferBytes(message.data.bytes)
// Use the magic deduplication property built into Artemis as our message identity too
putStringProperty(org.apache.activemq.artemis.api.core.Message.HDR_DUPLICATE_DETECTION_ID, SimpleString(message.uniqueMessageId.toString))
putStringProperty(org.apache.activemq.artemis.api.core.Message.HDR_DUPLICATE_DETECTION_ID, SimpleString(message.uniqueMessageId.toString()))
// If we are the sender (ie. we are not going through recovery of some sort), use sequence number short cut.
if (ourSenderUUID == message.senderUUID) {
putStringProperty(P2PMessagingHeaders.senderUUID, SimpleString(ourSenderUUID))

View File

@ -1,56 +1,79 @@
package net.corda.node.services.messaging
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.crypto.SecureHash
import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.NamedCacheFactory
import net.corda.node.services.statemachine.DeduplicationId
import net.corda.core.utilities.contextLogger
import net.corda.node.services.statemachine.SessionId
import net.corda.node.utilities.AppendOnlyPersistentMap
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
import net.corda.nodeapi.internal.persistence.currentDBSession
import java.math.BigInteger
import java.time.Instant
import java.util.concurrent.ConcurrentHashMap
import javax.persistence.Column
import javax.persistence.Entity
import javax.persistence.Id
import javax.persistence.Table
/**
* Encapsulate the de-duplication logic.
* This component is responsible for determining whether session-init messages are duplicates and it also keeps track of information related to
* sessions that can be used for this purpose.
*/
class P2PMessageDeduplicator(cacheFactory: NamedCacheFactory, private val database: CordaPersistence) {
companion object {
private val logger = contextLogger()
}
// A temporary in-memory set of deduplication IDs and associated high water mark details.
// When we receive a message we don't persist the ID immediately,
// so we store the ID here in the meantime (until the persisting db tx has committed). This is because Artemis may
// redeliver messages to the same consumer if they weren't ACKed.
private val beingProcessedMessages = ConcurrentHashMap<DeduplicationId, MessageMeta>()
private val processedMessages = createProcessedMessages(cacheFactory)
private val beingProcessedMessages = ConcurrentHashMap<MessageIdentifier, MessageMeta>()
private fun createProcessedMessages(cacheFactory: NamedCacheFactory): AppendOnlyPersistentMap<DeduplicationId, MessageMeta, ProcessedMessage, String> {
/**
* This table holds data *only* for sessions that have been initiated from a counterparty (e.g. ones we have received session-init messages from).
* This is because any other messages apart from session-init messages are deduplicated by the state machine.
*/
private val sessionData = createSessionDataMap(cacheFactory)
private fun createSessionDataMap(cacheFactory: NamedCacheFactory): AppendOnlyPersistentMap<SessionId, MessageMeta, SessionData, BigInteger> {
return AppendOnlyPersistentMap(
cacheFactory = cacheFactory,
name = "P2PMessageDeduplicator_processedMessages",
toPersistentEntityKey = { it.toString },
fromPersistentEntity = { Pair(DeduplicationId(it.id), MessageMeta(it.insertionTime, it.hash, it.seqNo)) },
toPersistentEntity = { key: DeduplicationId, value: MessageMeta ->
ProcessedMessage().apply {
id = key.toString
insertionTime = value.insertionTime
hash = value.senderHash
seqNo = value.senderSeqNo
name = "P2PMessageDeduplicator_sessionData",
toPersistentEntityKey = { it.value },
fromPersistentEntity = { Pair(SessionId(it.sessionId), MessageMeta(it.generationTime, it.senderHash, it.firstSenderSeqNo, it.lastSenderSeqNo)) },
toPersistentEntity = { key: SessionId, value: MessageMeta ->
SessionData().apply {
sessionId = key.value
generationTime = value.generationTime
senderHash = value.senderHash
firstSenderSeqNo = value.firstSenderSeqNo
lastSenderSeqNo = value.lastSenderSeqNo
}
},
persistentEntityClass = ProcessedMessage::class.java
persistentEntityClass = SessionData::class.java
)
}
private fun isDuplicateInDatabase(msg: ReceivedMessage): Boolean = database.transaction { msg.uniqueMessageId in processedMessages }
private fun isDuplicateInDatabase(msg: ReceivedMessage): Boolean = database.transaction { msg.uniqueMessageId.sessionIdentifier in sessionData }
// We need to incorporate the sending party, and the sessionInit flag as per the in-memory cache.
private fun senderHash(senderKey: SenderKey) = SecureHash.sha256(senderKey.peer.toString() + senderKey.isSessionInit.toString() + senderKey.senderUUID).toString()
/**
* Determines whether a session-init message is a duplicate.
* This is achieved by checking whether this message is currently being processed or if the associated session has already been created in the past.
* This method should be invoked only with session-init messages, otherwise it will fail with an [IllegalArgumentException].
*
* @return true if we have seen this message before.
*/
fun isDuplicate(msg: ReceivedMessage): Boolean {
fun isDuplicateSessionInit(msg: ReceivedMessage): Boolean {
require(msg.isSessionInit) { "Message ${msg.uniqueMessageId} was not a session-init message." }
if (beingProcessedMessages.containsKey(msg.uniqueMessageId)) {
return true
}
@ -65,44 +88,86 @@ class P2PMessageDeduplicator(cacheFactory: NamedCacheFactory, private val databa
val receivedSenderSeqNo = msg.senderSeqNo
// We don't want a mix of nulls and values so we ensure that here.
val senderHash: String? = if (receivedSenderUUID != null && receivedSenderSeqNo != null) senderHash(SenderKey(receivedSenderUUID, msg.peer, msg.isSessionInit)) else null
val senderSeqNo: Long? = if (senderHash != null) msg.senderSeqNo else null
beingProcessedMessages[msg.uniqueMessageId] = MessageMeta(Instant.now(), senderHash, senderSeqNo)
val firstSenderSeqNo: Long? = if (senderHash != null) msg.senderSeqNo else null
beingProcessedMessages[msg.uniqueMessageId] = MessageMeta(msg.uniqueMessageId.timestamp, senderHash, firstSenderSeqNo, null)
}
/**
* Called inside a DB transaction to persist [deduplicationId].
*/
fun persistDeduplicationId(deduplicationId: DeduplicationId) {
processedMessages[deduplicationId] = beingProcessedMessages[deduplicationId]!!
fun persistDeduplicationId(deduplicationId: MessageIdentifier) {
sessionData[deduplicationId.sessionIdentifier] = beingProcessedMessages[deduplicationId]!!
}
/**
* Called after the DB transaction persisting [deduplicationId] committed.
* Any subsequent redelivery will be deduplicated using the DB.
*/
fun signalMessageProcessFinish(deduplicationId: DeduplicationId) {
fun signalMessageProcessFinish(deduplicationId: MessageIdentifier) {
beingProcessedMessages.remove(deduplicationId)
}
/**
* Called inside a DB transaction to update entry for corresponding session.
* The parameters [senderUUID] and [senderSequenceNumber] correspond to the last message seen from this session before it ended.
* If [senderUUID] is not null, then [senderSequenceNumber] is also expected to not be null.
*/
@Suspendable
fun signalSessionEnd(sessionId: SessionId, senderUUID: String?, senderSequenceNumber: Long?) {
if (senderSequenceNumber != null && senderUUID != null) {
val existingEntry = sessionData[sessionId]
if (existingEntry != null) {
val newEntry = existingEntry.copy(lastSenderSeqNo = senderSequenceNumber)
sessionData.addOrUpdate(sessionId, newEntry) { k, v ->
update(k, v)
}
}
}
}
private fun update(key: SessionId, value: MessageMeta): Boolean {
val session = currentDBSession()
val criteriaBuilder = session.criteriaBuilder
val criteriaUpdate = criteriaBuilder.createCriteriaUpdate(SessionData::class.java)
val queryRoot = criteriaUpdate.from(SessionData::class.java)
criteriaUpdate.set(SessionData::lastSenderSeqNo.name, value.lastSenderSeqNo)
criteriaUpdate.where(criteriaBuilder.equal(queryRoot.get<BigInteger>(SessionData::sessionId.name), key.value))
val update = session.createQuery(criteriaUpdate)
val rowsUpdated = update.executeUpdate()
return rowsUpdated != 0
}
@Entity
@Suppress("MagicNumber") // database column width
@javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}message_ids")
class ProcessedMessage(
@Table(name = "${NODE_DATABASE_PREFIX}session_data")
class SessionData (
@Id
@Column(name = "message_id", length = 64, nullable = false)
var id: String = "",
@Column(name = "session_id", nullable = false)
var sessionId: BigInteger = BigInteger.ZERO,
@Column(name = "insertion_time", nullable = false)
var insertionTime: Instant = Instant.now(),
/**
* The time the corresponding session-init message was originally generated on the sender side.
*/
@Column(name = "init_generation_time", nullable = false)
var generationTime: Instant = Instant.now(),
@Column(name = "sender", length = 64, nullable = true)
var hash: String? = "",
@Column(name = "sender_hash", length = 64, nullable = true)
var senderHash: String? = "",
@Column(name = "sequence_number", nullable = true)
var seqNo: Long? = null
/**
* The sender sequence number of the first message seen in a session.
*/
@Column(name = "init_sequence_number", nullable = true)
var firstSenderSeqNo: Long? = null,
/**
* The sender sequence number of the last message seen in a session before it was closed/terminated.
*/
@Column(name = "last_sequence_number", nullable = true)
var lastSenderSeqNo: Long? = null
)
private data class MessageMeta(val insertionTime: Instant, val senderHash: String?, val senderSeqNo: Long?)
private data class MessageMeta(val generationTime: Instant, val senderHash: String?, val firstSenderSeqNo: SenderSequenceNumber?, val lastSenderSeqNo: SenderSequenceNumber?)
private data class SenderKey(val senderUUID: String, val peer: CordaX500Name, val isSessionInit: Boolean)
}

View File

@ -25,9 +25,9 @@ import net.corda.node.internal.artemis.ReactiveArtemisConsumer
import net.corda.node.internal.artemis.ReactiveArtemisConsumer.Companion.multiplex
import net.corda.node.services.api.NetworkMapCacheInternal
import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.statemachine.DeduplicationId
import net.corda.node.services.statemachine.ExternalEvent
import net.corda.node.services.statemachine.SenderDeduplicationId
import net.corda.node.services.statemachine.MessageType
import net.corda.node.services.statemachine.SessionId
import net.corda.node.utilities.AffinityExecutor
import net.corda.node.utilities.errorAndTerminate
import net.corda.nodeapi.internal.ArtemisMessagingComponent
@ -101,8 +101,8 @@ class P2PMessagingClient(val config: NodeConfiguration,
private class NodeClientMessage(override val topic: String,
override val data: ByteSequence,
override val uniqueMessageId: DeduplicationId,
override val senderUUID: String?,
override val uniqueMessageId: MessageIdentifier,
override val senderUUID: SenderUUID?,
override val additionalHeaders: Map<String, String>) : Message {
override val debugTimestamp: Instant = Instant.now()
override fun toString() = "$topic#${String(data.bytes)}"
@ -371,7 +371,7 @@ class P2PMessagingClient(val config: NodeConfiguration,
val user = requireNotNull(message.getStringProperty(HDR_VALIDATED_USER)) { "Message is not authenticated" }
val platformVersion = message.required(P2PMessagingHeaders.platformVersionProperty) { getIntProperty(it) }
// Use the magic deduplication property built into Artemis as our message identity too
val uniqueMessageId = message.required(HDR_DUPLICATE_DETECTION_ID) { DeduplicationId(message.getStringProperty(it)) }
val uniqueMessageId = message.required(HDR_DUPLICATE_DETECTION_ID) { MessageIdentifier.parse(message.getStringProperty(it)) }
val receivedSenderUUID = message.getStringProperty(P2PMessagingHeaders.senderUUID)
val receivedSenderSeqNo = if (message.containsProperty(P2PMessagingHeaders.senderSeqNo)) message.getLongProperty(P2PMessagingHeaders.senderSeqNo) else null
val isSessionInit = message.getStringProperty(P2PMessagingHeaders.Type.KEY) == P2PMessagingHeaders.Type.SESSION_INIT_VALUE
@ -392,8 +392,8 @@ class P2PMessagingClient(val config: NodeConfiguration,
private class ArtemisReceivedMessage(override val topic: String,
override val peer: CordaX500Name,
override val platformVersion: Int,
override val uniqueMessageId: DeduplicationId,
override val senderUUID: String?,
override val uniqueMessageId: MessageIdentifier,
override val senderUUID: SenderUUID?,
override val senderSeqNo: Long?,
override val isSessionInit: Boolean,
private val message: ClientMessage) : ReceivedMessage {
@ -405,12 +405,17 @@ class P2PMessagingClient(val config: NodeConfiguration,
internal fun deliver(artemisMessage: ClientMessage) {
artemisToCordaMessage(artemisMessage)?.let { cordaMessage ->
if (!deduplicator.isDuplicate(cordaMessage)) {
deduplicator.signalMessageProcessStart(cordaMessage)
deliver(cordaMessage, artemisMessage)
if (cordaMessage.uniqueMessageId.messageType == MessageType.SESSION_INIT) {
if (!deduplicator.isDuplicateSessionInit(cordaMessage)) {
deduplicator.signalMessageProcessStart(cordaMessage)
deliver(cordaMessage, artemisMessage)
} else {
log.debug { "Discarding duplicate session-init message with identifier: ${cordaMessage.uniqueMessageId}, senderUUID: ${cordaMessage.senderUUID}, senderSeqNo: ${cordaMessage.senderSeqNo}" }
messagingExecutor!!.acknowledge(artemisMessage)
}
} else {
log.trace { "Discard duplicate message ${cordaMessage.uniqueMessageId} for ${cordaMessage.topic}" }
messagingExecutor!!.acknowledge(artemisMessage)
// non session-init messages are directly handed to the state machine, which is responsible for performing deduplication.
deliver(cordaMessage, artemisMessage)
}
}
}
@ -420,7 +425,11 @@ class P2PMessagingClient(val config: NodeConfiguration,
val deliverTo = handlers[msg.topic]
if (deliverTo != null) {
try {
deliverTo(msg, HandlerRegistration(msg.topic, deliverTo), MessageDeduplicationHandler(artemisMessage, msg))
if (msg.uniqueMessageId.messageType == MessageType.SESSION_INIT) {
deliverTo(msg, HandlerRegistration(msg.topic, deliverTo), MessageDeduplicationHandlerForSessionInitMessages(artemisMessage, msg))
} else {
deliverTo(msg, HandlerRegistration(msg.topic, deliverTo), MessageDeduplicationHandlerForRegularMessages(artemisMessage, msg))
}
} catch (e: Exception) {
log.error("Caught exception whilst executing message handler for ${msg.topic}", e)
}
@ -429,11 +438,11 @@ class P2PMessagingClient(val config: NodeConfiguration,
}
}
private inner class MessageDeduplicationHandler(val artemisMessage: ClientMessage, override val receivedMessage: ReceivedMessage) : DeduplicationHandler, ExternalEvent.ExternalMessageEvent {
private inner class MessageDeduplicationHandlerForSessionInitMessages(val artemisMessage: ClientMessage, override val receivedMessage: ReceivedMessage) : DeduplicationHandler, ExternalEvent.ExternalMessageEvent {
override val externalCause: ExternalEvent
get() = this
override val flowId: StateMachineRunId by lazy { StateMachineRunId.createRandom() }
override val deduplicationHandler: MessageDeduplicationHandler
override val deduplicationHandler: MessageDeduplicationHandlerForSessionInitMessages
get() = this
override fun insideDatabaseTransaction() {
@ -450,6 +459,27 @@ class P2PMessagingClient(val config: NodeConfiguration,
}
}
private inner class MessageDeduplicationHandlerForRegularMessages(val artemisMessage: ClientMessage, override val receivedMessage: ReceivedMessage) : DeduplicationHandler, ExternalEvent.ExternalMessageEvent {
override val externalCause: ExternalEvent
get() = this
override val flowId: StateMachineRunId by lazy { StateMachineRunId.createRandom() }
override val deduplicationHandler: MessageDeduplicationHandlerForRegularMessages
get() = this
/**
* Nothing to do, since deduplication information is kept in the state machine.
*/
override fun insideDatabaseTransaction() {}
override fun afterDatabaseTransaction() {
messagingExecutor!!.acknowledge(artemisMessage)
}
override fun toString(): String {
return "${javaClass.simpleName}(${receivedMessage.uniqueMessageId})"
}
}
/**
* Initiates shutdown: if called from a thread that isn't controlled by the executor passed to the constructor
* then this will block until all in-flight messages have finished being handled and acknowledged. If called
@ -520,6 +550,11 @@ class P2PMessagingClient(val config: NodeConfiguration,
}
}
@Suspendable
override fun sessionEnded(sessionId: SessionId, senderUUID: SenderUUID?, senderSequenceNumber: SenderSequenceNumber?) {
deduplicator.signalSessionEnd(sessionId, senderUUID, senderSequenceNumber)
}
override fun resolveTargetToArtemisQueue(address: MessageRecipients): String {
return if (address == myAddress) {
// If we are sending to ourselves then route the message directly to our P2P queue.
@ -586,8 +621,8 @@ class P2PMessagingClient(val config: NodeConfiguration,
handlers.remove(registration.topic)
}
override fun createMessage(topic: String, data: ByteArray, deduplicationId: SenderDeduplicationId, additionalHeaders: Map<String, String>): Message {
return NodeClientMessage(topic, OpaqueBytes(data), deduplicationId.deduplicationId, deduplicationId.senderUUID, additionalHeaders)
override fun createMessage(topic: String, data: ByteArray, deduplicationInfo: SenderDeduplicationInfo, additionalHeaders: Map<String, String>): Message {
return NodeClientMessage(topic, OpaqueBytes(data), deduplicationInfo.messageIdentifier, deduplicationInfo.senderUUID, additionalHeaders)
}
override fun getAddressOfParty(partyInfo: PartyInfo): MessageRecipients {

View File

@ -0,0 +1,11 @@
package net.corda.node.services.messaging
/**
* This is a combination of a message's unique identifier along with a unique identifier for the sender of the message.
* The former can be used independently for deduplication purposes when receiving a message, but enriching it with the latter helps us
* optimise some paths and perform smarter deduplication logic per sender.
*
* The [senderUUID] property might be null if the flow is trying to replay messages and doesn't want an optimisation to ignore the message identifier
* because it could lead to false negatives (messages that are deemed duplicates, but are not).
*/
data class SenderDeduplicationInfo(val messageIdentifier: MessageIdentifier, val senderUUID: String?)

View File

@ -513,7 +513,7 @@ class CheckpointDumperImpl(private val checkpointStorage: CheckpointStorage, pri
private fun SessionState.toActiveSession(sessionId: SessionId): ActiveSession? {
return if (this is SessionState.Initiated) {
ActiveSession(peerParty, sessionId, receivedMessages, peerFlowInfo, peerSinkSessionId)
ActiveSession(peerParty, sessionId, receivedMessages.values.toList(), peerFlowInfo, peerSinkSessionId)
} else {
null
}

View File

@ -43,7 +43,7 @@ class NodeSchemaService(private val extraSchemas: Set<MappedSchema> = emptySet()
BasicHSMKeyManagementService.PersistentKey::class.java,
NodeSchedulerService.PersistentScheduledState::class.java,
NodeAttachmentService.DBAttachment::class.java,
P2PMessageDeduplicator.ProcessedMessage::class.java,
P2PMessageDeduplicator.SessionData::class.java,
PersistentIdentityService.PersistentPublicKeyHashToCertificate::class.java,
PersistentIdentityService.PersistentPublicKeyHashToParty::class.java,
PersistentIdentityService.PersistentHashToPublicKey::class.java,

View File

@ -6,6 +6,10 @@ import net.corda.core.flows.StateMachineRunId
import net.corda.core.identity.Party
import net.corda.core.internal.FlowAsyncOperation
import net.corda.node.services.messaging.DeduplicationHandler
import net.corda.node.services.messaging.MessageIdentifier
import net.corda.node.services.messaging.SenderDeduplicationInfo
import net.corda.node.services.messaging.SenderSequenceNumber
import net.corda.node.services.messaging.SenderUUID
import java.time.Instant
import java.util.*
@ -25,7 +29,7 @@ sealed class Action {
data class SendInitial(
val destination: Destination,
val initialise: InitialSessionMessage,
val deduplicationId: SenderDeduplicationId
val deduplicationInfo: SenderDeduplicationInfo
) : Action()
/**
@ -34,7 +38,7 @@ sealed class Action {
data class SendExisting(
val peerParty: Party,
val message: ExistingSessionMessage,
val deduplicationId: SenderDeduplicationId
val deduplicationInfo: SenderDeduplicationInfo
) : Action()
/**
@ -95,12 +99,11 @@ sealed class Action {
data class AcknowledgeMessages(val deduplicationHandlers: List<DeduplicationHandler>) : Action()
/**
* Propagate [errorMessages] to [sessions].
* @param sessions a map from source session IDs to initiated sessions.
* Propagate the specified error messages to the specified sessions.
* @param errorsPerSession a map containing the error messages to be sent per session along with their identifiers.
*/
data class PropagateErrors(
val errorMessages: List<ErrorSessionMessage>,
val sessions: List<SessionState.Initiated>,
val errorsPerSession: Map<SessionState.Initiated, List<Pair<MessageIdentifier, ErrorSessionMessage>>>,
val senderUUID: String?
) : Action()
@ -114,6 +117,11 @@ sealed class Action {
*/
data class RemoveSessionBindings(val sessionIds: Set<SessionId>) : Action()
/**
* Signal sessions ended at the messaging layer.
*/
data class SignalSessionsHasEnded(val terminatedSessions: Map<SessionId, Pair<SenderUUID?, SenderSequenceNumber?>>): Action()
/**
* Signal that the flow corresponding to [flowId] is considered started.
*/

View File

@ -10,6 +10,7 @@ import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.trace
import net.corda.node.services.api.CheckpointStorage
import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.messaging.SenderDeduplicationInfo
import net.corda.nodeapi.internal.persistence.contextDatabase
import net.corda.nodeapi.internal.persistence.contextTransaction
import net.corda.nodeapi.internal.persistence.contextTransactionOrNull
@ -58,6 +59,7 @@ internal class ActionExecutorImpl(
is Action.AddSessionBinding -> executeAddSessionBinding(action)
is Action.RemoveSessionBindings -> executeRemoveSessionBindings(action)
is Action.SignalFlowHasStarted -> executeSignalFlowHasStarted(action)
is Action.SignalSessionsHasEnded -> executeSignalSessionsHasEnded(action)
is Action.RemoveFlow -> executeRemoveFlow(action)
is Action.CreateTransaction -> executeCreateTransaction()
is Action.RollbackTransaction -> executeRollbackTransaction()
@ -132,16 +134,17 @@ internal class ActionExecutorImpl(
@Suspendable
private fun executePropagateErrors(action: Action.PropagateErrors) {
action.errorMessages.forEach { (exception) ->
val errors = action.errorsPerSession.values.flatMap { it.map { it.second } }.distinct()
errors.forEach { errorSessionMessage ->
val exception = errorSessionMessage.flowException
log.warn("Propagating error", exception)
}
for (sessionState in action.sessions) {
action.errorsPerSession.forEach { (sessionState, sessionErrors) ->
// Don't propagate errors to the originating session
for (errorMessage in action.errorMessages) {
for ((id, msg) in sessionErrors) {
val sinkSessionId = sessionState.peerSinkSessionId
val existingMessage = ExistingSessionMessage(sinkSessionId, errorMessage)
val deduplicationId = DeduplicationId.createForError(errorMessage.errorId, sinkSessionId)
flowMessaging.sendSessionMessage(sessionState.peerParty, existingMessage, SenderDeduplicationId(deduplicationId, action.senderUUID))
val errorMsg = ExistingSessionMessage(sinkSessionId, msg)
flowMessaging.sendSessionMessage(sessionState.peerParty, errorMsg, SenderDeduplicationInfo(id, action.senderUUID))
}
}
}
@ -162,18 +165,18 @@ internal class ActionExecutorImpl(
@Suspendable
private fun executeSendInitial(action: Action.SendInitial) {
flowMessaging.sendSessionMessage(action.destination, action.initialise, action.deduplicationId)
flowMessaging.sendSessionMessage(action.destination, action.initialise, action.deduplicationInfo)
}
@Suspendable
private fun executeSendExisting(action: Action.SendExisting) {
flowMessaging.sendSessionMessage(action.peerParty, action.message, action.deduplicationId)
flowMessaging.sendSessionMessage(action.peerParty, action.message, action.deduplicationInfo)
}
@Suspendable
private fun executeSendMultiple(action: Action.SendMultiple) {
val messages = action.sendInitial.map { Message(it.destination, it.initialise, it.deduplicationId) } +
action.sendExisting.map { Message(it.peerParty, it.message, it.deduplicationId) }
val messages = action.sendInitial.map { Message(it.destination, it.initialise, it.deduplicationInfo) } +
action.sendExisting.map { Message(it.peerParty, it.message, it.deduplicationInfo) }
flowMessaging.sendSessionMessages(messages)
}
@ -192,6 +195,13 @@ internal class ActionExecutorImpl(
stateMachineManager.signalFlowHasStarted(action.flowId)
}
@Suspendable
private fun executeSignalSessionsHasEnded(action: Action.SignalSessionsHasEnded) {
action.terminatedSessions.forEach { (sessionId, senderData) ->
flowMessaging.sessionEnded(sessionId, senderData.first, senderData.second)
}
}
@Suspendable
private fun executeRemoveFlow(action: Action.RemoveFlow) {
stateMachineManager.removeFlow(action.flowId, action.removalReason, action.lastState)

View File

@ -1,53 +0,0 @@
package net.corda.node.services.statemachine
import java.security.SecureRandom
/**
* A deduplication ID of a flow message.
*/
data class DeduplicationId(val toString: String) {
companion object {
/**
* Create a random deduplication ID. Note that this isn't deterministic, which means we will never dedupe it,
* unless we persist the ID somehow.
*/
fun createRandom(random: SecureRandom) = DeduplicationId("R-${random.nextLong()}")
/**
* Create a deduplication ID for a normal clean state message. This is used to have a deterministic way of
* creating IDs in case the message-generating flow logic is replayed on hard failure.
*
* A normal deduplication ID consists of:
* 1. A deduplication seed set per session. This is the initiator's session ID, with a prefix for initiator
* or initiated.
* 2. The number of *clean* suspends since the start of the flow.
* 3. An optional additional index, for cases where several messages are sent as part of the state transition.
* Note that care must be taken with this index, it must be a deterministic counter. For example a naive
* iteration over a HashMap will produce a different list of indeces than a previous run, causing the
* message-id map to change, which means deduplication will not happen correctly.
*/
fun createForNormal(checkpoint: Checkpoint, index: Int, session: SessionState): DeduplicationId {
return DeduplicationId("N-${session.deduplicationSeed}-${checkpoint.checkpointState.numberOfSuspends}-$index")
}
/**
* Create a deduplication ID for an error message. Note that these IDs live in a different namespace than normal
* IDs, as we don't want error conditions to affect the determinism of clean deduplication IDs. This allows the
* dirtiness state to be thrown away for resumption.
*
* An error deduplication ID consists of:
* 1. The error's ID. This is a unique value per "source" of error and is propagated.
* See [net.corda.core.flows.IdentifiableException].
* 2. The recipient's session ID.
*/
fun createForError(errorId: Long, recipientSessionId: SessionId): DeduplicationId {
return DeduplicationId("E-$errorId-${recipientSessionId.toLong}")
}
}
}
/**
* Represents the deduplication ID of a flow message, and the sender identifier for the flow doing the sending. The identifier might be
* null if the flow is trying to replay messages and doesn't want an optimisation to ignore the deduplication ID.
*/
data class SenderDeduplicationId(val deduplicationId: DeduplicationId, val senderUUID: String?)

View File

@ -8,6 +8,9 @@ import net.corda.core.serialization.SerializedBytes
import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.ProgressTracker
import net.corda.node.services.messaging.DeduplicationHandler
import net.corda.node.services.messaging.MessageIdentifier
import net.corda.node.services.messaging.SenderSequenceNumber
import net.corda.node.services.messaging.SenderUUID
import java.util.UUID
/**
@ -34,7 +37,10 @@ sealed class Event {
data class DeliverSessionMessage(
val sessionMessage: ExistingSessionMessage,
override val deduplicationHandler: DeduplicationHandler,
val sender: Party
val sender: Party,
val messageIdentifier: MessageIdentifier,
val senderUUID: SenderUUID?,
val senderSequenceNumber: SenderSequenceNumber?
) : Event(), GeneratedByExternalEvent
/**

View File

@ -155,7 +155,8 @@ class FlowCreator(
frozenFlowLogic,
ourIdentity,
flowCorDappVersion,
flowLogic.isEnabledTimedFlow()
flowLogic.isEnabledTimedFlow(),
serviceHub.clock.instant()
).getOrThrow()
val state = createStateMachineState(
@ -253,6 +254,7 @@ class FlowCreator(
return StateMachineState(
checkpoint = checkpoint,
pendingDeduplicationHandlers = deduplicationHandler?.let { listOf(it) } ?: emptyList(),
closedSessionsPendingToBeSignalled = emptyMap(),
isFlowResumed = false,
future = null,
isWaitingForFuture = false,

View File

@ -15,6 +15,9 @@ import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.messaging.DeduplicationHandler
import net.corda.node.services.messaging.MessagingService
import net.corda.node.services.messaging.ReceivedMessage
import net.corda.node.services.messaging.SenderDeduplicationInfo
import net.corda.node.services.messaging.SenderSequenceNumber
import net.corda.node.services.messaging.SenderUUID
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2PMessagingHeaders
import java.io.NotSerializableException
@ -23,21 +26,24 @@ import java.io.NotSerializableException
*/
interface FlowMessaging {
/**
* Send [message] to [destination] using [deduplicationId].
* Send [message] to [destination] using [deduplicationInfo].
*/
@Suspendable
fun sendSessionMessage(destination: Destination, message: SessionMessage, deduplicationId: SenderDeduplicationId)
fun sendSessionMessage(destination: Destination, message: SessionMessage, deduplicationInfo: SenderDeduplicationInfo)
@Suspendable
fun sendSessionMessages(messageData: List<Message>)
@Suspendable
fun sessionEnded(sessionId: SessionId, senderUUID: SenderUUID?, senderSequenceNumber: SenderSequenceNumber?)
/**
* Start the messaging using the [onMessage] message handler.
*/
fun start(onMessage: (ReceivedMessage, deduplicationHandler: DeduplicationHandler) -> Unit)
}
data class Message(val destination: Destination, val sessionMessage: SessionMessage, val dedupId: SenderDeduplicationId)
data class Message(val destination: Destination, val sessionMessage: SessionMessage, val dedupInfo: SenderDeduplicationInfo)
/**
* Implementation of [FlowMessaging] using a [ServiceHubInternal] to do the messaging and routing.
@ -56,18 +62,23 @@ class FlowMessagingImpl(val serviceHub: ServiceHubInternal): FlowMessaging {
}
@Suspendable
override fun sendSessionMessage(destination: Destination, message: SessionMessage, deduplicationId: SenderDeduplicationId) {
val addressedMessage = createMessage(destination, message, deduplicationId)
override fun sendSessionMessage(destination: Destination, message: SessionMessage, deduplicationInfo: SenderDeduplicationInfo) {
val addressedMessage = createMessage(destination, message, deduplicationInfo)
serviceHub.networkService.send(addressedMessage.message, addressedMessage.target, addressedMessage.sequenceKey)
}
@Suspendable
override fun sendSessionMessages(messageData: List<Message>) {
val addressedMessages = messageData.map { createMessage(it.destination, it.sessionMessage, it.dedupId) }
val addressedMessages = messageData.map { createMessage(it.destination, it.sessionMessage, it.dedupInfo) }
serviceHub.networkService.sendAll(addressedMessages)
}
private fun createMessage(destination: Destination, message: SessionMessage, deduplicationId: SenderDeduplicationId): MessagingService.AddressedMessage {
@Suspendable
override fun sessionEnded(sessionId: SessionId, senderUUID: SenderUUID?, senderSequenceNumber: SenderSequenceNumber?) {
serviceHub.networkService.sessionEnded(sessionId, senderUUID, senderSequenceNumber)
}
private fun createMessage(destination: Destination, message: SessionMessage, deduplicationId: SenderDeduplicationInfo): MessagingService.AddressedMessage {
// We assume that the destination type has already been checked by initiateFlow.
// Destination may point to a stale well-known identity due to key rotation, so always resolve actual identity via IdentityService.
val party = requireNotNull(serviceHub.identityService.wellKnownPartyFromAnonymous(destination as AbstractParty)) {
@ -80,6 +91,7 @@ class FlowMessagingImpl(val serviceHub: ServiceHubInternal): FlowMessaging {
}
val networkMessage = serviceHub.networkService.createMessage(sessionTopic, serializeSessionMessage(message).bytes, deduplicationId, message.additionalHeaders(party))
val partyInfo = requireNotNull(serviceHub.networkMapCache.getPartyInfo(party)) { "Don't know about ${party.description()}" }
val address = serviceHub.networkService.getAddressOfParty(partyInfo)
val sequenceKey = when (message) {
is InitialSessionMessage -> message.initiatorSessionId

View File

@ -179,7 +179,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
val stateMachine = transientValues.stateMachine
val oldState = transientState
val actionExecutor = transientValues.actionExecutor
val transition = stateMachine.transition(event, oldState)
val transition = stateMachine.transition(event, oldState, serviceHub.clock.instant())
val (continuation, newState) = transitionExecutor.executeTransition(
this,
oldState,

View File

@ -4,6 +4,7 @@ import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowInfo
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SerializedBytes
import java.math.BigInteger
import java.security.SecureRandom
/**
@ -21,9 +22,28 @@ import java.security.SecureRandom
sealed class SessionMessage
@CordaSerializable
data class SessionId(val toLong: Long) {
data class SessionId(val value: BigInteger) {
init {
require(value.signum() >= 0) { "Session identifier cannot be a negative number, but it was $value" }
require(value.bitLength() <= MAX_BIT_SIZE) { "The size of a session identifier cannot exceed $MAX_BIT_SIZE bits, but it was $value" }
}
/**
* This calculates the initiated session ID assuming this is the initiating session ID.
* This is the next larger number in the range [0, 2^[MAX_BIT_SIZE]] with wrap around the largest number in the interval.
*/
fun calculateInitiatedSessionId(): SessionId {
return if (this.value == LARGEST_SESSION_ID)
SessionId(BigInteger.ZERO)
else
SessionId(this.value.plus(BigInteger.ONE))
}
companion object {
fun createRandom(secureRandom: SecureRandom) = SessionId(secureRandom.nextLong())
const val MAX_BIT_SIZE = 128
val LARGEST_SESSION_ID = BigInteger.valueOf(2).pow(MAX_BIT_SIZE).minus(BigInteger.ONE)
fun createRandom(secureRandom: SecureRandom) = SessionId(BigInteger(MAX_BIT_SIZE, secureRandom))
}
}
@ -118,3 +138,29 @@ data class RejectSessionMessage(val message: String, val errorId: Long) : Existi
* protocols don't match up, e.g. one is waiting for the other, but the other side has already finished.
*/
object EndSessionMessage : ExistingSessionMessagePayload()
enum class MessageType {
SESSION_INIT,
SESSION_CONFIRM,
SESSION_REJECT,
DATA_MESSAGE,
SESSION_END,
SESSION_ERROR;
companion object {
fun inferFromMessage(message: SessionMessage): MessageType {
return when (message) {
is InitialSessionMessage -> SESSION_INIT
is ExistingSessionMessage -> {
when(message.payload) {
is ConfirmSessionMessage -> SESSION_CONFIRM
is RejectSessionMessage -> SESSION_REJECT
is DataSessionMessage -> DATA_MESSAGE
is EndSessionMessage -> SESSION_END
is ErrorSessionMessage -> SESSION_ERROR
}
}
}
}
}
}

View File

@ -37,6 +37,10 @@ import net.corda.node.internal.InitiatedFlowFactory
import net.corda.node.services.api.CheckpointStorage
import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.messaging.DeduplicationHandler
import net.corda.node.services.messaging.MessageIdentifier
import net.corda.node.services.messaging.SenderDeduplicationInfo
import net.corda.node.services.messaging.SenderSequenceNumber
import net.corda.node.services.messaging.SenderUUID
import net.corda.node.services.statemachine.FlowStateMachineImpl.Companion.currentStateMachine
import net.corda.node.services.statemachine.interceptors.DumpHistoryOnErrorInterceptor
import net.corda.node.services.statemachine.interceptors.HospitalisingInterceptor
@ -703,7 +707,7 @@ internal class SingleThreadedStateMachineManager(
val sender = serviceHub.networkMapCache.getPeerByLegalName(peer)
if (sender != null) {
when (sessionMessage) {
is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, sender, event)
is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, sender, event, event.receivedMessage.uniqueMessageId, event.receivedMessage.senderUUID, event.receivedMessage.senderSeqNo)
is InitialSessionMessage -> onSessionInit(sessionMessage, sender, event)
}
} else {
@ -716,7 +720,10 @@ internal class SingleThreadedStateMachineManager(
private fun onExistingSessionMessage(
sessionMessage: ExistingSessionMessage,
sender: Party,
externalEvent: ExternalEvent.ExternalMessageEvent
externalEvent: ExternalEvent.ExternalMessageEvent,
messageIdentifier: MessageIdentifier,
senderUUID: SenderUUID?,
senderSequenceNumber: SenderSequenceNumber?
) {
try {
val deduplicationHandler = externalEvent.deduplicationHandler
@ -734,7 +741,7 @@ internal class SingleThreadedStateMachineManager(
logger.info("Cannot find flow corresponding to session ID - $recipientId.")
}
} else {
val event = Event.DeliverSessionMessage(sessionMessage, deduplicationHandler, sender)
val event = Event.DeliverSessionMessage(sessionMessage, deduplicationHandler, sender, messageIdentifier, senderUUID, senderSequenceNumber)
innerState.withLock {
flows[flowId]?.run { fiber.scheduleEvent(event) }
// If flow is not running add it to the list of external events to be processed if/when the flow resumes.
@ -751,7 +758,7 @@ internal class SingleThreadedStateMachineManager(
private fun onSessionInit(sessionMessage: InitialSessionMessage, sender: Party, event: ExternalEvent.ExternalMessageEvent) {
try {
val initiatedFlowFactory = getInitiatedFlowFactory(sessionMessage)
val initiatedSessionId = SessionId.createRandom(secureRandom)
val initiatedSessionId = event.receivedMessage.uniqueMessageId.sessionIdentifier
val senderSession = FlowSessionImpl(sender, sender, initiatedSessionId)
val flowLogic = initiatedFlowFactory.createFlow(senderSession)
val initiatedFlowInfo = when (initiatedFlowFactory) {
@ -763,9 +770,8 @@ internal class SingleThreadedStateMachineManager(
is InitiatedFlowFactory.CorDapp -> null
}
startInitiatedFlow(
event.flowId,
event,
flowLogic,
event.deduplicationHandler,
senderSession,
initiatedSessionId,
sessionMessage,
@ -800,24 +806,24 @@ internal class SingleThreadedStateMachineManager(
@Suppress("LongParameterList")
private fun <A> startInitiatedFlow(
flowId: StateMachineRunId,
event: ExternalEvent.ExternalMessageEvent,
flowLogic: FlowLogic<A>,
initiatingMessageDeduplicationHandler: DeduplicationHandler,
peerSession: FlowSessionImpl,
initiatedSessionId: SessionId,
initiatingMessage: InitialSessionMessage,
senderCoreFlowVersion: Int?,
initiatedFlowInfo: FlowInfo
) {
val flowStart = FlowStart.Initiated(peerSession, initiatedSessionId, initiatingMessage, senderCoreFlowVersion, initiatedFlowInfo)
val flowStart = FlowStart.Initiated(peerSession, initiatedSessionId, initiatingMessage, senderCoreFlowVersion, initiatedFlowInfo,
event.receivedMessage.uniqueMessageId.shardIdentifier, event.receivedMessage.senderUUID, event.receivedMessage.senderSeqNo)
val ourIdentity = ourFirstIdentity
startFlowInternal(
flowId,
event.flowId,
InvocationContext.peer(peerSession.counterparty.name),
flowLogic,
flowStart,
ourIdentity,
initiatingMessageDeduplicationHandler
event.deduplicationHandler
)
}

View File

@ -21,6 +21,8 @@ import net.corda.core.utilities.debug
import net.corda.core.utilities.minutes
import net.corda.core.utilities.seconds
import net.corda.node.services.FinalityHandler
import net.corda.node.services.messaging.MessageIdentifier
import net.corda.node.services.messaging.SenderDeduplicationInfo
import org.hibernate.exception.ConstraintViolationException
import rx.subjects.PublishSubject
import java.io.Closeable
@ -169,7 +171,9 @@ class StaffedFlowHospital(private val flowMessaging: FlowMessaging,
log.info("Sending session initiation error back to $sender", error)
flowMessaging.sendSessionMessage(sender, replyError, SenderDeduplicationId(DeduplicationId.createRandom(secureRandom), ourSenderUUID))
val messageType = MessageType.inferFromMessage(replyError)
val messageIdentifier = MessageIdentifier(messageType, event.receivedMessage.uniqueMessageId.shardIdentifier, sessionMessage.initiatorSessionId, 0, event.receivedMessage.uniqueMessageId.timestamp)
flowMessaging.sendSessionMessage(sender, replyError, SenderDeduplicationInfo(messageIdentifier, ourSenderUUID))
event.deduplicationHandler.afterDatabaseTransaction()
}

View File

@ -22,11 +22,16 @@ import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.utilities.Try
import net.corda.node.services.messaging.DeduplicationHandler
import net.corda.node.services.messaging.MessageIdentifier
import net.corda.node.services.messaging.SenderSequenceNumber
import net.corda.node.services.messaging.SenderUUID
import java.lang.IllegalArgumentException
import java.lang.IllegalStateException
import java.security.Principal
import java.time.Instant
import java.util.concurrent.Future
import java.util.concurrent.Semaphore
import kotlin.math.max
/**
* The state of the state machine, capturing the state of a flow. It consists of two parts, an *immutable* part that is
@ -35,6 +40,7 @@ import java.util.concurrent.Semaphore
* @param checkpoint the persisted part of the state.
* @param flowLogic the [FlowLogic] associated with the flow. Note that this is mutable by the user.
* @param pendingDeduplicationHandlers the list of incomplete deduplication handlers.
* @param closedSessionsPendingToBeSignalled the sessions that have been closed and need to be signalled to the messaging layer on the next checkpoint (along with some metadata).
* @param isFlowResumed true if the control is returned (or being returned) to "user-space" flow code. This is used
* to make [Event.DoRemainingWork] idempotent.
* @param isWaitingForFuture true if the flow is waiting for the completion of a future triggered by one of the statemachine's actions
@ -61,6 +67,7 @@ data class StateMachineState(
val checkpoint: Checkpoint,
val flowLogic: FlowLogic<*>,
val pendingDeduplicationHandlers: List<DeduplicationHandler>,
val closedSessionsPendingToBeSignalled: Map<SessionId, Pair<SenderUUID?, SenderSequenceNumber?>>,
val isFlowResumed: Boolean,
val isWaitingForFuture: Boolean,
var future: Future<*>?,
@ -123,7 +130,8 @@ data class Checkpoint(
frozenFlowLogic: SerializedBytes<FlowLogic<*>>,
ourIdentity: Party,
subFlowVersion: SubFlowVersion,
isEnabledTimedFlow: Boolean
isEnabledTimedFlow: Boolean,
timestamp: Instant
): Try<Checkpoint> {
return SubFlow.create(flowLogicClass, subFlowVersion, isEnabledTimedFlow).map { topLevelSubFlow ->
Checkpoint(
@ -135,7 +143,8 @@ data class Checkpoint(
listOf(topLevelSubFlow),
numberOfSuspends = 0,
// We set this to 1 here to avoid an extra copy and increment in UnstartedFlowTransition.createInitialCheckpoint
numberOfCommits = 1
numberOfCommits = 1,
suspensionTime = timestamp
),
flowState = FlowState.Unstarted(flowStart, frozenFlowLogic),
errorState = ErrorState.Clean
@ -235,6 +244,7 @@ data class Checkpoint(
}
/**
<<<<<<< HEAD
* @param invocationContext The initiator of the flow.
* @param ourIdentity The identity the flow is run as.
* @param sessions Map of source session ID to session state.
@ -242,6 +252,7 @@ data class Checkpoint(
* @param subFlowStack The stack of currently executing subflows.
* @param numberOfSuspends The number of flow suspends due to IO API calls.
* @param numberOfCommits The number of times this checkpoint has been persisted.
* @param suspensionTime the time of the last suspension. This is supposed to be used as a stable timestamp in case of replays.
*/
@CordaSerializable
data class CheckpointState(
@ -251,7 +262,8 @@ data class CheckpointState(
val sessionsToBeClosed: Set<SessionId>,
val subFlowStack: List<SubFlow>,
val numberOfSuspends: Int,
val numberOfCommits: Int
val numberOfCommits: Int,
val suspensionTime: Instant
)
/**
@ -262,44 +274,162 @@ sealed class SessionState {
abstract val deduplicationSeed: String
/**
* We haven't yet sent the initialisation message
* the sender UUID last seen in this session, if there was one.
*/
abstract val lastSenderUUID: SenderUUID?
/**
* the sender sequence number last seen in this session, if there was one.
*/
abstract val lastSenderSeqNo: SenderSequenceNumber?
/**
* the messages that have been received and are pending processing indexed by their sequence number.
* this could be any [ExistingSessionMessagePayload] type in theory, but it in practice it can only be one of the following types now:
* * [DataSessionMessage]
* * [ErrorSessionMessage]
* * [EndSessionMessage]
*/
abstract val receivedMessages: Map<Int, ExistingSessionMessagePayload>
/**
* Returns a new session state with the specified messages added to the list of received messages.
*/
fun addReceivedMessages(message: ExistingSessionMessagePayload, messageIdentifier: MessageIdentifier, senderUUID: String?, senderSequenceNumber: Long?): SessionState {
val newReceivedMessages = receivedMessages.plus(messageIdentifier.sessionSequenceNumber to message)
val (newLastSenderUUID, newLastSenderSeqNo) = calculateSenderInfo(lastSenderUUID, lastSenderSeqNo, senderUUID, senderSequenceNumber)
return when(this) {
is Uninitiated -> { copy(receivedMessages = newReceivedMessages, lastSenderUUID = newLastSenderUUID, lastSenderSeqNo = newLastSenderSeqNo) }
is Initiating -> { copy(receivedMessages = newReceivedMessages, lastSenderUUID = newLastSenderUUID, lastSenderSeqNo = newLastSenderSeqNo) }
is Initiated -> { copy(receivedMessages = newReceivedMessages, lastSenderUUID = newLastSenderUUID, lastSenderSeqNo = newLastSenderSeqNo) }
}
}
private fun calculateSenderInfo(currentSender: String?, currentSenderSeqNo: Long?, msgSender: String?, msgSenderSeqNo: Long?): Pair<String?, Long?> {
return if (msgSender != null && msgSenderSeqNo != null) {
if (currentSenderSeqNo != null)
Pair(msgSender, max(msgSenderSeqNo, currentSenderSeqNo))
else
Pair(msgSender, msgSenderSeqNo)
} else {
Pair(currentSender, currentSenderSeqNo)
}
}
/**
* We haven't yet sent the initialisation message.
* This really means that the flow is in a state before sending the initialisation message,
* but in reality it could have sent it before and fail before reaching the next checkpoint, thus ending up replaying from the last checkpoint.
*
* @param hasBeenAcknowledged whether a positive response to a session initiation has already been received and the associated confirmation message, if so.
* @param hasBeenRejected whether a negative response to a session initiation has already been received and the associated rejection message, if so.
*/
data class Uninitiated(
val destination: Destination,
val initiatingSubFlow: SubFlow.Initiating,
val sourceSessionId: SessionId,
val additionalEntropy: Long
val additionalEntropy: Long,
val hasBeenAcknowledged: Pair<Party, ConfirmSessionMessage>?,
val hasBeenRejected: RejectSessionMessage?,
override val receivedMessages: Map<Int, ExistingSessionMessagePayload>,
override val lastSenderUUID: String?,
override val lastSenderSeqNo: Long?
) : SessionState() {
override val deduplicationSeed: String get() = "R-${sourceSessionId.toLong}-$additionalEntropy"
override val deduplicationSeed: String get() = "R-${sourceSessionId.value}-$additionalEntropy"
}
/**
* We have sent the initialisation message but have not yet received a confirmation.
* @property bufferedMessages the messages that have been buffered to be sent after the session is confirmed from the other side.
* @property rejectionError if non-null the initiation failed.
* @property nextSendingSeqNumber the sequence number of the next message to be sent.
* @property shardId the shard ID of the associated flow to be embedded on all the messages sent from this session.
*/
data class Initiating(
val bufferedMessages: List<Pair<DeduplicationId, ExistingSessionMessagePayload>>,
val bufferedMessages: List<Pair<MessageIdentifier, ExistingSessionMessagePayload>>,
val rejectionError: FlowError?,
override val deduplicationSeed: String
) : SessionState()
override val deduplicationSeed: String,
val nextSendingSeqNumber: Int,
val shardId: String,
override val receivedMessages: Map<Int, ExistingSessionMessagePayload>,
override val lastSenderUUID: String?,
override val lastSenderSeqNo: Long?
) : SessionState() {
/**
* Buffers an outgoing message to be sent when ready.
* Returns the new form of the state
*/
fun bufferMessage(messageIdentifier: MessageIdentifier, messagePayload: ExistingSessionMessagePayload): SessionState {
return this.copy(bufferedMessages = bufferedMessages + Pair(messageIdentifier, messagePayload), nextSendingSeqNumber = nextSendingSeqNumber + 1)
}
/**
* A batched form of [bufferMessage].
*/
fun bufferMessages(messages: List<Pair<MessageIdentifier, ExistingSessionMessagePayload>>): SessionState {
return this.copy(bufferedMessages = bufferedMessages + messages, nextSendingSeqNumber = nextSendingSeqNumber + messages.size)
}
}
/**
* We have received a confirmation, the peer party and session id is resolved.
* @property receivedMessages the messages that have been received and are pending processing.
* this could be any [ExistingSessionMessagePayload] type in theory, but it in practice it can only be one of the following types now:
* * [DataSessionMessage]
* * [ErrorSessionMessage]
* * [EndSessionMessage]
* @property otherSideErrored whether the session has received an error from the other side.
* @property nextSendingSeqNumber the sequence number that corresponds to the next message to be sent.
* @property lastProcessedSeqNumber the sequence number of the last message that has been processed.
* @property shardId the shard ID of the associated flow to be embedded on all the messages sent from this session.
*/
data class Initiated(
val peerParty: Party,
val peerFlowInfo: FlowInfo,
val receivedMessages: List<ExistingSessionMessagePayload>,
val otherSideErrored: Boolean,
val peerSinkSessionId: SessionId,
override val deduplicationSeed: String
) : SessionState()
override val deduplicationSeed: String,
val nextSendingSeqNumber: Int,
val lastProcessedSeqNumber: Int,
val shardId: String,
override val receivedMessages: Map<Int, ExistingSessionMessagePayload>,
override val lastSenderUUID: String?,
override val lastSenderSeqNo: Long?
) : SessionState() {
/**
* Indicates whether this message has already been processed.
*/
fun isDuplicate(messageIdentifier: MessageIdentifier): Boolean {
return messageIdentifier.sessionSequenceNumber <= lastProcessedSeqNumber
}
/**
* Indicates whether the session has an error message pending from the other side.
*/
fun hasErrored(): Boolean {
return hasNextMessageArrived() && receivedMessages[lastProcessedSeqNumber + 1] is ErrorSessionMessage
}
/**
* Indicates whether the next expected message has arrived.
*/
fun hasNextMessageArrived(): Boolean {
return receivedMessages.containsKey(lastProcessedSeqNumber + 1)
}
/**
* Returns the next message to be processed and the new session state.
* If you want to check first whether the next message has arrived, call [hasNextMessageArrived]
*
* @throws [IllegalArgumentException] if the next hasn't arrived.
*/
fun extractMessage(): Pair<ExistingSessionMessagePayload, Initiated> {
if (!hasNextMessageArrived()) {
throw IllegalArgumentException("Tried to extract a message that hasn't arrived yet.")
}
val message = receivedMessages[lastProcessedSeqNumber + 1]!!
val newState = this.copy(receivedMessages = receivedMessages.minus(lastProcessedSeqNumber + 1), lastProcessedSeqNumber = lastProcessedSeqNumber + 1)
return Pair(message, newState)
}
}
}
typealias SessionMap = Map<SessionId, SessionState>
@ -321,7 +451,10 @@ sealed class FlowStart {
val initiatedSessionId: SessionId,
val initiatingMessage: InitialSessionMessage,
val senderCoreFlowVersion: Int?,
val initiatedFlowInfo: FlowInfo
val initiatedFlowInfo: FlowInfo,
val shardIdentifier: String,
val senderUUID: String?,
val senderSequenceNumber: Long?
) : FlowStart() { override fun toString() = "Initiated" }
}

View File

@ -3,6 +3,7 @@ package net.corda.node.services.statemachine.transitions
import net.corda.core.flows.UnexpectedFlowEndException
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.debug
import net.corda.core.utilities.trace
import net.corda.node.services.statemachine.Action
import net.corda.node.services.statemachine.ConfirmSessionMessage
import net.corda.node.services.statemachine.DataSessionMessage
@ -13,7 +14,7 @@ import net.corda.node.services.statemachine.ExistingSessionMessage
import net.corda.node.services.statemachine.FlowError
import net.corda.node.services.statemachine.FlowState
import net.corda.node.services.statemachine.RejectSessionMessage
import net.corda.node.services.statemachine.SenderDeduplicationId
import net.corda.node.services.messaging.SenderDeduplicationInfo
import net.corda.node.services.statemachine.SessionState
import net.corda.node.services.statemachine.StateMachineState
@ -87,56 +88,78 @@ class DeliverSessionMessageTransition(
val initiatedSession = SessionState.Initiated(
peerParty = event.sender,
peerFlowInfo = message.initiatedFlowInfo,
receivedMessages = emptyList(),
receivedMessages = emptyMap(),
peerSinkSessionId = message.initiatedSessionId,
deduplicationSeed = sessionState.deduplicationSeed,
otherSideErrored = false
otherSideErrored = false,
nextSendingSeqNumber = sessionState.nextSendingSeqNumber,
lastProcessedSeqNumber = 0,
shardId = sessionState.shardId,
lastSenderUUID = event.senderUUID,
lastSenderSeqNo = event.senderSequenceNumber
)
val newCheckpoint = currentState.checkpoint.addSession(
event.sessionMessage.recipientSessionId to initiatedSession
)
// Send messages that were buffered pending confirmation of session.
val sendActions = sessionState.bufferedMessages.map { (deduplicationId, bufferedMessage) ->
val sendActions = sessionState.bufferedMessages.map { (messageId, bufferedMessage) ->
val existingMessage = ExistingSessionMessage(message.initiatedSessionId, bufferedMessage)
Action.SendExisting(initiatedSession.peerParty, existingMessage, SenderDeduplicationId(deduplicationId, startingState.senderUUID))
Action.SendExisting(initiatedSession.peerParty, existingMessage, SenderDeduplicationInfo(messageId, startingState.senderUUID))
}
actions.addAll(sendActions)
currentState = currentState.copy(checkpoint = newCheckpoint)
}
else -> freshErrorTransition(UnexpectedEventInState())
is SessionState.Initiated -> {
log.trace { "Discarding duplicate confirmation for session ${event.sessionMessage.recipientSessionId} with ${sessionState.peerParty}" }
}
is SessionState.Uninitiated -> {
val newSessionState = sessionState.copy(hasBeenAcknowledged = Pair(event.sender, message))
val newCheckpoint = currentState.checkpoint.addSession(event.sessionMessage.recipientSessionId to newSessionState)
currentState = currentState.copy(checkpoint = newCheckpoint)
}
}
}
private fun TransitionBuilder.dataMessageTransition(sessionState: SessionState, message: DataSessionMessage) {
// We received a data message. The corresponding session must be Initiated.
return when (sessionState) {
is SessionState.Initiated -> {
// Buffer the message in the session's receivedMessages buffer.
val newSessionState = sessionState.copy(
receivedMessages = sessionState.receivedMessages + message
)
if (!sessionState.isDuplicate(event.messageIdentifier)) {
val newSessionState = sessionState.addReceivedMessages(message, event.messageIdentifier, event.senderUUID, event.senderSequenceNumber)
currentState = currentState.copy(
checkpoint = currentState.checkpoint.addSession(event.sessionMessage.recipientSessionId to newSessionState)
)
} else {
log.trace { "Discarding duplicate data message for session ${event.sessionMessage.recipientSessionId} with ${sessionState.peerParty}" }
}
}
is SessionState.Initiating, is SessionState.Uninitiated -> {
val newSessionState = sessionState.addReceivedMessages(message, event.messageIdentifier, event.senderUUID, event.senderSequenceNumber)
currentState = currentState.copy(
checkpoint = currentState.checkpoint.addSession(
event.sessionMessage.recipientSessionId to newSessionState
)
checkpoint = currentState.checkpoint.addSession(event.sessionMessage.recipientSessionId to newSessionState)
)
}
else -> freshErrorTransition(UnexpectedEventInState())
}
}
private fun TransitionBuilder.errorMessageTransition(sessionState: SessionState, payload: ErrorSessionMessage) {
val sequenceNumber = event.messageIdentifier.sessionSequenceNumber
return when (sessionState) {
is SessionState.Initiated -> {
val checkpoint = currentState.checkpoint
val sessionId = event.sessionMessage.recipientSessionId
val newSessionState = sessionState.copy(receivedMessages = sessionState.receivedMessages + payload)
if (sequenceNumber > sessionState.lastProcessedSeqNumber) {
val newSessionState = sessionState.addReceivedMessages(payload, event.messageIdentifier, event.senderUUID, event.senderSequenceNumber)
currentState = currentState.copy(
checkpoint = currentState.checkpoint.addSession(event.sessionMessage.recipientSessionId to newSessionState)
)
} else {
log.trace { "Discarding duplicate error message for session ${event.sessionMessage.recipientSessionId} with ${sessionState.peerParty}" }
}
}
is SessionState.Initiating, is SessionState.Uninitiated -> {
val newSessionState = sessionState.addReceivedMessages(payload, event.messageIdentifier, event.senderUUID, event.senderSequenceNumber)
currentState = currentState.copy(
checkpoint = checkpoint.addSession(sessionId to newSessionState)
checkpoint = currentState.checkpoint.addSession(event.sessionMessage.recipientSessionId to newSessionState)
)
}
else -> freshErrorTransition(UnexpectedEventInState())
}
}
@ -145,42 +168,42 @@ class DeliverSessionMessageTransition(
return when (sessionState) {
is SessionState.Initiating -> {
if (sessionState.rejectionError != null) {
// Double reject
freshErrorTransition(UnexpectedEventInState())
log.trace { "Discarding duplicate session rejection message for session ${event.sessionMessage.recipientSessionId}" }
} else {
val checkpoint = currentState.checkpoint
val sessionId = event.sessionMessage.recipientSessionId
val flowError = FlowError(payload.errorId, exception)
currentState = currentState.copy(
checkpoint = checkpoint.addSession(sessionId to sessionState.copy(rejectionError = flowError))
checkpoint = currentState.checkpoint.addSession(event.sessionMessage.recipientSessionId to sessionState.copy(rejectionError = flowError))
)
}
}
else -> freshErrorTransition(UnexpectedEventInState())
is SessionState.Uninitiated -> {
currentState = currentState.copy(
checkpoint = currentState.checkpoint.addSession(event.sessionMessage.recipientSessionId to sessionState.copy(hasBeenRejected = payload))
)
}
is SessionState.Initiated -> {
freshErrorTransition(UnexpectedEventInState("A session rejection message was received for an already established session ${event.messageIdentifier.sessionIdentifier}."))
}
}
}
private fun TransitionBuilder.endMessageTransition(payload: EndSessionMessage) {
val flowState = currentState.checkpoint.flowState
// flow must have already been started when session end messages are being delivered.
if (flowState !is FlowState.Started)
return freshErrorTransition(UnexpectedEventInState())
val sessionId = event.sessionMessage.recipientSessionId
val sessions = currentState.checkpoint.checkpointState.sessions
// a check has already been performed to confirm the session exists for this message before this method is invoked.
val sessionState = sessions[sessionId]!!
when (sessionState) {
is SessionState.Initiated -> {
val flowState = currentState.checkpoint.flowState
// flow must have already been started when session end messages are being delivered.
if (flowState !is FlowState.Started)
return freshErrorTransition(UnexpectedEventInState())
val newSessionState = sessionState.copy(receivedMessages = sessionState.receivedMessages + payload)
is SessionState.Initiated, is SessionState.Initiating, is SessionState.Uninitiated -> {
val newSessionState = sessionState.addReceivedMessages(payload, event.messageIdentifier, event.senderUUID, event.senderSequenceNumber)
val newCheckpoint = currentState.checkpoint.addSession(event.sessionMessage.recipientSessionId to newSessionState)
.addSessionsToBeClosed(setOf(event.sessionMessage.recipientSessionId))
currentState = currentState.copy(checkpoint = newCheckpoint)
}
else -> {
freshErrorTransition(PrematureSessionEndException(event.sessionMessage.recipientSessionId))
}
}
}

View File

@ -1,6 +1,7 @@
package net.corda.node.services.statemachine.transitions
import net.corda.core.flows.FlowException
import net.corda.node.services.messaging.MessageIdentifier
import net.corda.node.services.statemachine.*
/**
@ -40,16 +41,28 @@ class ErrorFlowTransition(
return builder {
// If we're errored and propagating do the actual propagation and update the index.
if (remainingErrorsToPropagate.isNotEmpty() && errorState.propagating) {
val (initiatedSessions, newSessions) = bufferErrorMessagesInInitiatingSessions(
val (initiatedSessions, newSessionStates) = bufferErrorMessagesInInitiatingSessions(
startingState.checkpoint.checkpointState.sessions,
errorMessages
)
val sessionsWithAdvancedSeqNumbers = mutableMapOf<SessionId, SessionState>()
val errorsPerSession = initiatedSessions.map { (sessionId, sessionState) ->
var currentSeqNumber = sessionState.nextSendingSeqNumber
val errorsWithId = errorMessages.map { errorMsg ->
val messageIdentifier = MessageIdentifier(MessageType.SESSION_ERROR, sessionState.shardId, sessionState.peerSinkSessionId, currentSeqNumber, startingState.checkpoint.checkpointState.suspensionTime)
currentSeqNumber++
Pair(messageIdentifier, errorMsg)
}.toList()
sessionsWithAdvancedSeqNumbers[sessionId] = sessionState.copy(nextSendingSeqNumber = currentSeqNumber)
Pair(sessionState, errorsWithId)
}.toMap()
val newCheckpoint = startingState.checkpoint.copy(
errorState = errorState.copy(propagatedIndex = allErrors.size),
checkpointState = startingState.checkpoint.checkpointState.copy(sessions = newSessions)
checkpointState = startingState.checkpoint.checkpointState.copy(sessions = newSessionStates + sessionsWithAdvancedSeqNumbers)
)
currentState = currentState.copy(checkpoint = newCheckpoint)
actions += Action.PropagateErrors(errorMessages, initiatedSessions, startingState.senderUUID)
actions += Action.PropagateErrors(errorsPerSession, startingState.senderUUID)
}
// If we're errored but not propagating keep processing events.
@ -81,16 +94,27 @@ class ErrorFlowTransition(
isCheckpointUpdate = currentState.isAnyCheckpointPersisted
)
}
val signalSessionsEndMap = currentState.checkpoint.checkpointState.sessions.map { (sessionId, sessionState) ->
sessionId to Pair(sessionState.lastSenderUUID, sessionState.lastSenderSeqNo)
}.toMap()
actions += Action.CreateTransaction
actions += removeOrPersistCheckpoint
actions += Action.PersistDeduplicationFacts(startingState.pendingDeduplicationHandlers)
actions += Action.SignalSessionsHasEnded(signalSessionsEndMap)
actions += Action.ReleaseSoftLocks(context.id.uuid)
actions += Action.CommitTransaction(currentState)
actions += Action.AcknowledgeMessages(startingState.pendingDeduplicationHandlers)
actions += Action.RemoveSessionBindings(startingState.checkpoint.checkpointState.sessions.keys)
actions += Action.RemoveFlow(context.id, FlowRemovalReason.ErrorFinish(allErrors), currentState)
currentState = currentState.copy(
checkpoint = newCheckpoint,
pendingDeduplicationHandlers = emptyList(),
closedSessionsPendingToBeSignalled = emptyMap(),
isRemoved = true
)
FlowContinuation.Abort
} else {
// Otherwise keep processing events. This branch happens when there are some outstanding initiating
@ -112,31 +136,37 @@ class ErrorFlowTransition(
}
}
// Buffer error messages in Initiating sessions, return the initialised ones.
/**
* Buffers errors message for initiating states and filters the initiated states.
* Returns a pair that consists of:
* - a map containing the initiated states as filtered from the ones provided as input.
* - a map containing the new state of all the sessions.
*/
private fun bufferErrorMessagesInInitiatingSessions(
sessions: Map<SessionId, SessionState>,
errorMessages: List<ErrorSessionMessage>
): Pair<List<SessionState.Initiated>, Map<SessionId, SessionState>> {
val newSessions = sessions.mapValues { (sourceSessionId, sessionState) ->
): Pair<Map<SessionId, SessionState.Initiated>, Map<SessionId, SessionState>> {
val newSessionStates = sessions.mapValues { (sourceSessionId, sessionState) ->
if (sessionState is SessionState.Initiating && sessionState.rejectionError == null) {
// *prepend* the error messages in order to error the other sessions ASAP. The other messages will
// be delivered all the same, they just won't trigger flow resumption because of dirtiness.
val errorMessagesWithDeduplication = errorMessages.map {
DeduplicationId.createForError(it.errorId, sourceSessionId) to it
var currentSequenceNumber = sessionState.nextSendingSeqNumber
val errorMessagesWithDeduplication = errorMessages.map { errorMessage ->
val messageIdentifier = MessageIdentifier(MessageType.SESSION_ERROR, sessionState.shardId, sourceSessionId.calculateInitiatedSessionId(), currentSequenceNumber, startingState.checkpoint.checkpointState.suspensionTime)
currentSequenceNumber++
messageIdentifier to errorMessage
}
sessionState.copy(bufferedMessages = errorMessagesWithDeduplication + sessionState.bufferedMessages)
sessionState.bufferMessages(errorMessagesWithDeduplication)
} else {
sessionState
}
}
// if we have already received error message from the other side, we don't include that session in the list to avoid propagating errors.
val initiatedSessions = sessions.values.mapNotNull { session ->
if (session is SessionState.Initiated && !session.otherSideErrored) {
session
val initiatedSessions = sessions.mapNotNull { (sessionId, sessionState) ->
if (sessionState is SessionState.Initiated && !sessionState.otherSideErrored) {
sessionId to sessionState
} else {
null
}
}
return Pair(initiatedSessions, newSessions)
}.toMap()
return Pair(initiatedSessions, newSessionStates)
}
}

View File

@ -2,14 +2,15 @@ package net.corda.node.services.statemachine.transitions
import net.corda.core.flows.FlowException
import net.corda.core.flows.KilledFlowException
import net.corda.node.services.messaging.MessageIdentifier
import net.corda.node.services.statemachine.Action
import net.corda.node.services.statemachine.Checkpoint
import net.corda.node.services.statemachine.DeduplicationId
import net.corda.node.services.statemachine.ErrorSessionMessage
import net.corda.node.services.statemachine.Event
import net.corda.node.services.statemachine.FlowError
import net.corda.node.services.statemachine.FlowRemovalReason
import net.corda.node.services.statemachine.FlowState
import net.corda.node.services.statemachine.MessageType
import net.corda.node.services.statemachine.SessionId
import net.corda.node.services.statemachine.SessionState
import net.corda.node.services.statemachine.StateMachineState
@ -27,24 +28,37 @@ class KilledFlowTransition(
val killedFlowErrorMessage = createErrorMessageFromError(killedFlowError)
val errorMessages = listOf(killedFlowErrorMessage)
val (initiatedSessions, newSessions) = bufferErrorMessagesInInitiatingSessions(
val (initiatedSessions, newSessionStates) = bufferErrorMessagesInInitiatingSessions(
startingState.checkpoint.checkpointState.sessions,
errorMessages
)
val sessionsWithAdvancedSeqNumbers = mutableMapOf<SessionId, SessionState>()
val errorsPerSession = initiatedSessions.map { (sessionId, sessionState) ->
var currentSeqNumber = sessionState.nextSendingSeqNumber
val errorsWithId = errorMessages.map { errorMsg ->
val messageIdentifier = MessageIdentifier(MessageType.SESSION_ERROR, sessionState.shardId, sessionState.peerSinkSessionId, currentSeqNumber, startingState.checkpoint.checkpointState.suspensionTime)
currentSeqNumber++
Pair(messageIdentifier, errorMsg)
}.toList()
sessionsWithAdvancedSeqNumbers[sessionId] = sessionState.copy(nextSendingSeqNumber = currentSeqNumber)
Pair(sessionState, errorsWithId)
}.toMap()
val newCheckpoint = startingState.checkpoint.copy(
status = Checkpoint.FlowStatus.KILLED,
flowState = FlowState.Finished,
checkpointState = startingState.checkpoint.checkpointState.copy(sessions = newSessions)
status = Checkpoint.FlowStatus.KILLED,
flowState = FlowState.Finished,
checkpointState = startingState.checkpoint.checkpointState.copy(sessions = newSessionStates + sessionsWithAdvancedSeqNumbers)
)
currentState = currentState.copy(
checkpoint = newCheckpoint,
pendingDeduplicationHandlers = emptyList(),
isRemoved = true
checkpoint = newCheckpoint,
pendingDeduplicationHandlers = emptyList(),
closedSessionsPendingToBeSignalled = emptyMap(),
isRemoved = true
)
actions += Action.PropagateErrors(errorMessages, initiatedSessions, startingState.senderUUID)
actions += Action.PropagateErrors(errorsPerSession, startingState.senderUUID)
if (!startingState.isFlowResumed) {
actions += Action.CreateTransaction
@ -59,7 +73,12 @@ class KilledFlowTransition(
actions += Action.AddFlowException(context.id, killedFlowError.exception)
}
val signalSessionsEndMap = currentState.checkpoint.checkpointState.sessions.map { (sessionId, sessionState) ->
sessionId to Pair(sessionState.lastSenderUUID, sessionState.lastSenderSeqNo)
}.toMap()
actions += Action.PersistDeduplicationFacts(startingState.pendingDeduplicationHandlers)
actions += Action.SignalSessionsHasEnded(signalSessionsEndMap)
actions += Action.ReleaseSoftLocks(context.id.uuid)
actions += Action.CommitTransaction(currentState)
actions += Action.AcknowledgeMessages(startingState.pendingDeduplicationHandlers)
@ -91,32 +110,37 @@ class KilledFlowTransition(
}
}
// Purposely left the same as [bufferErrorMessagesInInitiatingSessions] in [ErrorFlowTransition] so that it can be refactored
// Buffer error messages in Initiating sessions, return the initialised ones.
/**
* Buffers errors message for initiating states and filters the initiated states.
* Returns a pair that consists of:
* - a map containing the initiated states as filtered from the ones provided as input.
* - a map containing the new state of all the sessions.
*/
private fun bufferErrorMessagesInInitiatingSessions(
sessions: Map<SessionId, SessionState>,
errorMessages: List<ErrorSessionMessage>
): Pair<List<SessionState.Initiated>, Map<SessionId, SessionState>> {
): Pair<Map<SessionId, SessionState.Initiated>, Map<SessionId, SessionState>> {
val newSessions = sessions.mapValues { (sourceSessionId, sessionState) ->
if (sessionState is SessionState.Initiating && sessionState.rejectionError == null) {
// *prepend* the error messages in order to error the other sessions ASAP. The other messages will
// be delivered all the same, they just won't trigger flow resumption because of dirtiness.
val errorMessagesWithDeduplication = errorMessages.map {
DeduplicationId.createForError(it.errorId, sourceSessionId) to it
var currentSequenceNumber = sessionState.nextSendingSeqNumber
val errorMessagesWithDeduplication = errorMessages.map { errorMessage ->
val messageIdentifier = MessageIdentifier(MessageType.SESSION_ERROR, sessionState.shardId, sourceSessionId.calculateInitiatedSessionId(), currentSequenceNumber, startingState.checkpoint.checkpointState.suspensionTime)
currentSequenceNumber++
messageIdentifier to errorMessage
}
sessionState.copy(bufferedMessages = errorMessagesWithDeduplication + sessionState.bufferedMessages)
sessionState.bufferMessages(errorMessagesWithDeduplication)
} else {
sessionState
}
}
// if we have already received error message from the other side, we don't include that session in the list to avoid propagating errors.
val initiatedSessions = sessions.values.mapNotNull { session ->
if (session is SessionState.Initiated && !session.otherSideErrored) {
session
val initiatedSessions = sessions.mapNotNull { (sessionId, sessionState) ->
if (sessionState is SessionState.Initiated && !sessionState.otherSideErrored) {
sessionId to sessionState
} else {
null
}
}
}.toMap()
return Pair(initiatedSessions, newSessions)
}

View File

@ -10,6 +10,9 @@ import net.corda.core.internal.FlowIORequest
import net.corda.core.serialization.SerializedBytes
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.toNonEmptySet
import net.corda.node.services.messaging.MessageIdentifier
import net.corda.node.services.messaging.SenderDeduplicationInfo
import net.corda.node.services.messaging.generateShardId
import net.corda.node.services.statemachine.*
import org.slf4j.Logger
import kotlin.collections.LinkedHashMap
@ -177,14 +180,19 @@ class StartedFlowTransition(
}
if (existingSessionsToRemove.isNotEmpty()) {
val sendEndMessageActions = existingSessionsToRemove.values.mapIndexed { index, state ->
val sinkSessionId = (state as SessionState.Initiated).peerSinkSessionId
val sendEndMessageActions = existingSessionsToRemove.map { (_, sessionState) ->
val sinkSessionId = (sessionState as SessionState.Initiated).peerSinkSessionId
val message = ExistingSessionMessage(sinkSessionId, EndSessionMessage)
val deduplicationId = DeduplicationId.createForNormal(currentState.checkpoint, index, state)
Action.SendExisting(state.peerParty, message, SenderDeduplicationId(deduplicationId, currentState.senderUUID))
val messageType = MessageType.inferFromMessage(message)
val messageIdentifier = MessageIdentifier(messageType, generateShardId(context.id.toString()), sinkSessionId, sessionState.nextSendingSeqNumber, currentState.checkpoint.checkpointState.suspensionTime)
Action.SendExisting(sessionState.peerParty, message, SenderDeduplicationInfo(messageIdentifier, currentState.senderUUID))
}
val signalSessionsEndMap = existingSessionsToRemove.map { (sessionId, _) ->
val sessionState = currentState.checkpoint.checkpointState.sessions[sessionId]!!
sessionId to Pair(sessionState.lastSenderUUID, sessionState.lastSenderSeqNo)
}.toMap()
currentState = currentState.copy(checkpoint = currentState.checkpoint.removeSessions(existingSessionsToRemove.keys))
currentState = currentState.copy(checkpoint = currentState.checkpoint.removeSessions(existingSessionsToRemove.keys), closedSessionsPendingToBeSignalled = currentState.closedSessionsPendingToBeSignalled + signalSessionsEndMap)
actions.add(Action.RemoveSessionBindings(sessionIdsToRemove))
actions.add(Action.SendMultiple(emptyList(), sendEndMessageActions))
}
@ -239,23 +247,23 @@ class StartedFlowTransition(
@Suppress("ComplexMethod", "NestedBlockDepth")
private fun pollSessionMessages(sessions: SessionMap, sessionIds: Set<SessionId>): PollResult? {
val newSessionMessages = LinkedHashMap(sessions)
val newSessionStates = LinkedHashMap(sessions)
val resultMessages = LinkedHashMap<SessionId, SerializedBytes<Any>>()
var someNotFound = false
for (sessionId in sessionIds) {
val sessionState = sessions[sessionId]
when (sessionState) {
is SessionState.Initiated -> {
val messages = sessionState.receivedMessages
if (messages.isEmpty()) {
if (!sessionState.hasNextMessageArrived()) {
someNotFound = true
} else {
newSessionMessages[sessionId] = sessionState.copy(receivedMessages = messages.subList(1, messages.size).toList())
val (message, newState) = sessionState.extractMessage()
newSessionStates[sessionId] = newState
// at this point, we've already checked for errors and session ends, so it's guaranteed that the first message will be a data message.
resultMessages[sessionId] = if (messages[0] is EndSessionMessage) {
resultMessages[sessionId] = if (message is EndSessionMessage) {
throw UnexpectedFlowEndException("Received session end message instead of a data session message. Mismatched send and receive?")
} else {
(messages[0] as DataSessionMessage).payload
(message as DataSessionMessage).payload
}
}
}
@ -267,14 +275,13 @@ class StartedFlowTransition(
return if (someNotFound) {
return null
} else {
PollResult(resultMessages, newSessionMessages)
PollResult(resultMessages, newSessionStates)
}
}
private fun TransitionBuilder.sendInitialSessionMessagesIfNeeded(sourceSessions: Set<SessionId>) {
val checkpoint = startingState.checkpoint
val newSessions = LinkedHashMap<SessionId, SessionState>(checkpoint.checkpointState.sessions)
var index = 0
for (sourceSessionId in sourceSessions) {
val sessionState = checkpoint.checkpointState.sessions[sourceSessionId]
if (sessionState == null) {
@ -283,14 +290,22 @@ class StartedFlowTransition(
if (sessionState !is SessionState.Uninitiated) {
continue
}
val shardId = generateShardId(context.id.toString())
val counterpartySessionId = sourceSessionId.calculateInitiatedSessionId()
val initialMessage = createInitialSessionMessage(sessionState.initiatingSubFlow, sourceSessionId, sessionState.additionalEntropy, null)
val newSessionState = SessionState.Initiating(
bufferedMessages = emptyList(),
rejectionError = null,
deduplicationSeed = sessionState.deduplicationSeed
deduplicationSeed = sessionState.deduplicationSeed,
nextSendingSeqNumber = 1,
shardId = shardId,
receivedMessages = emptyMap(),
lastSenderUUID = null,
lastSenderSeqNo = null
)
val deduplicationId = DeduplicationId.createForNormal(checkpoint, index++, newSessionState)
actions.add(Action.SendInitial(sessionState.destination, initialMessage, SenderDeduplicationId(deduplicationId, startingState.senderUUID)))
val messageType = MessageType.inferFromMessage(initialMessage)
val messageIdentifier = MessageIdentifier(messageType, shardId, counterpartySessionId, 0, checkpoint.checkpointState.suspensionTime)
actions.add(Action.SendInitial(sessionState.destination, initialMessage, SenderDeduplicationInfo(messageIdentifier, startingState.senderUUID)))
newSessions[sourceSessionId] = newSessionState
}
currentState = currentState.copy(checkpoint = checkpoint.setSessions(sessions = newSessions))
@ -313,37 +328,60 @@ class StartedFlowTransition(
private fun TransitionBuilder.sendToSessionsTransition(sourceSessionIdToMessage: Map<SessionId, SerializedBytes<Any>>) {
val checkpoint = startingState.checkpoint
val newSessions = LinkedHashMap(checkpoint.checkpointState.sessions)
var index = 0
val messagesByType = sourceSessionIdToMessage.toList()
.map { (sourceSessionId, message) -> Triple(sourceSessionId, checkpoint.checkpointState.sessions[sourceSessionId]!!, message) }
.groupBy { it.second::class }
val sendInitialActions = messagesByType[SessionState.Uninitiated::class]?.map { (sourceSessionId, sessionState, message) ->
val sendInitialActions = messagesByType[SessionState.Uninitiated::class]?.mapNotNull { (sourceSessionId, sessionState, message) ->
val uninitiatedSessionState = sessionState as SessionState.Uninitiated
val deduplicationId = DeduplicationId.createForNormal(checkpoint, index++, sessionState)
val initialMessage = createInitialSessionMessage(uninitiatedSessionState.initiatingSubFlow, sourceSessionId, uninitiatedSessionState.additionalEntropy, message)
newSessions[sourceSessionId] = SessionState.Initiating(
bufferedMessages = emptyList(),
rejectionError = null,
deduplicationSeed = uninitiatedSessionState.deduplicationSeed
)
Action.SendInitial(uninitiatedSessionState.destination, initialMessage, SenderDeduplicationId(deduplicationId, startingState.senderUUID))
val shardId = generateShardId(context.id.toString())
if (sessionState.hasBeenAcknowledged != null) {
newSessions[sourceSessionId] = SessionState.Initiated(
peerParty = sessionState.hasBeenAcknowledged.first,
peerFlowInfo = sessionState.hasBeenAcknowledged.second.initiatedFlowInfo,
receivedMessages = emptyMap(),
otherSideErrored = false,
peerSinkSessionId = sessionState.hasBeenAcknowledged.second.initiatedSessionId,
deduplicationSeed = sessionState.deduplicationSeed,
nextSendingSeqNumber = 1,
lastProcessedSeqNumber = 0,
shardId = shardId,
lastSenderUUID = null,
lastSenderSeqNo = null
)
null
} else {
newSessions[sourceSessionId] = SessionState.Initiating(
bufferedMessages = emptyList(),
rejectionError = null,
deduplicationSeed = sessionState.deduplicationSeed,
nextSendingSeqNumber = 1,
shardId = shardId,
receivedMessages = emptyMap(),
lastSenderUUID = null,
lastSenderSeqNo = null
)
val initialMessage = createInitialSessionMessage(uninitiatedSessionState.initiatingSubFlow, sourceSessionId, uninitiatedSessionState.additionalEntropy, message)
val messageType = MessageType.inferFromMessage(initialMessage)
val messageIdentifier = MessageIdentifier(messageType, shardId, sourceSessionId.calculateInitiatedSessionId(), 0, checkpoint.checkpointState.suspensionTime)
Action.SendInitial(uninitiatedSessionState.destination, initialMessage, SenderDeduplicationInfo(messageIdentifier, startingState.senderUUID))
}
} ?: emptyList()
messagesByType[SessionState.Initiating::class]?.forEach { (sourceSessionId, sessionState, message) ->
val initiatingSessionState = sessionState as SessionState.Initiating
val sessionMessage = DataSessionMessage(message)
val deduplicationId = DeduplicationId.createForNormal(checkpoint, index++, initiatingSessionState)
val newBufferedMessages = initiatingSessionState.bufferedMessages + Pair(deduplicationId, sessionMessage)
newSessions[sourceSessionId] = initiatingSessionState.copy(bufferedMessages = newBufferedMessages)
val messageIdentifier = MessageIdentifier(MessageType.DATA_MESSAGE, sessionState.shardId, sourceSessionId.calculateInitiatedSessionId(), sessionState.nextSendingSeqNumber, checkpoint.checkpointState.suspensionTime)
newSessions[sourceSessionId] = initiatingSessionState.bufferMessage(messageIdentifier, sessionMessage)
}
val sendExistingActions = messagesByType[SessionState.Initiated::class]?.map {(_, sessionState, message) ->
val sendExistingActions = messagesByType[SessionState.Initiated::class]?.map {(sourceSessionId, sessionState, message) ->
val initiatedSessionState = sessionState as SessionState.Initiated
val sessionMessage = DataSessionMessage(message)
val deduplicationId = DeduplicationId.createForNormal(checkpoint, index++, initiatedSessionState)
val sinkSessionId = initiatedSessionState.peerSinkSessionId
val existingMessage = ExistingSessionMessage(sinkSessionId, sessionMessage)
Action.SendExisting(initiatedSessionState.peerParty, existingMessage, SenderDeduplicationId(deduplicationId, startingState.senderUUID))
val messageType = MessageType.inferFromMessage(existingMessage)
val messageIdentifier = MessageIdentifier(messageType, sessionState.shardId, sessionState.peerSinkSessionId, sessionState.nextSendingSeqNumber, checkpoint.checkpointState.suspensionTime)
newSessions[sourceSessionId] = initiatedSessionState.copy(nextSendingSeqNumber = initiatedSessionState.nextSendingSeqNumber + 1)
Action.SendExisting(initiatedSessionState.peerParty, existingMessage, SenderDeduplicationInfo(messageIdentifier, startingState.senderUUID))
} ?: emptyList()
if (sendInitialActions.isNotEmpty() || sendExistingActions.isNotEmpty()) {
@ -372,11 +410,10 @@ class StartedFlowTransition(
}
}
is SessionState.Initiated -> {
if (sessionState.receivedMessages.isNotEmpty() && sessionState.receivedMessages.first() is ErrorSessionMessage) {
val errorMessage = sessionState.receivedMessages.first() as ErrorSessionMessage
val exception = convertErrorMessageToException(errorMessage, sessionState.peerParty)
val newSessionState = sessionState.copy(receivedMessages = sessionState.receivedMessages.subList(1, sessionState.receivedMessages.size), otherSideErrored = true)
val newCheckpoint = startingState.checkpoint.addSession(sessionId to newSessionState)
if (sessionState.hasErrored()) {
val (message, newSessionState) = sessionState.extractMessage()
val exception = convertErrorMessageToException(message as ErrorSessionMessage, sessionState.peerParty)
val newCheckpoint = startingState.checkpoint.addSession(sessionId to newSessionState.copy(otherSideErrored = true))
newState = startingState.copy(checkpoint = newCheckpoint)
listOf(exception)
} else {
@ -541,8 +578,8 @@ class StartedFlowTransition(
private fun findSessionsToBeTerminated(startingState: StateMachineState): SessionMap {
return startingState.checkpoint.checkpointState.sessionsToBeClosed.mapNotNull { sessionId ->
val sessionState = startingState.checkpoint.checkpointState.sessions[sessionId]!! as SessionState.Initiated
if (sessionState.receivedMessages.isNotEmpty() && sessionState.receivedMessages.first() is EndSessionMessage) {
val sessionState = startingState.checkpoint.checkpointState.sessions[sessionId]!!
if (sessionState is SessionState.Initiated && sessionState.receivedMessages.containsKey(sessionState.lastProcessedSeqNumber + 1) && sessionState.receivedMessages[sessionState.lastProcessedSeqNumber + 1] is EndSessionMessage) {
sessionId to sessionState
} else {
null

View File

@ -4,12 +4,13 @@ import net.corda.core.flows.StateMachineRunId
import net.corda.node.services.statemachine.Event
import net.corda.node.services.statemachine.StateMachineState
import java.security.SecureRandom
import java.time.Instant
class StateMachine(
val id: StateMachineRunId,
val secureRandom: SecureRandom
) {
fun transition(event: Event, state: StateMachineState): TransitionResult {
return TopLevelTransition(TransitionContext(id, secureRandom), state, event).transition()
fun transition(event: Event, state: StateMachineState, time: Instant): TransitionResult {
return TopLevelTransition(TransitionContext(id, secureRandom, time), state, event).transition()
}
}

View File

@ -7,9 +7,9 @@ import net.corda.core.serialization.deserialize
import net.corda.core.utilities.Try
import net.corda.core.utilities.contextLogger
import net.corda.node.services.messaging.DeduplicationHandler
import net.corda.node.services.messaging.MessageIdentifier
import net.corda.node.services.statemachine.Action
import net.corda.node.services.statemachine.Checkpoint
import net.corda.node.services.statemachine.DeduplicationId
import net.corda.node.services.statemachine.EndSessionMessage
import net.corda.node.services.statemachine.ErrorState
import net.corda.node.services.statemachine.Event
@ -19,7 +19,8 @@ import net.corda.node.services.statemachine.FlowRemovalReason
import net.corda.node.services.statemachine.FlowSessionImpl
import net.corda.node.services.statemachine.FlowState
import net.corda.node.services.statemachine.InitialSessionMessage
import net.corda.node.services.statemachine.SenderDeduplicationId
import net.corda.node.services.messaging.SenderDeduplicationInfo
import net.corda.node.services.statemachine.MessageType
import net.corda.node.services.statemachine.SessionId
import net.corda.node.services.statemachine.SessionMessage
import net.corda.node.services.statemachine.SessionState
@ -197,7 +198,8 @@ class TopLevelTransition(
checkpointState.invocationContext
},
numberOfSuspends = checkpointState.numberOfSuspends + 1,
numberOfCommits = checkpointState.numberOfCommits + 1
numberOfCommits = checkpointState.numberOfCommits + 1,
suspensionTime = context.time
)
copy(
flowState = FlowState.Started(event.ioRequest, event.fiber),
@ -217,11 +219,13 @@ class TopLevelTransition(
currentState = startingState.copy(
checkpoint = newCheckpoint,
pendingDeduplicationHandlers = emptyList(),
closedSessionsPendingToBeSignalled = emptyMap(),
isFlowResumed = false,
isAnyCheckpointPersisted = true
)
actions += Action.PersistCheckpoint(context.id, newCheckpoint, isCheckpointUpdate = startingState.isAnyCheckpointPersisted)
actions += Action.PersistDeduplicationFacts(startingState.pendingDeduplicationHandlers)
actions += Action.SignalSessionsHasEnded(startingState.closedSessionsPendingToBeSignalled)
actions += Action.CommitTransaction(currentState)
actions += Action.AcknowledgeMessages(startingState.pendingDeduplicationHandlers)
actions += Action.ScheduleEvent(Event.DoRemainingWork)
@ -240,12 +244,14 @@ class TopLevelTransition(
checkpoint = checkpoint.copy(
checkpointState = checkpoint.checkpointState.copy(
numberOfSuspends = checkpoint.checkpointState.numberOfSuspends + 1,
numberOfCommits = checkpoint.checkpointState.numberOfCommits + 1
numberOfCommits = checkpoint.checkpointState.numberOfCommits + 1,
suspensionTime = context.time
),
flowState = FlowState.Finished,
result = event.returnValue,
status = Checkpoint.FlowStatus.COMPLETED
),
).removeSessions(checkpoint.checkpointState.sessions.keys),
closedSessionsPendingToBeSignalled = emptyMap(),
pendingDeduplicationHandlers = emptyList(),
isFlowResumed = false,
isRemoved = true
@ -263,7 +269,12 @@ class TopLevelTransition(
)
}
val signalSessionsEndMap = startingState.checkpoint.checkpointState.sessions.map { (sessionId, sessionState) ->
sessionId to Pair(sessionState.lastSenderUUID, sessionState.lastSenderSeqNo)
}.toMap()
actions += Action.PersistDeduplicationFacts(startingState.pendingDeduplicationHandlers)
actions += Action.SignalSessionsHasEnded(signalSessionsEndMap)
actions += Action.ReleaseSoftLocks(event.softLocksId)
actions += Action.CommitTransaction(currentState)
actions += Action.AcknowledgeMessages(startingState.pendingDeduplicationHandlers)
@ -284,11 +295,12 @@ class TopLevelTransition(
}
private fun TransitionBuilder.sendEndMessages() {
val sendEndMessageActions = currentState.checkpoint.checkpointState.sessions.values.mapIndexed { index, state ->
val sendEndMessageActions = startingState.checkpoint.checkpointState.sessions.map { (sessionId, state) ->
if (state is SessionState.Initiated) {
val message = ExistingSessionMessage(state.peerSinkSessionId, EndSessionMessage)
val deduplicationId = DeduplicationId.createForNormal(currentState.checkpoint, index, state)
Action.SendExisting(state.peerParty, message, SenderDeduplicationId(deduplicationId, currentState.senderUUID))
val messageType = MessageType.inferFromMessage(message)
val messageIdentifier = MessageIdentifier(messageType, state.shardId, state.peerSinkSessionId, state.nextSendingSeqNumber, startingState.checkpoint.checkpointState.suspensionTime)
Action.SendExisting(state.peerParty, message, SenderDeduplicationInfo(messageIdentifier, startingState.senderUUID))
} else {
null
}
@ -306,7 +318,7 @@ class TopLevelTransition(
}
val sourceSessionId = SessionId.createRandom(context.secureRandom)
val sessionImpl = FlowSessionImpl(event.destination, event.wellKnownParty, sourceSessionId)
val newSessions = checkpoint.checkpointState.sessions + (sourceSessionId to SessionState.Uninitiated(event.destination, initiatingSubFlow, sourceSessionId, context.secureRandom.nextLong()))
val newSessions = checkpoint.checkpointState.sessions + (sourceSessionId to SessionState.Uninitiated(event.destination, initiatingSubFlow, sourceSessionId, context.secureRandom.nextLong(), null, null, emptyMap(), null, null))
currentState = currentState.copy(checkpoint = checkpoint.setSessions(newSessions))
actions.add(Action.AddSessionBinding(context.id, sourceSessionId))
FlowContinuation.Resume(sessionImpl)
@ -361,10 +373,12 @@ class TopLevelTransition(
numberOfCommits = startingState.checkpoint.checkpointState.numberOfCommits + 1
)
),
pendingDeduplicationHandlers = startingState.pendingDeduplicationHandlers - flowStartEvents
pendingDeduplicationHandlers = startingState.pendingDeduplicationHandlers - flowStartEvents,
closedSessionsPendingToBeSignalled = emptyMap()
)
actions += Action.CreateTransaction
actions += Action.PersistDeduplicationFacts(flowStartEvents)
actions += Action.SignalSessionsHasEnded(startingState.closedSessionsPendingToBeSignalled)
actions += Action.PersistCheckpoint(context.id, newCheckpoint, isCheckpointUpdate = startingState.isAnyCheckpointPersisted)
actions += Action.CommitTransaction(currentState)
actions += Action.AcknowledgeMessages(flowStartEvents)
@ -401,4 +415,5 @@ class TopLevelTransition(
FlowContinuation.Abort
}
}
}

View File

@ -3,6 +3,7 @@ package net.corda.node.services.statemachine.transitions
import net.corda.core.flows.StateMachineRunId
import net.corda.node.services.statemachine.StateMachineState
import java.security.SecureRandom
import java.time.Instant
/**
* An interface used to separate out different parts of the state machine transition function.
@ -28,5 +29,6 @@ interface Transition {
class TransitionContext(
val id: StateMachineRunId,
val secureRandom: SecureRandom
val secureRandom: SecureRandom,
val time: Instant
)

View File

@ -80,6 +80,6 @@ class TransitionBuilder(val context: TransitionContext, initialState: StateMachi
}
class CannotFindSessionException(sessionId: SessionId) : IllegalStateException("Couldn't find session with id $sessionId")
class UnexpectedEventInState : IllegalStateException("Unexpected event")
class UnexpectedEventInState(message: String = "") : IllegalStateException("An unexpected event happened. $message")
class PrematureSessionCloseException(sessionId: SessionId): IllegalStateException("The following session was closed before it was initialised: $sessionId")
class PrematureSessionEndException(sessionId: SessionId): IllegalStateException("A premature session end message was received before the session was initialised: $sessionId")

View File

@ -1,14 +1,15 @@
package net.corda.node.services.statemachine.transitions
import net.corda.core.flows.FlowInfo
import net.corda.node.services.messaging.MessageIdentifier
import net.corda.node.services.statemachine.Action
import net.corda.node.services.statemachine.ConfirmSessionMessage
import net.corda.node.services.statemachine.DataSessionMessage
import net.corda.node.services.statemachine.DeduplicationId
import net.corda.node.services.statemachine.ExistingSessionMessage
import net.corda.node.services.statemachine.FlowStart
import net.corda.node.services.statemachine.FlowState
import net.corda.node.services.statemachine.SenderDeduplicationId
import net.corda.node.services.messaging.SenderDeduplicationInfo
import net.corda.node.services.statemachine.MessageType
import net.corda.node.services.statemachine.SessionState
import net.corda.node.services.statemachine.StateMachineState
@ -50,25 +51,28 @@ class UnstartedFlowTransition(
appName = initiatingMessage.appName
),
receivedMessages = if (initiatingMessage.firstPayload == null) {
emptyList()
emptyMap()
} else {
listOf(DataSessionMessage(initiatingMessage.firstPayload))
mapOf(0 to DataSessionMessage(initiatingMessage.firstPayload))
},
deduplicationSeed = "D-${initiatingMessage.initiatorSessionId.toLong}-${initiatingMessage.initiationEntropy}",
otherSideErrored = false
deduplicationSeed = "D-${initiatingMessage.initiatorSessionId.value}-${initiatingMessage.initiationEntropy}",
otherSideErrored = false,
nextSendingSeqNumber = 1,
lastProcessedSeqNumber = if (initiatingMessage.firstPayload == null) {
0
} else {
-1
},
shardId = flowStart.shardIdentifier,
lastSenderUUID = flowStart.senderUUID,
lastSenderSeqNo = flowStart.senderSequenceNumber
)
val confirmationMessage = ConfirmSessionMessage(flowStart.initiatedSessionId, flowStart.initiatedFlowInfo)
val sessionMessage = ExistingSessionMessage(initiatingMessage.initiatorSessionId, confirmationMessage)
currentState = currentState.copy(
checkpoint = currentState.checkpoint.setSessions(mapOf(flowStart.initiatedSessionId to initiatedState))
)
actions.add(
Action.SendExisting(
flowStart.peerSession.counterparty,
sessionMessage,
SenderDeduplicationId(DeduplicationId.createForNormal(currentState.checkpoint, 0, initiatedState), currentState.senderUUID)
)
)
val messageType = MessageType.inferFromMessage(sessionMessage)
val messageIdentifier = MessageIdentifier(messageType, flowStart.shardIdentifier, initiatingMessage.initiatorSessionId, 0, currentState.checkpoint.checkpointState.suspensionTime)
currentState = currentState.copy(checkpoint = currentState.checkpoint.setSessions(mapOf(flowStart.initiatedSessionId to initiatedState)))
actions.add(Action.SendExisting(flowStart.peerSession.counterparty, sessionMessage, SenderDeduplicationInfo(messageIdentifier, currentState.senderUUID)))
}
// Create initial checkpoint and acknowledge triggering messages.

View File

@ -55,7 +55,7 @@ open class DefaultNamedCacheFactory protected constructor(private val metricRegi
name == "FlowDrainingMode_nodeProperties" -> caffeine.maximumSize(defaultCacheSize)
name == "ContractUpgradeService_upgrades" -> caffeine.maximumSize(defaultCacheSize)
name == "PersistentUniquenessProvider_transactions" -> caffeine.maximumSize(defaultCacheSize)
name == "P2PMessageDeduplicator_processedMessages" -> caffeine.maximumSize(defaultCacheSize)
name == "P2PMessageDeduplicator_sessionData" -> caffeine.maximumSize(defaultCacheSize)
name == "DeduplicationChecker_watermark" -> caffeine
name == "BFTNonValidatingNotaryService_transactions" -> caffeine.maximumSize(defaultCacheSize)
name == "RaftUniquenessProvider_transactions" -> caffeine.maximumSize(defaultCacheSize)

View File

@ -34,5 +34,7 @@
<include file="migration/node-core.changelog-v19.xml"/>
<include file="migration/node-core.changelog-v19-postgres.xml"/>
<include file="migration/node-core.changelog-v19-keys.xml"/>
<include file="migration/node-core.changelog-v20.xml"/>
<include file="migration/node-core.changelog-v21.xml"/>
</databaseChangeLog>

View File

@ -0,0 +1,28 @@
<?xml version="1.1" encoding="UTF-8" standalone="no"?>
<databaseChangeLog xmlns="http://www.liquibase.org/xml/ns/dbchangelog"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://www.liquibase.org/xml/ns/dbchangelog http://www.liquibase.org/xml/ns/dbchangelog/dbchangelog-3.5.xsd"
logicalFilePath="migration/node-services.changelog-init.xml">
<changeSet author="R3.Corda" id="add_session_data_table">
<createTable tableName="node_session_data">
<column name="session_id" type="NUMBER(128)">
<constraints nullable="false"/>
</column>
<column name="init_generation_time" type="timestamp">
<constraints nullable="false"/>
</column>
<column name="sender_hash" type="NVARCHAR(64)">
<constraints nullable="true"/>
</column>
<column name="init_sequence_number" type="BIGINT">
<constraints nullable="true"/>
</column>
<column name="last_sequence_number" type="BIGINT">
<constraints nullable="true"/>
</column>
</createTable>
<addPrimaryKey columnNames="session_id" constraintName="node_session_data_pk" tableName="node_session_data"/>
</changeSet>
</databaseChangeLog>

View File

@ -0,0 +1,47 @@
package net.corda.node.services.messaging
import net.corda.node.services.statemachine.MessageType
import net.corda.node.services.statemachine.SessionId
import org.assertj.core.api.Assertions
import org.assertj.core.api.Assertions.assertThat
import org.junit.Test
import java.lang.IllegalArgumentException
import java.math.BigInteger
import java.time.Instant
class MessageIdentifierTest {
private val shardIdentifier = "XXXXXXXX"
private val sessionIdentifier = SessionId(BigInteger.valueOf(14))
private val sessionSequenceNumber = 1
private val timestamp = Instant.ofEpochMilli(100)
private val messageIdString = "XI-0000000000000064-XXXXXXXX-0000000000000000000000000000000E-1"
@Test(timeout=300_000)
fun `can parse message identifier from string value`() {
val messageIdentifier = MessageIdentifier(MessageType.SESSION_INIT, shardIdentifier, sessionIdentifier, sessionSequenceNumber, timestamp)
assertThat(messageIdentifier.toString()).isEqualTo(messageIdString)
}
@Test(timeout=300_000)
fun `can convert message identifier object to string value`() {
val messageIdentifierString = messageIdString
val messageIdentifier = MessageIdentifier.parse(messageIdentifierString)
assertThat(messageIdentifier.messageType).isInstanceOf(MessageType.SESSION_INIT::class.java)
assertThat(messageIdentifier.shardIdentifier).isEqualTo(shardIdentifier)
assertThat(messageIdentifier.sessionIdentifier).isEqualTo(sessionIdentifier)
assertThat(messageIdentifier.sessionSequenceNumber).isEqualTo(1)
assertThat(messageIdentifier.timestamp).isEqualTo(timestamp)
}
@Test(timeout=300_000)
fun `shard identifier needs to be 8 characters long`() {
Assertions.assertThatThrownBy { MessageIdentifier(MessageType.SESSION_INIT, "XX", sessionIdentifier, 1, timestamp) }
.isInstanceOf(IllegalArgumentException::class.java)
.hasMessage("Shard identifier needs to be 8 characters long, but it was XX")
}
}

View File

@ -0,0 +1,28 @@
package net.corda.node.services.messaging
import net.corda.node.services.statemachine.SessionId
import org.assertj.core.api.Assertions.assertThatThrownBy
import org.junit.Test
import java.lang.IllegalArgumentException
import java.math.BigInteger
class SessionIdTest {
@Test(timeout=300_000)
fun `session identifier cannot be negative`() {
assertThatThrownBy { SessionId(BigInteger.valueOf(-1)) }
.isInstanceOf(IllegalArgumentException::class.java)
.hasMessage("Session identifier cannot be a negative number, but it was -1")
}
@Test(timeout=300_000)
fun `session identifier needs to be a number that can be represented in maximum 128 bits`() {
val largestSessionIdentifierValue = BigInteger.valueOf(2).pow(128).minus(BigInteger.ONE)
val largestValidSessionId = SessionId(largestSessionIdentifierValue)
assertThatThrownBy { SessionId(largestSessionIdentifierValue.plus(BigInteger.ONE)) }
.isInstanceOf(IllegalArgumentException::class.java)
.hasMessage("The size of a session identifier cannot exceed 128 bits, but it was 340282366920938463463374607431768211456")
}
}

View File

@ -873,7 +873,8 @@ class DBCheckpointStorageTests {
frozenLogic,
ALICE,
SubFlowVersion.CoreFlow(version),
false
false,
Clock.systemUTC().instant()
)
.getOrThrow()
return id to checkpoint

View File

@ -194,7 +194,7 @@ class CheckpointDumperImplTest {
override fun call() {}
}
val frozenLogic = logic.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT)
val checkpoint = Checkpoint.create(InvocationContext.shell(), FlowStart.Explicit, logic.javaClass, frozenLogic, myself.identity.party, SubFlowVersion.CoreFlow(version), false)
val checkpoint = Checkpoint.create(InvocationContext.shell(), FlowStart.Explicit, logic.javaClass, frozenLogic, myself.identity.party, SubFlowVersion.CoreFlow(version), false, Clock.systemUTC().instant())
.getOrThrow()
return id to checkpoint
}

View File

@ -8,7 +8,6 @@ import net.corda.client.rpc.notUsed
import net.corda.core.concurrent.CordaFuture
import net.corda.core.contracts.ContractState
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.random63BitValue
import net.corda.core.flows.Destination
import net.corda.core.flows.FinalityFlow
import net.corda.core.flows.FlowException
@ -42,6 +41,8 @@ import net.corda.core.utilities.ProgressTracker.Change
import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.seconds
import net.corda.core.utilities.unwrap
import net.corda.node.services.messaging.MessageIdentifier
import net.corda.node.services.messaging.SenderDeduplicationInfo
import net.corda.node.services.persistence.CheckpointPerformanceRecorder
import net.corda.node.services.persistence.DBCheckpointStorage
import net.corda.node.services.persistence.checkpoints
@ -80,6 +81,8 @@ import org.junit.Before
import org.junit.Test
import rx.Notification
import rx.Observable
import java.math.BigInteger
import java.security.SecureRandom
import java.sql.SQLTransientConnectionException
import java.time.Clock
import java.time.Duration
@ -571,7 +574,7 @@ class FlowFrameworkTests {
@Test(timeout=300_000)
fun `session init with unknown class is sent to the flow hospital, from where we then drop it`() {
aliceNode.sendSessionMessage(InitialSessionMessage(SessionId(random63BitValue()), 0, "not.a.real.Class", 1, "", null), bob)
aliceNode.sendSessionMessage(InitialSessionMessage(SessionId.createRandom(SecureRandom()), 0, "not.a.real.Class", 1, "", null), bob)
mockNet.runNetwork()
assertThat(receivedSessionMessages).hasSize(1) // Only the session-init is expected as the session-reject is blocked by the flow hospital
val medicalRecords = bobNode.smm.flowHospital.track().apply { updates.notUsed() }.snapshot
@ -587,7 +590,7 @@ class FlowFrameworkTests {
@Test(timeout=300_000)
fun `non-flow class in session init`() {
aliceNode.sendSessionMessage(InitialSessionMessage(SessionId(random63BitValue()), 0, String::class.java.name, 1, "", null), bob)
aliceNode.sendSessionMessage(InitialSessionMessage(SessionId.createRandom(SecureRandom()), 0, String::class.java.name, 1, "", null), bob)
mockNet.runNetwork()
assertThat(receivedSessionMessages).hasSize(2) // Only the session-init and session-reject are expected
val lastMessage = receivedSessionMessages.last().message as ExistingSessionMessage
@ -897,7 +900,7 @@ class FlowFrameworkTests {
}
//region Helpers
private val normalEnd = ExistingSessionMessage(SessionId(0), EndSessionMessage) // NormalSessionEnd(0)
private val normalEnd = ExistingSessionMessage(SessionId(BigInteger.valueOf(0)), EndSessionMessage) // NormalSessionEnd(0)
private fun assertSessionTransfers(vararg expected: SessionTransfer) {
assertThat(receivedSessionMessages).containsExactly(*expected)
@ -1039,21 +1042,21 @@ class FlowFrameworkTests {
}
internal fun sessionConfirm(flowVersion: Int = 1) =
ExistingSessionMessage(SessionId(0), ConfirmSessionMessage(SessionId(0), FlowInfo(flowVersion, "")))
ExistingSessionMessage(SessionId(BigInteger.valueOf(0)), ConfirmSessionMessage(SessionId(BigInteger.valueOf(0)), FlowInfo(flowVersion, "")))
internal inline fun <reified P : FlowLogic<*>> TestStartedNode.getSingleFlow(): Pair<P, CordaFuture<*>> {
return smm.findStateMachines(P::class.java).single()
}
private fun sanitise(message: SessionMessage) = when (message) {
is InitialSessionMessage -> message.copy(initiatorSessionId = SessionId(0), initiationEntropy = 0, appName = "")
is InitialSessionMessage -> message.copy(initiatorSessionId = SessionId(BigInteger.valueOf(0)), initiationEntropy = 0, appName = "")
is ExistingSessionMessage -> {
val payload = message.payload
message.copy(
recipientSessionId = SessionId(0),
recipientSessionId = SessionId(BigInteger.valueOf(0)),
payload = when (payload) {
is ConfirmSessionMessage -> payload.copy(
initiatedSessionId = SessionId(0),
initiatedSessionId = SessionId(BigInteger.valueOf(0)),
initiatedFlowInfo = payload.initiatedFlowInfo.copy(appName = "")
)
is ErrorSessionMessage -> payload.copy(
@ -1076,7 +1079,8 @@ internal fun Observable<MessageTransfer>.toSessionTransfers(): Observable<Sessio
internal fun TestStartedNode.sendSessionMessage(message: SessionMessage, destination: Party) {
services.networkService.apply {
val address = getAddressOfParty(PartyInfo.SingleNode(destination, emptyList()))
send(createMessage(FlowMessagingImpl.sessionTopic, message.serialize().bytes), address)
val messageIdentifier = MessageIdentifier(MessageType.DATA_MESSAGE, "XXXXXXXX", SessionId(BigInteger.valueOf(14)), 0, Clock.systemUTC().instant())
send(createMessage(FlowMessagingImpl.sessionTopic, message.serialize().bytes, SenderDeduplicationInfo(messageIdentifier, null), emptyMap()), address)
}
}
@ -1087,7 +1091,7 @@ inline fun <reified T> DatabaseTransaction.findRecordsFromDatabase(): List<T> {
}
internal fun errorMessage(errorResponse: FlowException? = null) =
ExistingSessionMessage(SessionId(0), ErrorSessionMessage(errorResponse, 0))
ExistingSessionMessage(SessionId(BigInteger.valueOf(0)), ErrorSessionMessage(errorResponse, 0))
internal infix fun TestStartedNode.sent(message: SessionMessage): Pair<Int, SessionMessage> = Pair(internals.id, message)
internal infix fun Pair<Int, SessionMessage>.to(node: TestStartedNode): SessionTransfer =
@ -1103,10 +1107,10 @@ internal data class SessionTransfer(val from: Int, val message: SessionMessage,
}
internal fun sessionInit(clientFlowClass: KClass<out FlowLogic<*>>, flowVersion: Int = 1, payload: Any? = null): InitialSessionMessage {
return InitialSessionMessage(SessionId(0), 0, clientFlowClass.java.name, flowVersion, "", payload?.serialize())
return InitialSessionMessage(SessionId(BigInteger.valueOf(0)), 0, clientFlowClass.java.name, flowVersion, "", payload?.serialize())
}
internal fun sessionData(payload: Any) = ExistingSessionMessage(SessionId(0), DataSessionMessage(payload.serialize()))
internal fun sessionData(payload: Any) = ExistingSessionMessage(SessionId(BigInteger.valueOf(0)), DataSessionMessage(payload.serialize()))
@InitiatingFlow
internal open class SendFlow(private val payload: Any, private vararg val otherParties: Party) : FlowLogic<FlowInfo>() {

View File

@ -18,6 +18,7 @@ import org.junit.After
import org.junit.Before
import org.junit.Test
import rx.Observable
import java.math.BigInteger
import java.util.*
class FlowFrameworkTripartyTests {
@ -168,7 +169,7 @@ class FlowFrameworkTripartyTests {
)
}
private val normalEnd = ExistingSessionMessage(SessionId(0), EndSessionMessage) // NormalSessionEnd(0)
private val normalEnd = ExistingSessionMessage(SessionId(BigInteger.valueOf(0)), EndSessionMessage) // NormalSessionEnd(0)
private fun assertSessionTransfers(node: TestStartedNode, vararg expected: SessionTransfer): List<SessionTransfer> {
val actualForNode = receivedSessionMessages.filter { it.from == node.internals.id || it.to == node.network.myAddress }

View File

@ -0,0 +1,42 @@
package net.corda.node.services.statemachine
import org.assertj.core.api.Assertions.*
import org.junit.Test
import java.lang.IllegalArgumentException
import java.math.BigInteger
class SessionIdTest {
companion object {
private val LARGEST_SESSION_ID_VALUE = BigInteger.valueOf(2).pow(128).minus(BigInteger.ONE)
}
@Test(timeout=300_000)
fun `session id must be positive and representable in 128 bits`() {
assertThatThrownBy { SessionId(BigInteger.ZERO.minus(BigInteger.ONE)) }
.isInstanceOf(IllegalArgumentException::class.java)
.hasMessage("Session identifier cannot be a negative number, but it was -1")
assertThatThrownBy { SessionId(LARGEST_SESSION_ID_VALUE.plus(BigInteger.ONE)) }
.isInstanceOf(IllegalArgumentException::class.java)
.hasMessageContaining("The size of a session identifier cannot exceed 128 bits, but it was")
val correctSessionId = SessionId(LARGEST_SESSION_ID_VALUE)
}
@Test(timeout=300_000)
fun `initiated session id is calculated properly`() {
val sessionId = SessionId(BigInteger.ONE)
val initiatedSessionId = sessionId.calculateInitiatedSessionId()
assertThat(initiatedSessionId.value.toLong()).isEqualTo(2)
}
@Test(timeout=300_000)
fun `calculation of initiated session id wraps around`() {
val sessionId = SessionId(LARGEST_SESSION_ID_VALUE)
val initiatedSessionId = sessionId.calculateInitiatedSessionId()
assertThat(initiatedSessionId.value.toLong()).isEqualTo(0)
}
}

View File

@ -2,7 +2,7 @@ package net.corda.testing.node.internal
import net.corda.core.utilities.ByteSequence
import net.corda.node.services.messaging.Message
import net.corda.node.services.statemachine.DeduplicationId
import net.corda.node.services.messaging.MessageIdentifier
import java.time.Instant
/**
@ -10,7 +10,7 @@ import java.time.Instant
*/
data class InMemoryMessage(override val topic: String,
override val data: ByteSequence,
override val uniqueMessageId: DeduplicationId,
override val uniqueMessageId: MessageIdentifier,
override val debugTimestamp: Instant = Instant.now(),
override val senderUUID: String? = null) : Message {

View File

@ -1,5 +1,6 @@
package net.corda.testing.node.internal
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.flows.StateMachineRunId
import net.corda.core.identity.CordaX500Name
import net.corda.core.identity.PartyAndCertificate
@ -13,9 +14,9 @@ import net.corda.core.utilities.OpaqueBytes
import net.corda.core.utilities.contextLogger
import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.messaging.*
import net.corda.node.services.statemachine.DeduplicationId
import net.corda.node.services.statemachine.ExternalEvent
import net.corda.node.services.statemachine.SenderDeduplicationId
import net.corda.node.services.messaging.SenderDeduplicationInfo
import net.corda.node.services.statemachine.SessionId
import net.corda.node.utilities.AffinityExecutor
import net.corda.nodeapi.internal.lifecycle.ServiceStateHelper
import net.corda.nodeapi.internal.lifecycle.ServiceStateSupport
@ -46,9 +47,9 @@ class MockNodeMessagingService(private val configuration: NodeConfiguration,
}
private val state = ThreadBox(InnerState())
private val processedMessages: MutableSet<DeduplicationId> = Collections.synchronizedSet(HashSet<DeduplicationId>())
private val processedMessages: MutableSet<MessageIdentifier> = Collections.synchronizedSet(HashSet())
override val ourSenderUUID: String = UUID.randomUUID().toString()
override val ourSenderUUID: SenderUUID = UUID.randomUUID().toString()
private var _myAddress: InMemoryMessagingNetwork.PeerHandle? = null
override val myAddress: InMemoryMessagingNetwork.PeerHandle get() = checkNotNull(_myAddress) { "Not started" }
@ -167,6 +168,11 @@ class MockNodeMessagingService(private val configuration: NodeConfiguration,
}
}
@Suspendable
override fun sessionEnded(sessionId: SessionId, senderUUID: SenderUUID?, senderSequenceNumber: SenderSequenceNumber?) {
// nothing to do here.
}
override fun close() {
backgroundThread?.let {
it.interrupt()
@ -178,8 +184,8 @@ class MockNodeMessagingService(private val configuration: NodeConfiguration,
}
/** Returns the given (topic & session, data) pair as a newly created message object. */
override fun createMessage(topic: String, data: ByteArray, deduplicationId: SenderDeduplicationId, additionalHeaders: Map<String, String>): Message {
return InMemoryMessage(topic, OpaqueBytes(data), deduplicationId.deduplicationId, senderUUID = deduplicationId.senderUUID)
override fun createMessage(topic: String, data: ByteArray, deduplicationInfo: SenderDeduplicationInfo, additionalHeaders: Map<String, String>): Message {
return InMemoryMessage(topic, OpaqueBytes(data), deduplicationInfo.messageIdentifier, senderUUID = deduplicationInfo.senderUUID)
}
/**
@ -269,7 +275,7 @@ class MockNodeMessagingService(private val configuration: NodeConfiguration,
private data class InMemoryReceivedMessage(override val topic: String,
override val data: ByteSequence,
override val platformVersion: Int,
override val uniqueMessageId: DeduplicationId,
override val uniqueMessageId: MessageIdentifier,
override val debugTimestamp: Instant,
override val peer: CordaX500Name,
override val senderUUID: String? = null,

View File

@ -4,14 +4,24 @@ import net.corda.core.messaging.AllPossibleRecipients
import net.corda.core.serialization.internal.effectiveSerializationEnv
import net.corda.node.services.messaging.Message
import net.corda.coretesting.internal.rigorousMock
import net.corda.node.services.messaging.MessageIdentifier
import net.corda.node.services.messaging.SenderDeduplicationInfo
import net.corda.node.services.statemachine.MessageType
import net.corda.node.services.statemachine.SessionId
import org.assertj.core.api.Assertions.assertThatThrownBy
import org.junit.After
import org.junit.Test
import java.math.BigInteger
import java.time.Clock
import java.util.*
import kotlin.test.assertEquals
import kotlin.test.assertTrue
class InternalMockNetworkTests {
companion object {
private val MESSAGE_IDENTIFIER = MessageIdentifier(MessageType.DATA_MESSAGE, "XXXXXXXX", SessionId(BigInteger.valueOf(14)), 0, Clock.systemUTC().instant())
}
lateinit var mockNet: InternalMockNetwork
@After
@ -39,7 +49,7 @@ class InternalMockNetworkTests {
}
// Node 1 sends a message and it should end up in finalDelivery, after we run the network
node1.network.send(node1.network.createMessage("test.topic", data = bits), node2.network.myAddress)
node1.network.send(node1.network.createMessage("test.topic", bits, SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap()), node2.network.myAddress)
mockNet.runNetwork(rounds = 1)
@ -58,7 +68,7 @@ class InternalMockNetworkTests {
var counter = 0
listOf(node1, node2, node3).forEach { it.network.addMessageHandler("test.topic") { _, _, _ -> counter++ } }
node1.network.send(node2.network.createMessage("test.topic", data = bits), rigorousMock<AllPossibleRecipients>())
node1.network.send(node2.network.createMessage("test.topic", bits, SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap()), rigorousMock<AllPossibleRecipients>())
mockNet.runNetwork(rounds = 1)
assertEquals(3, counter)
}
@ -79,8 +89,8 @@ class InternalMockNetworkTests {
received++
}
val invalidMessage = node2.network.createMessage("invalid_message", data = ByteArray(1))
val validMessage = node2.network.createMessage("valid_message", data = ByteArray(1))
val invalidMessage = node2.network.createMessage("invalid_message", ByteArray(1), SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
val validMessage = node2.network.createMessage("valid_message", ByteArray(1), SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
node2.network.send(invalidMessage, node1.network.myAddress)
mockNet.runNetwork()
assertEquals(0, received)
@ -91,8 +101,8 @@ class InternalMockNetworkTests {
// Here's the core of the test; previously the unhandled message would cause runNetwork() to abort early, so
// this would fail. Make fresh messages to stop duplicate uniqueMessageId causing drops
val invalidMessage2 = node2.network.createMessage("invalid_message", data = ByteArray(1))
val validMessage2 = node2.network.createMessage("valid_message", data = ByteArray(1))
val invalidMessage2 = node2.network.createMessage("invalid_message", ByteArray(1), SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
val validMessage2 = node2.network.createMessage("valid_message", ByteArray(1), SenderDeduplicationInfo(MESSAGE_IDENTIFIER, null), emptyMap())
node2.network.send(invalidMessage2, node1.network.myAddress)
node2.network.send(validMessage2, node1.network.myAddress)
mockNet.runNetwork()