Implement controlled stop of StateMachineManager.

This commit is contained in:
rick.parker 2016-11-04 10:19:51 +00:00
parent c0e08bee60
commit fe6bf0e6ea
7 changed files with 82 additions and 17 deletions

View File

@ -1,6 +1,7 @@
package com.r3corda.node.internal
import com.codahale.metrics.MetricRegistry
import com.google.common.annotations.VisibleForTesting
import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.MoreExecutors
import com.google.common.util.concurrent.SettableFuture
@ -255,12 +256,17 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, val netwo
runOnStop += Runnable { net.stop() }
_networkMapRegistrationFuture.setFuture(registerWithNetworkMap())
smm.start()
// Shut down the SMM so no Fibers are scheduled.
runOnStop += Runnable { smm.stop(acceptableLiveFiberCountOnStop()) }
scheduler.start()
}
started = true
return this
}
@VisibleForTesting
protected open fun acceptableLiveFiberCountOnStop(): Int = 0
private fun hasSSLCertificates(): Boolean {
val keyStore = try {
// This will throw exception if key file not found or keystore password is incorrect.

View File

@ -1,17 +1,14 @@
package com.r3corda.node.services.api
import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.crypto.Party
import com.r3corda.core.messaging.MessagingService
import com.r3corda.core.node.PluginServiceHub
import com.r3corda.core.node.ServiceHub
import com.r3corda.core.node.services.TxWritableStorageService
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.protocols.ProtocolLogicRefFactory
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl
import org.slf4j.LoggerFactory
import kotlin.reflect.KClass
interface MessagingServiceInternal : MessagingService {
/**

View File

@ -18,6 +18,7 @@ import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.node.services.statemachine.StateMachineManager.*
import com.r3corda.node.utilities.StrandLocalTransactionManager
import com.r3corda.node.utilities.createDatabaseTransaction
import com.r3corda.node.utilities.databaseTransaction
import org.jetbrains.exposed.sql.Database
import org.jetbrains.exposed.sql.Transaction
import org.jetbrains.exposed.sql.transactions.TransactionManager
@ -230,7 +231,7 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
// we have to pass the Thread local Transaction across via a transient field as the Fiber Park swaps them out.
txTrampoline = TransactionManager.currentOrNull()
StrandLocalTransactionManager.setThreadLocalTx(null)
ioRequest.session.waitingForResponse = true
ioRequest.session.waitingForResponse = (ioRequest is ReceiveRequest<*>)
parkAndSerialize { fiber, serializer ->
logger.trace { "Suspended on $ioRequest" }
// restore the Tx onto the ThreadLocal so that we can commit the ensuing checkpoint to the DB
@ -246,14 +247,17 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
processException(t)
}
}
ioRequest.session.waitingForResponse = false
logger.trace { "Resumed from $ioRequest" }
createTransaction()
}
private fun processException(t: Throwable) {
// This can get called in actionOnSuspend *after* we commit the database transaction, so optionally open a new one here.
databaseTransaction(database) {
actionOnEnd()
_resultFuture?.setException(t)
}
}
internal fun resume(scheduler: FiberScheduler) {
try {

View File

@ -29,6 +29,7 @@ import com.r3corda.node.utilities.AddOrRemove
import com.r3corda.node.utilities.AffinityExecutor
import com.r3corda.node.utilities.isolatedTransaction
import kotlinx.support.jdk8.collections.removeIf
import org.apache.activemq.artemis.utils.ReusableLatch
import org.jetbrains.exposed.sql.Database
import rx.Observable
import rx.subjects.PublishSubject
@ -95,6 +96,11 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
}
})
// True if we're shutting down, so don't resume anything.
@Volatile private var stopping = false
// How many Fibers are running and not suspended. If zero and stopping is true, then we are halted.
private val liveFibers = ReusableLatch()
// Monitoring support.
private val metrics = serviceHub.monitoringService.metrics
@ -144,6 +150,31 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
serviceHub.networkMapCache.mapServiceRegistered.then(executor) { resumeRestoredFibers() }
}
private fun decrementLiveFibers() {
liveFibers.countDown()
}
private fun incrementLiveFibers() {
liveFibers.countUp()
}
/**
* Start the shutdown process, bringing the [StateMachineManager] to a controlled stop. When this method returns,
* all Fibers have been suspended and checkpointed, or have completed.
*
* @param allowedUnsuspendedFiberCount Optional parameter is used in some tests.
*/
fun stop(allowedUnsuspendedFiberCount: Int = 0) {
check(allowedUnsuspendedFiberCount >= 0)
mutex.locked {
if (stopping) throw IllegalStateException("Already stopping!")
stopping = true
}
// Account for any expected Fibers in a test scenario.
liveFibers.countDown(allowedUnsuspendedFiberCount)
liveFibers.await()
}
/**
* Atomic get snapshot + subscribe. This is needed so we don't miss updates between subscriptions to [changes] and
* calls to [allStateMachines]
@ -203,6 +234,8 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
}
session.receivedMessages += message
if (session.waitingForResponse) {
// We only want to resume once, so immediately reset the flag.
session.waitingForResponse = false
updateCheckpoint(session.psm)
resumeFiber(session.psm)
}
@ -285,8 +318,10 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
// This will free up the ThreadLocal so on return the caller can carry on with other transactions
psm.commitTransaction()
processIORequest(ioRequest)
decrementLiveFibers()
}
psm.actionOnEnd = {
try {
psm.logic.progressTracker?.currentStep = ProgressTracker.DONE
mutex.locked {
stateMachines.remove(psm)?.let { checkpointStorage.removeCheckpoint(it) }
@ -294,6 +329,9 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
notifyChangeObservers(psm, AddOrRemove.REMOVE)
}
endAllFiberSessions(psm)
} finally {
decrementLiveFibers()
}
}
mutex.locked {
totalStartedProtocols.inc()
@ -370,8 +408,13 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
}
private fun resumeFiber(psm: ProtocolStateMachineImpl<*>) {
executor.executeASAP {
// Avoid race condition when setting stopping to true and then checking liveFibers
incrementLiveFibers()
if (!stopping) executor.executeASAP {
psm.resume(scheduler)
} else {
psm.logger.debug("Not resuming as SMM is stopping.")
decrementLiveFibers()
}
}

View File

@ -21,7 +21,10 @@ import com.r3corda.core.utilities.LogHelper
import com.r3corda.core.utilities.TEST_TX_TIME
import com.r3corda.node.internal.AbstractNode
import com.r3corda.node.services.config.NodeConfiguration
import com.r3corda.node.services.persistence.*
import com.r3corda.node.services.persistence.DBTransactionStorage
import com.r3corda.node.services.persistence.NodeAttachmentService
import com.r3corda.node.services.persistence.StorageServiceImpl
import com.r3corda.node.services.persistence.checkpoints
import com.r3corda.node.utilities.databaseTransaction
import com.r3corda.protocols.TwoPartyTradeProtocol.Buyer
import com.r3corda.protocols.TwoPartyTradeProtocol.Seller
@ -114,7 +117,7 @@ class TwoPartyTradeProtocolTests {
databaseTransaction(bobNode.database) {
assertThat(bobNode.checkpointStorage.checkpoints()).isEmpty()
}
aliceNode.manuallyCloseDB()
bobNode.manuallyCloseDB()
}
}

View File

@ -51,6 +51,7 @@ class StateMachineManagerTests {
@Test
fun `newly added protocol is preserved on restart`() {
node1.smm.add(NoOpProtocol(nonTerminating = true))
node1.acceptableLiveFiberCountOnStop = 1
val restoredProtocol = node1.restartAndGetRestoredProtocol<NoOpProtocol>()
assertThat(restoredProtocol.protocolStarted).isTrue()
}
@ -75,6 +76,7 @@ class StateMachineManagerTests {
// We push through just enough messages to get only the payload sent
node2.pumpReceive()
node2.disableDBCloseOnStop()
node2.acceptableLiveFiberCountOnStop = 1
node2.stop()
net.runNetwork()
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1)
@ -206,6 +208,9 @@ class StateMachineManagerTests {
node1 sent sessionEnd() to node3
//There's no session end from the other protocols as they're manually suspended
)
node2.acceptableLiveFiberCountOnStop = 1
node3.acceptableLiveFiberCountOnStop = 1
}
@Test
@ -218,6 +223,7 @@ class StateMachineManagerTests {
node3.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(node3Payload, it) }
val multiReceiveProtocol = ReceiveThenSuspendProtocol(node2.info.legalIdentity, node3.info.legalIdentity)
node1.smm.add(multiReceiveProtocol)
node1.acceptableLiveFiberCountOnStop = 1
net.runNetwork()
assertThat(multiReceiveProtocol.receivedPayloads[0]).isEqualTo(node2Payload)
assertThat(multiReceiveProtocol.receivedPayloads[1]).isEqualTo(node3Payload)
@ -271,6 +277,7 @@ class StateMachineManagerTests {
disableDBCloseOnStop() //Handover DB to new node copy
stop()
val newNode = mockNet.createNode(networkMapNode?.info?.address, id, advertisedServices = *advertisedServices.toTypedArray())
newNode.acceptableLiveFiberCountOnStop = 1
manuallyCloseDB()
mockNet.runNetwork() // allow NetworkMapService messages to stabilise and thus start the state machine
return newNode.getSingleProtocol<P>().first

View File

@ -112,7 +112,7 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
override val log: Logger = loggerFor<MockNode>()
override val serverThread: AffinityExecutor =
if (mockNet.threadPerNode)
ServiceAffinityExecutor("Mock node thread", 1)
ServiceAffinityExecutor("Mock node $id thread", 1)
else {
mockNet.sharedUserCount.incrementAndGet()
mockNet.sharedServerThread
@ -171,6 +171,11 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
dbCloser?.run()
dbCloser = null
}
// You can change this from zero if you have custom [ProtocolLogic] that park themselves. e.g. [StateMachineManagerTests]
var acceptableLiveFiberCountOnStop: Int = 0
override fun acceptableLiveFiberCountOnStop(): Int = acceptableLiveFiberCountOnStop
}
/** Returns a node, optionally created by the passed factory method. */