From 35274cd15cc71c189f824f328abefe61bb53ba03 Mon Sep 17 00:00:00 2001 From: "rick.parker" Date: Tue, 6 Sep 2016 16:46:36 +0100 Subject: [PATCH] Refactor network map service in preparation for persistence. Removed currently superfluous clock. --- .../com/r3corda/node/internal/AbstractNode.kt | 2 +- .../network/InMemoryNetworkMapCache.kt | 4 +- .../services/network/NetworkMapService.kt | 116 ++++++++++++------ .../services/InMemoryNetworkMapServiceTest.kt | 28 ++--- 4 files changed, 97 insertions(+), 53 deletions(-) diff --git a/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt index f6b5078c07..51d664ee48 100644 --- a/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt @@ -5,7 +5,6 @@ import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.MoreExecutors import com.google.common.util.concurrent.SettableFuture import com.r3corda.core.RunOnCallerThread -import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.crypto.Party import com.r3corda.core.messaging.runOnNextMessage import com.r3corda.core.node.CityDatabase @@ -21,6 +20,7 @@ import com.r3corda.core.seconds import com.r3corda.core.serialization.SingletonSerializeAsToken import com.r3corda.core.serialization.deserialize import com.r3corda.core.serialization.serialize +import com.r3corda.core.transactions.SignedTransaction import com.r3corda.node.api.APIServer import com.r3corda.node.services.api.* import com.r3corda.node.services.config.NodeConfiguration diff --git a/node/src/main/kotlin/com/r3corda/node/services/network/InMemoryNetworkMapCache.kt b/node/src/main/kotlin/com/r3corda/node/services/network/InMemoryNetworkMapCache.kt index b76dd04229..46920c83bd 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/network/InMemoryNetworkMapCache.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/network/InMemoryNetworkMapCache.kt @@ -5,7 +5,6 @@ import com.google.common.util.concurrent.MoreExecutors import com.google.common.util.concurrent.SettableFuture import com.r3corda.core.contracts.Contract import com.r3corda.core.crypto.Party -import com.r3corda.core.crypto.SecureHash import com.r3corda.core.messaging.MessagingService import com.r3corda.core.messaging.runOnNextMessage import com.r3corda.core.messaging.send @@ -65,9 +64,8 @@ open class InMemoryNetworkMapCache : SingletonSerializeAsToken(), NetworkMapCach net.addMessageHandler(NetworkMapService.PUSH_PROTOCOL_TOPIC, DEFAULT_SESSION_ID, null) { message, r -> try { val req = message.data.deserialize() - val hash = SecureHash.sha256(req.wireReg.serialize().bits) val ackMessage = net.createMessage(NetworkMapService.PUSH_ACK_PROTOCOL_TOPIC, DEFAULT_SESSION_ID, - NetworkMapService.UpdateAcknowledge(hash, net.myAddress).serialize().bits) + NetworkMapService.UpdateAcknowledge(req.mapVersion, net.myAddress).serialize().bits) net.send(ackMessage, req.replyTo) processUpdatePush(req) } catch(e: NodeMapError) { diff --git a/node/src/main/kotlin/com/r3corda/node/services/network/NetworkMapService.kt b/node/src/main/kotlin/com/r3corda/node/services/network/NetworkMapService.kt index e3679f6810..e86366febe 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/network/NetworkMapService.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/network/NetworkMapService.kt @@ -1,8 +1,11 @@ package com.r3corda.node.services.network -import co.paralleluniverse.common.util.VisibleForTesting +import com.google.common.annotations.VisibleForTesting import com.r3corda.core.ThreadBox -import com.r3corda.core.crypto.* +import com.r3corda.core.crypto.DigitalSignature +import com.r3corda.core.crypto.Party +import com.r3corda.core.crypto.SignedData +import com.r3corda.core.crypto.signWithECDSA import com.r3corda.core.messaging.MessageRecipients import com.r3corda.core.messaging.MessagingService import com.r3corda.core.messaging.SingleMessageRecipient @@ -16,6 +19,7 @@ import com.r3corda.core.serialization.serialize import com.r3corda.node.services.api.AbstractNodeService import com.r3corda.node.utilities.AddOrRemove import com.r3corda.protocols.ServiceRequestMessage +import kotlinx.support.jdk8.collections.compute import org.slf4j.LoggerFactory import java.security.PrivateKey import java.security.SignatureException @@ -73,17 +77,43 @@ interface NetworkMapService { data class RegistrationResponse(val success: Boolean) class SubscribeRequest(val subscribe: Boolean, replyTo: MessageRecipients, override val sessionID: Long) : NetworkMapRequestMessage(replyTo) data class SubscribeResponse(val confirmed: Boolean) - data class Update(val wireReg: WireNodeRegistration, val replyTo: MessageRecipients) - data class UpdateAcknowledge(val wireRegHash: SecureHash, val replyTo: MessageRecipients) + data class Update(val wireReg: WireNodeRegistration, val mapVersion: Int, val replyTo: MessageRecipients) + data class UpdateAcknowledge(val mapVersion: Int, val replyTo: MessageRecipients) } - @ThreadSafe -class InMemoryNetworkMapService(net: MessagingService, home: NodeRegistration, val cache: NetworkMapCache) : NetworkMapService, AbstractNodeService(net, cache) { - private val registeredNodes = ConcurrentHashMap() - // Map from subscriber address, to a list of unacknowledged updates - private val subscribers = ThreadBox(mutableMapOf>()) - private val mapVersion = AtomicInteger(1) +class InMemoryNetworkMapService(net: MessagingService, home: NodeRegistration, cache: NetworkMapCache) : + AbstractNetworkMapService(net, cache) { + + override val registeredNodes: MutableMap = ConcurrentHashMap() + override val subscribers = ThreadBox(mutableMapOf()) + + init { + setup(home) + } +} + +/** + * Abstracted out core functionality as the basis for a persistent implementation, as well as existing in-memory implementation. + * + * Design is slightly refactored to track time and map version of last acknowledge per subscriber to facilitate + * subscriber clean up and is simpler to persist than the previous implementation based on a set of missing messages acks. + */ +@ThreadSafe +abstract class AbstractNetworkMapService(net: MessagingService, val cache: NetworkMapCache) : NetworkMapService, AbstractNodeService(net, cache) { + protected abstract val registeredNodes: MutableMap + + // Map from subscriber address, to most recently acknowledged update map version. + protected abstract val subscribers: ThreadBox> + + protected val _mapVersion = AtomicInteger(0) + + @VisibleForTesting + val mapVersion: Int + get() = _mapVersion.get() + + private fun mapVersionIncrementAndGet(): Int = _mapVersion.incrementAndGet() + /** Maximum number of unacknowledged updates to send to a node before automatically unregistering them for updates */ val maxUnacknowledgedUpdates = 10 /** @@ -94,12 +124,13 @@ class InMemoryNetworkMapService(net: MessagingService, home: NodeRegistration, v // Filter reduces this to the entries that add a node to the map override val nodes: List - get() = registeredNodes.mapNotNull { if (it.value.type == AddOrRemove.ADD) it.value.node else null } + get() = registeredNodes.mapNotNull { if (it.value.reg.type == AddOrRemove.ADD) it.value.reg.node else null } - init { + protected fun setup(home: NodeRegistration) { // Register the local node with the service val homeIdentity = home.node.identity - registeredNodes[homeIdentity] = home + val registrationInfo = NodeRegistrationInfo(home, mapVersionIncrementAndGet()) + registeredNodes[homeIdentity] = registrationInfo // Register message handlers addMessageHandler(NetworkMapService.FETCH_PROTOCOL_TOPIC, @@ -118,13 +149,15 @@ class InMemoryNetworkMapService(net: MessagingService, home: NodeRegistration, v val req = message.data.deserialize() processAcknowledge(req) } + + // TODO: notify subscribers of name service registration. Network service is not up, so how? } private fun addSubscriber(subscriber: MessageRecipients) { if (subscriber !is SingleMessageRecipient) throw NodeMapError.InvalidSubscriber() subscribers.locked { if (!containsKey(subscriber)) { - put(subscriber, mutableListOf()) + put(subscriber, LastAcknowledgeInfo(mapVersion)) } } } @@ -135,24 +168,31 @@ class InMemoryNetworkMapService(net: MessagingService, home: NodeRegistration, v } @VisibleForTesting - fun getUnacknowledgedCount(subscriber: SingleMessageRecipient): Int? - = subscribers.locked { get(subscriber)?.count() } + fun getUnacknowledgedCount(subscriber: SingleMessageRecipient, mapVersion: Int): Int? { + return subscribers.locked { + val subscriberMapVersion = get(subscriber)?.mapVersion + if (subscriberMapVersion != null) { + mapVersion - subscriberMapVersion + } else { + null + } + } + } @VisibleForTesting - fun notifySubscribers(wireReg: WireNodeRegistration) { + fun notifySubscribers(wireReg: WireNodeRegistration, mapVersion: Int) { // TODO: Once we have a better established messaging system, we can probably send - // to a MessageRecipientGroup that nodes join/leave, rather than the network map - // service itself managing the group - val update = NetworkMapService.Update(wireReg, net.myAddress).serialize().bits + // to a MessageRecipientGroup that nodes join/leave, rather than the network map + // service itself managing the group + val update = NetworkMapService.Update(wireReg, mapVersion, net.myAddress).serialize().bits val message = net.createMessage(NetworkMapService.PUSH_PROTOCOL_TOPIC, DEFAULT_SESSION_ID, update) subscribers.locked { val toRemove = mutableListOf() - val hash = SecureHash.sha256(wireReg.raw.bits) - forEach { subscriber: Map.Entry> -> - val unacknowledged = subscriber.value - if (unacknowledged.count() < maxUnacknowledgedUpdates) { - unacknowledged.add(hash) + forEach { subscriber: Map.Entry -> + val unacknowledgedCount = mapVersion - subscriber.value.mapVersion + // TODO: introduce some concept of time in the condition to avoid unsubscribes when there's a message burst. + if (unacknowledgedCount <= maxUnacknowledgedUpdates) { net.send(message, subscriber.key) } else { toRemove.add(subscriber.key) @@ -164,8 +204,12 @@ class InMemoryNetworkMapService(net: MessagingService, home: NodeRegistration, v @VisibleForTesting fun processAcknowledge(req: NetworkMapService.UpdateAcknowledge): Unit { + if (req.replyTo !is SingleMessageRecipient) throw NodeMapError.InvalidSubscriber() subscribers.locked { - this[req.replyTo]?.remove(req.wireRegHash) + val lastVersionAcked = this[req.replyTo]?.mapVersion + if ((lastVersionAcked ?: 0) < req.mapVersion) { + this[req.replyTo] = LastAcknowledgeInfo(req.mapVersion) + } } } @@ -174,9 +218,9 @@ class InMemoryNetworkMapService(net: MessagingService, home: NodeRegistration, v if (req.subscribe) { addSubscriber(req.replyTo) } - val ver = mapVersion.get() + val ver = mapVersion if (req.ifChangedSinceVersion == null || req.ifChangedSinceVersion < ver) { - val nodes = ArrayList(registeredNodes.values) // Snapshot to avoid attempting to serialise ConcurrentHashMap internals + val nodes = ArrayList(registeredNodes.values.map { it.reg }) // Snapshot to avoid attempting to serialise Map internals return NetworkMapService.FetchMapResponse(nodes, ver) } else { return NetworkMapService.FetchMapResponse(null, ver) @@ -185,7 +229,7 @@ class InMemoryNetworkMapService(net: MessagingService, home: NodeRegistration, v @VisibleForTesting fun processQueryRequest(req: NetworkMapService.QueryIdentityRequest): NetworkMapService.QueryIdentityResponse { - val candidate = registeredNodes[req.identity] + val candidate = registeredNodes[req.identity]?.reg // If the most recent record we have is of the node being removed from the map, then it's considered // as no match. @@ -212,12 +256,12 @@ class InMemoryNetworkMapService(net: MessagingService, home: NodeRegistration, v // Update the current value atomically, so that if multiple updates come // in on different threads, there is no risk of a race condition while checking // sequence numbers. - registeredNodes.compute(node.identity, { mapKey: Party, existing: NodeRegistration? -> - changed = existing == null || existing.serial < change.serial + val registrationInfo = registeredNodes.compute(node.identity, { mapKey: Party, existing: NodeRegistrationInfo? -> + changed = existing == null || existing.reg.serial < change.serial if (changed) { when (change.type) { - AddOrRemove.ADD -> change - AddOrRemove.REMOVE -> change + AddOrRemove.ADD -> NodeRegistrationInfo(change, mapVersionIncrementAndGet()) + AddOrRemove.REMOVE -> NodeRegistrationInfo(change, mapVersionIncrementAndGet()) else -> throw NodeMapError.UnknownChangeType() } } else { @@ -225,7 +269,7 @@ class InMemoryNetworkMapService(net: MessagingService, home: NodeRegistration, v } }) if (changed) { - notifySubscribers(req.wireReg) + notifySubscribers(req.wireReg, registrationInfo!!.mapVersion) // Update the local cache // TODO: Once local messaging is fixed, this should go over the network layer as it does to other @@ -241,7 +285,6 @@ class InMemoryNetworkMapService(net: MessagingService, home: NodeRegistration, v } } - mapVersion.incrementAndGet() } return NetworkMapService.RegistrationResponse(changed) } @@ -304,3 +347,6 @@ sealed class NodeMapError : Exception() { /** Thrown if a change arrives which is of an unknown type */ class UnknownChangeType : NodeMapError() } + +data class LastAcknowledgeInfo(val mapVersion: Int) +data class NodeRegistrationInfo(val reg: NodeRegistration, val mapVersion: Int) diff --git a/node/src/test/kotlin/com/r3corda/node/services/InMemoryNetworkMapServiceTest.kt b/node/src/test/kotlin/com/r3corda/node/services/InMemoryNetworkMapServiceTest.kt index a5601f54ee..d8524c572e 100644 --- a/node/src/test/kotlin/com/r3corda/node/services/InMemoryNetworkMapServiceTest.kt +++ b/node/src/test/kotlin/com/r3corda/node/services/InMemoryNetworkMapServiceTest.kt @@ -1,15 +1,14 @@ package com.r3corda.node.services import co.paralleluniverse.fibers.Suspendable -import com.r3corda.core.crypto.SecureHash import com.r3corda.core.node.NodeInfo import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.random63BitValue -import com.r3corda.testing.node.MockNetwork import com.r3corda.node.services.network.InMemoryNetworkMapService import com.r3corda.node.services.network.NetworkMapService import com.r3corda.node.services.network.NodeRegistration import com.r3corda.node.utilities.AddOrRemove +import com.r3corda.testing.node.MockNetwork import org.junit.Before import org.junit.Test import java.security.PrivateKey @@ -63,11 +62,11 @@ class InMemoryNetworkMapServiceTest { assert(!service.processRegistrationChangeRequest(NetworkMapService.RegistrationRequest(removeWireChange, mapServiceNode.info.address, Long.MIN_VALUE)).success) } - class TestAcknowledgePSM(val server: NodeInfo, val hash: SecureHash) : ProtocolLogic() { + class TestAcknowledgePSM(val server: NodeInfo, val mapVersion: Int) : ProtocolLogic() { override val topic: String get() = NetworkMapService.PUSH_ACK_PROTOCOL_TOPIC @Suspendable override fun call() { - val req = NetworkMapService.UpdateAcknowledge(hash, serviceHub.networkService.myAddress) + val req = NetworkMapService.UpdateAcknowledge(mapVersion, serviceHub.networkService.myAddress) send(server.identity, 0, req) } } @@ -145,39 +144,40 @@ class InMemoryNetworkMapServiceTest { network.runNetwork() subscribePsm.get() + val startingMapVersion = service.mapVersion + // Check the unacknowledged count is zero - assertEquals(0, service.getUnacknowledgedCount(registerNode.info.address)) + assertEquals(0, service.getUnacknowledgedCount(registerNode.info.address, startingMapVersion)) // Fire off an update val nodeKey = registerNode.storage.myLegalIdentityKey - var seq = 1L + var seq = 0L val expires = Instant.now() + NetworkMapService.DEFAULT_EXPIRATION_PERIOD var reg = NodeRegistration(registerNode.info, seq++, AddOrRemove.ADD, expires) var wireReg = reg.toWire(nodeKey.private) - service.notifySubscribers(wireReg) + service.notifySubscribers(wireReg, startingMapVersion + 1) // Check the unacknowledged count is one - assertEquals(1, service.getUnacknowledgedCount(registerNode.info.address)) + assertEquals(1, service.getUnacknowledgedCount(registerNode.info.address, startingMapVersion + 1)) // Send in an acknowledgment and verify the count goes down - val hash = SecureHash.sha256(wireReg.raw.bits) val acknowledgePsm = registerNode.services.startProtocol(NetworkMapService.PUSH_ACK_PROTOCOL_TOPIC, - TestAcknowledgePSM(mapServiceNode.info, hash)) + TestAcknowledgePSM(mapServiceNode.info, startingMapVersion + 1)) network.runNetwork() acknowledgePsm.get() - assertEquals(0, service.getUnacknowledgedCount(registerNode.info.address)) + assertEquals(0, service.getUnacknowledgedCount(registerNode.info.address, startingMapVersion + 1)) // Intentionally fill the pending acknowledgements to verify it doesn't drop subscribers before the limit // is hit. On the last iteration overflow the pending list, and check the node is unsubscribed for (i in 0..service.maxUnacknowledgedUpdates) { reg = NodeRegistration(registerNode.info, seq++, AddOrRemove.ADD, expires) wireReg = reg.toWire(nodeKey.private) - service.notifySubscribers(wireReg) + service.notifySubscribers(wireReg, i + startingMapVersion + 2) if (i < service.maxUnacknowledgedUpdates) { - assertEquals(i + 1, service.getUnacknowledgedCount(registerNode.info.address)) + assertEquals(i + 1, service.getUnacknowledgedCount(registerNode.info.address, i + startingMapVersion + 2)) } else { - assertNull(service.getUnacknowledgedCount(registerNode.info.address)) + assertNull(service.getUnacknowledgedCount(registerNode.info.address, i + startingMapVersion + 2)) } } }