CORDA-3491 Remove the flow state when a flow finishes (#6083)

Added a new field Completed to the in-memory object FlowState.

FlowState.Completed is corresponds to flow_state=Null in the DB.

This change will save disk space.
This commit is contained in:
williamvigorr3 2020-03-30 16:56:03 +01:00 committed by GitHub
parent ce202995c5
commit 024d63147d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 134 additions and 42 deletions

View File

@ -18,7 +18,7 @@ interface CheckpointStorage {
/** /**
* Update an existing checkpoint. Will throw if there is not checkpoint for this id. * Update an existing checkpoint. Will throw if there is not checkpoint for this id.
*/ */
fun updateCheckpoint(id: StateMachineRunId, checkpoint: Checkpoint, serializedFlowState: SerializedBytes<FlowState>) fun updateCheckpoint(id: StateMachineRunId, checkpoint: Checkpoint, serializedFlowState: SerializedBytes<FlowState>?)
/** /**
* Remove existing checkpoint from the store. * Remove existing checkpoint from the store.

View File

@ -17,7 +17,7 @@ interface CheckpointPerformanceRecorder {
/** /**
* Record performance metrics regarding the serialized size of [CheckpointState] and [FlowState] * Record performance metrics regarding the serialized size of [CheckpointState] and [FlowState]
*/ */
fun record(serializedCheckpointState: SerializedBytes<CheckpointState>, serializedFlowState: SerializedBytes<FlowState>) fun record(serializedCheckpointState: SerializedBytes<CheckpointState>, serializedFlowState: SerializedBytes<FlowState>?)
} }
class DBCheckpointPerformanceRecorder(metrics: MetricRegistry) : CheckpointPerformanceRecorder { class DBCheckpointPerformanceRecorder(metrics: MetricRegistry) : CheckpointPerformanceRecorder {
@ -44,8 +44,15 @@ class DBCheckpointPerformanceRecorder(metrics: MetricRegistry) : CheckpointPerfo
} }
} }
override fun record(serializedCheckpointState: SerializedBytes<CheckpointState>, serializedFlowState: SerializedBytes<FlowState>) { override fun record(serializedCheckpointState: SerializedBytes<CheckpointState>, serializedFlowState: SerializedBytes<FlowState>?) {
val totalSize = serializedCheckpointState.size.toLong() + serializedFlowState.size.toLong() /* For now we don't record states where the serializedFlowState is null and thus the checkpoint is in a completed state.
As this will skew the mean with lots of small checkpoints. For the moment we only measure runnable checkpoints. */
serializedFlowState?.let {
updateData(serializedCheckpointState.size.toLong() + it.size.toLong())
}
}
private fun updateData(totalSize: Long) {
checkpointingMeter.mark() checkpointingMeter.mark()
checkpointSizesThisSecond.update(totalSize) checkpointSizesThisSecond.update(totalSize)
var lastUpdateTime = lastBandwidthUpdate.get() var lastUpdateTime = lastBandwidthUpdate.get()

View File

@ -134,10 +134,9 @@ class DBCheckpointStorage(
@Column(name = "checkpoint_value", nullable = false) @Column(name = "checkpoint_value", nullable = false)
var checkpoint: ByteArray = EMPTY_BYTE_ARRAY, var checkpoint: ByteArray = EMPTY_BYTE_ARRAY,
// A future task will make this nullable
@Type(type = "corda-blob") @Type(type = "corda-blob")
@Column(name = "flow_state", nullable = false) @Column(name = "flow_state")
var flowStack: ByteArray = EMPTY_BYTE_ARRAY, var flowStack: ByteArray?,
@Column(name = "hmac") @Column(name = "hmac")
var hmac: ByteArray, var hmac: ByteArray,
@ -269,7 +268,7 @@ class DBCheckpointStorage(
currentDBSession().save(createDBCheckpoint(id, checkpoint, serializedFlowState)) currentDBSession().save(createDBCheckpoint(id, checkpoint, serializedFlowState))
} }
override fun updateCheckpoint(id: StateMachineRunId, checkpoint: Checkpoint, serializedFlowState: SerializedBytes<FlowState>) { override fun updateCheckpoint(id: StateMachineRunId, checkpoint: Checkpoint, serializedFlowState: SerializedBytes<FlowState>?) {
currentDBSession().update(updateDBCheckpoint(id, checkpoint, serializedFlowState)) currentDBSession().update(updateDBCheckpoint(id, checkpoint, serializedFlowState))
} }
@ -368,7 +367,7 @@ class DBCheckpointStorage(
private fun updateDBCheckpoint( private fun updateDBCheckpoint(
id: StateMachineRunId, id: StateMachineRunId,
checkpoint: Checkpoint, checkpoint: Checkpoint,
serializedFlowState: SerializedBytes<FlowState> serializedFlowState: SerializedBytes<FlowState>?
): DBFlowCheckpoint { ): DBFlowCheckpoint {
val flowId = id.uuid.toString() val flowId = id.uuid.toString()
val now = clock.instant() val now = clock.instant()
@ -408,12 +407,12 @@ class DBCheckpointStorage(
private fun createDBCheckpointBlob( private fun createDBCheckpointBlob(
serializedCheckpointState: SerializedBytes<CheckpointState>, serializedCheckpointState: SerializedBytes<CheckpointState>,
serializedFlowState: SerializedBytes<FlowState>, serializedFlowState: SerializedBytes<FlowState>?,
now: Instant now: Instant
): DBFlowCheckpointBlob { ): DBFlowCheckpointBlob {
return DBFlowCheckpointBlob( return DBFlowCheckpointBlob(
checkpoint = serializedCheckpointState.bytes, checkpoint = serializedCheckpointState.bytes,
flowStack = serializedFlowState.bytes, flowStack = serializedFlowState?.bytes,
hmac = ByteArray(HMAC_SIZE_BYTES), hmac = ByteArray(HMAC_SIZE_BYTES),
persistedInstant = now persistedInstant = now
) )
@ -506,9 +505,10 @@ class DBCheckpointStorage(
} }
private fun DBFlowCheckpoint.toSerializedCheckpoint(): Checkpoint.Serialized { private fun DBFlowCheckpoint.toSerializedCheckpoint(): Checkpoint.Serialized {
val serialisedFlowState = blob.flowStack?.let { SerializedBytes<FlowState>(it) }
return Checkpoint.Serialized( return Checkpoint.Serialized(
serializedCheckpointState = SerializedBytes(blob.checkpoint), serializedCheckpointState = SerializedBytes(blob.checkpoint),
serializedFlowState = SerializedBytes(blob.flowStack), serializedFlowState = serialisedFlowState,
// Always load as a [Clean] checkpoint to represent that the checkpoint is the last _good_ checkpoint // Always load as a [Clean] checkpoint to represent that the checkpoint is the last _good_ checkpoint
errorState = ErrorState.Clean, errorState = ErrorState.Clean,
// A checkpoint with a result should not normally be loaded (it should be [null] most of the time) // A checkpoint with a result should not normally be loaded (it should be [null] most of the time)

View File

@ -204,6 +204,9 @@ class CheckpointDumperImpl(private val checkpointStorage: CheckpointStorage, pri
val fiber = flowState.frozenFiber.checkpointDeserialize(context = checkpointSerializationContext) val fiber = flowState.frozenFiber.checkpointDeserialize(context = checkpointSerializationContext)
fiber to fiber.logic fiber to fiber.logic
} }
is FlowState.Completed -> {
throw IllegalStateException("Only runnable checkpoints with their flow stack are output by the checkpoint dumper")
}
} }
val flowCallStack = if (fiber != null) { val flowCallStack = if (fiber != null) {

View File

@ -95,11 +95,18 @@ class ActionExecutorImpl(
@Suspendable @Suspendable
private fun executePersistCheckpoint(action: Action.PersistCheckpoint) { private fun executePersistCheckpoint(action: Action.PersistCheckpoint) {
val checkpoint = action.checkpoint val checkpoint = action.checkpoint
val serializedFlowState = checkpoint.flowState.checkpointSerialize(checkpointSerializationContext) val flowState = checkpoint.flowState
val serializedFlowState = when(flowState) {
FlowState.Completed -> null
else -> flowState.checkpointSerialize(checkpointSerializationContext)
}
if (action.isCheckpointUpdate) { if (action.isCheckpointUpdate) {
checkpointStorage.updateCheckpoint(action.id, checkpoint, serializedFlowState) checkpointStorage.updateCheckpoint(action.id, checkpoint, serializedFlowState)
} else { } else {
checkpointStorage.addCheckpoint(action.id, checkpoint, serializedFlowState) if (flowState is FlowState.Completed) {
throw IllegalStateException("A new checkpoint cannot be created with a Completed FlowState.")
}
checkpointStorage.addCheckpoint(action.id, checkpoint, serializedFlowState!!)
} }
} }

View File

@ -780,11 +780,10 @@ class SingleThreadedStateMachineManager(
initialDeduplicationHandler: DeduplicationHandler? initialDeduplicationHandler: DeduplicationHandler?
): Flow? { ): Flow? {
val checkpoint = tryDeserializeCheckpoint(serializedCheckpoint, id)?.copy(status = Checkpoint.FlowStatus.RUNNABLE) ?: return null val checkpoint = tryDeserializeCheckpoint(serializedCheckpoint, id)?.copy(status = Checkpoint.FlowStatus.RUNNABLE) ?: return null
val flowState = checkpoint.flowState
val resultFuture = openFuture<Any?>() val resultFuture = openFuture<Any?>()
val fiber = when (flowState) { val fiber = when (checkpoint.flowState) {
is FlowState.Unstarted -> { is FlowState.Unstarted -> {
val logic = tryCheckpointDeserialize(flowState.frozenFlowLogic, id) ?: return null val logic = tryCheckpointDeserialize(checkpoint.flowState.frozenFlowLogic, id) ?: return null
val state = StateMachineState( val state = StateMachineState(
checkpoint = checkpoint, checkpoint = checkpoint,
pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(), pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(),
@ -803,7 +802,7 @@ class SingleThreadedStateMachineManager(
fiber fiber
} }
is FlowState.Started -> { is FlowState.Started -> {
val fiber = tryCheckpointDeserialize(flowState.frozenFiber, id) ?: return null val fiber = tryCheckpointDeserialize(checkpoint.flowState.frozenFiber, id) ?: return null
val state = StateMachineState( val state = StateMachineState(
checkpoint = checkpoint, checkpoint = checkpoint,
pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(), pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(),
@ -820,6 +819,9 @@ class SingleThreadedStateMachineManager(
fiber.logic.stateMachine = fiber fiber.logic.stateMachine = fiber
fiber fiber
} }
is FlowState.Completed -> {
return null // Places calling this function is rely on it to return null if the flow cannot be created from the checkpoint.
}
} }
verifyFlowLogicIsSuspendable(fiber.logic) verifyFlowLogicIsSuspendable(fiber.logic)
@ -847,18 +849,23 @@ class SingleThreadedStateMachineManager(
val flowLogic = flow.fiber.logic val flowLogic = flow.fiber.logic
if (flowLogic.isEnabledTimedFlow()) scheduleTimeout(id) if (flowLogic.isEnabledTimedFlow()) scheduleTimeout(id)
flow.fiber.scheduleEvent(Event.DoRemainingWork) flow.fiber.scheduleEvent(Event.DoRemainingWork)
when (checkpoint.flowState) { startOrResume(checkpoint, flow)
is FlowState.Unstarted -> {
flow.fiber.start()
}
is FlowState.Started -> {
Fiber.unparkDeserialized(flow.fiber, scheduler)
}
}
} }
} }
} }
private fun startOrResume(checkpoint: Checkpoint, flow: Flow) {
when (checkpoint.flowState) {
is FlowState.Unstarted -> {
flow.fiber.start()
}
is FlowState.Started -> {
Fiber.unparkDeserialized(flow.fiber, scheduler)
}
is FlowState.Completed -> throw IllegalStateException("Cannot start (or resume) a completed flow.")
}
}
private fun getFlowSessionIds(checkpoint: Checkpoint): Set<SessionId> { private fun getFlowSessionIds(checkpoint: Checkpoint): Set<SessionId> {
val initiatedFlowStart = (checkpoint.flowState as? FlowState.Unstarted)?.flowStart as? FlowStart.Initiated val initiatedFlowStart = (checkpoint.flowState as? FlowState.Unstarted)?.flowStart as? FlowStart.Initiated
return if (initiatedFlowStart == null) { return if (initiatedFlowStart == null) {

View File

@ -149,7 +149,7 @@ data class Checkpoint(
*/ */
data class Serialized( data class Serialized(
val serializedCheckpointState: SerializedBytes<CheckpointState>, val serializedCheckpointState: SerializedBytes<CheckpointState>,
val serializedFlowState: SerializedBytes<FlowState>, val serializedFlowState: SerializedBytes<FlowState>?,
val errorState: ErrorState, val errorState: ErrorState,
val result: SerializedBytes<Any>?, val result: SerializedBytes<Any>?,
val status: FlowStatus, val status: FlowStatus,
@ -165,7 +165,7 @@ data class Checkpoint(
fun deserialize(checkpointSerializationContext: CheckpointSerializationContext): Checkpoint { fun deserialize(checkpointSerializationContext: CheckpointSerializationContext): Checkpoint {
return Checkpoint( return Checkpoint(
checkpointState = serializedCheckpointState.deserialize(context = SerializationDefaults.STORAGE_CONTEXT), checkpointState = serializedCheckpointState.deserialize(context = SerializationDefaults.STORAGE_CONTEXT),
flowState = serializedFlowState.checkpointDeserialize(checkpointSerializationContext), flowState = serializedFlowState?.checkpointDeserialize(checkpointSerializationContext) ?: FlowState.Completed,
errorState = errorState, errorState = errorState,
result = result?.deserialize(context = SerializationDefaults.STORAGE_CONTEXT), result = result?.deserialize(context = SerializationDefaults.STORAGE_CONTEXT),
status = status, status = status,
@ -299,6 +299,12 @@ sealed class FlowState {
) : FlowState() { ) : FlowState() {
override fun toString() = "Started(flowIORequest=$flowIORequest, frozenFiber=${frozenFiber.hash})" override fun toString() = "Started(flowIORequest=$flowIORequest, frozenFiber=${frozenFiber.hash})"
} }
/**
* The flow has completed. It does not have a running fiber that needs to be serialized and checkpointed.
*/
object Completed : FlowState()
} }
/** /**

View File

@ -24,10 +24,11 @@ class DoRemainingWorkTransition(
// If the flow is clean check the FlowState // If the flow is clean check the FlowState
private fun cleanTransition(): TransitionResult { private fun cleanTransition(): TransitionResult {
val checkpoint = startingState.checkpoint val flowState = startingState.checkpoint.flowState
return when (checkpoint.flowState) { return when (flowState) {
is FlowState.Unstarted -> UnstartedFlowTransition(context, startingState, checkpoint.flowState).transition() is FlowState.Unstarted -> UnstartedFlowTransition(context, startingState, flowState).transition()
is FlowState.Started -> StartedFlowTransition(context, startingState, checkpoint.flowState).transition() is FlowState.Started -> StartedFlowTransition(context, startingState, flowState).transition()
is FlowState.Completed -> throw IllegalStateException("Cannot transition a state with completed flow state.")
} }
} }

View File

@ -218,6 +218,7 @@ class TopLevelTransition(
checkpointState = checkpoint.checkpointState.copy( checkpointState = checkpoint.checkpointState.copy(
numberOfSuspends = checkpoint.checkpointState.numberOfSuspends + 1 numberOfSuspends = checkpoint.checkpointState.numberOfSuspends + 1
), ),
flowState = FlowState.Completed,
result = event.returnValue, result = event.returnValue,
status = Checkpoint.FlowStatus.COMPLETED status = Checkpoint.FlowStatus.COMPLETED
), ),

View File

@ -47,7 +47,7 @@
<constraints nullable="false"/> <constraints nullable="false"/>
</column> </column>
<column name="flow_state" type="varbinary(33554432)"> <column name="flow_state" type="varbinary(33554432)">
<constraints nullable="false"/> <constraints nullable="true"/>
</column> </column>
<column name="timestamp" type="java.sql.Types.TIMESTAMP"> <column name="timestamp" type="java.sql.Types.TIMESTAMP">
<constraints nullable="false"/> <constraints nullable="false"/>

View File

@ -47,7 +47,7 @@
<constraints nullable="false"/> <constraints nullable="false"/>
</column> </column>
<column name="flow_state" type="blob"> <column name="flow_state" type="blob">
<constraints nullable="false"/> <constraints nullable="true"/>
</column> </column>
<column name="timestamp" type="java.sql.Types.TIMESTAMP"> <column name="timestamp" type="java.sql.Types.TIMESTAMP">
<constraints nullable="false"/> <constraints nullable="false"/>

View File

@ -133,6 +133,26 @@ class DBCheckpointStorageTests {
} }
} }
@Test(timeout = 300_000)
fun `update a checkpoint to completed`() {
val (id, checkpoint) = newCheckpoint()
val serializedFlowState = checkpoint.serializeFlowState()
database.transaction {
checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState)
}
val completedCheckpoint = checkpoint.copy(flowState = FlowState.Completed)
database.transaction {
checkpointStorage.updateCheckpoint(id, completedCheckpoint, null)
}
database.transaction {
assertEquals(
completedCheckpoint,
checkpointStorage.checkpoints().single().deserialize()
)
}
}
@Test(timeout = 300_000) @Test(timeout = 300_000)
fun `remove checkpoint`() { fun `remove checkpoint`() {
val (id, checkpoint) = newCheckpoint() val (id, checkpoint) = newCheckpoint()
@ -530,7 +550,8 @@ class DBCheckpointStorageTests {
val paused = checkpoint.copy(status = Checkpoint.FlowStatus.PAUSED) // is considered runnable val paused = checkpoint.copy(status = Checkpoint.FlowStatus.PAUSED) // is considered runnable
database.transaction { database.transaction {
val serializedFlowState = checkpoint.flowState.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT) val serializedFlowState =
checkpoint.flowState.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT)
checkpointStorage.addCheckpoint(StateMachineRunId.createRandom(), runnable, serializedFlowState) checkpointStorage.addCheckpoint(StateMachineRunId.createRandom(), runnable, serializedFlowState)
checkpointStorage.addCheckpoint(StateMachineRunId.createRandom(), hospitalized, serializedFlowState) checkpointStorage.addCheckpoint(StateMachineRunId.createRandom(), hospitalized, serializedFlowState)
@ -551,7 +572,7 @@ class DBCheckpointStorageTests {
object : CheckpointPerformanceRecorder { object : CheckpointPerformanceRecorder {
override fun record( override fun record(
serializedCheckpointState: SerializedBytes<CheckpointState>, serializedCheckpointState: SerializedBytes<CheckpointState>,
serializedFlowState: SerializedBytes<FlowState> serializedFlowState: SerializedBytes<FlowState>?
) { ) {
// do nothing // do nothing
} }

View File

@ -5,21 +5,26 @@ import com.natpryce.hamkrest.containsSubstring
import com.nhaarman.mockito_kotlin.doReturn import com.nhaarman.mockito_kotlin.doReturn
import com.nhaarman.mockito_kotlin.mock import com.nhaarman.mockito_kotlin.mock
import com.nhaarman.mockito_kotlin.whenever import com.nhaarman.mockito_kotlin.whenever
import junit.framework.TestCase.assertNull
import net.corda.core.context.InvocationContext import net.corda.core.context.InvocationContext
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StateMachineRunId import net.corda.core.flows.StateMachineRunId
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.FlowIORequest
import net.corda.core.internal.createDirectories import net.corda.core.internal.createDirectories
import net.corda.core.internal.deleteIfExists import net.corda.core.internal.deleteIfExists
import net.corda.core.internal.deleteRecursively import net.corda.core.internal.deleteRecursively
import net.corda.core.internal.div import net.corda.core.internal.div
import net.corda.core.internal.inputStream import net.corda.core.internal.inputStream
import net.corda.core.internal.isRegularFile
import net.corda.core.internal.list
import net.corda.core.internal.readFully import net.corda.core.internal.readFully
import net.corda.core.node.ServiceHub import net.corda.core.node.ServiceHub
import net.corda.core.serialization.SerializeAsToken import net.corda.core.serialization.SerializeAsToken
import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.internal.CheckpointSerializationDefaults import net.corda.core.serialization.internal.CheckpointSerializationDefaults
import net.corda.core.serialization.internal.checkpointSerialize import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.toNonEmptySet
import net.corda.nodeapi.internal.lifecycle.NodeServicesContext import net.corda.nodeapi.internal.lifecycle.NodeServicesContext
import net.corda.nodeapi.internal.lifecycle.NodeLifecycleEvent import net.corda.nodeapi.internal.lifecycle.NodeLifecycleEvent
import net.corda.node.internal.NodeStartup import net.corda.node.internal.NodeStartup
@ -40,10 +45,12 @@ import org.junit.Before
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import java.nio.file.Files import java.nio.file.Files
import java.nio.file.Path
import java.nio.file.Paths import java.nio.file.Paths
import java.time.Clock import java.time.Clock
import java.time.Instant import java.time.Instant
import java.util.zip.ZipInputStream import java.util.zip.ZipInputStream
import kotlin.test.assertEquals
class CheckpointDumperImplTest { class CheckpointDumperImplTest {
@ -111,7 +118,29 @@ class CheckpointDumperImplTest {
} }
dumper.dumpCheckpoints() dumper.dumpCheckpoints()
checkDumpFile() checkDumpFile()
}
@Test(timeout=300_000)
fun `Checkpoint dumper doesn't output completed checkpoints`() {
val dumper = CheckpointDumperImpl(checkpointStorage, database, services, baseDirectory)
dumper.update(mockAfterStartEvent)
// add a checkpoint
val (id, checkpoint) = newCheckpoint()
database.transaction {
checkpointStorage.addCheckpoint(id, checkpoint, serializeFlowState(checkpoint))
}
val newCheckpoint = checkpoint.copy(
flowState = FlowState.Completed,
status = Checkpoint.FlowStatus.COMPLETED
)
database.transaction {
checkpointStorage.updateCheckpoint(id, newCheckpoint, null)
}
dumper.dumpCheckpoints()
checkDumpFileEmpty()
} }
private fun checkDumpFile() { private fun checkDumpFile() {
@ -123,6 +152,13 @@ class CheckpointDumperImplTest {
} }
} }
private fun checkDumpFileEmpty() {
ZipInputStream(file.inputStream()).use { zip ->
val entry = zip.nextEntry
assertNull(entry)
}
}
// This test will only succeed when the VM startup includes the "checkpoint-agent": // This test will only succeed when the VM startup includes the "checkpoint-agent":
// -javaagent:tools/checkpoint-agent/build/libs/checkpoint-agent.jar // -javaagent:tools/checkpoint-agent/build/libs/checkpoint-agent.jar
@Test(timeout=300_000) @Test(timeout=300_000)
@ -147,7 +183,7 @@ class CheckpointDumperImplTest {
object : CheckpointPerformanceRecorder { object : CheckpointPerformanceRecorder {
override fun record( override fun record(
serializedCheckpointState: SerializedBytes<CheckpointState>, serializedCheckpointState: SerializedBytes<CheckpointState>,
serializedFlowState: SerializedBytes<FlowState> serializedFlowState: SerializedBytes<FlowState>?
) { ) {
// do nothing // do nothing
} }

View File

@ -67,6 +67,7 @@ import org.assertj.core.api.Condition
import org.junit.After import org.junit.After
import org.junit.Assert.assertEquals import org.junit.Assert.assertEquals
import org.junit.Assert.assertNotEquals import org.junit.Assert.assertNotEquals
import org.junit.Assert.assertNotNull
import org.junit.Assert.assertNull import org.junit.Assert.assertNull
import org.junit.Before import org.junit.Before
import org.junit.Test import org.junit.Test
@ -103,7 +104,7 @@ class FlowFrameworkTests {
object : CheckpointPerformanceRecorder { object : CheckpointPerformanceRecorder {
override fun record( override fun record(
serializedCheckpointState: SerializedBytes<CheckpointState>, serializedCheckpointState: SerializedBytes<CheckpointState>,
serializedFlowState: SerializedBytes<FlowState> serializedFlowState: SerializedBytes<FlowState>?
) { ) {
// do nothing // do nothing
} }
@ -342,13 +343,14 @@ class FlowFrameworkTests {
//We should update this test when we do the work to persists the flow result. //We should update this test when we do the work to persists the flow result.
@Test(timeout = 300_000) @Test(timeout = 300_000)
fun `Flow status is set to completed in database when the flow finishes`() { fun `Flow status is set to completed in database when the flow finishes and serialised flow state is null`() {
val terminationSignal = Semaphore(0) val terminationSignal = Semaphore(0)
val flow = aliceNode.services.startFlow(NoOpFlow( terminateUponSignal = terminationSignal)) val flow = aliceNode.services.startFlow(NoOpFlow( terminateUponSignal = terminationSignal))
mockNet.waitQuiescent() // current thread needs to wait fiber running on a different thread, has reached the blocking point mockNet.waitQuiescent() // current thread needs to wait fiber running on a different thread, has reached the blocking point
aliceNode.database.transaction { aliceNode.database.transaction {
val checkpoint = dbCheckpointStorage.getCheckpoint(flow.id) val checkpoint = dbCheckpointStorage.getCheckpoint(flow.id)
assertNull(checkpoint!!.result) assertNull(checkpoint!!.result)
assertNotNull(checkpoint.serializedFlowState)
assertNotEquals(Checkpoint.FlowStatus.COMPLETED, checkpoint.status) assertNotEquals(Checkpoint.FlowStatus.COMPLETED, checkpoint.status)
} }
terminationSignal.release() terminationSignal.release()
@ -356,6 +358,7 @@ class FlowFrameworkTests {
aliceNode.database.transaction { aliceNode.database.transaction {
val checkpoint = dbCheckpointStorage.getCheckpoint(flow.id) val checkpoint = dbCheckpointStorage.getCheckpoint(flow.id)
assertNull(checkpoint!!.result) assertNull(checkpoint!!.result)
assertNull(checkpoint.serializedFlowState)
assertEquals(Checkpoint.FlowStatus.COMPLETED, checkpoint.status) assertEquals(Checkpoint.FlowStatus.COMPLETED, checkpoint.status)
} }
} }