mirror of
https://github.com/corda/corda.git
synced 2025-03-10 22:44:20 +00:00
Refactor common code in Error and KillFlowTransition
This commit is contained in:
parent
0858b7852d
commit
d3de729390
@ -41,28 +41,14 @@ class ErrorFlowTransition(
|
||||
return builder {
|
||||
// If we're errored and propagating do the actual propagation and update the index.
|
||||
if (remainingErrorsToPropagate.isNotEmpty() && errorState.propagating) {
|
||||
val (initiatedSessions, newSessionStates) = bufferErrorMessagesInInitiatingSessions(
|
||||
startingState.checkpoint.checkpointState.sessions,
|
||||
errorMessages
|
||||
)
|
||||
val sessionsWithAdvancedSeqNumbers = mutableMapOf<SessionId, SessionState>()
|
||||
val errorsPerSession = initiatedSessions.map { (sessionId, sessionState) ->
|
||||
var currentSeqNumber = sessionState.nextSendingSeqNumber
|
||||
val errorsWithId = errorMessages.map { errorMsg ->
|
||||
val messageIdentifier = MessageIdentifier(MessageType.SESSION_ERROR, sessionState.shardId, sessionState.peerSinkSessionId, currentSeqNumber, startingState.checkpoint.checkpointState.suspensionTime)
|
||||
currentSeqNumber++
|
||||
Pair(messageIdentifier, errorMsg)
|
||||
}.toList()
|
||||
sessionsWithAdvancedSeqNumbers[sessionId] = sessionState.copy(nextSendingSeqNumber = currentSeqNumber)
|
||||
Pair(sessionState, errorsWithId)
|
||||
}.toMap()
|
||||
val (propagateErrorsAction, newSessionStates) = sendAndBufferErrorMessages(startingState, errorMessages)
|
||||
|
||||
val newCheckpoint = startingState.checkpoint.copy(
|
||||
errorState = errorState.copy(propagatedIndex = allErrors.size),
|
||||
checkpointState = startingState.checkpoint.checkpointState.copy(sessions = newSessionStates + sessionsWithAdvancedSeqNumbers)
|
||||
checkpointState = startingState.checkpoint.checkpointState.copy(sessions = newSessionStates)
|
||||
)
|
||||
currentState = currentState.copy(checkpoint = newCheckpoint)
|
||||
actions += Action.PropagateErrors(errorsPerSession, startingState.senderUUID)
|
||||
actions += propagateErrorsAction
|
||||
}
|
||||
|
||||
// If we're errored but not propagating keep processing events.
|
||||
@ -136,37 +122,44 @@ class ErrorFlowTransition(
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Buffers errors message for initiating states and filters the initiated states.
|
||||
* Returns a pair that consists of:
|
||||
* - a map containing the initiated states as filtered from the ones provided as input.
|
||||
* - a map containing the new state of all the sessions.
|
||||
*/
|
||||
private fun bufferErrorMessagesInInitiatingSessions(
|
||||
sessions: Map<SessionId, SessionState>,
|
||||
errorMessages: List<ErrorSessionMessage>
|
||||
): Pair<Map<SessionId, SessionState.Initiated>, Map<SessionId, SessionState>> {
|
||||
val newSessionStates = sessions.mapValues { (sourceSessionId, sessionState) ->
|
||||
if (sessionState is SessionState.Initiating && sessionState.rejectionError == null) {
|
||||
var currentSequenceNumber = sessionState.nextSendingSeqNumber
|
||||
val errorMessagesWithDeduplication = errorMessages.map { errorMessage ->
|
||||
val messageIdentifier = MessageIdentifier(MessageType.SESSION_ERROR, sessionState.shardId, sourceSessionId.calculateInitiatedSessionId(), currentSequenceNumber, startingState.checkpoint.checkpointState.suspensionTime)
|
||||
currentSequenceNumber++
|
||||
messageIdentifier to errorMessage
|
||||
companion object {
|
||||
/**
|
||||
* Buffers errors message for initiating states and creates the necessary actions to propagate errors to initiated ones.
|
||||
* Returns a pair that consists of:
|
||||
* - a [Action.PropagateErrors] action that contains the error messages to be sent to initiated sessions.
|
||||
* - a map containing the new state of all the sessions.
|
||||
*/
|
||||
fun sendAndBufferErrorMessages(
|
||||
currentState: StateMachineState,
|
||||
errorMessages: List<ErrorSessionMessage>
|
||||
): Pair<Action.PropagateErrors, Map<SessionId, SessionState>> {
|
||||
val errorsToPropagatePerSession = mutableMapOf<SessionState.Initiated, List<Pair<MessageIdentifier, ErrorSessionMessage>>>()
|
||||
val newSessionStates = currentState.checkpoint.checkpointState.sessions.mapValues { (sourceSessionId, sessionState) ->
|
||||
if (sessionState is SessionState.Initiating && sessionState.rejectionError == null) {
|
||||
val errorsWithId = errorMessages.mapIndexed { idx, errorMsg ->
|
||||
val messageId = MessageIdentifier(MessageType.SESSION_ERROR, sessionState.shardId, sourceSessionId.calculateInitiatedSessionId(),
|
||||
sessionState.nextSendingSeqNumber+idx, currentState.checkpoint.checkpointState.suspensionTime)
|
||||
messageId to errorMsg
|
||||
}
|
||||
|
||||
sessionState.bufferMessages(errorsWithId)
|
||||
}
|
||||
// if we have already received error message from the other side, we don't propagate errors to this session.
|
||||
else if (sessionState is SessionState.Initiated && !sessionState.otherSideErrored) {
|
||||
val errorsWithId = errorMessages.mapIndexed { idx, errorMsg ->
|
||||
val messageId = MessageIdentifier(MessageType.SESSION_ERROR, sessionState.shardId, sessionState.peerSinkSessionId,
|
||||
sessionState.nextSendingSeqNumber+idx, currentState.checkpoint.checkpointState.suspensionTime)
|
||||
messageId to errorMsg
|
||||
}.toList()
|
||||
|
||||
errorsToPropagatePerSession[sessionState] = errorsWithId
|
||||
sessionState.copy(nextSendingSeqNumber = sessionState.nextSendingSeqNumber + errorMessages.size)
|
||||
} else {
|
||||
sessionState
|
||||
}
|
||||
sessionState.bufferMessages(errorMessagesWithDeduplication)
|
||||
} else {
|
||||
sessionState
|
||||
}
|
||||
|
||||
return Pair(Action.PropagateErrors(errorsToPropagatePerSession, currentState.senderUUID), newSessionStates)
|
||||
}
|
||||
// if we have already received error message from the other side, we don't include that session in the list to avoid propagating errors.
|
||||
val initiatedSessions = sessions.mapNotNull { (sessionId, sessionState) ->
|
||||
if (sessionState is SessionState.Initiated && !sessionState.otherSideErrored) {
|
||||
sessionId to sessionState
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}.toMap()
|
||||
return Pair(initiatedSessions, newSessionStates)
|
||||
}
|
||||
}
|
||||
|
@ -2,7 +2,6 @@ package net.corda.node.services.statemachine.transitions
|
||||
|
||||
import net.corda.core.flows.FlowException
|
||||
import net.corda.core.flows.KilledFlowException
|
||||
import net.corda.node.services.messaging.MessageIdentifier
|
||||
import net.corda.node.services.statemachine.Action
|
||||
import net.corda.node.services.statemachine.Checkpoint
|
||||
import net.corda.node.services.statemachine.ErrorSessionMessage
|
||||
@ -10,9 +9,6 @@ import net.corda.node.services.statemachine.Event
|
||||
import net.corda.node.services.statemachine.FlowError
|
||||
import net.corda.node.services.statemachine.FlowRemovalReason
|
||||
import net.corda.node.services.statemachine.FlowState
|
||||
import net.corda.node.services.statemachine.MessageType
|
||||
import net.corda.node.services.statemachine.SessionId
|
||||
import net.corda.node.services.statemachine.SessionState
|
||||
import net.corda.node.services.statemachine.StateMachineState
|
||||
|
||||
class KilledFlowTransition(
|
||||
@ -28,27 +24,12 @@ class KilledFlowTransition(
|
||||
val killedFlowErrorMessage = createErrorMessageFromError(killedFlowError)
|
||||
val errorMessages = listOf(killedFlowErrorMessage)
|
||||
|
||||
val (initiatedSessions, newSessionStates) = bufferErrorMessagesInInitiatingSessions(
|
||||
startingState.checkpoint.checkpointState.sessions,
|
||||
errorMessages
|
||||
)
|
||||
|
||||
val sessionsWithAdvancedSeqNumbers = mutableMapOf<SessionId, SessionState>()
|
||||
val errorsPerSession = initiatedSessions.map { (sessionId, sessionState) ->
|
||||
var currentSeqNumber = sessionState.nextSendingSeqNumber
|
||||
val errorsWithId = errorMessages.map { errorMsg ->
|
||||
val messageIdentifier = MessageIdentifier(MessageType.SESSION_ERROR, sessionState.shardId, sessionState.peerSinkSessionId, currentSeqNumber, startingState.checkpoint.checkpointState.suspensionTime)
|
||||
currentSeqNumber++
|
||||
Pair(messageIdentifier, errorMsg)
|
||||
}.toList()
|
||||
sessionsWithAdvancedSeqNumbers[sessionId] = sessionState.copy(nextSendingSeqNumber = currentSeqNumber)
|
||||
Pair(sessionState, errorsWithId)
|
||||
}.toMap()
|
||||
val (propagateErrorsAction, newSessionStates) = ErrorFlowTransition.sendAndBufferErrorMessages(startingState, errorMessages)
|
||||
|
||||
val newCheckpoint = startingState.checkpoint.copy(
|
||||
status = Checkpoint.FlowStatus.KILLED,
|
||||
flowState = FlowState.Finished,
|
||||
checkpointState = startingState.checkpoint.checkpointState.copy(sessions = newSessionStates + sessionsWithAdvancedSeqNumbers)
|
||||
checkpointState = startingState.checkpoint.checkpointState.copy(sessions = newSessionStates)
|
||||
)
|
||||
|
||||
currentState = currentState.copy(
|
||||
@ -58,7 +39,7 @@ class KilledFlowTransition(
|
||||
isRemoved = true
|
||||
)
|
||||
|
||||
actions += Action.PropagateErrors(errorsPerSession, startingState.senderUUID)
|
||||
actions += propagateErrorsAction
|
||||
|
||||
if (!startingState.isFlowResumed) {
|
||||
actions += Action.CreateTransaction
|
||||
@ -110,40 +91,6 @@ class KilledFlowTransition(
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Buffers errors message for initiating states and filters the initiated states.
|
||||
* Returns a pair that consists of:
|
||||
* - a map containing the initiated states as filtered from the ones provided as input.
|
||||
* - a map containing the new state of all the sessions.
|
||||
*/
|
||||
private fun bufferErrorMessagesInInitiatingSessions(
|
||||
sessions: Map<SessionId, SessionState>,
|
||||
errorMessages: List<ErrorSessionMessage>
|
||||
): Pair<Map<SessionId, SessionState.Initiated>, Map<SessionId, SessionState>> {
|
||||
val newSessions = sessions.mapValues { (sourceSessionId, sessionState) ->
|
||||
if (sessionState is SessionState.Initiating && sessionState.rejectionError == null) {
|
||||
var currentSequenceNumber = sessionState.nextSendingSeqNumber
|
||||
val errorMessagesWithDeduplication = errorMessages.map { errorMessage ->
|
||||
val messageIdentifier = MessageIdentifier(MessageType.SESSION_ERROR, sessionState.shardId, sourceSessionId.calculateInitiatedSessionId(), currentSequenceNumber, startingState.checkpoint.checkpointState.suspensionTime)
|
||||
currentSequenceNumber++
|
||||
messageIdentifier to errorMessage
|
||||
}
|
||||
sessionState.bufferMessages(errorMessagesWithDeduplication)
|
||||
} else {
|
||||
sessionState
|
||||
}
|
||||
}
|
||||
// if we have already received error message from the other side, we don't include that session in the list to avoid propagating errors.
|
||||
val initiatedSessions = sessions.mapNotNull { (sessionId, sessionState) ->
|
||||
if (sessionState is SessionState.Initiated && !sessionState.otherSideErrored) {
|
||||
sessionId to sessionState
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}.toMap()
|
||||
return Pair(initiatedSessions, newSessions)
|
||||
}
|
||||
|
||||
private fun createKilledRemovalReason(error: FlowError): FlowRemovalReason.ErrorFinish {
|
||||
return FlowRemovalReason.ErrorFinish(listOf(error))
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user