diff --git a/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt index 34cf8f82df..e8172e5556 100644 --- a/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt @@ -1,6 +1,7 @@ package com.r3corda.node.internal import com.codahale.metrics.MetricRegistry +import com.google.common.annotations.VisibleForTesting import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.MoreExecutors import com.google.common.util.concurrent.SettableFuture @@ -254,12 +255,17 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, val netwo runOnStop += Runnable { net.stop() } _networkMapRegistrationFuture.setFuture(registerWithNetworkMap()) smm.start() + // Shut down the SMM so no Fibers are scheduled. + runOnStop += Runnable { smm.stop(acceptableLiveFiberCountOnStop()) } scheduler.start() } started = true return this } + @VisibleForTesting + protected open fun acceptableLiveFiberCountOnStop(): Int = 0 + private fun hasSSLCertificates(): Boolean { val keyStore = try { // This will throw exception if key file not found or keystore password is incorrect. diff --git a/node/src/main/kotlin/com/r3corda/node/services/api/ServiceHubInternal.kt b/node/src/main/kotlin/com/r3corda/node/services/api/ServiceHubInternal.kt index 224ac57d27..e5c12b7151 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/api/ServiceHubInternal.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/api/ServiceHubInternal.kt @@ -1,17 +1,14 @@ package com.r3corda.node.services.api import com.google.common.util.concurrent.ListenableFuture -import com.r3corda.core.crypto.Party import com.r3corda.core.messaging.MessagingService import com.r3corda.core.node.PluginServiceHub -import com.r3corda.core.node.ServiceHub import com.r3corda.core.node.services.TxWritableStorageService import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogicRefFactory import com.r3corda.core.transactions.SignedTransaction import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl import org.slf4j.LoggerFactory -import kotlin.reflect.KClass interface MessagingServiceInternal : MessagingService { /** diff --git a/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt b/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt index 1b2226f4aa..7c0c5b684a 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt @@ -18,6 +18,7 @@ import com.r3corda.node.services.api.ServiceHubInternal import com.r3corda.node.services.statemachine.StateMachineManager.* import com.r3corda.node.utilities.StrandLocalTransactionManager import com.r3corda.node.utilities.createDatabaseTransaction +import com.r3corda.node.utilities.databaseTransaction import org.jetbrains.exposed.sql.Database import org.jetbrains.exposed.sql.Transaction import org.jetbrains.exposed.sql.transactions.TransactionManager @@ -230,7 +231,7 @@ class ProtocolStateMachineImpl(override val id: StateMachineRunId, // we have to pass the Thread local Transaction across via a transient field as the Fiber Park swaps them out. txTrampoline = TransactionManager.currentOrNull() StrandLocalTransactionManager.setThreadLocalTx(null) - ioRequest.session.waitingForResponse = true + ioRequest.session.waitingForResponse = (ioRequest is ReceiveRequest<*>) parkAndSerialize { fiber, serializer -> logger.trace { "Suspended on $ioRequest" } // restore the Tx onto the ThreadLocal so that we can commit the ensuing checkpoint to the DB @@ -246,13 +247,16 @@ class ProtocolStateMachineImpl(override val id: StateMachineRunId, processException(t) } } - ioRequest.session.waitingForResponse = false + logger.trace { "Resumed from $ioRequest" } createTransaction() } private fun processException(t: Throwable) { - actionOnEnd() - _resultFuture?.setException(t) + // This can get called in actionOnSuspend *after* we commit the database transaction, so optionally open a new one here. + databaseTransaction(database) { + actionOnEnd() + _resultFuture?.setException(t) + } } internal fun resume(scheduler: FiberScheduler) { diff --git a/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt b/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt index 0bf58db553..5a355e3647 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt @@ -29,6 +29,7 @@ import com.r3corda.node.utilities.AddOrRemove import com.r3corda.node.utilities.AffinityExecutor import com.r3corda.node.utilities.isolatedTransaction import kotlinx.support.jdk8.collections.removeIf +import org.apache.activemq.artemis.utils.ReusableLatch import org.jetbrains.exposed.sql.Database import rx.Observable import rx.subjects.PublishSubject @@ -95,6 +96,11 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, } }) + // True if we're shutting down, so don't resume anything. + @Volatile private var stopping = false + // How many Fibers are running and not suspended. If zero and stopping is true, then we are halted. + private val liveFibers = ReusableLatch() + // Monitoring support. private val metrics = serviceHub.monitoringService.metrics @@ -144,6 +150,31 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, serviceHub.networkMapCache.mapServiceRegistered.then(executor) { resumeRestoredFibers() } } + private fun decrementLiveFibers() { + liveFibers.countDown() + } + + private fun incrementLiveFibers() { + liveFibers.countUp() + } + + /** + * Start the shutdown process, bringing the [StateMachineManager] to a controlled stop. When this method returns, + * all Fibers have been suspended and checkpointed, or have completed. + * + * @param allowedUnsuspendedFiberCount Optional parameter is used in some tests. + */ + fun stop(allowedUnsuspendedFiberCount: Int = 0) { + check(allowedUnsuspendedFiberCount >= 0) + mutex.locked { + if (stopping) throw IllegalStateException("Already stopping!") + stopping = true + } + // Account for any expected Fibers in a test scenario. + liveFibers.countDown(allowedUnsuspendedFiberCount) + liveFibers.await() + } + /** * Atomic get snapshot + subscribe. This is needed so we don't miss updates between subscriptions to [changes] and * calls to [allStateMachines] @@ -203,6 +234,8 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, } session.receivedMessages += message if (session.waitingForResponse) { + // We only want to resume once, so immediately reset the flag. + session.waitingForResponse = false updateCheckpoint(session.psm) resumeFiber(session.psm) } @@ -285,15 +318,20 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, // This will free up the ThreadLocal so on return the caller can carry on with other transactions psm.commitTransaction() processIORequest(ioRequest) + decrementLiveFibers() } psm.actionOnEnd = { - psm.logic.progressTracker?.currentStep = ProgressTracker.DONE - mutex.locked { - stateMachines.remove(psm)?.let { checkpointStorage.removeCheckpoint(it) } - totalFinishedProtocols.inc() - notifyChangeObservers(psm, AddOrRemove.REMOVE) + try { + psm.logic.progressTracker?.currentStep = ProgressTracker.DONE + mutex.locked { + stateMachines.remove(psm)?.let { checkpointStorage.removeCheckpoint(it) } + totalFinishedProtocols.inc() + notifyChangeObservers(psm, AddOrRemove.REMOVE) + } + endAllFiberSessions(psm) + } finally { + decrementLiveFibers() } - endAllFiberSessions(psm) } mutex.locked { totalStartedProtocols.inc() @@ -370,8 +408,13 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, } private fun resumeFiber(psm: ProtocolStateMachineImpl<*>) { - executor.executeASAP { + // Avoid race condition when setting stopping to true and then checking liveFibers + incrementLiveFibers() + if (!stopping) executor.executeASAP { psm.resume(scheduler) + } else { + psm.logger.debug("Not resuming as SMM is stopping.") + decrementLiveFibers() } } diff --git a/node/src/test/kotlin/com/r3corda/node/messaging/TwoPartyTradeProtocolTests.kt b/node/src/test/kotlin/com/r3corda/node/messaging/TwoPartyTradeProtocolTests.kt index 1aa6d712a4..ca787d7395 100644 --- a/node/src/test/kotlin/com/r3corda/node/messaging/TwoPartyTradeProtocolTests.kt +++ b/node/src/test/kotlin/com/r3corda/node/messaging/TwoPartyTradeProtocolTests.kt @@ -21,7 +21,10 @@ import com.r3corda.core.utilities.LogHelper import com.r3corda.core.utilities.TEST_TX_TIME import com.r3corda.node.internal.AbstractNode import com.r3corda.node.services.config.NodeConfiguration -import com.r3corda.node.services.persistence.* +import com.r3corda.node.services.persistence.DBTransactionStorage +import com.r3corda.node.services.persistence.NodeAttachmentService +import com.r3corda.node.services.persistence.StorageServiceImpl +import com.r3corda.node.services.persistence.checkpoints import com.r3corda.node.utilities.databaseTransaction import com.r3corda.protocols.TwoPartyTradeProtocol.Buyer import com.r3corda.protocols.TwoPartyTradeProtocol.Seller @@ -114,7 +117,7 @@ class TwoPartyTradeProtocolTests { databaseTransaction(bobNode.database) { assertThat(bobNode.checkpointStorage.checkpoints()).isEmpty() } - aliceNode.manuallyCloseDB() + bobNode.manuallyCloseDB() } } diff --git a/node/src/test/kotlin/com/r3corda/node/services/statemachine/StateMachineManagerTests.kt b/node/src/test/kotlin/com/r3corda/node/services/statemachine/StateMachineManagerTests.kt index f4b8c7b4f3..9d3a80b070 100644 --- a/node/src/test/kotlin/com/r3corda/node/services/statemachine/StateMachineManagerTests.kt +++ b/node/src/test/kotlin/com/r3corda/node/services/statemachine/StateMachineManagerTests.kt @@ -51,6 +51,7 @@ class StateMachineManagerTests { @Test fun `newly added protocol is preserved on restart`() { node1.smm.add(NoOpProtocol(nonTerminating = true)) + node1.acceptableLiveFiberCountOnStop = 1 val restoredProtocol = node1.restartAndGetRestoredProtocol() assertThat(restoredProtocol.protocolStarted).isTrue() } @@ -75,6 +76,7 @@ class StateMachineManagerTests { // We push through just enough messages to get only the payload sent node2.pumpReceive() node2.disableDBCloseOnStop() + node2.acceptableLiveFiberCountOnStop = 1 node2.stop() net.runNetwork() val restoredProtocol = node2.restartAndGetRestoredProtocol(node1) @@ -206,6 +208,9 @@ class StateMachineManagerTests { node1 sent sessionEnd() to node3 //There's no session end from the other protocols as they're manually suspended ) + + node2.acceptableLiveFiberCountOnStop = 1 + node3.acceptableLiveFiberCountOnStop = 1 } @Test @@ -218,6 +223,7 @@ class StateMachineManagerTests { node3.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(node3Payload, it) } val multiReceiveProtocol = ReceiveThenSuspendProtocol(node2.info.legalIdentity, node3.info.legalIdentity) node1.smm.add(multiReceiveProtocol) + node1.acceptableLiveFiberCountOnStop = 1 net.runNetwork() assertThat(multiReceiveProtocol.receivedPayloads[0]).isEqualTo(node2Payload) assertThat(multiReceiveProtocol.receivedPayloads[1]).isEqualTo(node3Payload) @@ -271,6 +277,7 @@ class StateMachineManagerTests { disableDBCloseOnStop() //Handover DB to new node copy stop() val newNode = mockNet.createNode(networkMapNode?.info?.address, id, advertisedServices = *advertisedServices.toTypedArray()) + newNode.acceptableLiveFiberCountOnStop = 1 manuallyCloseDB() mockNet.runNetwork() // allow NetworkMapService messages to stabilise and thus start the state machine return newNode.getSingleProtocol

().first diff --git a/test-utils/src/main/kotlin/com/r3corda/testing/node/MockNode.kt b/test-utils/src/main/kotlin/com/r3corda/testing/node/MockNode.kt index 9ef1fd9cfd..8b27014cae 100644 --- a/test-utils/src/main/kotlin/com/r3corda/testing/node/MockNode.kt +++ b/test-utils/src/main/kotlin/com/r3corda/testing/node/MockNode.kt @@ -112,7 +112,7 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false, override val log: Logger = loggerFor() override val serverThread: AffinityExecutor = if (mockNet.threadPerNode) - ServiceAffinityExecutor("Mock node thread", 1) + ServiceAffinityExecutor("Mock node $id thread", 1) else { mockNet.sharedUserCount.incrementAndGet() mockNet.sharedServerThread @@ -171,6 +171,11 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false, dbCloser?.run() dbCloser = null } + + // You can change this from zero if you have custom [ProtocolLogic] that park themselves. e.g. [StateMachineManagerTests] + var acceptableLiveFiberCountOnStop: Int = 0 + + override fun acceptableLiveFiberCountOnStop(): Int = acceptableLiveFiberCountOnStop } /** Returns a node, optionally created by the passed factory method. */