Merge pull request #665 from corda/aslemmer-cleanup-rpc-resources-on-startup-failure

RPC: call close() on startup failure, add thread leak tests
This commit is contained in:
Andras Slemmer 2017-05-15 19:18:37 +01:00 committed by GitHub
commit 63d5aa03e9
16 changed files with 423 additions and 238 deletions

View File

@ -5,28 +5,181 @@ import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.pool.KryoPool import com.esotericsoftware.kryo.pool.KryoPool
import com.google.common.net.HostAndPort
import com.google.common.util.concurrent.Futures import com.google.common.util.concurrent.Futures
import net.corda.core.getOrThrow import net.corda.client.rpc.internal.RPCClient
import net.corda.client.rpc.internal.RPCClientConfiguration
import net.corda.core.*
import net.corda.core.messaging.RPCOps import net.corda.core.messaging.RPCOps
import net.corda.core.millis import net.corda.node.driver.poll
import net.corda.core.random63BitValue
import net.corda.node.services.messaging.RPCServerConfiguration import net.corda.node.services.messaging.RPCServerConfiguration
import net.corda.nodeapi.RPCApi import net.corda.nodeapi.RPCApi
import net.corda.nodeapi.RPCKryo import net.corda.nodeapi.RPCKryo
import net.corda.testing.* import net.corda.testing.*
import org.apache.activemq.artemis.api.core.SimpleString import org.apache.activemq.artemis.api.core.SimpleString
import org.junit.Assert.assertEquals
import org.junit.Assert.assertTrue
import org.junit.Test import org.junit.Test
import rx.Observable import rx.Observable
import rx.subjects.PublishSubject import rx.subjects.PublishSubject
import rx.subjects.UnicastSubject import rx.subjects.UnicastSubject
import java.time.Duration import java.time.Duration
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.Executors
import java.util.concurrent.ScheduledExecutorService
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.atomic.AtomicInteger
import kotlin.test.assertEquals
class RPCStabilityTests { class RPCStabilityTests {
object DummyOps : RPCOps {
override val protocolVersion = 0
}
private fun waitUntilNumberOfThreadsStable(executorService: ScheduledExecutorService): Int {
val values = ConcurrentLinkedQueue<Int>()
return poll(executorService, "number of threads to become stable", 250.millis) {
values.add(Thread.activeCount())
if (values.size > 5) {
values.poll()
}
val first = values.peek()
if (values.size == 5 && values.all { it == first }) {
first
} else {
null
}
}.get()
}
@Test
fun `client and server dont leak threads`() {
val executor = Executors.newScheduledThreadPool(1)
fun startAndStop() {
rpcDriver {
val server = startRpcServer<RPCOps>(ops = DummyOps)
startRpcClient<RPCOps>(server.get().broker.hostAndPort!!).get()
}
}
repeat(5) {
startAndStop()
}
val numberOfThreadsBefore = waitUntilNumberOfThreadsStable(executor)
repeat(5) {
startAndStop()
}
val numberOfThreadsAfter = waitUntilNumberOfThreadsStable(executor)
// This is a less than check because threads from other tests may be shutting down while this test is running.
// This is therefore a "best effort" check. When this test is run on its own this should be a strict equality.
assertTrue(numberOfThreadsBefore >= numberOfThreadsAfter)
executor.shutdownNow()
}
@Test
fun `client doesnt leak threads when it fails to start`() {
val executor = Executors.newScheduledThreadPool(1)
fun startAndStop() {
rpcDriver {
ErrorOr.catch { startRpcClient<RPCOps>(HostAndPort.fromString("localhost:9999")).get() }
val server = startRpcServer<RPCOps>(ops = DummyOps)
ErrorOr.catch { startRpcClient<RPCOps>(
server.get().broker.hostAndPort!!,
configuration = RPCClientConfiguration.default.copy(minimumServerProtocolVersion = 1)
).get() }
}
}
repeat(5) {
startAndStop()
}
val numberOfThreadsBefore = waitUntilNumberOfThreadsStable(executor)
repeat(5) {
startAndStop()
}
val numberOfThreadsAfter = waitUntilNumberOfThreadsStable(executor)
assertTrue(numberOfThreadsBefore >= numberOfThreadsAfter)
executor.shutdownNow()
}
fun RpcBrokerHandle.getStats(): Map<String, Any> {
return serverControl.run {
mapOf(
"connections" to listConnectionIDs().toSet(),
"sessionCount" to listConnectionIDs().flatMap { listSessions(it).toList() }.size,
"consumerCount" to totalConsumerCount
)
}
}
@Test
fun `rpc server close doesnt leak broker resources`() {
rpcDriver {
fun startAndCloseServer(broker: RpcBrokerHandle) {
startRpcServerWithBrokerRunning(
configuration = RPCServerConfiguration.default.copy(consumerPoolSize = 1, producerPoolBound = 1),
ops = DummyOps,
brokerHandle = broker
).rpcServer.close()
}
val broker = startRpcBroker().get()
startAndCloseServer(broker)
val initial = broker.getStats()
repeat(100) {
startAndCloseServer(broker)
}
pollUntilTrue("broker resources to be released") {
initial == broker.getStats()
}
}
}
@Test
fun `rpc client close doesnt leak broker resources`() {
rpcDriver {
val server = startRpcServer(configuration = RPCServerConfiguration.default.copy(consumerPoolSize = 1, producerPoolBound = 1), ops = DummyOps).get()
RPCClient<RPCOps>(server.broker.hostAndPort!!).start(RPCOps::class.java, rpcTestUser.username, rpcTestUser.password).close()
val initial = server.broker.getStats()
repeat(100) {
val connection = RPCClient<RPCOps>(server.broker.hostAndPort!!).start(RPCOps::class.java, rpcTestUser.username, rpcTestUser.password)
connection.close()
}
pollUntilTrue("broker resources to be released") {
initial == server.broker.getStats()
}
}
}
@Test
fun `rpc server close is idempotent`() {
rpcDriver {
val server = startRpcServer(ops = DummyOps).get()
repeat(10) {
server.rpcServer.close()
}
}
}
@Test
fun `rpc client close is idempotent`() {
rpcDriver {
val serverShutdown = shutdownManager.follower()
val server = startRpcServer(ops = DummyOps).get()
serverShutdown.unfollow()
// With the server up
val connection1 = RPCClient<RPCOps>(server.broker.hostAndPort!!).start(RPCOps::class.java, rpcTestUser.username, rpcTestUser.password)
repeat(10) {
connection1.close()
}
val connection2 = RPCClient<RPCOps>(server.broker.hostAndPort!!).start(RPCOps::class.java, rpcTestUser.username, rpcTestUser.password)
serverShutdown.shutdown()
// With the server down
repeat(10) {
connection2.close()
}
}
}
interface LeakObservableOps: RPCOps { interface LeakObservableOps: RPCOps {
fun leakObservable(): Observable<Nothing> fun leakObservable(): Observable<Nothing>
} }
@ -44,7 +197,7 @@ class RPCStabilityTests {
} }
} }
val server = startRpcServer<LeakObservableOps>(ops = leakObservableOpsImpl) val server = startRpcServer<LeakObservableOps>(ops = leakObservableOpsImpl)
val proxy = startRpcClient<LeakObservableOps>(server.get().hostAndPort).get() val proxy = startRpcClient<LeakObservableOps>(server.get().broker.hostAndPort!!).get()
// Leak many observables // Leak many observables
val N = 200 val N = 200
(1..N).toList().parallelStream().forEach { (1..N).toList().parallelStream().forEach {
@ -71,7 +224,7 @@ class RPCStabilityTests {
override fun ping() = "pong" override fun ping() = "pong"
} }
val serverFollower = shutdownManager.follower() val serverFollower = shutdownManager.follower()
val serverPort = startRpcServer<ReconnectOps>(ops = ops).getOrThrow().hostAndPort val serverPort = startRpcServer<ReconnectOps>(ops = ops).getOrThrow().broker.hostAndPort!!
serverFollower.unfollow() serverFollower.unfollow()
val clientFollower = shutdownManager.follower() val clientFollower = shutdownManager.follower()
val client = startRpcClient<ReconnectOps>(serverPort).getOrThrow() val client = startRpcClient<ReconnectOps>(serverPort).getOrThrow()
@ -113,7 +266,7 @@ class RPCStabilityTests {
val numberOfClients = 4 val numberOfClients = 4
val clients = Futures.allAsList((1 .. numberOfClients).map { val clients = Futures.allAsList((1 .. numberOfClients).map {
startRandomRpcClient<TrackSubscriberOps>(server.hostAndPort) startRandomRpcClient<TrackSubscriberOps>(server.broker.hostAndPort!!)
}).get() }).get()
// Poll until all clients connect // Poll until all clients connect
@ -158,7 +311,7 @@ class RPCStabilityTests {
// Construct an RPC session manually so that we can hang in the message handler // Construct an RPC session manually so that we can hang in the message handler
val myQueue = "${RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX}.test.${random63BitValue()}" val myQueue = "${RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX}.test.${random63BitValue()}"
val session = startArtemisSession(server.hostAndPort) val session = startArtemisSession(server.broker.hostAndPort!!)
session.createTemporaryQueue(myQueue, myQueue) session.createTemporaryQueue(myQueue, myQueue)
val consumer = session.createConsumer(myQueue, null, -1, -1, false) val consumer = session.createConsumer(myQueue, null, -1, -1, false)
consumer.setMessageHandler { consumer.setMessageHandler {
@ -190,7 +343,7 @@ class RPCStabilityTests {
fun RPCDriverExposedDSLInterface.pollUntilClientNumber(server: RpcServerHandle, expected: Int) { fun RPCDriverExposedDSLInterface.pollUntilClientNumber(server: RpcServerHandle, expected: Int) {
pollUntilTrue("number of RPC clients to become $expected") { pollUntilTrue("number of RPC clients to become $expected") {
val clientAddresses = server.serverControl.addressNames.filter { it.startsWith(RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX) } val clientAddresses = server.broker.serverControl.addressNames.filter { it.startsWith(RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX) }
clientAddresses.size == expected clientAddresses.size == expected
}.get() }.get()
} }

View File

@ -147,26 +147,32 @@ class RPCClient<I : RPCOps>(
} }
val proxyHandler = RPCClientProxyHandler(rpcConfiguration, username, password, serverLocator, clientAddress, rpcOpsClass) val proxyHandler = RPCClientProxyHandler(rpcConfiguration, username, password, serverLocator, clientAddress, rpcOpsClass)
proxyHandler.start() try {
proxyHandler.start()
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
val ops = Proxy.newProxyInstance(rpcOpsClass.classLoader, arrayOf(rpcOpsClass), proxyHandler) as I val ops = Proxy.newProxyInstance(rpcOpsClass.classLoader, arrayOf(rpcOpsClass), proxyHandler) as I
val serverProtocolVersion = ops.protocolVersion val serverProtocolVersion = ops.protocolVersion
if (serverProtocolVersion < rpcConfiguration.minimumServerProtocolVersion) { if (serverProtocolVersion < rpcConfiguration.minimumServerProtocolVersion) {
throw RPCException("Requested minimum protocol version (${rpcConfiguration.minimumServerProtocolVersion}) is higher" + throw RPCException("Requested minimum protocol version (${rpcConfiguration.minimumServerProtocolVersion}) is higher" +
" than the server's supported protocol version ($serverProtocolVersion)") " than the server's supported protocol version ($serverProtocolVersion)")
}
proxyHandler.setServerProtocolVersion(serverProtocolVersion)
log.debug("RPC connected, returning proxy")
object : RPCConnection<I> {
override val proxy = ops
override val serverProtocolVersion = serverProtocolVersion
override fun close() {
proxyHandler.close()
serverLocator.close()
} }
proxyHandler.setServerProtocolVersion(serverProtocolVersion)
log.debug("RPC connected, returning proxy")
object : RPCConnection<I> {
override val proxy = ops
override val serverProtocolVersion = serverProtocolVersion
override fun close() {
proxyHandler.close()
serverLocator.close()
}
}
} catch (exception: Throwable) {
proxyHandler.close()
serverLocator.close()
throw exception
} }
} }
} }

View File

@ -22,19 +22,15 @@ import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE
import org.apache.activemq.artemis.api.core.client.ClientMessage import org.apache.activemq.artemis.api.core.client.ClientMessage
import org.apache.activemq.artemis.api.core.client.ServerLocator import org.apache.activemq.artemis.api.core.client.ServerLocator
import org.apache.activemq.artemis.core.client.impl.ClientConsumerInternal
import rx.Notification import rx.Notification
import rx.Observable import rx.Observable
import rx.subjects.UnicastSubject import rx.subjects.UnicastSubject
import sun.reflect.CallerSensitive
import java.lang.reflect.InvocationHandler import java.lang.reflect.InvocationHandler
import java.lang.reflect.Method import java.lang.reflect.Method
import java.util.* import java.util.*
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.*
import java.util.concurrent.Executors
import java.util.concurrent.ScheduledFuture
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.atomic.AtomicInteger
import kotlin.collections.ArrayList
import kotlin.reflect.jvm.javaMethod import kotlin.reflect.jvm.javaMethod
/** /**
@ -87,10 +83,7 @@ class RPCClientProxyHandler(
} }
// Used for reaping // Used for reaping
private val reaperExecutor = Executors.newScheduledThreadPool( private var reaperExecutor: ScheduledExecutorService? = null
1,
ThreadFactoryBuilder().setNameFormat("rpc-client-reaper-%d").build()
)
// A sticky pool for running Observable.onNext()s. We need the stickiness to preserve the observation ordering. // A sticky pool for running Observable.onNext()s. We need the stickiness to preserve the observation ordering.
private val observationExecutorThreadFactory = ThreadFactoryBuilder().setNameFormat("rpc-client-observation-pool-%d").build() private val observationExecutorThreadFactory = ThreadFactoryBuilder().setNameFormat("rpc-client-observation-pool-%d").build()
@ -109,7 +102,7 @@ class RPCClientProxyHandler(
hardReferenceStore = Collections.synchronizedSet(mutableSetOf<Observable<*>>()) hardReferenceStore = Collections.synchronizedSet(mutableSetOf<Observable<*>>())
) )
// Holds a reference to the scheduled reaper. // Holds a reference to the scheduled reaper.
private lateinit var reaperScheduledFuture: ScheduledFuture<*> private var reaperScheduledFuture: ScheduledFuture<*>? = null
// The protocol version of the server, to be initialised to the value of [RPCOps.protocolVersion] // The protocol version of the server, to be initialised to the value of [RPCOps.protocolVersion]
private var serverProtocolVersion: Int? = null private var serverProtocolVersion: Int? = null
@ -145,7 +138,7 @@ class RPCClientProxyHandler(
// TODO We may need to pool these somehow anyway, otherwise if the server sends many big messages in parallel a // TODO We may need to pool these somehow anyway, otherwise if the server sends many big messages in parallel a
// single consumer may be starved for flow control credits. Recheck this once Artemis's large message streaming is // single consumer may be starved for flow control credits. Recheck this once Artemis's large message streaming is
// integrated properly. // integrated properly.
private lateinit var sessionAndConsumer: ArtemisConsumer private var sessionAndConsumer: ArtemisConsumer? = null
// Pool producers to reduce contention on the client side. // Pool producers to reduce contention on the client side.
private val sessionAndProducerPool = LazyPool(bound = rpcConfiguration.producerPoolBound) { private val sessionAndProducerPool = LazyPool(bound = rpcConfiguration.producerPoolBound) {
// Note how we create new sessions *and* session factories per producer. // Note how we create new sessions *and* session factories per producer.
@ -162,7 +155,12 @@ class RPCClientProxyHandler(
* Start the client. This creates the per-client queue, starts the consumer session and the reaper. * Start the client. This creates the per-client queue, starts the consumer session and the reaper.
*/ */
fun start() { fun start() {
reaperScheduledFuture = reaperExecutor.scheduleAtFixedRate( lifeCycle.requireState(State.UNSTARTED)
reaperExecutor = Executors.newScheduledThreadPool(
1,
ThreadFactoryBuilder().setNameFormat("rpc-client-reaper-%d").build()
)
reaperScheduledFuture = reaperExecutor!!.scheduleAtFixedRate(
this::reapObservables, this::reapObservables,
rpcConfiguration.reapInterval.toMillis(), rpcConfiguration.reapInterval.toMillis(),
rpcConfiguration.reapInterval.toMillis(), rpcConfiguration.reapInterval.toMillis(),
@ -187,7 +185,7 @@ class RPCClientProxyHandler(
if (method == toStringMethod) { if (method == toStringMethod) {
return "Client RPC proxy for $rpcOpsClass" return "Client RPC proxy for $rpcOpsClass"
} }
if (sessionAndConsumer.session.isClosed) { if (sessionAndConsumer!!.session.isClosed) {
throw RPCException("RPC Proxy is closed") throw RPCException("RPC Proxy is closed")
} }
val rpcId = RPCApi.RpcRequestId(random63BitValue()) val rpcId = RPCApi.RpcRequestId(random63BitValue())
@ -268,24 +266,19 @@ class RPCClientProxyHandler(
* Closes the RPC proxy. Reaps all observables, shuts down the reaper, closes all sessions and executors. * Closes the RPC proxy. Reaps all observables, shuts down the reaper, closes all sessions and executors.
*/ */
fun close() { fun close() {
sessionAndConsumer.consumer.close() sessionAndConsumer?.sessionFactory?.close()
sessionAndConsumer.session.close() reaperScheduledFuture?.cancel(false)
sessionAndConsumer.sessionFactory.close()
reaperScheduledFuture.cancel(false)
observableContext.observableMap.invalidateAll() observableContext.observableMap.invalidateAll()
reapObservables() reapObservables()
reaperExecutor.shutdownNow() reaperExecutor?.shutdownNow()
sessionAndProducerPool.close().forEach { sessionAndProducerPool.close().forEach {
it.producer.close()
it.session.close()
it.sessionFactory.close() it.sessionFactory.close()
} }
// Note the ordering is important, we shut down the consumer *before* the observation executor, otherwise we may // Note the ordering is important, we shut down the consumer *before* the observation executor, otherwise we may
// leak borrowed executors. // leak borrowed executors.
val observationExecutors = observationExecutorPool.close() val observationExecutors = observationExecutorPool.close()
observationExecutors.forEach { it.shutdownNow() } observationExecutors.forEach { it.shutdownNow() }
observationExecutors.forEach { it.awaitTermination(100, TimeUnit.MILLISECONDS) } lifeCycle.justTransition(State.FINISHED)
lifeCycle.transition(State.STARTED, State.FINISHED)
} }
/** /**

View File

@ -47,8 +47,8 @@ open class AbstractRPCTest {
}.get() }.get()
RPCTestMode.Netty -> RPCTestMode.Netty ->
startRpcServer(ops = ops, rpcUser = rpcUser, configuration = serverConfiguration).flatMap { server -> startRpcServer(ops = ops, rpcUser = rpcUser, configuration = serverConfiguration).flatMap { server ->
startRpcClient<I>(server.hostAndPort, rpcUser.username, rpcUser.password, clientConfiguration).map { startRpcClient<I>(server.broker.hostAndPort!!, rpcUser.username, rpcUser.password, clientConfiguration).map {
TestProxy(it, { startArtemisSession(server.hostAndPort, rpcUser.username, rpcUser.password) }) TestProxy(it, { startArtemisSession(server.broker.hostAndPort!!, rpcUser.username, rpcUser.password) })
} }
}.get() }.get()
} }

View File

@ -2,6 +2,7 @@ package net.corda.core.contracts
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.DeserializeAsKotlinObjectDef
import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.LedgerTransaction
import net.corda.core.transactions.TransactionBuilder import net.corda.core.transactions.TransactionBuilder
import java.security.PublicKey import java.security.PublicKey
@ -60,7 +61,7 @@ sealed class TransactionType {
abstract fun verifyTransaction(tx: LedgerTransaction) abstract fun verifyTransaction(tx: LedgerTransaction)
/** A general transaction type where transaction validity is determined by custom contract code */ /** A general transaction type where transaction validity is determined by custom contract code */
object General : TransactionType() { object General : TransactionType(), DeserializeAsKotlinObjectDef {
/** Just uses the default [TransactionBuilder] with no special logic */ /** Just uses the default [TransactionBuilder] with no special logic */
class Builder(notary: Party?) : TransactionBuilder(General, notary) class Builder(notary: Party?) : TransactionBuilder(General, notary)
@ -140,7 +141,7 @@ sealed class TransactionType {
* A special transaction type for reassigning a notary for a state. Validation does not involve running * A special transaction type for reassigning a notary for a state. Validation does not involve running
* any contract code, it just checks that the states are unmodified apart from the notary field. * any contract code, it just checks that the states are unmodified apart from the notary field.
*/ */
object NotaryChange : TransactionType() { object NotaryChange : TransactionType(), DeserializeAsKotlinObjectDef {
/** /**
* A transaction builder that automatically sets the transaction type to [NotaryChange] * A transaction builder that automatically sets the transaction type to [NotaryChange]
* and adds the list of participants to the signers set for every input state. * and adds the list of participants to the signers set for every input state.

View File

@ -59,8 +59,10 @@ class LazyPool<A>(
* the returned iterable will be inaccurate. * the returned iterable will be inaccurate.
*/ */
fun close(): Iterable<A> { fun close(): Iterable<A> {
lifeCycle.transition(State.STARTED, State.FINISHED) lifeCycle.justTransition(State.FINISHED)
return poolQueue val elements = poolQueue.toList()
poolQueue.clear()
return elements
} }
inline fun <R> run(withInstance: (A) -> R): R { inline fun <R> run(withInstance: (A) -> R): R {

View File

@ -13,7 +13,7 @@ class LifeCycle<S : Enum<S>>(initial: S) {
private val lock = ReentrantReadWriteLock() private val lock = ReentrantReadWriteLock()
private var state = initial private var state = initial
/** Assert that the lifecycle in the [requiredState] */ /** Assert that the lifecycle in the [requiredState]. */
fun requireState(requiredState: S) { fun requireState(requiredState: S) {
requireState({ "Required state to be $requiredState, was $it" }) { it == requiredState } requireState({ "Required state to be $requiredState, was $it" }) { it == requiredState }
} }
@ -28,11 +28,18 @@ class LifeCycle<S : Enum<S>>(initial: S) {
} }
} }
/** Transition the state from [from] to [to] */ /** Transition the state from [from] to [to]. */
fun transition(from: S, to: S) { fun transition(from: S, to: S) {
lock.writeLock().withLock { lock.writeLock().withLock {
require(state == from) { "Required state to be $from to transition to $to, was $state" } require(state == from) { "Required state to be $from to transition to $to, was $state" }
state = to state = to
} }
} }
/** Transition the state to [to] without performing a current state check. */
fun justTransition(to: S) {
lock.writeLock().withLock {
state = to
}
}
} }

View File

@ -109,7 +109,7 @@ class ContractUpgradeFlowTest {
rpcAddress = startRpcServer( rpcAddress = startRpcServer(
rpcUser = user, rpcUser = user,
ops = CordaRPCOpsImpl(node.services, node.smm, node.database) ops = CordaRPCOpsImpl(node.services, node.smm, node.database)
).get().hostAndPort, ).get().broker.hostAndPort!!,
username = user.username, username = user.username,
password = user.password password = user.password
).get() ).get()

View File

@ -1,10 +1,11 @@
package net.corda.node package net.corda.node
import com.google.common.base.Stopwatch import com.google.common.base.Stopwatch
import net.corda.node.driver.FalseNetworkMap import net.corda.node.driver.NetworkMapStartStrategy
import net.corda.node.driver.driver import net.corda.node.driver.driver
import org.junit.Ignore import org.junit.Ignore
import org.junit.Test import org.junit.Test
import java.util.*
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
@Ignore("Only use locally") @Ignore("Only use locally")
@ -13,8 +14,8 @@ class NodeStartupPerformanceTests {
// Measure the startup time of nodes. Note that this includes an RPC roundtrip, which causes e.g. Kryo initialisation. // Measure the startup time of nodes. Note that this includes an RPC roundtrip, which causes e.g. Kryo initialisation.
@Test @Test
fun `single node startup time`() { fun `single node startup time`() {
driver(networkMapStrategy = FalseNetworkMap) { driver(networkMapStartStrategy = NetworkMapStartStrategy.Dedicated(startAutomatically = false)) {
startNetworkMapService().get() startDedicatedNetworkMapService().get()
val times = ArrayList<Long>() val times = ArrayList<Long>()
for (i in 1 .. 10) { for (i in 1 .. 10) {
val time = Stopwatch.createStarted().apply { val time = Stopwatch.createStarted().apply {

View File

@ -104,7 +104,7 @@ interface DriverDSLExposedInterface {
* Starts a network map service node. Note that only a single one should ever be running, so you will probably want * Starts a network map service node. Note that only a single one should ever be running, so you will probably want
* to set networkMapStrategy to FalseNetworkMap in your [driver] call. * to set networkMapStrategy to FalseNetworkMap in your [driver] call.
*/ */
fun startNetworkMapService(): ListenableFuture<Unit> fun startDedicatedNetworkMapService(): ListenableFuture<Unit>
fun waitForAllNodesToFinish() fun waitForAllNodesToFinish()
@ -168,6 +168,11 @@ sealed class PortAllocation {
} }
} }
sealed class NetworkMapStartStrategy {
data class Dedicated(val startAutomatically: Boolean) : NetworkMapStartStrategy()
data class Nominated(val legalName: X500Name, val address: HostAndPort) : NetworkMapStartStrategy()
}
/** /**
* [driver] allows one to start up nodes like this: * [driver] allows one to start up nodes like this:
* driver { * driver {
@ -201,7 +206,7 @@ fun <A> driver(
debugPortAllocation: PortAllocation = PortAllocation.Incremental(5005), debugPortAllocation: PortAllocation = PortAllocation.Incremental(5005),
systemProperties: Map<String, String> = emptyMap(), systemProperties: Map<String, String> = emptyMap(),
useTestClock: Boolean = false, useTestClock: Boolean = false,
networkMapStrategy: NetworkMapStrategy = DedicatedNetworkMap, networkMapStartStrategy: NetworkMapStartStrategy = NetworkMapStartStrategy.Dedicated(startAutomatically = true),
dsl: DriverDSLExposedInterface.() -> A dsl: DriverDSLExposedInterface.() -> A
) = genericDriver( ) = genericDriver(
driverDsl = DriverDSL( driverDsl = DriverDSL(
@ -210,7 +215,7 @@ fun <A> driver(
systemProperties = systemProperties, systemProperties = systemProperties,
driverDirectory = driverDirectory.toAbsolutePath(), driverDirectory = driverDirectory.toAbsolutePath(),
useTestClock = useTestClock, useTestClock = useTestClock,
networkMapStrategy = networkMapStrategy, networkMapStartStrategy = networkMapStartStrategy,
isDebug = isDebug isDebug = isDebug
), ),
coerce = { it }, coerce = { it },
@ -412,13 +417,14 @@ class DriverDSL(
val driverDirectory: Path, val driverDirectory: Path,
val useTestClock: Boolean, val useTestClock: Boolean,
val isDebug: Boolean, val isDebug: Boolean,
val networkMapStrategy: NetworkMapStrategy val networkMapStartStrategy: NetworkMapStartStrategy
) : DriverDSLInternalInterface { ) : DriverDSLInternalInterface {
private val dedicatedNetworkMapAddress = portAllocation.nextHostAndPort() private val dedicatedNetworkMapAddress = portAllocation.nextHostAndPort()
val executorService: ListeningScheduledExecutorService = MoreExecutors.listeningDecorator( private val dedicatedNetworkMapLegalName = DUMMY_MAP.name
Executors.newScheduledThreadPool(2, ThreadFactoryBuilder().setNameFormat("driver-pool-thread-%d").build()) var _executorService: ListeningScheduledExecutorService? = null
) val executorService get() = _executorService!!
override val shutdownManager = ShutdownManager(executorService) var _shutdownManager: ShutdownManager? = null
override val shutdownManager get() = _shutdownManager!!
class State { class State {
val processes = ArrayList<ListenableFuture<Process>>() val processes = ArrayList<ListenableFuture<Process>>()
@ -449,8 +455,8 @@ class DriverDSL(
} }
override fun shutdown() { override fun shutdown() {
shutdownManager.shutdown() _shutdownManager?.shutdown()
executorService.shutdown() _executorService?.shutdownNow()
} }
private fun establishRpc(nodeAddress: HostAndPort, sslConfig: SSLConfiguration): ListenableFuture<CordaRPCOps> { private fun establishRpc(nodeAddress: HostAndPort, sslConfig: SSLConfiguration): ListenableFuture<CordaRPCOps> {
@ -467,6 +473,12 @@ class DriverDSL(
} }
} }
// TODO move to cmopanion
private fun toServiceConfig(address: HostAndPort, legalName: X500Name) = mapOf(
"address" to address.toString(),
"legalName" to legalName.toString()
)
override fun startNode( override fun startNode(
providedName: X500Name?, providedName: X500Name?,
advertisedServices: Set<ServiceInfo>, advertisedServices: Set<ServiceInfo>,
@ -487,7 +499,17 @@ class DriverDSL(
"rpcAddress" to rpcAddress.toString(), "rpcAddress" to rpcAddress.toString(),
"webAddress" to webAddress.toString(), "webAddress" to webAddress.toString(),
"extraAdvertisedServiceIds" to advertisedServices.map { it.toString() }, "extraAdvertisedServiceIds" to advertisedServices.map { it.toString() },
"networkMapService" to networkMapStrategy.serviceConfig(dedicatedNetworkMapAddress, name, p2pAddress), "networkMapService" to when (networkMapStartStrategy) {
is NetworkMapStartStrategy.Dedicated -> toServiceConfig(dedicatedNetworkMapAddress, dedicatedNetworkMapLegalName)
is NetworkMapStartStrategy.Nominated -> networkMapStartStrategy.run {
if (name != legalName) {
toServiceConfig(address, legalName)
} else {
p2pAddress == address || throw IllegalArgumentException("Passed-in address $address of nominated network map $legalName is wrong, it should be: $p2pAddress")
null
}
}
},
"useTestClock" to useTestClock, "useTestClock" to useTestClock,
"rpcUsers" to rpcUsers.map { "rpcUsers" to rpcUsers.map {
mapOf( mapOf(
@ -574,21 +596,24 @@ class DriverDSL(
} }
override fun start() { override fun start() {
if (networkMapStrategy.startDedicated) { _executorService = MoreExecutors.listeningDecorator(
startNetworkMapService() Executors.newScheduledThreadPool(2, ThreadFactoryBuilder().setNameFormat("driver-pool-thread-%d").build())
)
_shutdownManager = ShutdownManager(executorService)
if (networkMapStartStrategy is NetworkMapStartStrategy.Dedicated && networkMapStartStrategy.startAutomatically) {
startDedicatedNetworkMapService()
} }
} }
override fun startNetworkMapService(): ListenableFuture<Unit> { override fun startDedicatedNetworkMapService(): ListenableFuture<Unit> {
val debugPort = if (isDebug) debugPortAllocation.nextPort() else null val debugPort = if (isDebug) debugPortAllocation.nextPort() else null
val apiAddress = portAllocation.nextHostAndPort().toString() val apiAddress = portAllocation.nextHostAndPort().toString()
val networkMapLegalName = networkMapStrategy.legalName val baseDirectory = driverDirectory / dedicatedNetworkMapLegalName.commonName
val baseDirectory = driverDirectory / networkMapLegalName.commonName
val config = ConfigHelper.loadConfig( val config = ConfigHelper.loadConfig(
baseDirectory = baseDirectory, baseDirectory = baseDirectory,
allowMissingConfig = true, allowMissingConfig = true,
configOverrides = mapOf( configOverrides = mapOf(
"myLegalName" to networkMapLegalName.toString(), "myLegalName" to dedicatedNetworkMapLegalName.toString(),
// TODO: remove the webAddress as NMS doesn't need to run a web server. This will cause all // TODO: remove the webAddress as NMS doesn't need to run a web server. This will cause all
// node port numbers to be shifted, so all demos and docs need to be updated accordingly. // node port numbers to be shifted, so all demos and docs need to be updated accordingly.
"webAddress" to apiAddress, "webAddress" to apiAddress,
@ -684,3 +709,4 @@ fun writeConfig(path: Path, filename: String, config: Config) {
path.toFile().mkdirs() path.toFile().mkdirs()
File("$path/$filename").writeText(config.root().render(ConfigRenderOptions.defaults())) File("$path/$filename").writeText(config.root().render(ConfigRenderOptions.defaults()))
} }

View File

@ -1,47 +0,0 @@
package net.corda.node.driver
import com.google.common.net.HostAndPort
import net.corda.core.utilities.DUMMY_MAP
import org.bouncycastle.asn1.x500.X500Name
/**
* Instruct the driver how to set up the network map, if at all.
* @see FalseNetworkMap
* @see DedicatedNetworkMap
* @see NominatedNetworkMap
*/
abstract class NetworkMapStrategy(internal val startDedicated: Boolean, internal val legalName: X500Name) {
internal abstract fun serviceConfig(dedicatedAddress: HostAndPort, nodeName: X500Name, p2pAddress: HostAndPort): Map<String, String>?
}
private fun toServiceConfig(address: HostAndPort, legalName: X500Name) = mapOf(
"address" to address.toString(),
"legalName" to legalName.toString()
)
abstract class AbstractDedicatedNetworkMap(start: Boolean) : NetworkMapStrategy(start, DUMMY_MAP.name) {
override fun serviceConfig(dedicatedAddress: HostAndPort, nodeName: X500Name, p2pAddress: HostAndPort) = toServiceConfig(dedicatedAddress, legalName)
}
/**
* Do not start a network map.
*/
object FalseNetworkMap : AbstractDedicatedNetworkMap(false)
/**
* Start a dedicated node to host the network map.
*/
object DedicatedNetworkMap : AbstractDedicatedNetworkMap(true)
/**
* As in gradle-based demos, nominate a node to host the network map, so that there is one fewer node in total than in the [DedicatedNetworkMap] case.
* Will fail if the port you pass in does not match the P2P port the driver assigns to the named node.
*/
class NominatedNetworkMap(legalName: X500Name, private val address: HostAndPort) : NetworkMapStrategy(false, legalName) {
override fun serviceConfig(dedicatedAddress: HostAndPort, nodeName: X500Name, p2pAddress: HostAndPort) = if (nodeName != legalName) {
toServiceConfig(address, legalName)
} else {
p2pAddress == address || throw IllegalArgumentException("Passed-in address $address of nominated network map $legalName is wrong, it should be: $p2pAddress")
null
}
}

View File

@ -13,7 +13,6 @@ import com.google.common.collect.Multimaps
import com.google.common.collect.SetMultimap import com.google.common.collect.SetMultimap
import com.google.common.util.concurrent.ThreadFactoryBuilder import com.google.common.util.concurrent.ThreadFactoryBuilder
import net.corda.core.ErrorOr import net.corda.core.ErrorOr
import net.corda.core.crypto.commonName
import net.corda.core.messaging.RPCOps import net.corda.core.messaging.RPCOps
import net.corda.core.random63BitValue import net.corda.core.random63BitValue
import net.corda.core.seconds import net.corda.core.seconds
@ -42,10 +41,8 @@ import rx.Subscriber
import rx.Subscription import rx.Subscription
import java.lang.reflect.InvocationTargetException import java.lang.reflect.InvocationTargetException
import java.time.Duration import java.time.Duration
import java.util.concurrent.ExecutorService import java.util.*
import java.util.concurrent.Executors import java.util.concurrent.*
import java.util.concurrent.ScheduledFuture
import java.util.concurrent.TimeUnit
data class RPCServerConfiguration( data class RPCServerConfiguration(
/** The number of threads to use for handling RPC requests */ /** The number of threads to use for handling RPC requests */
@ -101,22 +98,11 @@ class RPCServer(
// A mapping from client addresses to IDs of associated Observables // A mapping from client addresses to IDs of associated Observables
private val clientAddressToObservables = Multimaps.synchronizedSetMultimap(HashMultimap.create<SimpleString, RPCApi.ObservableId>()) private val clientAddressToObservables = Multimaps.synchronizedSetMultimap(HashMultimap.create<SimpleString, RPCApi.ObservableId>())
// The scheduled reaper handle. // The scheduled reaper handle.
private lateinit var reaperScheduledFuture: ScheduledFuture<*> private var reaperScheduledFuture: ScheduledFuture<*>? = null
private val observationSendExecutor = Executors.newFixedThreadPool( private var observationSendExecutor: ExecutorService? = null
1, private var rpcExecutor: ScheduledExecutorService? = null
ThreadFactoryBuilder().setNameFormat("rpc-observation-sender-%d").build() private var reaperExecutor: ScheduledExecutorService? = null
)
private val rpcExecutor = Executors.newScheduledThreadPool(
rpcConfiguration.rpcThreadPoolSize,
ThreadFactoryBuilder().setNameFormat("rpc-server-handler-pool-%d").build()
)
private val reaperExecutor = Executors.newScheduledThreadPool(
1,
ThreadFactoryBuilder().setNameFormat("rpc-server-reaper-%d").build()
)
private val sessionAndConsumers = ArrayList<ArtemisConsumer>(rpcConfiguration.consumerPoolSize) private val sessionAndConsumers = ArrayList<ArtemisConsumer>(rpcConfiguration.consumerPoolSize)
private val sessionAndProducerPool = LazyStickyPool(rpcConfiguration.producerPoolBound) { private val sessionAndProducerPool = LazyStickyPool(rpcConfiguration.producerPoolBound) {
@ -125,8 +111,8 @@ class RPCServer(
session.start() session.start()
ArtemisProducer(sessionFactory, session, session.createProducer()) ArtemisProducer(sessionFactory, session, session.createProducer())
} }
private lateinit var clientBindingRemovalConsumer: ClientConsumer private var clientBindingRemovalConsumer: ClientConsumer? = null
private lateinit var serverControl: ActiveMQServerControl private var serverControl: ActiveMQServerControl? = null
private fun createObservableSubscriptionMap(): ObservableSubscriptionMap { private fun createObservableSubscriptionMap(): ObservableSubscriptionMap {
val onObservableRemove = RemovalListener<RPCApi.ObservableId, ObservableSubscription> { val onObservableRemove = RemovalListener<RPCApi.ObservableId, ObservableSubscription> {
@ -137,52 +123,64 @@ class RPCServer(
} }
fun start(activeMqServerControl: ActiveMQServerControl) { fun start(activeMqServerControl: ActiveMQServerControl) {
log.info("Starting RPC server with configuration $rpcConfiguration") try {
reaperScheduledFuture = reaperExecutor.scheduleAtFixedRate( lifeCycle.requireState(State.UNSTARTED)
this::reapSubscriptions, log.info("Starting RPC server with configuration $rpcConfiguration")
rpcConfiguration.reapInterval.toMillis(), observationSendExecutor = Executors.newFixedThreadPool(
rpcConfiguration.reapInterval.toMillis(), 1,
TimeUnit.MILLISECONDS ThreadFactoryBuilder().setNameFormat("rpc-observation-sender-%d").build()
) )
val sessions = ArrayList<ClientSession>() rpcExecutor = Executors.newScheduledThreadPool(
for (i in 1 .. rpcConfiguration.consumerPoolSize) { rpcConfiguration.rpcThreadPoolSize,
val sessionFactory = serverLocator.createSessionFactory() ThreadFactoryBuilder().setNameFormat("rpc-server-handler-pool-%d").build()
val session = sessionFactory.createSession(rpcServerUsername, rpcServerPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE) )
val consumer = session.createConsumer(RPCApi.RPC_SERVER_QUEUE_NAME) reaperExecutor = Executors.newScheduledThreadPool(
consumer.setMessageHandler(this@RPCServer::clientArtemisMessageHandler) 1,
sessionAndConsumers.add(ArtemisConsumer(sessionFactory, session, consumer)) ThreadFactoryBuilder().setNameFormat("rpc-server-reaper-%d").build()
sessions.add(session) )
} reaperScheduledFuture = reaperExecutor!!.scheduleAtFixedRate(
clientBindingRemovalConsumer = sessionAndConsumers[0].session.createConsumer(RPCApi.RPC_CLIENT_BINDING_REMOVALS) this::reapSubscriptions,
clientBindingRemovalConsumer.setMessageHandler(this::bindingRemovalArtemisMessageHandler) rpcConfiguration.reapInterval.toMillis(),
serverControl = activeMqServerControl rpcConfiguration.reapInterval.toMillis(),
lifeCycle.transition(State.UNSTARTED, State.STARTED) TimeUnit.MILLISECONDS
// We delay the consumer session start because Artemis starts delivering messages immediately, so we need to be )
// fully initialised. val sessions = ArrayList<ClientSession>()
sessions.forEach { for (i in 1 .. rpcConfiguration.consumerPoolSize) {
it.start() val sessionFactory = serverLocator.createSessionFactory()
val session = sessionFactory.createSession(rpcServerUsername, rpcServerPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE)
val consumer = session.createConsumer(RPCApi.RPC_SERVER_QUEUE_NAME)
consumer.setMessageHandler(this@RPCServer::clientArtemisMessageHandler)
sessionAndConsumers.add(ArtemisConsumer(sessionFactory, session, consumer))
sessions.add(session)
}
clientBindingRemovalConsumer = sessionAndConsumers[0].session.createConsumer(RPCApi.RPC_CLIENT_BINDING_REMOVALS)
clientBindingRemovalConsumer!!.setMessageHandler(this::bindingRemovalArtemisMessageHandler)
serverControl = activeMqServerControl
lifeCycle.transition(State.UNSTARTED, State.STARTED)
// We delay the consumer session start because Artemis starts delivering messages immediately, so we need to be
// fully initialised.
sessions.forEach {
it.start()
}
} catch (exception: Throwable) {
close()
throw exception
} }
} }
fun close() { fun close() {
reaperScheduledFuture.cancel(false) reaperScheduledFuture?.cancel(false)
rpcExecutor.shutdownNow() rpcExecutor?.shutdownNow()
reaperExecutor.shutdownNow() reaperExecutor?.shutdownNow()
rpcExecutor.awaitTermination(500, TimeUnit.MILLISECONDS)
reaperExecutor.awaitTermination(500, TimeUnit.MILLISECONDS)
sessionAndConsumers.forEach { sessionAndConsumers.forEach {
it.consumer.close()
it.session.close()
it.sessionFactory.close() it.sessionFactory.close()
} }
observableMap.invalidateAll() observableMap.invalidateAll()
reapSubscriptions() reapSubscriptions()
sessionAndProducerPool.close().forEach { sessionAndProducerPool.close().forEach {
it.producer.close()
it.session.close()
it.sessionFactory.close() it.sessionFactory.close()
} }
lifeCycle.transition(State.STARTED, State.FINISHED) lifeCycle.justTransition(State.FINISHED)
} }
private fun bindingRemovalArtemisMessageHandler(artemisMessage: ClientMessage) { private fun bindingRemovalArtemisMessageHandler(artemisMessage: ClientMessage) {
@ -211,7 +209,7 @@ class RPCServer(
val rpcContext = RpcContext( val rpcContext = RpcContext(
currentUser = getUser(artemisMessage) currentUser = getUser(artemisMessage)
) )
rpcExecutor.submit { rpcExecutor!!.submit {
val result = ErrorOr.catch { val result = ErrorOr.catch {
try { try {
CURRENT_RPC_CONTEXT.set(rpcContext) CURRENT_RPC_CONTEXT.set(rpcContext)
@ -239,9 +237,9 @@ class RPCServer(
observableMap, observableMap,
clientAddressToObservables, clientAddressToObservables,
clientToServer.clientAddress, clientToServer.clientAddress,
serverControl, serverControl!!,
sessionAndProducerPool, sessionAndProducerPool,
observationSendExecutor, observationSendExecutor!!,
kryoPool kryoPool
) )
observableContext.sendMessage(reply) observableContext.sendMessage(reply)
@ -255,7 +253,6 @@ class RPCServer(
} }
private fun reapSubscriptions() { private fun reapSubscriptions() {
lifeCycle.requireState(State.STARTED)
observableMap.cleanUp() observableMap.cleanUp()
} }

View File

@ -7,7 +7,7 @@ import net.corda.core.utilities.ALICE
import net.corda.core.utilities.BOB import net.corda.core.utilities.BOB
import net.corda.core.utilities.DUMMY_NOTARY import net.corda.core.utilities.DUMMY_NOTARY
import net.corda.flows.NotaryFlow import net.corda.flows.NotaryFlow
import net.corda.node.driver.NominatedNetworkMap import net.corda.node.driver.NetworkMapStartStrategy
import net.corda.node.driver.PortAllocation import net.corda.node.driver.PortAllocation
import net.corda.node.driver.driver import net.corda.node.driver.driver
import net.corda.node.services.startFlowPermission import net.corda.node.services.startFlowPermission
@ -20,8 +20,8 @@ import java.nio.file.Paths
/** Creates and starts all nodes required for the demo. */ /** Creates and starts all nodes required for the demo. */
fun main(args: Array<String>) { fun main(args: Array<String>) {
val demoUser = listOf(User("demo", "demo", setOf(startFlowPermission<DummyIssueAndMove>(), startFlowPermission<NotaryFlow.Client>()))) val demoUser = listOf(User("demo", "demo", setOf(startFlowPermission<DummyIssueAndMove>(), startFlowPermission<NotaryFlow.Client>())))
val networkMap = NominatedNetworkMap(DUMMY_NOTARY.name.appendToCommonName("1"), HostAndPort.fromParts("localhost", 10009)) val networkMap = NetworkMapStartStrategy.Nominated(DUMMY_NOTARY.name.appendToCommonName("1"), HostAndPort.fromParts("localhost", 10009))
driver(isDebug = true, driverDirectory = Paths.get("build") / "notary-demo-nodes", networkMapStrategy = networkMap, portAllocation = PortAllocation.Incremental(10001)) { driver(isDebug = true, driverDirectory = Paths.get("build") / "notary-demo-nodes", networkMapStartStrategy = networkMap, portAllocation = PortAllocation.Incremental(10001)) {
startNode(ALICE.name, rpcUsers = demoUser) startNode(ALICE.name, rpcUsers = demoUser)
startNode(BOB.name) startNode(BOB.name)
startNotaryCluster(X500Name("CN=Raft,O=R3,OU=corda,L=Zurich,C=CH"), clusterSize = 3, type = RaftValidatingNotaryService.type) startNotaryCluster(X500Name("CN=Raft,O=R3,OU=corda,L=Zurich,C=CH"), clusterSize = 3, type = RaftValidatingNotaryService.type)

View File

@ -9,6 +9,7 @@ import net.corda.client.mock.string
import net.corda.client.rpc.internal.RPCClient import net.corda.client.rpc.internal.RPCClient
import net.corda.client.rpc.internal.RPCClientConfiguration import net.corda.client.rpc.internal.RPCClientConfiguration
import net.corda.core.div import net.corda.core.div
import net.corda.core.map
import net.corda.core.messaging.RPCOps import net.corda.core.messaging.RPCOps
import net.corda.core.random63BitValue import net.corda.core.random63BitValue
import net.corda.core.utilities.ProcessUtilities import net.corda.core.utilities.ProcessUtilities
@ -64,7 +65,7 @@ interface RPCDriverExposedDSLInterface : DriverDSLExposedInterface {
maxBufferedBytesPerClient: Long = 10L * ArtemisMessagingServer.MAX_FILE_SIZE, maxBufferedBytesPerClient: Long = 10L * ArtemisMessagingServer.MAX_FILE_SIZE,
configuration: RPCServerConfiguration = RPCServerConfiguration.default, configuration: RPCServerConfiguration = RPCServerConfiguration.default,
ops : I ops : I
): ListenableFuture<Unit> ): ListenableFuture<RpcServerHandle>
/** /**
* Starts an In-VM RPC client. * Starts an In-VM RPC client.
@ -156,6 +157,28 @@ interface RPCDriverExposedDSLInterface : DriverDSLExposedInterface {
username: String = rpcTestUser.username, username: String = rpcTestUser.username,
password: String = rpcTestUser.password password: String = rpcTestUser.password
): ClientSession ): ClientSession
fun startRpcBroker(
serverName: String = "driver-rpc-server-${random63BitValue()}",
rpcUser: User = rpcTestUser,
maxFileSize: Int = ArtemisMessagingServer.MAX_FILE_SIZE,
maxBufferedBytesPerClient: Long = 10L * ArtemisMessagingServer.MAX_FILE_SIZE,
customPort: HostAndPort? = null
): ListenableFuture<RpcBrokerHandle>
fun startInVmRpcBroker(
rpcUser: User = rpcTestUser,
maxFileSize: Int = ArtemisMessagingServer.MAX_FILE_SIZE,
maxBufferedBytesPerClient: Long = 10L * ArtemisMessagingServer.MAX_FILE_SIZE
): ListenableFuture<RpcBrokerHandle>
fun <I : RPCOps> startRpcServerWithBrokerRunning(
rpcUser: User = rpcTestUser,
nodeLegalName: X500Name = fakeNodeLegalName,
configuration: RPCServerConfiguration = RPCServerConfiguration.default,
ops: I,
brokerHandle: RpcBrokerHandle
): RpcServerHandle
} }
inline fun <reified I : RPCOps> RPCDriverExposedDSLInterface.startInVmRpcClient( inline fun <reified I : RPCOps> RPCDriverExposedDSLInterface.startInVmRpcClient(
username: String = rpcTestUser.username, username: String = rpcTestUser.username,
@ -176,11 +199,17 @@ inline fun <reified I : RPCOps> RPCDriverExposedDSLInterface.startRpcClient(
interface RPCDriverInternalDSLInterface : DriverDSLInternalInterface, RPCDriverExposedDSLInterface interface RPCDriverInternalDSLInterface : DriverDSLInternalInterface, RPCDriverExposedDSLInterface
data class RpcServerHandle( data class RpcBrokerHandle(
val hostAndPort: HostAndPort, val hostAndPort: HostAndPort?, /** null if this is an InVM broker */
val clientTransportConfiguration: TransportConfiguration,
val serverControl: ActiveMQServerControl val serverControl: ActiveMQServerControl
) )
data class RpcServerHandle(
val broker: RpcBrokerHandle,
val rpcServer: RPCServer
)
val rpcTestUser = User("user1", "test", permissions = emptySet()) val rpcTestUser = User("user1", "test", permissions = emptySet())
val fakeNodeLegalName = X500Name("CN=not:a:valid:name") val fakeNodeLegalName = X500Name("CN=not:a:valid:name")
@ -194,7 +223,7 @@ fun <A> rpcDriver(
debugPortAllocation: PortAllocation = globalDebugPortAllocation, debugPortAllocation: PortAllocation = globalDebugPortAllocation,
systemProperties: Map<String, String> = emptyMap(), systemProperties: Map<String, String> = emptyMap(),
useTestClock: Boolean = false, useTestClock: Boolean = false,
networkMapStrategy: NetworkMapStrategy = FalseNetworkMap, networkMapStartStrategy: NetworkMapStartStrategy = NetworkMapStartStrategy.Dedicated(startAutomatically = false),
dsl: RPCDriverExposedDSLInterface.() -> A dsl: RPCDriverExposedDSLInterface.() -> A
) = genericDriver( ) = genericDriver(
driverDsl = RPCDriverDSL( driverDsl = RPCDriverDSL(
@ -204,7 +233,7 @@ fun <A> rpcDriver(
systemProperties = systemProperties, systemProperties = systemProperties,
driverDirectory = driverDirectory.toAbsolutePath(), driverDirectory = driverDirectory.toAbsolutePath(),
useTestClock = useTestClock, useTestClock = useTestClock,
networkMapStrategy = networkMapStrategy, networkMapStartStrategy = networkMapStartStrategy,
isDebug = isDebug isDebug = isDebug
) )
), ),
@ -293,21 +322,9 @@ data class RPCDriverDSL(
maxBufferedBytesPerClient: Long, maxBufferedBytesPerClient: Long,
configuration: RPCServerConfiguration, configuration: RPCServerConfiguration,
ops: I ops: I
): ListenableFuture<Unit> { ): ListenableFuture<RpcServerHandle> {
return driverDSL.executorService.submit<Unit> { return startInVmRpcBroker(rpcUser, maxFileSize, maxBufferedBytesPerClient).map { broker ->
val artemisConfig = createInVmRpcServerArtemisConfig(maxFileSize, maxBufferedBytesPerClient) startRpcServerWithBrokerRunning(rpcUser, nodeLegalName, configuration, ops, broker)
val server = EmbeddedActiveMQ()
server.setConfiguration(artemisConfig)
server.setSecurityManager(SingleUserSecurityManager(rpcUser))
server.start()
driverDSL.shutdownManager.registerShutdown {
server.activeMQServer.stop()
server.stop()
}
startRpcServerWithBrokerRunning(
rpcUser, nodeLegalName, configuration, ops, inVmClientTransportConfiguration,
server.activeMQServer.activeMQServerControl
)
} }
} }
@ -344,22 +361,8 @@ data class RPCDriverDSL(
customPort: HostAndPort?, customPort: HostAndPort?,
ops: I ops: I
): ListenableFuture<RpcServerHandle> { ): ListenableFuture<RpcServerHandle> {
val hostAndPort = customPort ?: driverDSL.portAllocation.nextHostAndPort() return startRpcBroker(serverName, rpcUser, maxFileSize, maxBufferedBytesPerClient, customPort).map { broker ->
addressMustNotBeBound(driverDSL.executorService, hostAndPort) startRpcServerWithBrokerRunning(rpcUser, nodeLegalName, configuration, ops, broker)
return driverDSL.executorService.submit<RpcServerHandle> {
val artemisConfig = createRpcServerArtemisConfig(maxFileSize, maxBufferedBytesPerClient, driverDSL.driverDirectory / serverName, hostAndPort)
val server = ActiveMQServerImpl(artemisConfig, SingleUserSecurityManager(rpcUser))
server.start()
driverDSL.shutdownManager.registerShutdown {
server.stop()
addressMustNotBeBound(driverDSL.executorService, hostAndPort).get()
}
val transportConfiguration = createNettyClientTransportConfiguration(hostAndPort)
startRpcServerWithBrokerRunning(
rpcUser, nodeLegalName, configuration, ops, transportConfiguration,
server.activeMQServerControl
)
RpcServerHandle(hostAndPort, server.activeMQServerControl)
} }
} }
@ -401,16 +404,58 @@ data class RPCDriverDSL(
return session return session
} }
override fun startRpcBroker(
serverName: String,
rpcUser: User,
maxFileSize: Int,
maxBufferedBytesPerClient: Long,
customPort: HostAndPort?
): ListenableFuture<RpcBrokerHandle> {
val hostAndPort = customPort ?: driverDSL.portAllocation.nextHostAndPort()
addressMustNotBeBound(driverDSL.executorService, hostAndPort)
return driverDSL.executorService.submit<RpcBrokerHandle> {
val artemisConfig = createRpcServerArtemisConfig(maxFileSize, maxBufferedBytesPerClient, driverDSL.driverDirectory / serverName, hostAndPort)
val server = ActiveMQServerImpl(artemisConfig, SingleUserSecurityManager(rpcUser))
server.start()
driverDSL.shutdownManager.registerShutdown {
server.stop()
addressMustNotBeBound(driverDSL.executorService, hostAndPort).get()
}
RpcBrokerHandle(
hostAndPort = hostAndPort,
clientTransportConfiguration = createNettyClientTransportConfiguration(hostAndPort),
serverControl = server.activeMQServerControl
)
}
}
private fun <I : RPCOps> startRpcServerWithBrokerRunning( override fun startInVmRpcBroker(rpcUser: User, maxFileSize: Int, maxBufferedBytesPerClient: Long): ListenableFuture<RpcBrokerHandle> {
return driverDSL.executorService.submit<RpcBrokerHandle> {
val artemisConfig = createInVmRpcServerArtemisConfig(maxFileSize, maxBufferedBytesPerClient)
val server = EmbeddedActiveMQ()
server.setConfiguration(artemisConfig)
server.setSecurityManager(SingleUserSecurityManager(rpcUser))
server.start()
driverDSL.shutdownManager.registerShutdown {
server.activeMQServer.stop()
server.stop()
}
RpcBrokerHandle(
hostAndPort = null,
clientTransportConfiguration = inVmClientTransportConfiguration,
serverControl = server.activeMQServer.activeMQServerControl
)
}
}
override fun <I : RPCOps> startRpcServerWithBrokerRunning(
rpcUser: User, rpcUser: User,
nodeLegalName: X500Name, nodeLegalName: X500Name,
configuration: RPCServerConfiguration, configuration: RPCServerConfiguration,
ops: I, ops: I,
transportConfiguration: TransportConfiguration, brokerHandle: RpcBrokerHandle
serverControl: ActiveMQServerControl ): RpcServerHandle {
) { val locator = ActiveMQClient.createServerLocatorWithoutHA(brokerHandle.clientTransportConfiguration).apply {
val locator = ActiveMQClient.createServerLocatorWithoutHA(transportConfiguration).apply {
minLargeMessageSize = ArtemisMessagingServer.MAX_FILE_SIZE minLargeMessageSize = ArtemisMessagingServer.MAX_FILE_SIZE
} }
val userService = object : RPCUserService { val userService = object : RPCUserService {
@ -430,7 +475,8 @@ data class RPCDriverDSL(
rpcServer.close() rpcServer.close()
locator.close() locator.close()
} }
rpcServer.start(serverControl) rpcServer.start(brokerHandle.serverControl)
return RpcServerHandle(brokerHandle, rpcServer)
} }
} }

View File

@ -78,7 +78,7 @@ fun <A> verifierDriver(
debugPortAllocation: PortAllocation = PortAllocation.Incremental(5005), debugPortAllocation: PortAllocation = PortAllocation.Incremental(5005),
systemProperties: Map<String, String> = emptyMap(), systemProperties: Map<String, String> = emptyMap(),
useTestClock: Boolean = false, useTestClock: Boolean = false,
networkMapStrategy: NetworkMapStrategy = DedicatedNetworkMap, networkMapStartStrategy: NetworkMapStartStrategy = NetworkMapStartStrategy.Dedicated(startAutomatically = false),
dsl: VerifierExposedDSLInterface.() -> A dsl: VerifierExposedDSLInterface.() -> A
) = genericDriver( ) = genericDriver(
driverDsl = VerifierDriverDSL( driverDsl = VerifierDriverDSL(
@ -88,7 +88,7 @@ fun <A> verifierDriver(
systemProperties = systemProperties, systemProperties = systemProperties,
driverDirectory = driverDirectory.toAbsolutePath(), driverDirectory = driverDirectory.toAbsolutePath(),
useTestClock = useTestClock, useTestClock = useTestClock,
networkMapStrategy = networkMapStrategy, networkMapStartStrategy = networkMapStartStrategy,
isDebug = isDebug isDebug = isDebug
) )
), ),

View File

@ -13,7 +13,7 @@ import net.corda.core.utilities.ALICE
import net.corda.core.utilities.DUMMY_NOTARY import net.corda.core.utilities.DUMMY_NOTARY
import net.corda.flows.CashIssueFlow import net.corda.flows.CashIssueFlow
import net.corda.flows.CashPaymentFlow import net.corda.flows.CashPaymentFlow
import net.corda.node.driver.FalseNetworkMap import net.corda.node.driver.NetworkMapStartStrategy
import net.corda.node.services.config.VerifierType import net.corda.node.services.config.VerifierType
import net.corda.node.services.transactions.ValidatingNotaryService import net.corda.node.services.transactions.ValidatingNotaryService
import org.junit.Test import org.junit.Test
@ -35,7 +35,7 @@ class VerifierTests {
@Test @Test
fun `single verifier works with requestor`() { fun `single verifier works with requestor`() {
verifierDriver(networkMapStrategy = FalseNetworkMap) { verifierDriver {
val aliceFuture = startVerificationRequestor(ALICE.name) val aliceFuture = startVerificationRequestor(ALICE.name)
val transactions = generateTransactions(100) val transactions = generateTransactions(100)
val alice = aliceFuture.get() val alice = aliceFuture.get()
@ -52,7 +52,7 @@ class VerifierTests {
@Test @Test
fun `multiple verifiers work with requestor`() { fun `multiple verifiers work with requestor`() {
verifierDriver(networkMapStrategy = FalseNetworkMap) { verifierDriver {
val aliceFuture = startVerificationRequestor(ALICE.name) val aliceFuture = startVerificationRequestor(ALICE.name)
val transactions = generateTransactions(100) val transactions = generateTransactions(100)
val alice = aliceFuture.get() val alice = aliceFuture.get()
@ -72,7 +72,7 @@ class VerifierTests {
@Test @Test
fun `verification redistributes on verifier death`() { fun `verification redistributes on verifier death`() {
verifierDriver(networkMapStrategy = FalseNetworkMap) { verifierDriver {
val aliceFuture = startVerificationRequestor(ALICE.name) val aliceFuture = startVerificationRequestor(ALICE.name)
val numberOfTransactions = 100 val numberOfTransactions = 100
val transactions = generateTransactions(numberOfTransactions) val transactions = generateTransactions(numberOfTransactions)
@ -100,7 +100,7 @@ class VerifierTests {
@Test @Test
fun `verification request waits until verifier comes online`() { fun `verification request waits until verifier comes online`() {
verifierDriver(networkMapStrategy = FalseNetworkMap) { verifierDriver {
val aliceFuture = startVerificationRequestor(ALICE.name) val aliceFuture = startVerificationRequestor(ALICE.name)
val transactions = generateTransactions(100) val transactions = generateTransactions(100)
val alice = aliceFuture.get() val alice = aliceFuture.get()
@ -112,7 +112,7 @@ class VerifierTests {
@Test @Test
fun `single verifier works with a node`() { fun `single verifier works with a node`() {
verifierDriver { verifierDriver(networkMapStartStrategy = NetworkMapStartStrategy.Dedicated(startAutomatically = true)) {
val aliceFuture = startNode(ALICE.name) val aliceFuture = startNode(ALICE.name)
val notaryFuture = startNode(DUMMY_NOTARY.name, advertisedServices = setOf(ServiceInfo(ValidatingNotaryService.type)), verifierType = VerifierType.OutOfProcess) val notaryFuture = startNode(DUMMY_NOTARY.name, advertisedServices = setOf(ServiceInfo(ValidatingNotaryService.type)), verifierType = VerifierType.OutOfProcess)
val alice = aliceFuture.get() val alice = aliceFuture.get()