Merged in parkri-protocol-error-handling-fix (pull request #447)

Fix handling of node shutdown so protocols don't blow up when they encounter the messaging layer already shutdown.  Protocols will also stop resuming once shutdown has commenced.
This commit is contained in:
Rick Parker 2016-11-08 18:15:35 +00:00
commit af8859ebf1
7 changed files with 82 additions and 17 deletions

View File

@ -1,6 +1,7 @@
package com.r3corda.node.internal package com.r3corda.node.internal
import com.codahale.metrics.MetricRegistry import com.codahale.metrics.MetricRegistry
import com.google.common.annotations.VisibleForTesting
import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.MoreExecutors import com.google.common.util.concurrent.MoreExecutors
import com.google.common.util.concurrent.SettableFuture import com.google.common.util.concurrent.SettableFuture
@ -254,12 +255,17 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, val netwo
runOnStop += Runnable { net.stop() } runOnStop += Runnable { net.stop() }
_networkMapRegistrationFuture.setFuture(registerWithNetworkMap()) _networkMapRegistrationFuture.setFuture(registerWithNetworkMap())
smm.start() smm.start()
// Shut down the SMM so no Fibers are scheduled.
runOnStop += Runnable { smm.stop(acceptableLiveFiberCountOnStop()) }
scheduler.start() scheduler.start()
} }
started = true started = true
return this return this
} }
@VisibleForTesting
protected open fun acceptableLiveFiberCountOnStop(): Int = 0
private fun hasSSLCertificates(): Boolean { private fun hasSSLCertificates(): Boolean {
val keyStore = try { val keyStore = try {
// This will throw exception if key file not found or keystore password is incorrect. // This will throw exception if key file not found or keystore password is incorrect.

View File

@ -1,17 +1,14 @@
package com.r3corda.node.services.api package com.r3corda.node.services.api
import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.crypto.Party
import com.r3corda.core.messaging.MessagingService import com.r3corda.core.messaging.MessagingService
import com.r3corda.core.node.PluginServiceHub import com.r3corda.core.node.PluginServiceHub
import com.r3corda.core.node.ServiceHub
import com.r3corda.core.node.services.TxWritableStorageService import com.r3corda.core.node.services.TxWritableStorageService
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.protocols.ProtocolLogicRefFactory import com.r3corda.core.protocols.ProtocolLogicRefFactory
import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import kotlin.reflect.KClass
interface MessagingServiceInternal : MessagingService { interface MessagingServiceInternal : MessagingService {
/** /**

View File

@ -18,6 +18,7 @@ import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.node.services.statemachine.StateMachineManager.* import com.r3corda.node.services.statemachine.StateMachineManager.*
import com.r3corda.node.utilities.StrandLocalTransactionManager import com.r3corda.node.utilities.StrandLocalTransactionManager
import com.r3corda.node.utilities.createDatabaseTransaction import com.r3corda.node.utilities.createDatabaseTransaction
import com.r3corda.node.utilities.databaseTransaction
import org.jetbrains.exposed.sql.Database import org.jetbrains.exposed.sql.Database
import org.jetbrains.exposed.sql.Transaction import org.jetbrains.exposed.sql.Transaction
import org.jetbrains.exposed.sql.transactions.TransactionManager import org.jetbrains.exposed.sql.transactions.TransactionManager
@ -230,7 +231,7 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
// we have to pass the Thread local Transaction across via a transient field as the Fiber Park swaps them out. // we have to pass the Thread local Transaction across via a transient field as the Fiber Park swaps them out.
txTrampoline = TransactionManager.currentOrNull() txTrampoline = TransactionManager.currentOrNull()
StrandLocalTransactionManager.setThreadLocalTx(null) StrandLocalTransactionManager.setThreadLocalTx(null)
ioRequest.session.waitingForResponse = true ioRequest.session.waitingForResponse = (ioRequest is ReceiveRequest<*>)
parkAndSerialize { fiber, serializer -> parkAndSerialize { fiber, serializer ->
logger.trace { "Suspended on $ioRequest" } logger.trace { "Suspended on $ioRequest" }
// restore the Tx onto the ThreadLocal so that we can commit the ensuing checkpoint to the DB // restore the Tx onto the ThreadLocal so that we can commit the ensuing checkpoint to the DB
@ -246,13 +247,16 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
processException(t) processException(t)
} }
} }
ioRequest.session.waitingForResponse = false logger.trace { "Resumed from $ioRequest" }
createTransaction() createTransaction()
} }
private fun processException(t: Throwable) { private fun processException(t: Throwable) {
actionOnEnd() // This can get called in actionOnSuspend *after* we commit the database transaction, so optionally open a new one here.
_resultFuture?.setException(t) databaseTransaction(database) {
actionOnEnd()
_resultFuture?.setException(t)
}
} }
internal fun resume(scheduler: FiberScheduler) { internal fun resume(scheduler: FiberScheduler) {

View File

@ -29,6 +29,7 @@ import com.r3corda.node.utilities.AddOrRemove
import com.r3corda.node.utilities.AffinityExecutor import com.r3corda.node.utilities.AffinityExecutor
import com.r3corda.node.utilities.isolatedTransaction import com.r3corda.node.utilities.isolatedTransaction
import kotlinx.support.jdk8.collections.removeIf import kotlinx.support.jdk8.collections.removeIf
import org.apache.activemq.artemis.utils.ReusableLatch
import org.jetbrains.exposed.sql.Database import org.jetbrains.exposed.sql.Database
import rx.Observable import rx.Observable
import rx.subjects.PublishSubject import rx.subjects.PublishSubject
@ -95,6 +96,11 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
} }
}) })
// True if we're shutting down, so don't resume anything.
@Volatile private var stopping = false
// How many Fibers are running and not suspended. If zero and stopping is true, then we are halted.
private val liveFibers = ReusableLatch()
// Monitoring support. // Monitoring support.
private val metrics = serviceHub.monitoringService.metrics private val metrics = serviceHub.monitoringService.metrics
@ -144,6 +150,31 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
serviceHub.networkMapCache.mapServiceRegistered.then(executor) { resumeRestoredFibers() } serviceHub.networkMapCache.mapServiceRegistered.then(executor) { resumeRestoredFibers() }
} }
private fun decrementLiveFibers() {
liveFibers.countDown()
}
private fun incrementLiveFibers() {
liveFibers.countUp()
}
/**
* Start the shutdown process, bringing the [StateMachineManager] to a controlled stop. When this method returns,
* all Fibers have been suspended and checkpointed, or have completed.
*
* @param allowedUnsuspendedFiberCount Optional parameter is used in some tests.
*/
fun stop(allowedUnsuspendedFiberCount: Int = 0) {
check(allowedUnsuspendedFiberCount >= 0)
mutex.locked {
if (stopping) throw IllegalStateException("Already stopping!")
stopping = true
}
// Account for any expected Fibers in a test scenario.
liveFibers.countDown(allowedUnsuspendedFiberCount)
liveFibers.await()
}
/** /**
* Atomic get snapshot + subscribe. This is needed so we don't miss updates between subscriptions to [changes] and * Atomic get snapshot + subscribe. This is needed so we don't miss updates between subscriptions to [changes] and
* calls to [allStateMachines] * calls to [allStateMachines]
@ -203,6 +234,8 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
} }
session.receivedMessages += message session.receivedMessages += message
if (session.waitingForResponse) { if (session.waitingForResponse) {
// We only want to resume once, so immediately reset the flag.
session.waitingForResponse = false
updateCheckpoint(session.psm) updateCheckpoint(session.psm)
resumeFiber(session.psm) resumeFiber(session.psm)
} }
@ -285,15 +318,20 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
// This will free up the ThreadLocal so on return the caller can carry on with other transactions // This will free up the ThreadLocal so on return the caller can carry on with other transactions
psm.commitTransaction() psm.commitTransaction()
processIORequest(ioRequest) processIORequest(ioRequest)
decrementLiveFibers()
} }
psm.actionOnEnd = { psm.actionOnEnd = {
psm.logic.progressTracker?.currentStep = ProgressTracker.DONE try {
mutex.locked { psm.logic.progressTracker?.currentStep = ProgressTracker.DONE
stateMachines.remove(psm)?.let { checkpointStorage.removeCheckpoint(it) } mutex.locked {
totalFinishedProtocols.inc() stateMachines.remove(psm)?.let { checkpointStorage.removeCheckpoint(it) }
notifyChangeObservers(psm, AddOrRemove.REMOVE) totalFinishedProtocols.inc()
notifyChangeObservers(psm, AddOrRemove.REMOVE)
}
endAllFiberSessions(psm)
} finally {
decrementLiveFibers()
} }
endAllFiberSessions(psm)
} }
mutex.locked { mutex.locked {
totalStartedProtocols.inc() totalStartedProtocols.inc()
@ -370,8 +408,13 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
} }
private fun resumeFiber(psm: ProtocolStateMachineImpl<*>) { private fun resumeFiber(psm: ProtocolStateMachineImpl<*>) {
executor.executeASAP { // Avoid race condition when setting stopping to true and then checking liveFibers
incrementLiveFibers()
if (!stopping) executor.executeASAP {
psm.resume(scheduler) psm.resume(scheduler)
} else {
psm.logger.debug("Not resuming as SMM is stopping.")
decrementLiveFibers()
} }
} }

View File

@ -21,7 +21,10 @@ import com.r3corda.core.utilities.LogHelper
import com.r3corda.core.utilities.TEST_TX_TIME import com.r3corda.core.utilities.TEST_TX_TIME
import com.r3corda.node.internal.AbstractNode import com.r3corda.node.internal.AbstractNode
import com.r3corda.node.services.config.NodeConfiguration import com.r3corda.node.services.config.NodeConfiguration
import com.r3corda.node.services.persistence.* import com.r3corda.node.services.persistence.DBTransactionStorage
import com.r3corda.node.services.persistence.NodeAttachmentService
import com.r3corda.node.services.persistence.StorageServiceImpl
import com.r3corda.node.services.persistence.checkpoints
import com.r3corda.node.utilities.databaseTransaction import com.r3corda.node.utilities.databaseTransaction
import com.r3corda.protocols.TwoPartyTradeProtocol.Buyer import com.r3corda.protocols.TwoPartyTradeProtocol.Buyer
import com.r3corda.protocols.TwoPartyTradeProtocol.Seller import com.r3corda.protocols.TwoPartyTradeProtocol.Seller
@ -114,7 +117,7 @@ class TwoPartyTradeProtocolTests {
databaseTransaction(bobNode.database) { databaseTransaction(bobNode.database) {
assertThat(bobNode.checkpointStorage.checkpoints()).isEmpty() assertThat(bobNode.checkpointStorage.checkpoints()).isEmpty()
} }
aliceNode.manuallyCloseDB() bobNode.manuallyCloseDB()
} }
} }

View File

@ -51,6 +51,7 @@ class StateMachineManagerTests {
@Test @Test
fun `newly added protocol is preserved on restart`() { fun `newly added protocol is preserved on restart`() {
node1.smm.add(NoOpProtocol(nonTerminating = true)) node1.smm.add(NoOpProtocol(nonTerminating = true))
node1.acceptableLiveFiberCountOnStop = 1
val restoredProtocol = node1.restartAndGetRestoredProtocol<NoOpProtocol>() val restoredProtocol = node1.restartAndGetRestoredProtocol<NoOpProtocol>()
assertThat(restoredProtocol.protocolStarted).isTrue() assertThat(restoredProtocol.protocolStarted).isTrue()
} }
@ -75,6 +76,7 @@ class StateMachineManagerTests {
// We push through just enough messages to get only the payload sent // We push through just enough messages to get only the payload sent
node2.pumpReceive() node2.pumpReceive()
node2.disableDBCloseOnStop() node2.disableDBCloseOnStop()
node2.acceptableLiveFiberCountOnStop = 1
node2.stop() node2.stop()
net.runNetwork() net.runNetwork()
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1) val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1)
@ -206,6 +208,9 @@ class StateMachineManagerTests {
node1 sent sessionEnd() to node3 node1 sent sessionEnd() to node3
//There's no session end from the other protocols as they're manually suspended //There's no session end from the other protocols as they're manually suspended
) )
node2.acceptableLiveFiberCountOnStop = 1
node3.acceptableLiveFiberCountOnStop = 1
} }
@Test @Test
@ -218,6 +223,7 @@ class StateMachineManagerTests {
node3.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(node3Payload, it) } node3.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(node3Payload, it) }
val multiReceiveProtocol = ReceiveThenSuspendProtocol(node2.info.legalIdentity, node3.info.legalIdentity) val multiReceiveProtocol = ReceiveThenSuspendProtocol(node2.info.legalIdentity, node3.info.legalIdentity)
node1.smm.add(multiReceiveProtocol) node1.smm.add(multiReceiveProtocol)
node1.acceptableLiveFiberCountOnStop = 1
net.runNetwork() net.runNetwork()
assertThat(multiReceiveProtocol.receivedPayloads[0]).isEqualTo(node2Payload) assertThat(multiReceiveProtocol.receivedPayloads[0]).isEqualTo(node2Payload)
assertThat(multiReceiveProtocol.receivedPayloads[1]).isEqualTo(node3Payload) assertThat(multiReceiveProtocol.receivedPayloads[1]).isEqualTo(node3Payload)
@ -271,6 +277,7 @@ class StateMachineManagerTests {
disableDBCloseOnStop() //Handover DB to new node copy disableDBCloseOnStop() //Handover DB to new node copy
stop() stop()
val newNode = mockNet.createNode(networkMapNode?.info?.address, id, advertisedServices = *advertisedServices.toTypedArray()) val newNode = mockNet.createNode(networkMapNode?.info?.address, id, advertisedServices = *advertisedServices.toTypedArray())
newNode.acceptableLiveFiberCountOnStop = 1
manuallyCloseDB() manuallyCloseDB()
mockNet.runNetwork() // allow NetworkMapService messages to stabilise and thus start the state machine mockNet.runNetwork() // allow NetworkMapService messages to stabilise and thus start the state machine
return newNode.getSingleProtocol<P>().first return newNode.getSingleProtocol<P>().first

View File

@ -112,7 +112,7 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
override val log: Logger = loggerFor<MockNode>() override val log: Logger = loggerFor<MockNode>()
override val serverThread: AffinityExecutor = override val serverThread: AffinityExecutor =
if (mockNet.threadPerNode) if (mockNet.threadPerNode)
ServiceAffinityExecutor("Mock node thread", 1) ServiceAffinityExecutor("Mock node $id thread", 1)
else { else {
mockNet.sharedUserCount.incrementAndGet() mockNet.sharedUserCount.incrementAndGet()
mockNet.sharedServerThread mockNet.sharedServerThread
@ -171,6 +171,11 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
dbCloser?.run() dbCloser?.run()
dbCloser = null dbCloser = null
} }
// You can change this from zero if you have custom [ProtocolLogic] that park themselves. e.g. [StateMachineManagerTests]
var acceptableLiveFiberCountOnStop: Int = 0
override fun acceptableLiveFiberCountOnStop(): Int = acceptableLiveFiberCountOnStop
} }
/** Returns a node, optionally created by the passed factory method. */ /** Returns a node, optionally created by the passed factory method. */