CORDA-3490 Add option to start node without starting checkpointed flows (#6136)

Added command-line option: `--pause-all-flows` to the Node to control this.
This mode causes all checkpoints to be set to status PAUSED when the
state machine starts up (in StartMode.Safe mode). 

Changed the state machine so that PAUSED checkpoints are loaded into
memory (the checkpoint is deserialised but the flow state is left serialised)
but not started.

Messages from peers are queued whilst the flow is paused and processed
once the flow is resumed.
This commit is contained in:
williamvigorr3 2020-05-19 16:27:41 +01:00 committed by GitHub
parent 356172c370
commit eb52de1b40
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 723 additions and 220 deletions

View File

@ -9,4 +9,4 @@ package net.corda.common.logging
* (originally added to source control for ease of use)
*/
internal const val CURRENT_MAJOR_RELEASE = "4.6-SNAPSHOT"
internal const val CURRENT_MAJOR_RELEASE = "4.6-SNAPSHOT"

View File

@ -1,5 +1,6 @@
package net.corda.core.internal.messaging
import net.corda.core.flows.StateMachineRunId
import net.corda.core.internal.AttachmentTrustInfo
import net.corda.core.messaging.CordaRPCOps
@ -13,4 +14,11 @@ interface InternalCordaRPCOps : CordaRPCOps {
/** Get all attachment trust information */
val attachmentTrustInfos: List<AttachmentTrustInfo>
/**
* Resume a paused flow.
*
* @return whether the flow was successfully resumed.
*/
fun unPauseFlow(id: StateMachineRunId): Boolean
}

View File

@ -0,0 +1,113 @@
package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowSession
import net.corda.core.flows.InitiatedBy
import net.corda.core.flows.InitiatingFlow
import net.corda.core.flows.StartableByRPC
import net.corda.core.flows.StateMachineRunId
import net.corda.core.identity.Party
import net.corda.core.internal.messaging.InternalCordaRPCOps
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.messaging.startFlow
import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.unwrap
import net.corda.node.services.Permissions
import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.BOB_NAME
import net.corda.testing.driver.DriverParameters
import net.corda.testing.driver.NodeParameters
import net.corda.testing.driver.driver
import net.corda.testing.node.User
import org.junit.Test
import java.time.Duration
import kotlin.test.assertEquals
import kotlin.test.assertNotNull
class FlowPausingTest {
companion object {
val TOTAL_MESSAGES = 100
val SLEEP_BETWEEN_MESSAGES_MS = 10L
}
@Test(timeout = 300_000)
fun `Paused flows can recieve session messages`() {
val rpcUser = User("demo", "demo", setOf(Permissions.startFlow<HardRestartTest.Ping>(), Permissions.all()))
driver(DriverParameters(startNodesInProcess = true, inMemoryDB = false)) {
val alice = startNode(NodeParameters(providedName = ALICE_NAME, rpcUsers = listOf(rpcUser))).getOrThrow()
val bob = startNode(NodeParameters(providedName = BOB_NAME, rpcUsers = listOf(rpcUser)))
val startedBob = bob.getOrThrow()
val aliceFlow = alice.rpc.startFlow(::HeartbeatFlow, startedBob.nodeInfo.legalIdentities[0])
// We wait here for the initiated flow to start running on bob
val initiatedFlowId = startedBob.rpc.waitForFlowToStart(150)
assertNotNull(initiatedFlowId)
/* We shut down bob, we want this to happen before bob has finished receiving all of the heartbeats.
This is a Race but if bob finishes too quickly then we will fail to unpause the initiated flow running on BOB latter
and this test will fail.*/
startedBob.stop()
//Start bob backup in Safe mode. This means no flows will run but BOB should receive messages and queue these up.
val restartedBob = startNode(NodeParameters(
providedName = BOB_NAME,
rpcUsers = listOf(rpcUser),
customOverrides = mapOf("smmStartMode" to "Safe"))).getOrThrow()
//Sleep for long enough so BOB has time to receive all the messages.
//All messages in this period should be queued up and replayed when the flow is unpaused.
Thread.sleep(TOTAL_MESSAGES * SLEEP_BETWEEN_MESSAGES_MS)
//ALICE should not have finished yet as the HeartbeatResponderFlow should not have sent the final message back (as it is paused).
assertEquals(false, aliceFlow.returnValue.isDone)
assertEquals(true, (restartedBob.rpc as InternalCordaRPCOps).unPauseFlow(initiatedFlowId!!))
assertEquals(true, aliceFlow.returnValue.getOrThrow())
alice.stop()
restartedBob.stop()
}
}
fun CordaRPCOps.waitForFlowToStart(maxTrys: Int): StateMachineRunId? {
for (i in 1..maxTrys) {
val snapshot = this.stateMachinesSnapshot().singleOrNull()
if (snapshot == null) {
Thread.sleep(SLEEP_BETWEEN_MESSAGES_MS)
} else {
return snapshot.id
}
}
return null
}
@StartableByRPC
@InitiatingFlow
class HeartbeatFlow(private val otherParty: Party): FlowLogic<Boolean>() {
var sequenceNumber = 0
@Suspendable
override fun call(): Boolean {
val session = initiateFlow(otherParty)
for (i in 1..TOTAL_MESSAGES) {
session.send(sequenceNumber++)
sleep(Duration.ofMillis(10))
}
val success = session.receive<Boolean>().unwrap{data -> data}
return success
}
}
@InitiatedBy(HeartbeatFlow::class)
class HeartbeatResponderFlow(val session: FlowSession): FlowLogic<Unit>() {
var sequenceNumber : Int = 0
@Suspendable
override fun call() {
var pass = true
for (i in 1..TOTAL_MESSAGES) {
val receivedSequenceNumber = session.receive<Int>().unwrap{data -> data}
if (receivedSequenceNumber != sequenceNumber) {
pass = false
}
sequenceNumber++
}
session.send(pass)
}
}
}

View File

@ -14,6 +14,7 @@ import net.corda.node.services.config.ConfigHelper
import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.config.Valid
import net.corda.node.services.config.parseAsNodeConfiguration
import net.corda.node.services.statemachine.StateMachineManager
import net.corda.nodeapi.internal.config.UnknownConfigKeysPolicy
import picocli.CommandLine.Option
import java.nio.file.Path
@ -48,6 +49,12 @@ open class SharedNodeCmdLineOptions {
)
var devMode: Boolean? = null
@Option(
names = ["--pause-all-flows"],
description = ["Do not run any flows on startup. Sets all flows to paused, which can be unpaused via RPC."]
)
var safeMode: Boolean = false
open fun parseConfiguration(configuration: Config): Valid<NodeConfiguration> {
val option = Configuration.Options(strict = unknownConfigKeysPolicy == UnknownConfigKeysPolicy.FAIL)
return configuration.parseAsNodeConfiguration(option)
@ -186,6 +193,9 @@ open class NodeCmdLineOptions : SharedNodeCmdLineOptions() {
devMode?.let {
configOverrides += "devMode" to it
}
if (safeMode) {
configOverrides += "smmStartMode" to StateMachineManager.StartMode.Safe.toString()
}
return try {
valid(ConfigHelper.loadConfig(baseDirectory, configFile, configOverrides = ConfigFactory.parseMap(configOverrides)))
} catch (e: ConfigException) {

View File

@ -543,7 +543,7 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
tokenizableServices = null
verifyCheckpointsCompatible(frozenTokenizableServices)
val smmStartedFuture = smm.start(frozenTokenizableServices)
val smmStartedFuture = smm.start(frozenTokenizableServices, configuration.smmStartMode)
// Shut down the SMM so no Fibers are scheduled.
runOnStop += { smm.stop(acceptableLiveFiberCountOnStop()) }
val flowMonitor = FlowMonitor(
@ -1379,4 +1379,4 @@ fun clientSslOptionsCompatibleWith(nodeRpcOptions: NodeRpcOptions): ClientRpcSsl
}
// Here we're using the node's RPC key store as the RPC client's trust store.
return ClientRpcSslOptions(trustStorePath = nodeRpcOptions.sslConfig!!.keyStorePath, trustStorePassword = nodeRpcOptions.sslConfig!!.keyStorePassword)
}
}

View File

@ -35,7 +35,7 @@ object CheckpointVerifier {
val cordappsByHash = currentCordapps.associateBy { it.jarHash }
checkpointStorage.getRunnableCheckpoints().use {
checkpointStorage.getCheckpointsToRun().use {
it.forEach { (_, serializedCheckpoint) ->
val checkpoint = try {
serializedCheckpoint.deserialize(checkpointSerializationContext)

View File

@ -169,6 +169,8 @@ internal class CordaRPCOpsImpl(
override fun killFlow(id: StateMachineRunId): Boolean = smm.killFlow(id)
override fun unPauseFlow(id: StateMachineRunId): Boolean = smm.unPauseFlow(id)
override fun stateMachinesFeed(): DataFeed<List<StateMachineInfo>, StateMachineUpdate> {
val (allStateMachines, changes) = smm.track()

View File

@ -20,6 +20,12 @@ interface CheckpointStorage {
*/
fun updateCheckpoint(id: StateMachineRunId, checkpoint: Checkpoint, serializedFlowState: SerializedBytes<FlowState>?)
/**
* Update all persisted checkpoints with status [Checkpoint.FlowStatus.RUNNABLE] or [Checkpoint.FlowStatus.HOSPITALIZED],
* changing the status to [Checkpoint.FlowStatus.PAUSED].
*/
fun markAllPaused()
/**
* Remove existing checkpoint from the store.
* @return whether the id matched a checkpoint that was removed.
@ -37,14 +43,23 @@ interface CheckpointStorage {
fun getCheckpoint(id: StateMachineRunId): Checkpoint.Serialized?
/**
* Stream all checkpoints from the store. If this is backed by a database the stream will be valid until the
* underlying database connection is closed, so any processing should happen before it is closed.
* Stream all checkpoints with statuses [statuses] from the store. If this is backed by a database the stream will be valid
* until the underlying database connection is closed, so any processing should happen before it is closed.
*/
fun getAllCheckpoints(): Stream<Pair<StateMachineRunId, Checkpoint.Serialized>>
fun getCheckpoints(
statuses: Collection<Checkpoint.FlowStatus> = Checkpoint.FlowStatus.values().toSet()
): Stream<Pair<StateMachineRunId, Checkpoint.Serialized>>
/**
* Stream runnable checkpoints from the store. If this is backed by a database the stream will be valid
* until the underlying database connection is closed, so any processing should happen before it is closed.
*/
fun getRunnableCheckpoints(): Stream<Pair<StateMachineRunId, Checkpoint.Serialized>>
fun getCheckpointsToRun(): Stream<Pair<StateMachineRunId, Checkpoint.Serialized>>
/**
* Stream paused checkpoints from the store. If this is backed by a database the stream will be valid
* until the underlying database connection is closed, so any processing should happen before it is closed.
* This method does not fetch [Checkpoint.Serialized.serializedFlowState] to save memory.
*/
fun getPausedCheckpoints(): Stream<Pair<StateMachineRunId, Checkpoint.Serialized>>
}

View File

@ -11,6 +11,7 @@ import net.corda.core.internal.notary.NotaryServiceFlow
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.node.services.config.rpc.NodeRpcOptions
import net.corda.node.services.config.schema.v1.V1NodeConfigurationSpec
import net.corda.node.services.statemachine.StateMachineManager
import net.corda.nodeapi.internal.config.FileBasedCertificateStoreSupplier
import net.corda.nodeapi.internal.config.MutualSslConfiguration
import net.corda.nodeapi.internal.config.User
@ -93,6 +94,8 @@ interface NodeConfiguration : ConfigurationWithOptionsContainer {
val quasarExcludePackages: List<String>
val smmStartMode: StateMachineManager.StartMode
companion object {
// default to at least 8MB and a bit extra for larger heap sizes
val defaultTransactionCacheSize: Long = 8.MB + getAdditionalCacheMemory()

View File

@ -8,6 +8,7 @@ import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.loggerFor
import net.corda.core.utilities.seconds
import net.corda.node.services.config.rpc.NodeRpcOptions
import net.corda.node.services.statemachine.StateMachineManager
import net.corda.nodeapi.BrokerRpcSslOptions
import net.corda.nodeapi.internal.DEV_PUB_KEY_HASHES
import net.corda.nodeapi.internal.config.FileBasedCertificateStoreSupplier
@ -84,7 +85,8 @@ data class NodeConfigurationImpl(
override val blacklistedAttachmentSigningKeys: List<String> = Defaults.blacklistedAttachmentSigningKeys,
override val configurationWithOptions: ConfigurationWithOptions,
override val flowExternalOperationThreadPoolSize: Int = Defaults.flowExternalOperationThreadPoolSize,
override val quasarExcludePackages: List<String> = Defaults.quasarExcludePackages
override val quasarExcludePackages: List<String> = Defaults.quasarExcludePackages,
override val smmStartMode : StateMachineManager.StartMode = Defaults.smmStartMode
) : NodeConfiguration {
internal object Defaults {
val jmxMonitoringHttpPort: Int? = null
@ -123,6 +125,7 @@ data class NodeConfigurationImpl(
val blacklistedAttachmentSigningKeys: List<String> = emptyList()
const val flowExternalOperationThreadPoolSize: Int = 1
val quasarExcludePackages: List<String> = emptyList()
val smmStartMode : StateMachineManager.StartMode = StateMachineManager.StartMode.ExcludingPaused
fun cordappsDirectories(baseDirectory: Path) = listOf(baseDirectory / CORDAPPS_DIR_NAME_DEFAULT)

View File

@ -9,6 +9,7 @@ import net.corda.common.validation.internal.Validated.Companion.valid
import net.corda.node.services.config.*
import net.corda.node.services.config.NodeConfigurationImpl.Defaults
import net.corda.node.services.config.schema.parsers.*
import net.corda.node.services.statemachine.StateMachineManager
internal object V1NodeConfigurationSpec : Configuration.Specification<NodeConfiguration>("NodeConfiguration") {
private val myLegalName by string().mapValid(::toCordaX500Name)
@ -66,6 +67,7 @@ internal object V1NodeConfigurationSpec : Configuration.Specification<NodeConfig
.withDefaultValue(Defaults.networkParameterAcceptanceSettings)
private val flowExternalOperationThreadPoolSize by int().optional().withDefaultValue(Defaults.flowExternalOperationThreadPoolSize)
private val quasarExcludePackages by string().list().optional().withDefaultValue(Defaults.quasarExcludePackages)
private val smmStartMode by enum(StateMachineManager.StartMode::class).optional().withDefaultValue(Defaults.smmStartMode)
@Suppress("unused")
private val custom by nestedObject().optional()
@Suppress("unused")
@ -133,7 +135,8 @@ internal object V1NodeConfigurationSpec : Configuration.Specification<NodeConfig
networkParameterAcceptanceSettings = config[networkParameterAcceptanceSettings],
configurationWithOptions = ConfigurationWithOptions(configuration, Configuration.Options.defaults),
flowExternalOperationThreadPoolSize = config[flowExternalOperationThreadPoolSize],
quasarExcludePackages = config[quasarExcludePackages]
quasarExcludePackages = config[quasarExcludePackages],
smmStartMode = config[smmStartMode]
))
} catch (e: Exception) {
return when (e) {

View File

@ -59,7 +59,7 @@ class DBCheckpointStorage(
private const val MAX_FLOW_NAME_LENGTH = 128
private const val MAX_PROGRESS_STEP_LENGTH = 256
private val NOT_RUNNABLE_CHECKPOINTS = listOf(FlowStatus.COMPLETED, FlowStatus.FAILED, FlowStatus.KILLED)
private val RUNNABLE_CHECKPOINTS = setOf(FlowStatus.RUNNABLE, FlowStatus.HOSPITALIZED)
/**
* This needs to run before Hibernate is initialised.
@ -281,6 +281,15 @@ class DBCheckpointStorage(
currentDBSession().update(updateDBCheckpoint(id, checkpoint, serializedFlowState))
}
override fun markAllPaused() {
val session = currentDBSession()
val runnableOrdinals = RUNNABLE_CHECKPOINTS.map{ "${it.ordinal}"}.joinToString { it }
val sqlQuery = "Update ${NODE_DATABASE_PREFIX}checkpoints set status = ${FlowStatus.PAUSED.ordinal} " +
"where status in ($runnableOrdinals)"
val query = session.createNativeQuery(sqlQuery)
query.executeUpdate()
}
override fun removeCheckpoint(id: StateMachineRunId): Boolean {
// This will be changed after performance tuning
return currentDBSession().let { session ->
@ -300,33 +309,39 @@ class DBCheckpointStorage(
return getDBCheckpoint(id)?.toSerializedCheckpoint()
}
override fun getAllCheckpoints(): Stream<Pair<StateMachineRunId, Checkpoint.Serialized>> {
val session = currentDBSession()
val criteriaQuery = session.criteriaBuilder.createQuery(DBFlowCheckpoint::class.java)
val root = criteriaQuery.from(DBFlowCheckpoint::class.java)
criteriaQuery.select(root)
return session.createQuery(criteriaQuery).stream().map {
StateMachineRunId(UUID.fromString(it.id)) to it.toSerializedCheckpoint()
}
}
override fun getRunnableCheckpoints(): Stream<Pair<StateMachineRunId, Checkpoint.Serialized>> {
override fun getCheckpoints(statuses: Collection<FlowStatus>): Stream<Pair<StateMachineRunId, Checkpoint.Serialized>> {
val session = currentDBSession()
val criteriaBuilder = session.criteriaBuilder
val criteriaQuery = criteriaBuilder.createQuery(DBFlowCheckpoint::class.java)
val root = criteriaQuery.from(DBFlowCheckpoint::class.java)
criteriaQuery.select(root)
.where(criteriaBuilder.not(root.get<FlowStatus>(DBFlowCheckpoint::status.name).`in`(NOT_RUNNABLE_CHECKPOINTS)))
.where(criteriaBuilder.isTrue(root.get<FlowStatus>(DBFlowCheckpoint::status.name).`in`(statuses)))
return session.createQuery(criteriaQuery).stream().map {
StateMachineRunId(UUID.fromString(it.id)) to it.toSerializedCheckpoint()
}
}
override fun getCheckpointsToRun(): Stream<Pair<StateMachineRunId, Checkpoint.Serialized>> {
return getCheckpoints(RUNNABLE_CHECKPOINTS)
}
@VisibleForTesting
internal fun getDBCheckpoint(id: StateMachineRunId): DBFlowCheckpoint? {
return currentDBSession().find(DBFlowCheckpoint::class.java, id.uuid.toString())
}
override fun getPausedCheckpoints(): Stream<Pair<StateMachineRunId, Checkpoint.Serialized>> {
val session = currentDBSession()
val jpqlQuery = """select new ${DBPausedFields::class.java.name}(checkpoint.id, blob.checkpoint, checkpoint.status,
checkpoint.progressStep, checkpoint.ioRequestType, checkpoint.compatible) from ${DBFlowCheckpoint::class.java.name}
checkpoint join ${DBFlowCheckpointBlob::class.java.name} blob on checkpoint.blob = blob.id where
checkpoint.status = ${FlowStatus.PAUSED.ordinal}""".trimIndent()
val query = session.createQuery(jpqlQuery, DBPausedFields::class.java)
return query.resultList.stream().map {
StateMachineRunId(UUID.fromString(it.id)) to it.toSerializedCheckpoint()
}
}
private fun createDBCheckpoint(
id: StateMachineRunId,
checkpoint: Checkpoint,
@ -542,6 +557,29 @@ class DBCheckpointStorage(
)
}
private class DBPausedFields(
val id: String,
val checkpoint: ByteArray = EMPTY_BYTE_ARRAY,
val status: FlowStatus,
val progressStep: String?,
val ioRequestType: String?,
val compatible: Boolean
) {
fun toSerializedCheckpoint(): Checkpoint.Serialized {
return Checkpoint.Serialized(
serializedCheckpointState = SerializedBytes(checkpoint),
serializedFlowState = null,
// Always load as a [Clean] checkpoint to represent that the checkpoint is the last _good_ checkpoint
errorState = ErrorState.Clean,
result = null,
status = status,
progressStep = progressStep,
flowIoRequest = ioRequestType,
compatible = compatible
)
}
}
private fun <T : Any> T.storageSerialize(): SerializedBytes<T> {
return serialize(context = SerializationDefaults.STORAGE_CONTEXT)
}

View File

@ -90,6 +90,11 @@ class CheckpointDumperImpl(private val checkpointStorage: CheckpointStorage, pri
companion object {
internal val TIME_FORMATTER = DateTimeFormatter.ofPattern("yyyyMMdd-HHmmss").withZone(UTC)
private val log = contextLogger()
private val DUMPABLE_CHECKPOINTS = setOf(
Checkpoint.FlowStatus.RUNNABLE,
Checkpoint.FlowStatus.HOSPITALIZED,
Checkpoint.FlowStatus.PAUSED
)
}
override val priority: Int = SERVICE_PRIORITY_NORMAL
@ -141,7 +146,7 @@ class CheckpointDumperImpl(private val checkpointStorage: CheckpointStorage, pri
try {
if (lock.getAndIncrement() == 0 && !file.exists()) {
database.transaction {
checkpointStorage.getRunnableCheckpoints().use { stream ->
checkpointStorage.getCheckpoints(DUMPABLE_CHECKPOINTS).use { stream ->
ZipOutputStream(file.outputStream()).use { zip ->
stream.forEach { (runId, serialisedCheckpoint) ->
@ -204,7 +209,7 @@ class CheckpointDumperImpl(private val checkpointStorage: CheckpointStorage, pri
val fiber = flowState.frozenFiber.checkpointDeserialize(context = checkpointSerializationContext)
fiber to fiber.logic
}
is FlowState.Completed -> {
else -> {
throw IllegalStateException("Only runnable checkpoints with their flow stack are output by the checkpoint dumper")
}
}

View File

@ -0,0 +1,191 @@
package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.FiberScheduler
import co.paralleluniverse.fibers.Suspendable
import co.paralleluniverse.strands.channels.Channels
import net.corda.core.concurrent.CordaFuture
import net.corda.core.context.InvocationContext
import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StateMachineRunId
import net.corda.core.identity.Party
import net.corda.core.internal.concurrent.OpenFuture
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.contextLogger
import net.corda.node.services.api.CheckpointStorage
import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.messaging.DeduplicationHandler
import net.corda.node.services.statemachine.transitions.StateMachine
import net.corda.node.utilities.isEnabledTimedFlow
import net.corda.nodeapi.internal.persistence.CordaPersistence
import org.apache.activemq.artemis.utils.ReusableLatch
import java.security.SecureRandom
class Flow<A>(val fiber: FlowStateMachineImpl<A>, val resultFuture: OpenFuture<Any?>)
class NonResidentFlow(val runId: StateMachineRunId, val checkpoint: Checkpoint) {
val externalEvents = mutableListOf<Event.DeliverSessionMessage>()
fun addExternalEvent(message: Event.DeliverSessionMessage) {
externalEvents.add(message)
}
}
class FlowCreator(
val checkpointSerializationContext: CheckpointSerializationContext,
private val checkpointStorage: CheckpointStorage,
val scheduler: FiberScheduler,
val database: CordaPersistence,
val transitionExecutor: TransitionExecutor,
val actionExecutor: ActionExecutor,
val secureRandom: SecureRandom,
val serviceHub: ServiceHubInternal,
val unfinishedFibers: ReusableLatch,
val resetCustomTimeout: (StateMachineRunId, Long) -> Unit) {
companion object {
private val logger = contextLogger()
}
fun createFlowFromNonResidentFlow(nonResidentFlow: NonResidentFlow): Flow<*>? {
// As for paused flows we don't extract the serialized flow state we need to re-extract the checkpoint from the database.
val checkpoint = when (nonResidentFlow.checkpoint.status) {
Checkpoint.FlowStatus.PAUSED -> {
val serialized = database.transaction {
checkpointStorage.getCheckpoint(nonResidentFlow.runId)
}
serialized?.copy(status = Checkpoint.FlowStatus.RUNNABLE)?.deserialize(checkpointSerializationContext) ?: return null
}
else -> nonResidentFlow.checkpoint
}
return createFlowFromCheckpoint(nonResidentFlow.runId, checkpoint)
}
fun createFlowFromCheckpoint(runId: StateMachineRunId, oldCheckpoint: Checkpoint): Flow<*>? {
val checkpoint = oldCheckpoint.copy(status = Checkpoint.FlowStatus.RUNNABLE)
val fiber = checkpoint.getFiberFromCheckpoint(runId) ?: return null
val resultFuture = openFuture<Any?>()
fiber.transientValues = TransientReference(createTransientValues(runId, resultFuture))
fiber.logic.stateMachine = fiber
verifyFlowLogicIsSuspendable(fiber.logic)
val state = createStateMachineState(checkpoint, fiber, true)
fiber.transientState = TransientReference(state)
return Flow(fiber, resultFuture)
}
@Suppress("LongParameterList")
fun <A> createFlowFromLogic(
flowId: StateMachineRunId,
invocationContext: InvocationContext,
flowLogic: FlowLogic<A>,
flowStart: FlowStart,
ourIdentity: Party,
existingCheckpoint: Checkpoint?,
deduplicationHandler: DeduplicationHandler?,
senderUUID: String?): Flow<A> {
// Before we construct the state machine state by freezing the FlowLogic we need to make sure that lazy properties
// have access to the fiber (and thereby the service hub)
val flowStateMachineImpl = FlowStateMachineImpl(flowId, flowLogic, scheduler)
val resultFuture = openFuture<Any?>()
flowStateMachineImpl.transientValues = TransientReference(createTransientValues(flowId, resultFuture))
flowLogic.stateMachine = flowStateMachineImpl
val frozenFlowLogic = (flowLogic as FlowLogic<*>).checkpointSerialize(context = checkpointSerializationContext)
val flowCorDappVersion = FlowStateMachineImpl.createSubFlowVersion(
serviceHub.cordappProvider.getCordappForFlow(flowLogic), serviceHub.myInfo.platformVersion)
val checkpoint = existingCheckpoint?.copy(status = Checkpoint.FlowStatus.RUNNABLE) ?: Checkpoint.create(
invocationContext,
flowStart,
flowLogic.javaClass,
frozenFlowLogic,
ourIdentity,
flowCorDappVersion,
flowLogic.isEnabledTimedFlow()
).getOrThrow()
val state = createStateMachineState(
checkpoint,
flowStateMachineImpl,
existingCheckpoint != null,
deduplicationHandler,
senderUUID)
flowStateMachineImpl.transientState = TransientReference(state)
return Flow(flowStateMachineImpl, resultFuture)
}
private fun Checkpoint.getFiberFromCheckpoint(runId: StateMachineRunId): FlowStateMachineImpl<*>? {
return when (this.flowState) {
is FlowState.Unstarted -> {
val logic = tryCheckpointDeserialize(this.flowState.frozenFlowLogic, runId) ?: return null
FlowStateMachineImpl(runId, logic, scheduler)
}
is FlowState.Started -> tryCheckpointDeserialize(this.flowState.frozenFiber, runId) ?: return null
// Places calling this function is rely on it to return null if the flow cannot be created from the checkpoint.
else -> {
return null
}
}
}
@Suppress("TooGenericExceptionCaught")
private inline fun <reified T : Any> tryCheckpointDeserialize(bytes: SerializedBytes<T>, flowId: StateMachineRunId): T? {
return try {
bytes.checkpointDeserialize(context = checkpointSerializationContext)
} catch (e: Exception) {
logger.error("Unable to deserialize checkpoint for flow $flowId. Something is very wrong and this flow will be ignored.", e)
null
}
}
private fun verifyFlowLogicIsSuspendable(logic: FlowLogic<Any?>) {
// Quasar requires (in Java 8) that at least the call method be annotated suspendable. Unfortunately, it's
// easy to forget to add this when creating a new flow, so we check here to give the user a better error.
//
// The Kotlin compiler can sometimes generate a synthetic bridge method from a single call declaration, which
// forwards to the void method and then returns Unit. However annotations do not get copied across to this
// bridge, so we have to do a more complex scan here.
val call = logic.javaClass.methods.first { !it.isSynthetic && it.name == "call" && it.parameterCount == 0 }
if (call.getAnnotation(Suspendable::class.java) == null) {
throw FlowException("${logic.javaClass.name}.call() is not annotated as @Suspendable. Please fix this.")
}
}
private fun createTransientValues(id: StateMachineRunId, resultFuture: CordaFuture<Any?>): FlowStateMachineImpl.TransientValues {
return FlowStateMachineImpl.TransientValues(
eventQueue = Channels.newChannel(-1, Channels.OverflowPolicy.BLOCK),
resultFuture = resultFuture,
database = database,
transitionExecutor = transitionExecutor,
actionExecutor = actionExecutor,
stateMachine = StateMachine(id, secureRandom),
serviceHub = serviceHub,
checkpointSerializationContext = checkpointSerializationContext,
unfinishedFibers = unfinishedFibers,
waitTimeUpdateHook = { flowId, timeout -> resetCustomTimeout(flowId, timeout) }
)
}
private fun createStateMachineState(
checkpoint: Checkpoint,
fiber: FlowStateMachineImpl<*>,
anyCheckpointPersisted: Boolean,
deduplicationHandler: DeduplicationHandler? = null,
senderUUID: String? = null): StateMachineState {
return StateMachineState(
checkpoint = checkpoint,
pendingDeduplicationHandlers = deduplicationHandler?.let { listOf(it) } ?: emptyList(),
isFlowResumed = false,
future = null,
isWaitingForFuture = false,
isAnyCheckpointPersisted = anyCheckpointPersisted,
isStartIdempotent = false,
isRemoved = false,
isKilled = false,
flowLogic = fiber.logic,
senderUUID = senderUUID)
}
}

View File

@ -2,9 +2,7 @@ package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Fiber
import co.paralleluniverse.fibers.FiberExecutorScheduler
import co.paralleluniverse.fibers.Suspendable
import co.paralleluniverse.fibers.instrument.JavaAgent
import co.paralleluniverse.strands.channels.Channels
import com.codahale.metrics.Gauge
import com.google.common.util.concurrent.ThreadFactoryBuilder
import net.corda.core.concurrent.CordaFuture
@ -24,12 +22,9 @@ import net.corda.core.internal.concurrent.mapError
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.internal.mapNotNull
import net.corda.core.messaging.DataFeed
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.CheckpointSerializationDefaults
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.Try
import net.corda.core.utilities.contextLogger
@ -39,13 +34,11 @@ import net.corda.node.services.api.CheckpointStorage
import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.config.shouldCheckCheckpoints
import net.corda.node.services.messaging.DeduplicationHandler
import net.corda.node.services.statemachine.FlowStateMachineImpl.Companion.createSubFlowVersion
import net.corda.node.services.statemachine.interceptors.DumpHistoryOnErrorInterceptor
import net.corda.node.services.statemachine.interceptors.FiberDeserializationChecker
import net.corda.node.services.statemachine.interceptors.FiberDeserializationCheckingInterceptor
import net.corda.node.services.statemachine.interceptors.HospitalisingInterceptor
import net.corda.node.services.statemachine.interceptors.PrintingInterceptor
import net.corda.node.services.statemachine.transitions.StateMachine
import net.corda.node.utilities.AffinityExecutor
import net.corda.node.utilities.errorAndTerminate
import net.corda.node.utilities.injectOldProgressTracker
@ -61,6 +54,7 @@ import java.lang.Integer.min
import java.security.SecureRandom
import java.time.Duration
import java.util.HashSet
import java.util.*
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
@ -90,8 +84,6 @@ class SingleThreadedStateMachineManager(
private val logger = contextLogger()
}
private class Flow(val fiber: FlowStateMachineImpl<*>, val resultFuture: OpenFuture<Any?>)
private data class ScheduledTimeout(
/** Will fire a [FlowTimeoutException] indicating to the flow hospital to restart the flow. */
val scheduledFuture: ScheduledFuture<*>,
@ -105,7 +97,8 @@ class SingleThreadedStateMachineManager(
val changesPublisher = PublishSubject.create<StateMachineManager.Change>()!!
/** True if we're shutting down, so don't resume anything. */
var stopping = false
val flows = HashMap<StateMachineRunId, Flow>()
val flows = HashMap<StateMachineRunId, Flow<*>>()
val pausedFlows = HashMap<StateMachineRunId, NonResidentFlow>()
val startedFutures = HashMap<StateMachineRunId, OpenFuture<Unit>>()
/** Flows scheduled to be retried if not finished within the specified timeout period. */
val timedFlows = HashMap<StateMachineRunId, ScheduledTimeout>()
@ -127,7 +120,7 @@ class SingleThreadedStateMachineManager(
private val ourSenderUUID = serviceHub.networkService.ourSenderUUID
private var checkpointSerializationContext: CheckpointSerializationContext? = null
private var actionExecutor: ActionExecutor? = null
private lateinit var flowCreator: FlowCreator
override val flowHospital: StaffedFlowHospital = makeFlowHospital()
private val transitionExecutor = makeTransitionExecutor()
@ -146,7 +139,7 @@ class SingleThreadedStateMachineManager(
*/
override val changes: Observable<StateMachineManager.Change> = mutex.content.changesPublisher
override fun start(tokenizableServices: List<Any>) : CordaFuture<Unit> {
override fun start(tokenizableServices: List<Any>, startMode: StateMachineManager.StartMode): CordaFuture<Unit> {
checkQuasarJavaAgentPresence()
val checkpointSerializationContext = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext(
CheckpointSerializeAsTokenContextImpl(
@ -157,8 +150,24 @@ class SingleThreadedStateMachineManager(
)
)
this.checkpointSerializationContext = checkpointSerializationContext
this.actionExecutor = makeActionExecutor(checkpointSerializationContext)
val actionExecutor = makeActionExecutor(checkpointSerializationContext)
fiberDeserializationChecker?.start(checkpointSerializationContext)
when (startMode) {
StateMachineManager.StartMode.ExcludingPaused -> {}
StateMachineManager.StartMode.Safe -> markAllFlowsAsPaused()
}
this.flowCreator = FlowCreator(
checkpointSerializationContext,
checkpointStorage,
scheduler,
database,
transitionExecutor,
actionExecutor,
secureRandom,
serviceHub,
unfinishedFibers,
::resetCustomTimeout)
val fibers = restoreFlowsFromCheckpoints()
metrics.register("Flows.InFlight", Gauge<Int> { mutex.content.flows.size })
Fiber.setDefaultUncaughtExceptionHandler { fiber, throwable ->
@ -168,6 +177,17 @@ class SingleThreadedStateMachineManager(
(fiber as FlowStateMachineImpl<*>).logger.warn("Caught exception from flow", throwable)
}
}
val pausedFlows = restoreNonResidentFlowsFromPausedCheckpoints()
mutex.locked {
this.pausedFlows.putAll(pausedFlows)
for ((id, flow) in pausedFlows) {
val checkpoint = flow.checkpoint
for (sessionId in getFlowSessionIds(checkpoint)) {
sessionToFlow[sessionId] = id
}
}
}
return serviceHub.networkMapCache.nodeReady.map {
logger.info("Node ready, info: ${serviceHub.myInfo}")
resumeRestoredFlows(fibers)
@ -241,8 +261,7 @@ class SingleThreadedStateMachineManager(
flowLogic = flowLogic,
flowStart = FlowStart.Explicit,
ourIdentity = ourIdentity ?: ourFirstIdentity,
deduplicationHandler = deduplicationHandler,
isStartIdempotent = false
deduplicationHandler = deduplicationHandler
)
}
@ -282,6 +301,22 @@ class SingleThreadedStateMachineManager(
}
}
private fun markAllFlowsAsPaused() {
return checkpointStorage.markAllPaused()
}
override fun unPauseFlow(id: StateMachineRunId): Boolean {
mutex.locked {
val pausedFlow = pausedFlows.remove(id) ?: return false
val flow = flowCreator.createFlowFromNonResidentFlow(pausedFlow) ?: return false
addAndStartFlow(flow.fiber.id, flow)
for (event in pausedFlow.externalEvents) {
flow.fiber.scheduleEvent(event)
}
}
return true
}
override fun addSessionBinding(flowId: StateMachineRunId, sessionId: SessionId) {
val previousFlowId = sessionToFlow.put(sessionId, flowId)
if (previousFlowId != null) {
@ -352,23 +387,28 @@ class SingleThreadedStateMachineManager(
liveFibers.countUp()
}
private fun restoreFlowsFromCheckpoints(): List<Flow> {
return checkpointStorage.getRunnableCheckpoints().use {
private fun restoreFlowsFromCheckpoints(): List<Flow<*>> {
return checkpointStorage.getCheckpointsToRun().use {
it.mapNotNull { (id, serializedCheckpoint) ->
// If a flow is added before start() then don't attempt to restore it
mutex.locked { if (id in flows) return@mapNotNull null }
createFlowFromCheckpoint(
id = id,
serializedCheckpoint = serializedCheckpoint,
initialDeduplicationHandler = null,
isAnyCheckpointPersisted = true,
isStartIdempotent = false
)
val checkpoint = tryDeserializeCheckpoint(serializedCheckpoint, id) ?: return@mapNotNull null
flowCreator.createFlowFromCheckpoint(id, checkpoint)
}.toList()
}
}
private fun resumeRestoredFlows(flows: List<Flow>) {
private fun restoreNonResidentFlowsFromPausedCheckpoints(): Map<StateMachineRunId, NonResidentFlow> {
return checkpointStorage.getPausedCheckpoints().use {
it.mapNotNull { (id, serializedCheckpoint) ->
// If a flow is added before start() then don't attempt to restore it
val checkpoint = tryDeserializeCheckpoint(serializedCheckpoint, id) ?: return@mapNotNull null
id to NonResidentFlow(id, checkpoint)
}.toList().toMap()
}
}
private fun resumeRestoredFlows(flows: List<Flow<*>>) {
for (flow in flows) {
addAndStartFlow(flow.fiber.id, flow)
}
@ -393,14 +433,10 @@ class SingleThreadedStateMachineManager(
logger.error("Unable to find database checkpoint for flow $flowId. Something is very wrong. The flow will not retry.")
return
}
val checkpoint = tryDeserializeCheckpoint(serializedCheckpoint, flowId) ?: return
// Resurrect flow
createFlowFromCheckpoint(
id = flowId,
serializedCheckpoint = serializedCheckpoint,
initialDeduplicationHandler = null,
isAnyCheckpointPersisted = true,
isStartIdempotent = false
) ?: return
flowCreator.createFlowFromCheckpoint(flowId, checkpoint) ?: return
} else {
// Just flow initiation message
null
@ -503,9 +539,13 @@ class SingleThreadedStateMachineManager(
logger.info("Cannot find flow corresponding to session ID - $recipientId.")
}
} else {
mutex.locked { flows[flowId] }?.run {
fiber.scheduleEvent(Event.DeliverSessionMessage(sessionMessage, deduplicationHandler, sender))
} ?: logger.info("Cannot find fiber corresponding to flow ID $flowId")
val event = Event.DeliverSessionMessage(sessionMessage, deduplicationHandler, sender)
mutex.locked {
flows[flowId]?.run { fiber.scheduleEvent(event) }
// If flow is not running add it to the list of external events to be processed if/when the flow resumes.
?: pausedFlows[flowId]?.run { addExternalEvent(event) }
?: logger.info("Cannot find fiber corresponding to flow ID $flowId")
}
}
} catch (exception: Exception) {
logger.error("Exception while routing $sessionMessage", exception)
@ -582,8 +622,7 @@ class SingleThreadedStateMachineManager(
flowLogic,
flowStart,
ourIdentity,
initiatingMessageDeduplicationHandler,
isStartIdempotent = false
initiatingMessageDeduplicationHandler
)
}
@ -594,20 +633,9 @@ class SingleThreadedStateMachineManager(
flowLogic: FlowLogic<A>,
flowStart: FlowStart,
ourIdentity: Party,
deduplicationHandler: DeduplicationHandler?,
isStartIdempotent: Boolean
deduplicationHandler: DeduplicationHandler?
): CordaFuture<FlowStateMachine<A>> {
// Before we construct the state machine state by freezing the FlowLogic we need to make sure that lazy properties
// have access to the fiber (and thereby the service hub)
val flowStateMachineImpl = FlowStateMachineImpl(flowId, flowLogic, scheduler)
val resultFuture = openFuture<Any?>()
flowStateMachineImpl.transientValues = TransientReference(createTransientValues(flowId, resultFuture))
flowLogic.stateMachine = flowStateMachineImpl
val frozenFlowLogic = (flowLogic as FlowLogic<*>).checkpointSerialize(context = checkpointSerializationContext!!)
val flowCorDappVersion = createSubFlowVersion(serviceHub.cordappProvider.getCordappForFlow(flowLogic), serviceHub.myInfo.platformVersion)
val flowAlreadyExists = mutex.locked { flows[flowId] != null }
val existingCheckpoint = if (flowAlreadyExists) {
@ -629,37 +657,15 @@ class SingleThreadedStateMachineManager(
// This is a brand new flow
null
}
val checkpoint = existingCheckpoint?.copy(status = Checkpoint.FlowStatus.RUNNABLE) ?: Checkpoint.create(
invocationContext,
flowStart,
flowLogic.javaClass,
frozenFlowLogic,
ourIdentity,
flowCorDappVersion,
flowLogic.isEnabledTimedFlow()
).getOrThrow()
val flow = flowCreator.createFlowFromLogic(flowId, invocationContext, flowLogic, flowStart, ourIdentity, existingCheckpoint, deduplicationHandler, ourSenderUUID)
val startedFuture = openFuture<Unit>()
val initialState = StateMachineState(
checkpoint = checkpoint,
pendingDeduplicationHandlers = deduplicationHandler?.let { listOf(it) } ?: emptyList(),
isFlowResumed = false,
isWaitingForFuture = false,
future = null,
isAnyCheckpointPersisted = existingCheckpoint != null,
isStartIdempotent = isStartIdempotent,
isRemoved = false,
isKilled = false,
flowLogic = flowLogic,
senderUUID = ourSenderUUID
)
flowStateMachineImpl.transientState = TransientReference(initialState)
mutex.locked {
startedFutures[flowId] = startedFuture
}
totalStartedFlows.inc()
addAndStartFlow(flowId, Flow(flowStateMachineImpl, resultFuture))
return startedFuture.map { flowStateMachineImpl as FlowStateMachine<A> }
addAndStartFlow(flowId, flow)
return startedFuture.map { flow.fiber as FlowStateMachine<A> }
}
override fun scheduleFlowTimeout(flowId: StateMachineRunId) {
@ -739,7 +745,7 @@ class SingleThreadedStateMachineManager(
}
/** Schedules a [FlowTimeoutException] to be fired in order to restart the flow. */
private fun scheduleTimeoutException(flow: Flow, delay: Long): ScheduledFuture<*> {
private fun scheduleTimeoutException(flow: Flow<*>, delay: Long): ScheduledFuture<*> {
return with(serviceHub.configuration.flowTimeout) {
scheduledFutureExecutor.schedule({
val event = Event.Error(FlowTimeoutException())
@ -767,43 +773,6 @@ class SingleThreadedStateMachineManager(
}
}
private fun verifyFlowLogicIsSuspendable(logic: FlowLogic<Any?>) {
// Quasar requires (in Java 8) that at least the call method be annotated suspendable. Unfortunately, it's
// easy to forget to add this when creating a new flow, so we check here to give the user a better error.
//
// The Kotlin compiler can sometimes generate a synthetic bridge method from a single call declaration, which
// forwards to the void method and then returns Unit. However annotations do not get copied across to this
// bridge, so we have to do a more complex scan here.
val call = logic.javaClass.methods.first { !it.isSynthetic && it.name == "call" && it.parameterCount == 0 }
if (call.getAnnotation(Suspendable::class.java) == null) {
throw FlowException("${logic.javaClass.name}.call() is not annotated as @Suspendable. Please fix this.")
}
}
private fun createTransientValues(id: StateMachineRunId, resultFuture: CordaFuture<Any?>): FlowStateMachineImpl.TransientValues {
return FlowStateMachineImpl.TransientValues(
eventQueue = Channels.newChannel(-1, Channels.OverflowPolicy.BLOCK),
resultFuture = resultFuture,
database = database,
transitionExecutor = transitionExecutor,
actionExecutor = actionExecutor!!,
stateMachine = StateMachine(id, secureRandom),
serviceHub = serviceHub,
checkpointSerializationContext = checkpointSerializationContext!!,
unfinishedFibers = unfinishedFibers,
waitTimeUpdateHook = { flowId, timeout -> resetCustomTimeout(flowId, timeout) }
)
}
private inline fun <reified T : Any> tryCheckpointDeserialize(bytes: SerializedBytes<T>, flowId: StateMachineRunId): T? {
return try {
bytes.checkpointDeserialize(context = checkpointSerializationContext!!)
} catch (e: Exception) {
logger.error("Unable to deserialize checkpoint for flow $flowId. Something is very wrong and this flow will be ignored.", e)
null
}
}
private fun tryDeserializeCheckpoint(serializedCheckpoint: Checkpoint.Serialized, flowId: StateMachineRunId): Checkpoint? {
return try {
serializedCheckpoint.deserialize(checkpointSerializationContext!!)
@ -813,68 +782,7 @@ class SingleThreadedStateMachineManager(
}
}
private fun createFlowFromCheckpoint(
id: StateMachineRunId,
serializedCheckpoint: Checkpoint.Serialized,
isAnyCheckpointPersisted: Boolean,
isStartIdempotent: Boolean,
initialDeduplicationHandler: DeduplicationHandler?
): Flow? {
val checkpoint = tryDeserializeCheckpoint(serializedCheckpoint, id)?.copy(status = Checkpoint.FlowStatus.RUNNABLE) ?: return null
val resultFuture = openFuture<Any?>()
val fiber = when (checkpoint.flowState) {
is FlowState.Unstarted -> {
val logic = tryCheckpointDeserialize(checkpoint.flowState.frozenFlowLogic, id) ?: return null
val state = StateMachineState(
checkpoint = checkpoint,
pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(),
isFlowResumed = false,
isWaitingForFuture = false,
future = null,
isAnyCheckpointPersisted = isAnyCheckpointPersisted,
isStartIdempotent = isStartIdempotent,
isRemoved = false,
isKilled = false,
flowLogic = logic,
senderUUID = null
)
val fiber = FlowStateMachineImpl(id, logic, scheduler)
fiber.transientValues = TransientReference(createTransientValues(id, resultFuture))
fiber.transientState = TransientReference(state)
fiber.logic.stateMachine = fiber
fiber
}
is FlowState.Started -> {
val fiber = tryCheckpointDeserialize(checkpoint.flowState.frozenFiber, id) ?: return null
val state = StateMachineState(
checkpoint = checkpoint,
pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(),
isFlowResumed = false,
isWaitingForFuture = false,
future = null,
isAnyCheckpointPersisted = isAnyCheckpointPersisted,
isStartIdempotent = isStartIdempotent,
isRemoved = false,
isKilled = false,
flowLogic = fiber.logic,
senderUUID = null
)
fiber.transientValues = TransientReference(createTransientValues(id, resultFuture))
fiber.transientState = TransientReference(state)
fiber.logic.stateMachine = fiber
fiber
}
is FlowState.Completed -> {
return null // Places calling this function is rely on it to return null if the flow cannot be created from the checkpoint.
}
}
verifyFlowLogicIsSuspendable(fiber.logic)
return Flow(fiber, resultFuture)
}
private fun addAndStartFlow(id: StateMachineRunId, flow: Flow) {
private fun addAndStartFlow(id: StateMachineRunId, flow: Flow<*>) {
val checkpoint = flow.fiber.snapshot().checkpoint
for (sessionId in getFlowSessionIds(checkpoint)) {
sessionToFlow[sessionId] = id
@ -899,7 +807,7 @@ class SingleThreadedStateMachineManager(
}
}
private fun startOrResume(checkpoint: Checkpoint, flow: Flow) {
private fun startOrResume(checkpoint: Checkpoint, flow: Flow<*>) {
when (checkpoint.flowState) {
is FlowState.Unstarted -> {
flow.fiber.start()
@ -953,7 +861,7 @@ class SingleThreadedStateMachineManager(
}
private fun InnerState.removeFlowOrderly(
flow: Flow,
flow: Flow<*>,
removalReason: FlowRemovalReason.OrderlyFinish,
lastState: StateMachineState
) {
@ -969,7 +877,7 @@ class SingleThreadedStateMachineManager(
}
private fun InnerState.removeFlowError(
flow: Flow,
flow: Flow<*>,
removalReason: FlowRemovalReason.ErrorFinish,
lastState: StateMachineState
) {
@ -983,7 +891,7 @@ class SingleThreadedStateMachineManager(
}
// The flow's event queue may be non-empty in case it shut down abruptly. We handle outstanding events here.
private fun drainFlowEventQueue(flow: Flow) {
private fun drainFlowEventQueue(flow: Flow<*>) {
while (true) {
val event = flow.fiber.transientValues!!.value.eventQueue.tryReceive() ?: return
when (event) {

View File

@ -30,12 +30,18 @@ import java.time.Duration
* TODO: Don't store all active flows in memory, load from the database on demand.
*/
interface StateMachineManager {
enum class StartMode {
ExcludingPaused, // Resume all flows except paused flows.
Safe // Mark all flows as paused.
}
/**
* Starts the state machine manager, loading and starting the state machines in storage.
*
* @return `Future` which completes when SMM is fully started
*/
fun start(tokenizableServices: List<Any>) : CordaFuture<Unit>
fun start(tokenizableServices: List<Any>, startMode: StartMode = StartMode.ExcludingPaused) : CordaFuture<Unit>
/**
* Stops the state machine manager gracefully, waiting until all but [allowedUnsuspendedFiberCount] flows reach the
@ -80,6 +86,13 @@ interface StateMachineManager {
*/
fun killFlow(id: StateMachineRunId): Boolean
/**
* Start a paused flow.
*
* @return whether the flow was successfully started.
*/
fun unPauseFlow(id: StateMachineRunId): Boolean
/**
* Deliver an external event to the state machine. Such an event might be a new P2P message, or a request to start a flow.
* The event may be replayed if a flow fails and attempts to retry.

View File

@ -170,9 +170,14 @@ data class Checkpoint(
* @return A [Checkpoint] with all its fields filled in from [Checkpoint.Serialized]
*/
fun deserialize(checkpointSerializationContext: CheckpointSerializationContext): Checkpoint {
val flowState = when(status) {
FlowStatus.PAUSED -> FlowState.Paused
FlowStatus.COMPLETED -> FlowState.Completed
else -> serializedFlowState!!.checkpointDeserialize(checkpointSerializationContext)
}
return Checkpoint(
checkpointState = serializedCheckpointState.deserialize(context = SerializationDefaults.STORAGE_CONTEXT),
flowState = serializedFlowState?.checkpointDeserialize(checkpointSerializationContext) ?: FlowState.Completed,
flowState = flowState,
errorState = errorState,
result = result?.deserialize(context = SerializationDefaults.STORAGE_CONTEXT),
status = status,
@ -307,6 +312,11 @@ sealed class FlowState {
override fun toString() = "Started(flowIORequest=$flowIORequest, frozenFiber=${frozenFiber.hash})"
}
/**
* The flow is paused. To save memory we don't store the FlowState
*/
object Paused: FlowState()
/**
* The flow has completed. It does not have a running fiber that needs to be serialized and checkpointed.
*/

View File

@ -29,6 +29,7 @@ class DoRemainingWorkTransition(
is FlowState.Unstarted -> UnstartedFlowTransition(context, startingState, flowState).transition()
is FlowState.Started -> StartedFlowTransition(context, startingState, flowState).transition()
is FlowState.Completed -> throw IllegalStateException("Cannot transition a state with completed flow state.")
is FlowState.Paused -> throw IllegalStateException("Cannot transition a state with paused flow state.")
}
}

View File

@ -32,6 +32,7 @@ class NodeStartupCliTest {
Assertions.assertThat(startup.verbose).isEqualTo(false)
Assertions.assertThat(startup.loggingLevel).isEqualTo(Level.INFO)
Assertions.assertThat(startup.cmdLineOptions.noLocalShell).isEqualTo(false)
Assertions.assertThat(startup.cmdLineOptions.safeMode).isEqualTo(false)
Assertions.assertThat(startup.cmdLineOptions.sshdServer).isEqualTo(false)
Assertions.assertThat(startup.cmdLineOptions.justGenerateNodeInfo).isEqualTo(false)
Assertions.assertThat(startup.cmdLineOptions.justGenerateRpcSslCerts).isEqualTo(false)

View File

@ -63,7 +63,7 @@ import kotlin.test.assertFailsWith
import kotlin.test.assertTrue
internal fun CheckpointStorage.getAllIncompleteCheckpoints(): List<Checkpoint.Serialized> {
return getRunnableCheckpoints().use {
return getCheckpointsToRun().use {
it.map { it.second }.toList()
}.filter { it.status != Checkpoint.FlowStatus.COMPLETED }
}

View File

@ -5,6 +5,7 @@ import net.corda.core.context.InvocationOrigin
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StateMachineRunId
import net.corda.core.internal.FlowIORequest
import net.corda.core.internal.toSet
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.internal.CheckpointSerializationDefaults
import net.corda.core.serialization.internal.checkpointSerialize
@ -41,13 +42,13 @@ import org.junit.Ignore
import org.junit.Rule
import org.junit.Test
import java.time.Clock
import java.util.ArrayList
import java.util.*
import kotlin.streams.toList
import kotlin.test.assertEquals
import kotlin.test.assertTrue
internal fun CheckpointStorage.checkpoints(): List<Checkpoint.Serialized> {
return getAllCheckpoints().use {
return getCheckpoints().use {
it.map { it.second }.toList()
}
}
@ -148,18 +149,38 @@ class DBCheckpointStorageTests {
checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState)
}
val completedCheckpoint = checkpoint.copy(flowState = FlowState.Completed)
val completedCheckpoint = checkpoint.copy(status = Checkpoint.FlowStatus.COMPLETED)
database.transaction {
checkpointStorage.updateCheckpoint(id, completedCheckpoint, null)
}
database.transaction {
assertEquals(
completedCheckpoint,
completedCheckpoint.copy(flowState = FlowState.Completed),
checkpointStorage.checkpoints().single().deserialize()
)
}
}
@Test(timeout = 300_000)
fun `update a checkpoint to paused`() {
val (id, checkpoint) = newCheckpoint()
val serializedFlowState = checkpoint.serializeFlowState()
database.transaction {
checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState)
}
val pausedCheckpoint = checkpoint.copy(status = Checkpoint.FlowStatus.PAUSED)
database.transaction {
checkpointStorage.updateCheckpoint(id, pausedCheckpoint, null)
}
database.transaction {
assertEquals(
pausedCheckpoint.copy(flowState = FlowState.Paused),
checkpointStorage.checkpoints().single().deserialize()
)
}
}
@Test(timeout = 300_000)
fun `removing a checkpoint deletes from all checkpoint tables`() {
val exception = IllegalStateException("I am a naughty exception")
@ -641,7 +662,7 @@ class DBCheckpointStorageTests {
}
@Test(timeout = 300_000)
fun `fetch runnable checkpoints`() {
fun `Checkpoints can be fetched from the database by status`() {
val (_, checkpoint) = newCheckpoint(1)
// runnables
val runnable = checkpoint.copy(status = Checkpoint.FlowStatus.RUNNABLE)
@ -650,8 +671,8 @@ class DBCheckpointStorageTests {
val completed = checkpoint.copy(status = Checkpoint.FlowStatus.COMPLETED)
val failed = checkpoint.copy(status = Checkpoint.FlowStatus.FAILED)
val killed = checkpoint.copy(status = Checkpoint.FlowStatus.KILLED)
// tentative
val paused = checkpoint.copy(status = Checkpoint.FlowStatus.PAUSED) // is considered runnable
// paused
val paused = checkpoint.copy(status = Checkpoint.FlowStatus.PAUSED)
database.transaction {
val serializedFlowState =
@ -666,7 +687,15 @@ class DBCheckpointStorageTests {
}
database.transaction {
assertEquals(3, checkpointStorage.getRunnableCheckpoints().count())
val toRunStatuses = setOf(Checkpoint.FlowStatus.RUNNABLE, Checkpoint.FlowStatus.HOSPITALIZED)
val pausedStatuses = setOf(Checkpoint.FlowStatus.PAUSED)
val customStatuses = setOf(Checkpoint.FlowStatus.RUNNABLE, Checkpoint.FlowStatus.KILLED)
val customStatuses1 = setOf(Checkpoint.FlowStatus.PAUSED, Checkpoint.FlowStatus.HOSPITALIZED, Checkpoint.FlowStatus.FAILED)
assertEquals(toRunStatuses, checkpointStorage.getCheckpointsToRun().map { it.second.status }.toSet())
assertEquals(pausedStatuses, checkpointStorage.getPausedCheckpoints().map { it.second.status }.toSet())
assertEquals(customStatuses, checkpointStorage.getCheckpoints(customStatuses).map { it.second.status }.toSet())
assertEquals(customStatuses1, checkpointStorage.getCheckpoints(customStatuses1).map { it.second.status }.toSet())
}
}
@ -749,6 +778,78 @@ class DBCheckpointStorageTests {
else -> throw IllegalStateException("Unknown line.separator")
}
@Test(timeout = 300_000)
fun `paused checkpoints can be extracted`() {
val (id, checkpoint) = newCheckpoint()
val serializedFlowState = checkpoint.serializeFlowState()
val pausedCheckpoint = checkpoint.copy(status = Checkpoint.FlowStatus.PAUSED)
database.transaction {
checkpointStorage.addCheckpoint(id, pausedCheckpoint, serializedFlowState)
}
database.transaction {
val (extractedId, extractedCheckpoint) = checkpointStorage.getPausedCheckpoints().toList().single()
assertEquals(id, extractedId)
//We don't extract the result or the flowstate from a paused checkpoint
assertEquals(null, extractedCheckpoint.serializedFlowState)
assertEquals(null, extractedCheckpoint.result)
assertEquals(pausedCheckpoint.status, extractedCheckpoint.status)
assertEquals(pausedCheckpoint.progressStep, extractedCheckpoint.progressStep)
assertEquals(pausedCheckpoint.flowIoRequest, extractedCheckpoint.flowIoRequest)
val deserialisedCheckpoint = extractedCheckpoint.deserialize()
assertEquals(pausedCheckpoint.checkpointState, deserialisedCheckpoint.checkpointState)
assertEquals(FlowState.Paused, deserialisedCheckpoint.flowState)
}
}
@Test(timeout = 300_000)
fun `checkpoints correctly change there status to paused`() {
val (_, checkpoint) = newCheckpoint(1)
// runnables
val runnable = changeStatus(checkpoint, Checkpoint.FlowStatus.RUNNABLE)
val hospitalized = changeStatus(checkpoint, Checkpoint.FlowStatus.HOSPITALIZED)
// not runnables
val completed = changeStatus(checkpoint, Checkpoint.FlowStatus.COMPLETED)
val failed = changeStatus(checkpoint, Checkpoint.FlowStatus.FAILED)
val killed = changeStatus(checkpoint, Checkpoint.FlowStatus.KILLED)
// paused
val paused = changeStatus(checkpoint, Checkpoint.FlowStatus.PAUSED)
database.transaction {
val serializedFlowState =
checkpoint.flowState.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT)
checkpointStorage.addCheckpoint(runnable.id, runnable.checkpoint, serializedFlowState)
checkpointStorage.addCheckpoint(hospitalized.id, hospitalized.checkpoint, serializedFlowState)
checkpointStorage.addCheckpoint(completed.id, completed.checkpoint, serializedFlowState)
checkpointStorage.addCheckpoint(failed.id, failed.checkpoint, serializedFlowState)
checkpointStorage.addCheckpoint(killed.id, killed.checkpoint, serializedFlowState)
checkpointStorage.addCheckpoint(paused.id, paused.checkpoint, serializedFlowState)
}
database.transaction {
checkpointStorage.markAllPaused()
}
database.transaction {
//Hospitalised and paused checkpoints status should update
assertEquals(Checkpoint.FlowStatus.PAUSED, checkpointStorage.getDBCheckpoint(runnable.id)!!.status)
assertEquals(Checkpoint.FlowStatus.PAUSED, checkpointStorage.getDBCheckpoint(hospitalized.id)!!.status)
//Other checkpoints should not be updated
assertEquals(Checkpoint.FlowStatus.COMPLETED, checkpointStorage.getDBCheckpoint(completed.id)!!.status)
assertEquals(Checkpoint.FlowStatus.FAILED, checkpointStorage.getDBCheckpoint(failed.id)!!.status)
assertEquals(Checkpoint.FlowStatus.KILLED, checkpointStorage.getDBCheckpoint(killed.id)!!.status)
assertEquals(Checkpoint.FlowStatus.PAUSED, checkpointStorage.getDBCheckpoint(paused.id)!!.status)
}
}
data class IdAndCheckpoint(val id: StateMachineRunId, val checkpoint: Checkpoint)
private fun changeStatus(oldCheckpoint: Checkpoint, status: Checkpoint.FlowStatus): IdAndCheckpoint {
return IdAndCheckpoint(StateMachineRunId.createRandom(), oldCheckpoint.copy(status = status))
}
private fun newCheckpointStorage() {
database.transaction {
checkpointStorage = DBCheckpointStorage(

View File

@ -682,14 +682,14 @@ class FlowFrameworkTests {
if (firstExecution) {
throw HospitalizeFlowException()
} else {
dbCheckpointStatusBeforeSuspension = aliceNode.internals.checkpointStorage.getAllCheckpoints().toList().single().second.status
dbCheckpointStatusBeforeSuspension = aliceNode.internals.checkpointStorage.getCheckpoints().toList().single().second.status
inMemoryCheckpointStatusBeforeSuspension = flowFiber.transientState!!.value.checkpoint.status
futureFiber.complete(flowFiber)
}
}
SuspendingFlow.hookAfterCheckpoint = {
dbCheckpointStatusAfterSuspension = aliceNode.internals.checkpointStorage.getRunnableCheckpoints().toList().single()
dbCheckpointStatusAfterSuspension = aliceNode.internals.checkpointStorage.getCheckpointsToRun().toList().single()
.second.status
}
@ -701,7 +701,7 @@ class FlowFrameworkTests {
val inMemoryHospitalizedCheckpointStatus = aliceNode.internals.smm.snapshot().first().transientState?.value?.checkpoint?.status
assertEquals(Checkpoint.FlowStatus.HOSPITALIZED, inMemoryHospitalizedCheckpointStatus)
aliceNode.database.transaction {
val checkpoint = aliceNode.internals.checkpointStorage.getAllCheckpoints().toList().single().second
val checkpoint = aliceNode.internals.checkpointStorage.getCheckpoints().toList().single().second
assertEquals(Checkpoint.FlowStatus.HOSPITALIZED, checkpoint.status)
}
// restart Node - flow will be loaded from checkpoint
@ -729,7 +729,7 @@ class FlowFrameworkTests {
if (firstExecution) {
throw HospitalizeFlowException()
} else {
dbCheckpointStatus = aliceNode.internals.checkpointStorage.getAllCheckpoints().toList().single().second.status
dbCheckpointStatus = aliceNode.internals.checkpointStorage.getCheckpoints().toList().single().second.status
inMemoryCheckpointStatus = flowFiber.transientState!!.value.checkpoint.status
futureFiber.complete(flowFiber)
@ -742,7 +742,7 @@ class FlowFrameworkTests {
// flow is in hospital
assertTrue(flowState is FlowState.Started)
aliceNode.database.transaction {
val checkpoint = aliceNode.internals.checkpointStorage.getAllCheckpoints().toList().single().second
val checkpoint = aliceNode.internals.checkpointStorage.getCheckpoints().toList().single().second
assertEquals(Checkpoint.FlowStatus.HOSPITALIZED, checkpoint.status)
}
// restart Node - flow will be loaded from checkpoint
@ -812,7 +812,7 @@ class FlowFrameworkTests {
throw SQLTransientConnectionException("connection is not available")
} else {
val flowFiber = this as? FlowStateMachineImpl<*>
dbCheckpointStatus = aliceNode.internals.checkpointStorage.getAllCheckpoints().toList().single().second.status
dbCheckpointStatus = aliceNode.internals.checkpointStorage.getCheckpoints().toList().single().second.status
inMemoryCheckpointStatus = flowFiber!!.transientState!!.value.checkpoint.status
persistedException = aliceNode.internals.checkpointStorage.getDBCheckpoint(flowFiber.id)!!.exceptionDetails
}

View File

@ -0,0 +1,77 @@
package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Suspendable
import com.nhaarman.mockito_kotlin.doReturn
import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.flows.FlowLogic
import net.corda.core.internal.FlowStateMachine
import net.corda.node.services.config.NodeConfiguration
import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.BOB_NAME
import net.corda.testing.node.internal.InternalMockNetwork
import net.corda.testing.node.internal.InternalMockNodeParameters
import net.corda.testing.node.internal.TestStartedNode
import net.corda.testing.node.internal.startFlow
import org.junit.After
import org.junit.Before
import org.junit.Test
import java.time.Duration
import kotlin.test.assertEquals
class FlowPausingTests {
companion object {
const val NUMBER_OF_FLOWS = 4
const val SLEEP_TIME = 1000L
}
private lateinit var mockNet: InternalMockNetwork
private lateinit var aliceNode: TestStartedNode
private lateinit var bobNode: TestStartedNode
@Before
fun setUpMockNet() {
mockNet = InternalMockNetwork()
aliceNode = mockNet.createNode(InternalMockNodeParameters(legalName = ALICE_NAME))
bobNode = mockNet.createNode(InternalMockNodeParameters(legalName = BOB_NAME))
}
@After
fun cleanUp() {
mockNet.stopNodes()
}
private fun restartNode(node: TestStartedNode, smmStartMode: StateMachineManager.StartMode) : TestStartedNode {
val parameters = InternalMockNodeParameters(configOverrides = {
conf: NodeConfiguration ->
doReturn(smmStartMode).whenever(conf).smmStartMode
})
return mockNet.restartNode(node, parameters = parameters)
}
@Test(timeout = 300_000)
fun `All are paused when the node is restarted in safe start mode`() {
val flows = ArrayList<FlowStateMachine<Unit>>()
for (i in 1..NUMBER_OF_FLOWS) {
flows += aliceNode.services.startFlow(CheckpointingFlow())
}
//All of the flows must not resume before the node restarts.
val restartedAlice = restartNode(aliceNode, StateMachineManager.StartMode.Safe)
assertEquals(0, restartedAlice.smm.snapshot().size)
//We need to wait long enough here so any running flows would finish.
Thread.sleep(NUMBER_OF_FLOWS * SLEEP_TIME)
restartedAlice.database.transaction {
for (flow in flows) {
val checkpoint = restartedAlice.internals.checkpointStorage.getCheckpoint(flow.id)
assertEquals(Checkpoint.FlowStatus.PAUSED, checkpoint!!.status)
}
}
}
internal class CheckpointingFlow: FlowLogic<Unit>() {
@Suspendable
override fun call() {
sleep(Duration.ofMillis(SLEEP_TIME))
}
}
}

View File

@ -638,6 +638,7 @@ private fun mockNodeConfiguration(certificatesDirectory: Path): NodeConfiguratio
doReturn(NetworkParameterAcceptanceSettings()).whenever(it).networkParameterAcceptanceSettings
doReturn(rigorousMock<ConfigurationWithOptions>()).whenever(it).configurationWithOptions
doReturn(2).whenever(it).flowExternalOperationThreadPoolSize
doReturn(StateMachineManager.StartMode.ExcludingPaused).whenever(it).smmStartMode
}
}