mirror of
https://github.com/corda/corda.git
synced 2025-04-16 07:27:17 +00:00
CORDA-3490 Add option to start node without starting checkpointed flows (#6136)
Added command-line option: `--pause-all-flows` to the Node to control this. This mode causes all checkpoints to be set to status PAUSED when the state machine starts up (in StartMode.Safe mode). Changed the state machine so that PAUSED checkpoints are loaded into memory (the checkpoint is deserialised but the flow state is left serialised) but not started. Messages from peers are queued whilst the flow is paused and processed once the flow is resumed.
This commit is contained in:
parent
356172c370
commit
eb52de1b40
@ -9,4 +9,4 @@ package net.corda.common.logging
|
||||
* (originally added to source control for ease of use)
|
||||
*/
|
||||
|
||||
internal const val CURRENT_MAJOR_RELEASE = "4.6-SNAPSHOT"
|
||||
internal const val CURRENT_MAJOR_RELEASE = "4.6-SNAPSHOT"
|
@ -1,5 +1,6 @@
|
||||
package net.corda.core.internal.messaging
|
||||
|
||||
import net.corda.core.flows.StateMachineRunId
|
||||
import net.corda.core.internal.AttachmentTrustInfo
|
||||
import net.corda.core.messaging.CordaRPCOps
|
||||
|
||||
@ -13,4 +14,11 @@ interface InternalCordaRPCOps : CordaRPCOps {
|
||||
|
||||
/** Get all attachment trust information */
|
||||
val attachmentTrustInfos: List<AttachmentTrustInfo>
|
||||
|
||||
/**
|
||||
* Resume a paused flow.
|
||||
*
|
||||
* @return whether the flow was successfully resumed.
|
||||
*/
|
||||
fun unPauseFlow(id: StateMachineRunId): Boolean
|
||||
}
|
@ -0,0 +1,113 @@
|
||||
package net.corda.node.services.statemachine
|
||||
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import net.corda.core.flows.FlowLogic
|
||||
import net.corda.core.flows.FlowSession
|
||||
import net.corda.core.flows.InitiatedBy
|
||||
import net.corda.core.flows.InitiatingFlow
|
||||
import net.corda.core.flows.StartableByRPC
|
||||
import net.corda.core.flows.StateMachineRunId
|
||||
import net.corda.core.identity.Party
|
||||
import net.corda.core.internal.messaging.InternalCordaRPCOps
|
||||
import net.corda.core.messaging.CordaRPCOps
|
||||
import net.corda.core.messaging.startFlow
|
||||
import net.corda.core.utilities.getOrThrow
|
||||
import net.corda.core.utilities.unwrap
|
||||
import net.corda.node.services.Permissions
|
||||
import net.corda.testing.core.ALICE_NAME
|
||||
import net.corda.testing.core.BOB_NAME
|
||||
import net.corda.testing.driver.DriverParameters
|
||||
import net.corda.testing.driver.NodeParameters
|
||||
import net.corda.testing.driver.driver
|
||||
import net.corda.testing.node.User
|
||||
import org.junit.Test
|
||||
import java.time.Duration
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertNotNull
|
||||
|
||||
class FlowPausingTest {
|
||||
|
||||
companion object {
|
||||
val TOTAL_MESSAGES = 100
|
||||
val SLEEP_BETWEEN_MESSAGES_MS = 10L
|
||||
}
|
||||
|
||||
@Test(timeout = 300_000)
|
||||
fun `Paused flows can recieve session messages`() {
|
||||
val rpcUser = User("demo", "demo", setOf(Permissions.startFlow<HardRestartTest.Ping>(), Permissions.all()))
|
||||
driver(DriverParameters(startNodesInProcess = true, inMemoryDB = false)) {
|
||||
val alice = startNode(NodeParameters(providedName = ALICE_NAME, rpcUsers = listOf(rpcUser))).getOrThrow()
|
||||
val bob = startNode(NodeParameters(providedName = BOB_NAME, rpcUsers = listOf(rpcUser)))
|
||||
val startedBob = bob.getOrThrow()
|
||||
val aliceFlow = alice.rpc.startFlow(::HeartbeatFlow, startedBob.nodeInfo.legalIdentities[0])
|
||||
// We wait here for the initiated flow to start running on bob
|
||||
val initiatedFlowId = startedBob.rpc.waitForFlowToStart(150)
|
||||
assertNotNull(initiatedFlowId)
|
||||
/* We shut down bob, we want this to happen before bob has finished receiving all of the heartbeats.
|
||||
This is a Race but if bob finishes too quickly then we will fail to unpause the initiated flow running on BOB latter
|
||||
and this test will fail.*/
|
||||
startedBob.stop()
|
||||
//Start bob backup in Safe mode. This means no flows will run but BOB should receive messages and queue these up.
|
||||
val restartedBob = startNode(NodeParameters(
|
||||
providedName = BOB_NAME,
|
||||
rpcUsers = listOf(rpcUser),
|
||||
customOverrides = mapOf("smmStartMode" to "Safe"))).getOrThrow()
|
||||
|
||||
//Sleep for long enough so BOB has time to receive all the messages.
|
||||
//All messages in this period should be queued up and replayed when the flow is unpaused.
|
||||
Thread.sleep(TOTAL_MESSAGES * SLEEP_BETWEEN_MESSAGES_MS)
|
||||
//ALICE should not have finished yet as the HeartbeatResponderFlow should not have sent the final message back (as it is paused).
|
||||
assertEquals(false, aliceFlow.returnValue.isDone)
|
||||
assertEquals(true, (restartedBob.rpc as InternalCordaRPCOps).unPauseFlow(initiatedFlowId!!))
|
||||
|
||||
assertEquals(true, aliceFlow.returnValue.getOrThrow())
|
||||
alice.stop()
|
||||
restartedBob.stop()
|
||||
}
|
||||
}
|
||||
|
||||
fun CordaRPCOps.waitForFlowToStart(maxTrys: Int): StateMachineRunId? {
|
||||
for (i in 1..maxTrys) {
|
||||
val snapshot = this.stateMachinesSnapshot().singleOrNull()
|
||||
if (snapshot == null) {
|
||||
Thread.sleep(SLEEP_BETWEEN_MESSAGES_MS)
|
||||
} else {
|
||||
return snapshot.id
|
||||
}
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
@StartableByRPC
|
||||
@InitiatingFlow
|
||||
class HeartbeatFlow(private val otherParty: Party): FlowLogic<Boolean>() {
|
||||
var sequenceNumber = 0
|
||||
@Suspendable
|
||||
override fun call(): Boolean {
|
||||
val session = initiateFlow(otherParty)
|
||||
for (i in 1..TOTAL_MESSAGES) {
|
||||
session.send(sequenceNumber++)
|
||||
sleep(Duration.ofMillis(10))
|
||||
}
|
||||
val success = session.receive<Boolean>().unwrap{data -> data}
|
||||
return success
|
||||
}
|
||||
}
|
||||
|
||||
@InitiatedBy(HeartbeatFlow::class)
|
||||
class HeartbeatResponderFlow(val session: FlowSession): FlowLogic<Unit>() {
|
||||
var sequenceNumber : Int = 0
|
||||
@Suspendable
|
||||
override fun call() {
|
||||
var pass = true
|
||||
for (i in 1..TOTAL_MESSAGES) {
|
||||
val receivedSequenceNumber = session.receive<Int>().unwrap{data -> data}
|
||||
if (receivedSequenceNumber != sequenceNumber) {
|
||||
pass = false
|
||||
}
|
||||
sequenceNumber++
|
||||
}
|
||||
session.send(pass)
|
||||
}
|
||||
}
|
||||
}
|
@ -14,6 +14,7 @@ import net.corda.node.services.config.ConfigHelper
|
||||
import net.corda.node.services.config.NodeConfiguration
|
||||
import net.corda.node.services.config.Valid
|
||||
import net.corda.node.services.config.parseAsNodeConfiguration
|
||||
import net.corda.node.services.statemachine.StateMachineManager
|
||||
import net.corda.nodeapi.internal.config.UnknownConfigKeysPolicy
|
||||
import picocli.CommandLine.Option
|
||||
import java.nio.file.Path
|
||||
@ -48,6 +49,12 @@ open class SharedNodeCmdLineOptions {
|
||||
)
|
||||
var devMode: Boolean? = null
|
||||
|
||||
@Option(
|
||||
names = ["--pause-all-flows"],
|
||||
description = ["Do not run any flows on startup. Sets all flows to paused, which can be unpaused via RPC."]
|
||||
)
|
||||
var safeMode: Boolean = false
|
||||
|
||||
open fun parseConfiguration(configuration: Config): Valid<NodeConfiguration> {
|
||||
val option = Configuration.Options(strict = unknownConfigKeysPolicy == UnknownConfigKeysPolicy.FAIL)
|
||||
return configuration.parseAsNodeConfiguration(option)
|
||||
@ -186,6 +193,9 @@ open class NodeCmdLineOptions : SharedNodeCmdLineOptions() {
|
||||
devMode?.let {
|
||||
configOverrides += "devMode" to it
|
||||
}
|
||||
if (safeMode) {
|
||||
configOverrides += "smmStartMode" to StateMachineManager.StartMode.Safe.toString()
|
||||
}
|
||||
return try {
|
||||
valid(ConfigHelper.loadConfig(baseDirectory, configFile, configOverrides = ConfigFactory.parseMap(configOverrides)))
|
||||
} catch (e: ConfigException) {
|
||||
|
@ -543,7 +543,7 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
|
||||
tokenizableServices = null
|
||||
|
||||
verifyCheckpointsCompatible(frozenTokenizableServices)
|
||||
val smmStartedFuture = smm.start(frozenTokenizableServices)
|
||||
val smmStartedFuture = smm.start(frozenTokenizableServices, configuration.smmStartMode)
|
||||
// Shut down the SMM so no Fibers are scheduled.
|
||||
runOnStop += { smm.stop(acceptableLiveFiberCountOnStop()) }
|
||||
val flowMonitor = FlowMonitor(
|
||||
@ -1379,4 +1379,4 @@ fun clientSslOptionsCompatibleWith(nodeRpcOptions: NodeRpcOptions): ClientRpcSsl
|
||||
}
|
||||
// Here we're using the node's RPC key store as the RPC client's trust store.
|
||||
return ClientRpcSslOptions(trustStorePath = nodeRpcOptions.sslConfig!!.keyStorePath, trustStorePassword = nodeRpcOptions.sslConfig!!.keyStorePassword)
|
||||
}
|
||||
}
|
||||
|
@ -35,7 +35,7 @@ object CheckpointVerifier {
|
||||
|
||||
val cordappsByHash = currentCordapps.associateBy { it.jarHash }
|
||||
|
||||
checkpointStorage.getRunnableCheckpoints().use {
|
||||
checkpointStorage.getCheckpointsToRun().use {
|
||||
it.forEach { (_, serializedCheckpoint) ->
|
||||
val checkpoint = try {
|
||||
serializedCheckpoint.deserialize(checkpointSerializationContext)
|
||||
|
@ -169,6 +169,8 @@ internal class CordaRPCOpsImpl(
|
||||
|
||||
override fun killFlow(id: StateMachineRunId): Boolean = smm.killFlow(id)
|
||||
|
||||
override fun unPauseFlow(id: StateMachineRunId): Boolean = smm.unPauseFlow(id)
|
||||
|
||||
override fun stateMachinesFeed(): DataFeed<List<StateMachineInfo>, StateMachineUpdate> {
|
||||
|
||||
val (allStateMachines, changes) = smm.track()
|
||||
|
@ -20,6 +20,12 @@ interface CheckpointStorage {
|
||||
*/
|
||||
fun updateCheckpoint(id: StateMachineRunId, checkpoint: Checkpoint, serializedFlowState: SerializedBytes<FlowState>?)
|
||||
|
||||
/**
|
||||
* Update all persisted checkpoints with status [Checkpoint.FlowStatus.RUNNABLE] or [Checkpoint.FlowStatus.HOSPITALIZED],
|
||||
* changing the status to [Checkpoint.FlowStatus.PAUSED].
|
||||
*/
|
||||
fun markAllPaused()
|
||||
|
||||
/**
|
||||
* Remove existing checkpoint from the store.
|
||||
* @return whether the id matched a checkpoint that was removed.
|
||||
@ -37,14 +43,23 @@ interface CheckpointStorage {
|
||||
fun getCheckpoint(id: StateMachineRunId): Checkpoint.Serialized?
|
||||
|
||||
/**
|
||||
* Stream all checkpoints from the store. If this is backed by a database the stream will be valid until the
|
||||
* underlying database connection is closed, so any processing should happen before it is closed.
|
||||
* Stream all checkpoints with statuses [statuses] from the store. If this is backed by a database the stream will be valid
|
||||
* until the underlying database connection is closed, so any processing should happen before it is closed.
|
||||
*/
|
||||
fun getAllCheckpoints(): Stream<Pair<StateMachineRunId, Checkpoint.Serialized>>
|
||||
fun getCheckpoints(
|
||||
statuses: Collection<Checkpoint.FlowStatus> = Checkpoint.FlowStatus.values().toSet()
|
||||
): Stream<Pair<StateMachineRunId, Checkpoint.Serialized>>
|
||||
|
||||
/**
|
||||
* Stream runnable checkpoints from the store. If this is backed by a database the stream will be valid
|
||||
* until the underlying database connection is closed, so any processing should happen before it is closed.
|
||||
*/
|
||||
fun getRunnableCheckpoints(): Stream<Pair<StateMachineRunId, Checkpoint.Serialized>>
|
||||
fun getCheckpointsToRun(): Stream<Pair<StateMachineRunId, Checkpoint.Serialized>>
|
||||
|
||||
/**
|
||||
* Stream paused checkpoints from the store. If this is backed by a database the stream will be valid
|
||||
* until the underlying database connection is closed, so any processing should happen before it is closed.
|
||||
* This method does not fetch [Checkpoint.Serialized.serializedFlowState] to save memory.
|
||||
*/
|
||||
fun getPausedCheckpoints(): Stream<Pair<StateMachineRunId, Checkpoint.Serialized>>
|
||||
}
|
||||
|
@ -11,6 +11,7 @@ import net.corda.core.internal.notary.NotaryServiceFlow
|
||||
import net.corda.core.utilities.NetworkHostAndPort
|
||||
import net.corda.node.services.config.rpc.NodeRpcOptions
|
||||
import net.corda.node.services.config.schema.v1.V1NodeConfigurationSpec
|
||||
import net.corda.node.services.statemachine.StateMachineManager
|
||||
import net.corda.nodeapi.internal.config.FileBasedCertificateStoreSupplier
|
||||
import net.corda.nodeapi.internal.config.MutualSslConfiguration
|
||||
import net.corda.nodeapi.internal.config.User
|
||||
@ -93,6 +94,8 @@ interface NodeConfiguration : ConfigurationWithOptionsContainer {
|
||||
|
||||
val quasarExcludePackages: List<String>
|
||||
|
||||
val smmStartMode: StateMachineManager.StartMode
|
||||
|
||||
companion object {
|
||||
// default to at least 8MB and a bit extra for larger heap sizes
|
||||
val defaultTransactionCacheSize: Long = 8.MB + getAdditionalCacheMemory()
|
||||
|
@ -8,6 +8,7 @@ import net.corda.core.utilities.NetworkHostAndPort
|
||||
import net.corda.core.utilities.loggerFor
|
||||
import net.corda.core.utilities.seconds
|
||||
import net.corda.node.services.config.rpc.NodeRpcOptions
|
||||
import net.corda.node.services.statemachine.StateMachineManager
|
||||
import net.corda.nodeapi.BrokerRpcSslOptions
|
||||
import net.corda.nodeapi.internal.DEV_PUB_KEY_HASHES
|
||||
import net.corda.nodeapi.internal.config.FileBasedCertificateStoreSupplier
|
||||
@ -84,7 +85,8 @@ data class NodeConfigurationImpl(
|
||||
override val blacklistedAttachmentSigningKeys: List<String> = Defaults.blacklistedAttachmentSigningKeys,
|
||||
override val configurationWithOptions: ConfigurationWithOptions,
|
||||
override val flowExternalOperationThreadPoolSize: Int = Defaults.flowExternalOperationThreadPoolSize,
|
||||
override val quasarExcludePackages: List<String> = Defaults.quasarExcludePackages
|
||||
override val quasarExcludePackages: List<String> = Defaults.quasarExcludePackages,
|
||||
override val smmStartMode : StateMachineManager.StartMode = Defaults.smmStartMode
|
||||
) : NodeConfiguration {
|
||||
internal object Defaults {
|
||||
val jmxMonitoringHttpPort: Int? = null
|
||||
@ -123,6 +125,7 @@ data class NodeConfigurationImpl(
|
||||
val blacklistedAttachmentSigningKeys: List<String> = emptyList()
|
||||
const val flowExternalOperationThreadPoolSize: Int = 1
|
||||
val quasarExcludePackages: List<String> = emptyList()
|
||||
val smmStartMode : StateMachineManager.StartMode = StateMachineManager.StartMode.ExcludingPaused
|
||||
|
||||
fun cordappsDirectories(baseDirectory: Path) = listOf(baseDirectory / CORDAPPS_DIR_NAME_DEFAULT)
|
||||
|
||||
|
@ -9,6 +9,7 @@ import net.corda.common.validation.internal.Validated.Companion.valid
|
||||
import net.corda.node.services.config.*
|
||||
import net.corda.node.services.config.NodeConfigurationImpl.Defaults
|
||||
import net.corda.node.services.config.schema.parsers.*
|
||||
import net.corda.node.services.statemachine.StateMachineManager
|
||||
|
||||
internal object V1NodeConfigurationSpec : Configuration.Specification<NodeConfiguration>("NodeConfiguration") {
|
||||
private val myLegalName by string().mapValid(::toCordaX500Name)
|
||||
@ -66,6 +67,7 @@ internal object V1NodeConfigurationSpec : Configuration.Specification<NodeConfig
|
||||
.withDefaultValue(Defaults.networkParameterAcceptanceSettings)
|
||||
private val flowExternalOperationThreadPoolSize by int().optional().withDefaultValue(Defaults.flowExternalOperationThreadPoolSize)
|
||||
private val quasarExcludePackages by string().list().optional().withDefaultValue(Defaults.quasarExcludePackages)
|
||||
private val smmStartMode by enum(StateMachineManager.StartMode::class).optional().withDefaultValue(Defaults.smmStartMode)
|
||||
@Suppress("unused")
|
||||
private val custom by nestedObject().optional()
|
||||
@Suppress("unused")
|
||||
@ -133,7 +135,8 @@ internal object V1NodeConfigurationSpec : Configuration.Specification<NodeConfig
|
||||
networkParameterAcceptanceSettings = config[networkParameterAcceptanceSettings],
|
||||
configurationWithOptions = ConfigurationWithOptions(configuration, Configuration.Options.defaults),
|
||||
flowExternalOperationThreadPoolSize = config[flowExternalOperationThreadPoolSize],
|
||||
quasarExcludePackages = config[quasarExcludePackages]
|
||||
quasarExcludePackages = config[quasarExcludePackages],
|
||||
smmStartMode = config[smmStartMode]
|
||||
))
|
||||
} catch (e: Exception) {
|
||||
return when (e) {
|
||||
|
@ -59,7 +59,7 @@ class DBCheckpointStorage(
|
||||
private const val MAX_FLOW_NAME_LENGTH = 128
|
||||
private const val MAX_PROGRESS_STEP_LENGTH = 256
|
||||
|
||||
private val NOT_RUNNABLE_CHECKPOINTS = listOf(FlowStatus.COMPLETED, FlowStatus.FAILED, FlowStatus.KILLED)
|
||||
private val RUNNABLE_CHECKPOINTS = setOf(FlowStatus.RUNNABLE, FlowStatus.HOSPITALIZED)
|
||||
|
||||
/**
|
||||
* This needs to run before Hibernate is initialised.
|
||||
@ -281,6 +281,15 @@ class DBCheckpointStorage(
|
||||
currentDBSession().update(updateDBCheckpoint(id, checkpoint, serializedFlowState))
|
||||
}
|
||||
|
||||
override fun markAllPaused() {
|
||||
val session = currentDBSession()
|
||||
val runnableOrdinals = RUNNABLE_CHECKPOINTS.map{ "${it.ordinal}"}.joinToString { it }
|
||||
val sqlQuery = "Update ${NODE_DATABASE_PREFIX}checkpoints set status = ${FlowStatus.PAUSED.ordinal} " +
|
||||
"where status in ($runnableOrdinals)"
|
||||
val query = session.createNativeQuery(sqlQuery)
|
||||
query.executeUpdate()
|
||||
}
|
||||
|
||||
override fun removeCheckpoint(id: StateMachineRunId): Boolean {
|
||||
// This will be changed after performance tuning
|
||||
return currentDBSession().let { session ->
|
||||
@ -300,33 +309,39 @@ class DBCheckpointStorage(
|
||||
return getDBCheckpoint(id)?.toSerializedCheckpoint()
|
||||
}
|
||||
|
||||
override fun getAllCheckpoints(): Stream<Pair<StateMachineRunId, Checkpoint.Serialized>> {
|
||||
val session = currentDBSession()
|
||||
val criteriaQuery = session.criteriaBuilder.createQuery(DBFlowCheckpoint::class.java)
|
||||
val root = criteriaQuery.from(DBFlowCheckpoint::class.java)
|
||||
criteriaQuery.select(root)
|
||||
return session.createQuery(criteriaQuery).stream().map {
|
||||
StateMachineRunId(UUID.fromString(it.id)) to it.toSerializedCheckpoint()
|
||||
}
|
||||
}
|
||||
|
||||
override fun getRunnableCheckpoints(): Stream<Pair<StateMachineRunId, Checkpoint.Serialized>> {
|
||||
override fun getCheckpoints(statuses: Collection<FlowStatus>): Stream<Pair<StateMachineRunId, Checkpoint.Serialized>> {
|
||||
val session = currentDBSession()
|
||||
val criteriaBuilder = session.criteriaBuilder
|
||||
val criteriaQuery = criteriaBuilder.createQuery(DBFlowCheckpoint::class.java)
|
||||
val root = criteriaQuery.from(DBFlowCheckpoint::class.java)
|
||||
criteriaQuery.select(root)
|
||||
.where(criteriaBuilder.not(root.get<FlowStatus>(DBFlowCheckpoint::status.name).`in`(NOT_RUNNABLE_CHECKPOINTS)))
|
||||
.where(criteriaBuilder.isTrue(root.get<FlowStatus>(DBFlowCheckpoint::status.name).`in`(statuses)))
|
||||
return session.createQuery(criteriaQuery).stream().map {
|
||||
StateMachineRunId(UUID.fromString(it.id)) to it.toSerializedCheckpoint()
|
||||
}
|
||||
}
|
||||
|
||||
override fun getCheckpointsToRun(): Stream<Pair<StateMachineRunId, Checkpoint.Serialized>> {
|
||||
return getCheckpoints(RUNNABLE_CHECKPOINTS)
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
internal fun getDBCheckpoint(id: StateMachineRunId): DBFlowCheckpoint? {
|
||||
return currentDBSession().find(DBFlowCheckpoint::class.java, id.uuid.toString())
|
||||
}
|
||||
|
||||
override fun getPausedCheckpoints(): Stream<Pair<StateMachineRunId, Checkpoint.Serialized>> {
|
||||
val session = currentDBSession()
|
||||
val jpqlQuery = """select new ${DBPausedFields::class.java.name}(checkpoint.id, blob.checkpoint, checkpoint.status,
|
||||
checkpoint.progressStep, checkpoint.ioRequestType, checkpoint.compatible) from ${DBFlowCheckpoint::class.java.name}
|
||||
checkpoint join ${DBFlowCheckpointBlob::class.java.name} blob on checkpoint.blob = blob.id where
|
||||
checkpoint.status = ${FlowStatus.PAUSED.ordinal}""".trimIndent()
|
||||
val query = session.createQuery(jpqlQuery, DBPausedFields::class.java)
|
||||
return query.resultList.stream().map {
|
||||
StateMachineRunId(UUID.fromString(it.id)) to it.toSerializedCheckpoint()
|
||||
}
|
||||
}
|
||||
|
||||
private fun createDBCheckpoint(
|
||||
id: StateMachineRunId,
|
||||
checkpoint: Checkpoint,
|
||||
@ -542,6 +557,29 @@ class DBCheckpointStorage(
|
||||
)
|
||||
}
|
||||
|
||||
private class DBPausedFields(
|
||||
val id: String,
|
||||
val checkpoint: ByteArray = EMPTY_BYTE_ARRAY,
|
||||
val status: FlowStatus,
|
||||
val progressStep: String?,
|
||||
val ioRequestType: String?,
|
||||
val compatible: Boolean
|
||||
) {
|
||||
fun toSerializedCheckpoint(): Checkpoint.Serialized {
|
||||
return Checkpoint.Serialized(
|
||||
serializedCheckpointState = SerializedBytes(checkpoint),
|
||||
serializedFlowState = null,
|
||||
// Always load as a [Clean] checkpoint to represent that the checkpoint is the last _good_ checkpoint
|
||||
errorState = ErrorState.Clean,
|
||||
result = null,
|
||||
status = status,
|
||||
progressStep = progressStep,
|
||||
flowIoRequest = ioRequestType,
|
||||
compatible = compatible
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
private fun <T : Any> T.storageSerialize(): SerializedBytes<T> {
|
||||
return serialize(context = SerializationDefaults.STORAGE_CONTEXT)
|
||||
}
|
||||
|
@ -90,6 +90,11 @@ class CheckpointDumperImpl(private val checkpointStorage: CheckpointStorage, pri
|
||||
companion object {
|
||||
internal val TIME_FORMATTER = DateTimeFormatter.ofPattern("yyyyMMdd-HHmmss").withZone(UTC)
|
||||
private val log = contextLogger()
|
||||
private val DUMPABLE_CHECKPOINTS = setOf(
|
||||
Checkpoint.FlowStatus.RUNNABLE,
|
||||
Checkpoint.FlowStatus.HOSPITALIZED,
|
||||
Checkpoint.FlowStatus.PAUSED
|
||||
)
|
||||
}
|
||||
|
||||
override val priority: Int = SERVICE_PRIORITY_NORMAL
|
||||
@ -141,7 +146,7 @@ class CheckpointDumperImpl(private val checkpointStorage: CheckpointStorage, pri
|
||||
try {
|
||||
if (lock.getAndIncrement() == 0 && !file.exists()) {
|
||||
database.transaction {
|
||||
checkpointStorage.getRunnableCheckpoints().use { stream ->
|
||||
checkpointStorage.getCheckpoints(DUMPABLE_CHECKPOINTS).use { stream ->
|
||||
ZipOutputStream(file.outputStream()).use { zip ->
|
||||
stream.forEach { (runId, serialisedCheckpoint) ->
|
||||
|
||||
@ -204,7 +209,7 @@ class CheckpointDumperImpl(private val checkpointStorage: CheckpointStorage, pri
|
||||
val fiber = flowState.frozenFiber.checkpointDeserialize(context = checkpointSerializationContext)
|
||||
fiber to fiber.logic
|
||||
}
|
||||
is FlowState.Completed -> {
|
||||
else -> {
|
||||
throw IllegalStateException("Only runnable checkpoints with their flow stack are output by the checkpoint dumper")
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,191 @@
|
||||
package net.corda.node.services.statemachine
|
||||
|
||||
import co.paralleluniverse.fibers.FiberScheduler
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import co.paralleluniverse.strands.channels.Channels
|
||||
import net.corda.core.concurrent.CordaFuture
|
||||
import net.corda.core.context.InvocationContext
|
||||
import net.corda.core.flows.FlowException
|
||||
import net.corda.core.flows.FlowLogic
|
||||
import net.corda.core.flows.StateMachineRunId
|
||||
import net.corda.core.identity.Party
|
||||
import net.corda.core.internal.concurrent.OpenFuture
|
||||
import net.corda.core.internal.concurrent.openFuture
|
||||
import net.corda.core.serialization.SerializedBytes
|
||||
import net.corda.core.serialization.internal.CheckpointSerializationContext
|
||||
import net.corda.core.serialization.internal.checkpointDeserialize
|
||||
import net.corda.core.serialization.internal.checkpointSerialize
|
||||
import net.corda.core.utilities.contextLogger
|
||||
import net.corda.node.services.api.CheckpointStorage
|
||||
import net.corda.node.services.api.ServiceHubInternal
|
||||
import net.corda.node.services.messaging.DeduplicationHandler
|
||||
import net.corda.node.services.statemachine.transitions.StateMachine
|
||||
import net.corda.node.utilities.isEnabledTimedFlow
|
||||
import net.corda.nodeapi.internal.persistence.CordaPersistence
|
||||
import org.apache.activemq.artemis.utils.ReusableLatch
|
||||
import java.security.SecureRandom
|
||||
|
||||
class Flow<A>(val fiber: FlowStateMachineImpl<A>, val resultFuture: OpenFuture<Any?>)
|
||||
|
||||
class NonResidentFlow(val runId: StateMachineRunId, val checkpoint: Checkpoint) {
|
||||
val externalEvents = mutableListOf<Event.DeliverSessionMessage>()
|
||||
|
||||
fun addExternalEvent(message: Event.DeliverSessionMessage) {
|
||||
externalEvents.add(message)
|
||||
}
|
||||
}
|
||||
|
||||
class FlowCreator(
|
||||
val checkpointSerializationContext: CheckpointSerializationContext,
|
||||
private val checkpointStorage: CheckpointStorage,
|
||||
val scheduler: FiberScheduler,
|
||||
val database: CordaPersistence,
|
||||
val transitionExecutor: TransitionExecutor,
|
||||
val actionExecutor: ActionExecutor,
|
||||
val secureRandom: SecureRandom,
|
||||
val serviceHub: ServiceHubInternal,
|
||||
val unfinishedFibers: ReusableLatch,
|
||||
val resetCustomTimeout: (StateMachineRunId, Long) -> Unit) {
|
||||
|
||||
companion object {
|
||||
private val logger = contextLogger()
|
||||
}
|
||||
|
||||
fun createFlowFromNonResidentFlow(nonResidentFlow: NonResidentFlow): Flow<*>? {
|
||||
// As for paused flows we don't extract the serialized flow state we need to re-extract the checkpoint from the database.
|
||||
val checkpoint = when (nonResidentFlow.checkpoint.status) {
|
||||
Checkpoint.FlowStatus.PAUSED -> {
|
||||
val serialized = database.transaction {
|
||||
checkpointStorage.getCheckpoint(nonResidentFlow.runId)
|
||||
}
|
||||
serialized?.copy(status = Checkpoint.FlowStatus.RUNNABLE)?.deserialize(checkpointSerializationContext) ?: return null
|
||||
}
|
||||
else -> nonResidentFlow.checkpoint
|
||||
}
|
||||
return createFlowFromCheckpoint(nonResidentFlow.runId, checkpoint)
|
||||
}
|
||||
|
||||
fun createFlowFromCheckpoint(runId: StateMachineRunId, oldCheckpoint: Checkpoint): Flow<*>? {
|
||||
val checkpoint = oldCheckpoint.copy(status = Checkpoint.FlowStatus.RUNNABLE)
|
||||
val fiber = checkpoint.getFiberFromCheckpoint(runId) ?: return null
|
||||
val resultFuture = openFuture<Any?>()
|
||||
fiber.transientValues = TransientReference(createTransientValues(runId, resultFuture))
|
||||
fiber.logic.stateMachine = fiber
|
||||
verifyFlowLogicIsSuspendable(fiber.logic)
|
||||
val state = createStateMachineState(checkpoint, fiber, true)
|
||||
fiber.transientState = TransientReference(state)
|
||||
return Flow(fiber, resultFuture)
|
||||
}
|
||||
|
||||
@Suppress("LongParameterList")
|
||||
fun <A> createFlowFromLogic(
|
||||
flowId: StateMachineRunId,
|
||||
invocationContext: InvocationContext,
|
||||
flowLogic: FlowLogic<A>,
|
||||
flowStart: FlowStart,
|
||||
ourIdentity: Party,
|
||||
existingCheckpoint: Checkpoint?,
|
||||
deduplicationHandler: DeduplicationHandler?,
|
||||
senderUUID: String?): Flow<A> {
|
||||
// Before we construct the state machine state by freezing the FlowLogic we need to make sure that lazy properties
|
||||
// have access to the fiber (and thereby the service hub)
|
||||
val flowStateMachineImpl = FlowStateMachineImpl(flowId, flowLogic, scheduler)
|
||||
val resultFuture = openFuture<Any?>()
|
||||
flowStateMachineImpl.transientValues = TransientReference(createTransientValues(flowId, resultFuture))
|
||||
flowLogic.stateMachine = flowStateMachineImpl
|
||||
val frozenFlowLogic = (flowLogic as FlowLogic<*>).checkpointSerialize(context = checkpointSerializationContext)
|
||||
val flowCorDappVersion = FlowStateMachineImpl.createSubFlowVersion(
|
||||
serviceHub.cordappProvider.getCordappForFlow(flowLogic), serviceHub.myInfo.platformVersion)
|
||||
|
||||
val checkpoint = existingCheckpoint?.copy(status = Checkpoint.FlowStatus.RUNNABLE) ?: Checkpoint.create(
|
||||
invocationContext,
|
||||
flowStart,
|
||||
flowLogic.javaClass,
|
||||
frozenFlowLogic,
|
||||
ourIdentity,
|
||||
flowCorDappVersion,
|
||||
flowLogic.isEnabledTimedFlow()
|
||||
).getOrThrow()
|
||||
|
||||
val state = createStateMachineState(
|
||||
checkpoint,
|
||||
flowStateMachineImpl,
|
||||
existingCheckpoint != null,
|
||||
deduplicationHandler,
|
||||
senderUUID)
|
||||
flowStateMachineImpl.transientState = TransientReference(state)
|
||||
return Flow(flowStateMachineImpl, resultFuture)
|
||||
}
|
||||
|
||||
private fun Checkpoint.getFiberFromCheckpoint(runId: StateMachineRunId): FlowStateMachineImpl<*>? {
|
||||
return when (this.flowState) {
|
||||
is FlowState.Unstarted -> {
|
||||
val logic = tryCheckpointDeserialize(this.flowState.frozenFlowLogic, runId) ?: return null
|
||||
FlowStateMachineImpl(runId, logic, scheduler)
|
||||
}
|
||||
is FlowState.Started -> tryCheckpointDeserialize(this.flowState.frozenFiber, runId) ?: return null
|
||||
// Places calling this function is rely on it to return null if the flow cannot be created from the checkpoint.
|
||||
else -> {
|
||||
return null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Suppress("TooGenericExceptionCaught")
|
||||
private inline fun <reified T : Any> tryCheckpointDeserialize(bytes: SerializedBytes<T>, flowId: StateMachineRunId): T? {
|
||||
return try {
|
||||
bytes.checkpointDeserialize(context = checkpointSerializationContext)
|
||||
} catch (e: Exception) {
|
||||
logger.error("Unable to deserialize checkpoint for flow $flowId. Something is very wrong and this flow will be ignored.", e)
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
private fun verifyFlowLogicIsSuspendable(logic: FlowLogic<Any?>) {
|
||||
// Quasar requires (in Java 8) that at least the call method be annotated suspendable. Unfortunately, it's
|
||||
// easy to forget to add this when creating a new flow, so we check here to give the user a better error.
|
||||
//
|
||||
// The Kotlin compiler can sometimes generate a synthetic bridge method from a single call declaration, which
|
||||
// forwards to the void method and then returns Unit. However annotations do not get copied across to this
|
||||
// bridge, so we have to do a more complex scan here.
|
||||
val call = logic.javaClass.methods.first { !it.isSynthetic && it.name == "call" && it.parameterCount == 0 }
|
||||
if (call.getAnnotation(Suspendable::class.java) == null) {
|
||||
throw FlowException("${logic.javaClass.name}.call() is not annotated as @Suspendable. Please fix this.")
|
||||
}
|
||||
}
|
||||
|
||||
private fun createTransientValues(id: StateMachineRunId, resultFuture: CordaFuture<Any?>): FlowStateMachineImpl.TransientValues {
|
||||
return FlowStateMachineImpl.TransientValues(
|
||||
eventQueue = Channels.newChannel(-1, Channels.OverflowPolicy.BLOCK),
|
||||
resultFuture = resultFuture,
|
||||
database = database,
|
||||
transitionExecutor = transitionExecutor,
|
||||
actionExecutor = actionExecutor,
|
||||
stateMachine = StateMachine(id, secureRandom),
|
||||
serviceHub = serviceHub,
|
||||
checkpointSerializationContext = checkpointSerializationContext,
|
||||
unfinishedFibers = unfinishedFibers,
|
||||
waitTimeUpdateHook = { flowId, timeout -> resetCustomTimeout(flowId, timeout) }
|
||||
)
|
||||
}
|
||||
|
||||
private fun createStateMachineState(
|
||||
checkpoint: Checkpoint,
|
||||
fiber: FlowStateMachineImpl<*>,
|
||||
anyCheckpointPersisted: Boolean,
|
||||
deduplicationHandler: DeduplicationHandler? = null,
|
||||
senderUUID: String? = null): StateMachineState {
|
||||
return StateMachineState(
|
||||
checkpoint = checkpoint,
|
||||
pendingDeduplicationHandlers = deduplicationHandler?.let { listOf(it) } ?: emptyList(),
|
||||
isFlowResumed = false,
|
||||
future = null,
|
||||
isWaitingForFuture = false,
|
||||
isAnyCheckpointPersisted = anyCheckpointPersisted,
|
||||
isStartIdempotent = false,
|
||||
isRemoved = false,
|
||||
isKilled = false,
|
||||
flowLogic = fiber.logic,
|
||||
senderUUID = senderUUID)
|
||||
}
|
||||
}
|
@ -2,9 +2,7 @@ package net.corda.node.services.statemachine
|
||||
|
||||
import co.paralleluniverse.fibers.Fiber
|
||||
import co.paralleluniverse.fibers.FiberExecutorScheduler
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import co.paralleluniverse.fibers.instrument.JavaAgent
|
||||
import co.paralleluniverse.strands.channels.Channels
|
||||
import com.codahale.metrics.Gauge
|
||||
import com.google.common.util.concurrent.ThreadFactoryBuilder
|
||||
import net.corda.core.concurrent.CordaFuture
|
||||
@ -24,12 +22,9 @@ import net.corda.core.internal.concurrent.mapError
|
||||
import net.corda.core.internal.concurrent.openFuture
|
||||
import net.corda.core.internal.mapNotNull
|
||||
import net.corda.core.messaging.DataFeed
|
||||
import net.corda.core.serialization.SerializedBytes
|
||||
import net.corda.core.serialization.deserialize
|
||||
import net.corda.core.serialization.internal.CheckpointSerializationContext
|
||||
import net.corda.core.serialization.internal.CheckpointSerializationDefaults
|
||||
import net.corda.core.serialization.internal.checkpointDeserialize
|
||||
import net.corda.core.serialization.internal.checkpointSerialize
|
||||
import net.corda.core.utilities.ProgressTracker
|
||||
import net.corda.core.utilities.Try
|
||||
import net.corda.core.utilities.contextLogger
|
||||
@ -39,13 +34,11 @@ import net.corda.node.services.api.CheckpointStorage
|
||||
import net.corda.node.services.api.ServiceHubInternal
|
||||
import net.corda.node.services.config.shouldCheckCheckpoints
|
||||
import net.corda.node.services.messaging.DeduplicationHandler
|
||||
import net.corda.node.services.statemachine.FlowStateMachineImpl.Companion.createSubFlowVersion
|
||||
import net.corda.node.services.statemachine.interceptors.DumpHistoryOnErrorInterceptor
|
||||
import net.corda.node.services.statemachine.interceptors.FiberDeserializationChecker
|
||||
import net.corda.node.services.statemachine.interceptors.FiberDeserializationCheckingInterceptor
|
||||
import net.corda.node.services.statemachine.interceptors.HospitalisingInterceptor
|
||||
import net.corda.node.services.statemachine.interceptors.PrintingInterceptor
|
||||
import net.corda.node.services.statemachine.transitions.StateMachine
|
||||
import net.corda.node.utilities.AffinityExecutor
|
||||
import net.corda.node.utilities.errorAndTerminate
|
||||
import net.corda.node.utilities.injectOldProgressTracker
|
||||
@ -61,6 +54,7 @@ import java.lang.Integer.min
|
||||
import java.security.SecureRandom
|
||||
import java.time.Duration
|
||||
import java.util.HashSet
|
||||
import java.util.*
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import java.util.concurrent.ExecutorService
|
||||
import java.util.concurrent.Executors
|
||||
@ -90,8 +84,6 @@ class SingleThreadedStateMachineManager(
|
||||
private val logger = contextLogger()
|
||||
}
|
||||
|
||||
private class Flow(val fiber: FlowStateMachineImpl<*>, val resultFuture: OpenFuture<Any?>)
|
||||
|
||||
private data class ScheduledTimeout(
|
||||
/** Will fire a [FlowTimeoutException] indicating to the flow hospital to restart the flow. */
|
||||
val scheduledFuture: ScheduledFuture<*>,
|
||||
@ -105,7 +97,8 @@ class SingleThreadedStateMachineManager(
|
||||
val changesPublisher = PublishSubject.create<StateMachineManager.Change>()!!
|
||||
/** True if we're shutting down, so don't resume anything. */
|
||||
var stopping = false
|
||||
val flows = HashMap<StateMachineRunId, Flow>()
|
||||
val flows = HashMap<StateMachineRunId, Flow<*>>()
|
||||
val pausedFlows = HashMap<StateMachineRunId, NonResidentFlow>()
|
||||
val startedFutures = HashMap<StateMachineRunId, OpenFuture<Unit>>()
|
||||
/** Flows scheduled to be retried if not finished within the specified timeout period. */
|
||||
val timedFlows = HashMap<StateMachineRunId, ScheduledTimeout>()
|
||||
@ -127,7 +120,7 @@ class SingleThreadedStateMachineManager(
|
||||
private val ourSenderUUID = serviceHub.networkService.ourSenderUUID
|
||||
|
||||
private var checkpointSerializationContext: CheckpointSerializationContext? = null
|
||||
private var actionExecutor: ActionExecutor? = null
|
||||
private lateinit var flowCreator: FlowCreator
|
||||
|
||||
override val flowHospital: StaffedFlowHospital = makeFlowHospital()
|
||||
private val transitionExecutor = makeTransitionExecutor()
|
||||
@ -146,7 +139,7 @@ class SingleThreadedStateMachineManager(
|
||||
*/
|
||||
override val changes: Observable<StateMachineManager.Change> = mutex.content.changesPublisher
|
||||
|
||||
override fun start(tokenizableServices: List<Any>) : CordaFuture<Unit> {
|
||||
override fun start(tokenizableServices: List<Any>, startMode: StateMachineManager.StartMode): CordaFuture<Unit> {
|
||||
checkQuasarJavaAgentPresence()
|
||||
val checkpointSerializationContext = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext(
|
||||
CheckpointSerializeAsTokenContextImpl(
|
||||
@ -157,8 +150,24 @@ class SingleThreadedStateMachineManager(
|
||||
)
|
||||
)
|
||||
this.checkpointSerializationContext = checkpointSerializationContext
|
||||
this.actionExecutor = makeActionExecutor(checkpointSerializationContext)
|
||||
val actionExecutor = makeActionExecutor(checkpointSerializationContext)
|
||||
fiberDeserializationChecker?.start(checkpointSerializationContext)
|
||||
when (startMode) {
|
||||
StateMachineManager.StartMode.ExcludingPaused -> {}
|
||||
StateMachineManager.StartMode.Safe -> markAllFlowsAsPaused()
|
||||
}
|
||||
this.flowCreator = FlowCreator(
|
||||
checkpointSerializationContext,
|
||||
checkpointStorage,
|
||||
scheduler,
|
||||
database,
|
||||
transitionExecutor,
|
||||
actionExecutor,
|
||||
secureRandom,
|
||||
serviceHub,
|
||||
unfinishedFibers,
|
||||
::resetCustomTimeout)
|
||||
|
||||
val fibers = restoreFlowsFromCheckpoints()
|
||||
metrics.register("Flows.InFlight", Gauge<Int> { mutex.content.flows.size })
|
||||
Fiber.setDefaultUncaughtExceptionHandler { fiber, throwable ->
|
||||
@ -168,6 +177,17 @@ class SingleThreadedStateMachineManager(
|
||||
(fiber as FlowStateMachineImpl<*>).logger.warn("Caught exception from flow", throwable)
|
||||
}
|
||||
}
|
||||
|
||||
val pausedFlows = restoreNonResidentFlowsFromPausedCheckpoints()
|
||||
mutex.locked {
|
||||
this.pausedFlows.putAll(pausedFlows)
|
||||
for ((id, flow) in pausedFlows) {
|
||||
val checkpoint = flow.checkpoint
|
||||
for (sessionId in getFlowSessionIds(checkpoint)) {
|
||||
sessionToFlow[sessionId] = id
|
||||
}
|
||||
}
|
||||
}
|
||||
return serviceHub.networkMapCache.nodeReady.map {
|
||||
logger.info("Node ready, info: ${serviceHub.myInfo}")
|
||||
resumeRestoredFlows(fibers)
|
||||
@ -241,8 +261,7 @@ class SingleThreadedStateMachineManager(
|
||||
flowLogic = flowLogic,
|
||||
flowStart = FlowStart.Explicit,
|
||||
ourIdentity = ourIdentity ?: ourFirstIdentity,
|
||||
deduplicationHandler = deduplicationHandler,
|
||||
isStartIdempotent = false
|
||||
deduplicationHandler = deduplicationHandler
|
||||
)
|
||||
}
|
||||
|
||||
@ -282,6 +301,22 @@ class SingleThreadedStateMachineManager(
|
||||
}
|
||||
}
|
||||
|
||||
private fun markAllFlowsAsPaused() {
|
||||
return checkpointStorage.markAllPaused()
|
||||
}
|
||||
|
||||
override fun unPauseFlow(id: StateMachineRunId): Boolean {
|
||||
mutex.locked {
|
||||
val pausedFlow = pausedFlows.remove(id) ?: return false
|
||||
val flow = flowCreator.createFlowFromNonResidentFlow(pausedFlow) ?: return false
|
||||
addAndStartFlow(flow.fiber.id, flow)
|
||||
for (event in pausedFlow.externalEvents) {
|
||||
flow.fiber.scheduleEvent(event)
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
override fun addSessionBinding(flowId: StateMachineRunId, sessionId: SessionId) {
|
||||
val previousFlowId = sessionToFlow.put(sessionId, flowId)
|
||||
if (previousFlowId != null) {
|
||||
@ -352,23 +387,28 @@ class SingleThreadedStateMachineManager(
|
||||
liveFibers.countUp()
|
||||
}
|
||||
|
||||
private fun restoreFlowsFromCheckpoints(): List<Flow> {
|
||||
return checkpointStorage.getRunnableCheckpoints().use {
|
||||
private fun restoreFlowsFromCheckpoints(): List<Flow<*>> {
|
||||
return checkpointStorage.getCheckpointsToRun().use {
|
||||
it.mapNotNull { (id, serializedCheckpoint) ->
|
||||
// If a flow is added before start() then don't attempt to restore it
|
||||
mutex.locked { if (id in flows) return@mapNotNull null }
|
||||
createFlowFromCheckpoint(
|
||||
id = id,
|
||||
serializedCheckpoint = serializedCheckpoint,
|
||||
initialDeduplicationHandler = null,
|
||||
isAnyCheckpointPersisted = true,
|
||||
isStartIdempotent = false
|
||||
)
|
||||
val checkpoint = tryDeserializeCheckpoint(serializedCheckpoint, id) ?: return@mapNotNull null
|
||||
flowCreator.createFlowFromCheckpoint(id, checkpoint)
|
||||
}.toList()
|
||||
}
|
||||
}
|
||||
|
||||
private fun resumeRestoredFlows(flows: List<Flow>) {
|
||||
private fun restoreNonResidentFlowsFromPausedCheckpoints(): Map<StateMachineRunId, NonResidentFlow> {
|
||||
return checkpointStorage.getPausedCheckpoints().use {
|
||||
it.mapNotNull { (id, serializedCheckpoint) ->
|
||||
// If a flow is added before start() then don't attempt to restore it
|
||||
val checkpoint = tryDeserializeCheckpoint(serializedCheckpoint, id) ?: return@mapNotNull null
|
||||
id to NonResidentFlow(id, checkpoint)
|
||||
}.toList().toMap()
|
||||
}
|
||||
}
|
||||
|
||||
private fun resumeRestoredFlows(flows: List<Flow<*>>) {
|
||||
for (flow in flows) {
|
||||
addAndStartFlow(flow.fiber.id, flow)
|
||||
}
|
||||
@ -393,14 +433,10 @@ class SingleThreadedStateMachineManager(
|
||||
logger.error("Unable to find database checkpoint for flow $flowId. Something is very wrong. The flow will not retry.")
|
||||
return
|
||||
}
|
||||
|
||||
val checkpoint = tryDeserializeCheckpoint(serializedCheckpoint, flowId) ?: return
|
||||
// Resurrect flow
|
||||
createFlowFromCheckpoint(
|
||||
id = flowId,
|
||||
serializedCheckpoint = serializedCheckpoint,
|
||||
initialDeduplicationHandler = null,
|
||||
isAnyCheckpointPersisted = true,
|
||||
isStartIdempotent = false
|
||||
) ?: return
|
||||
flowCreator.createFlowFromCheckpoint(flowId, checkpoint) ?: return
|
||||
} else {
|
||||
// Just flow initiation message
|
||||
null
|
||||
@ -503,9 +539,13 @@ class SingleThreadedStateMachineManager(
|
||||
logger.info("Cannot find flow corresponding to session ID - $recipientId.")
|
||||
}
|
||||
} else {
|
||||
mutex.locked { flows[flowId] }?.run {
|
||||
fiber.scheduleEvent(Event.DeliverSessionMessage(sessionMessage, deduplicationHandler, sender))
|
||||
} ?: logger.info("Cannot find fiber corresponding to flow ID $flowId")
|
||||
val event = Event.DeliverSessionMessage(sessionMessage, deduplicationHandler, sender)
|
||||
mutex.locked {
|
||||
flows[flowId]?.run { fiber.scheduleEvent(event) }
|
||||
// If flow is not running add it to the list of external events to be processed if/when the flow resumes.
|
||||
?: pausedFlows[flowId]?.run { addExternalEvent(event) }
|
||||
?: logger.info("Cannot find fiber corresponding to flow ID $flowId")
|
||||
}
|
||||
}
|
||||
} catch (exception: Exception) {
|
||||
logger.error("Exception while routing $sessionMessage", exception)
|
||||
@ -582,8 +622,7 @@ class SingleThreadedStateMachineManager(
|
||||
flowLogic,
|
||||
flowStart,
|
||||
ourIdentity,
|
||||
initiatingMessageDeduplicationHandler,
|
||||
isStartIdempotent = false
|
||||
initiatingMessageDeduplicationHandler
|
||||
)
|
||||
}
|
||||
|
||||
@ -594,20 +633,9 @@ class SingleThreadedStateMachineManager(
|
||||
flowLogic: FlowLogic<A>,
|
||||
flowStart: FlowStart,
|
||||
ourIdentity: Party,
|
||||
deduplicationHandler: DeduplicationHandler?,
|
||||
isStartIdempotent: Boolean
|
||||
deduplicationHandler: DeduplicationHandler?
|
||||
): CordaFuture<FlowStateMachine<A>> {
|
||||
|
||||
// Before we construct the state machine state by freezing the FlowLogic we need to make sure that lazy properties
|
||||
// have access to the fiber (and thereby the service hub)
|
||||
val flowStateMachineImpl = FlowStateMachineImpl(flowId, flowLogic, scheduler)
|
||||
val resultFuture = openFuture<Any?>()
|
||||
flowStateMachineImpl.transientValues = TransientReference(createTransientValues(flowId, resultFuture))
|
||||
flowLogic.stateMachine = flowStateMachineImpl
|
||||
val frozenFlowLogic = (flowLogic as FlowLogic<*>).checkpointSerialize(context = checkpointSerializationContext!!)
|
||||
|
||||
val flowCorDappVersion = createSubFlowVersion(serviceHub.cordappProvider.getCordappForFlow(flowLogic), serviceHub.myInfo.platformVersion)
|
||||
|
||||
val flowAlreadyExists = mutex.locked { flows[flowId] != null }
|
||||
|
||||
val existingCheckpoint = if (flowAlreadyExists) {
|
||||
@ -629,37 +657,15 @@ class SingleThreadedStateMachineManager(
|
||||
// This is a brand new flow
|
||||
null
|
||||
}
|
||||
val checkpoint = existingCheckpoint?.copy(status = Checkpoint.FlowStatus.RUNNABLE) ?: Checkpoint.create(
|
||||
invocationContext,
|
||||
flowStart,
|
||||
flowLogic.javaClass,
|
||||
frozenFlowLogic,
|
||||
ourIdentity,
|
||||
flowCorDappVersion,
|
||||
flowLogic.isEnabledTimedFlow()
|
||||
).getOrThrow()
|
||||
|
||||
val flow = flowCreator.createFlowFromLogic(flowId, invocationContext, flowLogic, flowStart, ourIdentity, existingCheckpoint, deduplicationHandler, ourSenderUUID)
|
||||
val startedFuture = openFuture<Unit>()
|
||||
val initialState = StateMachineState(
|
||||
checkpoint = checkpoint,
|
||||
pendingDeduplicationHandlers = deduplicationHandler?.let { listOf(it) } ?: emptyList(),
|
||||
isFlowResumed = false,
|
||||
isWaitingForFuture = false,
|
||||
future = null,
|
||||
isAnyCheckpointPersisted = existingCheckpoint != null,
|
||||
isStartIdempotent = isStartIdempotent,
|
||||
isRemoved = false,
|
||||
isKilled = false,
|
||||
flowLogic = flowLogic,
|
||||
senderUUID = ourSenderUUID
|
||||
)
|
||||
flowStateMachineImpl.transientState = TransientReference(initialState)
|
||||
mutex.locked {
|
||||
startedFutures[flowId] = startedFuture
|
||||
}
|
||||
totalStartedFlows.inc()
|
||||
addAndStartFlow(flowId, Flow(flowStateMachineImpl, resultFuture))
|
||||
return startedFuture.map { flowStateMachineImpl as FlowStateMachine<A> }
|
||||
addAndStartFlow(flowId, flow)
|
||||
return startedFuture.map { flow.fiber as FlowStateMachine<A> }
|
||||
}
|
||||
|
||||
override fun scheduleFlowTimeout(flowId: StateMachineRunId) {
|
||||
@ -739,7 +745,7 @@ class SingleThreadedStateMachineManager(
|
||||
}
|
||||
|
||||
/** Schedules a [FlowTimeoutException] to be fired in order to restart the flow. */
|
||||
private fun scheduleTimeoutException(flow: Flow, delay: Long): ScheduledFuture<*> {
|
||||
private fun scheduleTimeoutException(flow: Flow<*>, delay: Long): ScheduledFuture<*> {
|
||||
return with(serviceHub.configuration.flowTimeout) {
|
||||
scheduledFutureExecutor.schedule({
|
||||
val event = Event.Error(FlowTimeoutException())
|
||||
@ -767,43 +773,6 @@ class SingleThreadedStateMachineManager(
|
||||
}
|
||||
}
|
||||
|
||||
private fun verifyFlowLogicIsSuspendable(logic: FlowLogic<Any?>) {
|
||||
// Quasar requires (in Java 8) that at least the call method be annotated suspendable. Unfortunately, it's
|
||||
// easy to forget to add this when creating a new flow, so we check here to give the user a better error.
|
||||
//
|
||||
// The Kotlin compiler can sometimes generate a synthetic bridge method from a single call declaration, which
|
||||
// forwards to the void method and then returns Unit. However annotations do not get copied across to this
|
||||
// bridge, so we have to do a more complex scan here.
|
||||
val call = logic.javaClass.methods.first { !it.isSynthetic && it.name == "call" && it.parameterCount == 0 }
|
||||
if (call.getAnnotation(Suspendable::class.java) == null) {
|
||||
throw FlowException("${logic.javaClass.name}.call() is not annotated as @Suspendable. Please fix this.")
|
||||
}
|
||||
}
|
||||
|
||||
private fun createTransientValues(id: StateMachineRunId, resultFuture: CordaFuture<Any?>): FlowStateMachineImpl.TransientValues {
|
||||
return FlowStateMachineImpl.TransientValues(
|
||||
eventQueue = Channels.newChannel(-1, Channels.OverflowPolicy.BLOCK),
|
||||
resultFuture = resultFuture,
|
||||
database = database,
|
||||
transitionExecutor = transitionExecutor,
|
||||
actionExecutor = actionExecutor!!,
|
||||
stateMachine = StateMachine(id, secureRandom),
|
||||
serviceHub = serviceHub,
|
||||
checkpointSerializationContext = checkpointSerializationContext!!,
|
||||
unfinishedFibers = unfinishedFibers,
|
||||
waitTimeUpdateHook = { flowId, timeout -> resetCustomTimeout(flowId, timeout) }
|
||||
)
|
||||
}
|
||||
|
||||
private inline fun <reified T : Any> tryCheckpointDeserialize(bytes: SerializedBytes<T>, flowId: StateMachineRunId): T? {
|
||||
return try {
|
||||
bytes.checkpointDeserialize(context = checkpointSerializationContext!!)
|
||||
} catch (e: Exception) {
|
||||
logger.error("Unable to deserialize checkpoint for flow $flowId. Something is very wrong and this flow will be ignored.", e)
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
private fun tryDeserializeCheckpoint(serializedCheckpoint: Checkpoint.Serialized, flowId: StateMachineRunId): Checkpoint? {
|
||||
return try {
|
||||
serializedCheckpoint.deserialize(checkpointSerializationContext!!)
|
||||
@ -813,68 +782,7 @@ class SingleThreadedStateMachineManager(
|
||||
}
|
||||
}
|
||||
|
||||
private fun createFlowFromCheckpoint(
|
||||
id: StateMachineRunId,
|
||||
serializedCheckpoint: Checkpoint.Serialized,
|
||||
isAnyCheckpointPersisted: Boolean,
|
||||
isStartIdempotent: Boolean,
|
||||
initialDeduplicationHandler: DeduplicationHandler?
|
||||
): Flow? {
|
||||
val checkpoint = tryDeserializeCheckpoint(serializedCheckpoint, id)?.copy(status = Checkpoint.FlowStatus.RUNNABLE) ?: return null
|
||||
val resultFuture = openFuture<Any?>()
|
||||
val fiber = when (checkpoint.flowState) {
|
||||
is FlowState.Unstarted -> {
|
||||
val logic = tryCheckpointDeserialize(checkpoint.flowState.frozenFlowLogic, id) ?: return null
|
||||
val state = StateMachineState(
|
||||
checkpoint = checkpoint,
|
||||
pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(),
|
||||
isFlowResumed = false,
|
||||
isWaitingForFuture = false,
|
||||
future = null,
|
||||
isAnyCheckpointPersisted = isAnyCheckpointPersisted,
|
||||
isStartIdempotent = isStartIdempotent,
|
||||
isRemoved = false,
|
||||
isKilled = false,
|
||||
flowLogic = logic,
|
||||
senderUUID = null
|
||||
)
|
||||
val fiber = FlowStateMachineImpl(id, logic, scheduler)
|
||||
fiber.transientValues = TransientReference(createTransientValues(id, resultFuture))
|
||||
fiber.transientState = TransientReference(state)
|
||||
fiber.logic.stateMachine = fiber
|
||||
fiber
|
||||
}
|
||||
is FlowState.Started -> {
|
||||
val fiber = tryCheckpointDeserialize(checkpoint.flowState.frozenFiber, id) ?: return null
|
||||
val state = StateMachineState(
|
||||
checkpoint = checkpoint,
|
||||
pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(),
|
||||
isFlowResumed = false,
|
||||
isWaitingForFuture = false,
|
||||
future = null,
|
||||
isAnyCheckpointPersisted = isAnyCheckpointPersisted,
|
||||
isStartIdempotent = isStartIdempotent,
|
||||
isRemoved = false,
|
||||
isKilled = false,
|
||||
flowLogic = fiber.logic,
|
||||
senderUUID = null
|
||||
)
|
||||
fiber.transientValues = TransientReference(createTransientValues(id, resultFuture))
|
||||
fiber.transientState = TransientReference(state)
|
||||
fiber.logic.stateMachine = fiber
|
||||
fiber
|
||||
}
|
||||
is FlowState.Completed -> {
|
||||
return null // Places calling this function is rely on it to return null if the flow cannot be created from the checkpoint.
|
||||
}
|
||||
}
|
||||
|
||||
verifyFlowLogicIsSuspendable(fiber.logic)
|
||||
|
||||
return Flow(fiber, resultFuture)
|
||||
}
|
||||
|
||||
private fun addAndStartFlow(id: StateMachineRunId, flow: Flow) {
|
||||
private fun addAndStartFlow(id: StateMachineRunId, flow: Flow<*>) {
|
||||
val checkpoint = flow.fiber.snapshot().checkpoint
|
||||
for (sessionId in getFlowSessionIds(checkpoint)) {
|
||||
sessionToFlow[sessionId] = id
|
||||
@ -899,7 +807,7 @@ class SingleThreadedStateMachineManager(
|
||||
}
|
||||
}
|
||||
|
||||
private fun startOrResume(checkpoint: Checkpoint, flow: Flow) {
|
||||
private fun startOrResume(checkpoint: Checkpoint, flow: Flow<*>) {
|
||||
when (checkpoint.flowState) {
|
||||
is FlowState.Unstarted -> {
|
||||
flow.fiber.start()
|
||||
@ -953,7 +861,7 @@ class SingleThreadedStateMachineManager(
|
||||
}
|
||||
|
||||
private fun InnerState.removeFlowOrderly(
|
||||
flow: Flow,
|
||||
flow: Flow<*>,
|
||||
removalReason: FlowRemovalReason.OrderlyFinish,
|
||||
lastState: StateMachineState
|
||||
) {
|
||||
@ -969,7 +877,7 @@ class SingleThreadedStateMachineManager(
|
||||
}
|
||||
|
||||
private fun InnerState.removeFlowError(
|
||||
flow: Flow,
|
||||
flow: Flow<*>,
|
||||
removalReason: FlowRemovalReason.ErrorFinish,
|
||||
lastState: StateMachineState
|
||||
) {
|
||||
@ -983,7 +891,7 @@ class SingleThreadedStateMachineManager(
|
||||
}
|
||||
|
||||
// The flow's event queue may be non-empty in case it shut down abruptly. We handle outstanding events here.
|
||||
private fun drainFlowEventQueue(flow: Flow) {
|
||||
private fun drainFlowEventQueue(flow: Flow<*>) {
|
||||
while (true) {
|
||||
val event = flow.fiber.transientValues!!.value.eventQueue.tryReceive() ?: return
|
||||
when (event) {
|
||||
|
@ -30,12 +30,18 @@ import java.time.Duration
|
||||
* TODO: Don't store all active flows in memory, load from the database on demand.
|
||||
*/
|
||||
interface StateMachineManager {
|
||||
|
||||
enum class StartMode {
|
||||
ExcludingPaused, // Resume all flows except paused flows.
|
||||
Safe // Mark all flows as paused.
|
||||
}
|
||||
|
||||
/**
|
||||
* Starts the state machine manager, loading and starting the state machines in storage.
|
||||
*
|
||||
* @return `Future` which completes when SMM is fully started
|
||||
*/
|
||||
fun start(tokenizableServices: List<Any>) : CordaFuture<Unit>
|
||||
fun start(tokenizableServices: List<Any>, startMode: StartMode = StartMode.ExcludingPaused) : CordaFuture<Unit>
|
||||
|
||||
/**
|
||||
* Stops the state machine manager gracefully, waiting until all but [allowedUnsuspendedFiberCount] flows reach the
|
||||
@ -80,6 +86,13 @@ interface StateMachineManager {
|
||||
*/
|
||||
fun killFlow(id: StateMachineRunId): Boolean
|
||||
|
||||
/**
|
||||
* Start a paused flow.
|
||||
*
|
||||
* @return whether the flow was successfully started.
|
||||
*/
|
||||
fun unPauseFlow(id: StateMachineRunId): Boolean
|
||||
|
||||
/**
|
||||
* Deliver an external event to the state machine. Such an event might be a new P2P message, or a request to start a flow.
|
||||
* The event may be replayed if a flow fails and attempts to retry.
|
||||
|
@ -170,9 +170,14 @@ data class Checkpoint(
|
||||
* @return A [Checkpoint] with all its fields filled in from [Checkpoint.Serialized]
|
||||
*/
|
||||
fun deserialize(checkpointSerializationContext: CheckpointSerializationContext): Checkpoint {
|
||||
val flowState = when(status) {
|
||||
FlowStatus.PAUSED -> FlowState.Paused
|
||||
FlowStatus.COMPLETED -> FlowState.Completed
|
||||
else -> serializedFlowState!!.checkpointDeserialize(checkpointSerializationContext)
|
||||
}
|
||||
return Checkpoint(
|
||||
checkpointState = serializedCheckpointState.deserialize(context = SerializationDefaults.STORAGE_CONTEXT),
|
||||
flowState = serializedFlowState?.checkpointDeserialize(checkpointSerializationContext) ?: FlowState.Completed,
|
||||
flowState = flowState,
|
||||
errorState = errorState,
|
||||
result = result?.deserialize(context = SerializationDefaults.STORAGE_CONTEXT),
|
||||
status = status,
|
||||
@ -307,6 +312,11 @@ sealed class FlowState {
|
||||
override fun toString() = "Started(flowIORequest=$flowIORequest, frozenFiber=${frozenFiber.hash})"
|
||||
}
|
||||
|
||||
/**
|
||||
* The flow is paused. To save memory we don't store the FlowState
|
||||
*/
|
||||
object Paused: FlowState()
|
||||
|
||||
/**
|
||||
* The flow has completed. It does not have a running fiber that needs to be serialized and checkpointed.
|
||||
*/
|
||||
|
@ -29,6 +29,7 @@ class DoRemainingWorkTransition(
|
||||
is FlowState.Unstarted -> UnstartedFlowTransition(context, startingState, flowState).transition()
|
||||
is FlowState.Started -> StartedFlowTransition(context, startingState, flowState).transition()
|
||||
is FlowState.Completed -> throw IllegalStateException("Cannot transition a state with completed flow state.")
|
||||
is FlowState.Paused -> throw IllegalStateException("Cannot transition a state with paused flow state.")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -32,6 +32,7 @@ class NodeStartupCliTest {
|
||||
Assertions.assertThat(startup.verbose).isEqualTo(false)
|
||||
Assertions.assertThat(startup.loggingLevel).isEqualTo(Level.INFO)
|
||||
Assertions.assertThat(startup.cmdLineOptions.noLocalShell).isEqualTo(false)
|
||||
Assertions.assertThat(startup.cmdLineOptions.safeMode).isEqualTo(false)
|
||||
Assertions.assertThat(startup.cmdLineOptions.sshdServer).isEqualTo(false)
|
||||
Assertions.assertThat(startup.cmdLineOptions.justGenerateNodeInfo).isEqualTo(false)
|
||||
Assertions.assertThat(startup.cmdLineOptions.justGenerateRpcSslCerts).isEqualTo(false)
|
||||
|
@ -63,7 +63,7 @@ import kotlin.test.assertFailsWith
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
internal fun CheckpointStorage.getAllIncompleteCheckpoints(): List<Checkpoint.Serialized> {
|
||||
return getRunnableCheckpoints().use {
|
||||
return getCheckpointsToRun().use {
|
||||
it.map { it.second }.toList()
|
||||
}.filter { it.status != Checkpoint.FlowStatus.COMPLETED }
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ import net.corda.core.context.InvocationOrigin
|
||||
import net.corda.core.flows.FlowLogic
|
||||
import net.corda.core.flows.StateMachineRunId
|
||||
import net.corda.core.internal.FlowIORequest
|
||||
import net.corda.core.internal.toSet
|
||||
import net.corda.core.serialization.SerializedBytes
|
||||
import net.corda.core.serialization.internal.CheckpointSerializationDefaults
|
||||
import net.corda.core.serialization.internal.checkpointSerialize
|
||||
@ -41,13 +42,13 @@ import org.junit.Ignore
|
||||
import org.junit.Rule
|
||||
import org.junit.Test
|
||||
import java.time.Clock
|
||||
import java.util.ArrayList
|
||||
import java.util.*
|
||||
import kotlin.streams.toList
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
internal fun CheckpointStorage.checkpoints(): List<Checkpoint.Serialized> {
|
||||
return getAllCheckpoints().use {
|
||||
return getCheckpoints().use {
|
||||
it.map { it.second }.toList()
|
||||
}
|
||||
}
|
||||
@ -148,18 +149,38 @@ class DBCheckpointStorageTests {
|
||||
checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState)
|
||||
}
|
||||
|
||||
val completedCheckpoint = checkpoint.copy(flowState = FlowState.Completed)
|
||||
val completedCheckpoint = checkpoint.copy(status = Checkpoint.FlowStatus.COMPLETED)
|
||||
database.transaction {
|
||||
checkpointStorage.updateCheckpoint(id, completedCheckpoint, null)
|
||||
}
|
||||
database.transaction {
|
||||
assertEquals(
|
||||
completedCheckpoint,
|
||||
completedCheckpoint.copy(flowState = FlowState.Completed),
|
||||
checkpointStorage.checkpoints().single().deserialize()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Test(timeout = 300_000)
|
||||
fun `update a checkpoint to paused`() {
|
||||
val (id, checkpoint) = newCheckpoint()
|
||||
val serializedFlowState = checkpoint.serializeFlowState()
|
||||
database.transaction {
|
||||
checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState)
|
||||
}
|
||||
|
||||
val pausedCheckpoint = checkpoint.copy(status = Checkpoint.FlowStatus.PAUSED)
|
||||
database.transaction {
|
||||
checkpointStorage.updateCheckpoint(id, pausedCheckpoint, null)
|
||||
}
|
||||
database.transaction {
|
||||
assertEquals(
|
||||
pausedCheckpoint.copy(flowState = FlowState.Paused),
|
||||
checkpointStorage.checkpoints().single().deserialize()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Test(timeout = 300_000)
|
||||
fun `removing a checkpoint deletes from all checkpoint tables`() {
|
||||
val exception = IllegalStateException("I am a naughty exception")
|
||||
@ -641,7 +662,7 @@ class DBCheckpointStorageTests {
|
||||
}
|
||||
|
||||
@Test(timeout = 300_000)
|
||||
fun `fetch runnable checkpoints`() {
|
||||
fun `Checkpoints can be fetched from the database by status`() {
|
||||
val (_, checkpoint) = newCheckpoint(1)
|
||||
// runnables
|
||||
val runnable = checkpoint.copy(status = Checkpoint.FlowStatus.RUNNABLE)
|
||||
@ -650,8 +671,8 @@ class DBCheckpointStorageTests {
|
||||
val completed = checkpoint.copy(status = Checkpoint.FlowStatus.COMPLETED)
|
||||
val failed = checkpoint.copy(status = Checkpoint.FlowStatus.FAILED)
|
||||
val killed = checkpoint.copy(status = Checkpoint.FlowStatus.KILLED)
|
||||
// tentative
|
||||
val paused = checkpoint.copy(status = Checkpoint.FlowStatus.PAUSED) // is considered runnable
|
||||
// paused
|
||||
val paused = checkpoint.copy(status = Checkpoint.FlowStatus.PAUSED)
|
||||
|
||||
database.transaction {
|
||||
val serializedFlowState =
|
||||
@ -666,7 +687,15 @@ class DBCheckpointStorageTests {
|
||||
}
|
||||
|
||||
database.transaction {
|
||||
assertEquals(3, checkpointStorage.getRunnableCheckpoints().count())
|
||||
val toRunStatuses = setOf(Checkpoint.FlowStatus.RUNNABLE, Checkpoint.FlowStatus.HOSPITALIZED)
|
||||
val pausedStatuses = setOf(Checkpoint.FlowStatus.PAUSED)
|
||||
val customStatuses = setOf(Checkpoint.FlowStatus.RUNNABLE, Checkpoint.FlowStatus.KILLED)
|
||||
val customStatuses1 = setOf(Checkpoint.FlowStatus.PAUSED, Checkpoint.FlowStatus.HOSPITALIZED, Checkpoint.FlowStatus.FAILED)
|
||||
|
||||
assertEquals(toRunStatuses, checkpointStorage.getCheckpointsToRun().map { it.second.status }.toSet())
|
||||
assertEquals(pausedStatuses, checkpointStorage.getPausedCheckpoints().map { it.second.status }.toSet())
|
||||
assertEquals(customStatuses, checkpointStorage.getCheckpoints(customStatuses).map { it.second.status }.toSet())
|
||||
assertEquals(customStatuses1, checkpointStorage.getCheckpoints(customStatuses1).map { it.second.status }.toSet())
|
||||
}
|
||||
}
|
||||
|
||||
@ -749,6 +778,78 @@ class DBCheckpointStorageTests {
|
||||
else -> throw IllegalStateException("Unknown line.separator")
|
||||
}
|
||||
|
||||
@Test(timeout = 300_000)
|
||||
fun `paused checkpoints can be extracted`() {
|
||||
val (id, checkpoint) = newCheckpoint()
|
||||
val serializedFlowState = checkpoint.serializeFlowState()
|
||||
val pausedCheckpoint = checkpoint.copy(status = Checkpoint.FlowStatus.PAUSED)
|
||||
database.transaction {
|
||||
checkpointStorage.addCheckpoint(id, pausedCheckpoint, serializedFlowState)
|
||||
}
|
||||
|
||||
database.transaction {
|
||||
val (extractedId, extractedCheckpoint) = checkpointStorage.getPausedCheckpoints().toList().single()
|
||||
assertEquals(id, extractedId)
|
||||
//We don't extract the result or the flowstate from a paused checkpoint
|
||||
assertEquals(null, extractedCheckpoint.serializedFlowState)
|
||||
assertEquals(null, extractedCheckpoint.result)
|
||||
|
||||
assertEquals(pausedCheckpoint.status, extractedCheckpoint.status)
|
||||
assertEquals(pausedCheckpoint.progressStep, extractedCheckpoint.progressStep)
|
||||
assertEquals(pausedCheckpoint.flowIoRequest, extractedCheckpoint.flowIoRequest)
|
||||
|
||||
val deserialisedCheckpoint = extractedCheckpoint.deserialize()
|
||||
assertEquals(pausedCheckpoint.checkpointState, deserialisedCheckpoint.checkpointState)
|
||||
assertEquals(FlowState.Paused, deserialisedCheckpoint.flowState)
|
||||
}
|
||||
}
|
||||
|
||||
@Test(timeout = 300_000)
|
||||
fun `checkpoints correctly change there status to paused`() {
|
||||
val (_, checkpoint) = newCheckpoint(1)
|
||||
// runnables
|
||||
val runnable = changeStatus(checkpoint, Checkpoint.FlowStatus.RUNNABLE)
|
||||
val hospitalized = changeStatus(checkpoint, Checkpoint.FlowStatus.HOSPITALIZED)
|
||||
// not runnables
|
||||
val completed = changeStatus(checkpoint, Checkpoint.FlowStatus.COMPLETED)
|
||||
val failed = changeStatus(checkpoint, Checkpoint.FlowStatus.FAILED)
|
||||
val killed = changeStatus(checkpoint, Checkpoint.FlowStatus.KILLED)
|
||||
// paused
|
||||
val paused = changeStatus(checkpoint, Checkpoint.FlowStatus.PAUSED)
|
||||
database.transaction {
|
||||
val serializedFlowState =
|
||||
checkpoint.flowState.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT)
|
||||
|
||||
checkpointStorage.addCheckpoint(runnable.id, runnable.checkpoint, serializedFlowState)
|
||||
checkpointStorage.addCheckpoint(hospitalized.id, hospitalized.checkpoint, serializedFlowState)
|
||||
checkpointStorage.addCheckpoint(completed.id, completed.checkpoint, serializedFlowState)
|
||||
checkpointStorage.addCheckpoint(failed.id, failed.checkpoint, serializedFlowState)
|
||||
checkpointStorage.addCheckpoint(killed.id, killed.checkpoint, serializedFlowState)
|
||||
checkpointStorage.addCheckpoint(paused.id, paused.checkpoint, serializedFlowState)
|
||||
}
|
||||
|
||||
database.transaction {
|
||||
checkpointStorage.markAllPaused()
|
||||
}
|
||||
|
||||
database.transaction {
|
||||
//Hospitalised and paused checkpoints status should update
|
||||
assertEquals(Checkpoint.FlowStatus.PAUSED, checkpointStorage.getDBCheckpoint(runnable.id)!!.status)
|
||||
assertEquals(Checkpoint.FlowStatus.PAUSED, checkpointStorage.getDBCheckpoint(hospitalized.id)!!.status)
|
||||
//Other checkpoints should not be updated
|
||||
assertEquals(Checkpoint.FlowStatus.COMPLETED, checkpointStorage.getDBCheckpoint(completed.id)!!.status)
|
||||
assertEquals(Checkpoint.FlowStatus.FAILED, checkpointStorage.getDBCheckpoint(failed.id)!!.status)
|
||||
assertEquals(Checkpoint.FlowStatus.KILLED, checkpointStorage.getDBCheckpoint(killed.id)!!.status)
|
||||
assertEquals(Checkpoint.FlowStatus.PAUSED, checkpointStorage.getDBCheckpoint(paused.id)!!.status)
|
||||
}
|
||||
}
|
||||
|
||||
data class IdAndCheckpoint(val id: StateMachineRunId, val checkpoint: Checkpoint)
|
||||
|
||||
private fun changeStatus(oldCheckpoint: Checkpoint, status: Checkpoint.FlowStatus): IdAndCheckpoint {
|
||||
return IdAndCheckpoint(StateMachineRunId.createRandom(), oldCheckpoint.copy(status = status))
|
||||
}
|
||||
|
||||
private fun newCheckpointStorage() {
|
||||
database.transaction {
|
||||
checkpointStorage = DBCheckpointStorage(
|
||||
|
@ -682,14 +682,14 @@ class FlowFrameworkTests {
|
||||
if (firstExecution) {
|
||||
throw HospitalizeFlowException()
|
||||
} else {
|
||||
dbCheckpointStatusBeforeSuspension = aliceNode.internals.checkpointStorage.getAllCheckpoints().toList().single().second.status
|
||||
dbCheckpointStatusBeforeSuspension = aliceNode.internals.checkpointStorage.getCheckpoints().toList().single().second.status
|
||||
inMemoryCheckpointStatusBeforeSuspension = flowFiber.transientState!!.value.checkpoint.status
|
||||
|
||||
futureFiber.complete(flowFiber)
|
||||
}
|
||||
}
|
||||
SuspendingFlow.hookAfterCheckpoint = {
|
||||
dbCheckpointStatusAfterSuspension = aliceNode.internals.checkpointStorage.getRunnableCheckpoints().toList().single()
|
||||
dbCheckpointStatusAfterSuspension = aliceNode.internals.checkpointStorage.getCheckpointsToRun().toList().single()
|
||||
.second.status
|
||||
}
|
||||
|
||||
@ -701,7 +701,7 @@ class FlowFrameworkTests {
|
||||
val inMemoryHospitalizedCheckpointStatus = aliceNode.internals.smm.snapshot().first().transientState?.value?.checkpoint?.status
|
||||
assertEquals(Checkpoint.FlowStatus.HOSPITALIZED, inMemoryHospitalizedCheckpointStatus)
|
||||
aliceNode.database.transaction {
|
||||
val checkpoint = aliceNode.internals.checkpointStorage.getAllCheckpoints().toList().single().second
|
||||
val checkpoint = aliceNode.internals.checkpointStorage.getCheckpoints().toList().single().second
|
||||
assertEquals(Checkpoint.FlowStatus.HOSPITALIZED, checkpoint.status)
|
||||
}
|
||||
// restart Node - flow will be loaded from checkpoint
|
||||
@ -729,7 +729,7 @@ class FlowFrameworkTests {
|
||||
if (firstExecution) {
|
||||
throw HospitalizeFlowException()
|
||||
} else {
|
||||
dbCheckpointStatus = aliceNode.internals.checkpointStorage.getAllCheckpoints().toList().single().second.status
|
||||
dbCheckpointStatus = aliceNode.internals.checkpointStorage.getCheckpoints().toList().single().second.status
|
||||
inMemoryCheckpointStatus = flowFiber.transientState!!.value.checkpoint.status
|
||||
|
||||
futureFiber.complete(flowFiber)
|
||||
@ -742,7 +742,7 @@ class FlowFrameworkTests {
|
||||
// flow is in hospital
|
||||
assertTrue(flowState is FlowState.Started)
|
||||
aliceNode.database.transaction {
|
||||
val checkpoint = aliceNode.internals.checkpointStorage.getAllCheckpoints().toList().single().second
|
||||
val checkpoint = aliceNode.internals.checkpointStorage.getCheckpoints().toList().single().second
|
||||
assertEquals(Checkpoint.FlowStatus.HOSPITALIZED, checkpoint.status)
|
||||
}
|
||||
// restart Node - flow will be loaded from checkpoint
|
||||
@ -812,7 +812,7 @@ class FlowFrameworkTests {
|
||||
throw SQLTransientConnectionException("connection is not available")
|
||||
} else {
|
||||
val flowFiber = this as? FlowStateMachineImpl<*>
|
||||
dbCheckpointStatus = aliceNode.internals.checkpointStorage.getAllCheckpoints().toList().single().second.status
|
||||
dbCheckpointStatus = aliceNode.internals.checkpointStorage.getCheckpoints().toList().single().second.status
|
||||
inMemoryCheckpointStatus = flowFiber!!.transientState!!.value.checkpoint.status
|
||||
persistedException = aliceNode.internals.checkpointStorage.getDBCheckpoint(flowFiber.id)!!.exceptionDetails
|
||||
}
|
||||
|
@ -0,0 +1,77 @@
|
||||
package net.corda.node.services.statemachine
|
||||
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import com.nhaarman.mockito_kotlin.doReturn
|
||||
import com.nhaarman.mockito_kotlin.whenever
|
||||
import net.corda.core.flows.FlowLogic
|
||||
import net.corda.core.internal.FlowStateMachine
|
||||
import net.corda.node.services.config.NodeConfiguration
|
||||
import net.corda.testing.core.ALICE_NAME
|
||||
import net.corda.testing.core.BOB_NAME
|
||||
import net.corda.testing.node.internal.InternalMockNetwork
|
||||
import net.corda.testing.node.internal.InternalMockNodeParameters
|
||||
import net.corda.testing.node.internal.TestStartedNode
|
||||
import net.corda.testing.node.internal.startFlow
|
||||
import org.junit.After
|
||||
import org.junit.Before
|
||||
import org.junit.Test
|
||||
import java.time.Duration
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
class FlowPausingTests {
|
||||
|
||||
companion object {
|
||||
const val NUMBER_OF_FLOWS = 4
|
||||
const val SLEEP_TIME = 1000L
|
||||
}
|
||||
|
||||
private lateinit var mockNet: InternalMockNetwork
|
||||
private lateinit var aliceNode: TestStartedNode
|
||||
private lateinit var bobNode: TestStartedNode
|
||||
|
||||
@Before
|
||||
fun setUpMockNet() {
|
||||
mockNet = InternalMockNetwork()
|
||||
aliceNode = mockNet.createNode(InternalMockNodeParameters(legalName = ALICE_NAME))
|
||||
bobNode = mockNet.createNode(InternalMockNodeParameters(legalName = BOB_NAME))
|
||||
}
|
||||
|
||||
@After
|
||||
fun cleanUp() {
|
||||
mockNet.stopNodes()
|
||||
}
|
||||
|
||||
private fun restartNode(node: TestStartedNode, smmStartMode: StateMachineManager.StartMode) : TestStartedNode {
|
||||
val parameters = InternalMockNodeParameters(configOverrides = {
|
||||
conf: NodeConfiguration ->
|
||||
doReturn(smmStartMode).whenever(conf).smmStartMode
|
||||
})
|
||||
return mockNet.restartNode(node, parameters = parameters)
|
||||
}
|
||||
|
||||
@Test(timeout = 300_000)
|
||||
fun `All are paused when the node is restarted in safe start mode`() {
|
||||
val flows = ArrayList<FlowStateMachine<Unit>>()
|
||||
for (i in 1..NUMBER_OF_FLOWS) {
|
||||
flows += aliceNode.services.startFlow(CheckpointingFlow())
|
||||
}
|
||||
//All of the flows must not resume before the node restarts.
|
||||
val restartedAlice = restartNode(aliceNode, StateMachineManager.StartMode.Safe)
|
||||
assertEquals(0, restartedAlice.smm.snapshot().size)
|
||||
//We need to wait long enough here so any running flows would finish.
|
||||
Thread.sleep(NUMBER_OF_FLOWS * SLEEP_TIME)
|
||||
restartedAlice.database.transaction {
|
||||
for (flow in flows) {
|
||||
val checkpoint = restartedAlice.internals.checkpointStorage.getCheckpoint(flow.id)
|
||||
assertEquals(Checkpoint.FlowStatus.PAUSED, checkpoint!!.status)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal class CheckpointingFlow: FlowLogic<Unit>() {
|
||||
@Suspendable
|
||||
override fun call() {
|
||||
sleep(Duration.ofMillis(SLEEP_TIME))
|
||||
}
|
||||
}
|
||||
}
|
@ -638,6 +638,7 @@ private fun mockNodeConfiguration(certificatesDirectory: Path): NodeConfiguratio
|
||||
doReturn(NetworkParameterAcceptanceSettings()).whenever(it).networkParameterAcceptanceSettings
|
||||
doReturn(rigorousMock<ConfigurationWithOptions>()).whenever(it).configurationWithOptions
|
||||
doReturn(2).whenever(it).flowExternalOperationThreadPoolSize
|
||||
doReturn(StateMachineManager.StartMode.ExcludingPaused).whenever(it).smmStartMode
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user