CORDA-3850 Add a per flow lock (#6437)

Add a lock to `StateMachineState`, allowing every flow to lock
themselves when performing a transition or when an external thread (such
as `killFlow`) tries to interact with a flow from occurring at the same
time.

Doing this prevents race-conditions where the external threads mutate
the database or the flow's state causing an in-flight transition to
fail.

A `Semaphore` is used to acquire and release the lock. A `ReentrantLock`
is not used as it is possible for a flow to suspend while locked, and
resume on a different thread. This causes a `ReentrantLock` to fail when
releasing the lock because the thread doing so is not the thread holding
the lock. `Semaphore`s can be used across threads, therefore bypassing
this issue.

The lock is copied across when a flow is retried. This is to prevent
another thread from interacting with a flow just after it has been
retried. Without copying the lock, the external thread would acquire the
old lock and execute, while the fiber thread acquires the new lock and
also executes.
This commit is contained in:
Dan Newton 2020-08-06 09:51:42 +01:00 committed by GitHub
parent fd374bfc6d
commit a73dad00e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 91 additions and 40 deletions

View File

@ -25,6 +25,7 @@ import net.corda.node.utilities.isEnabledTimedFlow
import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.CordaPersistence
import org.apache.activemq.artemis.utils.ReusableLatch import org.apache.activemq.artemis.utils.ReusableLatch
import java.security.SecureRandom import java.security.SecureRandom
import java.util.concurrent.Semaphore
class Flow<A>(val fiber: FlowStateMachineImpl<A>, val resultFuture: OpenFuture<Any?>) class Flow<A>(val fiber: FlowStateMachineImpl<A>, val resultFuture: OpenFuture<Any?>)
@ -71,22 +72,23 @@ class FlowCreator(
fun createFlowFromCheckpoint( fun createFlowFromCheckpoint(
runId: StateMachineRunId, runId: StateMachineRunId,
oldCheckpoint: Checkpoint, oldCheckpoint: Checkpoint,
reloadCheckpointAfterSuspendCount: Int? = null reloadCheckpointAfterSuspendCount: Int? = null,
lock: Semaphore = Semaphore(1)
): Flow<*>? { ): Flow<*>? {
val checkpoint = oldCheckpoint.copy(status = Checkpoint.FlowStatus.RUNNABLE) val checkpoint = oldCheckpoint.copy(status = Checkpoint.FlowStatus.RUNNABLE)
val fiber = checkpoint.getFiberFromCheckpoint(runId) ?: return null val fiber = checkpoint.getFiberFromCheckpoint(runId) ?: return null
val resultFuture = openFuture<Any?>() val resultFuture = openFuture<Any?>()
fiber.logic.stateMachine = fiber fiber.logic.stateMachine = fiber
verifyFlowLogicIsSuspendable(fiber.logic) verifyFlowLogicIsSuspendable(fiber.logic)
val state = createStateMachineState( fiber.transientValues = createTransientValues(runId, resultFuture)
fiber.transientState = createStateMachineState(
checkpoint = checkpoint, checkpoint = checkpoint,
fiber = fiber, fiber = fiber,
anyCheckpointPersisted = true, anyCheckpointPersisted = true,
reloadCheckpointAfterSuspendCount = reloadCheckpointAfterSuspendCount reloadCheckpointAfterSuspendCount = reloadCheckpointAfterSuspendCount
?: if (reloadCheckpointAfterSuspend) checkpoint.checkpointState.numberOfSuspends else null ?: if (reloadCheckpointAfterSuspend) checkpoint.checkpointState.numberOfSuspends else null,
lock = lock
) )
fiber.transientValues = createTransientValues(runId, resultFuture)
fiber.transientState = state
return Flow(fiber, resultFuture) return Flow(fiber, resultFuture)
} }
@ -125,6 +127,7 @@ class FlowCreator(
fiber = flowStateMachineImpl, fiber = flowStateMachineImpl,
anyCheckpointPersisted = existingCheckpoint != null, anyCheckpointPersisted = existingCheckpoint != null,
reloadCheckpointAfterSuspendCount = if (reloadCheckpointAfterSuspend) 0 else null, reloadCheckpointAfterSuspendCount = if (reloadCheckpointAfterSuspend) 0 else null,
lock = Semaphore(1),
deduplicationHandler = deduplicationHandler, deduplicationHandler = deduplicationHandler,
senderUUID = senderUUID senderUUID = senderUUID
) )
@ -196,6 +199,7 @@ class FlowCreator(
fiber: FlowStateMachineImpl<*>, fiber: FlowStateMachineImpl<*>,
anyCheckpointPersisted: Boolean, anyCheckpointPersisted: Boolean,
reloadCheckpointAfterSuspendCount: Int?, reloadCheckpointAfterSuspendCount: Int?,
lock: Semaphore,
deduplicationHandler: DeduplicationHandler? = null, deduplicationHandler: DeduplicationHandler? = null,
senderUUID: String? = null senderUUID: String? = null
): StateMachineState { ): StateMachineState {
@ -211,7 +215,8 @@ class FlowCreator(
isKilled = false, isKilled = false,
flowLogic = fiber.logic, flowLogic = fiber.logic,
senderUUID = senderUUID, senderUUID = senderUUID,
reloadCheckpointAfterSuspendCount = reloadCheckpointAfterSuspendCount reloadCheckpointAfterSuspendCount = reloadCheckpointAfterSuspendCount,
lock = lock
) )
} }
} }

