diff --git a/node/src/integration-test/kotlin/net/corda/node/services/MultiThreadedTimedFlowTests.kt b/node/src/integration-test/kotlin/net/corda/node/services/TimedFlowMultiThreadedSMMTests.kt similarity index 56% rename from node/src/integration-test/kotlin/net/corda/node/services/MultiThreadedTimedFlowTests.kt rename to node/src/integration-test/kotlin/net/corda/node/services/TimedFlowMultiThreadedSMMTests.kt index 442f9c0e8a..3541682d2a 100644 --- a/node/src/integration-test/kotlin/net/corda/node/services/MultiThreadedTimedFlowTests.kt +++ b/node/src/integration-test/kotlin/net/corda/node/services/TimedFlowMultiThreadedSMMTests.kt @@ -18,8 +18,13 @@ import net.corda.core.flows.* import net.corda.core.identity.Party import net.corda.core.internal.TimedFlow import net.corda.core.messaging.startFlow +import net.corda.core.messaging.startTrackedFlow +import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.unwrap +import net.corda.node.services.TimedFlowMultiThreadedSMMTests.AbstractTimedFlow.Companion.STEP_1 +import net.corda.node.services.TimedFlowMultiThreadedSMMTests.AbstractTimedFlow.Companion.STEP_2 +import net.corda.node.services.TimedFlowMultiThreadedSMMTests.AbstractTimedFlow.Companion.STEP_3 import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.BOB_NAME import net.corda.testing.core.DUMMY_NOTARY_NAME @@ -32,10 +37,12 @@ import net.corda.testing.internal.IntegrationTest import net.corda.testing.internal.IntegrationTestSchemas import net.corda.testing.internal.toDatabaseSchemaName import net.corda.testing.node.User +import org.assertj.core.api.Assertions.assertThat import org.junit.Before import org.junit.ClassRule import org.junit.Test import java.time.Duration +import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicInteger import kotlin.test.assertEquals @@ -78,6 +85,30 @@ class TimedFlowMultiThreadedSMMTests : IntegrationTest() { } } + @Test + fun `progress tracker is preserved after flow is retried`() { + val user = User("test", "pwd", setOf(Permissions.startFlow(), Permissions.startFlow())) + driver(DriverParameters(isDebug = true, startNodesInProcess = true, + portAllocation = RandomFree)) { + + val configOverrides = mapOf("flowTimeout" to mapOf( + "timeout" to Duration.ofSeconds(2), + "maxRestartCount" to 2, + "backoffBase" to 1.0 + )) + + val nodeAHandle = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user), customOverrides = configOverrides).getOrThrow() + val nodeBHandle = startNode(providedName = BOB_NAME, rpcUsers = listOf(user)).getOrThrow() + + CordaRPCClient(nodeAHandle.rpcAddress).start(user.username, user.password).use { rpc -> + resetCounters() + whenInvokedDirectlyAndTracked(rpc, nodeBHandle) + assertEquals(2, invocationCount.get()) + } + } + } + + private fun whenInvokedDirectly(rpc: CordaRPCConnection, nodeBHandle: NodeHandle) { rpc.proxy.startFlow(::TimedInitiatorFlow, nodeBHandle.nodeInfo.singleIdentity()).returnValue.getOrThrow() /* The TimedInitiatorFlow is expected to time out the first time, and succeed the second time. */ @@ -97,14 +128,49 @@ class TimedFlowMultiThreadedSMMTests : IntegrationTest() { } } + private fun whenInvokedDirectlyAndTracked(rpc: CordaRPCConnection, nodeBHandle: NodeHandle) { + val flowHandle = rpc.proxy.startTrackedFlow(::TimedInitiatorFlow, nodeBHandle.nodeInfo.singleIdentity()) + + val stepsCount = 4 + assertEquals(stepsCount, flowHandle.stepsTreeFeed!!.snapshot.size, "Expected progress tracker to return the last step") + + val doneIndex = 3 + val doneIndexStepFromSnapshot = flowHandle.stepsTreeIndexFeed!!.snapshot + val doneIndexFromUpdates = flowHandle.stepsTreeIndexFeed!!.updates.takeFirst { it == doneIndex } + .timeout(5, TimeUnit.SECONDS).onErrorResumeNext(rx.Observable.empty()).toBlocking().singleOrDefault(0) + // we got the last index either via snapshot or update + assertThat(setOf(doneIndexStepFromSnapshot, doneIndexFromUpdates)).contains(doneIndex).withFailMessage("Expected the last step to be reached") + + val doneLabel = "Done" + val doneStep = flowHandle.progress.takeFirst { it == doneLabel } + .timeout(5, TimeUnit.SECONDS).onErrorResumeNext(rx.Observable.empty()).toBlocking().singleOrDefault("") + assertEquals(doneLabel, doneStep) + + flowHandle.returnValue.getOrThrow() + } + + /** This abstract class is required to test that the progress tracker gets preserved after restart correctly. */ + abstract class AbstractTimedFlow(override val progressTracker: ProgressTracker) : FlowLogic() { + companion object { + object STEP_1 : ProgressTracker.Step("Step 1") + object STEP_2 : ProgressTracker.Step("Step 2") + object STEP_3 : ProgressTracker.Step("Step 3") + + fun tracker() = ProgressTracker(STEP_1, STEP_2, STEP_3) + } + } + @StartableByRPC @InitiatingFlow - class TimedInitiatorFlow(private val other: Party) : FlowLogic(), TimedFlow { + class TimedInitiatorFlow(private val other: Party) : AbstractTimedFlow(tracker()), TimedFlow { @Suspendable override fun call() { + progressTracker.currentStep = STEP_1 invocationCount.incrementAndGet() + progressTracker.currentStep = STEP_2 val session = initiateFlow(other) session.sendAndReceive(Unit) + progressTracker.currentStep = STEP_3 } } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/MultiThreadedStateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/MultiThreadedStateMachineManager.kt index 6dd1766e7c..2d3816fc8c 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/MultiThreadedStateMachineManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/MultiThreadedStateMachineManager.kt @@ -43,6 +43,7 @@ import net.corda.node.services.messaging.ReceivedMessage import net.corda.node.services.statemachine.interceptors.* import net.corda.node.services.statemachine.transitions.StateMachine import net.corda.node.utilities.AffinityExecutor +import net.corda.node.utilities.injectOldProgressTracker import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.serialization.internal.SerializeAsTokenContextImpl import net.corda.serialization.internal.withTokenContext @@ -377,7 +378,10 @@ class MultiThreadedStateMachineManager( for (sessionId in getFlowSessionIds(currentState.checkpoint)) { sessionToFlow.remove(sessionId) } - if (flow != null) addAndStartFlow(flowId, flow) + if (flow != null) { + injectOldProgressTracker(currentState.flowLogic.progressTracker, flow.fiber.logic) + addAndStartFlow(flowId, flow) + } // Deliver all the external events from the old flow instance. val unprocessedExternalEvents = mutableListOf() do { diff --git a/node/src/main/kotlin/net/corda/node/utilities/StateMachineManagerUtils.kt b/node/src/main/kotlin/net/corda/node/utilities/StateMachineManagerUtils.kt index d4cdf6b2d7..e7bbdfb15b 100644 --- a/node/src/main/kotlin/net/corda/node/utilities/StateMachineManagerUtils.kt +++ b/node/src/main/kotlin/net/corda/node/utilities/StateMachineManagerUtils.kt @@ -1,3 +1,13 @@ +/* + * R3 Proprietary and Confidential + * + * Copyright (c) 2018 R3 Limited. All rights reserved. + * + * The intellectual and technical concepts contained herein are proprietary to R3 and its suppliers and are protected by trade secret law. + * + * Distribution of this file or any portion thereof via any medium without the express permission of R3 is strictly prohibited. + */ + package net.corda.node.utilities import net.corda.core.flows.FlowLogic