mirror of
https://github.com/corda/corda.git
synced 2025-01-19 11:16:54 +00:00
Merged in mnesbit-cor-259-controlled-smm-start (pull request #361)
Delay State Machine fiber start until network map cache is fully populated.
This commit is contained in:
commit
b3c1940e1d
@ -1,5 +1,6 @@
|
||||
package com.r3corda.core.node.services
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting
|
||||
import com.google.common.util.concurrent.ListenableFuture
|
||||
import com.r3corda.core.contracts.Contract
|
||||
import com.r3corda.core.crypto.Party
|
||||
@ -32,6 +33,8 @@ interface NetworkMapCache {
|
||||
val partyNodes: List<NodeInfo>
|
||||
/** Tracks changes to the network map cache */
|
||||
val changed: Observable<MapChange>
|
||||
/** Future to track completion of the NetworkMapService registration. */
|
||||
val mapServiceRegistered: ListenableFuture<Unit>
|
||||
|
||||
/**
|
||||
* A list of nodes that advertise a regulatory service. Identifying the correct regulator for a trade is outside
|
||||
@ -97,6 +100,12 @@ interface NetworkMapCache {
|
||||
* @param service the network map service to fetch current state from.
|
||||
*/
|
||||
fun deregisterForUpdates(net: MessagingService, service: NodeInfo): ListenableFuture<Unit>
|
||||
|
||||
/**
|
||||
* For testing where the network map cache is manipulated marks the service as immediately ready.
|
||||
*/
|
||||
@VisibleForTesting
|
||||
fun runWithoutMapService()
|
||||
}
|
||||
|
||||
sealed class NetworkCacheError : Exception() {
|
||||
|
@ -297,10 +297,11 @@ abstract class AbstractNode(val dir: Path, val configuration: NodeConfiguration,
|
||||
}
|
||||
services.networkMapCache.addNode(info)
|
||||
// In the unit test environment, we may run without any network map service sometimes.
|
||||
if (networkMapService == null && inNodeNetworkMapService == null)
|
||||
if (networkMapService == null && inNodeNetworkMapService == null) {
|
||||
services.networkMapCache.runWithoutMapService()
|
||||
return noNetworkMapConfigured()
|
||||
else
|
||||
return registerWithNetworkMap(networkMapService ?: info.address)
|
||||
}
|
||||
return registerWithNetworkMap(networkMapService ?: info.address)
|
||||
}
|
||||
|
||||
private fun registerWithNetworkMap(networkMapServiceAddress: SingleMessageRecipient): ListenableFuture<Unit> {
|
||||
|
@ -34,4 +34,10 @@ data class Checkpoint(
|
||||
val serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>,
|
||||
val request: ProtocolIORequest?,
|
||||
val receivedPayload: Any?
|
||||
)
|
||||
) {
|
||||
// This flag is always false when loaded from storage as it isn't serialised.
|
||||
// It is used to track when the associated fiber has been created, but not necessarily started when
|
||||
// messages for protocols arrive before the system has fully loaded at startup.
|
||||
@Transient
|
||||
var fiberCreated: Boolean = false
|
||||
}
|
@ -1,5 +1,6 @@
|
||||
package com.r3corda.node.services.network
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting
|
||||
import com.google.common.util.concurrent.ListenableFuture
|
||||
import com.google.common.util.concurrent.MoreExecutors
|
||||
import com.google.common.util.concurrent.SettableFuture
|
||||
@ -45,6 +46,9 @@ open class InMemoryNetworkMapCache : SingletonSerializeAsToken(), NetworkMapCach
|
||||
get() = registeredNodes.map { it.value }
|
||||
private val _changed = PublishSubject.create<MapChange>()
|
||||
override val changed: Observable<MapChange> = _changed
|
||||
private val _registrationFuture = SettableFuture.create<Unit>()
|
||||
override val mapServiceRegistered: ListenableFuture<Unit>
|
||||
get() = _registrationFuture
|
||||
|
||||
private var registeredForPush = false
|
||||
protected var registeredNodes = Collections.synchronizedMap(HashMap<Party, NodeInfo>())
|
||||
@ -82,6 +86,7 @@ open class InMemoryNetworkMapCache : SingletonSerializeAsToken(), NetworkMapCach
|
||||
// Add a message handler for the response, and prepare a future to put the data into.
|
||||
// Note that the message handler will run on the network thread (not this one).
|
||||
val future = SettableFuture.create<Unit>()
|
||||
_registrationFuture.setFuture(future)
|
||||
net.runOnNextMessage(NetworkMapService.FETCH_PROTOCOL_TOPIC, sessionID, MoreExecutors.directExecutor()) { message ->
|
||||
val resp = message.data.deserialize<NetworkMapService.FetchMapResponse>()
|
||||
// We may not receive any nodes back, if the map hasn't changed since the version specified
|
||||
@ -120,6 +125,7 @@ open class InMemoryNetworkMapCache : SingletonSerializeAsToken(), NetworkMapCach
|
||||
// Add a message handler for the response, and prepare a future to put the data into.
|
||||
// Note that the message handler will run on the network thread (not this one).
|
||||
val future = SettableFuture.create<Unit>()
|
||||
_registrationFuture.setFuture(future)
|
||||
net.runOnNextMessage(NetworkMapService.SUBSCRIPTION_PROTOCOL_TOPIC, sessionID, MoreExecutors.directExecutor()) { message ->
|
||||
val resp = message.data.deserialize<NetworkMapService.SubscribeResponse>()
|
||||
if (resp.confirmed) {
|
||||
@ -151,4 +157,9 @@ open class InMemoryNetworkMapCache : SingletonSerializeAsToken(), NetworkMapCach
|
||||
AddOrRemove.REMOVE -> removeNode(reg.node)
|
||||
}
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
override fun runWithoutMapService() {
|
||||
_registrationFuture.set(Unit)
|
||||
}
|
||||
}
|
@ -14,6 +14,7 @@ import com.r3corda.core.messaging.send
|
||||
import com.r3corda.core.protocols.ProtocolLogic
|
||||
import com.r3corda.core.protocols.ProtocolStateMachine
|
||||
import com.r3corda.core.serialization.*
|
||||
import com.r3corda.core.then
|
||||
import com.r3corda.core.utilities.ProgressTracker
|
||||
import com.r3corda.core.utilities.trace
|
||||
import com.r3corda.node.services.api.Checkpoint
|
||||
@ -73,6 +74,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
||||
private val checkpointingMeter = metrics.meter("Protocols.Checkpointing Rate")
|
||||
private val totalStartedProtocols = metrics.counter("Protocols.Started")
|
||||
private val totalFinishedProtocols = metrics.counter("Protocols.Finished")
|
||||
private var started = false
|
||||
|
||||
// Context for tokenized services in checkpoints
|
||||
private val serializationContext = SerializeAsTokenContext(tokenizableServices, quasarKryo())
|
||||
@ -118,13 +120,23 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
||||
}
|
||||
|
||||
fun start() {
|
||||
checkpointStorage.checkpoints.forEach { restoreFromCheckpoint(it) }
|
||||
checkpointStorage.checkpoints.forEach { createFiberForCheckpoint(it) }
|
||||
serviceHub.networkMapCache.mapServiceRegistered.then(executor) {
|
||||
synchronized(started) {
|
||||
started = true
|
||||
stateMachines.forEach { restartFiber(it.key, it.value) }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun restoreFromCheckpoint(checkpoint: Checkpoint) {
|
||||
val fiber = deserializeFiber(checkpoint.serialisedFiber)
|
||||
initFiber(fiber, { checkpoint })
|
||||
private fun createFiberForCheckpoint(checkpoint: Checkpoint) {
|
||||
if (!checkpoint.fiberCreated) {
|
||||
val fiber = deserializeFiber(checkpoint.serialisedFiber)
|
||||
initFiber(fiber, { checkpoint })
|
||||
}
|
||||
}
|
||||
|
||||
private fun restartFiber(fiber: ProtocolStateMachineImpl<*>, checkpoint: Checkpoint) {
|
||||
if (checkpoint.request is ReceiveRequest<*>) {
|
||||
val topicSession = checkpoint.request.receiveTopicSession
|
||||
fiber.logger.info("Restored ${fiber.logic} - it was previously waiting for message of type ${checkpoint.request.receiveType.name} on $topicSession")
|
||||
@ -179,7 +191,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
||||
}
|
||||
}
|
||||
|
||||
private fun initFiber(psm: ProtocolStateMachineImpl<*>, startingCheckpoint: () -> Checkpoint) {
|
||||
private fun initFiber(psm: ProtocolStateMachineImpl<*>, startingCheckpoint: () -> Checkpoint): Checkpoint {
|
||||
psm.serviceHub = serviceHub
|
||||
psm.suspendAction = { request ->
|
||||
psm.logger.trace { "Suspended fiber ${psm.id} ${psm.logic}" }
|
||||
@ -194,8 +206,12 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
||||
totalFinishedProtocols.inc()
|
||||
notifyChangeObservers(psm, AddOrRemove.REMOVE)
|
||||
}
|
||||
stateMachines[psm] = startingCheckpoint()
|
||||
val checkpoint = startingCheckpoint()
|
||||
checkpoint.fiberCreated = true
|
||||
totalStartedProtocols.inc()
|
||||
stateMachines[psm] = checkpoint
|
||||
notifyChangeObservers(psm, AddOrRemove.ADD)
|
||||
return checkpoint
|
||||
}
|
||||
|
||||
/**
|
||||
@ -206,17 +222,22 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
||||
fun <T> add(loggerName: String, logic: ProtocolLogic<T>): ProtocolStateMachine<T> {
|
||||
val fiber = ProtocolStateMachineImpl(logic, scheduler, loggerName)
|
||||
// Need to add before iterating in case of immediate completion
|
||||
initFiber(fiber) {
|
||||
val checkpoint = initFiber(fiber) {
|
||||
val checkpoint = Checkpoint(serializeFiber(fiber), null, null)
|
||||
checkpointStorage.addCheckpoint(checkpoint)
|
||||
checkpoint
|
||||
}
|
||||
checkpointStorage.addCheckpoint(checkpoint)
|
||||
synchronized(started) { // If we are not started then our checkpoint will be picked up during start
|
||||
if (!started) {
|
||||
return fiber
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
executor.executeASAP {
|
||||
iterateStateMachine(fiber, null) {
|
||||
fiber.start()
|
||||
}
|
||||
totalStartedProtocols.inc()
|
||||
}
|
||||
} catch (e: ExecutionException) {
|
||||
// There are two ways we can take exceptions in this method:
|
||||
|
@ -91,6 +91,7 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
|
||||
smmHasRemovedAllProtocols.countDown()
|
||||
}
|
||||
}
|
||||
mockSMM.start()
|
||||
services.smm = mockSMM
|
||||
}
|
||||
|
||||
|
@ -13,6 +13,8 @@ import org.assertj.core.api.Assertions.assertThat
|
||||
import org.junit.After
|
||||
import org.junit.Before
|
||||
import org.junit.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
class StateMachineManagerTests {
|
||||
|
||||
@ -62,13 +64,72 @@ class StateMachineManagerTests {
|
||||
assertThat(restoredProtocol.receivedPayload).isEqualTo(payload)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `protocol added before network map does run after init`() {
|
||||
val node3 = net.createNode(node1.info.address) //create vanilla node
|
||||
val protocol = ProtocolNoBlocking()
|
||||
node3.smm.add("test", protocol)
|
||||
assertEquals(false, protocol.protocolStarted) // Not started yet as no network activity has been allowed yet
|
||||
net.runNetwork() // Allow network map messages to flow
|
||||
assertEquals(true, protocol.protocolStarted) // Now we should have run the protocol
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `protocol added before network map will be init checkpointed`() {
|
||||
var node3 = net.createNode(node1.info.address) //create vanilla node
|
||||
val protocol = ProtocolNoBlocking()
|
||||
node3.smm.add("test", protocol)
|
||||
assertEquals(false, protocol.protocolStarted) // Not started yet as no network activity has been allowed yet
|
||||
node3.stop()
|
||||
|
||||
node3 = net.createNode(node1.info.address, forcedID = node3.id)
|
||||
val restoredProtocol = node3.smm.findStateMachines(ProtocolNoBlocking::class.java).single().first
|
||||
assertEquals(false, restoredProtocol.protocolStarted) // Not started yet as no network activity has been allowed yet
|
||||
net.runNetwork() // Allow network map messages to flow
|
||||
node3.smm.executor.flush()
|
||||
assertEquals(true, restoredProtocol.protocolStarted) // Now we should have run the protocol and hopefully cleared the init checkpoint
|
||||
node3.stop()
|
||||
|
||||
// Now it is completed the protocol should leave no Checkpoint.
|
||||
node3 = net.createNode(node1.info.address, forcedID = node3.id)
|
||||
net.runNetwork() // Allow network map messages to flow
|
||||
node3.smm.executor.flush()
|
||||
assertTrue(node3.smm.findStateMachines(ProtocolNoBlocking::class.java).isEmpty())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `protocol loaded from checkpoint will respond to messages from before start`() {
|
||||
val topic = "send-and-receive"
|
||||
val payload = random63BitValue()
|
||||
val sendProtocol = SendProtocol(topic, node2.info.identity, payload)
|
||||
val receiveProtocol = ReceiveProtocol(topic, node1.info.identity)
|
||||
connectProtocols(sendProtocol, receiveProtocol)
|
||||
node2.smm.add("test", receiveProtocol) // Prepare checkpointed receive protocol
|
||||
node2.stop() // kill receiver
|
||||
node1.smm.add("test", sendProtocol) // now generate message to spool up and thus come in ahead of messages for NetworkMapService
|
||||
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveProtocol>(node1.info.address)
|
||||
assertThat(restoredProtocol.receivedPayload).isEqualTo(payload)
|
||||
}
|
||||
|
||||
private inline fun <reified P : NonTerminatingProtocol> MockNode.restartAndGetRestoredProtocol(networkMapAddress: SingleMessageRecipient? = null): P {
|
||||
val servicesArray = advertisedServices.toTypedArray()
|
||||
val node = mockNet.createNode(networkMapAddress, id, advertisedServices = *servicesArray)
|
||||
mockNet.runNetwork() // allow NetworkMapService messages to stabilise and thus start the state machine
|
||||
return node.smm.findStateMachines(P::class.java).single().first
|
||||
}
|
||||
|
||||
|
||||
private class ProtocolNoBlocking : ProtocolLogic<Unit>() {
|
||||
@Transient var protocolStarted = false
|
||||
|
||||
@Suspendable
|
||||
override fun call() {
|
||||
protocolStarted = true
|
||||
}
|
||||
|
||||
override val topic: String get() = throw UnsupportedOperationException()
|
||||
}
|
||||
|
||||
private class ProtocolWithoutCheckpoints : NonTerminatingProtocol() {
|
||||
|
||||
@Transient var protocolStarted = false
|
||||
|
@ -22,6 +22,7 @@ class MockNetworkMapCache() : com.r3corda.node.services.network.InMemoryNetworkM
|
||||
val mockNodeB = NodeInfo(MockAddress("bankD:8080"), Party("Bank D", DummyPublicKey("Bank D")))
|
||||
registeredNodes[mockNodeA.identity] = mockNodeA
|
||||
registeredNodes[mockNodeB.identity] = mockNodeB
|
||||
runWithoutMapService()
|
||||
}
|
||||
|
||||
/**
|
||||
|
Loading…
Reference in New Issue
Block a user