View File

@ -155,6 +155,16 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
internal val softLockedStates = mutableSetOf<StateRef>() internal val softLockedStates = mutableSetOf<StateRef>()
internal inline fun <RESULT> withFlowLock(block: FlowStateMachineImpl<R>.() -> RESULT): RESULT {
transientState.lock.acquire()
return try {
block(this)
} finally {
transientState.lock.release()
}
}
/** /**
* Processes an event by creating the associated transition and executing it using the given executor. * Processes an event by creating the associated transition and executing it using the given executor.
* Try to avoid using this directly, instead use [processEventsUntilFlowIsResumed] or [processEventImmediately] * Try to avoid using this directly, instead use [processEventsUntilFlowIsResumed] or [processEventImmediately]
@ -162,20 +172,23 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
*/ */
@Suspendable @Suspendable
private fun processEvent(transitionExecutor: TransitionExecutor, event: Event): FlowContinuation { private fun processEvent(transitionExecutor: TransitionExecutor, event: Event): FlowContinuation {
setLoggingContext() return withFlowLock {
val stateMachine = transientValues.stateMachine setLoggingContext()
val oldState = transientState val stateMachine = transientValues.stateMachine
val actionExecutor = transientValues.actionExecutor val oldState = transientState
val transition = stateMachine.transition(event, oldState) val actionExecutor = transientValues.actionExecutor
val (continuation, newState) = transitionExecutor.executeTransition(this, oldState, event, transition, actionExecutor) val transition = stateMachine.transition(event, oldState)
// Ensure that the next state that is being written to the transient state maintains the [isKilled] flag val (continuation, newState) = transitionExecutor.executeTransition(
// This condition can be met if a flow is killed during [TransitionExecutor.executeTransition] this,
if (oldState.isKilled && !newState.isKilled) { oldState,
newState.isKilled = true event,
transition,
actionExecutor
)
transientState = newState
setLoggingContext()
continuation
} }
transientState = newState
setLoggingContext()
return continuation
} }
/** /**

View File

@ -19,6 +19,7 @@ import net.corda.core.internal.concurrent.map
import net.corda.core.internal.concurrent.mapError import net.corda.core.internal.concurrent.mapError
import net.corda.core.internal.concurrent.openFuture import net.corda.core.internal.concurrent.openFuture
import net.corda.core.internal.mapNotNull import net.corda.core.internal.mapNotNull
import net.corda.core.internal.uncheckedCast
import net.corda.core.messaging.DataFeed import net.corda.core.messaging.DataFeed
import net.corda.core.serialization.deserialize import net.corda.core.serialization.deserialize
import net.corda.core.serialization.internal.CheckpointSerializationContext import net.corda.core.serialization.internal.CheckpointSerializationContext
@ -72,6 +73,14 @@ internal class SingleThreadedStateMachineManager(
) : StateMachineManager, StateMachineManagerInternal { ) : StateMachineManager, StateMachineManagerInternal {
companion object { companion object {
private val logger = contextLogger() private val logger = contextLogger()
private val VALID_KILL_FLOW_STATUSES = setOf(
Checkpoint.FlowStatus.RUNNABLE,
Checkpoint.FlowStatus.FAILED,
Checkpoint.FlowStatus.COMPLETED,
Checkpoint.FlowStatus.HOSPITALIZED,
Checkpoint.FlowStatus.PAUSED
)
} }
private val innerState = StateMachineInnerStateImpl() private val innerState = StateMachineInnerStateImpl()
@ -102,6 +111,26 @@ internal class SingleThreadedStateMachineManager(
private val totalStartedFlows = metrics.counter("Flows.Started") private val totalStartedFlows = metrics.counter("Flows.Started")
private val totalFinishedFlows = metrics.counter("Flows.Finished") private val totalFinishedFlows = metrics.counter("Flows.Finished")
private inline fun <R> Flow<R>.withFlowLock(
validStatuses: Set<Checkpoint.FlowStatus>,
block: FlowStateMachineImpl<R>.() -> Boolean
): Boolean {
if (!fiber.hasValidStatus(validStatuses)) return false
return fiber.withFlowLock {
// Get the flow again, in case another thread removed it from the map
innerState.withLock {
flows[id]?.run {
if (!fiber.hasValidStatus(validStatuses)) return false
block(uncheckedCast(this.fiber))
}
} ?: false
}
}
private fun FlowStateMachineImpl<*>.hasValidStatus(validStatuses: Set<Checkpoint.FlowStatus>): Boolean {
return transientState.checkpoint.status in validStatuses
}
/** /**
* An observable that emits triples of the changing flow, the type of change, and a process-specific ID number * An observable that emits triples of the changing flow, the type of change, and a process-specific ID number
* which may change across restarts. * which may change across restarts.
@ -239,9 +268,9 @@ internal class SingleThreadedStateMachineManager(
} }
override fun killFlow(id: StateMachineRunId): Boolean { override fun killFlow(id: StateMachineRunId): Boolean {
val killFlowResult = innerState.withLock { val flow = innerState.withLock { flows[id] }
val flow = flows[id] val killFlowResult = if (flow != null) {
if (flow != null) { flow.withFlowLock(VALID_KILL_FLOW_STATUSES) {
logger.info("Killing flow $id known to this node.") logger.info("Killing flow $id known to this node.")
// The checkpoint and soft locks are removed here instead of relying on the processing of the next event after setting // The checkpoint and soft locks are removed here instead of relying on the processing of the next event after setting
// the killed flag. This is to ensure a flow can be removed from the database, even if it is stuck in a infinite loop. // the killed flag. This is to ensure a flow can be removed from the database, even if it is stuck in a infinite loop.
@ -249,24 +278,19 @@ internal class SingleThreadedStateMachineManager(
checkpointStorage.removeCheckpoint(id) checkpointStorage.removeCheckpoint(id)
serviceHub.vaultService.softLockRelease(id.uuid) serviceHub.vaultService.softLockRelease(id.uuid)
} }
// the same code is NOT done in remove flow when an error occurs
// what is the point of this latch?
unfinishedFibers.countDown() unfinishedFibers.countDown()
val state = flow.fiber.transientState flow.fiber.transientState = flow.fiber.transientState.copy(isKilled = true)
state.isKilled = true scheduleEvent(Event.DoRemainingWork)
flow.fiber.scheduleEvent(Event.DoRemainingWork)
true true
} else {
// It may be that the id refers to a checkpoint that couldn't be deserialised into a flow, so we delete it if it exists.
database.transaction { checkpointStorage.removeCheckpoint(id) }
} }
}
return if (killFlowResult) {
true
} else { } else {
flowHospital.dropSessionInit(id) // It may be that the id refers to a checkpoint that couldn't be deserialised into a flow, so we delete it if it exists.
database.transaction { checkpointStorage.removeCheckpoint(id) }
} }
return killFlowResult || flowHospital.dropSessionInit(id)
} }
private fun markAllFlowsAsPaused() { private fun markAllFlowsAsPaused() {
@ -390,7 +414,12 @@ internal class SingleThreadedStateMachineManager(
val checkpoint = tryDeserializeCheckpoint(serializedCheckpoint, flowId) ?: return val checkpoint = tryDeserializeCheckpoint(serializedCheckpoint, flowId) ?: return
// Resurrect flow // Resurrect flow
flowCreator.createFlowFromCheckpoint(flowId, checkpoint, currentState.reloadCheckpointAfterSuspendCount) ?: return flowCreator.createFlowFromCheckpoint(
flowId,
checkpoint,
currentState.reloadCheckpointAfterSuspendCount,
currentState.lock
) ?: return
} else { } else {
// Just flow initiation message // Just flow initiation message
null null

View File

@ -22,6 +22,7 @@ import net.corda.node.services.messaging.DeduplicationHandler
import java.lang.IllegalStateException import java.lang.IllegalStateException
import java.time.Instant import java.time.Instant
import java.util.concurrent.Future import java.util.concurrent.Future
import java.util.concurrent.Semaphore
/** /**
* The state of the state machine, capturing the state of a flow. It consists of two parts, an *immutable* part that is * The state of the state machine, capturing the state of a flow. It consists of two parts, an *immutable* part that is
@ -41,9 +42,12 @@ import java.util.concurrent.Future
* @param isRemoved true if the flow has been removed from the state machine manager. This is used to avoid any further * @param isRemoved true if the flow has been removed from the state machine manager. This is used to avoid any further
* work. * work.
* @param isKilled true if the flow has been marked as killed. This is used to cause a flow to move to a killed flow transition no matter * @param isKilled true if the flow has been marked as killed. This is used to cause a flow to move to a killed flow transition no matter
* what event it is set to process next. [isKilled] is a `var` and set as [Volatile] to prevent concurrency errors that can occur if a flow * what event it is set to process next.
* is killed during the middle of a state transition.
* @param senderUUID the identifier of the sending state machine or null if this flow is resumed from a checkpoint so that it does not participate in de-duplication high-water-marking. * @param senderUUID the identifier of the sending state machine or null if this flow is resumed from a checkpoint so that it does not participate in de-duplication high-water-marking.
* @param reloadCheckpointAfterSuspendCount The number of times a flow has been reloaded (not retried). This is [null] when
* [NodeConfiguration.reloadCheckpointAfterSuspendCount] is not enabled.
* @param lock The flow's lock, used to prevent the flow performing a transition while being interacted with from external threads, and
* vise-versa.
*/ */
// TODO perhaps add a read-only environment to the state machine for things that don't change over time? // TODO perhaps add a read-only environment to the state machine for things that don't change over time?
// TODO evaluate persistent datastructure libraries to replace the inefficient copying we currently do. // TODO evaluate persistent datastructure libraries to replace the inefficient copying we currently do.
@ -57,10 +61,10 @@ data class StateMachineState(
val isAnyCheckpointPersisted: Boolean, val isAnyCheckpointPersisted: Boolean,
val isStartIdempotent: Boolean, val isStartIdempotent: Boolean,
val isRemoved: Boolean, val isRemoved: Boolean,
@Volatile val isKilled: Boolean,
var isKilled: Boolean,
val senderUUID: String?, val senderUUID: String?,
val reloadCheckpointAfterSuspendCount: Int? val reloadCheckpointAfterSuspendCount: Int?,
val lock: Semaphore
) : KryoSerializable { ) : KryoSerializable {
override fun write(kryo: Kryo?, output: Output?) { override fun write(kryo: Kryo?, output: Output?) {
throw IllegalStateException("${StateMachineState::class.qualifiedName} should never be serialized") throw IllegalStateException("${StateMachineState::class.qualifiedName} should never be serialized")