diff --git a/core/src/main/kotlin/net/corda/core/internal/IdempotentFlow.kt b/core/src/main/kotlin/net/corda/core/internal/IdempotentFlow.kt index 22d7f1c681..3e4c4f6fce 100644 --- a/core/src/main/kotlin/net/corda/core/internal/IdempotentFlow.kt +++ b/core/src/main/kotlin/net/corda/core/internal/IdempotentFlow.kt @@ -1,3 +1,13 @@ +/* + * R3 Proprietary and Confidential + * + * Copyright (c) 2018 R3 Limited. All rights reserved. + * + * The intellectual and technical concepts contained herein are proprietary to R3 and its suppliers and are protected by trade secret law. + * + * Distribution of this file or any portion thereof via any medium without the express permission of R3 is strictly prohibited. + */ + package net.corda.core.internal /** diff --git a/node/src/integration-test/kotlin/net/corda/node/services/MultiThreadedTimedFlowTests.kt b/node/src/integration-test/kotlin/net/corda/node/services/MultiThreadedTimedFlowTests.kt new file mode 100644 index 0000000000..b15fe3eb19 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/node/services/MultiThreadedTimedFlowTests.kt @@ -0,0 +1,123 @@ +/* + * R3 Proprietary and Confidential + * + * Copyright (c) 2018 R3 Limited. All rights reserved. + * + * The intellectual and technical concepts contained herein are proprietary to R3 and its suppliers and are protected by trade secret law. + * + * Distribution of this file or any portion thereof via any medium without the express permission of R3 is strictly prohibited. + */ + +package net.corda.node.services + +import co.paralleluniverse.fibers.Suspendable +import net.corda.client.rpc.CordaRPCClient +import net.corda.client.rpc.CordaRPCConnection +import net.corda.core.crypto.SecureHash +import net.corda.core.flows.* +import net.corda.core.identity.Party +import net.corda.core.internal.TimedFlow +import net.corda.core.messaging.startFlow +import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.unwrap +import net.corda.testing.core.ALICE_NAME +import net.corda.testing.core.BOB_NAME +import net.corda.testing.core.DUMMY_NOTARY_NAME +import net.corda.testing.core.singleIdentity +import net.corda.testing.driver.DriverParameters +import net.corda.testing.driver.NodeHandle +import net.corda.testing.driver.driver +import net.corda.testing.driver.internal.RandomFree +import net.corda.testing.internal.IntegrationTest +import net.corda.testing.internal.IntegrationTestSchemas +import net.corda.testing.internal.toDatabaseSchemaName +import net.corda.testing.node.User +import org.junit.Before +import org.junit.ClassRule +import org.junit.Test +import java.time.Duration +import java.util.concurrent.atomic.AtomicInteger +import kotlin.test.assertEquals + +class TimedFlowMultiThreadedSMMTests : IntegrationTest() { + companion object { + @ClassRule + @JvmField + val databaseSchemas = IntegrationTestSchemas(ALICE_NAME.toDatabaseSchemaName(), BOB_NAME.toDatabaseSchemaName(), DUMMY_NOTARY_NAME.toDatabaseSchemaName()) + + val requestCount = AtomicInteger(0) + val invocationCount = AtomicInteger(0) + } + + @Before + fun resetCounters() { + requestCount.set(0) + invocationCount.set(0) + } + + @Test + fun `timed flow is retried`() { + val user = User("test", "pwd", setOf(Permissions.startFlow<TimedInitiatorFlow>(), Permissions.startFlow<SuperFlow>())) + driver(DriverParameters(isDebug = true, startNodesInProcess = true, + portAllocation = RandomFree)) { + + val configOverrides = mapOf("p2pMessagingRetry" to mapOf( + "messageRedeliveryDelay" to Duration.ofSeconds(1), + "maxRetryCount" to 2, + "backoffBase" to 1.0 + )) + + val nodeAHandle = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user), customOverrides = configOverrides).getOrThrow() + val nodeBHandle = startNode(providedName = BOB_NAME, rpcUsers = listOf(user)).getOrThrow() + + CordaRPCClient(nodeAHandle.rpcAddress).start(user.username, user.password).use { rpc -> + whenInvokedDirectly(rpc, nodeBHandle) + resetCounters() + whenInvokedAsSubflow(rpc, nodeBHandle) + } + } + } + + private fun whenInvokedDirectly(rpc: CordaRPCConnection, nodeBHandle: NodeHandle) { + rpc.proxy.startFlow(::TimedInitiatorFlow, nodeBHandle.nodeInfo.singleIdentity()).returnValue.getOrThrow() + /* The TimedInitiatorFlow is expected to time out the first time, and succeed the second time. */ + assertEquals(2, invocationCount.get()) + } + + private fun whenInvokedAsSubflow(rpc: CordaRPCConnection, nodeBHandle: NodeHandle) { + rpc.proxy.startFlow(::SuperFlow, nodeBHandle.nodeInfo.singleIdentity()).returnValue.getOrThrow() + assertEquals(2, invocationCount.get()) + } + + @StartableByRPC + class SuperFlow(private val other: Party) : FlowLogic<Unit>() { + @Suspendable + override fun call() { + subFlow(TimedInitiatorFlow(other)) + } + } + + @StartableByRPC + @InitiatingFlow + class TimedInitiatorFlow(private val other: Party) : FlowLogic<Unit>(), TimedFlow { + @Suspendable + override fun call() { + invocationCount.incrementAndGet() + val session = initiateFlow(other) + session.sendAndReceive<Unit>(Unit) + } + } + + @InitiatedBy(TimedInitiatorFlow::class) + class InitiatedFlow(val session: FlowSession) : FlowLogic<Any>() { + @Suspendable + override fun call() { + val value = session.receive<Unit>().unwrap { } + if (TimedFlowMultiThreadedSMMTests.requestCount.getAndIncrement() == 0) { + waitForLedgerCommit(SecureHash.randomSHA256()) + } else { + session.send(value) + } + } + } +} \ No newline at end of file diff --git a/node/src/integration-test/kotlin/net/corda/node/services/TimedFlowTests.kt b/node/src/integration-test/kotlin/net/corda/node/services/TimedFlowTests.kt index 180d2024af..7ac0f1a59e 100644 --- a/node/src/integration-test/kotlin/net/corda/node/services/TimedFlowTests.kt +++ b/node/src/integration-test/kotlin/net/corda/node/services/TimedFlowTests.kt @@ -1,3 +1,13 @@ +/* + * R3 Proprietary and Confidential + * + * Copyright (c) 2018 R3 Limited. All rights reserved. + * + * The intellectual and technical concepts contained herein are proprietary to R3 and its suppliers and are protected by trade secret law. + * + * Distribution of this file or any portion thereof via any medium without the express permission of R3 is strictly prohibited. + */ + package net.corda.node.services import co.paralleluniverse.fibers.Suspendable diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/MultiThreadedStateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/MultiThreadedStateMachineManager.kt index 40077b495c..4e06845d80 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/MultiThreadedStateMachineManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/MultiThreadedStateMachineManager.kt @@ -51,9 +51,7 @@ import rx.Observable import rx.subjects.PublishSubject import java.security.SecureRandom import java.util.* -import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.ExecutorService -import java.util.concurrent.ThreadPoolExecutor +import java.util.concurrent.* import java.util.concurrent.locks.ReentrantReadWriteLock import javax.annotation.concurrent.ThreadSafe import kotlin.collections.ArrayList @@ -80,6 +78,14 @@ class MultiThreadedStateMachineManager( } private class Flow(val fiber: FlowStateMachineImpl<*>, val resultFuture: OpenFuture<Any?>) + + private data class ScheduledTimeout( + /** Will fire a [FlowTimeoutException] indicating to the flow hospital to restart the flow. */ + val scheduledFuture: ScheduledFuture<*>, + /** Specifies the number of times this flow has been retried. */ + val retryCount: Int = 0 + ) + private enum class State { UNSTARTED, STARTED, @@ -92,11 +98,15 @@ class MultiThreadedStateMachineManager( val flows = ConcurrentHashMap<StateMachineRunId, Flow>() val startedFutures = ConcurrentHashMap<StateMachineRunId, OpenFuture<Unit>>() val changesPublisher = PublishSubject.create<StateMachineManager.Change>()!! + /** Flows scheduled to be retried if not finished within the specified timeout period. */ + val timedFlows = ConcurrentHashMap<StateMachineRunId, ScheduledTimeout>() } private val concurrentBox = ConcurrentBox(InnerState()) private val scheduler = FiberExecutorScheduler("Flow fiber scheduler", executor) + private val timeoutScheduler = AffinityExecutor.ServiceAffinityExecutor("Flow timeout scheduler", 1) + // How many Fibers are running and not suspended. If zero and stopping is true, then we are halted. private val liveFibers = ReusableLatch() // Monitoring support. @@ -212,6 +222,7 @@ class MultiThreadedStateMachineManager( override fun killFlow(id: StateMachineRunId): Boolean { concurrentBox.concurrent { + cancelTimeoutIfScheduled(id) val flow = flows.remove(id) if (flow != null) { logger.debug("Killing flow known to physical node.") @@ -263,6 +274,7 @@ class MultiThreadedStateMachineManager( override fun removeFlow(flowId: StateMachineRunId, removalReason: FlowRemovalReason, lastState: StateMachineState) { concurrentBox.concurrent { + cancelTimeoutIfScheduled(flowId) val flow = flows.remove(flowId) if (flow != null) { decrementLiveFibers() @@ -427,7 +439,7 @@ class MultiThreadedStateMachineManager( "unknown session $recipientId, discarding..." } } else { - throw IllegalArgumentException("Cannot find flow corresponding to session ID $recipientId") + logger.warn("Cannot find flow corresponding to session ID $recipientId.") } } else { val flow = concurrentBox.content.flows[flowId] ?: throw IllegalStateException("Cannot find fiber corresponding to ID $flowId") @@ -533,7 +545,7 @@ class MultiThreadedStateMachineManager( flowLogic.stateMachine = flowStateMachineImpl val frozenFlowLogic = (flowLogic as FlowLogic<*>).serialize(context = checkpointSerializationContext!!) - val flowCorDappVersion= FlowStateMachineImpl.createSubFlowVersion(serviceHub.cordappProvider.getCordappForFlow(flowLogic), serviceHub.myInfo.platformVersion) + val flowCorDappVersion = FlowStateMachineImpl.createSubFlowVersion(serviceHub.cordappProvider.getCordappForFlow(flowLogic), serviceHub.myInfo.platformVersion) val initialCheckpoint = Checkpoint.create(invocationContext, flowStart, flowLogic.javaClass, frozenFlowLogic, ourIdentity, deduplicationSeed, flowCorDappVersion).getOrThrow() val startedFuture = openFuture<Unit>() val initialState = StateMachineState( @@ -556,6 +568,59 @@ class MultiThreadedStateMachineManager( return startedFuture.map { flowStateMachineImpl as FlowStateMachine<A> } } + override fun scheduleFlowTimeout(flowId: StateMachineRunId) { + concurrentBox.concurrent { scheduleTimeout(flowId) } + } + + override fun cancelFlowTimeout(flowId: StateMachineRunId) { + concurrentBox.concurrent { cancelTimeoutIfScheduled(flowId) } + } + + /** + * Schedules the flow [flowId] to be retried if it does not finish within the timeout period + * specified in the config. + * + * Assumes a read lock is taken on the [InnerState]. + */ + private fun InnerState.scheduleTimeout(flowId: StateMachineRunId) { + val flow = flows[flowId] + if (flow != null) { + val scheduledTimeout = timedFlows[flowId] + val retryCount = if (scheduledTimeout != null) { + val timeoutFuture = scheduledTimeout.scheduledFuture + if (!timeoutFuture.isDone) scheduledTimeout.scheduledFuture.cancel(true) + scheduledTimeout.retryCount + } else 0 + val scheduledFuture = scheduleTimeoutException(flow, retryCount) + timedFlows[flowId] = ScheduledTimeout(scheduledFuture, retryCount + 1) + } else { + logger.warn("Unable to schedule timeout for flow $flowId – flow not found.") + } + } + + /** Schedules a [FlowTimeoutException] to be fired in order to restart the flow. */ + private fun scheduleTimeoutException(flow: Flow, retryCount: Int): ScheduledFuture<*> { + return with(serviceHub.configuration.p2pMessagingRetry) { + val timeoutDelaySeconds = messageRedeliveryDelay.seconds * Math.pow(backoffBase, retryCount.toDouble()).toLong() + timeoutScheduler.schedule({ + val event = Event.Error(FlowTimeoutException(maxRetryCount)) + flow.fiber.scheduleEvent(event) + }, timeoutDelaySeconds, TimeUnit.SECONDS) + } + } + + /** + * Cancels any scheduled flow timeout for [flowId]. + * + * Assumes a read lock is taken on the [InnerState]. + */ + private fun InnerState.cancelTimeoutIfScheduled(flowId: StateMachineRunId) { + timedFlows[flowId]?.let { (future, _) -> + if (!future.isDone) future.cancel(true) + timedFlows.remove(flowId) + } + } + private fun deserializeCheckpoint(serializedCheckpoint: SerializedBytes<Checkpoint>): Checkpoint? { return try { serializedCheckpoint.deserialize(context = checkpointSerializationContext!!) @@ -659,6 +724,8 @@ class MultiThreadedStateMachineManager( } else { oldFlow.resultFuture.captureLater(flow.resultFuture) } + val flowLogic = flow.fiber.logic + if (flowLogic is TimedFlow) scheduleTimeout(id) flow.fiber.scheduleEvent(Event.DoRemainingWork) when (checkpoint.flowState) { is FlowState.Unstarted -> {