Implement controlled stop of StateMachineManager.

This commit is contained in:
rick.parker 2016-11-04 10:19:51 +00:00
parent c0e08bee60
commit fe6bf0e6ea
7 changed files with 82 additions and 17 deletions

View File

@ -1,6 +1,7 @@
package com.r3corda.node.internal package com.r3corda.node.internal
import com.codahale.metrics.MetricRegistry import com.codahale.metrics.MetricRegistry
import com.google.common.annotations.VisibleForTesting
import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.MoreExecutors import com.google.common.util.concurrent.MoreExecutors
import com.google.common.util.concurrent.SettableFuture import com.google.common.util.concurrent.SettableFuture
@ -255,12 +256,17 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, val netwo
runOnStop += Runnable { net.stop() } runOnStop += Runnable { net.stop() }
_networkMapRegistrationFuture.setFuture(registerWithNetworkMap()) _networkMapRegistrationFuture.setFuture(registerWithNetworkMap())
smm.start() smm.start()
// Shut down the SMM so no Fibers are scheduled.
runOnStop += Runnable { smm.stop(acceptableLiveFiberCountOnStop()) }
scheduler.start() scheduler.start()
} }
started = true started = true
return this return this
} }
@VisibleForTesting
protected open fun acceptableLiveFiberCountOnStop(): Int = 0
private fun hasSSLCertificates(): Boolean { private fun hasSSLCertificates(): Boolean {
val keyStore = try { val keyStore = try {
// This will throw exception if key file not found or keystore password is incorrect. // This will throw exception if key file not found or keystore password is incorrect.

View File

@ -1,17 +1,14 @@
package com.r3corda.node.services.api package com.r3corda.node.services.api
import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.crypto.Party
import com.r3corda.core.messaging.MessagingService import com.r3corda.core.messaging.MessagingService
import com.r3corda.core.node.PluginServiceHub import com.r3corda.core.node.PluginServiceHub
import com.r3corda.core.node.ServiceHub
import com.r3corda.core.node.services.TxWritableStorageService import com.r3corda.core.node.services.TxWritableStorageService
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.protocols.ProtocolLogicRefFactory import com.r3corda.core.protocols.ProtocolLogicRefFactory
import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import kotlin.reflect.KClass
interface MessagingServiceInternal : MessagingService { interface MessagingServiceInternal : MessagingService {
/** /**

View File

@ -18,6 +18,7 @@ import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.node.services.statemachine.StateMachineManager.* import com.r3corda.node.services.statemachine.StateMachineManager.*
import com.r3corda.node.utilities.StrandLocalTransactionManager import com.r3corda.node.utilities.StrandLocalTransactionManager
import com.r3corda.node.utilities.createDatabaseTransaction import com.r3corda.node.utilities.createDatabaseTransaction
import com.r3corda.node.utilities.databaseTransaction
import org.jetbrains.exposed.sql.Database import org.jetbrains.exposed.sql.Database
import org.jetbrains.exposed.sql.Transaction import org.jetbrains.exposed.sql.Transaction
import org.jetbrains.exposed.sql.transactions.TransactionManager import org.jetbrains.exposed.sql.transactions.TransactionManager
@ -230,7 +231,7 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
// we have to pass the Thread local Transaction across via a transient field as the Fiber Park swaps them out. // we have to pass the Thread local Transaction across via a transient field as the Fiber Park swaps them out.
txTrampoline = TransactionManager.currentOrNull() txTrampoline = TransactionManager.currentOrNull()
StrandLocalTransactionManager.setThreadLocalTx(null) StrandLocalTransactionManager.setThreadLocalTx(null)
ioRequest.session.waitingForResponse = true ioRequest.session.waitingForResponse = (ioRequest is ReceiveRequest<*>)
parkAndSerialize { fiber, serializer -> parkAndSerialize { fiber, serializer ->
logger.trace { "Suspended on $ioRequest" } logger.trace { "Suspended on $ioRequest" }
// restore the Tx onto the ThreadLocal so that we can commit the ensuing checkpoint to the DB // restore the Tx onto the ThreadLocal so that we can commit the ensuing checkpoint to the DB
@ -246,13 +247,16 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
processException(t) processException(t)
} }
} }
ioRequest.session.waitingForResponse = false logger.trace { "Resumed from $ioRequest" }
createTransaction() createTransaction()
} }
private fun processException(t: Throwable) { private fun processException(t: Throwable) {
actionOnEnd() // This can get called in actionOnSuspend *after* we commit the database transaction, so optionally open a new one here.
_resultFuture?.setException(t) databaseTransaction(database) {
actionOnEnd()
_resultFuture?.setException(t)
}
} }
internal fun resume(scheduler: FiberScheduler) { internal fun resume(scheduler: FiberScheduler) {

View File

@ -29,6 +29,7 @@ import com.r3corda.node.utilities.AddOrRemove
import com.r3corda.node.utilities.AffinityExecutor import com.r3corda.node.utilities.AffinityExecutor
import com.r3corda.node.utilities.isolatedTransaction import com.r3corda.node.utilities.isolatedTransaction
import kotlinx.support.jdk8.collections.removeIf import kotlinx.support.jdk8.collections.removeIf
import org.apache.activemq.artemis.utils.ReusableLatch
import org.jetbrains.exposed.sql.Database import org.jetbrains.exposed.sql.Database
import rx.Observable import rx.Observable
import rx.subjects.PublishSubject import rx.subjects.PublishSubject
@ -95,6 +96,11 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
} }
}) })
// True if we're shutting down, so don't resume anything.
@Volatile private var stopping = false
// How many Fibers are running and not suspended. If zero and stopping is true, then we are halted.
private val liveFibers = ReusableLatch()
// Monitoring support. // Monitoring support.
private val metrics = serviceHub.monitoringService.metrics private val metrics = serviceHub.monitoringService.metrics
@ -144,6 +150,31 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
serviceHub.networkMapCache.mapServiceRegistered.then(executor) { resumeRestoredFibers() } serviceHub.networkMapCache.mapServiceRegistered.then(executor) { resumeRestoredFibers() }
} }
private fun decrementLiveFibers() {
liveFibers.countDown()
}
private fun incrementLiveFibers() {
liveFibers.countUp()
}
/**
* Start the shutdown process, bringing the [StateMachineManager] to a controlled stop. When this method returns,
* all Fibers have been suspended and checkpointed, or have completed.
*
* @param allowedUnsuspendedFiberCount Optional parameter is used in some tests.
*/
fun stop(allowedUnsuspendedFiberCount: Int = 0) {
check(allowedUnsuspendedFiberCount >= 0)
mutex.locked {
if (stopping) throw IllegalStateException("Already stopping!")
stopping = true
}
// Account for any expected Fibers in a test scenario.
liveFibers.countDown(allowedUnsuspendedFiberCount)
liveFibers.await()
}
/** /**
* Atomic get snapshot + subscribe. This is needed so we don't miss updates between subscriptions to [changes] and * Atomic get snapshot + subscribe. This is needed so we don't miss updates between subscriptions to [changes] and
* calls to [allStateMachines] * calls to [allStateMachines]
@ -203,6 +234,8 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
} }
session.receivedMessages += message session.receivedMessages += message
if (session.waitingForResponse) { if (session.waitingForResponse) {
// We only want to resume once, so immediately reset the flag.
session.waitingForResponse = false
updateCheckpoint(session.psm) updateCheckpoint(session.psm)
resumeFiber(session.psm) resumeFiber(session.psm)
} }
@ -285,15 +318,20 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
// This will free up the ThreadLocal so on return the caller can carry on with other transactions // This will free up the ThreadLocal so on return the caller can carry on with other transactions
psm.commitTransaction() psm.commitTransaction()
processIORequest(ioRequest) processIORequest(ioRequest)
decrementLiveFibers()
} }
psm.actionOnEnd = { psm.actionOnEnd = {
psm.logic.progressTracker?.currentStep = ProgressTracker.DONE try {
mutex.locked { psm.logic.progressTracker?.currentStep = ProgressTracker.DONE
stateMachines.remove(psm)?.let { checkpointStorage.removeCheckpoint(it) } mutex.locked {
totalFinishedProtocols.inc() stateMachines.remove(psm)?.let { checkpointStorage.removeCheckpoint(it) }
notifyChangeObservers(psm, AddOrRemove.REMOVE) totalFinishedProtocols.inc()
notifyChangeObservers(psm, AddOrRemove.REMOVE)
}
endAllFiberSessions(psm)
} finally {
decrementLiveFibers()
} }
endAllFiberSessions(psm)
} }
mutex.locked { mutex.locked {
totalStartedProtocols.inc() totalStartedProtocols.inc()
@ -370,8 +408,13 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
} }
private fun resumeFiber(psm: ProtocolStateMachineImpl<*>) { private fun resumeFiber(psm: ProtocolStateMachineImpl<*>) {
executor.executeASAP { // Avoid race condition when setting stopping to true and then checking liveFibers
incrementLiveFibers()
if (!stopping) executor.executeASAP {
psm.resume(scheduler) psm.resume(scheduler)
} else {
psm.logger.debug("Not resuming as SMM is stopping.")
decrementLiveFibers()
} }
} }

View File

@ -21,7 +21,10 @@ import com.r3corda.core.utilities.LogHelper
import com.r3corda.core.utilities.TEST_TX_TIME import com.r3corda.core.utilities.TEST_TX_TIME
import com.r3corda.node.internal.AbstractNode import com.r3corda.node.internal.AbstractNode
import com.r3corda.node.services.config.NodeConfiguration import com.r3corda.node.services.config.NodeConfiguration
import com.r3corda.node.services.persistence.* import com.r3corda.node.services.persistence.DBTransactionStorage
import com.r3corda.node.services.persistence.NodeAttachmentService
import com.r3corda.node.services.persistence.StorageServiceImpl
import com.r3corda.node.services.persistence.checkpoints
import com.r3corda.node.utilities.databaseTransaction import com.r3corda.node.utilities.databaseTransaction
import com.r3corda.protocols.TwoPartyTradeProtocol.Buyer import com.r3corda.protocols.TwoPartyTradeProtocol.Buyer
import com.r3corda.protocols.TwoPartyTradeProtocol.Seller import com.r3corda.protocols.TwoPartyTradeProtocol.Seller
@ -114,7 +117,7 @@ class TwoPartyTradeProtocolTests {
databaseTransaction(bobNode.database) { databaseTransaction(bobNode.database) {
assertThat(bobNode.checkpointStorage.checkpoints()).isEmpty() assertThat(bobNode.checkpointStorage.checkpoints()).isEmpty()
} }
aliceNode.manuallyCloseDB() bobNode.manuallyCloseDB()
} }
} }

View File

@ -51,6 +51,7 @@ class StateMachineManagerTests {
@Test @Test
fun `newly added protocol is preserved on restart`() { fun `newly added protocol is preserved on restart`() {
node1.smm.add(NoOpProtocol(nonTerminating = true)) node1.smm.add(NoOpProtocol(nonTerminating = true))
node1.acceptableLiveFiberCountOnStop = 1
val restoredProtocol = node1.restartAndGetRestoredProtocol<NoOpProtocol>() val restoredProtocol = node1.restartAndGetRestoredProtocol<NoOpProtocol>()
assertThat(restoredProtocol.protocolStarted).isTrue() assertThat(restoredProtocol.protocolStarted).isTrue()
} }
@ -75,6 +76,7 @@ class StateMachineManagerTests {
// We push through just enough messages to get only the payload sent // We push through just enough messages to get only the payload sent
node2.pumpReceive() node2.pumpReceive()
node2.disableDBCloseOnStop() node2.disableDBCloseOnStop()
node2.acceptableLiveFiberCountOnStop = 1
node2.stop() node2.stop()
net.runNetwork() net.runNetwork()
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1) val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1)
@ -206,6 +208,9 @@ class StateMachineManagerTests {
node1 sent sessionEnd() to node3 node1 sent sessionEnd() to node3
//There's no session end from the other protocols as they're manually suspended //There's no session end from the other protocols as they're manually suspended
) )
node2.acceptableLiveFiberCountOnStop = 1
node3.acceptableLiveFiberCountOnStop = 1
} }
@Test @Test
@ -218,6 +223,7 @@ class StateMachineManagerTests {
node3.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(node3Payload, it) } node3.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(node3Payload, it) }
val multiReceiveProtocol = ReceiveThenSuspendProtocol(node2.info.legalIdentity, node3.info.legalIdentity) val multiReceiveProtocol = ReceiveThenSuspendProtocol(node2.info.legalIdentity, node3.info.legalIdentity)
node1.smm.add(multiReceiveProtocol) node1.smm.add(multiReceiveProtocol)
node1.acceptableLiveFiberCountOnStop = 1
net.runNetwork() net.runNetwork()
assertThat(multiReceiveProtocol.receivedPayloads[0]).isEqualTo(node2Payload) assertThat(multiReceiveProtocol.receivedPayloads[0]).isEqualTo(node2Payload)
assertThat(multiReceiveProtocol.receivedPayloads[1]).isEqualTo(node3Payload) assertThat(multiReceiveProtocol.receivedPayloads[1]).isEqualTo(node3Payload)
@ -271,6 +277,7 @@ class StateMachineManagerTests {
disableDBCloseOnStop() //Handover DB to new node copy disableDBCloseOnStop() //Handover DB to new node copy
stop() stop()
val newNode = mockNet.createNode(networkMapNode?.info?.address, id, advertisedServices = *advertisedServices.toTypedArray()) val newNode = mockNet.createNode(networkMapNode?.info?.address, id, advertisedServices = *advertisedServices.toTypedArray())
newNode.acceptableLiveFiberCountOnStop = 1
manuallyCloseDB() manuallyCloseDB()
mockNet.runNetwork() // allow NetworkMapService messages to stabilise and thus start the state machine mockNet.runNetwork() // allow NetworkMapService messages to stabilise and thus start the state machine
return newNode.getSingleProtocol<P>().first return newNode.getSingleProtocol<P>().first

View File

@ -112,7 +112,7 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
override val log: Logger = loggerFor<MockNode>() override val log: Logger = loggerFor<MockNode>()
override val serverThread: AffinityExecutor = override val serverThread: AffinityExecutor =
if (mockNet.threadPerNode) if (mockNet.threadPerNode)
ServiceAffinityExecutor("Mock node thread", 1) ServiceAffinityExecutor("Mock node $id thread", 1)
else { else {
mockNet.sharedUserCount.incrementAndGet() mockNet.sharedUserCount.incrementAndGet()
mockNet.sharedServerThread mockNet.sharedServerThread
@ -171,6 +171,11 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
dbCloser?.run() dbCloser?.run()
dbCloser = null dbCloser = null
} }
// You can change this from zero if you have custom [ProtocolLogic] that park themselves. e.g. [StateMachineManagerTests]
var acceptableLiveFiberCountOnStop: Int = 0
override fun acceptableLiveFiberCountOnStop(): Int = acceptableLiveFiberCountOnStop
} }
/** Returns a node, optionally created by the passed factory method. */ /** Returns a node, optionally created by the passed factory method. */