RPC: call close() on startup failure, add thread leak tests

This commit is contained in:
Andras Slemmer 2017-05-10 15:30:36 +01:00
parent 5f5f51bf51
commit 7c3a566197
12 changed files with 242 additions and 176 deletions

View File

@ -5,28 +5,100 @@ import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.pool.KryoPool import com.esotericsoftware.kryo.pool.KryoPool
import com.google.common.net.HostAndPort
import com.google.common.util.concurrent.Futures import com.google.common.util.concurrent.Futures
import net.corda.client.rpc.internal.RPCClientConfiguration
import net.corda.core.ErrorOr
import net.corda.core.getOrThrow import net.corda.core.getOrThrow
import net.corda.core.messaging.RPCOps import net.corda.core.messaging.RPCOps
import net.corda.core.millis import net.corda.core.millis
import net.corda.core.random63BitValue import net.corda.core.random63BitValue
import net.corda.node.driver.poll
import net.corda.node.services.messaging.RPCServerConfiguration import net.corda.node.services.messaging.RPCServerConfiguration
import net.corda.nodeapi.RPCApi import net.corda.nodeapi.RPCApi
import net.corda.nodeapi.RPCKryo import net.corda.nodeapi.RPCKryo
import net.corda.testing.* import net.corda.testing.*
import org.apache.activemq.artemis.api.core.SimpleString import org.apache.activemq.artemis.api.core.SimpleString
import org.junit.Assert.assertEquals
import org.junit.Test import org.junit.Test
import rx.Observable import rx.Observable
import rx.subjects.PublishSubject import rx.subjects.PublishSubject
import rx.subjects.UnicastSubject import rx.subjects.UnicastSubject
import java.time.Duration import java.time.Duration
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.Executors
import java.util.concurrent.ScheduledExecutorService
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.atomic.AtomicInteger
import kotlin.test.assertEquals
class RPCStabilityTests { class RPCStabilityTests {
object DummyOps : RPCOps {
override val protocolVersion = 0
}
private fun waitUntilNumberOfThreadsStable(executorService: ScheduledExecutorService): Int {
val values = ConcurrentLinkedQueue<Int>()
return poll(executorService, "number of threads to become stable", 250.millis) {
values.add(Thread.activeCount())
if (values.size > 5) {
values.poll()
}
val first = values.peek()
if (values.size == 5 && values.all { it == first } ) {
first
} else {
null
}
}.get()
}
@Test
fun `client and server dont leak threads`() {
val executor = Executors.newScheduledThreadPool(1)
fun startAndStop() {
rpcDriver {
val server = startRpcServer<RPCOps>(ops = DummyOps)
startRpcClient<RPCOps>(server.get().hostAndPort).get()
}
}
for (i in 1 .. 5) {
startAndStop()
}
val numberOfThreadsBefore = waitUntilNumberOfThreadsStable(executor)
for (i in 1 .. 5) {
startAndStop()
}
val numberOfThreadsAfter = waitUntilNumberOfThreadsStable(executor)
// This is a less than check because threads from other tests may be shutting down while this test is running.
// This is therefore a "best effort" check. When this test is run on its own this should be a strict equality.
require(numberOfThreadsBefore >= numberOfThreadsAfter)
executor.shutdownNow()
}
@Test
fun `client doesnt leak threads when it fails to start`() {
val executor = Executors.newScheduledThreadPool(1)
fun startAndStop() {
rpcDriver {
ErrorOr.catch { startRpcClient<RPCOps>(HostAndPort.fromString("localhost:9999")).get() }
val server = startRpcServer<RPCOps>(ops = DummyOps)
ErrorOr.catch { startRpcClient<RPCOps>(server.get().hostAndPort, configuration = RPCClientConfiguration.default.copy(minimumServerProtocolVersion = 1)).get() }
}
}
for (i in 1 .. 5) {
startAndStop()
}
val numberOfThreadsBefore = waitUntilNumberOfThreadsStable(executor)
for (i in 1 .. 5) {
startAndStop()
}
val numberOfThreadsAfter = waitUntilNumberOfThreadsStable(executor)
require(numberOfThreadsBefore >= numberOfThreadsAfter)
executor.shutdownNow()
}
interface LeakObservableOps: RPCOps { interface LeakObservableOps: RPCOps {
fun leakObservable(): Observable<Nothing> fun leakObservable(): Observable<Nothing>
} }

View File

@ -147,26 +147,32 @@ class RPCClient<I : RPCOps>(
} }
val proxyHandler = RPCClientProxyHandler(rpcConfiguration, username, password, serverLocator, clientAddress, rpcOpsClass) val proxyHandler = RPCClientProxyHandler(rpcConfiguration, username, password, serverLocator, clientAddress, rpcOpsClass)
proxyHandler.start() try {
proxyHandler.start()
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
val ops = Proxy.newProxyInstance(rpcOpsClass.classLoader, arrayOf(rpcOpsClass), proxyHandler) as I val ops = Proxy.newProxyInstance(rpcOpsClass.classLoader, arrayOf(rpcOpsClass), proxyHandler) as I
val serverProtocolVersion = ops.protocolVersion val serverProtocolVersion = ops.protocolVersion
if (serverProtocolVersion < rpcConfiguration.minimumServerProtocolVersion) { if (serverProtocolVersion < rpcConfiguration.minimumServerProtocolVersion) {
throw RPCException("Requested minimum protocol version (${rpcConfiguration.minimumServerProtocolVersion}) is higher" + throw RPCException("Requested minimum protocol version (${rpcConfiguration.minimumServerProtocolVersion}) is higher" +
" than the server's supported protocol version ($serverProtocolVersion)") " than the server's supported protocol version ($serverProtocolVersion)")
}
proxyHandler.setServerProtocolVersion(serverProtocolVersion)
log.debug("RPC connected, returning proxy")
object : RPCConnection<I> {
override val proxy = ops
override val serverProtocolVersion = serverProtocolVersion
override fun close() {
proxyHandler.close()
serverLocator.close()
} }
proxyHandler.setServerProtocolVersion(serverProtocolVersion)
log.debug("RPC connected, returning proxy")
object : RPCConnection<I> {
override val proxy = ops
override val serverProtocolVersion = serverProtocolVersion
override fun close() {
proxyHandler.close()
serverLocator.close()
}
}
} catch (exception: Throwable) {
proxyHandler.close()
serverLocator.close()
throw exception
} }
} }
} }

