From a73dad00e2feb672147ff17e97f84fd699097e12 Mon Sep 17 00:00:00 2001 From: Dan Newton Date: Thu, 6 Aug 2020 09:51:42 +0100 Subject: [PATCH] CORDA-3850 Add a per flow lock (#6437) Add a lock to `StateMachineState`, allowing every flow to lock themselves when performing a transition or when an external thread (such as `killFlow`) tries to interact with a flow from occurring at the same time. Doing this prevents race-conditions where the external threads mutate the database or the flow's state causing an in-flight transition to fail. A `Semaphore` is used to acquire and release the lock. A `ReentrantLock` is not used as it is possible for a flow to suspend while locked, and resume on a different thread. This causes a `ReentrantLock` to fail when releasing the lock because the thread doing so is not the thread holding the lock. `Semaphore`s can be used across threads, therefore bypassing this issue. The lock is copied across when a flow is retried. This is to prevent another thread from interacting with a flow just after it has been retried. Without copying the lock, the external thread would acquire the old lock and execute, while the fiber thread acquires the new lock and also executes. --- .../node/services/statemachine/FlowCreator.kt | 17 ++++-- .../statemachine/FlowStateMachineImpl.kt | 39 ++++++++---- .../SingleThreadedStateMachineManager.kt | 61 ++++++++++++++----- .../statemachine/StateMachineState.kt | 14 +++-- 4 files changed, 91 insertions(+), 40 deletions(-) diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowCreator.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowCreator.kt index 8e388b8d35..ae0affd110 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowCreator.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowCreator.kt @@ -25,6 +25,7 @@ import net.corda.node.utilities.isEnabledTimedFlow import net.corda.nodeapi.internal.persistence.CordaPersistence import org.apache.activemq.artemis.utils.ReusableLatch import java.security.SecureRandom +import java.util.concurrent.Semaphore class Flow(val fiber: FlowStateMachineImpl, val resultFuture: OpenFuture) @@ -71,22 +72,23 @@ class FlowCreator( fun createFlowFromCheckpoint( runId: StateMachineRunId, oldCheckpoint: Checkpoint, - reloadCheckpointAfterSuspendCount: Int? = null + reloadCheckpointAfterSuspendCount: Int? = null, + lock: Semaphore = Semaphore(1) ): Flow<*>? { val checkpoint = oldCheckpoint.copy(status = Checkpoint.FlowStatus.RUNNABLE) val fiber = checkpoint.getFiberFromCheckpoint(runId) ?: return null val resultFuture = openFuture() fiber.logic.stateMachine = fiber verifyFlowLogicIsSuspendable(fiber.logic) - val state = createStateMachineState( + fiber.transientValues = createTransientValues(runId, resultFuture) + fiber.transientState = createStateMachineState( checkpoint = checkpoint, fiber = fiber, anyCheckpointPersisted = true, reloadCheckpointAfterSuspendCount = reloadCheckpointAfterSuspendCount - ?: if (reloadCheckpointAfterSuspend) checkpoint.checkpointState.numberOfSuspends else null + ?: if (reloadCheckpointAfterSuspend) checkpoint.checkpointState.numberOfSuspends else null, + lock = lock ) - fiber.transientValues = createTransientValues(runId, resultFuture) - fiber.transientState = state return Flow(fiber, resultFuture) } @@ -125,6 +127,7 @@ class FlowCreator( fiber = flowStateMachineImpl, anyCheckpointPersisted = existingCheckpoint != null, reloadCheckpointAfterSuspendCount = if (reloadCheckpointAfterSuspend) 0 else null, + lock = Semaphore(1), deduplicationHandler = deduplicationHandler, senderUUID = senderUUID ) @@ -196,6 +199,7 @@ class FlowCreator( fiber: FlowStateMachineImpl<*>, anyCheckpointPersisted: Boolean, reloadCheckpointAfterSuspendCount: Int?, + lock: Semaphore, deduplicationHandler: DeduplicationHandler? = null, senderUUID: String? = null ): StateMachineState { @@ -211,7 +215,8 @@ class FlowCreator( isKilled = false, flowLogic = fiber.logic, senderUUID = senderUUID, - reloadCheckpointAfterSuspendCount = reloadCheckpointAfterSuspendCount + reloadCheckpointAfterSuspendCount = reloadCheckpointAfterSuspendCount, + lock = lock ) } } \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt index 6b0ad10698..b157b0d575 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt @@ -155,6 +155,16 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, internal val softLockedStates = mutableSetOf() + internal inline fun withFlowLock(block: FlowStateMachineImpl.() -> RESULT): RESULT { + transientState.lock.acquire() + return try { + block(this) + } finally { + transientState.lock.release() + } + } + + /** * Processes an event by creating the associated transition and executing it using the given executor. * Try to avoid using this directly, instead use [processEventsUntilFlowIsResumed] or [processEventImmediately] @@ -162,20 +172,23 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, */ @Suspendable private fun processEvent(transitionExecutor: TransitionExecutor, event: Event): FlowContinuation { - setLoggingContext() - val stateMachine = transientValues.stateMachine - val oldState = transientState - val actionExecutor = transientValues.actionExecutor - val transition = stateMachine.transition(event, oldState) - val (continuation, newState) = transitionExecutor.executeTransition(this, oldState, event, transition, actionExecutor) - // Ensure that the next state that is being written to the transient state maintains the [isKilled] flag - // This condition can be met if a flow is killed during [TransitionExecutor.executeTransition] - if (oldState.isKilled && !newState.isKilled) { - newState.isKilled = true + return withFlowLock { + setLoggingContext() + val stateMachine = transientValues.stateMachine + val oldState = transientState + val actionExecutor = transientValues.actionExecutor + val transition = stateMachine.transition(event, oldState) + val (continuation, newState) = transitionExecutor.executeTransition( + this, + oldState, + event, + transition, + actionExecutor + ) + transientState = newState + setLoggingContext() + continuation } - transientState = newState - setLoggingContext() - return continuation } /** diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt index d1aa9ee6aa..738b94e1ab 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt @@ -19,6 +19,7 @@ import net.corda.core.internal.concurrent.map import net.corda.core.internal.concurrent.mapError import net.corda.core.internal.concurrent.openFuture import net.corda.core.internal.mapNotNull +import net.corda.core.internal.uncheckedCast import net.corda.core.messaging.DataFeed import net.corda.core.serialization.deserialize import net.corda.core.serialization.internal.CheckpointSerializationContext @@ -72,6 +73,14 @@ internal class SingleThreadedStateMachineManager( ) : StateMachineManager, StateMachineManagerInternal { companion object { private val logger = contextLogger() + + private val VALID_KILL_FLOW_STATUSES = setOf( + Checkpoint.FlowStatus.RUNNABLE, + Checkpoint.FlowStatus.FAILED, + Checkpoint.FlowStatus.COMPLETED, + Checkpoint.FlowStatus.HOSPITALIZED, + Checkpoint.FlowStatus.PAUSED + ) } private val innerState = StateMachineInnerStateImpl() @@ -102,6 +111,26 @@ internal class SingleThreadedStateMachineManager( private val totalStartedFlows = metrics.counter("Flows.Started") private val totalFinishedFlows = metrics.counter("Flows.Finished") + private inline fun Flow.withFlowLock( + validStatuses: Set, + block: FlowStateMachineImpl.() -> Boolean + ): Boolean { + if (!fiber.hasValidStatus(validStatuses)) return false + return fiber.withFlowLock { + // Get the flow again, in case another thread removed it from the map + innerState.withLock { + flows[id]?.run { + if (!fiber.hasValidStatus(validStatuses)) return false + block(uncheckedCast(this.fiber)) + } + } ?: false + } + } + + private fun FlowStateMachineImpl<*>.hasValidStatus(validStatuses: Set): Boolean { + return transientState.checkpoint.status in validStatuses + } + /** * An observable that emits triples of the changing flow, the type of change, and a process-specific ID number * which may change across restarts. @@ -239,9 +268,9 @@ internal class SingleThreadedStateMachineManager( } override fun killFlow(id: StateMachineRunId): Boolean { - val killFlowResult = innerState.withLock { - val flow = flows[id] - if (flow != null) { + val flow = innerState.withLock { flows[id] } + val killFlowResult = if (flow != null) { + flow.withFlowLock(VALID_KILL_FLOW_STATUSES) { logger.info("Killing flow $id known to this node.") // The checkpoint and soft locks are removed here instead of relying on the processing of the next event after setting // the killed flag. This is to ensure a flow can be removed from the database, even if it is stuck in a infinite loop. @@ -249,24 +278,19 @@ internal class SingleThreadedStateMachineManager( checkpointStorage.removeCheckpoint(id) serviceHub.vaultService.softLockRelease(id.uuid) } - // the same code is NOT done in remove flow when an error occurs - // what is the point of this latch? + unfinishedFibers.countDown() - val state = flow.fiber.transientState - state.isKilled = true - flow.fiber.scheduleEvent(Event.DoRemainingWork) + flow.fiber.transientState = flow.fiber.transientState.copy(isKilled = true) + scheduleEvent(Event.DoRemainingWork) true - } else { - // It may be that the id refers to a checkpoint that couldn't be deserialised into a flow, so we delete it if it exists. - database.transaction { checkpointStorage.removeCheckpoint(id) } } - } - return if (killFlowResult) { - true } else { - flowHospital.dropSessionInit(id) + // It may be that the id refers to a checkpoint that couldn't be deserialised into a flow, so we delete it if it exists. + database.transaction { checkpointStorage.removeCheckpoint(id) } } + + return killFlowResult || flowHospital.dropSessionInit(id) } private fun markAllFlowsAsPaused() { @@ -390,7 +414,12 @@ internal class SingleThreadedStateMachineManager( val checkpoint = tryDeserializeCheckpoint(serializedCheckpoint, flowId) ?: return // Resurrect flow - flowCreator.createFlowFromCheckpoint(flowId, checkpoint, currentState.reloadCheckpointAfterSuspendCount) ?: return + flowCreator.createFlowFromCheckpoint( + flowId, + checkpoint, + currentState.reloadCheckpointAfterSuspendCount, + currentState.lock + ) ?: return } else { // Just flow initiation message null diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt index 5d8326b668..c94e38187a 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt @@ -22,6 +22,7 @@ import net.corda.node.services.messaging.DeduplicationHandler import java.lang.IllegalStateException import java.time.Instant import java.util.concurrent.Future +import java.util.concurrent.Semaphore /** * The state of the state machine, capturing the state of a flow. It consists of two parts, an *immutable* part that is @@ -41,9 +42,12 @@ import java.util.concurrent.Future * @param isRemoved true if the flow has been removed from the state machine manager. This is used to avoid any further * work. * @param isKilled true if the flow has been marked as killed. This is used to cause a flow to move to a killed flow transition no matter - * what event it is set to process next. [isKilled] is a `var` and set as [Volatile] to prevent concurrency errors that can occur if a flow - * is killed during the middle of a state transition. + * what event it is set to process next. * @param senderUUID the identifier of the sending state machine or null if this flow is resumed from a checkpoint so that it does not participate in de-duplication high-water-marking. + * @param reloadCheckpointAfterSuspendCount The number of times a flow has been reloaded (not retried). This is [null] when + * [NodeConfiguration.reloadCheckpointAfterSuspendCount] is not enabled. + * @param lock The flow's lock, used to prevent the flow performing a transition while being interacted with from external threads, and + * vise-versa. */ // TODO perhaps add a read-only environment to the state machine for things that don't change over time? // TODO evaluate persistent datastructure libraries to replace the inefficient copying we currently do. @@ -57,10 +61,10 @@ data class StateMachineState( val isAnyCheckpointPersisted: Boolean, val isStartIdempotent: Boolean, val isRemoved: Boolean, - @Volatile - var isKilled: Boolean, + val isKilled: Boolean, val senderUUID: String?, - val reloadCheckpointAfterSuspendCount: Int? + val reloadCheckpointAfterSuspendCount: Int?, + val lock: Semaphore ) : KryoSerializable { override fun write(kryo: Kryo?, output: Output?) { throw IllegalStateException("${StateMachineState::class.qualifiedName} should never be serialized")