Refactor common code in Error and KillFlowTransition

This commit is contained in:
Dimos Raptis 2020-09-30 09:09:16 +01:00
parent 0858b7852d
commit d3de729390
2 changed files with 42 additions and 102 deletions

View File

@ -41,28 +41,14 @@ class ErrorFlowTransition(
return builder {
// If we're errored and propagating do the actual propagation and update the index.
if (remainingErrorsToPropagate.isNotEmpty() && errorState.propagating) {
val (initiatedSessions, newSessionStates) = bufferErrorMessagesInInitiatingSessions(
startingState.checkpoint.checkpointState.sessions,
errorMessages
)
val sessionsWithAdvancedSeqNumbers = mutableMapOf<SessionId, SessionState>()
val errorsPerSession = initiatedSessions.map { (sessionId, sessionState) ->
var currentSeqNumber = sessionState.nextSendingSeqNumber
val errorsWithId = errorMessages.map { errorMsg ->
val messageIdentifier = MessageIdentifier(MessageType.SESSION_ERROR, sessionState.shardId, sessionState.peerSinkSessionId, currentSeqNumber, startingState.checkpoint.checkpointState.suspensionTime)
currentSeqNumber++
Pair(messageIdentifier, errorMsg)
}.toList()
sessionsWithAdvancedSeqNumbers[sessionId] = sessionState.copy(nextSendingSeqNumber = currentSeqNumber)
Pair(sessionState, errorsWithId)
}.toMap()
val (propagateErrorsAction, newSessionStates) = sendAndBufferErrorMessages(startingState, errorMessages)
val newCheckpoint = startingState.checkpoint.copy(
errorState = errorState.copy(propagatedIndex = allErrors.size),
checkpointState = startingState.checkpoint.checkpointState.copy(sessions = newSessionStates + sessionsWithAdvancedSeqNumbers)
checkpointState = startingState.checkpoint.checkpointState.copy(sessions = newSessionStates)
)
currentState = currentState.copy(checkpoint = newCheckpoint)
actions += Action.PropagateErrors(errorsPerSession, startingState.senderUUID)
actions += propagateErrorsAction
}
// If we're errored but not propagating keep processing events.
@ -136,37 +122,44 @@ class ErrorFlowTransition(
}
}
companion object {
/**
* Buffers errors message for initiating states and filters the initiated states.
* Buffers errors message for initiating states and creates the necessary actions to propagate errors to initiated ones.
* Returns a pair that consists of:
* - a map containing the initiated states as filtered from the ones provided as input.
* - a [Action.PropagateErrors] action that contains the error messages to be sent to initiated sessions.
* - a map containing the new state of all the sessions.
*/
private fun bufferErrorMessagesInInitiatingSessions(
sessions: Map<SessionId, SessionState>,
fun sendAndBufferErrorMessages(
currentState: StateMachineState,
errorMessages: List<ErrorSessionMessage>
): Pair<Map<SessionId, SessionState.Initiated>, Map<SessionId, SessionState>> {
val newSessionStates = sessions.mapValues { (sourceSessionId, sessionState) ->
): Pair<Action.PropagateErrors, Map<SessionId, SessionState>> {
val errorsToPropagatePerSession = mutableMapOf<SessionState.Initiated, List<Pair<MessageIdentifier, ErrorSessionMessage>>>()
val newSessionStates = currentState.checkpoint.checkpointState.sessions.mapValues { (sourceSessionId, sessionState) ->
if (sessionState is SessionState.Initiating && sessionState.rejectionError == null) {
var currentSequenceNumber = sessionState.nextSendingSeqNumber
val errorMessagesWithDeduplication = errorMessages.map { errorMessage ->
val messageIdentifier = MessageIdentifier(MessageType.SESSION_ERROR, sessionState.shardId, sourceSessionId.calculateInitiatedSessionId(), currentSequenceNumber, startingState.checkpoint.checkpointState.suspensionTime)
currentSequenceNumber++
messageIdentifier to errorMessage
val errorsWithId = errorMessages.mapIndexed { idx, errorMsg ->
val messageId = MessageIdentifier(MessageType.SESSION_ERROR, sessionState.shardId, sourceSessionId.calculateInitiatedSessionId(),
sessionState.nextSendingSeqNumber+idx, currentState.checkpoint.checkpointState.suspensionTime)
messageId to errorMsg
}
sessionState.bufferMessages(errorMessagesWithDeduplication)
sessionState.bufferMessages(errorsWithId)
}
// if we have already received error message from the other side, we don't propagate errors to this session.
else if (sessionState is SessionState.Initiated && !sessionState.otherSideErrored) {
val errorsWithId = errorMessages.mapIndexed { idx, errorMsg ->
val messageId = MessageIdentifier(MessageType.SESSION_ERROR, sessionState.shardId, sessionState.peerSinkSessionId,
sessionState.nextSendingSeqNumber+idx, currentState.checkpoint.checkpointState.suspensionTime)
messageId to errorMsg
}.toList()
errorsToPropagatePerSession[sessionState] = errorsWithId
sessionState.copy(nextSendingSeqNumber = sessionState.nextSendingSeqNumber + errorMessages.size)
} else {
sessionState
}
}
// if we have already received error message from the other side, we don't include that session in the list to avoid propagating errors.
val initiatedSessions = sessions.mapNotNull { (sessionId, sessionState) ->
if (sessionState is SessionState.Initiated && !sessionState.otherSideErrored) {
sessionId to sessionState
} else {
null
}
}.toMap()
return Pair(initiatedSessions, newSessionStates)
return Pair(Action.PropagateErrors(errorsToPropagatePerSession, currentState.senderUUID), newSessionStates)
}
}
}

View File

@ -2,7 +2,6 @@ package net.corda.node.services.statemachine.transitions
import net.corda.core.flows.FlowException
import net.corda.core.flows.KilledFlowException
import net.corda.node.services.messaging.MessageIdentifier
import net.corda.node.services.statemachine.Action
import net.corda.node.services.statemachine.Checkpoint
import net.corda.node.services.statemachine.ErrorSessionMessage
@ -10,9 +9,6 @@ import net.corda.node.services.statemachine.Event
import net.corda.node.services.statemachine.FlowError
import net.corda.node.services.statemachine.FlowRemovalReason
import net.corda.node.services.statemachine.FlowState
import net.corda.node.services.statemachine.MessageType
import net.corda.node.services.statemachine.SessionId
import net.corda.node.services.statemachine.SessionState
import net.corda.node.services.statemachine.StateMachineState
class KilledFlowTransition(
@ -28,27 +24,12 @@ class KilledFlowTransition(
val killedFlowErrorMessage = createErrorMessageFromError(killedFlowError)
val errorMessages = listOf(killedFlowErrorMessage)
val (initiatedSessions, newSessionStates) = bufferErrorMessagesInInitiatingSessions(
startingState.checkpoint.checkpointState.sessions,
errorMessages
)
val sessionsWithAdvancedSeqNumbers = mutableMapOf<SessionId, SessionState>()
val errorsPerSession = initiatedSessions.map { (sessionId, sessionState) ->
var currentSeqNumber = sessionState.nextSendingSeqNumber
val errorsWithId = errorMessages.map { errorMsg ->
val messageIdentifier = MessageIdentifier(MessageType.SESSION_ERROR, sessionState.shardId, sessionState.peerSinkSessionId, currentSeqNumber, startingState.checkpoint.checkpointState.suspensionTime)
currentSeqNumber++
Pair(messageIdentifier, errorMsg)
}.toList()
sessionsWithAdvancedSeqNumbers[sessionId] = sessionState.copy(nextSendingSeqNumber = currentSeqNumber)
Pair(sessionState, errorsWithId)
}.toMap()
val (propagateErrorsAction, newSessionStates) = ErrorFlowTransition.sendAndBufferErrorMessages(startingState, errorMessages)
val newCheckpoint = startingState.checkpoint.copy(
status = Checkpoint.FlowStatus.KILLED,
flowState = FlowState.Finished,
checkpointState = startingState.checkpoint.checkpointState.copy(sessions = newSessionStates + sessionsWithAdvancedSeqNumbers)
checkpointState = startingState.checkpoint.checkpointState.copy(sessions = newSessionStates)
)
currentState = currentState.copy(
@ -58,7 +39,7 @@ class KilledFlowTransition(
isRemoved = true
)
actions += Action.PropagateErrors(errorsPerSession, startingState.senderUUID)
actions += propagateErrorsAction
if (!startingState.isFlowResumed) {
actions += Action.CreateTransaction
@ -110,40 +91,6 @@ class KilledFlowTransition(
}
}
/**
* Buffers errors message for initiating states and filters the initiated states.
* Returns a pair that consists of:
* - a map containing the initiated states as filtered from the ones provided as input.
* - a map containing the new state of all the sessions.
*/
private fun bufferErrorMessagesInInitiatingSessions(
sessions: Map<SessionId, SessionState>,
errorMessages: List<ErrorSessionMessage>
): Pair<Map<SessionId, SessionState.Initiated>, Map<SessionId, SessionState>> {
val newSessions = sessions.mapValues { (sourceSessionId, sessionState) ->
if (sessionState is SessionState.Initiating && sessionState.rejectionError == null) {
var currentSequenceNumber = sessionState.nextSendingSeqNumber
val errorMessagesWithDeduplication = errorMessages.map { errorMessage ->
val messageIdentifier = MessageIdentifier(MessageType.SESSION_ERROR, sessionState.shardId, sourceSessionId.calculateInitiatedSessionId(), currentSequenceNumber, startingState.checkpoint.checkpointState.suspensionTime)
currentSequenceNumber++
messageIdentifier to errorMessage
}
sessionState.bufferMessages(errorMessagesWithDeduplication)
} else {
sessionState
}
}
// if we have already received error message from the other side, we don't include that session in the list to avoid propagating errors.
val initiatedSessions = sessions.mapNotNull { (sessionId, sessionState) ->
if (sessionState is SessionState.Initiated && !sessionState.otherSideErrored) {
sessionId to sessionState
} else {
null
}
}.toMap()
return Pair(initiatedSessions, newSessions)
}
private fun createKilledRemovalReason(error: FlowError): FlowRemovalReason.ErrorFinish {
return FlowRemovalReason.ErrorFinish(listOf(error))
}