diff --git a/core-tests/src/test/kotlin/net/corda/coretests/flows/FlowIsKilledTest.kt b/core-tests/src/test/kotlin/net/corda/coretests/flows/FlowIsKilledTest.kt index 35a1639714..e2ad9d69b2 100644 --- a/core-tests/src/test/kotlin/net/corda/coretests/flows/FlowIsKilledTest.kt +++ b/core-tests/src/test/kotlin/net/corda/coretests/flows/FlowIsKilledTest.kt @@ -15,7 +15,6 @@ import net.corda.core.messaging.startFlow import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.minutes import net.corda.core.utilities.seconds -import net.corda.node.services.statemachine.Checkpoint import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.BOB_NAME import net.corda.testing.core.CHARLIE_NAME @@ -77,9 +76,10 @@ class FlowIsKilledTest { assertEquals(11, AFlowThatWantsToDieAndKillsItsFriends.position) assertTrue(AFlowThatWantsToDieAndKillsItsFriendsResponder.receivedKilledExceptions[BOB_NAME]!!) assertTrue(AFlowThatWantsToDieAndKillsItsFriendsResponder.receivedKilledExceptions[CHARLIE_NAME]!!) - assertEquals(1, alice.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) - assertEquals(2, bob.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) - assertEquals(1, bob.rpc.startFlow(::GetNumberOfFailedCheckpointsFlow).returnValue.getOrThrow(20.seconds)) + val aliceCheckpoints = alice.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) + assertEquals(1, aliceCheckpoints) + val bobCheckpoints = bob.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) + assertEquals(1, bobCheckpoints) } } } @@ -109,9 +109,10 @@ class FlowIsKilledTest { handle.returnValue.getOrThrow(1.minutes) } assertEquals(11, AFlowThatGetsMurderedByItsFriendResponder.position) - assertEquals(2, alice.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) - assertEquals(1, alice.rpc.startFlow(::GetNumberOfFailedCheckpointsFlow).returnValue.getOrThrow(20.seconds)) - assertEquals(1, bob.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) + val aliceCheckpoints = alice.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) + assertEquals(1, aliceCheckpoints) + val bobCheckpoints = bob.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) + assertEquals(1, bobCheckpoints) } } @@ -360,18 +361,4 @@ class FlowIsKilledTest { } } } - - @StartableByRPC - class GetNumberOfFailedCheckpointsFlow : FlowLogic() { - override fun call(): Long { - return serviceHub.jdbcSession() - .prepareStatement("select count(*) from node_checkpoints where status = ${Checkpoint.FlowStatus.FAILED.ordinal}") - .use { ps -> - ps.executeQuery().use { rs -> - rs.next() - rs.getLong(1) - } - } - } - } } \ No newline at end of file diff --git a/core-tests/src/test/kotlin/net/corda/coretests/flows/WithFinality.kt b/core-tests/src/test/kotlin/net/corda/coretests/flows/WithFinality.kt index 5e1daa8a09..9ed9b04679 100644 --- a/core-tests/src/test/kotlin/net/corda/coretests/flows/WithFinality.kt +++ b/core-tests/src/test/kotlin/net/corda/coretests/flows/WithFinality.kt @@ -6,7 +6,7 @@ import com.natpryce.hamkrest.Matcher import com.natpryce.hamkrest.equalTo import net.corda.core.flows.* import net.corda.core.identity.Party -import net.corda.core.internal.FlowStateMachine +import net.corda.core.internal.FlowStateMachineHandle import net.corda.core.messaging.CordaRPCOps import net.corda.core.messaging.FlowHandle import net.corda.core.messaging.startFlow @@ -16,7 +16,7 @@ import net.corda.testing.node.internal.TestStartedNode interface WithFinality : WithMockNet { //region Operations - fun TestStartedNode.finalise(stx: SignedTransaction, vararg recipients: Party): FlowStateMachine { + fun TestStartedNode.finalise(stx: SignedTransaction, vararg recipients: Party): FlowStateMachineHandle { return startFlowAndRunNetwork(FinalityInvoker(stx, recipients.toSet(), emptySet())) } diff --git a/core-tests/src/test/kotlin/net/corda/coretests/flows/WithMockNet.kt b/core-tests/src/test/kotlin/net/corda/coretests/flows/WithMockNet.kt index 8069b6d807..4a4574112e 100644 --- a/core-tests/src/test/kotlin/net/corda/coretests/flows/WithMockNet.kt +++ b/core-tests/src/test/kotlin/net/corda/coretests/flows/WithMockNet.kt @@ -6,7 +6,7 @@ import net.corda.core.flows.FlowLogic import net.corda.core.identity.CordaX500Name import net.corda.core.identity.Party import net.corda.core.identity.PartyAndCertificate -import net.corda.core.internal.FlowStateMachine +import net.corda.core.internal.FlowStateMachineHandle import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.TransactionBuilder import net.corda.testing.core.makeUnique @@ -48,12 +48,12 @@ interface WithMockNet { /** * Start a flow */ - fun TestStartedNode.startFlow(logic: FlowLogic): FlowStateMachine = services.startFlow(logic) + fun TestStartedNode.startFlow(logic: FlowLogic): FlowStateMachineHandle = services.startFlow(logic) /** * Start a flow and run the network immediately afterwards */ - fun TestStartedNode.startFlowAndRunNetwork(logic: FlowLogic): FlowStateMachine = + fun TestStartedNode.startFlowAndRunNetwork(logic: FlowLogic): FlowStateMachineHandle = startFlow(logic).andRunNetwork() fun TestStartedNode.createConfidentialIdentity(party: Party) = diff --git a/core/src/main/kotlin/net/corda/core/context/InvocationContext.kt b/core/src/main/kotlin/net/corda/core/context/InvocationContext.kt index ef90810b05..06e9210801 100644 --- a/core/src/main/kotlin/net/corda/core/context/InvocationContext.kt +++ b/core/src/main/kotlin/net/corda/core/context/InvocationContext.kt @@ -24,7 +24,8 @@ data class InvocationContext( val actor: Actor?, val externalTrace: Trace? = null, val impersonatedActor: Actor? = null, - val arguments: List = emptyList() + val arguments: List? = emptyList(), // 'arguments' is nullable so that a - >= 4.6 version - RPC client can be backwards compatible against - < 4.6 version - nodes + val clientId: String? = null ) { constructor( @@ -49,8 +50,9 @@ data class InvocationContext( actor: Actor? = null, externalTrace: Trace? = null, impersonatedActor: Actor? = null, - arguments: List = emptyList() - ) = InvocationContext(origin, trace, actor, externalTrace, impersonatedActor, arguments) + arguments: List = emptyList(), + clientId: String? = null + ) = InvocationContext(origin, trace, actor, externalTrace, impersonatedActor, arguments, clientId) /** * Creates an [InvocationContext] with [InvocationOrigin.RPC] origin. @@ -113,7 +115,8 @@ data class InvocationContext( actor = actor, externalTrace = externalTrace, impersonatedActor = impersonatedActor, - arguments = arguments + arguments = arguments, + clientId = clientId ) } } diff --git a/core/src/main/kotlin/net/corda/core/flows/ResultSerializationException.kt b/core/src/main/kotlin/net/corda/core/flows/ResultSerializationException.kt new file mode 100644 index 0000000000..34e463d1ac --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/flows/ResultSerializationException.kt @@ -0,0 +1,11 @@ +package net.corda.core.flows + +import net.corda.core.CordaRuntimeException +import net.corda.core.serialization.internal.MissingSerializerException + +/** + * Thrown whenever a flow result cannot be serialized when attempting to save it in the database + */ +class ResultSerializationException private constructor(message: String?) : CordaRuntimeException(message) { + constructor(e: MissingSerializerException): this(e.message) +} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt b/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt index c8a96da1cd..42db120f36 100644 --- a/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt +++ b/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt @@ -11,10 +11,19 @@ import net.corda.core.node.ServiceHub import net.corda.core.serialization.SerializedBytes import org.slf4j.Logger +@DeleteForDJVM +@DoNotImplement +interface FlowStateMachineHandle { + val logic: FlowLogic? + val id: StateMachineRunId + val resultFuture: CordaFuture + val clientId: String? +} + /** This is an internal interface that is implemented by code in the node module. You should look at [FlowLogic]. */ @DeleteForDJVM @DoNotImplement -interface FlowStateMachine { +interface FlowStateMachine : FlowStateMachineHandle { @Suspendable fun suspend(ioRequest: FlowIORequest, maySkipCheckpoint: Boolean): SUSPENDRETURN @@ -38,14 +47,11 @@ interface FlowStateMachine { fun updateTimedFlowTimeout(timeoutSeconds: Long) - val logic: FlowLogic val serviceHub: ServiceHub val logger: Logger - val id: StateMachineRunId - val resultFuture: CordaFuture val context: InvocationContext val ourIdentity: Party val ourSenderUUID: String? val creationTime: Long val isKilled: Boolean -} +} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/messaging/CordaRPCOps.kt b/core/src/main/kotlin/net/corda/core/messaging/CordaRPCOps.kt index 6098b0c707..826f52f0d9 100644 --- a/core/src/main/kotlin/net/corda/core/messaging/CordaRPCOps.kt +++ b/core/src/main/kotlin/net/corda/core/messaging/CordaRPCOps.kt @@ -264,6 +264,25 @@ interface CordaRPCOps : RPCOps { @RPCReturnsObservables fun startFlowDynamic(logicType: Class>, vararg args: Any?): FlowHandle + /** + * Start the given flow with the given arguments and a [clientId]. + * + * The flow's result/ exception will be available for the client to re-connect and retrieve even after the flow's lifetime, + * by re-calling [startFlowDynamicWithClientId] with the same [clientId]. The [logicType] and [args] will be ignored if the + * [clientId] matches an existing flow. If you don't have the original values, consider using [reattachFlowWithClientId]. + * + * Upon calling [removeClientId], the node's resources holding the result/ exception will be freed and the result/ exception will + * no longer be available. + * + * [logicType] must be annotated with [net.corda.core.flows.StartableByRPC]. + * + * @param clientId The client id to relate the flow to (or is already related to if the flow already exists) + * @param logicType The [FlowLogic] to start + * @param args The arguments to pass to the flow + */ + @RPCReturnsObservables + fun startFlowDynamicWithClientId(clientId: String, logicType: Class>, vararg args: Any?): FlowHandleWithClientId + /** * Start the given flow with the given arguments, returning an [Observable] with a single observation of the * result of running the flow. [logicType] must be annotated with [net.corda.core.flows.StartableByRPC]. @@ -278,6 +297,30 @@ interface CordaRPCOps : RPCOps { */ fun killFlow(id: StateMachineRunId): Boolean + /** + * Reattach to an existing flow that was started with [startFlowDynamicWithClientId] and has a [clientId]. + * + * If there is a flow matching the [clientId] then its result or exception is returned. + * + * When there is no flow matching the [clientId] then [null] is returned directly (not a future/[FlowHandleWithClientId]). + * + * Calling [reattachFlowWithClientId] after [removeClientId] with the same [clientId] will cause the function to return [null] as + * the result/exception of the flow will no longer be available. + * + * @param clientId The client id relating to an existing flow + */ + @RPCReturnsObservables + fun reattachFlowWithClientId(clientId: String): FlowHandleWithClientId? + + /** + * Removes a flow's [clientId] to result/ exception mapping. If the mapping is of a running flow, then the mapping will not get removed. + * + * See [startFlowDynamicWithClientId] for more information. + * + * @return whether the mapping was removed. + */ + fun removeClientId(clientId: String): Boolean + /** Returns Node's NodeInfo, assuming this will not change while the node is running. */ fun nodeInfo(): NodeInfo @@ -542,6 +585,79 @@ inline fun > CordaRPCOps.startFlow arg5: F ): FlowHandle = startFlowDynamic(R::class.java, arg0, arg1, arg2, arg3, arg4, arg5) +/** + * Extension function for type safe invocation of flows from Kotlin, with [clientId]. + */ +@Suppress("unused") +inline fun > CordaRPCOps.startFlowWithClientId( + clientId: String, + @Suppress("unused_parameter") + flowConstructor: () -> R +): FlowHandleWithClientId = startFlowDynamicWithClientId(clientId, R::class.java) + +@Suppress("unused") +inline fun > CordaRPCOps.startFlowWithClientId( + clientId: String, + @Suppress("unused_parameter") + flowConstructor: (A) -> R, + arg0: A +): FlowHandleWithClientId = startFlowDynamicWithClientId(clientId, R::class.java, arg0) + +@Suppress("unused") +inline fun > CordaRPCOps.startFlowWithClientId( + clientId: String, + @Suppress("unused_parameter") + flowConstructor: (A, B) -> R, + arg0: A, + arg1: B +): FlowHandleWithClientId = startFlowDynamicWithClientId(clientId, R::class.java, arg0, arg1) + +@Suppress("unused") +inline fun > CordaRPCOps.startFlowWithClientId( + clientId: String, + @Suppress("unused_parameter") + flowConstructor: (A, B, C) -> R, + arg0: A, + arg1: B, + arg2: C +): FlowHandleWithClientId = startFlowDynamicWithClientId(clientId, R::class.java, arg0, arg1, arg2) + +@Suppress("unused") +inline fun > CordaRPCOps.startFlowWithClientId( + clientId: String, + @Suppress("unused_parameter") + flowConstructor: (A, B, C, D) -> R, + arg0: A, + arg1: B, + arg2: C, + arg3: D +): FlowHandleWithClientId = startFlowDynamicWithClientId(clientId, R::class.java, arg0, arg1, arg2, arg3) + +@Suppress("unused") +inline fun > CordaRPCOps.startFlowWithClientId( + clientId: String, + @Suppress("unused_parameter") + flowConstructor: (A, B, C, D, E) -> R, + arg0: A, + arg1: B, + arg2: C, + arg3: D, + arg4: E +): FlowHandleWithClientId = startFlowDynamicWithClientId(clientId, R::class.java, arg0, arg1, arg2, arg3, arg4) + +@Suppress("unused") +inline fun > CordaRPCOps.startFlowWithClientId( + clientId: String, + @Suppress("unused_parameter") + flowConstructor: (A, B, C, D, E, F) -> R, + arg0: A, + arg1: B, + arg2: C, + arg3: D, + arg4: E, + arg5: F +): FlowHandleWithClientId = startFlowDynamicWithClientId(clientId, R::class.java, arg0, arg1, arg2, arg3, arg4, arg5) + /** * Extension function for type safe invocation of flows from Kotlin, with progress tracking enabled. */ diff --git a/core/src/main/kotlin/net/corda/core/messaging/FlowHandle.kt b/core/src/main/kotlin/net/corda/core/messaging/FlowHandle.kt index 4d540d69c8..88bff4fe6d 100644 --- a/core/src/main/kotlin/net/corda/core/messaging/FlowHandle.kt +++ b/core/src/main/kotlin/net/corda/core/messaging/FlowHandle.kt @@ -28,6 +28,14 @@ interface FlowHandle : AutoCloseable { override fun close() } +interface FlowHandleWithClientId : FlowHandle { + + /** + * The [clientId] with which the client has started the flow. + */ + val clientId: String +} + /** * [FlowProgressHandle] is a serialisable handle for the started flow, parameterised by the type of the flow's return value. */ @@ -66,6 +74,18 @@ data class FlowHandleImpl( } } +@CordaSerializable +data class FlowHandleWithClientIdImpl( + override val id: StateMachineRunId, + override val returnValue: CordaFuture, + override val clientId: String) : FlowHandleWithClientId { + + // Remember to add @Throws to FlowHandle.close() if this throws an exception. + override fun close() { + returnValue.cancel(false) + } +} + @CordaSerializable data class FlowProgressHandleImpl @JvmOverloads constructor( override val id: StateMachineRunId, diff --git a/detekt-baseline.xml b/detekt-baseline.xml index 401dfbe681..cf8fc1beac 100644 --- a/detekt-baseline.xml +++ b/detekt-baseline.xml @@ -640,6 +640,9 @@ LongParameterList:CordaRPCOps.kt$( @Suppress("UNUSED_PARAMETER") flowConstructor: (A, B, C, D, E, F) -> R, arg0: A, arg1: B, arg2: C, arg3: D, arg4: E, arg5: F ) LongParameterList:CordaRPCOps.kt$( @Suppress("unused_parameter") flowConstructor: (A, B, C, D, E) -> R, arg0: A, arg1: B, arg2: C, arg3: D, arg4: E ) LongParameterList:CordaRPCOps.kt$( @Suppress("unused_parameter") flowConstructor: (A, B, C, D, E, F) -> R, arg0: A, arg1: B, arg2: C, arg3: D, arg4: E, arg5: F ) + LongParameterList:CordaRPCOps.kt$( clientId: String, @Suppress("unused_parameter") flowConstructor: (A, B, C, D) -> R, arg0: A, arg1: B, arg2: C, arg3: D ) + LongParameterList:CordaRPCOps.kt$( clientId: String, @Suppress("unused_parameter") flowConstructor: (A, B, C, D, E) -> R, arg0: A, arg1: B, arg2: C, arg3: D, arg4: E ) + LongParameterList:CordaRPCOps.kt$( clientId: String, @Suppress("unused_parameter") flowConstructor: (A, B, C, D, E, F) -> R, arg0: A, arg1: B, arg2: C, arg3: D, arg4: E, arg5: F ) LongParameterList:Driver.kt$DriverParameters$( isDebug: Boolean, driverDirectory: Path, portAllocation: PortAllocation, debugPortAllocation: PortAllocation, systemProperties: Map<String, String>, useTestClock: Boolean, startNodesInProcess: Boolean, waitForAllNodesToFinish: Boolean, notarySpecs: List<NotarySpec>, extraCordappPackagesToScan: List<String>, jmxPolicy: JmxPolicy, networkParameters: NetworkParameters ) LongParameterList:Driver.kt$DriverParameters$( isDebug: Boolean, driverDirectory: Path, portAllocation: PortAllocation, debugPortAllocation: PortAllocation, systemProperties: Map<String, String>, useTestClock: Boolean, startNodesInProcess: Boolean, waitForAllNodesToFinish: Boolean, notarySpecs: List<NotarySpec>, extraCordappPackagesToScan: List<String>, jmxPolicy: JmxPolicy, networkParameters: NetworkParameters, cordappsForAllNodes: Set<TestCordapp>? ) LongParameterList:DriverDSL.kt$DriverDSL$( defaultParameters: NodeParameters = NodeParameters(), providedName: CordaX500Name? = defaultParameters.providedName, rpcUsers: List<User> = defaultParameters.rpcUsers, verifierType: VerifierType = defaultParameters.verifierType, customOverrides: Map<String, Any?> = defaultParameters.customOverrides, startInSameProcess: Boolean? = defaultParameters.startInSameProcess, maximumHeapSize: String = defaultParameters.maximumHeapSize ) @@ -1261,7 +1264,6 @@ SpreadOperator:ConfigUtilities.kt$(*pairs) SpreadOperator:Configuration.kt$Configuration.Validation.Error$(*(containingPath.toList() + this.containingPath).toTypedArray()) SpreadOperator:ContractJarTestUtils.kt$ContractJarTestUtils$(jarName, *contractNames.map{ "${it.replace(".", "/")}.class" }.toTypedArray()) - SpreadOperator:CordaRPCOpsImpl.kt$CordaRPCOpsImpl$(logicType, context(), *args) SpreadOperator:CordaX500Name.kt$CordaX500Name.Companion$(*Locale.getISOCountries(), unspecifiedCountry) SpreadOperator:CustomCordapp.kt$CustomCordapp$(*classes.map { it.name }.toTypedArray()) SpreadOperator:CustomCordapp.kt$CustomCordapp$(*packages.map { it.replace('.', '/') }.toTypedArray()) diff --git a/node/src/integration-test/kotlin/net/corda/node/flows/FlowReloadAfterCheckpointTest.kt b/node/src/integration-test/kotlin/net/corda/node/flows/FlowReloadAfterCheckpointTest.kt index 981fcc3ba3..dd51ad621d 100644 --- a/node/src/integration-test/kotlin/net/corda/node/flows/FlowReloadAfterCheckpointTest.kt +++ b/node/src/integration-test/kotlin/net/corda/node/flows/FlowReloadAfterCheckpointTest.kt @@ -519,4 +519,8 @@ class FlowReloadAfterCheckpointTest { stateMachine.suspend(FlowIORequest.ForceCheckpoint, false) } } +} + +internal class BrokenMap(delegate: MutableMap = mutableMapOf()) : MutableMap by delegate { + override fun put(key: K, value: V): V? = throw IllegalStateException("Broken on purpose") } \ No newline at end of file diff --git a/node/src/integration-test/kotlin/net/corda/node/flows/FlowRetryTest.kt b/node/src/integration-test/kotlin/net/corda/node/flows/FlowRetryTest.kt index 1dda43c691..499dcfd232 100644 --- a/node/src/integration-test/kotlin/net/corda/node/flows/FlowRetryTest.kt +++ b/node/src/integration-test/kotlin/net/corda/node/flows/FlowRetryTest.kt @@ -161,7 +161,7 @@ class FlowRetryTest { } @Test(timeout = 300_000) - fun `General external exceptions are not retried and propagate`() { + fun `general external exceptions are not retried and propagate`() { driver(DriverParameters(startNodesInProcess = true, notarySpecs = emptyList(), cordappsForAllNodes = cordapps)) { val (nodeAHandle, nodeBHandle) = listOf(ALICE_NAME, BOB_NAME) @@ -176,10 +176,7 @@ class FlowRetryTest { ).returnValue.getOrThrow() } assertEquals(0, GeneralExternalFailureFlow.retryCount) - assertEquals( - 1, - nodeAHandle.rpc.startFlow(::GetCheckpointNumberOfStatusFlow, Checkpoint.FlowStatus.FAILED).returnValue.get() - ) + assertEquals(0, nodeAHandle.rpc.startFlow(::GetCheckpointNumberOfStatusFlow, Checkpoint.FlowStatus.FAILED).returnValue.get()) } } @@ -304,10 +301,6 @@ enum class Step { First, BeforeInitiate, AfterInitiate, AfterInitiateSendReceive data class Visited(val sessionNum: Int, val iterationNum: Int, val step: Step) -class BrokenMap(delegate: MutableMap = mutableMapOf()) : MutableMap by delegate { - override fun put(key: K, value: V): V? = throw IllegalStateException("Broken on purpose") -} - @StartableByRPC class RetryFlow() : FlowLogic(), IdempotentFlow { companion object { diff --git a/node/src/integration-test/kotlin/net/corda/node/flows/FlowWithClientIdTest.kt b/node/src/integration-test/kotlin/net/corda/node/flows/FlowWithClientIdTest.kt new file mode 100644 index 0000000000..ec1bff03e5 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/node/flows/FlowWithClientIdTest.kt @@ -0,0 +1,174 @@ +package net.corda.node.flows + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.CordaRuntimeException +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.ResultSerializationException +import net.corda.core.flows.StartableByRPC +import net.corda.core.internal.concurrent.OpenFuture +import net.corda.core.internal.concurrent.openFuture +import net.corda.core.messaging.startFlowWithClientId +import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.seconds +import net.corda.testing.driver.DriverParameters +import net.corda.testing.driver.driver +import org.assertj.core.api.Assertions +import org.junit.Before +import org.junit.Test +import rx.Observable +import java.util.UUID +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertNotEquals +import kotlin.test.assertTrue + +class FlowWithClientIdTest { + + @Before + fun reset() { + ResultFlow.hook = null + } + + @Test(timeout=300_000) + fun `start flow with client id`() { + val clientId = UUID.randomUUID().toString() + driver(DriverParameters(startNodesInProcess = true, cordappsForAllNodes = emptySet())) { + val nodeA = startNode().getOrThrow() + val flowHandle = nodeA.rpc.startFlowWithClientId(clientId, ::ResultFlow, 5) + + assertEquals(5, flowHandle.returnValue.getOrThrow(20.seconds)) + assertEquals(clientId, flowHandle.clientId) + } + } + + @Test(timeout=300_000) + fun `remove client id`() { + val clientId = UUID.randomUUID().toString() + var counter = 0 + ResultFlow.hook = { counter++ } + driver(DriverParameters(startNodesInProcess = true, cordappsForAllNodes = emptySet())) { + val nodeA = startNode().getOrThrow() + + val flowHandle0 = nodeA.rpc.startFlowWithClientId(clientId, ::ResultFlow, 5) + flowHandle0.returnValue.getOrThrow(20.seconds) + + val removed = nodeA.rpc.removeClientId(clientId) + + val flowHandle1 = nodeA.rpc.startFlowWithClientId(clientId, ::ResultFlow, 5) + flowHandle1.returnValue.getOrThrow(20.seconds) + + assertTrue(removed) + assertNotEquals(flowHandle0.id, flowHandle1.id) + assertEquals(flowHandle0.clientId, flowHandle1.clientId) + assertEquals(2, counter) // this asserts that 2 different flows were spawned indeed + } + } + + @Test(timeout=300_000) + fun `on flow unserializable result a 'CordaRuntimeException' is thrown containing in its message the unserializable type`() { + val clientId = UUID.randomUUID().toString() + driver(DriverParameters(startNodesInProcess = true, cordappsForAllNodes = emptySet())) { + val nodeA = startNode().getOrThrow() + + val e = assertFailsWith { + nodeA.rpc.startFlowWithClientId(clientId, ::UnserializableResultFlow).returnValue.getOrThrow(20.seconds) + } + + val errorMessage = e.message + assertTrue(errorMessage!!.contains("Unable to create an object serializer for type class ${UnserializableResultFlow.UNSERIALIZABLE_OBJECT::class.java.name}")) + } + } + + @Test(timeout=300_000) + fun `If flow has an unserializable exception result then it gets converted into a 'CordaRuntimeException'`() { + ResultFlow.hook = { + throw UnserializableException() + } + val clientId = UUID.randomUUID().toString() + driver(DriverParameters(startNodesInProcess = true, cordappsForAllNodes = emptySet())) { + val node = startNode().getOrThrow() + + // the below exception is the one populating the flows future. It will get serialized on node jvm, sent over to client and + // deserialized on client's. + val e0 = assertFailsWith { + node.rpc.startFlowWithClientId(clientId, ::ResultFlow, 5).returnValue.getOrThrow() + } + + // the below exception is getting fetched from the database first, and deserialized on node's jvm, + // then serialized on node jvm, sent over to client and deserialized on client's. + val e1 = assertFailsWith { + node.rpc.startFlowWithClientId(clientId, ::ResultFlow, 5).returnValue.getOrThrow() + } + + assertTrue(e0 !is UnserializableException) + assertTrue(e1 !is UnserializableException) + assertEquals(UnserializableException::class.java.name, e0.originalExceptionClassName) + assertEquals(UnserializableException::class.java.name, e1.originalExceptionClassName) + } + } + + @Test(timeout=300_000) + fun `reattachFlowWithClientId can retrieve results from existing flow future`() { + val clientId = UUID.randomUUID().toString() + driver(DriverParameters(startNodesInProcess = true, cordappsForAllNodes = emptySet())) { + val nodeA = startNode().getOrThrow() + val flowHandle = nodeA.rpc.startFlowWithClientId(clientId, ::ResultFlow, 5) + val reattachedFlowHandle = nodeA.rpc.reattachFlowWithClientId(clientId) + assertEquals(5, flowHandle.returnValue.getOrThrow(20.seconds)) + assertEquals(clientId, flowHandle.clientId) + assertEquals(flowHandle.id, reattachedFlowHandle?.id) + assertEquals(flowHandle.returnValue.get(), reattachedFlowHandle?.returnValue?.get()) + } + } + + @Test(timeout = 300_000) + fun `reattachFlowWithClientId can retrieve exception from existing flow future`() { + ResultFlow.hook = { throw IllegalStateException("Bla bla bla") } + val clientId = UUID.randomUUID().toString() + driver(DriverParameters(startNodesInProcess = true, cordappsForAllNodes = emptySet())) { + val nodeA = startNode().getOrThrow() + val flowHandle = nodeA.rpc.startFlowWithClientId(clientId, ::ResultFlow, 5) + val reattachedFlowHandle = nodeA.rpc.reattachFlowWithClientId(clientId) + + // [CordaRunTimeException] returned because [IllegalStateException] is not serializable + Assertions.assertThatExceptionOfType(CordaRuntimeException::class.java).isThrownBy { + flowHandle.returnValue.getOrThrow(20.seconds) + }.withMessage("java.lang.IllegalStateException: Bla bla bla") + + Assertions.assertThatExceptionOfType(CordaRuntimeException::class.java).isThrownBy { + reattachedFlowHandle?.returnValue?.getOrThrow() + }.withMessage("java.lang.IllegalStateException: Bla bla bla") + } + } +} + +@StartableByRPC +internal class ResultFlow(private val result: A): FlowLogic() { + companion object { + var hook: (() -> Unit)? = null + var suspendableHook: FlowLogic? = null + } + + @Suspendable + override fun call(): A { + hook?.invoke() + suspendableHook?.let { subFlow(it) } + return result + } +} + +@StartableByRPC +internal class UnserializableResultFlow: FlowLogic>>() { + companion object { + val UNSERIALIZABLE_OBJECT = openFuture>().also { it.set(Observable.empty())} + } + + @Suspendable + override fun call(): OpenFuture> { + return UNSERIALIZABLE_OBJECT + } +} + +internal class UnserializableException( + val unserializableObject: BrokenMap = BrokenMap() +): CordaRuntimeException("123") \ No newline at end of file diff --git a/node/src/integration-test/kotlin/net/corda/node/flows/KillFlowTest.kt b/node/src/integration-test/kotlin/net/corda/node/flows/KillFlowTest.kt index 9af99f30f1..f29248d161 100644 --- a/node/src/integration-test/kotlin/net/corda/node/flows/KillFlowTest.kt +++ b/node/src/integration-test/kotlin/net/corda/node/flows/KillFlowTest.kt @@ -26,7 +26,6 @@ import net.corda.core.utilities.seconds import net.corda.finance.DOLLARS import net.corda.finance.contracts.asset.Cash import net.corda.finance.flows.CashIssueFlow -import net.corda.node.services.statemachine.Checkpoint import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.BOB_NAME import net.corda.testing.core.CHARLIE_NAME @@ -62,7 +61,8 @@ class KillFlowTest { assertFailsWith { handle.returnValue.getOrThrow(1.minutes) } - assertEquals(1, rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) + val checkpoints = rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) + assertEquals(1, checkpoints) } } } @@ -89,11 +89,12 @@ class KillFlowTest { AFlowThatGetsMurderedWhenItTriesToSuspendAndSomehowKillsItsFriendsResponder.locks.forEach { it.value.acquire() } assertTrue(AFlowThatGetsMurderedWhenItTriesToSuspendAndSomehowKillsItsFriendsResponder.receivedKilledExceptions[BOB_NAME]!!) assertTrue(AFlowThatGetsMurderedWhenItTriesToSuspendAndSomehowKillsItsFriendsResponder.receivedKilledExceptions[CHARLIE_NAME]!!) - assertEquals(1, rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) - assertEquals(2, bob.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) - assertEquals(1, bob.rpc.startFlow(::GetNumberOfFailedCheckpointsFlow).returnValue.getOrThrow(20.seconds)) - assertEquals(2, charlie.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) - assertEquals(1, charlie.rpc.startFlow(::GetNumberOfFailedCheckpointsFlow).returnValue.getOrThrow(20.seconds)) + val aliceCheckpoints = rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) + assertEquals(1, aliceCheckpoints) + val bobCheckpoints = bob.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) + assertEquals(1, bobCheckpoints) + val charlieCheckpoints = charlie.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) + assertEquals(1, charlieCheckpoints) } } } @@ -113,7 +114,8 @@ class KillFlowTest { } assertTrue(time < 1.minutes.toMillis(), "It should at a minimum, take less than a minute to kill this flow") assertTrue(time < 5.seconds.toMillis(), "Really, it should take less than a few seconds to kill a flow") - assertEquals(1, rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) + val checkpoints = rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) + assertEquals(1, checkpoints) } } } @@ -151,7 +153,8 @@ class KillFlowTest { } assertTrue(time < 1.minutes.toMillis(), "It should at a minimum, take less than a minute to kill this flow") assertTrue(time < 5.seconds.toMillis(), "Really, it should take less than a few seconds to kill a flow") - assertEquals(1, startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) + val checkpoints = startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) + assertEquals(1, checkpoints) } @Test(timeout = 300_000) @@ -169,7 +172,8 @@ class KillFlowTest { } assertTrue(time < 1.minutes.toMillis(), "It should at a minimum, take less than a minute to kill this flow") assertTrue(time < 5.seconds.toMillis(), "Really, it should take less than a few seconds to kill a flow") - assertEquals(1, rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) + val checkpoints = rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) + assertEquals(1, checkpoints) } } } @@ -189,7 +193,8 @@ class KillFlowTest { } assertTrue(time < 1.minutes.toMillis(), "It should at a minimum, take less than a minute to kill this flow") assertTrue(time < 5.seconds.toMillis(), "Really, it should take less than a few seconds to kill a flow") - assertEquals(1, rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) + val checkpoints = rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) + assertEquals(1, checkpoints) } } } @@ -219,11 +224,12 @@ class KillFlowTest { } assertTrue(AFlowThatGetsMurderedAndSomehowKillsItsFriendsResponder.receivedKilledExceptions[BOB_NAME]!!) assertTrue(AFlowThatGetsMurderedAndSomehowKillsItsFriendsResponder.receivedKilledExceptions[CHARLIE_NAME]!!) - assertEquals(1, rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) - assertEquals(2, bob.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) - assertEquals(1, bob.rpc.startFlow(::GetNumberOfFailedCheckpointsFlow).returnValue.getOrThrow(20.seconds)) - assertEquals(2, charlie.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) - assertEquals(1, charlie.rpc.startFlow(::GetNumberOfFailedCheckpointsFlow).returnValue.getOrThrow(20.seconds)) + val aliceCheckpoints = rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) + assertEquals(1, aliceCheckpoints) + val bobCheckpoints = bob.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) + assertEquals(1, bobCheckpoints) + val charlieCheckpoints = charlie.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) + assertEquals(1, charlieCheckpoints) } } } @@ -253,11 +259,12 @@ class KillFlowTest { assertTrue(AFlowThatGetsMurderedByItsFriend.receivedKilledException) assertFalse(AFlowThatGetsMurderedByItsFriendResponder.receivedKilledExceptions[BOB_NAME]!!) assertTrue(AFlowThatGetsMurderedByItsFriendResponder.receivedKilledExceptions[CHARLIE_NAME]!!) - assertEquals(2, alice.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) - assertEquals(1, alice.rpc.startFlow(::GetNumberOfFailedCheckpointsFlow).returnValue.getOrThrow(20.seconds)) - assertEquals(1, bob.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) - assertEquals(2, charlie.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds)) - assertEquals(1, charlie.rpc.startFlow(::GetNumberOfFailedCheckpointsFlow).returnValue.getOrThrow(20.seconds)) + val aliceCheckpoints = alice.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) + assertEquals(1, aliceCheckpoints) + val bobCheckpoints = bob.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) + assertEquals(1, bobCheckpoints) + val charlieCheckpoints = charlie.rpc.startFlow(::GetNumberOfCheckpointsFlow).returnValue.getOrThrow(20.seconds) + assertEquals(1, charlieCheckpoints) } } @@ -590,18 +597,4 @@ class KillFlowTest { } } } - - @StartableByRPC - class GetNumberOfFailedCheckpointsFlow : FlowLogic() { - override fun call(): Long { - return serviceHub.jdbcSession() - .prepareStatement("select count(*) from node_checkpoints where status = ${Checkpoint.FlowStatus.FAILED.ordinal}") - .use { ps -> - ps.executeQuery().use { rs -> - rs.next() - rs.getLong(1) - } - } - } - } } \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt index 1ff856ddef..3c22064c9c 100644 --- a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt @@ -28,7 +28,7 @@ import net.corda.core.identity.CordaX500Name import net.corda.core.identity.Party import net.corda.core.identity.PartyAndCertificate import net.corda.core.internal.AttachmentTrustCalculator -import net.corda.core.internal.FlowStateMachine +import net.corda.core.internal.FlowStateMachineHandle import net.corda.core.internal.NODE_INFO_DIRECTORY import net.corda.core.internal.NamedCacheFactory import net.corda.core.internal.NetworkParametersStorage @@ -351,7 +351,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val checkpointStorage = DBCheckpointStorage(DBCheckpointPerformanceRecorder(services.monitoringService.metrics), platformClock) @Suppress("LeakingThis") val smm = makeStateMachineManager() - val flowStarter = FlowStarterImpl(smm, flowLogicRefFactory) + val flowStarter = FlowStarterImpl(smm, flowLogicRefFactory, DBCheckpointStorage.MAX_CLIENT_ID_LENGTH) private val schedulerService = NodeSchedulerService( platformClock, database, @@ -1374,13 +1374,22 @@ internal fun logVendorString(database: CordaPersistence, log: Logger) { } // TODO Move this into its own file -class FlowStarterImpl(private val smm: StateMachineManager, private val flowLogicRefFactory: FlowLogicRefFactory) : FlowStarter { - override fun startFlow(event: ExternalEvent.ExternalStartFlowEvent): CordaFuture> { - smm.deliverExternalEvent(event) +class FlowStarterImpl( + private val smm: StateMachineManager, + private val flowLogicRefFactory: FlowLogicRefFactory, + private val maxClientIdLength: Int +) : FlowStarter { + override fun startFlow(event: ExternalEvent.ExternalStartFlowEvent): CordaFuture> { + val clientId = event.context.clientId + if (clientId != null && clientId.length > maxClientIdLength) { + throw IllegalArgumentException("clientId cannot be longer than $maxClientIdLength characters") + } else { + smm.deliverExternalEvent(event) + } return event.future } - override fun startFlow(logic: FlowLogic, context: InvocationContext): CordaFuture> { + override fun startFlow(logic: FlowLogic, context: InvocationContext): CordaFuture> { val startFlowEvent = object : ExternalEvent.ExternalStartFlowEvent, DeduplicationHandler { override fun insideDatabaseTransaction() {} @@ -1397,12 +1406,12 @@ class FlowStarterImpl(private val smm: StateMachineManager, private val flowLogi override val context: InvocationContext get() = context - override fun wireUpFuture(flowFuture: CordaFuture>) { + override fun wireUpFuture(flowFuture: CordaFuture>) { _future.captureLater(flowFuture) } - private val _future = openFuture>() - override val future: CordaFuture> + private val _future = openFuture>() + override val future: CordaFuture> get() = _future } return startFlow(startFlowEvent) @@ -1411,7 +1420,7 @@ class FlowStarterImpl(private val smm: StateMachineManager, private val flowLogi override fun invokeFlowAsync( logicType: Class>, context: InvocationContext, - vararg args: Any?): CordaFuture> { + vararg args: Any?): CordaFuture> { val logicRef = flowLogicRefFactory.createForRPC(logicType, *args) val logic: FlowLogic = uncheckedCast(flowLogicRefFactory.toFlowLogic(logicRef)) return startFlow(logic, context) diff --git a/node/src/main/kotlin/net/corda/node/internal/AppServiceHubImpl.kt b/node/src/main/kotlin/net/corda/node/internal/AppServiceHubImpl.kt index 9c90173cb2..922316045b 100644 --- a/node/src/main/kotlin/net/corda/node/internal/AppServiceHubImpl.kt +++ b/node/src/main/kotlin/net/corda/node/internal/AppServiceHubImpl.kt @@ -3,7 +3,7 @@ package net.corda.node.internal import net.corda.core.context.InvocationContext import net.corda.core.flows.FlowLogic import net.corda.core.flows.StartableByService -import net.corda.core.internal.FlowStateMachine +import net.corda.core.internal.FlowStateMachineHandle import net.corda.core.internal.concurrent.doneFuture import net.corda.core.messaging.FlowHandle import net.corda.core.messaging.FlowHandleImpl @@ -78,7 +78,7 @@ internal class AppServiceHubImpl(private val serviceHub: S return FlowProgressHandleImpl( id = stateMachine.id, returnValue = stateMachine.resultFuture, - progress = stateMachine.logic.track()?.updates ?: Observable.empty() + progress = stateMachine.logic?.track()?.updates ?: Observable.empty() ) } @@ -95,7 +95,7 @@ internal class AppServiceHubImpl(private val serviceHub: S } } - private fun startFlowChecked(flow: FlowLogic): FlowStateMachine { + private fun startFlowChecked(flow: FlowLogic): FlowStateMachineHandle { val logicType = flow.javaClass require(logicType.isAnnotationPresent(StartableByService::class.java)) { "${logicType.name} was not designed for starting by a CordaService" } // TODO check service permissions diff --git a/node/src/main/kotlin/net/corda/node/internal/CordaRPCOpsImpl.kt b/node/src/main/kotlin/net/corda/node/internal/CordaRPCOpsImpl.kt index 6d058aaf37..1a758a4050 100644 --- a/node/src/main/kotlin/net/corda/node/internal/CordaRPCOpsImpl.kt +++ b/node/src/main/kotlin/net/corda/node/internal/CordaRPCOpsImpl.kt @@ -19,7 +19,7 @@ import net.corda.core.identity.AbstractParty import net.corda.core.identity.CordaX500Name import net.corda.core.identity.Party import net.corda.core.internal.AttachmentTrustInfo -import net.corda.core.internal.FlowStateMachine +import net.corda.core.internal.FlowStateMachineHandle import net.corda.core.internal.RPC_UPLOADER import net.corda.core.internal.STRUCTURAL_STEP_PREFIX import net.corda.core.internal.messaging.InternalCordaRPCOps @@ -27,6 +27,8 @@ import net.corda.core.internal.sign import net.corda.core.messaging.DataFeed import net.corda.core.messaging.FlowHandle import net.corda.core.messaging.FlowHandleImpl +import net.corda.core.messaging.FlowHandleWithClientId +import net.corda.core.messaging.FlowHandleWithClientIdImpl import net.corda.core.messaging.FlowProgressHandle import net.corda.core.messaging.FlowProgressHandleImpl import net.corda.core.messaging.ParametersUpdateInfo @@ -170,6 +172,14 @@ internal class CordaRPCOpsImpl( override fun killFlow(id: StateMachineRunId): Boolean = smm.killFlow(id) + override fun reattachFlowWithClientId(clientId: String): FlowHandleWithClientId? { + return smm.reattachFlowWithClientId(clientId)?.run { + FlowHandleWithClientIdImpl(id = id, returnValue = resultFuture, clientId = clientId) + } + } + + override fun removeClientId(clientId: String): Boolean = smm.removeClientId(clientId) + override fun stateMachinesFeed(): DataFeed, StateMachineUpdate> { val (allStateMachines, changes) = smm.track() @@ -236,27 +246,38 @@ internal class CordaRPCOpsImpl( } override fun startTrackedFlowDynamic(logicType: Class>, vararg args: Any?): FlowProgressHandle { - val stateMachine = startFlow(logicType, args) + val stateMachine = startFlow(logicType, context(), args) return FlowProgressHandleImpl( id = stateMachine.id, returnValue = stateMachine.resultFuture, - progress = stateMachine.logic.track()?.updates?.filter { !it.startsWith(STRUCTURAL_STEP_PREFIX) } ?: Observable.empty(), - stepsTreeIndexFeed = stateMachine.logic.trackStepsTreeIndex(), - stepsTreeFeed = stateMachine.logic.trackStepsTree() + progress = stateMachine.logic?.track()?.updates?.filter { !it.startsWith(STRUCTURAL_STEP_PREFIX) } ?: Observable.empty(), + stepsTreeIndexFeed = stateMachine.logic?.trackStepsTreeIndex(), + stepsTreeFeed = stateMachine.logic?.trackStepsTree() ) } override fun startFlowDynamic(logicType: Class>, vararg args: Any?): FlowHandle { - val stateMachine = startFlow(logicType, args) + val stateMachine = startFlow(logicType, context(), args) return FlowHandleImpl(id = stateMachine.id, returnValue = stateMachine.resultFuture) } - private fun startFlow(logicType: Class>, args: Array): FlowStateMachine { + override fun startFlowDynamicWithClientId( + clientId: String, + logicType: Class>, + vararg args: Any? + ): FlowHandleWithClientId { + return startFlow(logicType, context().withClientId(clientId), args).run { + FlowHandleWithClientIdImpl(id = id, returnValue = resultFuture, clientId = clientId) + } + } + + @Suppress("SpreadOperator") + private fun startFlow(logicType: Class>, context: InvocationContext, args: Array): FlowStateMachineHandle { if (!logicType.isAnnotationPresent(StartableByRPC::class.java)) throw NonRpcFlowException(logicType) if (isFlowsDrainingModeEnabled()) { throw RejectedCommandException("Node is draining before shutdown. Cannot start new flows through RPC.") } - return flowStarter.invokeFlowAsync(logicType, context(), *args).getOrThrow() + return flowStarter.invokeFlowAsync(logicType, context, *args).getOrThrow() } override fun attachmentExists(id: SecureHash): Boolean { @@ -464,4 +485,6 @@ internal class CordaRPCOpsImpl( private inline fun Class<*>.checkIsA() { require(TARGET::class.java.isAssignableFrom(this)) { "$name is not a ${TARGET::class.java.name}" } } + + private fun InvocationContext.withClientId(clientId: String) = copy(clientId = clientId) } diff --git a/node/src/main/kotlin/net/corda/node/services/api/CheckpointStorage.kt b/node/src/main/kotlin/net/corda/node/services/api/CheckpointStorage.kt index c7624266e7..49ea860589 100644 --- a/node/src/main/kotlin/net/corda/node/services/api/CheckpointStorage.kt +++ b/node/src/main/kotlin/net/corda/node/services/api/CheckpointStorage.kt @@ -4,6 +4,7 @@ import net.corda.core.flows.StateMachineRunId import net.corda.core.serialization.SerializedBytes import net.corda.node.services.statemachine.Checkpoint import net.corda.node.services.statemachine.CheckpointState +import net.corda.node.services.statemachine.FlowResultMetadata import net.corda.node.services.statemachine.FlowState import java.util.stream.Stream @@ -41,9 +42,12 @@ interface CheckpointStorage { /** * Remove existing checkpoint from the store. + * + * [mayHavePersistentResults] is used for optimization. If set to [false] it will not attempt to delete the database result or the database exception. + * Please note that if there is a doubt on whether a flow could be finished or not [mayHavePersistentResults] should be set to [true]. * @return whether the id matched a checkpoint that was removed. */ - fun removeCheckpoint(id: StateMachineRunId): Boolean + fun removeCheckpoint(id: StateMachineRunId, mayHavePersistentResults: Boolean = true): Boolean /** * Load an existing checkpoint from the store. @@ -75,4 +79,20 @@ interface CheckpointStorage { * This method does not fetch [Checkpoint.Serialized.serializedFlowState] to save memory. */ fun getPausedCheckpoints(): Stream> + + fun getFinishedFlowsResultsMetadata(): Stream> + + /** + * Load a flow result from the store. If [throwIfMissing] is true then it throws an [IllegalStateException] + * if the flow result is missing in the database. + */ + fun getFlowResult(id: StateMachineRunId, throwIfMissing: Boolean = false): Any? + + /** + * Load a flow exception from the store. If [throwIfMissing] is true then it throws an [IllegalStateException] + * if the flow exception is missing in the database. + */ + fun getFlowException(id: StateMachineRunId, throwIfMissing: Boolean = false): Any? + + fun removeFlowException(id: StateMachineRunId): Boolean } diff --git a/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt b/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt index 6fa3ed5869..273a95dfaa 100644 --- a/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt +++ b/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt @@ -215,13 +215,13 @@ interface FlowStarter { * just synthesizes an [ExternalEvent.ExternalStartFlowEvent] and calls the method below. * @param context indicates who started the flow, see: [InvocationContext]. */ - fun startFlow(logic: FlowLogic, context: InvocationContext): CordaFuture> + fun startFlow(logic: FlowLogic, context: InvocationContext): CordaFuture> /** * Starts a flow as described by an [ExternalEvent.ExternalStartFlowEvent]. If a transient error * occurs during invocation, it will re-attempt to start the flow. */ - fun startFlow(event: ExternalEvent.ExternalStartFlowEvent): CordaFuture> + fun startFlow(event: ExternalEvent.ExternalStartFlowEvent): CordaFuture> /** * Will check [logicType] and [args] against a whitelist and if acceptable then construct and initiate the flow. @@ -232,9 +232,10 @@ interface FlowStarter { * [logicType] or [args]. */ fun invokeFlowAsync( - logicType: Class>, - context: InvocationContext, - vararg args: Any?): CordaFuture> + logicType: Class>, + context: InvocationContext, + vararg args: Any? + ): CordaFuture> } interface StartedNodeServices : ServiceHubInternal, FlowStarter diff --git a/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt b/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt index 57c3254cfc..ff341af5d2 100644 --- a/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt +++ b/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt @@ -258,7 +258,7 @@ class NodeSchedulerService(private val clock: CordaClock, return "${javaClass.simpleName}($scheduledState)" } - override fun wireUpFuture(flowFuture: CordaFuture>) { + override fun wireUpFuture(flowFuture: CordaFuture>) { _future.captureLater(flowFuture) val future = _future.flatMap { it.resultFuture } future.then { @@ -266,8 +266,8 @@ class NodeSchedulerService(private val clock: CordaClock, } } - private val _future = openFuture>() - override val future: CordaFuture> + private val _future = openFuture>() + override val future: CordaFuture> get() = _future } diff --git a/node/src/main/kotlin/net/corda/node/services/logging/ContextualLoggingUtils.kt b/node/src/main/kotlin/net/corda/node/services/logging/ContextualLoggingUtils.kt index 2e2211b695..03a4f6fac9 100644 --- a/node/src/main/kotlin/net/corda/node/services/logging/ContextualLoggingUtils.kt +++ b/node/src/main/kotlin/net/corda/node/services/logging/ContextualLoggingUtils.kt @@ -13,6 +13,12 @@ internal fun InvocationContext.pushToLoggingContext() { origin.pushToLoggingContext() externalTrace?.pushToLoggingContext("external_") impersonatedActor?.pushToLoggingContext("impersonating_") + + clientId?.let { + MDC.getMDCAdapter().apply { + put("client_id", it) + } + } } internal fun Trace.pushToLoggingContext(prefix: String = "") { diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt index 2a815fba94..0b96ccfb7c 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt @@ -6,8 +6,11 @@ import net.corda.core.flows.StateMachineRunId import net.corda.core.internal.PLATFORM_VERSION import net.corda.core.internal.VisibleForTesting import net.corda.core.internal.uncheckedCast +import net.corda.core.flows.ResultSerializationException import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.SerializedBytes +import net.corda.core.serialization.deserialize +import net.corda.core.serialization.internal.MissingSerializerException import net.corda.core.serialization.serialize import net.corda.core.utilities.contextLogger import net.corda.node.services.api.CheckpointStorage @@ -15,6 +18,7 @@ import net.corda.node.services.statemachine.Checkpoint import net.corda.node.services.statemachine.Checkpoint.FlowStatus import net.corda.node.services.statemachine.CheckpointState import net.corda.node.services.statemachine.ErrorState +import net.corda.node.services.statemachine.FlowResultMetadata import net.corda.node.services.statemachine.FlowState import net.corda.node.services.statemachine.SubFlowVersion import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX @@ -55,9 +59,28 @@ class DBCheckpointStorage( private const val MAX_EXC_TYPE_LENGTH = 256 private const val MAX_FLOW_NAME_LENGTH = 128 private const val MAX_PROGRESS_STEP_LENGTH = 256 + const val MAX_CLIENT_ID_LENGTH = 512 private val RUNNABLE_CHECKPOINTS = setOf(FlowStatus.RUNNABLE, FlowStatus.HOSPITALIZED) + // This is a dummy [DBFlowMetadata] object which help us whenever we want to persist a [DBFlowCheckpoint], but not persist its [DBFlowMetadata]. + // [DBFlowCheckpoint] needs to always reference a [DBFlowMetadata] ([DBFlowCheckpoint.flowMetadata] is not nullable). + // However, since we do not -hibernate- cascade, it does not get persisted into the database. + private val dummyDBFlowMetadata: DBFlowMetadata = DBFlowMetadata( + flowId = "dummyFlowId", + invocationId = "dummyInvocationId", + flowName = "dummyFlowName", + userSuppliedIdentifier = "dummyUserSuppliedIdentifier", + startType = StartReason.INITIATED, + initialParameters = ByteArray(0), + launchingCordapp = "dummyLaunchingCordapp", + platformVersion = -1, + startedBy = "dummyStartedBy", + invocationInstant = Instant.now(), + startInstant = Instant.now(), + finishInstant = null + ) + /** * This needs to run before Hibernate is initialised. * @@ -137,7 +160,7 @@ class DBCheckpointStorage( var checkpoint: ByteArray = EMPTY_BYTE_ARRAY, @Type(type = "corda-blob") - @Column(name = "flow_state") + @Column(name = "flow_state", nullable = true) var flowStack: ByteArray?, @Type(type = "corda-wrapper-binary") @@ -184,28 +207,31 @@ class DBCheckpointStorage( var flow_id: String, @Type(type = "corda-blob") - @Column(name = "result_value", nullable = false) - var value: ByteArray = EMPTY_BYTE_ARRAY, + @Column(name = "result_value", nullable = true) + var value: ByteArray? = null, @Column(name = "timestamp") val persistedInstant: Instant ) { + @Suppress("ComplexMethod") override fun equals(other: Any?): Boolean { if (this === other) return true if (javaClass != other?.javaClass) return false - other as DBFlowResult - if (flow_id != other.flow_id) return false - if (!value.contentEquals(other.value)) return false + val value = value + val otherValue = other.value + if (value != null) { + if (otherValue == null) return false + if (!value.contentEquals(otherValue)) return false + } else if (otherValue != null) return false if (persistedInstant != other.persistedInstant) return false - return true } override fun hashCode(): Int { var result = flow_id.hashCode() - result = 31 * result + value.contentHashCode() + result = 31 * result + (value?.contentHashCode() ?: 0) result = 31 * result + persistedInstant.hashCode() return result } @@ -298,7 +324,7 @@ class DBCheckpointStorage( @Column(name = "invocation_time", nullable = false) var invocationInstant: Instant, - @Column(name = "start_time", nullable = true) + @Column(name = "start_time", nullable = false) var startInstant: Instant, @Column(name = "finish_time", nullable = true) @@ -362,7 +388,7 @@ class DBCheckpointStorage( now ) - val metadata = createDBFlowMetadata(flowId, checkpoint) + val metadata = createDBFlowMetadata(flowId, checkpoint, now) // Most fields are null as they cannot have been set when creating the initial checkpoint val dbFlowCheckpoint = DBFlowCheckpoint( @@ -383,15 +409,25 @@ class DBCheckpointStorage( currentDBSession().save(metadata) } + @Suppress("ComplexMethod") override fun updateCheckpoint( - id: StateMachineRunId, checkpoint: Checkpoint, serializedFlowState: SerializedBytes?, + id: StateMachineRunId, + checkpoint: Checkpoint, + serializedFlowState: SerializedBytes?, serializedCheckpointState: SerializedBytes ) { val now = clock.instant() val flowId = id.uuid.toString() - // Do not update in DB [Checkpoint.checkpointState] or [Checkpoint.flowState] if flow failed or got hospitalized - val blob = if (checkpoint.status == FlowStatus.FAILED || checkpoint.status == FlowStatus.HOSPITALIZED) { + val blob = if (checkpoint.status == FlowStatus.HOSPITALIZED) { + // Do not update 'checkpointState' or 'flowState' if flow hospitalized + null + } else if (checkpoint.status == FlowStatus.FAILED) { + // We need to update only the 'flowState' to null, and we don't want to update the checkpoint state + // because we want to retain the last clean checkpoint state, therefore just use a query for that update. + val sqlQuery = "Update ${NODE_DATABASE_PREFIX}checkpoint_blobs set flow_state = null where flow_id = '$flowId'" + val query = currentDBSession().createNativeQuery(sqlQuery) + query.executeUpdate() null } else { checkpointPerformanceRecorder.record(serializedCheckpointState, serializedFlowState) @@ -403,18 +439,31 @@ class DBCheckpointStorage( ) } - //This code needs to be added back in when we want to persist the result. For now this requires the result to be @CordaSerializable. - //val result = updateDBFlowResult(entity, checkpoint, now) - val exceptionDetails = updateDBFlowException(flowId, checkpoint, now) + val dbFlowResult = if (checkpoint.status == FlowStatus.COMPLETED) { + try { + createDBFlowResult(flowId, checkpoint.result, now) + } catch (e: MissingSerializerException) { + throw ResultSerializationException(e) + } + } else { + null + } - val metadata = createDBFlowMetadata(flowId, checkpoint) + val dbFlowException = if (checkpoint.status == FlowStatus.FAILED || checkpoint.status == FlowStatus.HOSPITALIZED) { + val errored = checkpoint.errorState as? ErrorState.Errored + errored?.let { createDBFlowException(flowId, it, now) } + ?: throw IllegalStateException("Found '${checkpoint.status}' checkpoint whose error state is not ${ErrorState.Errored::class.java.simpleName}") + } else { + null + } + // Updates to children entities ([DBFlowCheckpointBlob], [DBFlowResult], [DBFlowException], [DBFlowMetadata]) are not cascaded to children tables. val dbFlowCheckpoint = DBFlowCheckpoint( flowId = flowId, blob = blob, - result = null, - exceptionDetails = exceptionDetails, - flowMetadata = metadata, + result = dbFlowResult, + exceptionDetails = dbFlowException, + flowMetadata = dummyDBFlowMetadata, // [DBFlowMetadata] will only update its 'finish_time' when a checkpoint finishes status = checkpoint.status, compatible = checkpoint.compatible, progressStep = checkpoint.progressStep?.take(MAX_PROGRESS_STEP_LENGTH), @@ -424,9 +473,10 @@ class DBCheckpointStorage( currentDBSession().update(dbFlowCheckpoint) blob?.let { currentDBSession().update(it) } + dbFlowResult?.let { currentDBSession().save(it) } + dbFlowException?.let { currentDBSession().save(it) } if (checkpoint.isFinished()) { - metadata.finishInstant = now - currentDBSession().update(metadata) + setDBFlowMetadataFinishTime(flowId, now) } } @@ -439,17 +489,18 @@ class DBCheckpointStorage( query.executeUpdate() } - // DBFlowResult and DBFlowException to be integrated with rest of schema @Suppress("MagicNumber") - override fun removeCheckpoint(id: StateMachineRunId): Boolean { + override fun removeCheckpoint(id: StateMachineRunId, mayHavePersistentResults: Boolean): Boolean { var deletedRows = 0 val flowId = id.uuid.toString() - deletedRows += deleteRow(DBFlowMetadata::class.java, DBFlowMetadata::flowId.name, flowId) - deletedRows += deleteRow(DBFlowCheckpointBlob::class.java, DBFlowCheckpointBlob::flowId.name, flowId) deletedRows += deleteRow(DBFlowCheckpoint::class.java, DBFlowCheckpoint::flowId.name, flowId) -// resultId?.let { deletedRows += deleteRow(DBFlowResult::class.java, DBFlowResult::flow_id.name, it.toString()) } -// exceptionId?.let { deletedRows += deleteRow(DBFlowException::class.java, DBFlowException::flow_id.name, it.toString()) } - return deletedRows == 3 + deletedRows += deleteRow(DBFlowCheckpointBlob::class.java, DBFlowCheckpointBlob::flowId.name, flowId) + if (mayHavePersistentResults) { + deletedRows += deleteRow(DBFlowResult::class.java, DBFlowResult::flow_id.name, flowId) + deletedRows += deleteRow(DBFlowException::class.java, DBFlowException::flow_id.name, flowId) + } + deletedRows += deleteRow(DBFlowMetadata::class.java, DBFlowMetadata::flowId.name, flowId) + return deletedRows >= 2 } private fun deleteRow(clazz: Class, pk: String, value: String): Int { @@ -487,6 +538,14 @@ class DBCheckpointStorage( return currentDBSession().find(DBFlowCheckpoint::class.java, id.uuid.toString()) } + private fun getDBFlowResult(id: StateMachineRunId): DBFlowResult? { + return currentDBSession().find(DBFlowResult::class.java, id.uuid.toString()) + } + + private fun getDBFlowException(id: StateMachineRunId): DBFlowException? { + return currentDBSession().find(DBFlowException::class.java, id.uuid.toString()) + } + override fun getPausedCheckpoints(): Stream> { val session = currentDBSession() val jpqlQuery = """select new ${DBPausedFields::class.java.name}(checkpoint.id, blob.checkpoint, checkpoint.status, @@ -499,6 +558,42 @@ class DBCheckpointStorage( } } + override fun getFinishedFlowsResultsMetadata(): Stream> { + val session = currentDBSession() + val jpqlQuery = + """select new ${DBFlowResultMetadataFields::class.java.name}(checkpoint.id, checkpoint.status, metadata.userSuppliedIdentifier) + from ${DBFlowCheckpoint::class.java.name} checkpoint + join ${DBFlowMetadata::class.java.name} metadata on metadata.id = checkpoint.flowMetadata + where checkpoint.status = ${FlowStatus.COMPLETED.ordinal} or checkpoint.status = ${FlowStatus.FAILED.ordinal}""".trimIndent() + val query = session.createQuery(jpqlQuery, DBFlowResultMetadataFields::class.java) + return query.resultList.stream().map { + StateMachineRunId(UUID.fromString(it.id)) to FlowResultMetadata(it.status, it.clientId) + } + } + + override fun getFlowResult(id: StateMachineRunId, throwIfMissing: Boolean): Any? { + val dbFlowResult = getDBFlowResult(id) + if (throwIfMissing && dbFlowResult == null) { + throw IllegalStateException("Flow's $id result was not found in the database. Something is very wrong.") + } + val serializedFlowResult = dbFlowResult?.value?.let { SerializedBytes(it) } + return serializedFlowResult?.deserialize(context = SerializationDefaults.STORAGE_CONTEXT) + } + + override fun getFlowException(id: StateMachineRunId, throwIfMissing: Boolean): Any? { + val dbFlowException = getDBFlowException(id) + if (throwIfMissing && dbFlowException == null) { + throw IllegalStateException("Flow's $id exception was not found in the database. Something is very wrong.") + } + val serializedFlowException = dbFlowException?.value?.let { SerializedBytes(it) } + return serializedFlowException?.deserialize(context = SerializationDefaults.STORAGE_CONTEXT) + } + + override fun removeFlowException(id: StateMachineRunId): Boolean { + val flowId = id.uuid.toString() + return deleteRow(DBFlowException::class.java, DBFlowException::flow_id.name, flowId) == 1 + } + override fun updateStatus(runId: StateMachineRunId, flowStatus: FlowStatus) { val update = "Update ${NODE_DATABASE_PREFIX}checkpoints set status = ${flowStatus.ordinal} where flow_id = '${runId.uuid}'" currentDBSession().createNativeQuery(update).executeUpdate() @@ -509,7 +604,7 @@ class DBCheckpointStorage( currentDBSession().createNativeQuery(update).executeUpdate() } - private fun createDBFlowMetadata(flowId: String, checkpoint: Checkpoint): DBFlowMetadata { + private fun createDBFlowMetadata(flowId: String, checkpoint: Checkpoint, now: Instant): DBFlowMetadata { val context = checkpoint.checkpointState.invocationContext val flowInfo = checkpoint.checkpointState.subFlowStack.first() return DBFlowMetadata( @@ -518,15 +613,14 @@ class DBCheckpointStorage( // Truncate the flow name to fit into the database column // Flow names are unlikely to be this long flowName = flowInfo.flowClass.name.take(MAX_FLOW_NAME_LENGTH), - // will come from the context - userSuppliedIdentifier = null, + userSuppliedIdentifier = context.clientId, startType = context.getStartedType(), initialParameters = context.getFlowParameters().storageSerialize().bytes, launchingCordapp = (flowInfo.subFlowVersion as? SubFlowVersion.CorDappFlow)?.corDappName ?: "Core flow", platformVersion = PLATFORM_VERSION, startedBy = context.principal().name, invocationInstant = context.trace.invocationId.timestamp, - startInstant = clock.instant(), + startInstant = now, finishInstant = null ) } @@ -546,70 +640,14 @@ class DBCheckpointStorage( ) } - /** - * Creates, updates or deletes the result related to the current flow/checkpoint. - * - * This is needed because updates are not cascading via Hibernate, therefore operations must be handled manually. - * - * A [DBFlowResult] is created if [DBFlowCheckpoint.result] does not exist and the [Checkpoint] has a result.. - * The existing [DBFlowResult] is updated if [DBFlowCheckpoint.result] exists and the [Checkpoint] has a result. - * The existing [DBFlowResult] is deleted if [DBFlowCheckpoint.result] exists and the [Checkpoint] has no result. - * Nothing happens if both [DBFlowCheckpoint] and [Checkpoint] do not have a result. - */ - private fun updateDBFlowResult(flowId: String, entity: DBFlowCheckpoint, checkpoint: Checkpoint, now: Instant): DBFlowResult? { - val result = checkpoint.result?.let { createDBFlowResult(flowId, it, now) } - if (entity.result != null) { - if (result != null) { - result.flow_id = entity.result!!.flow_id - currentDBSession().update(result) - } else { - currentDBSession().delete(entity.result) - } - } else if (result != null) { - currentDBSession().save(result) - } - return result - } - - private fun createDBFlowResult(flowId: String, result: Any, now: Instant): DBFlowResult { + private fun createDBFlowResult(flowId: String, result: Any?, now: Instant): DBFlowResult { return DBFlowResult( flow_id = flowId, - value = result.storageSerialize().bytes, + value = result?.storageSerialize()?.bytes, persistedInstant = now ) } - /** - * Creates, updates or deletes the error related to the current flow/checkpoint. - * - * This is needed because updates are not cascading via Hibernate, therefore operations must be handled manually. - * - * A [DBFlowException] is created if [DBFlowCheckpoint.exceptionDetails] does not exist and the [Checkpoint] has an error attached to it. - * The existing [DBFlowException] is updated if [DBFlowCheckpoint.exceptionDetails] exists and the [Checkpoint] has an error. - * The existing [DBFlowException] is deleted if [DBFlowCheckpoint.exceptionDetails] exists and the [Checkpoint] has no error. - * Nothing happens if both [DBFlowCheckpoint] and [Checkpoint] are related to no errors. - */ - // DBFlowException to be integrated with rest of schema - // Add a flag notifying if an exception is already saved in the database for below logic (are we going to do this after all?) - private fun updateDBFlowException(flowId: String, checkpoint: Checkpoint, now: Instant): DBFlowException? { - val exceptionDetails = (checkpoint.errorState as? ErrorState.Errored)?.let { createDBFlowException(flowId, it, now) } -// if (checkpoint.dbExoSkeleton.dbFlowExceptionId != null) { -// if (exceptionDetails != null) { -// exceptionDetails.flow_id = checkpoint.dbExoSkeleton.dbFlowExceptionId!! -// currentDBSession().update(exceptionDetails) -// } else { -// val session = currentDBSession() -// val entity = session.get(DBFlowException::class.java, checkpoint.dbExoSkeleton.dbFlowExceptionId) -// session.delete(entity) -// return null -// } -// } else if (exceptionDetails != null) { -// currentDBSession().save(exceptionDetails) -// checkpoint.dbExoSkeleton.dbFlowExceptionId = exceptionDetails.flow_id -// } - return exceptionDetails - } - private fun createDBFlowException(flowId: String, errorState: ErrorState.Errored, now: Instant): DBFlowException { return errorState.errors.last().exception.let { DBFlowException( @@ -617,12 +655,20 @@ class DBCheckpointStorage( type = it::class.java.name.truncate(MAX_EXC_TYPE_LENGTH, true), message = it.message?.truncate(MAX_EXC_MSG_LENGTH, false), stackTrace = it.stackTraceToString(), - value = null, // TODO to be populated upon implementing https://r3-cev.atlassian.net/browse/CORDA-3681 + value = it.storageSerialize().bytes, persistedInstant = now ) } } + private fun setDBFlowMetadataFinishTime(flowId: String, now: Instant) { + val session = currentDBSession() + val sqlQuery = "Update ${NODE_DATABASE_PREFIX}flow_metadata set finish_time = '$now' " + + "where flow_id = '$flowId'" + val query = session.createNativeQuery(sqlQuery) + query.executeUpdate() + } + private fun InvocationContext.getStartedType(): StartReason { return when (origin) { is InvocationOrigin.RPC, is InvocationOrigin.Shell -> StartReason.RPC @@ -632,10 +678,14 @@ class DBCheckpointStorage( } } + @Suppress("MagicNumber") private fun InvocationContext.getFlowParameters(): List { - // Only RPC flows have parameters which are found in index 1 - return if (arguments.isNotEmpty()) { - uncheckedCast>(arguments[1]).toList() + // Only RPC flows have parameters which are found in index 1 or index 2 (if called with client id) + return if (arguments!!.isNotEmpty()) { + arguments!!.run { + check(size == 2 || size == 3) { "Unexpected argument number provided in rpc call" } + uncheckedCast>(last()).toList() + } } else { emptyList() } @@ -649,7 +699,7 @@ class DBCheckpointStorage( // Always load as a [Clean] checkpoint to represent that the checkpoint is the last _good_ checkpoint errorState = ErrorState.Clean, // A checkpoint with a result should not normally be loaded (it should be [null] most of the time) - result = result?.let { SerializedBytes(it.value) }, + result = result?.let { dbFlowResult -> dbFlowResult.value?.let { SerializedBytes(it) } }, status = status, progressStep = progressStep, flowIoRequest = ioRequestType, @@ -680,6 +730,12 @@ class DBCheckpointStorage( } } + private class DBFlowResultMetadataFields( + val id: String, + val status: FlowStatus, + val clientId: String? + ) + private fun T.storageSerialize(): SerializedBytes { return serialize(context = SerializationDefaults.STORAGE_CONTEXT) } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/Action.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/Action.kt index b98578472b..d5faad801e 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/Action.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/Action.kt @@ -63,9 +63,11 @@ sealed class Action { data class UpdateFlowStatus(val id: StateMachineRunId, val status: Checkpoint.FlowStatus): Action() /** - * Remove the checkpoint corresponding to [id]. + * Remove the checkpoint corresponding to [id]. [mayHavePersistentResults] denotes that at the time of injecting a [RemoveCheckpoint] + * the flow could have persisted its database result or exception. + * For more information see [CheckpointStorage.removeCheckpoint]. */ - data class RemoveCheckpoint(val id: StateMachineRunId) : Action() + data class RemoveCheckpoint(val id: StateMachineRunId, val mayHavePersistentResults: Boolean = false) : Action() /** * Persist the deduplication facts of [deduplicationHandlers]. diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt index 977102c3ed..b1162a390b 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt @@ -85,7 +85,7 @@ internal class ActionExecutorImpl( val checkpoint = action.checkpoint val flowState = checkpoint.flowState val serializedFlowState = when(flowState) { - FlowState.Completed -> null + FlowState.Finished -> null // upon implementing CORDA-3816: If we have errored or hospitalized then we don't need to serialize the flowState as it will not get saved in the DB else -> flowState.checkpointSerialize(checkpointSerializationContext) } @@ -94,8 +94,8 @@ internal class ActionExecutorImpl( if (action.isCheckpointUpdate) { checkpointStorage.updateCheckpoint(action.id, checkpoint, serializedFlowState, serializedCheckpointState) } else { - if (flowState is FlowState.Completed) { - throw IllegalStateException("A new checkpoint cannot be created with a Completed FlowState.") + if (flowState is FlowState.Finished) { + throw IllegalStateException("A new checkpoint cannot be created with a finished flow state.") } checkpointStorage.addCheckpoint(action.id, checkpoint, serializedFlowState!!, serializedCheckpointState) } @@ -158,7 +158,7 @@ internal class ActionExecutorImpl( @Suspendable private fun executeRemoveCheckpoint(action: Action.RemoveCheckpoint) { - checkpointStorage.removeCheckpoint(action.id) + checkpointStorage.removeCheckpoint(action.id, action.mayHavePersistentResults) } @Suspendable diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt index b157b0d575..5ea2ea6fcc 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt @@ -146,6 +146,8 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, override val context: InvocationContext get() = transientState.checkpoint.checkpointState.invocationContext override val ourIdentity: Party get() = transientState.checkpoint.checkpointState.ourIdentity override val isKilled: Boolean get() = transientState.isKilled + override val clientId: String? get() = transientState.checkpoint.checkpointState.invocationContext.clientId + /** * What sender identifier to put on messages sent by this flow. This will either be the identifier for the current * state machine manager / messaging client, or null to indicate this flow is restored from a checkpoint and diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt index 71fda8c194..00cb28d0da 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt @@ -14,11 +14,15 @@ import net.corda.core.flows.FlowLogic import net.corda.core.flows.StateMachineRunId import net.corda.core.identity.Party import net.corda.core.internal.FlowStateMachine +import net.corda.core.internal.FlowStateMachineHandle +import net.corda.core.internal.VisibleForTesting import net.corda.core.internal.bufferUntilSubscribed import net.corda.core.internal.castIfPossible +import net.corda.core.internal.concurrent.OpenFuture +import net.corda.core.internal.concurrent.doneFuture import net.corda.core.internal.concurrent.map -import net.corda.core.internal.concurrent.mapError import net.corda.core.internal.concurrent.openFuture +import net.corda.core.internal.mapNotNull import net.corda.core.internal.uncheckedCast import net.corda.core.messaging.DataFeed import net.corda.core.serialization.deserialize @@ -55,6 +59,7 @@ import javax.annotation.concurrent.ThreadSafe import kotlin.collections.component1 import kotlin.collections.component2 import kotlin.collections.set +import kotlin.streams.toList /** * The StateMachineManagerImpl will always invoke the flow fibers on the given [AffinityExecutor], regardless of which @@ -80,12 +85,21 @@ internal class SingleThreadedStateMachineManager( Checkpoint.FlowStatus.HOSPITALIZED, Checkpoint.FlowStatus.PAUSED ) + + @VisibleForTesting + var beforeClientIDCheck: (() -> Unit)? = null + @VisibleForTesting + var onClientIDNotFound: (() -> Unit)? = null + @VisibleForTesting + var onCallingStartFlowInternal: (() -> Unit)? = null + @VisibleForTesting + var onStartFlowInternalThrewAndAboutToRemove: (() -> Unit)? = null } private val innerState = StateMachineInnerStateImpl() private val scheduler = FiberExecutorScheduler("Same thread scheduler", executor) private val scheduledFutureExecutor = Executors.newSingleThreadScheduledExecutor( - ThreadFactoryBuilder().setNameFormat("flow-scheduled-future-thread").setDaemon(true).build() + ThreadFactoryBuilder().setNameFormat("flow-scheduled-future-thread").setDaemon(true).build() ) // How many Fibers are running (this includes suspended flows). If zero and stopping is true, then we are halted. private val liveFibers = ReusableLatch() @@ -138,6 +152,7 @@ internal class SingleThreadedStateMachineManager( */ override val changes: Observable = innerState.changesPublisher + @Suppress("ComplexMethod") override fun start(tokenizableServices: List, startMode: StateMachineManager.StartMode): CordaFuture { checkQuasarJavaAgentPresence() val checkpointSerializationContext = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext( @@ -181,6 +196,33 @@ internal class SingleThreadedStateMachineManager( } } } + + // - Incompatible checkpoints need to be handled upon implementing CORDA-3897 + for (flow in fibers.values) { + flow.fiber.clientId?.let { + innerState.clientIdsToFlowIds[it] = FlowWithClientIdStatus.Active(doneFuture(flow.fiber)) + } + } + + for (pausedFlow in pausedFlows) { + pausedFlow.value.checkpoint.checkpointState.invocationContext.clientId?.let { + innerState.clientIdsToFlowIds[it] = FlowWithClientIdStatus.Active( + doneClientIdFuture(pausedFlow.key, pausedFlow.value.resultFuture, it) + ) + } + } + + val finishedFlowsResults = checkpointStorage.getFinishedFlowsResultsMetadata().toList() + for ((id, finishedFlowResult) in finishedFlowsResults) { + finishedFlowResult.clientId?.let { + if (finishedFlowResult.status == Checkpoint.FlowStatus.COMPLETED) { + innerState.clientIdsToFlowIds[it] = FlowWithClientIdStatus.Removed(id, true) + } else { + innerState.clientIdsToFlowIds[it] = FlowWithClientIdStatus.Removed(id, false) + } + } ?: logger.error("Found finished flow $id without a client id. Something is very wrong and this flow will be ignored.") + } + return serviceHub.networkMapCache.nodeReady.map { logger.info("Node ready, info: ${serviceHub.myInfo}") resumeRestoredFlows(fibers) @@ -248,21 +290,62 @@ internal class SingleThreadedStateMachineManager( } } + @Suppress("ComplexMethod") private fun startFlow( flowId: StateMachineRunId, flowLogic: FlowLogic, context: InvocationContext, ourIdentity: Party?, deduplicationHandler: DeduplicationHandler? - ): CordaFuture> { - return startFlowInternal( + ): CordaFuture> { + beforeClientIDCheck?.invoke() + + var newFuture: OpenFuture>? = null + + val clientId = context.clientId + if (clientId != null) { + var existingStatus: FlowWithClientIdStatus? = null + innerState.withLock { + clientIdsToFlowIds.compute(clientId) { _, status -> + if (status != null) { + existingStatus = status + status + } else { + newFuture = openFuture() + FlowWithClientIdStatus.Active(newFuture!!) + } + } + } + + // Flow -started with client id- already exists, return the existing's flow future and don't start a new flow. + existingStatus?.let { + val existingFuture = activeOrRemovedClientIdFuture(it, clientId) + return@startFlow uncheckedCast(existingFuture) + } + onClientIDNotFound?.invoke() + } + + return try { + startFlowInternal( flowId, invocationContext = context, flowLogic = flowLogic, flowStart = FlowStart.Explicit, ourIdentity = ourIdentity ?: ourFirstIdentity, deduplicationHandler = deduplicationHandler - ) + ).also { + newFuture?.captureLater(uncheckedCast(it)) + } + } catch (t: Throwable) { + onStartFlowInternalThrewAndAboutToRemove?.invoke() + innerState.withLock { + clientIdsToFlowIds.remove(clientId) + newFuture?.setException(t) + } + // Throwing the exception plain here is the same as to return an exceptionally completed future since the caller calls + // getOrThrow() on the returned future at [CordaRPCOpsImpl.startFlow]. + throw t + } } override fun killFlow(id: StateMachineRunId): Boolean { @@ -273,7 +356,7 @@ internal class SingleThreadedStateMachineManager( // The checkpoint and soft locks are removed here instead of relying on the processing of the next event after setting // the killed flag. This is to ensure a flow can be removed from the database, even if it is stuck in a infinite loop. database.transaction { - checkpointStorage.removeCheckpoint(id) + checkpointStorage.removeCheckpoint(id, mayHavePersistentResults = true) serviceHub.vaultService.softLockRelease(id.uuid) } @@ -285,7 +368,7 @@ internal class SingleThreadedStateMachineManager( } } else { // It may be that the id refers to a checkpoint that couldn't be deserialised into a flow, so we delete it if it exists. - database.transaction { checkpointStorage.removeCheckpoint(id) } + database.transaction { checkpointStorage.removeCheckpoint(id, mayHavePersistentResults = true) } } return killFlowResult || flowHospital.dropSessionInit(id) @@ -370,7 +453,15 @@ internal class SingleThreadedStateMachineManager( checkpointStorage.getCheckpointsToRun().forEach Checkpoints@{(id, serializedCheckpoint) -> // If a flow is added before start() then don't attempt to restore it innerState.withLock { if (id in flows) return@Checkpoints } - val checkpoint = tryDeserializeCheckpoint(serializedCheckpoint, id) ?: return@Checkpoints + val checkpoint = tryDeserializeCheckpoint(serializedCheckpoint, id)?.also { + if (it.status == Checkpoint.FlowStatus.HOSPITALIZED) { + checkpointStorage.updateStatus(id, Checkpoint.FlowStatus.RUNNABLE) + if (!checkpointStorage.removeFlowException(id)) { + logger.error("Unable to remove database exception for flow $id. Something is very wrong. The flow will not be loaded and run.") + return@Checkpoints + } + } + } ?: return@Checkpoints val flow = flowCreator.createFlowFromCheckpoint(id, checkpoint) if (flow == null) { // Set the flowState to paused so we don't waste memory storing it anymore. @@ -415,6 +506,10 @@ internal class SingleThreadedStateMachineManager( tryDeserializeCheckpoint(serializedCheckpoint, flowId)?.also { if (it.status == Checkpoint.FlowStatus.HOSPITALIZED) { checkpointStorage.updateStatus(flowId, Checkpoint.FlowStatus.RUNNABLE) + if (!checkpointStorage.removeFlowException(flowId)) { + logger.error("Unable to remove database exception for flow $flowId. Something is very wrong. The flow will not retry.") + return@transaction null + } } } ?: return@transaction null } ?: return @@ -658,6 +753,7 @@ internal class SingleThreadedStateMachineManager( ourIdentity: Party, deduplicationHandler: DeduplicationHandler? ): CordaFuture> { + onCallingStartFlowInternal?.invoke() val existingFlow = innerState.withLock { flows[flowId] } val existingCheckpoint = if (existingFlow != null && existingFlow.fiber.transientState.isAnyCheckpointPersisted) { @@ -666,17 +762,9 @@ internal class SingleThreadedStateMachineManager( // CORDA-3359 - Do not start/retry a flow that failed after deleting its checkpoint (the whole of the flow might replay) val existingCheckpoint = database.transaction { checkpointStorage.getCheckpoint(flowId) } existingCheckpoint?.let { serializedCheckpoint -> - val checkpoint = tryDeserializeCheckpoint(serializedCheckpoint, flowId) - if (checkpoint == null) { - return openFuture>().mapError { - IllegalStateException( - "Unable to deserialize database checkpoint for flow $flowId. " + - "Something is very wrong. The flow will not retry." - ) - } - } else { - checkpoint - } + tryDeserializeCheckpoint(serializedCheckpoint, flowId) ?: throw IllegalStateException( + "Unable to deserialize database checkpoint for flow $flowId. Something is very wrong. The flow will not retry." + ) } } else { // This is a brand new flow @@ -780,7 +868,7 @@ internal class SingleThreadedStateMachineManager( is FlowState.Started -> { Fiber.unparkDeserialized(flow.fiber, scheduler) } - is FlowState.Completed -> throw IllegalStateException("Cannot start (or resume) a completed flow.") + is FlowState.Finished -> throw IllegalStateException("Cannot start (or resume) a finished flow.") } } @@ -834,6 +922,7 @@ internal class SingleThreadedStateMachineManager( require(lastState.isRemoved) { "Flow must be in removable state before removal" } require(lastState.checkpoint.checkpointState.subFlowStack.size == 1) { "Checkpointed stack must be empty" } require(flow.fiber.id !in sessionToFlow.values) { "Flow fibre must not be needed by an existing session" } + flow.fiber.clientId?.let { setClientIdAsSucceeded(it, flow.fiber.id) } flow.resultFuture.set(removalReason.flowReturnValue) lastState.flowLogic.progressTracker?.currentStep = ProgressTracker.DONE changesPublisher.onNext(StateMachineManager.Change.Removed(lastState.flowLogic, Try.Success(removalReason.flowReturnValue))) @@ -845,13 +934,19 @@ internal class SingleThreadedStateMachineManager( lastState: StateMachineState ) { drainFlowEventQueue(flow) + // Complete the started future, needed when the flow fails during flow init (before completing an [UnstartedFlowTransition]) + startedFutures.remove(flow.fiber.id)?.set(Unit) + flow.fiber.clientId?.let { + if (flow.fiber.isKilled) { + clientIdsToFlowIds.remove(it) + } else { + setClientIdAsFailed(it, flow.fiber.id) } + } val flowError = removalReason.flowErrors[0] // TODO what to do with several? val exception = flowError.exception (exception as? FlowException)?.originalErrorId = flowError.errorId flow.resultFuture.setException(exception) lastState.flowLogic.progressTracker?.endWithError(exception) - // Complete the started future, needed when the flow fails during flow init (before completing an [UnstartedFlowTransition]) - startedFutures.remove(flow.fiber.id)?.set(Unit) changesPublisher.onNext(StateMachineManager.Change.Removed(lastState.flowLogic, Try.Failure(exception))) } @@ -887,4 +982,117 @@ internal class SingleThreadedStateMachineManager( future = null } } + + private fun StateMachineInnerState.setClientIdAsSucceeded(clientId: String, id: StateMachineRunId) { + setClientIdAsRemoved(clientId, id, true) + } + + private fun StateMachineInnerState.setClientIdAsFailed(clientId: String, id: StateMachineRunId) { + setClientIdAsRemoved(clientId, id, false) + } + + private fun StateMachineInnerState.setClientIdAsRemoved( + clientId: String, + id: StateMachineRunId, + succeeded: Boolean + ) { + clientIdsToFlowIds.compute(clientId) { _, existingStatus -> + require(existingStatus != null && existingStatus is FlowWithClientIdStatus.Active) + FlowWithClientIdStatus.Removed(id, succeeded) + } + } + + private fun activeOrRemovedClientIdFuture(existingStatus: FlowWithClientIdStatus, clientId: String) = when (existingStatus) { + is FlowWithClientIdStatus.Active -> existingStatus.flowStateMachineFuture + is FlowWithClientIdStatus.Removed -> { + val flowId = existingStatus.flowId + val resultFuture = if (existingStatus.succeeded) { + val flowResult = database.transaction { checkpointStorage.getFlowResult(existingStatus.flowId, throwIfMissing = true) } + doneFuture(flowResult) + } else { + val flowException = + database.transaction { checkpointStorage.getFlowException(existingStatus.flowId, throwIfMissing = true) } + openFuture().apply { setException(flowException as Throwable) } + } + + doneClientIdFuture(flowId, resultFuture, clientId) + } + } + + /** + * The flow out of which a [doneFuture] will be produced should be a started flow, + * i.e. it should not exist in [mutex.content.startedFutures]. + */ + private fun doneClientIdFuture( + id: StateMachineRunId, + resultFuture: CordaFuture, + clientId: String + ): CordaFuture> = + doneFuture(object : FlowStateMachineHandle { + override val logic: Nothing? = null + override val id: StateMachineRunId = id + override val resultFuture: CordaFuture = resultFuture + override val clientId: String? = clientId + } + ) + + override fun reattachFlowWithClientId(clientId: String): FlowStateMachineHandle? { + return innerState.withLock { + clientIdsToFlowIds[clientId]?.let { + val existingFuture = activeOrRemovedClientIdFutureForReattach(it, clientId) + existingFuture?.let { uncheckedCast(existingFuture.get()) } + } + } + } + + @Suppress("NestedBlockDepth") + private fun activeOrRemovedClientIdFutureForReattach( + existingStatus: FlowWithClientIdStatus, + clientId: String + ): CordaFuture>? { + return when (existingStatus) { + is FlowWithClientIdStatus.Active -> existingStatus.flowStateMachineFuture + is FlowWithClientIdStatus.Removed -> { + val flowId = existingStatus.flowId + val resultFuture = if (existingStatus.succeeded) { + try { + val flowResult = + database.transaction { checkpointStorage.getFlowResult(existingStatus.flowId, throwIfMissing = true) } + doneFuture(flowResult) + } catch (e: IllegalStateException) { + null + } + } else { + try { + val flowException = + database.transaction { checkpointStorage.getFlowException(existingStatus.flowId, throwIfMissing = true) } + openFuture().apply { setException(flowException as Throwable) } + } catch (e: IllegalStateException) { + null + } + } + + resultFuture?.let { doneClientIdFuture(flowId, it, clientId) } + } + } + } + + override fun removeClientId(clientId: String): Boolean { + var removedFlowId: StateMachineRunId? = null + innerState.withLock { + clientIdsToFlowIds.computeIfPresent(clientId) { _, existingStatus -> + if (existingStatus is FlowWithClientIdStatus.Removed) { + removedFlowId = existingStatus.flowId + null + } else { + existingStatus + } + } + } + + removedFlowId?.let { + return database.transaction { checkpointStorage.removeCheckpoint(it, mayHavePersistentResults = true) } + } + return false + } } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineInnerState.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineInnerState.kt index 0252e21e80..66017480ca 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineInnerState.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineInnerState.kt @@ -17,6 +17,7 @@ internal interface StateMachineInnerState { val changesPublisher: PublishSubject /** Flows scheduled to be retried if not finished within the specified timeout period. */ val timedFlows: MutableMap + val clientIdsToFlowIds: MutableMap fun withMutex(block: StateMachineInnerState.() -> R): R } @@ -30,6 +31,7 @@ internal class StateMachineInnerStateImpl : StateMachineInnerState { override val pausedFlows = HashMap() override val startedFutures = HashMap>() override val timedFlows = HashMap() + override val clientIdsToFlowIds = HashMap() override fun withMutex(block: StateMachineInnerState.() -> R): R = lock.withLock { block(this) } } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt index 123f9e920c..832af2cc99 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt @@ -5,7 +5,9 @@ import net.corda.core.context.InvocationContext import net.corda.core.flows.FlowLogic import net.corda.core.flows.StateMachineRunId import net.corda.core.internal.FlowStateMachine +import net.corda.core.internal.FlowStateMachineHandle import net.corda.core.messaging.DataFeed +import net.corda.core.messaging.FlowHandleWithClientId import net.corda.core.utilities.Try import net.corda.node.services.messaging.DeduplicationHandler import net.corda.node.services.messaging.ReceivedMessage @@ -97,6 +99,27 @@ interface StateMachineManager { * Returns a snapshot of all [FlowStateMachineImpl]s currently managed. */ fun snapshot(): Set> + + /** + * Reattach to an existing flow that was started with [startFlowDynamicWithClientId] and has a [clientId]. + * + * If there is a flow matching the [clientId] then its result or exception is returned. + * + * When there is no flow matching the [clientId] then [null] is returned directly (not a future/[FlowHandleWithClientId]). + * + * Calling [reattachFlowWithClientId] after [removeClientId] with the same [clientId] will cause the function to return [null] as + * the result/exception of the flow will no longer be available. + * + * @param clientId The client id relating to an existing flow + */ + fun reattachFlowWithClientId(clientId: String): FlowStateMachineHandle? + + /** + * Removes a flow's [clientId] to result/ exception mapping. + * + * @return whether the mapping was removed. + */ + fun removeClientId(clientId: String): Boolean } // These must be idempotent! A later failure in the state transition may error the flow state, and a replay may call @@ -138,11 +161,11 @@ interface ExternalEvent { /** * A callback for the state machine to pass back the [CordaFuture] associated with the flow start to the submitter. */ - fun wireUpFuture(flowFuture: CordaFuture>) + fun wireUpFuture(flowFuture: CordaFuture>) /** * The future representing the flow start, passed back from the state machine to the submitter of this event. */ - val future: CordaFuture> + val future: CordaFuture> } } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt index c94e38187a..b734810f0b 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt @@ -4,13 +4,16 @@ import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.KryoSerializable import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Output +import net.corda.core.concurrent.CordaFuture import net.corda.core.context.InvocationContext import net.corda.core.crypto.SecureHash import net.corda.core.flows.Destination import net.corda.core.flows.FlowInfo import net.corda.core.flows.FlowLogic +import net.corda.core.flows.StateMachineRunId import net.corda.core.identity.Party import net.corda.core.internal.FlowIORequest +import net.corda.core.internal.FlowStateMachineHandle import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.SerializedBytes @@ -128,8 +131,8 @@ data class Checkpoint( listOf(topLevelSubFlow), numberOfSuspends = 0 ), - errorState = ErrorState.Clean, - flowState = FlowState.Unstarted(flowStart, frozenFlowLogic) + flowState = FlowState.Unstarted(flowStart, frozenFlowLogic), + errorState = ErrorState.Clean ) } } @@ -207,7 +210,7 @@ data class Checkpoint( fun deserialize(checkpointSerializationContext: CheckpointSerializationContext): Checkpoint { val flowState = when(status) { FlowStatus.PAUSED -> FlowState.Paused - FlowStatus.COMPLETED -> FlowState.Completed + FlowStatus.COMPLETED, FlowStatus.FAILED -> FlowState.Finished else -> serializedFlowState!!.checkpointDeserialize(checkpointSerializationContext) } return Checkpoint( @@ -350,9 +353,9 @@ sealed class FlowState { object Paused: FlowState() /** - * The flow has completed. It does not have a running fiber that needs to be serialized and checkpointed. + * The flow has finished. It does not have a running fiber that needs to be serialized and checkpointed. */ - object Completed : FlowState() + object Finished : FlowState() } @@ -412,3 +415,13 @@ sealed class SubFlowVersion { data class CoreFlow(override val platformVersion: Int) : SubFlowVersion() data class CorDappFlow(override val platformVersion: Int, val corDappName: String, val corDappHash: SecureHash) : SubFlowVersion() } + +sealed class FlowWithClientIdStatus { + data class Active(val flowStateMachineFuture: CordaFuture>) : FlowWithClientIdStatus() + data class Removed(val flowId: StateMachineRunId, val succeeded: Boolean) : FlowWithClientIdStatus() +} + +data class FlowResultMetadata( + val status: Checkpoint.FlowStatus, + val clientId: String? +) \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/TransitionExecutorImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/TransitionExecutorImpl.kt index 8b22573421..a2fa8b5bb3 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/TransitionExecutorImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/TransitionExecutorImpl.kt @@ -1,6 +1,7 @@ package net.corda.node.services.statemachine import co.paralleluniverse.fibers.Suspendable +import net.corda.core.flows.ResultSerializationException import net.corda.core.utilities.contextLogger import net.corda.node.services.statemachine.transitions.FlowContinuation import net.corda.node.services.statemachine.transitions.TransitionResult @@ -73,22 +74,12 @@ class TransitionExecutorImpl( log.info("Error while executing $action, with event $event, erroring state", exception) } - // distinguish between a DatabaseTransactionException and an actual StateTransitionException - val stateTransitionOrDatabaseTransactionException = - if (exception is DatabaseTransactionException) { - // if the exception is a DatabaseTransactionException then it is not really a StateTransitionException - // it is actually an exception that previously broke a DatabaseTransaction and was suppressed by user code - // it was rethrown on [DatabaseTransaction.commit]. Unwrap the original exception and pass it to flow hospital - exception.cause - } else { - // Wrap the exception with [StateTransitionException] for handling by the flow hospital - StateTransitionException(action, event, exception) - } + val flowError = createError(exception, action, event) val newState = previousState.copy( checkpoint = previousState.checkpoint.copy( errorState = previousState.checkpoint.errorState.addErrors( - listOf(FlowError(secureRandom.nextLong(), stateTransitionOrDatabaseTransactionException)) + listOf(flowError) ) ), isFlowResumed = false @@ -121,4 +112,23 @@ class TransitionExecutorImpl( } } } + + private fun createError(e: Exception, action: Action, event: Event): FlowError { + // distinguish between a DatabaseTransactionException and an actual StateTransitionException + val stateTransitionOrOtherException: Throwable = + if (e is DatabaseTransactionException) { + // if the exception is a DatabaseTransactionException then it is not really a StateTransitionException + // it is actually an exception that previously broke a DatabaseTransaction and was suppressed by user code + // it was rethrown on [DatabaseTransaction.commit]. Unwrap the original exception and pass it to flow hospital + e.cause + } else if (e is ResultSerializationException) { + // We must not wrap a [ResultSerializationException] with a [StateTransitionException], + // because we will propagate the exception to rpc clients and [StateTransitionException] cannot be propagated to rpc clients. + e + } else { + // Wrap the exception with [StateTransitionException] for handling by the flow hospital + StateTransitionException(action, event, e) + } + return FlowError(secureRandom.nextLong(), stateTransitionOrOtherException) + } } \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DoRemainingWorkTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DoRemainingWorkTransition.kt index 7d56967c24..05872aea7f 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DoRemainingWorkTransition.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/DoRemainingWorkTransition.kt @@ -25,12 +25,13 @@ class DoRemainingWorkTransition( } // If the flow is clean check the FlowState + @Suppress("ThrowsCount") private fun cleanTransition(): TransitionResult { val flowState = startingState.checkpoint.flowState return when (flowState) { is FlowState.Unstarted -> UnstartedFlowTransition(context, startingState, flowState).transition() is FlowState.Started -> StartedFlowTransition(context, startingState, flowState).transition() - is FlowState.Completed -> throw IllegalStateException("Cannot transition a state with completed flow state.") + is FlowState.Finished -> throw IllegalStateException("Cannot transition a state with finished flow state.") is FlowState.Paused -> throw IllegalStateException("Cannot transition a state with paused flow state.") } } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/ErrorFlowTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/ErrorFlowTransition.kt index ba5ecaa6bd..4f4e6cd51e 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/ErrorFlowTransition.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/ErrorFlowTransition.kt @@ -61,9 +61,15 @@ class ErrorFlowTransition( if (!currentState.isRemoved) { val newCheckpoint = startingState.checkpoint.copy(status = Checkpoint.FlowStatus.FAILED) + val removeOrPersistCheckpoint = if (currentState.checkpoint.checkpointState.invocationContext.clientId == null) { + Action.RemoveCheckpoint(context.id) + } else { + Action.PersistCheckpoint(context.id, newCheckpoint.copy(flowState = FlowState.Finished), isCheckpointUpdate = currentState.isAnyCheckpointPersisted) + } + actions.addAll(arrayOf( Action.CreateTransaction, - Action.PersistCheckpoint(context.id, newCheckpoint, isCheckpointUpdate = currentState.isAnyCheckpointPersisted), + removeOrPersistCheckpoint, Action.PersistDeduplicationFacts(currentState.pendingDeduplicationHandlers), Action.ReleaseSoftLocks(context.id.uuid), Action.CommitTransaction, diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/KilledFlowTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/KilledFlowTransition.kt index 9c44f5988c..bc059668d3 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/KilledFlowTransition.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/KilledFlowTransition.kt @@ -44,7 +44,7 @@ class KilledFlowTransition( } // The checkpoint and soft locks are also removed directly in [StateMachineManager.killFlow] if (startingState.isAnyCheckpointPersisted) { - actions.add(Action.RemoveCheckpoint(context.id)) + actions.add(Action.RemoveCheckpoint(context.id, mayHavePersistentResults = true)) } actions.addAll( arrayOf( diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt index 511eca00d5..7ab0328e86 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt @@ -180,7 +180,7 @@ class TopLevelTransition( private fun suspendTransition(event: Event.Suspend): TransitionResult { return builder { val newCheckpoint = currentState.checkpoint.run { - val newCheckpointState = if (checkpointState.invocationContext.arguments.isNotEmpty()) { + val newCheckpointState = if (checkpointState.invocationContext.arguments!!.isNotEmpty()) { checkpointState.copy( invocationContext = checkpointState.invocationContext.copy(arguments = emptyList()), numberOfSuspends = checkpointState.numberOfSuspends + 1 @@ -234,7 +234,7 @@ class TopLevelTransition( checkpointState = checkpoint.checkpointState.copy( numberOfSuspends = checkpoint.checkpointState.numberOfSuspends + 1 ), - flowState = FlowState.Completed, + flowState = FlowState.Finished, result = event.returnValue, status = Checkpoint.FlowStatus.COMPLETED ), @@ -242,10 +242,22 @@ class TopLevelTransition( isFlowResumed = false, isRemoved = true ) - val allSourceSessionIds = checkpoint.checkpointState.sessions.keys + if (currentState.isAnyCheckpointPersisted) { - actions.add(Action.RemoveCheckpoint(context.id)) + if (currentState.checkpoint.checkpointState.invocationContext.clientId == null) { + actions.add(Action.RemoveCheckpoint(context.id)) + } else { + actions.add( + Action.PersistCheckpoint( + context.id, + currentState.checkpoint, + isCheckpointUpdate = currentState.isAnyCheckpointPersisted + ) + ) + } } + + val allSourceSessionIds = currentState.checkpoint.checkpointState.sessions.keys actions.addAll(arrayOf( Action.PersistDeduplicationFacts(pendingDeduplicationHandlers), Action.ReleaseSoftLocks(event.softLocksId), diff --git a/node/src/main/resources/migration/node-core.changelog-v19-keys.xml b/node/src/main/resources/migration/node-core.changelog-v19-keys.xml index 26359ecd2f..15460c310a 100644 --- a/node/src/main/resources/migration/node-core.changelog-v19-keys.xml +++ b/node/src/main/resources/migration/node-core.changelog-v19-keys.xml @@ -12,12 +12,18 @@ - + + + + + + + diff --git a/node/src/main/resources/migration/node-core.changelog-v19-postgres.xml b/node/src/main/resources/migration/node-core.changelog-v19-postgres.xml index 6aedc510b4..d7098a5076 100644 --- a/node/src/main/resources/migration/node-core.changelog-v19-postgres.xml +++ b/node/src/main/resources/migration/node-core.changelog-v19-postgres.xml @@ -49,14 +49,13 @@ - - + diff --git a/node/src/main/resources/migration/node-core.changelog-v19.xml b/node/src/main/resources/migration/node-core.changelog-v19.xml index 6b8c1e9b24..03165ce7d6 100644 --- a/node/src/main/resources/migration/node-core.changelog-v19.xml +++ b/node/src/main/resources/migration/node-core.changelog-v19.xml @@ -49,14 +49,13 @@ - - + diff --git a/node/src/test/kotlin/net/corda/node/services/persistence/DBCheckpointStorageTests.kt b/node/src/test/kotlin/net/corda/node/services/persistence/DBCheckpointStorageTests.kt index 7017a19e65..a7420977b9 100644 --- a/node/src/test/kotlin/net/corda/node/services/persistence/DBCheckpointStorageTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/persistence/DBCheckpointStorageTests.kt @@ -1,12 +1,15 @@ package net.corda.node.services.persistence +import net.corda.core.CordaRuntimeException import net.corda.core.context.InvocationContext import net.corda.core.context.InvocationOrigin import net.corda.core.flows.FlowLogic import net.corda.core.flows.StateMachineRunId import net.corda.core.internal.FlowIORequest import net.corda.core.internal.toSet +import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.SerializedBytes +import net.corda.core.serialization.deserialize import net.corda.core.serialization.internal.CheckpointSerializationDefaults import net.corda.core.serialization.internal.checkpointSerialize import net.corda.core.utilities.contextLogger @@ -38,7 +41,6 @@ import org.junit.After import org.junit.Assert.assertNotNull import org.junit.Assert.assertNull import org.junit.Before -import org.junit.Ignore import org.junit.Rule import org.junit.Test import java.time.Clock @@ -155,7 +157,7 @@ class DBCheckpointStorageTests { } database.transaction { assertEquals( - completedCheckpoint.copy(flowState = FlowState.Completed), + completedCheckpoint.copy(flowState = FlowState.Finished), checkpointStorage.checkpoints().single().deserialize() ) } @@ -181,51 +183,6 @@ class DBCheckpointStorageTests { } } - @Ignore - @Test(timeout = 300_000) - fun `removing a checkpoint deletes from all checkpoint tables`() { - val exception = IllegalStateException("I am a naughty exception") - val (id, checkpoint) = newCheckpoint() - val serializedFlowState = checkpoint.serializeFlowState() - database.transaction { - checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState, checkpoint.serializeCheckpointState()) - } - val updatedCheckpoint = checkpoint.addError(exception).copy(result = "The result") - val updatedSerializedFlowState = updatedCheckpoint.serializeFlowState() - database.transaction { checkpointStorage.updateCheckpoint(id, updatedCheckpoint, updatedSerializedFlowState, updatedCheckpoint.serializeCheckpointState()) } - - database.transaction { - assertEquals(1, findRecordsFromDatabase().size) - // The result not stored yet - assertEquals(0, findRecordsFromDatabase().size) - assertEquals(1, findRecordsFromDatabase().size) - // The saving of checkpoint blobs needs to be fixed - assertEquals(2, findRecordsFromDatabase().size) - assertEquals(1, findRecordsFromDatabase().size) - } - - database.transaction { - checkpointStorage.removeCheckpoint(id) - } - database.transaction { - assertThat(checkpointStorage.checkpoints()).isEmpty() - } - newCheckpointStorage() - database.transaction { - assertThat(checkpointStorage.checkpoints()).isEmpty() - } - - database.transaction { - assertEquals(0, findRecordsFromDatabase().size) - assertEquals(0, findRecordsFromDatabase().size) - assertEquals(0, findRecordsFromDatabase().size) - // The saving of checkpoint blobs needs to be fixed - assertEquals(1, findRecordsFromDatabase().size) - assertEquals(0, findRecordsFromDatabase().size) - } - } - - @Ignore @Test(timeout = 300_000) fun `removing a checkpoint when there is no result does not fail`() { val exception = IllegalStateException("I am a naughty exception") @@ -240,11 +197,9 @@ class DBCheckpointStorageTests { database.transaction { assertEquals(1, findRecordsFromDatabase().size) - // The result not stored yet assertEquals(0, findRecordsFromDatabase().size) assertEquals(1, findRecordsFromDatabase().size) - // The saving of checkpoint blobs needs to be fixed - assertEquals(2, findRecordsFromDatabase().size) + assertEquals(1, findRecordsFromDatabase().size) assertEquals(1, findRecordsFromDatabase().size) } @@ -263,8 +218,7 @@ class DBCheckpointStorageTests { assertEquals(0, findRecordsFromDatabase().size) assertEquals(0, findRecordsFromDatabase().size) assertEquals(0, findRecordsFromDatabase().size) - // The saving of checkpoint blobs needs to be fixed - assertEquals(1, findRecordsFromDatabase().size) + assertEquals(0, findRecordsFromDatabase().size) assertEquals(0, findRecordsFromDatabase().size) } } @@ -276,14 +230,13 @@ class DBCheckpointStorageTests { database.transaction { checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState, checkpoint.serializeCheckpointState()) } - val updatedCheckpoint = checkpoint.copy(result = "The result") + val updatedCheckpoint = checkpoint.copy(result = "The result", status = Checkpoint.FlowStatus.COMPLETED) val updatedSerializedFlowState = updatedCheckpoint.serializeFlowState() database.transaction { checkpointStorage.updateCheckpoint(id, updatedCheckpoint, updatedSerializedFlowState, updatedCheckpoint.serializeCheckpointState()) } database.transaction { assertEquals(0, findRecordsFromDatabase().size) - // The result not stored yet - assertEquals(0, findRecordsFromDatabase().size) + assertEquals(1, findRecordsFromDatabase().size) assertEquals(1, findRecordsFromDatabase().size) assertEquals(1, findRecordsFromDatabase().size) assertEquals(1, findRecordsFromDatabase().size) @@ -457,7 +410,6 @@ class DBCheckpointStorageTests { } @Test(timeout = 300_000) - @Ignore fun `update checkpoint with result information creates new result database record`() { val result = "This is the result" val (id, checkpoint) = newCheckpoint() @@ -466,7 +418,7 @@ class DBCheckpointStorageTests { database.transaction { checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState, checkpoint.serializeCheckpointState()) } - val updatedCheckpoint = checkpoint.copy(result = result) + val updatedCheckpoint = checkpoint.copy(result = result, status = Checkpoint.FlowStatus.COMPLETED) val updatedSerializedFlowState = updatedCheckpoint.serializeFlowState() database.transaction { checkpointStorage.updateCheckpoint(id, updatedCheckpoint, updatedSerializedFlowState, updatedCheckpoint.serializeCheckpointState()) @@ -481,64 +433,6 @@ class DBCheckpointStorageTests { } } - @Test(timeout = 300_000) - @Ignore - fun `update checkpoint with result information updates existing result database record`() { - val result = "This is the result" - val somehowThereIsANewResult = "Another result (which should not be possible!)" - val (id, checkpoint) = newCheckpoint() - val serializedFlowState = - checkpoint.serializeFlowState() - database.transaction { - checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState, checkpoint.serializeCheckpointState()) - } - val updatedCheckpoint = checkpoint.copy(result = result) - val updatedSerializedFlowState = updatedCheckpoint.serializeFlowState() - database.transaction { - checkpointStorage.updateCheckpoint(id, updatedCheckpoint, updatedSerializedFlowState, updatedCheckpoint.serializeCheckpointState()) - } - val updatedCheckpoint2 = checkpoint.copy(result = somehowThereIsANewResult) - val updatedSerializedFlowState2 = updatedCheckpoint2.serializeFlowState() - database.transaction { - checkpointStorage.updateCheckpoint(id, updatedCheckpoint2, updatedSerializedFlowState2, updatedCheckpoint2.serializeCheckpointState()) - } - database.transaction { - assertEquals( - somehowThereIsANewResult, - checkpointStorage.getCheckpoint(id)!!.deserialize().result - ) - assertNotNull(session.get(DBCheckpointStorage.DBFlowCheckpoint::class.java, id.uuid.toString()).result) - assertEquals(1, findRecordsFromDatabase().size) - } - } - - @Test(timeout = 300_000) - fun `removing result information from checkpoint deletes existing result database record`() { - val result = "This is the result" - val (id, checkpoint) = newCheckpoint() - val serializedFlowState = - checkpoint.serializeFlowState() - database.transaction { - checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState, checkpoint.serializeCheckpointState()) - } - val updatedCheckpoint = checkpoint.copy(result = result) - val updatedSerializedFlowState = updatedCheckpoint.serializeFlowState() - database.transaction { - checkpointStorage.updateCheckpoint(id, updatedCheckpoint, updatedSerializedFlowState, updatedCheckpoint.serializeCheckpointState()) - } - val updatedCheckpoint2 = checkpoint.copy(result = null) - val updatedSerializedFlowState2 = updatedCheckpoint2.serializeFlowState() - database.transaction { - checkpointStorage.updateCheckpoint(id, updatedCheckpoint2, updatedSerializedFlowState2, updatedCheckpoint2.serializeCheckpointState()) - } - database.transaction { - assertNull(checkpointStorage.getCheckpoint(id)!!.deserialize().result) - assertNull(session.get(DBCheckpointStorage.DBFlowCheckpoint::class.java, id.uuid.toString()).result) - assertEquals(0, findRecordsFromDatabase().size) - } - } - - @Ignore @Test(timeout = 300_000) fun `update checkpoint with error information creates a new error database record`() { val exception = IllegalStateException("I am a naughty exception") @@ -557,58 +451,12 @@ class DBCheckpointStorageTests { assertNotNull(exceptionDetails) assertEquals(exception::class.java.name, exceptionDetails!!.type) assertEquals(exception.message, exceptionDetails.message) - assertEquals(1, findRecordsFromDatabase().size) - } - } - - @Ignore - @Test(timeout = 300_000) - fun `update checkpoint with new error information updates the existing error database record`() { - val illegalStateException = IllegalStateException("I am a naughty exception") - val illegalArgumentException = IllegalArgumentException("I am a very naughty exception") - val (id, checkpoint) = newCheckpoint() - val serializedFlowState = checkpoint.serializeFlowState() - database.transaction { - checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState, checkpoint.serializeCheckpointState()) - } - val updatedCheckpoint1 = checkpoint.addError(illegalStateException) - val updatedSerializedFlowState1 = updatedCheckpoint1.serializeFlowState() - database.transaction { checkpointStorage.updateCheckpoint(id, updatedCheckpoint1, updatedSerializedFlowState1, updatedCheckpoint1.serializeCheckpointState()) } - // Set back to clean - val updatedCheckpoint2 = checkpoint.addError(illegalArgumentException) - val updatedSerializedFlowState2 = updatedCheckpoint2.serializeFlowState() - database.transaction { checkpointStorage.updateCheckpoint(id, updatedCheckpoint2, updatedSerializedFlowState2, updatedCheckpoint2.serializeCheckpointState()) } - database.transaction { - assertTrue(checkpointStorage.getCheckpoint(id)!!.deserialize().errorState is ErrorState.Clean) - val exceptionDetails = session.get(DBCheckpointStorage.DBFlowCheckpoint::class.java, id.uuid.toString()).exceptionDetails - assertNotNull(exceptionDetails) - assertEquals(illegalArgumentException::class.java.name, exceptionDetails!!.type) - assertEquals(illegalArgumentException.message, exceptionDetails.message) - assertEquals(1, findRecordsFromDatabase().size) - } - } - - @Test(timeout = 300_000) - fun `clean checkpoints delete the error record from the database`() { - val exception = IllegalStateException("I am a naughty exception") - val (id, checkpoint) = newCheckpoint() - val serializedFlowState = checkpoint.serializeFlowState() - database.transaction { - checkpointStorage.addCheckpoint(id, checkpoint, serializedFlowState, checkpoint.serializeCheckpointState()) - } - val updatedCheckpoint = checkpoint.addError(exception) - val updatedSerializedFlowState = updatedCheckpoint.serializeFlowState() - database.transaction { checkpointStorage.updateCheckpoint(id, updatedCheckpoint, updatedSerializedFlowState, updatedCheckpoint.serializeCheckpointState()) } - database.transaction { - // Checkpoint always returns clean error state when retrieved via [getCheckpoint] - assertTrue(checkpointStorage.getCheckpoint(id)!!.deserialize().errorState is ErrorState.Clean) - } - // Set back to clean - database.transaction { checkpointStorage.updateCheckpoint(id, checkpoint, serializedFlowState, checkpoint.serializeCheckpointState()) } - database.transaction { - assertTrue(checkpointStorage.getCheckpoint(id)!!.deserialize().errorState is ErrorState.Clean) - assertNull(session.get(DBCheckpointStorage.DBFlowCheckpoint::class.java, id.uuid.toString()).exceptionDetails) - assertEquals(0, findRecordsFromDatabase().size) + val deserializedException = exceptionDetails.value?.let { SerializedBytes(it) }?.deserialize(context = SerializationDefaults.STORAGE_CONTEXT) + // IllegalStateException does not implement [CordaThrowable] therefore gets deserialized as a [CordaRuntimeException] + assertTrue(deserializedException is CordaRuntimeException) + val cordaRuntimeException = deserializedException as CordaRuntimeException + assertEquals(IllegalStateException::class.java.name, cordaRuntimeException.originalExceptionClassName) + assertEquals("I am a naughty exception", cordaRuntimeException.originalMessage!!) } } @@ -701,7 +549,6 @@ class DBCheckpointStorageTests { } } - @Ignore @Test(timeout = 300_000) fun `-not greater than DBCheckpointStorage_MAX_STACKTRACE_LENGTH- stackTrace gets persisted as a whole`() { val smallerDummyStackTrace = ArrayList() @@ -734,7 +581,6 @@ class DBCheckpointStorageTests { } } - @Ignore @Test(timeout = 300_000) fun `-greater than DBCheckpointStorage_MAX_STACKTRACE_LENGTH- stackTrace gets truncated to MAX_LENGTH_VARCHAR, and persisted`() { val smallerDummyStackTrace = ArrayList() @@ -780,9 +626,9 @@ class DBCheckpointStorageTests { private fun iterationsBasedOnLineSeparatorLength() = when { System.getProperty("line.separator").length == 1 -> // Linux or Mac - 158 + 78 System.getProperty("line.separator").length == 2 -> // Windows - 152 + 75 else -> throw IllegalStateException("Unknown line.separator") } @@ -853,7 +699,7 @@ class DBCheckpointStorageTests { } @Test(timeout = 300_000) - fun `updateCheckpoint setting DBFlowCheckpoint_blob to null whenever flow fails or gets hospitalized doesn't break ORM relationship`() { + fun `'updateCheckpoint' setting 'DBFlowCheckpoint_blob' to null whenever flow fails or gets hospitalized doesn't break ORM relationship`() { val (id, checkpoint) = newCheckpoint() val serializedFlowState = checkpoint.flowState.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT) @@ -862,8 +708,8 @@ class DBCheckpointStorageTests { } database.transaction { - val paused = changeStatus(checkpoint, Checkpoint.FlowStatus.FAILED) // the exact same behaviour applies for 'HOSPITALIZED' as well - checkpointStorage.updateCheckpoint(id, paused.checkpoint, serializedFlowState, paused.checkpoint.serializeCheckpointState()) + val failed = checkpoint.addError(IllegalStateException()) // the exact same behaviour applies for 'HOSPITALIZED' as well + checkpointStorage.updateCheckpoint(id, failed, serializedFlowState, failed.serializeCheckpointState()) } database.transaction { @@ -908,6 +754,43 @@ class DBCheckpointStorageTests { } } + @Test(timeout = 300_000) + fun `'getFinishedFlowsResultsMetadata' fetches flows results metadata for finished flows only`() { + val (_, checkpoint) = newCheckpoint(1) + val runnable = changeStatus(checkpoint, Checkpoint.FlowStatus.RUNNABLE) + val hospitalized = changeStatus(checkpoint, Checkpoint.FlowStatus.HOSPITALIZED) + val completed = changeStatus(checkpoint, Checkpoint.FlowStatus.COMPLETED) + val failed = changeStatus(checkpoint, Checkpoint.FlowStatus.FAILED) + val killed = changeStatus(checkpoint, Checkpoint.FlowStatus.KILLED) + val paused = changeStatus(checkpoint, Checkpoint.FlowStatus.PAUSED) + + database.transaction { + val serializedFlowState = + checkpoint.flowState.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT) + + checkpointStorage.addCheckpoint(runnable.id, runnable.checkpoint, serializedFlowState, runnable.checkpoint.serializeCheckpointState()) + checkpointStorage.addCheckpoint(hospitalized.id, hospitalized.checkpoint, serializedFlowState, hospitalized.checkpoint.serializeCheckpointState()) + checkpointStorage.addCheckpoint(completed.id, completed.checkpoint, serializedFlowState, completed.checkpoint.serializeCheckpointState()) + checkpointStorage.addCheckpoint(failed.id, failed.checkpoint, serializedFlowState, failed.checkpoint.serializeCheckpointState()) + checkpointStorage.addCheckpoint(killed.id, killed.checkpoint, serializedFlowState, killed.checkpoint.serializeCheckpointState()) + checkpointStorage.addCheckpoint(paused.id, paused.checkpoint, serializedFlowState, paused.checkpoint.serializeCheckpointState()) + } + + val checkpointsInDb = database.transaction { + checkpointStorage.getCheckpoints().toList().size + } + + val resultsMetadata = database.transaction { + checkpointStorage.getFinishedFlowsResultsMetadata() + }.toList() + + assertEquals(6, checkpointsInDb) + + val finishedStatuses = resultsMetadata.map { it.second.status } + assertTrue(Checkpoint.FlowStatus.COMPLETED in finishedStatuses) + assertTrue(Checkpoint.FlowStatus.FAILED in finishedStatuses) + } + data class IdAndCheckpoint(val id: StateMachineRunId, val checkpoint: Checkpoint) private fun changeStatus(oldCheckpoint: Checkpoint, status: Checkpoint.FlowStatus): IdAndCheckpoint { @@ -970,7 +853,8 @@ class DBCheckpointStorageTests { exception ) ), 0, false - ) + ), + status = Checkpoint.FlowStatus.FAILED ) } diff --git a/node/src/test/kotlin/net/corda/node/services/rpc/CheckpointDumperImplTest.kt b/node/src/test/kotlin/net/corda/node/services/rpc/CheckpointDumperImplTest.kt index 4037bd80f0..d8644ccede 100644 --- a/node/src/test/kotlin/net/corda/node/services/rpc/CheckpointDumperImplTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/rpc/CheckpointDumperImplTest.kt @@ -126,7 +126,7 @@ class CheckpointDumperImplTest { checkpointStorage.addCheckpoint(id, checkpoint, serializeFlowState(checkpoint), serializeCheckpointState(checkpoint)) } val newCheckpoint = checkpoint.copy( - flowState = FlowState.Completed, + flowState = FlowState.Finished, status = Checkpoint.FlowStatus.COMPLETED ) database.transaction { diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowClientIdTests.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowClientIdTests.kt new file mode 100644 index 0000000000..c945b55aa4 --- /dev/null +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowClientIdTests.kt @@ -0,0 +1,809 @@ +package net.corda.node.services.statemachine + +import co.paralleluniverse.fibers.Suspendable +import co.paralleluniverse.strands.concurrent.Semaphore +import net.corda.core.CordaRuntimeException +import net.corda.core.flows.FlowLogic +import net.corda.core.internal.FlowIORequest +import net.corda.core.internal.FlowStateMachineHandle +import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.seconds +import net.corda.node.services.persistence.DBCheckpointStorage +import net.corda.testing.core.ALICE_NAME +import net.corda.testing.node.InMemoryMessagingNetwork +import net.corda.testing.node.internal.DUMMY_CONTRACTS_CORDAPP +import net.corda.testing.node.internal.FINANCE_CONTRACTS_CORDAPP +import net.corda.testing.node.internal.InternalMockNetwork +import net.corda.testing.node.internal.InternalMockNodeParameters +import net.corda.testing.node.internal.TestStartedNode +import net.corda.testing.node.internal.startFlow +import net.corda.testing.node.internal.startFlowWithClientId +import net.corda.core.flows.KilledFlowException +import org.assertj.core.api.Assertions.assertThatExceptionOfType +import org.junit.After +import org.junit.Assert +import org.junit.Before +import org.junit.Test +import rx.Observable +import java.lang.IllegalArgumentException +import java.sql.SQLTransientConnectionException +import java.util.UUID +import java.util.concurrent.atomic.AtomicInteger +import kotlin.IllegalStateException +import kotlin.concurrent.thread +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertNotEquals +import kotlin.test.assertNull +import kotlin.test.assertTrue + +class FlowClientIdTests { + + private lateinit var mockNet: InternalMockNetwork + private lateinit var aliceNode: TestStartedNode + + @Before + fun setUpMockNet() { + mockNet = InternalMockNetwork( + cordappsForAllNodes = listOf(DUMMY_CONTRACTS_CORDAPP, FINANCE_CONTRACTS_CORDAPP), + servicePeerAllocationStrategy = InMemoryMessagingNetwork.ServicePeerAllocationStrategy.RoundRobin() + ) + + aliceNode = mockNet.createNode(InternalMockNodeParameters(legalName = ALICE_NAME)) + } + + @After + fun cleanUp() { + mockNet.stopNodes() + ResultFlow.hook = null + ResultFlow.suspendableHook = null + UnSerializableResultFlow.firstRun = true + SingleThreadedStateMachineManager.beforeClientIDCheck = null + SingleThreadedStateMachineManager.onClientIDNotFound = null + SingleThreadedStateMachineManager.onCallingStartFlowInternal = null + SingleThreadedStateMachineManager.onStartFlowInternalThrewAndAboutToRemove = null + + StaffedFlowHospital.onFlowErrorPropagated.clear() + } + + @Test(timeout = 300_000) + fun `no new flow starts if the client id provided pre exists`() { + var counter = 0 + ResultFlow.hook = { counter++ } + val clientId = UUID.randomUUID().toString() + aliceNode.services.startFlowWithClientId(clientId, ResultFlow(5)).resultFuture.getOrThrow() + aliceNode.services.startFlowWithClientId(clientId, ResultFlow(5)).resultFuture.getOrThrow() + Assert.assertEquals(1, counter) + } + + @Test(timeout = 300_000) + fun `flow's result gets persisted if the flow is started with a client id`() { + val clientId = UUID.randomUUID().toString() + aliceNode.services.startFlowWithClientId(clientId, ResultFlow(10)).resultFuture.getOrThrow() + + aliceNode.database.transaction { + assertEquals(1, findRecordsFromDatabase().size) + } + } + + @Test(timeout = 300_000) + fun `flow's result is retrievable after flow's lifetime, when flow is started with a client id - different parameters are ignored`() { + val clientId = UUID.randomUUID().toString() + val handle0 = aliceNode.services.startFlowWithClientId(clientId, ResultFlow(5)) + val clientId0 = handle0.clientId + val flowId0 = handle0.id + val result0 = handle0.resultFuture.getOrThrow() + + val handle1 = aliceNode.services.startFlowWithClientId(clientId, ResultFlow(10)) + val clientId1 = handle1.clientId + val flowId1 = handle1.id + val result1 = handle1.resultFuture.getOrThrow() + + Assert.assertEquals(clientId0, clientId1) + Assert.assertEquals(flowId0, flowId1) + Assert.assertEquals(result0, result1) + } + + @Test(timeout = 300_000) + fun `if flow's result is not found in the database an IllegalStateException is thrown`() { + val clientId = UUID.randomUUID().toString() + val handle0 = aliceNode.services.startFlowWithClientId(clientId, ResultFlow(5)) + val flowId0 = handle0.id + handle0.resultFuture.getOrThrow() + + // manually remove the checkpoint (including DBFlowResult) from the database + aliceNode.database.transaction { + aliceNode.internals.checkpointStorage.removeCheckpoint(flowId0) + } + + assertFailsWith { + aliceNode.services.startFlowWithClientId(clientId, ResultFlow(5)) + } + } + + @Test(timeout = 300_000) + fun `flow returning null gets retrieved after flow's lifetime when started with client id`() { + val clientId = UUID.randomUUID().toString() + aliceNode.services.startFlowWithClientId(clientId, ResultFlow(null)).resultFuture.getOrThrow() + + val flowResult = aliceNode.services.startFlowWithClientId(clientId, ResultFlow(null)).resultFuture.getOrThrow() + assertNull(flowResult) + } + + @Test(timeout = 300_000) + fun `flow returning Unit gets retrieved after flow's lifetime when started with client id`() { + val clientId = UUID.randomUUID().toString() + aliceNode.services.startFlowWithClientId(clientId, ResultFlow(Unit)).resultFuture.getOrThrow() + + val flowResult = aliceNode.services.startFlowWithClientId(clientId, ResultFlow(Unit)).resultFuture.getOrThrow() + assertEquals(Unit, flowResult) + } + + @Test(timeout = 300_000) + fun `flow's result is available if reconnect after flow had retried from previous checkpoint, when flow is started with a client id`() { + var firstRun = true + ResultFlow.hook = { + if (firstRun) { + firstRun = false + throw SQLTransientConnectionException("connection is not available") + } + } + + val clientId = UUID.randomUUID().toString() + val result0 = aliceNode.services.startFlowWithClientId(clientId, ResultFlow(5)).resultFuture.getOrThrow() + val result1 = aliceNode.services.startFlowWithClientId(clientId, ResultFlow(5)).resultFuture.getOrThrow() + Assert.assertEquals(result0, result1) + } + + @Test(timeout = 300_000) + fun `flow's result is available if reconnect during flow's retrying from previous checkpoint, when flow is started with a client id`() { + var firstRun = true + val waitForSecondRequest = Semaphore(0) + val waitUntilFlowHasRetried = Semaphore(0) + ResultFlow.suspendableHook = object : FlowLogic() { + @Suspendable + override fun call() { + if (firstRun) { + firstRun = false + throw SQLTransientConnectionException("connection is not available") + } else { + waitUntilFlowHasRetried.release() + waitForSecondRequest.acquire() + } + } + } + + var result1 = 0 + val clientId = UUID.randomUUID().toString() + val handle0 = aliceNode.services.startFlowWithClientId(clientId, ResultFlow(5)) + waitUntilFlowHasRetried.acquire() + val t = thread { result1 = aliceNode.services.startFlowWithClientId(clientId, ResultFlow(5)).resultFuture.getOrThrow() } + + Thread.sleep(1000) + waitForSecondRequest.release() + val result0 = handle0.resultFuture.getOrThrow() + t.join() + Assert.assertEquals(result0, result1) + } + + @Test(timeout = 300_000) + fun `failing flow's exception is available after flow's lifetime if flow is started with a client id`() { + var counter = 0 + ResultFlow.hook = { + counter++ + throw IllegalStateException() + } + val clientId = UUID.randomUUID().toString() + + var flowHandle0: FlowStateMachineHandle? = null + assertFailsWith { + flowHandle0 = aliceNode.services.startFlowWithClientId(clientId, ResultFlow(5)) + flowHandle0!!.resultFuture.getOrThrow() + } + + var flowHandle1: FlowStateMachineHandle? = null + assertFailsWith { + flowHandle1 = aliceNode.services.startFlowWithClientId(clientId, ResultFlow(5)) + flowHandle1!!.resultFuture.getOrThrow() + } + + // Assert no new flow has started + assertEquals(flowHandle0!!.id, flowHandle1!!.id) + assertEquals(1, counter) + } + + @Test(timeout = 300_000) + fun `failed flow's exception is available after flow's lifetime on node start if flow was started with a client id`() { + var counter = 0 + ResultFlow.hook = { + counter++ + throw IllegalStateException() + } + val clientId = UUID.randomUUID().toString() + + var flowHandle0: FlowStateMachineHandle? = null + assertFailsWith { + flowHandle0 = aliceNode.services.startFlowWithClientId(clientId, ResultFlow(5)) + flowHandle0!!.resultFuture.getOrThrow() + } + + aliceNode = mockNet.restartNode(aliceNode) + + var flowHandle1: FlowStateMachineHandle? = null + assertFailsWith { + flowHandle1 = aliceNode.services.startFlowWithClientId(clientId, ResultFlow(5)) + flowHandle1!!.resultFuture.getOrThrow() + } + + // Assert no new flow has started + assertEquals(flowHandle0!!.id, flowHandle1!!.id) + assertEquals(1, counter) + } + + @Test(timeout = 300_000) + fun `killing a flow, removes the flow from the client id mapping`() { + var counter = 0 + val flowIsRunning = Semaphore(0) + val waitUntilFlowIsRunning = Semaphore(0) + ResultFlow.suspendableHook = object : FlowLogic() { + var firstRun = true + + @Suspendable + override fun call() { + ++counter + if (firstRun) { + firstRun = false + waitUntilFlowIsRunning.release() + flowIsRunning.acquire() + } + } + } + val clientId = UUID.randomUUID().toString() + + var flowHandle0: FlowStateMachineHandle? = null + assertFailsWith { + flowHandle0 = aliceNode.services.startFlowWithClientId(clientId, ResultFlow(5)) + waitUntilFlowIsRunning.acquire() + aliceNode.internals.smm.killFlow(flowHandle0!!.id) + flowIsRunning.release() + flowHandle0!!.resultFuture.getOrThrow() + } + + // a new flow will start since the client id mapping was removed when flow got killed + val flowHandle1: FlowStateMachineHandle = aliceNode.services.startFlowWithClientId(clientId, ResultFlow(5)) + flowHandle1.resultFuture.getOrThrow() + + assertNotEquals(flowHandle0!!.id, flowHandle1.id) + assertEquals(2, counter) + } + + @Test(timeout = 300_000) + fun `flow's client id mapping gets removed upon request`() { + val clientId = UUID.randomUUID().toString() + var counter = 0 + ResultFlow.hook = { counter++ } + val flowHandle0 = aliceNode.services.startFlowWithClientId(clientId, ResultFlow(5)) + flowHandle0.resultFuture.getOrThrow(20.seconds) + val removed = aliceNode.smm.removeClientId(clientId) + // On new request with clientId, after the same clientId was removed, a brand new flow will start with that clientId + val flowHandle1 = aliceNode.services.startFlowWithClientId(clientId, ResultFlow(5)) + flowHandle1.resultFuture.getOrThrow(20.seconds) + + assertTrue(removed) + Assert.assertNotEquals(flowHandle0.id, flowHandle1.id) + Assert.assertEquals(flowHandle0.clientId, flowHandle1.clientId) + Assert.assertEquals(2, counter) + } + + @Test(timeout = 300_000) + fun `removing a client id result clears resources properly`() { + val clientId = UUID.randomUUID().toString() + aliceNode.services.startFlowWithClientId(clientId, ResultFlow(5)).resultFuture.getOrThrow() + // assert database status before remove + aliceNode.services.database.transaction { + assertEquals(1, findRecordsFromDatabase().size) + assertEquals(1, findRecordsFromDatabase().size) + assertEquals(1, findRecordsFromDatabase().size) + assertEquals(1, findRecordsFromDatabase().size) + } + + aliceNode.smm.removeClientId(clientId) + + // assert database status after remove + aliceNode.services.database.transaction { + assertEquals(0, findRecordsFromDatabase().size) + assertEquals(0, findRecordsFromDatabase().size) + assertEquals(0, findRecordsFromDatabase().size) + assertEquals(0, findRecordsFromDatabase().size) + } + } + + @Test(timeout=300_000) + fun `removing a client id exception clears resources properly`() { + val clientId = UUID.randomUUID().toString() + ResultFlow.hook = { throw IllegalStateException() } + assertFailsWith { + aliceNode.services.startFlowWithClientId(clientId, ResultFlow(Unit)).resultFuture.getOrThrow() + } + // assert database status before remove + aliceNode.services.database.transaction { + assertEquals(1, findRecordsFromDatabase().size) + assertEquals(1, findRecordsFromDatabase().size) + assertEquals(1, findRecordsFromDatabase().size) + assertEquals(1, findRecordsFromDatabase().size) + } + + aliceNode.smm.removeClientId(clientId) + + // assert database status after remove + aliceNode.services.database.transaction { + assertEquals(0, findRecordsFromDatabase().size) + assertEquals(0, findRecordsFromDatabase().size) + assertEquals(0, findRecordsFromDatabase().size) + assertEquals(0, findRecordsFromDatabase().size) + } + } + + @Test(timeout=300_000) + fun `flow's client id mapping can only get removed once the flow gets removed`() { + val clientId = UUID.randomUUID().toString() + var tries = 0 + val maxTries = 10 + var failedRemovals = 0 + val semaphore = Semaphore(0) + ResultFlow.suspendableHook = object : FlowLogic() { + @Suspendable + override fun call() { + semaphore.acquire() + } + } + val flowHandle0 = aliceNode.services.startFlowWithClientId(clientId, ResultFlow(5)) + + var removed = false + while (!removed) { + removed = aliceNode.smm.removeClientId(clientId) + if (!removed) ++failedRemovals + ++tries + if (tries >= maxTries) { + semaphore.release() + flowHandle0.resultFuture.getOrThrow(20.seconds) + } + } + + assertTrue(removed) + Assert.assertEquals(maxTries, failedRemovals) + } + + @Test(timeout = 300_000) + fun `only one flow starts upon concurrent requests with the same client id`() { + val requests = 2 + val counter = AtomicInteger(0) + val resultsCounter = AtomicInteger(0) + ResultFlow.hook = { counter.incrementAndGet() } + //(aliceNode.smm as SingleThreadedStateMachineManager).concurrentRequests = true + + val clientId = UUID.randomUUID().toString() + val threads = arrayOfNulls(requests) + for (i in 0 until requests) { + threads[i] = Thread { + val result = aliceNode.services.startFlowWithClientId(clientId, ResultFlow(5)).resultFuture.getOrThrow() + resultsCounter.addAndGet(result) + } + } + + val beforeCount = AtomicInteger(0) + SingleThreadedStateMachineManager.beforeClientIDCheck = { + beforeCount.incrementAndGet() + } + + val clientIdNotFound = Semaphore(0) + val waitUntilClientIdNotFound = Semaphore(0) + SingleThreadedStateMachineManager.onClientIDNotFound = { + // Only the first request should reach this point + waitUntilClientIdNotFound.release() + clientIdNotFound.acquire() + } + + for (i in 0 until requests) { + threads[i]!!.start() + } + + waitUntilClientIdNotFound.acquire() + for (i in 0 until requests) { + clientIdNotFound.release() + } + + for (thread in threads) { + thread!!.join() + } + Assert.assertEquals(1, counter.get()) + Assert.assertEquals(2, beforeCount.get()) + Assert.assertEquals(10, resultsCounter.get()) + } + + @Test(timeout = 300_000) + fun `on node start -running- flows with client id are hook-able`() { + val clientId = UUID.randomUUID().toString() + var firstRun = true + val flowIsRunning = Semaphore(0) + val waitUntilFlowIsRunning = Semaphore(0) + + ResultFlow.suspendableHook = object : FlowLogic() { + @Suspendable + override fun call() { + waitUntilFlowIsRunning.release() + + if (firstRun) { + firstRun = false + // high sleeping time doesn't matter because the fiber will get an [Event.SoftShutdown] on node restart, which will wake up the fiber + sleep(100.seconds, maySkipCheckpoint = true) + } + + flowIsRunning.acquire() // make flow wait here to impersonate a running flow + } + } + + val flowHandle0 = aliceNode.services.startFlowWithClientId(clientId, ResultFlow(5)) + waitUntilFlowIsRunning.acquire() + val aliceNode = mockNet.restartNode(aliceNode) + + waitUntilFlowIsRunning.acquire() + // Re-hook a running flow + val flowHandle1 = aliceNode.services.startFlowWithClientId(clientId, ResultFlow(5)) + flowIsRunning.release() + + Assert.assertEquals(flowHandle0.id, flowHandle1.id) + Assert.assertEquals(clientId, flowHandle1.clientId) + Assert.assertEquals(5, flowHandle1.resultFuture.getOrThrow(20.seconds)) + } + + // the below test has to be made available only in ENT +// @Test(timeout=300_000) +// fun `on node restart -paused- flows with client id are hook-able`() { +// val clientId = UUID.randomUUID().toString() +// var noSecondFlowWasSpawned = 0 +// var firstRun = true +// var firstFiber: Fiber? = null +// val flowIsRunning = Semaphore(0) +// val waitUntilFlowIsRunning = Semaphore(0) +// +// ResultFlow.suspendableHook = object : FlowLogic() { +// @Suspendable +// override fun call() { +// if (firstRun) { +// firstFiber = Fiber.currentFiber() +// firstRun = false +// } +// +// waitUntilFlowIsRunning.release() +// try { +// flowIsRunning.acquire() // make flow wait here to impersonate a running flow +// } catch (e: InterruptedException) { +// flowIsRunning.release() +// throw e +// } +// +// noSecondFlowWasSpawned++ +// } +// } +// +// val flowHandle0 = aliceNode.services.startFlowWithClientId(clientId, ResultFlow(5)) +// waitUntilFlowIsRunning.acquire() +// aliceNode.internals.acceptableLiveFiberCountOnStop = 1 +// // Pause the flow on node restart +// val aliceNode = mockNet.restartNode(aliceNode, +// InternalMockNodeParameters( +// configOverrides = { +// doReturn(StateMachineManager.StartMode.Safe).whenever(it).smmStartMode +// } +// )) +// // Blow up the first fiber running our flow as it is leaked here, on normal node shutdown that fiber should be gone +// firstFiber!!.interrupt() +// +// // Re-hook a paused flow +// val flowHandle1 = aliceNode.services.startFlowWithClientId(clientId, ResultFlow(5)) +// +// Assert.assertEquals(flowHandle0.id, flowHandle1.id) +// Assert.assertEquals(clientId, flowHandle1.clientId) +// aliceNode.smm.unPauseFlow(flowHandle1.id) +// Assert.assertEquals(5, flowHandle1.resultFuture.getOrThrow(20.seconds)) +// Assert.assertEquals(1, noSecondFlowWasSpawned) +// } + + @Test(timeout = 300_000) + fun `on node start -completed- flows with client id are hook-able`() { + val clientId = UUID.randomUUID().toString() + var counter = 0 + ResultFlow.hook = { + counter++ + } + + val flowHandle0 = aliceNode.services.startFlowWithClientId(clientId, ResultFlow(5)) + flowHandle0.resultFuture.getOrThrow() + val aliceNode = mockNet.restartNode(aliceNode) + + // Re-hook a completed flow + val flowHandle1 = aliceNode.services.startFlowWithClientId(clientId, ResultFlow(5)) + val result1 = flowHandle1.resultFuture.getOrThrow(20.seconds) + + Assert.assertEquals(1, counter) // assert flow has run only once + Assert.assertEquals(flowHandle0.id, flowHandle1.id) + Assert.assertEquals(clientId, flowHandle1.clientId) + Assert.assertEquals(5, result1) + } + + @Test(timeout = 300_000) + fun `On 'startFlowInternal' throwing, subsequent request with same client id does not get de-duplicated and starts a new flow`() { + val clientId = UUID.randomUUID().toString() + var firstRequest = true + SingleThreadedStateMachineManager.onCallingStartFlowInternal = { + if (firstRequest) { + firstRequest = false + throw IllegalStateException("Yet another one") + } + } + var counter = 0 + ResultFlow.hook = { counter++ } + + assertFailsWith { + aliceNode.services.startFlowWithClientId(clientId, ResultFlow(5)) + } + + val flowHandle1 = aliceNode.services.startFlowWithClientId(clientId, ResultFlow(5)) + flowHandle1.resultFuture.getOrThrow(20.seconds) + + assertEquals(clientId, flowHandle1.clientId) + assertEquals(1, counter) + } + + // the below test has to be made available only in ENT +// @Test(timeout=300_000) +// fun `On 'startFlowInternal' throwing, subsequent request with same client hits the time window in which the previous request was about to remove the client id mapping`() { +// val clientId = UUID.randomUUID().toString() +// var firstRequest = true +// SingleThreadedStateMachineManager.onCallingStartFlowInternal = { +// if (firstRequest) { +// firstRequest = false +// throw IllegalStateException("Yet another one") +// } +// } +// +// val wait = Semaphore(0) +// val waitForFirstRequest = Semaphore(0) +// SingleThreadedStateMachineManager.onStartFlowInternalThrewAndAboutToRemove = { +// waitForFirstRequest.release() +// wait.acquire() +// Thread.sleep(10000) +// } +// var counter = 0 +// ResultFlow.hook = { counter++ } +// +// thread { +// assertFailsWith { +// aliceNode.services.startFlowWithClientId(clientId, ResultFlow(5)) +// } +// } +// +// waitForFirstRequest.acquire() +// wait.release() +// assertFailsWith { +// // the subsequent request will not hang on a never ending future, because the previous request ,upon failing, will also complete the future exceptionally +// aliceNode.services.startFlowWithClientId(clientId, ResultFlow(5)) +// } +// +// assertEquals(0, counter) +// } + + @Test(timeout = 300_000) + fun `if flow fails to serialize its result then the result gets converted to an exception result`() { + val clientId = UUID.randomUUID().toString() + assertFailsWith { + aliceNode.services.startFlowWithClientId(clientId, ResultFlow>(Observable.empty())).resultFuture.getOrThrow() + } + + // flow has failed to serialize its result => table 'node_flow_results' should be empty, 'node_flow_exceptions' should get one row instead + aliceNode.services.database.transaction { + val checkpointStatus = findRecordsFromDatabase().single().status + assertEquals(Checkpoint.FlowStatus.FAILED, checkpointStatus) + assertEquals(0, findRecordsFromDatabase().size) + assertEquals(1, findRecordsFromDatabase().size) + } + + assertFailsWith { + aliceNode.services.startFlowWithClientId(clientId, ResultFlow>(Observable.empty())).resultFuture.getOrThrow() + } + } + + /** + * The below test does not follow a valid path. Normally it should error and propagate. + * However, we want to assert that a flow that fails to serialize its result its retriable. + */ + @Test(timeout = 300_000) + fun `flow failing to serialize its result gets retried and succeeds if returning a different result`() { + val clientId = UUID.randomUUID().toString() + // before the hospital schedules a [Event.Error] we manually schedule a [Event.RetryFlowFromSafePoint] + StaffedFlowHospital.onFlowErrorPropagated.add { _, _ -> + FlowStateMachineImpl.currentStateMachine()!!.scheduleEvent(Event.RetryFlowFromSafePoint) + } + val result = aliceNode.services.startFlowWithClientId(clientId, UnSerializableResultFlow()).resultFuture.getOrThrow() + assertEquals(5, result) + } + + @Test(timeout = 300_000) + fun `flow that fails does not retain its checkpoint nor its exception in the database if not started with a client id`() { + assertFailsWith { + aliceNode.services.startFlow(ExceptionFlow { IllegalStateException("another exception") }).resultFuture.getOrThrow() + } + + aliceNode.services.database.transaction { + assertEquals(0, findRecordsFromDatabase().size) + assertEquals(0, findRecordsFromDatabase().size) + assertEquals(0, findRecordsFromDatabase().size) + assertEquals(0, findRecordsFromDatabase().size) + } + } + + @Test(timeout = 300_000) + fun `subsequent request to failed flow that cannot find a 'DBFlowException' in the database, fails with 'IllegalStateException'`() { + ResultFlow.hook = { + // just throwing a different exception from the one expected out of startFlowWithClientId second call below ([IllegalStateException]) + // to be sure [IllegalStateException] gets thrown from [DBFlowException] that is missing + throw IllegalArgumentException() + } + val clientId = UUID.randomUUID().toString() + + var flowHandle0: FlowStateMachineHandle? = null + assertFailsWith { + flowHandle0 = aliceNode.services.startFlowWithClientId(clientId, ResultFlow(5)) + flowHandle0!!.resultFuture.getOrThrow() + } + + // manually remove [DBFlowException] from the database to impersonate missing [DBFlowException] + val removed = aliceNode.services.database.transaction { + aliceNode.internals.checkpointStorage.removeFlowException(flowHandle0!!.id) + } + assertTrue(removed) + + val e = assertFailsWith { + aliceNode.services.startFlowWithClientId(clientId, ResultFlow(5)).resultFuture.getOrThrow() + } + + assertEquals("Flow's ${flowHandle0!!.id} exception was not found in the database. Something is very wrong.", e.message) + } + + @Test(timeout=300_000) + fun `completed flow started with a client id nulls its flow state in database after its lifetime`() { + val clientId = UUID.randomUUID().toString() + val flowHandle = aliceNode.services.startFlowWithClientId(clientId, ResultFlow(5)) + flowHandle.resultFuture.getOrThrow() + + aliceNode.services.database.transaction { + val dbFlowCheckpoint = aliceNode.internals.checkpointStorage.getDBCheckpoint(flowHandle.id) + assertNull(dbFlowCheckpoint!!.blob!!.flowStack) + } + } + + @Test(timeout=300_000) + fun `failed flow started with a client id nulls its flow state in database after its lifetime`() { + val clientId = UUID.randomUUID().toString() + ResultFlow.hook = { throw IllegalStateException() } + + var flowHandle: FlowStateMachineHandle? = null + assertFailsWith { + flowHandle = aliceNode.services.startFlowWithClientId(clientId, ResultFlow(5)) + flowHandle!!.resultFuture.getOrThrow() + } + + aliceNode.services.database.transaction { + val dbFlowCheckpoint = aliceNode.internals.checkpointStorage.getDBCheckpoint(flowHandle!!.id) + assertNull(dbFlowCheckpoint!!.blob!!.flowStack) + } + } + @Test(timeout = 300_000) + fun `reattachFlowWithClientId can retrieve existing flow future`() { + val clientId = UUID.randomUUID().toString() + val flowHandle = aliceNode.services.startFlowWithClientId(clientId, ResultFlow(10)) + val reattachedFlowHandle = aliceNode.smm.reattachFlowWithClientId(clientId) + + assertEquals(10, flowHandle.resultFuture.getOrThrow(20.seconds)) + assertEquals(clientId, flowHandle.clientId) + assertEquals(flowHandle.id, reattachedFlowHandle?.id) + assertEquals(flowHandle.resultFuture.get(), reattachedFlowHandle?.resultFuture?.get()) + } + + @Test(timeout = 300_000) + fun `reattachFlowWithClientId can retrieve a null result from a flow future`() { + val clientId = UUID.randomUUID().toString() + val flowHandle = aliceNode.services.startFlowWithClientId(clientId, ResultFlow(null)) + val reattachedFlowHandle = aliceNode.smm.reattachFlowWithClientId(clientId) + + assertEquals(null, flowHandle.resultFuture.getOrThrow(20.seconds)) + assertEquals(clientId, flowHandle.clientId) + assertEquals(flowHandle.id, reattachedFlowHandle?.id) + assertEquals(flowHandle.resultFuture.get(), reattachedFlowHandle?.resultFuture?.get()) + } + + @Test(timeout = 300_000) + fun `reattachFlowWithClientId can retrieve result from completed flow`() { + val clientId = UUID.randomUUID().toString() + val flowHandle = aliceNode.services.startFlowWithClientId(clientId, ResultFlow(10)) + + assertEquals(10, flowHandle.resultFuture.getOrThrow(20.seconds)) + assertEquals(clientId, flowHandle.clientId) + + val reattachedFlowHandle = aliceNode.smm.reattachFlowWithClientId(clientId) + + assertEquals(flowHandle.id, reattachedFlowHandle?.id) + assertEquals(flowHandle.resultFuture.get(), reattachedFlowHandle?.resultFuture?.get()) + } + + @Test(timeout = 300_000) + fun `reattachFlowWithClientId returns null if no flow matches the client id`() { + assertEquals(null, aliceNode.smm.reattachFlowWithClientId(UUID.randomUUID().toString())) + } + + @Test(timeout = 300_000) + fun `reattachFlowWithClientId can retrieve exception from existing flow future`() { + ResultFlow.hook = { throw IllegalStateException("Bla bla bla") } + val clientId = UUID.randomUUID().toString() + val flowHandle = aliceNode.services.startFlowWithClientId(clientId, ResultFlow(10)) + val reattachedFlowHandle = aliceNode.smm.reattachFlowWithClientId(clientId) + + assertThatExceptionOfType(IllegalStateException::class.java).isThrownBy { + flowHandle.resultFuture.getOrThrow(20.seconds) + }.withMessage("Bla bla bla") + + assertThatExceptionOfType(IllegalStateException::class.java).isThrownBy { + reattachedFlowHandle?.resultFuture?.getOrThrow() + }.withMessage("Bla bla bla") + } + + @Test(timeout = 300_000) + fun `reattachFlowWithClientId can retrieve exception from completed flow`() { + ResultFlow.hook = { throw IllegalStateException("Bla bla bla") } + val clientId = UUID.randomUUID().toString() + val flowHandle = aliceNode.services.startFlowWithClientId(clientId, ResultFlow(10)) + + assertThatExceptionOfType(IllegalStateException::class.java).isThrownBy { + flowHandle.resultFuture.getOrThrow(20.seconds) + }.withMessage("Bla bla bla") + + val reattachedFlowHandle = aliceNode.smm.reattachFlowWithClientId(clientId) + + // [CordaRunTimeException] returned because [IllegalStateException] is not serializable + assertThatExceptionOfType(CordaRuntimeException::class.java).isThrownBy { + reattachedFlowHandle?.resultFuture?.getOrThrow() + }.withMessage("java.lang.IllegalStateException: Bla bla bla") + } +} + +internal class ResultFlow(private val result: A): FlowLogic() { + companion object { + var hook: (() -> Unit)? = null + var suspendableHook: FlowLogic? = null + } + + @Suspendable + override fun call(): A { + hook?.invoke() + suspendableHook?.let { subFlow(it) } + return result + } +} + +internal class UnSerializableResultFlow: FlowLogic() { + companion object { + var firstRun = true + } + + @Suspendable + override fun call(): Any { + stateMachine.suspend(FlowIORequest.ForceCheckpoint, false) + return if (firstRun) { + firstRun = false + Observable.empty() + } else { + 5 // serializable result + } + } +} \ No newline at end of file diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt index 1967f9ff63..fd7c926d0d 100644 --- a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt @@ -30,6 +30,7 @@ import net.corda.core.internal.declaredField import net.corda.core.messaging.MessageRecipients import net.corda.core.node.services.PartyInfo import net.corda.core.node.services.queryBy +import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize @@ -63,7 +64,7 @@ import net.corda.testing.node.internal.InternalMockNodeParameters import net.corda.testing.node.internal.TestStartedNode import net.corda.testing.node.internal.getMessage import net.corda.testing.node.internal.startFlow -import org.apache.commons.lang3.exception.ExceptionUtils +import net.corda.testing.node.internal.startFlowWithClientId import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatIllegalArgumentException import org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType @@ -74,7 +75,6 @@ import org.junit.Assert.assertNotEquals import org.junit.Assert.assertNotNull import org.junit.Assert.assertNull import org.junit.Before -import org.junit.Ignore import org.junit.Test import rx.Notification import rx.Observable @@ -82,9 +82,10 @@ import java.sql.SQLTransientConnectionException import java.time.Clock import java.time.Duration import java.time.Instant -import java.util.ArrayList +import java.util.UUID import java.util.concurrent.TimeoutException import java.util.function.Predicate +import kotlin.concurrent.thread import kotlin.reflect.KClass import kotlin.streams.toList import kotlin.test.assertFailsWith @@ -308,8 +309,7 @@ class FlowFrameworkTests { .withStackTraceContaining(ReceiveFlow::class.java.name) // Make sure the stack trace is that of the receiving flow .withStackTraceContaining("Received counter-flow exception from peer") bobNode.database.transaction { - val checkpoint = bobNode.internals.checkpointStorage.checkpoints().single() - assertEquals(Checkpoint.FlowStatus.FAILED, checkpoint.status) + assertThat(bobNode.internals.checkpointStorage.checkpoints()).isEmpty() } assertThat(receivingFiber.state).isEqualTo(Strand.State.WAITING) @@ -376,12 +376,11 @@ class FlowFrameworkTests { } } - // Ignoring test since completed flows are not currently keeping their checkpoints in the database - @Ignore @Test(timeout = 300_000) fun `Flow metadata finish time is set in database when the flow finishes`() { val terminationSignal = Semaphore(0) - val flow = aliceNode.services.startFlow(NoOpFlow(terminateUponSignal = terminationSignal)) + val clientId = UUID.randomUUID().toString() + val flow = aliceNode.services.startFlowWithClientId(clientId, NoOpFlow(terminateUponSignal = terminationSignal)) mockNet.waitQuiescent() aliceNode.database.transaction { val metadata = session.find(DBCheckpointStorage.DBFlowMetadata::class.java, flow.id.uuid.toString()) @@ -686,8 +685,12 @@ class FlowFrameworkTests { flowState = flowFiber!!.transientState.checkpoint.flowState if (firstExecution) { + firstExecution = false throw HospitalizeFlowException() } else { + // the below sleep should be removed once we fix : The thread's transaction executing StateMachineManager.start takes long + // and doesn't commit before flow starts running. + Thread.sleep(3000) dbCheckpointStatusBeforeSuspension = aliceNode.internals.checkpointStorage.getCheckpoints().toList().single().second.status currentDBSession().clear() // clear session as Hibernate with fails with 'org.hibernate.NonUniqueObjectException' once it tries to save a DBFlowCheckpoint upon checkpoint inMemoryCheckpointStatusBeforeSuspension = flowFiber.transientState.checkpoint.status @@ -701,7 +704,7 @@ class FlowFrameworkTests { } assertFailsWith { - aliceNode.services.startFlow(SuspendingFlow()).resultFuture.getOrThrow(30.seconds) // wait till flow gets hospitalized + aliceNode.services.startFlow(SuspendingFlow()).resultFuture.getOrThrow(10.seconds) // wait till flow gets hospitalized } // flow is in hospital assertTrue(flowState is FlowState.Unstarted) @@ -712,11 +715,10 @@ class FlowFrameworkTests { assertEquals(Checkpoint.FlowStatus.HOSPITALIZED, checkpoint.status) } // restart Node - flow will be loaded from checkpoint - firstExecution = false aliceNode = mockNet.restartNode(aliceNode) futureFiber.get().resultFuture.getOrThrow() // wait until the flow has completed // checkpoint states ,after flow retried, before and after suspension - assertEquals(Checkpoint.FlowStatus.HOSPITALIZED, dbCheckpointStatusBeforeSuspension) + assertEquals(Checkpoint.FlowStatus.RUNNABLE, dbCheckpointStatusBeforeSuspension) assertEquals(Checkpoint.FlowStatus.RUNNABLE, inMemoryCheckpointStatusBeforeSuspension) assertEquals(Checkpoint.FlowStatus.RUNNABLE, dbCheckpointStatusAfterSuspension) } @@ -734,8 +736,12 @@ class FlowFrameworkTests { flowState = flowFiber!!.transientState.checkpoint.flowState if (firstExecution) { + firstExecution = false throw HospitalizeFlowException() } else { + // the below sleep should be removed once we fix : The thread's transaction executing StateMachineManager.start takes long + // and doesn't commit before flow starts running. + Thread.sleep(3000) dbCheckpointStatus = aliceNode.internals.checkpointStorage.getCheckpoints().toList().single().second.status inMemoryCheckpointStatus = flowFiber.transientState.checkpoint.status @@ -744,7 +750,7 @@ class FlowFrameworkTests { } assertFailsWith { - aliceNode.services.startFlow(SuspendingFlow()).resultFuture.getOrThrow(30.seconds) // wait till flow gets hospitalized + aliceNode.services.startFlow(SuspendingFlow()).resultFuture.getOrThrow(10.seconds) // wait till flow gets hospitalized } // flow is in hospital assertTrue(flowState is FlowState.Started) @@ -753,41 +759,13 @@ class FlowFrameworkTests { assertEquals(Checkpoint.FlowStatus.HOSPITALIZED, checkpoint.status) } // restart Node - flow will be loaded from checkpoint - firstExecution = false aliceNode = mockNet.restartNode(aliceNode) futureFiber.get().resultFuture.getOrThrow() // wait until the flow has completed // checkpoint states ,after flow retried, after suspension - assertEquals(Checkpoint.FlowStatus.HOSPITALIZED, dbCheckpointStatus) + assertEquals(Checkpoint.FlowStatus.RUNNABLE, dbCheckpointStatus) assertEquals(Checkpoint.FlowStatus.RUNNABLE, inMemoryCheckpointStatus) } - // Upon implementing CORDA-3681 unignore this test; DBFlowException is not currently integrated - @Ignore - @Test(timeout=300_000) - fun `Checkpoint is updated in DB with FAILED status and the error when flow fails`() { - var flowId: StateMachineRunId? = null - - val e = assertFailsWith { - val fiber = aliceNode.services.startFlow(ExceptionFlow { FlowException("Just an exception") }) - flowId = fiber.id - fiber.resultFuture.getOrThrow() - } - - aliceNode.database.transaction { - val checkpoint = aliceNode.internals.checkpointStorage.checkpoints().single() - assertEquals(Checkpoint.FlowStatus.FAILED, checkpoint.status) - - // assert all fields of DBFlowException - val persistedException = aliceNode.internals.checkpointStorage.getDBCheckpoint(flowId!!)!!.exceptionDetails - assertEquals(FlowException::class.java.name, persistedException!!.type) - assertEquals("Just an exception", persistedException.message) - assertEquals(ExceptionUtils.getStackTrace(e), persistedException.stackTrace) - assertEquals(null, persistedException.value) - } - } - - // Upon implementing CORDA-3681 unignore this test; DBFlowException is not currently integrated - @Ignore @Test(timeout=300_000) fun `Checkpoint is updated in DB with HOSPITALIZED status and the error when flow is kept for overnight observation` () { var flowId: StateMachineRunId? = null @@ -803,10 +781,13 @@ class FlowFrameworkTests { assertEquals(Checkpoint.FlowStatus.HOSPITALIZED, checkpoint.status) // assert all fields of DBFlowException - val persistedException = aliceNode.internals.checkpointStorage.getDBCheckpoint(flowId!!)!!.exceptionDetails - assertEquals(HospitalizeFlowException::class.java.name, persistedException!!.type) - assertEquals("Overnight observation", persistedException.message) - assertEquals(null, persistedException.value) + val exceptionDetails = aliceNode.internals.checkpointStorage.getDBCheckpoint(flowId!!)!!.exceptionDetails + assertEquals(HospitalizeFlowException::class.java.name, exceptionDetails!!.type) + assertEquals("Overnight observation", exceptionDetails.message) + val deserializedException = exceptionDetails.value?.let { SerializedBytes(it) }?.deserialize(context = SerializationDefaults.STORAGE_CONTEXT) + assertNotNull(deserializedException) + val hospitalizeFlowException = deserializedException as HospitalizeFlowException + assertEquals("Overnight observation", hospitalizeFlowException.message) } } @@ -836,13 +817,84 @@ class FlowFrameworkTests { assertEquals(null, persistedException) } - private inline fun DatabaseTransaction.findRecordsFromDatabase(): List { - val criteria = session.criteriaBuilder.createQuery(T::class.java) - criteria.select(criteria.from(T::class.java)) - return session.createQuery(criteria).resultList + // When ported to ENT use the existing API there to properly retry the flow + @Test(timeout=300_000) + fun `Hospitalized flow, resets to 'RUNNABLE' and clears exception when retried`() { + var firstRun = true + var counter = 0 + val waitUntilHospitalizedTwice = Semaphore(-1) + + StaffedFlowHospital.onFlowKeptForOvernightObservation.add { _, _ -> + ++counter + if (firstRun) { + firstRun = false + val fiber = FlowStateMachineImpl.currentStateMachine()!! + thread { + // schedule a [RetryFlowFromSafePoint] after the [OvernightObservation] gets scheduled by the hospital + Thread.sleep(2000) + fiber.scheduleEvent(Event.RetryFlowFromSafePoint) + } + } + waitUntilHospitalizedTwice.release() + } + + var counterRes = 0 + StaffedFlowHospital.onFlowResuscitated.add { _, _, _ -> ++counterRes } + + aliceNode.services.startFlow(ExceptionFlow { HospitalizeFlowException("hospitalizing") }) + + waitUntilHospitalizedTwice.acquire() + assertEquals(2, counter) + assertEquals(0, counterRes) } - //region Helpers + @Test(timeout=300_000) + fun `Hospitalized flow, resets to 'RUNNABLE' and clears database exception on node start`() { + var checkpointStatusAfterRestart: Checkpoint.FlowStatus? = null + var dbExceptionAfterRestart: List? = null + + var secondRun = false + SuspendingFlow.hookBeforeCheckpoint = { + if(secondRun) { + // the below sleep should be removed once we fix : The thread's transaction executing StateMachineManager.start takes long + // and doesn't commit before flow starts running. + Thread.sleep(3000) + aliceNode.database.transaction { + checkpointStatusAfterRestart = findRecordsFromDatabase().single().status + dbExceptionAfterRestart = findRecordsFromDatabase() + } + } else { + secondRun = true + } + + throw HospitalizeFlowException("hospitalizing") + } + + var counter = 0 + val waitUntilHospitalized = Semaphore(0) + StaffedFlowHospital.onFlowKeptForOvernightObservation.add { _, _ -> + ++counter + waitUntilHospitalized.release() + } + + var counterRes = 0 + StaffedFlowHospital.onFlowResuscitated.add { _, _, _ -> ++counterRes } + + aliceNode.services.startFlow(SuspendingFlow()) + + waitUntilHospitalized.acquire() + Thread.sleep(3000) // wait until flow saves overnight observation state in database + aliceNode = mockNet.restartNode(aliceNode) + + + waitUntilHospitalized.acquire() + Thread.sleep(3000) // wait until flow saves overnight observation state in database + assertEquals(2, counter) + assertEquals(0, counterRes) + assertEquals(Checkpoint.FlowStatus.RUNNABLE, checkpointStatusAfterRestart) + assertEquals(0, dbExceptionAfterRestart!!.size) + } + //region Helpers private val normalEnd = ExistingSessionMessage(SessionId(0), EndSessionMessage) // NormalSessionEnd(0) @@ -1027,6 +1079,12 @@ internal fun TestStartedNode.sendSessionMessage(message: SessionMessage, destina } } +inline fun DatabaseTransaction.findRecordsFromDatabase(): List { + val criteria = session.criteriaBuilder.createQuery(T::class.java) + criteria.select(criteria.from(T::class.java)) + return session.createQuery(criteria).resultList +} + internal fun errorMessage(errorResponse: FlowException? = null) = ExistingSessionMessage(SessionId(0), ErrorSessionMessage(errorResponse, 0)) @@ -1207,7 +1265,7 @@ internal class SuspendingFlow : FlowLogic() { @Suspendable override fun call() { stateMachine.hookBeforeCheckpoint() - sleep(1.seconds) // flow checkpoints => checkpoint is in DB + stateMachine.suspend(FlowIORequest.ForceCheckpoint, maySkipCheckpoint = false) // flow checkpoints => checkpoint is in DB stateMachine.hookAfterCheckpoint() } } \ No newline at end of file diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowMetadataRecordingTest.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowMetadataRecordingTest.kt index 8d6fbf6c0e..97fa69f6a5 100644 --- a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowMetadataRecordingTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowMetadataRecordingTest.kt @@ -2,6 +2,7 @@ package net.corda.node.services.statemachine import co.paralleluniverse.fibers.Suspendable import net.corda.client.rpc.CordaRPCClient +import net.corda.core.CordaRuntimeException import net.corda.core.context.InvocationContext import net.corda.core.contracts.BelongsToContract import net.corda.core.contracts.LinearState @@ -47,12 +48,14 @@ import org.junit.Before import org.junit.Ignore import org.junit.Test import java.time.Instant +import java.util.UUID import java.util.concurrent.CompletableFuture import java.util.concurrent.Executors import java.util.concurrent.Semaphore import java.util.function.Supplier import kotlin.reflect.jvm.jvmName import kotlin.test.assertEquals +import kotlin.test.assertFailsWith import kotlin.test.assertNotNull import kotlin.test.assertNull import kotlin.test.assertTrue @@ -90,9 +93,11 @@ class FlowMetadataRecordingTest { metadata = metadataFromHook } + val clientId = UUID.randomUUID().toString() CordaRPCClient(nodeAHandle.rpcAddress).start(user.username, user.password).use { - it.proxy.startFlow( - ::MyFlow, + it.proxy.startFlowDynamicWithClientId( + clientId, + MyFlow::class.java, nodeBHandle.nodeInfo.singleIdentity(), string, someObject @@ -104,7 +109,7 @@ class FlowMetadataRecordingTest { assertEquals(flowId!!.uuid.toString(), it.flowId) assertEquals(MyFlow::class.java.name, it.flowName) // Should be changed when [userSuppliedIdentifier] gets filled in future changes - assertNull(it.userSuppliedIdentifier) + assertEquals(clientId, it.userSuppliedIdentifier) assertEquals(DBCheckpointStorage.StartReason.RPC, it.startType) assertEquals( listOf(nodeBHandle.nodeInfo.singleIdentity(), string, someObject), @@ -197,7 +202,7 @@ class FlowMetadataRecordingTest { assertEquals( listOf(nodeBHandle.nodeInfo.singleIdentity(), string, someObject), - uncheckedCast>(context!!.arguments[1]).toList() + uncheckedCast>(context!!.arguments!![1]).toList() ) assertEquals( listOf(nodeBHandle.nodeInfo.singleIdentity(), string, someObject), @@ -406,6 +411,19 @@ class FlowMetadataRecordingTest { } } + @Test(timeout = 300_000) + fun `assert that flow started with longer client id than MAX_CLIENT_ID_LENGTH fails`() { + val clientId = "1".repeat(513) // DBCheckpointStorage.MAX_CLIENT_ID_LENGTH == 512 + driver(DriverParameters(startNodesInProcess = true)) { + val nodeAHandle = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user)).getOrThrow() + val rpc = CordaRPCClient(nodeAHandle.rpcAddress).start(user.username, user.password).proxy + + assertFailsWith("clientId cannot be longer than ${DBCheckpointStorage.MAX_CLIENT_ID_LENGTH} characters") { + rpc.startFlowDynamicWithClientId(clientId, EmptyFlow::class.java).returnValue.getOrThrow() + } + } + } + @InitiatingFlow @StartableByRPC @StartableByService @@ -566,4 +584,11 @@ class FlowMetadataRecordingTest { return ScheduledActivity(logicRef, Instant.now()) } } + + @StartableByRPC + class EmptyFlow : FlowLogic() { + @Suspendable + override fun call() { + } + } } \ No newline at end of file diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/RetryFlowMockTest.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/RetryFlowMockTest.kt index ecaa28f0fe..aca41ff4e2 100644 --- a/node/src/test/kotlin/net/corda/node/services/statemachine/RetryFlowMockTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/RetryFlowMockTest.kt @@ -12,7 +12,6 @@ import net.corda.core.flows.KilledFlowException import net.corda.core.flows.UnexpectedFlowEndException import net.corda.core.identity.CordaX500Name import net.corda.core.identity.Party -import net.corda.core.internal.FlowStateMachine import net.corda.core.internal.concurrent.flatMap import net.corda.core.messaging.MessageRecipients import net.corda.core.utilities.UntrustworthyData @@ -156,7 +155,7 @@ class RetryFlowMockTest { // Make sure we have seen an update from the hospital, and thus the flow went there. val alice = TestIdentity(CordaX500Name.parse("L=London,O=Alice Ltd,OU=Trade,C=GB")).party val records = nodeA.smm.flowHospital.track().updates.toBlocking().toIterable().iterator() - val flow: FlowStateMachine = nodeA.services.startFlow(FinalityHandler(object : FlowSession() { + val flow = nodeA.services.startFlow(FinalityHandler(object : FlowSession() { override val destination: Destination get() = alice override val counterparty: Party get() = alice diff --git a/testing/core-test-utils/src/main/kotlin/net/corda/coretesting/internal/matchers/flow/FlowMatchers.kt b/testing/core-test-utils/src/main/kotlin/net/corda/coretesting/internal/matchers/flow/FlowMatchers.kt index 42b27f10c9..abb1ee9ab9 100644 --- a/testing/core-test-utils/src/main/kotlin/net/corda/coretesting/internal/matchers/flow/FlowMatchers.kt +++ b/testing/core-test-utils/src/main/kotlin/net/corda/coretesting/internal/matchers/flow/FlowMatchers.kt @@ -2,23 +2,23 @@ package net.corda.coretesting.internal.matchers.flow import com.natpryce.hamkrest.Matcher import com.natpryce.hamkrest.equalTo -import net.corda.core.internal.FlowStateMachine +import net.corda.core.internal.FlowStateMachineHandle import net.corda.coretesting.internal.matchers.* /** * Matches a Flow that succeeds with a result matched by the given matcher */ -fun willReturn(): Matcher> = net.corda.coretesting.internal.matchers.future.willReturn() - .extrude(FlowStateMachine::resultFuture) +fun willReturn(): Matcher> = net.corda.coretesting.internal.matchers.future.willReturn() + .extrude(FlowStateMachineHandle::resultFuture) .redescribe { "is a flow that will return" } -fun willReturn(expected: T): Matcher> = willReturn(equalTo(expected)) +fun willReturn(expected: T): Matcher> = willReturn(equalTo(expected)) /** * Matches a Flow that succeeds with a result matched by the given matcher */ fun willReturn(successMatcher: Matcher) = net.corda.coretesting.internal.matchers.future.willReturn(successMatcher) - .extrude(FlowStateMachine::resultFuture) + .extrude(FlowStateMachineHandle::resultFuture) .redescribe { "is a flow that will return with a value that ${successMatcher.description}" } /** @@ -26,7 +26,7 @@ fun willReturn(successMatcher: Matcher) = net.corda.coretesting.internal. */ inline fun willThrow(failureMatcher: Matcher) = net.corda.coretesting.internal.matchers.future.willThrow(failureMatcher) - .extrude(FlowStateMachine<*>::resultFuture) + .extrude(FlowStateMachineHandle<*>::resultFuture) .redescribe { "is a flow that will fail, throwing an exception that ${failureMatcher.description}" } /** @@ -34,5 +34,5 @@ inline fun willThrow(failureMatcher: Matcher) = */ inline fun willThrow() = net.corda.coretesting.internal.matchers.future.willThrow() - .extrude(FlowStateMachine<*>::resultFuture) + .extrude(FlowStateMachineHandle<*>::resultFuture) .redescribe { "is a flow that will fail with an exception of type ${E::class.java.simpleName}" } \ No newline at end of file diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InternalTestUtils.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InternalTestUtils.kt index 3c6de690bf..32cd878b88 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InternalTestUtils.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InternalTestUtils.kt @@ -7,7 +7,7 @@ import net.corda.core.concurrent.CordaFuture import net.corda.core.context.InvocationContext import net.corda.core.flows.FlowLogic import net.corda.core.identity.CordaX500Name -import net.corda.core.internal.FlowStateMachine +import net.corda.core.internal.FlowStateMachineHandle import net.corda.core.internal.VisibleForTesting import net.corda.core.internal.concurrent.openFuture import net.corda.core.internal.div @@ -269,7 +269,10 @@ class NodeListenProcessDeathException(hostAndPort: NetworkHostAndPort, listenPro """.trimIndent() ) -fun StartedNodeServices.startFlow(logic: FlowLogic): FlowStateMachine = startFlow(logic, newContext()).getOrThrow() +fun StartedNodeServices.startFlow(logic: FlowLogic): FlowStateMachineHandle = startFlow(logic, newContext()).getOrThrow() + +fun StartedNodeServices.startFlowWithClientId(clientId: String, logic: FlowLogic): FlowStateMachineHandle = + startFlow(logic, newContext().copy(clientId = clientId)).getOrThrow() fun StartedNodeServices.newContext(): InvocationContext = testContext(myInfo.chooseIdentity().name)