diff --git a/common/logging/src/main/kotlin/net/corda/common/logging/Constants.kt b/common/logging/src/main/kotlin/net/corda/common/logging/Constants.kt index b08dda81d7..5eb0817584 100644 --- a/common/logging/src/main/kotlin/net/corda/common/logging/Constants.kt +++ b/common/logging/src/main/kotlin/net/corda/common/logging/Constants.kt @@ -9,4 +9,4 @@ package net.corda.common.logging * (originally added to source control for ease of use) */ -internal const val CURRENT_MAJOR_RELEASE = "4.6-SNAPSHOT" +internal const val CURRENT_MAJOR_RELEASE = "4.6-SNAPSHOT" \ No newline at end of file diff --git a/core-tests/src/test/kotlin/net/corda/coretests/flows/FlowIsKilledTest.kt b/core-tests/src/test/kotlin/net/corda/coretests/flows/FlowIsKilledTest.kt index 14a2607b26..b3cee7c1ca 100644 --- a/core-tests/src/test/kotlin/net/corda/coretests/flows/FlowIsKilledTest.kt +++ b/core-tests/src/test/kotlin/net/corda/coretests/flows/FlowIsKilledTest.kt @@ -14,6 +14,7 @@ import net.corda.core.messaging.startFlow import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.minutes import net.corda.core.utilities.seconds +import net.corda.node.services.statemachine.Checkpoint import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.BOB_NAME import net.corda.testing.core.CHARLIE_NAME @@ -74,10 +75,9 @@ class FlowIsKilledTest { assertEquals(11, AFlowThatWantsToDieAndKillsItsFriends.position) assertTrue(AFlowThatWantsToDieAndKillsItsFriendsResponder.receivedKilledExceptions[BOB_NAME]!!) assertTrue(AFlowThatWantsToDieAndKillsItsFriendsResponder.receivedKilledExceptions[CHARLIE_NAME]!!) - val aliceCheckpoints = alice.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) - assertEquals(1, aliceCheckpoints) - val bobCheckpoints = bob.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) - assertEquals(1, bobCheckpoints) + assertEquals(1, alice.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) + assertEquals(2, bob.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) + assertEquals(1, bob.rpc.startFlow(::GetNumberOfFailedCheckpointsFlow).returnValue.getOrThrow(20.seconds)) } } } @@ -104,10 +104,9 @@ class FlowIsKilledTest { handle.returnValue.getOrThrow(1.minutes) } assertEquals(11, AFlowThatGetsMurderedByItsFriendResponder.position) - val aliceCheckpoints = alice.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) - assertEquals(1, aliceCheckpoints) - val bobCheckpoints = bob.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) - assertEquals(1, bobCheckpoints) + assertEquals(2, alice.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) + assertEquals(1, alice.rpc.startFlow(::GetNumberOfFailedCheckpointsFlow).returnValue.getOrThrow(20.seconds)) + assertEquals(1, bob.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) } } @@ -356,4 +355,18 @@ class FlowIsKilledTest { } } } + + @StartableByRPC + class GetNumberOfFailedCheckpointsFlow : FlowLogic<Long>() { + override fun call(): Long { + return serviceHub.jdbcSession() + .prepareStatement("select count(*) from node_checkpoints where status = ${Checkpoint.FlowStatus.FAILED.ordinal}") + .use { ps -> + ps.executeQuery().use { rs -> + rs.next() + rs.getLong(1) + } + } + } + } } \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/context/InvocationContext.kt b/core/src/main/kotlin/net/corda/core/context/InvocationContext.kt index b9f66ca423..ef90810b05 100644 --- a/core/src/main/kotlin/net/corda/core/context/InvocationContext.kt +++ b/core/src/main/kotlin/net/corda/core/context/InvocationContext.kt @@ -18,21 +18,53 @@ import java.security.Principal * @property impersonatedActor Optional impersonated actor, used for logging but not for authorisation. */ @CordaSerializable -data class InvocationContext(val origin: InvocationOrigin, val trace: Trace, val actor: Actor?, val externalTrace: Trace? = null, val impersonatedActor: Actor? = null) { +data class InvocationContext( + val origin: InvocationOrigin, + val trace: Trace, + val actor: Actor?, + val externalTrace: Trace? = null, + val impersonatedActor: Actor? = null, + val arguments: List<Any?> = emptyList() +) { + + constructor( + origin: InvocationOrigin, + trace: Trace, + actor: Actor?, + externalTrace: Trace? = null, + impersonatedActor: Actor? = null + ) : this(origin, trace, actor, externalTrace, impersonatedActor, emptyList()) + companion object { /** * Creates an [InvocationContext] with a [Trace] that defaults to a [java.util.UUID] as value and [java.time.Instant.now] timestamp. */ @DeleteForDJVM @JvmStatic - fun newInstance(origin: InvocationOrigin, trace: Trace = Trace.newInstance(), actor: Actor? = null, externalTrace: Trace? = null, impersonatedActor: Actor? = null) = InvocationContext(origin, trace, actor, externalTrace, impersonatedActor) + @JvmOverloads + @Suppress("LongParameterList") + fun newInstance( + origin: InvocationOrigin, + trace: Trace = Trace.newInstance(), + actor: Actor? = null, + externalTrace: Trace? = null, + impersonatedActor: Actor? = null, + arguments: List<Any?> = emptyList() + ) = InvocationContext(origin, trace, actor, externalTrace, impersonatedActor, arguments) /** * Creates an [InvocationContext] with [InvocationOrigin.RPC] origin. */ @DeleteForDJVM @JvmStatic - fun rpc(actor: Actor, trace: Trace = Trace.newInstance(), externalTrace: Trace? = null, impersonatedActor: Actor? = null): InvocationContext = newInstance(InvocationOrigin.RPC(actor), trace, actor, externalTrace, impersonatedActor) + @JvmOverloads + fun rpc( + actor: Actor, + trace: Trace = Trace.newInstance(), + externalTrace: Trace? = null, + impersonatedActor: Actor? = null, + arguments: List<Any?> = emptyList() + ): InvocationContext = newInstance(InvocationOrigin.RPC(actor), trace, actor, externalTrace, impersonatedActor, arguments) /** * Creates an [InvocationContext] with [InvocationOrigin.Peer] origin. @@ -67,6 +99,23 @@ data class InvocationContext(val origin: InvocationOrigin, val trace: Trace, val * Associated security principal. */ fun principal(): Principal = origin.principal() + + fun copy( + origin: InvocationOrigin = this.origin, + trace: Trace = this.trace, + actor: Actor? = this.actor, + externalTrace: Trace? = this.externalTrace, + impersonatedActor: Actor? = this.impersonatedActor + ): InvocationContext { + return copy( + origin = origin, + trace = trace, + actor = actor, + externalTrace = externalTrace, + impersonatedActor = impersonatedActor, + arguments = arguments + ) + } } /** diff --git a/core/src/main/kotlin/net/corda/core/internal/FlowIORequest.kt b/core/src/main/kotlin/net/corda/core/internal/FlowIORequest.kt index 494c5099aa..0d54a4715a 100644 --- a/core/src/main/kotlin/net/corda/core/internal/FlowIORequest.kt +++ b/core/src/main/kotlin/net/corda/core/internal/FlowIORequest.kt @@ -45,6 +45,7 @@ sealed class FlowIORequest<out R : Any> { * @property shouldRetrySend specifies whether the send should be retried. * @return a map from session to received message. */ + //net.corda.core.internal.FlowIORequest.SendAndReceive data class SendAndReceive( val sessionToMessage: Map<FlowSession, SerializedBytes<Any>>, val shouldRetrySend: Boolean @@ -80,7 +81,15 @@ sealed class FlowIORequest<out R : Any> { /** * Suspend the flow until all Initiating sessions are confirmed. */ - object WaitForSessionConfirmations : FlowIORequest<Unit>() + class WaitForSessionConfirmations : FlowIORequest<Unit>() { + override fun equals(other: Any?): Boolean { + return this === other + } + + override fun hashCode(): Int { + return System.identityHashCode(this) + } + } /** * Execute the specified [operation], suspend the flow until completion. diff --git a/core/src/main/kotlin/net/corda/core/internal/messaging/InternalCordaRPCOps.kt b/core/src/main/kotlin/net/corda/core/internal/messaging/InternalCordaRPCOps.kt index 8f92d54c32..e3ab065422 100644 --- a/core/src/main/kotlin/net/corda/core/internal/messaging/InternalCordaRPCOps.kt +++ b/core/src/main/kotlin/net/corda/core/internal/messaging/InternalCordaRPCOps.kt @@ -1,5 +1,6 @@ package net.corda.core.internal.messaging +import net.corda.core.flows.StateMachineRunId import net.corda.core.internal.AttachmentTrustInfo import net.corda.core.messaging.CordaRPCOps @@ -13,4 +14,11 @@ interface InternalCordaRPCOps : CordaRPCOps { /** Get all attachment trust information */ val attachmentTrustInfos: List<AttachmentTrustInfo> + + /** + * Resume a paused flow. + * + * @return whether the flow was successfully resumed. + */ + fun unPauseFlow(id: StateMachineRunId): Boolean } \ No newline at end of file diff --git a/node-api-tests/src/test/kotlin/net/corda/nodeapitests/internal/persistence/MissingSchemaMigrationTest.kt b/node-api-tests/src/test/kotlin/net/corda/nodeapitests/internal/persistence/MissingSchemaMigrationTest.kt index b81a0eaceb..a9f422497f 100644 --- a/node-api-tests/src/test/kotlin/net/corda/nodeapitests/internal/persistence/MissingSchemaMigrationTest.kt +++ b/node-api-tests/src/test/kotlin/net/corda/nodeapitests/internal/persistence/MissingSchemaMigrationTest.kt @@ -50,7 +50,7 @@ class MissingSchemaMigrationTest { fun `test that an error is thrown when forceThrowOnMissingMigration is set and a mapped schema is missing a migration`() { assertThatThrownBy { createSchemaMigration(setOf(GoodSchema), true) - .nodeStartup(dataSource.connection.use { DBCheckpointStorage().getCheckpointCount(it) != 0L }) + .nodeStartup(dataSource.connection.use { DBCheckpointStorage.getCheckpointCount(it) != 0L }) }.isInstanceOf(MissingMigrationException::class.java) } @@ -58,7 +58,7 @@ class MissingSchemaMigrationTest { fun `test that an error is not thrown when forceThrowOnMissingMigration is not set and a mapped schema is missing a migration`() { assertDoesNotThrow { createSchemaMigration(setOf(GoodSchema), false) - .nodeStartup(dataSource.connection.use { DBCheckpointStorage().getCheckpointCount(it) != 0L }) + .nodeStartup(dataSource.connection.use { DBCheckpointStorage.getCheckpointCount(it) != 0L }) } } @@ -67,7 +67,7 @@ class MissingSchemaMigrationTest { assertDoesNotThrow("This test failure indicates " + "a new table has been added to the node without the appropriate migration scripts being present") { createSchemaMigration(NodeSchemaService().internalSchemas(), false) - .nodeStartup(dataSource.connection.use { DBCheckpointStorage().getCheckpointCount(it) != 0L }) + .nodeStartup(dataSource.connection.use { DBCheckpointStorage.getCheckpointCount(it) != 0L }) } } diff --git a/node/src/integration-test-slow/kotlin/net/corda/node/flows/FlowCheckpointVersionNodeStartupCheckTest.kt b/node/src/integration-test-slow/kotlin/net/corda/node/flows/FlowCheckpointVersionNodeStartupCheckTest.kt index 1e60ce62fb..d6abe718f1 100644 --- a/node/src/integration-test-slow/kotlin/net/corda/node/flows/FlowCheckpointVersionNodeStartupCheckTest.kt +++ b/node/src/integration-test-slow/kotlin/net/corda/node/flows/FlowCheckpointVersionNodeStartupCheckTest.kt @@ -17,7 +17,7 @@ import net.corda.testing.driver.DriverParameters import net.corda.testing.driver.NodeParameters import net.corda.testing.driver.driver import net.corda.testing.node.internal.ListenProcessDeathException -import net.corda.testing.node.internal.assertCheckpoints +import net.corda.testing.node.internal.assertUncompletedCheckpoints import net.corda.testing.node.internal.enclosedCordapp import org.assertj.core.api.Assertions.assertThat import org.junit.Test @@ -75,7 +75,7 @@ class FlowCheckpointVersionNodeStartupCheckTest { } private fun DriverDSL.assertBobFailsToStartWithLogMessage(logMessage: String) { - assertCheckpoints(BOB_NAME, 1) + assertUncompletedCheckpoints(BOB_NAME, 1) assertFailsWith(ListenProcessDeathException::class) { startNode(NodeParameters( diff --git a/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StatemachineErrorHandlingTest.kt b/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StatemachineErrorHandlingTest.kt index d933625407..e7bda45134 100644 --- a/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StatemachineErrorHandlingTest.kt +++ b/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StatemachineErrorHandlingTest.kt @@ -110,9 +110,10 @@ abstract class StatemachineErrorHandlingTest { } @StartableByRPC - class GetNumberOfCheckpointsFlow : FlowLogic<Long>() { + class GetNumberOfUncompletedCheckpointsFlow : FlowLogic<Long>() { override fun call(): Long { - return serviceHub.jdbcSession().prepareStatement("select count(*) from node_checkpoints").use { ps -> + val sqlStatement = "select count(*) from node_checkpoints where status not in (${Checkpoint.FlowStatus.COMPLETED.ordinal})" + return serviceHub.jdbcSession().prepareStatement(sqlStatement).use { ps -> ps.executeQuery().use { rs -> rs.next() rs.getLong(1) diff --git a/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StatemachineFinalityErrorHandlingTest.kt b/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StatemachineFinalityErrorHandlingTest.kt index 1855aa11c3..98e199afe2 100644 --- a/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StatemachineFinalityErrorHandlingTest.kt +++ b/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StatemachineFinalityErrorHandlingTest.kt @@ -89,9 +89,9 @@ class StatemachineFinalityErrorHandlingTest : StatemachineErrorHandlingTest() { assertEquals(0, aliceClient.stateMachinesSnapshot().size) assertEquals(1, charlieClient.stateMachinesSnapshot().size) // 1 for GetNumberOfCheckpointsFlow - assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfUncompletedCheckpointsFlow).returnValue.get()) // 1 ReceiveFinalityFlow and 1 for GetNumberOfCheckpointsFlow - assertEquals(2, charlieClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(2, charlieClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfUncompletedCheckpointsFlow).returnValue.get()) } } @@ -160,9 +160,9 @@ class StatemachineFinalityErrorHandlingTest : StatemachineErrorHandlingTest() { assertEquals(0, aliceClient.stateMachinesSnapshot().size) assertEquals(1, charlieClient.stateMachinesSnapshot().size) // 1 for GetNumberOfCheckpointsFlow - assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfUncompletedCheckpointsFlow).returnValue.get()) // 1 for ReceiveFinalityFlow and 1 for GetNumberOfCheckpointsFlow - assertEquals(2, charlieClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(2, charlieClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfUncompletedCheckpointsFlow).returnValue.get()) } } @@ -252,9 +252,9 @@ class StatemachineFinalityErrorHandlingTest : StatemachineErrorHandlingTest() { assertEquals(0, aliceClient.stateMachinesSnapshot().size) assertEquals(0, charlieClient.stateMachinesSnapshot().size) // 1 for GetNumberOfCheckpointsFlow - assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfUncompletedCheckpointsFlow).returnValue.get()) // 1 for GetNumberOfCheckpointsFlow - assertEquals(1, charlieClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(1, charlieClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfUncompletedCheckpointsFlow).returnValue.get()) } } @@ -349,9 +349,9 @@ class StatemachineFinalityErrorHandlingTest : StatemachineErrorHandlingTest() { assertEquals(1, aliceClient.stateMachinesSnapshot().size) assertEquals(1, charlieClient.stateMachinesSnapshot().size) // 1 for CashIssueAndPaymentFlow and 1 for GetNumberOfCheckpointsFlow - assertEquals(2, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(2, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfUncompletedCheckpointsFlow).returnValue.get()) // 1 for ReceiveFinalityFlow and 1 for GetNumberOfCheckpointsFlow - assertEquals(2, charlieClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(2, charlieClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfUncompletedCheckpointsFlow).returnValue.get()) } } } \ No newline at end of file diff --git a/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StatemachineGeneralErrorHandlingTest.kt b/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StatemachineGeneralErrorHandlingTest.kt index 5aacac8a4a..a7a3e1cbb1 100644 --- a/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StatemachineGeneralErrorHandlingTest.kt +++ b/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StatemachineGeneralErrorHandlingTest.kt @@ -94,7 +94,7 @@ class StatemachineGeneralErrorHandlingTest : StatemachineErrorHandlingTest() { assertEquals(1, observation) assertEquals(1, aliceClient.stateMachinesSnapshot().size) // 1 for the errored flow kept for observation and another for GetNumberOfCheckpointsFlow - assertEquals(2, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(2, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfUncompletedCheckpointsFlow).returnValue.get()) } } @@ -172,7 +172,7 @@ class StatemachineGeneralErrorHandlingTest : StatemachineErrorHandlingTest() { assertEquals(0, observation) assertEquals(0, aliceClient.stateMachinesSnapshot().size) // 1 for GetNumberOfCheckpointsFlow - assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfUncompletedCheckpointsFlow).returnValue.get()) } } @@ -252,7 +252,7 @@ class StatemachineGeneralErrorHandlingTest : StatemachineErrorHandlingTest() { assertEquals(0, observation) assertEquals(0, aliceClient.stateMachinesSnapshot().size) // 1 for GetNumberOfCheckpointsFlow - assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfUncompletedCheckpointsFlow).returnValue.get()) } } @@ -337,7 +337,7 @@ class StatemachineGeneralErrorHandlingTest : StatemachineErrorHandlingTest() { assertEquals(0, observation) assertEquals(0, aliceClient.stateMachinesSnapshot().size) // 1 for GetNumberOfCheckpointsFlow - assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfUncompletedCheckpointsFlow).returnValue.get()) } } @@ -426,7 +426,7 @@ class StatemachineGeneralErrorHandlingTest : StatemachineErrorHandlingTest() { assertEquals(1, observation) assertEquals(1, aliceClient.stateMachinesSnapshot().size) // 1 for GetNumberOfCheckpointsFlow - assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfUncompletedCheckpointsFlow).returnValue.get()) } } @@ -527,7 +527,7 @@ class StatemachineGeneralErrorHandlingTest : StatemachineErrorHandlingTest() { assertEquals(0, observation) assertEquals(0, aliceClient.stateMachinesSnapshot().size) // 1 for GetNumberOfCheckpointsFlow - assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfUncompletedCheckpointsFlow).returnValue.get()) } } @@ -616,7 +616,7 @@ class StatemachineGeneralErrorHandlingTest : StatemachineErrorHandlingTest() { assertEquals(0, observation) assertEquals(0, aliceClient.stateMachinesSnapshot().size) // 1 for GetNumberOfCheckpointsFlow - assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfUncompletedCheckpointsFlow).returnValue.get()) } } @@ -714,7 +714,7 @@ class StatemachineGeneralErrorHandlingTest : StatemachineErrorHandlingTest() { assertEquals(1, observation) assertEquals(1, aliceClient.stateMachinesSnapshot().size) // 1 for the errored flow kept for observation and another for GetNumberOfCheckpointsFlow - assertEquals(2, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(2, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfUncompletedCheckpointsFlow).returnValue.get()) } } @@ -812,7 +812,7 @@ class StatemachineGeneralErrorHandlingTest : StatemachineErrorHandlingTest() { assertEquals(0, observation) assertEquals(0, aliceClient.stateMachinesSnapshot().size) // 1 for GetNumberOfCheckpointsFlow - assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfUncompletedCheckpointsFlow).returnValue.get()) } } @@ -898,7 +898,7 @@ class StatemachineGeneralErrorHandlingTest : StatemachineErrorHandlingTest() { assertEquals(1, observation) assertEquals(1, aliceClient.stateMachinesSnapshot().size) // 1 for the errored flow kept for observation and another for GetNumberOfCheckpointsFlow - assertEquals(2, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(2, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfUncompletedCheckpointsFlow).returnValue.get()) } } @@ -990,7 +990,7 @@ class StatemachineGeneralErrorHandlingTest : StatemachineErrorHandlingTest() { assertEquals(1, observation) assertEquals(1, aliceClient.stateMachinesSnapshot().size) // 1 for errored flow and 1 for GetNumberOfCheckpointsFlow - assertEquals(2, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(2, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfUncompletedCheckpointsFlow).returnValue.get()) } } @@ -1079,9 +1079,9 @@ class StatemachineGeneralErrorHandlingTest : StatemachineErrorHandlingTest() { assertEquals(0, aliceClient.stateMachinesSnapshot().size) assertEquals(0, charlieClient.stateMachinesSnapshot().size) // 1 for GetNumberOfCheckpointsFlow - assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfUncompletedCheckpointsFlow).returnValue.get()) // 1 for GetNumberOfCheckpointsFlow - assertEquals(1, charlieClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(1, charlieClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfUncompletedCheckpointsFlow).returnValue.get()) } } @@ -1176,11 +1176,11 @@ class StatemachineGeneralErrorHandlingTest : StatemachineErrorHandlingTest() { assertEquals(1, aliceClient.stateMachinesSnapshot().size) assertEquals(1, charlieClient.stateMachinesSnapshot().size) // 1 for the flow that is waiting for the errored counterparty flow to finish and 1 for GetNumberOfCheckpointsFlow - assertEquals(2, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(2, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfUncompletedCheckpointsFlow).returnValue.get()) // 1 for GetNumberOfCheckpointsFlow // the checkpoint is not persisted since it kept failing the original checkpoint commit // the flow will recover since artemis will keep the events and replay them on node restart - assertEquals(1, charlieClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(1, charlieClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfUncompletedCheckpointsFlow).returnValue.get()) } } @@ -1273,9 +1273,9 @@ class StatemachineGeneralErrorHandlingTest : StatemachineErrorHandlingTest() { assertEquals(0, aliceClient.stateMachinesSnapshot().size) assertEquals(0, charlieClient.stateMachinesSnapshot().size) // 1 for GetNumberOfCheckpointsFlow - assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfUncompletedCheckpointsFlow).returnValue.get()) // 1 for GetNumberOfCheckpointsFlow - assertEquals(1, charlieClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(1, charlieClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfUncompletedCheckpointsFlow).returnValue.get()) } } } \ No newline at end of file diff --git a/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StatemachineKillFlowErrorHandlingTest.kt b/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StatemachineKillFlowErrorHandlingTest.kt index 9319e8ae66..0d0c8f7177 100644 --- a/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StatemachineKillFlowErrorHandlingTest.kt +++ b/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StatemachineKillFlowErrorHandlingTest.kt @@ -98,7 +98,7 @@ class StatemachineKillFlowErrorHandlingTest : StatemachineErrorHandlingTest() { assertEquals(0, observation) assertEquals(0, aliceClient.stateMachinesSnapshot().size) // 1 for GetNumberOfCheckpointsFlow - assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfUncompletedCheckpointsFlow).returnValue.get()) } } @@ -185,7 +185,7 @@ class StatemachineKillFlowErrorHandlingTest : StatemachineErrorHandlingTest() { assertEquals(0, observation) assertEquals(0, aliceClient.stateMachinesSnapshot().size) // 1 for GetNumberOfCheckpointsFlow - assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfUncompletedCheckpointsFlow).returnValue.get()) } } @@ -277,7 +277,7 @@ class StatemachineKillFlowErrorHandlingTest : StatemachineErrorHandlingTest() { assertEquals(1, observation) assertEquals(0, aliceClient.stateMachinesSnapshot().size) // 1 for GetNumberOfCheckpointsFlow - assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfUncompletedCheckpointsFlow).returnValue.get()) } } diff --git a/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StatemachineSubflowErrorHandlingTest.kt b/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StatemachineSubflowErrorHandlingTest.kt index fd491eab97..161f3c4b39 100644 --- a/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StatemachineSubflowErrorHandlingTest.kt +++ b/node/src/integration-test-slow/kotlin/net/corda/node/services/statemachine/StatemachineSubflowErrorHandlingTest.kt @@ -128,7 +128,7 @@ class StatemachineSubflowErrorHandlingTest : StatemachineErrorHandlingTest() { assertEquals(0, observation) assertEquals(0, aliceClient.stateMachinesSnapshot().size) // 1 for GetNumberOfCheckpointsFlow - assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfUncompletedCheckpointsFlow).returnValue.get()) } } @@ -230,7 +230,7 @@ class StatemachineSubflowErrorHandlingTest : StatemachineErrorHandlingTest() { assertEquals(0, observation) assertEquals(0, aliceClient.stateMachinesSnapshot().size) // 1 for GetNumberOfCheckpointsFlow - assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfUncompletedCheckpointsFlow).returnValue.get()) } } @@ -324,7 +324,7 @@ class StatemachineSubflowErrorHandlingTest : StatemachineErrorHandlingTest() { assertEquals(0, observation) assertEquals(0, aliceClient.stateMachinesSnapshot().size) // 1 for GetNumberOfCheckpointsFlow - assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfUncompletedCheckpointsFlow).returnValue.get()) } } @@ -426,7 +426,7 @@ class StatemachineSubflowErrorHandlingTest : StatemachineErrorHandlingTest() { assertEquals(0, observation) assertEquals(0, aliceClient.stateMachinesSnapshot().size) // 1 for GetNumberOfCheckpointsFlow - assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(1, aliceClient.startFlow(StatemachineErrorHandlingTest::GetNumberOfUncompletedCheckpointsFlow).returnValue.get()) } } diff --git a/node/src/integration-test/kotlin/net/corda/node/flows/FlowRetryTest.kt b/node/src/integration-test/kotlin/net/corda/node/flows/FlowRetryTest.kt index 04f8485059..939d755ad9 100644 --- a/node/src/integration-test/kotlin/net/corda/node/flows/FlowRetryTest.kt +++ b/node/src/integration-test/kotlin/net/corda/node/flows/FlowRetryTest.kt @@ -13,6 +13,7 @@ import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.unwrap import net.corda.node.services.Permissions +import net.corda.node.services.statemachine.Checkpoint import net.corda.node.services.statemachine.FlowTimeoutException import net.corda.node.services.statemachine.StaffedFlowHospital import net.corda.testing.core.ALICE_NAME @@ -141,8 +142,7 @@ class FlowRetryTest { .returnValue.getOrThrow(Duration.of(10, ChronoUnit.SECONDS)) } assertEquals(3, TransientConnectionFailureFlow.retryCount) - // 1 for the errored flow kept for observation and another for GetNumberOfCheckpointsFlow - assertEquals(2, it.proxy.startFlow(::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(1, it.proxy.startFlow(::GetCheckpointNumberOfStatusFlow, Checkpoint.FlowStatus.HOSPITALIZED).returnValue.get()) } } } @@ -160,8 +160,7 @@ class FlowRetryTest { .returnValue.getOrThrow(Duration.of(10, ChronoUnit.SECONDS)) } assertEquals(3, WrappedTransientConnectionFailureFlow.retryCount) - // 1 for the errored flow kept for observation and another for GetNumberOfCheckpointsFlow - assertEquals(2, it.proxy.startFlow(::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(1, it.proxy.startFlow(::GetCheckpointNumberOfStatusFlow, Checkpoint.FlowStatus.HOSPITALIZED).returnValue.get()) } } } @@ -179,8 +178,7 @@ class FlowRetryTest { it.proxy.startFlow(::GeneralExternalFailureFlow, nodeBHandle.nodeInfo.singleIdentity()).returnValue.getOrThrow() } assertEquals(0, GeneralExternalFailureFlow.retryCount) - // 1 for the errored flow kept for observation and another for GetNumberOfCheckpointsFlow - assertEquals(1, it.proxy.startFlow(::GetNumberOfCheckpointsFlow).returnValue.get()) + assertEquals(1, it.proxy.startFlow(::GetCheckpointNumberOfStatusFlow, Checkpoint.FlowStatus.FAILED).returnValue.get()) } } } @@ -457,9 +455,15 @@ class GeneralExternalFailureResponder(private val session: FlowSession) : FlowLo } @StartableByRPC -class GetNumberOfCheckpointsFlow : FlowLogic<Long>() { +class GetCheckpointNumberOfStatusFlow(private val flowStatus: Checkpoint.FlowStatus) : FlowLogic<Long>() { override fun call(): Long { - return serviceHub.jdbcSession().prepareStatement("select count(*) from node_checkpoints").use { ps -> + val sqlStatement = + "select count(*) " + + "from node_checkpoints " + + "where status = ${flowStatus.ordinal} " + + "and flow_id != '${runId.uuid}' " // don't count in the checkpoint of the current flow + + return serviceHub.jdbcSession().prepareStatement(sqlStatement).use { ps -> ps.executeQuery().use { rs -> rs.next() rs.getLong(1) diff --git a/node/src/integration-test/kotlin/net/corda/node/flows/KillFlowTest.kt b/node/src/integration-test/kotlin/net/corda/node/flows/KillFlowTest.kt index 0dc0c0d995..eb219baf4e 100644 --- a/node/src/integration-test/kotlin/net/corda/node/flows/KillFlowTest.kt +++ b/node/src/integration-test/kotlin/net/corda/node/flows/KillFlowTest.kt @@ -25,6 +25,7 @@ import net.corda.core.utilities.seconds import net.corda.finance.DOLLARS import net.corda.finance.contracts.asset.Cash import net.corda.finance.flows.CashIssueFlow +import net.corda.node.services.statemachine.Checkpoint import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.BOB_NAME import net.corda.testing.core.CHARLIE_NAME @@ -59,8 +60,7 @@ class KillFlowTest { assertFailsWith<KilledFlowException> { handle.returnValue.getOrThrow(1.minutes) } - val checkpoints = rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) - assertEquals(1, checkpoints) + assertEquals(1, rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) } } } @@ -86,12 +86,11 @@ class KillFlowTest { AFlowThatGetsMurderedWhenItTriesToSuspendAndSomehowKillsItsFriendsResponder.locks.forEach { it.value.acquire() } assertTrue(AFlowThatGetsMurderedWhenItTriesToSuspendAndSomehowKillsItsFriendsResponder.receivedKilledExceptions[BOB_NAME]!!) assertTrue(AFlowThatGetsMurderedWhenItTriesToSuspendAndSomehowKillsItsFriendsResponder.receivedKilledExceptions[CHARLIE_NAME]!!) - val aliceCheckpoints = rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) - assertEquals(1, aliceCheckpoints) - val bobCheckpoints = bob.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) - assertEquals(1, bobCheckpoints) - val charlieCheckpoints = charlie.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) - assertEquals(1, charlieCheckpoints) + assertEquals(1, rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) + assertEquals(2, bob.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) + assertEquals(1, bob.rpc.startFlow(::GetNumberOfFailedCheckpointsFlow).returnValue.getOrThrow(20.seconds)) + assertEquals(2, charlie.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) + assertEquals(1, charlie.rpc.startFlow(::GetNumberOfFailedCheckpointsFlow).returnValue.getOrThrow(20.seconds)) } } } @@ -111,8 +110,7 @@ class KillFlowTest { } assertTrue(time < 1.minutes.toMillis(), "It should at a minimum, take less than a minute to kill this flow") assertTrue(time < 5.seconds.toMillis(), "Really, it should take less than a few seconds to kill a flow") - val checkpoints = rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) - assertEquals(1, checkpoints) + assertEquals(1, rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) } } } @@ -148,8 +146,7 @@ class KillFlowTest { } assertTrue(time < 1.minutes.toMillis(), "It should at a minimum, take less than a minute to kill this flow") assertTrue(time < 5.seconds.toMillis(), "Really, it should take less than a few seconds to kill a flow") - val checkpoints = startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) - assertEquals(1, checkpoints) + assertEquals(1, startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) } @Test(timeout = 300_000) @@ -167,8 +164,7 @@ class KillFlowTest { } assertTrue(time < 1.minutes.toMillis(), "It should at a minimum, take less than a minute to kill this flow") assertTrue(time < 5.seconds.toMillis(), "Really, it should take less than a few seconds to kill a flow") - val checkpoints = rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) - assertEquals(1, checkpoints) + assertEquals(1, rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) } } } @@ -188,8 +184,7 @@ class KillFlowTest { } assertTrue(time < 1.minutes.toMillis(), "It should at a minimum, take less than a minute to kill this flow") assertTrue(time < 5.seconds.toMillis(), "Really, it should take less than a few seconds to kill a flow") - val checkpoints = rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) - assertEquals(1, checkpoints) + assertEquals(1, rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) } } } @@ -217,12 +212,11 @@ class KillFlowTest { } assertTrue(AFlowThatGetsMurderedAndSomehowKillsItsFriendsResponder.receivedKilledExceptions[BOB_NAME]!!) assertTrue(AFlowThatGetsMurderedAndSomehowKillsItsFriendsResponder.receivedKilledExceptions[CHARLIE_NAME]!!) - val aliceCheckpoints = rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) - assertEquals(1, aliceCheckpoints) - val bobCheckpoints = bob.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) - assertEquals(1, bobCheckpoints) - val charlieCheckpoints = charlie.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) - assertEquals(1, charlieCheckpoints) + assertEquals(1, rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) + assertEquals(2, bob.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) + assertEquals(1, bob.rpc.startFlow(::GetNumberOfFailedCheckpointsFlow).returnValue.getOrThrow(20.seconds)) + assertEquals(2, charlie.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) + assertEquals(1, charlie.rpc.startFlow(::GetNumberOfFailedCheckpointsFlow).returnValue.getOrThrow(20.seconds)) } } } @@ -251,12 +245,11 @@ class KillFlowTest { assertTrue(AFlowThatGetsMurderedByItsFriend.receivedKilledException) assertFalse(AFlowThatGetsMurderedByItsFriendResponder.receivedKilledExceptions[BOB_NAME]!!) assertTrue(AFlowThatGetsMurderedByItsFriendResponder.receivedKilledExceptions[CHARLIE_NAME]!!) - val aliceCheckpoints = alice.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) - assertEquals(1, aliceCheckpoints) - val bobCheckpoints = bob.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) - assertEquals(1, bobCheckpoints) - val charlieCheckpoints = charlie.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) - assertEquals(1, charlieCheckpoints) + assertEquals(2, alice.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) + assertEquals(1, alice.rpc.startFlow(::GetNumberOfFailedCheckpointsFlow).returnValue.getOrThrow(20.seconds)) + assertEquals(1, bob.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) + assertEquals(2, charlie.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) + assertEquals(1, charlie.rpc.startFlow(::GetNumberOfFailedCheckpointsFlow).returnValue.getOrThrow(20.seconds)) } } @@ -589,4 +582,18 @@ class KillFlowTest { } } } + + @StartableByRPC + class GetNumberOfFailedCheckpointsFlow : FlowLogic<Long>() { + override fun call(): Long { + return serviceHub.jdbcSession() + .prepareStatement("select count(*) from node_checkpoints where status = ${Checkpoint.FlowStatus.FAILED.ordinal}") + .use { ps -> + ps.executeQuery().use { rs -> + rs.next() + rs.getLong(1) + } + } + } + } } \ No newline at end of file diff --git a/node/src/integration-test/kotlin/net/corda/node/services/persistence/CordaPersistenceServiceTests.kt b/node/src/integration-test/kotlin/net/corda/node/services/persistence/CordaPersistenceServiceTests.kt index 7d6aa1edf8..307e67d12b 100644 --- a/node/src/integration-test/kotlin/net/corda/node/services/persistence/CordaPersistenceServiceTests.kt +++ b/node/src/integration-test/kotlin/net/corda/node/services/persistence/CordaPersistenceServiceTests.kt @@ -3,11 +3,15 @@ package net.corda.node.services.persistence import co.paralleluniverse.fibers.Suspendable import net.corda.core.flows.FlowLogic import net.corda.core.flows.StartableByRPC +import net.corda.core.internal.FlowIORequest +import net.corda.core.internal.PLATFORM_VERSION import net.corda.core.messaging.startFlow import net.corda.core.node.AppServiceHub import net.corda.core.node.services.CordaService +import net.corda.core.node.services.vault.SessionScope import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.utilities.getOrThrow +import net.corda.node.services.statemachine.Checkpoint.FlowStatus import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX import net.corda.testing.driver.DriverParameters import net.corda.testing.driver.driver @@ -15,6 +19,8 @@ import net.corda.testing.driver.internal.incrementalPortAllocation import net.corda.testing.node.internal.enclosedCordapp import org.junit.Test import java.sql.DriverManager +import java.time.Instant +import java.util.UUID import kotlin.test.assertEquals import kotlin.test.assertTrue @@ -31,7 +37,9 @@ class CordaPersistenceServiceTests { assertEquals(sampleSize, count) DriverManager.getConnection("jdbc:h2:tcp://localhost:$port/node", "sa", "").use { - val resultSet = it.createStatement().executeQuery("SELECT count(*) from ${NODE_DATABASE_PREFIX}checkpoints") + val resultSet = it.createStatement().executeQuery( + "SELECT count(*) from ${NODE_DATABASE_PREFIX}checkpoints where status not in (${FlowStatus.COMPLETED.ordinal})" + ) assertTrue(resultSet.next()) val resultSize = resultSet.getInt(1) assertEquals(sampleSize, resultSize) @@ -50,16 +58,53 @@ class CordaPersistenceServiceTests { @CordaService class MultiThreadedDbLoader(private val services: AppServiceHub) : SingletonSerializeAsToken() { - fun createObjects(count: Int) : Int { + fun createObjects(count: Int): Int { (1..count).toList().parallelStream().forEach { + val now = Instant.now() services.database.transaction { - session.save(DBCheckpointStorage.DBCheckpoint().apply { - checkpointId = it.toString() - }) + val flowId = it.toString() + session.save( + DBCheckpointStorage.DBFlowCheckpoint( + flowId = flowId, + blob = DBCheckpointStorage.DBFlowCheckpointBlob( + flowId = flowId, + checkpoint = ByteArray(8192), + flowStack = ByteArray(8192), + hmac = ByteArray(16), + persistedInstant = now + ), + result = null, + exceptionDetails = null, + status = FlowStatus.RUNNABLE, + compatible = false, + progressStep = "", + ioRequestType = FlowIORequest.ForceCheckpoint::class.java.simpleName, + checkpointInstant = now, + flowMetadata = createMetadataRecord(flowId, now) + ) + ) } } return count } + + private fun SessionScope.createMetadataRecord(flowId: String, timestamp: Instant): DBCheckpointStorage.DBFlowMetadata { + val metadata = DBCheckpointStorage.DBFlowMetadata( + invocationId = UUID.randomUUID().toString(), + flowId = flowId, + flowName = "random.flow", + userSuppliedIdentifier = null, + startType = DBCheckpointStorage.StartReason.RPC, + launchingCordapp = "this cordapp", + platformVersion = PLATFORM_VERSION, + startedBy = "Batman", + invocationInstant = timestamp, + startInstant = timestamp, + finishInstant = null + ) + session.save(metadata) + return metadata + } } -} \ No newline at end of file +} diff --git a/node/src/integration-test/kotlin/net/corda/node/services/statemachine/FlowPausingTest.kt b/node/src/integration-test/kotlin/net/corda/node/services/statemachine/FlowPausingTest.kt new file mode 100644 index 0000000000..c2961a8045 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/node/services/statemachine/FlowPausingTest.kt @@ -0,0 +1,113 @@ +package net.corda.node.services.statemachine + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.FlowSession +import net.corda.core.flows.InitiatedBy +import net.corda.core.flows.InitiatingFlow +import net.corda.core.flows.StartableByRPC +import net.corda.core.flows.StateMachineRunId +import net.corda.core.identity.Party +import net.corda.core.internal.messaging.InternalCordaRPCOps +import net.corda.core.messaging.CordaRPCOps +import net.corda.core.messaging.startFlow +import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.unwrap +import net.corda.node.services.Permissions +import net.corda.testing.core.ALICE_NAME +import net.corda.testing.core.BOB_NAME +import net.corda.testing.driver.DriverParameters +import net.corda.testing.driver.NodeParameters +import net.corda.testing.driver.driver +import net.corda.testing.node.User +import org.junit.Test +import java.time.Duration +import kotlin.test.assertEquals +import kotlin.test.assertNotNull + +class FlowPausingTest { + + companion object { + val TOTAL_MESSAGES = 100 + val SLEEP_BETWEEN_MESSAGES_MS = 10L + } + + @Test(timeout = 300_000) + fun `Paused flows can recieve session messages`() { + val rpcUser = User("demo", "demo", setOf(Permissions.startFlow<HardRestartTest.Ping>(), Permissions.all())) + driver(DriverParameters(startNodesInProcess = true, inMemoryDB = false)) { + val alice = startNode(NodeParameters(providedName = ALICE_NAME, rpcUsers = listOf(rpcUser))).getOrThrow() + val bob = startNode(NodeParameters(providedName = BOB_NAME, rpcUsers = listOf(rpcUser))) + val startedBob = bob.getOrThrow() + val aliceFlow = alice.rpc.startFlow(::HeartbeatFlow, startedBob.nodeInfo.legalIdentities[0]) + // We wait here for the initiated flow to start running on bob + val initiatedFlowId = startedBob.rpc.waitForFlowToStart(150) + assertNotNull(initiatedFlowId) + /* We shut down bob, we want this to happen before bob has finished receiving all of the heartbeats. + This is a Race but if bob finishes too quickly then we will fail to unpause the initiated flow running on BOB latter + and this test will fail.*/ + startedBob.stop() + //Start bob backup in Safe mode. This means no flows will run but BOB should receive messages and queue these up. + val restartedBob = startNode(NodeParameters( + providedName = BOB_NAME, + rpcUsers = listOf(rpcUser), + customOverrides = mapOf("smmStartMode" to "Safe"))).getOrThrow() + + //Sleep for long enough so BOB has time to receive all the messages. + //All messages in this period should be queued up and replayed when the flow is unpaused. + Thread.sleep(TOTAL_MESSAGES * SLEEP_BETWEEN_MESSAGES_MS) + //ALICE should not have finished yet as the HeartbeatResponderFlow should not have sent the final message back (as it is paused). + assertEquals(false, aliceFlow.returnValue.isDone) + assertEquals(true, (restartedBob.rpc as InternalCordaRPCOps).unPauseFlow(initiatedFlowId!!)) + + assertEquals(true, aliceFlow.returnValue.getOrThrow()) + alice.stop() + restartedBob.stop() + } + } + + fun CordaRPCOps.waitForFlowToStart(maxTrys: Int): StateMachineRunId? { + for (i in 1..maxTrys) { + val snapshot = this.stateMachinesSnapshot().singleOrNull() + if (snapshot == null) { + Thread.sleep(SLEEP_BETWEEN_MESSAGES_MS) + } else { + return snapshot.id + } + } + return null + } + + @StartableByRPC + @InitiatingFlow + class HeartbeatFlow(private val otherParty: Party): FlowLogic<Boolean>() { + var sequenceNumber = 0 + @Suspendable + override fun call(): Boolean { + val session = initiateFlow(otherParty) + for (i in 1..TOTAL_MESSAGES) { + session.send(sequenceNumber++) + sleep(Duration.ofMillis(10)) + } + val success = session.receive<Boolean>().unwrap{data -> data} + return success + } + } + + @InitiatedBy(HeartbeatFlow::class) + class HeartbeatResponderFlow(val session: FlowSession): FlowLogic<Unit>() { + var sequenceNumber : Int = 0 + @Suspendable + override fun call() { + var pass = true + for (i in 1..TOTAL_MESSAGES) { + val receivedSequenceNumber = session.receive<Int>().unwrap{data -> data} + if (receivedSequenceNumber != sequenceNumber) { + pass = false + } + sequenceNumber++ + } + session.send(pass) + } + } +} \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/NodeCmdLineOptions.kt b/node/src/main/kotlin/net/corda/node/NodeCmdLineOptions.kt index d530ae4d41..f98d782ca3 100644 --- a/node/src/main/kotlin/net/corda/node/NodeCmdLineOptions.kt +++ b/node/src/main/kotlin/net/corda/node/NodeCmdLineOptions.kt @@ -14,6 +14,7 @@ import net.corda.node.services.config.ConfigHelper import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.config.Valid import net.corda.node.services.config.parseAsNodeConfiguration +import net.corda.node.services.statemachine.StateMachineManager import net.corda.nodeapi.internal.config.UnknownConfigKeysPolicy import picocli.CommandLine.Option import java.nio.file.Path @@ -48,6 +49,12 @@ open class SharedNodeCmdLineOptions { ) var devMode: Boolean? = null + @Option( + names = ["--pause-all-flows"], + description = ["Do not run any flows on startup. Sets all flows to paused, which can be unpaused via RPC."] + ) + var safeMode: Boolean = false + open fun parseConfiguration(configuration: Config): Valid<NodeConfiguration> { val option = Configuration.Options(strict = unknownConfigKeysPolicy == UnknownConfigKeysPolicy.FAIL) return configuration.parseAsNodeConfiguration(option) @@ -186,6 +193,9 @@ open class NodeCmdLineOptions : SharedNodeCmdLineOptions() { devMode?.let { configOverrides += "devMode" to it } + if (safeMode) { + configOverrides += "smmStartMode" to StateMachineManager.StartMode.Safe.toString() + } return try { valid(ConfigHelper.loadConfig(baseDirectory, configFile, configOverrides = ConfigFactory.parseMap(configOverrides))) } catch (e: ConfigException) { diff --git a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt index 9c015067e3..4c86351957 100644 --- a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt @@ -112,6 +112,7 @@ import net.corda.node.services.network.PersistentNetworkMapCache import net.corda.node.services.persistence.AbstractPartyDescriptor import net.corda.node.services.persistence.AbstractPartyToX500NameAsStringConverter import net.corda.node.services.persistence.AttachmentStorageInternal +import net.corda.node.services.persistence.DBCheckpointPerformanceRecorder import net.corda.node.services.persistence.DBCheckpointStorage import net.corda.node.services.persistence.DBTransactionMappingStorage import net.corda.node.services.persistence.DBTransactionStorage @@ -258,7 +259,6 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration, } val networkMapCache = PersistentNetworkMapCache(cacheFactory, database, identityService).tokenize() - val checkpointStorage = DBCheckpointStorage() @Suppress("LeakingThis") val transactionStorage = makeTransactionStorage(configuration.transactionCacheSizeBytes).tokenize() val networkMapClient: NetworkMapClient? = configuration.networkServices?.let { NetworkMapClient(it.networkMapURL, versionInfo) } @@ -328,6 +328,7 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration, }) } val services = ServiceHubInternalImpl().tokenize() + val checkpointStorage = DBCheckpointStorage(DBCheckpointPerformanceRecorder(services.monitoringService.metrics), platformClock) @Suppress("LeakingThis") val smm = makeStateMachineManager() val flowStarter = FlowStarterImpl(smm, flowLogicRefFactory) @@ -540,7 +541,7 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration, tokenizableServices = null verifyCheckpointsCompatible(frozenTokenizableServices) - val smmStartedFuture = smm.start(frozenTokenizableServices) + val smmStartedFuture = smm.start(frozenTokenizableServices, configuration.smmStartMode) // Shut down the SMM so no Fibers are scheduled. runOnStop += { smm.stop(acceptableLiveFiberCountOnStop()) } val flowMonitor = FlowMonitor( @@ -1348,7 +1349,7 @@ fun CordaPersistence.startHikariPool(hikariProperties: Properties, databaseConfi try { val dataSource = DataSourceFactory.createDataSource(hikariProperties, metricRegistry = metricRegistry) val schemaMigration = SchemaMigration(schemas, dataSource, databaseConfig, cordappLoader, currentDir, ourName) - schemaMigration.nodeStartup(dataSource.connection.use { DBCheckpointStorage().getCheckpointCount(it) != 0L }) + schemaMigration.nodeStartup(dataSource.connection.use { DBCheckpointStorage.getCheckpointCount(it) != 0L }) start(dataSource) } catch (ex: Exception) { when { @@ -1376,4 +1377,4 @@ fun clientSslOptionsCompatibleWith(nodeRpcOptions: NodeRpcOptions): ClientRpcSsl } // Here we're using the node's RPC key store as the RPC client's trust store. return ClientRpcSslOptions(trustStorePath = nodeRpcOptions.sslConfig!!.keyStorePath, trustStorePassword = nodeRpcOptions.sslConfig!!.keyStorePassword) -} \ No newline at end of file +} diff --git a/node/src/main/kotlin/net/corda/node/internal/CheckpointVerifier.kt b/node/src/main/kotlin/net/corda/node/internal/CheckpointVerifier.kt index acde0a1a9b..5042e2e9ff 100644 --- a/node/src/main/kotlin/net/corda/node/internal/CheckpointVerifier.kt +++ b/node/src/main/kotlin/net/corda/node/internal/CheckpointVerifier.kt @@ -5,7 +5,6 @@ import net.corda.core.crypto.SecureHash import net.corda.core.flows.FlowLogic import net.corda.core.node.ServiceHub import net.corda.core.serialization.internal.CheckpointSerializationDefaults -import net.corda.core.serialization.internal.checkpointDeserialize import net.corda.node.services.api.CheckpointStorage import net.corda.node.services.statemachine.SubFlow import net.corda.node.services.statemachine.SubFlowVersion @@ -36,10 +35,10 @@ object CheckpointVerifier { val cordappsByHash = currentCordapps.associateBy { it.jarHash } - checkpointStorage.getAllCheckpoints().use { + checkpointStorage.getCheckpointsToRun().use { it.forEach { (_, serializedCheckpoint) -> val checkpoint = try { - serializedCheckpoint.checkpointDeserialize(context = checkpointSerializationContext) + serializedCheckpoint.deserialize(checkpointSerializationContext) } catch (e: ClassNotFoundException) { val message = e.message if (message != null) { @@ -52,7 +51,7 @@ object CheckpointVerifier { } // For each Subflow, compare the checkpointed version to the current version. - checkpoint.subFlowStack.forEach { checkFlowCompatible(it, cordappsByHash, platformVersion) } + checkpoint.checkpointState.subFlowStack.forEach { checkFlowCompatible(it, cordappsByHash, platformVersion) } } } } diff --git a/node/src/main/kotlin/net/corda/node/internal/CordaRPCOpsImpl.kt b/node/src/main/kotlin/net/corda/node/internal/CordaRPCOpsImpl.kt index 571c97b82c..3ee3126c29 100644 --- a/node/src/main/kotlin/net/corda/node/internal/CordaRPCOpsImpl.kt +++ b/node/src/main/kotlin/net/corda/node/internal/CordaRPCOpsImpl.kt @@ -169,6 +169,8 @@ internal class CordaRPCOpsImpl( override fun killFlow(id: StateMachineRunId): Boolean = smm.killFlow(id) + override fun unPauseFlow(id: StateMachineRunId): Boolean = smm.unPauseFlow(id) + override fun stateMachinesFeed(): DataFeed<List<StateMachineInfo>, StateMachineUpdate> { val (allStateMachines, changes) = smm.track() diff --git a/node/src/main/kotlin/net/corda/node/services/api/CheckpointStorage.kt b/node/src/main/kotlin/net/corda/node/services/api/CheckpointStorage.kt index b463372909..0bac15c171 100644 --- a/node/src/main/kotlin/net/corda/node/services/api/CheckpointStorage.kt +++ b/node/src/main/kotlin/net/corda/node/services/api/CheckpointStorage.kt @@ -3,7 +3,8 @@ package net.corda.node.services.api import net.corda.core.flows.StateMachineRunId import net.corda.core.serialization.SerializedBytes import net.corda.node.services.statemachine.Checkpoint -import java.sql.Connection +import net.corda.node.services.statemachine.CheckpointState +import net.corda.node.services.statemachine.FlowState import java.util.stream.Stream /** @@ -13,12 +14,20 @@ interface CheckpointStorage { /** * Add a checkpoint for a new id to the store. Will throw if there is already a checkpoint for this id */ - fun addCheckpoint(id: StateMachineRunId, checkpoint: SerializedBytes<Checkpoint>) + fun addCheckpoint(id: StateMachineRunId, checkpoint: Checkpoint, serializedFlowState: SerializedBytes<FlowState>, + serializedCheckpointState: SerializedBytes<CheckpointState>) /** * Update an existing checkpoint. Will throw if there is not checkpoint for this id. */ - fun updateCheckpoint(id: StateMachineRunId, checkpoint: SerializedBytes<Checkpoint>) + fun updateCheckpoint(id: StateMachineRunId, checkpoint: Checkpoint, serializedFlowState: SerializedBytes<FlowState>?, + serializedCheckpointState: SerializedBytes<CheckpointState>) + + /** + * Update all persisted checkpoints with status [Checkpoint.FlowStatus.RUNNABLE] or [Checkpoint.FlowStatus.HOSPITALIZED], + * changing the status to [Checkpoint.FlowStatus.PAUSED]. + */ + fun markAllPaused() /** * Remove existing checkpoint from the store. @@ -28,21 +37,32 @@ interface CheckpointStorage { /** * Load an existing checkpoint from the store. - * @return the checkpoint, still in serialized form, or null if not found. - */ - fun getCheckpoint(id: StateMachineRunId): SerializedBytes<Checkpoint>? - - /** - * Stream all checkpoints from the store. If this is backed by a database the stream will be valid until the - * underlying database connection is closed, so any processing should happen before it is closed. - */ - fun getAllCheckpoints(): Stream<Pair<StateMachineRunId, SerializedBytes<Checkpoint>>> - - /** - * This needs to run before Hibernate is initialised. * - * @param connection The SQL Connection. - * @return the number of checkpoints stored in the database. + * The checkpoint returned from this function will be a _clean_ checkpoint. No error information is loaded into the checkpoint + * even if the previous status of the checkpoint was [Checkpoint.FlowStatus.FAILED] or [Checkpoint.FlowStatus.HOSPITALIZED]. + * + * @return The checkpoint, in a partially serialized form, or null if not found. */ - fun getCheckpointCount(connection: Connection): Long -} \ No newline at end of file + fun getCheckpoint(id: StateMachineRunId): Checkpoint.Serialized? + + /** + * Stream all checkpoints with statuses [statuses] from the store. If this is backed by a database the stream will be valid + * until the underlying database connection is closed, so any processing should happen before it is closed. + */ + fun getCheckpoints( + statuses: Collection<Checkpoint.FlowStatus> = Checkpoint.FlowStatus.values().toSet() + ): Stream<Pair<StateMachineRunId, Checkpoint.Serialized>> + + /** + * Stream runnable checkpoints from the store. If this is backed by a database the stream will be valid + * until the underlying database connection is closed, so any processing should happen before it is closed. + */ + fun getCheckpointsToRun(): Stream<Pair<StateMachineRunId, Checkpoint.Serialized>> + + /** + * Stream paused checkpoints from the store. If this is backed by a database the stream will be valid + * until the underlying database connection is closed, so any processing should happen before it is closed. + * This method does not fetch [Checkpoint.Serialized.serializedFlowState] to save memory. + */ + fun getPausedCheckpoints(): Stream<Pair<StateMachineRunId, Checkpoint.Serialized>> +} diff --git a/node/src/main/kotlin/net/corda/node/services/config/NodeConfiguration.kt b/node/src/main/kotlin/net/corda/node/services/config/NodeConfiguration.kt index f2dc3f16cb..39d2c04a93 100644 --- a/node/src/main/kotlin/net/corda/node/services/config/NodeConfiguration.kt +++ b/node/src/main/kotlin/net/corda/node/services/config/NodeConfiguration.kt @@ -11,6 +11,7 @@ import net.corda.core.internal.notary.NotaryServiceFlow import net.corda.core.utilities.NetworkHostAndPort import net.corda.node.services.config.rpc.NodeRpcOptions import net.corda.node.services.config.schema.v1.V1NodeConfigurationSpec +import net.corda.node.services.statemachine.StateMachineManager import net.corda.nodeapi.internal.config.FileBasedCertificateStoreSupplier import net.corda.nodeapi.internal.config.MutualSslConfiguration import net.corda.nodeapi.internal.config.User @@ -93,6 +94,8 @@ interface NodeConfiguration : ConfigurationWithOptionsContainer { val quasarExcludePackages: List<String> + val smmStartMode: StateMachineManager.StartMode + companion object { // default to at least 8MB and a bit extra for larger heap sizes val defaultTransactionCacheSize: Long = 8.MB + getAdditionalCacheMemory() diff --git a/node/src/main/kotlin/net/corda/node/services/config/NodeConfigurationImpl.kt b/node/src/main/kotlin/net/corda/node/services/config/NodeConfigurationImpl.kt index e1dcc86903..6a9128b503 100644 --- a/node/src/main/kotlin/net/corda/node/services/config/NodeConfigurationImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/config/NodeConfigurationImpl.kt @@ -8,6 +8,7 @@ import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.loggerFor import net.corda.core.utilities.seconds import net.corda.node.services.config.rpc.NodeRpcOptions +import net.corda.node.services.statemachine.StateMachineManager import net.corda.nodeapi.BrokerRpcSslOptions import net.corda.nodeapi.internal.DEV_PUB_KEY_HASHES import net.corda.nodeapi.internal.config.FileBasedCertificateStoreSupplier @@ -84,7 +85,8 @@ data class NodeConfigurationImpl( override val blacklistedAttachmentSigningKeys: List<String> = Defaults.blacklistedAttachmentSigningKeys, override val configurationWithOptions: ConfigurationWithOptions, override val flowExternalOperationThreadPoolSize: Int = Defaults.flowExternalOperationThreadPoolSize, - override val quasarExcludePackages: List<String> = Defaults.quasarExcludePackages + override val quasarExcludePackages: List<String> = Defaults.quasarExcludePackages, + override val smmStartMode : StateMachineManager.StartMode = Defaults.smmStartMode ) : NodeConfiguration { internal object Defaults { val jmxMonitoringHttpPort: Int? = null @@ -123,6 +125,7 @@ data class NodeConfigurationImpl( val blacklistedAttachmentSigningKeys: List<String> = emptyList() const val flowExternalOperationThreadPoolSize: Int = 1 val quasarExcludePackages: List<String> = emptyList() + val smmStartMode : StateMachineManager.StartMode = StateMachineManager.StartMode.ExcludingPaused fun cordappsDirectories(baseDirectory: Path) = listOf(baseDirectory / CORDAPPS_DIR_NAME_DEFAULT) diff --git a/node/src/main/kotlin/net/corda/node/services/config/schema/v1/V1NodeConfigurationSpec.kt b/node/src/main/kotlin/net/corda/node/services/config/schema/v1/V1NodeConfigurationSpec.kt index b4c5477e14..a5bde3e836 100644 --- a/node/src/main/kotlin/net/corda/node/services/config/schema/v1/V1NodeConfigurationSpec.kt +++ b/node/src/main/kotlin/net/corda/node/services/config/schema/v1/V1NodeConfigurationSpec.kt @@ -9,6 +9,7 @@ import net.corda.common.validation.internal.Validated.Companion.valid import net.corda.node.services.config.* import net.corda.node.services.config.NodeConfigurationImpl.Defaults import net.corda.node.services.config.schema.parsers.* +import net.corda.node.services.statemachine.StateMachineManager internal object V1NodeConfigurationSpec : Configuration.Specification<NodeConfiguration>("NodeConfiguration") { private val myLegalName by string().mapValid(::toCordaX500Name) @@ -66,6 +67,7 @@ internal object V1NodeConfigurationSpec : Configuration.Specification<NodeConfig .withDefaultValue(Defaults.networkParameterAcceptanceSettings) private val flowExternalOperationThreadPoolSize by int().optional().withDefaultValue(Defaults.flowExternalOperationThreadPoolSize) private val quasarExcludePackages by string().list().optional().withDefaultValue(Defaults.quasarExcludePackages) + private val smmStartMode by enum(StateMachineManager.StartMode::class).optional().withDefaultValue(Defaults.smmStartMode) @Suppress("unused") private val custom by nestedObject().optional() @Suppress("unused") @@ -133,7 +135,8 @@ internal object V1NodeConfigurationSpec : Configuration.Specification<NodeConfig networkParameterAcceptanceSettings = config[networkParameterAcceptanceSettings], configurationWithOptions = ConfigurationWithOptions(configuration, Configuration.Options.defaults), flowExternalOperationThreadPoolSize = config[flowExternalOperationThreadPoolSize], - quasarExcludePackages = config[quasarExcludePackages] + quasarExcludePackages = config[quasarExcludePackages], + smmStartMode = config[smmStartMode] )) } catch (e: Exception) { return when (e) { diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointPerformanceRecorder.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointPerformanceRecorder.kt new file mode 100644 index 0000000000..d1d713f96a --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointPerformanceRecorder.kt @@ -0,0 +1,67 @@ +package net.corda.node.services.persistence + +import com.codahale.metrics.Gauge +import com.codahale.metrics.Histogram +import com.codahale.metrics.MetricRegistry +import com.codahale.metrics.Reservoir +import com.codahale.metrics.SlidingTimeWindowArrayReservoir +import com.codahale.metrics.SlidingTimeWindowReservoir +import net.corda.core.serialization.SerializedBytes +import net.corda.node.services.statemachine.CheckpointState +import net.corda.node.services.statemachine.FlowState +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicLong + +interface CheckpointPerformanceRecorder { + + /** + * Record performance metrics regarding the serialized size of [CheckpointState] and [FlowState] + */ + fun record(serializedCheckpointState: SerializedBytes<CheckpointState>, serializedFlowState: SerializedBytes<FlowState>?) +} + +class DBCheckpointPerformanceRecorder(metrics: MetricRegistry) : CheckpointPerformanceRecorder { + + private val checkpointingMeter = metrics.meter("Flows.Checkpointing Rate") + private val checkpointSizesThisSecond = SlidingTimeWindowReservoir(1, TimeUnit.SECONDS) + private val lastBandwidthUpdate = AtomicLong(0) + private val checkpointBandwidthHist = metrics.register( + "Flows.CheckpointVolumeBytesPerSecondHist", Histogram( + SlidingTimeWindowArrayReservoir(1, TimeUnit.DAYS) + ) + ) + private val checkpointBandwidth = metrics.register( + "Flows.CheckpointVolumeBytesPerSecondCurrent", + LatchedGauge(checkpointSizesThisSecond) + ) + + /** + * This [Gauge] just reports the sum of the bytes checkpointed during the last second. + */ + private class LatchedGauge(private val reservoir: Reservoir) : Gauge<Long> { + override fun getValue(): Long { + return reservoir.snapshot.values.sum() + } + } + + override fun record(serializedCheckpointState: SerializedBytes<CheckpointState>, serializedFlowState: SerializedBytes<FlowState>?) { + /* For now we don't record states where the serializedFlowState is null and thus the checkpoint is in a completed state. + As this will skew the mean with lots of small checkpoints. For the moment we only measure runnable checkpoints. */ + serializedFlowState?.let { + updateData(serializedCheckpointState.size.toLong() + it.size.toLong()) + } + } + + private fun updateData(totalSize: Long) { + checkpointingMeter.mark() + checkpointSizesThisSecond.update(totalSize) + var lastUpdateTime = lastBandwidthUpdate.get() + while (System.nanoTime() - lastUpdateTime > TimeUnit.SECONDS.toNanos(1)) { + if (lastBandwidthUpdate.compareAndSet(lastUpdateTime, System.nanoTime())) { + val checkpointVolume = checkpointSizesThisSecond.snapshot.values.sum() + checkpointBandwidthHist.update(checkpointVolume) + } + lastUpdateTime = lastBandwidthUpdate.get() + } + } +} diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt index b1eec763f6..2dc5a0d3e9 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt @@ -1,97 +1,625 @@ package net.corda.node.services.persistence +import net.corda.core.context.InvocationContext +import net.corda.core.context.InvocationOrigin import net.corda.core.flows.StateMachineRunId +import net.corda.core.internal.PLATFORM_VERSION +import net.corda.core.internal.VisibleForTesting +import net.corda.core.internal.uncheckedCast +import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.SerializedBytes -import net.corda.core.utilities.debug +import net.corda.core.serialization.serialize +import net.corda.core.utilities.contextLogger import net.corda.node.services.api.CheckpointStorage import net.corda.node.services.statemachine.Checkpoint +import net.corda.node.services.statemachine.Checkpoint.FlowStatus +import net.corda.node.services.statemachine.CheckpointState +import net.corda.node.services.statemachine.ErrorState +import net.corda.node.services.statemachine.FlowState +import net.corda.node.services.statemachine.SubFlowVersion import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX import net.corda.nodeapi.internal.persistence.currentDBSession import org.apache.commons.lang3.ArrayUtils.EMPTY_BYTE_ARRAY -import org.slf4j.Logger -import org.slf4j.LoggerFactory +import org.apache.commons.lang3.exception.ExceptionUtils +import org.hibernate.annotations.Type +import java.sql.Connection +import java.sql.SQLException +import java.time.Clock +import java.time.Instant import java.util.* import java.util.stream.Stream import javax.persistence.Column import javax.persistence.Entity +import javax.persistence.FetchType import javax.persistence.Id -import org.hibernate.annotations.Type -import java.sql.Connection -import java.sql.SQLException +import javax.persistence.OneToOne +import javax.persistence.PrimaryKeyJoinColumn /** * Simple checkpoint key value storage in DB. */ -class DBCheckpointStorage : CheckpointStorage { - val log: Logger = LoggerFactory.getLogger(this::class.java) +@Suppress("TooManyFunctions") +class DBCheckpointStorage( + private val checkpointPerformanceRecorder: CheckpointPerformanceRecorder, + private val clock: Clock +) : CheckpointStorage { + + companion object { + val log = contextLogger() + + private const val HMAC_SIZE_BYTES = 16 + + @VisibleForTesting + const val MAX_STACKTRACE_LENGTH = 4000 + private const val MAX_EXC_MSG_LENGTH = 4000 + private const val MAX_EXC_TYPE_LENGTH = 256 + private const val MAX_FLOW_NAME_LENGTH = 128 + private const val MAX_PROGRESS_STEP_LENGTH = 256 + + private val RUNNABLE_CHECKPOINTS = setOf(FlowStatus.RUNNABLE, FlowStatus.HOSPITALIZED) + + /** + * This needs to run before Hibernate is initialised. + * + * No need to set up [DBCheckpointStorage] fully for this function + * + * @param connection The SQL Connection. + * @return the number of checkpoints stored in the database. + */ + fun getCheckpointCount(connection: Connection): Long { + // No need to set up [DBCheckpointStorage] fully for this function + return try { + connection.prepareStatement("select count(*) from node_checkpoints").use { ps -> + ps.executeQuery().use { rs -> + rs.next() + rs.getLong(1) + } + } + } catch (e: SQLException) { + // Happens when the table was not created yet. + 0L + } + } + } + + enum class StartReason { + RPC, SERVICE, SCHEDULED, INITIATED + } @Entity @javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}checkpoints") - class DBCheckpoint( - @Id - @Suppress("MagicNumber") // database column width - @Column(name = "checkpoint_id", length = 64, nullable = false) - var checkpointId: String = "", + class DBFlowCheckpoint( + @Id + @Column(name = "flow_id", length = 64, nullable = false) + var flowId: String, - @Type(type = "corda-blob") - @Column(name = "checkpoint_value", nullable = false) - var checkpoint: ByteArray = EMPTY_BYTE_ARRAY + @OneToOne(fetch = FetchType.LAZY, optional = true) + @PrimaryKeyJoinColumn + var blob: DBFlowCheckpointBlob?, + + @OneToOne(fetch = FetchType.LAZY, optional = true) + @PrimaryKeyJoinColumn + var result: DBFlowResult?, + + @OneToOne(fetch = FetchType.LAZY, optional = true) + @PrimaryKeyJoinColumn + var exceptionDetails: DBFlowException?, + + @OneToOne(fetch = FetchType.LAZY) + @PrimaryKeyJoinColumn + var flowMetadata: DBFlowMetadata, + + @Column(name = "status", nullable = false) + var status: FlowStatus, + + @Column(name = "compatible", nullable = false) + var compatible: Boolean, + + @Column(name = "progress_step") + var progressStep: String?, + + @Column(name = "flow_io_request") + var ioRequestType: String?, + + @Column(name = "timestamp", nullable = false) + var checkpointInstant: Instant + ) + + @Entity + @javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}checkpoint_blobs") + class DBFlowCheckpointBlob( + @Id + @Column(name = "flow_id", length = 64, nullable = false) + var flowId: String, + + @Type(type = "corda-blob") + @Column(name = "checkpoint_value", nullable = false) + var checkpoint: ByteArray = EMPTY_BYTE_ARRAY, + + @Type(type = "corda-blob") + @Column(name = "flow_state") + var flowStack: ByteArray?, + + @Type(type = "corda-wrapper-binary") + @Column(name = "hmac") + var hmac: ByteArray, + + @Column(name = "timestamp") + var persistedInstant: Instant + ) + + @Entity + @javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}flow_results") + class DBFlowResult( + @Id + @Column(name = "flow_id", length = 64, nullable = false) + var flow_id: String, + + @Type(type = "corda-blob") + @Column(name = "result_value", nullable = false) + var value: ByteArray = EMPTY_BYTE_ARRAY, + + @Column(name = "timestamp") + val persistedInstant: Instant + ) + + @Entity + @javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}flow_exceptions") + class DBFlowException( + @Id + @Column(name = "flow_id", length = 64, nullable = false) + var flow_id: String, + + @Column(name = "type", nullable = false) + var type: String, + + @Column(name = "exception_message") + var message: String? = null, + + @Column(name = "stack_trace", nullable = false) + var stackTrace: String, + + @Type(type = "corda-blob") + @Column(name = "exception_value") + var value: ByteArray? = null, + + @Column(name = "timestamp") + val persistedInstant: Instant + ) + + @Entity + @javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}flow_metadata") + class DBFlowMetadata( + @Id + @Column(name = "flow_id", length = 64, nullable = false) + var flowId: String, + + @Column(name = "invocation_id", nullable = false) + var invocationId: String, + + @Column(name = "flow_name", nullable = false) + var flowName: String, + + @Column(name = "flow_identifier", nullable = true) + var userSuppliedIdentifier: String?, + + @Column(name = "started_type", nullable = false) + var startType: StartReason, + + @Column(name = "flow_parameters", nullable = false) + var initialParameters: ByteArray = EMPTY_BYTE_ARRAY, + + @Column(name = "cordapp_name", nullable = false) + var launchingCordapp: String, + + @Column(name = "platform_version", nullable = false) + var platformVersion: Int, + + @Column(name = "started_by", nullable = false) + var startedBy: String, + + @Column(name = "invocation_time", nullable = false) + var invocationInstant: Instant, + + @Column(name = "start_time", nullable = true) + var startInstant: Instant, + + @Column(name = "finish_time", nullable = true) + var finishInstant: Instant? ) { - override fun toString() = "DBCheckpoint(checkpointId = ${checkpointId}, checkpointSize = ${checkpoint.size})" - } + @Suppress("ComplexMethod") + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (javaClass != other?.javaClass) return false - override fun addCheckpoint(id: StateMachineRunId, checkpoint: SerializedBytes<Checkpoint>) { - currentDBSession().save(DBCheckpoint().apply { - checkpointId = id.uuid.toString() - this.checkpoint = checkpoint.bytes - log.debug { "Checkpoint $checkpointId, size=${this.checkpoint.size}" } - }) + other as DBFlowMetadata + + if (flowId != other.flowId) return false + if (invocationId != other.invocationId) return false + if (flowName != other.flowName) return false + if (userSuppliedIdentifier != other.userSuppliedIdentifier) return false + if (startType != other.startType) return false + if (!initialParameters.contentEquals(other.initialParameters)) return false + if (launchingCordapp != other.launchingCordapp) return false + if (platformVersion != other.platformVersion) return false + if (startedBy != other.startedBy) return false + if (invocationInstant != other.invocationInstant) return false + if (startInstant != other.startInstant) return false + if (finishInstant != other.finishInstant) return false + + return true + } + + override fun hashCode(): Int { + var result = flowId.hashCode() + result = 31 * result + invocationId.hashCode() + result = 31 * result + flowName.hashCode() + result = 31 * result + (userSuppliedIdentifier?.hashCode() ?: 0) + result = 31 * result + startType.hashCode() + result = 31 * result + initialParameters.contentHashCode() + result = 31 * result + launchingCordapp.hashCode() + result = 31 * result + platformVersion + result = 31 * result + startedBy.hashCode() + result = 31 * result + invocationInstant.hashCode() + result = 31 * result + startInstant.hashCode() + result = 31 * result + (finishInstant?.hashCode() ?: 0) + return result + } } - override fun updateCheckpoint(id: StateMachineRunId, checkpoint: SerializedBytes<Checkpoint>) { - currentDBSession().update(DBCheckpoint().apply { - checkpointId = id.uuid.toString() - this.checkpoint = checkpoint.bytes - log.debug { "Checkpoint $checkpointId, size=${this.checkpoint.size}" } - }) + override fun addCheckpoint( + id: StateMachineRunId, + checkpoint: Checkpoint, + serializedFlowState: SerializedBytes<FlowState>, + serializedCheckpointState: SerializedBytes<CheckpointState> + ) { + val now = clock.instant() + val flowId = id.uuid.toString() + + checkpointPerformanceRecorder.record(serializedCheckpointState, serializedFlowState) + + val blob = createDBCheckpointBlob( + flowId, + serializedCheckpointState, + serializedFlowState, + now + ) + + val metadata = createDBFlowMetadata(flowId, checkpoint) + + // Most fields are null as they cannot have been set when creating the initial checkpoint + val dbFlowCheckpoint = DBFlowCheckpoint( + flowId = flowId, + blob = blob, + result = null, + exceptionDetails = null, + flowMetadata = metadata, + status = checkpoint.status, + compatible = checkpoint.compatible, + progressStep = null, + ioRequestType = null, + checkpointInstant = now + ) + + currentDBSession().save(dbFlowCheckpoint) + currentDBSession().save(blob) + currentDBSession().save(metadata) } + override fun updateCheckpoint( + id: StateMachineRunId, checkpoint: Checkpoint, serializedFlowState: SerializedBytes<FlowState>?, + serializedCheckpointState: SerializedBytes<CheckpointState> + ) { + val now = clock.instant() + val flowId = id.uuid.toString() + // Do not update in DB [Checkpoint.checkpointState] or [Checkpoint.flowState] if flow failed or got hospitalized + val blob = if (checkpoint.status == FlowStatus.FAILED || checkpoint.status == FlowStatus.HOSPITALIZED) { + null + } else { + checkpointPerformanceRecorder.record(serializedCheckpointState, serializedFlowState) + createDBCheckpointBlob( + flowId, + serializedCheckpointState, + serializedFlowState, + now + ) + } + + //This code needs to be added back in when we want to persist the result. For now this requires the result to be @CordaSerializable. + //val result = updateDBFlowResult(entity, checkpoint, now) + val exceptionDetails = updateDBFlowException(flowId, checkpoint, now) + + val metadata = createDBFlowMetadata(flowId, checkpoint) + + val dbFlowCheckpoint = DBFlowCheckpoint( + flowId = flowId, + blob = blob, + result = null, + exceptionDetails = exceptionDetails, + flowMetadata = metadata, + status = checkpoint.status, + compatible = checkpoint.compatible, + progressStep = checkpoint.progressStep?.take(MAX_PROGRESS_STEP_LENGTH), + ioRequestType = checkpoint.flowIoRequest, + checkpointInstant = now + ) + + currentDBSession().update(dbFlowCheckpoint) + blob?.let { currentDBSession().update(it) } + if (checkpoint.isFinished()) { + metadata.finishInstant = now + currentDBSession().update(metadata) + } + } + + override fun markAllPaused() { + val session = currentDBSession() + val runnableOrdinals = RUNNABLE_CHECKPOINTS.map { "${it.ordinal}" }.joinToString { it } + val sqlQuery = "Update ${NODE_DATABASE_PREFIX}checkpoints set status = ${FlowStatus.PAUSED.ordinal} " + + "where status in ($runnableOrdinals)" + val query = session.createNativeQuery(sqlQuery) + query.executeUpdate() + } + + // DBFlowResult and DBFlowException to be integrated with rest of schema + @Suppress("MagicNumber") override fun removeCheckpoint(id: StateMachineRunId): Boolean { + var deletedRows = 0 + val flowId = id.uuid.toString() + deletedRows += deleteRow(DBFlowMetadata::class.java, DBFlowMetadata::flowId.name, flowId) + deletedRows += deleteRow(DBFlowCheckpointBlob::class.java, DBFlowCheckpointBlob::flowId.name, flowId) + deletedRows += deleteRow(DBFlowCheckpoint::class.java, DBFlowCheckpoint::flowId.name, flowId) +// resultId?.let { deletedRows += deleteRow(DBFlowResult::class.java, DBFlowResult::flow_id.name, it.toString()) } +// exceptionId?.let { deletedRows += deleteRow(DBFlowException::class.java, DBFlowException::flow_id.name, it.toString()) } + return deletedRows == 3 + } + + private fun <T> deleteRow(clazz: Class<T>, pk: String, value: String): Int { val session = currentDBSession() val criteriaBuilder = session.criteriaBuilder - val delete = criteriaBuilder.createCriteriaDelete(DBCheckpoint::class.java) - val root = delete.from(DBCheckpoint::class.java) - delete.where(criteriaBuilder.equal(root.get<String>(DBCheckpoint::checkpointId.name), id.uuid.toString())) - return session.createQuery(delete).executeUpdate() > 0 + val delete = criteriaBuilder.createCriteriaDelete(clazz) + val root = delete.from(clazz) + delete.where(criteriaBuilder.equal(root.get<String>(pk), value)) + return session.createQuery(delete).executeUpdate() } - override fun getCheckpoint(id: StateMachineRunId): SerializedBytes<Checkpoint>? { - val bytes = currentDBSession().get(DBCheckpoint::class.java, id.uuid.toString())?.checkpoint ?: return null - return SerializedBytes(bytes) + override fun getCheckpoint(id: StateMachineRunId): Checkpoint.Serialized? { + return getDBCheckpoint(id)?.toSerializedCheckpoint() } - override fun getAllCheckpoints(): Stream<Pair<StateMachineRunId, SerializedBytes<Checkpoint>>> { + override fun getCheckpoints(statuses: Collection<FlowStatus>): Stream<Pair<StateMachineRunId, Checkpoint.Serialized>> { val session = currentDBSession() - val criteriaQuery = session.criteriaBuilder.createQuery(DBCheckpoint::class.java) - val root = criteriaQuery.from(DBCheckpoint::class.java) + val criteriaBuilder = session.criteriaBuilder + val criteriaQuery = criteriaBuilder.createQuery(DBFlowCheckpoint::class.java) + val root = criteriaQuery.from(DBFlowCheckpoint::class.java) criteriaQuery.select(root) + .where(criteriaBuilder.isTrue(root.get<FlowStatus>(DBFlowCheckpoint::status.name).`in`(statuses))) return session.createQuery(criteriaQuery).stream().map { - StateMachineRunId(UUID.fromString(it.checkpointId)) to SerializedBytes<Checkpoint>(it.checkpoint) + StateMachineRunId(UUID.fromString(it.flowId)) to it.toSerializedCheckpoint() } } - override fun getCheckpointCount(connection: Connection): Long { - return try { - connection.prepareStatement("select count(*) from node_checkpoints").use { ps -> - ps.executeQuery().use { rs -> - rs.next() - rs.getLong(1) - } - } - } catch (e: SQLException) { - // Happens when the table was not created yet. - 0L + override fun getCheckpointsToRun(): Stream<Pair<StateMachineRunId, Checkpoint.Serialized>> { + return getCheckpoints(RUNNABLE_CHECKPOINTS) + } + + @VisibleForTesting + internal fun getDBCheckpoint(id: StateMachineRunId): DBFlowCheckpoint? { + return currentDBSession().find(DBFlowCheckpoint::class.java, id.uuid.toString()) + } + + override fun getPausedCheckpoints(): Stream<Pair<StateMachineRunId, Checkpoint.Serialized>> { + val session = currentDBSession() + val jpqlQuery = """select new ${DBPausedFields::class.java.name}(checkpoint.id, blob.checkpoint, checkpoint.status, + checkpoint.progressStep, checkpoint.ioRequestType, checkpoint.compatible) from ${DBFlowCheckpoint::class.java.name} + checkpoint join ${DBFlowCheckpointBlob::class.java.name} blob on checkpoint.blob = blob.id where + checkpoint.status = ${FlowStatus.PAUSED.ordinal}""".trimIndent() + val query = session.createQuery(jpqlQuery, DBPausedFields::class.java) + return query.resultList.stream().map { + StateMachineRunId(UUID.fromString(it.id)) to it.toSerializedCheckpoint() } } + + private fun createDBFlowMetadata(flowId: String, checkpoint: Checkpoint): DBFlowMetadata { + val context = checkpoint.checkpointState.invocationContext + val flowInfo = checkpoint.checkpointState.subFlowStack.first() + return DBFlowMetadata( + flowId = flowId, + invocationId = context.trace.invocationId.value, + // Truncate the flow name to fit into the database column + // Flow names are unlikely to be this long + flowName = flowInfo.flowClass.name.take(MAX_FLOW_NAME_LENGTH), + // will come from the context + userSuppliedIdentifier = null, + startType = context.getStartedType(), + initialParameters = context.getFlowParameters().storageSerialize().bytes, + launchingCordapp = (flowInfo.subFlowVersion as? SubFlowVersion.CorDappFlow)?.corDappName ?: "Core flow", + platformVersion = PLATFORM_VERSION, + startedBy = context.principal().name, + invocationInstant = context.trace.invocationId.timestamp, + startInstant = clock.instant(), + finishInstant = null + ) + } + + private fun createDBCheckpointBlob( + flowId: String, + serializedCheckpointState: SerializedBytes<CheckpointState>, + serializedFlowState: SerializedBytes<FlowState>?, + now: Instant + ): DBFlowCheckpointBlob { + return DBFlowCheckpointBlob( + flowId = flowId, + checkpoint = serializedCheckpointState.bytes, + flowStack = serializedFlowState?.bytes, + hmac = ByteArray(HMAC_SIZE_BYTES), + persistedInstant = now + ) + } + + /** + * Creates, updates or deletes the result related to the current flow/checkpoint. + * + * This is needed because updates are not cascading via Hibernate, therefore operations must be handled manually. + * + * A [DBFlowResult] is created if [DBFlowCheckpoint.result] does not exist and the [Checkpoint] has a result.. + * The existing [DBFlowResult] is updated if [DBFlowCheckpoint.result] exists and the [Checkpoint] has a result. + * The existing [DBFlowResult] is deleted if [DBFlowCheckpoint.result] exists and the [Checkpoint] has no result. + * Nothing happens if both [DBFlowCheckpoint] and [Checkpoint] do not have a result. + */ + private fun updateDBFlowResult(flowId: String, entity: DBFlowCheckpoint, checkpoint: Checkpoint, now: Instant): DBFlowResult? { + val result = checkpoint.result?.let { createDBFlowResult(flowId, it, now) } + if (entity.result != null) { + if (result != null) { + result.flow_id = entity.result!!.flow_id + currentDBSession().update(result) + } else { + currentDBSession().delete(entity.result) + } + } else if (result != null) { + currentDBSession().save(result) + } + return result + } + + private fun createDBFlowResult(flowId: String, result: Any, now: Instant): DBFlowResult { + return DBFlowResult( + flow_id = flowId, + value = result.storageSerialize().bytes, + persistedInstant = now + ) + } + + /** + * Creates, updates or deletes the error related to the current flow/checkpoint. + * + * This is needed because updates are not cascading via Hibernate, therefore operations must be handled manually. + * + * A [DBFlowException] is created if [DBFlowCheckpoint.exceptionDetails] does not exist and the [Checkpoint] has an error attached to it. + * The existing [DBFlowException] is updated if [DBFlowCheckpoint.exceptionDetails] exists and the [Checkpoint] has an error. + * The existing [DBFlowException] is deleted if [DBFlowCheckpoint.exceptionDetails] exists and the [Checkpoint] has no error. + * Nothing happens if both [DBFlowCheckpoint] and [Checkpoint] are related to no errors. + */ + // DBFlowException to be integrated with rest of schema + // Add a flag notifying if an exception is already saved in the database for below logic (are we going to do this after all?) + private fun updateDBFlowException(flowId: String, checkpoint: Checkpoint, now: Instant): DBFlowException? { + val exceptionDetails = (checkpoint.errorState as? ErrorState.Errored)?.let { createDBFlowException(flowId, it, now) } +// if (checkpoint.dbExoSkeleton.dbFlowExceptionId != null) { +// if (exceptionDetails != null) { +// exceptionDetails.flow_id = checkpoint.dbExoSkeleton.dbFlowExceptionId!! +// currentDBSession().update(exceptionDetails) +// } else { +// val session = currentDBSession() +// val entity = session.get(DBFlowException::class.java, checkpoint.dbExoSkeleton.dbFlowExceptionId) +// session.delete(entity) +// return null +// } +// } else if (exceptionDetails != null) { +// currentDBSession().save(exceptionDetails) +// checkpoint.dbExoSkeleton.dbFlowExceptionId = exceptionDetails.flow_id +// } + return exceptionDetails + } + + private fun createDBFlowException(flowId: String, errorState: ErrorState.Errored, now: Instant): DBFlowException { + return errorState.errors.last().exception.let { + DBFlowException( + flow_id = flowId, + type = it::class.java.name.truncate(MAX_EXC_TYPE_LENGTH, true), + message = it.message?.truncate(MAX_EXC_MSG_LENGTH, false), + stackTrace = it.stackTraceToString(), + value = null, // TODO to be populated upon implementing https://r3-cev.atlassian.net/browse/CORDA-3681 + persistedInstant = now + ) + } + } + + private fun InvocationContext.getStartedType(): StartReason { + return when (origin) { + is InvocationOrigin.RPC, is InvocationOrigin.Shell -> StartReason.RPC + is InvocationOrigin.Peer -> StartReason.INITIATED + is InvocationOrigin.Service -> StartReason.SERVICE + is InvocationOrigin.Scheduled -> StartReason.SCHEDULED + } + } + + private fun InvocationContext.getFlowParameters(): List<Any?> { + // Only RPC flows have parameters which are found in index 1 + return if (arguments.isNotEmpty()) { + uncheckedCast<Any?, Array<Any?>>(arguments[1]).toList() + } else { + emptyList() + } + } + + private fun DBFlowCheckpoint.toSerializedCheckpoint(): Checkpoint.Serialized { + val serialisedFlowState = blob!!.flowStack?.let { SerializedBytes<FlowState>(it) } + return Checkpoint.Serialized( + serializedCheckpointState = SerializedBytes(blob!!.checkpoint), + serializedFlowState = serialisedFlowState, + // Always load as a [Clean] checkpoint to represent that the checkpoint is the last _good_ checkpoint + errorState = ErrorState.Clean, + // A checkpoint with a result should not normally be loaded (it should be [null] most of the time) + result = result?.let { SerializedBytes<Any>(it.value) }, + status = status, + progressStep = progressStep, + flowIoRequest = ioRequestType, + compatible = compatible + ) + } + + private class DBPausedFields( + val id: String, + val checkpoint: ByteArray = EMPTY_BYTE_ARRAY, + val status: FlowStatus, + val progressStep: String?, + val ioRequestType: String?, + val compatible: Boolean + ) { + fun toSerializedCheckpoint(): Checkpoint.Serialized { + return Checkpoint.Serialized( + serializedCheckpointState = SerializedBytes(checkpoint), + serializedFlowState = null, + // Always load as a [Clean] checkpoint to represent that the checkpoint is the last _good_ checkpoint + errorState = ErrorState.Clean, + result = null, + status = status, + progressStep = progressStep, + flowIoRequest = ioRequestType, + compatible = compatible + ) + } + } + + private fun <T : Any> T.storageSerialize(): SerializedBytes<T> { + return serialize(context = SerializationDefaults.STORAGE_CONTEXT) + } + + private fun Checkpoint.isFinished() = when (status) { + FlowStatus.COMPLETED, FlowStatus.KILLED, FlowStatus.FAILED -> true + else -> false + } + + private fun String.truncate(maxLength: Int, withWarnings: Boolean): String { + var str = this + if (length > maxLength) { + if (withWarnings) { + log.warn("Truncating long string before storing it into the database. String: $str.") + } + str = str.substring(0, maxLength) + } + return str + } + + private fun Throwable.stackTraceToString(): String { + var stackTraceStr = ExceptionUtils.getStackTrace(this) + if (stackTraceStr.length > MAX_STACKTRACE_LENGTH) { + // cut off the last line, which will be a half line + val lineBreak = System.getProperty("line.separator") + val truncateIndex = stackTraceStr.lastIndexOf(lineBreak, MAX_STACKTRACE_LENGTH - 1) + stackTraceStr = stackTraceStr.substring(0, truncateIndex + lineBreak.length) // include last line break in + } + return stackTraceStr + } } diff --git a/node/src/main/kotlin/net/corda/node/services/rpc/CheckpointDumperImpl.kt b/node/src/main/kotlin/net/corda/node/services/rpc/CheckpointDumperImpl.kt index c3116f03cb..2d57f8947e 100644 --- a/node/src/main/kotlin/net/corda/node/services/rpc/CheckpointDumperImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/rpc/CheckpointDumperImpl.kt @@ -90,6 +90,11 @@ class CheckpointDumperImpl(private val checkpointStorage: CheckpointStorage, pri companion object { internal val TIME_FORMATTER = DateTimeFormatter.ofPattern("yyyyMMdd-HHmmss").withZone(UTC) private val log = contextLogger() + private val DUMPABLE_CHECKPOINTS = setOf( + Checkpoint.FlowStatus.RUNNABLE, + Checkpoint.FlowStatus.HOSPITALIZED, + Checkpoint.FlowStatus.PAUSED + ) } override val priority: Int = SERVICE_PRIORITY_NORMAL @@ -141,7 +146,7 @@ class CheckpointDumperImpl(private val checkpointStorage: CheckpointStorage, pri try { if (lock.getAndIncrement() == 0 && !file.exists()) { database.transaction { - checkpointStorage.getAllCheckpoints().use { stream -> + checkpointStorage.getCheckpoints(DUMPABLE_CHECKPOINTS).use { stream -> ZipOutputStream(file.outputStream()).use { zip -> stream.forEach { (runId, serialisedCheckpoint) -> @@ -149,8 +154,7 @@ class CheckpointDumperImpl(private val checkpointStorage: CheckpointStorage, pri instrumentCheckpointAgent(runId) val (bytes, fileName) = try { - val checkpoint = - serialisedCheckpoint.checkpointDeserialize(context = checkpointSerializationContext) + val checkpoint = serialisedCheckpoint.deserialize(checkpointSerializationContext) val json = checkpoint.toJson(runId.uuid, now) val jsonBytes = writer.writeValueAsBytes(json) jsonBytes to "${json.topLevelFlowClass.simpleName}-${runId.uuid}.json" @@ -205,13 +209,16 @@ class CheckpointDumperImpl(private val checkpointStorage: CheckpointStorage, pri val fiber = flowState.frozenFiber.checkpointDeserialize(context = checkpointSerializationContext) fiber to fiber.logic } + else -> { + throw IllegalStateException("Only runnable checkpoints with their flow stack are output by the checkpoint dumper") + } } val flowCallStack = if (fiber != null) { // Poke into Quasar's stack and find the object references to the sub-flows so that we can correctly get the current progress // step for each sub-call. val stackObjects = fiber.getQuasarStack() - subFlowStack.map { it.toJson(stackObjects) } + checkpointState.subFlowStack.map { it.toJson(stackObjects) } } else { emptyList() } @@ -226,9 +233,10 @@ class CheckpointDumperImpl(private val checkpointStorage: CheckpointStorage, pri timestamp, now ), - origin = invocationContext.origin.toOrigin(), - ourIdentity = ourIdentity, - activeSessions = sessions.mapNotNull { it.value.toActiveSession(it.key) }, + origin = checkpointState.invocationContext.origin.toOrigin(), + ourIdentity = checkpointState.ourIdentity, + activeSessions = checkpointState.sessions.mapNotNull { it.value.toActiveSession(it.key) }, + // This can only ever return as [ErrorState.Clean] which causes it to become [null] errored = errorState as? ErrorState.Errored ) } diff --git a/node/src/main/kotlin/net/corda/node/services/rpc/RPCServer.kt b/node/src/main/kotlin/net/corda/node/services/rpc/RPCServer.kt index 9a1e1474d0..63503f8f28 100644 --- a/node/src/main/kotlin/net/corda/node/services/rpc/RPCServer.kt +++ b/node/src/main/kotlin/net/corda/node/services/rpc/RPCServer.kt @@ -388,10 +388,11 @@ class RPCServer( val arguments = Try.on { clientToServer.serialisedArguments.deserialize<List<Any?>>(context = RPC_SERVER_CONTEXT) } - val context = artemisMessage.context(clientToServer.sessionId) - context.invocation.pushToLoggingContext() + val context: RpcAuthContext when (arguments) { is Try.Success -> { + context = artemisMessage.context(clientToServer.sessionId, arguments.value) + context.invocation.pushToLoggingContext() log.debug { "Arguments: ${arguments.value.toTypedArray().contentDeepToString()}" } rpcExecutor!!.submit { val result = invokeRpc(context, clientToServer.methodName, arguments.value) @@ -399,6 +400,8 @@ class RPCServer( } } is Try.Failure -> { + context = artemisMessage.context(clientToServer.sessionId, emptyList()) + context.invocation.pushToLoggingContext() // We failed to deserialise the arguments, route back the error log.warn("Inbound RPC failed", arguments.exception) sendReply(clientToServer.replyId, clientToServer.clientAddress, arguments) @@ -476,12 +479,12 @@ class RPCServer( observableMap.cleanUp() } - private fun ClientMessage.context(sessionId: Trace.SessionId): RpcAuthContext { + private fun ClientMessage.context(sessionId: Trace.SessionId, arguments: List<Any?>): RpcAuthContext { val trace = Trace.newInstance(sessionId = sessionId) val externalTrace = externalTrace() val rpcActor = actorFrom(this) val impersonatedActor = impersonatedActor() - return RpcAuthContext(InvocationContext.rpc(rpcActor.first, trace, externalTrace, impersonatedActor), rpcActor.second) + return RpcAuthContext(InvocationContext.rpc(rpcActor.first, trace, externalTrace, impersonatedActor, arguments), rpcActor.second) } private fun actorFrom(message: ClientMessage): Pair<Actor, AuthorizingSubject> { diff --git a/node/src/main/kotlin/net/corda/node/services/schema/NodeSchemaService.kt b/node/src/main/kotlin/net/corda/node/services/schema/NodeSchemaService.kt index a834dbb3ab..d38c6371ef 100644 --- a/node/src/main/kotlin/net/corda/node/services/schema/NodeSchemaService.kt +++ b/node/src/main/kotlin/net/corda/node/services/schema/NodeSchemaService.kt @@ -33,7 +33,12 @@ class NodeSchemaService(private val extraSchemas: Set<MappedSchema> = emptySet() object NodeCore object NodeCoreV1 : MappedSchema(schemaFamily = NodeCore.javaClass, version = 1, - mappedTypes = listOf(DBCheckpointStorage.DBCheckpoint::class.java, + mappedTypes = listOf(DBCheckpointStorage.DBFlowCheckpoint::class.java, + DBCheckpointStorage.DBFlowCheckpointBlob::class.java, + DBCheckpointStorage.DBFlowResult::class.java, + DBCheckpointStorage.DBFlowException::class.java, + DBCheckpointStorage.DBFlowMetadata::class.java, + DBTransactionStorage.DBTransaction::class.java, BasicHSMKeyManagementService.PersistentKey::class.java, NodeSchedulerService.PersistentScheduledState::class.java, diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt index f6bf4463d2..c3ddadd716 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt @@ -2,11 +2,7 @@ package net.corda.node.services.statemachine import co.paralleluniverse.fibers.Suspendable import com.codahale.metrics.Gauge -import com.codahale.metrics.Histogram -import com.codahale.metrics.MetricRegistry import com.codahale.metrics.Reservoir -import com.codahale.metrics.SlidingTimeWindowArrayReservoir -import com.codahale.metrics.SlidingTimeWindowReservoir import net.corda.core.internal.concurrent.thenMatch import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.internal.CheckpointSerializationContext @@ -19,8 +15,6 @@ import net.corda.nodeapi.internal.persistence.contextDatabase import net.corda.nodeapi.internal.persistence.contextTransaction import net.corda.nodeapi.internal.persistence.contextTransactionOrNull import java.time.Duration -import java.util.concurrent.TimeUnit -import java.util.concurrent.atomic.AtomicLong /** * This is the bottom execution engine of flow side-effects. @@ -30,8 +24,7 @@ class ActionExecutorImpl( private val checkpointStorage: CheckpointStorage, private val flowMessaging: FlowMessaging, private val stateMachineManager: StateMachineManagerInternal, - private val checkpointSerializationContext: CheckpointSerializationContext, - metrics: MetricRegistry + private val checkpointSerializationContext: CheckpointSerializationContext ) : ActionExecutor { private companion object { @@ -47,12 +40,6 @@ class ActionExecutorImpl( } } - private val checkpointingMeter = metrics.meter("Flows.Checkpointing Rate") - private val checkpointSizesThisSecond = SlidingTimeWindowReservoir(1, TimeUnit.SECONDS) - private val lastBandwidthUpdate = AtomicLong(0) - private val checkpointBandwidthHist = metrics.register("Flows.CheckpointVolumeBytesPerSecondHist", Histogram(SlidingTimeWindowArrayReservoir(1, TimeUnit.DAYS))) - private val checkpointBandwidth = metrics.register("Flows.CheckpointVolumeBytesPerSecondCurrent", LatchedGauge(checkpointSizesThisSecond)) - @Suspendable override fun executeAction(fiber: FlowFiber, action: Action) { log.trace { "Flow ${fiber.id} executing $action" } @@ -100,21 +87,22 @@ class ActionExecutorImpl( @Suspendable private fun executePersistCheckpoint(action: Action.PersistCheckpoint) { - val checkpointBytes = serializeCheckpoint(action.checkpoint) - if (action.isCheckpointUpdate) { - checkpointStorage.updateCheckpoint(action.id, checkpointBytes) - } else { - checkpointStorage.addCheckpoint(action.id, checkpointBytes) + val checkpoint = action.checkpoint + val flowState = checkpoint.flowState + val serializedFlowState = when(flowState) { + FlowState.Completed -> null + // upon implementing CORDA-3816: If we have errored or hospitalized then we don't need to serialize the flowState as it will not get saved in the DB + else -> flowState.checkpointSerialize(checkpointSerializationContext) } - checkpointingMeter.mark() - checkpointSizesThisSecond.update(checkpointBytes.size.toLong()) - var lastUpdateTime = lastBandwidthUpdate.get() - while (System.nanoTime() - lastUpdateTime > TimeUnit.SECONDS.toNanos(1)) { - if (lastBandwidthUpdate.compareAndSet(lastUpdateTime, System.nanoTime())) { - val checkpointVolume = checkpointSizesThisSecond.snapshot.values.sum() - checkpointBandwidthHist.update(checkpointVolume) + // upon implementing CORDA-3816: If we have errored or hospitalized then we don't need to serialize the serializedCheckpointState as it will not get saved in the DB + val serializedCheckpointState: SerializedBytes<CheckpointState> = checkpoint.checkpointState.checkpointSerialize(checkpointSerializationContext) + if (action.isCheckpointUpdate) { + checkpointStorage.updateCheckpoint(action.id, checkpoint, serializedFlowState, serializedCheckpointState) + } else { + if (flowState is FlowState.Completed) { + throw IllegalStateException("A new checkpoint cannot be created with a Completed FlowState.") } - lastUpdateTime = lastBandwidthUpdate.get() + checkpointStorage.addCheckpoint(action.id, checkpoint, serializedFlowState!!, serializedCheckpointState) } } @@ -269,10 +257,6 @@ class ActionExecutorImpl( stateMachineManager.retryFlowFromSafePoint(action.currentState) } - private fun serializeCheckpoint(checkpoint: Checkpoint): SerializedBytes<Checkpoint> { - return checkpoint.checkpointSerialize(context = checkpointSerializationContext) - } - private fun cancelFlowTimeout(action: Action.CancelFlowTimeout) { stateMachineManager.cancelFlowTimeout(action.flowId) } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/DeduplicationId.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/DeduplicationId.kt index 036c2d2846..aa2778ba49 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/DeduplicationId.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/DeduplicationId.kt @@ -27,7 +27,7 @@ data class DeduplicationId(val toString: String) { * message-id map to change, which means deduplication will not happen correctly. */ fun createForNormal(checkpoint: Checkpoint, index: Int, session: SessionState): DeduplicationId { - return DeduplicationId("N-${session.deduplicationSeed}-${checkpoint.numberOfSuspends}-$index") + return DeduplicationId("N-${session.deduplicationSeed}-${checkpoint.checkpointState.numberOfSuspends}-$index") } /** diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/Event.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/Event.kt index aad376b2ff..fc80c17dfb 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/Event.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/Event.kt @@ -6,8 +6,9 @@ import net.corda.core.identity.Party import net.corda.core.internal.FlowIORequest import net.corda.core.serialization.SerializedBytes import net.corda.core.transactions.SignedTransaction +import net.corda.core.utilities.ProgressTracker import net.corda.node.services.messaging.DeduplicationHandler -import java.util.* +import java.util.UUID /** * Transitions in the flow state machine are triggered by [Event]s that may originate from the flow itself or from @@ -101,17 +102,20 @@ sealed class Event { * @param ioRequest the request triggering the suspension. * @param maySkipCheckpoint indicates whether the persistence may be skipped. * @param fiber the serialised stack of the flow. + * @param progressStep the current progress tracker step. */ data class Suspend( val ioRequest: FlowIORequest<*>, val maySkipCheckpoint: Boolean, - val fiber: SerializedBytes<FlowStateMachineImpl<*>> + val fiber: SerializedBytes<FlowStateMachineImpl<*>>, + var progressStep: ProgressTracker.Step? ) : Event() { override fun toString() = "Suspend(" + "ioRequest=$ioRequest, " + "maySkipCheckpoint=$maySkipCheckpoint, " + "fiber=${fiber.hash}, " + + "currentStep=${progressStep?.label}" + ")" } @@ -150,6 +154,15 @@ sealed class Event { override fun toString() = "RetryFlowFromSafePoint" } + /** + * Keeps a flow for overnight observation. Overnight observation practically sends the fiber to get suspended, + * in [FlowStateMachineImpl.processEventsUntilFlowIsResumed]. Since the fiber's channel will have no more events to process, + * the fiber gets suspended (i.e. hospitalized). + */ + object OvernightObservation : Event() { + override fun toString() = "OvernightObservation" + } + /** * Wake a flow up from its sleep. */ diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowCreator.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowCreator.kt new file mode 100644 index 0000000000..be8026b73f --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowCreator.kt @@ -0,0 +1,191 @@ +package net.corda.node.services.statemachine + +import co.paralleluniverse.fibers.FiberScheduler +import co.paralleluniverse.fibers.Suspendable +import co.paralleluniverse.strands.channels.Channels +import net.corda.core.concurrent.CordaFuture +import net.corda.core.context.InvocationContext +import net.corda.core.flows.FlowException +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.StateMachineRunId +import net.corda.core.identity.Party +import net.corda.core.internal.concurrent.OpenFuture +import net.corda.core.internal.concurrent.openFuture +import net.corda.core.serialization.SerializedBytes +import net.corda.core.serialization.internal.CheckpointSerializationContext +import net.corda.core.serialization.internal.checkpointDeserialize +import net.corda.core.serialization.internal.checkpointSerialize +import net.corda.core.utilities.contextLogger +import net.corda.node.services.api.CheckpointStorage +import net.corda.node.services.api.ServiceHubInternal +import net.corda.node.services.messaging.DeduplicationHandler +import net.corda.node.services.statemachine.transitions.StateMachine +import net.corda.node.utilities.isEnabledTimedFlow +import net.corda.nodeapi.internal.persistence.CordaPersistence +import org.apache.activemq.artemis.utils.ReusableLatch +import java.security.SecureRandom + +class Flow<A>(val fiber: FlowStateMachineImpl<A>, val resultFuture: OpenFuture<Any?>) + +class NonResidentFlow(val runId: StateMachineRunId, val checkpoint: Checkpoint) { + val externalEvents = mutableListOf<Event.DeliverSessionMessage>() + + fun addExternalEvent(message: Event.DeliverSessionMessage) { + externalEvents.add(message) + } +} + +class FlowCreator( + val checkpointSerializationContext: CheckpointSerializationContext, + private val checkpointStorage: CheckpointStorage, + val scheduler: FiberScheduler, + val database: CordaPersistence, + val transitionExecutor: TransitionExecutor, + val actionExecutor: ActionExecutor, + val secureRandom: SecureRandom, + val serviceHub: ServiceHubInternal, + val unfinishedFibers: ReusableLatch, + val resetCustomTimeout: (StateMachineRunId, Long) -> Unit) { + + companion object { + private val logger = contextLogger() + } + + fun createFlowFromNonResidentFlow(nonResidentFlow: NonResidentFlow): Flow<*>? { + // As for paused flows we don't extract the serialized flow state we need to re-extract the checkpoint from the database. + val checkpoint = when (nonResidentFlow.checkpoint.status) { + Checkpoint.FlowStatus.PAUSED -> { + val serialized = database.transaction { + checkpointStorage.getCheckpoint(nonResidentFlow.runId) + } + serialized?.copy(status = Checkpoint.FlowStatus.RUNNABLE)?.deserialize(checkpointSerializationContext) ?: return null + } + else -> nonResidentFlow.checkpoint + } + return createFlowFromCheckpoint(nonResidentFlow.runId, checkpoint) + } + + fun createFlowFromCheckpoint(runId: StateMachineRunId, oldCheckpoint: Checkpoint): Flow<*>? { + val checkpoint = oldCheckpoint.copy(status = Checkpoint.FlowStatus.RUNNABLE) + val fiber = checkpoint.getFiberFromCheckpoint(runId) ?: return null + val resultFuture = openFuture<Any?>() + fiber.transientValues = TransientReference(createTransientValues(runId, resultFuture)) + fiber.logic.stateMachine = fiber + verifyFlowLogicIsSuspendable(fiber.logic) + val state = createStateMachineState(checkpoint, fiber, true) + fiber.transientState = TransientReference(state) + return Flow(fiber, resultFuture) + } + + @Suppress("LongParameterList") + fun <A> createFlowFromLogic( + flowId: StateMachineRunId, + invocationContext: InvocationContext, + flowLogic: FlowLogic<A>, + flowStart: FlowStart, + ourIdentity: Party, + existingCheckpoint: Checkpoint?, + deduplicationHandler: DeduplicationHandler?, + senderUUID: String?): Flow<A> { + // Before we construct the state machine state by freezing the FlowLogic we need to make sure that lazy properties + // have access to the fiber (and thereby the service hub) + val flowStateMachineImpl = FlowStateMachineImpl(flowId, flowLogic, scheduler) + val resultFuture = openFuture<Any?>() + flowStateMachineImpl.transientValues = TransientReference(createTransientValues(flowId, resultFuture)) + flowLogic.stateMachine = flowStateMachineImpl + val frozenFlowLogic = (flowLogic as FlowLogic<*>).checkpointSerialize(context = checkpointSerializationContext) + val flowCorDappVersion = FlowStateMachineImpl.createSubFlowVersion( + serviceHub.cordappProvider.getCordappForFlow(flowLogic), serviceHub.myInfo.platformVersion) + + val checkpoint = existingCheckpoint?.copy(status = Checkpoint.FlowStatus.RUNNABLE) ?: Checkpoint.create( + invocationContext, + flowStart, + flowLogic.javaClass, + frozenFlowLogic, + ourIdentity, + flowCorDappVersion, + flowLogic.isEnabledTimedFlow() + ).getOrThrow() + + val state = createStateMachineState( + checkpoint, + flowStateMachineImpl, + existingCheckpoint != null, + deduplicationHandler, + senderUUID) + flowStateMachineImpl.transientState = TransientReference(state) + return Flow(flowStateMachineImpl, resultFuture) + } + + private fun Checkpoint.getFiberFromCheckpoint(runId: StateMachineRunId): FlowStateMachineImpl<*>? { + return when (this.flowState) { + is FlowState.Unstarted -> { + val logic = tryCheckpointDeserialize(this.flowState.frozenFlowLogic, runId) ?: return null + FlowStateMachineImpl(runId, logic, scheduler) + } + is FlowState.Started -> tryCheckpointDeserialize(this.flowState.frozenFiber, runId) ?: return null + // Places calling this function is rely on it to return null if the flow cannot be created from the checkpoint. + else -> { + return null + } + } + } + + @Suppress("TooGenericExceptionCaught") + private inline fun <reified T : Any> tryCheckpointDeserialize(bytes: SerializedBytes<T>, flowId: StateMachineRunId): T? { + return try { + bytes.checkpointDeserialize(context = checkpointSerializationContext) + } catch (e: Exception) { + logger.error("Unable to deserialize checkpoint for flow $flowId. Something is very wrong and this flow will be ignored.", e) + null + } + } + + private fun verifyFlowLogicIsSuspendable(logic: FlowLogic<Any?>) { + // Quasar requires (in Java 8) that at least the call method be annotated suspendable. Unfortunately, it's + // easy to forget to add this when creating a new flow, so we check here to give the user a better error. + // + // The Kotlin compiler can sometimes generate a synthetic bridge method from a single call declaration, which + // forwards to the void method and then returns Unit. However annotations do not get copied across to this + // bridge, so we have to do a more complex scan here. + val call = logic.javaClass.methods.first { !it.isSynthetic && it.name == "call" && it.parameterCount == 0 } + if (call.getAnnotation(Suspendable::class.java) == null) { + throw FlowException("${logic.javaClass.name}.call() is not annotated as @Suspendable. Please fix this.") + } + } + + private fun createTransientValues(id: StateMachineRunId, resultFuture: CordaFuture<Any?>): FlowStateMachineImpl.TransientValues { + return FlowStateMachineImpl.TransientValues( + eventQueue = Channels.newChannel(-1, Channels.OverflowPolicy.BLOCK), + resultFuture = resultFuture, + database = database, + transitionExecutor = transitionExecutor, + actionExecutor = actionExecutor, + stateMachine = StateMachine(id, secureRandom), + serviceHub = serviceHub, + checkpointSerializationContext = checkpointSerializationContext, + unfinishedFibers = unfinishedFibers, + waitTimeUpdateHook = { flowId, timeout -> resetCustomTimeout(flowId, timeout) } + ) + } + + private fun createStateMachineState( + checkpoint: Checkpoint, + fiber: FlowStateMachineImpl<*>, + anyCheckpointPersisted: Boolean, + deduplicationHandler: DeduplicationHandler? = null, + senderUUID: String? = null): StateMachineState { + return StateMachineState( + checkpoint = checkpoint, + pendingDeduplicationHandlers = deduplicationHandler?.let { listOf(it) } ?: emptyList(), + isFlowResumed = false, + future = null, + isWaitingForFuture = false, + isAnyCheckpointPersisted = anyCheckpointPersisted, + isStartIdempotent = false, + isRemoved = false, + isKilled = false, + flowLogic = fiber.logic, + senderUUID = senderUUID) + } +} \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMonitor.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMonitor.kt index b947f62f2b..9f80005880 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMonitor.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMonitor.kt @@ -81,7 +81,7 @@ internal class FlowMonitor( is FlowIORequest.WaitForLedgerCommit -> "for the ledger to commit transaction with hash ${request.hash}" is FlowIORequest.GetFlowInfo -> "to get flow information from parties ${request.sessions.partiesInvolved()}" is FlowIORequest.Sleep -> "to wake up from sleep ending at ${LocalDateTime.ofInstant(request.wakeUpAfter, ZoneId.systemDefault())}" - FlowIORequest.WaitForSessionConfirmations -> "for sessions to be confirmed" + is FlowIORequest.WaitForSessionConfirmations -> "for sessions to be confirmed" is FlowIORequest.ExecuteAsyncOperation -> "for asynchronous operation of type ${request.operation::javaClass} to complete" FlowIORequest.ForceCheckpoint -> "for forcing a checkpoint at an arbitrary point in a flow" } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt index c76d4aa2e9..408d8a12b5 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt @@ -8,6 +8,7 @@ import co.paralleluniverse.strands.Strand import co.paralleluniverse.strands.channels.Channel import net.corda.core.concurrent.CordaFuture import net.corda.core.context.InvocationContext +import net.corda.core.contracts.StateRef import net.corda.core.cordapp.Cordapp import net.corda.core.flows.Destination import net.corda.core.flows.FlowException @@ -128,14 +129,11 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId, */ override val logger = log override val resultFuture: CordaFuture<R> get() = uncheckedCast(getTransientField(TransientValues::resultFuture)) - override val context: InvocationContext get() = transientState!!.value.checkpoint.invocationContext - override val ourIdentity: Party get() = transientState!!.value.checkpoint.ourIdentity + override val context: InvocationContext get() = transientState!!.value.checkpoint.checkpointState.invocationContext + override val ourIdentity: Party get() = transientState!!.value.checkpoint.checkpointState.ourIdentity override val isKilled: Boolean get() = transientState!!.value.isKilled - internal var hasSoftLockedStates: Boolean = false - set(value) { - if (value) field = value else throw IllegalArgumentException("Can only set to true") - } + internal val softLockedStates = mutableSetOf<StateRef>() /** * Processes an event by creating the associated transition and executing it using the given executor. @@ -297,7 +295,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId, Thread.currentThread().contextClassLoader = (serviceHub.cordappProvider as CordappProviderImpl).cordappLoader.appClassLoader val result = logic.call() - suspend(FlowIORequest.WaitForSessionConfirmations, maySkipCheckpoint = true) + suspend(FlowIORequest.WaitForSessionConfirmations(), maySkipCheckpoint = true) Try.Success(result) } catch (t: Throwable) { if(t.isUnrecoverable()) { @@ -306,7 +304,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId, logFlowError(t) Try.Failure<R>(t) } - val softLocksId = if (hasSoftLockedStates) logic.runId.uuid else null + val softLocksId = if (softLockedStates.isNotEmpty()) logic.runId.uuid else null val finalEvent = when (resultOrError) { is Try.Success -> { Event.FlowFinish(resultOrError.value, softLocksId) @@ -400,7 +398,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId, */ @Suspendable private fun checkpointIfSubflowIdempotent(subFlow: Class<FlowLogic<*>>) { - val currentFlow = snapshot().checkpoint.subFlowStack.last().flowClass + val currentFlow = snapshot().checkpoint.checkpointState.subFlowStack.last().flowClass if (!currentFlow.isIdempotentFlow() && subFlow.isIdempotentFlow()) { suspend(FlowIORequest.ForceCheckpoint, false) } @@ -489,7 +487,8 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId, Event.Suspend( ioRequest = ioRequest, maySkipCheckpoint = skipPersistingCheckpoint, - fiber = this.checkpointSerialize(context = serializationContext.value) + fiber = this.checkpointSerialize(context = serializationContext.value), + progressStep = logic.progressTracker?.currentStep ) } catch (exception: Exception) { Event.Error(exception) @@ -518,7 +517,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId, } private fun containsIdempotentFlows(): Boolean { - val subFlowStack = snapshot().checkpoint.subFlowStack + val subFlowStack = snapshot().checkpoint.checkpointState.subFlowStack return subFlowStack.any { IdempotentFlow::class.java.isAssignableFrom(it.flowClass) } } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt index 595b644493..c0353eac0a 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt @@ -2,9 +2,7 @@ package net.corda.node.services.statemachine import co.paralleluniverse.fibers.Fiber import co.paralleluniverse.fibers.FiberExecutorScheduler -import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.instrument.JavaAgent -import co.paralleluniverse.strands.channels.Channels import com.codahale.metrics.Gauge import com.google.common.util.concurrent.ThreadFactoryBuilder import net.corda.core.concurrent.CordaFuture @@ -24,12 +22,9 @@ import net.corda.core.internal.concurrent.mapError import net.corda.core.internal.concurrent.openFuture import net.corda.core.internal.mapNotNull import net.corda.core.messaging.DataFeed -import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.deserialize import net.corda.core.serialization.internal.CheckpointSerializationContext import net.corda.core.serialization.internal.CheckpointSerializationDefaults -import net.corda.core.serialization.internal.checkpointDeserialize -import net.corda.core.serialization.internal.checkpointSerialize import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.Try import net.corda.core.utilities.contextLogger @@ -39,13 +34,11 @@ import net.corda.node.services.api.CheckpointStorage import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.config.shouldCheckCheckpoints import net.corda.node.services.messaging.DeduplicationHandler -import net.corda.node.services.statemachine.FlowStateMachineImpl.Companion.createSubFlowVersion import net.corda.node.services.statemachine.interceptors.DumpHistoryOnErrorInterceptor import net.corda.node.services.statemachine.interceptors.FiberDeserializationChecker import net.corda.node.services.statemachine.interceptors.FiberDeserializationCheckingInterceptor import net.corda.node.services.statemachine.interceptors.HospitalisingInterceptor import net.corda.node.services.statemachine.interceptors.PrintingInterceptor -import net.corda.node.services.statemachine.transitions.StateMachine import net.corda.node.utilities.AffinityExecutor import net.corda.node.utilities.errorAndTerminate import net.corda.node.utilities.injectOldProgressTracker @@ -61,6 +54,7 @@ import java.lang.Integer.min import java.security.SecureRandom import java.time.Duration import java.util.HashSet +import java.util.* import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ExecutorService import java.util.concurrent.Executors @@ -90,8 +84,6 @@ class SingleThreadedStateMachineManager( private val logger = contextLogger() } - private class Flow(val fiber: FlowStateMachineImpl<*>, val resultFuture: OpenFuture<Any?>) - private data class ScheduledTimeout( /** Will fire a [FlowTimeoutException] indicating to the flow hospital to restart the flow. */ val scheduledFuture: ScheduledFuture<*>, @@ -105,7 +97,8 @@ class SingleThreadedStateMachineManager( val changesPublisher = PublishSubject.create<StateMachineManager.Change>()!! /** True if we're shutting down, so don't resume anything. */ var stopping = false - val flows = HashMap<StateMachineRunId, Flow>() + val flows = HashMap<StateMachineRunId, Flow<*>>() + val pausedFlows = HashMap<StateMachineRunId, NonResidentFlow>() val startedFutures = HashMap<StateMachineRunId, OpenFuture<Unit>>() /** Flows scheduled to be retried if not finished within the specified timeout period. */ val timedFlows = HashMap<StateMachineRunId, ScheduledTimeout>() @@ -127,7 +120,7 @@ class SingleThreadedStateMachineManager( private val ourSenderUUID = serviceHub.networkService.ourSenderUUID private var checkpointSerializationContext: CheckpointSerializationContext? = null - private var actionExecutor: ActionExecutor? = null + private lateinit var flowCreator: FlowCreator override val flowHospital: StaffedFlowHospital = makeFlowHospital() private val transitionExecutor = makeTransitionExecutor() @@ -146,7 +139,7 @@ class SingleThreadedStateMachineManager( */ override val changes: Observable<StateMachineManager.Change> = mutex.content.changesPublisher - override fun start(tokenizableServices: List<Any>) : CordaFuture<Unit> { + override fun start(tokenizableServices: List<Any>, startMode: StateMachineManager.StartMode): CordaFuture<Unit> { checkQuasarJavaAgentPresence() val checkpointSerializationContext = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext( CheckpointSerializeAsTokenContextImpl( @@ -157,8 +150,24 @@ class SingleThreadedStateMachineManager( ) ) this.checkpointSerializationContext = checkpointSerializationContext - this.actionExecutor = makeActionExecutor(checkpointSerializationContext) + val actionExecutor = makeActionExecutor(checkpointSerializationContext) fiberDeserializationChecker?.start(checkpointSerializationContext) + when (startMode) { + StateMachineManager.StartMode.ExcludingPaused -> {} + StateMachineManager.StartMode.Safe -> markAllFlowsAsPaused() + } + this.flowCreator = FlowCreator( + checkpointSerializationContext, + checkpointStorage, + scheduler, + database, + transitionExecutor, + actionExecutor, + secureRandom, + serviceHub, + unfinishedFibers, + ::resetCustomTimeout) + val fibers = restoreFlowsFromCheckpoints() metrics.register("Flows.InFlight", Gauge<Int> { mutex.content.flows.size }) Fiber.setDefaultUncaughtExceptionHandler { fiber, throwable -> @@ -168,6 +177,17 @@ class SingleThreadedStateMachineManager( (fiber as FlowStateMachineImpl<*>).logger.warn("Caught exception from flow", throwable) } } + + val pausedFlows = restoreNonResidentFlowsFromPausedCheckpoints() + mutex.locked { + this.pausedFlows.putAll(pausedFlows) + for ((id, flow) in pausedFlows) { + val checkpoint = flow.checkpoint + for (sessionId in getFlowSessionIds(checkpoint)) { + sessionToFlow[sessionId] = id + } + } + } return serviceHub.networkMapCache.nodeReady.map { logger.info("Node ready, info: ${serviceHub.myInfo}") resumeRestoredFlows(fibers) @@ -241,8 +261,7 @@ class SingleThreadedStateMachineManager( flowLogic = flowLogic, flowStart = FlowStart.Explicit, ourIdentity = ourIdentity ?: ourFirstIdentity, - deduplicationHandler = deduplicationHandler, - isStartIdempotent = false + deduplicationHandler = deduplicationHandler ) } @@ -282,6 +301,22 @@ class SingleThreadedStateMachineManager( } } + private fun markAllFlowsAsPaused() { + return checkpointStorage.markAllPaused() + } + + override fun unPauseFlow(id: StateMachineRunId): Boolean { + mutex.locked { + val pausedFlow = pausedFlows.remove(id) ?: return false + val flow = flowCreator.createFlowFromNonResidentFlow(pausedFlow) ?: return false + addAndStartFlow(flow.fiber.id, flow) + for (event in pausedFlow.externalEvents) { + flow.fiber.scheduleEvent(event) + } + } + return true + } + override fun addSessionBinding(flowId: StateMachineRunId, sessionId: SessionId) { val previousFlowId = sessionToFlow.put(sessionId, flowId) if (previousFlowId != null) { @@ -351,23 +386,28 @@ class SingleThreadedStateMachineManager( liveFibers.countUp() } - private fun restoreFlowsFromCheckpoints(): List<Flow> { - return checkpointStorage.getAllCheckpoints().use { + private fun restoreFlowsFromCheckpoints(): List<Flow<*>> { + return checkpointStorage.getCheckpointsToRun().use { it.mapNotNull { (id, serializedCheckpoint) -> // If a flow is added before start() then don't attempt to restore it mutex.locked { if (id in flows) return@mapNotNull null } - createFlowFromCheckpoint( - id = id, - serializedCheckpoint = serializedCheckpoint, - initialDeduplicationHandler = null, - isAnyCheckpointPersisted = true, - isStartIdempotent = false - ) + val checkpoint = tryDeserializeCheckpoint(serializedCheckpoint, id) ?: return@mapNotNull null + flowCreator.createFlowFromCheckpoint(id, checkpoint) }.toList() } } - private fun resumeRestoredFlows(flows: List<Flow>) { + private fun restoreNonResidentFlowsFromPausedCheckpoints(): Map<StateMachineRunId, NonResidentFlow> { + return checkpointStorage.getPausedCheckpoints().use { + it.mapNotNull { (id, serializedCheckpoint) -> + // If a flow is added before start() then don't attempt to restore it + val checkpoint = tryDeserializeCheckpoint(serializedCheckpoint, id) ?: return@mapNotNull null + id to NonResidentFlow(id, checkpoint) + }.toList().toMap() + } + } + + private fun resumeRestoredFlows(flows: List<Flow<*>>) { for (flow in flows) { addAndStartFlow(flow.fiber.id, flow) } @@ -392,14 +432,10 @@ class SingleThreadedStateMachineManager( logger.error("Unable to find database checkpoint for flow $flowId. Something is very wrong. The flow will not retry.") return } + + val checkpoint = tryDeserializeCheckpoint(serializedCheckpoint, flowId) ?: return // Resurrect flow - createFlowFromCheckpoint( - id = flowId, - serializedCheckpoint = serializedCheckpoint, - initialDeduplicationHandler = null, - isAnyCheckpointPersisted = true, - isStartIdempotent = false - ) ?: return + flowCreator.createFlowFromCheckpoint(flowId, checkpoint) ?: return } else { // Just flow initiation message null @@ -433,9 +469,9 @@ class SingleThreadedStateMachineManager( // Failed to retry - manually put the flow in for observation rather than // relying on the [HospitalisingInterceptor] to do so val exceptions = (currentState.checkpoint.errorState as? ErrorState.Errored) - ?.errors - ?.map { it.exception } - ?.plus(e) ?: emptyList() + ?.errors + ?.map { it.exception } + ?.plus(e) ?: emptyList() logger.info("Failed to retry flow $flowId, keeping in for observation and aborting") flowHospital.forceIntoOvernightObservation(flowId, exceptions) throw e @@ -455,11 +491,11 @@ class SingleThreadedStateMachineManager( private fun <T> onExternalStartFlow(event: ExternalEvent.ExternalStartFlowEvent<T>) { val future = startFlow( - event.flowId, - event.flowLogic, - event.context, - ourIdentity = null, - deduplicationHandler = event.deduplicationHandler + event.flowId, + event.flowLogic, + event.context, + ourIdentity = null, + deduplicationHandler = event.deduplicationHandler ) event.wireUpFuture(future) } @@ -502,9 +538,13 @@ class SingleThreadedStateMachineManager( logger.info("Cannot find flow corresponding to session ID - $recipientId.") } } else { - mutex.locked { flows[flowId] }?.run { - fiber.scheduleEvent(Event.DeliverSessionMessage(sessionMessage, deduplicationHandler, sender)) - } ?: logger.info("Cannot find fiber corresponding to flow ID $flowId") + val event = Event.DeliverSessionMessage(sessionMessage, deduplicationHandler, sender) + mutex.locked { + flows[flowId]?.run { fiber.scheduleEvent(event) } + // If flow is not running add it to the list of external events to be processed if/when the flow resumes. + ?: pausedFlows[flowId]?.run { addExternalEvent(event) } + ?: logger.info("Cannot find fiber corresponding to flow ID $flowId") + } } } catch (exception: Exception) { logger.error("Exception while routing $sessionMessage", exception) @@ -527,14 +567,14 @@ class SingleThreadedStateMachineManager( is InitiatedFlowFactory.CorDapp -> null } startInitiatedFlow( - event.flowId, - flowLogic, - event.deduplicationHandler, - senderSession, - initiatedSessionId, - sessionMessage, - senderCoreFlowVersion, - initiatedFlowInfo + event.flowId, + flowLogic, + event.deduplicationHandler, + senderSession, + initiatedSessionId, + sessionMessage, + senderCoreFlowVersion, + initiatedFlowInfo ) } catch (t: Throwable) { logger.warn("Unable to initiate flow from $sender (appName=${sessionMessage.appName} " + @@ -581,8 +621,7 @@ class SingleThreadedStateMachineManager( flowLogic, flowStart, ourIdentity, - initiatingMessageDeduplicationHandler, - isStartIdempotent = false + initiatingMessageDeduplicationHandler ) } @@ -593,20 +632,9 @@ class SingleThreadedStateMachineManager( flowLogic: FlowLogic<A>, flowStart: FlowStart, ourIdentity: Party, - deduplicationHandler: DeduplicationHandler?, - isStartIdempotent: Boolean + deduplicationHandler: DeduplicationHandler? ): CordaFuture<FlowStateMachine<A>> { - // Before we construct the state machine state by freezing the FlowLogic we need to make sure that lazy properties - // have access to the fiber (and thereby the service hub) - val flowStateMachineImpl = FlowStateMachineImpl(flowId, flowLogic, scheduler) - val resultFuture = openFuture<Any?>() - flowStateMachineImpl.transientValues = TransientReference(createTransientValues(flowId, resultFuture)) - flowLogic.stateMachine = flowStateMachineImpl - val frozenFlowLogic = (flowLogic as FlowLogic<*>).checkpointSerialize(context = checkpointSerializationContext!!) - - val flowCorDappVersion = createSubFlowVersion(serviceHub.cordappProvider.getCordappForFlow(flowLogic), serviceHub.myInfo.platformVersion) - val flowAlreadyExists = mutex.locked { flows[flowId] != null } val existingCheckpoint = if (flowAlreadyExists) { @@ -614,7 +642,7 @@ class SingleThreadedStateMachineManager( // The checkpoint will be missing if the flow failed before persisting the original checkpoint // CORDA-3359 - Do not start/retry a flow that failed after deleting its checkpoint (the whole of the flow might replay) checkpointStorage.getCheckpoint(flowId)?.let { serializedCheckpoint -> - val checkpoint = tryCheckpointDeserialize(serializedCheckpoint, flowId) + val checkpoint = tryDeserializeCheckpoint(serializedCheckpoint, flowId) if (checkpoint == null) { return openFuture<FlowStateMachine<A>>().mapError { IllegalStateException("Unable to deserialize database checkpoint for flow $flowId. " + @@ -628,37 +656,15 @@ class SingleThreadedStateMachineManager( // This is a brand new flow null } - val checkpoint = existingCheckpoint ?: Checkpoint.create( - invocationContext, - flowStart, - flowLogic.javaClass, - frozenFlowLogic, - ourIdentity, - flowCorDappVersion, - flowLogic.isEnabledTimedFlow() - ).getOrThrow() + val flow = flowCreator.createFlowFromLogic(flowId, invocationContext, flowLogic, flowStart, ourIdentity, existingCheckpoint, deduplicationHandler, ourSenderUUID) val startedFuture = openFuture<Unit>() - val initialState = StateMachineState( - checkpoint = checkpoint, - pendingDeduplicationHandlers = deduplicationHandler?.let { listOf(it) } ?: emptyList(), - isFlowResumed = false, - isWaitingForFuture = false, - future = null, - isAnyCheckpointPersisted = existingCheckpoint != null, - isStartIdempotent = isStartIdempotent, - isRemoved = false, - isKilled = false, - flowLogic = flowLogic, - senderUUID = ourSenderUUID - ) - flowStateMachineImpl.transientState = TransientReference(initialState) mutex.locked { startedFutures[flowId] = startedFuture } totalStartedFlows.inc() - addAndStartFlow(flowId, Flow(flowStateMachineImpl, resultFuture)) - return startedFuture.map { flowStateMachineImpl as FlowStateMachine<A> } + addAndStartFlow(flowId, flow) + return startedFuture.map { flow.fiber as FlowStateMachine<A> } } override fun scheduleFlowTimeout(flowId: StateMachineRunId) { @@ -738,7 +744,7 @@ class SingleThreadedStateMachineManager( } /** Schedules a [FlowTimeoutException] to be fired in order to restart the flow. */ - private fun scheduleTimeoutException(flow: Flow, delay: Long): ScheduledFuture<*> { + private fun scheduleTimeoutException(flow: Flow<*>, delay: Long): ScheduledFuture<*> { return with(serviceHub.configuration.flowTimeout) { scheduledFutureExecutor.schedule({ val event = Event.Error(FlowTimeoutException()) @@ -766,108 +772,16 @@ class SingleThreadedStateMachineManager( } } - private fun verifyFlowLogicIsSuspendable(logic: FlowLogic<Any?>) { - // Quasar requires (in Java 8) that at least the call method be annotated suspendable. Unfortunately, it's - // easy to forget to add this when creating a new flow, so we check here to give the user a better error. - // - // The Kotlin compiler can sometimes generate a synthetic bridge method from a single call declaration, which - // forwards to the void method and then returns Unit. However annotations do not get copied across to this - // bridge, so we have to do a more complex scan here. - val call = logic.javaClass.methods.first { !it.isSynthetic && it.name == "call" && it.parameterCount == 0 } - if (call.getAnnotation(Suspendable::class.java) == null) { - throw FlowException("${logic.javaClass.name}.call() is not annotated as @Suspendable. Please fix this.") - } - } - - private fun createTransientValues(id: StateMachineRunId, resultFuture: CordaFuture<Any?>): FlowStateMachineImpl.TransientValues { - return FlowStateMachineImpl.TransientValues( - eventQueue = Channels.newChannel(-1, Channels.OverflowPolicy.BLOCK), - resultFuture = resultFuture, - database = database, - transitionExecutor = transitionExecutor, - actionExecutor = actionExecutor!!, - stateMachine = StateMachine(id, secureRandom), - serviceHub = serviceHub, - checkpointSerializationContext = checkpointSerializationContext!!, - unfinishedFibers = unfinishedFibers, - waitTimeUpdateHook = { flowId, timeout -> resetCustomTimeout(flowId, timeout) } - ) - } - - private inline fun <reified T : Any> tryCheckpointDeserialize(bytes: SerializedBytes<T>, flowId: StateMachineRunId): T? { + private fun tryDeserializeCheckpoint(serializedCheckpoint: Checkpoint.Serialized, flowId: StateMachineRunId): Checkpoint? { return try { - bytes.checkpointDeserialize(context = checkpointSerializationContext!!) + serializedCheckpoint.deserialize(checkpointSerializationContext!!) } catch (e: Exception) { logger.error("Unable to deserialize checkpoint for flow $flowId. Something is very wrong and this flow will be ignored.", e) null } } - private fun createFlowFromCheckpoint( - id: StateMachineRunId, - serializedCheckpoint: SerializedBytes<Checkpoint>, - isAnyCheckpointPersisted: Boolean, - isStartIdempotent: Boolean, - initialDeduplicationHandler: DeduplicationHandler? - ): Flow? { - val checkpoint = tryCheckpointDeserialize(serializedCheckpoint, id) ?: return null - val flowState = checkpoint.flowState - val resultFuture = openFuture<Any?>() - val fiber = when (flowState) { - is FlowState.Unstarted -> { - val logic = tryCheckpointDeserialize(flowState.frozenFlowLogic, id) ?: return null - val state = StateMachineState( - checkpoint = checkpoint, - pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(), - isFlowResumed = false, - isWaitingForFuture = false, - future = null, - isAnyCheckpointPersisted = isAnyCheckpointPersisted, - isStartIdempotent = isStartIdempotent, - isRemoved = false, - isKilled = false, - flowLogic = logic, - senderUUID = null - ) - val fiber = FlowStateMachineImpl(id, logic, scheduler) - fiber.transientValues = TransientReference(createTransientValues(id, resultFuture)) - fiber.transientState = TransientReference(state) - fiber.logic.stateMachine = fiber - fiber - } - is FlowState.Started -> { - val fiber = tryCheckpointDeserialize(flowState.frozenFiber, id) ?: return null - val state = StateMachineState( - // Do a trivial checkpoint copy below, to update the Checkpoint#timestamp value. - // The Checkpoint#timestamp is being used by FlowMonitor as the starting time point of a potential suspension. - // We need to refresh the Checkpoint#timestamp here, in case of an e.g. node start up after a long period. - // If not then, there is a time window (until the next checkpoint update) in which the FlowMonitor - // could log this flow as a waiting flow, from the last checkpoint update i.e. before the node's start up. - checkpoint = checkpoint.copy(), - pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(), - isFlowResumed = false, - isWaitingForFuture = false, - future = null, - isAnyCheckpointPersisted = isAnyCheckpointPersisted, - isStartIdempotent = isStartIdempotent, - isRemoved = false, - isKilled = false, - flowLogic = fiber.logic, - senderUUID = null - ) - fiber.transientValues = TransientReference(createTransientValues(id, resultFuture)) - fiber.transientState = TransientReference(state) - fiber.logic.stateMachine = fiber - fiber - } - } - - verifyFlowLogicIsSuspendable(fiber.logic) - - return Flow(fiber, resultFuture) - } - - private fun addAndStartFlow(id: StateMachineRunId, flow: Flow) { + private fun addAndStartFlow(id: StateMachineRunId, flow: Flow<*>) { val checkpoint = flow.fiber.snapshot().checkpoint for (sessionId in getFlowSessionIds(checkpoint)) { sessionToFlow[sessionId] = id @@ -887,24 +801,29 @@ class SingleThreadedStateMachineManager( val flowLogic = flow.fiber.logic if (flowLogic.isEnabledTimedFlow()) scheduleTimeout(id) flow.fiber.scheduleEvent(Event.DoRemainingWork) - when (checkpoint.flowState) { - is FlowState.Unstarted -> { - flow.fiber.start() - } - is FlowState.Started -> { - Fiber.unparkDeserialized(flow.fiber, scheduler) - } - } + startOrResume(checkpoint, flow) } } } + private fun startOrResume(checkpoint: Checkpoint, flow: Flow<*>) { + when (checkpoint.flowState) { + is FlowState.Unstarted -> { + flow.fiber.start() + } + is FlowState.Started -> { + Fiber.unparkDeserialized(flow.fiber, scheduler) + } + is FlowState.Completed -> throw IllegalStateException("Cannot start (or resume) a completed flow.") + } + } + private fun getFlowSessionIds(checkpoint: Checkpoint): Set<SessionId> { val initiatedFlowStart = (checkpoint.flowState as? FlowState.Unstarted)?.flowStart as? FlowStart.Initiated return if (initiatedFlowStart == null) { - checkpoint.sessions.keys + checkpoint.checkpointState.sessions.keys } else { - checkpoint.sessions.keys + initiatedFlowStart.initiatedSessionId + checkpoint.checkpointState.sessions.keys + initiatedFlowStart.initiatedSessionId } } @@ -914,8 +833,7 @@ class SingleThreadedStateMachineManager( checkpointStorage, flowMessaging, this, - checkpointSerializationContext, - metrics + checkpointSerializationContext ) } @@ -942,7 +860,7 @@ class SingleThreadedStateMachineManager( } private fun InnerState.removeFlowOrderly( - flow: Flow, + flow: Flow<*>, removalReason: FlowRemovalReason.OrderlyFinish, lastState: StateMachineState ) { @@ -950,7 +868,7 @@ class SingleThreadedStateMachineManager( // final sanity checks require(lastState.pendingDeduplicationHandlers.isEmpty()) { "Flow cannot be removed until all pending deduplications have completed" } require(lastState.isRemoved) { "Flow must be in removable state before removal" } - require(lastState.checkpoint.subFlowStack.size == 1) { "Checkpointed stack must be empty" } + require(lastState.checkpoint.checkpointState.subFlowStack.size == 1) { "Checkpointed stack must be empty" } require(flow.fiber.id !in sessionToFlow.values) { "Flow fibre must not be needed by an existing session" } flow.resultFuture.set(removalReason.flowReturnValue) lastState.flowLogic.progressTracker?.currentStep = ProgressTracker.DONE @@ -958,7 +876,7 @@ class SingleThreadedStateMachineManager( } private fun InnerState.removeFlowError( - flow: Flow, + flow: Flow<*>, removalReason: FlowRemovalReason.ErrorFinish, lastState: StateMachineState ) { @@ -972,7 +890,7 @@ class SingleThreadedStateMachineManager( } // The flow's event queue may be non-empty in case it shut down abruptly. We handle outstanding events here. - private fun drainFlowEventQueue(flow: Flow) { + private fun drainFlowEventQueue(flow: Flow<*>) { while (true) { val event = flow.fiber.transientValues!!.value.eventQueue.tryReceive() ?: return when (event) { diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StaffedFlowHospital.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StaffedFlowHospital.kt index 6408748b42..af7d197d50 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/StaffedFlowHospital.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StaffedFlowHospital.kt @@ -224,7 +224,7 @@ class StaffedFlowHospital(private val flowMessaging: FlowMessaging, log.info("Flow error kept for overnight observation by ${report.by} (error was ${report.error.message})") // We don't schedule a next event for the flow - it will automatically retry from its checkpoint on node restart onFlowKeptForOvernightObservation.forEach { hook -> hook.invoke(flowFiber.id, report.by.map{it.toString()}) } - Triple(Outcome.OVERNIGHT_OBSERVATION, null, 0.seconds) + Triple(Outcome.OVERNIGHT_OBSERVATION, Event.OvernightObservation, 0.seconds) } Diagnosis.NOT_MY_SPECIALTY, Diagnosis.TERMINAL -> { // None of the staff care for these errors, or someone decided it is a terminal condition, so we let them propagate @@ -233,20 +233,19 @@ class StaffedFlowHospital(private val flowMessaging: FlowMessaging, } } - val record = MedicalRecord.Flow(time, flowFiber.id, currentState.checkpoint.numberOfSuspends, errors, report.by, outcome) + val numberOfSuspends = currentState.checkpoint.checkpointState.numberOfSuspends + val record = MedicalRecord.Flow(time, flowFiber.id, numberOfSuspends, errors, report.by, outcome) medicalHistory.records += record recordsPublisher.onNext(record) Pair(event, backOffForChronicCondition) } - if (event != null) { - if (backOffForChronicCondition.isZero) { + if (backOffForChronicCondition.isZero) { + flowFiber.scheduleEvent(event) + } else { + hospitalJobTimer.schedule(timerTask { flowFiber.scheduleEvent(event) - } else { - hospitalJobTimer.schedule(timerTask { - flowFiber.scheduleEvent(event) - }, backOffForChronicCondition.toMillis()) - } + }, backOffForChronicCondition.toMillis()) } } @@ -319,7 +318,7 @@ class StaffedFlowHospital(private val flowMessaging: FlowMessaging, } fun timesDischargedForTheSameThing(by: Staff, currentState: StateMachineState): Int { - val lastAdmittanceSuspendCount = currentState.checkpoint.numberOfSuspends + val lastAdmittanceSuspendCount = currentState.checkpoint.checkpointState.numberOfSuspends return records.count { it.outcome == Outcome.DISCHARGE && by in it.by && it.suspendCount == lastAdmittanceSuspendCount } } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt index 7fa4c22e9b..6079fbccf1 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt @@ -30,12 +30,18 @@ import java.time.Duration * TODO: Don't store all active flows in memory, load from the database on demand. */ interface StateMachineManager { + + enum class StartMode { + ExcludingPaused, // Resume all flows except paused flows. + Safe // Mark all flows as paused. + } + /** * Starts the state machine manager, loading and starting the state machines in storage. * * @return `Future` which completes when SMM is fully started */ - fun start(tokenizableServices: List<Any>) : CordaFuture<Unit> + fun start(tokenizableServices: List<Any>, startMode: StartMode = StartMode.ExcludingPaused) : CordaFuture<Unit> /** * Stops the state machine manager gracefully, waiting until all but [allowedUnsuspendedFiberCount] flows reach the @@ -80,6 +86,13 @@ interface StateMachineManager { */ fun killFlow(id: StateMachineRunId): Boolean + /** + * Start a paused flow. + * + * @return whether the flow was successfully started. + */ + fun unPauseFlow(id: StateMachineRunId): Boolean + /** * Deliver an external event to the state machine. Such an event might be a new P2P message, or a request to start a flow. * The event may be replayed if a flow fails and attempts to retry. diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt index 96b4ad04c1..58a072fc99 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt @@ -7,7 +7,12 @@ import net.corda.core.flows.FlowInfo import net.corda.core.flows.FlowLogic import net.corda.core.identity.Party import net.corda.core.internal.FlowIORequest +import net.corda.core.serialization.CordaSerializable +import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.SerializedBytes +import net.corda.core.serialization.deserialize +import net.corda.core.serialization.internal.CheckpointSerializationContext +import net.corda.core.serialization.internal.checkpointDeserialize import net.corda.core.utilities.Try import net.corda.node.services.messaging.DeduplicationHandler import java.time.Instant @@ -53,25 +58,36 @@ data class StateMachineState( ) /** - * @param invocationContext the initiator of the flow. - * @param ourIdentity the identity the flow is run as. - * @param sessions map of source session ID to session state. - * @param subFlowStack the stack of currently executing subflows. + * @param checkpointState the state of the checkpoint * @param flowState the state of the flow itself, including the frozen fiber/FlowLogic. * @param errorState the "dirtiness" state including the involved errors and their propagation status. - * @param numberOfSuspends the number of flow suspends due to IO API calls. */ data class Checkpoint( - val invocationContext: InvocationContext, - val ourIdentity: Party, - val sessions: SessionMap, // This must preserve the insertion order! - val subFlowStack: List<SubFlow>, + val checkpointState: CheckpointState, val flowState: FlowState, val errorState: ErrorState, - val numberOfSuspends: Int + val result: Any? = null, + val status: FlowStatus = FlowStatus.RUNNABLE, + val progressStep: String? = null, + val flowIoRequest: String? = null, + val compatible: Boolean = true ) { + @CordaSerializable + enum class FlowStatus { + RUNNABLE, + FAILED, + COMPLETED, + HOSPITALIZED, + KILLED, + PAUSED + } - val timestamp: Instant = Instant.now() // This will get updated every time a Checkpoint object is created/ created by copy. + /** + * [timestamp] will get updated every time a [Checkpoint] object is created/ created by copy. + * It will be updated, therefore, for example when a flow is being suspended or whenever a flow + * is being loaded from [Checkpoint] through [Serialized.deserialize]. + */ + val timestamp: Instant = Instant.now() companion object { @@ -86,19 +102,109 @@ data class Checkpoint( ): Try<Checkpoint> { return SubFlow.create(flowLogicClass, subFlowVersion, isEnabledTimedFlow).map { topLevelSubFlow -> Checkpoint( - invocationContext = invocationContext, - ourIdentity = ourIdentity, - sessions = emptyMap(), - subFlowStack = listOf(topLevelSubFlow), - flowState = FlowState.Unstarted(flowStart, frozenFlowLogic), - errorState = ErrorState.Clean, + checkpointState = CheckpointState( + invocationContext, + ourIdentity, + emptyMap(), + listOf(topLevelSubFlow), numberOfSuspends = 0 + ), + errorState = ErrorState.Clean, + flowState = FlowState.Unstarted(flowStart, frozenFlowLogic) ) } } } + + /** + * Returns a copy of the Checkpoint with a new session map. + * @param sessions the new map of session ID to session state. + */ + fun setSessions(sessions: SessionMap) : Checkpoint { + return copy(checkpointState = checkpointState.copy(sessions = sessions)) + } + + /** + * Returns a copy of the Checkpoint with an extra session added to the session map. + * @param session the extra session to add. + */ + fun addSession(session: Pair<SessionId, SessionState>) : Checkpoint { + return copy(checkpointState = checkpointState.copy(sessions = checkpointState.sessions + session)) + } + + /** + * Returns a copy of the Checkpoint with a new subFlow stack. + * @param subFlows the new List of subFlows. + */ + fun setSubflows(subFlows: List<SubFlow>) : Checkpoint { + return copy(checkpointState = checkpointState.copy(subFlowStack = subFlows)) + } + + /** + * Returns a copy of the Checkpoint with an extra subflow added to the subFlow Stack. + * @param subFlow the subFlow to add to the stack of subFlows + */ + fun addSubflow(subFlow: SubFlow) : Checkpoint { + return copy(checkpointState = checkpointState.copy(subFlowStack = checkpointState.subFlowStack + subFlow)) + } + + /** + * A partially serialized form of [Checkpoint]. + * + * [Checkpoint.Serialized] contains the same fields as [Checkpoint] except that some of its fields are still serialized. The checkpoint + * can then be deserialized as needed. + */ + data class Serialized( + val serializedCheckpointState: SerializedBytes<CheckpointState>, + val serializedFlowState: SerializedBytes<FlowState>?, + val errorState: ErrorState, + val result: SerializedBytes<Any>?, + val status: FlowStatus, + val progressStep: String?, + val flowIoRequest: String?, + val compatible: Boolean + ) { + /** + * Deserializes the serialized fields contained in [Checkpoint.Serialized]. + * + * @return A [Checkpoint] with all its fields filled in from [Checkpoint.Serialized] + */ + fun deserialize(checkpointSerializationContext: CheckpointSerializationContext): Checkpoint { + val flowState = when(status) { + FlowStatus.PAUSED -> FlowState.Paused + FlowStatus.COMPLETED -> FlowState.Completed + else -> serializedFlowState!!.checkpointDeserialize(checkpointSerializationContext) + } + return Checkpoint( + checkpointState = serializedCheckpointState.checkpointDeserialize(checkpointSerializationContext), + flowState = flowState, + errorState = errorState, + result = result?.deserialize(context = SerializationDefaults.STORAGE_CONTEXT), + status = status, + progressStep = progressStep, + flowIoRequest = flowIoRequest, + compatible = compatible + ) + } + } } +/** + * @param invocationContext the initiator of the flow. + * @param ourIdentity the identity the flow is run as. + * @param sessions map of source session ID to session state. + * @param subFlowStack the stack of currently executing subflows. + * @param numberOfSuspends the number of flow suspends due to IO API calls. + */ +@CordaSerializable +data class CheckpointState( + val invocationContext: InvocationContext, + val ourIdentity: Party, + val sessions: SessionMap, // This must preserve the insertion order! + val subFlowStack: List<SubFlow>, + val numberOfSuspends: Int +) + /** * The state of a session. */ @@ -205,6 +311,17 @@ sealed class FlowState { ) : FlowState() { override fun toString() = "Started(flowIORequest=$flowIORequest, frozenFiber=${frozenFiber.hash})" } + + /** + * The flow is paused. To save memory we don't store the FlowState + */ + object Paused: FlowState() + + /** + * The flow has completed. It does not have a running fiber that needs to be serialized and checkpointed. + */ + object Completed : FlowState() + } /** @@ -213,17 +330,20 @@ sealed class FlowState { * @param exception the exception itself. Note that this may not contain information about the source error depending * on whether the source error was a FlowException or otherwise. */ +@CordaSerializable data class FlowError(val errorId: Long, val exception: Throwable) /** * The flow's error state. */ +@CordaSerializable sealed class ErrorState { abstract fun addErrors(newErrors: List<FlowError>): ErrorState /** * The flow is in a clean state. */ + @CordaSerializable object Clean : ErrorState() { override fun addErrors(newErrors: List<FlowError>): ErrorState { return Errored(newErrors, 0, false) @@ -240,6 +360,7 @@ sealed class ErrorState { * @param propagating true if error propagation was triggered. If this is set the dirtiness is permanent as the * sessions associated with the flow have been (or about to be) dirtied in counter-flows. */ + @CordaSerializable data class Errored( val errors: List<FlowError>, val propagatedIndex: Int, @@ -258,4 +379,4 @@ sealed class SubFlowVersion { abstract val platformVersion: Int data class CoreFlow(override val platformVersion: Int) : SubFlowVersion() data class CorDappFlow(override val platformVersion: Int, val corDappName: String, val corDappHash: SecureHash) : SubFlowVersion() -} \ No newline at end of file +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DeliverSessionMessageTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DeliverSessionMessageTransition.kt index f225eca1c4..0aa58241eb 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DeliverSessionMessageTransition.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DeliverSessionMessageTransition.kt @@ -47,7 +47,7 @@ class DeliverSessionMessageTransition( pendingDeduplicationHandlers = currentState.pendingDeduplicationHandlers + event.deduplicationHandler ) // Check whether we have a session corresponding to the message. - val existingSession = startingState.checkpoint.sessions[event.sessionMessage.recipientSessionId] + val existingSession = startingState.checkpoint.checkpointState.sessions[event.sessionMessage.recipientSessionId] if (existingSession == null) { freshErrorTransition(CannotFindSessionException(event.sessionMessage.recipientSessionId)) } else { @@ -80,8 +80,8 @@ class DeliverSessionMessageTransition( errors = emptyList(), deduplicationSeed = sessionState.deduplicationSeed ) - val newCheckpoint = currentState.checkpoint.copy( - sessions = currentState.checkpoint.sessions + (event.sessionMessage.recipientSessionId to initiatedSession) + val newCheckpoint = currentState.checkpoint.addSession( + event.sessionMessage.recipientSessionId to initiatedSession ) // Send messages that were buffered pending confirmation of session. val sendActions = sessionState.bufferedMessages.map { (deduplicationId, bufferedMessage) -> @@ -103,9 +103,10 @@ class DeliverSessionMessageTransition( val newSessionState = sessionState.copy( receivedMessages = sessionState.receivedMessages + message ) + currentState = currentState.copy( - checkpoint = currentState.checkpoint.copy( - sessions = startingState.checkpoint.sessions + (event.sessionMessage.recipientSessionId to newSessionState) + checkpoint = currentState.checkpoint.addSession( + event.sessionMessage.recipientSessionId to newSessionState ) ) } @@ -137,9 +138,7 @@ class DeliverSessionMessageTransition( val flowError = FlowError(payload.errorId, exception) val newSessionState = sessionState.copy(errors = sessionState.errors + flowError) currentState = currentState.copy( - checkpoint = checkpoint.copy( - sessions = checkpoint.sessions + (sessionId to newSessionState) - ) + checkpoint = checkpoint.addSession(sessionId to newSessionState) ) } else -> freshErrorTransition(UnexpectedEventInState()) @@ -158,9 +157,7 @@ class DeliverSessionMessageTransition( val sessionId = event.sessionMessage.recipientSessionId val flowError = FlowError(payload.errorId, exception) currentState = currentState.copy( - checkpoint = checkpoint.copy( - sessions = checkpoint.sessions + (sessionId to sessionState.copy(rejectionError = flowError)) - ) + checkpoint = checkpoint.addSession(sessionId to sessionState.copy(rejectionError = flowError)) ) } } @@ -170,7 +167,7 @@ class DeliverSessionMessageTransition( private fun TransitionBuilder.endMessageTransition() { val sessionId = event.sessionMessage.recipientSessionId - val sessions = currentState.checkpoint.sessions + val sessions = currentState.checkpoint.checkpointState.sessions val sessionState = sessions[sessionId] if (sessionState == null) { return freshErrorTransition(CannotFindSessionException(sessionId)) @@ -179,9 +176,8 @@ class DeliverSessionMessageTransition( is SessionState.Initiated -> { val newSessionState = sessionState.copy(initiatedState = InitiatedSessionState.Ended) currentState = currentState.copy( - checkpoint = currentState.checkpoint.copy( - sessions = sessions + (sessionId to newSessionState) - ) + checkpoint = currentState.checkpoint.addSession(sessionId to newSessionState) + ) } else -> { diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DoRemainingWorkTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DoRemainingWorkTransition.kt index 53214c71fb..21b06c6e40 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DoRemainingWorkTransition.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DoRemainingWorkTransition.kt @@ -24,10 +24,12 @@ class DoRemainingWorkTransition( // If the flow is clean check the FlowState private fun cleanTransition(): TransitionResult { - val checkpoint = startingState.checkpoint - return when (checkpoint.flowState) { - is FlowState.Unstarted -> UnstartedFlowTransition(context, startingState, checkpoint.flowState).transition() - is FlowState.Started -> StartedFlowTransition(context, startingState, checkpoint.flowState).transition() + val flowState = startingState.checkpoint.flowState + return when (flowState) { + is FlowState.Unstarted -> UnstartedFlowTransition(context, startingState, flowState).transition() + is FlowState.Started -> StartedFlowTransition(context, startingState, flowState).transition() + is FlowState.Completed -> throw IllegalStateException("Cannot transition a state with completed flow state.") + is FlowState.Paused -> throw IllegalStateException("Cannot transition a state with paused flow state.") } } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/ErrorFlowTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/ErrorFlowTransition.kt index 89b1b00a29..551807fcdf 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/ErrorFlowTransition.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/ErrorFlowTransition.kt @@ -40,10 +40,13 @@ class ErrorFlowTransition( return builder { // If we're errored and propagating do the actual propagation and update the index. if (remainingErrorsToPropagate.isNotEmpty() && errorState.propagating) { - val (initiatedSessions, newSessions) = bufferErrorMessagesInInitiatingSessions(startingState.checkpoint.sessions, errorMessages) + val (initiatedSessions, newSessions) = bufferErrorMessagesInInitiatingSessions( + startingState.checkpoint.checkpointState.sessions, + errorMessages + ) val newCheckpoint = startingState.checkpoint.copy( errorState = errorState.copy(propagatedIndex = allErrors.size), - sessions = newSessions + checkpointState = startingState.checkpoint.checkpointState.copy(sessions = newSessions) ) currentState = currentState.copy(checkpoint = newCheckpoint) actions.add(Action.PropagateErrors(errorMessages, initiatedSessions, startingState.senderUUID)) @@ -56,19 +59,20 @@ class ErrorFlowTransition( // If we haven't been removed yet remove the flow. if (!currentState.isRemoved) { - actions.add(Action.CreateTransaction) - if (currentState.isAnyCheckpointPersisted) { - actions.add(Action.RemoveCheckpoint(context.id)) - } + val newCheckpoint = startingState.checkpoint.copy(status = Checkpoint.FlowStatus.FAILED) + actions.addAll(arrayOf( + Action.CreateTransaction, + Action.PersistCheckpoint(context.id, newCheckpoint, isCheckpointUpdate = currentState.isAnyCheckpointPersisted), Action.PersistDeduplicationFacts(currentState.pendingDeduplicationHandlers), Action.ReleaseSoftLocks(context.id.uuid), Action.CommitTransaction, Action.AcknowledgeMessages(currentState.pendingDeduplicationHandlers), - Action.RemoveSessionBindings(currentState.checkpoint.sessions.keys) + Action.RemoveSessionBindings(currentState.checkpoint.checkpointState.sessions.keys) )) currentState = currentState.copy( + checkpoint = newCheckpoint, pendingDeduplicationHandlers = emptyList(), isRemoved = true ) diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/KilledFlowTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/KilledFlowTransition.kt index de25acac71..5c7b095e80 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/KilledFlowTransition.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/KilledFlowTransition.kt @@ -25,10 +25,11 @@ class KilledFlowTransition( val killedFlowErrorMessage = createErrorMessageFromError(killedFlowError) val errorMessages = listOf(killedFlowErrorMessage) - val (initiatedSessions, newSessions) = bufferErrorMessagesInInitiatingSessions(startingState.checkpoint.sessions, errorMessages) - val newCheckpoint = startingState.checkpoint.copy( - sessions = newSessions + val (initiatedSessions, newSessions) = bufferErrorMessagesInInitiatingSessions( + startingState.checkpoint.checkpointState.sessions, + errorMessages ) + val newCheckpoint = startingState.checkpoint.setSessions(sessions = newSessions) currentState = currentState.copy(checkpoint = newCheckpoint) actions.add( Action.PropagateErrors( @@ -42,7 +43,7 @@ class KilledFlowTransition( actions.add(Action.CreateTransaction) } // The checkpoint and soft locks are also removed directly in [StateMachineManager.killFlow] - if(startingState.isAnyCheckpointPersisted) { + if (startingState.isAnyCheckpointPersisted) { actions.add(Action.RemoveCheckpoint(context.id)) } actions.addAll( @@ -51,7 +52,7 @@ class KilledFlowTransition( Action.ReleaseSoftLocks(context.id.uuid), Action.CommitTransaction, Action.AcknowledgeMessages(currentState.pendingDeduplicationHandlers), - Action.RemoveSessionBindings(currentState.checkpoint.sessions.keys) + Action.RemoveSessionBindings(currentState.checkpoint.checkpointState.sessions.keys) ) ) diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/StartedFlowTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/StartedFlowTransition.kt index ab874763fc..904ab3f06a 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/StartedFlowTransition.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/StartedFlowTransition.kt @@ -47,7 +47,7 @@ class StartedFlowTransition( private fun waitForSessionConfirmationsTransition(): TransitionResult { return builder { - if (currentState.checkpoint.sessions.values.any { it is SessionState.Initiating }) { + if (currentState.checkpoint.checkpointState.sessions.values.any { it is SessionState.Initiating }) { FlowContinuation.ProcessEvents } else { resumeFlowLogic(Unit) @@ -77,7 +77,7 @@ class StartedFlowTransition( val checkpoint = currentState.checkpoint val resultMap = LinkedHashMap<FlowSession, FlowInfo>() for ((sessionId, session) in sessionIdToSession) { - val sessionState = checkpoint.sessions[sessionId] + val sessionState = checkpoint.checkpointState.sessions[sessionId] if (sessionState is SessionState.Initiated) { resultMap[session] = sessionState.peerFlowInfo } else { @@ -169,14 +169,14 @@ class StartedFlowTransition( sourceSessionIdToSessionMap: Map<SessionId, FlowSessionImpl> ): Map<FlowSession, SerializedBytes<Any>>? { val checkpoint = currentState.checkpoint - val pollResult = pollSessionMessages(checkpoint.sessions, sourceSessionIdToSessionMap.keys) ?: return null + val pollResult = pollSessionMessages(checkpoint.checkpointState.sessions, sourceSessionIdToSessionMap.keys) ?: return null val resultMap = LinkedHashMap<FlowSession, SerializedBytes<Any>>() for ((sessionId, message) in pollResult.messages) { val session = sourceSessionIdToSessionMap[sessionId]!! resultMap[session] = message } currentState = currentState.copy( - checkpoint = checkpoint.copy(sessions = pollResult.newSessionMap) + checkpoint = checkpoint.setSessions(sessions = pollResult.newSessionMap) ) return resultMap } @@ -215,10 +215,10 @@ class StartedFlowTransition( private fun TransitionBuilder.sendInitialSessionMessagesIfNeeded(sourceSessions: Set<SessionId>) { val checkpoint = startingState.checkpoint - val newSessions = LinkedHashMap<SessionId, SessionState>(checkpoint.sessions) + val newSessions = LinkedHashMap<SessionId, SessionState>(checkpoint.checkpointState.sessions) var index = 0 for (sourceSessionId in sourceSessions) { - val sessionState = checkpoint.sessions[sourceSessionId] + val sessionState = checkpoint.checkpointState.sessions[sourceSessionId] if (sessionState == null) { return freshErrorTransition(CannotFindSessionException(sourceSessionId)) } @@ -235,7 +235,7 @@ class StartedFlowTransition( actions.add(Action.SendInitial(sessionState.destination, initialMessage, SenderDeduplicationId(deduplicationId, startingState.senderUUID))) newSessions[sourceSessionId] = newSessionState } - currentState = currentState.copy(checkpoint = checkpoint.copy(sessions = newSessions)) + currentState = currentState.copy(checkpoint = checkpoint.setSessions(sessions = newSessions)) } private fun sendTransition(flowIORequest: FlowIORequest.Send): TransitionResult { @@ -254,17 +254,17 @@ class StartedFlowTransition( private fun TransitionBuilder.sendToSessionsTransition(sourceSessionIdToMessage: Map<SessionId, SerializedBytes<Any>>) { val checkpoint = startingState.checkpoint - val newSessions = LinkedHashMap(checkpoint.sessions) + val newSessions = LinkedHashMap(checkpoint.checkpointState.sessions) var index = 0 for ((sourceSessionId, _) in sourceSessionIdToMessage) { - val existingSessionState = checkpoint.sessions[sourceSessionId] ?: return freshErrorTransition(CannotFindSessionException(sourceSessionId)) + val existingSessionState = checkpoint.checkpointState.sessions[sourceSessionId] ?: return freshErrorTransition(CannotFindSessionException(sourceSessionId)) if (existingSessionState is SessionState.Initiated && existingSessionState.initiatedState is InitiatedSessionState.Ended) { return freshErrorTransition(IllegalStateException("Tried to send to ended session $sourceSessionId")) } } val messagesByType = sourceSessionIdToMessage.toList() - .map { (sourceSessionId, message) -> Triple(sourceSessionId, checkpoint.sessions[sourceSessionId]!!, message) } + .map { (sourceSessionId, message) -> Triple(sourceSessionId, checkpoint.checkpointState.sessions[sourceSessionId]!!, message) } .groupBy { it.second::class } val sendInitialActions = messagesByType[SessionState.Uninitiated::class]?.map { (sourceSessionId, sessionState, message) -> @@ -301,7 +301,7 @@ class StartedFlowTransition( if (sendInitialActions.isNotEmpty() || sendExistingActions.isNotEmpty()) { actions.add(Action.SendMultiple(sendInitialActions, sendExistingActions)) } - currentState = currentState.copy(checkpoint = checkpoint.copy(sessions = newSessions)) + currentState = currentState.copy(checkpoint = checkpoint.setSessions(newSessions)) } private fun sessionToSessionId(session: FlowSession): SessionId { @@ -310,7 +310,7 @@ class StartedFlowTransition( private fun collectErroredSessionErrors(sessionIds: Collection<SessionId>, checkpoint: Checkpoint): List<Throwable> { return sessionIds.flatMap { sessionId -> - val sessionState = checkpoint.sessions[sessionId]!! + val sessionState = checkpoint.checkpointState.sessions[sessionId]!! when (sessionState) { is SessionState.Uninitiated -> emptyList() is SessionState.Initiating -> { @@ -326,14 +326,14 @@ class StartedFlowTransition( } private fun collectErroredInitiatingSessionErrors(checkpoint: Checkpoint): List<Throwable> { - return checkpoint.sessions.values.mapNotNull { sessionState -> + return checkpoint.checkpointState.sessions.values.mapNotNull { sessionState -> (sessionState as? SessionState.Initiating)?.rejectionError?.exception } } private fun collectEndedSessionErrors(sessionIds: Collection<SessionId>, checkpoint: Checkpoint): List<Throwable> { return sessionIds.mapNotNull { sessionId -> - val sessionState = checkpoint.sessions[sessionId]!! + val sessionState = checkpoint.checkpointState.sessions[sessionId]!! when (sessionState) { is SessionState.Initiated -> { if (sessionState.initiatedState === InitiatedSessionState.Ended) { @@ -353,7 +353,7 @@ class StartedFlowTransition( private fun collectEndedEmptySessionErrors(sessionIds: Collection<SessionId>, checkpoint: Checkpoint): List<Throwable> { return sessionIds.mapNotNull { sessionId -> - val sessionState = checkpoint.sessions[sessionId]!! + val sessionState = checkpoint.checkpointState.sessions[sessionId]!! when (sessionState) { is SessionState.Initiated -> { if (sessionState.initiatedState === InitiatedSessionState.Ended && @@ -387,7 +387,7 @@ class StartedFlowTransition( collectErroredSessionErrors(sessionIds, checkpoint) + collectEndedSessionErrors(sessionIds, checkpoint) } is FlowIORequest.WaitForLedgerCommit -> { - collectErroredSessionErrors(checkpoint.sessions.keys, checkpoint) + collectErroredSessionErrors(checkpoint.checkpointState.sessions.keys, checkpoint) } is FlowIORequest.GetFlowInfo -> { collectErroredSessionErrors(flowIORequest.sessions.map(this::sessionToSessionId), checkpoint) @@ -431,7 +431,7 @@ class StartedFlowTransition( builder { // The `numberOfSuspends` is added to the deduplication ID in case an async // operation is executed multiple times within the same flow. - val deduplicationId = context.id.toString() + ":" + currentState.checkpoint.numberOfSuspends.toString() + val deduplicationId = context.id.toString() + ":" + currentState.checkpoint.checkpointState.numberOfSuspends.toString() actions.add(Action.ExecuteAsyncOperation(deduplicationId, flowIORequest.operation)) currentState = currentState.copy(isWaitingForFuture = true) FlowContinuation.ProcessEvents diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt index 58ce438e6a..037d408928 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt @@ -55,6 +55,7 @@ class TopLevelTransition( is Event.AsyncOperationCompletion -> asyncOperationCompletionTransition(event) is Event.AsyncOperationThrows -> asyncOperationThrowsTransition(event) is Event.RetryFlowFromSafePoint -> retryFlowFromSafePointTransition(startingState) + is Event.OvernightObservation -> overnightObservationTransition() is Event.WakeUpFromSleep -> wakeUpFromSleepTransition() } } @@ -98,7 +99,7 @@ class TopLevelTransition( return TransitionResult( newState = lastState, actions = listOf( - Action.RemoveSessionBindings(startingState.checkpoint.sessions.keys), + Action.RemoveSessionBindings(startingState.checkpoint.checkpointState.sessions.keys), Action.RemoveFlow(context.id, FlowRemovalReason.SoftShutdown, lastState) ), continuation = FlowContinuation.Abort @@ -128,11 +129,9 @@ class TopLevelTransition( val subFlow = SubFlow.create(event.subFlowClass, event.subFlowVersion, event.isEnabledTimedFlow) when (subFlow) { is Try.Success -> { - val containsTimedSubflow = containsTimedFlows(currentState.checkpoint.subFlowStack) + val containsTimedSubflow = containsTimedFlows(currentState.checkpoint.checkpointState.subFlowStack) currentState = currentState.copy( - checkpoint = currentState.checkpoint.copy( - subFlowStack = currentState.checkpoint.subFlowStack + subFlow.value - ) + checkpoint = currentState.checkpoint.addSubflow(subFlow.value) ) // We don't schedule a timeout if there already is a timed subflow on the stack - a timeout had // been scheduled already. @@ -151,17 +150,15 @@ class TopLevelTransition( private fun leaveSubFlowTransition(): TransitionResult { return builder { val checkpoint = currentState.checkpoint - if (checkpoint.subFlowStack.isEmpty()) { + if (checkpoint.checkpointState.subFlowStack.isEmpty()) { freshErrorTransition(UnexpectedEventInState()) } else { - val isLastSubFlowTimed = checkpoint.subFlowStack.last().isEnabledTimedFlow - val newSubFlowStack = checkpoint.subFlowStack.dropLast(1) + val isLastSubFlowTimed = checkpoint.checkpointState.subFlowStack.last().isEnabledTimedFlow + val newSubFlowStack = checkpoint.checkpointState.subFlowStack.dropLast(1) currentState = currentState.copy( - checkpoint = checkpoint.copy( - subFlowStack = newSubFlowStack - ) + checkpoint = checkpoint.setSubflows(newSubFlowStack) ) - if (isLastSubFlowTimed && !containsTimedFlows(currentState.checkpoint.subFlowStack)) { + if (isLastSubFlowTimed && !containsTimedFlows(currentState.checkpoint.checkpointState.subFlowStack)) { actions.add(Action.CancelFlowTimeout(currentState.flowLogic.runId)) } } @@ -175,10 +172,22 @@ class TopLevelTransition( private fun suspendTransition(event: Event.Suspend): TransitionResult { return builder { - val newCheckpoint = currentState.checkpoint.copy( + val newCheckpoint = currentState.checkpoint.run { + val newCheckpointState = if (checkpointState.invocationContext.arguments.isNotEmpty()) { + checkpointState.copy( + invocationContext = checkpointState.invocationContext.copy(arguments = emptyList()), + numberOfSuspends = checkpointState.numberOfSuspends + 1 + ) + } else { + checkpointState.copy(numberOfSuspends = checkpointState.numberOfSuspends + 1) + } + copy( flowState = FlowState.Started(event.ioRequest, event.fiber), - numberOfSuspends = currentState.checkpoint.numberOfSuspends + 1 - ) + checkpointState = newCheckpointState, + flowIoRequest = event.ioRequest::class.java.simpleName, + progressStep = event.progressStep?.label + ) + } if (event.maySkipCheckpoint) { actions.addAll(arrayOf( Action.CommitTransaction, @@ -215,18 +224,23 @@ class TopLevelTransition( val pendingDeduplicationHandlers = currentState.pendingDeduplicationHandlers currentState = currentState.copy( checkpoint = checkpoint.copy( - numberOfSuspends = checkpoint.numberOfSuspends + 1 + checkpointState = checkpoint.checkpointState.copy( + numberOfSuspends = checkpoint.checkpointState.numberOfSuspends + 1 + ), + flowState = FlowState.Completed, + result = event.returnValue, + status = Checkpoint.FlowStatus.COMPLETED ), pendingDeduplicationHandlers = emptyList(), isFlowResumed = false, isRemoved = true ) - val allSourceSessionIds = checkpoint.sessions.keys + val allSourceSessionIds = checkpoint.checkpointState.sessions.keys if (currentState.isAnyCheckpointPersisted) { actions.add(Action.RemoveCheckpoint(context.id)) } actions.addAll(arrayOf( - Action.PersistDeduplicationFacts(pendingDeduplicationHandlers), + Action.PersistDeduplicationFacts(pendingDeduplicationHandlers), Action.ReleaseSoftLocks(event.softLocksId), Action.CommitTransaction, Action.AcknowledgeMessages(pendingDeduplicationHandlers), @@ -247,7 +261,7 @@ class TopLevelTransition( } private fun TransitionBuilder.sendEndMessages() { - val sendEndMessageActions = currentState.checkpoint.sessions.values.mapIndexed { index, state -> + val sendEndMessageActions = currentState.checkpoint.checkpointState.sessions.values.mapIndexed { index, state -> if (state is SessionState.Initiated && state.initiatedState is InitiatedSessionState.Live) { val message = ExistingSessionMessage(state.initiatedState.peerSinkSessionId, EndSessionMessage) val deduplicationId = DeduplicationId.createForNormal(currentState.checkpoint, index, state) @@ -269,15 +283,15 @@ class TopLevelTransition( } val sourceSessionId = SessionId.createRandom(context.secureRandom) val sessionImpl = FlowSessionImpl(event.destination, event.wellKnownParty, sourceSessionId) - val newSessions = checkpoint.sessions + (sourceSessionId to SessionState.Uninitiated(event.destination, initiatingSubFlow, sourceSessionId, context.secureRandom.nextLong())) - currentState = currentState.copy(checkpoint = checkpoint.copy(sessions = newSessions)) + val newSessions = checkpoint.checkpointState.sessions + (sourceSessionId to SessionState.Uninitiated(event.destination, initiatingSubFlow, sourceSessionId, context.secureRandom.nextLong())) + currentState = currentState.copy(checkpoint = checkpoint.setSessions(newSessions)) actions.add(Action.AddSessionBinding(context.id, sourceSessionId)) FlowContinuation.Resume(sessionImpl) } } private fun getClosestAncestorInitiatingSubFlow(checkpoint: Checkpoint): SubFlow.Initiating? { - for (subFlow in checkpoint.subFlowStack.asReversed()) { + for (subFlow in checkpoint.checkpointState.subFlowStack.asReversed()) { if (subFlow is SubFlow.Initiating) { return subFlow } @@ -307,9 +321,20 @@ class TopLevelTransition( } } + private fun overnightObservationTransition(): TransitionResult { + return builder { + val newCheckpoint = startingState.checkpoint.copy(status = Checkpoint.FlowStatus.HOSPITALIZED) + actions.add(Action.CreateTransaction) + actions.add(Action.PersistCheckpoint(context.id, newCheckpoint, isCheckpointUpdate = currentState.isAnyCheckpointPersisted)) + actions.add(Action.CommitTransaction) + currentState = currentState.copy(checkpoint = newCheckpoint) + FlowContinuation.ProcessEvents + } + } + private fun wakeUpFromSleepTransition(): TransitionResult { return builder { resumeFlowLogic(Unit) } } -} \ No newline at end of file +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/UnstartedFlowTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/UnstartedFlowTransition.kt index ac3c232377..c85830fb03 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/UnstartedFlowTransition.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/UnstartedFlowTransition.kt @@ -61,9 +61,7 @@ class UnstartedFlowTransition( val confirmationMessage = ConfirmSessionMessage(flowStart.initiatedSessionId, flowStart.initiatedFlowInfo) val sessionMessage = ExistingSessionMessage(initiatingMessage.initiatorSessionId, confirmationMessage) currentState = currentState.copy( - checkpoint = currentState.checkpoint.copy( - sessions = mapOf(flowStart.initiatedSessionId to initiatedState) - ) + checkpoint = currentState.checkpoint.setSessions(mapOf(flowStart.initiatedSessionId to initiatedState)) ) actions.add( Action.SendExisting( diff --git a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt index 438e2c8dec..6dfd3b6e04 100644 --- a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt +++ b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt @@ -42,11 +42,6 @@ import javax.persistence.criteria.CriteriaUpdate import javax.persistence.criteria.Predicate import javax.persistence.criteria.Root -private fun CriteriaBuilder.executeUpdate(session: Session, configure: Root<*>.(CriteriaUpdate<*>) -> Any?) = createCriteriaUpdate(VaultSchemaV1.VaultStates::class.java).let { update -> - update.from(VaultSchemaV1.VaultStates::class.java).run { configure(update) } - session.createQuery(update).executeUpdate() -} - /** * The vault service handles storage, retrieval and querying of states. * @@ -67,6 +62,8 @@ class NodeVaultService( companion object { private val log = contextLogger() + val MAX_SQL_IN_CLAUSE_SET = 16 + /** * Establish whether a given state is relevant to a node, given the node's public keys. * @@ -462,13 +459,20 @@ class NodeVaultService( } } + /** + * Whenever executed inside a [FlowStateMachineImpl], if [lockId] refers to the currently running [FlowStateMachineImpl], + * then in that case the [FlowStateMachineImpl] instance is locking states with its [FlowStateMachineImpl.id]'s [UUID]. + * In this case alone, we keep the reserved set of [StateRef] in [FlowStateMachineImpl.softLockedStates]. This set will be then + * used by default in [softLockRelease]. + */ + @Suppress("NestedBlockDepth", "ComplexMethod") @Throws(StatesNotAvailableException::class) override fun softLockReserve(lockId: UUID, stateRefs: NonEmptySet<StateRef>) { val softLockTimestamp = clock.instant() try { val session = currentDBSession() val criteriaBuilder = session.criteriaBuilder - fun execute(configure: Root<*>.(CriteriaUpdate<*>, Array<Predicate>) -> Any?) = criteriaBuilder.executeUpdate(session) { update -> + fun execute(configure: Root<*>.(CriteriaUpdate<*>, Array<Predicate>) -> Any?) = criteriaBuilder.executeUpdate(session, null) { update, _ -> val persistentStateRefs = stateRefs.map { PersistentStateRef(it.txhash.bytes.toHexString(), it.index) } val compositeKey = get<PersistentStateRef>(VaultSchemaV1.VaultStates::stateRef.name) val stateRefsPredicate = criteriaBuilder.and(compositeKey.`in`(persistentStateRefs)) @@ -485,7 +489,11 @@ class NodeVaultService( } if (updatedRows > 0 && updatedRows == stateRefs.size) { log.trace { "Reserving soft lock states for $lockId: $stateRefs" } - FlowStateMachineImpl.currentStateMachine()?.hasSoftLockedStates = true + FlowStateMachineImpl.currentStateMachine()?.let { + if (lockId == it.id.uuid) { + it.softLockedStates.addAll(stateRefs) + } + } } else { // revert partial soft locks val revertUpdatedRows = execute { update, commonPredicates -> @@ -508,19 +516,44 @@ class NodeVaultService( } } + /** + * Whenever executed inside a [FlowStateMachineImpl], if [lockId] refers to the currently running [FlowStateMachineImpl] and [stateRefs] is null, + * then in that case the [FlowStateMachineImpl] instance will, by default, retrieve its set of [StateRef] + * from [FlowStateMachineImpl.softLockedStates] (previously reserved from [softLockReserve]). This set will be then explicitly provided + * to the below query which then leads to the database query optimizer use the primary key index in VAULT_STATES table, instead of lock_id_idx + * in order to search rows to be updated. That way the query will be aligned with the rest of the queries that are following that route as well + * (i.e. making use of the primary key), and therefore its locking order of resources within the database will be aligned + * with the rest queries' locking orders (solving SQL deadlocks). + * + * If [lockId] does not refer to the currently running [FlowStateMachineImpl] and [stateRefs] is null, then it will be using only [lockId] in + * the below query. + */ + @Suppress("NestedBlockDepth", "ComplexMethod") override fun softLockRelease(lockId: UUID, stateRefs: NonEmptySet<StateRef>?) { val softLockTimestamp = clock.instant() val session = currentDBSession() val criteriaBuilder = session.criteriaBuilder - fun execute(configure: Root<*>.(CriteriaUpdate<*>, Array<Predicate>) -> Any?) = criteriaBuilder.executeUpdate(session) { update -> + fun execute(stateRefs: NonEmptySet<StateRef>?, configure: Root<*>.(CriteriaUpdate<*>, Array<Predicate>, List<PersistentStateRef>?) -> Any?) = + criteriaBuilder.executeUpdate(session, stateRefs) { update, persistentStateRefs -> val stateStatusPredication = criteriaBuilder.equal(get<Vault.StateStatus>(VaultSchemaV1.VaultStates::stateStatus.name), Vault.StateStatus.UNCONSUMED) val lockIdPredicate = criteriaBuilder.equal(get<String>(VaultSchemaV1.VaultStates::lockId.name), lockId.toString()) update.set<String>(get<String>(VaultSchemaV1.VaultStates::lockId.name), criteriaBuilder.nullLiteral(String::class.java)) update.set(get<Instant>(VaultSchemaV1.VaultStates::lockUpdateTime.name), softLockTimestamp) - configure(update, arrayOf(stateStatusPredication, lockIdPredicate)) + configure(update, arrayOf(stateStatusPredication, lockIdPredicate), persistentStateRefs) } - if (stateRefs == null) { - val update = execute { update, commonPredicates -> + + val stateRefsToBeReleased = + stateRefs ?: FlowStateMachineImpl.currentStateMachine()?.let { + // We only hold states under our flowId. For all other lockId fall back to old query mechanism, i.e. stateRefsToBeReleased = null + if (lockId == it.id.uuid && it.softLockedStates.isNotEmpty()) { + NonEmptySet.copyOf(it.softLockedStates) + } else { + null + } + } + + if (stateRefsToBeReleased == null) { + val update = execute(null) { update, commonPredicates, _ -> update.where(*commonPredicates) } if (update > 0) { @@ -528,19 +561,21 @@ class NodeVaultService( } } else { try { - val updatedRows = execute { update, commonPredicates -> - val persistentStateRefs = stateRefs.map { PersistentStateRef(it.txhash.bytes.toHexString(), it.index) } + val updatedRows = execute(stateRefsToBeReleased) { update, commonPredicates, persistentStateRefs -> val compositeKey = get<PersistentStateRef>(VaultSchemaV1.VaultStates::stateRef.name) val stateRefsPredicate = criteriaBuilder.and(compositeKey.`in`(persistentStateRefs)) update.where(*commonPredicates, stateRefsPredicate) } if (updatedRows > 0) { - log.trace { "Releasing $updatedRows soft locked states for $lockId and stateRefs $stateRefs" } + FlowStateMachineImpl.currentStateMachine()?.let { + if (lockId == it.id.uuid) { + it.softLockedStates.removeAll(stateRefsToBeReleased) + } + } + log.trace { "Releasing $updatedRows soft locked states for $lockId and stateRefs $stateRefsToBeReleased" } } } catch (e: Exception) { - log.error("""soft lock update error attempting to release states for $lockId and $stateRefs") - $e. - """) + log.error("Soft lock update error attempting to release states for $lockId and $stateRefsToBeReleased", e) throw e } } @@ -819,5 +854,29 @@ class NodeVaultService( } } +private fun CriteriaBuilder.executeUpdate( + session: Session, + stateRefs: NonEmptySet<StateRef>?, + configure: Root<*>.(CriteriaUpdate<*>, List<PersistentStateRef>?) -> Any? +): Int { + fun doUpdate(persistentStateRefs: List<PersistentStateRef>?): Int { + createCriteriaUpdate(VaultSchemaV1.VaultStates::class.java).let { update -> + update.from(VaultSchemaV1.VaultStates::class.java).run { configure(update, persistentStateRefs) } + return session.createQuery(update).executeUpdate() + } + } + return stateRefs?.let { + // Increase SQL server performance by, processing updates in chunks allowing the database's optimizer to make use of the index. + var updatedRows = 0 + it.asSequence() + .map { stateRef -> PersistentStateRef(stateRef.txhash.bytes.toHexString(), stateRef.index) } + .chunked(NodeVaultService.MAX_SQL_IN_CLAUSE_SET) + .forEach { persistentStateRefs -> + updatedRows += doUpdate(persistentStateRefs) + } + updatedRows + } ?: doUpdate(null) +} + /** The Observable returned allows subscribing with custom SafeSubscribers to source [Observable]. */ internal fun<T> Observable<T>.resilientOnError(): Observable<T> = Observable.unsafeCreate(OnResilientSubscribe(this, false)) \ No newline at end of file diff --git a/node/src/main/resources/migration/node-core.changelog-master.xml b/node/src/main/resources/migration/node-core.changelog-master.xml index 28842e0825..9e96e93d01 100644 --- a/node/src/main/resources/migration/node-core.changelog-master.xml +++ b/node/src/main/resources/migration/node-core.changelog-master.xml @@ -31,4 +31,8 @@ <include file="migration/node-core.changelog-v14-data.xml"/> + <include file="migration/node-core.changelog-v19.xml"/> + <include file="migration/node-core.changelog-v19-postgres.xml"/> + <include file="migration/node-core.changelog-v19-keys.xml"/> + </databaseChangeLog> diff --git a/node/src/main/resources/migration/node-core.changelog-v19-keys.xml b/node/src/main/resources/migration/node-core.changelog-v19-keys.xml new file mode 100644 index 0000000000..26359ecd2f --- /dev/null +++ b/node/src/main/resources/migration/node-core.changelog-v19-keys.xml @@ -0,0 +1,26 @@ +<?xml version="1.1" encoding="UTF-8" standalone="no"?> +<databaseChangeLog xmlns="http://www.liquibase.org/xml/ns/dbchangelog" + xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://www.liquibase.org/xml/ns/dbchangelog http://www.liquibase.org/xml/ns/dbchangelog/dbchangelog-3.5.xsd" + logicalFilePath="migration/node-services.changelog-init.xml"> + + <changeSet author="R3.Corda" id="add_new_checkpoint_schema_primary_keys" dbms="mssql,azure"> + <addPrimaryKey columnNames="flow_id" constraintName="node_checkpoints_pk" tableName="node_checkpoints" clustered="false"/> + </changeSet> + + <changeSet author="R3.Corda" id="add_new_checkpoint_schema_primary_keys-h2_postgres_oracle" dbms="h2,postgresql,oracle"> + <addPrimaryKey columnNames="flow_id" constraintName="node_checkpoints_pk" tableName="node_checkpoints"/> + </changeSet> + + <!-- TODO: add indexes for the rest of the tables as well (Results + Exceptions) --> + <!-- TODO: the following only add indexes so maybe also align name of file? --> + <changeSet author="R3.Corda" id="add_new_checkpoint_schema_indexes"> + <createIndex indexName="node_checkpoint_blobs_idx" tableName="node_checkpoint_blobs" clustered="false" unique="true"> + <column name="flow_id"/> + </createIndex> + <createIndex indexName="node_flow_metadata_idx" tableName="node_flow_metadata" clustered="false" unique="true"> + <column name="flow_id"/> + </createIndex> + </changeSet> + +</databaseChangeLog> diff --git a/node/src/main/resources/migration/node-core.changelog-v19-postgres.xml b/node/src/main/resources/migration/node-core.changelog-v19-postgres.xml new file mode 100644 index 0000000000..3f9ed5cab1 --- /dev/null +++ b/node/src/main/resources/migration/node-core.changelog-v19-postgres.xml @@ -0,0 +1,131 @@ +<?xml version="1.1" encoding="UTF-8" standalone="no"?> +<databaseChangeLog xmlns="http://www.liquibase.org/xml/ns/dbchangelog" + xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://www.liquibase.org/xml/ns/dbchangelog http://www.liquibase.org/xml/ns/dbchangelog/dbchangelog-3.5.xsd" + logicalFilePath="migration/node-services.changelog-init.xml"> + + <changeSet author="R3.Corda" id="add_new_checkpoints_table-postgres" dbms="postgresql"> + <dropTable tableName="node_checkpoints"></dropTable> + <createTable tableName="node_checkpoints"> + <column name="flow_id" type="NVARCHAR(64)"> + <constraints nullable="false"/> + </column> + <column name="status" type="INTEGER"> + <constraints nullable="false"/> + </column> + <column name="compatible" type="BOOLEAN"> + <constraints nullable="false"/> + </column> + <column name="progress_step" type="NVARCHAR(256)"> + <constraints nullable="true"/> + </column> + <!-- net.corda.core.internal.FlowIORequest.SendAndReceive --> + <column name="flow_io_request" type="NVARCHAR(128)"> + <constraints nullable="true"/> + </column> + <column name="timestamp" type="java.sql.Types.TIMESTAMP"> + <constraints nullable="false"/> + </column> + </createTable> + </changeSet> + + <changeSet author="R3.Corda" id="add_new_checkpoint_blob_table-postgres" dbms="postgresql"> + <createTable tableName="node_checkpoint_blobs"> + <column name="flow_id" type="NVARCHAR(64)"> + <constraints nullable="false"/> + </column> + <column name="checkpoint_value" type="varbinary(33554432)"> + <constraints nullable="false"/> + </column> + <column name="flow_state" type="varbinary(33554432)"> + <constraints nullable="true"/> + </column> + <column name="timestamp" type="java.sql.Types.TIMESTAMP"> + <constraints nullable="false"/> + </column> + <column name="hmac" type="VARBINARY(32)"> + <constraints nullable="false"/> + </column> + </createTable> + </changeSet> + + + <changeSet author="R3.Corda" id="add_new_flow_result_table-postgres" dbms="postgresql"> + <createTable tableName="node_flow_results"> + <column name="flow_id" type="NVARCHAR(64)"> + <constraints nullable="false"/> + </column> + <column name="result_value" type="varbinary(33554432)"> + <constraints nullable="false"/> + </column> + <column name="timestamp" type="java.sql.Types.TIMESTAMP"> + <constraints nullable="false"/> + </column> + </createTable> + </changeSet> + + <changeSet author="R3.Corda" id="add_new_flow_exceptions_table-postgres" dbms="postgresql"> + <createTable tableName="node_flow_exceptions"> + <column name="flow_id" type="NVARCHAR(64)"> + <constraints nullable="false"/> + </column> + <column name="type" type="NVARCHAR(255)"> + <constraints nullable="false"/> + </column> + <column name="exception_message" type="NVARCHAR(4000)"> + <constraints nullable="true"/> + </column> + <column name="stack_trace" type="NVARCHAR(4000)"> + <constraints nullable="false"/> + </column> + <column name="exception_value" type="varbinary(33554432)"> + <constraints nullable="true"/> + </column> + <column name="timestamp" type="java.sql.Types.TIMESTAMP"> + <constraints nullable="false"/> + </column> + </createTable> + </changeSet> + + <changeSet author="R3.Corda" id="add_new_flow_metadata_table-postgres" dbms="postgresql"> + <createTable tableName="node_flow_metadata"> + <column name="flow_id" type="NVARCHAR(64)"> + <constraints nullable="false"/> + </column> + <column name="invocation_id" type="NVARCHAR(128)"> + <constraints nullable="false"/> + </column> + <column name="flow_name" type="NVARCHAR(128)"> + <constraints nullable="false"/> + </column> + <column name="flow_identifier" type="NVARCHAR(512)"> + <constraints nullable="true"/> + </column> + <column name="started_type" type="INTEGER"> + <constraints nullable="false"/> + </column> + <column name="flow_parameters" type="varbinary(33554432)"> + <constraints nullable="false"/> + </column> + <column name="cordapp_name" type="NVARCHAR(128)"> + <constraints nullable="false"/> + </column> + <column name="platform_version" type="INTEGER"> + <constraints nullable="false"/> + </column> + <column name="started_by" type="NVARCHAR(128)"> + <constraints nullable="false"/> + </column> + <column name="invocation_time" type="java.sql.Types.TIMESTAMP"> + <constraints nullable="false"/> + </column> + <column name="start_time" type="java.sql.Types.TIMESTAMP"> + <constraints nullable="false"/> + </column> + <column name="finish_time" type="java.sql.Types.TIMESTAMP"> + <constraints nullable="true"/> + </column> + </createTable> + </changeSet> + +</databaseChangeLog> diff --git a/node/src/main/resources/migration/node-core.changelog-v19.xml b/node/src/main/resources/migration/node-core.changelog-v19.xml new file mode 100644 index 0000000000..cba014503c --- /dev/null +++ b/node/src/main/resources/migration/node-core.changelog-v19.xml @@ -0,0 +1,131 @@ +<?xml version="1.1" encoding="UTF-8" standalone="no"?> +<databaseChangeLog xmlns="http://www.liquibase.org/xml/ns/dbchangelog" + xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://www.liquibase.org/xml/ns/dbchangelog http://www.liquibase.org/xml/ns/dbchangelog/dbchangelog-3.5.xsd" + logicalFilePath="migration/node-services.changelog-init.xml"> + + <changeSet author="R3.Corda" id="add_new_checkpoints_table" dbms="!postgresql"> + <dropTable tableName="node_checkpoints"></dropTable> + <createTable tableName="node_checkpoints"> + <column name="flow_id" type="NVARCHAR(64)"> + <constraints nullable="false"/> + </column> + <column name="status" type="INTEGER"> + <constraints nullable="false"/> + </column> + <column name="compatible" type="BOOLEAN"> + <constraints nullable="false"/> + </column> + <column name="progress_step" type="NVARCHAR(256)"> + <constraints nullable="true"/> + </column> + <!-- net.corda.core.internal.FlowIORequest.SendAndReceive --> + <column name="flow_io_request" type="NVARCHAR(128)"> + <constraints nullable="true"/> + </column> + <column name="timestamp" type="java.sql.Types.TIMESTAMP"> + <constraints nullable="false"/> + </column> + </createTable> + </changeSet> + + <changeSet author="R3.Corda" id="add_new_checkpoint_blob_table" dbms="!postgresql"> + <createTable tableName="node_checkpoint_blobs"> + <column name="flow_id" type="NVARCHAR(64)"> + <constraints nullable="false"/> + </column> + <column name="checkpoint_value" type="blob"> + <constraints nullable="false"/> + </column> + <column name="flow_state" type="blob"> + <constraints nullable="true"/> + </column> + <column name="timestamp" type="java.sql.Types.TIMESTAMP"> + <constraints nullable="false"/> + </column> + <column name="hmac" type="VARBINARY(32)"> + <constraints nullable="false"/> + </column> + </createTable> + </changeSet> + + + <changeSet author="R3.Corda" id="add_new_flow_result_table" dbms="!postgresql"> + <createTable tableName="node_flow_results"> + <column name="flow_id" type="NVARCHAR(64)"> + <constraints nullable="false"/> + </column> + <column name="result_value" type="blob"> + <constraints nullable="false"/> + </column> + <column name="timestamp" type="java.sql.Types.TIMESTAMP"> + <constraints nullable="false"/> + </column> + </createTable> + </changeSet> + + <changeSet author="R3.Corda" id="add_new_flow_exceptions_table" dbms="!postgresql"> + <createTable tableName="node_flow_exceptions"> + <column name="flow_id" type="NVARCHAR(64)"> + <constraints nullable="false"/> + </column> + <column name="type" type="NVARCHAR(255)"> + <constraints nullable="false"/> + </column> + <column name="exception_message" type="NVARCHAR(4000)"> + <constraints nullable="true"/> + </column> + <column name="stack_trace" type="NVARCHAR(4000)"> + <constraints nullable="false"/> + </column> + <column name="exception_value" type="blob"> + <constraints nullable="true"/> + </column> + <column name="timestamp" type="java.sql.Types.TIMESTAMP"> + <constraints nullable="false"/> + </column> + </createTable> + </changeSet> + + <changeSet author="R3.Corda" id="add_new_flow_metadata_table" dbms="!postgresql"> + <createTable tableName="node_flow_metadata"> + <column name="flow_id" type="NVARCHAR(64)"> + <constraints nullable="false"/> + </column> + <column name="invocation_id" type="NVARCHAR(128)"> + <constraints nullable="false"/> + </column> + <column name="flow_name" type="NVARCHAR(128)"> + <constraints nullable="false"/> + </column> + <column name="flow_identifier" type="NVARCHAR(512)"> + <constraints nullable="true"/> + </column> + <column name="started_type" type="INTEGER"> + <constraints nullable="false"/> + </column> + <column name="flow_parameters" type="blob"> + <constraints nullable="false"/> + </column> + <column name="cordapp_name" type="NVARCHAR(128)"> + <constraints nullable="false"/> + </column> + <column name="platform_version" type="INTEGER"> + <constraints nullable="false"/> + </column> + <column name="started_by" type="NVARCHAR(128)"> + <constraints nullable="false"/> + </column> + <column name="invocation_time" type="java.sql.Types.TIMESTAMP"> + <constraints nullable="false"/> + </column> + <column name="start_time" type="java.sql.Types.TIMESTAMP"> + <constraints nullable="false"/> + </column> + <column name="finish_time" type="java.sql.Types.TIMESTAMP"> + <constraints nullable="true"/> + </column> + </createTable> + </changeSet> + +</databaseChangeLog> diff --git a/node/src/test/kotlin/net/corda/node/internal/NodeStartupCliTest.kt b/node/src/test/kotlin/net/corda/node/internal/NodeStartupCliTest.kt index 8fc0154f37..125d38f81b 100644 --- a/node/src/test/kotlin/net/corda/node/internal/NodeStartupCliTest.kt +++ b/node/src/test/kotlin/net/corda/node/internal/NodeStartupCliTest.kt @@ -32,6 +32,7 @@ class NodeStartupCliTest { Assertions.assertThat(startup.verbose).isEqualTo(false) Assertions.assertThat(startup.loggingLevel).isEqualTo(Level.INFO) Assertions.assertThat(startup.cmdLineOptions.noLocalShell).isEqualTo(false) + Assertions.assertThat(startup.cmdLineOptions.safeMode).isEqualTo(false) Assertions.assertThat(startup.cmdLineOptions.sshdServer).isEqualTo(false) Assertions.assertThat(startup.cmdLineOptions.justGenerateNodeInfo).isEqualTo(false) Assertions.assertThat(startup.cmdLineOptions.justGenerateRpcSslCerts).isEqualTo(false) diff --git a/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt b/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt index 85a99576f2..6c8ce1d7bd 100644 --- a/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt +++ b/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt @@ -33,9 +33,10 @@ import net.corda.finance.contracts.asset.CASH import net.corda.finance.contracts.asset.Cash import net.corda.finance.flows.TwoPartyTradeFlow.Buyer import net.corda.finance.flows.TwoPartyTradeFlow.Seller +import net.corda.node.services.api.CheckpointStorage import net.corda.node.services.api.WritableTransactionStorage import net.corda.node.services.persistence.DBTransactionStorage -import net.corda.node.services.persistence.checkpoints +import net.corda.node.services.statemachine.Checkpoint import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.testing.core.* import net.corda.testing.dsl.LedgerDSL @@ -56,10 +57,17 @@ import java.io.ByteArrayOutputStream import java.util.* import java.util.jar.JarOutputStream import java.util.zip.ZipEntry +import kotlin.streams.toList import kotlin.test.assertEquals import kotlin.test.assertFailsWith import kotlin.test.assertTrue +internal fun CheckpointStorage.getAllIncompleteCheckpoints(): List<Checkpoint.Serialized> { + return getCheckpointsToRun().use { + it.map { it.second }.toList() + }.filter { it.status != Checkpoint.FlowStatus.COMPLETED } +} + /** * In this example, Alice wishes to sell her commercial paper to Bob in return for $1,000,000 and they wish to do * it on the ledger atomically. Therefore they must work together to build a transaction. @@ -135,11 +143,11 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) { bobNode.dispose() aliceNode.database.transaction { - assertThat(aliceNode.internals.checkpointStorage.checkpoints()).isEmpty() + assertThat(aliceNode.internals.checkpointStorage.getAllIncompleteCheckpoints()).isEmpty() } aliceNode.internals.manuallyCloseDB() bobNode.database.transaction { - assertThat(bobNode.internals.checkpointStorage.checkpoints()).isEmpty() + assertThat(bobNode.internals.checkpointStorage.getAllIncompleteCheckpoints()).isEmpty() } bobNode.internals.manuallyCloseDB() } @@ -191,11 +199,11 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) { bobNode.dispose() aliceNode.database.transaction { - assertThat(aliceNode.internals.checkpointStorage.checkpoints()).isEmpty() + assertThat(aliceNode.internals.checkpointStorage.getAllIncompleteCheckpoints()).isEmpty() } aliceNode.internals.manuallyCloseDB() bobNode.database.transaction { - assertThat(bobNode.internals.checkpointStorage.checkpoints()).isEmpty() + assertThat(bobNode.internals.checkpointStorage.getAllIncompleteCheckpoints()).isEmpty() } bobNode.internals.manuallyCloseDB() } @@ -245,7 +253,7 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) { // OK, now Bob has sent the partial transaction back to Alice and is waiting for Alice's signature. bobNode.database.transaction { - assertThat(bobNode.internals.checkpointStorage.checkpoints()).hasSize(1) + assertThat(bobNode.internals.checkpointStorage.getAllIncompleteCheckpoints()).hasSize(1) } val storage = bobNode.services.validatedTransactions @@ -278,10 +286,10 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) { assertThat(bobNode.smm.findStateMachines(Buyer::class.java)).isEmpty() bobNode.database.transaction { - assertThat(bobNode.internals.checkpointStorage.checkpoints()).isEmpty() + assertThat(bobNode.internals.checkpointStorage.getAllIncompleteCheckpoints()).isEmpty() } aliceNode.database.transaction { - assertThat(aliceNode.internals.checkpointStorage.checkpoints()).isEmpty() + assertThat(aliceNode.internals.checkpointStorage.getAllIncompleteCheckpoints()).isEmpty() } bobNode.database.transaction { diff --git a/node/src/test/kotlin/net/corda/node/services/persistence/DBCheckpointStorageTests.kt b/node/src/test/kotlin/net/corda/node/services/persistence/DBCheckpointStorageTests.kt index c7ceb785da..a75960523b 100644 --- a/node/src/test/kotlin/net/corda/node/services/persistence/DBCheckpointStorageTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/persistence/DBCheckpointStorageTests.kt @@ -1,20 +1,29 @@ package net.corda.node.services.persistence import net.corda.core.context.InvocationContext +import net.corda.core.context.InvocationOrigin import net.corda.core.flows.FlowLogic import net.corda.core.flows.StateMachineRunId +import net.corda.core.internal.FlowIORequest +import net.corda.core.internal.toSet import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.internal.CheckpointSerializationDefaults import net.corda.core.serialization.internal.checkpointSerialize +import net.corda.core.utilities.contextLogger import net.corda.node.internal.CheckpointIncompatibleException import net.corda.node.internal.CheckpointVerifier import net.corda.node.services.api.CheckpointStorage import net.corda.node.services.statemachine.Checkpoint +import net.corda.node.services.statemachine.CheckpointState +import net.corda.node.services.statemachine.ErrorState +import net.corda.node.services.statemachine.FlowError import net.corda.node.services.statemachine.FlowStart +import net.corda.node.services.statemachine.FlowState import net.corda.node.services.statemachine.SubFlowVersion import net.corda.node.services.transactions.PersistentUniquenessProvider import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.DatabaseConfig +import net.corda.nodeapi.internal.persistence.DatabaseTransaction import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.SerializationEnvironmentRule import net.corda.testing.core.TestIdentity @@ -22,22 +31,33 @@ import net.corda.testing.internal.LogHelper import net.corda.testing.internal.configureDatabase import net.corda.testing.node.MockServices import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties +import org.apache.commons.lang3.exception.ExceptionUtils import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatThrownBy import org.junit.After +import org.junit.Assert.assertNotNull +import org.junit.Assert.assertNull import org.junit.Before +import org.junit.Ignore import org.junit.Rule import org.junit.Test +import java.time.Clock +import java.util.* import kotlin.streams.toList +import kotlin.test.assertEquals +import kotlin.test.assertTrue -internal fun CheckpointStorage.checkpoints(): List<SerializedBytes<Checkpoint>> { - return getAllCheckpoints().use { +internal fun CheckpointStorage.checkpoints(): List<Checkpoint.Serialized> { + return getCheckpoints().use { it.map { it.second }.toList() } } class DBCheckpointStorageTests { private companion object { + + val log = contextLogger() + val ALICE = TestIdentity(ALICE_NAME, 70).party } @@ -61,27 +81,129 @@ class DBCheckpointStorageTests { LogHelper.reset(PersistentUniquenessProvider::class) } - @Test(timeout=300_000) - fun `add new checkpoint`() { + @Test(timeout = 300_000) + fun `add new checkpoint`() { val (id, checkpoint) = newCheckpoint() + val serializedFlowState = checkpoint.serializeFlowState() database.transaction { - checkpointStorage.addCheckpoint(id, checkpoint) + checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState, checkpoint.serializeCheckpointState()) } database.transaction { - assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint) + assertEquals(serializedFlowState, checkpointStorage.checkpoints().single().serializedFlowState) + assertEquals( + checkpoint, + checkpointStorage.checkpoints().single().deserialize() + ) } newCheckpointStorage() database.transaction { - assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint) + assertEquals( + checkpoint, + checkpointStorage.checkpoints().single().deserialize() + ) + session.get(DBCheckpointStorage.DBFlowCheckpoint::class.java, id.uuid.toString()).also { + assertNotNull(it) + assertNotNull(it.blob) + } } } - @Test(timeout=300_000) - fun `remove checkpoint`() { + @Test(timeout = 300_000) + fun `update a checkpoint`() { val (id, checkpoint) = newCheckpoint() + val serializedFlowState = checkpoint.serializeFlowState() database.transaction { - checkpointStorage.addCheckpoint(id, checkpoint) + checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState, checkpoint.serializeCheckpointState()) } + val logic: FlowLogic<*> = object : FlowLogic<String>() { + override fun call(): String { + return "Updated flow logic" + } + } + val updatedCheckpoint = checkpoint.copy( + checkpointState = checkpoint.checkpointState.copy(numberOfSuspends = 20), + flowState = FlowState.Unstarted( + FlowStart.Explicit, + logic.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT) + ), + progressStep = "I have made progress", + flowIoRequest = FlowIORequest.SendAndReceive::class.java.simpleName + ) + val updatedSerializedFlowState = updatedCheckpoint.serializeFlowState() + database.transaction { + checkpointStorage.updateCheckpoint(id, updatedCheckpoint, updatedSerializedFlowState, updatedCheckpoint.serializeCheckpointState()) + } + database.transaction { + assertEquals( + updatedCheckpoint, + checkpointStorage.checkpoints().single().deserialize() + ) + } + } + + @Test(timeout = 300_000) + fun `update a checkpoint to completed`() { + val (id, checkpoint) = newCheckpoint() + val serializedFlowState = checkpoint.serializeFlowState() + database.transaction { + checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState, checkpoint.serializeCheckpointState()) + } + + val completedCheckpoint = checkpoint.copy(status = Checkpoint.FlowStatus.COMPLETED) + database.transaction { + checkpointStorage.updateCheckpoint(id, completedCheckpoint, null, completedCheckpoint.serializeCheckpointState()) + } + database.transaction { + assertEquals( + completedCheckpoint.copy(flowState = FlowState.Completed), + checkpointStorage.checkpoints().single().deserialize() + ) + } + } + + @Test(timeout = 300_000) + fun `update a checkpoint to paused`() { + val (id, checkpoint) = newCheckpoint() + val serializedFlowState = checkpoint.serializeFlowState() + database.transaction { + checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState, checkpoint.serializeCheckpointState()) + } + + val pausedCheckpoint = checkpoint.copy(status = Checkpoint.FlowStatus.PAUSED) + database.transaction { + checkpointStorage.updateCheckpoint(id, pausedCheckpoint, null, pausedCheckpoint.serializeCheckpointState()) + } + database.transaction { + assertEquals( + pausedCheckpoint.copy(flowState = FlowState.Paused), + checkpointStorage.checkpoints().single().deserialize() + ) + } + } + + @Ignore + @Test(timeout = 300_000) + fun `removing a checkpoint deletes from all checkpoint tables`() { + val exception = IllegalStateException("I am a naughty exception") + val (id, checkpoint) = newCheckpoint() + val serializedFlowState = checkpoint.serializeFlowState() + database.transaction { + checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState, checkpoint.serializeCheckpointState()) + } + val updatedCheckpoint = checkpoint.addError(exception).copy(result = "The result") + val updatedSerializedFlowState = updatedCheckpoint.serializeFlowState() + database.transaction { checkpointStorage.updateCheckpoint(id, updatedCheckpoint, updatedSerializedFlowState, updatedCheckpoint.serializeCheckpointState()) } + + database.transaction { + assertEquals(1, findRecordsFromDatabase<DBCheckpointStorage.DBFlowException>().size) + // The result not stored yet + assertEquals(0, findRecordsFromDatabase<DBCheckpointStorage.DBFlowResult>().size) + assertEquals(1, findRecordsFromDatabase<DBCheckpointStorage.DBFlowMetadata>().size) + // The saving of checkpoint blobs needs to be fixed + assertEquals(2, findRecordsFromDatabase<DBCheckpointStorage.DBFlowCheckpointBlob>().size) + assertEquals(1, findRecordsFromDatabase<DBCheckpointStorage.DBFlowCheckpoint>().size) + } + database.transaction { checkpointStorage.removeCheckpoint(id) } @@ -92,60 +214,173 @@ class DBCheckpointStorageTests { database.transaction { assertThat(checkpointStorage.checkpoints()).isEmpty() } + + database.transaction { + assertEquals(0, findRecordsFromDatabase<DBCheckpointStorage.DBFlowException>().size) + assertEquals(0, findRecordsFromDatabase<DBCheckpointStorage.DBFlowResult>().size) + assertEquals(0, findRecordsFromDatabase<DBCheckpointStorage.DBFlowMetadata>().size) + // The saving of checkpoint blobs needs to be fixed + assertEquals(1, findRecordsFromDatabase<DBCheckpointStorage.DBFlowCheckpointBlob>().size) + assertEquals(0, findRecordsFromDatabase<DBCheckpointStorage.DBFlowCheckpoint>().size) + } } - @Test(timeout=300_000) - fun `add and remove checkpoint in single commit operate`() { + @Ignore + @Test(timeout = 300_000) + fun `removing a checkpoint when there is no result does not fail`() { + val exception = IllegalStateException("I am a naughty exception") val (id, checkpoint) = newCheckpoint() + val serializedFlowState = checkpoint.serializeFlowState() + database.transaction { + checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState, checkpoint.serializeCheckpointState()) + } + val updatedCheckpoint = checkpoint.addError(exception) + val updatedSerializedFlowState = updatedCheckpoint.serializeFlowState() + database.transaction { checkpointStorage.updateCheckpoint(id, updatedCheckpoint, updatedSerializedFlowState, updatedCheckpoint.serializeCheckpointState()) } + + database.transaction { + assertEquals(1, findRecordsFromDatabase<DBCheckpointStorage.DBFlowException>().size) + // The result not stored yet + assertEquals(0, findRecordsFromDatabase<DBCheckpointStorage.DBFlowResult>().size) + assertEquals(1, findRecordsFromDatabase<DBCheckpointStorage.DBFlowMetadata>().size) + // The saving of checkpoint blobs needs to be fixed + assertEquals(2, findRecordsFromDatabase<DBCheckpointStorage.DBFlowCheckpointBlob>().size) + assertEquals(1, findRecordsFromDatabase<DBCheckpointStorage.DBFlowCheckpoint>().size) + } + + database.transaction { + checkpointStorage.removeCheckpoint(id) + } + database.transaction { + assertThat(checkpointStorage.checkpoints()).isEmpty() + } + newCheckpointStorage() + database.transaction { + assertThat(checkpointStorage.checkpoints()).isEmpty() + } + + database.transaction { + assertEquals(0, findRecordsFromDatabase<DBCheckpointStorage.DBFlowException>().size) + assertEquals(0, findRecordsFromDatabase<DBCheckpointStorage.DBFlowResult>().size) + assertEquals(0, findRecordsFromDatabase<DBCheckpointStorage.DBFlowMetadata>().size) + // The saving of checkpoint blobs needs to be fixed + assertEquals(1, findRecordsFromDatabase<DBCheckpointStorage.DBFlowCheckpointBlob>().size) + assertEquals(0, findRecordsFromDatabase<DBCheckpointStorage.DBFlowCheckpoint>().size) + } + } + + @Test(timeout = 300_000) + fun `removing a checkpoint when there is no exception does not fail`() { + val (id, checkpoint) = newCheckpoint() + val serializedFlowState = checkpoint.serializeFlowState() + database.transaction { + checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState, checkpoint.serializeCheckpointState()) + } + val updatedCheckpoint = checkpoint.copy(result = "The result") + val updatedSerializedFlowState = updatedCheckpoint.serializeFlowState() + database.transaction { checkpointStorage.updateCheckpoint(id, updatedCheckpoint, updatedSerializedFlowState, updatedCheckpoint.serializeCheckpointState()) } + + database.transaction { + assertEquals(0, findRecordsFromDatabase<DBCheckpointStorage.DBFlowException>().size) + // The result not stored yet + assertEquals(0, findRecordsFromDatabase<DBCheckpointStorage.DBFlowResult>().size) + assertEquals(1, findRecordsFromDatabase<DBCheckpointStorage.DBFlowMetadata>().size) + assertEquals(1, findRecordsFromDatabase<DBCheckpointStorage.DBFlowCheckpointBlob>().size) + assertEquals(1, findRecordsFromDatabase<DBCheckpointStorage.DBFlowCheckpoint>().size) + } + + database.transaction { + checkpointStorage.removeCheckpoint(id) + } + database.transaction { + assertThat(checkpointStorage.checkpoints()).isEmpty() + } + newCheckpointStorage() + database.transaction { + assertThat(checkpointStorage.checkpoints()).isEmpty() + } + + database.transaction { + assertEquals(0, findRecordsFromDatabase<DBCheckpointStorage.DBFlowException>().size) + assertEquals(0, findRecordsFromDatabase<DBCheckpointStorage.DBFlowResult>().size) + assertEquals(0, findRecordsFromDatabase<DBCheckpointStorage.DBFlowMetadata>().size) + assertEquals(0, findRecordsFromDatabase<DBCheckpointStorage.DBFlowCheckpointBlob>().size) + assertEquals(0, findRecordsFromDatabase<DBCheckpointStorage.DBFlowCheckpoint>().size) + } + } + + @Test(timeout = 300_000) + fun `add and remove checkpoint in single commit operation`() { + val (id, checkpoint) = newCheckpoint() + val serializedFlowState = checkpoint.serializeFlowState() val (id2, checkpoint2) = newCheckpoint() + val serializedFlowState2 = checkpoint.serializeFlowState() database.transaction { - checkpointStorage.addCheckpoint(id, checkpoint) - checkpointStorage.addCheckpoint(id2, checkpoint2) + checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState, checkpoint.serializeCheckpointState()) + checkpointStorage.addCheckpoint(id2, checkpoint2, serializedFlowState2, checkpoint2.serializeCheckpointState()) checkpointStorage.removeCheckpoint(id) } database.transaction { - assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint2) + assertEquals( + checkpoint2, + checkpointStorage.checkpoints().single().deserialize() + ) } newCheckpointStorage() database.transaction { - assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint2) + assertEquals( + checkpoint2, + checkpointStorage.checkpoints().single().deserialize() + ) } } - @Test(timeout=300_000) - fun `add two checkpoints then remove first one`() { + @Test(timeout = 300_000) + fun `add two checkpoints then remove first one`() { val (id, firstCheckpoint) = newCheckpoint() + val serializedFirstFlowState = firstCheckpoint.serializeFlowState() + database.transaction { - checkpointStorage.addCheckpoint(id, firstCheckpoint) + checkpointStorage.addCheckpoint(id, firstCheckpoint, serializedFirstFlowState, firstCheckpoint.serializeCheckpointState()) } val (id2, secondCheckpoint) = newCheckpoint() + val serializedSecondFlowState = secondCheckpoint.serializeFlowState() database.transaction { - checkpointStorage.addCheckpoint(id2, secondCheckpoint) + checkpointStorage.addCheckpoint(id2, secondCheckpoint, serializedSecondFlowState, secondCheckpoint.serializeCheckpointState()) } database.transaction { checkpointStorage.removeCheckpoint(id) } database.transaction { - assertThat(checkpointStorage.checkpoints()).containsExactly(secondCheckpoint) + assertEquals( + secondCheckpoint, + checkpointStorage.checkpoints().single().deserialize() + ) } newCheckpointStorage() database.transaction { - assertThat(checkpointStorage.checkpoints()).containsExactly(secondCheckpoint) + assertEquals( + secondCheckpoint, + checkpointStorage.checkpoints().single().deserialize() + ) } } - @Test(timeout=300_000) - fun `add checkpoint and then remove after 'restart'`() { + @Test(timeout = 300_000) + fun `add checkpoint and then remove after 'restart'`() { val (id, originalCheckpoint) = newCheckpoint() + val serializedOriginalFlowState = originalCheckpoint.serializeFlowState() database.transaction { - checkpointStorage.addCheckpoint(id, originalCheckpoint) + checkpointStorage.addCheckpoint(id, originalCheckpoint, serializedOriginalFlowState, originalCheckpoint.serializeCheckpointState()) } newCheckpointStorage() val reconstructedCheckpoint = database.transaction { checkpointStorage.checkpoints().single() } database.transaction { - assertThat(reconstructedCheckpoint).isEqualTo(originalCheckpoint).isNotSameAs(originalCheckpoint) + assertEquals(originalCheckpoint, reconstructedCheckpoint.deserialize()) + assertThat(reconstructedCheckpoint.serializedFlowState).isEqualTo(serializedOriginalFlowState) + .isNotSameAs(serializedOriginalFlowState) } database.transaction { checkpointStorage.removeCheckpoint(id) @@ -155,12 +390,53 @@ class DBCheckpointStorageTests { } } - @Test(timeout=300_000) - fun `verify checkpoints compatible`() { + @Test(timeout = 300_000) + fun `adding a new checkpoint creates a metadata record`() { + val (id, checkpoint) = newCheckpoint() + val serializedFlowState = checkpoint.serializeFlowState() + database.transaction { + checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState, checkpoint.serializeCheckpointState()) + } + database.transaction { + session.get(DBCheckpointStorage.DBFlowMetadata::class.java, id.uuid.toString()).also { + assertNotNull(it) + } + } + } + + @Test(timeout = 300_000) + fun `updating a checkpoint does not change the metadata record`() { + val (id, checkpoint) = newCheckpoint() + val serializedFlowState = checkpoint.serializeFlowState() + database.transaction { + checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState, checkpoint.serializeCheckpointState()) + } + val metadata = database.transaction { + session.get(DBCheckpointStorage.DBFlowMetadata::class.java, id.uuid.toString()).also { + assertNotNull(it) + } + } + val updatedCheckpoint = checkpoint.copy( + checkpointState = checkpoint.checkpointState.copy( + invocationContext = InvocationContext.newInstance(InvocationOrigin.Peer(ALICE_NAME)) + ) + ) + database.transaction { + checkpointStorage.updateCheckpoint(id, updatedCheckpoint, serializedFlowState, updatedCheckpoint.serializeCheckpointState()) + } + val potentiallyUpdatedMetadata = database.transaction { + session.get(DBCheckpointStorage.DBFlowMetadata::class.java, id.uuid.toString()) + } + assertEquals(metadata, potentiallyUpdatedMetadata) + } + + @Test(timeout = 300_000) + fun `verify checkpoints compatible`() { val mockServices = MockServices(emptyList(), ALICE.name) database.transaction { val (id, checkpoint) = newCheckpoint(1) - checkpointStorage.addCheckpoint(id, checkpoint) + val serializedFlowState = checkpoint.serializeFlowState() + checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState, checkpoint.serializeCheckpointState()) } database.transaction { @@ -169,7 +445,8 @@ class DBCheckpointStorageTests { database.transaction { val (id1, checkpoint1) = newCheckpoint(2) - checkpointStorage.addCheckpoint(id1, checkpoint1) + val serializedFlowState1 = checkpoint1.serializeFlowState() + checkpointStorage.addCheckpoint(id1, checkpoint1, serializedFlowState1, checkpoint1.serializeCheckpointState()) } assertThatThrownBy { @@ -179,21 +456,491 @@ class DBCheckpointStorageTests { }.isInstanceOf(CheckpointIncompatibleException::class.java) } - private fun newCheckpointStorage() { + @Test(timeout = 300_000) + @Ignore + fun `update checkpoint with result information creates new result database record`() { + val result = "This is the result" + val (id, checkpoint) = newCheckpoint() + val serializedFlowState = + checkpoint.serializeFlowState() database.transaction { - checkpointStorage = DBCheckpointStorage() + checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState, checkpoint.serializeCheckpointState()) + } + val updatedCheckpoint = checkpoint.copy(result = result) + val updatedSerializedFlowState = updatedCheckpoint.serializeFlowState() + database.transaction { + checkpointStorage.updateCheckpoint(id, updatedCheckpoint, updatedSerializedFlowState, updatedCheckpoint.serializeCheckpointState()) + } + database.transaction { + assertEquals( + result, + checkpointStorage.getCheckpoint(id)!!.deserialize().result + ) + assertNotNull(session.get(DBCheckpointStorage.DBFlowCheckpoint::class.java, id.uuid.toString()).result) + assertEquals(1, findRecordsFromDatabase<DBCheckpointStorage.DBFlowResult>().size) } } - private fun newCheckpoint(version: Int = 1): Pair<StateMachineRunId, SerializedBytes<Checkpoint>> { + @Test(timeout = 300_000) + @Ignore + fun `update checkpoint with result information updates existing result database record`() { + val result = "This is the result" + val somehowThereIsANewResult = "Another result (which should not be possible!)" + val (id, checkpoint) = newCheckpoint() + val serializedFlowState = + checkpoint.serializeFlowState() + database.transaction { + checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState, checkpoint.serializeCheckpointState()) + } + val updatedCheckpoint = checkpoint.copy(result = result) + val updatedSerializedFlowState = updatedCheckpoint.serializeFlowState() + database.transaction { + checkpointStorage.updateCheckpoint(id, updatedCheckpoint, updatedSerializedFlowState, updatedCheckpoint.serializeCheckpointState()) + } + val updatedCheckpoint2 = checkpoint.copy(result = somehowThereIsANewResult) + val updatedSerializedFlowState2 = updatedCheckpoint2.serializeFlowState() + database.transaction { + checkpointStorage.updateCheckpoint(id, updatedCheckpoint2, updatedSerializedFlowState2, updatedCheckpoint2.serializeCheckpointState()) + } + database.transaction { + assertEquals( + somehowThereIsANewResult, + checkpointStorage.getCheckpoint(id)!!.deserialize().result + ) + assertNotNull(session.get(DBCheckpointStorage.DBFlowCheckpoint::class.java, id.uuid.toString()).result) + assertEquals(1, findRecordsFromDatabase<DBCheckpointStorage.DBFlowResult>().size) + } + } + + @Test(timeout = 300_000) + fun `removing result information from checkpoint deletes existing result database record`() { + val result = "This is the result" + val (id, checkpoint) = newCheckpoint() + val serializedFlowState = + checkpoint.serializeFlowState() + database.transaction { + checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState, checkpoint.serializeCheckpointState()) + } + val updatedCheckpoint = checkpoint.copy(result = result) + val updatedSerializedFlowState = updatedCheckpoint.serializeFlowState() + database.transaction { + checkpointStorage.updateCheckpoint(id, updatedCheckpoint, updatedSerializedFlowState, updatedCheckpoint.serializeCheckpointState()) + } + val updatedCheckpoint2 = checkpoint.copy(result = null) + val updatedSerializedFlowState2 = updatedCheckpoint2.serializeFlowState() + database.transaction { + checkpointStorage.updateCheckpoint(id, updatedCheckpoint2, updatedSerializedFlowState2, updatedCheckpoint2.serializeCheckpointState()) + } + database.transaction { + assertNull(checkpointStorage.getCheckpoint(id)!!.deserialize().result) + assertNull(session.get(DBCheckpointStorage.DBFlowCheckpoint::class.java, id.uuid.toString()).result) + assertEquals(0, findRecordsFromDatabase<DBCheckpointStorage.DBFlowResult>().size) + } + } + + @Ignore + @Test(timeout = 300_000) + fun `update checkpoint with error information creates a new error database record`() { + val exception = IllegalStateException("I am a naughty exception") + val (id, checkpoint) = newCheckpoint() + val serializedFlowState = checkpoint.serializeFlowState() + database.transaction { + checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState, checkpoint.serializeCheckpointState()) + } + val updatedCheckpoint = checkpoint.addError(exception) + val updatedSerializedFlowState = updatedCheckpoint.serializeFlowState() + database.transaction { checkpointStorage.updateCheckpoint(id, updatedCheckpoint, updatedSerializedFlowState, updatedCheckpoint.serializeCheckpointState()) } + database.transaction { + // Checkpoint always returns clean error state when retrieved via [getCheckpoint] + assertTrue(checkpointStorage.getCheckpoint(id)!!.deserialize().errorState is ErrorState.Clean) + val exceptionDetails = session.get(DBCheckpointStorage.DBFlowCheckpoint::class.java, id.uuid.toString()).exceptionDetails + assertNotNull(exceptionDetails) + assertEquals(exception::class.java.name, exceptionDetails!!.type) + assertEquals(exception.message, exceptionDetails.message) + assertEquals(1, findRecordsFromDatabase<DBCheckpointStorage.DBFlowException>().size) + } + } + + @Ignore + @Test(timeout = 300_000) + fun `update checkpoint with new error information updates the existing error database record`() { + val illegalStateException = IllegalStateException("I am a naughty exception") + val illegalArgumentException = IllegalArgumentException("I am a very naughty exception") + val (id, checkpoint) = newCheckpoint() + val serializedFlowState = checkpoint.serializeFlowState() + database.transaction { + checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState, checkpoint.serializeCheckpointState()) + } + val updatedCheckpoint1 = checkpoint.addError(illegalStateException) + val updatedSerializedFlowState1 = updatedCheckpoint1.serializeFlowState() + database.transaction { checkpointStorage.updateCheckpoint(id, updatedCheckpoint1, updatedSerializedFlowState1, updatedCheckpoint1.serializeCheckpointState()) } + // Set back to clean + val updatedCheckpoint2 = checkpoint.addError(illegalArgumentException) + val updatedSerializedFlowState2 = updatedCheckpoint2.serializeFlowState() + database.transaction { checkpointStorage.updateCheckpoint(id, updatedCheckpoint2, updatedSerializedFlowState2, updatedCheckpoint2.serializeCheckpointState()) } + database.transaction { + assertTrue(checkpointStorage.getCheckpoint(id)!!.deserialize().errorState is ErrorState.Clean) + val exceptionDetails = session.get(DBCheckpointStorage.DBFlowCheckpoint::class.java, id.uuid.toString()).exceptionDetails + assertNotNull(exceptionDetails) + assertEquals(illegalArgumentException::class.java.name, exceptionDetails!!.type) + assertEquals(illegalArgumentException.message, exceptionDetails.message) + assertEquals(1, findRecordsFromDatabase<DBCheckpointStorage.DBFlowException>().size) + } + } + + @Test(timeout = 300_000) + fun `clean checkpoints delete the error record from the database`() { + val exception = IllegalStateException("I am a naughty exception") + val (id, checkpoint) = newCheckpoint() + val serializedFlowState = checkpoint.serializeFlowState() + database.transaction { + checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState, checkpoint.serializeCheckpointState()) + } + val updatedCheckpoint = checkpoint.addError(exception) + val updatedSerializedFlowState = updatedCheckpoint.serializeFlowState() + database.transaction { checkpointStorage.updateCheckpoint(id, updatedCheckpoint, updatedSerializedFlowState, updatedCheckpoint.serializeCheckpointState()) } + database.transaction { + // Checkpoint always returns clean error state when retrieved via [getCheckpoint] + assertTrue(checkpointStorage.getCheckpoint(id)!!.deserialize().errorState is ErrorState.Clean) + } + // Set back to clean + database.transaction { checkpointStorage.updateCheckpoint(id, checkpoint, serializedFlowState, checkpoint.serializeCheckpointState()) } + database.transaction { + assertTrue(checkpointStorage.getCheckpoint(id)!!.deserialize().errorState is ErrorState.Clean) + assertNull(session.get(DBCheckpointStorage.DBFlowCheckpoint::class.java, id.uuid.toString()).exceptionDetails) + assertEquals(0, findRecordsFromDatabase<DBCheckpointStorage.DBFlowException>().size) + } + } + + @Test(timeout = 300_000) + fun `Checkpoint can be updated with flow io request information`() { + val (id, checkpoint) = newCheckpoint(1) + database.transaction { + val serializedFlowState = checkpoint.flowState.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT) + checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState, checkpoint.serializeCheckpointState()) + val checkpointFromStorage = checkpointStorage.getCheckpoint(id) + assertNull(checkpointFromStorage!!.flowIoRequest) + } + database.transaction { + val newCheckpoint = checkpoint.copy(flowIoRequest = FlowIORequest.Sleep::class.java.simpleName) + val serializedFlowState = newCheckpoint.flowState.checkpointSerialize( + context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT + ) + checkpointStorage.updateCheckpoint(id, newCheckpoint, serializedFlowState, newCheckpoint.serializeCheckpointState()) + } + database.transaction { + val checkpointFromStorage = checkpointStorage.getCheckpoint(id) + assertNotNull(checkpointFromStorage!!.flowIoRequest) + val flowIORequest = checkpointFromStorage.flowIoRequest + assertEquals(FlowIORequest.Sleep::class.java.simpleName, flowIORequest) + } + } + + @Test(timeout = 300_000) + fun `Checkpoint truncates long progressTracker step name`() { + val maxProgressStepLength = 256 + val (id, checkpoint) = newCheckpoint(1) + database.transaction { + val serializedFlowState = checkpoint.flowState.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT) + checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState, checkpoint.serializeCheckpointState()) + val checkpointFromStorage = checkpointStorage.getCheckpoint(id) + assertNull(checkpointFromStorage!!.progressStep) + } + val longString = """Long string Long string Long string Long string Long string Long string Long string Long string Long string + Long string Long string Long string Long string Long string Long string Long string Long string Long string Long string + Long string Long string Long string Long string Long string Long string Long string Long string Long string Long string + """.trimIndent() + database.transaction { + val newCheckpoint = checkpoint.copy(progressStep = longString) + val serializedFlowState = newCheckpoint.flowState.checkpointSerialize( + context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT + ) + checkpointStorage.updateCheckpoint(id, newCheckpoint, serializedFlowState, newCheckpoint.serializeCheckpointState()) + } + database.transaction { + val checkpointFromStorage = checkpointStorage.getCheckpoint(id) + assertEquals(longString.take(maxProgressStepLength), checkpointFromStorage!!.progressStep) + } + } + + @Test(timeout = 300_000) + fun `Checkpoints can be fetched from the database by status`() { + val (_, checkpoint) = newCheckpoint(1) + // runnables + val runnable = checkpoint.copy(status = Checkpoint.FlowStatus.RUNNABLE) + val hospitalized = checkpoint.copy(status = Checkpoint.FlowStatus.HOSPITALIZED) + // not runnables + val completed = checkpoint.copy(status = Checkpoint.FlowStatus.COMPLETED) + val failed = checkpoint.copy(status = Checkpoint.FlowStatus.FAILED) + val killed = checkpoint.copy(status = Checkpoint.FlowStatus.KILLED) + // paused + val paused = checkpoint.copy(status = Checkpoint.FlowStatus.PAUSED) + + database.transaction { + val serializedFlowState = + checkpoint.flowState.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT) + + checkpointStorage.addCheckpoint(StateMachineRunId.createRandom(), runnable, serializedFlowState, runnable.serializeCheckpointState()) + checkpointStorage.addCheckpoint(StateMachineRunId.createRandom(), hospitalized, serializedFlowState, hospitalized.serializeCheckpointState()) + checkpointStorage.addCheckpoint(StateMachineRunId.createRandom(), completed, serializedFlowState, completed.serializeCheckpointState()) + checkpointStorage.addCheckpoint(StateMachineRunId.createRandom(), failed, serializedFlowState, failed.serializeCheckpointState()) + checkpointStorage.addCheckpoint(StateMachineRunId.createRandom(), killed, serializedFlowState, killed.serializeCheckpointState()) + checkpointStorage.addCheckpoint(StateMachineRunId.createRandom(), paused, serializedFlowState, paused.serializeCheckpointState()) + } + + database.transaction { + val toRunStatuses = setOf(Checkpoint.FlowStatus.RUNNABLE, Checkpoint.FlowStatus.HOSPITALIZED) + val pausedStatuses = setOf(Checkpoint.FlowStatus.PAUSED) + val customStatuses = setOf(Checkpoint.FlowStatus.RUNNABLE, Checkpoint.FlowStatus.KILLED) + val customStatuses1 = setOf(Checkpoint.FlowStatus.PAUSED, Checkpoint.FlowStatus.HOSPITALIZED, Checkpoint.FlowStatus.FAILED) + + assertEquals(toRunStatuses, checkpointStorage.getCheckpointsToRun().map { it.second.status }.toSet()) + assertEquals(pausedStatuses, checkpointStorage.getPausedCheckpoints().map { it.second.status }.toSet()) + assertEquals(customStatuses, checkpointStorage.getCheckpoints(customStatuses).map { it.second.status }.toSet()) + assertEquals(customStatuses1, checkpointStorage.getCheckpoints(customStatuses1).map { it.second.status }.toSet()) + } + } + + @Ignore + @Test(timeout = 300_000) + fun `-not greater than DBCheckpointStorage_MAX_STACKTRACE_LENGTH- stackTrace gets persisted as a whole`() { + val smallerDummyStackTrace = ArrayList<StackTraceElement>() + val dummyStackTraceElement = StackTraceElement("class", "method", "file", 0) + + for (i in 0 until iterationsBasedOnLineSeparatorLength()) { + smallerDummyStackTrace.add(dummyStackTraceElement) + } + + val smallerStackTraceException = java.lang.IllegalStateException() + .apply { + this.stackTrace = smallerDummyStackTrace.toTypedArray() + } + val smallerStackTraceSize = ExceptionUtils.getStackTrace(smallerStackTraceException).length + assertTrue(smallerStackTraceSize < DBCheckpointStorage.MAX_STACKTRACE_LENGTH) + + val (id, checkpoint) = newCheckpoint() + val serializedFlowState = checkpoint.serializeFlowState() + database.transaction { + checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState, checkpoint.serializeCheckpointState()) + } + database.transaction { + checkpointStorage.updateCheckpoint(id, checkpoint.addError(smallerStackTraceException), serializedFlowState, checkpoint.serializeCheckpointState()) + } + database.transaction { + val persistedCheckpoint = checkpointStorage.getDBCheckpoint(id) + val persistedStackTrace = persistedCheckpoint!!.exceptionDetails!!.stackTrace + assertEquals(smallerStackTraceSize, persistedStackTrace.length) + assertEquals(ExceptionUtils.getStackTrace(smallerStackTraceException), persistedStackTrace) + } + } + + @Ignore + @Test(timeout = 300_000) + fun `-greater than DBCheckpointStorage_MAX_STACKTRACE_LENGTH- stackTrace gets truncated to MAX_LENGTH_VARCHAR, and persisted`() { + val smallerDummyStackTrace = ArrayList<StackTraceElement>() + val dummyStackTraceElement = StackTraceElement("class", "method", "file", 0) + + for (i in 0 until iterationsBasedOnLineSeparatorLength()) { + smallerDummyStackTrace.add(dummyStackTraceElement) + } + + val smallerStackTraceException = java.lang.IllegalStateException() + .apply { + this.stackTrace = smallerDummyStackTrace.toTypedArray() + } + val smallerStackTraceSize = ExceptionUtils.getStackTrace(smallerStackTraceException).length + log.info("smallerStackTraceSize = $smallerStackTraceSize") + assertTrue(smallerStackTraceSize < DBCheckpointStorage.MAX_STACKTRACE_LENGTH) + + val biggerStackTraceException = java.lang.IllegalStateException() + .apply { + this.stackTrace = (smallerDummyStackTrace + dummyStackTraceElement).toTypedArray() + } + val biggerStackTraceSize = ExceptionUtils.getStackTrace(biggerStackTraceException).length + log.info("biggerStackTraceSize = $biggerStackTraceSize") + assertTrue(biggerStackTraceSize > DBCheckpointStorage.MAX_STACKTRACE_LENGTH) + + val (id, checkpoint) = newCheckpoint() + val serializedFlowState = checkpoint.serializeFlowState() + database.transaction { + checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState, checkpoint.serializeCheckpointState()) + } + database.transaction { + checkpointStorage.updateCheckpoint(id, checkpoint.addError(biggerStackTraceException), serializedFlowState, checkpoint.serializeCheckpointState()) + } + database.transaction { + val persistedCheckpoint = checkpointStorage.getDBCheckpoint(id) + val persistedStackTrace = persistedCheckpoint!!.exceptionDetails!!.stackTrace + // last line of DBFlowException.stackTrace was a half line; will be truncated to the last whole line, + // therefore it will have the exact same length as 'notGreaterThanDummyException' exception + assertEquals(smallerStackTraceSize, persistedStackTrace.length) + assertEquals(ExceptionUtils.getStackTrace(smallerStackTraceException), persistedStackTrace) + } + } + + private fun iterationsBasedOnLineSeparatorLength() = when { + System.getProperty("line.separator").length == 1 -> // Linux or Mac + 158 + System.getProperty("line.separator").length == 2 -> // Windows + 152 + else -> throw IllegalStateException("Unknown line.separator") + } + + @Test(timeout = 300_000) + fun `paused checkpoints can be extracted`() { + val (id, checkpoint) = newCheckpoint() + val serializedFlowState = checkpoint.serializeFlowState() + val pausedCheckpoint = checkpoint.copy(status = Checkpoint.FlowStatus.PAUSED) + database.transaction { + checkpointStorage.addCheckpoint(id, pausedCheckpoint, serializedFlowState, pausedCheckpoint.serializeCheckpointState()) + } + + database.transaction { + val (extractedId, extractedCheckpoint) = checkpointStorage.getPausedCheckpoints().toList().single() + assertEquals(id, extractedId) + //We don't extract the result or the flowstate from a paused checkpoint + assertEquals(null, extractedCheckpoint.serializedFlowState) + assertEquals(null, extractedCheckpoint.result) + + assertEquals(pausedCheckpoint.status, extractedCheckpoint.status) + assertEquals(pausedCheckpoint.progressStep, extractedCheckpoint.progressStep) + assertEquals(pausedCheckpoint.flowIoRequest, extractedCheckpoint.flowIoRequest) + + val deserialisedCheckpoint = extractedCheckpoint.deserialize() + assertEquals(pausedCheckpoint.checkpointState, deserialisedCheckpoint.checkpointState) + assertEquals(FlowState.Paused, deserialisedCheckpoint.flowState) + } + } + + @Test(timeout = 300_000) + fun `checkpoints correctly change there status to paused`() { + val (_, checkpoint) = newCheckpoint(1) + // runnables + val runnable = changeStatus(checkpoint, Checkpoint.FlowStatus.RUNNABLE) + val hospitalized = changeStatus(checkpoint, Checkpoint.FlowStatus.HOSPITALIZED) + // not runnables + val completed = changeStatus(checkpoint, Checkpoint.FlowStatus.COMPLETED) + val failed = changeStatus(checkpoint, Checkpoint.FlowStatus.FAILED) + val killed = changeStatus(checkpoint, Checkpoint.FlowStatus.KILLED) + // paused + val paused = changeStatus(checkpoint, Checkpoint.FlowStatus.PAUSED) + database.transaction { + val serializedFlowState = + checkpoint.flowState.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT) + + checkpointStorage.addCheckpoint(runnable.id, runnable.checkpoint, serializedFlowState, runnable.checkpoint.serializeCheckpointState()) + checkpointStorage.addCheckpoint(hospitalized.id, hospitalized.checkpoint, serializedFlowState, hospitalized.checkpoint.serializeCheckpointState()) + checkpointStorage.addCheckpoint(completed.id, completed.checkpoint, serializedFlowState, completed.checkpoint.serializeCheckpointState()) + checkpointStorage.addCheckpoint(failed.id, failed.checkpoint, serializedFlowState, failed.checkpoint.serializeCheckpointState()) + checkpointStorage.addCheckpoint(killed.id, killed.checkpoint, serializedFlowState, killed.checkpoint.serializeCheckpointState()) + checkpointStorage.addCheckpoint(paused.id, paused.checkpoint, serializedFlowState, paused.checkpoint.serializeCheckpointState()) + } + + database.transaction { + checkpointStorage.markAllPaused() + } + + database.transaction { + //Hospitalised and paused checkpoints status should update + assertEquals(Checkpoint.FlowStatus.PAUSED, checkpointStorage.getDBCheckpoint(runnable.id)!!.status) + assertEquals(Checkpoint.FlowStatus.PAUSED, checkpointStorage.getDBCheckpoint(hospitalized.id)!!.status) + //Other checkpoints should not be updated + assertEquals(Checkpoint.FlowStatus.COMPLETED, checkpointStorage.getDBCheckpoint(completed.id)!!.status) + assertEquals(Checkpoint.FlowStatus.FAILED, checkpointStorage.getDBCheckpoint(failed.id)!!.status) + assertEquals(Checkpoint.FlowStatus.KILLED, checkpointStorage.getDBCheckpoint(killed.id)!!.status) + assertEquals(Checkpoint.FlowStatus.PAUSED, checkpointStorage.getDBCheckpoint(paused.id)!!.status) + } + } + + @Test(timeout = 300_000) + fun `updateCheckpoint setting DBFlowCheckpoint_blob to null whenever flow fails or gets hospitalized doesn't break ORM relationship`() { + val (id, checkpoint) = newCheckpoint() + val serializedFlowState = checkpoint.flowState.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT) + + database.transaction { + checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState, checkpoint.serializeCheckpointState()) + } + + database.transaction { + val paused = changeStatus(checkpoint, Checkpoint.FlowStatus.FAILED) // the exact same behaviour applies for 'HOSPITALIZED' as well + checkpointStorage.updateCheckpoint(id, paused.checkpoint, serializedFlowState, paused.checkpoint.serializeCheckpointState()) + } + + database.transaction { + val dbFlowCheckpoint= checkpointStorage.getDBCheckpoint(id) + assert(dbFlowCheckpoint!!.blob != null) + } + } + + data class IdAndCheckpoint(val id: StateMachineRunId, val checkpoint: Checkpoint) + + private fun changeStatus(oldCheckpoint: Checkpoint, status: Checkpoint.FlowStatus): IdAndCheckpoint { + return IdAndCheckpoint(StateMachineRunId.createRandom(), oldCheckpoint.copy(status = status)) + } + + private fun newCheckpointStorage() { + database.transaction { + checkpointStorage = DBCheckpointStorage( + object : CheckpointPerformanceRecorder { + override fun record( + serializedCheckpointState: SerializedBytes<CheckpointState>, + serializedFlowState: SerializedBytes<FlowState>? + ) { + // do nothing + } + }, + Clock.systemUTC() + ) + } + } + + private fun newCheckpoint(version: Int = 1): Pair<StateMachineRunId, Checkpoint> { val id = StateMachineRunId.createRandom() val logic: FlowLogic<*> = object : FlowLogic<Unit>() { override fun call() {} } val frozenLogic = logic.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT) - val checkpoint = Checkpoint.create(InvocationContext.shell(), FlowStart.Explicit, logic.javaClass, frozenLogic, ALICE, SubFlowVersion.CoreFlow(version), false) - .getOrThrow() - return id to checkpoint.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT) + val checkpoint = Checkpoint.create( + InvocationContext.shell(), + FlowStart.Explicit, + logic.javaClass, + frozenLogic, + ALICE, + SubFlowVersion.CoreFlow(version), + false + ) + .getOrThrow() + return id to checkpoint } + private fun Checkpoint.serializeFlowState(): SerializedBytes<FlowState> { + return flowState.checkpointSerialize(CheckpointSerializationDefaults.CHECKPOINT_CONTEXT) + } + + private fun Checkpoint.serializeCheckpointState(): SerializedBytes<CheckpointState> { + return checkpointState.checkpointSerialize(CheckpointSerializationDefaults.CHECKPOINT_CONTEXT) + } + + private fun Checkpoint.Serialized.deserialize(): Checkpoint { + return deserialize(CheckpointSerializationDefaults.CHECKPOINT_CONTEXT) + } + + private fun Checkpoint.addError(exception: Exception): Checkpoint { + return copy( + errorState = ErrorState.Errored( + listOf( + FlowError( + 0, + exception + ) + ), 0, false + ) + ) + } + + private inline fun <reified T> DatabaseTransaction.findRecordsFromDatabase(): List<T> { + val criteria = session.criteriaBuilder.createQuery(T::class.java) + criteria.select(criteria.from(T::class.java)) + return session.createQuery(criteria).resultList + } } diff --git a/node/src/test/kotlin/net/corda/node/services/rpc/CheckpointDumperImplTest.kt b/node/src/test/kotlin/net/corda/node/services/rpc/CheckpointDumperImplTest.kt index bf9a482e66..4037bd80f0 100644 --- a/node/src/test/kotlin/net/corda/node/services/rpc/CheckpointDumperImplTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/rpc/CheckpointDumperImplTest.kt @@ -5,6 +5,7 @@ import com.natpryce.hamkrest.containsSubstring import com.nhaarman.mockito_kotlin.doReturn import com.nhaarman.mockito_kotlin.mock import com.nhaarman.mockito_kotlin.whenever +import junit.framework.TestCase.assertNull import net.corda.core.context.InvocationContext import net.corda.core.flows.FlowLogic import net.corda.core.flows.StateMachineRunId @@ -20,13 +21,16 @@ import net.corda.core.serialization.SerializeAsToken import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.internal.CheckpointSerializationDefaults import net.corda.core.serialization.internal.checkpointSerialize -import net.corda.nodeapi.internal.lifecycle.NodeServicesContext -import net.corda.nodeapi.internal.lifecycle.NodeLifecycleEvent import net.corda.node.internal.NodeStartup +import net.corda.node.services.persistence.CheckpointPerformanceRecorder import net.corda.node.services.persistence.DBCheckpointStorage import net.corda.node.services.statemachine.Checkpoint +import net.corda.node.services.statemachine.CheckpointState import net.corda.node.services.statemachine.FlowStart +import net.corda.node.services.statemachine.FlowState import net.corda.node.services.statemachine.SubFlowVersion +import net.corda.nodeapi.internal.lifecycle.NodeLifecycleEvent +import net.corda.nodeapi.internal.lifecycle.NodeServicesContext import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.testing.core.SerializationEnvironmentRule import net.corda.testing.core.TestIdentity @@ -104,11 +108,33 @@ class CheckpointDumperImplTest { // add a checkpoint val (id, checkpoint) = newCheckpoint() database.transaction { - checkpointStorage.addCheckpoint(id, checkpoint) + checkpointStorage.addCheckpoint(id, checkpoint, serializeFlowState(checkpoint), serializeCheckpointState(checkpoint)) } dumper.dumpCheckpoints() - checkDumpFile() + checkDumpFile() + } + + @Test(timeout=300_000) + fun `Checkpoint dumper doesn't output completed checkpoints`() { + val dumper = CheckpointDumperImpl(checkpointStorage, database, services, baseDirectory) + dumper.update(mockAfterStartEvent) + + // add a checkpoint + val (id, checkpoint) = newCheckpoint() + database.transaction { + checkpointStorage.addCheckpoint(id, checkpoint, serializeFlowState(checkpoint), serializeCheckpointState(checkpoint)) + } + val newCheckpoint = checkpoint.copy( + flowState = FlowState.Completed, + status = Checkpoint.FlowStatus.COMPLETED + ) + database.transaction { + checkpointStorage.updateCheckpoint(id, newCheckpoint, null, serializeCheckpointState(newCheckpoint)) + } + + dumper.dumpCheckpoints() + checkDumpFileEmpty() } private fun checkDumpFile() { @@ -120,6 +146,13 @@ class CheckpointDumperImplTest { } } + private fun checkDumpFileEmpty() { + ZipInputStream(file.inputStream()).use { zip -> + val entry = zip.nextEntry + assertNull(entry) + } + } + // This test will only succeed when the VM startup includes the "checkpoint-agent": // -javaagent:tools/checkpoint-agent/build/libs/checkpoint-agent.jar @Test(timeout=300_000) @@ -130,7 +163,7 @@ class CheckpointDumperImplTest { // add a checkpoint val (id, checkpoint) = newCheckpoint() database.transaction { - checkpointStorage.addCheckpoint(id, checkpoint) + checkpointStorage.addCheckpoint(id, checkpoint, serializeFlowState(checkpoint), serializeCheckpointState(checkpoint)) } dumper.dumpCheckpoints() @@ -140,11 +173,21 @@ class CheckpointDumperImplTest { private fun newCheckpointStorage() { database.transaction { - checkpointStorage = DBCheckpointStorage() + checkpointStorage = DBCheckpointStorage( + object : CheckpointPerformanceRecorder { + override fun record( + serializedCheckpointState: SerializedBytes<CheckpointState>, + serializedFlowState: SerializedBytes<FlowState>? + ) { + // do nothing + } + }, + Clock.systemUTC() + ) } } - private fun newCheckpoint(version: Int = 1): Pair<StateMachineRunId, SerializedBytes<Checkpoint>> { + private fun newCheckpoint(version: Int = 1): Pair<StateMachineRunId, Checkpoint> { val id = StateMachineRunId.createRandom() val logic: FlowLogic<*> = object : FlowLogic<Unit>() { override fun call() {} @@ -152,6 +195,14 @@ class CheckpointDumperImplTest { val frozenLogic = logic.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT) val checkpoint = Checkpoint.create(InvocationContext.shell(), FlowStart.Explicit, logic.javaClass, frozenLogic, myself.identity.party, SubFlowVersion.CoreFlow(version), false) .getOrThrow() - return id to checkpoint.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT) + return id to checkpoint } -} \ No newline at end of file + + private fun serializeFlowState(checkpoint: Checkpoint): SerializedBytes<FlowState> { + return checkpoint.flowState.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT) + } + + private fun serializeCheckpointState(checkpoint: Checkpoint): SerializedBytes<CheckpointState> { + return checkpoint.checkpointState.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT) + } +} diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt index ce92f2954b..feafb34279 100644 --- a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt @@ -9,13 +9,27 @@ import net.corda.core.concurrent.CordaFuture import net.corda.core.contracts.ContractState import net.corda.core.crypto.SecureHash import net.corda.core.crypto.random63BitValue -import net.corda.core.flows.* +import net.corda.core.flows.Destination +import net.corda.core.flows.FinalityFlow +import net.corda.core.flows.FlowException +import net.corda.core.flows.FlowInfo +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.FlowSession +import net.corda.core.flows.HospitalizeFlowException +import net.corda.core.flows.InitiatingFlow +import net.corda.core.flows.ReceiveFinalityFlow +import net.corda.core.flows.StateMachineRunId +import net.corda.core.flows.UnexpectedFlowEndException import net.corda.core.identity.Party import net.corda.core.internal.DeclaredField +import net.corda.core.internal.FlowIORequest +import net.corda.core.internal.FlowStateMachine import net.corda.core.internal.concurrent.flatMap +import net.corda.core.internal.concurrent.openFuture import net.corda.core.messaging.MessageRecipients import net.corda.core.node.services.PartyInfo import net.corda.core.node.services.queryBy +import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize import net.corda.core.toFuture @@ -24,8 +38,13 @@ import net.corda.core.transactions.TransactionBuilder import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.ProgressTracker.Change import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.seconds import net.corda.core.utilities.unwrap +import net.corda.node.services.persistence.CheckpointPerformanceRecorder +import net.corda.node.services.persistence.DBCheckpointStorage import net.corda.node.services.persistence.checkpoints +import net.corda.nodeapi.internal.persistence.DatabaseTransaction +import net.corda.nodeapi.internal.persistence.currentDBSession import net.corda.testing.contracts.DummyContract import net.corda.testing.contracts.DummyState import net.corda.testing.core.ALICE_NAME @@ -36,23 +55,39 @@ import net.corda.testing.flows.registerCordappFlowFactory import net.corda.testing.internal.LogHelper import net.corda.testing.node.InMemoryMessagingNetwork.MessageTransfer import net.corda.testing.node.InMemoryMessagingNetwork.ServicePeerAllocationStrategy.RoundRobin -import net.corda.testing.node.internal.* +import net.corda.testing.node.internal.DUMMY_CONTRACTS_CORDAPP +import net.corda.testing.node.internal.FINANCE_CONTRACTS_CORDAPP +import net.corda.testing.node.internal.InternalMockNetwork +import net.corda.testing.node.internal.InternalMockNodeParameters +import net.corda.testing.node.internal.TestStartedNode +import net.corda.testing.node.internal.getMessage +import net.corda.testing.node.internal.startFlow +import org.apache.commons.lang3.exception.ExceptionUtils import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatIllegalArgumentException import org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType import org.assertj.core.api.Condition import org.junit.After +import org.junit.Assert.assertEquals +import org.junit.Assert.assertNotEquals +import org.junit.Assert.assertNotNull +import org.junit.Assert.assertNull import org.junit.Before +import org.junit.Ignore import org.junit.Test import rx.Notification import rx.Observable +import java.sql.SQLTransientConnectionException +import java.time.Clock import java.time.Duration import java.time.Instant -import java.util.* +import java.util.ArrayList +import java.util.concurrent.TimeoutException import java.util.function.Predicate import kotlin.reflect.KClass -import kotlin.test.assertEquals +import kotlin.streams.toList import kotlin.test.assertFailsWith +import kotlin.test.assertTrue class FlowFrameworkTests { companion object { @@ -69,10 +104,22 @@ class FlowFrameworkTests { private lateinit var notaryIdentity: Party private val receivedSessionMessages = ArrayList<SessionTransfer>() + private val dbCheckpointStorage = DBCheckpointStorage( + object : CheckpointPerformanceRecorder { + override fun record( + serializedCheckpointState: SerializedBytes<CheckpointState>, + serializedFlowState: SerializedBytes<FlowState>? + ) { + // do nothing + } + }, + Clock.systemUTC() + ) + @Before fun setUpMockNet() { mockNet = InternalMockNetwork( - cordappsForAllNodes = listOf(DUMMY_CONTRACTS_CORDAPP), + cordappsForAllNodes = listOf(DUMMY_CONTRACTS_CORDAPP, FINANCE_CONTRACTS_CORDAPP), servicePeerAllocationStrategy = RoundRobin() ) @@ -95,6 +142,9 @@ class FlowFrameworkTests { fun cleanUp() { mockNet.stopNodes() receivedSessionMessages.clear() + + SuspendingFlow.hookBeforeCheckpoint = {} + SuspendingFlow.hookAfterCheckpoint = {} } @Test(timeout=300_000) @@ -208,6 +258,19 @@ class FlowFrameworkTests { script(FlowMonitor(aliceNode.smm, Duration.ZERO, Duration.ZERO), FlowMonitor(bobNode.smm, Duration.ZERO, Duration.ZERO)) } + @Test(timeout = 300_000) + fun `flow status is updated in database when flow suspends on ioRequest`() { + val terminationSignal = Semaphore(0) + bobNode.registerCordappFlowFactory(ReceiveFlow::class) { NoOpFlow( terminateUponSignal = terminationSignal) } + val flowId = aliceNode.services.startFlow(ReceiveFlow(bob)).id + mockNet.runNetwork() + aliceNode.database.transaction { + val checkpoint = dbCheckpointStorage.getCheckpoint(flowId) + assertEquals(FlowIORequest.Receive::class.java.simpleName, checkpoint?.flowIoRequest) + } + terminationSignal.release() + } + @Test(timeout=300_000) fun `receiving unexpected session end before entering sendAndReceive`() { bobNode.registerCordappFlowFactory(WaitForOtherSideEndBeforeSendAndReceive::class) { NoOpFlow() } @@ -241,7 +304,8 @@ class FlowFrameworkTests { .withStackTraceContaining(ReceiveFlow::class.java.name) // Make sure the stack trace is that of the receiving flow .withStackTraceContaining("Received counter-flow exception from peer") bobNode.database.transaction { - assertThat(bobNode.internals.checkpointStorage.checkpoints()).isEmpty() + val checkpoint = bobNode.internals.checkpointStorage.checkpoints().single() + assertEquals(Checkpoint.FlowStatus.FAILED, checkpoint.status) } assertThat(receivingFiber.state).isEqualTo(Strand.State.WAITING) @@ -285,6 +349,64 @@ class FlowFrameworkTests { }, "FlowException's private peer field has value set")) } + //We should update this test when we do the work to persists the flow result. + @Test(timeout = 300_000) + fun `Checkpoint and all its related records are deleted when the flow finishes`() { + val terminationSignal = Semaphore(0) + val flow = aliceNode.services.startFlow(NoOpFlow( terminateUponSignal = terminationSignal)) + mockNet.waitQuiescent() // current thread needs to wait fiber running on a different thread, has reached the blocking point + aliceNode.database.transaction { + val checkpoint = dbCheckpointStorage.getCheckpoint(flow.id) + assertNull(checkpoint!!.result) + assertNotNull(checkpoint.serializedFlowState) + assertNotEquals(Checkpoint.FlowStatus.COMPLETED, checkpoint.status) + } + terminationSignal.release() + mockNet.waitQuiescent() + aliceNode.database.transaction { + val checkpoint = dbCheckpointStorage.getCheckpoint(flow.id) + assertNull(checkpoint) + assertEquals(0, findRecordsFromDatabase<DBCheckpointStorage.DBFlowMetadata>().size) + assertEquals(0, findRecordsFromDatabase<DBCheckpointStorage.DBFlowCheckpointBlob>().size) + assertEquals(0, findRecordsFromDatabase<DBCheckpointStorage.DBFlowCheckpoint>().size) + } + } + + // Ignoring test since completed flows are not currently keeping their checkpoints in the database + @Ignore + @Test(timeout = 300_000) + fun `Flow metadata finish time is set in database when the flow finishes`() { + val terminationSignal = Semaphore(0) + val flow = aliceNode.services.startFlow(NoOpFlow(terminateUponSignal = terminationSignal)) + mockNet.waitQuiescent() + aliceNode.database.transaction { + val metadata = session.find(DBCheckpointStorage.DBFlowMetadata::class.java, flow.id.uuid.toString()) + assertNull(metadata.finishInstant) + } + terminationSignal.release() + mockNet.waitQuiescent() + aliceNode.database.transaction { + val metadata = session.find(DBCheckpointStorage.DBFlowMetadata::class.java, flow.id.uuid.toString()) + assertNotNull(metadata.finishInstant) + assertTrue(metadata.finishInstant!! >= metadata.startInstant) + } + } + + @Test(timeout = 300_000) + fun `Flow persists progress tracker in the database when the flow suspends`() { + bobNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedReceiveFlow(it) } + val aliceFlowId = aliceNode.services.startFlow(ReceiveFlow(bob)).id + mockNet.runNetwork() + aliceNode.database.transaction { + val checkpoint = aliceNode.internals.checkpointStorage.getCheckpoint(aliceFlowId) + assertEquals(ReceiveFlow.START_STEP.label, checkpoint!!.progressStep) + } + bobNode.database.transaction { + val checkpoints = bobNode.internals.checkpointStorage.checkpoints().single() + assertEquals(InitiatedReceiveFlow.START_STEP.label, checkpoints.progressStep) + } + } + private class ConditionalExceptionFlow(val otherPartySession: FlowSession, val sendPayload: Any) : FlowLogic<Unit>() { @Suspendable override fun call() { @@ -546,6 +668,176 @@ class FlowFrameworkTests { assertThat(result.getOrThrow()).isEqualTo("HelloHello") } + @Test(timeout=300_000) + fun `Checkpoint status changes to RUNNABLE when flow is loaded from checkpoint - FlowState Unstarted`() { + var firstExecution = true + var flowState: FlowState? = null + var dbCheckpointStatusBeforeSuspension: Checkpoint.FlowStatus? = null + var dbCheckpointStatusAfterSuspension: Checkpoint.FlowStatus? = null + var inMemoryCheckpointStatusBeforeSuspension: Checkpoint.FlowStatus? = null + val futureFiber = openFuture<FlowStateMachineImpl<*>>().toCompletableFuture() + + SuspendingFlow.hookBeforeCheckpoint = { + val flowFiber = this as? FlowStateMachineImpl<*> + flowState = flowFiber!!.transientState!!.value.checkpoint.flowState + + if (firstExecution) { + throw HospitalizeFlowException() + } else { + dbCheckpointStatusBeforeSuspension = aliceNode.internals.checkpointStorage.getCheckpoints().toList().single().second.status + currentDBSession().clear() // clear session as Hibernate with fails with 'org.hibernate.NonUniqueObjectException' once it tries to save a DBFlowCheckpoint upon checkpoint + inMemoryCheckpointStatusBeforeSuspension = flowFiber.transientState!!.value.checkpoint.status + + futureFiber.complete(flowFiber) + } + } + SuspendingFlow.hookAfterCheckpoint = { + dbCheckpointStatusAfterSuspension = aliceNode.internals.checkpointStorage.getCheckpointsToRun().toList().single() + .second.status + } + + assertFailsWith<TimeoutException> { + aliceNode.services.startFlow(SuspendingFlow()).resultFuture.getOrThrow(30.seconds) // wait till flow gets hospitalized + } + // flow is in hospital + assertTrue(flowState is FlowState.Unstarted) + val inMemoryHospitalizedCheckpointStatus = aliceNode.internals.smm.snapshot().first().transientState?.value?.checkpoint?.status + assertEquals(Checkpoint.FlowStatus.HOSPITALIZED, inMemoryHospitalizedCheckpointStatus) + aliceNode.database.transaction { + val checkpoint = aliceNode.internals.checkpointStorage.getCheckpoints().toList().single().second + assertEquals(Checkpoint.FlowStatus.HOSPITALIZED, checkpoint.status) + } + // restart Node - flow will be loaded from checkpoint + firstExecution = false + aliceNode = mockNet.restartNode(aliceNode) + futureFiber.get().resultFuture.getOrThrow() // wait until the flow has completed + // checkpoint states ,after flow retried, before and after suspension + assertEquals(Checkpoint.FlowStatus.HOSPITALIZED, dbCheckpointStatusBeforeSuspension) + assertEquals(Checkpoint.FlowStatus.RUNNABLE, inMemoryCheckpointStatusBeforeSuspension) + assertEquals(Checkpoint.FlowStatus.RUNNABLE, dbCheckpointStatusAfterSuspension) + } + + @Test(timeout=300_000) + fun `Checkpoint status changes to RUNNABLE when flow is loaded from checkpoint - FlowState Started`() { + var firstExecution = true + var flowState: FlowState? = null + var dbCheckpointStatus: Checkpoint.FlowStatus? = null + var inMemoryCheckpointStatus: Checkpoint.FlowStatus? = null + val futureFiber = openFuture<FlowStateMachineImpl<*>>().toCompletableFuture() + + SuspendingFlow.hookAfterCheckpoint = { + val flowFiber = this as? FlowStateMachineImpl<*> + flowState = flowFiber!!.transientState!!.value.checkpoint.flowState + + if (firstExecution) { + throw HospitalizeFlowException() + } else { + dbCheckpointStatus = aliceNode.internals.checkpointStorage.getCheckpoints().toList().single().second.status + inMemoryCheckpointStatus = flowFiber.transientState!!.value.checkpoint.status + + futureFiber.complete(flowFiber) + } + } + + assertFailsWith<TimeoutException> { + aliceNode.services.startFlow(SuspendingFlow()).resultFuture.getOrThrow(30.seconds) // wait till flow gets hospitalized + } + // flow is in hospital + assertTrue(flowState is FlowState.Started) + aliceNode.database.transaction { + val checkpoint = aliceNode.internals.checkpointStorage.getCheckpoints().toList().single().second + assertEquals(Checkpoint.FlowStatus.HOSPITALIZED, checkpoint.status) + } + // restart Node - flow will be loaded from checkpoint + firstExecution = false + aliceNode = mockNet.restartNode(aliceNode) + futureFiber.get().resultFuture.getOrThrow() // wait until the flow has completed + // checkpoint states ,after flow retried, after suspension + assertEquals(Checkpoint.FlowStatus.HOSPITALIZED, dbCheckpointStatus) + assertEquals(Checkpoint.FlowStatus.RUNNABLE, inMemoryCheckpointStatus) + } + + // Upon implementing CORDA-3681 unignore this test; DBFlowException is not currently integrated + @Ignore + @Test(timeout=300_000) + fun `Checkpoint is updated in DB with FAILED status and the error when flow fails`() { + var flowId: StateMachineRunId? = null + + val e = assertFailsWith<FlowException> { + val fiber = aliceNode.services.startFlow(ExceptionFlow { FlowException("Just an exception") }) + flowId = fiber.id + fiber.resultFuture.getOrThrow() + } + + aliceNode.database.transaction { + val checkpoint = aliceNode.internals.checkpointStorage.checkpoints().single() + assertEquals(Checkpoint.FlowStatus.FAILED, checkpoint.status) + + // assert all fields of DBFlowException + val persistedException = aliceNode.internals.checkpointStorage.getDBCheckpoint(flowId!!)!!.exceptionDetails + assertEquals(FlowException::class.java.name, persistedException!!.type) + assertEquals("Just an exception", persistedException.message) + assertEquals(ExceptionUtils.getStackTrace(e), persistedException.stackTrace) + assertEquals(null, persistedException.value) + } + } + + // Upon implementing CORDA-3681 unignore this test; DBFlowException is not currently integrated + @Ignore + @Test(timeout=300_000) + fun `Checkpoint is updated in DB with HOSPITALIZED status and the error when flow is kept for overnight observation` () { + var flowId: StateMachineRunId? = null + + assertFailsWith<TimeoutException> { + val fiber = aliceNode.services.startFlow(ExceptionFlow { HospitalizeFlowException("Overnight observation") }) + flowId = fiber.id + fiber.resultFuture.getOrThrow(10.seconds) + } + + aliceNode.database.transaction { + val checkpoint = aliceNode.internals.checkpointStorage.checkpoints().single() + assertEquals(Checkpoint.FlowStatus.HOSPITALIZED, checkpoint.status) + + // assert all fields of DBFlowException + val persistedException = aliceNode.internals.checkpointStorage.getDBCheckpoint(flowId!!)!!.exceptionDetails + assertEquals(HospitalizeFlowException::class.java.name, persistedException!!.type) + assertEquals("Overnight observation", persistedException.message) + assertEquals(null, persistedException.value) + } + } + + @Test(timeout=300_000) + fun `Checkpoint status and error in memory and in DB are not dirtied upon flow retry`() { + var firstExecution = true + var dbCheckpointStatus: Checkpoint.FlowStatus? = null + var inMemoryCheckpointStatus: Checkpoint.FlowStatus? = null + var persistedException: DBCheckpointStorage.DBFlowException? = null + + SuspendingFlow.hookAfterCheckpoint = { + if (firstExecution) { + firstExecution = false + throw SQLTransientConnectionException("connection is not available") + } else { + val flowFiber = this as? FlowStateMachineImpl<*> + dbCheckpointStatus = aliceNode.internals.checkpointStorage.getCheckpoints().toList().single().second.status + inMemoryCheckpointStatus = flowFiber!!.transientState!!.value.checkpoint.status + persistedException = aliceNode.internals.checkpointStorage.getDBCheckpoint(flowFiber.id)!!.exceptionDetails + } + } + + aliceNode.services.startFlow(SuspendingFlow()).resultFuture.getOrThrow() + // checkpoint states ,after flow retried, after suspension + assertEquals(Checkpoint.FlowStatus.RUNNABLE, dbCheckpointStatus) + assertEquals(Checkpoint.FlowStatus.RUNNABLE, inMemoryCheckpointStatus) + assertEquals(null, persistedException) + } + + private inline fun <reified T> DatabaseTransaction.findRecordsFromDatabase(): List<T> { + val criteria = session.criteriaBuilder.createQuery(T::class.java) + criteria.select(criteria.from(T::class.java)) + return session.createQuery(criteria).resultList + } + //region Helpers private val normalEnd = ExistingSessionMessage(SessionId(0), EndSessionMessage) // NormalSessionEnd(0) @@ -898,4 +1190,19 @@ internal class ExceptionFlow<E : Exception>(val exception: () -> E) : FlowLogic< exceptionThrown = exception() throw exceptionThrown } +} + +internal class SuspendingFlow : FlowLogic<Unit>() { + + companion object { + var hookBeforeCheckpoint: FlowStateMachine<*>.() -> Unit = {} + var hookAfterCheckpoint: FlowStateMachine<*>.() -> Unit = {} + } + + @Suspendable + override fun call() { + stateMachine.hookBeforeCheckpoint() + sleep(1.seconds) // flow checkpoints => checkpoint is in DB + stateMachine.hookAfterCheckpoint() + } } \ No newline at end of file diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowMetadataRecordingTest.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowMetadataRecordingTest.kt new file mode 100644 index 0000000000..ddac3afba8 --- /dev/null +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowMetadataRecordingTest.kt @@ -0,0 +1,556 @@ +package net.corda.node.services.statemachine + +import co.paralleluniverse.fibers.Suspendable +import net.corda.client.rpc.CordaRPCClient +import net.corda.core.context.InvocationContext +import net.corda.core.contracts.BelongsToContract +import net.corda.core.contracts.LinearState +import net.corda.core.contracts.SchedulableState +import net.corda.core.contracts.ScheduledActivity +import net.corda.core.contracts.StateRef +import net.corda.core.contracts.UniqueIdentifier +import net.corda.core.flows.FlowExternalAsyncOperation +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.FlowLogicRefFactory +import net.corda.core.flows.FlowSession +import net.corda.core.flows.InitiatedBy +import net.corda.core.flows.InitiatingFlow +import net.corda.core.flows.SchedulableFlow +import net.corda.core.flows.StartableByRPC +import net.corda.core.flows.StartableByService +import net.corda.core.flows.StateMachineRunId +import net.corda.core.identity.Party +import net.corda.core.internal.PLATFORM_VERSION +import net.corda.core.internal.uncheckedCast +import net.corda.core.messaging.startFlow +import net.corda.core.node.AppServiceHub +import net.corda.core.node.services.CordaService +import net.corda.core.serialization.CordaSerializable +import net.corda.core.serialization.SerializationDefaults +import net.corda.core.serialization.SingletonSerializeAsToken +import net.corda.core.serialization.deserialize +import net.corda.core.transactions.TransactionBuilder +import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.minutes +import net.corda.node.services.Permissions +import net.corda.node.services.persistence.DBCheckpointStorage +import net.corda.testing.contracts.DummyContract +import net.corda.testing.core.ALICE_NAME +import net.corda.testing.core.BOB_NAME +import net.corda.testing.core.singleIdentity +import net.corda.testing.driver.DriverParameters +import net.corda.testing.driver.driver +import net.corda.testing.node.User +import org.assertj.core.api.Assertions.assertThat +import org.junit.Before +import org.junit.Ignore +import org.junit.Test +import java.time.Instant +import java.util.concurrent.CompletableFuture +import java.util.concurrent.Executors +import java.util.concurrent.Semaphore +import java.util.function.Supplier +import kotlin.reflect.jvm.jvmName +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertNull +import kotlin.test.assertTrue + +class FlowMetadataRecordingTest { + + private val user = User("mark", "dadada", setOf(Permissions.all())) + private val string = "I must be delivered for 4.5" + private val someObject = SomeObject("Store me in the database please", 1234) + + @Before + fun before() { + MyFlow.hookAfterInitialCheckpoint = null + MyFlow.hookAfterSuspend = null + MyResponder.hookAfterInitialCheckpoint = null + MyFlowWithoutParameters.hookAfterInitialCheckpoint = null + } + + @Test(timeout = 300_000) + fun `rpc started flows have metadata recorded`() { + driver(DriverParameters(startNodesInProcess = true)) { + + val nodeAHandle = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user)).getOrThrow() + val nodeBHandle = startNode(providedName = BOB_NAME, rpcUsers = listOf(user)).getOrThrow() + + var flowId: StateMachineRunId? = null + var context: InvocationContext? = null + var metadata: DBCheckpointStorage.DBFlowMetadata? = null + MyFlow.hookAfterInitialCheckpoint = + { flowIdFromHook: StateMachineRunId, contextFromHook: InvocationContext, metadataFromHook: DBCheckpointStorage.DBFlowMetadata -> + flowId = flowIdFromHook + context = contextFromHook + metadata = metadataFromHook + } + + CordaRPCClient(nodeAHandle.rpcAddress).start(user.username, user.password).use { + it.proxy.startFlow( + ::MyFlow, + nodeBHandle.nodeInfo.singleIdentity(), + string, + someObject + ).returnValue.getOrThrow(1.minutes) + } + + metadata!!.let { + assertEquals(context!!.trace.invocationId.value, it.invocationId) + assertEquals(flowId!!.uuid.toString(), it.flowId) + assertEquals(MyFlow::class.java.name, it.flowName) + // Should be changed when [userSuppliedIdentifier] gets filled in future changes + assertNull(it.userSuppliedIdentifier) + assertEquals(DBCheckpointStorage.StartReason.RPC, it.startType) + assertEquals( + listOf(nodeBHandle.nodeInfo.singleIdentity(), string, someObject), + it.initialParameters.deserialize(context = SerializationDefaults.STORAGE_CONTEXT) + ) + assertThat(it.launchingCordapp).contains("custom-cordapp") + assertEquals(PLATFORM_VERSION, it.platformVersion) + assertEquals(user.username, it.startedBy) + assertEquals(context!!.trace.invocationId.timestamp, it.invocationInstant) + assertTrue(it.startInstant >= it.invocationInstant) + assertNull(it.finishInstant) + } + } + } + + @Test(timeout = 300_000) + fun `rpc started flows have metadata recorded - no parameters`() { + driver(DriverParameters(startNodesInProcess = true)) { + + val nodeAHandle = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user)).getOrThrow() + + var flowId: StateMachineRunId? = null + var context: InvocationContext? = null + var metadata: DBCheckpointStorage.DBFlowMetadata? = null + MyFlowWithoutParameters.hookAfterInitialCheckpoint = + { flowIdFromHook: StateMachineRunId, contextFromHook: InvocationContext, metadataFromHook: DBCheckpointStorage.DBFlowMetadata -> + flowId = flowIdFromHook + context = contextFromHook + metadata = metadataFromHook + } + + CordaRPCClient(nodeAHandle.rpcAddress).start(user.username, user.password).use { + it.proxy.startFlow(::MyFlowWithoutParameters).returnValue.getOrThrow(1.minutes) + } + + metadata!!.let { + assertEquals(context!!.trace.invocationId.value, it.invocationId) + assertEquals(flowId!!.uuid.toString(), it.flowId) + assertEquals(MyFlowWithoutParameters::class.java.name, it.flowName) + // Should be changed when [userSuppliedIdentifier] gets filled in future changes + assertNull(it.userSuppliedIdentifier) + assertEquals(DBCheckpointStorage.StartReason.RPC, it.startType) + assertEquals( + emptyList<Any?>(), + it.initialParameters.deserialize(context = SerializationDefaults.STORAGE_CONTEXT) + ) + assertThat(it.launchingCordapp).contains("custom-cordapp") + assertEquals(PLATFORM_VERSION, it.platformVersion) + assertEquals(user.username, it.startedBy) + assertEquals(context!!.trace.invocationId.timestamp, it.invocationInstant) + assertTrue(it.startInstant >= it.invocationInstant) + assertNull(it.finishInstant) + } + } + } + + @Test(timeout = 300_000) + fun `rpc started flows have their arguments removed from in-memory checkpoint after zero'th checkpoint`() { + driver(DriverParameters(startNodesInProcess = true)) { + + val nodeAHandle = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user)).getOrThrow() + val nodeBHandle = startNode(providedName = BOB_NAME, rpcUsers = listOf(user)).getOrThrow() + + var context: InvocationContext? = null + var metadata: DBCheckpointStorage.DBFlowMetadata? = null + MyFlow.hookAfterInitialCheckpoint = + { _, contextFromHook: InvocationContext, metadataFromHook: DBCheckpointStorage.DBFlowMetadata -> + context = contextFromHook + metadata = metadataFromHook + } + + var context2: InvocationContext? = null + var metadata2: DBCheckpointStorage.DBFlowMetadata? = null + MyFlow.hookAfterSuspend = + { contextFromHook: InvocationContext, metadataFromHook: DBCheckpointStorage.DBFlowMetadata -> + context2 = contextFromHook + metadata2 = metadataFromHook + } + + CordaRPCClient(nodeAHandle.rpcAddress).start(user.username, user.password).use { + it.proxy.startFlow( + ::MyFlow, + nodeBHandle.nodeInfo.singleIdentity(), + string, + someObject + ).returnValue.getOrThrow(1.minutes) + } + + assertEquals( + listOf(nodeBHandle.nodeInfo.singleIdentity(), string, someObject), + uncheckedCast<Any?, Array<Any?>>(context!!.arguments[1]).toList() + ) + assertEquals( + listOf(nodeBHandle.nodeInfo.singleIdentity(), string, someObject), + metadata!!.initialParameters.deserialize(context = SerializationDefaults.STORAGE_CONTEXT) + ) + + assertEquals( + emptyList(), + context2!!.arguments + ) + assertEquals( + listOf(nodeBHandle.nodeInfo.singleIdentity(), string, someObject), + metadata2!!.initialParameters.deserialize(context = SerializationDefaults.STORAGE_CONTEXT) + ) + } + } + + @Test(timeout = 300_000) + fun `initiated flows have metadata recorded`() { + driver(DriverParameters(startNodesInProcess = true)) { + + val nodeAHandle = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user)).getOrThrow() + val nodeBHandle = startNode(providedName = BOB_NAME, rpcUsers = listOf(user)).getOrThrow() + + var flowId: StateMachineRunId? = null + var context: InvocationContext? = null + var metadata: DBCheckpointStorage.DBFlowMetadata? = null + MyResponder.hookAfterInitialCheckpoint = + { flowIdFromHook: StateMachineRunId, contextFromHook: InvocationContext, metadataFromHook: DBCheckpointStorage.DBFlowMetadata -> + flowId = flowIdFromHook + context = contextFromHook + metadata = metadataFromHook + } + + CordaRPCClient(nodeAHandle.rpcAddress).start(user.username, user.password).use { + it.proxy.startFlow( + ::MyFlow, + nodeBHandle.nodeInfo.singleIdentity(), + string, + someObject + ).returnValue.getOrThrow(1.minutes) + } + + metadata!!.let { + assertEquals(context!!.trace.invocationId.value, it.invocationId) + assertEquals(flowId!!.uuid.toString(), it.flowId) + assertEquals(MyResponder::class.java.name, it.flowName) + assertNull(it.userSuppliedIdentifier) + assertEquals(DBCheckpointStorage.StartReason.INITIATED, it.startType) + assertEquals( + emptyList<Any?>(), + it.initialParameters.deserialize(context = SerializationDefaults.STORAGE_CONTEXT) + ) + assertThat(it.launchingCordapp).contains("custom-cordapp") + assertEquals(7, it.platformVersion) + assertEquals(nodeAHandle.nodeInfo.singleIdentity().name.toString(), it.startedBy) + assertEquals(context!!.trace.invocationId.timestamp, it.invocationInstant) + assertTrue(it.startInstant >= it.invocationInstant) + assertNull(it.finishInstant) + } + } + } + + @Test(timeout = 300_000) + fun `service started flows have metadata recorded`() { + driver(DriverParameters(startNodesInProcess = true)) { + + val nodeAHandle = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user)).getOrThrow() + val nodeBHandle = startNode(providedName = BOB_NAME, rpcUsers = listOf(user)).getOrThrow() + + var flowId: StateMachineRunId? = null + var context: InvocationContext? = null + var metadata: DBCheckpointStorage.DBFlowMetadata? = null + MyFlow.hookAfterInitialCheckpoint = + { flowIdFromHook: StateMachineRunId, contextFromHook: InvocationContext, metadataFromHook: DBCheckpointStorage.DBFlowMetadata -> + flowId = flowIdFromHook + context = contextFromHook + metadata = metadataFromHook + } + + CordaRPCClient(nodeAHandle.rpcAddress).start(user.username, user.password).use { + it.proxy.startFlow( + ::MyServiceStartingFlow, + nodeBHandle.nodeInfo.singleIdentity(), + string, + someObject + ).returnValue.getOrThrow(1.minutes) + } + + metadata!!.let { + assertEquals(context!!.trace.invocationId.value, it.invocationId) + assertEquals(flowId!!.uuid.toString(), it.flowId) + assertEquals(MyFlow::class.java.name, it.flowName) + assertNull(it.userSuppliedIdentifier) + assertEquals(DBCheckpointStorage.StartReason.SERVICE, it.startType) + assertEquals( + emptyList<Any?>(), + it.initialParameters.deserialize(context = SerializationDefaults.STORAGE_CONTEXT) + ) + assertThat(it.launchingCordapp).contains("custom-cordapp") + assertEquals(PLATFORM_VERSION, it.platformVersion) + assertEquals(MyService::class.java.name, it.startedBy) + assertEquals(context!!.trace.invocationId.timestamp, it.invocationInstant) + assertTrue(it.startInstant >= it.invocationInstant) + assertNull(it.finishInstant) + } + } + } + + @Test(timeout = 300_000) + fun `scheduled flows have metadata recorded`() { + driver(DriverParameters(startNodesInProcess = true)) { + + val nodeAHandle = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user)).getOrThrow() + val nodeBHandle = startNode(providedName = BOB_NAME, rpcUsers = listOf(user)).getOrThrow() + + val lock = Semaphore(0) + + var flowId: StateMachineRunId? = null + var context: InvocationContext? = null + var metadata: DBCheckpointStorage.DBFlowMetadata? = null + MyFlow.hookAfterInitialCheckpoint = + { flowIdFromHook: StateMachineRunId, contextFromHook: InvocationContext, metadataFromHook: DBCheckpointStorage.DBFlowMetadata -> + flowId = flowIdFromHook + context = contextFromHook + metadata = metadataFromHook + // Release the lock so the asserts can be processed + lock.release() + } + + CordaRPCClient(nodeAHandle.rpcAddress).start(user.username, user.password).use { + it.proxy.startFlow( + ::MyStartedScheduledFlow, + nodeBHandle.nodeInfo.singleIdentity(), + string, + someObject + ).returnValue.getOrThrow(1.minutes) + } + + // Block here until released in the hook + lock.acquire() + + metadata!!.let { + assertEquals(context!!.trace.invocationId.value, it.invocationId) + assertEquals(flowId!!.uuid.toString(), it.flowId) + assertEquals(MyFlow::class.java.name, it.flowName) + assertNull(it.userSuppliedIdentifier) + assertEquals(DBCheckpointStorage.StartReason.SCHEDULED, it.startType) + assertEquals( + emptyList<Any?>(), + it.initialParameters.deserialize(context = SerializationDefaults.STORAGE_CONTEXT) + ) + assertThat(it.launchingCordapp).contains("custom-cordapp") + assertEquals(PLATFORM_VERSION, it.platformVersion) + assertEquals("Scheduler", it.startedBy) + assertEquals(context!!.trace.invocationId.timestamp, it.invocationInstant) + assertTrue(it.startInstant >= it.invocationInstant) + assertNull(it.finishInstant) + } + } + } + + // Ignoring test since completed flows are not currently keeping their checkpoints in the database + @Ignore + @Test(timeout = 300_000) + fun `flows have their finish time recorded when completed`() { + driver(DriverParameters(startNodesInProcess = true)) { + + val nodeAHandle = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user)).getOrThrow() + val nodeBHandle = startNode(providedName = BOB_NAME, rpcUsers = listOf(user)).getOrThrow() + + var flowId: StateMachineRunId? = null + var metadata: DBCheckpointStorage.DBFlowMetadata? = null + MyFlow.hookAfterInitialCheckpoint = + { flowIdFromHook: StateMachineRunId, _, metadataFromHook: DBCheckpointStorage.DBFlowMetadata -> + flowId = flowIdFromHook + metadata = metadataFromHook + } + + val finishTime = CordaRPCClient(nodeAHandle.rpcAddress).start(user.username, user.password).use { + it.proxy.startFlow( + ::MyFlow, + nodeBHandle.nodeInfo.singleIdentity(), + string, + someObject + ).returnValue.getOrThrow(1.minutes) + it.proxy.startFlow( + ::GetFlowFinishTimeFlow, + flowId!! + ).returnValue.getOrThrow(1.minutes) + } + + metadata!!.let { + assertNull(it.finishInstant) + assertNotNull(finishTime) + assertTrue(finishTime!! >= it.startInstant) + } + } + } + + @InitiatingFlow + @StartableByRPC + @StartableByService + @SchedulableFlow + @Suppress("UNUSED_PARAMETER") + class MyFlow(private val party: Party, string: String, someObject: SomeObject) : + FlowLogic<Unit>() { + + companion object { + var hookAfterInitialCheckpoint: (( + flowId: StateMachineRunId, + context: InvocationContext, + metadata: DBCheckpointStorage.DBFlowMetadata + ) -> Unit)? = null + var hookAfterSuspend: (( + context: InvocationContext, + metadata: DBCheckpointStorage.DBFlowMetadata + ) -> Unit)? = null + } + + @Suspendable + override fun call() { + hookAfterInitialCheckpoint?.let { + it( + stateMachine.id, + stateMachine.context, + serviceHub.cordaService(MyService::class.java).findMetadata(stateMachine.id) + ) + } + initiateFlow(party).sendAndReceive<String>("Hello there") + hookAfterSuspend?.let { + it( + stateMachine.context, + serviceHub.cordaService(MyService::class.java).findMetadata(stateMachine.id) + ) + } + } + } + + @InitiatedBy(MyFlow::class) + class MyResponder(private val session: FlowSession) : FlowLogic<Unit>() { + + companion object { + var hookAfterInitialCheckpoint: (( + flowId: StateMachineRunId, + context: InvocationContext, + metadata: DBCheckpointStorage.DBFlowMetadata + ) -> Unit)? = null + } + + @Suspendable + override fun call() { + session.receive<String>() + hookAfterInitialCheckpoint?.let { + it( + stateMachine.id, + stateMachine.context, + serviceHub.cordaService(MyService::class.java).findMetadata(stateMachine.id) + ) + } + session.send("Hello there") + } + } + + @StartableByRPC + class MyFlowWithoutParameters : FlowLogic<Unit>() { + + companion object { + var hookAfterInitialCheckpoint: (( + flowId: StateMachineRunId, + context: InvocationContext, + metadata: DBCheckpointStorage.DBFlowMetadata + ) -> Unit)? = null + } + + @Suspendable + override fun call() { + hookAfterInitialCheckpoint?.let { + it( + stateMachine.id, + stateMachine.context, + serviceHub.cordaService(MyService::class.java).findMetadata(stateMachine.id) + ) + } + } + } + + @StartableByRPC + class MyServiceStartingFlow(private val party: Party, private val string: String, private val someObject: SomeObject) : + FlowLogic<Unit>() { + + @Suspendable + override fun call() { + await(object : FlowExternalAsyncOperation<Unit> { + override fun execute(deduplicationId: String): CompletableFuture<Unit> { + return serviceHub.cordaService(MyService::class.java).startFlow(party, string, someObject) + } + }) + } + } + + @StartableByRPC + class MyStartedScheduledFlow(private val party: Party, private val string: String, private val someObject: SomeObject) : + FlowLogic<Unit>() { + + @Suspendable + override fun call() { + val tx = TransactionBuilder(serviceHub.networkMapCache.notaryIdentities.first()).apply { + addOutputState(ScheduledState(party, string, someObject, listOf(ourIdentity))) + addCommand(DummyContract.Commands.Create(), ourIdentity.owningKey) + } + val stx = serviceHub.signInitialTransaction(tx) + serviceHub.recordTransactions(stx) + } + } + + @StartableByRPC + class GetFlowFinishTimeFlow(private val flowId: StateMachineRunId) : FlowLogic<Instant?>() { + @Suspendable + override fun call(): Instant? { + return serviceHub.cordaService(MyService::class.java).findMetadata(flowId).finishInstant + } + } + + @CordaService + class MyService(private val services: AppServiceHub) : SingletonSerializeAsToken() { + + private val executorService = Executors.newFixedThreadPool(1) + + fun findMetadata(flowId: StateMachineRunId): DBCheckpointStorage.DBFlowMetadata { + return services.database.transaction { + session.find(DBCheckpointStorage.DBFlowMetadata::class.java, flowId.uuid.toString()) + } + } + + fun startFlow(party: Party, string: String, someObject: SomeObject): CompletableFuture<Unit> { + return CompletableFuture.supplyAsync( + Supplier<Unit> { services.startFlow(MyFlow(party, string, someObject)).returnValue.getOrThrow() }, + executorService + ) + } + } + + @CordaSerializable + data class SomeObject(private val string: String, private val number: Int) + + @BelongsToContract(DummyContract::class) + data class ScheduledState( + val party: Party, + val string: String, + val someObject: SomeObject, + override val participants: List<Party>, + override val linearId: UniqueIdentifier = UniqueIdentifier() + ) : SchedulableState, LinearState { + + override fun nextScheduledActivity(thisStateRef: StateRef, flowLogicRefFactory: FlowLogicRefFactory): ScheduledActivity? { + val logicRef = flowLogicRefFactory.create(MyFlow::class.jvmName, party, string, someObject) + return ScheduledActivity(logicRef, Instant.now()) + } + } +} \ No newline at end of file diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowPausingTests.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowPausingTests.kt new file mode 100644 index 0000000000..1a0415892b --- /dev/null +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowPausingTests.kt @@ -0,0 +1,77 @@ +package net.corda.node.services.statemachine + +import co.paralleluniverse.fibers.Suspendable +import com.nhaarman.mockito_kotlin.doReturn +import com.nhaarman.mockito_kotlin.whenever +import net.corda.core.flows.FlowLogic +import net.corda.core.internal.FlowStateMachine +import net.corda.node.services.config.NodeConfiguration +import net.corda.testing.core.ALICE_NAME +import net.corda.testing.core.BOB_NAME +import net.corda.testing.node.internal.InternalMockNetwork +import net.corda.testing.node.internal.InternalMockNodeParameters +import net.corda.testing.node.internal.TestStartedNode +import net.corda.testing.node.internal.startFlow +import org.junit.After +import org.junit.Before +import org.junit.Test +import java.time.Duration +import kotlin.test.assertEquals + +class FlowPausingTests { + + companion object { + const val NUMBER_OF_FLOWS = 4 + const val SLEEP_TIME = 1000L + } + + private lateinit var mockNet: InternalMockNetwork + private lateinit var aliceNode: TestStartedNode + private lateinit var bobNode: TestStartedNode + + @Before + fun setUpMockNet() { + mockNet = InternalMockNetwork() + aliceNode = mockNet.createNode(InternalMockNodeParameters(legalName = ALICE_NAME)) + bobNode = mockNet.createNode(InternalMockNodeParameters(legalName = BOB_NAME)) + } + + @After + fun cleanUp() { + mockNet.stopNodes() + } + + private fun restartNode(node: TestStartedNode, smmStartMode: StateMachineManager.StartMode) : TestStartedNode { + val parameters = InternalMockNodeParameters(configOverrides = { + conf: NodeConfiguration -> + doReturn(smmStartMode).whenever(conf).smmStartMode + }) + return mockNet.restartNode(node, parameters = parameters) + } + + @Test(timeout = 300_000) + fun `All are paused when the node is restarted in safe start mode`() { + val flows = ArrayList<FlowStateMachine<Unit>>() + for (i in 1..NUMBER_OF_FLOWS) { + flows += aliceNode.services.startFlow(CheckpointingFlow()) + } + //All of the flows must not resume before the node restarts. + val restartedAlice = restartNode(aliceNode, StateMachineManager.StartMode.Safe) + assertEquals(0, restartedAlice.smm.snapshot().size) + //We need to wait long enough here so any running flows would finish. + Thread.sleep(NUMBER_OF_FLOWS * SLEEP_TIME) + restartedAlice.database.transaction { + for (flow in flows) { + val checkpoint = restartedAlice.internals.checkpointStorage.getCheckpoint(flow.id) + assertEquals(Checkpoint.FlowStatus.PAUSED, checkpoint!!.status) + } + } + } + + internal class CheckpointingFlow: FlowLogic<Unit>() { + @Suspendable + override fun call() { + sleep(Duration.ofMillis(SLEEP_TIME)) + } + } +} diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowSoftLocksTests.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowSoftLocksTests.kt new file mode 100644 index 0000000000..6f0fa3278c --- /dev/null +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowSoftLocksTests.kt @@ -0,0 +1,330 @@ +package net.corda.node.services.statemachine + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.contracts.Command +import net.corda.core.contracts.StateRef +import net.corda.core.flows.FlowLogic +import net.corda.core.identity.Party +import net.corda.core.internal.FlowIORequest +import net.corda.core.node.services.Vault +import net.corda.core.node.services.VaultService +import net.corda.core.node.services.queryBy +import net.corda.core.node.services.vault.QueryCriteria +import net.corda.core.transactions.TransactionBuilder +import net.corda.core.utilities.NonEmptySet +import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.seconds +import net.corda.finance.DOLLARS +import net.corda.finance.contracts.asset.Cash +import net.corda.node.services.statemachine.FlowSoftLocksTests.Companion.queryCashStates +import net.corda.node.services.vault.NodeVaultServiceTest +import net.corda.testing.contracts.DummyContract +import net.corda.testing.core.ALICE_NAME +import net.corda.testing.core.BOC_NAME +import net.corda.testing.core.TestIdentity +import net.corda.testing.core.singleIdentity +import net.corda.testing.internal.vault.VaultFiller +import net.corda.testing.node.internal.DUMMY_CONTRACTS_CORDAPP +import net.corda.testing.node.internal.FINANCE_CONTRACTS_CORDAPP +import net.corda.testing.node.internal.InternalMockNetwork +import net.corda.testing.node.internal.InternalMockNodeParameters +import net.corda.testing.node.internal.TestStartedNode +import net.corda.testing.node.internal.startFlow +import org.junit.After +import org.junit.Assert.assertEquals +import org.junit.Before +import org.junit.Test +import java.lang.IllegalStateException +import java.sql.SQLTransientConnectionException +import java.util.UUID +import kotlin.test.assertFailsWith +import kotlin.test.assertTrue + +class FlowSoftLocksTests { + + companion object { + fun queryCashStates(softLockingType: QueryCriteria.SoftLockingType, vaultService: VaultService) = + vaultService.queryBy<Cash.State>( + QueryCriteria.VaultQueryCriteria( + softLockingCondition = QueryCriteria.SoftLockingCondition( + softLockingType + ) + ) + ).states.map { it.ref }.toSet() + + val EMPTY_SET = emptySet<StateRef>() + } + + private lateinit var mockNet: InternalMockNetwork + private lateinit var aliceNode: TestStartedNode + private lateinit var notaryIdentity: Party + + @Before + fun setUpMockNet() { + mockNet = InternalMockNetwork( + cordappsForAllNodes = listOf(DUMMY_CONTRACTS_CORDAPP, FINANCE_CONTRACTS_CORDAPP) + ) + aliceNode = mockNet.createNode(InternalMockNodeParameters(legalName = ALICE_NAME)) + notaryIdentity = mockNet.defaultNotaryIdentity + } + + @After + fun cleanUp() { + mockNet.stopNodes() + } + + @Test(timeout=300_000) + fun `flow reserves fungible states with its own flow id and then manually releases them`() { + val vaultStates = fillVault(aliceNode, 10)!!.states.map { it.ref }.toSet() + val softLockActions = arrayOf( + SoftLockAction(SoftLockingAction.LOCK, null, vaultStates, ExpectedSoftLocks(vaultStates, QueryCriteria.SoftLockingType.LOCKED_ONLY), expectedSoftLockedStates = vaultStates), + SoftLockAction(SoftLockingAction.UNLOCK, null, vaultStates, ExpectedSoftLocks(vaultStates, QueryCriteria.SoftLockingType.UNLOCKED_ONLY), expectedSoftLockedStates = EMPTY_SET) + ) + val flowCompleted = aliceNode.services.startFlow(LockingUnlockingFlow(softLockActions)).resultFuture.getOrThrow(30.seconds) + assertTrue(flowCompleted) + assertEquals(vaultStates, queryCashStates(QueryCriteria.SoftLockingType.UNLOCKED_ONLY, aliceNode.services.vaultService)) + } + + @Test(timeout=300_000) + fun `flow reserves fungible states with its own flow id and by default releases them when completing`() { + val vaultStates = fillVault(aliceNode, 10)!!.states.map { it.ref }.toSet() + val softLockActions = arrayOf( + SoftLockAction(SoftLockingAction.LOCK, null, vaultStates, ExpectedSoftLocks(vaultStates, QueryCriteria.SoftLockingType.LOCKED_ONLY), expectedSoftLockedStates = vaultStates) + ) + val flowCompleted = aliceNode.services.startFlow(LockingUnlockingFlow(softLockActions)).resultFuture.getOrThrow(30.seconds) + assertTrue(flowCompleted) + assertEquals(vaultStates, queryCashStates(QueryCriteria.SoftLockingType.UNLOCKED_ONLY, aliceNode.services.vaultService)) + } + + @Test(timeout=300_000) + fun `flow reserves fungible states with its own flow id and by default releases them when errors`() { + val vaultStates = fillVault(aliceNode, 10)!!.states.map { it.ref }.toSet() + val softLockActions = arrayOf( + SoftLockAction( + SoftLockingAction.LOCK, + null, + vaultStates, + ExpectedSoftLocks(vaultStates, QueryCriteria.SoftLockingType.LOCKED_ONLY), + expectedSoftLockedStates = vaultStates, + exception = IllegalStateException("Throwing error after flow has soft locked states") + ) + ) + assertFailsWith<IllegalStateException> { + aliceNode.services.startFlow(LockingUnlockingFlow(softLockActions)).resultFuture.getOrThrow(30.seconds) + } + assertEquals(vaultStates, queryCashStates(QueryCriteria.SoftLockingType.UNLOCKED_ONLY, aliceNode.services.vaultService)) + LockingUnlockingFlow.throwOnlyOnce = true + } + + @Test(timeout=300_000) + fun `flow reserves fungible states with random id and then manually releases them`() { + val randomId = UUID.randomUUID() + val vaultStates = fillVault(aliceNode, 10)!!.states.map { it.ref }.toSet() + val softLockActions = arrayOf( + SoftLockAction(SoftLockingAction.LOCK, randomId, vaultStates, ExpectedSoftLocks(vaultStates, QueryCriteria.SoftLockingType.LOCKED_ONLY), expectedSoftLockedStates = EMPTY_SET), + SoftLockAction(SoftLockingAction.UNLOCK, randomId, vaultStates, ExpectedSoftLocks(vaultStates, QueryCriteria.SoftLockingType.UNLOCKED_ONLY), expectedSoftLockedStates = EMPTY_SET) + ) + val flowCompleted = aliceNode.services.startFlow(LockingUnlockingFlow(softLockActions)).resultFuture.getOrThrow(30.seconds) + assertTrue(flowCompleted) + assertEquals(vaultStates, queryCashStates(QueryCriteria.SoftLockingType.UNLOCKED_ONLY, aliceNode.services.vaultService)) + } + + @Test(timeout=300_000) + fun `flow reserves fungible states with random id and does not release them upon completing`() { + val randomId = UUID.randomUUID() + val vaultStates = fillVault(aliceNode, 10)!!.states.map { it.ref }.toSet() + val softLockActions = arrayOf( + SoftLockAction(SoftLockingAction.LOCK, randomId, vaultStates, ExpectedSoftLocks(vaultStates, QueryCriteria.SoftLockingType.LOCKED_ONLY), expectedSoftLockedStates = EMPTY_SET) + ) + val flowCompleted = aliceNode.services.startFlow(LockingUnlockingFlow(softLockActions)).resultFuture.getOrThrow(30.seconds) + assertTrue(flowCompleted) + assertEquals(vaultStates, queryCashStates(QueryCriteria.SoftLockingType.LOCKED_ONLY, aliceNode.services.vaultService)) + } + + @Test(timeout=300_000) + fun `flow only releases by default reserved states with flow id upon completing`() { + // lock with flow id and random id, dont manually release any. At the end, check that only flow id ones got unlocked. + val randomId = UUID.randomUUID() + val vaultStates = fillVault(aliceNode, 10)!!.states.map { it.ref }.toList() + val flowIdStates = vaultStates.subList(0, vaultStates.size / 2).toSet() + val randomIdStates = vaultStates.subList(vaultStates.size / 2, vaultStates.size).toSet() + val softLockActions = arrayOf( + SoftLockAction(SoftLockingAction.LOCK, null, flowIdStates, ExpectedSoftLocks(flowIdStates, QueryCriteria.SoftLockingType.LOCKED_ONLY), expectedSoftLockedStates = flowIdStates), + SoftLockAction(SoftLockingAction.LOCK, randomId, randomIdStates, ExpectedSoftLocks(flowIdStates + randomIdStates, QueryCriteria.SoftLockingType.LOCKED_ONLY), expectedSoftLockedStates = flowIdStates) + ) + val flowCompleted = aliceNode.services.startFlow(LockingUnlockingFlow(softLockActions)).resultFuture.getOrThrow(30.seconds) + assertTrue(flowCompleted) + assertEquals(flowIdStates, queryCashStates(QueryCriteria.SoftLockingType.UNLOCKED_ONLY, aliceNode.services.vaultService)) + assertEquals(randomIdStates, queryCashStates(QueryCriteria.SoftLockingType.LOCKED_ONLY, aliceNode.services.vaultService)) + } + + @Test(timeout=300_000) + fun `flow reserves fungible states with flow id and random id, then releases the flow id ones - assert the random id ones are still locked`() { + val randomId = UUID.randomUUID() + val vaultStates = fillVault(aliceNode, 10)!!.states.map { it.ref }.toList() + val flowIdStates = vaultStates.subList(0, vaultStates.size / 2).toSet() + val randomIdStates = vaultStates.subList(vaultStates.size / 2, vaultStates.size).toSet() + val softLockActions = arrayOf( + SoftLockAction(SoftLockingAction.LOCK, null, flowIdStates, ExpectedSoftLocks(flowIdStates, QueryCriteria.SoftLockingType.LOCKED_ONLY), expectedSoftLockedStates = flowIdStates), + SoftLockAction(SoftLockingAction.LOCK, randomId, randomIdStates, ExpectedSoftLocks(flowIdStates + randomIdStates, QueryCriteria.SoftLockingType.LOCKED_ONLY), expectedSoftLockedStates = flowIdStates), + SoftLockAction(SoftLockingAction.UNLOCK, null, flowIdStates, ExpectedSoftLocks(randomIdStates, QueryCriteria.SoftLockingType.LOCKED_ONLY), expectedSoftLockedStates = EMPTY_SET) + ) + val flowCompleted = aliceNode.services.startFlow(LockingUnlockingFlow(softLockActions)).resultFuture.getOrThrow(30.seconds) + assertTrue(flowCompleted) + assertEquals(flowIdStates, queryCashStates(QueryCriteria.SoftLockingType.UNLOCKED_ONLY, aliceNode.services.vaultService)) + assertEquals(randomIdStates, queryCashStates(QueryCriteria.SoftLockingType.LOCKED_ONLY, aliceNode.services.vaultService)) + } + + @Test(timeout=300_000) + fun `flow reserves fungible states with flow id and random id, then releases the random id ones - assert the flow id ones are still locked inside the flow`() { + val randomId = UUID.randomUUID() + val vaultStates = fillVault(aliceNode, 10)!!.states.map { it.ref }.toList() + val flowIdStates = vaultStates.subList(0, vaultStates.size / 2).toSet() + val randomIdStates = vaultStates.subList(vaultStates.size / 2, vaultStates.size).toSet() + val softLockActions = arrayOf( + SoftLockAction(SoftLockingAction.LOCK, null, flowIdStates, ExpectedSoftLocks(flowIdStates, QueryCriteria.SoftLockingType.LOCKED_ONLY), expectedSoftLockedStates = flowIdStates), + SoftLockAction(SoftLockingAction.LOCK, randomId, randomIdStates, ExpectedSoftLocks(flowIdStates + randomIdStates, QueryCriteria.SoftLockingType.LOCKED_ONLY), expectedSoftLockedStates = flowIdStates), + SoftLockAction(SoftLockingAction.UNLOCK, randomId, randomIdStates, ExpectedSoftLocks(flowIdStates, QueryCriteria.SoftLockingType.LOCKED_ONLY), expectedSoftLockedStates = flowIdStates) + ) + val flowCompleted = aliceNode.services.startFlow(LockingUnlockingFlow(softLockActions)).resultFuture.getOrThrow(30.seconds) + assertTrue(flowCompleted) + assertEquals(flowIdStates + randomIdStates, queryCashStates(QueryCriteria.SoftLockingType.UNLOCKED_ONLY, aliceNode.services.vaultService)) + } + + @Test(timeout=300_000) + fun `flow soft locks fungible state upon creation`() { + var lockedStates = 0 + CreateFungibleStateFLow.hook = { vaultService -> + lockedStates = vaultService.queryBy<NodeVaultServiceTest.FungibleFoo>( + QueryCriteria.VaultQueryCriteria(softLockingCondition = QueryCriteria.SoftLockingCondition(QueryCriteria.SoftLockingType.LOCKED_ONLY)) + ).states.size + } + aliceNode.services.startFlow(CreateFungibleStateFLow()).resultFuture.getOrThrow(30.seconds) + assertEquals(1, lockedStates) + } + + @Test(timeout=300_000) + fun `when flow soft locks, then errors and retries from previous checkpoint, softLockedStates are reverted back correctly`() { + val randomId = UUID.randomUUID() + val vaultStates = fillVault(aliceNode, 10)!!.states.map { it.ref }.toList() + val flowIdStates = vaultStates.subList(0, vaultStates.size / 2).toSet() + val randomIdStates = vaultStates.subList(vaultStates.size / 2, vaultStates.size).toSet() + val softLockActions = arrayOf( + SoftLockAction(SoftLockingAction.LOCK, null, flowIdStates, ExpectedSoftLocks(flowIdStates, QueryCriteria.SoftLockingType.LOCKED_ONLY), expectedSoftLockedStates = flowIdStates), + SoftLockAction( + SoftLockingAction.LOCK, + randomId, + randomIdStates, + ExpectedSoftLocks(flowIdStates + randomIdStates, QueryCriteria.SoftLockingType.LOCKED_ONLY), + expectedSoftLockedStates = flowIdStates, + doCheckpoint = true + ), + SoftLockAction(SoftLockingAction.UNLOCK, null, flowIdStates, ExpectedSoftLocks(randomIdStates, QueryCriteria.SoftLockingType.LOCKED_ONLY), expectedSoftLockedStates = EMPTY_SET), + SoftLockAction( + SoftLockingAction.UNLOCK, + randomId, + randomIdStates, + ExpectedSoftLocks(EMPTY_SET, QueryCriteria.SoftLockingType.LOCKED_ONLY), + expectedSoftLockedStates = EMPTY_SET, + exception = SQLTransientConnectionException("connection is not available") + ) + ) + val flowCompleted = aliceNode.services.startFlow(LockingUnlockingFlow(softLockActions)).resultFuture.getOrThrow(30.seconds) + assertTrue(flowCompleted) + assertEquals(flowIdStates + randomIdStates, queryCashStates(QueryCriteria.SoftLockingType.UNLOCKED_ONLY, aliceNode.services.vaultService)) + LockingUnlockingFlow.throwOnlyOnce = true + } + + private fun fillVault(node: TestStartedNode, thisManyStates: Int): Vault<Cash.State>? { + val bankNode = mockNet.createPartyNode(BOC_NAME) + val bank = bankNode.info.singleIdentity() + val cashIssuer = bank.ref(1) + return node.database.transaction { + VaultFiller(node.services, TestIdentity(notaryIdentity.name, 20), notaryIdentity).fillWithSomeTestCash( + 100.DOLLARS, + bankNode.services, + thisManyStates, + thisManyStates, + cashIssuer + ) + } + } +} + +enum class SoftLockingAction { + LOCK, + UNLOCK +} + +data class ExpectedSoftLocks(val states: Set<StateRef>, val queryCriteria: QueryCriteria.SoftLockingType) + +/** + * If [lockId] is set to null, it will be populated with the flowId within the flow. + */ +data class SoftLockAction(val action: SoftLockingAction, + var lockId: UUID?, + val states: Set<StateRef>, + val expectedSoftLocks: ExpectedSoftLocks, + val expectedSoftLockedStates: Set<StateRef>, + val exception: Exception? = null, + val doCheckpoint: Boolean = false) + +internal class LockingUnlockingFlow(private val softLockActions: Array<SoftLockAction>): FlowLogic<Boolean>() { + + companion object { + var throwOnlyOnce = true + } + + @Suspendable + override fun call(): Boolean { + for (softLockAction in softLockActions) { + if (softLockAction.lockId == null) { softLockAction.lockId = stateMachine.id.uuid } + + when (softLockAction.action) { + SoftLockingAction.LOCK -> { + serviceHub.vaultService.softLockReserve(softLockAction.lockId!!, NonEmptySet.copyOf(softLockAction.states)) + // We checkpoint here so that, upon retrying to assert state after reserving + if (softLockAction.doCheckpoint) { + stateMachine.suspend(FlowIORequest.ForceCheckpoint, false) + } + assertEquals(softLockAction.expectedSoftLocks.states, queryCashStates(softLockAction.expectedSoftLocks.queryCriteria, serviceHub.vaultService)) + assertEquals(softLockAction.expectedSoftLockedStates, (stateMachine as? FlowStateMachineImpl<*>)!!.softLockedStates) + } + SoftLockingAction.UNLOCK -> { + serviceHub.vaultService.softLockRelease(softLockAction.lockId!!, NonEmptySet.copyOf(softLockAction.states)) + assertEquals(softLockAction.expectedSoftLocks.states, queryCashStates(softLockAction.expectedSoftLocks.queryCriteria, serviceHub.vaultService)) + assertEquals(softLockAction.expectedSoftLockedStates, (stateMachine as? FlowStateMachineImpl<*>)!!.softLockedStates) + } + } + + softLockAction.exception?.let { + if (throwOnlyOnce) { + throwOnlyOnce = false + throw it + } + } + } + return true + } +} + +internal class CreateFungibleStateFLow : FlowLogic<Unit>() { + + companion object { + var hook: ((VaultService) -> Unit)? = null + } + + @Suspendable + override fun call() { + val issuer = serviceHub.myInfo.legalIdentities.first() + val notary = serviceHub.networkMapCache.notaryIdentities[0] + val fungibleState = NodeVaultServiceTest.FungibleFoo(100.DOLLARS, listOf(issuer)) + val txCommand = Command(DummyContract.Commands.Create(), issuer.owningKey) + val txBuilder = TransactionBuilder(notary) + .addOutputState(fungibleState, DummyContract.PROGRAM_ID) + .addCommand(txCommand) + val signedTx = serviceHub.signInitialTransaction(txBuilder) + serviceHub.recordTransactions(signedTx) + hook?.invoke(serviceHub.vaultService) + } +} \ No newline at end of file diff --git a/node/src/test/kotlin/net/corda/node/services/vault/NodeVaultServiceTest.kt b/node/src/test/kotlin/net/corda/node/services/vault/NodeVaultServiceTest.kt index 7b7a051491..320137e8b4 100644 --- a/node/src/test/kotlin/net/corda/node/services/vault/NodeVaultServiceTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/vault/NodeVaultServiceTest.kt @@ -406,6 +406,47 @@ class NodeVaultServiceTest { } } + @Test(timeout=300_000) + fun `softLockRelease - correctly releases n locked states`() { + fun queryStates(softLockingType: SoftLockingType) = + vaultService.queryBy<Cash.State>(VaultQueryCriteria(softLockingCondition = SoftLockingCondition(softLockingType))).states + + database.transaction { + vaultFiller.fillWithSomeTestCash(100.DOLLARS, issuerServices, 100, DUMMY_CASH_ISSUER) + } + + val softLockId = UUID.randomUUID() + val lockCount = NodeVaultService.MAX_SQL_IN_CLAUSE_SET * 2 + database.transaction { + assertEquals(100, queryStates(SoftLockingType.UNLOCKED_ONLY).size) + val unconsumedStates = vaultService.queryBy<Cash.State>().states + + val lockSet = mutableListOf<StateRef>() + for (i in 0 until lockCount) { + lockSet.add(unconsumedStates[i].ref) + } + vaultService.softLockReserve(softLockId, NonEmptySet.copyOf(lockSet)) + assertEquals(lockCount, queryStates(SoftLockingType.LOCKED_ONLY).size) + + val unlockSet0 = mutableSetOf<StateRef>() + for (i in 0 until NodeVaultService.MAX_SQL_IN_CLAUSE_SET + 1) { + unlockSet0.add(lockSet[i]) + } + vaultService.softLockRelease(softLockId, NonEmptySet.copyOf(unlockSet0)) + assertEquals(NodeVaultService.MAX_SQL_IN_CLAUSE_SET - 1, queryStates(SoftLockingType.LOCKED_ONLY).size) + + val unlockSet1 = mutableSetOf<StateRef>() + for (i in NodeVaultService.MAX_SQL_IN_CLAUSE_SET + 1 until NodeVaultService.MAX_SQL_IN_CLAUSE_SET + 3) { + unlockSet1.add(lockSet[i]) + } + vaultService.softLockRelease(softLockId, NonEmptySet.copyOf(unlockSet1)) + assertEquals(NodeVaultService.MAX_SQL_IN_CLAUSE_SET - 1 - 2, queryStates(SoftLockingType.LOCKED_ONLY).size) + + vaultService.softLockRelease(softLockId) // release the rest + assertEquals(100, queryStates(SoftLockingType.UNLOCKED_ONLY).size) + } + } + @Test(timeout=300_000) fun `unconsumedStatesForSpending exact amount`() { database.transaction { diff --git a/samples/trader-demo/src/integration-test/kotlin/net/corda/traderdemo/TraderDemoTest.kt b/samples/trader-demo/src/integration-test/kotlin/net/corda/traderdemo/TraderDemoTest.kt index e9f91f7cbe..46e954ca06 100644 --- a/samples/trader-demo/src/integration-test/kotlin/net/corda/traderdemo/TraderDemoTest.kt +++ b/samples/trader-demo/src/integration-test/kotlin/net/corda/traderdemo/TraderDemoTest.kt @@ -19,7 +19,7 @@ import net.corda.testing.driver.* import net.corda.testing.node.TestCordapp import net.corda.testing.node.User import net.corda.testing.node.internal.FINANCE_CORDAPPS -import net.corda.testing.node.internal.assertCheckpoints +import net.corda.testing.node.internal.assertUncompletedCheckpoints import net.corda.testing.node.internal.poll import net.corda.traderdemo.flow.CommercialPaperIssueFlow import net.corda.traderdemo.flow.SellerFlow @@ -100,7 +100,7 @@ class TraderDemoTest { val saleFuture = seller.rpc.startFlow(::SellerFlow, buyer.nodeInfo.singleIdentity(), 5.DOLLARS).returnValue buyer.rpc.stateMachinesFeed().updates.toBlocking().first() // wait until initiated flow starts buyer.stop() - assertCheckpoints(DUMMY_BANK_A_NAME, 1) + assertUncompletedCheckpoints(DUMMY_BANK_A_NAME, 1) val buyer2 = startNode(providedName = DUMMY_BANK_A_NAME, customOverrides = mapOf("p2pAddress" to buyer.p2pAddress.toString())).getOrThrow() saleFuture.getOrThrow() assertThat(buyer2.rpc.getCashBalance(USD)).isEqualTo(95.DOLLARS) diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InternalMockNetwork.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InternalMockNetwork.kt index aea0e9d5d0..857a171373 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InternalMockNetwork.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InternalMockNetwork.kt @@ -638,6 +638,7 @@ private fun mockNodeConfiguration(certificatesDirectory: Path): NodeConfiguratio doReturn(NetworkParameterAcceptanceSettings()).whenever(it).networkParameterAcceptanceSettings doReturn(rigorousMock<ConfigurationWithOptions>()).whenever(it).configurationWithOptions doReturn(2).whenever(it).flowExternalOperationThreadPoolSize + doReturn(StateMachineManager.StartMode.ExcludingPaused).whenever(it).smmStartMode } } diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InternalTestUtils.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InternalTestUtils.kt index 0793a740ba..1f4b3ab632 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InternalTestUtils.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InternalTestUtils.kt @@ -23,6 +23,7 @@ import net.corda.core.utilities.millis import net.corda.core.utilities.seconds import net.corda.node.services.api.StartedNodeServices import net.corda.node.services.messaging.Message +import net.corda.node.services.statemachine.Checkpoint import net.corda.testing.driver.DriverDSL import net.corda.testing.driver.NodeHandle import net.corda.testing.internal.chooseIdentity @@ -273,9 +274,10 @@ fun CordaRPCOps.waitForShutdown(): Observable<Unit> { return completable } -fun DriverDSL.assertCheckpoints(name: CordaX500Name, expected: Long) { +fun DriverDSL.assertUncompletedCheckpoints(name: CordaX500Name, expected: Long) { + val sqlStatement = "select count(*) from node_checkpoints where status not in (${Checkpoint.FlowStatus.COMPLETED.ordinal})" DriverManager.getConnection("jdbc:h2:file:${baseDirectory(name) / "persistence"}", "sa", "").use { connection -> - connection.createStatement().executeQuery("select count(*) from NODE_CHECKPOINTS").use { rs -> + connection.createStatement().executeQuery(sqlStatement).use { rs -> rs.next() assertThat(rs.getLong(1)).isEqualTo(expected) }