CORDA-3850 Add a per flow lock ()

Add a lock to `StateMachineState`, allowing every flow to lock
themselves when performing a transition or when an external thread (such
as `killFlow`) tries to interact with a flow from occurring at the same
time.

Doing this prevents race-conditions where the external threads mutate
the database or the flow's state causing an in-flight transition to
fail.

A `Semaphore` is used to acquire and release the lock. A `ReentrantLock`
is not used as it is possible for a flow to suspend while locked, and
resume on a different thread. This causes a `ReentrantLock` to fail when
releasing the lock because the thread doing so is not the thread holding
the lock. `Semaphore`s can be used across threads, therefore bypassing
this issue.

The lock is copied across when a flow is retried. This is to prevent
another thread from interacting with a flow just after it has been
retried. Without copying the lock, the external thread would acquire the
old lock and execute, while the fiber thread acquires the new lock and
also executes.
This commit is contained in:
Dan Newton 2020-08-06 09:51:42 +01:00 committed by GitHub
parent fd374bfc6d
commit a73dad00e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 91 additions and 40 deletions

View File

@ -25,6 +25,7 @@ import net.corda.node.utilities.isEnabledTimedFlow
import net.corda.nodeapi.internal.persistence.CordaPersistence
import org.apache.activemq.artemis.utils.ReusableLatch
import java.security.SecureRandom
import java.util.concurrent.Semaphore
class Flow<A>(val fiber: FlowStateMachineImpl<A>, val resultFuture: OpenFuture<Any?>)
@ -71,22 +72,23 @@ class FlowCreator(
fun createFlowFromCheckpoint(
runId: StateMachineRunId,
oldCheckpoint: Checkpoint,
reloadCheckpointAfterSuspendCount: Int? = null
reloadCheckpointAfterSuspendCount: Int? = null,
lock: Semaphore = Semaphore(1)
): Flow<*>? {
val checkpoint = oldCheckpoint.copy(status = Checkpoint.FlowStatus.RUNNABLE)
val fiber = checkpoint.getFiberFromCheckpoint(runId) ?: return null
val resultFuture = openFuture<Any?>()
fiber.logic.stateMachine = fiber
verifyFlowLogicIsSuspendable(fiber.logic)
val state = createStateMachineState(
fiber.transientValues = createTransientValues(runId, resultFuture)
fiber.transientState = createStateMachineState(
checkpoint = checkpoint,
fiber = fiber,
anyCheckpointPersisted = true,
reloadCheckpointAfterSuspendCount = reloadCheckpointAfterSuspendCount
?: if (reloadCheckpointAfterSuspend) checkpoint.checkpointState.numberOfSuspends else null
?: if (reloadCheckpointAfterSuspend) checkpoint.checkpointState.numberOfSuspends else null,
lock = lock
)
fiber.transientValues = createTransientValues(runId, resultFuture)
fiber.transientState = state
return Flow(fiber, resultFuture)
}
@ -125,6 +127,7 @@ class FlowCreator(
fiber = flowStateMachineImpl,
anyCheckpointPersisted = existingCheckpoint != null,
reloadCheckpointAfterSuspendCount = if (reloadCheckpointAfterSuspend) 0 else null,
lock = Semaphore(1),
deduplicationHandler = deduplicationHandler,
senderUUID = senderUUID
)
@ -196,6 +199,7 @@ class FlowCreator(
fiber: FlowStateMachineImpl<*>,
anyCheckpointPersisted: Boolean,
reloadCheckpointAfterSuspendCount: Int?,
lock: Semaphore,
deduplicationHandler: DeduplicationHandler? = null,
senderUUID: String? = null
): StateMachineState {
@ -211,7 +215,8 @@ class FlowCreator(
isKilled = false,
flowLogic = fiber.logic,
senderUUID = senderUUID,
reloadCheckpointAfterSuspendCount = reloadCheckpointAfterSuspendCount
reloadCheckpointAfterSuspendCount = reloadCheckpointAfterSuspendCount,
lock = lock
)
}
}

View File

@ -155,6 +155,16 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
internal val softLockedStates = mutableSetOf<StateRef>()
internal inline fun <RESULT> withFlowLock(block: FlowStateMachineImpl<R>.() -> RESULT): RESULT {
transientState.lock.acquire()
return try {
block(this)
} finally {
transientState.lock.release()
}
}
/**
* Processes an event by creating the associated transition and executing it using the given executor.
* Try to avoid using this directly, instead use [processEventsUntilFlowIsResumed] or [processEventImmediately]
@ -162,20 +172,23 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
*/
@Suspendable
private fun processEvent(transitionExecutor: TransitionExecutor, event: Event): FlowContinuation {
setLoggingContext()
val stateMachine = transientValues.stateMachine
val oldState = transientState
val actionExecutor = transientValues.actionExecutor
val transition = stateMachine.transition(event, oldState)
val (continuation, newState) = transitionExecutor.executeTransition(this, oldState, event, transition, actionExecutor)
// Ensure that the next state that is being written to the transient state maintains the [isKilled] flag
// This condition can be met if a flow is killed during [TransitionExecutor.executeTransition]
if (oldState.isKilled && !newState.isKilled) {
newState.isKilled = true
return withFlowLock {
setLoggingContext()
val stateMachine = transientValues.stateMachine
val oldState = transientState
val actionExecutor = transientValues.actionExecutor
val transition = stateMachine.transition(event, oldState)
val (continuation, newState) = transitionExecutor.executeTransition(
this,
oldState,
event,
transition,
actionExecutor
)
transientState = newState
setLoggingContext()
continuation
}
transientState = newState
setLoggingContext()
return continuation
}
/**

View File

@ -19,6 +19,7 @@ import net.corda.core.internal.concurrent.map
import net.corda.core.internal.concurrent.mapError
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.internal.mapNotNull
import net.corda.core.internal.uncheckedCast
import net.corda.core.messaging.DataFeed
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.internal.CheckpointSerializationContext
@ -72,6 +73,14 @@ internal class SingleThreadedStateMachineManager(
) : StateMachineManager, StateMachineManagerInternal {
companion object {
private val logger = contextLogger()
private val VALID_KILL_FLOW_STATUSES = setOf(
Checkpoint.FlowStatus.RUNNABLE,
Checkpoint.FlowStatus.FAILED,
Checkpoint.FlowStatus.COMPLETED,
Checkpoint.FlowStatus.HOSPITALIZED,
Checkpoint.FlowStatus.PAUSED
)
}
private val innerState = StateMachineInnerStateImpl()
@ -102,6 +111,26 @@ internal class SingleThreadedStateMachineManager(
private val totalStartedFlows = metrics.counter("Flows.Started")
private val totalFinishedFlows = metrics.counter("Flows.Finished")
private inline fun <R> Flow<R>.withFlowLock(
validStatuses: Set<Checkpoint.FlowStatus>,
block: FlowStateMachineImpl<R>.() -> Boolean
): Boolean {
if (!fiber.hasValidStatus(validStatuses)) return false
return fiber.withFlowLock {
// Get the flow again, in case another thread removed it from the map
innerState.withLock {
flows[id]?.run {
if (!fiber.hasValidStatus(validStatuses)) return false
block(uncheckedCast(this.fiber))
}
} ?: false
}
}
private fun FlowStateMachineImpl<*>.hasValidStatus(validStatuses: Set<Checkpoint.FlowStatus>): Boolean {
return transientState.checkpoint.status in validStatuses
}
/**
* An observable that emits triples of the changing flow, the type of change, and a process-specific ID number
* which may change across restarts.
@ -239,9 +268,9 @@ internal class SingleThreadedStateMachineManager(
}
override fun killFlow(id: StateMachineRunId): Boolean {
val killFlowResult = innerState.withLock {
val flow = flows[id]
if (flow != null) {
val flow = innerState.withLock { flows[id] }
val killFlowResult = if (flow != null) {
flow.withFlowLock(VALID_KILL_FLOW_STATUSES) {
logger.info("Killing flow $id known to this node.")
// The checkpoint and soft locks are removed here instead of relying on the processing of the next event after setting
// the killed flag. This is to ensure a flow can be removed from the database, even if it is stuck in a infinite loop.
@ -249,24 +278,19 @@ internal class SingleThreadedStateMachineManager(
checkpointStorage.removeCheckpoint(id)
serviceHub.vaultService.softLockRelease(id.uuid)
}
// the same code is NOT done in remove flow when an error occurs
// what is the point of this latch?
unfinishedFibers.countDown()
val state = flow.fiber.transientState
state.isKilled = true
flow.fiber.scheduleEvent(Event.DoRemainingWork)
flow.fiber.transientState = flow.fiber.transientState.copy(isKilled = true)
scheduleEvent(Event.DoRemainingWork)
true
} else {
// It may be that the id refers to a checkpoint that couldn't be deserialised into a flow, so we delete it if it exists.
database.transaction { checkpointStorage.removeCheckpoint(id) }
}
}
return if (killFlowResult) {
true
} else {
flowHospital.dropSessionInit(id)
// It may be that the id refers to a checkpoint that couldn't be deserialised into a flow, so we delete it if it exists.
database.transaction { checkpointStorage.removeCheckpoint(id) }
}
return killFlowResult || flowHospital.dropSessionInit(id)
}
private fun markAllFlowsAsPaused() {
@ -390,7 +414,12 @@ internal class SingleThreadedStateMachineManager(
val checkpoint = tryDeserializeCheckpoint(serializedCheckpoint, flowId) ?: return
// Resurrect flow
flowCreator.createFlowFromCheckpoint(flowId, checkpoint, currentState.reloadCheckpointAfterSuspendCount) ?: return
flowCreator.createFlowFromCheckpoint(
flowId,
checkpoint,
currentState.reloadCheckpointAfterSuspendCount,
currentState.lock
) ?: return
} else {
// Just flow initiation message
null

View File

@ -22,6 +22,7 @@ import net.corda.node.services.messaging.DeduplicationHandler
import java.lang.IllegalStateException
import java.time.Instant
import java.util.concurrent.Future
import java.util.concurrent.Semaphore
/**
* The state of the state machine, capturing the state of a flow. It consists of two parts, an *immutable* part that is
@ -41,9 +42,12 @@ import java.util.concurrent.Future
* @param isRemoved true if the flow has been removed from the state machine manager. This is used to avoid any further
* work.
* @param isKilled true if the flow has been marked as killed. This is used to cause a flow to move to a killed flow transition no matter
* what event it is set to process next. [isKilled] is a `var` and set as [Volatile] to prevent concurrency errors that can occur if a flow
* is killed during the middle of a state transition.
* what event it is set to process next.
* @param senderUUID the identifier of the sending state machine or null if this flow is resumed from a checkpoint so that it does not participate in de-duplication high-water-marking.
* @param reloadCheckpointAfterSuspendCount The number of times a flow has been reloaded (not retried). This is [null] when
* [NodeConfiguration.reloadCheckpointAfterSuspendCount] is not enabled.
* @param lock The flow's lock, used to prevent the flow performing a transition while being interacted with from external threads, and
* vise-versa.
*/
// TODO perhaps add a read-only environment to the state machine for things that don't change over time?
// TODO evaluate persistent datastructure libraries to replace the inefficient copying we currently do.
@ -57,10 +61,10 @@ data class StateMachineState(
val isAnyCheckpointPersisted: Boolean,
val isStartIdempotent: Boolean,
val isRemoved: Boolean,
@Volatile
var isKilled: Boolean,
val isKilled: Boolean,
val senderUUID: String?,
val reloadCheckpointAfterSuspendCount: Int?
val reloadCheckpointAfterSuspendCount: Int?,
val lock: Semaphore
) : KryoSerializable {
override fun write(kryo: Kryo?, output: Output?) {
throw IllegalStateException("${StateMachineState::class.qualifiedName} should never be serialized")