View File

@ -25,16 +25,11 @@ import org.apache.activemq.artemis.api.core.client.ServerLocator
import rx.Notification import rx.Notification
import rx.Observable import rx.Observable
import rx.subjects.UnicastSubject import rx.subjects.UnicastSubject
import sun.reflect.CallerSensitive
import java.lang.reflect.InvocationHandler import java.lang.reflect.InvocationHandler
import java.lang.reflect.Method import java.lang.reflect.Method
import java.util.* import java.util.*
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.*
import java.util.concurrent.Executors
import java.util.concurrent.ScheduledFuture
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.atomic.AtomicInteger
import kotlin.collections.ArrayList
import kotlin.reflect.jvm.javaMethod import kotlin.reflect.jvm.javaMethod
/** /**
@ -87,10 +82,7 @@ class RPCClientProxyHandler(
} }
// Used for reaping // Used for reaping
private val reaperExecutor = Executors.newScheduledThreadPool( private var reaperExecutor: ScheduledExecutorService? = null
1,
ThreadFactoryBuilder().setNameFormat("rpc-client-reaper-%d").build()
)
// A sticky pool for running Observable.onNext()s. We need the stickiness to preserve the observation ordering. // A sticky pool for running Observable.onNext()s. We need the stickiness to preserve the observation ordering.
private val observationExecutorThreadFactory = ThreadFactoryBuilder().setNameFormat("rpc-client-observation-pool-%d").build() private val observationExecutorThreadFactory = ThreadFactoryBuilder().setNameFormat("rpc-client-observation-pool-%d").build()
@ -109,7 +101,7 @@ class RPCClientProxyHandler(
hardReferenceStore = Collections.synchronizedSet(mutableSetOf<Observable<*>>()) hardReferenceStore = Collections.synchronizedSet(mutableSetOf<Observable<*>>())
) )
// Holds a reference to the scheduled reaper. // Holds a reference to the scheduled reaper.
private lateinit var reaperScheduledFuture: ScheduledFuture<*> private var reaperScheduledFuture: ScheduledFuture<*>? = null
// The protocol version of the server, to be initialised to the value of [RPCOps.protocolVersion] // The protocol version of the server, to be initialised to the value of [RPCOps.protocolVersion]
private var serverProtocolVersion: Int? = null private var serverProtocolVersion: Int? = null
@ -145,7 +137,7 @@ class RPCClientProxyHandler(
// TODO We may need to pool these somehow anyway, otherwise if the server sends many big messages in parallel a // TODO We may need to pool these somehow anyway, otherwise if the server sends many big messages in parallel a
// single consumer may be starved for flow control credits. Recheck this once Artemis's large message streaming is // single consumer may be starved for flow control credits. Recheck this once Artemis's large message streaming is
// integrated properly. // integrated properly.
private lateinit var sessionAndConsumer: ArtemisConsumer private var sessionAndConsumer: ArtemisConsumer? = null
// Pool producers to reduce contention on the client side. // Pool producers to reduce contention on the client side.
private val sessionAndProducerPool = LazyPool(bound = rpcConfiguration.producerPoolBound) { private val sessionAndProducerPool = LazyPool(bound = rpcConfiguration.producerPoolBound) {
// Note how we create new sessions *and* session factories per producer. // Note how we create new sessions *and* session factories per producer.
@ -162,7 +154,12 @@ class RPCClientProxyHandler(
* Start the client. This creates the per-client queue, starts the consumer session and the reaper. * Start the client. This creates the per-client queue, starts the consumer session and the reaper.
*/ */
fun start() { fun start() {
reaperScheduledFuture = reaperExecutor.scheduleAtFixedRate( lifeCycle.requireState(State.UNSTARTED)
reaperExecutor = Executors.newScheduledThreadPool(
1,
ThreadFactoryBuilder().setNameFormat("rpc-client-reaper-%d").build()
)
reaperScheduledFuture = reaperExecutor!!.scheduleAtFixedRate(
this::reapObservables, this::reapObservables,
rpcConfiguration.reapInterval.toMillis(), rpcConfiguration.reapInterval.toMillis(),
rpcConfiguration.reapInterval.toMillis(), rpcConfiguration.reapInterval.toMillis(),
@ -187,7 +184,7 @@ class RPCClientProxyHandler(
if (method == toStringMethod) { if (method == toStringMethod) {
return "Client RPC proxy for $rpcOpsClass" return "Client RPC proxy for $rpcOpsClass"
} }
if (sessionAndConsumer.session.isClosed) { if (sessionAndConsumer!!.session.isClosed) {
throw RPCException("RPC Proxy is closed") throw RPCException("RPC Proxy is closed")
} }
val rpcId = RPCApi.RpcRequestId(random63BitValue()) val rpcId = RPCApi.RpcRequestId(random63BitValue())
@ -268,13 +265,13 @@ class RPCClientProxyHandler(
* Closes the RPC proxy. Reaps all observables, shuts down the reaper, closes all sessions and executors. * Closes the RPC proxy. Reaps all observables, shuts down the reaper, closes all sessions and executors.
*/ */
fun close() { fun close() {
sessionAndConsumer.consumer.close() sessionAndConsumer?.consumer?.close()
sessionAndConsumer.session.close() sessionAndConsumer?.session?.close()
sessionAndConsumer.sessionFactory.close() sessionAndConsumer?.sessionFactory?.close()
reaperScheduledFuture.cancel(false) reaperScheduledFuture?.cancel(false)
observableContext.observableMap.invalidateAll() observableContext.observableMap.invalidateAll()
reapObservables() reapObservables()
reaperExecutor.shutdownNow() reaperExecutor?.shutdownNow()
sessionAndProducerPool.close().forEach { sessionAndProducerPool.close().forEach {
it.producer.close() it.producer.close()
it.session.close() it.session.close()
@ -284,8 +281,7 @@ class RPCClientProxyHandler(
// leak borrowed executors. // leak borrowed executors.
val observationExecutors = observationExecutorPool.close() val observationExecutors = observationExecutorPool.close()
observationExecutors.forEach { it.shutdownNow() } observationExecutors.forEach { it.shutdownNow() }
observationExecutors.forEach { it.awaitTermination(100, TimeUnit.MILLISECONDS) } lifeCycle.transition(State.FINISHED)
lifeCycle.transition(State.STARTED, State.FINISHED)
} }
/** /**

View File

@ -13,7 +13,7 @@ class LifeCycle<S : Enum<S>>(initial: S) {
private val lock = ReentrantReadWriteLock() private val lock = ReentrantReadWriteLock()
private var state = initial private var state = initial
/** Assert that the lifecycle in the [requiredState] */ /** Assert that the lifecycle in the [requiredState]. */
fun requireState(requiredState: S) { fun requireState(requiredState: S) {
requireState({ "Required state to be $requiredState, was $it" }) { it == requiredState } requireState({ "Required state to be $requiredState, was $it" }) { it == requiredState }
} }
@ -28,11 +28,18 @@ class LifeCycle<S : Enum<S>>(initial: S) {
} }
} }
/** Transition the state from [from] to [to] */ /** Transition the state from [from] to [to]. */
fun transition(from: S, to: S) { fun transition(from: S, to: S) {
lock.writeLock().withLock { lock.writeLock().withLock {
require(state == from) { "Required state to be $from to transition to $to, was $state" } require(state == from) { "Required state to be $from to transition to $to, was $state" }
state = to state = to
} }
} }
/** Transition the state to [to] without performing a current state check. */
fun transition(to: S) {
lock.writeLock().withLock {
state = to
}
}
} }

View File

@ -1,10 +1,11 @@
package net.corda.node package net.corda.node
import com.google.common.base.Stopwatch import com.google.common.base.Stopwatch
import net.corda.node.driver.FalseNetworkMap import net.corda.node.driver.NetworkMapStartStrategy
import net.corda.node.driver.driver import net.corda.node.driver.driver
import org.junit.Ignore import org.junit.Ignore
import org.junit.Test import org.junit.Test
import java.util.*
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
@Ignore("Only use locally") @Ignore("Only use locally")
@ -13,8 +14,8 @@ class NodeStartupPerformanceTests {
// Measure the startup time of nodes. Note that this includes an RPC roundtrip, which causes e.g. Kryo initialisation. // Measure the startup time of nodes. Note that this includes an RPC roundtrip, which causes e.g. Kryo initialisation.
@Test @Test
fun `single node startup time`() { fun `single node startup time`() {
driver(networkMapStrategy = FalseNetworkMap) { driver(networkMapStartStrategy = NetworkMapStartStrategy.Dedicated(startAutomatically = false)) {
startNetworkMapService().get() startDedicatedNetworkMapService().get()
val times = ArrayList<Long>() val times = ArrayList<Long>()
for (i in 1 .. 10) { for (i in 1 .. 10) {
val time = Stopwatch.createStarted().apply { val time = Stopwatch.createStarted().apply {

View File

@ -104,7 +104,7 @@ interface DriverDSLExposedInterface {
* Starts a network map service node. Note that only a single one should ever be running, so you will probably want * Starts a network map service node. Note that only a single one should ever be running, so you will probably want
* to set networkMapStrategy to FalseNetworkMap in your [driver] call. * to set networkMapStrategy to FalseNetworkMap in your [driver] call.
*/ */
fun startNetworkMapService(): ListenableFuture<Unit> fun startDedicatedNetworkMapService(): ListenableFuture<Unit>
fun waitForAllNodesToFinish() fun waitForAllNodesToFinish()
@ -168,6 +168,11 @@ sealed class PortAllocation {
} }
} }
sealed class NetworkMapStartStrategy {
data class Dedicated(val startAutomatically: Boolean) : NetworkMapStartStrategy()
data class Nominated(val legalName: X500Name, val address: HostAndPort) : NetworkMapStartStrategy()
}
/** /**
* [driver] allows one to start up nodes like this: * [driver] allows one to start up nodes like this:
* driver { * driver {
@ -201,7 +206,7 @@ fun <A> driver(
debugPortAllocation: PortAllocation = PortAllocation.Incremental(5005), debugPortAllocation: PortAllocation = PortAllocation.Incremental(5005),
systemProperties: Map<String, String> = emptyMap(), systemProperties: Map<String, String> = emptyMap(),
useTestClock: Boolean = false, useTestClock: Boolean = false,
networkMapStrategy: NetworkMapStrategy = DedicatedNetworkMap, networkMapStartStrategy: NetworkMapStartStrategy = NetworkMapStartStrategy.Dedicated(startAutomatically = true),
dsl: DriverDSLExposedInterface.() -> A dsl: DriverDSLExposedInterface.() -> A
) = genericDriver( ) = genericDriver(
driverDsl = DriverDSL( driverDsl = DriverDSL(
@ -210,7 +215,7 @@ fun <A> driver(
systemProperties = systemProperties, systemProperties = systemProperties,
driverDirectory = driverDirectory.toAbsolutePath(), driverDirectory = driverDirectory.toAbsolutePath(),
useTestClock = useTestClock, useTestClock = useTestClock,
networkMapStrategy = networkMapStrategy, networkMapStartStrategy = networkMapStartStrategy,
isDebug = isDebug isDebug = isDebug
), ),
coerce = { it }, coerce = { it },
@ -412,13 +417,14 @@ class DriverDSL(
val driverDirectory: Path, val driverDirectory: Path,
val useTestClock: Boolean, val useTestClock: Boolean,
val isDebug: Boolean, val isDebug: Boolean,
val networkMapStrategy: NetworkMapStrategy val networkMapStartStrategy: NetworkMapStartStrategy
) : DriverDSLInternalInterface { ) : DriverDSLInternalInterface {
private val dedicatedNetworkMapAddress = portAllocation.nextHostAndPort() private val dedicatedNetworkMapAddress = portAllocation.nextHostAndPort()
val executorService: ListeningScheduledExecutorService = MoreExecutors.listeningDecorator( private val dedicatedNetworkMapLegalName = DUMMY_MAP.name
Executors.newScheduledThreadPool(2, ThreadFactoryBuilder().setNameFormat("driver-pool-thread-%d").build()) var _executorService: ListeningScheduledExecutorService? = null
) val executorService get() = _executorService!!
override val shutdownManager = ShutdownManager(executorService) var _shutdownManager: ShutdownManager? = null
override val shutdownManager get() = _shutdownManager!!
class State { class State {
val processes = ArrayList<ListenableFuture<Process>>() val processes = ArrayList<ListenableFuture<Process>>()
@ -449,8 +455,11 @@ class DriverDSL(
} }
override fun shutdown() { override fun shutdown() {
shutdownManager.shutdown() _shutdownManager?.shutdown()
executorService.shutdown() _executorService?.apply {
shutdownNow()
require(awaitTermination(1, TimeUnit.SECONDS))
}
} }
private fun establishRpc(nodeAddress: HostAndPort, sslConfig: SSLConfiguration): ListenableFuture<CordaRPCOps> { private fun establishRpc(nodeAddress: HostAndPort, sslConfig: SSLConfiguration): ListenableFuture<CordaRPCOps> {
@ -467,6 +476,12 @@ class DriverDSL(
} }
} }
// TODO move to cmopanion
private fun toServiceConfig(address: HostAndPort, legalName: X500Name) = mapOf(
"address" to address.toString(),
"legalName" to legalName.toString()
)
override fun startNode( override fun startNode(
providedName: X500Name?, providedName: X500Name?,
advertisedServices: Set<ServiceInfo>, advertisedServices: Set<ServiceInfo>,
@ -487,7 +502,17 @@ class DriverDSL(
"rpcAddress" to rpcAddress.toString(), "rpcAddress" to rpcAddress.toString(),
"webAddress" to webAddress.toString(), "webAddress" to webAddress.toString(),
"extraAdvertisedServiceIds" to advertisedServices.map { it.toString() }, "extraAdvertisedServiceIds" to advertisedServices.map { it.toString() },
"networkMapService" to networkMapStrategy.serviceConfig(dedicatedNetworkMapAddress, name, p2pAddress), "networkMapService" to when (networkMapStartStrategy) {
is NetworkMapStartStrategy.Dedicated -> toServiceConfig(dedicatedNetworkMapAddress, dedicatedNetworkMapLegalName)
is NetworkMapStartStrategy.Nominated -> networkMapStartStrategy.run {
if (name != legalName) {
toServiceConfig(address, legalName)
} else {
p2pAddress == address || throw IllegalArgumentException("Passed-in address $address of nominated network map $legalName is wrong, it should be: $p2pAddress")
null
}
}
},
"useTestClock" to useTestClock, "useTestClock" to useTestClock,
"rpcUsers" to rpcUsers.map { "rpcUsers" to rpcUsers.map {
mapOf( mapOf(
@ -574,21 +599,24 @@ class DriverDSL(
} }
override fun start() { override fun start() {
if (networkMapStrategy.startDedicated) { _executorService = MoreExecutors.listeningDecorator(
startNetworkMapService() Executors.newScheduledThreadPool(2, ThreadFactoryBuilder().setNameFormat("driver-pool-thread-%d").build())
)
_shutdownManager = ShutdownManager(executorService)
if (networkMapStartStrategy is NetworkMapStartStrategy.Dedicated && networkMapStartStrategy.startAutomatically) {
startDedicatedNetworkMapService()
} }
} }
override fun startNetworkMapService(): ListenableFuture<Unit> { override fun startDedicatedNetworkMapService(): ListenableFuture<Unit> {
val debugPort = if (isDebug) debugPortAllocation.nextPort() else null val debugPort = if (isDebug) debugPortAllocation.nextPort() else null
val apiAddress = portAllocation.nextHostAndPort().toString() val apiAddress = portAllocation.nextHostAndPort().toString()
val networkMapLegalName = networkMapStrategy.legalName val baseDirectory = driverDirectory / dedicatedNetworkMapLegalName.commonName
val baseDirectory = driverDirectory / networkMapLegalName.commonName
val config = ConfigHelper.loadConfig( val config = ConfigHelper.loadConfig(
baseDirectory = baseDirectory, baseDirectory = baseDirectory,
allowMissingConfig = true, allowMissingConfig = true,
configOverrides = mapOf( configOverrides = mapOf(
"myLegalName" to networkMapLegalName.toString(), "myLegalName" to dedicatedNetworkMapLegalName.toString(),
// TODO: remove the webAddress as NMS doesn't need to run a web server. This will cause all // TODO: remove the webAddress as NMS doesn't need to run a web server. This will cause all
// node port numbers to be shifted, so all demos and docs need to be updated accordingly. // node port numbers to be shifted, so all demos and docs need to be updated accordingly.
"webAddress" to apiAddress, "webAddress" to apiAddress,
@ -684,3 +712,4 @@ fun writeConfig(path: Path, filename: String, config: Config) {
path.toFile().mkdirs() path.toFile().mkdirs()
File("$path/$filename").writeText(config.root().render(ConfigRenderOptions.defaults())) File("$path/$filename").writeText(config.root().render(ConfigRenderOptions.defaults()))
} }

View File

@ -1,47 +0,0 @@
package net.corda.node.driver
import com.google.common.net.HostAndPort
import net.corda.core.utilities.DUMMY_MAP
import org.bouncycastle.asn1.x500.X500Name
/**
* Instruct the driver how to set up the network map, if at all.
* @see FalseNetworkMap
* @see DedicatedNetworkMap
* @see NominatedNetworkMap
*/
abstract class NetworkMapStrategy(internal val startDedicated: Boolean, internal val legalName: X500Name) {
internal abstract fun serviceConfig(dedicatedAddress: HostAndPort, nodeName: X500Name, p2pAddress: HostAndPort): Map<String, String>?
}
private fun toServiceConfig(address: HostAndPort, legalName: X500Name) = mapOf(
"address" to address.toString(),
"legalName" to legalName.toString()
)
abstract class AbstractDedicatedNetworkMap(start: Boolean) : NetworkMapStrategy(start, DUMMY_MAP.name) {
override fun serviceConfig(dedicatedAddress: HostAndPort, nodeName: X500Name, p2pAddress: HostAndPort) = toServiceConfig(dedicatedAddress, legalName)
}
/**
* Do not start a network map.
*/
object FalseNetworkMap : AbstractDedicatedNetworkMap(false)
/**
* Start a dedicated node to host the network map.
*/
object DedicatedNetworkMap : AbstractDedicatedNetworkMap(true)
/**
* As in gradle-based demos, nominate a node to host the network map, so that there is one fewer node in total than in the [DedicatedNetworkMap] case.
* Will fail if the port you pass in does not match the P2P port the driver assigns to the named node.
*/
class NominatedNetworkMap(legalName: X500Name, private val address: HostAndPort) : NetworkMapStrategy(false, legalName) {
override fun serviceConfig(dedicatedAddress: HostAndPort, nodeName: X500Name, p2pAddress: HostAndPort) = if (nodeName != legalName) {
toServiceConfig(address, legalName)
} else {
p2pAddress == address || throw IllegalArgumentException("Passed-in address $address of nominated network map $legalName is wrong, it should be: $p2pAddress")
null
}
}

View File

@ -13,7 +13,6 @@ import com.google.common.collect.Multimaps
import com.google.common.collect.SetMultimap import com.google.common.collect.SetMultimap
import com.google.common.util.concurrent.ThreadFactoryBuilder import com.google.common.util.concurrent.ThreadFactoryBuilder
import net.corda.core.ErrorOr import net.corda.core.ErrorOr
import net.corda.core.crypto.commonName
import net.corda.core.messaging.RPCOps import net.corda.core.messaging.RPCOps
import net.corda.core.random63BitValue import net.corda.core.random63BitValue
import net.corda.core.seconds import net.corda.core.seconds
@ -42,10 +41,8 @@ import rx.Subscriber
import rx.Subscription import rx.Subscription
import java.lang.reflect.InvocationTargetException import java.lang.reflect.InvocationTargetException
import java.time.Duration import java.time.Duration
import java.util.concurrent.ExecutorService import java.util.*
import java.util.concurrent.Executors import java.util.concurrent.*
import java.util.concurrent.ScheduledFuture
import java.util.concurrent.TimeUnit
data class RPCServerConfiguration( data class RPCServerConfiguration(
/** The number of threads to use for handling RPC requests */ /** The number of threads to use for handling RPC requests */
@ -101,22 +98,11 @@ class RPCServer(
// A mapping from client addresses to IDs of associated Observables // A mapping from client addresses to IDs of associated Observables
private val clientAddressToObservables = Multimaps.synchronizedSetMultimap(HashMultimap.create<SimpleString, RPCApi.ObservableId>()) private val clientAddressToObservables = Multimaps.synchronizedSetMultimap(HashMultimap.create<SimpleString, RPCApi.ObservableId>())
// The scheduled reaper handle. // The scheduled reaper handle.
private lateinit var reaperScheduledFuture: ScheduledFuture<*> private var reaperScheduledFuture: ScheduledFuture<*>? = null
private val observationSendExecutor = Executors.newFixedThreadPool( private var observationSendExecutor: ExecutorService? = null
1, private var rpcExecutor: ScheduledExecutorService? = null
ThreadFactoryBuilder().setNameFormat("rpc-observation-sender-%d").build() private var reaperExecutor: ScheduledExecutorService? = null
)
private val rpcExecutor = Executors.newScheduledThreadPool(
rpcConfiguration.rpcThreadPoolSize,
ThreadFactoryBuilder().setNameFormat("rpc-server-handler-pool-%d").build()
)
private val reaperExecutor = Executors.newScheduledThreadPool(
1,
ThreadFactoryBuilder().setNameFormat("rpc-server-reaper-%d").build()
)
private val sessionAndConsumers = ArrayList<ArtemisConsumer>(rpcConfiguration.consumerPoolSize) private val sessionAndConsumers = ArrayList<ArtemisConsumer>(rpcConfiguration.consumerPoolSize)
private val sessionAndProducerPool = LazyStickyPool(rpcConfiguration.producerPoolBound) { private val sessionAndProducerPool = LazyStickyPool(rpcConfiguration.producerPoolBound) {
@ -125,8 +111,8 @@ class RPCServer(
session.start() session.start()
ArtemisProducer(sessionFactory, session, session.createProducer()) ArtemisProducer(sessionFactory, session, session.createProducer())
} }
private lateinit var clientBindingRemovalConsumer: ClientConsumer private var clientBindingRemovalConsumer: ClientConsumer? = null
private lateinit var serverControl: ActiveMQServerControl private var serverControl: ActiveMQServerControl? = null
private fun createObservableSubscriptionMap(): ObservableSubscriptionMap { private fun createObservableSubscriptionMap(): ObservableSubscriptionMap {
val onObservableRemove = RemovalListener<RPCApi.ObservableId, ObservableSubscription> { val onObservableRemove = RemovalListener<RPCApi.ObservableId, ObservableSubscription> {
@ -137,39 +123,55 @@ class RPCServer(
} }
fun start(activeMqServerControl: ActiveMQServerControl) { fun start(activeMqServerControl: ActiveMQServerControl) {
log.info("Starting RPC server with configuration $rpcConfiguration") try {
reaperScheduledFuture = reaperExecutor.scheduleAtFixedRate( lifeCycle.requireState(State.UNSTARTED)
this::reapSubscriptions, log.info("Starting RPC server with configuration $rpcConfiguration")
rpcConfiguration.reapInterval.toMillis(), observationSendExecutor = Executors.newFixedThreadPool(
rpcConfiguration.reapInterval.toMillis(), 1,
TimeUnit.MILLISECONDS ThreadFactoryBuilder().setNameFormat("rpc-observation-sender-%d").build()
) )
val sessions = ArrayList<ClientSession>() rpcExecutor = Executors.newScheduledThreadPool(
for (i in 1 .. rpcConfiguration.consumerPoolSize) { rpcConfiguration.rpcThreadPoolSize,
val sessionFactory = serverLocator.createSessionFactory() ThreadFactoryBuilder().setNameFormat("rpc-server-handler-pool-%d").build()
val session = sessionFactory.createSession(rpcServerUsername, rpcServerPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE) )
val consumer = session.createConsumer(RPCApi.RPC_SERVER_QUEUE_NAME) reaperExecutor = Executors.newScheduledThreadPool(
consumer.setMessageHandler(this@RPCServer::clientArtemisMessageHandler) 1,
sessionAndConsumers.add(ArtemisConsumer(sessionFactory, session, consumer)) ThreadFactoryBuilder().setNameFormat("rpc-server-reaper-%d").build()
sessions.add(session) )
} reaperScheduledFuture = reaperExecutor!!.scheduleAtFixedRate(
clientBindingRemovalConsumer = sessionAndConsumers[0].session.createConsumer(RPCApi.RPC_CLIENT_BINDING_REMOVALS) this::reapSubscriptions,
clientBindingRemovalConsumer.setMessageHandler(this::bindingRemovalArtemisMessageHandler) rpcConfiguration.reapInterval.toMillis(),
serverControl = activeMqServerControl rpcConfiguration.reapInterval.toMillis(),
lifeCycle.transition(State.UNSTARTED, State.STARTED) TimeUnit.MILLISECONDS
// We delay the consumer session start because Artemis starts delivering messages immediately, so we need to be )
// fully initialised. val sessions = ArrayList<ClientSession>()
sessions.forEach { for (i in 1 .. rpcConfiguration.consumerPoolSize) {
it.start() val sessionFactory = serverLocator.createSessionFactory()
val session = sessionFactory.createSession(rpcServerUsername, rpcServerPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE)
val consumer = session.createConsumer(RPCApi.RPC_SERVER_QUEUE_NAME)
consumer.setMessageHandler(this@RPCServer::clientArtemisMessageHandler)
sessionAndConsumers.add(ArtemisConsumer(sessionFactory, session, consumer))
sessions.add(session)
}
clientBindingRemovalConsumer = sessionAndConsumers[0].session.createConsumer(RPCApi.RPC_CLIENT_BINDING_REMOVALS)
clientBindingRemovalConsumer!!.setMessageHandler(this::bindingRemovalArtemisMessageHandler)
serverControl = activeMqServerControl
lifeCycle.transition(State.UNSTARTED, State.STARTED)
// We delay the consumer session start because Artemis starts delivering messages immediately, so we need to be
// fully initialised.
sessions.forEach {
it.start()
}
} catch (exception: Throwable) {
close()
throw exception
} }
} }
fun close() { fun close() {
reaperScheduledFuture.cancel(false) reaperScheduledFuture?.cancel(false)
rpcExecutor.shutdownNow() rpcExecutor?.shutdownNow()
reaperExecutor.shutdownNow() reaperExecutor?.shutdownNow()
rpcExecutor.awaitTermination(500, TimeUnit.MILLISECONDS)
reaperExecutor.awaitTermination(500, TimeUnit.MILLISECONDS)
sessionAndConsumers.forEach { sessionAndConsumers.forEach {
it.consumer.close() it.consumer.close()
it.session.close() it.session.close()
@ -182,7 +184,7 @@ class RPCServer(
it.session.close() it.session.close()
it.sessionFactory.close() it.sessionFactory.close()
} }
lifeCycle.transition(State.STARTED, State.FINISHED) lifeCycle.transition(State.FINISHED)
} }
private fun bindingRemovalArtemisMessageHandler(artemisMessage: ClientMessage) { private fun bindingRemovalArtemisMessageHandler(artemisMessage: ClientMessage) {
@ -211,7 +213,7 @@ class RPCServer(
val rpcContext = RpcContext( val rpcContext = RpcContext(
currentUser = getUser(artemisMessage) currentUser = getUser(artemisMessage)
) )
rpcExecutor.submit { rpcExecutor!!.submit {
val result = ErrorOr.catch { val result = ErrorOr.catch {
try { try {
CURRENT_RPC_CONTEXT.set(rpcContext) CURRENT_RPC_CONTEXT.set(rpcContext)
@ -239,9 +241,9 @@ class RPCServer(
observableMap, observableMap,
clientAddressToObservables, clientAddressToObservables,
clientToServer.clientAddress, clientToServer.clientAddress,
serverControl, serverControl!!,
sessionAndProducerPool, sessionAndProducerPool,
observationSendExecutor, observationSendExecutor!!,
kryoPool kryoPool
) )
observableContext.sendMessage(reply) observableContext.sendMessage(reply)

View File

@ -7,7 +7,7 @@ import net.corda.core.utilities.ALICE
import net.corda.core.utilities.BOB import net.corda.core.utilities.BOB
import net.corda.core.utilities.DUMMY_NOTARY import net.corda.core.utilities.DUMMY_NOTARY
import net.corda.flows.NotaryFlow import net.corda.flows.NotaryFlow
import net.corda.node.driver.NominatedNetworkMap import net.corda.node.driver.NetworkMapStartStrategy
import net.corda.node.driver.PortAllocation import net.corda.node.driver.PortAllocation
import net.corda.node.driver.driver import net.corda.node.driver.driver
import net.corda.node.services.startFlowPermission import net.corda.node.services.startFlowPermission
@ -20,8 +20,8 @@ import java.nio.file.Paths
/** Creates and starts all nodes required for the demo. */ /** Creates and starts all nodes required for the demo. */
fun main(args: Array<String>) { fun main(args: Array<String>) {
val demoUser = listOf(User("demo", "demo", setOf(startFlowPermission<DummyIssueAndMove>(), startFlowPermission<NotaryFlow.Client>()))) val demoUser = listOf(User("demo", "demo", setOf(startFlowPermission<DummyIssueAndMove>(), startFlowPermission<NotaryFlow.Client>())))
val networkMap = NominatedNetworkMap(DUMMY_NOTARY.name.appendToCommonName("1"), HostAndPort.fromParts("localhost", 10009)) val networkMap = NetworkMapStartStrategy.Nominated(DUMMY_NOTARY.name.appendToCommonName("1"), HostAndPort.fromParts("localhost", 10009))
driver(isDebug = true, driverDirectory = Paths.get("build") / "notary-demo-nodes", networkMapStrategy = networkMap, portAllocation = PortAllocation.Incremental(10001)) { driver(isDebug = true, driverDirectory = Paths.get("build") / "notary-demo-nodes", networkMapStartStrategy = networkMap, portAllocation = PortAllocation.Incremental(10001)) {
startNode(ALICE.name, rpcUsers = demoUser) startNode(ALICE.name, rpcUsers = demoUser)
startNode(BOB.name) startNode(BOB.name)
startNotaryCluster(X500Name("CN=Raft,O=R3,OU=corda,L=Zurich,C=CH"), clusterSize = 3, type = RaftValidatingNotaryService.type) startNotaryCluster(X500Name("CN=Raft,O=R3,OU=corda,L=Zurich,C=CH"), clusterSize = 3, type = RaftValidatingNotaryService.type)

View File

@ -194,7 +194,7 @@ fun <A> rpcDriver(
debugPortAllocation: PortAllocation = globalDebugPortAllocation, debugPortAllocation: PortAllocation = globalDebugPortAllocation,
systemProperties: Map<String, String> = emptyMap(), systemProperties: Map<String, String> = emptyMap(),
useTestClock: Boolean = false, useTestClock: Boolean = false,
networkMapStrategy: NetworkMapStrategy = FalseNetworkMap, networkMapStartStrategy: NetworkMapStartStrategy = FalseNetworkMap,
dsl: RPCDriverExposedDSLInterface.() -> A dsl: RPCDriverExposedDSLInterface.() -> A
) = genericDriver( ) = genericDriver(
driverDsl = RPCDriverDSL( driverDsl = RPCDriverDSL(
@ -204,7 +204,7 @@ fun <A> rpcDriver(
systemProperties = systemProperties, systemProperties = systemProperties,
driverDirectory = driverDirectory.toAbsolutePath(), driverDirectory = driverDirectory.toAbsolutePath(),
useTestClock = useTestClock, useTestClock = useTestClock,
networkMapStrategy = networkMapStrategy, networkMapStartStrategy = networkMapStartStrategy,
isDebug = isDebug isDebug = isDebug
) )
), ),

View File

@ -78,7 +78,7 @@ fun <A> verifierDriver(
debugPortAllocation: PortAllocation = PortAllocation.Incremental(5005), debugPortAllocation: PortAllocation = PortAllocation.Incremental(5005),
systemProperties: Map<String, String> = emptyMap(), systemProperties: Map<String, String> = emptyMap(),
useTestClock: Boolean = false, useTestClock: Boolean = false,
networkMapStrategy: NetworkMapStrategy = DedicatedNetworkMap, networkMapStartStrategy: NetworkMapStartStrategy = NetworkMapStartStrategy.Dedicated(startAutomatically = false),
dsl: VerifierExposedDSLInterface.() -> A dsl: VerifierExposedDSLInterface.() -> A
) = genericDriver( ) = genericDriver(
driverDsl = VerifierDriverDSL( driverDsl = VerifierDriverDSL(
@ -88,7 +88,7 @@ fun <A> verifierDriver(
systemProperties = systemProperties, systemProperties = systemProperties,
driverDirectory = driverDirectory.toAbsolutePath(), driverDirectory = driverDirectory.toAbsolutePath(),
useTestClock = useTestClock, useTestClock = useTestClock,
networkMapStrategy = networkMapStrategy, networkMapStartStrategy = networkMapStartStrategy,
isDebug = isDebug isDebug = isDebug
) )
), ),

View File

@ -13,7 +13,7 @@ import net.corda.core.utilities.ALICE
import net.corda.core.utilities.DUMMY_NOTARY import net.corda.core.utilities.DUMMY_NOTARY
import net.corda.flows.CashIssueFlow import net.corda.flows.CashIssueFlow
import net.corda.flows.CashPaymentFlow import net.corda.flows.CashPaymentFlow
import net.corda.node.driver.FalseNetworkMap import net.corda.node.driver.NetworkMapStartStrategy
import net.corda.node.services.config.VerifierType import net.corda.node.services.config.VerifierType
import net.corda.node.services.transactions.ValidatingNotaryService import net.corda.node.services.transactions.ValidatingNotaryService
import org.junit.Test import org.junit.Test
@ -35,7 +35,7 @@ class VerifierTests {
@Test @Test
fun `single verifier works with requestor`() { fun `single verifier works with requestor`() {
verifierDriver(networkMapStrategy = FalseNetworkMap) { verifierDriver {
val aliceFuture = startVerificationRequestor(ALICE.name) val aliceFuture = startVerificationRequestor(ALICE.name)
val transactions = generateTransactions(100) val transactions = generateTransactions(100)
val alice = aliceFuture.get() val alice = aliceFuture.get()
@ -52,7 +52,7 @@ class VerifierTests {
@Test @Test
fun `multiple verifiers work with requestor`() { fun `multiple verifiers work with requestor`() {
verifierDriver(networkMapStrategy = FalseNetworkMap) { verifierDriver {
val aliceFuture = startVerificationRequestor(ALICE.name) val aliceFuture = startVerificationRequestor(ALICE.name)
val transactions = generateTransactions(100) val transactions = generateTransactions(100)
val alice = aliceFuture.get() val alice = aliceFuture.get()
@ -72,7 +72,7 @@ class VerifierTests {
@Test @Test
fun `verification redistributes on verifier death`() { fun `verification redistributes on verifier death`() {
verifierDriver(networkMapStrategy = FalseNetworkMap) { verifierDriver {
val aliceFuture = startVerificationRequestor(ALICE.name) val aliceFuture = startVerificationRequestor(ALICE.name)
val numberOfTransactions = 100 val numberOfTransactions = 100
val transactions = generateTransactions(numberOfTransactions) val transactions = generateTransactions(numberOfTransactions)
@ -100,7 +100,7 @@ class VerifierTests {
@Test @Test
fun `verification request waits until verifier comes online`() { fun `verification request waits until verifier comes online`() {
verifierDriver(networkMapStrategy = FalseNetworkMap) { verifierDriver {
val aliceFuture = startVerificationRequestor(ALICE.name) val aliceFuture = startVerificationRequestor(ALICE.name)
val transactions = generateTransactions(100) val transactions = generateTransactions(100)
val alice = aliceFuture.get() val alice = aliceFuture.get()
@ -112,7 +112,7 @@ class VerifierTests {
@Test @Test
fun `single verifier works with a node`() { fun `single verifier works with a node`() {
verifierDriver { verifierDriver(networkMapStartStrategy = NetworkMapStartStrategy.Dedicated(startAutomatically = true)) {
val aliceFuture = startNode(ALICE.name) val aliceFuture = startNode(ALICE.name)
val notaryFuture = startNode(DUMMY_NOTARY.name, advertisedServices = setOf(ServiceInfo(ValidatingNotaryService.type)), verifierType = VerifierType.OutOfProcess) val notaryFuture = startNode(DUMMY_NOTARY.name, advertisedServices = setOf(ServiceInfo(ValidatingNotaryService.type)), verifierType = VerifierType.OutOfProcess)
val alice = aliceFuture.get() val alice = aliceFuture.get()