mirror of
https://github.com/corda/corda.git
synced 2025-03-14 00:06:45 +00:00
CORDA-1610: Retain progress tracker during flow retry - multithreaded SMM (#1049)
* Test to show progress tracker losing updates when a flow is retried. * Fix the injection logic - attach the new children to the old progress tracker
This commit is contained in:
parent
20c53a5a45
commit
75e30c8114
@ -18,8 +18,13 @@ import net.corda.core.flows.*
|
||||
import net.corda.core.identity.Party
|
||||
import net.corda.core.internal.TimedFlow
|
||||
import net.corda.core.messaging.startFlow
|
||||
import net.corda.core.messaging.startTrackedFlow
|
||||
import net.corda.core.utilities.ProgressTracker
|
||||
import net.corda.core.utilities.getOrThrow
|
||||
import net.corda.core.utilities.unwrap
|
||||
import net.corda.node.services.TimedFlowMultiThreadedSMMTests.AbstractTimedFlow.Companion.STEP_1
|
||||
import net.corda.node.services.TimedFlowMultiThreadedSMMTests.AbstractTimedFlow.Companion.STEP_2
|
||||
import net.corda.node.services.TimedFlowMultiThreadedSMMTests.AbstractTimedFlow.Companion.STEP_3
|
||||
import net.corda.testing.core.ALICE_NAME
|
||||
import net.corda.testing.core.BOB_NAME
|
||||
import net.corda.testing.core.DUMMY_NOTARY_NAME
|
||||
@ -32,10 +37,12 @@ import net.corda.testing.internal.IntegrationTest
|
||||
import net.corda.testing.internal.IntegrationTestSchemas
|
||||
import net.corda.testing.internal.toDatabaseSchemaName
|
||||
import net.corda.testing.node.User
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.junit.Before
|
||||
import org.junit.ClassRule
|
||||
import org.junit.Test
|
||||
import java.time.Duration
|
||||
import java.util.concurrent.TimeUnit
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
@ -78,6 +85,30 @@ class TimedFlowMultiThreadedSMMTests : IntegrationTest() {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `progress tracker is preserved after flow is retried`() {
|
||||
val user = User("test", "pwd", setOf(Permissions.startFlow<TimedInitiatorFlow>(), Permissions.startFlow<SuperFlow>()))
|
||||
driver(DriverParameters(isDebug = true, startNodesInProcess = true,
|
||||
portAllocation = RandomFree)) {
|
||||
|
||||
val configOverrides = mapOf("flowTimeout" to mapOf(
|
||||
"timeout" to Duration.ofSeconds(2),
|
||||
"maxRestartCount" to 2,
|
||||
"backoffBase" to 1.0
|
||||
))
|
||||
|
||||
val nodeAHandle = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user), customOverrides = configOverrides).getOrThrow()
|
||||
val nodeBHandle = startNode(providedName = BOB_NAME, rpcUsers = listOf(user)).getOrThrow()
|
||||
|
||||
CordaRPCClient(nodeAHandle.rpcAddress).start(user.username, user.password).use { rpc ->
|
||||
resetCounters()
|
||||
whenInvokedDirectlyAndTracked(rpc, nodeBHandle)
|
||||
assertEquals(2, invocationCount.get())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private fun whenInvokedDirectly(rpc: CordaRPCConnection, nodeBHandle: NodeHandle) {
|
||||
rpc.proxy.startFlow(::TimedInitiatorFlow, nodeBHandle.nodeInfo.singleIdentity()).returnValue.getOrThrow()
|
||||
/* The TimedInitiatorFlow is expected to time out the first time, and succeed the second time. */
|
||||
@ -97,14 +128,49 @@ class TimedFlowMultiThreadedSMMTests : IntegrationTest() {
|
||||
}
|
||||
}
|
||||
|
||||
private fun whenInvokedDirectlyAndTracked(rpc: CordaRPCConnection, nodeBHandle: NodeHandle) {
|
||||
val flowHandle = rpc.proxy.startTrackedFlow(::TimedInitiatorFlow, nodeBHandle.nodeInfo.singleIdentity())
|
||||
|
||||
val stepsCount = 4
|
||||
assertEquals(stepsCount, flowHandle.stepsTreeFeed!!.snapshot.size, "Expected progress tracker to return the last step")
|
||||
|
||||
val doneIndex = 3
|
||||
val doneIndexStepFromSnapshot = flowHandle.stepsTreeIndexFeed!!.snapshot
|
||||
val doneIndexFromUpdates = flowHandle.stepsTreeIndexFeed!!.updates.takeFirst { it == doneIndex }
|
||||
.timeout(5, TimeUnit.SECONDS).onErrorResumeNext(rx.Observable.empty()).toBlocking().singleOrDefault(0)
|
||||
// we got the last index either via snapshot or update
|
||||
assertThat(setOf(doneIndexStepFromSnapshot, doneIndexFromUpdates)).contains(doneIndex).withFailMessage("Expected the last step to be reached")
|
||||
|
||||
val doneLabel = "Done"
|
||||
val doneStep = flowHandle.progress.takeFirst { it == doneLabel }
|
||||
.timeout(5, TimeUnit.SECONDS).onErrorResumeNext(rx.Observable.empty()).toBlocking().singleOrDefault("")
|
||||
assertEquals(doneLabel, doneStep)
|
||||
|
||||
flowHandle.returnValue.getOrThrow()
|
||||
}
|
||||
|
||||
/** This abstract class is required to test that the progress tracker gets preserved after restart correctly. */
|
||||
abstract class AbstractTimedFlow(override val progressTracker: ProgressTracker) : FlowLogic<Unit>() {
|
||||
companion object {
|
||||
object STEP_1 : ProgressTracker.Step("Step 1")
|
||||
object STEP_2 : ProgressTracker.Step("Step 2")
|
||||
object STEP_3 : ProgressTracker.Step("Step 3")
|
||||
|
||||
fun tracker() = ProgressTracker(STEP_1, STEP_2, STEP_3)
|
||||
}
|
||||
}
|
||||
|
||||
@StartableByRPC
|
||||
@InitiatingFlow
|
||||
class TimedInitiatorFlow(private val other: Party) : FlowLogic<Unit>(), TimedFlow {
|
||||
class TimedInitiatorFlow(private val other: Party) : AbstractTimedFlow(tracker()), TimedFlow {
|
||||
@Suspendable
|
||||
override fun call() {
|
||||
progressTracker.currentStep = STEP_1
|
||||
invocationCount.incrementAndGet()
|
||||
progressTracker.currentStep = STEP_2
|
||||
val session = initiateFlow(other)
|
||||
session.sendAndReceive<Unit>(Unit)
|
||||
progressTracker.currentStep = STEP_3
|
||||
}
|
||||
}
|
||||
|
@ -43,6 +43,7 @@ import net.corda.node.services.messaging.ReceivedMessage
|
||||
import net.corda.node.services.statemachine.interceptors.*
|
||||
import net.corda.node.services.statemachine.transitions.StateMachine
|
||||
import net.corda.node.utilities.AffinityExecutor
|
||||
import net.corda.node.utilities.injectOldProgressTracker
|
||||
import net.corda.nodeapi.internal.persistence.CordaPersistence
|
||||
import net.corda.serialization.internal.SerializeAsTokenContextImpl
|
||||
import net.corda.serialization.internal.withTokenContext
|
||||
@ -377,7 +378,10 @@ class MultiThreadedStateMachineManager(
|
||||
for (sessionId in getFlowSessionIds(currentState.checkpoint)) {
|
||||
sessionToFlow.remove(sessionId)
|
||||
}
|
||||
if (flow != null) addAndStartFlow(flowId, flow)
|
||||
if (flow != null) {
|
||||
injectOldProgressTracker(currentState.flowLogic.progressTracker, flow.fiber.logic)
|
||||
addAndStartFlow(flowId, flow)
|
||||
}
|
||||
// Deliver all the external events from the old flow instance.
|
||||
val unprocessedExternalEvents = mutableListOf<ExternalEvent>()
|
||||
do {
|
||||
|
@ -1,3 +1,13 @@
|
||||
/*
|
||||
* R3 Proprietary and Confidential
|
||||
*
|
||||
* Copyright (c) 2018 R3 Limited. All rights reserved.
|
||||
*
|
||||
* The intellectual and technical concepts contained herein are proprietary to R3 and its suppliers and are protected by trade secret law.
|
||||
*
|
||||
* Distribution of this file or any portion thereof via any medium without the express permission of R3 is strictly prohibited.
|
||||
*/
|
||||
|
||||
package net.corda.node.utilities
|
||||
|
||||
import net.corda.core.flows.FlowLogic
|
||||
|
Loading…
x
Reference in New Issue
Block a user