ENT-6142 Use ArrayList for SessionState structures (#4169)

Prevent some serialization errors that occur due to serialization
and deserialization of `ArrayList$SubList` found inside the
`SessionState` data structures.

To prevent this, an explicit `ArrayList` is used rather than a `List`.

Overload the `List` operator functions so that `+` returns an
`ArrayList` instead of a `List`.

Create `toArrayList` for a few conversions.
This commit is contained in:
Dan Newton 2021-01-18 14:45:00 +00:00 committed by Dan Newton
parent c79ad972d0
commit 88172b630d
7 changed files with 50 additions and 13 deletions

View File

@ -281,7 +281,7 @@ sealed class SessionState {
* @property rejectionError if non-null the initiation failed. * @property rejectionError if non-null the initiation failed.
*/ */
data class Initiating( data class Initiating(
val bufferedMessages: List<Pair<DeduplicationId, ExistingSessionMessagePayload>>, val bufferedMessages: ArrayList<Pair<DeduplicationId, ExistingSessionMessagePayload>>,
val rejectionError: FlowError?, val rejectionError: FlowError?,
override val deduplicationSeed: String override val deduplicationSeed: String
) : SessionState() ) : SessionState()
@ -298,7 +298,7 @@ sealed class SessionState {
data class Initiated( data class Initiated(
val peerParty: Party, val peerParty: Party,
val peerFlowInfo: FlowInfo, val peerFlowInfo: FlowInfo,
val receivedMessages: List<ExistingSessionMessagePayload>, val receivedMessages: ArrayList<ExistingSessionMessagePayload>,
val otherSideErrored: Boolean, val otherSideErrored: Boolean,
val peerSinkSessionId: SessionId, val peerSinkSessionId: SessionId,
override val deduplicationSeed: String override val deduplicationSeed: String

View File

@ -87,7 +87,7 @@ class DeliverSessionMessageTransition(
val initiatedSession = SessionState.Initiated( val initiatedSession = SessionState.Initiated(
peerParty = event.sender, peerParty = event.sender,
peerFlowInfo = message.initiatedFlowInfo, peerFlowInfo = message.initiatedFlowInfo,
receivedMessages = emptyList(), receivedMessages = arrayListOf(),
peerSinkSessionId = message.initiatedSessionId, peerSinkSessionId = message.initiatedSessionId,
deduplicationSeed = sessionState.deduplicationSeed, deduplicationSeed = sessionState.deduplicationSeed,
otherSideErrored = false otherSideErrored = false

View File

@ -121,9 +121,9 @@ class ErrorFlowTransition(
if (sessionState is SessionState.Initiating && sessionState.rejectionError == null) { if (sessionState is SessionState.Initiating && sessionState.rejectionError == null) {
// *prepend* the error messages in order to error the other sessions ASAP. The other messages will // *prepend* the error messages in order to error the other sessions ASAP. The other messages will
// be delivered all the same, they just won't trigger flow resumption because of dirtiness. // be delivered all the same, they just won't trigger flow resumption because of dirtiness.
val errorMessagesWithDeduplication = errorMessages.map { val errorMessagesWithDeduplication: ArrayList<Pair<DeduplicationId, ExistingSessionMessagePayload>> = errorMessages.map {
DeduplicationId.createForError(it.errorId, sourceSessionId) to it DeduplicationId.createForError(it.errorId, sourceSessionId) to it
} }.toArrayList()
sessionState.copy(bufferedMessages = errorMessagesWithDeduplication + sessionState.bufferedMessages) sessionState.copy(bufferedMessages = errorMessagesWithDeduplication + sessionState.bufferedMessages)
} else { } else {
sessionState sessionState

View File

@ -7,12 +7,14 @@ import net.corda.node.services.statemachine.Checkpoint
import net.corda.node.services.statemachine.DeduplicationId import net.corda.node.services.statemachine.DeduplicationId
import net.corda.node.services.statemachine.ErrorSessionMessage import net.corda.node.services.statemachine.ErrorSessionMessage
import net.corda.node.services.statemachine.Event import net.corda.node.services.statemachine.Event
import net.corda.node.services.statemachine.ExistingSessionMessagePayload
import net.corda.node.services.statemachine.FlowError import net.corda.node.services.statemachine.FlowError
import net.corda.node.services.statemachine.FlowRemovalReason import net.corda.node.services.statemachine.FlowRemovalReason
import net.corda.node.services.statemachine.FlowState import net.corda.node.services.statemachine.FlowState
import net.corda.node.services.statemachine.SessionId import net.corda.node.services.statemachine.SessionId
import net.corda.node.services.statemachine.SessionState import net.corda.node.services.statemachine.SessionState
import net.corda.node.services.statemachine.StateMachineState import net.corda.node.services.statemachine.StateMachineState
import java.util.ArrayList
class KilledFlowTransition( class KilledFlowTransition(
override val context: TransitionContext, override val context: TransitionContext,
@ -101,9 +103,9 @@ class KilledFlowTransition(
if (sessionState is SessionState.Initiating && sessionState.rejectionError == null) { if (sessionState is SessionState.Initiating && sessionState.rejectionError == null) {
// *prepend* the error messages in order to error the other sessions ASAP. The other messages will // *prepend* the error messages in order to error the other sessions ASAP. The other messages will
// be delivered all the same, they just won't trigger flow resumption because of dirtiness. // be delivered all the same, they just won't trigger flow resumption because of dirtiness.
val errorMessagesWithDeduplication = errorMessages.map { val errorMessagesWithDeduplication: ArrayList<Pair<DeduplicationId, ExistingSessionMessagePayload>> = errorMessages.map {
DeduplicationId.createForError(it.errorId, sourceSessionId) to it DeduplicationId.createForError(it.errorId, sourceSessionId) to it
} }.toArrayList()
sessionState.copy(bufferedMessages = errorMessagesWithDeduplication + sessionState.bufferedMessages) sessionState.copy(bufferedMessages = errorMessagesWithDeduplication + sessionState.bufferedMessages)
} else { } else {
sessionState sessionState

View File

@ -250,7 +250,7 @@ class StartedFlowTransition(
if (messages.isEmpty()) { if (messages.isEmpty()) {
someNotFound = true someNotFound = true
} else { } else {
newSessionMessages[sessionId] = sessionState.copy(receivedMessages = messages.subList(1, messages.size).toList()) newSessionMessages[sessionId] = sessionState.copy(receivedMessages = messages.subList(1, messages.size).toArrayList())
// at this point, we've already checked for errors and session ends, so it's guaranteed that the first message will be a data message. // at this point, we've already checked for errors and session ends, so it's guaranteed that the first message will be a data message.
resultMessages[sessionId] = if (messages[0] is EndSessionMessage) { resultMessages[sessionId] = if (messages[0] is EndSessionMessage) {
throw UnexpectedFlowEndException("Received session end message instead of a data session message. Mismatched send and receive?") throw UnexpectedFlowEndException("Received session end message instead of a data session message. Mismatched send and receive?")
@ -285,7 +285,7 @@ class StartedFlowTransition(
} }
val initialMessage = createInitialSessionMessage(sessionState.initiatingSubFlow, sourceSessionId, sessionState.additionalEntropy, null) val initialMessage = createInitialSessionMessage(sessionState.initiatingSubFlow, sourceSessionId, sessionState.additionalEntropy, null)
val newSessionState = SessionState.Initiating( val newSessionState = SessionState.Initiating(
bufferedMessages = emptyList(), bufferedMessages = arrayListOf(),
rejectionError = null, rejectionError = null,
deduplicationSeed = sessionState.deduplicationSeed deduplicationSeed = sessionState.deduplicationSeed
) )
@ -324,7 +324,7 @@ class StartedFlowTransition(
val deduplicationId = DeduplicationId.createForNormal(checkpoint, index++, sessionState) val deduplicationId = DeduplicationId.createForNormal(checkpoint, index++, sessionState)
val initialMessage = createInitialSessionMessage(uninitiatedSessionState.initiatingSubFlow, sourceSessionId, uninitiatedSessionState.additionalEntropy, message) val initialMessage = createInitialSessionMessage(uninitiatedSessionState.initiatingSubFlow, sourceSessionId, uninitiatedSessionState.additionalEntropy, message)
newSessions[sourceSessionId] = SessionState.Initiating( newSessions[sourceSessionId] = SessionState.Initiating(
bufferedMessages = emptyList(), bufferedMessages = arrayListOf(),
rejectionError = null, rejectionError = null,
deduplicationSeed = uninitiatedSessionState.deduplicationSeed deduplicationSeed = uninitiatedSessionState.deduplicationSeed
) )
@ -375,7 +375,10 @@ class StartedFlowTransition(
if (sessionState.receivedMessages.isNotEmpty() && sessionState.receivedMessages.first() is ErrorSessionMessage) { if (sessionState.receivedMessages.isNotEmpty() && sessionState.receivedMessages.first() is ErrorSessionMessage) {
val errorMessage = sessionState.receivedMessages.first() as ErrorSessionMessage val errorMessage = sessionState.receivedMessages.first() as ErrorSessionMessage
val exception = convertErrorMessageToException(errorMessage, sessionState.peerParty) val exception = convertErrorMessageToException(errorMessage, sessionState.peerParty)
val newSessionState = sessionState.copy(receivedMessages = sessionState.receivedMessages.subList(1, sessionState.receivedMessages.size), otherSideErrored = true) val newSessionState = sessionState.copy(
receivedMessages = sessionState.receivedMessages.subList(1, sessionState.receivedMessages.size).toArrayList(),
otherSideErrored = true
)
val newCheckpoint = startingState.checkpoint.addSession(sessionId to newSessionState) val newCheckpoint = startingState.checkpoint.addSession(sessionId to newSessionState)
newState = startingState.copy(checkpoint = newCheckpoint) newState = startingState.copy(checkpoint = newCheckpoint)
listOf(exception) listOf(exception)

View File

@ -24,6 +24,37 @@ interface Transition {
val continuation = build(builder) val continuation = build(builder)
return TransitionResult(builder.currentState, builder.actions, continuation) return TransitionResult(builder.currentState, builder.actions, continuation)
} }
/**
* Add [element] to the [ArrayList] and return the list.
*
* Copy of [List.plus] that returns an [ArrayList] instead.
*/
operator fun <T> ArrayList<T>.plus(element: T) : ArrayList<T> {
val result = ArrayList<T>(size + 1)
result.addAll(this)
result.add(element)
return result
}
/**
* Add [elements] to the [ArrayList] and return the list.
*
* Copy of [List.plus] that returns an [ArrayList] instead.
*/
operator fun <T> ArrayList<T>.plus(elements: Collection<T>) : ArrayList<T> {
val result = ArrayList<T>(this.size + elements.size)
result.addAll(this)
result.addAll(elements)
return result
}
/**
* Convert the [List] into an [ArrayList].
*/
fun <T> List<T>.toArrayList() : ArrayList<T> {
return ArrayList(this)
}
} }
class TransitionContext( class TransitionContext(

View File

@ -6,6 +6,7 @@ import net.corda.node.services.statemachine.ConfirmSessionMessage
import net.corda.node.services.statemachine.DataSessionMessage import net.corda.node.services.statemachine.DataSessionMessage
import net.corda.node.services.statemachine.DeduplicationId import net.corda.node.services.statemachine.DeduplicationId
import net.corda.node.services.statemachine.ExistingSessionMessage import net.corda.node.services.statemachine.ExistingSessionMessage
import net.corda.node.services.statemachine.ExistingSessionMessagePayload
import net.corda.node.services.statemachine.FlowStart import net.corda.node.services.statemachine.FlowStart
import net.corda.node.services.statemachine.FlowState import net.corda.node.services.statemachine.FlowState
import net.corda.node.services.statemachine.SenderDeduplicationId import net.corda.node.services.statemachine.SenderDeduplicationId
@ -50,9 +51,9 @@ class UnstartedFlowTransition(
appName = initiatingMessage.appName appName = initiatingMessage.appName
), ),
receivedMessages = if (initiatingMessage.firstPayload == null) { receivedMessages = if (initiatingMessage.firstPayload == null) {
emptyList() arrayListOf()
} else { } else {
listOf(DataSessionMessage(initiatingMessage.firstPayload)) arrayListOf<ExistingSessionMessagePayload>(DataSessionMessage(initiatingMessage.firstPayload))
}, },
deduplicationSeed = "D-${initiatingMessage.initiatorSessionId.toLong}-${initiatingMessage.initiationEntropy}", deduplicationSeed = "D-${initiatingMessage.initiatorSessionId.toLong}-${initiatingMessage.initiationEntropy}",
otherSideErrored = false otherSideErrored = false