CORDA-3599 Add progress tracker information to checkpoint (#6063)

* Add progress tracker information to checkpoint

The checkpoint Datebase is updated when the statemachine suspends
with the progress trackers current step name. This is truncated if
it is longer than the Database column.

* Minor rename in statemachine for clarity
This commit is contained in:
williamvigorr3
2020-03-16 09:30:23 +00:00
committed by GitHub
parent 0174d996bd
commit 1025ee1dee
6 changed files with 55 additions and 7 deletions

View File

@ -43,6 +43,8 @@ class DBCheckpointStorage(private val checkpointPerformanceRecorder: CheckpointP
private const val HMAC_SIZE_BYTES = 16 private const val HMAC_SIZE_BYTES = 16
private const val MAX_PROGRESS_STEP_LENGTH = 256
/** /**
* This needs to run before Hibernate is initialised. * This needs to run before Hibernate is initialised.
* *
@ -342,7 +344,7 @@ class DBCheckpointStorage(private val checkpointPerformanceRecorder: CheckpointP
this.flowMetadata = entity.flowMetadata this.flowMetadata = entity.flowMetadata
this.status = checkpoint.status this.status = checkpoint.status
this.compatible = checkpoint.compatible this.compatible = checkpoint.compatible
this.progressStep = checkpoint.progressStep this.progressStep = checkpoint.progressStep?.take(MAX_PROGRESS_STEP_LENGTH)
this.ioRequestType = checkpoint.flowIoRequest this.ioRequestType = checkpoint.flowIoRequest
this.checkpointInstant = now this.checkpointInstant = now
} }

View File

@ -6,6 +6,7 @@ import net.corda.core.identity.Party
import net.corda.core.internal.FlowIORequest import net.corda.core.internal.FlowIORequest
import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.SerializedBytes
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.ProgressTracker
import net.corda.node.services.messaging.DeduplicationHandler import net.corda.node.services.messaging.DeduplicationHandler
import java.util.* import java.util.*
@ -101,17 +102,20 @@ sealed class Event {
* @param ioRequest the request triggering the suspension. * @param ioRequest the request triggering the suspension.
* @param maySkipCheckpoint indicates whether the persistence may be skipped. * @param maySkipCheckpoint indicates whether the persistence may be skipped.
* @param fiber the serialised stack of the flow. * @param fiber the serialised stack of the flow.
* @param progressStep the current progress tracker step.
*/ */
data class Suspend( data class Suspend(
val ioRequest: FlowIORequest<*>, val ioRequest: FlowIORequest<*>,
val maySkipCheckpoint: Boolean, val maySkipCheckpoint: Boolean,
val fiber: SerializedBytes<FlowStateMachineImpl<*>> val fiber: SerializedBytes<FlowStateMachineImpl<*>>,
var progressStep: ProgressTracker.Step?
) : Event() { ) : Event() {
override fun toString() = override fun toString() =
"Suspend(" + "Suspend(" +
"ioRequest=$ioRequest, " + "ioRequest=$ioRequest, " +
"maySkipCheckpoint=$maySkipCheckpoint, " + "maySkipCheckpoint=$maySkipCheckpoint, " +
"fiber=${fiber.hash}, " + "fiber=${fiber.hash}, " +
"currentStep=${progressStep?.label}" +
")" ")"
} }

View File

@ -430,7 +430,8 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
Event.Suspend( Event.Suspend(
ioRequest = ioRequest, ioRequest = ioRequest,
maySkipCheckpoint = skipPersistingCheckpoint, maySkipCheckpoint = skipPersistingCheckpoint,
fiber = this.checkpointSerialize(context = serializationContext.value) fiber = this.checkpointSerialize(context = serializationContext.value),
progressStep = logic.progressTracker?.currentStep
) )
} catch (exception: Exception) { } catch (exception: Exception) {
Event.Error(exception) Event.Error(exception)

View File

@ -159,7 +159,8 @@ class TopLevelTransition(
checkpointState = currentState.checkpoint.checkpointState.copy( checkpointState = currentState.checkpoint.checkpointState.copy(
numberOfSuspends = currentState.checkpoint.checkpointState.numberOfSuspends + 1 numberOfSuspends = currentState.checkpoint.checkpointState.numberOfSuspends + 1
), ),
flowIoRequest = event.ioRequest::class.java.simpleName flowIoRequest = event.ioRequest::class.java.simpleName,
progressStep = event.progressStep?.label
) )
if (event.maySkipCheckpoint) { if (event.maySkipCheckpoint) {
actions.addAll(arrayOf( actions.addAll(arrayOf(

View File

@ -23,7 +23,6 @@ import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.DatabaseConfig import net.corda.nodeapi.internal.persistence.DatabaseConfig
import net.corda.nodeapi.internal.persistence.DatabaseTransaction import net.corda.nodeapi.internal.persistence.DatabaseTransaction
import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.DUMMY_NOTARY_NAME
import net.corda.testing.core.SerializationEnvironmentRule import net.corda.testing.core.SerializationEnvironmentRule
import net.corda.testing.core.TestIdentity import net.corda.testing.core.TestIdentity
import net.corda.testing.internal.LogHelper import net.corda.testing.internal.LogHelper
@ -467,6 +466,33 @@ class DBCheckpointStorageTests {
} }
} }
@Test(timeout = 300_000)
fun `Checkpoint truncates long progressTracker step name`() {
val maxProgressStepLength = 256
val (id, checkpoint) = newCheckpoint(1)
database.transaction {
val serializedFlowState = checkpoint.flowState.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT)
checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState)
val checkpointFromStorage = checkpointStorage.getCheckpoint(id)
assertNull(checkpointFromStorage!!.progressStep)
}
val longString = """Long string Long string Long string Long string Long string Long string Long string Long string Long string
Long string Long string Long string Long string Long string Long string Long string Long string Long string Long string
Long string Long string Long string Long string Long string Long string Long string Long string Long string Long string
""".trimIndent()
database.transaction {
val newCheckpoint = checkpoint.copy(progressStep = longString)
val serializedFlowState = newCheckpoint.flowState.checkpointSerialize(
context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT
)
checkpointStorage.updateCheckpoint(id, newCheckpoint, serializedFlowState)
}
database.transaction {
val checkpointFromStorage = checkpointStorage.getCheckpoint(id)
assertEquals(longString.take(maxProgressStepLength), checkpointFromStorage!!.progressStep)
}
}
private fun newCheckpointStorage() { private fun newCheckpointStorage() {
database.transaction { database.transaction {
checkpointStorage = DBCheckpointStorage(object : CheckpointPerformanceRecorder { checkpointStorage = DBCheckpointStorage(object : CheckpointPerformanceRecorder {

View File

@ -20,8 +20,8 @@ import net.corda.core.flows.ReceiveFinalityFlow
import net.corda.core.flows.UnexpectedFlowEndException import net.corda.core.flows.UnexpectedFlowEndException
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.internal.DeclaredField import net.corda.core.internal.DeclaredField
import net.corda.core.internal.FlowStateMachine
import net.corda.core.internal.FlowIORequest import net.corda.core.internal.FlowIORequest
import net.corda.core.internal.FlowStateMachine
import net.corda.core.internal.concurrent.flatMap import net.corda.core.internal.concurrent.flatMap
import net.corda.core.messaging.MessageRecipients import net.corda.core.messaging.MessageRecipients
import net.corda.core.node.services.PartyInfo import net.corda.core.node.services.PartyInfo
@ -79,7 +79,6 @@ import java.util.*
import java.util.function.Predicate import java.util.function.Predicate
import kotlin.reflect.KClass import kotlin.reflect.KClass
import kotlin.streams.toList import kotlin.streams.toList
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith import kotlin.test.assertFailsWith
import kotlin.test.assertTrue import kotlin.test.assertTrue
@ -356,6 +355,21 @@ class FlowFrameworkTests {
} }
} }
@Test(timeout = 300_000)
fun `Flow persists progress tracker in the database when the flow suspends`() {
bobNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedReceiveFlow(it) }
val aliceFlowId = aliceNode.services.startFlow(ReceiveFlow(bob)).id
mockNet.runNetwork()
aliceNode.database.transaction {
val checkpoint = aliceNode.internals.checkpointStorage.getCheckpoint(aliceFlowId)
assertEquals(ReceiveFlow.START_STEP.label, checkpoint!!.progressStep)
}
bobNode.database.transaction {
val checkpoints = bobNode.internals.checkpointStorage.checkpoints().single()
assertEquals(InitiatedReceiveFlow.START_STEP.label, checkpoints.progressStep)
}
}
private class ConditionalExceptionFlow(val otherPartySession: FlowSession, val sendPayload: Any) : FlowLogic<Unit>() { private class ConditionalExceptionFlow(val otherPartySession: FlowSession, val sendPayload: Any) : FlowLogic<Unit>() {
@Suspendable @Suspendable
override fun call() { override fun call() {