Speed-up NodeRegistrationTest (#2873)

* Improve logging for NetworkMap requests

* Allow interrupt in polling if the process started successfully

* Put `advertiseNewParameters` back

* Additional log line to indicate when all the nodes are started

* Improve logging and use concurrent map since it is updated from multiple threads

* Change NetworkMap response validity duration and rename parameter accordingly

* Changes following code review from @shamsasari
This commit is contained in:
Viktor Kolomeyko 2018-04-03 17:33:42 +01:00 committed by GitHub
parent 65ff214130
commit 1f5559e3c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 39 additions and 23 deletions

View File

@ -2,11 +2,9 @@ package net.corda.node.utilities.registration
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.concurrent.transpose import net.corda.core.internal.concurrent.transpose
import net.corda.core.internal.logElapsedTime
import net.corda.core.messaging.startFlow import net.corda.core.messaging.startFlow
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.*
import net.corda.core.utilities.OpaqueBytes
import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.minutes
import net.corda.finance.DOLLARS import net.corda.finance.DOLLARS
import net.corda.finance.flows.CashIssueAndPaymentFlow import net.corda.finance.flows.CashIssueAndPaymentFlow
import net.corda.nodeapi.internal.crypto.CertificateAndKeyPair import net.corda.nodeapi.internal.crypto.CertificateAndKeyPair
@ -49,6 +47,7 @@ class NodeRegistrationTest {
private val notaryName = CordaX500Name("NotaryService", "Zurich", "CH") private val notaryName = CordaX500Name("NotaryService", "Zurich", "CH")
private val aliceName = CordaX500Name("Alice", "London", "GB") private val aliceName = CordaX500Name("Alice", "London", "GB")
private val genevieveName = CordaX500Name("Genevieve", "London", "GB") private val genevieveName = CordaX500Name("Genevieve", "London", "GB")
private val log = contextLogger()
} }
@Rule @Rule
@ -63,7 +62,7 @@ class NodeRegistrationTest {
@Before @Before
fun startServer() { fun startServer() {
server = NetworkMapServer( server = NetworkMapServer(
cacheTimeout = 1.minutes, pollInterval = 1.seconds,
hostAndPort = portAllocation.nextHostAndPort(), hostAndPort = portAllocation.nextHostAndPort(),
myHostNameValue = "localhost", myHostNameValue = "localhost",
additionalServices = registrationHandler) additionalServices = registrationHandler)
@ -93,6 +92,9 @@ class NodeRegistrationTest {
startNode(providedName = genevieveName), startNode(providedName = genevieveName),
defaultNotaryNode defaultNotaryNode
).transpose().getOrThrow() ).transpose().getOrThrow()
log.info("Nodes started")
val (alice, genevieve) = nodes val (alice, genevieve) = nodes
assertThat(registrationHandler.idsPolled).containsOnly( assertThat(registrationHandler.idsPolled).containsOnly(
@ -119,25 +121,33 @@ class RegistrationHandler(private val rootCertAndKeyPair: CertificateAndKeyPair)
private val certPaths = HashMap<String, CertPath>() private val certPaths = HashMap<String, CertPath>()
val idsPolled = HashSet<String>() val idsPolled = HashSet<String>()
companion object {
val log = loggerFor<RegistrationHandler>()
}
@POST @POST
@Consumes(MediaType.APPLICATION_OCTET_STREAM) @Consumes(MediaType.APPLICATION_OCTET_STREAM)
@Produces(MediaType.TEXT_PLAIN) @Produces(MediaType.TEXT_PLAIN)
fun registration(input: InputStream): Response { fun registration(input: InputStream): Response {
val certificationRequest = input.use { JcaPKCS10CertificationRequest(it.readBytes()) } return log.logElapsedTime("Registration") {
val (certPath, name) = createSignedClientCertificate( val certificationRequest = input.use { JcaPKCS10CertificationRequest(it.readBytes()) }
certificationRequest, val (certPath, name) = createSignedClientCertificate(
rootCertAndKeyPair.keyPair, certificationRequest,
listOf(rootCertAndKeyPair.certificate)) rootCertAndKeyPair.keyPair,
require(!name.organisation.contains("\\s".toRegex())) { "Whitespace in the organisation name not supported" } listOf(rootCertAndKeyPair.certificate))
certPaths[name.organisation] = certPath require(!name.organisation.contains("\\s".toRegex())) { "Whitespace in the organisation name not supported" }
return Response.ok(name.organisation).build() certPaths[name.organisation] = certPath
Response.ok(name.organisation).build()
}
} }
@GET @GET
@Path("{id}") @Path("{id}")
fun reply(@PathParam("id") id: String): Response { fun reply(@PathParam("id") id: String): Response {
idsPolled += id return log.logElapsedTime("Reply by Id") {
return buildResponse(certPaths[id]!!.certificates) idsPolled += id
buildResponse(certPaths[id]!!.certificates)
}
} }
private fun buildResponse(certificates: List<Certificate>): Response { private fun buildResponse(certificates: List<Certificate>): Response {

View File

@ -73,6 +73,7 @@ import java.time.Instant
import java.time.ZoneOffset.UTC import java.time.ZoneOffset.UTC
import java.time.format.DateTimeFormatter import java.time.format.DateTimeFormatter
import java.util.* import java.util.*
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.Executors import java.util.concurrent.Executors
import java.util.concurrent.ScheduledExecutorService import java.util.concurrent.ScheduledExecutorService
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
@ -101,7 +102,7 @@ class DriverDSLImpl(
override val shutdownManager get() = _shutdownManager!! override val shutdownManager get() = _shutdownManager!!
private val cordappPackages = extraCordappPackagesToScan + getCallerPackage() private val cordappPackages = extraCordappPackagesToScan + getCallerPackage()
// Map from a nodes legal name to an observable emitting the number of nodes in its network map. // Map from a nodes legal name to an observable emitting the number of nodes in its network map.
private val countObservables = mutableMapOf<CordaX500Name, Observable<Int>>() private val countObservables = ConcurrentHashMap<CordaX500Name, Observable<Int>>()
private val nodeNames = mutableSetOf<CordaX500Name>() private val nodeNames = mutableSetOf<CordaX500Name>()
/** /**
* Future which completes when the network map is available, whether a local one or one from the CZ. This future acts * Future which completes when the network map is available, whether a local one or one from the CZ. This future acts
@ -575,15 +576,17 @@ class DriverDSLImpl(
} }
/** /**
* @nodeName the name of the node which performs counting
* @param initial number of nodes currently in the network map of a running node. * @param initial number of nodes currently in the network map of a running node.
* @param networkMapCacheChangeObservable an observable returning the updates to the node network map. * @param networkMapCacheChangeObservable an observable returning the updates to the node network map.
* @return a [ConnectableObservable] which emits a new [Int] every time the number of registered nodes changes * @return a [ConnectableObservable] which emits a new [Int] every time the number of registered nodes changes
* the initial value emitted is always [initial] * the initial value emitted is always [initial]
*/ */
private fun nodeCountObservable(initial: Int, networkMapCacheChangeObservable: Observable<NetworkMapCache.MapChange>): private fun nodeCountObservable(nodeName: CordaX500Name, initial: Int, networkMapCacheChangeObservable: Observable<NetworkMapCache.MapChange>):
ConnectableObservable<Int> { ConnectableObservable<Int> {
val count = AtomicInteger(initial) val count = AtomicInteger(initial)
return networkMapCacheChangeObservable.map { return networkMapCacheChangeObservable.map {
log.debug("nodeCountObservable for '$nodeName' received '$it'")
when (it) { when (it) {
is NetworkMapCache.MapChange.Added -> count.incrementAndGet() is NetworkMapCache.MapChange.Added -> count.incrementAndGet()
is NetworkMapCache.MapChange.Removed -> count.decrementAndGet() is NetworkMapCache.MapChange.Removed -> count.decrementAndGet()
@ -599,8 +602,9 @@ class DriverDSLImpl(
*/ */
private fun allNodesConnected(rpc: CordaRPCOps): CordaFuture<Int> { private fun allNodesConnected(rpc: CordaRPCOps): CordaFuture<Int> {
val (snapshot, updates) = rpc.networkMapFeed() val (snapshot, updates) = rpc.networkMapFeed()
val counterObservable = nodeCountObservable(snapshot.size, updates) val nodeName = rpc.nodeInfo().legalIdentities[0].name
countObservables[rpc.nodeInfo().legalIdentities[0].name] = counterObservable val counterObservable = nodeCountObservable(nodeName, snapshot.size, updates)
countObservables[nodeName] = counterObservable
/* TODO: this might not always be the exact number of nodes one has to wait for, /* TODO: this might not always be the exact number of nodes one has to wait for,
* for example in the following sequence * for example in the following sequence
* 1 start 3 nodes in order, A, B, C. * 1 start 3 nodes in order, A, B, C.
@ -611,6 +615,7 @@ class DriverDSLImpl(
// This is an observable which yield the minimum number of nodes in each node network map. // This is an observable which yield the minimum number of nodes in each node network map.
val smallestSeenNetworkMapSize = Observable.combineLatest(countObservables.values.toList()) { args: Array<Any> -> val smallestSeenNetworkMapSize = Observable.combineLatest(countObservables.values.toList()) { args: Array<Any> ->
log.debug("smallestSeenNetworkMapSize for '$nodeName' is: ${args.toList()}")
args.map { it as Int }.min() ?: 0 args.map { it as Int }.min() ?: 0
} }
val future = smallestSeenNetworkMapSize.filter { it >= requiredNodes }.toFuture() val future = smallestSeenNetworkMapSize.filter { it >= requiredNodes }.toFuture()
@ -701,7 +706,8 @@ class DriverDSLImpl(
if (it == processDeathFuture) { if (it == processDeathFuture) {
throw ListenProcessDeathException(config.corda.p2pAddress, process) throw ListenProcessDeathException(config.corda.p2pAddress, process)
} }
processDeathFuture.cancel(false) // Will interrupt polling for process death as this is no longer relevant since the process been successfully started and reflected itself in the NetworkMap.
processDeathFuture.cancel(true)
log.info("Node handle is ready. NodeInfo: ${rpc.nodeInfo()}, WebAddress: $webAddress") log.info("Node handle is ready. NodeInfo: ${rpc.nodeInfo()}, WebAddress: $webAddress")
OutOfProcessImpl(rpc.nodeInfo(), rpc, config.corda, webAddress, useHTTPS, debugPort, process, onNodeExit) OutOfProcessImpl(rpc.nodeInfo(), rpc, config.corda, webAddress, useHTTPS, debugPort, process, onNodeExit)
} }

View File

@ -32,7 +32,7 @@ import javax.ws.rs.core.Response
import javax.ws.rs.core.Response.ok import javax.ws.rs.core.Response.ok
import javax.ws.rs.core.Response.status import javax.ws.rs.core.Response.status
class NetworkMapServer(private val cacheTimeout: Duration, class NetworkMapServer(private val pollInterval: Duration,
hostAndPort: NetworkHostAndPort, hostAndPort: NetworkHostAndPort,
private val networkMapCertAndKeyPair: CertificateAndKeyPair = createDevNetworkMapCa(), private val networkMapCertAndKeyPair: CertificateAndKeyPair = createDevNetworkMapCa(),
private val myHostNameValue: String = "test.host.name", private val myHostNameValue: String = "test.host.name",
@ -137,7 +137,7 @@ class NetworkMapServer(private val cacheTimeout: Duration,
fun getNetworkMap(): Response { fun getNetworkMap(): Response {
val networkMap = NetworkMap(nodeInfoMap.keys.toList(), signedNetParams.raw.hash, parametersUpdate) val networkMap = NetworkMap(nodeInfoMap.keys.toList(), signedNetParams.raw.hash, parametersUpdate)
val signedNetworkMap = networkMapCertAndKeyPair.sign(networkMap) val signedNetworkMap = networkMapCertAndKeyPair.sign(networkMap)
return Response.ok(signedNetworkMap.serialize().bytes).header("Cache-Control", "max-age=${cacheTimeout.seconds}").build() return Response.ok(signedNetworkMap.serialize().bytes).header("Cache-Control", "max-age=${pollInterval.seconds}").build()
} }
// Remove nodeInfo for testing. // Remove nodeInfo for testing.
@ -177,4 +177,4 @@ class NetworkMapServer(private val cacheTimeout: Duration,
@Path("my-hostname") @Path("my-hostname")
fun getHostName(): Response = Response.ok(myHostNameValue).build() fun getHostName(): Response = Response.ok(myHostNameValue).build()
} }
} }