CORDA-3899 Refactor flow's transient fields (#6441)

Refactor `FlowStateMachineImpl.transientValues` and
`FlowStateMachineImpl.transientState` to stop the fields from exposing
the fact that they are nullable.

This is done by having private backing fields `transientValuesReference`
and `transientStateReference` that can be null. The nullability is still
needed due to serialisation and deserialisation of flow fibers. The
fields are transient and therefore will be null when reloaded from the
database.

Getters and setters hide the private field, allowing a non-null field to
returned.

There is no point other than in `FlowCreator` where the transient fields
can be null. Therefore the non null checks that are being made are
valid.

Add custom kryo serialisation and deserialisation to `TransientValues`
and `StateMachineState` to ensure that neither of the objects are ever
touched by kryo.
This commit is contained in:
Dan Newton 2020-07-22 16:19:20 +01:00 committed by GitHub
parent 8ee070953a
commit a41152edf6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 118 additions and 92 deletions

View File

@ -69,11 +69,11 @@ class FlowCreator(
val checkpoint = oldCheckpoint.copy(status = Checkpoint.FlowStatus.RUNNABLE)
val fiber = checkpoint.getFiberFromCheckpoint(runId) ?: return null
val resultFuture = openFuture<Any?>()
fiber.transientValues = TransientReference(createTransientValues(runId, resultFuture))
fiber.logic.stateMachine = fiber
verifyFlowLogicIsSuspendable(fiber.logic)
val state = createStateMachineState(checkpoint, fiber, true)
fiber.transientState = TransientReference(state)
fiber.transientValues = createTransientValues(runId, resultFuture)
fiber.transientState = state
return Flow(fiber, resultFuture)
}
@ -91,7 +91,7 @@ class FlowCreator(
// have access to the fiber (and thereby the service hub)
val flowStateMachineImpl = FlowStateMachineImpl(flowId, flowLogic, scheduler)
val resultFuture = openFuture<Any?>()
flowStateMachineImpl.transientValues = TransientReference(createTransientValues(flowId, resultFuture))
flowStateMachineImpl.transientValues = createTransientValues(flowId, resultFuture)
flowLogic.stateMachine = flowStateMachineImpl
val frozenFlowLogic = (flowLogic as FlowLogic<*>).checkpointSerialize(context = checkpointSerializationContext)
val flowCorDappVersion = FlowStateMachineImpl.createSubFlowVersion(
@ -113,7 +113,7 @@ class FlowCreator(
existingCheckpoint != null,
deduplicationHandler,
senderUUID)
flowStateMachineImpl.transientState = TransientReference(state)
flowStateMachineImpl.transientState = state
return Flow(flowStateMachineImpl, resultFuture)
}

View File

@ -39,18 +39,14 @@ class FlowDefaultUncaughtExceptionHandler(
val id = fiber.id
if (!fiber.resultFuture.isDone) {
fiber.transientState.let { state ->
if (state != null) {
fiber.logger.warn("Forcing flow $id into overnight observation")
flowHospital.forceIntoOvernightObservation(state.value, listOf(throwable))
val hospitalizedCheckpoint = state.value.checkpoint.copy(status = Checkpoint.FlowStatus.HOSPITALIZED)
val hospitalizedState = state.value.copy(checkpoint = hospitalizedCheckpoint)
fiber.transientState = TransientReference(hospitalizedState)
} else {
fiber.logger.warn("The fiber's transient state is not set, cannot force flow $id into in-memory overnight observation, status will still be updated in database")
}
fiber.logger.warn("Forcing flow $id into overnight observation")
flowHospital.forceIntoOvernightObservation(state, listOf(throwable))
val hospitalizedCheckpoint = state.checkpoint.copy(status = Checkpoint.FlowStatus.HOSPITALIZED)
val hospitalizedState = state.copy(checkpoint = hospitalizedCheckpoint)
fiber.transientState = hospitalizedState
}
scheduledExecutor.schedule({ setFlowToHospitalizedRescheduleOnFailure(id) }, 0, TimeUnit.SECONDS)
}
scheduledExecutor.schedule({ setFlowToHospitalizedRescheduleOnFailure(id) }, 0, TimeUnit.SECONDS)
}
@Suppress("TooGenericExceptionCaught")

View File

@ -96,12 +96,12 @@ internal class FlowMonitor(
private fun FlowStateMachineImpl<*>.ioRequest() = (snapshot().checkpoint.flowState as? FlowState.Started)?.flowIORequest
private fun FlowStateMachineImpl<*>.ongoingDuration(now: Instant): Duration {
return transientState?.value?.checkpoint?.timestamp?.let { Duration.between(it, now) } ?: Duration.ZERO
return transientState.checkpoint.timestamp.let { Duration.between(it, now) } ?: Duration.ZERO
}
private fun FlowStateMachineImpl<*>.isSuspended() = !snapshot().isFlowResumed
private fun FlowStateMachineImpl<*>.isStarted() = transientState?.value?.checkpoint?.flowState is FlowState.Started
private fun FlowStateMachineImpl<*>.isStarted() = transientState.checkpoint.flowState is FlowState.Started
private operator fun StaffedFlowHospital.contains(flow: FlowStateMachine<*>) = contains(flow.id)
}

View File

@ -6,6 +6,10 @@ import co.paralleluniverse.fibers.FiberScheduler
import co.paralleluniverse.fibers.Suspendable
import co.paralleluniverse.strands.Strand
import co.paralleluniverse.strands.channels.Channel
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.KryoSerializable
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import net.corda.core.concurrent.CordaFuture
import net.corda.core.context.InvocationContext
import net.corda.core.contracts.StateRef
@ -58,7 +62,6 @@ import org.slf4j.Logger
import org.slf4j.LoggerFactory
import org.slf4j.MDC
import java.util.concurrent.TimeUnit
import kotlin.reflect.KProperty1
class FlowPermissionException(message: String) : FlowException(message)
@ -86,52 +89,65 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
private val SERIALIZER_BLOCKER = Fiber::class.java.getDeclaredField("SERIALIZER_BLOCKER").apply { isAccessible = true }.get(null)
}
override val serviceHub get() = getTransientField(TransientValues::serviceHub)
data class TransientValues(
val eventQueue: Channel<Event>,
val resultFuture: CordaFuture<Any?>,
val database: CordaPersistence,
val transitionExecutor: TransitionExecutor,
val actionExecutor: ActionExecutor,
val stateMachine: StateMachine,
val serviceHub: ServiceHubInternal,
val checkpointSerializationContext: CheckpointSerializationContext,
val unfinishedFibers: ReusableLatch,
val waitTimeUpdateHook: (id: StateMachineRunId, timeout: Long) -> Unit
)
val eventQueue: Channel<Event>,
val resultFuture: CordaFuture<Any?>,
val database: CordaPersistence,
val transitionExecutor: TransitionExecutor,
val actionExecutor: ActionExecutor,
val stateMachine: StateMachine,
val serviceHub: ServiceHubInternal,
val checkpointSerializationContext: CheckpointSerializationContext,
val unfinishedFibers: ReusableLatch,
val waitTimeUpdateHook: (id: StateMachineRunId, timeout: Long) -> Unit
) : KryoSerializable {
override fun write(kryo: Kryo?, output: Output?) {
throw IllegalStateException("${TransientValues::class.qualifiedName} should never be serialized")
}
internal var transientValues: TransientReference<TransientValues>? = null
internal var transientState: TransientReference<StateMachineState>? = null
/**
* What sender identifier to put on messages sent by this flow. This will either be the identifier for the current
* state machine manager / messaging client, or null to indicate this flow is restored from a checkpoint and
* the de-duplication of messages it sends should not be optimised since this could be unreliable.
*/
override val ourSenderUUID: String?
get() = transientState?.value?.senderUUID
private fun <A> getTransientField(field: KProperty1<TransientValues, A>): A {
val suppliedValues = transientValues ?: throw IllegalStateException("${field.name} wasn't supplied!")
return field.get(suppliedValues.value)
override fun read(kryo: Kryo?, input: Input?) {
throw IllegalStateException("${TransientValues::class.qualifiedName} should never be deserialized")
}
}
private fun extractThreadLocalTransaction(): TransientReference<DatabaseTransaction> {
val transaction = contextTransaction
contextTransactionOrNull = null
return TransientReference(transaction)
}
private var transientValuesReference: TransientReference<TransientValues>? = null
internal var transientValues: TransientValues
// After the flow has been created, the transient values should never be null
get() = transientValuesReference!!.value
set(values) {
check(transientValuesReference?.value == null) { "The transient values should only be set once when initialising a flow" }
transientValuesReference = TransientReference(values)
}
private var transientStateReference: TransientReference<StateMachineState>? = null
internal var transientState: StateMachineState
// After the flow has been created, the transient state should never be null
get() = transientStateReference!!.value
set(state) {
transientStateReference = TransientReference(state)
}
/**
* Return the logger for this state machine. The logger name incorporates [id] and so including it in the log message
* is not necessary.
*/
override val logger = log
override val resultFuture: CordaFuture<R> get() = uncheckedCast(getTransientField(TransientValues::resultFuture))
override val context: InvocationContext get() = transientState!!.value.checkpoint.checkpointState.invocationContext
override val ourIdentity: Party get() = transientState!!.value.checkpoint.checkpointState.ourIdentity
override val isKilled: Boolean get() = transientState!!.value.isKilled
override val instanceId: StateMachineInstanceId get() = StateMachineInstanceId(id, super.getId())
override val serviceHub: ServiceHubInternal get() = transientValues.serviceHub
override val stateMachine: StateMachine get() = transientValues.stateMachine
override val resultFuture: CordaFuture<R> get() = uncheckedCast(transientValues.resultFuture)
override val context: InvocationContext get() = transientState.checkpoint.checkpointState.invocationContext
override val ourIdentity: Party get() = transientState.checkpoint.checkpointState.ourIdentity
override val isKilled: Boolean get() = transientState.isKilled
/**
* What sender identifier to put on messages sent by this flow. This will either be the identifier for the current
* state machine manager / messaging client, or null to indicate this flow is restored from a checkpoint and
* the de-duplication of messages it sends should not be optimised since this could be unreliable.
*/
override val ourSenderUUID: String? get() = transientState.senderUUID
internal val softLockedStates = mutableSetOf<StateRef>()
@ -143,9 +159,9 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
@Suspendable
private fun processEvent(transitionExecutor: TransitionExecutor, event: Event): FlowContinuation {
setLoggingContext()
val stateMachine = getTransientField(TransientValues::stateMachine)
val oldState = transientState!!.value
val actionExecutor = getTransientField(TransientValues::actionExecutor)
val stateMachine = transientValues.stateMachine
val oldState = transientState
val actionExecutor = transientValues.actionExecutor
val transition = stateMachine.transition(event, oldState)
val (continuation, newState) = transitionExecutor.executeTransition(this, oldState, event, transition, actionExecutor)
// Ensure that the next state that is being written to the transient state maintains the [isKilled] flag
@ -153,7 +169,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
if (oldState.isKilled && !newState.isKilled) {
newState.isKilled = true
}
transientState = TransientReference(newState)
transientState = newState
setLoggingContext()
return continuation
}
@ -171,15 +187,15 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
@Suspendable
private fun processEventsUntilFlowIsResumed(isDbTransactionOpenOnEntry: Boolean, isDbTransactionOpenOnExit: Boolean): Any? {
checkDbTransaction(isDbTransactionOpenOnEntry)
val transitionExecutor = getTransientField(TransientValues::transitionExecutor)
val eventQueue = getTransientField(TransientValues::eventQueue)
val transitionExecutor = transientValues.transitionExecutor
val eventQueue = transientValues.eventQueue
try {
eventLoop@ while (true) {
val nextEvent = try {
eventQueue.receive()
} catch (interrupted: InterruptedException) {
log.error("Flow interrupted while waiting for events, aborting immediately")
(transientValues?.value?.resultFuture as? OpenFuture<*>)?.setException(KilledFlowException(id))
(transientValues.resultFuture as? OpenFuture<*>)?.setException(KilledFlowException(id))
abortFiber()
}
val continuation = processEvent(transitionExecutor, nextEvent)
@ -246,7 +262,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
isDbTransactionOpenOnEntry: Boolean,
isDbTransactionOpenOnExit: Boolean): FlowContinuation {
checkDbTransaction(isDbTransactionOpenOnEntry)
val transitionExecutor = getTransientField(TransientValues::transitionExecutor)
val transitionExecutor = transientValues.transitionExecutor
val continuation = processEvent(transitionExecutor, event)
checkDbTransaction(isDbTransactionOpenOnExit)
return continuation
@ -270,7 +286,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
}
private fun openThreadLocalWormhole() {
val threadLocal = getTransientField(TransientValues::database).hikariPoolThreadLocal
val threadLocal = transientValues.database.hikariPoolThreadLocal
if (threadLocal != null) {
val valueFromThread = swappedOutThreadLocalValue(threadLocal)
threadLocal.set(valueFromThread)
@ -332,7 +348,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
}
recordDuration(startTime)
getTransientField(TransientValues::unfinishedFibers).countDown()
transientValues.unfinishedFibers.countDown()
}
@Suspendable
@ -476,7 +492,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
@Suspendable
override fun <R : Any> suspend(ioRequest: FlowIORequest<R>, maySkipCheckpoint: Boolean): R {
val serializationContext = TransientReference(getTransientField(TransientValues::checkpointSerializationContext))
val serializationContext = TransientReference(transientValues.checkpointSerializationContext)
val transaction = extractThreadLocalTransaction()
parkAndSerialize { _, _ ->
setLoggingContext()
@ -524,13 +540,19 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
return subFlowStack.any { IdempotentFlow::class.java.isAssignableFrom(it.flowClass) }
}
private fun extractThreadLocalTransaction(): TransientReference<DatabaseTransaction> {
val transaction = contextTransaction
contextTransactionOrNull = null
return TransientReference(transaction)
}
@Suspendable
override fun scheduleEvent(event: Event) {
getTransientField(TransientValues::eventQueue).send(event)
transientValues.eventQueue.send(event)
}
override fun snapshot(): StateMachineState {
return transientState!!.value
return transientState
}
/**
@ -538,13 +560,9 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
* retried.
*/
override fun updateTimedFlowTimeout(timeoutSeconds: Long) {
getTransientField(TransientValues::waitTimeUpdateHook).invoke(id, timeoutSeconds)
transientValues.waitTimeUpdateHook.invoke(id, timeoutSeconds)
}
override val stateMachine get() = getTransientField(TransientValues::stateMachine)
override val instanceId: StateMachineInstanceId get() = StateMachineInstanceId(id, super.getId())
/**
* Records the duration of this flow from call() to completion or failure.
* Note that the duration will include the time the flow spent being parked, and not just the total

View File

@ -261,14 +261,9 @@ internal class SingleThreadedStateMachineManager(
unfinishedFibers.countDown()
val state = flow.fiber.transientState
return@withLock if (state != null) {
state.value.isKilled = true
flow.fiber.scheduleEvent(Event.DoRemainingWork)
true
} else {
logger.info("Flow $id has not been initialised correctly and cannot be killed")
false
}
state.isKilled = true
flow.fiber.scheduleEvent(Event.DoRemainingWork)
true
} else {
// It may be that the id refers to a checkpoint that couldn't be deserialised into a flow, so we delete it if it exists.
database.transaction { checkpointStorage.removeCheckpoint(id) }
@ -386,7 +381,7 @@ internal class SingleThreadedStateMachineManager(
currentState.cancelFutureIfRunning()
// Get set of external events
val flowId = currentState.flowLogic.runId
val oldFlowLeftOver = innerState.withLock { flows[flowId] }?.fiber?.transientValues?.value?.eventQueue
val oldFlowLeftOver = innerState.withLock { flows[flowId] }?.fiber?.transientValues?.eventQueue
if (oldFlowLeftOver == null) {
logger.error("Unable to find flow for flow $flowId. Something is very wrong. The flow will not retry.")
return
@ -592,7 +587,7 @@ internal class SingleThreadedStateMachineManager(
): CordaFuture<FlowStateMachine<A>> {
val existingFlow = innerState.withLock { flows[flowId] }
val existingCheckpoint = if (existingFlow != null && existingFlow.fiber.transientState?.value?.isAnyCheckpointPersisted == true) {
val existingCheckpoint = if (existingFlow != null && existingFlow.fiber.transientState.isAnyCheckpointPersisted) {
// Load the flow's checkpoint
// The checkpoint will be missing if the flow failed before persisting the original checkpoint
// CORDA-3359 - Do not start/retry a flow that failed after deleting its checkpoint (the whole of the flow might replay)
@ -756,7 +751,7 @@ internal class SingleThreadedStateMachineManager(
// The flow's event queue may be non-empty in case it shut down abruptly. We handle outstanding events here.
private fun drainFlowEventQueue(flow: Flow<*>) {
while (true) {
val event = flow.fiber.transientValues!!.value.eventQueue.tryReceive() ?: return
val event = flow.fiber.transientValues.eventQueue.tryReceive() ?: return
when (event) {
is Event.DoRemainingWork -> {}
is Event.DeliverSessionMessage -> {

View File

@ -1,5 +1,9 @@
package net.corda.node.services.statemachine
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.KryoSerializable
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import net.corda.core.context.InvocationContext
import net.corda.core.crypto.SecureHash
import net.corda.core.flows.Destination
@ -15,6 +19,7 @@ import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.utilities.Try
import net.corda.node.services.messaging.DeduplicationHandler
import java.lang.IllegalStateException
import java.time.Instant
import java.util.concurrent.Future
@ -55,7 +60,15 @@ data class StateMachineState(
@Volatile
var isKilled: Boolean,
val senderUUID: String?
)
) : KryoSerializable {
override fun write(kryo: Kryo?, output: Output?) {
throw IllegalStateException("${StateMachineState::class.qualifiedName} should never be serialized")
}
override fun read(kryo: Kryo?, input: Input?) {
throw IllegalStateException("${StateMachineState::class.qualifiedName} should never be deserialized")
}
}
/**
* @param checkpointState the state of the checkpoint

View File

@ -26,6 +26,7 @@ import net.corda.core.internal.FlowIORequest
import net.corda.core.internal.FlowStateMachine
import net.corda.core.internal.concurrent.flatMap
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.internal.declaredField
import net.corda.core.messaging.MessageRecipients
import net.corda.core.node.services.PartyInfo
import net.corda.core.node.services.queryBy
@ -173,9 +174,12 @@ class FlowFrameworkTests {
val flow = ReceiveFlow(bob)
val fiber = aliceNode.services.startFlow(flow) as FlowStateMachineImpl
// Before the flow runs change the suspend action to throw an exception
val throwingActionExecutor = SuspendThrowingActionExecutor(Exception("Thrown during suspend"),
fiber.transientValues!!.value.actionExecutor)
fiber.transientValues = TransientReference(fiber.transientValues!!.value.copy(actionExecutor = throwingActionExecutor))
val throwingActionExecutor = SuspendThrowingActionExecutor(
Exception("Thrown during suspend"),
fiber.transientValues.actionExecutor
)
fiber.declaredField<TransientReference<FlowStateMachineImpl.TransientValues>>("transientValuesReference").value =
TransientReference(fiber.transientValues.copy(actionExecutor = throwingActionExecutor))
mockNet.runNetwork()
fiber.resultFuture.getOrThrow()
assertThat(aliceNode.smm.allStateMachines).isEmpty()
@ -679,14 +683,14 @@ class FlowFrameworkTests {
SuspendingFlow.hookBeforeCheckpoint = {
val flowFiber = this as? FlowStateMachineImpl<*>
flowState = flowFiber!!.transientState!!.value.checkpoint.flowState
flowState = flowFiber!!.transientState.checkpoint.flowState
if (firstExecution) {
throw HospitalizeFlowException()
} else {
dbCheckpointStatusBeforeSuspension = aliceNode.internals.checkpointStorage.getCheckpoints().toList().single().second.status
currentDBSession().clear() // clear session as Hibernate with fails with 'org.hibernate.NonUniqueObjectException' once it tries to save a DBFlowCheckpoint upon checkpoint
inMemoryCheckpointStatusBeforeSuspension = flowFiber.transientState!!.value.checkpoint.status
inMemoryCheckpointStatusBeforeSuspension = flowFiber.transientState.checkpoint.status
futureFiber.complete(flowFiber)
}
@ -701,7 +705,7 @@ class FlowFrameworkTests {
}
// flow is in hospital
assertTrue(flowState is FlowState.Unstarted)
val inMemoryHospitalizedCheckpointStatus = aliceNode.internals.smm.snapshot().first().transientState?.value?.checkpoint?.status
val inMemoryHospitalizedCheckpointStatus = aliceNode.internals.smm.snapshot().first().transientState.checkpoint.status
assertEquals(Checkpoint.FlowStatus.HOSPITALIZED, inMemoryHospitalizedCheckpointStatus)
aliceNode.database.transaction {
val checkpoint = aliceNode.internals.checkpointStorage.getCheckpoints().toList().single().second
@ -727,13 +731,13 @@ class FlowFrameworkTests {
SuspendingFlow.hookAfterCheckpoint = {
val flowFiber = this as? FlowStateMachineImpl<*>
flowState = flowFiber!!.transientState!!.value.checkpoint.flowState
flowState = flowFiber!!.transientState.checkpoint.flowState
if (firstExecution) {
throw HospitalizeFlowException()
} else {
dbCheckpointStatus = aliceNode.internals.checkpointStorage.getCheckpoints().toList().single().second.status
inMemoryCheckpointStatus = flowFiber.transientState!!.value.checkpoint.status
inMemoryCheckpointStatus = flowFiber.transientState.checkpoint.status
futureFiber.complete(flowFiber)
}
@ -820,7 +824,7 @@ class FlowFrameworkTests {
} else {
val flowFiber = this as? FlowStateMachineImpl<*>
dbCheckpointStatus = aliceNode.internals.checkpointStorage.getCheckpoints().toList().single().second.status
inMemoryCheckpointStatus = flowFiber!!.transientState!!.value.checkpoint.status
inMemoryCheckpointStatus = flowFiber!!.transientState.checkpoint.status
persistedException = aliceNode.internals.checkpointStorage.getDBCheckpoint(flowFiber.id)!!.exceptionDetails
}
}