Merged in mnesbit-cor-259-controlled-smm-start (pull request #361)

Delay State Machine fiber start until network map cache is fully populated.
This commit is contained in:
Matthew Nesbit 2016-09-19 13:08:25 +01:00
commit b3c1940e1d
8 changed files with 124 additions and 13 deletions

View File

@ -1,5 +1,6 @@
package com.r3corda.core.node.services
import com.google.common.annotations.VisibleForTesting
import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.contracts.Contract
import com.r3corda.core.crypto.Party
@ -32,6 +33,8 @@ interface NetworkMapCache {
val partyNodes: List<NodeInfo>
/** Tracks changes to the network map cache */
val changed: Observable<MapChange>
/** Future to track completion of the NetworkMapService registration. */
val mapServiceRegistered: ListenableFuture<Unit>
/**
* A list of nodes that advertise a regulatory service. Identifying the correct regulator for a trade is outside
@ -97,6 +100,12 @@ interface NetworkMapCache {
* @param service the network map service to fetch current state from.
*/
fun deregisterForUpdates(net: MessagingService, service: NodeInfo): ListenableFuture<Unit>
/**
* For testing where the network map cache is manipulated marks the service as immediately ready.
*/
@VisibleForTesting
fun runWithoutMapService()
}
sealed class NetworkCacheError : Exception() {

View File

@ -297,10 +297,11 @@ abstract class AbstractNode(val dir: Path, val configuration: NodeConfiguration,
}
services.networkMapCache.addNode(info)
// In the unit test environment, we may run without any network map service sometimes.
if (networkMapService == null && inNodeNetworkMapService == null)
if (networkMapService == null && inNodeNetworkMapService == null) {
services.networkMapCache.runWithoutMapService()
return noNetworkMapConfigured()
else
return registerWithNetworkMap(networkMapService ?: info.address)
}
return registerWithNetworkMap(networkMapService ?: info.address)
}
private fun registerWithNetworkMap(networkMapServiceAddress: SingleMessageRecipient): ListenableFuture<Unit> {

View File

@ -34,4 +34,10 @@ data class Checkpoint(
val serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>,
val request: ProtocolIORequest?,
val receivedPayload: Any?
)
) {
// This flag is always false when loaded from storage as it isn't serialised.
// It is used to track when the associated fiber has been created, but not necessarily started when
// messages for protocols arrive before the system has fully loaded at startup.
@Transient
var fiberCreated: Boolean = false
}

View File

@ -1,5 +1,6 @@
package com.r3corda.node.services.network
import com.google.common.annotations.VisibleForTesting
import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.MoreExecutors
import com.google.common.util.concurrent.SettableFuture
@ -45,6 +46,9 @@ open class InMemoryNetworkMapCache : SingletonSerializeAsToken(), NetworkMapCach
get() = registeredNodes.map { it.value }
private val _changed = PublishSubject.create<MapChange>()
override val changed: Observable<MapChange> = _changed
private val _registrationFuture = SettableFuture.create<Unit>()
override val mapServiceRegistered: ListenableFuture<Unit>
get() = _registrationFuture
private var registeredForPush = false
protected var registeredNodes = Collections.synchronizedMap(HashMap<Party, NodeInfo>())
@ -82,6 +86,7 @@ open class InMemoryNetworkMapCache : SingletonSerializeAsToken(), NetworkMapCach
// Add a message handler for the response, and prepare a future to put the data into.
// Note that the message handler will run on the network thread (not this one).
val future = SettableFuture.create<Unit>()
_registrationFuture.setFuture(future)
net.runOnNextMessage(NetworkMapService.FETCH_PROTOCOL_TOPIC, sessionID, MoreExecutors.directExecutor()) { message ->
val resp = message.data.deserialize<NetworkMapService.FetchMapResponse>()
// We may not receive any nodes back, if the map hasn't changed since the version specified
@ -120,6 +125,7 @@ open class InMemoryNetworkMapCache : SingletonSerializeAsToken(), NetworkMapCach
// Add a message handler for the response, and prepare a future to put the data into.
// Note that the message handler will run on the network thread (not this one).
val future = SettableFuture.create<Unit>()
_registrationFuture.setFuture(future)
net.runOnNextMessage(NetworkMapService.SUBSCRIPTION_PROTOCOL_TOPIC, sessionID, MoreExecutors.directExecutor()) { message ->
val resp = message.data.deserialize<NetworkMapService.SubscribeResponse>()
if (resp.confirmed) {
@ -151,4 +157,9 @@ open class InMemoryNetworkMapCache : SingletonSerializeAsToken(), NetworkMapCach
AddOrRemove.REMOVE -> removeNode(reg.node)
}
}
@VisibleForTesting
override fun runWithoutMapService() {
_registrationFuture.set(Unit)
}
}

View File

@ -14,6 +14,7 @@ import com.r3corda.core.messaging.send
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.protocols.ProtocolStateMachine
import com.r3corda.core.serialization.*
import com.r3corda.core.then
import com.r3corda.core.utilities.ProgressTracker
import com.r3corda.core.utilities.trace
import com.r3corda.node.services.api.Checkpoint
@ -73,6 +74,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
private val checkpointingMeter = metrics.meter("Protocols.Checkpointing Rate")
private val totalStartedProtocols = metrics.counter("Protocols.Started")
private val totalFinishedProtocols = metrics.counter("Protocols.Finished")
private var started = false
// Context for tokenized services in checkpoints
private val serializationContext = SerializeAsTokenContext(tokenizableServices, quasarKryo())
@ -118,13 +120,23 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
}
fun start() {
checkpointStorage.checkpoints.forEach { restoreFromCheckpoint(it) }
checkpointStorage.checkpoints.forEach { createFiberForCheckpoint(it) }
serviceHub.networkMapCache.mapServiceRegistered.then(executor) {
synchronized(started) {
started = true
stateMachines.forEach { restartFiber(it.key, it.value) }
}
}
}
private fun restoreFromCheckpoint(checkpoint: Checkpoint) {
val fiber = deserializeFiber(checkpoint.serialisedFiber)
initFiber(fiber, { checkpoint })
private fun createFiberForCheckpoint(checkpoint: Checkpoint) {
if (!checkpoint.fiberCreated) {
val fiber = deserializeFiber(checkpoint.serialisedFiber)
initFiber(fiber, { checkpoint })
}
}
private fun restartFiber(fiber: ProtocolStateMachineImpl<*>, checkpoint: Checkpoint) {
if (checkpoint.request is ReceiveRequest<*>) {
val topicSession = checkpoint.request.receiveTopicSession
fiber.logger.info("Restored ${fiber.logic} - it was previously waiting for message of type ${checkpoint.request.receiveType.name} on $topicSession")
@ -179,7 +191,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
}
}
private fun initFiber(psm: ProtocolStateMachineImpl<*>, startingCheckpoint: () -> Checkpoint) {
private fun initFiber(psm: ProtocolStateMachineImpl<*>, startingCheckpoint: () -> Checkpoint): Checkpoint {
psm.serviceHub = serviceHub
psm.suspendAction = { request ->
psm.logger.trace { "Suspended fiber ${psm.id} ${psm.logic}" }
@ -194,8 +206,12 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
totalFinishedProtocols.inc()
notifyChangeObservers(psm, AddOrRemove.REMOVE)
}
stateMachines[psm] = startingCheckpoint()
val checkpoint = startingCheckpoint()
checkpoint.fiberCreated = true
totalStartedProtocols.inc()
stateMachines[psm] = checkpoint
notifyChangeObservers(psm, AddOrRemove.ADD)
return checkpoint
}
/**
@ -206,17 +222,22 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
fun <T> add(loggerName: String, logic: ProtocolLogic<T>): ProtocolStateMachine<T> {
val fiber = ProtocolStateMachineImpl(logic, scheduler, loggerName)
// Need to add before iterating in case of immediate completion
initFiber(fiber) {
val checkpoint = initFiber(fiber) {
val checkpoint = Checkpoint(serializeFiber(fiber), null, null)
checkpointStorage.addCheckpoint(checkpoint)
checkpoint
}
checkpointStorage.addCheckpoint(checkpoint)
synchronized(started) { // If we are not started then our checkpoint will be picked up during start
if (!started) {
return fiber
}
}
try {
executor.executeASAP {
iterateStateMachine(fiber, null) {
fiber.start()
}
totalStartedProtocols.inc()
}
} catch (e: ExecutionException) {
// There are two ways we can take exceptions in this method:

View File

@ -91,6 +91,7 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
smmHasRemovedAllProtocols.countDown()
}
}
mockSMM.start()
services.smm = mockSMM
}

View File

@ -13,6 +13,8 @@ import org.assertj.core.api.Assertions.assertThat
import org.junit.After
import org.junit.Before
import org.junit.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue
class StateMachineManagerTests {
@ -62,13 +64,72 @@ class StateMachineManagerTests {
assertThat(restoredProtocol.receivedPayload).isEqualTo(payload)
}
@Test
fun `protocol added before network map does run after init`() {
val node3 = net.createNode(node1.info.address) //create vanilla node
val protocol = ProtocolNoBlocking()
node3.smm.add("test", protocol)
assertEquals(false, protocol.protocolStarted) // Not started yet as no network activity has been allowed yet
net.runNetwork() // Allow network map messages to flow
assertEquals(true, protocol.protocolStarted) // Now we should have run the protocol
}
@Test
fun `protocol added before network map will be init checkpointed`() {
var node3 = net.createNode(node1.info.address) //create vanilla node
val protocol = ProtocolNoBlocking()
node3.smm.add("test", protocol)
assertEquals(false, protocol.protocolStarted) // Not started yet as no network activity has been allowed yet
node3.stop()
node3 = net.createNode(node1.info.address, forcedID = node3.id)
val restoredProtocol = node3.smm.findStateMachines(ProtocolNoBlocking::class.java).single().first
assertEquals(false, restoredProtocol.protocolStarted) // Not started yet as no network activity has been allowed yet
net.runNetwork() // Allow network map messages to flow
node3.smm.executor.flush()
assertEquals(true, restoredProtocol.protocolStarted) // Now we should have run the protocol and hopefully cleared the init checkpoint
node3.stop()
// Now it is completed the protocol should leave no Checkpoint.
node3 = net.createNode(node1.info.address, forcedID = node3.id)
net.runNetwork() // Allow network map messages to flow
node3.smm.executor.flush()
assertTrue(node3.smm.findStateMachines(ProtocolNoBlocking::class.java).isEmpty())
}
@Test
fun `protocol loaded from checkpoint will respond to messages from before start`() {
val topic = "send-and-receive"
val payload = random63BitValue()
val sendProtocol = SendProtocol(topic, node2.info.identity, payload)
val receiveProtocol = ReceiveProtocol(topic, node1.info.identity)
connectProtocols(sendProtocol, receiveProtocol)
node2.smm.add("test", receiveProtocol) // Prepare checkpointed receive protocol
node2.stop() // kill receiver
node1.smm.add("test", sendProtocol) // now generate message to spool up and thus come in ahead of messages for NetworkMapService
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveProtocol>(node1.info.address)
assertThat(restoredProtocol.receivedPayload).isEqualTo(payload)
}
private inline fun <reified P : NonTerminatingProtocol> MockNode.restartAndGetRestoredProtocol(networkMapAddress: SingleMessageRecipient? = null): P {
val servicesArray = advertisedServices.toTypedArray()
val node = mockNet.createNode(networkMapAddress, id, advertisedServices = *servicesArray)
mockNet.runNetwork() // allow NetworkMapService messages to stabilise and thus start the state machine
return node.smm.findStateMachines(P::class.java).single().first
}
private class ProtocolNoBlocking : ProtocolLogic<Unit>() {
@Transient var protocolStarted = false
@Suspendable
override fun call() {
protocolStarted = true
}
override val topic: String get() = throw UnsupportedOperationException()
}
private class ProtocolWithoutCheckpoints : NonTerminatingProtocol() {
@Transient var protocolStarted = false

View File

@ -22,6 +22,7 @@ class MockNetworkMapCache() : com.r3corda.node.services.network.InMemoryNetworkM
val mockNodeB = NodeInfo(MockAddress("bankD:8080"), Party("Bank D", DummyPublicKey("Bank D")))
registeredNodes[mockNodeA.identity] = mockNodeA
registeredNodes[mockNodeB.identity] = mockNodeB
runWithoutMapService()
}
/**