mirror of
https://github.com/corda/corda.git
synced 2024-12-19 13:08:04 +00:00
Implement controlled stop of StateMachineManager.
This commit is contained in:
parent
c0e08bee60
commit
fe6bf0e6ea
@ -1,6 +1,7 @@
|
|||||||
package com.r3corda.node.internal
|
package com.r3corda.node.internal
|
||||||
|
|
||||||
import com.codahale.metrics.MetricRegistry
|
import com.codahale.metrics.MetricRegistry
|
||||||
|
import com.google.common.annotations.VisibleForTesting
|
||||||
import com.google.common.util.concurrent.ListenableFuture
|
import com.google.common.util.concurrent.ListenableFuture
|
||||||
import com.google.common.util.concurrent.MoreExecutors
|
import com.google.common.util.concurrent.MoreExecutors
|
||||||
import com.google.common.util.concurrent.SettableFuture
|
import com.google.common.util.concurrent.SettableFuture
|
||||||
@ -255,12 +256,17 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, val netwo
|
|||||||
runOnStop += Runnable { net.stop() }
|
runOnStop += Runnable { net.stop() }
|
||||||
_networkMapRegistrationFuture.setFuture(registerWithNetworkMap())
|
_networkMapRegistrationFuture.setFuture(registerWithNetworkMap())
|
||||||
smm.start()
|
smm.start()
|
||||||
|
// Shut down the SMM so no Fibers are scheduled.
|
||||||
|
runOnStop += Runnable { smm.stop(acceptableLiveFiberCountOnStop()) }
|
||||||
scheduler.start()
|
scheduler.start()
|
||||||
}
|
}
|
||||||
started = true
|
started = true
|
||||||
return this
|
return this
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@VisibleForTesting
|
||||||
|
protected open fun acceptableLiveFiberCountOnStop(): Int = 0
|
||||||
|
|
||||||
private fun hasSSLCertificates(): Boolean {
|
private fun hasSSLCertificates(): Boolean {
|
||||||
val keyStore = try {
|
val keyStore = try {
|
||||||
// This will throw exception if key file not found or keystore password is incorrect.
|
// This will throw exception if key file not found or keystore password is incorrect.
|
||||||
|
@ -1,17 +1,14 @@
|
|||||||
package com.r3corda.node.services.api
|
package com.r3corda.node.services.api
|
||||||
|
|
||||||
import com.google.common.util.concurrent.ListenableFuture
|
import com.google.common.util.concurrent.ListenableFuture
|
||||||
import com.r3corda.core.crypto.Party
|
|
||||||
import com.r3corda.core.messaging.MessagingService
|
import com.r3corda.core.messaging.MessagingService
|
||||||
import com.r3corda.core.node.PluginServiceHub
|
import com.r3corda.core.node.PluginServiceHub
|
||||||
import com.r3corda.core.node.ServiceHub
|
|
||||||
import com.r3corda.core.node.services.TxWritableStorageService
|
import com.r3corda.core.node.services.TxWritableStorageService
|
||||||
import com.r3corda.core.protocols.ProtocolLogic
|
import com.r3corda.core.protocols.ProtocolLogic
|
||||||
import com.r3corda.core.protocols.ProtocolLogicRefFactory
|
import com.r3corda.core.protocols.ProtocolLogicRefFactory
|
||||||
import com.r3corda.core.transactions.SignedTransaction
|
import com.r3corda.core.transactions.SignedTransaction
|
||||||
import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl
|
import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl
|
||||||
import org.slf4j.LoggerFactory
|
import org.slf4j.LoggerFactory
|
||||||
import kotlin.reflect.KClass
|
|
||||||
|
|
||||||
interface MessagingServiceInternal : MessagingService {
|
interface MessagingServiceInternal : MessagingService {
|
||||||
/**
|
/**
|
||||||
|
@ -18,6 +18,7 @@ import com.r3corda.node.services.api.ServiceHubInternal
|
|||||||
import com.r3corda.node.services.statemachine.StateMachineManager.*
|
import com.r3corda.node.services.statemachine.StateMachineManager.*
|
||||||
import com.r3corda.node.utilities.StrandLocalTransactionManager
|
import com.r3corda.node.utilities.StrandLocalTransactionManager
|
||||||
import com.r3corda.node.utilities.createDatabaseTransaction
|
import com.r3corda.node.utilities.createDatabaseTransaction
|
||||||
|
import com.r3corda.node.utilities.databaseTransaction
|
||||||
import org.jetbrains.exposed.sql.Database
|
import org.jetbrains.exposed.sql.Database
|
||||||
import org.jetbrains.exposed.sql.Transaction
|
import org.jetbrains.exposed.sql.Transaction
|
||||||
import org.jetbrains.exposed.sql.transactions.TransactionManager
|
import org.jetbrains.exposed.sql.transactions.TransactionManager
|
||||||
@ -230,7 +231,7 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
|
|||||||
// we have to pass the Thread local Transaction across via a transient field as the Fiber Park swaps them out.
|
// we have to pass the Thread local Transaction across via a transient field as the Fiber Park swaps them out.
|
||||||
txTrampoline = TransactionManager.currentOrNull()
|
txTrampoline = TransactionManager.currentOrNull()
|
||||||
StrandLocalTransactionManager.setThreadLocalTx(null)
|
StrandLocalTransactionManager.setThreadLocalTx(null)
|
||||||
ioRequest.session.waitingForResponse = true
|
ioRequest.session.waitingForResponse = (ioRequest is ReceiveRequest<*>)
|
||||||
parkAndSerialize { fiber, serializer ->
|
parkAndSerialize { fiber, serializer ->
|
||||||
logger.trace { "Suspended on $ioRequest" }
|
logger.trace { "Suspended on $ioRequest" }
|
||||||
// restore the Tx onto the ThreadLocal so that we can commit the ensuing checkpoint to the DB
|
// restore the Tx onto the ThreadLocal so that we can commit the ensuing checkpoint to the DB
|
||||||
@ -246,13 +247,16 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
|
|||||||
processException(t)
|
processException(t)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ioRequest.session.waitingForResponse = false
|
logger.trace { "Resumed from $ioRequest" }
|
||||||
createTransaction()
|
createTransaction()
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun processException(t: Throwable) {
|
private fun processException(t: Throwable) {
|
||||||
actionOnEnd()
|
// This can get called in actionOnSuspend *after* we commit the database transaction, so optionally open a new one here.
|
||||||
_resultFuture?.setException(t)
|
databaseTransaction(database) {
|
||||||
|
actionOnEnd()
|
||||||
|
_resultFuture?.setException(t)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
internal fun resume(scheduler: FiberScheduler) {
|
internal fun resume(scheduler: FiberScheduler) {
|
||||||
|
@ -29,6 +29,7 @@ import com.r3corda.node.utilities.AddOrRemove
|
|||||||
import com.r3corda.node.utilities.AffinityExecutor
|
import com.r3corda.node.utilities.AffinityExecutor
|
||||||
import com.r3corda.node.utilities.isolatedTransaction
|
import com.r3corda.node.utilities.isolatedTransaction
|
||||||
import kotlinx.support.jdk8.collections.removeIf
|
import kotlinx.support.jdk8.collections.removeIf
|
||||||
|
import org.apache.activemq.artemis.utils.ReusableLatch
|
||||||
import org.jetbrains.exposed.sql.Database
|
import org.jetbrains.exposed.sql.Database
|
||||||
import rx.Observable
|
import rx.Observable
|
||||||
import rx.subjects.PublishSubject
|
import rx.subjects.PublishSubject
|
||||||
@ -95,6 +96,11 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// True if we're shutting down, so don't resume anything.
|
||||||
|
@Volatile private var stopping = false
|
||||||
|
// How many Fibers are running and not suspended. If zero and stopping is true, then we are halted.
|
||||||
|
private val liveFibers = ReusableLatch()
|
||||||
|
|
||||||
// Monitoring support.
|
// Monitoring support.
|
||||||
private val metrics = serviceHub.monitoringService.metrics
|
private val metrics = serviceHub.monitoringService.metrics
|
||||||
|
|
||||||
@ -144,6 +150,31 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
|||||||
serviceHub.networkMapCache.mapServiceRegistered.then(executor) { resumeRestoredFibers() }
|
serviceHub.networkMapCache.mapServiceRegistered.then(executor) { resumeRestoredFibers() }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private fun decrementLiveFibers() {
|
||||||
|
liveFibers.countDown()
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun incrementLiveFibers() {
|
||||||
|
liveFibers.countUp()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Start the shutdown process, bringing the [StateMachineManager] to a controlled stop. When this method returns,
|
||||||
|
* all Fibers have been suspended and checkpointed, or have completed.
|
||||||
|
*
|
||||||
|
* @param allowedUnsuspendedFiberCount Optional parameter is used in some tests.
|
||||||
|
*/
|
||||||
|
fun stop(allowedUnsuspendedFiberCount: Int = 0) {
|
||||||
|
check(allowedUnsuspendedFiberCount >= 0)
|
||||||
|
mutex.locked {
|
||||||
|
if (stopping) throw IllegalStateException("Already stopping!")
|
||||||
|
stopping = true
|
||||||
|
}
|
||||||
|
// Account for any expected Fibers in a test scenario.
|
||||||
|
liveFibers.countDown(allowedUnsuspendedFiberCount)
|
||||||
|
liveFibers.await()
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Atomic get snapshot + subscribe. This is needed so we don't miss updates between subscriptions to [changes] and
|
* Atomic get snapshot + subscribe. This is needed so we don't miss updates between subscriptions to [changes] and
|
||||||
* calls to [allStateMachines]
|
* calls to [allStateMachines]
|
||||||
@ -203,6 +234,8 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
|||||||
}
|
}
|
||||||
session.receivedMessages += message
|
session.receivedMessages += message
|
||||||
if (session.waitingForResponse) {
|
if (session.waitingForResponse) {
|
||||||
|
// We only want to resume once, so immediately reset the flag.
|
||||||
|
session.waitingForResponse = false
|
||||||
updateCheckpoint(session.psm)
|
updateCheckpoint(session.psm)
|
||||||
resumeFiber(session.psm)
|
resumeFiber(session.psm)
|
||||||
}
|
}
|
||||||
@ -285,15 +318,20 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
|||||||
// This will free up the ThreadLocal so on return the caller can carry on with other transactions
|
// This will free up the ThreadLocal so on return the caller can carry on with other transactions
|
||||||
psm.commitTransaction()
|
psm.commitTransaction()
|
||||||
processIORequest(ioRequest)
|
processIORequest(ioRequest)
|
||||||
|
decrementLiveFibers()
|
||||||
}
|
}
|
||||||
psm.actionOnEnd = {
|
psm.actionOnEnd = {
|
||||||
psm.logic.progressTracker?.currentStep = ProgressTracker.DONE
|
try {
|
||||||
mutex.locked {
|
psm.logic.progressTracker?.currentStep = ProgressTracker.DONE
|
||||||
stateMachines.remove(psm)?.let { checkpointStorage.removeCheckpoint(it) }
|
mutex.locked {
|
||||||
totalFinishedProtocols.inc()
|
stateMachines.remove(psm)?.let { checkpointStorage.removeCheckpoint(it) }
|
||||||
notifyChangeObservers(psm, AddOrRemove.REMOVE)
|
totalFinishedProtocols.inc()
|
||||||
|
notifyChangeObservers(psm, AddOrRemove.REMOVE)
|
||||||
|
}
|
||||||
|
endAllFiberSessions(psm)
|
||||||
|
} finally {
|
||||||
|
decrementLiveFibers()
|
||||||
}
|
}
|
||||||
endAllFiberSessions(psm)
|
|
||||||
}
|
}
|
||||||
mutex.locked {
|
mutex.locked {
|
||||||
totalStartedProtocols.inc()
|
totalStartedProtocols.inc()
|
||||||
@ -370,8 +408,13 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
|||||||
}
|
}
|
||||||
|
|
||||||
private fun resumeFiber(psm: ProtocolStateMachineImpl<*>) {
|
private fun resumeFiber(psm: ProtocolStateMachineImpl<*>) {
|
||||||
executor.executeASAP {
|
// Avoid race condition when setting stopping to true and then checking liveFibers
|
||||||
|
incrementLiveFibers()
|
||||||
|
if (!stopping) executor.executeASAP {
|
||||||
psm.resume(scheduler)
|
psm.resume(scheduler)
|
||||||
|
} else {
|
||||||
|
psm.logger.debug("Not resuming as SMM is stopping.")
|
||||||
|
decrementLiveFibers()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -21,7 +21,10 @@ import com.r3corda.core.utilities.LogHelper
|
|||||||
import com.r3corda.core.utilities.TEST_TX_TIME
|
import com.r3corda.core.utilities.TEST_TX_TIME
|
||||||
import com.r3corda.node.internal.AbstractNode
|
import com.r3corda.node.internal.AbstractNode
|
||||||
import com.r3corda.node.services.config.NodeConfiguration
|
import com.r3corda.node.services.config.NodeConfiguration
|
||||||
import com.r3corda.node.services.persistence.*
|
import com.r3corda.node.services.persistence.DBTransactionStorage
|
||||||
|
import com.r3corda.node.services.persistence.NodeAttachmentService
|
||||||
|
import com.r3corda.node.services.persistence.StorageServiceImpl
|
||||||
|
import com.r3corda.node.services.persistence.checkpoints
|
||||||
import com.r3corda.node.utilities.databaseTransaction
|
import com.r3corda.node.utilities.databaseTransaction
|
||||||
import com.r3corda.protocols.TwoPartyTradeProtocol.Buyer
|
import com.r3corda.protocols.TwoPartyTradeProtocol.Buyer
|
||||||
import com.r3corda.protocols.TwoPartyTradeProtocol.Seller
|
import com.r3corda.protocols.TwoPartyTradeProtocol.Seller
|
||||||
@ -114,7 +117,7 @@ class TwoPartyTradeProtocolTests {
|
|||||||
databaseTransaction(bobNode.database) {
|
databaseTransaction(bobNode.database) {
|
||||||
assertThat(bobNode.checkpointStorage.checkpoints()).isEmpty()
|
assertThat(bobNode.checkpointStorage.checkpoints()).isEmpty()
|
||||||
}
|
}
|
||||||
aliceNode.manuallyCloseDB()
|
bobNode.manuallyCloseDB()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -51,6 +51,7 @@ class StateMachineManagerTests {
|
|||||||
@Test
|
@Test
|
||||||
fun `newly added protocol is preserved on restart`() {
|
fun `newly added protocol is preserved on restart`() {
|
||||||
node1.smm.add(NoOpProtocol(nonTerminating = true))
|
node1.smm.add(NoOpProtocol(nonTerminating = true))
|
||||||
|
node1.acceptableLiveFiberCountOnStop = 1
|
||||||
val restoredProtocol = node1.restartAndGetRestoredProtocol<NoOpProtocol>()
|
val restoredProtocol = node1.restartAndGetRestoredProtocol<NoOpProtocol>()
|
||||||
assertThat(restoredProtocol.protocolStarted).isTrue()
|
assertThat(restoredProtocol.protocolStarted).isTrue()
|
||||||
}
|
}
|
||||||
@ -75,6 +76,7 @@ class StateMachineManagerTests {
|
|||||||
// We push through just enough messages to get only the payload sent
|
// We push through just enough messages to get only the payload sent
|
||||||
node2.pumpReceive()
|
node2.pumpReceive()
|
||||||
node2.disableDBCloseOnStop()
|
node2.disableDBCloseOnStop()
|
||||||
|
node2.acceptableLiveFiberCountOnStop = 1
|
||||||
node2.stop()
|
node2.stop()
|
||||||
net.runNetwork()
|
net.runNetwork()
|
||||||
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1)
|
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1)
|
||||||
@ -206,6 +208,9 @@ class StateMachineManagerTests {
|
|||||||
node1 sent sessionEnd() to node3
|
node1 sent sessionEnd() to node3
|
||||||
//There's no session end from the other protocols as they're manually suspended
|
//There's no session end from the other protocols as they're manually suspended
|
||||||
)
|
)
|
||||||
|
|
||||||
|
node2.acceptableLiveFiberCountOnStop = 1
|
||||||
|
node3.acceptableLiveFiberCountOnStop = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -218,6 +223,7 @@ class StateMachineManagerTests {
|
|||||||
node3.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(node3Payload, it) }
|
node3.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(node3Payload, it) }
|
||||||
val multiReceiveProtocol = ReceiveThenSuspendProtocol(node2.info.legalIdentity, node3.info.legalIdentity)
|
val multiReceiveProtocol = ReceiveThenSuspendProtocol(node2.info.legalIdentity, node3.info.legalIdentity)
|
||||||
node1.smm.add(multiReceiveProtocol)
|
node1.smm.add(multiReceiveProtocol)
|
||||||
|
node1.acceptableLiveFiberCountOnStop = 1
|
||||||
net.runNetwork()
|
net.runNetwork()
|
||||||
assertThat(multiReceiveProtocol.receivedPayloads[0]).isEqualTo(node2Payload)
|
assertThat(multiReceiveProtocol.receivedPayloads[0]).isEqualTo(node2Payload)
|
||||||
assertThat(multiReceiveProtocol.receivedPayloads[1]).isEqualTo(node3Payload)
|
assertThat(multiReceiveProtocol.receivedPayloads[1]).isEqualTo(node3Payload)
|
||||||
@ -271,6 +277,7 @@ class StateMachineManagerTests {
|
|||||||
disableDBCloseOnStop() //Handover DB to new node copy
|
disableDBCloseOnStop() //Handover DB to new node copy
|
||||||
stop()
|
stop()
|
||||||
val newNode = mockNet.createNode(networkMapNode?.info?.address, id, advertisedServices = *advertisedServices.toTypedArray())
|
val newNode = mockNet.createNode(networkMapNode?.info?.address, id, advertisedServices = *advertisedServices.toTypedArray())
|
||||||
|
newNode.acceptableLiveFiberCountOnStop = 1
|
||||||
manuallyCloseDB()
|
manuallyCloseDB()
|
||||||
mockNet.runNetwork() // allow NetworkMapService messages to stabilise and thus start the state machine
|
mockNet.runNetwork() // allow NetworkMapService messages to stabilise and thus start the state machine
|
||||||
return newNode.getSingleProtocol<P>().first
|
return newNode.getSingleProtocol<P>().first
|
||||||
|
@ -112,7 +112,7 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
|
|||||||
override val log: Logger = loggerFor<MockNode>()
|
override val log: Logger = loggerFor<MockNode>()
|
||||||
override val serverThread: AffinityExecutor =
|
override val serverThread: AffinityExecutor =
|
||||||
if (mockNet.threadPerNode)
|
if (mockNet.threadPerNode)
|
||||||
ServiceAffinityExecutor("Mock node thread", 1)
|
ServiceAffinityExecutor("Mock node $id thread", 1)
|
||||||
else {
|
else {
|
||||||
mockNet.sharedUserCount.incrementAndGet()
|
mockNet.sharedUserCount.incrementAndGet()
|
||||||
mockNet.sharedServerThread
|
mockNet.sharedServerThread
|
||||||
@ -171,6 +171,11 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
|
|||||||
dbCloser?.run()
|
dbCloser?.run()
|
||||||
dbCloser = null
|
dbCloser = null
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// You can change this from zero if you have custom [ProtocolLogic] that park themselves. e.g. [StateMachineManagerTests]
|
||||||
|
var acceptableLiveFiberCountOnStop: Int = 0
|
||||||
|
|
||||||
|
override fun acceptableLiveFiberCountOnStop(): Int = acceptableLiveFiberCountOnStop
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Returns a node, optionally created by the passed factory method. */
|
/** Returns a node, optionally created by the passed factory method. */
|
||||||
|
Loading…
Reference in New Issue
Block a user