mirror of
https://github.com/corda/corda.git
synced 2025-01-18 18:56:28 +00:00
Implement controlled stop of StateMachineManager.
This commit is contained in:
parent
c0e08bee60
commit
fe6bf0e6ea
@ -1,6 +1,7 @@
|
||||
package com.r3corda.node.internal
|
||||
|
||||
import com.codahale.metrics.MetricRegistry
|
||||
import com.google.common.annotations.VisibleForTesting
|
||||
import com.google.common.util.concurrent.ListenableFuture
|
||||
import com.google.common.util.concurrent.MoreExecutors
|
||||
import com.google.common.util.concurrent.SettableFuture
|
||||
@ -255,12 +256,17 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, val netwo
|
||||
runOnStop += Runnable { net.stop() }
|
||||
_networkMapRegistrationFuture.setFuture(registerWithNetworkMap())
|
||||
smm.start()
|
||||
// Shut down the SMM so no Fibers are scheduled.
|
||||
runOnStop += Runnable { smm.stop(acceptableLiveFiberCountOnStop()) }
|
||||
scheduler.start()
|
||||
}
|
||||
started = true
|
||||
return this
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
protected open fun acceptableLiveFiberCountOnStop(): Int = 0
|
||||
|
||||
private fun hasSSLCertificates(): Boolean {
|
||||
val keyStore = try {
|
||||
// This will throw exception if key file not found or keystore password is incorrect.
|
||||
|
@ -1,17 +1,14 @@
|
||||
package com.r3corda.node.services.api
|
||||
|
||||
import com.google.common.util.concurrent.ListenableFuture
|
||||
import com.r3corda.core.crypto.Party
|
||||
import com.r3corda.core.messaging.MessagingService
|
||||
import com.r3corda.core.node.PluginServiceHub
|
||||
import com.r3corda.core.node.ServiceHub
|
||||
import com.r3corda.core.node.services.TxWritableStorageService
|
||||
import com.r3corda.core.protocols.ProtocolLogic
|
||||
import com.r3corda.core.protocols.ProtocolLogicRefFactory
|
||||
import com.r3corda.core.transactions.SignedTransaction
|
||||
import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl
|
||||
import org.slf4j.LoggerFactory
|
||||
import kotlin.reflect.KClass
|
||||
|
||||
interface MessagingServiceInternal : MessagingService {
|
||||
/**
|
||||
|
@ -18,6 +18,7 @@ import com.r3corda.node.services.api.ServiceHubInternal
|
||||
import com.r3corda.node.services.statemachine.StateMachineManager.*
|
||||
import com.r3corda.node.utilities.StrandLocalTransactionManager
|
||||
import com.r3corda.node.utilities.createDatabaseTransaction
|
||||
import com.r3corda.node.utilities.databaseTransaction
|
||||
import org.jetbrains.exposed.sql.Database
|
||||
import org.jetbrains.exposed.sql.Transaction
|
||||
import org.jetbrains.exposed.sql.transactions.TransactionManager
|
||||
@ -230,7 +231,7 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
// we have to pass the Thread local Transaction across via a transient field as the Fiber Park swaps them out.
|
||||
txTrampoline = TransactionManager.currentOrNull()
|
||||
StrandLocalTransactionManager.setThreadLocalTx(null)
|
||||
ioRequest.session.waitingForResponse = true
|
||||
ioRequest.session.waitingForResponse = (ioRequest is ReceiveRequest<*>)
|
||||
parkAndSerialize { fiber, serializer ->
|
||||
logger.trace { "Suspended on $ioRequest" }
|
||||
// restore the Tx onto the ThreadLocal so that we can commit the ensuing checkpoint to the DB
|
||||
@ -246,13 +247,16 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
processException(t)
|
||||
}
|
||||
}
|
||||
ioRequest.session.waitingForResponse = false
|
||||
logger.trace { "Resumed from $ioRequest" }
|
||||
createTransaction()
|
||||
}
|
||||
|
||||
private fun processException(t: Throwable) {
|
||||
actionOnEnd()
|
||||
_resultFuture?.setException(t)
|
||||
// This can get called in actionOnSuspend *after* we commit the database transaction, so optionally open a new one here.
|
||||
databaseTransaction(database) {
|
||||
actionOnEnd()
|
||||
_resultFuture?.setException(t)
|
||||
}
|
||||
}
|
||||
|
||||
internal fun resume(scheduler: FiberScheduler) {
|
||||
|
@ -29,6 +29,7 @@ import com.r3corda.node.utilities.AddOrRemove
|
||||
import com.r3corda.node.utilities.AffinityExecutor
|
||||
import com.r3corda.node.utilities.isolatedTransaction
|
||||
import kotlinx.support.jdk8.collections.removeIf
|
||||
import org.apache.activemq.artemis.utils.ReusableLatch
|
||||
import org.jetbrains.exposed.sql.Database
|
||||
import rx.Observable
|
||||
import rx.subjects.PublishSubject
|
||||
@ -95,6 +96,11 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
||||
}
|
||||
})
|
||||
|
||||
// True if we're shutting down, so don't resume anything.
|
||||
@Volatile private var stopping = false
|
||||
// How many Fibers are running and not suspended. If zero and stopping is true, then we are halted.
|
||||
private val liveFibers = ReusableLatch()
|
||||
|
||||
// Monitoring support.
|
||||
private val metrics = serviceHub.monitoringService.metrics
|
||||
|
||||
@ -144,6 +150,31 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
||||
serviceHub.networkMapCache.mapServiceRegistered.then(executor) { resumeRestoredFibers() }
|
||||
}
|
||||
|
||||
private fun decrementLiveFibers() {
|
||||
liveFibers.countDown()
|
||||
}
|
||||
|
||||
private fun incrementLiveFibers() {
|
||||
liveFibers.countUp()
|
||||
}
|
||||
|
||||
/**
|
||||
* Start the shutdown process, bringing the [StateMachineManager] to a controlled stop. When this method returns,
|
||||
* all Fibers have been suspended and checkpointed, or have completed.
|
||||
*
|
||||
* @param allowedUnsuspendedFiberCount Optional parameter is used in some tests.
|
||||
*/
|
||||
fun stop(allowedUnsuspendedFiberCount: Int = 0) {
|
||||
check(allowedUnsuspendedFiberCount >= 0)
|
||||
mutex.locked {
|
||||
if (stopping) throw IllegalStateException("Already stopping!")
|
||||
stopping = true
|
||||
}
|
||||
// Account for any expected Fibers in a test scenario.
|
||||
liveFibers.countDown(allowedUnsuspendedFiberCount)
|
||||
liveFibers.await()
|
||||
}
|
||||
|
||||
/**
|
||||
* Atomic get snapshot + subscribe. This is needed so we don't miss updates between subscriptions to [changes] and
|
||||
* calls to [allStateMachines]
|
||||
@ -203,6 +234,8 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
||||
}
|
||||
session.receivedMessages += message
|
||||
if (session.waitingForResponse) {
|
||||
// We only want to resume once, so immediately reset the flag.
|
||||
session.waitingForResponse = false
|
||||
updateCheckpoint(session.psm)
|
||||
resumeFiber(session.psm)
|
||||
}
|
||||
@ -285,15 +318,20 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
||||
// This will free up the ThreadLocal so on return the caller can carry on with other transactions
|
||||
psm.commitTransaction()
|
||||
processIORequest(ioRequest)
|
||||
decrementLiveFibers()
|
||||
}
|
||||
psm.actionOnEnd = {
|
||||
psm.logic.progressTracker?.currentStep = ProgressTracker.DONE
|
||||
mutex.locked {
|
||||
stateMachines.remove(psm)?.let { checkpointStorage.removeCheckpoint(it) }
|
||||
totalFinishedProtocols.inc()
|
||||
notifyChangeObservers(psm, AddOrRemove.REMOVE)
|
||||
try {
|
||||
psm.logic.progressTracker?.currentStep = ProgressTracker.DONE
|
||||
mutex.locked {
|
||||
stateMachines.remove(psm)?.let { checkpointStorage.removeCheckpoint(it) }
|
||||
totalFinishedProtocols.inc()
|
||||
notifyChangeObservers(psm, AddOrRemove.REMOVE)
|
||||
}
|
||||
endAllFiberSessions(psm)
|
||||
} finally {
|
||||
decrementLiveFibers()
|
||||
}
|
||||
endAllFiberSessions(psm)
|
||||
}
|
||||
mutex.locked {
|
||||
totalStartedProtocols.inc()
|
||||
@ -370,8 +408,13 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
||||
}
|
||||
|
||||
private fun resumeFiber(psm: ProtocolStateMachineImpl<*>) {
|
||||
executor.executeASAP {
|
||||
// Avoid race condition when setting stopping to true and then checking liveFibers
|
||||
incrementLiveFibers()
|
||||
if (!stopping) executor.executeASAP {
|
||||
psm.resume(scheduler)
|
||||
} else {
|
||||
psm.logger.debug("Not resuming as SMM is stopping.")
|
||||
decrementLiveFibers()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -21,7 +21,10 @@ import com.r3corda.core.utilities.LogHelper
|
||||
import com.r3corda.core.utilities.TEST_TX_TIME
|
||||
import com.r3corda.node.internal.AbstractNode
|
||||
import com.r3corda.node.services.config.NodeConfiguration
|
||||
import com.r3corda.node.services.persistence.*
|
||||
import com.r3corda.node.services.persistence.DBTransactionStorage
|
||||
import com.r3corda.node.services.persistence.NodeAttachmentService
|
||||
import com.r3corda.node.services.persistence.StorageServiceImpl
|
||||
import com.r3corda.node.services.persistence.checkpoints
|
||||
import com.r3corda.node.utilities.databaseTransaction
|
||||
import com.r3corda.protocols.TwoPartyTradeProtocol.Buyer
|
||||
import com.r3corda.protocols.TwoPartyTradeProtocol.Seller
|
||||
@ -114,7 +117,7 @@ class TwoPartyTradeProtocolTests {
|
||||
databaseTransaction(bobNode.database) {
|
||||
assertThat(bobNode.checkpointStorage.checkpoints()).isEmpty()
|
||||
}
|
||||
aliceNode.manuallyCloseDB()
|
||||
bobNode.manuallyCloseDB()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -51,6 +51,7 @@ class StateMachineManagerTests {
|
||||
@Test
|
||||
fun `newly added protocol is preserved on restart`() {
|
||||
node1.smm.add(NoOpProtocol(nonTerminating = true))
|
||||
node1.acceptableLiveFiberCountOnStop = 1
|
||||
val restoredProtocol = node1.restartAndGetRestoredProtocol<NoOpProtocol>()
|
||||
assertThat(restoredProtocol.protocolStarted).isTrue()
|
||||
}
|
||||
@ -75,6 +76,7 @@ class StateMachineManagerTests {
|
||||
// We push through just enough messages to get only the payload sent
|
||||
node2.pumpReceive()
|
||||
node2.disableDBCloseOnStop()
|
||||
node2.acceptableLiveFiberCountOnStop = 1
|
||||
node2.stop()
|
||||
net.runNetwork()
|
||||
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1)
|
||||
@ -206,6 +208,9 @@ class StateMachineManagerTests {
|
||||
node1 sent sessionEnd() to node3
|
||||
//There's no session end from the other protocols as they're manually suspended
|
||||
)
|
||||
|
||||
node2.acceptableLiveFiberCountOnStop = 1
|
||||
node3.acceptableLiveFiberCountOnStop = 1
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -218,6 +223,7 @@ class StateMachineManagerTests {
|
||||
node3.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(node3Payload, it) }
|
||||
val multiReceiveProtocol = ReceiveThenSuspendProtocol(node2.info.legalIdentity, node3.info.legalIdentity)
|
||||
node1.smm.add(multiReceiveProtocol)
|
||||
node1.acceptableLiveFiberCountOnStop = 1
|
||||
net.runNetwork()
|
||||
assertThat(multiReceiveProtocol.receivedPayloads[0]).isEqualTo(node2Payload)
|
||||
assertThat(multiReceiveProtocol.receivedPayloads[1]).isEqualTo(node3Payload)
|
||||
@ -271,6 +277,7 @@ class StateMachineManagerTests {
|
||||
disableDBCloseOnStop() //Handover DB to new node copy
|
||||
stop()
|
||||
val newNode = mockNet.createNode(networkMapNode?.info?.address, id, advertisedServices = *advertisedServices.toTypedArray())
|
||||
newNode.acceptableLiveFiberCountOnStop = 1
|
||||
manuallyCloseDB()
|
||||
mockNet.runNetwork() // allow NetworkMapService messages to stabilise and thus start the state machine
|
||||
return newNode.getSingleProtocol<P>().first
|
||||
|
@ -112,7 +112,7 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
|
||||
override val log: Logger = loggerFor<MockNode>()
|
||||
override val serverThread: AffinityExecutor =
|
||||
if (mockNet.threadPerNode)
|
||||
ServiceAffinityExecutor("Mock node thread", 1)
|
||||
ServiceAffinityExecutor("Mock node $id thread", 1)
|
||||
else {
|
||||
mockNet.sharedUserCount.incrementAndGet()
|
||||
mockNet.sharedServerThread
|
||||
@ -171,6 +171,11 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
|
||||
dbCloser?.run()
|
||||
dbCloser = null
|
||||
}
|
||||
|
||||
// You can change this from zero if you have custom [ProtocolLogic] that park themselves. e.g. [StateMachineManagerTests]
|
||||
var acceptableLiveFiberCountOnStop: Int = 0
|
||||
|
||||
override fun acceptableLiveFiberCountOnStop(): Int = acceptableLiveFiberCountOnStop
|
||||
}
|
||||
|
||||
/** Returns a node, optionally created by the passed factory method. */
|
||||
|
Loading…
Reference in New Issue
Block a user