From de88ad4f4062885e92cc05a4d96e2cd549d3a4d9 Mon Sep 17 00:00:00 2001 From: Andras Slemmer Date: Wed, 29 Mar 2017 17:28:02 +0100 Subject: [PATCH 1/6] RPC muxing, multithreading, RPC driver, performance tests --- build.gradle | 1 + .../client/jfx/model/NodeMonitorModel.kt | 14 +- .../kotlin/net/corda/client/mock/Generator.kt | 17 + client/rpc/build.gradle | 1 + .../corda/client/rpc/CordaRPCClientTest.kt | 55 +- .../net/corda/client/rpc/RPCStabilityTests.kt | 170 +++++++ .../net/corda/client/rpc/CordaRPCClient.kt | 189 ++----- .../corda/client/rpc/CordaRPCClientImpl.kt | 418 ---------------- .../corda/client/rpc/internal/RPCClient.kt | 169 +++++++ .../rpc/internal/RPCClientProxyHandler.kt | 423 ++++++++++++++++ .../corda/client/rpc/AbstractClientRPCTest.kt | 102 ---- .../net/corda/client/rpc/AbstractRPCTest.kt | 56 +++ .../rpc/ClientRPCInfrastructureTests.kt | 190 ++++--- .../corda/client/rpc/RPCConcurrencyTests.kt | 194 ++++++++ .../corda/client/rpc/RPCPerformanceTests.kt | 315 ++++++++++++ .../corda/client/rpc/RPCPermissionsTest.kt | 85 ---- .../corda/client/rpc/RPCPermissionsTests.kt | 93 ++++ .../client/rpc/RepeatingBytesInputStream.kt | 25 + core/src/main/kotlin/net/corda/core/Utils.kt | 2 + .../net/corda/core/serialization/Kryo.kt | 47 +- .../net/corda/core/utilities/LazyPool.kt | 79 +++ .../corda/core/utilities/LazyStickyPool.kt | 67 +++ .../net/corda/core/utilities/LifeCycle.kt | 38 ++ .../core/flows/ContractUpgradeFlowTest.kt | 103 ++-- .../corda/docs/IntegrationTestingTutorial.kt | 11 +- .../net/corda/docs/ClientRpcTutorial.kt | 3 +- .../nodeapi/ArtemisMessagingComponent.kt | 3 - .../main/kotlin/net/corda/nodeapi/RPCApi.kt | 206 ++++++++ .../kotlin/net/corda/nodeapi/RPCStructures.kt | 60 +-- .../kotlin/net/corda/node/BootTests.kt | 5 +- .../node/services/DistributedServiceTests.kt | 3 +- .../messaging/MQSecurityAsNodeTest.kt | 4 +- .../services/messaging/MQSecurityAsRPCTest.kt | 6 +- .../services/messaging/MQSecurityTest.kt | 39 +- .../kotlin/net/corda/node/driver/Driver.kt | 67 ++- .../net/corda/node/internal/AbstractNode.kt | 2 +- .../corda/node/internal/CordaRPCOpsImpl.kt | 12 +- .../kotlin/net/corda/node/internal/Node.kt | 2 +- .../messaging/ArtemisMessagingServer.kt | 34 +- .../services/messaging/NodeMessagingClient.kt | 58 +-- .../node/services/messaging/RPCDispatcher.kt | 219 -------- .../node/services/messaging/RPCServer.kt | 346 +++++++++++++ .../services/messaging/RPCServerStructures.kt | 4 +- .../net/corda/node/shell/InteractiveShell.kt | 5 +- .../net/corda/node/CordaRPCOpsImplTest.kt | 9 +- .../messaging/ArtemisMessagingTests.kt | 2 +- .../attachmentdemo/AttachmentDemoTest.kt | 8 +- .../corda/attachmentdemo/AttachmentDemo.kt | 14 +- .../corda/bank/BankOfCordaRPCClientTest.kt | 6 +- .../corda/bank/api/BankOfCordaClientApi.kt | 21 +- .../kotlin/net/corda/irs/IRSDemoTest.kt | 3 +- .../kotlin/net/corda/notarydemo/NotaryDemo.kt | 4 +- .../net/corda/traderdemo/TraderDemoTest.kt | 18 +- .../kotlin/net/corda/traderdemo/TraderDemo.kt | 6 +- test-utils/build.gradle | 1 + .../main/kotlin/net/corda/testing/Measure.kt | 53 ++ .../kotlin/net/corda/testing/RPCDriver.kt | 469 ++++++++++++++++++ .../net/corda/testing/node/NodeBasedTest.kt | 16 +- .../net/corda/testing/node/SimpleNode.kt | 2 +- .../kotlin/net/corda/demobench/rpc/NodeRPC.kt | 9 +- .../main/kotlin/net/corda/explorer/Main.kt | 24 +- .../net/corda/loadtest/ConnectionManager.kt | 29 +- .../corda/webserver/internal/NodeWebServer.kt | 4 +- 63 files changed, 3223 insertions(+), 1417 deletions(-) create mode 100644 client/rpc/src/integration-test/kotlin/net/corda/client/rpc/RPCStabilityTests.kt delete mode 100644 client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClientImpl.kt create mode 100644 client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt create mode 100644 client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt delete mode 100644 client/rpc/src/test/kotlin/net/corda/client/rpc/AbstractClientRPCTest.kt create mode 100644 client/rpc/src/test/kotlin/net/corda/client/rpc/AbstractRPCTest.kt create mode 100644 client/rpc/src/test/kotlin/net/corda/client/rpc/RPCConcurrencyTests.kt create mode 100644 client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPerformanceTests.kt delete mode 100644 client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPermissionsTest.kt create mode 100644 client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPermissionsTests.kt create mode 100644 client/rpc/src/test/kotlin/net/corda/client/rpc/RepeatingBytesInputStream.kt create mode 100644 core/src/main/kotlin/net/corda/core/utilities/LazyPool.kt create mode 100644 core/src/main/kotlin/net/corda/core/utilities/LazyStickyPool.kt create mode 100644 core/src/main/kotlin/net/corda/core/utilities/LifeCycle.kt create mode 100644 node-api/src/main/kotlin/net/corda/nodeapi/RPCApi.kt delete mode 100644 node/src/main/kotlin/net/corda/node/services/messaging/RPCDispatcher.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/messaging/RPCServer.kt create mode 100644 test-utils/src/main/kotlin/net/corda/testing/Measure.kt create mode 100644 test-utils/src/main/kotlin/net/corda/testing/RPCDriver.kt diff --git a/build.gradle b/build.gradle index 00b8adeb4e..8293e8812c 100644 --- a/build.gradle +++ b/build.gradle @@ -1,3 +1,4 @@ + buildscript { // For sharing constants between builds Properties constants = new Properties() diff --git a/client/jfx/src/main/kotlin/net/corda/client/jfx/model/NodeMonitorModel.kt b/client/jfx/src/main/kotlin/net/corda/client/jfx/model/NodeMonitorModel.kt index a4eacbc47d..946815f91c 100644 --- a/client/jfx/src/main/kotlin/net/corda/client/jfx/model/NodeMonitorModel.kt +++ b/client/jfx/src/main/kotlin/net/corda/client/jfx/model/NodeMonitorModel.kt @@ -3,6 +3,7 @@ package net.corda.client.jfx.model import com.google.common.net.HostAndPort import javafx.beans.property.SimpleObjectProperty import net.corda.client.rpc.CordaRPCClient +import net.corda.client.rpc.CordaRPCClientConfiguration import net.corda.core.flows.StateMachineRunId import net.corda.core.messaging.CordaRPCOps import net.corda.core.messaging.StateMachineInfo @@ -52,11 +53,14 @@ class NodeMonitorModel { * TODO provide an unsubscribe mechanism */ fun register(nodeHostAndPort: HostAndPort, username: String, password: String) { - val client = CordaRPCClient(nodeHostAndPort) { - maxRetryInterval = 10.seconds.toMillis() - } - client.start(username, password) - val proxy = client.proxy() + val client = CordaRPCClient( + hostAndPort = nodeHostAndPort, + configuration = CordaRPCClientConfiguration.default.copy( + connectionMaxRetryInterval = 10.seconds + ) + ) + val connection = client.start(username, password) + val proxy = connection.proxy val (stateMachines, stateMachineUpdates) = proxy.stateMachinesAndUpdates() // Extract the flow tracking stream diff --git a/client/mock/src/main/kotlin/net/corda/client/mock/Generator.kt b/client/mock/src/main/kotlin/net/corda/client/mock/Generator.kt index bf737bc898..74c052e723 100644 --- a/client/mock/src/main/kotlin/net/corda/client/mock/Generator.kt +++ b/client/mock/src/main/kotlin/net/corda/client/mock/Generator.kt @@ -144,6 +144,23 @@ fun Generator.Companion.doubleRange(from: Double, to: Double): Generator from + it.nextDouble() * (to - from) } +fun Generator.Companion.char() = Generator { + val codePoint = Math.abs(it.nextInt()) % (17 * (1 shl 16)) + if (Character.isValidCodePoint(codePoint)) { + return@Generator ErrorOr(codePoint.toChar()) + } else { + ErrorOr.of(IllegalStateException("Could not generate valid codepoint")) + } +} + +fun Generator.Companion.string(meanSize: Double = 16.0) = replicatePoisson(meanSize, char()).map { + val builder = StringBuilder() + it.forEach { + builder.append(it) + } + builder.toString() +} + fun Generator.Companion.replicate(number: Int, generator: Generator): Generator> { val generators = mutableListOf>() for (i in 1..number) { diff --git a/client/rpc/build.gradle b/client/rpc/build.gradle index 27cfe636e9..ec52494023 100644 --- a/client/rpc/build.gradle +++ b/client/rpc/build.gradle @@ -36,6 +36,7 @@ dependencies { testCompile "org.assertj:assertj-core:${assertj_version}" testCompile project(':test-utils') + testCompile project(':client:mock') // Integration test helpers integrationTestCompile "junit:junit:$junit_version" diff --git a/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/CordaRPCClientTest.kt b/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/CordaRPCClientTest.kt index b7f51b1b55..d6d0f2719b 100644 --- a/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/CordaRPCClientTest.kt +++ b/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/CordaRPCClientTest.kt @@ -1,15 +1,10 @@ package net.corda.client.rpc import net.corda.core.contracts.DOLLARS -import net.corda.core.flows.FlowInitiator import net.corda.core.flows.FlowException +import net.corda.core.flows.FlowInitiator import net.corda.core.getOrThrow -import net.corda.core.messaging.FlowHandle -import net.corda.core.messaging.FlowProgressHandle -import net.corda.core.messaging.CordaRPCOps -import net.corda.core.messaging.StateMachineUpdate -import net.corda.core.messaging.startFlow -import net.corda.core.messaging.startTrackedFlow +import net.corda.core.messaging.* import net.corda.core.node.services.ServiceInfo import net.corda.core.random63BitValue import net.corda.core.serialization.OpaqueBytes @@ -27,7 +22,9 @@ import org.junit.After import org.junit.Before import org.junit.Test import java.util.* -import kotlin.test.* +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue class CordaRPCClientTest : NodeBasedTest() { private val rpcUser = User("user1", "test", permissions = setOf( @@ -36,6 +33,11 @@ class CordaRPCClientTest : NodeBasedTest() { )) private lateinit var node: Node private lateinit var client: CordaRPCClient + private var connection: CordaRPCConnection? = null + + private fun login(username: String, password: String) { + connection = client.start(username, password) + } @Before fun setUp() { @@ -45,33 +47,35 @@ class CordaRPCClientTest : NodeBasedTest() { @After fun done() { - client.close() + connection?.close() } @Test fun `log in with valid username and password`() { - client.start(rpcUser.username, rpcUser.password) + login(rpcUser.username, rpcUser.password) } @Test fun `log in with unknown user`() { assertThatExceptionOfType(ActiveMQSecurityException::class.java).isThrownBy { - client.start(random63BitValue().toString(), rpcUser.password) + login(random63BitValue().toString(), rpcUser.password) } } @Test fun `log in with incorrect password`() { assertThatExceptionOfType(ActiveMQSecurityException::class.java).isThrownBy { - client.start(rpcUser.username, random63BitValue().toString()) + login(rpcUser.username, random63BitValue().toString()) } } @Test fun `close-send deadlock and premature shutdown on empty observable`() { - val proxy = createRpcProxy(rpcUser.username, rpcUser.password) + println("Starting client") + login(rpcUser.username, rpcUser.password) + println("Creating proxy") println("Starting flow") - val flowHandle = proxy.startTrackedFlow( + val flowHandle = connection!!.proxy.startTrackedFlow( ::CashIssueFlow, 20.DOLLARS, OpaqueBytes.of(0), node.info.legalIdentity, node.info.legalIdentity) println("Started flow, waiting on result") @@ -83,9 +87,8 @@ class CordaRPCClientTest : NodeBasedTest() { @Test fun `FlowException thrown by flow`() { - client.start(rpcUser.username, rpcUser.password) - val proxy = client.proxy() - val handle = proxy.startFlow(::CashPaymentFlow, 100.DOLLARS, node.info.legalIdentity) + login(rpcUser.username, rpcUser.password) + val handle = connection!!.proxy.startFlow(::CashPaymentFlow, 100.DOLLARS, node.info.legalIdentity) // TODO Restrict this to CashException once RPC serialisation has been fixed assertThatExceptionOfType(FlowException::class.java).isThrownBy { handle.returnValue.getOrThrow() @@ -94,9 +97,8 @@ class CordaRPCClientTest : NodeBasedTest() { @Test fun `check basic flow has no progress`() { - client.start(rpcUser.username, rpcUser.password) - val proxy = client.proxy() - proxy.startFlow(::CashPaymentFlow, 100.DOLLARS, node.info.legalIdentity).use { + login(rpcUser.username, rpcUser.password) + connection!!.proxy.startFlow(::CashPaymentFlow, 100.DOLLARS, node.info.legalIdentity).use { assertFalse(it is FlowProgressHandle<*>) assertTrue(it is FlowHandle<*>) } @@ -104,7 +106,8 @@ class CordaRPCClientTest : NodeBasedTest() { @Test fun `get cash balances`() { - val proxy = createRpcProxy(rpcUser.username, rpcUser.password) + login(rpcUser.username, rpcUser.password) + val proxy = connection!!.proxy val startCash = proxy.getCashBalances() assertTrue(startCash.isEmpty(), "Should not start with any cash") @@ -123,7 +126,8 @@ class CordaRPCClientTest : NodeBasedTest() { @Test fun `flow initiator via RPC`() { - val proxy = createRpcProxy(rpcUser.username, rpcUser.password) + login(rpcUser.username, rpcUser.password) + val proxy = connection!!.proxy val smUpdates = proxy.stateMachinesAndUpdates() var countRpcFlows = 0 var countShellFlows = 0 @@ -148,11 +152,4 @@ class CordaRPCClientTest : NodeBasedTest() { assertEquals(2, countRpcFlows) assertEquals(1, countShellFlows) } - - private fun createRpcProxy(username: String, password: String): CordaRPCOps { - println("Starting client") - client.start(username, password) - println("Creating proxy") - return client.proxy() - } } diff --git a/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/RPCStabilityTests.kt b/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/RPCStabilityTests.kt new file mode 100644 index 0000000000..63a4dd6673 --- /dev/null +++ b/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/RPCStabilityTests.kt @@ -0,0 +1,170 @@ +package net.corda.client.rpc + +import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.Serializer +import com.esotericsoftware.kryo.io.Input +import com.esotericsoftware.kryo.io.Output +import com.esotericsoftware.kryo.pool.KryoPool +import com.google.common.util.concurrent.Futures +import net.corda.core.messaging.RPCOps +import net.corda.core.millis +import net.corda.core.random63BitValue +import net.corda.node.services.messaging.RPCServerConfiguration +import net.corda.nodeapi.RPCApi +import net.corda.nodeapi.RPCKryo +import net.corda.testing.* +import org.apache.activemq.artemis.api.core.SimpleString +import org.bouncycastle.crypto.tls.ConnectionEnd.server +import org.junit.Test +import rx.Observable +import rx.subjects.PublishSubject +import rx.subjects.UnicastSubject +import java.time.Duration +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicInteger + + +class RPCStabilityTests { + + interface LeakObservableOps: RPCOps { + fun leakObservable(): Observable + } + + @Test + fun `client cleans up leaked observables`() { + rpcDriver { + val leakObservableOpsImpl = object : LeakObservableOps { + val leakedUnsubscribedCount = AtomicInteger(0) + override val protocolVersion = 0 + override fun leakObservable(): Observable { + return PublishSubject.create().doOnUnsubscribe { + leakedUnsubscribedCount.incrementAndGet() + } + } + } + val server = startRpcServer(ops = leakObservableOpsImpl) + val proxy = startRpcClient(server.get().hostAndPort).get() + // Leak many observables + val N = 200 + (1..N).toList().parallelStream().forEach { + proxy.leakObservable() + } + // In a loop force GC and check whether the server is notified + while (true) { + System.gc() + if (leakObservableOpsImpl.leakedUnsubscribedCount.get() == N) break + Thread.sleep(100) + } + } + } + + interface TrackSubscriberOps : RPCOps { + fun subscribe(): Observable + } + + /** + * In this test we create a number of out of process RPC clients that call [TrackSubscriberOps.subscribe] in a loop. + */ + @Test + fun `server cleans up queues after disconnected clients`() { + rpcDriver { + val trackSubscriberOpsImpl = object : TrackSubscriberOps { + override val protocolVersion = 0 + val subscriberCount = AtomicInteger(0) + val trackSubscriberCountObservable = UnicastSubject.create().share(). + doOnSubscribe { subscriberCount.incrementAndGet() }. + doOnUnsubscribe { subscriberCount.decrementAndGet() } + override fun subscribe(): Observable { + return trackSubscriberCountObservable + } + } + val server = startRpcServer( + configuration = RPCServerConfiguration.default.copy( + reapIntervalMs = 100 + ), + ops = trackSubscriberOpsImpl + ).get() + + val numberOfClients = 4 + val clients = Futures.allAsList((1 .. numberOfClients).map { + startRandomRpcClient(server.hostAndPort) + }).get() + + // Poll until all clients connect + pollUntilClientNumber(server, numberOfClients) + pollUntilTrue("number of times subscribe() has been called") { trackSubscriberOpsImpl.subscriberCount.get() >= 100 }.get() + // Kill one client + clients[0].destroyForcibly() + pollUntilClientNumber(server, numberOfClients - 1) + // Kill the rest + (1 .. numberOfClients - 1).forEach { + clients[it].destroyForcibly() + } + pollUntilClientNumber(server, 0) + // Now poll until the server detects the disconnects and unsubscribes from all obserables. + pollUntilTrue("number of times subscribe() has been called") { trackSubscriberOpsImpl.subscriberCount.get() == 0 }.get() + } + } + + interface SlowConsumerRPCOps : RPCOps { + fun streamAtInterval(interval: Duration, size: Int): Observable + } + class SlowConsumerRPCOpsImpl : SlowConsumerRPCOps { + override val protocolVersion = 0 + + override fun streamAtInterval(interval: Duration, size: Int): Observable { + val chunk = ByteArray(size) + return Observable.interval(interval.toMillis(), TimeUnit.MILLISECONDS).map { chunk } + } + } + val dummyObservableSerialiser = object : Serializer>() { + override fun write(kryo: Kryo?, output: Output?, `object`: Observable?) { + } + override fun read(kryo: Kryo?, input: Input?, type: Class>?): Observable { + return Observable.empty() + } + } + @Test + fun `slow consumers are kicked`() { + val kryoPool = KryoPool.Builder { RPCKryo(dummyObservableSerialiser) }.build() + rpcDriver { + val server = startRpcServer(maxBufferedBytesPerClient = 10 * 1024 * 1024, ops = SlowConsumerRPCOpsImpl()).get() + + // Construct an RPC session manually so that we can hang in the message handler + val myQueue = "${RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX}.test.${random63BitValue()}" + val session = startArtemisSession(server.hostAndPort) + session.createTemporaryQueue(myQueue, myQueue) + val consumer = session.createConsumer(myQueue, null, -1, -1, false) + consumer.setMessageHandler { + Thread.sleep(50) // 5x slower than the server producer + it.acknowledge() + } + val producer = session.createProducer(RPCApi.RPC_SERVER_QUEUE_NAME) + session.start() + + pollUntilClientNumber(server, 1) + + val message = session.createMessage(false) + val request = RPCApi.ClientToServer.RpcRequest( + clientAddress = SimpleString(myQueue), + id = RPCApi.RpcRequestId(random63BitValue()), + methodName = SlowConsumerRPCOps::streamAtInterval.name, + arguments = listOf(10.millis, 123456) + ) + request.writeToClientMessage(kryoPool, message) + producer.send(message) + session.commit() + + // We are consuming slower than the server is producing, so we should be kicked after a while + pollUntilClientNumber(server, 0) + } + } + +} + +fun RPCDriverExposedDSLInterface.pollUntilClientNumber(server: RpcServerHandle, expected: Int) { + pollUntilTrue("number of RPC clients to become $expected") { + val clientAddresses = server.serverControl.addressNames.filter { it.startsWith(RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX) } + clientAddresses.size == expected + }.get() +} \ No newline at end of file diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt index 35768010f7..a6c97c3e3a 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt @@ -1,168 +1,49 @@ package net.corda.client.rpc import com.google.common.net.HostAndPort -import net.corda.core.ThreadBox -import net.corda.core.logElapsedTime +import net.corda.client.rpc.internal.RPCClient +import net.corda.client.rpc.internal.RPCClientConfiguration import net.corda.core.messaging.CordaRPCOps -import net.corda.core.minutes -import net.corda.core.seconds -import net.corda.core.utilities.loggerFor -import net.corda.nodeapi.ArtemisMessagingComponent import net.corda.nodeapi.ArtemisTcpTransport.Companion.tcpTransport import net.corda.nodeapi.ConnectionDirection -import net.corda.nodeapi.RPCException import net.corda.nodeapi.config.SSLConfiguration -import net.corda.nodeapi.rpcLog -import org.apache.activemq.artemis.api.core.ActiveMQException -import org.apache.activemq.artemis.api.core.client.ActiveMQClient -import org.apache.activemq.artemis.api.core.client.ClientSession -import org.apache.activemq.artemis.api.core.client.ClientSessionFactory -import org.apache.activemq.artemis.api.core.client.ServerLocator -import rx.Observable -import java.io.Closeable import java.time.Duration -import javax.annotation.concurrent.ThreadSafe -/** - * An RPC client connects to the specified server and allows you to make calls to the server that perform various - * useful tasks. See the documentation for [proxy] or review the docsite to learn more about how this API works. - * - * @param host The hostname and messaging port of the node. - * @param config If specified, the SSL configuration to use. If not specified, SSL will be disabled and the node will only be authenticated on non-SSL RPC port, the RPC traffic with not be encrypted when SSL is disabled. - */ -@ThreadSafe -class CordaRPCClient(val host: HostAndPort, override val config: SSLConfiguration? = null, val serviceConfigurationOverride: (ServerLocator.() -> Unit)? = null) : Closeable, ArtemisMessagingComponent() { - private companion object { - val log = loggerFor() - /** 10 MiB maximum allowed file size for attachments, including message headers. TODO: acquire this value from Network Map when supported. */ - @JvmStatic val MAX_FILE_SIZE = 10485760 +class CordaRPCConnection internal constructor( + connection: RPCClient.RPCConnection +) : RPCClient.RPCConnection by connection +data class CordaRPCClientConfiguration( + val connectionMaxRetryInterval: Duration +) { + internal fun toRpcClientConfiguration(): RPCClientConfiguration { + return RPCClientConfiguration.default.copy( + connectionMaxRetryInterval = connectionMaxRetryInterval + ) + } + companion object { + @JvmStatic + val default = CordaRPCClientConfiguration( + connectionMaxRetryInterval = RPCClientConfiguration.default.connectionMaxRetryInterval + ) + } +} + +class CordaRPCClient( + hostAndPort: HostAndPort, + sslConfiguration: SSLConfiguration? = null, + configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.default +) { + private val rpcClient = RPCClient( + tcpTransport(ConnectionDirection.Outbound(), hostAndPort, sslConfiguration), + configuration.toRpcClientConfiguration() + ) + + fun start(username: String, password: String): CordaRPCConnection { + return CordaRPCConnection(rpcClient.start(CordaRPCOps::class.java, username, password)) } - // TODO: Certificate handling for clients needs more work. - private inner class State { - var running = false - lateinit var sessionFactory: ClientSessionFactory - lateinit var session: ClientSession - lateinit var clientImpl: CordaRPCClientImpl - } - - private val state = ThreadBox(State()) - - /** - * Opens the connection to the server with the given username and password, then returns itself. - * Registers a JVM shutdown hook to cleanly disconnect. - */ - @Throws(ActiveMQException::class) - fun start(username: String, password: String): CordaRPCClient { - state.locked { - check(!running) - log.logElapsedTime("Startup") { - checkStorePasswords() - val serverLocator = ActiveMQClient.createServerLocatorWithoutHA(tcpTransport(ConnectionDirection.Outbound(), host, config, enableSSL = config != null)).apply { - // TODO: Put these in config file or make it user configurable? - threadPoolMaxSize = 1 - confirmationWindowSize = 100000 // a guess - retryInterval = 5.seconds.toMillis() - retryIntervalMultiplier = 1.5 // Exponential backoff - maxRetryInterval = 3.minutes.toMillis() - minLargeMessageSize = MAX_FILE_SIZE - serviceConfigurationOverride?.invoke(this) - } - sessionFactory = serverLocator.createSessionFactory() - session = sessionFactory.createSession(username, password, false, true, true, serverLocator.isPreAcknowledge, serverLocator.ackBatchSize) - session.start() - clientImpl = CordaRPCClientImpl(session, state.lock, username) - running = true - } - } - - Runtime.getRuntime().addShutdownHook(Thread { - close() - }) - - return this - } - - /** - * A convenience function that opens a connection with the given credentials, executes the given code block with all - * available RPCs in scope and shuts down the RPC connection again. It's meant for quick prototyping and demos. For - * more control you probably want to control the lifecycle of the client and proxies independently, as well as - * configuring a timeout and other such features via the [proxy] method. - * - * After this method returns the client is closed and can't be restarted. - */ - @Throws(ActiveMQException::class) - fun use(username: String, password: String, block: CordaRPCOps.() -> T): T { - require(!state.locked { running }) - start(username, password) - (this as Closeable).use { - return proxy().block() - } - } - - /** Shuts down the client and lets the server know it can free the used resources (in a nice way). */ - override fun close() { - state.locked { - if (!running) return - session.close() - sessionFactory.close() - running = false - } - } - - /** - * Returns a fresh proxy that lets you invoke RPCs on the server. Calls on it block, and if the server throws an - * exception then it will be rethrown on the client. Proxies are thread safe but only one RPC can be in flight at - * once. If you'd like to perform multiple RPCs in parallel, use this function multiple times to get multiple - * proxies. - * - * Creation of a proxy is a somewhat expensive operation that involves calls to the server, so if you want to do - * calls from many threads at once you should cache one proxy per thread and reuse them. This function itself is - * thread safe though so requires no extra synchronisation. - * - * RPC sends and receives are logged on the net.corda.rpc logger. - * - * By default there are no timeouts on calls. This is deliberate, RPCs without timeouts can survive restarts, - * maintenance downtime and moves of the server. RPCs can survive temporary losses or changes in client connectivity, - * like switching between wifi networks. You can specify a timeout on the level of a proxy. If a call times - * out it will throw [RPCException.Deadline]. - * - * The [CordaRPCOps] defines what client RPCs are available. If an RPC returns an [Observable] anywhere in the - * object graph returned then the server-side observable is transparently linked to a messaging queue, and that - * queue linked to another observable on the client side here. *You are expected to use it*. The server will begin - * buffering messages immediately that it will expect you to drain by subscribing to the returned observer. You can - * opt-out of this by simply casting the [Observable] to [Closeable] or [AutoCloseable] and then calling the close - * method on it. You don't have to explicitly close the observable if you actually subscribe to it: it will close - * itself and free up the server-side resources either when the client or JVM itself is shutdown, or when there are - * no more subscribers to it. Once all the subscribers to a returned observable are unsubscribed, the observable is - * closed and you can't then re-subscribe again: you'll have to re-request a fresh observable with another RPC. - * - * The proxy and linked observables consume some small amount of resources on the server. It's OK to just exit your - * process and let the server clean up, but in a long running process where you only need something for a short - * amount of time it is polite to cast the objects to [Closeable] or [AutoCloseable] and close it when you are done. - * Finalizers are in place to warn you if you lose a reference to an unclosed proxy or observable. - * - * @throws RPCException if the server version is too low or if the server isn't reachable within the given time. - */ - @JvmOverloads - @Throws(RPCException::class) - fun proxy(timeout: Duration? = null, minVersion: Int = 0): CordaRPCOps { - return state.locked { - check(running) { "Client must have been started first" } - log.logElapsedTime("Proxy build") { - clientImpl.proxyFor(CordaRPCOps::class.java, timeout, minVersion) - } - } - } - - @Suppress("UNUSED") - private fun finalize() { - state.locked { - if (running) { - rpcLog.warn("A CordaMQClient is being finalised whilst still running, did you forget to call close?") - close() - } - } + inline fun use(username: String, password: String, block: (CordaRPCConnection) -> A): A { + return start(username, password).use(block) } } \ No newline at end of file diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClientImpl.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClientImpl.kt deleted file mode 100644 index c5ba696ce1..0000000000 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClientImpl.kt +++ /dev/null @@ -1,418 +0,0 @@ -package net.corda.client.rpc - -import com.esotericsoftware.kryo.Kryo -import com.esotericsoftware.kryo.KryoException -import com.esotericsoftware.kryo.Serializer -import com.esotericsoftware.kryo.io.Input -import com.esotericsoftware.kryo.io.Output -import com.esotericsoftware.kryo.pool.KryoPool -import com.google.common.cache.CacheBuilder -import net.corda.core.ErrorOr -import net.corda.core.bufferUntilSubscribed -import net.corda.core.messaging.RPCOps -import net.corda.core.messaging.RPCReturnsObservables -import net.corda.core.random63BitValue -import net.corda.core.serialization.deserialize -import net.corda.core.serialization.serialize -import net.corda.core.utilities.debug -import net.corda.nodeapi.* -import org.apache.activemq.artemis.api.core.ActiveMQObjectClosedException -import org.apache.activemq.artemis.api.core.Message.HDR_DUPLICATE_DETECTION_ID -import org.apache.activemq.artemis.api.core.SimpleString -import org.apache.activemq.artemis.api.core.client.ClientConsumer -import org.apache.activemq.artemis.api.core.client.ClientMessage -import org.apache.activemq.artemis.api.core.client.ClientProducer -import org.apache.activemq.artemis.api.core.client.ClientSession -import rx.Observable -import rx.subjects.PublishSubject -import java.io.Closeable -import java.lang.ref.WeakReference -import java.lang.reflect.InvocationHandler -import java.lang.reflect.Method -import java.lang.reflect.Proxy -import java.time.Duration -import java.util.* -import java.util.concurrent.atomic.AtomicInteger -import java.util.concurrent.locks.ReentrantLock -import javax.annotation.concurrent.GuardedBy -import javax.annotation.concurrent.ThreadSafe -import kotlin.concurrent.withLock -import kotlin.reflect.jvm.javaMethod - -/** - * Core RPC engine implementation, to learn how to use RPC you should be looking at [CordaRPCClient]. - * - * # Design notes - * - * The way RPCs are handled is fairly standard except for the handling of observables. When an RPC might return - * an [Observable] it is specially tagged. This causes the client to create a new transient queue for the - * receiving of observables and their observations with a random ID in the name. This ID is sent to the server in - * a message header. All observations are sent via this single queue. - * - * The reason for doing it this way and not the more obvious approach of one-queue-per-observable is that we want - * the queues to be *transient*, meaning their lifetime in the broker is tied to the session that created them. - * A server side observable and its associated queue is not a cost-free thing, let alone the memory and resources - * needed to actually generate the observations themselves, therefore we want to ensure these cannot leak. A - * transient queue will be deleted automatically if the client session terminates, which by default happens on - * disconnect but can also be configured to happen after a short delay (this allows clients to e.g. switch IP - * address). On the server the deletion of the observations queue triggers unsubscription from the associated - * observables, which in turn may then be garbage collected. - * - * Creating a transient queue requires a roundtrip to the broker and thus doing an RPC that could return - * observables takes two server roundtrips instead of one. That's why we require RPCs to be marked with - * [RPCReturnsObservables] as needing this special treatment instead of always doing it. - * - * If the Artemis/JMS APIs allowed us to create transient queues assigned to someone else then we could - * potentially use a different design in which the node creates new transient queues (one per observable) on the - * fly. The client would then have to watch out for this and start consuming those queues as they were created. - * - * We use one queue per RPC because we don't know ahead of time how many observables the server might return and - * often the server doesn't know either, which pushes towards a single queue design, but at the same time the - * processing of observations returned by an RPC might be striped across multiple threads and we'd like - * backpressure management to not be scoped per client process but with more granularity. So we end up with - * a compromise where the unit of backpressure management is the response to a single RPC. - * - * TODO: Backpressure isn't propagated all the way through the MQ broker at the moment. - */ -class CordaRPCClientImpl(private val session: ClientSession, - private val sessionLock: ReentrantLock, - private val username: String) { - companion object { - private val closeableCloseMethod = Closeable::close.javaMethod - private val autocloseableCloseMethod = AutoCloseable::close.javaMethod - } - - /** - * Builds a proxy for the given type, which must descend from [RPCOps]. - * - * @see CordaRPCClient.proxy for more information about how to use the proxies. - */ - fun proxyFor(rpcInterface: Class, timeout: Duration? = null, minVersion: Int = 0): T { - sessionLock.withLock { - if (producer == null) - producer = session.createProducer() - } - val proxyImpl = RPCProxyHandler(timeout) - @Suppress("UNCHECKED_CAST") - val proxy = Proxy.newProxyInstance(rpcInterface.classLoader, arrayOf(rpcInterface, Closeable::class.java), proxyImpl) as T - proxyImpl.serverProtocolVersion = proxy.protocolVersion - if (minVersion > proxyImpl.serverProtocolVersion) - throw RPCException("Requested minimum protocol version $minVersion is higher than the server's supported protocol version (${proxyImpl.serverProtocolVersion})") - return proxy - } - - ////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - // - //region RPC engine - // - // You can find docs on all this in the api doc for the proxyFor method, and in the docsite. - - // Utility to quickly suck out the contents of an Artemis message. There's probably a more efficient way to - // do this. - private fun ClientMessage.deserialize(kryo: Kryo): T = ByteArray(bodySize).apply { bodyBuffer.readBytes(this) }.deserialize(kryo) - - // We by default use a weak reference so GC can happen, otherwise they persist for the life of the client. - @GuardedBy("sessionLock") - private val addressToQueuedObservables = CacheBuilder.newBuilder().weakValues().build() - // This is used to hold a reference counted hard reference when we know there are subscribers. - private val hardReferencesToQueuedObservables = Collections.synchronizedSet(mutableSetOf()) - - private var producer: ClientProducer? = null - - class ObservableDeserializer : Serializer>() { - override fun read(kryo: Kryo, input: Input, type: Class>): Observable { - val qName = kryo.context[RPCKryoQNameKey] as String - val rpcName = kryo.context[RPCKryoMethodNameKey] as String - val rpcLocation = kryo.context[RPCKryoLocationKey] as Throwable - val rpcClient = kryo.context[RPCKryoClientKey] as CordaRPCClientImpl - val handle = input.readInt(true) - val ob = rpcClient.sessionLock.withLock { - rpcClient.addressToQueuedObservables.getIfPresent(qName) ?: rpcClient.QueuedObservable(qName, rpcName, rpcLocation).apply { - rpcClient.addressToQueuedObservables.put(qName, this) - } - } - val result = ob.getForHandle(handle) - rpcLog.debug { "Deserializing and connecting a new observable for $rpcName on $qName: $result" } - return result - } - - override fun write(kryo: Kryo, output: Output, `object`: Observable) { - throw UnsupportedOperationException("not implemented") - } - } - - /** - * The proxy class returned to the client is auto-generated on the fly by the java.lang.reflect Proxy - * infrastructure. The JDK Proxy class writes bytecode into memory for a class that implements the requested - * interfaces and then routes all method calls to the invoke method below in a conveniently reified form. - * We can then easily take the data about the method call and turn it into an RPC. This avoids the need - * for the compile-time code generation which is so common in RPC systems. - */ - @ThreadSafe - private inner class RPCProxyHandler(private val timeout: Duration?) : InvocationHandler, Closeable { - private val proxyId = random63BitValue() - private val consumer: ClientConsumer - - var serverProtocolVersion = 0 - - init { - val proxyAddress = constructAddress(proxyId) - consumer = sessionLock.withLock { - session.createTemporaryQueue(proxyAddress, proxyAddress) - session.createConsumer(proxyAddress) - } - } - - private fun constructAddress(addressId: Long) = "${ArtemisMessagingComponent.CLIENTS_PREFIX}$username.rpc.$addressId" - - @Synchronized - override fun invoke(proxy: Any, method: Method, args: Array?): Any? { - if (isCloseInvocation(method)) { - close() - return null - } - if (method.name == "toString" && args == null) - return "Client RPC proxy" - - if (consumer.isClosed) - throw RPCException("RPC Proxy is closed") - - // All invoked methods on the proxy end up here. - val location = Throwable() - rpcLog.debug { - val argStr = args?.joinToString() ?: "" - "-> RPC -> ${method.name}($argStr): ${method.returnType}" - } - - checkMethodVersion(method) - - val msg: ClientMessage = createMessage(method) - // We could of course also check the return type of the method to see if it's Observable, but I'd - // rather haved the annotation be used consistently. - val returnsObservables = method.isAnnotationPresent(RPCReturnsObservables::class.java) - val kryo = if (returnsObservables) maybePrepareForObservables(location, method, msg) else createRPCKryoForDeserialization(this@CordaRPCClientImpl) - val next: ErrorOr<*> = try { - sendRequest(args, msg) - receiveResponse(kryo, method, timeout) - } finally { - releaseRPCKryoForDeserialization(kryo) - } - rpcLog.debug { "<- RPC <- ${method.name} = $next" } - return unwrapOrThrow(next) - } - - @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN") - private fun unwrapOrThrow(next: ErrorOr<*>): Any? { - val ex = next.error - if (ex != null) { - // Replace the stack trace because that's an implementation detail of the server that isn't so - // helpful to the user who wants to see where the error was on their side, and serialising stack - // frame objects is a bit annoying. We slice it here to avoid the invoke() machinery being exposed. - // The resulting exception looks like it was thrown from inside the called method. - (ex as java.lang.Throwable).stackTrace = java.lang.Throwable().stackTrace.let { it.sliceArray(1..it.size - 1) } - throw ex - } else { - return next.value - } - } - - private fun receiveResponse(kryo: Kryo, method: Method, timeout: Duration?): ErrorOr<*> { - val artemisMessage: ClientMessage = - if (timeout == null) - consumer.receive() ?: throw ActiveMQObjectClosedException() - else - consumer.receive(timeout.toMillis()) ?: throw RPCException.DeadlineExceeded(method.name) - artemisMessage.acknowledge() - val next = artemisMessage.deserialize>(kryo) - return next - } - - private fun sendRequest(args: Array?, msg: ClientMessage) { - sessionLock.withLock { - val argsKryo = createRPCKryoForDeserialization(this@CordaRPCClientImpl) - val serializedArgs = try { - (args ?: emptyArray()).serialize(argsKryo) - } catch (e: KryoException) { - throw RPCException("Could not serialize RPC arguments", e) - } finally { - releaseRPCKryoForDeserialization(argsKryo) - } - msg.writeBodyBufferBytes(serializedArgs.bytes) - producer!!.send(ArtemisMessagingComponent.RPC_REQUESTS_QUEUE, msg) - } - } - - private fun maybePrepareForObservables(location: Throwable, method: Method, msg: ClientMessage): Kryo { - // Create a temporary queue just for the emissions on any observables that are returned. - val observationsId = random63BitValue() - val observationsQueueName = constructAddress(observationsId) - session.createTemporaryQueue(observationsQueueName, observationsQueueName) - msg.putLongProperty(ClientRPCRequestMessage.OBSERVATIONS_TO, observationsId) - // And make sure that we deserialise observable handles so that they're linked to the right - // queue. Also record a bit of metadata for debugging purposes. - return createRPCKryoForDeserialization(this@CordaRPCClientImpl, observationsQueueName, method.name, location) - } - - private fun createMessage(method: Method): ClientMessage { - return session.createMessage(false).apply { - putStringProperty(ClientRPCRequestMessage.METHOD_NAME, method.name) - putLongProperty(ClientRPCRequestMessage.REPLY_TO, proxyId) - // Use the magic deduplication property built into Artemis as our message identity too - putStringProperty(HDR_DUPLICATE_DETECTION_ID, SimpleString(UUID.randomUUID().toString())) - } - } - - private fun checkMethodVersion(method: Method) { - val methodVersion = method.getAnnotation(RPCSinceVersion::class.java)?.version ?: 0 - if (methodVersion > serverProtocolVersion) - throw UnsupportedOperationException("Method ${method.name} was added in RPC protocol version $methodVersion but the server is running $serverProtocolVersion") - } - - private fun isCloseInvocation(method: Method) = method == closeableCloseMethod || method == autocloseableCloseMethod - - override fun close() { - consumer.close() - sessionLock.withLock { session.deleteQueue(constructAddress(proxyId)) } - } - - override fun toString() = "Corda RPC Proxy listening on queue ${constructAddress(proxyId)}" - } - - /** - * When subscribed to, starts consuming from the given queue name and demultiplexing the observables being - * sent to it. The server queue is moved into in-memory buffers (one per attached server-side observable) - * until drained through a subscription. When the subscriptions are all gone, the server-side queue is deleted. - */ - @ThreadSafe - private inner class QueuedObservable(private val qName: String, - private val rpcName: String, - private val rpcLocation: Throwable) { - private val root = PublishSubject.create() - private val rootShared = root.doOnUnsubscribe { close() }.share() - - // This could be made more efficient by using a specialised IntMap - // When handling this map we don't synchronise on [this], otherwise there is a race condition between close() and deliver() - private val observables = Collections.synchronizedMap(HashMap>()) - - @GuardedBy("sessionLock") - private var consumer: ClientConsumer? = null - - private val referenceCount = AtomicInteger(0) - - // We have to create a weak reference, otherwise we cannot be GC'd. - init { - val weakThis = WeakReference(this) - consumer = sessionLock.withLock { session.createConsumer(qName) }.setMessageHandler { weakThis.get()?.deliver(it) } - } - - /** - * We have to reference count subscriptions to the returned [Observable]s to prevent early GC because we are - * weak referenced. - * - * Derived [Observables] (e.g. filtered etc) hold a strong reference to the original, but for example, if - * the pattern as follows is used, the original passes out of scope and the direction of reference is from the - * original to the [Observer]. We use the reference counting to allow for this pattern. - * - * val observationsSubject = PublishSubject.create() - * originalObservable.subscribe(observationsSubject) - * return observationsSubject - */ - private fun refCountUp() { - if (referenceCount.andIncrement == 0) { - hardReferencesToQueuedObservables.add(this) - } - } - - private fun refCountDown() { - if (referenceCount.decrementAndGet() == 0) { - hardReferencesToQueuedObservables.remove(this) - } - } - - fun getForHandle(handle: Int): Observable { - synchronized(observables) { - return observables.getOrPut(handle) { - /** - * Note that the order of bufferUntilSubscribed() -> dematerialize() is very important here. - * - * In particular doing it the other way around may result in the following edge case: - * The RPC returns two (or more) Observables. The first Observable unsubscribes *during serialisation*, - * before the second one is hit, causing the [rootShared] to unsubscribe and consequently closing - * the underlying artemis queue, even though the second Observable was not even registered. - * - * The buffer -> dematerialize order ensures that the Observable may not unsubscribe until the caller - * subscribes, which must be after full deserialisation and registering of all top level Observables. - * - * In addition, when subscribe and unsubscribe is called on the [Observable] returned here, we - * reference count a hard reference to this [QueuedObservable] to prevent premature GC. - */ - rootShared.filter { it.forHandle == handle }.map { it.what }.bufferUntilSubscribed().dematerialize().doOnSubscribe { refCountUp() }.doOnUnsubscribe { refCountDown() }.share() - } - } - } - - private fun deliver(msg: ClientMessage) { - sessionLock.withLock { msg.acknowledge() } - val kryo = createRPCKryoForDeserialization(this@CordaRPCClientImpl, qName, rpcName, rpcLocation) - val received: MarshalledObservation = try { - msg.deserialize(kryo) - } finally { - releaseRPCKryoForDeserialization(kryo) - } - rpcLog.debug { "<- Observable [$rpcName] <- Received $received" } - synchronized(observables) { - // Force creation of the buffer if it doesn't already exist. - getForHandle(received.forHandle) - root.onNext(received) - } - } - - fun close() { - sessionLock.withLock { - if (consumer != null) { - rpcLog.debug("Closing queue observable for call to $rpcName : $qName") - consumer?.close() - consumer = null - session.deleteQueue(qName) - } - } - } - - @Suppress("UNUSED") - fun finalize() { - val closed = sessionLock.withLock { - if (consumer != null) { - consumer!!.close() - consumer = null - true - } else - false - } - if (closed) { - rpcLog.warn("""A hot observable returned from an RPC ($rpcName) was never subscribed to. - This wastes server-side resources because it was queueing observations for retrieval. - It is being closed now, but please adjust your code to call .notUsed() on the observable - to close it explicitly. (Java users: subscribe to it then unsubscribe). This warning - will appear less frequently in future versions of the platform and you can ignore it - if you want to. - """.trimIndent().replace('\n', ' '), rpcLocation) - } - } - } - //endregion -} - -private val rpcDesKryoPool = KryoPool.Builder { RPCKryo(CordaRPCClientImpl.ObservableDeserializer()) }.build() - -fun createRPCKryoForDeserialization(rpcClient: CordaRPCClientImpl, qName: String? = null, rpcName: String? = null, rpcLocation: Throwable? = null): Kryo { - val kryo = rpcDesKryoPool.borrow() - kryo.context.put(RPCKryoClientKey, rpcClient) - kryo.context.put(RPCKryoQNameKey, qName) - kryo.context.put(RPCKryoMethodNameKey, rpcName) - kryo.context.put(RPCKryoLocationKey, rpcLocation) - return kryo -} - -fun releaseRPCKryoForDeserialization(kryo: Kryo) { - rpcDesKryoPool.release(kryo) -} diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt new file mode 100644 index 0000000000..60c77c676e --- /dev/null +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt @@ -0,0 +1,169 @@ +package net.corda.client.rpc.internal + +import com.google.common.net.HostAndPort +import net.corda.core.logElapsedTime +import net.corda.core.messaging.RPCOps +import net.corda.core.minutes +import net.corda.core.random63BitValue +import net.corda.core.seconds +import net.corda.core.utilities.loggerFor +import net.corda.nodeapi.ArtemisTcpTransport.Companion.tcpTransport +import net.corda.nodeapi.ConnectionDirection +import net.corda.nodeapi.RPCApi +import net.corda.nodeapi.RPCException +import net.corda.nodeapi.config.SSLConfiguration +import org.apache.activemq.artemis.api.core.SimpleString +import org.apache.activemq.artemis.api.core.TransportConfiguration +import org.apache.activemq.artemis.api.core.client.ActiveMQClient +import java.io.Closeable +import java.lang.reflect.Proxy +import java.time.Duration + +/** + * This configuration may be used to tweak the internals of the RPC client. + */ +data class RPCClientConfiguration( + /** The minimum protocol version required from the server */ + val minimumServerProtocolVersion: Int, + /** + * If set to true the client will track RPC call sites. If an error occurs subsequently during the RPC or in a + * returned Observable stream the stack trace of the originating RPC will be shown as well. Note that + * constructing call stacks is a moderately expensive operation. + */ + val trackRpcCallSites: Boolean, + /** + * The interval of unused observable reaping in milliseconds. Leaked Observables (unused ones) are + * detected using weak references and are cleaned up in batches in this interval. If set too large it will waste + * server side resources for this duration. If set too low it wastes client side cycles. + */ + val reapIntervalMs: Long, + /** The number of threads to use for observations (for executing [Observable.onNext]) */ + val observationExecutorPoolSize: Int, + /** The maximum number of producers to create to handle outgoing messages */ + val producerPoolBound: Int, + /** + * Determines the concurrency level of the Observable Cache. This is exposed because it implicitly determines + * the limit on the number of leaked observables reaped because of garbage collection per reaping. + * See the implementation of [com.google.common.cache.LocalCache] for details. + */ + val cacheConcurrencyLevel: Int, + /** The retry interval of artemis connections in milliseconds */ + val connectionRetryInterval: Duration, + /** The retry interval multiplier for exponential backoff */ + val connectionRetryIntervalMultiplier: Double, + /** Maximum retry interval */ + val connectionMaxRetryInterval: Duration, + /** Maximum file size */ + val maxFileSize: Int +) { + companion object { + @JvmStatic + val default = RPCClientConfiguration( + minimumServerProtocolVersion = 0, + trackRpcCallSites = false, + reapIntervalMs = 1000, + observationExecutorPoolSize = 4, + producerPoolBound = 1, + cacheConcurrencyLevel = 8, + connectionRetryInterval = 5.seconds, + connectionRetryIntervalMultiplier = 1.5, + connectionMaxRetryInterval = 3.minutes, + /** 10 MiB maximum allowed file size for attachments, including message headers. TODO: acquire this value from Network Map when supported. */ + maxFileSize = 10485760 + ) + } +} + +/** + * An RPC client that may be used to create connections to an RPC server. + * + * @param transport The Artemis transport to use to connect to the server. + * @param rpcConfiguration Configuration used to tweak client behaviour. + */ +class RPCClient( + val transport: TransportConfiguration, + val rpcConfiguration: RPCClientConfiguration = RPCClientConfiguration.default +) { + constructor( + hostAndPort: HostAndPort, + sslConfiguration: SSLConfiguration? = null, + configuration: RPCClientConfiguration = RPCClientConfiguration.default + ) : this(tcpTransport(ConnectionDirection.Outbound(), hostAndPort, sslConfiguration), configuration) + + companion object { + private val log = loggerFor>() + } + + /** + * Holds a proxy object implementing [I] that forwards requests to the RPC server. + * + * [Closeable.close] may be used to shut down the connection and release associated resources. + */ + interface RPCConnection : Closeable { + val proxy: I + /** The RPC protocol version reported by the server */ + val serverProtocolVersion: Int + } + + /** + * Returns an [RPCConnection] containing a proxy that lets you invoke RPCs on the server. Calls on it block, and if + * the server throws an exception then it will be rethrown on the client. Proxies are thread safe and may be used to + * invoke multiple RPCs in parallel. + * + * RPC sends and receives are logged on the net.corda.rpc logger. + * + * The [RPCOps] defines what client RPCs are available. If an RPC returns an [Observable] anywhere in the object + * graph returned then the server-side observable is transparently forwarded to the client side here. + * *You are expected to use it*. The server will begin buffering messages immediately that it will expect you to + * drain by subscribing to the returned observer. You can opt-out of this by simply calling the + * [net.corda.client.rpc.notUsed] method on it. You don't have to explicitly close the observable if you actually + * subscribe to it: it will close itself and free up the server-side resources either when the client or JVM itself + * is shutdown, or when there are no more subscribers to it. Once all the subscribers to a returned observable are + * unsubscribed or the observable completes successfully or with an error, the observable is closed and you can't + * then re-subscribe again: you'll have to re-request a fresh observable with another RPC. + * + * @param rpcOpsClass The [Class] of the RPC interface. + * @param username The username to authenticate with. + * @param password The password to authenticate with. + * @throws RPCException if the server version is too low or if the server isn't reachable within the given time. + */ + fun start( + rpcOpsClass: Class, + username: String, + password: String + ): RPCConnection { + return log.logElapsedTime("Startup") { + val clientAddress = SimpleString("${RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX}.$username.${random63BitValue()}") + + val serverLocator = ActiveMQClient.createServerLocatorWithoutHA(transport).apply { + retryInterval = rpcConfiguration.connectionRetryInterval.toMillis() + retryIntervalMultiplier = rpcConfiguration.connectionRetryIntervalMultiplier + maxRetryInterval = rpcConfiguration.connectionMaxRetryInterval.toMillis() + minLargeMessageSize = rpcConfiguration.maxFileSize + } + + val proxyHandler = RPCClientProxyHandler(rpcConfiguration, username, password, serverLocator, clientAddress, rpcOpsClass) + proxyHandler.start() + + @Suppress("UNCHECKED_CAST") + val ops = Proxy.newProxyInstance(rpcOpsClass.classLoader, arrayOf(rpcOpsClass), proxyHandler) as I + + val serverProtocolVersion = ops.protocolVersion + if (serverProtocolVersion < rpcConfiguration.minimumServerProtocolVersion) { + throw RPCException("Requested minimum protocol version (${rpcConfiguration.minimumServerProtocolVersion}) is higher" + + " than the server's supported protocol version ($serverProtocolVersion)") + } + proxyHandler.setServerProtocolVersion(serverProtocolVersion) + + log.debug("RPC connected, returning proxy") + object : RPCConnection { + override val proxy = ops + override val serverProtocolVersion = serverProtocolVersion + override fun close() { + proxyHandler.close() + serverLocator.close() + } + } + } + } +} diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt new file mode 100644 index 0000000000..6309af41f8 --- /dev/null +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt @@ -0,0 +1,423 @@ +package net.corda.client.rpc.internal + +import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.Serializer +import com.esotericsoftware.kryo.io.Input +import com.esotericsoftware.kryo.io.Output +import com.esotericsoftware.kryo.pool.KryoPool +import com.google.common.cache.Cache +import com.google.common.cache.CacheBuilder +import com.google.common.cache.RemovalCause +import com.google.common.cache.RemovalListener +import com.google.common.util.concurrent.SettableFuture +import com.google.common.util.concurrent.ThreadFactoryBuilder +import net.corda.core.ThreadBox +import net.corda.core.getOrThrow +import net.corda.core.messaging.RPCOps +import net.corda.core.random63BitValue +import net.corda.core.serialization.KryoPoolWithContext +import net.corda.core.utilities.* +import net.corda.nodeapi.* +import org.apache.activemq.artemis.api.core.SimpleString +import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE +import org.apache.activemq.artemis.api.core.client.ClientMessage +import org.apache.activemq.artemis.api.core.client.ServerLocator +import rx.Notification +import rx.Observable +import rx.subjects.UnicastSubject +import sun.reflect.CallerSensitive +import java.lang.reflect.InvocationHandler +import java.lang.reflect.Method +import java.util.* +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.Executors +import java.util.concurrent.ScheduledFuture +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicInteger +import kotlin.collections.ArrayList +import kotlin.reflect.jvm.javaMethod + +/** + * This class provides a proxy implementation of an RPC interface for RPC clients. It translates API calls to lower-level + * RPC protocol messages. For this protocol see [RPCApi]. + * + * When a method is called on the interface the arguments are serialised and the request is forwarded to the server. The + * server then executes the code that implements the RPC and sends a reply. + * + * An RPC reply may contain [Observable]s, which are serialised simply as unique IDs. On the client side we create a + * [UnicastSubject] for each such ID. Subsequently the server may send observations attached to this ID, which are + * forwarded to the [UnicastSubject]. Note that the observations themselves may contain further [Observable]s, which are + * handled in the same way. + * + * To do the above we take advantage of Kryo's datastructure traversal. When the client is deserialising a message from + * the server that may contain Observables it is supplied with an [ObservableContext] that exposes the map used to demux + * the observations. When an [Observable] is encountered during traversal a new [UnicastSubject] is added to the map and + * we carry on. Each observation later contains the corresponding Observable ID, and we just forward that to the + * associated [UnicastSubject]. + * + * The client may signal that it no longer consumes a particular [Observable]. This may be done explicitly by + * unsubscribing from the [Observable], or if the [Observable] is garbage collected the client will eventually + * automatically signal the server. This is done using a cache that holds weak references to the [UnicastSubject]s. + * The cleanup happens in batches using a dedicated reaper, scheduled on [reaperExecutor]. + */ +class RPCClientProxyHandler( + private val rpcConfiguration: RPCClientConfiguration, + private val rpcUsername: String, + private val rpcPassword: String, + private val serverLocator: ServerLocator, + private val clientAddress: SimpleString, + private val rpcOpsClass: Class +) : InvocationHandler { + + private enum class State { + UNSTARTED, + SERVER_VERSION_NOT_SET, + STARTED, + FINISHED + } + private val lifeCycle = LifeCycle(State.UNSTARTED) + + private companion object { + val log = loggerFor() + // Note that this KryoPool is not yet capable of deserialising Observables, it requires Proxy-specific context + // to do that. However it may still be used for serialisation of RPC requests and related messages. + val kryoPool = KryoPool.Builder { RPCKryo(RpcClientObservableSerializer) }.build() + // To check whether toString() is being invoked + val toStringMethod: Method = Object::toString.javaMethod!! + } + + // Used for reaping + private val reaperExecutor = Executors.newScheduledThreadPool( + 1, + ThreadFactoryBuilder().setNameFormat("rpc-client-reaper-%d").build() + ) + + // A sticky pool for running Observable.onNext()s. We need the stickiness to preserve the observation ordering. + private val observationExecutorThreadFactory = ThreadFactoryBuilder().setNameFormat("rpc-client-observation-pool-%d").build() + private val observationExecutorPool = LazyStickyPool(rpcConfiguration.observationExecutorPoolSize) { + Executors.newFixedThreadPool(1, observationExecutorThreadFactory) + } + + // Holds the RPC reply futures. + private val rpcReplyMap = RpcReplyMap() + // Optionally holds RPC call site stack traces to be shown on errors/warnings. + private val callSiteMap = if (rpcConfiguration.trackRpcCallSites) CallSiteMap() else null + // Holds the Observables and a reference store to keep Observables alive when subscribed to. + private val observableContext = ObservableContext( + callSiteMap = callSiteMap, + observableMap = createRpcObservableMap(), + hardReferenceStore = Collections.synchronizedSet(mutableSetOf>()) + ) + // Holds a reference to the scheduled reaper. + private lateinit var reaperScheduledFuture: ScheduledFuture<*> + // The protocol version of the server, to be initialised to the value of [RPCOps.protocolVersion] + private var serverProtocolVersion: Int? = null + + // Stores the Observable IDs that are already removed from the map but are not yet sent to the server. + private val observablesToReap = ThreadBox(object { + var observables = ArrayList() + }) + // A Kryo pool that automatically adds the observable context when an instance is requested. + private val kryoPoolWithObservableContext = RpcClientObservableSerializer.createPoolWithContext(kryoPool, observableContext) + + private fun createRpcObservableMap(): RpcObservableMap { + val onObservableRemove = RemovalListener>> { + val rpcCallSite = callSiteMap?.remove(it.key.toLong) + if (it.cause == RemovalCause.COLLECTED) { + log.warn(listOf( + "A hot observable returned from an RPC was never subscribed to.", + "This wastes server-side resources because it was queueing observations for retrieval.", + "It is being closed now, but please adjust your code to call .notUsed() on the observable", + "to close it explicitly. (Java users: subscribe to it then unsubscribe). This warning", + "will appear less frequently in future versions of the platform and you can ignore it", + "if you want to.").joinToString(" "), rpcCallSite) + } + observablesToReap.locked { observables.add(it.key) } + } + return CacheBuilder.newBuilder(). + weakValues(). + removalListener(onObservableRemove). + concurrencyLevel(rpcConfiguration.cacheConcurrencyLevel). + build() + } + + // We cannot pool consumers as we need to preserve the original muxed message order. + // TODO We may need to pool these somehow anyway, otherwise if the server sends many big messages in parallel a + // single consumer may be starved for flow control credits. Recheck this once Artemis's large message streaming is + // integrated properly. + private lateinit var sessionAndConsumer: ArtemisConsumer + // Pool producers to reduce contention on the client side. + private val sessionAndProducerPool = LazyPool(bound = rpcConfiguration.producerPoolBound) { + // Note how we create new sessions *and* session factories per producer. + // We cannot simply pool producers on one session because sessions are single threaded. + // We cannot simply pool sessions on one session factory because flow control credits are tied to factories, so + // sessions tend to starve each other when used concurrently. + val sessionFactory = serverLocator.createSessionFactory() + val session = sessionFactory.createSession(rpcUsername, rpcPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE) + session.start() + ArtemisProducer(sessionFactory, session, session.createProducer(RPCApi.RPC_SERVER_QUEUE_NAME)) + } + + /** + * Start the client. This creates the per-client queue, starts the consumer session and the reaper. + */ + fun start() { + lifeCycle.transition(State.UNSTARTED, State.SERVER_VERSION_NOT_SET) + reaperScheduledFuture = reaperExecutor.scheduleAtFixedRate( + this::reapObservables, + rpcConfiguration.reapIntervalMs, + rpcConfiguration.reapIntervalMs, + TimeUnit.MILLISECONDS + ) + sessionAndProducerPool.run { + it.session.createTemporaryQueue(clientAddress, clientAddress) + } + val sessionFactory = serverLocator.createSessionFactory() + val session = sessionFactory.createSession(rpcUsername, rpcPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE) + val consumer = session.createConsumer(clientAddress) + consumer.setMessageHandler(this@RPCClientProxyHandler::artemisMessageHandler) + session.start() + sessionAndConsumer = ArtemisConsumer(sessionFactory, session, consumer) + } + + // This is the general function that transforms a client side RPC to internal Artemis messages. + @CallerSensitive + override fun invoke(proxy: Any, method: Method, arguments: Array?): Any? { + lifeCycle.requireState { it == State.STARTED || it == State.SERVER_VERSION_NOT_SET } + checkProtocolVersion(method) + if (method == toStringMethod) { + return "Client RPC proxy for $rpcOpsClass" + } + if (sessionAndConsumer.session.isClosed) { + throw RPCException("RPC Proxy is closed") + } + val rpcId = RPCApi.RpcRequestId(random63BitValue()) + callSiteMap?.set(rpcId.toLong, Throwable("")) + try { + val request = RPCApi.ClientToServer.RpcRequest(clientAddress, rpcId, method.name, arguments?.toList() ?: emptyList()) + val replyFuture = SettableFuture.create() + sessionAndProducerPool.run { + val message = it.session.createMessage(false) + request.writeToClientMessage(kryoPool, message) + + log.debug { + val argumentsString = arguments?.joinToString() ?: "" + "-> RPC($rpcId) -> ${method.name}($argumentsString): ${method.returnType}" + } + + require(rpcReplyMap.put(rpcId, replyFuture) == null) { + "Generated several RPC requests with same ID $rpcId" + } + it.producer.send(message) + it.session.commit() + } + return replyFuture.getOrThrow() + } finally { + callSiteMap?.remove(rpcId.toLong) + } + } + + // The handler for Artemis messages. + private fun artemisMessageHandler(message: ClientMessage) { + val serverToClient = RPCApi.ServerToClient.fromClientMessage(kryoPoolWithObservableContext, message) + log.debug { "Got message from RPC server $serverToClient" } + when (serverToClient) { + is RPCApi.ServerToClient.RpcReply -> { + val replyFuture = rpcReplyMap.remove(serverToClient.id) + if (replyFuture == null) { + log.error("RPC reply arrived to unknown RPC ID ${serverToClient.id}, this indicates an internal RPC error.") + } else { + val rpcCallSite = callSiteMap?.get(serverToClient.id.toLong) + serverToClient.result.match( + onError = { + if (rpcCallSite != null) addRpcCallSiteToThrowable(it, rpcCallSite) + replyFuture.setException(it) + }, + onValue = { replyFuture.set(it) } + ) + } + } + is RPCApi.ServerToClient.Observation -> { + val observable = observableContext.observableMap.getIfPresent(serverToClient.id) + if (observable == null) { + log.debug("Observation ${serverToClient.content} arrived to unknown Observable with ID ${serverToClient.id}. " + + "This may be due to an observation arriving before the server was " + + "notified of observable shutdown") + } else { + // We schedule the onNext() on an executor sticky-pooled based on the Observable ID. + observationExecutorPool.run(serverToClient.id) { executor -> + executor.submit { + val content = serverToClient.content + if (content.isOnCompleted || content.isOnError) { + observableContext.observableMap.invalidate(serverToClient.id) + } + // Add call site information on error + if (content.isOnError) { + val rpcCallSite = callSiteMap?.get(serverToClient.id.toLong) + if (rpcCallSite != null) addRpcCallSiteToThrowable(content.throwable, rpcCallSite) + } + observable.onNext(content) + } + } + } + } + } + message.acknowledge() + } + + /** + * Closes the RPC proxy. Reaps all observables, shuts down the reaper, closes all sessions and executors. + */ + fun close() { + lifeCycle.transition(State.STARTED, State.FINISHED) + sessionAndConsumer.consumer.close() + sessionAndConsumer.session.close() + sessionAndConsumer.sessionFactory.close() + reaperScheduledFuture.cancel(false) + observableContext.observableMap.invalidateAll() + reapObservables() + reaperExecutor.shutdownNow() + sessionAndProducerPool.close().forEach { + it.producer.close() + it.session.close() + it.sessionFactory.close() + } + // Note the ordering is important, we shut down the consumer *before* the observation executor, otherwise we may + // leak borrowed executors. + val observationExecutors = observationExecutorPool.close() + observationExecutors.forEach { it.shutdownNow() } + observationExecutors.forEach { it.awaitTermination(100, TimeUnit.MILLISECONDS) } + } + + /** + * Check the [RPCSinceVersion] of the passed in [calledMethod] against the server's protocol version. + */ + private fun checkProtocolVersion(calledMethod: Method) { + val serverProtocolVersion = serverProtocolVersion + if (serverProtocolVersion == null) { + lifeCycle.requireState(State.SERVER_VERSION_NOT_SET) + } else { + lifeCycle.requireState(State.STARTED) + val sinceVersion = calledMethod.getAnnotation(RPCSinceVersion::class.java)?.version ?: 0 + if (sinceVersion > serverProtocolVersion) { + throw UnsupportedOperationException("Method $calledMethod was added in RPC protocol version $sinceVersion but the server is running $serverProtocolVersion") + } + } + } + + /** + * Set the server's protocol version. Note that before doing so the client is not considered fully started, although + * RPCs already may be called with it. + */ + internal fun setServerProtocolVersion(version: Int) { + lifeCycle.transition(State.SERVER_VERSION_NOT_SET, State.STARTED) + if (serverProtocolVersion == null) { + serverProtocolVersion = version + } else { + throw IllegalStateException("setServerProtocolVersion called, but the protocol version was already set!") + } + } + + private fun reapObservables() { + observableContext.observableMap.cleanUp() + val observableIds = observablesToReap.locked { + if (observables.isNotEmpty()) { + val temporary = observables + observables = ArrayList() + temporary + } else { + null + } + } + if (observableIds != null) { + log.debug { "Reaping ${observableIds.size} observables" } + sessionAndProducerPool.run { + val message = it.session.createMessage(false) + RPCApi.ClientToServer.ObservablesClosed(observableIds).writeToClientMessage(message) + it.producer.send(message) + } + } + } +} + +private typealias RpcObservableMap = Cache>> +private typealias RpcReplyMap = ConcurrentHashMap> +private typealias CallSiteMap = ConcurrentHashMap + +/** + * Holds a context available during Kryo deserialisation of messages that are expected to contain Observables. + * + * @param observableMap holds the Observables that are ultimately exposed to the user. + * @param hardReferenceStore holds references to Observables we want to keep alive while they are subscribed to. + */ +private data class ObservableContext( + val callSiteMap: CallSiteMap?, + val observableMap: RpcObservableMap, + val hardReferenceStore: MutableSet> +) + +/** + * A [Serializer] to deserialise Observables once the corresponding Kryo instance has been provided with an [ObservableContext]. + */ +private object RpcClientObservableSerializer : Serializer>() { + private object RpcObservableContextKey + fun createPoolWithContext(kryoPool: KryoPool, observableContext: ObservableContext): KryoPool { + return KryoPoolWithContext(kryoPool, RpcObservableContextKey, observableContext) + } + + override fun read(kryo: Kryo, input: Input, type: Class>): Observable { + @Suppress("UNCHECKED_CAST") + val observableContext = kryo.context[RpcObservableContextKey] as ObservableContext + val observableId = RPCApi.ObservableId(input.readLong(true)) + val observable = UnicastSubject.create>() + require(observableContext.observableMap.getIfPresent(observableId) == null) { + "Multiple Observables arrived with the same ID $observableId" + } + val rpcCallSite = getRpcCallSite(kryo, observableContext) + observableContext.observableMap.put(observableId, observable) + observableContext.callSiteMap?.put(observableId.toLong, rpcCallSite) + // We pin all Observables into a hard reference store (rooted in the RPC proxy) on subscription so that users + // don't need to store a reference to the Observables themselves. + return observable.pinInSubscriptions(observableContext.hardReferenceStore).doOnUnsubscribe { + // This causes Future completions to give warnings because the corresponding OnComplete sent from the server + // will arrive after the client unsubscribes from the observable and consequently invalidates the mapping. + // The unsubscribe is due to [ObservableToFuture]'s use of first(). + observableContext.observableMap.invalidate(observableId) + }.dematerialize() + } + + override fun write(kryo: Kryo, output: Output, observable: Observable) { + throw UnsupportedOperationException("Cannot serialise Observables on the client side") + } + + private fun getRpcCallSite(kryo: Kryo, observableContext: ObservableContext): Throwable? { + val rpcRequestOrObservableId = kryo.context[RPCApi.RpcRequestOrObservableIdKey] as Long + return observableContext.callSiteMap?.get(rpcRequestOrObservableId) + } +} + +private fun addRpcCallSiteToThrowable(throwable: Throwable, callSite: Throwable) { + var currentThrowable = throwable + while (true) { + val cause = currentThrowable.cause + if (cause == null) { + currentThrowable.initCause(callSite) + break + } else { + currentThrowable = cause + } + } +} + +private fun Observable.pinInSubscriptions(hardReferenceStore: MutableSet>): Observable { + val refCount = AtomicInteger(0) + return this.doOnSubscribe { + if (refCount.getAndIncrement() == 0) { + require(hardReferenceStore.add(this)) { "Reference store already contained reference $this on add" } + } + }.doOnUnsubscribe { + if (refCount.decrementAndGet() == 0) { + require(hardReferenceStore.remove(this)) { "Reference store did not contain reference $this on remove" } + } + } +} diff --git a/client/rpc/src/test/kotlin/net/corda/client/rpc/AbstractClientRPCTest.kt b/client/rpc/src/test/kotlin/net/corda/client/rpc/AbstractClientRPCTest.kt deleted file mode 100644 index ae3420686e..0000000000 --- a/client/rpc/src/test/kotlin/net/corda/client/rpc/AbstractClientRPCTest.kt +++ /dev/null @@ -1,102 +0,0 @@ -package net.corda.client.rpc - -import net.corda.core.messaging.RPCOps -import net.corda.core.serialization.SerializedBytes -import net.corda.core.utilities.ALICE -import net.corda.core.utilities.LogHelper -import net.corda.node.services.RPCUserService -import net.corda.node.services.messaging.RPCDispatcher -import net.corda.node.utilities.AffinityExecutor -import net.corda.nodeapi.ArtemisMessagingComponent -import net.corda.nodeapi.User -import org.apache.activemq.artemis.api.core.Message -import org.apache.activemq.artemis.api.core.SimpleString -import org.apache.activemq.artemis.api.core.TransportConfiguration -import org.apache.activemq.artemis.api.core.client.ActiveMQClient -import org.apache.activemq.artemis.api.core.client.ClientMessage -import org.apache.activemq.artemis.api.core.client.ClientProducer -import org.apache.activemq.artemis.api.core.client.ClientSession -import org.apache.activemq.artemis.core.config.impl.ConfigurationImpl -import org.apache.activemq.artemis.core.remoting.impl.invm.InVMAcceptorFactory -import org.apache.activemq.artemis.core.remoting.impl.invm.InVMConnectorFactory -import org.apache.activemq.artemis.core.server.embedded.EmbeddedActiveMQ -import org.junit.After -import org.junit.Before -import java.util.* -import java.util.concurrent.locks.ReentrantLock - -abstract class AbstractClientRPCTest { - lateinit var artemis: EmbeddedActiveMQ - lateinit var serverSession: ClientSession - lateinit var clientSession: ClientSession - lateinit var producer: ClientProducer - lateinit var serverThread: AffinityExecutor.ServiceAffinityExecutor - - @Before - fun rpcSetup() { - // Set up an in-memory Artemis with an RPC requests queue. - artemis = EmbeddedActiveMQ() - artemis.setConfiguration(ConfigurationImpl().apply { - acceptorConfigurations = setOf(TransportConfiguration(InVMAcceptorFactory::class.java.name)) - isSecurityEnabled = false - isPersistenceEnabled = false - }) - artemis.start() - - val serverLocator = ActiveMQClient.createServerLocatorWithoutHA(TransportConfiguration(InVMConnectorFactory::class.java.name)) - val sessionFactory = serverLocator.createSessionFactory() - serverSession = sessionFactory.createSession() - serverSession.start() - - serverSession.createTemporaryQueue(ArtemisMessagingComponent.RPC_REQUESTS_QUEUE, ArtemisMessagingComponent.RPC_REQUESTS_QUEUE) - producer = serverSession.createProducer() - serverThread = AffinityExecutor.ServiceAffinityExecutor("unit-tests-rpc-dispatch-thread", 1) - serverSession.createTemporaryQueue("activemq.notifications", "rpc.qremovals", "_AMQ_NotifType = 'BINDING_REMOVED'") - - clientSession = sessionFactory.createSession() - clientSession.start() - - LogHelper.setLevel("+net.corda.rpc") - } - - @After - fun rpcShutdown() { - safeClose(producer) - clientSession.stop() - serverSession.stop() - artemis.stop() - serverThread.shutdownNow() - } - - fun rpcProxyFor(rpcUser: User, rpcImpl: T, type: Class): T { - val userService = object : RPCUserService { - override fun getUser(username: String): User? = if (username == rpcUser.username) rpcUser else null - override val users: List get() = listOf(rpcUser) - } - - val dispatcher = object : RPCDispatcher(rpcImpl, userService, ALICE.name) { - override fun send(data: SerializedBytes<*>, toAddress: String) { - val msg = serverSession.createMessage(false).apply { - writeBodyBufferBytes(data.bytes) - // Use the magic deduplication property built into Artemis as our message identity too - putStringProperty(Message.HDR_DUPLICATE_DETECTION_ID, SimpleString(UUID.randomUUID().toString())) - } - producer.send(toAddress, msg) - } - - override fun getUser(message: ClientMessage): User = rpcUser - } - - val serverNotifConsumer = serverSession.createConsumer("rpc.qremovals") - val serverConsumer = serverSession.createConsumer(ArtemisMessagingComponent.RPC_REQUESTS_QUEUE) - dispatcher.start(serverConsumer, serverNotifConsumer, serverThread) - return CordaRPCClientImpl(clientSession, ReentrantLock(), rpcUser.username).proxyFor(type) - } - - fun safeClose(obj: Any) { - try { - (obj as AutoCloseable).close() - } catch (e: Exception) { - } - } -} \ No newline at end of file diff --git a/client/rpc/src/test/kotlin/net/corda/client/rpc/AbstractRPCTest.kt b/client/rpc/src/test/kotlin/net/corda/client/rpc/AbstractRPCTest.kt new file mode 100644 index 0000000000..6139ad79fb --- /dev/null +++ b/client/rpc/src/test/kotlin/net/corda/client/rpc/AbstractRPCTest.kt @@ -0,0 +1,56 @@ +package net.corda.client.rpc + +import net.corda.client.rpc.internal.RPCClientConfiguration +import net.corda.core.flatMap +import net.corda.core.map +import net.corda.core.messaging.RPCOps +import net.corda.node.services.messaging.RPCServerConfiguration +import net.corda.nodeapi.User +import net.corda.testing.RPCDriverExposedDSLInterface +import net.corda.testing.rpcTestUser +import net.corda.testing.startInVmRpcClient +import net.corda.testing.startRpcClient +import org.apache.activemq.artemis.api.core.client.ClientSession +import org.junit.runners.Parameterized + +open class AbstractRPCTest { + enum class RPCTestMode { + InVm, + Netty + } + + companion object { + @JvmStatic @Parameterized.Parameters(name = "Mode = {0}") + fun defaultModes() = modes(RPCTestMode.InVm, RPCTestMode.Netty) + fun modes(vararg modes: RPCTestMode) = listOf(*modes).map { arrayOf(it) } + } + @Parameterized.Parameter + lateinit var mode: RPCTestMode + + data class TestProxy( + val ops: I, + val createSession: () -> ClientSession + ) + + inline fun RPCDriverExposedDSLInterface.testProxy( + ops: I, + rpcUser: User = rpcTestUser, + clientConfiguration: RPCClientConfiguration = RPCClientConfiguration.default, + serverConfiguration: RPCServerConfiguration = RPCServerConfiguration.default + ): TestProxy { + return when (mode) { + RPCTestMode.InVm -> + startInVmRpcServer(ops = ops, rpcUser = rpcUser, configuration = serverConfiguration).flatMap { + startInVmRpcClient(rpcUser.username, rpcUser.password, clientConfiguration).map { + TestProxy(it, { startInVmArtemisSession(rpcUser.username, rpcUser.password) }) + } + }.get() + RPCTestMode.Netty -> + startRpcServer(ops = ops, rpcUser = rpcUser, configuration = serverConfiguration).flatMap { server -> + startRpcClient(server.hostAndPort, rpcUser.username, rpcUser.password, clientConfiguration).map { + TestProxy(it, { startArtemisSession(server.hostAndPort, rpcUser.username, rpcUser.password) }) + } + }.get() + } + } +} diff --git a/client/rpc/src/test/kotlin/net/corda/client/rpc/ClientRPCInfrastructureTests.kt b/client/rpc/src/test/kotlin/net/corda/client/rpc/ClientRPCInfrastructureTests.kt index acc2fb4872..c9d3c65879 100644 --- a/client/rpc/src/test/kotlin/net/corda/client/rpc/ClientRPCInfrastructureTests.kt +++ b/client/rpc/src/test/kotlin/net/corda/client/rpc/ClientRPCInfrastructureTests.kt @@ -5,16 +5,16 @@ import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.SettableFuture import net.corda.core.getOrThrow import net.corda.core.messaging.RPCOps -import net.corda.core.messaging.RPCReturnsObservables import net.corda.core.success -import net.corda.nodeapi.CURRENT_RPC_USER +import net.corda.node.services.messaging.getRpcContext import net.corda.nodeapi.RPCSinceVersion -import net.corda.nodeapi.User -import org.apache.activemq.artemis.api.core.SimpleString +import net.corda.testing.RPCDriverExposedDSLInterface +import net.corda.testing.rpcDriver +import net.corda.testing.rpcTestUser import org.assertj.core.api.Assertions.assertThat -import org.junit.After -import org.junit.Before import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized import rx.Observable import rx.subjects.PublishSubject import java.util.concurrent.CountDownLatch @@ -23,22 +23,11 @@ import kotlin.test.assertEquals import kotlin.test.assertFailsWith import kotlin.test.assertTrue -class ClientRPCInfrastructureTests : AbstractClientRPCTest() { +@RunWith(Parameterized::class) +class ClientRPCInfrastructureTests : AbstractRPCTest() { // TODO: Test that timeouts work - lateinit var proxy: TestOps - - private val authenticatedUser = User("test", "password", permissions = setOf()) - - @Before - fun setup() { - proxy = rpcProxyFor(authenticatedUser, TestOpsImpl(), TestOps::class.java) - } - - @After - fun shutdown() { - safeClose(proxy) - } + private fun RPCDriverExposedDSLInterface.testProxy() = testProxy(TestOpsImpl()).ops interface TestOps : RPCOps { @Throws(IllegalArgumentException::class) @@ -48,16 +37,12 @@ class ClientRPCInfrastructureTests : AbstractClientRPCTest() { fun someCalculation(str: String, num: Int): String - @RPCReturnsObservables fun makeObservable(): Observable - @RPCReturnsObservables fun makeComplicatedObservable(): Observable>> - @RPCReturnsObservables fun makeListenableFuture(): ListenableFuture - @RPCReturnsObservables fun makeComplicatedListenableFuture(): ListenableFuture>> @RPCSinceVersion(2) @@ -78,117 +63,130 @@ class ClientRPCInfrastructureTests : AbstractClientRPCTest() { override fun makeListenableFuture(): ListenableFuture = Futures.immediateFuture(1) override fun makeComplicatedObservable() = complicatedObservable override fun makeComplicatedListenableFuture(): ListenableFuture>> = complicatedListenableFuturee - override fun addedLater(): Unit = throw UnsupportedOperationException("not implemented") - override fun captureUser(): String = CURRENT_RPC_USER.get().username + override fun addedLater(): Unit = throw IllegalStateException() + override fun captureUser(): String = getRpcContext().currentUser.username } @Test fun `simple RPCs`() { - // Does nothing, doesn't throw. - proxy.void() + rpcDriver { + val proxy = testProxy() + // Does nothing, doesn't throw. + proxy.void() - assertEquals("Barf!", assertFailsWith { - proxy.barf() - }.message) + assertEquals("Barf!", assertFailsWith { + proxy.barf() + }.message) - assertEquals("hi 5", proxy.someCalculation("hi", 5)) + assertEquals("hi 5", proxy.someCalculation("hi", 5)) + } } @Test fun `simple observable`() { - // This tests that the observations are transmitted correctly, also completion is transmitted. - val observations = proxy.makeObservable().toBlocking().toIterable().toList() - assertEquals(listOf(1, 2, 3, 4), observations) + rpcDriver { + val proxy = testProxy() + // This tests that the observations are transmitted correctly, also completion is transmitted. + val observations = proxy.makeObservable().toBlocking().toIterable().toList() + assertEquals(listOf(1, 2, 3, 4), observations) + } } @Test fun `complex observables`() { - // This checks that we can return an object graph with complex usage of observables, like an observable - // that emits objects that contain more observables. - val serverQuotes = PublishSubject.create>>() - val unsubscribeLatch = CountDownLatch(1) - complicatedObservable = serverQuotes.asObservable().doOnUnsubscribe { unsubscribeLatch.countDown() } + rpcDriver { + val proxy = testProxy() + // This checks that we can return an object graph with complex usage of observables, like an observable + // that emits objects that contain more observables. + val serverQuotes = PublishSubject.create>>() + val unsubscribeLatch = CountDownLatch(1) + complicatedObservable = serverQuotes.asObservable().doOnUnsubscribe { unsubscribeLatch.countDown() } - val twainQuotes = "Mark Twain" to Observable.just( - "I have never let my schooling interfere with my education.", - "Clothes make the man. Naked people have little or no influence on society." - ) - val wildeQuotes = "Oscar Wilde" to Observable.just( - "I can resist everything except temptation.", - "Always forgive your enemies - nothing annoys them so much." - ) + val twainQuotes = "Mark Twain" to Observable.just( + "I have never let my schooling interfere with my education.", + "Clothes make the man. Naked people have little or no influence on society." + ) + val wildeQuotes = "Oscar Wilde" to Observable.just( + "I can resist everything except temptation.", + "Always forgive your enemies - nothing annoys them so much." + ) - val clientQuotes = LinkedBlockingQueue() - val clientObs = proxy.makeComplicatedObservable() + val clientQuotes = LinkedBlockingQueue() + val clientObs = proxy.makeComplicatedObservable() - val subscription = clientObs.subscribe { - val name = it.first - it.second.subscribe { - clientQuotes += "Quote by $name: $it" + val subscription = clientObs.subscribe { + val name = it.first + it.second.subscribe { + clientQuotes += "Quote by $name: $it" + } } + + assertThat(clientQuotes).isEmpty() + + serverQuotes.onNext(twainQuotes) + assertEquals("Quote by Mark Twain: I have never let my schooling interfere with my education.", clientQuotes.take()) + assertEquals("Quote by Mark Twain: Clothes make the man. Naked people have little or no influence on society.", clientQuotes.take()) + + serverQuotes.onNext(wildeQuotes) + assertEquals("Quote by Oscar Wilde: I can resist everything except temptation.", clientQuotes.take()) + assertEquals("Quote by Oscar Wilde: Always forgive your enemies - nothing annoys them so much.", clientQuotes.take()) + + assertTrue(serverQuotes.hasObservers()) + subscription.unsubscribe() + unsubscribeLatch.await() } - - val rpcQueuesQuery = SimpleString("clients.${authenticatedUser.username}.rpc.*") - assertEquals(2, clientSession.addressQuery(rpcQueuesQuery).queueNames.size) - - assertThat(clientQuotes).isEmpty() - - serverQuotes.onNext(twainQuotes) - assertEquals("Quote by Mark Twain: I have never let my schooling interfere with my education.", clientQuotes.take()) - assertEquals("Quote by Mark Twain: Clothes make the man. Naked people have little or no influence on society.", clientQuotes.take()) - - serverQuotes.onNext(wildeQuotes) - assertEquals("Quote by Oscar Wilde: I can resist everything except temptation.", clientQuotes.take()) - assertEquals("Quote by Oscar Wilde: Always forgive your enemies - nothing annoys them so much.", clientQuotes.take()) - - assertTrue(serverQuotes.hasObservers()) - subscription.unsubscribe() - unsubscribeLatch.await() - assertEquals(1, clientSession.addressQuery(rpcQueuesQuery).queueNames.size) } @Test fun `simple ListenableFuture`() { - val value = proxy.makeListenableFuture().getOrThrow() - assertThat(value).isEqualTo(1) + rpcDriver { + val proxy = testProxy() + val value = proxy.makeListenableFuture().getOrThrow() + assertThat(value).isEqualTo(1) + } } @Test fun `complex ListenableFuture`() { - val serverQuote = SettableFuture.create>>() - complicatedListenableFuturee = serverQuote + rpcDriver { + val proxy = testProxy() + val serverQuote = SettableFuture.create>>() + complicatedListenableFuturee = serverQuote - val twainQuote = "Mark Twain" to Futures.immediateFuture("I have never let my schooling interfere with my education.") + val twainQuote = "Mark Twain" to Futures.immediateFuture("I have never let my schooling interfere with my education.") - val clientQuotes = LinkedBlockingQueue() - val clientFuture = proxy.makeComplicatedListenableFuture() + val clientQuotes = LinkedBlockingQueue() + val clientFuture = proxy.makeComplicatedListenableFuture() - clientFuture.success { - val name = it.first - it.second.success { - clientQuotes += "Quote by $name: $it" + clientFuture.success { + val name = it.first + it.second.success { + clientQuotes += "Quote by $name: $it" + } } + + assertThat(clientQuotes).isEmpty() + + serverQuote.set(twainQuote) + assertThat(clientQuotes.take()).isEqualTo("Quote by Mark Twain: I have never let my schooling interfere with my education.") + + // TODO This final assert sometimes fails because the relevant queue hasn't been removed yet } - - val rpcQueuesQuery = SimpleString("clients.${authenticatedUser.username}.rpc.*") - assertEquals(2, clientSession.addressQuery(rpcQueuesQuery).queueNames.size) - - assertThat(clientQuotes).isEmpty() - - serverQuote.set(twainQuote) - assertThat(clientQuotes.take()).isEqualTo("Quote by Mark Twain: I have never let my schooling interfere with my education.") - - // TODO This final assert sometimes fails because the relevant queue hasn't been removed yet -// assertEquals(1, clientSession.addressQuery(rpcQueuesQuery).queueNames.size) } @Test fun versioning() { - assertFailsWith { proxy.addedLater() } + rpcDriver { + val proxy = testProxy() + assertFailsWith { proxy.addedLater() } + } } @Test fun `authenticated user is available to RPC`() { - assertThat(proxy.captureUser()).isEqualTo(authenticatedUser.username) + rpcDriver { + val proxy = testProxy() + assertThat(proxy.captureUser()).isEqualTo(rpcTestUser.username) + } } } diff --git a/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCConcurrencyTests.kt b/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCConcurrencyTests.kt new file mode 100644 index 0000000000..2e563bc40c --- /dev/null +++ b/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCConcurrencyTests.kt @@ -0,0 +1,194 @@ +package net.corda.client.rpc + +import com.google.common.util.concurrent.Futures +import com.google.common.util.concurrent.ListenableFuture +import net.corda.client.rpc.internal.RPCClientConfiguration +import net.corda.core.future +import net.corda.core.messaging.RPCOps +import net.corda.core.random63BitValue +import net.corda.core.serialization.CordaSerializable +import net.corda.core.utilities.loggerFor +import net.corda.node.driver.poll +import net.corda.node.services.messaging.RPCServerConfiguration +import net.corda.nodeapi.RPCApi +import net.corda.testing.RPCDriverExposedDSLInterface +import net.corda.testing.rpcDriver +import net.corda.testing.startRandomRpcClient +import net.corda.testing.startRpcClient +import org.apache.activemq.artemis.api.core.SimpleString +import org.junit.Ignore +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import rx.Observable +import rx.subjects.PublishSubject +import rx.subjects.UnicastSubject +import java.util.* +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.CountDownLatch +import java.util.concurrent.Executors +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicInteger + +@RunWith(Parameterized::class) +class RPCConcurrencyTests : AbstractRPCTest() { + + /** + * Holds a "rose"-tree of [Observable]s which allows us to test arbitrary [Observable] nesting in RPC replies. + */ + @CordaSerializable + data class ObservableRose(val value: A, val branches: Observable>) + + private interface TestOps : RPCOps { + fun newLatch(numberOfDowns: Int): Long + fun waitLatch(id: Long) + fun downLatch(id: Long) + fun getImmediateObservableTree(depth: Int, branchingFactor: Int): ObservableRose + fun getParallelObservableTree(depth: Int, branchingFactor: Int): ObservableRose + } + + class TestOpsImpl : TestOps { + private val latches = ConcurrentHashMap() + override val protocolVersion = 0 + + override fun newLatch(numberOfDowns: Int): Long { + val id = random63BitValue() + val latch = CountDownLatch(numberOfDowns) + latches.put(id, latch) + return id + } + + override fun waitLatch(id: Long) { + latches[id]!!.await() + } + + override fun downLatch(id: Long) { + latches[id]!!.countDown() + } + + override fun getImmediateObservableTree(depth: Int, branchingFactor: Int): ObservableRose { + val branches = if (depth == 0) { + Observable.empty>() + } else { + Observable.just(getImmediateObservableTree(depth - 1, branchingFactor)).repeat(branchingFactor.toLong()) + } + return ObservableRose(depth, branches) + } + + override fun getParallelObservableTree(depth: Int, branchingFactor: Int): ObservableRose { + val branches = if (depth == 0) { + Observable.empty>() + } else { + val publish = UnicastSubject.create>() + future { + (1..branchingFactor).toList().parallelStream().forEach { + publish.onNext(getParallelObservableTree(depth - 1, branchingFactor)) + } + publish.onCompleted() + } + publish + } + return ObservableRose(depth, branches) + } + } + + private lateinit var testOpsImpl: TestOpsImpl + private fun RPCDriverExposedDSLInterface.testProxy(): TestProxy { + testOpsImpl = TestOpsImpl() + return testProxy( + testOpsImpl, + clientConfiguration = RPCClientConfiguration.default.copy( + reapIntervalMs = 100, + cacheConcurrencyLevel = 16 + ), + serverConfiguration = RPCServerConfiguration.default.copy( + rpcThreadPoolSize = 4 + ) + ) + } + + @Test + fun `call multiple RPCs in parallel`() { + rpcDriver { + val proxy = testProxy() + val numberOfBlockedCalls = 2 + val numberOfDownsRequired = 100 + val id = proxy.ops.newLatch(numberOfDownsRequired) + val done = CountDownLatch(numberOfBlockedCalls) + // Start a couple of blocking RPC calls + (1..numberOfBlockedCalls).forEach { + future { + proxy.ops.waitLatch(id) + done.countDown() + } + } + // Down the latch that the others are waiting for concurrently + (1..numberOfDownsRequired).toList().parallelStream().forEach { + proxy.ops.downLatch(id) + } + done.await() + } + } + + private fun intPower(base: Int, power: Int): Int { + return when (power) { + 0 -> 1 + 1 -> base + else -> { + val a = intPower(base, power / 2) + if (power and 1 == 0) { + a * a + } else { + a * a * base + } + } + } + } + + @Test + fun `nested immediate observables sequence correctly`() { + rpcDriver { + // We construct a rose tree of immediate Observables and check that parent observations arrive before children. + val proxy = testProxy() + val treeDepth = 6 + val treeBranchingFactor = 3 + val remainingLatch = CountDownLatch((intPower(treeBranchingFactor, treeDepth + 1) - 1) / (treeBranchingFactor - 1)) + val depthsSeen = Collections.synchronizedSet(HashSet()) + fun ObservableRose.subscribeToAll() { + remainingLatch.countDown() + this.branches.subscribe { tree -> + (tree.value + 1..treeDepth - 1).forEach { + require(it in depthsSeen) { "Got ${tree.value} before $it" } + } + depthsSeen.add(tree.value) + tree.subscribeToAll() + } + } + proxy.ops.getImmediateObservableTree(treeDepth, treeBranchingFactor).subscribeToAll() + remainingLatch.await() + } + } + + @Test + fun `parallel nested observables`() { + rpcDriver { + val proxy = testProxy() + val treeDepth = 2 + val treeBranchingFactor = 10 + val remainingLatch = CountDownLatch((intPower(treeBranchingFactor, treeDepth + 1) - 1) / (treeBranchingFactor - 1)) + val depthsSeen = Collections.synchronizedSet(HashSet()) + fun ObservableRose.subscribeToAll() { + remainingLatch.countDown() + branches.subscribe { tree -> + (tree.value + 1..treeDepth - 1).forEach { + require(it in depthsSeen) { "Got ${tree.value} before $it" } + } + depthsSeen.add(tree.value) + tree.subscribeToAll() + } + } + proxy.ops.getParallelObservableTree(treeDepth, treeBranchingFactor).subscribeToAll() + remainingLatch.await() + } + } +} \ No newline at end of file diff --git a/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPerformanceTests.kt b/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPerformanceTests.kt new file mode 100644 index 0000000000..fa804dd2c3 --- /dev/null +++ b/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPerformanceTests.kt @@ -0,0 +1,315 @@ +package net.corda.client.rpc + +import com.codahale.metrics.Gauge +import com.codahale.metrics.JmxReporter +import com.codahale.metrics.MetricRegistry +import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.Serializer +import com.esotericsoftware.kryo.io.Input +import com.esotericsoftware.kryo.io.Output +import com.esotericsoftware.kryo.pool.KryoPool +import com.google.common.base.Stopwatch +import net.corda.client.rpc.internal.RPCClientConfiguration +import net.corda.core.messaging.RPCOps +import net.corda.core.millis +import net.corda.core.random63BitValue +import net.corda.node.driver.ShutdownManager +import net.corda.node.services.messaging.RPCServerConfiguration +import net.corda.nodeapi.RPCApi +import net.corda.nodeapi.RPCKryo +import net.corda.testing.RPCDriverExposedDSLInterface +import net.corda.testing.measure +import net.corda.testing.rpcDriver +import org.apache.activemq.artemis.api.core.SimpleString +import org.junit.Ignore +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import rx.Observable +import java.time.Duration +import java.util.* +import java.util.concurrent.* +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.locks.ReentrantLock +import javax.management.ObjectName +import kotlin.concurrent.thread +import kotlin.concurrent.withLock + +@Ignore("Only use this locally for profiling") +@RunWith(Parameterized::class) +class RPCPerformanceTests : AbstractRPCTest() { + companion object { + @JvmStatic @Parameterized.Parameters(name = "Mode = {0}") + fun modes() = modes(RPCTestMode.Netty) + } + private interface TestOps : RPCOps { + fun simpleReply(input: ByteArray, sizeOfReply: Int): ByteArray + } + + class TestOpsImpl : TestOps { + override val protocolVersion = 0 + override fun simpleReply(input: ByteArray, sizeOfReply: Int): ByteArray { + return ByteArray(sizeOfReply) + } + } + + private fun RPCDriverExposedDSLInterface.testProxy( + clientConfiguration: RPCClientConfiguration, + serverConfiguration: RPCServerConfiguration + ): TestProxy { + return testProxy( + TestOpsImpl(), + clientConfiguration = clientConfiguration, + serverConfiguration = serverConfiguration + ) + } + + private fun warmup() { + rpcDriver { + val proxy = testProxy( + RPCClientConfiguration.default, + RPCServerConfiguration.default + ) + val executor = Executors.newFixedThreadPool(4) + val N = 10000 + val latch = CountDownLatch(N) + for (i in 1 .. N) { + executor.submit { + proxy.ops.simpleReply(ByteArray(1024), 1024) + latch.countDown() + } + } + latch.await() + } + } + + data class SimpleRPCResult( + val requestPerSecond: Double, + val averageIndividualMs: Double, + val Mbps: Double + ) + @Test + fun `measure Megabytes per second for simple RPCs`() { + warmup() + val inputOutputSizes = listOf(1024, 4096, 100 * 1024) + val overallTraffic = 512 * 1024 * 1024L + measure(inputOutputSizes, (1..5)) { inputOutputSize, N -> + rpcDriver { + val proxy = testProxy( + RPCClientConfiguration.default.copy( + cacheConcurrencyLevel = 16, + observationExecutorPoolSize = 2, + producerPoolBound = 2 + ), + RPCServerConfiguration.default.copy( + rpcThreadPoolSize = 8, + consumerPoolSize = 2, + producerPoolBound = 8 + ) + ) + + val numberOfRequests = overallTraffic / (2 * inputOutputSize) + val timings = Collections.synchronizedList(ArrayList()) + val executor = Executors.newFixedThreadPool(8) + val totalElapsed = Stopwatch.createStarted().apply { + startInjectorWithBoundedQueue( + executor = executor, + numberOfInjections = numberOfRequests.toInt(), + queueBound = 100 + ) { + val elapsed = Stopwatch.createStarted().apply { + proxy.ops.simpleReply(ByteArray(inputOutputSize), inputOutputSize) + }.stop().elapsed(TimeUnit.MICROSECONDS) + timings.add(elapsed) + } + }.stop().elapsed(TimeUnit.MICROSECONDS) + executor.shutdownNow() + SimpleRPCResult( + requestPerSecond = 1000000.0 * numberOfRequests.toDouble() / totalElapsed.toDouble(), + averageIndividualMs = timings.average() / 1000.0, + Mbps = (overallTraffic.toDouble() / totalElapsed.toDouble()) * (1000000.0 / (1024.0 * 1024.0)) + ) + } + }.forEach(::println) + } + + /** + * Runs 20k RPCs per second for two minutes and publishes relevant stats to JMX. + */ + @Test + fun `consumption rate`() { + rpcDriver { + val metricRegistry = startJmxReporter() + val proxy = testProxy( + RPCClientConfiguration.default.copy( + reapIntervalMs = 100, + cacheConcurrencyLevel = 16 + ), + RPCServerConfiguration.default.copy( + rpcThreadPoolSize = 4, + consumerPoolSize = 4, + producerPoolBound = 4 + ) + ) + measurePerformancePublishMetrics( + metricRegistry = metricRegistry, + parallelism = 4, + overallDurationSecond = 120.0, + injectionRatePerSecond = 20000.0, + queueSizeMetricName = "$mode.QueueSize", + workDurationMetricName = "$mode.WorkDuration", + shutdownManager = this.shutdownManager, + work = { + proxy.ops.simpleReply(ByteArray(4096), 4096) + } + ) + } + } + + data class BigMessagesResult( + val Mbps: Double + ) + @Test + fun `big messages`() { + warmup() + measure(listOf(1)) { clientParallelism -> // TODO this hangs with more parallelism + rpcDriver { + val proxy = testProxy( + RPCClientConfiguration.default, + RPCServerConfiguration.default.copy( + consumerPoolSize = 1 + ) + ) + val executor = Executors.newFixedThreadPool(clientParallelism) + val numberOfMessages = 1000 + val bigSize = 10_000_000 + val elapsed = Stopwatch.createStarted().apply { + startInjectorWithBoundedQueue( + executor = executor, + numberOfInjections = numberOfMessages, + queueBound = 4 + ) { + proxy.ops.simpleReply(ByteArray(bigSize), 0) + } + }.stop().elapsed(TimeUnit.MICROSECONDS) + executor.shutdownNow() + BigMessagesResult( + Mbps = bigSize.toDouble() * numberOfMessages.toDouble() / elapsed * (1000000.0 / (1024.0 * 1024.0)) + ) + } + }.forEach(::println) + } +} + +fun measurePerformancePublishMetrics( + metricRegistry: MetricRegistry, + parallelism: Int, + overallDurationSecond: Double, + injectionRatePerSecond: Double, + queueSizeMetricName: String, + workDurationMetricName: String, + shutdownManager: ShutdownManager, + work: () -> Unit +) { + val workSemaphore = Semaphore(0) + metricRegistry.register(queueSizeMetricName, Gauge { workSemaphore.availablePermits() }) + val workDurationTimer = metricRegistry.timer(workDurationMetricName) + val executor = Executors.newSingleThreadScheduledExecutor() + val workExecutor = Executors.newFixedThreadPool(parallelism) + val timings = Collections.synchronizedList(ArrayList()) + for (i in 1 .. parallelism) { + workExecutor.submit { + try { + while (true) { + workSemaphore.acquire() + workDurationTimer.time { + timings.add( + Stopwatch.createStarted().apply { + work() + }.stop().elapsed(TimeUnit.MICROSECONDS) + ) + } + } + } catch (throwable: Throwable) { + throwable.printStackTrace() + } + } + } + val injector = executor.scheduleAtFixedRate( + { + workSemaphore.release(injectionRatePerSecond.toInt()) + }, + 0, + 1, + TimeUnit.SECONDS + ) + shutdownManager.registerShutdown { + injector.cancel(true) + workExecutor.shutdownNow() + executor.shutdownNow() + workExecutor.awaitTermination(1, TimeUnit.SECONDS) + executor.awaitTermination(1, TimeUnit.SECONDS) + } + Thread.sleep((overallDurationSecond * 1000).toLong()) +} + +fun startInjectorWithBoundedQueue( + executor: ExecutorService, + numberOfInjections: Int, + queueBound: Int, + work: () -> Unit +) { + val remainingLatch = CountDownLatch(numberOfInjections) + val queuedCount = AtomicInteger(0) + val lock = ReentrantLock() + val canQueueAgain = lock.newCondition() + val injectorShutdown = AtomicBoolean(false) + val injector = thread(name = "injector") { + while (true) { + if (injectorShutdown.get()) break + executor.submit { + work() + if (queuedCount.decrementAndGet() < queueBound / 2) { + lock.withLock { + canQueueAgain.signal() + } + } + remainingLatch.countDown() + } + if (queuedCount.incrementAndGet() > queueBound) { + lock.withLock { + canQueueAgain.await() + } + } + } + } + remainingLatch.await() + injectorShutdown.set(true) + injector.join() +} + +fun RPCDriverExposedDSLInterface.startJmxReporter(): MetricRegistry { + val metricRegistry = MetricRegistry() + val jmxReporter = thread { + JmxReporter. + forRegistry(metricRegistry). + inDomain("net.corda"). + createsObjectNamesWith { _, domain, name -> + // Make the JMX hierarchy a bit better organised. + val category = name.substringBefore('.') + val subName = name.substringAfter('.', "") + if (subName == "") + ObjectName("$domain:name=$category") + else + ObjectName("$domain:type=$category,name=$subName") + }. + build(). + start() + } + shutdownManager.registerShutdown { + jmxReporter.interrupt() + jmxReporter.join() + } + return metricRegistry +} diff --git a/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPermissionsTest.kt b/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPermissionsTest.kt deleted file mode 100644 index d00e594b53..0000000000 --- a/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPermissionsTest.kt +++ /dev/null @@ -1,85 +0,0 @@ -package net.corda.client.rpc - -import net.corda.core.messaging.RPCOps -import net.corda.node.services.messaging.requirePermission -import net.corda.nodeapi.PermissionException -import net.corda.nodeapi.User -import org.junit.After -import org.junit.Test -import kotlin.test.assertFailsWith - -class RPCPermissionsTest : AbstractClientRPCTest() { - companion object { - const val DUMMY_FLOW = "StartFlow.net.corda.flows.DummyFlow" - const val OTHER_FLOW = "StartFlow.net.corda.flows.OtherFlow" - const val ALL_ALLOWED = "ALL" - } - - lateinit var proxy: TestOps - - @After - fun shutdown() { - safeClose(proxy) - } - - /* - * RPC operation. - */ - interface TestOps : RPCOps { - fun validatePermission(str: String) - } - - class TestOpsImpl : TestOps { - override val protocolVersion = 1 - override fun validatePermission(str: String) = requirePermission(str) - } - - /** - * Create an RPC proxy for the given user. - */ - private fun proxyFor(rpcUser: User): TestOps = rpcProxyFor(rpcUser, TestOpsImpl(), TestOps::class.java) - - private fun userOf(name: String, permissions: Set) = User(name, "password", permissions) - - @Test - fun `empty user cannot use any flows`() { - val emptyUser = userOf("empty", emptySet()) - proxy = proxyFor(emptyUser) - assertFailsWith(PermissionException::class, - "User ${emptyUser.username} should not be allowed to use $DUMMY_FLOW.", - { proxy.validatePermission(DUMMY_FLOW) }) - } - - @Test - fun `admin user can use any flow`() { - val adminUser = userOf("admin", setOf(ALL_ALLOWED)) - proxy = proxyFor(adminUser) - proxy.validatePermission(DUMMY_FLOW) - } - - @Test - fun `joe user is allowed to use DummyFlow`() { - val joeUser = userOf("joe", setOf(DUMMY_FLOW)) - proxy = proxyFor(joeUser) - proxy.validatePermission(DUMMY_FLOW) - } - - @Test - fun `joe user is not allowed to use OtherFlow`() { - val joeUser = userOf("joe", setOf(DUMMY_FLOW)) - proxy = proxyFor(joeUser) - assertFailsWith(PermissionException::class, - "User ${joeUser.username} should not be allowed to use $OTHER_FLOW", - { proxy.validatePermission(OTHER_FLOW) }) - } - - @Test - fun `check ALL is implemented the correct way round`() { - val joeUser = userOf("joe", setOf(DUMMY_FLOW)) - proxy = proxyFor(joeUser) - assertFailsWith(PermissionException::class, - "Permission $ALL_ALLOWED should not do anything for User ${joeUser.username}", - { proxy.validatePermission(ALL_ALLOWED) }) - } - -} diff --git a/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPermissionsTests.kt b/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPermissionsTests.kt new file mode 100644 index 0000000000..ebc9cef461 --- /dev/null +++ b/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPermissionsTests.kt @@ -0,0 +1,93 @@ +package net.corda.client.rpc + +import net.corda.core.messaging.RPCOps +import net.corda.node.services.messaging.requirePermission +import net.corda.node.services.messaging.getRpcContext +import net.corda.nodeapi.PermissionException +import net.corda.nodeapi.User +import net.corda.testing.RPCDriverExposedDSLInterface +import net.corda.testing.rpcDriver +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import kotlin.test.assertFailsWith + +@RunWith(Parameterized::class) +class RPCPermissionsTests : AbstractRPCTest() { + companion object { + const val DUMMY_FLOW = "StartFlow.net.corda.flows.DummyFlow" + const val OTHER_FLOW = "StartFlow.net.corda.flows.OtherFlow" + const val ALL_ALLOWED = "ALL" + } + + /* + * RPC operation. + */ + interface TestOps : RPCOps { + fun validatePermission(str: String) + } + + class TestOpsImpl : TestOps { + override val protocolVersion = 1 + override fun validatePermission(str: String) = getRpcContext().requirePermission(str) + } + + /** + * Create an RPC proxy for the given user. + */ + private fun RPCDriverExposedDSLInterface.testProxyFor(rpcUser: User) = testProxy(TestOpsImpl(), rpcUser).ops + + private fun userOf(name: String, permissions: Set) = User(name, "password", permissions) + + @Test + fun `empty user cannot use any flows`() { + rpcDriver { + val emptyUser = userOf("empty", emptySet()) + val proxy = testProxyFor(emptyUser) + assertFailsWith(PermissionException::class, + "User ${emptyUser.username} should not be allowed to use $DUMMY_FLOW.", + { proxy.validatePermission(DUMMY_FLOW) }) + } + } + + @Test + fun `admin user can use any flow`() { + rpcDriver { + val adminUser = userOf("admin", setOf(ALL_ALLOWED)) + val proxy = testProxyFor(adminUser) + proxy.validatePermission(DUMMY_FLOW) + } + } + + @Test + fun `joe user is allowed to use DummyFlow`() { + rpcDriver { + val joeUser = userOf("joe", setOf(DUMMY_FLOW)) + val proxy = testProxyFor(joeUser) + proxy.validatePermission(DUMMY_FLOW) + } + } + + @Test + fun `joe user is not allowed to use OtherFlow`() { + rpcDriver { + val joeUser = userOf("joe", setOf(DUMMY_FLOW)) + val proxy = testProxyFor(joeUser) + assertFailsWith(PermissionException::class, + "User ${joeUser.username} should not be allowed to use $OTHER_FLOW", + { proxy.validatePermission(OTHER_FLOW) }) + } + } + + @Test + fun `check ALL is implemented the correct way round` () { + rpcDriver { + val joeUser = userOf("joe", setOf(DUMMY_FLOW)) + val proxy = testProxyFor(joeUser) + assertFailsWith(PermissionException::class, + "Permission $ALL_ALLOWED should not do anything for User ${joeUser.username}", + { proxy.validatePermission(ALL_ALLOWED) }) + } + } + +} diff --git a/client/rpc/src/test/kotlin/net/corda/client/rpc/RepeatingBytesInputStream.kt b/client/rpc/src/test/kotlin/net/corda/client/rpc/RepeatingBytesInputStream.kt new file mode 100644 index 0000000000..06ed23f1bb --- /dev/null +++ b/client/rpc/src/test/kotlin/net/corda/client/rpc/RepeatingBytesInputStream.kt @@ -0,0 +1,25 @@ +package net.corda.client.rpc + +import java.io.InputStream + +class RepeatingBytesInputStream(val bytesToRepeat: ByteArray, val numberOfBytes: Int) : InputStream() { + private var bytesLeft = numberOfBytes + override fun available() = bytesLeft + override fun read(): Int { + if (bytesLeft == 0) { + return -1 + } else { + bytesLeft-- + return bytesToRepeat[(numberOfBytes - bytesLeft) % bytesToRepeat.size].toInt() + } + } + override fun read(byteArray: ByteArray, offset: Int, length: Int): Int { + val until = Math.min(Math.min(offset + length, byteArray.size), offset + bytesLeft) + for (i in offset .. until - 1) { + byteArray[i] = bytesToRepeat[(numberOfBytes - bytesLeft + i - offset) % bytesToRepeat.size] + } + val bytesRead = until - offset + bytesLeft -= bytesRead + return if (bytesRead == 0 && bytesLeft == 0) -1 else bytesRead + } +} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/Utils.kt b/core/src/main/kotlin/net/corda/core/Utils.kt index d9fdbbcf6f..6bdb45d83e 100644 --- a/core/src/main/kotlin/net/corda/core/Utils.kt +++ b/core/src/main/kotlin/net/corda/core/Utils.kt @@ -404,6 +404,8 @@ data class ErrorOr private constructor(val value: A?, val error: Throwabl ErrorOr.of(error) } } + + fun mapError(function: (Throwable) -> Throwable) = ErrorOr(value, error?.let(function)) } /** diff --git a/core/src/main/kotlin/net/corda/core/serialization/Kryo.kt b/core/src/main/kotlin/net/corda/core/serialization/Kryo.kt index bdd23b170e..95b82c118d 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/Kryo.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/Kryo.kt @@ -3,6 +3,7 @@ package net.corda.core.serialization import com.esotericsoftware.kryo.* import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Output +import com.esotericsoftware.kryo.pool.KryoCallback import com.esotericsoftware.kryo.pool.KryoPool import com.esotericsoftware.kryo.util.MapReferenceResolver import com.google.common.annotations.VisibleForTesting @@ -10,6 +11,7 @@ import net.corda.core.contracts.* import net.corda.core.crypto.* import net.corda.core.node.AttachmentsClassLoader import net.corda.core.transactions.WireTransaction +import net.corda.core.utilities.LazyPool import net.i2p.crypto.eddsa.EdDSAPrivateKey import net.i2p.crypto.eddsa.EdDSAPublicKey import net.i2p.crypto.eddsa.spec.EdDSANamedCurveSpec @@ -19,7 +21,9 @@ import org.bouncycastle.asn1.ASN1InputStream import org.bouncycastle.asn1.x500.X500Name import org.slf4j.Logger import org.slf4j.LoggerFactory -import java.io.* +import java.io.ByteArrayInputStream +import java.io.ByteArrayOutputStream +import java.io.InputStream import java.lang.reflect.InvocationTargetException import java.nio.file.Files import java.nio.file.Path @@ -143,13 +147,20 @@ fun T.serialize(kryo: KryoPool = p2PKryo(), internalOnly: Boolean = fa return kryo.run { k -> serialize(k, internalOnly) } } + +private val serializeBufferPool = LazyPool { ByteArray(64 * 1024) } +private val serializeOutputStreamPool = LazyPool(ByteArrayOutputStream::reset) { ByteArrayOutputStream(64 * 1024) } fun T.serialize(kryo: Kryo, internalOnly: Boolean = false): SerializedBytes { - val stream = ByteArrayOutputStream() - Output(stream).use { - it.writeBytes(KryoHeaderV0_1.bytes) - kryo.writeClassAndObject(it, this) + return serializeOutputStreamPool.run { stream -> + serializeBufferPool.run { buffer -> + Output(buffer).use { + it.setOutputStream(stream) + it.writeBytes(KryoHeaderV0_1.bytes) + kryo.writeClassAndObject(it, this) + } + SerializedBytes(stream.toByteArray(), internalOnly) + } } - return SerializedBytes(stream.toByteArray(), internalOnly) } /** @@ -592,4 +603,26 @@ object X500NameSerializer : Serializer() { override fun write(kryo: Kryo, output: Output, obj: X500Name) { output.writeBytes(obj.encoded) } -} \ No newline at end of file +} + +class KryoPoolWithContext(val baseKryoPool: KryoPool, val contextKey: Any, val context: Any) : KryoPool { + override fun run(callback: KryoCallback): T { + val kryo = borrow() + try { + return callback.execute(kryo) + } finally { + release(kryo) + } + } + + override fun borrow(): Kryo { + val kryo = baseKryoPool.borrow() + require(kryo.context.put(contextKey, context) == null) { "KryoPool already has context" } + return kryo + } + + override fun release(kryo: Kryo) { + requireNotNull(kryo.context.remove(contextKey)) { "Kryo instance lost context while borrowed" } + baseKryoPool.release(kryo) + } +} diff --git a/core/src/main/kotlin/net/corda/core/utilities/LazyPool.kt b/core/src/main/kotlin/net/corda/core/utilities/LazyPool.kt new file mode 100644 index 0000000000..a9df98637e --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/utilities/LazyPool.kt @@ -0,0 +1,79 @@ +package net.corda.core.utilities + +import java.util.concurrent.LinkedBlockingQueue +import java.util.concurrent.atomic.AtomicBoolean + +/** + * A lazy pool of resources [A]. + * + * @param clear If specified this function will be run on each borrowed instance before handing it over. + * @param bound If specified the pool will be bounded. Once all instances are borrowed subsequent borrows will block until an + * instance is released. + * @param create The function to call to lazily create a pooled resource. + */ +class LazyPool( + private val clear: ((A) -> Unit)? = null, + private val bound: Int? = null, + private val create: () -> A +) { + private val poolQueue = LinkedBlockingQueue() + private var poolSize = 0 + + private enum class State { + STARTED, + FINISHED + } + private val lifeCycle = LifeCycle(State.STARTED) + + private fun clearIfNeeded(instance: A): A { + clear?.invoke(instance) + return instance + } + + fun borrow(): A { + lifeCycle.requireState(State.STARTED) + val pooled = poolQueue.poll() + if (pooled == null) { + if (bound != null) { + val waitForRelease = synchronized(this) { + if (poolSize < bound) { + poolSize++ + false + } else { + true + } + } + if (waitForRelease) { + // Wait until one is released + return clearIfNeeded(poolQueue.take()) + } + } + return create() + } else { + return clearIfNeeded(pooled) + } + } + + fun release(instance: A) { + lifeCycle.requireState(State.STARTED) + poolQueue.add(instance) + } + + /** + * Closes the pool. Note that all borrowed instances must have been released before calling this function, otherwise + * the returned iterable will be inaccurate. + */ + fun close(): Iterable { + lifeCycle.transition(State.STARTED, State.FINISHED) + return poolQueue + } + + inline fun run(withInstance: (A) -> R): R { + val instance = borrow() + try { + return withInstance(instance) + } finally { + release(instance) + } + } +} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/utilities/LazyStickyPool.kt b/core/src/main/kotlin/net/corda/core/utilities/LazyStickyPool.kt new file mode 100644 index 0000000000..298279f09f --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/utilities/LazyStickyPool.kt @@ -0,0 +1,67 @@ +package net.corda.core.utilities + +import java.util.* +import java.util.concurrent.LinkedBlockingQueue + +/** + * A [LazyStickyPool] is a lazy pool of resources where a [borrow] may "stick" the borrowed instance to an object. + * Any subsequent borrows using the same object will return the same pooled instance. + */ +// TODO This could be implemented more efficiently. Currently the "non-sticky" use case is not optimised, it just chooses a random instance to wait on. +class LazyStickyPool( + size: Int, + private val newInstance: () -> A +) { + private class InstanceBox { + var instance: LinkedBlockingQueue? = null + } + private val random = Random() + private val boxes = Array(size) { InstanceBox() } + + private fun toIndex(stickTo: Any): Int { + return Math.abs(stickTo.hashCode()) % boxes.size + } + + fun borrow(stickTo: Any): A { + val box = boxes[toIndex(stickTo)] + val instance = synchronized(box) { + val instance = box.instance + if (instance == null) { + val newInstance = LinkedBlockingQueue(listOf(newInstance())) + box.instance = newInstance + newInstance + } else { + instance + } + } + return instance.take() + } + + fun borrow(): Pair { + val randomInt = random.nextInt() + val instance = borrow(randomInt) + return Pair(randomInt, instance) + } + + fun release(stickTo: Any, instance: A) { + val box = boxes[toIndex(stickTo)] + box.instance!!.add(instance) + } + + inline fun run(stickToOrNull: Any? = null, withInstance: (A) -> R): R { + val (stickTo, instance) = if (stickToOrNull == null) { + borrow() + } else { + Pair(stickToOrNull, borrow(stickToOrNull)) + } + try { + return withInstance(instance) + } finally { + release(stickTo, instance) + } + } + + fun close(): Iterable { + return boxes.map { it.instance?.poll() }.filterNotNull() + } +} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/utilities/LifeCycle.kt b/core/src/main/kotlin/net/corda/core/utilities/LifeCycle.kt new file mode 100644 index 0000000000..2a7b0c2bb3 --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/utilities/LifeCycle.kt @@ -0,0 +1,38 @@ +package net.corda.core.utilities + +import java.util.concurrent.locks.ReentrantReadWriteLock +import kotlin.concurrent.withLock + +/** + * This class provides simple tracking of the lifecycle of a service-type object. + * [S] is an enum enumerating the possible states the service can be in. + * + * @param initial The initial state. + */ +class LifeCycle>(initial: S) { + private val lock = ReentrantReadWriteLock() + private var state = initial + + /** Assert that the lifecycle in the [requiredState] */ + fun requireState(requiredState: S) { + requireState({ "Required state to be $requiredState, was $it" }) { it == requiredState } + } + + /** Assert something about the current state atomically. */ + fun requireState( + errorMessage: (S) -> String = { "Predicate failed on state $it" }, + predicate: (S) -> Boolean + ) { + lock.readLock().withLock { + require(predicate(state)) { errorMessage(state) } + } + } + + /** Transition the state from [from] to [to] */ + fun transition(from: S, to: S) { + lock.writeLock().withLock { + require(state == from) { "Required state to be $from to transition to $to, was $state" } + state = to + } + } +} \ No newline at end of file diff --git a/core/src/test/kotlin/net/corda/core/flows/ContractUpgradeFlowTest.kt b/core/src/test/kotlin/net/corda/core/flows/ContractUpgradeFlowTest.kt index 78f18c4346..850d5d2efd 100644 --- a/core/src/test/kotlin/net/corda/core/flows/ContractUpgradeFlowTest.kt +++ b/core/src/test/kotlin/net/corda/core/flows/ContractUpgradeFlowTest.kt @@ -5,6 +5,7 @@ import net.corda.core.contracts.* import net.corda.core.crypto.Party import net.corda.core.crypto.SecureHash import net.corda.core.getOrThrow +import net.corda.core.messaging.CordaRPCOps import net.corda.core.messaging.startFlow import net.corda.core.node.services.unconsumedStates import net.corda.core.serialization.OpaqueBytes @@ -15,9 +16,12 @@ import net.corda.flows.FinalityFlow import net.corda.node.internal.CordaRPCOpsImpl import net.corda.node.services.startFlowPermission import net.corda.node.utilities.transaction -import net.corda.nodeapi.CURRENT_RPC_USER import net.corda.nodeapi.User +import net.corda.testing.RPCDriverExposedDSLInterface import net.corda.testing.node.MockNetwork +import net.corda.testing.rpcDriver +import net.corda.testing.rpcTestUser +import net.corda.testing.startRpcClient import org.junit.After import org.junit.Before import org.junit.Test @@ -99,60 +103,71 @@ class ContractUpgradeFlowTest { check(b) } + private fun RPCDriverExposedDSLInterface.startProxy(node: MockNetwork.MockNode, user: User): CordaRPCOps { + return startRpcClient( + rpcAddress = startRpcServer( + rpcUser = user, + ops = CordaRPCOpsImpl(node.services, node.smm, node.database) + ).get().hostAndPort, + username = user.username, + password = user.password + ).get() + } + @Test fun `2 parties contract upgrade using RPC`() { - // Create dummy contract. - val twoPartyDummyContract = DummyContract.generateInitial(0, notary, a.info.legalIdentity.ref(1), b.info.legalIdentity.ref(1)) - val stx = twoPartyDummyContract.signWith(a.services.legalIdentityKey) - .signWith(b.services.legalIdentityKey) - .toSignedTransaction() + rpcDriver { + // Create dummy contract. + val twoPartyDummyContract = DummyContract.generateInitial(0, notary, a.info.legalIdentity.ref(1), b.info.legalIdentity.ref(1)) + val stx = twoPartyDummyContract.signWith(a.services.legalIdentityKey) + .signWith(b.services.legalIdentityKey) + .toSignedTransaction() - a.services.startFlow(FinalityFlow(stx, setOf(a.info.legalIdentity, b.info.legalIdentity))) - mockNet.runNetwork() + val user = rpcTestUser.copy(permissions = setOf( + startFlowPermission(), + startFlowPermission>() + )) + val rpcA = startProxy(a, user) + val rpcB = startProxy(b, user) + val handle = rpcA.startFlow(::FinalityFlow, stx, setOf(a.info.legalIdentity, b.info.legalIdentity)) + mockNet.runNetwork() + handle.returnValue.getOrThrow() - val atx = a.database.transaction { a.services.storageService.validatedTransactions.getTransaction(stx.id) } - val btx = b.database.transaction { b.services.storageService.validatedTransactions.getTransaction(stx.id) } - requireNotNull(atx) - requireNotNull(btx) + val atx = a.database.transaction { a.services.storageService.validatedTransactions.getTransaction(stx.id) } + val btx = b.database.transaction { b.services.storageService.validatedTransactions.getTransaction(stx.id) } + requireNotNull(atx) + requireNotNull(btx) - // The request is expected to be rejected because party B haven't authorise the upgrade yet. + val rejectedFuture = rpcA.startFlow({ stateAndRef, upgrade -> ContractUpgradeFlow(stateAndRef, upgrade) }, + atx!!.tx.outRef(0), + DummyContractV2::class.java).returnValue - val rpcA = CordaRPCOpsImpl(a.services, a.smm, a.database) - val rpcB = CordaRPCOpsImpl(b.services, b.smm, b.database) + mockNet.runNetwork() + assertFailsWith(ExecutionException::class) { rejectedFuture.get() } - CURRENT_RPC_USER.set(User("user", "pwd", permissions = setOf( - startFlowPermission>() - ))) + // Party B authorise the contract state upgrade. + rpcB.authoriseContractUpgrade(btx!!.tx.outRef(0), DummyContractV2::class.java) - val rejectedFuture = rpcA.startFlow({ stateAndRef, upgrade -> ContractUpgradeFlow(stateAndRef, upgrade) }, - atx!!.tx.outRef(0), - DummyContractV2::class.java).returnValue + // Party A initiates contract upgrade flow, expected to succeed this time. + val resultFuture = rpcA.startFlow({ stateAndRef, upgrade -> ContractUpgradeFlow(stateAndRef, upgrade) }, + atx.tx.outRef(0), + DummyContractV2::class.java).returnValue - mockNet.runNetwork() - assertFailsWith(ExecutionException::class) { rejectedFuture.get() } + mockNet.runNetwork() + val result = resultFuture.get() + // Check results. + listOf(a, b).forEach { + val signedTX = a.database.transaction { a.services.storageService.validatedTransactions.getTransaction(result.ref.txhash) } + requireNotNull(signedTX) - // Party B authorise the contract state upgrade. - rpcB.authoriseContractUpgrade(btx!!.tx.outRef(0), DummyContractV2::class.java) + // Verify inputs. + val input = a.database.transaction { a.services.storageService.validatedTransactions.getTransaction(signedTX!!.tx.inputs.single().txhash) } + requireNotNull(input) + assertTrue(input!!.tx.outputs.single().data is DummyContract.State) - // Party A initiates contract upgrade flow, expected to succeed this time. - val resultFuture = rpcA.startFlow({ stateAndRef, upgrade -> ContractUpgradeFlow(stateAndRef, upgrade) }, - atx.tx.outRef(0), - DummyContractV2::class.java).returnValue - - mockNet.runNetwork() - val result = resultFuture.get() - // Check results. - listOf(a, b).forEach { - val signedTX = a.database.transaction { a.services.storageService.validatedTransactions.getTransaction(result.ref.txhash) } - requireNotNull(signedTX) - - // Verify inputs. - val input = a.database.transaction { a.services.storageService.validatedTransactions.getTransaction(signedTX!!.tx.inputs.single().txhash) } - requireNotNull(input) - assertTrue(input!!.tx.outputs.single().data is DummyContract.State) - - // Verify outputs. - assertTrue(signedTX!!.tx.outputs.single().data is DummyContractV2.State) + // Verify outputs. + assertTrue(signedTX!!.tx.outputs.single().data is DummyContractV2.State) + } } } diff --git a/docs/source/example-code/src/integration-test/kotlin/net/corda/docs/IntegrationTestingTutorial.kt b/docs/source/example-code/src/integration-test/kotlin/net/corda/docs/IntegrationTestingTutorial.kt index 24b1766837..c6c8afb8f8 100644 --- a/docs/source/example-code/src/integration-test/kotlin/net/corda/docs/IntegrationTestingTutorial.kt +++ b/docs/source/example-code/src/integration-test/kotlin/net/corda/docs/IntegrationTestingTutorial.kt @@ -18,7 +18,10 @@ import net.corda.node.driver.driver import net.corda.node.services.startFlowPermission import net.corda.node.services.transactions.ValidatingNotaryService import net.corda.nodeapi.User -import net.corda.testing.* +import net.corda.testing.expect +import net.corda.testing.expectEvents +import net.corda.testing.parallel +import net.corda.testing.sequence import org.junit.Test import java.util.* import kotlin.concurrent.thread @@ -44,12 +47,10 @@ class IntegrationTestingTutorial { // START 2 val aliceClient = alice.rpcClientToNode() - aliceClient.start("aliceUser", "testPassword1") - val aliceProxy = aliceClient.proxy() + val aliceProxy = aliceClient.start("aliceUser", "testPassword1").proxy val bobClient = bob.rpcClientToNode() - bobClient.start("bobUser", "testPassword2") - val bobProxy = bobClient.proxy() + val bobProxy = bobClient.start("bobUser", "testPassword2").proxy // END 2 // START 3 diff --git a/docs/source/example-code/src/main/kotlin/net/corda/docs/ClientRpcTutorial.kt b/docs/source/example-code/src/main/kotlin/net/corda/docs/ClientRpcTutorial.kt index f2d8b82aee..b232f9fd6e 100644 --- a/docs/source/example-code/src/main/kotlin/net/corda/docs/ClientRpcTutorial.kt +++ b/docs/source/example-code/src/main/kotlin/net/corda/docs/ClientRpcTutorial.kt @@ -55,8 +55,7 @@ fun main(args: Array) { // START 2 val client = node.rpcClientToNode() - client.start("user", "password") - val proxy = client.proxy() + val proxy = client.start("user", "password").proxy thread { generateTransactions(proxy) diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/ArtemisMessagingComponent.kt b/node-api/src/main/kotlin/net/corda/nodeapi/ArtemisMessagingComponent.kt index af48164958..0f16c2e056 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/ArtemisMessagingComponent.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/ArtemisMessagingComponent.kt @@ -30,10 +30,7 @@ abstract class ArtemisMessagingComponent : SingletonSerializeAsToken() { const val INTERNAL_PREFIX = "internal." const val PEERS_PREFIX = "${INTERNAL_PREFIX}peers." const val SERVICES_PREFIX = "${INTERNAL_PREFIX}services." - const val CLIENTS_PREFIX = "clients." const val P2P_QUEUE = "p2p.inbound" - const val RPC_REQUESTS_QUEUE = "rpc.requests" - const val RPC_QUEUE_REMOVALS_QUEUE = "rpc.qremovals" const val NOTIFICATIONS_ADDRESS = "${INTERNAL_PREFIX}activemq.notifications" const val NETWORK_MAP_QUEUE = "${INTERNAL_PREFIX}networkmap" diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/RPCApi.kt b/node-api/src/main/kotlin/net/corda/nodeapi/RPCApi.kt new file mode 100644 index 0000000000..39d85d90d2 --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/RPCApi.kt @@ -0,0 +1,206 @@ +package net.corda.nodeapi + +import com.esotericsoftware.kryo.pool.KryoPool +import net.corda.core.ErrorOr +import net.corda.core.serialization.KryoPoolWithContext +import net.corda.core.serialization.deserialize +import net.corda.core.serialization.serialize +import net.corda.nodeapi.RPCApi.ClientToServer +import net.corda.nodeapi.RPCApi.ObservableId +import net.corda.nodeapi.RPCApi.RPC_CLIENT_BINDING_REMOVALS +import net.corda.nodeapi.RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX +import net.corda.nodeapi.RPCApi.RPC_SERVER_QUEUE_NAME +import net.corda.nodeapi.RPCApi.RpcRequestId +import net.corda.nodeapi.RPCApi.ServerToClient +import org.apache.activemq.artemis.api.core.SimpleString +import org.apache.activemq.artemis.api.core.client.* +import org.apache.activemq.artemis.api.core.management.CoreNotificationType +import org.apache.activemq.artemis.api.core.management.ManagementHelper +import org.apache.activemq.artemis.reader.MessageUtil +import rx.Notification +import java.util.* + +/** + * The RPC protocol: + * + * The server consumes the queue "[RPC_SERVER_QUEUE_NAME]" and receives RPC requests ([ClientToServer.RpcRequest]) on it. + * When a client starts up it should create a queue for its inbound messages, this should be of the form + * "[RPC_CLIENT_QUEUE_NAME_PREFIX].$username.$nonce". Each RPC request contains this address (in + * [ClientToServer.RpcRequest.clientAddress]), this is where the server will send the reply to the request as well as + * subsequent Observations rooted in the RPC. The requests/replies are muxed using a unique [RpcRequestId] generated by + * the client for each request. + * + * If an RPC reply's payload ([ServerToClient.RpcReply.result]) contains [Observable]s then the server will generate a + * unique [ObservableId] for each and serialise them in place of the [Observable]s themselves. Subsequently the client + * should be prepared to receive observations ([ServerToClient.Observation]), muxed by the relevant [ObservableId]. + * In addition each observation itself may contain further [Observable]s, this case should behave the same as before. + * + * Additionally the client may send [ClientToServer.ObservablesClosed] messages indicating that certain observables + * aren't consumed anymore, which should subsequently stop the stream from the server. Note that some observations may + * already be in flight when this is sent, the client should handle this gracefully. + * + * An example session: + * Client Server + * ----------RpcRequest(RID0)-----------> // Client makes RPC request with ID "RID0" + * <----RpcReply(RID0, Payload(OID0))---- // Server sends reply containing an observable with ID "OID0" + * <---------Observation(OID0)----------- // Server sends observation onto "OID0" + * <---Observation(OID0, Payload(OID1))-- // Server sends another observation, this time containing another observable + * <---------Observation(OID1)----------- // Observation onto new "OID1" + * <---------Observation(OID0)----------- + * -----ObservablesClosed(OID0, OID1)---> // Client indicates it stopped consuming the Observables. + * <---------Observation(OID1)----------- // Observation was already in-flight before the previous message was processed + * (FIN) + * + * Note that multiple sessions like the above may interleave in an arbitrary fashion. + * + * Additionally the server may listen on client binding removals for cleanup using [RPC_CLIENT_BINDING_REMOVALS]. This + * requires the server to create a filter on the artemis notification address using + */ +object RPCApi { + private val TAG_FIELD_NAME = "tag" + private val RPC_ID_FIELD_NAME = "rpc-id" + private val OBSERVABLE_ID_FIELD_NAME = "observable-id" + private val METHOD_NAME_FIELD_NAME = "method-name" + + val RPC_SERVER_QUEUE_NAME = "rpc.server" + val RPC_CLIENT_QUEUE_NAME_PREFIX = "rpc.client" + val RPC_CLIENT_BINDING_REMOVALS = "rpc.clientqueueremovals" + + val RPC_CLIENT_BINDING_REMOVAL_FILTER_EXPRESSION = + "${ManagementHelper.HDR_NOTIFICATION_TYPE} = '${CoreNotificationType.BINDING_REMOVED.name}' AND " + + "${ManagementHelper.HDR_ROUTING_NAME} LIKE '$RPC_CLIENT_QUEUE_NAME_PREFIX.%'" + + data class RpcRequestId(val toLong: Long) + data class ObservableId(val toLong: Long) + + object RpcRequestOrObservableIdKey + + private fun ClientMessage.getBodyAsByteArray(): ByteArray { + return ByteArray(bodySize).apply { bodyBuffer.readBytes(this) } + } + + sealed class ClientToServer { + private enum class Tag { + RPC_REQUEST, + OBSERVABLES_CLOSED + } + + data class RpcRequest( + val clientAddress: SimpleString, + val id: RpcRequestId, + val methodName: String, + val arguments: List + ) : ClientToServer() { + fun writeToClientMessage(kryoPool: KryoPool, message: ClientMessage) { + MessageUtil.setJMSReplyTo(message, clientAddress) + message.putIntProperty(TAG_FIELD_NAME, Tag.RPC_REQUEST.ordinal) + message.putLongProperty(RPC_ID_FIELD_NAME, id.toLong) + message.putStringProperty(METHOD_NAME_FIELD_NAME, methodName) + message.bodyBuffer.writeBytes(arguments.serialize(kryoPool).bytes) + } + } + + data class ObservablesClosed( + val ids: List + ) : ClientToServer() { + fun writeToClientMessage(message: ClientMessage) { + message.putIntProperty(TAG_FIELD_NAME, Tag.OBSERVABLES_CLOSED.ordinal) + val buffer = message.bodyBuffer + buffer.writeInt(ids.size) + ids.forEach { + buffer.writeLong(it.toLong) + } + } + } + + companion object { + fun fromClientMessage(kryoPool: KryoPool, message: ClientMessage): ClientToServer { + val tag = Tag.values()[message.getIntProperty(TAG_FIELD_NAME)] + return when (tag) { + RPCApi.ClientToServer.Tag.RPC_REQUEST -> RpcRequest( + clientAddress = MessageUtil.getJMSReplyTo(message), + id = RpcRequestId(message.getLongProperty(RPC_ID_FIELD_NAME)), + methodName = message.getStringProperty(METHOD_NAME_FIELD_NAME), + arguments = message.getBodyAsByteArray().deserialize(kryoPool) + ) + RPCApi.ClientToServer.Tag.OBSERVABLES_CLOSED -> { + val ids = ArrayList() + val buffer = message.bodyBuffer + val numberOfIds = buffer.readInt() + for (i in 1 .. numberOfIds) { + ids.add(ObservableId(buffer.readLong())) + } + ObservablesClosed(ids) + } + } + } + } + } + + sealed class ServerToClient { + private enum class Tag { + RPC_REPLY, + OBSERVATION + } + + abstract fun writeToClientMessage(kryoPool: KryoPool, message: ClientMessage) + + data class RpcReply( + val id: RpcRequestId, + val result: ErrorOr + ) : ServerToClient() { + override fun writeToClientMessage(kryoPool: KryoPool, message: ClientMessage) { + message.putIntProperty(TAG_FIELD_NAME, Tag.RPC_REPLY.ordinal) + message.putLongProperty(RPC_ID_FIELD_NAME, id.toLong) + message.bodyBuffer.writeBytes(result.serialize(kryoPool).bytes) + } + } + + data class Observation( + val id: ObservableId, + val content: Notification + ) : ServerToClient() { + override fun writeToClientMessage(kryoPool: KryoPool, message: ClientMessage) { + message.putIntProperty(TAG_FIELD_NAME, Tag.OBSERVATION.ordinal) + message.putLongProperty(OBSERVABLE_ID_FIELD_NAME, id.toLong) + message.bodyBuffer.writeBytes(content.serialize(kryoPool).bytes) + } + } + + companion object { + fun fromClientMessage(kryoPool: KryoPool, message: ClientMessage): ServerToClient { + val tag = Tag.values()[message.getIntProperty(TAG_FIELD_NAME)] + return when (tag) { + RPCApi.ServerToClient.Tag.RPC_REPLY -> { + val id = RpcRequestId(message.getLongProperty(RPC_ID_FIELD_NAME)) + val poolWithIdContext = KryoPoolWithContext(kryoPool, RpcRequestOrObservableIdKey, id.toLong) + RpcReply( + id = id, + result = message.getBodyAsByteArray().deserialize(poolWithIdContext) + ) + } + RPCApi.ServerToClient.Tag.OBSERVATION -> { + val id = ObservableId(message.getLongProperty(OBSERVABLE_ID_FIELD_NAME)) + val poolWithIdContext = KryoPoolWithContext(kryoPool, RpcRequestOrObservableIdKey, id.toLong) + Observation( + id = id, + content = message.getBodyAsByteArray().deserialize(poolWithIdContext) + ) + } + } + } + } + } +} + +data class ArtemisProducer( + val sessionFactory: ClientSessionFactory, + val session: ClientSession, + val producer: ClientProducer +) + +data class ArtemisConsumer( + val sessionFactory: ClientSessionFactory, + val session: ClientSession, + val consumer: ClientConsumer +) diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/RPCStructures.kt b/node-api/src/main/kotlin/net/corda/nodeapi/RPCStructures.kt index 171a1151dd..329155c6e5 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/RPCStructures.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/RPCStructures.kt @@ -10,17 +10,10 @@ import net.corda.core.serialization.* import net.corda.core.toFuture import net.corda.core.toObservable import net.corda.nodeapi.config.OldConfig -import org.apache.commons.fileupload.MultipartStream -import org.slf4j.Logger -import org.slf4j.LoggerFactory -import rx.Notification import rx.Observable - -/** Global RPC logger */ -val rpcLog: Logger by lazy { LoggerFactory.getLogger("net.corda.rpc") } - -/** Used in the RPC wire protocol to wrap an observation with the handle of the observable it's intended for. */ -data class MarshalledObservation(val forHandle: Int, val what: Notification<*>) +import java.io.InputStream +import java.io.PrintWriter +import java.io.StringWriter data class User( @OldConfig("user") @@ -35,28 +28,6 @@ data class User( @MustBeDocumented annotation class RPCSinceVersion(val version: Int) -/** The contents of an RPC request message, separated from the MQ layer. */ -data class ClientRPCRequestMessage( - val args: SerializedBytes>, - val replyToAddress: String, - val observationsToAddress: String?, - val methodName: String, - val user: User -) { - companion object { - const val REPLY_TO = "reply-to" - const val OBSERVATIONS_TO = "observations-to" - const val METHOD_NAME = "method-name" - } -} - -/** - * This is available to RPC implementations to query the validated [User] that is calling it. Each user has a set of - * permissions they're entitled to which can be used to control access. - */ -@JvmField -val CURRENT_RPC_USER: ThreadLocal = ThreadLocal() - /** * Thrown to indicate a fatal error in the RPC system itself, as opposed to an error generated by the invoked * method. @@ -64,19 +35,11 @@ val CURRENT_RPC_USER: ThreadLocal = ThreadLocal() @CordaSerializable open class RPCException(msg: String, cause: Throwable?) : RuntimeException(msg, cause) { constructor(msg: String) : this(msg, null) - - class DeadlineExceeded(rpcName: String) : RPCException("Deadline exceeded on call to $rpcName") } @CordaSerializable class PermissionException(msg: String) : RuntimeException(msg) -object RPCKryoClientKey -object RPCKryoDispatcherKey -object RPCKryoQNameKey -object RPCKryoMethodNameKey -object RPCKryoLocationKey - // The Kryo used for the RPC wire protocol. Every type in the wire protocol is listed here explicitly. // This is annoying to write out, but will make it easier to formalise the wire protocol when the time comes, // because we can see everything we're using in one place. @@ -85,8 +48,7 @@ class RPCKryo(observableSerializer: Serializer>) : CordaKryo(mak DefaultKryoCustomizer.customize(this) // RPC specific classes - register(MultipartStream.ItemInputStream::class.java, InputStreamSerializer) - register(MarshalledObservation::class.java, ImmutableClassSerializer(MarshalledObservation::class)) + register(InputStream::class.java, InputStreamSerializer) register(Observable::class.java, observableSerializer) @Suppress("UNCHECKED_CAST") register(ListenableFuture::class, @@ -110,14 +72,14 @@ class RPCKryo(observableSerializer: Serializer>) : CordaKryo(mak } override fun getRegistration(type: Class<*>): Registration { - val annotated = context[RPCKryoQNameKey] != null - if (Observable::class.java.isAssignableFrom(type)) { - return if (annotated) super.getRegistration(Observable::class.java) - else throw IllegalStateException("This RPC was not annotated with @RPCReturnsObservables") + if (Observable::class.java != type && Observable::class.java.isAssignableFrom(type)) { + return super.getRegistration(Observable::class.java) } - if (ListenableFuture::class.java.isAssignableFrom(type)) { - return if (annotated) super.getRegistration(ListenableFuture::class.java) - else throw IllegalStateException("This RPC was not annotated with @RPCReturnsObservables") + if (InputStream::class.java != type && InputStream::class.java.isAssignableFrom(type)) { + return super.getRegistration(InputStream::class.java) + } + if (ListenableFuture::class.java != type && ListenableFuture::class.java.isAssignableFrom(type)) { + return super.getRegistration(ListenableFuture::class.java) } if (FlowException::class.java.isAssignableFrom(type)) return super.getRegistration(FlowException::class.java) diff --git a/node/src/integration-test/kotlin/net/corda/node/BootTests.kt b/node/src/integration-test/kotlin/net/corda/node/BootTests.kt index bdc410e5db..2b0314bc00 100644 --- a/node/src/integration-test/kotlin/net/corda/node/BootTests.kt +++ b/node/src/integration-test/kotlin/net/corda/node/BootTests.kt @@ -17,9 +17,8 @@ class BootTests { fun `java deserialization is disabled`() { driver { val user = User("u", "p", setOf(startFlowPermission())) - val future = startNode(rpcUsers = listOf(user)).getOrThrow().rpcClientToNode().apply { - start(user.username, user.password) - }.proxy().startFlow(::ObjectInputStreamFlow).returnValue + val future = startNode(rpcUsers = listOf(user)).getOrThrow().rpcClientToNode(). + start(user.username, user.password).proxy.startFlow(::ObjectInputStreamFlow).returnValue assertThatThrownBy { future.getOrThrow() }.isInstanceOf(InvalidClassException::class.java).hasMessage("filter status: REJECTED") } } diff --git a/node/src/integration-test/kotlin/net/corda/node/services/DistributedServiceTests.kt b/node/src/integration-test/kotlin/net/corda/node/services/DistributedServiceTests.kt index 21001405bb..61d983cbb4 100644 --- a/node/src/integration-test/kotlin/net/corda/node/services/DistributedServiceTests.kt +++ b/node/src/integration-test/kotlin/net/corda/node/services/DistributedServiceTests.kt @@ -60,8 +60,7 @@ class DistributedServiceTests : DriverBasedTest() { // Connect to Alice and the notaries fun connectRpc(node: NodeHandle): CordaRPCOps { val client = node.rpcClientToNode() - client.start("test", "test") - return client.proxy() + return client.start("test", "test").proxy } aliceProxy = connectRpc(alice) val rpcClientsToNotaries = notaries.map(::connectRpc) diff --git a/node/src/integration-test/kotlin/net/corda/services/messaging/MQSecurityAsNodeTest.kt b/node/src/integration-test/kotlin/net/corda/services/messaging/MQSecurityAsNodeTest.kt index 5c5ac2edac..64a3ee71a6 100644 --- a/node/src/integration-test/kotlin/net/corda/services/messaging/MQSecurityAsNodeTest.kt +++ b/node/src/integration-test/kotlin/net/corda/services/messaging/MQSecurityAsNodeTest.kt @@ -2,7 +2,7 @@ package net.corda.services.messaging import net.corda.nodeapi.ArtemisMessagingComponent.Companion.NODE_USER import net.corda.nodeapi.ArtemisMessagingComponent.Companion.PEER_USER -import net.corda.nodeapi.ArtemisMessagingComponent.Companion.RPC_REQUESTS_QUEUE +import net.corda.nodeapi.RPCApi import net.corda.testing.messaging.SimpleMQClient import org.apache.activemq.artemis.api.config.ActiveMQDefaultConfiguration import org.apache.activemq.artemis.api.core.ActiveMQClusterSecurityException @@ -24,7 +24,7 @@ class MQSecurityAsNodeTest : MQSecurityTest() { @Test fun `send message to RPC requests address`() { - assertSendAttackFails(RPC_REQUESTS_QUEUE) + assertSendAttackFails(RPCApi.RPC_SERVER_QUEUE_NAME) } @Test diff --git a/node/src/integration-test/kotlin/net/corda/services/messaging/MQSecurityAsRPCTest.kt b/node/src/integration-test/kotlin/net/corda/services/messaging/MQSecurityAsRPCTest.kt index 6eab75c5f2..ca458158ed 100644 --- a/node/src/integration-test/kotlin/net/corda/services/messaging/MQSecurityAsRPCTest.kt +++ b/node/src/integration-test/kotlin/net/corda/services/messaging/MQSecurityAsRPCTest.kt @@ -1,6 +1,7 @@ package net.corda.services.messaging import net.corda.nodeapi.User +import net.corda.testing.configureTestSSL import net.corda.testing.messaging.SimpleMQClient import org.apache.activemq.artemis.api.core.ActiveMQSecurityException import org.assertj.core.api.Assertions.assertThatExceptionOfType @@ -23,14 +24,13 @@ class MQSecurityAsRPCTest : MQSecurityTest() { override val extraRPCUsers = listOf(User("evil", "pass", permissions = emptySet())) override fun startAttacker(attacker: SimpleMQClient) { - attacker.loginToRPC(extraRPCUsers[0]) + attacker.start(extraRPCUsers[0].username, extraRPCUsers[0].password, false) } @Test fun `login to a ssl port as a RPC user`() { - val attacker = clientTo(alice.configuration.p2pAddress) assertThatExceptionOfType(ActiveMQSecurityException::class.java).isThrownBy { - attacker.loginToRPC(extraRPCUsers[0], enableSSL = true) + loginToRPC(alice.configuration.p2pAddress, extraRPCUsers[0], configureTestSSL()) } } } diff --git a/node/src/integration-test/kotlin/net/corda/services/messaging/MQSecurityTest.kt b/node/src/integration-test/kotlin/net/corda/services/messaging/MQSecurityTest.kt index db794bba72..e9d4d1ea02 100644 --- a/node/src/integration-test/kotlin/net/corda/services/messaging/MQSecurityTest.kt +++ b/node/src/integration-test/kotlin/net/corda/services/messaging/MQSecurityTest.kt @@ -2,7 +2,7 @@ package net.corda.services.messaging import co.paralleluniverse.fibers.Suspendable import com.google.common.net.HostAndPort -import net.corda.client.rpc.CordaRPCClientImpl +import net.corda.client.rpc.CordaRPCClient import net.corda.core.crypto.Party import net.corda.core.crypto.generateKeyPair import net.corda.core.crypto.toBase58String @@ -10,19 +10,16 @@ import net.corda.core.flows.FlowLogic import net.corda.core.getOrThrow import net.corda.core.messaging.CordaRPCOps import net.corda.core.random63BitValue -import net.corda.core.seconds import net.corda.core.utilities.ALICE import net.corda.core.utilities.BOB import net.corda.core.utilities.unwrap import net.corda.node.internal.Node -import net.corda.nodeapi.ArtemisMessagingComponent.Companion.CLIENTS_PREFIX import net.corda.nodeapi.ArtemisMessagingComponent.Companion.INTERNAL_PREFIX import net.corda.nodeapi.ArtemisMessagingComponent.Companion.NETWORK_MAP_QUEUE import net.corda.nodeapi.ArtemisMessagingComponent.Companion.NOTIFICATIONS_ADDRESS import net.corda.nodeapi.ArtemisMessagingComponent.Companion.P2P_QUEUE import net.corda.nodeapi.ArtemisMessagingComponent.Companion.PEERS_PREFIX -import net.corda.nodeapi.ArtemisMessagingComponent.Companion.RPC_QUEUE_REMOVALS_QUEUE -import net.corda.nodeapi.ArtemisMessagingComponent.Companion.RPC_REQUESTS_QUEUE +import net.corda.nodeapi.RPCApi import net.corda.nodeapi.User import net.corda.nodeapi.config.SSLConfiguration import net.corda.testing.configureTestSSL @@ -36,7 +33,6 @@ import org.junit.After import org.junit.Before import org.junit.Test import java.util.* -import java.util.concurrent.locks.ReentrantLock import kotlin.test.assertEquals /** @@ -108,7 +104,7 @@ abstract class MQSecurityTest : NodeBasedTest() { @Test fun `consume message from RPC requests queue`() { - assertConsumeAttackFails(RPC_REQUESTS_QUEUE) + assertConsumeAttackFails(RPCApi.RPC_SERVER_QUEUE_NAME) } @Test @@ -119,21 +115,16 @@ abstract class MQSecurityTest : NodeBasedTest() { @Test fun `create queue for valid RPC user`() { - val user1Queue = "$CLIENTS_PREFIX${rpcUser.username}.rpc.${random63BitValue()}" + val user1Queue = "${RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX}.${rpcUser.username}.${random63BitValue()}" assertTempQueueCreationAttackFails(user1Queue) } @Test fun `create queue for invalid RPC user`() { - val invalidRPCQueue = "$CLIENTS_PREFIX${random63BitValue()}.rpc.${random63BitValue()}" + val invalidRPCQueue = "${RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX}.${random63BitValue()}.${random63BitValue()}" assertTempQueueCreationAttackFails(invalidRPCQueue) } - @Test - fun `consume message from RPC queue removals queue`() { - assertConsumeAttackFails(RPC_QUEUE_REMOVALS_QUEUE) - } - @Test fun `send message to notifications address`() { assertSendAttackFails(NOTIFICATIONS_ADDRESS) @@ -157,22 +148,16 @@ abstract class MQSecurityTest : NodeBasedTest() { return client } - fun loginToRPC(target: HostAndPort, rpcUser: User, sslConfiguration: SSLConfiguration? = null): SimpleMQClient { - val client = clientTo(target, sslConfiguration) - client.loginToRPC(rpcUser) - return client - } - - fun SimpleMQClient.loginToRPC(rpcUser: User, enableSSL: Boolean = false): CordaRPCOps { - start(rpcUser.username, rpcUser.password, enableSSL) - val clientImpl = CordaRPCClientImpl(session, ReentrantLock(), rpcUser.username) - return clientImpl.proxyFor(CordaRPCOps::class.java, timeout = 1.seconds) + fun loginToRPC(target: HostAndPort, rpcUser: User, sslConfiguration: SSLConfiguration? = null): CordaRPCOps { + return CordaRPCClient(target, sslConfiguration).start(rpcUser.username, rpcUser.password).proxy } fun loginToRPCAndGetClientQueue(): String { - val rpcClient = loginToRPC(alice.configuration.rpcAddress!!, rpcUser) - val clientQueueQuery = SimpleString("$CLIENTS_PREFIX${rpcUser.username}.rpc.*") - return rpcClient.session.addressQuery(clientQueueQuery).queueNames.single().toString() + loginToRPC(alice.configuration.rpcAddress!!, rpcUser) + val clientQueueQuery = SimpleString("${RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX}.${rpcUser.username}.*") + val client = clientTo(alice.configuration.rpcAddress!!) + client.start(rpcUser.username, rpcUser.password, false) + return client.session.addressQuery(clientQueueQuery).queueNames.single().toString() } fun assertAllQueueCreationAttacksFail(queue: String) { diff --git a/node/src/main/kotlin/net/corda/node/driver/Driver.kt b/node/src/main/kotlin/net/corda/node/driver/Driver.kt index 0361d508bc..a0df2aff13 100644 --- a/node/src/main/kotlin/net/corda/node/driver/Driver.kt +++ b/node/src/main/kotlin/net/corda/node/driver/Driver.kt @@ -7,13 +7,10 @@ import com.google.common.util.concurrent.* import com.typesafe.config.Config import com.typesafe.config.ConfigRenderOptions import net.corda.client.rpc.CordaRPCClient -import net.corda.core.ThreadBox +import net.corda.core.* import net.corda.core.crypto.Party import net.corda.core.crypto.X509Utilities import net.corda.core.crypto.commonName -import net.corda.core.div -import net.corda.core.flatMap -import net.corda.core.map import net.corda.core.messaging.CordaRPCOps import net.corda.core.node.NodeInfo import net.corda.core.node.services.ServiceInfo @@ -40,6 +37,7 @@ import java.io.File import java.net.* import java.nio.file.Path import java.nio.file.Paths +import java.time.Duration import java.time.Instant import java.time.ZoneOffset.UTC import java.time.format.DateTimeFormatter @@ -110,6 +108,26 @@ interface DriverDSLExposedInterface { fun startNetworkMapService() fun waitForAllNodesToFinish() + + /** + * Polls a function until it returns a non-null value. Note that there is no timeout on the polling. + * + * @param pollName A description of what is being polled. + * @param pollInterval The interval of polling. + * @param warnCount The number of polls after the Driver gives a warning. + * @param check The function being polled. + * @return A future that completes with the non-null value [check] has returned. + */ + fun pollUntilNonNull(pollName: String, pollInterval: Duration = 500.millis, warnCount: Int = 120, check: () -> A?): ListenableFuture + /** + * Polls the given function until it returns true. + * @see pollUntilNonNull + */ + fun pollUntilTrue(pollName: String, pollInterval: Duration = 500.millis, warnCount: Int = 120, check: () -> Boolean): ListenableFuture { + return pollUntilNonNull(pollName, pollInterval, warnCount) { if (check()) Unit else null } + } + + val shutdownManager: ShutdownManager } interface DriverDSLInternalInterface : DriverDSLExposedInterface { @@ -216,15 +234,13 @@ fun genericD var shutdownHook: Thread? = null try { driverDsl.start() - val returnValue = dsl(coerce(driverDsl)) shutdownHook = Thread({ driverDsl.shutdown() }) Runtime.getRuntime().addShutdownHook(shutdownHook) - return returnValue + return dsl(coerce(driverDsl)) } catch (exception: Throwable) { - println("Driver shutting down because of exception $exception") - exception.printStackTrace() + log.error("Driver shutting down because of exception", exception) throw exception } finally { driverDsl.shutdown() @@ -271,7 +287,7 @@ fun addressMustNotBeBound(executorService: ScheduledExecutorService, hostAndPort fun poll( executorService: ScheduledExecutorService, pollName: String, - pollIntervalMs: Long = 500, + pollInterval: Duration = 500.millis, warnCount: Int = 120, check: () -> A? ): ListenableFuture { @@ -286,7 +302,7 @@ fun poll( executorService.schedule(task@ { counter++ if (counter == warnCount) { - log.warn("Been polling $pollName for ${pollIntervalMs * warnCount / 1000.0} seconds...") + log.warn("Been polling $pollName for ${pollInterval.seconds * warnCount} seconds...") } val result = try { check() @@ -299,7 +315,7 @@ fun poll( } else { resultFuture.set(result) } - }, pollIntervalMs, MILLISECONDS) + }, pollInterval.toMillis(), MILLISECONDS) } schedulePoll() return resultFuture @@ -326,7 +342,13 @@ class ShutdownManager(private val executorService: ExecutorService) { /** Could not get all of them, collect what we have */ shutdownFutures.filter { it.isDone }.map { it.get() } } - shutdowns.reversed().forEach { it() } + shutdowns.reversed().forEach { shutdown -> + try { + shutdown() + } catch (throwable: Throwable) { + log.error("Exception while shutting down", throwable) + } + } } fun registerShutdown(shutdown: ListenableFuture<() -> Unit>) { @@ -335,6 +357,7 @@ class ShutdownManager(private val executorService: ExecutorService) { registeredShutdowns.add(shutdown) } } + fun registerShutdown(shutdown: () -> Unit) = registerShutdown(Futures.immediateFuture(shutdown)) fun registerProcessShutdown(processFuture: ListenableFuture) { val processShutdown = processFuture.map { process -> @@ -368,8 +391,10 @@ class DriverDSL( ) : DriverDSLInternalInterface { private val networkMapLegalName = DUMMY_MAP.name private val networkMapAddress = portAllocation.nextHostAndPort() - val executorService: ListeningScheduledExecutorService = MoreExecutors.listeningDecorator(Executors.newScheduledThreadPool(2)) - val shutdownManager = ShutdownManager(executorService) + val executorService: ListeningScheduledExecutorService = MoreExecutors.listeningDecorator( + Executors.newScheduledThreadPool(2, ThreadFactoryBuilder().setNameFormat("driver-pool-thread-%d").build()) + ) + override val shutdownManager = ShutdownManager(executorService) class State { val processes = ArrayList>() @@ -401,9 +426,6 @@ class DriverDSL( override fun shutdown() { shutdownManager.shutdown() - - // Check that we shut down properly - addressMustNotBeBound(executorService, networkMapAddress).get() executorService.shutdown() } @@ -411,8 +433,9 @@ class DriverDSL( val client = CordaRPCClient(nodeAddress, sslConfig) return poll(executorService, "for RPC connection") { try { - client.start(ArtemisMessagingComponent.NODE_USER, ArtemisMessagingComponent.NODE_USER) - return@poll client.proxy() + val connection = client.start(ArtemisMessagingComponent.NODE_USER, ArtemisMessagingComponent.NODE_USER) + shutdownManager.registerShutdown { connection.close() } + return@poll connection.proxy } catch(e: Exception) { log.error("Exception $e, Retrying RPC connection at $nodeAddress") null @@ -566,6 +589,12 @@ class DriverDSL( registerProcess(startNode) } + override fun pollUntilNonNull(pollName: String, pollInterval: Duration, warnCount: Int, check: () -> A?): ListenableFuture { + val pollFuture = poll(executorService, pollName, pollInterval, warnCount, check) + shutdownManager.registerShutdown { pollFuture.cancel(true) } + return pollFuture + } + companion object { val name = arrayOf( ALICE.name, diff --git a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt index 690d8fe13f..27b61df002 100644 --- a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt @@ -97,7 +97,7 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, CashExitFlow::class.java to setOf(Amount::class.java, PartyAndReference::class.java), CashIssueFlow::class.java to setOf(Amount::class.java, OpaqueBytes::class.java, Party::class.java), CashPaymentFlow::class.java to setOf(Amount::class.java, Party::class.java), - FinalityFlow::class.java to emptySet(), + FinalityFlow::class.java to setOf(LinkedHashSet::class.java), ContractUpgradeFlow::class.java to emptySet() ) } diff --git a/node/src/main/kotlin/net/corda/node/internal/CordaRPCOpsImpl.kt b/node/src/main/kotlin/net/corda/node/internal/CordaRPCOpsImpl.kt index dfa2ca7a03..f853305e6b 100644 --- a/node/src/main/kotlin/net/corda/node/internal/CordaRPCOpsImpl.kt +++ b/node/src/main/kotlin/net/corda/node/internal/CordaRPCOpsImpl.kt @@ -21,11 +21,11 @@ import net.corda.core.node.services.vault.Sort import net.corda.core.serialization.CordaSerializable import net.corda.core.transactions.SignedTransaction import net.corda.node.services.api.ServiceHubInternal +import net.corda.node.services.messaging.getRpcContext import net.corda.node.services.messaging.requirePermission import net.corda.node.services.startFlowPermission import net.corda.node.services.statemachine.StateMachineManager import net.corda.node.utilities.transaction -import net.corda.nodeapi.CURRENT_RPC_USER import org.bouncycastle.asn1.x500.X500Name import org.jetbrains.exposed.sql.Database import rx.Observable @@ -121,8 +121,9 @@ class CordaRPCOpsImpl( // TODO: Check that this flow is annotated as being intended for RPC invocation override fun startTrackedFlowDynamic(logicType: Class>, vararg args: Any?): FlowProgressHandle { - requirePermission(startFlowPermission(logicType)) - val currentUser = FlowInitiator.RPC(CURRENT_RPC_USER.get().username) + val rpcContext = getRpcContext() + rpcContext.requirePermission(startFlowPermission(logicType)) + val currentUser = FlowInitiator.RPC(rpcContext.currentUser.username) val stateMachine = services.invokeFlowAsync(logicType, currentUser, *args) return FlowProgressHandleImpl( id = stateMachine.id, @@ -133,8 +134,9 @@ class CordaRPCOpsImpl( // TODO: Check that this flow is annotated as being intended for RPC invocation override fun startFlowDynamic(logicType: Class>, vararg args: Any?): FlowHandle { - requirePermission(startFlowPermission(logicType)) - val currentUser = FlowInitiator.RPC(CURRENT_RPC_USER.get().username) + val rpcContext = getRpcContext() + rpcContext.requirePermission(startFlowPermission(logicType)) + val currentUser = FlowInitiator.RPC(rpcContext.currentUser.username) val stateMachine = services.invokeFlowAsync(logicType, currentUser, *args) return FlowHandleImpl(id = stateMachine.id, returnValue = stateMachine.resultFuture) } diff --git a/node/src/main/kotlin/net/corda/node/internal/Node.kt b/node/src/main/kotlin/net/corda/node/internal/Node.kt index 515673622e..16767dee36 100644 --- a/node/src/main/kotlin/net/corda/node/internal/Node.kt +++ b/node/src/main/kotlin/net/corda/node/internal/Node.kt @@ -256,7 +256,7 @@ class Node(override val configuration: FullNodeConfiguration, /** Starts a blocking event loop for message dispatch. */ fun run() { - (net as NodeMessagingClient).run() + (net as NodeMessagingClient).run(messageBroker!!.serverControl) } // TODO: Do we really need setup? diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt b/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt index 54ab843746..4f6f8181d5 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt @@ -21,10 +21,10 @@ import net.corda.node.services.messaging.NodeLoginModule.Companion.PEER_ROLE import net.corda.node.services.messaging.NodeLoginModule.Companion.RPC_ROLE import net.corda.node.services.messaging.NodeLoginModule.Companion.VERIFIER_ROLE import net.corda.nodeapi.* -import net.corda.nodeapi.ArtemisMessagingComponent.Companion.CLIENTS_PREFIX import net.corda.nodeapi.ArtemisMessagingComponent.Companion.NODE_USER import net.corda.nodeapi.ArtemisMessagingComponent.Companion.PEER_USER import org.apache.activemq.artemis.api.core.SimpleString +import org.apache.activemq.artemis.api.core.management.ActiveMQServerControl import org.apache.activemq.artemis.core.config.BridgeConfiguration import org.apache.activemq.artemis.core.config.Configuration import org.apache.activemq.artemis.core.config.CoreQueueConfiguration @@ -37,6 +37,8 @@ import org.apache.activemq.artemis.core.remoting.impl.netty.NettyConnectorFactor import org.apache.activemq.artemis.core.security.Role import org.apache.activemq.artemis.core.server.ActiveMQServer import org.apache.activemq.artemis.core.server.impl.ActiveMQServerImpl +import org.apache.activemq.artemis.core.settings.impl.AddressFullMessagePolicy +import org.apache.activemq.artemis.core.settings.impl.AddressSettings import org.apache.activemq.artemis.spi.core.remoting.* import org.apache.activemq.artemis.spi.core.security.ActiveMQJAASSecurityManager import org.apache.activemq.artemis.spi.core.security.jaas.CertificateCallback @@ -97,6 +99,7 @@ class ArtemisMessagingServer(override val config: NodeConfiguration, private val mutex = ThreadBox(InnerState()) private lateinit var activeMQServer: ActiveMQServer + val serverControl: ActiveMQServerControl get() = activeMQServer.activeMQServerControl private val _networkMapConnectionFuture = config.networkMapService?.let { SettableFuture.create() } /** * A [ListenableFuture] which completes when the server successfully connects to the network map node. If a @@ -185,10 +188,19 @@ class ArtemisMessagingServer(override val config: NodeConfiguration, // Create an RPC queue: this will service locally connected clients only (not via a bridge) and those // clients must have authenticated. We could use a single consumer for everything and perhaps we should, // but these queues are not worth persisting. - queueConfig(RPC_REQUESTS_QUEUE, durable = false), - // The custom name for the queue is intentional - we may wish other things to subscribe to the - // NOTIFICATIONS_ADDRESS with different filters in future - queueConfig(RPC_QUEUE_REMOVALS_QUEUE, address = NOTIFICATIONS_ADDRESS, filter = "_AMQ_NotifType = 1", durable = false) + queueConfig(RPCApi.RPC_SERVER_QUEUE_NAME, durable = false), + queueConfig( + name = RPCApi.RPC_CLIENT_BINDING_REMOVALS, + address = NOTIFICATIONS_ADDRESS, + filter = RPCApi.RPC_CLIENT_BINDING_REMOVAL_FILTER_EXPRESSION, + durable = false + ) + ) + addressesSettings = mapOf( + "${RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX}.#" to AddressSettings().apply { + maxSizeBytes = 10L * MAX_FILE_SIZE + addressFullMessagePolicy = AddressFullMessagePolicy.FAIL + } ) configureAddressSecurity() } @@ -213,16 +225,16 @@ class ArtemisMessagingServer(override val config: NodeConfiguration, val nodeInternalRole = Role(NODE_ROLE, true, true, true, true, true, true, true, true) securityRoles["$INTERNAL_PREFIX#"] = setOf(nodeInternalRole) // Do not add any other roles here as it's only for the node securityRoles[P2P_QUEUE] = setOf(nodeInternalRole, restrictedRole(PEER_ROLE, send = true)) - securityRoles[RPC_REQUESTS_QUEUE] = setOf(nodeInternalRole, restrictedRole(RPC_ROLE, send = true)) + securityRoles[RPCApi.RPC_SERVER_QUEUE_NAME] = setOf(nodeInternalRole, restrictedRole(RPC_ROLE, send = true)) // TODO remove the NODE_USER role once the webserver doesn't need it - securityRoles["$CLIENTS_PREFIX$NODE_USER.rpc.*"] = setOf(nodeInternalRole) + securityRoles["${RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX}.$NODE_USER.#"] = setOf(nodeInternalRole) for ((username) in userService.users) { - securityRoles["$CLIENTS_PREFIX$username.rpc.*"] = setOf( + securityRoles["${RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX}.$username.#"] = setOf( nodeInternalRole, - restrictedRole("$CLIENTS_PREFIX$username", consume = true, createNonDurableQueue = true, deleteNonDurableQueue = true)) + restrictedRole("${RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX}.$username", consume = true, createNonDurableQueue = true, deleteNonDurableQueue = true)) } securityRoles[VerifierApi.VERIFICATION_REQUESTS_QUEUE_NAME] = setOf(nodeInternalRole, restrictedRole(VERIFIER_ROLE, consume = true)) - securityRoles["${VerifierApi.VERIFICATION_RESPONSES_QUEUE_NAME_PREFIX}.*"] = setOf(nodeInternalRole, restrictedRole(VERIFIER_ROLE, send = true)) + securityRoles["${VerifierApi.VERIFICATION_RESPONSES_QUEUE_NAME_PREFIX}.#"] = setOf(nodeInternalRole, restrictedRole(VERIFIER_ROLE, send = true)) } private fun restrictedRole(name: String, send: Boolean = false, consume: Boolean = false, createDurableQueue: Boolean = false, @@ -629,7 +641,7 @@ class NodeLoginModule : LoginModule { throw FailedLoginException("Password for user $username does not match") } principals += RolePrincipal(RPC_ROLE) // This enables the RPC client to send requests - principals += RolePrincipal("$CLIENTS_PREFIX$username") // This enables the RPC client to receive responses + principals += RolePrincipal("${RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX}.$username") // This enables the RPC client to receive responses return username } diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/NodeMessagingClient.kt b/node/src/main/kotlin/net/corda/node/services/messaging/NodeMessagingClient.kt index ef4ca7668d..d879e0b8ce 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/NodeMessagingClient.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/NodeMessagingClient.kt @@ -11,7 +11,6 @@ import net.corda.core.node.VersionInfo import net.corda.core.node.services.PartyInfo import net.corda.core.node.services.TransactionVerifierService import net.corda.core.random63BitValue -import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.opaque import net.corda.core.success import net.corda.core.transactions.LedgerTransaction @@ -25,10 +24,7 @@ import net.corda.node.services.statemachine.StateMachineManager import net.corda.node.services.transactions.InMemoryTransactionVerifierService import net.corda.node.services.transactions.OutOfProcessTransactionVerifierService import net.corda.node.utilities.* -import net.corda.nodeapi.ArtemisMessagingComponent -import net.corda.nodeapi.ArtemisTcpTransport -import net.corda.nodeapi.ConnectionDirection -import net.corda.nodeapi.VerifierApi +import net.corda.nodeapi.* import net.corda.nodeapi.VerifierApi.VERIFICATION_REQUESTS_QUEUE_NAME import net.corda.nodeapi.VerifierApi.VERIFICATION_RESPONSES_QUEUE_NAME_PREFIX import org.apache.activemq.artemis.api.core.ActiveMQObjectClosedException @@ -36,6 +32,7 @@ import org.apache.activemq.artemis.api.core.Message.* import org.apache.activemq.artemis.api.core.SimpleString import org.apache.activemq.artemis.api.core.client.* import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE +import org.apache.activemq.artemis.api.core.management.ActiveMQServerControl import org.bouncycastle.asn1.x500.X500Name import org.jetbrains.exposed.sql.Database import org.jetbrains.exposed.sql.ResultRow @@ -71,7 +68,7 @@ class NodeMessagingClient(override val config: NodeConfiguration, val versionInfo: VersionInfo, val serverHostPort: HostAndPort, val myIdentity: PublicKey?, - val nodeExecutor: AffinityExecutor, + val nodeExecutor: AffinityExecutor.ServiceAffinityExecutor, val database: Database, val networkMapRegistrationFuture: ListenableFuture, val monitoringService: MonitoringService @@ -100,11 +97,9 @@ class NodeMessagingClient(override val config: NodeConfiguration, var producer: ClientProducer? = null var p2pConsumer: ClientConsumer? = null var session: ClientSession? = null - var clientFactory: ClientSessionFactory? = null - var rpcDispatcher: RPCDispatcher? = null + var sessionFactory: ClientSessionFactory? = null + var rpcServer: RPCServer? = null // Consumer for inbound client RPC messages. - var rpcConsumer: ClientConsumer? = null - var rpcNotificationConsumer: ClientConsumer? = null var verificationResponseConsumer: ClientConsumer? = null } @@ -163,18 +158,19 @@ class NodeMessagingClient(override val config: NodeConfiguration, val tcpTransport = ArtemisTcpTransport.tcpTransport(ConnectionDirection.Outbound(), serverHostPort, config) val locator = ActiveMQClient.createServerLocatorWithoutHA(tcpTransport) locator.minLargeMessageSize = ArtemisMessagingServer.MAX_FILE_SIZE - clientFactory = locator.createSessionFactory() + sessionFactory = locator.createSessionFactory() // Login using the node username. The broker will authentiate us as its node (as opposed to another peer) // using our TLS certificate. // Note that the acknowledgement of messages is not flushed to the Artermis journal until the default buffer // size of 1MB is acknowledged. - val session = clientFactory!!.createSession(NODE_USER, NODE_USER, false, true, true, locator.isPreAcknowledge, DEFAULT_ACK_BATCH_SIZE) + val session = sessionFactory!!.createSession(NODE_USER, NODE_USER, false, true, true, locator.isPreAcknowledge, DEFAULT_ACK_BATCH_SIZE) this.session = session session.start() // Create a general purpose producer. - producer = session.createProducer() + val producer = session.createProducer() + this.producer = producer // Create a queue, consumer and producer for handling P2P network messages. p2pConsumer = makeP2PConsumer(session, true) @@ -190,9 +186,7 @@ class NodeMessagingClient(override val config: NodeConfiguration, } } - rpcConsumer = session.createConsumer(RPC_REQUESTS_QUEUE) - rpcNotificationConsumer = session.createConsumer(RPC_QUEUE_REMOVALS_QUEUE) - rpcDispatcher = createRPCDispatcher(rpcOps, userService, config.myLegalName) + rpcServer = RPCServer(rpcOps, NODE_USER, NODE_USER, locator, userService, config.myLegalName) fun checkVerifierCount() { if (session.queueQuery(SimpleString(VERIFICATION_REQUESTS_QUEUE_NAME)).consumerCount == 0) { @@ -269,12 +263,12 @@ class NodeMessagingClient(override val config: NodeConfiguration, return true } - private fun runPreNetworkMap() { + private fun runPreNetworkMap(serverControl: ActiveMQServerControl) { val consumer = state.locked { check(started) { "start must be called first" } check(!running) { "run can't be called twice" } running = true - rpcDispatcher!!.start(rpcConsumer!!, rpcNotificationConsumer!!, nodeExecutor) + rpcServer!!.start(serverControl) (verifierService as? OutOfProcessTransactionVerifierService)?.start(verificationResponseConsumer!!) p2pConsumer!! } @@ -300,9 +294,9 @@ class NodeMessagingClient(override val config: NodeConfiguration, * we get our network map fetch response. At that point the filtering consumer is closed and we proceed to the second loop and * consume all messages via a new consumer without a filter applied. */ - fun run() { + fun run(serverControl: ActiveMQServerControl) { // Build the network map. - runPreNetworkMap() + runPreNetworkMap(serverControl) // Process everything else once we have the network map. runPostNetworkMap() shutdownLatch.countDown() @@ -404,17 +398,13 @@ class NodeMessagingClient(override val config: NodeConfiguration, // Only first caller to gets running true to protect against double stop, which seems to happen in some integration tests. if (running) { state.locked { - rpcConsumer?.close() - rpcConsumer = null - rpcNotificationConsumer?.close() - rpcNotificationConsumer = null producer?.close() producer = null // Ensure any trailing messages are committed to the journal session!!.commit() // Closing the factory closes all the sessions it produced as well. - clientFactory!!.close() - clientFactory = null + sessionFactory!!.close() + sessionFactory = null } } } @@ -547,22 +537,6 @@ class NodeMessagingClient(override val config: NodeConfiguration, } } - private fun createRPCDispatcher(ops: RPCOps, userService: RPCUserService, nodeLegalName: X500Name): RPCDispatcher = - object : RPCDispatcher(ops, userService, nodeLegalName) { - override fun send(data: SerializedBytes<*>, toAddress: String) { - messagingExecutor.fetchFrom { - state.locked { - val msg = session!!.createMessage(false).apply { - writeBodyBufferBytes(data.bytes) - // Use the magic deduplication property built into Artemis as our message identity too - putStringProperty(HDR_DUPLICATE_DETECTION_ID, SimpleString(UUID.randomUUID().toString())) - } - producer!!.send(toAddress, msg) - } - } - } - } - private fun createOutOfProcessVerifierService(): TransactionVerifierService { return object : OutOfProcessTransactionVerifierService(monitoringService) { override fun sendRequest(nonce: Long, transaction: LedgerTransaction) { diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/RPCDispatcher.kt b/node/src/main/kotlin/net/corda/node/services/messaging/RPCDispatcher.kt deleted file mode 100644 index 4e8e7de62a..0000000000 --- a/node/src/main/kotlin/net/corda/node/services/messaging/RPCDispatcher.kt +++ /dev/null @@ -1,219 +0,0 @@ -package net.corda.node.services.messaging - -import com.esotericsoftware.kryo.Kryo -import com.esotericsoftware.kryo.KryoException -import com.esotericsoftware.kryo.Serializer -import com.esotericsoftware.kryo.io.Input -import com.esotericsoftware.kryo.io.Output -import com.esotericsoftware.kryo.pool.KryoPool -import com.google.common.annotations.VisibleForTesting -import com.google.common.collect.HashMultimap -import net.corda.core.ErrorOr -import net.corda.core.crypto.commonName -import net.corda.core.messaging.RPCOps -import net.corda.core.messaging.RPCReturnsObservables -import net.corda.core.serialization.SerializedBytes -import net.corda.core.serialization.deserialize -import net.corda.core.serialization.serialize -import net.corda.core.utilities.debug -import net.corda.node.services.RPCUserService -import net.corda.node.utilities.AffinityExecutor -import net.corda.nodeapi.* -import net.corda.nodeapi.ArtemisMessagingComponent.Companion.NODE_USER -import org.apache.activemq.artemis.api.core.Message -import org.apache.activemq.artemis.api.core.client.ClientConsumer -import org.apache.activemq.artemis.api.core.client.ClientMessage -import org.bouncycastle.asn1.x500.X500Name -import rx.Notification -import rx.Observable -import rx.Subscription -import java.lang.reflect.InvocationTargetException -import java.util.concurrent.atomic.AtomicInteger - -/** - * Intended to service transient clients only (not p2p nodes) for short-lived, transient request/response pairs. - * If you need robustness, this is the wrong system. If you don't want a response, this is probably the - * wrong system (you could just send a message). If you want complex customisation of how requests/responses - * are handled, this is probably the wrong system. - */ -// TODO remove the nodeLegalName parameter once the webserver doesn't need special privileges -abstract class RPCDispatcher(val ops: RPCOps, val userService: RPCUserService, val nodeLegalName: X500Name) { - // Throw an exception if there are overloaded methods - private val methodTable = ops.javaClass.declaredMethods.groupBy { it.name }.mapValues { it.value.single() } - - private val queueToSubscription = HashMultimap.create() - - private val handleCounter = AtomicInteger() - - // Created afresh for every RPC that is annotated as returning observables. Every time an observable is - // encountered either in the RPC response or in an object graph that is being emitted by one of those - // observables, the handle counter is incremented and the server-side observable is subscribed to. The - // materialized observations are then sent to the queue the client created where they can be picked up. - // - // When the observables are deserialised on the client side, the handle is read from the byte stream and - // the queue is filtered to extract just those observations. - class ObservableSerializer : Serializer>() { - private fun toQName(kryo: Kryo): String = kryo.context[RPCKryoQNameKey] as String - private fun toDispatcher(kryo: Kryo): RPCDispatcher = kryo.context[RPCKryoDispatcherKey] as RPCDispatcher - - override fun read(kryo: Kryo, input: Input, type: Class>): Observable { - throw UnsupportedOperationException("not implemented") - } - - override fun write(kryo: Kryo, output: Output, obj: Observable) { - val qName = toQName(kryo) - val dispatcher = toDispatcher(kryo) - val handle = dispatcher.handleCounter.andIncrement - output.writeInt(handle, true) - // Observables can do three kinds of callback: "next" with a content object, "completed" and "error". - // Materializing the observable converts these three kinds of callback into a single stream of objects - // representing what happened, which is useful for us to send over the wire. - val subscription = obj.materialize().subscribe { materialised: Notification -> - val newKryo = createRPCKryoForSerialization(qName, dispatcher) - val bits = try { - MarshalledObservation(handle, materialised).serialize(newKryo) - } finally { - releaseRPCKryoForSerialization(newKryo) - } - rpcLog.debug("RPC sending observation: $materialised") - dispatcher.send(bits, qName) - } - synchronized(dispatcher.queueToSubscription) { - dispatcher.queueToSubscription.put(qName, subscription) - } - } - } - - fun dispatch(msg: ClientRPCRequestMessage) { - val (argsBytes, replyTo, observationsTo, methodName) = msg - - val response: ErrorOr = ErrorOr.catch { - val method = methodTable[methodName] ?: throw RPCException("Received RPC for unknown method $methodName - possible client/server version skew?") - if (method.isAnnotationPresent(RPCReturnsObservables::class.java) && observationsTo == null) - throw RPCException("Received RPC without any destination for observations, but the RPC returns observables") - - val kryo = createRPCKryoForSerialization(observationsTo, this) - val args = try { - argsBytes.deserialize(kryo) - } finally { - releaseRPCKryoForSerialization(kryo) - } - - rpcLog.debug { "-> RPC -> $methodName(${args.joinToString()}) [reply to $replyTo]" } - - try { - method.invoke(ops, *args) - } catch (e: InvocationTargetException) { - throw e.cause!! - } - } - rpcLog.debug { "<- RPC <- $methodName = $response " } - - // Serialise, or send back a simple serialised ErrorOr structure if we couldn't do it. - val kryo = createRPCKryoForSerialization(observationsTo, this) - val responseBits = try { - response.serialize(kryo) - } catch (e: KryoException) { - rpcLog.error("Failed to respond to inbound RPC $methodName", e) - ErrorOr.of(e).serialize(kryo) - } finally { - releaseRPCKryoForSerialization(kryo) - } - send(responseBits, replyTo) - } - - abstract fun send(data: SerializedBytes<*>, toAddress: String) - - fun start(rpcConsumer: ClientConsumer, rpcNotificationConsumer: ClientConsumer?, onExecutor: AffinityExecutor) { - rpcNotificationConsumer?.setMessageHandler { msg -> - val qName = msg.getStringProperty("_AMQ_RoutingName") - val subscriptions = synchronized(queueToSubscription) { - queueToSubscription.removeAll(qName) - } - if (subscriptions.isNotEmpty()) { - rpcLog.debug("Observable queue was deleted, unsubscribing: $qName") - subscriptions.forEach { it.unsubscribe() } - } - } - rpcConsumer.setMessageHandler { msg -> - msg.acknowledge() - // All RPCs run on the main server thread, in order to avoid running concurrently with - // potentially state changing requests from other nodes and each other. If we need to - // give better latency to client RPCs in future we could use an executor that supports - // job priorities. - onExecutor.execute { - try { - val rpcMessage = msg.toRPCRequestMessage() - CURRENT_RPC_USER.set(rpcMessage.user) - dispatch(rpcMessage) - } catch(e: RPCException) { - rpcLog.warn("Received malformed client RPC message: ${e.message}") - rpcLog.trace("RPC exception", e) - } catch(e: Throwable) { - rpcLog.error("Uncaught exception when dispatching client RPC", e) - } finally { - CURRENT_RPC_USER.remove() - } - } - } - } - - private fun ClientMessage.requiredString(name: String): String { - return getStringProperty(name) ?: throw RPCException("missing $name property") - } - - /** Convert an Artemis [ClientMessage] to a MQ-neutral [ClientRPCRequestMessage]. */ - private fun ClientMessage.toRPCRequestMessage(): ClientRPCRequestMessage { - val user = getUser(this) - val replyTo = getReturnAddress(user, ClientRPCRequestMessage.REPLY_TO, true)!! - val observationsTo = getReturnAddress(user, ClientRPCRequestMessage.OBSERVATIONS_TO, false) - val argBytes = ByteArray(bodySize).apply { bodyBuffer.readBytes(this) } - if (argBytes.isEmpty()) { - throw RPCException("empty serialized args") - } - val methodName = requiredString(ClientRPCRequestMessage.METHOD_NAME) - return ClientRPCRequestMessage(SerializedBytes(argBytes), replyTo, observationsTo, methodName, user) - } - - // TODO remove this User once webserver doesn't need it - private val nodeUser = User(NODE_USER, NODE_USER, setOf()) - - @VisibleForTesting - protected open fun getUser(message: ClientMessage): User { - val validatedUser = message.requiredString(Message.HDR_VALIDATED_USER.toString()) - val rpcUser = userService.getUser(validatedUser) - if (rpcUser != null) { - return rpcUser - } else { - try { - if (X500Name(validatedUser) == nodeLegalName) { - return nodeUser - } - } catch (ex: IllegalArgumentException) { - // Just means the two can't be compared, treat as no match - } - throw IllegalArgumentException("Validated user '$validatedUser' is not an RPC user nor the NODE user") - } - } - - private fun ClientMessage.getReturnAddress(user: User, property: String, required: Boolean): String? { - return if (containsProperty(property)) { - "${ArtemisMessagingComponent.CLIENTS_PREFIX}${user.username}.rpc.${getLongProperty(property)}" - } else { - if (required) throw RPCException("missing $property property") else null - } - } -} - -private val rpcSerKryoPool = KryoPool.Builder { RPCKryo(RPCDispatcher.ObservableSerializer()) }.build() - -fun createRPCKryoForSerialization(qName: String? = null, dispatcher: RPCDispatcher? = null): Kryo { - val kryo = rpcSerKryoPool.borrow() - kryo.context.put(RPCKryoQNameKey, qName) - kryo.context.put(RPCKryoDispatcherKey, dispatcher) - return kryo -} - -fun releaseRPCKryoForSerialization(kryo: Kryo) { - rpcSerKryoPool.release(kryo) -} diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/RPCServer.kt b/node/src/main/kotlin/net/corda/node/services/messaging/RPCServer.kt new file mode 100644 index 0000000000..cb3f44ab28 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/messaging/RPCServer.kt @@ -0,0 +1,346 @@ +package net.corda.node.services.messaging + +import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.Serializer +import com.esotericsoftware.kryo.io.Input +import com.esotericsoftware.kryo.io.Output +import com.esotericsoftware.kryo.pool.KryoPool +import com.google.common.cache.Cache +import com.google.common.cache.CacheBuilder +import com.google.common.cache.RemovalListener +import com.google.common.collect.HashMultimap +import com.google.common.collect.Multimaps +import com.google.common.collect.SetMultimap +import com.google.common.util.concurrent.ThreadFactoryBuilder +import net.corda.core.ErrorOr +import net.corda.core.crypto.commonName +import net.corda.core.messaging.RPCOps +import net.corda.core.random63BitValue +import net.corda.core.serialization.KryoPoolWithContext +import net.corda.core.utilities.LazyStickyPool +import net.corda.core.utilities.debug +import net.corda.core.utilities.loggerFor +import net.corda.node.services.RPCUserService +import net.corda.nodeapi.* +import net.corda.nodeapi.ArtemisMessagingComponent.Companion.NODE_USER +import org.apache.activemq.artemis.api.core.Message +import org.apache.activemq.artemis.api.core.SimpleString +import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE +import org.apache.activemq.artemis.api.core.client.ClientConsumer +import org.apache.activemq.artemis.api.core.client.ClientMessage +import org.apache.activemq.artemis.api.core.client.ServerLocator +import org.apache.activemq.artemis.api.core.management.ActiveMQServerControl +import org.apache.activemq.artemis.api.core.management.CoreNotificationType +import org.apache.activemq.artemis.api.core.management.ManagementHelper +import org.bouncycastle.asn1.x500.X500Name +import rx.Notification +import rx.Observable +import rx.Subscriber +import rx.Subscription +import java.lang.reflect.InvocationTargetException +import java.util.concurrent.ExecutorService +import java.util.concurrent.Executors +import java.util.concurrent.ScheduledFuture +import java.util.concurrent.TimeUnit + +data class RPCServerConfiguration( + /** The number of threads to use for handling RPC requests */ + val rpcThreadPoolSize: Int, + /** The number of consumers to handle incoming messages */ + val consumerPoolSize: Int, + /** The maximum number of producers to create to handle outgoing messages */ + val producerPoolBound: Int, + /** The interval of subscription reaping in milliseconds */ + val reapIntervalMs: Long +) { + companion object { + val default = RPCServerConfiguration( + rpcThreadPoolSize = 4, + consumerPoolSize = 2, + producerPoolBound = 4, + reapIntervalMs = 1000 + ) + } +} + +/** + * The [RPCServer] implements the complement of [RPCClient]. When an RPC request arrives it dispatches to the + * corresponding function in [ops]. During serialisation of the reply (and later observations) the server subscribes to + * each Observable it encounters and captures the client address to associate with these Observables. Later it uses this + * address to forward observations arriving on the Observables. + * + * The way this is done is similar to that in [RPCClient], we use Kryo and add a context to stores the subscription map. + */ +class RPCServer( + private val ops: RPCOps, + private val rpcServerUsername: String, + private val rpcServerPassword: String, + private val serverLocator: ServerLocator, + private val userService: RPCUserService, + private val nodeLegalName: X500Name, + private val rpcConfiguration: RPCServerConfiguration = RPCServerConfiguration.default +) { + private companion object { + val log = loggerFor() + val kryoPool = KryoPool.Builder { RPCKryo(RpcServerObservableSerializer) }.build() + } + // The methodname->Method map to use for dispatching. + private val methodTable = ops.javaClass.declaredMethods.groupBy { it.name }.mapValues { it.value.single() } + // The observable subscription mapping. + private val observableMap = createObservableSubscriptionMap() + // A mapping from client addresses to IDs of associated Observables + private val clientAddressToObservables = Multimaps.synchronizedSetMultimap(HashMultimap.create()) + // The scheduled reaper handle. + private lateinit var reaperScheduledFuture: ScheduledFuture<*> + + private val observationSendExecutor = Executors.newFixedThreadPool( + 1, + ThreadFactoryBuilder().setNameFormat("rpc-observation-sender-%d").build() + ) + + private val rpcExecutor = Executors.newScheduledThreadPool( + rpcConfiguration.rpcThreadPoolSize, + ThreadFactoryBuilder().setNameFormat("rpc-server-handler-pool-%d").build() + ) + + private val reaperExecutor = Executors.newScheduledThreadPool( + 1, + ThreadFactoryBuilder().setNameFormat("rpc-server-reaper-%d").build() + ) + + private val sessionAndConsumers = ArrayList(rpcConfiguration.consumerPoolSize) + private val sessionAndProducerPool = LazyStickyPool(rpcConfiguration.producerPoolBound) { + val sessionFactory = serverLocator.createSessionFactory() + val session = sessionFactory.createSession(rpcServerUsername, rpcServerPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE) + session.start() + ArtemisProducer(sessionFactory, session, session.createProducer()) + } + private lateinit var clientBindingRemovalConsumer: ClientConsumer + private lateinit var serverControl: ActiveMQServerControl + + private fun createObservableSubscriptionMap(): ObservableSubscriptionMap { + val onObservableRemove = RemovalListener { + log.debug { "Unsubscribing from Observable with id ${it.key} because of ${it.cause}" } + it.value.subscription.unsubscribe() + } + return CacheBuilder.newBuilder().removalListener(onObservableRemove).build() + } + + fun start(activeMqServerControl: ActiveMQServerControl) { + log.info("Starting RPC server with configuration $rpcConfiguration") + reaperScheduledFuture = reaperExecutor.scheduleAtFixedRate( + this::reapSubscriptions, + rpcConfiguration.reapIntervalMs, + rpcConfiguration.reapIntervalMs, + TimeUnit.MILLISECONDS + ) + for (i in 1 .. rpcConfiguration.consumerPoolSize) { + val sessionFactory = serverLocator.createSessionFactory() + val session = sessionFactory.createSession(rpcServerUsername, rpcServerPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE) + val consumer = session.createConsumer(RPCApi.RPC_SERVER_QUEUE_NAME) + consumer.setMessageHandler(this@RPCServer::clientArtemisMessageHandler) + session.start() + sessionAndConsumers.add(ArtemisConsumer(sessionFactory, session, consumer)) + } + clientBindingRemovalConsumer = sessionAndConsumers[0].session.createConsumer(RPCApi.RPC_CLIENT_BINDING_REMOVALS) + clientBindingRemovalConsumer.setMessageHandler(this::bindingRemovalArtemisMessageHandler) + serverControl = activeMqServerControl + } + + fun close() { + reaperScheduledFuture.cancel(false) + rpcExecutor.shutdownNow() + reaperExecutor.shutdownNow() + rpcExecutor.awaitTermination(500, TimeUnit.MILLISECONDS) + reaperExecutor.awaitTermination(500, TimeUnit.MILLISECONDS) + sessionAndConsumers.forEach { + it.consumer.close() + it.session.close() + it.sessionFactory.close() + } + observableMap.invalidateAll() + reapSubscriptions() + sessionAndProducerPool.close().forEach { + it.producer.close() + it.session.close() + it.sessionFactory.close() + } + } + + private fun bindingRemovalArtemisMessageHandler(artemisMessage: ClientMessage) { + val notificationType = artemisMessage.getStringProperty(ManagementHelper.HDR_NOTIFICATION_TYPE) + require(notificationType == CoreNotificationType.BINDING_REMOVED.name) + val clientAddress = artemisMessage.getStringProperty(ManagementHelper.HDR_ROUTING_NAME) + log.warn("Detected RPC client disconnect on address $clientAddress, scheduling for reaping") + invalidateClient(SimpleString(clientAddress)) + } + + // Note that this function operates on the *current* view of client observables. During invalidation further + // Observables may be serialised and thus registered. + private fun invalidateClient(clientAddress: SimpleString) { + val observableIds = clientAddressToObservables.removeAll(clientAddress) + observableMap.invalidateAll(observableIds) + } + + private fun clientArtemisMessageHandler(artemisMessage: ClientMessage) { + val clientToServer = RPCApi.ClientToServer.fromClientMessage(kryoPool, artemisMessage) + log.debug { "-> RPC -> $clientToServer" } + when (clientToServer) { + is RPCApi.ClientToServer.RpcRequest -> { + val rpcContext = RpcContext( + currentUser = getUser(artemisMessage) + ) + rpcExecutor.submit { + val result = ErrorOr.catch { + try { + CURRENT_RPC_CONTEXT.set(rpcContext) + log.debug { "Calling ${clientToServer.methodName}" } + val method = methodTable[clientToServer.methodName] ?: + throw RPCException("Received RPC for unknown method ${clientToServer.methodName} - possible client/server version skew?") + method.invoke(ops, *clientToServer.arguments.toTypedArray()) + } finally { + CURRENT_RPC_CONTEXT.remove() + } + } + val resultWithExceptionUnwrapped = result.mapError { + if (it is InvocationTargetException) { + it.cause ?: RPCException("Caught InvocationTargetException without cause") + } else { + it + } + } + val reply = RPCApi.ServerToClient.RpcReply( + id = clientToServer.id, + result = resultWithExceptionUnwrapped + ) + val observableContext = ObservableContext( + clientToServer.id, + observableMap, + clientAddressToObservables, + clientToServer.clientAddress, + serverControl, + sessionAndProducerPool, + observationSendExecutor, + kryoPool + ) + observableContext.sendMessage(reply) + } + } + is RPCApi.ClientToServer.ObservablesClosed -> { + observableMap.invalidateAll(clientToServer.ids) + } + } + artemisMessage.acknowledge() + } + + private fun reapSubscriptions() { + observableMap.cleanUp() + } + + // TODO remove this User once webserver doesn't need it + private val nodeUser = User(NODE_USER, NODE_USER, setOf()) + private fun getUser(message: ClientMessage): User { + val validatedUser = message.getStringProperty(Message.HDR_VALIDATED_USER) ?: throw IllegalArgumentException("Missing validated user from the Artemis message") + val rpcUser = userService.getUser(validatedUser) + if (rpcUser != null) { + return rpcUser + } else if (X500Name(validatedUser) == nodeLegalName) { + return nodeUser + } else { + throw IllegalArgumentException("Validated user '$validatedUser' is not an RPC user nor the NODE user") + } + } +} + +@JvmField +internal val CURRENT_RPC_CONTEXT: ThreadLocal = ThreadLocal() +fun getRpcContext(): RpcContext = CURRENT_RPC_CONTEXT.get() + +/** + * @param currentUser This is available to RPC implementations to query the validated [User] that is calling it. Each + * user has a set of permissions they're entitled to which can be used to control access. + */ +data class RpcContext( + val currentUser: User +) + +class ObservableSubscription( + val subscription: Subscription +) + +typealias ObservableSubscriptionMap = Cache + +// We construct an observable context on each RPC request. If subsequently a nested Observable is +// encountered this same context is propagated by the instrumented KryoPool. This way all +// observations rooted in a single RPC will be muxed correctly. Note that the context construction +// itself is quite cheap. +class ObservableContext( + val rpcRequestId: RPCApi.RpcRequestId, + val observableMap: ObservableSubscriptionMap, + val clientAddressToObservables: SetMultimap, + val clientAddress: SimpleString, + val serverControl: ActiveMQServerControl, + val sessionAndProducerPool: LazyStickyPool, + val observationSendExecutor: ExecutorService, + kryoPool: KryoPool +) { + private companion object { + val log = loggerFor() + } + + private val kryoPoolWithObservableContext = RpcServerObservableSerializer.createPoolWithContext(kryoPool, this) + fun sendMessage(serverToClient: RPCApi.ServerToClient) { + try { + sessionAndProducerPool.run(rpcRequestId) { + val artemisMessage = it.session.createMessage(false) + serverToClient.writeToClientMessage(kryoPoolWithObservableContext, artemisMessage) + it.producer.send(clientAddress, artemisMessage) + log.debug("<- RPC <- $serverToClient") + } + } catch (throwable: Throwable) { + log.error("Failed to send message, kicking client. Message was $serverToClient", throwable) + serverControl.closeConsumerConnectionsForAddress(clientAddress.toString()) + } + } +} + +private object RpcServerObservableSerializer : Serializer>() { + private object RpcObservableContextKey + private val log = loggerFor() + + fun createPoolWithContext(kryoPool: KryoPool, observableContext: ObservableContext): KryoPool { + return KryoPoolWithContext(kryoPool, RpcObservableContextKey, observableContext) + } + + override fun read(kryo: Kryo?, input: Input?, type: Class>?): Observable { + throw UnsupportedOperationException() + } + + override fun write(kryo: Kryo, output: Output, observable: Observable) { + val observableId = RPCApi.ObservableId(random63BitValue()) + val observableContext = kryo.context[RpcObservableContextKey] as ObservableContext + output.writeLong(observableId.toLong, true) + val observableWithSubscription = ObservableSubscription( + // We capture [observableContext] in the subscriber. Note that all synchronisation/kryo borrowing + // must be done again within the subscriber + subscription = observable.materialize().subscribe( + object : Subscriber>() { + override fun onNext(observation: Notification) { + if (!isUnsubscribed) { + observableContext.observationSendExecutor.submit { + observableContext.sendMessage(RPCApi.ServerToClient.Observation(observableId, observation)) + } + } + } + override fun onError(exception: Throwable) { + log.error("onError called in materialize()d RPC Observable", exception) + } + override fun onCompleted() { + } + } + ) + ) + observableContext.clientAddressToObservables.put(observableContext.clientAddress, observableId) + observableContext.observableMap.put(observableId, observableWithSubscription) + } +} \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/RPCServerStructures.kt b/node/src/main/kotlin/net/corda/node/services/messaging/RPCServerStructures.kt index e1a3308867..ab546ea4d9 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/RPCServerStructures.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/RPCServerStructures.kt @@ -3,13 +3,11 @@ package net.corda.node.services.messaging import net.corda.nodeapi.ArtemisMessagingComponent -import net.corda.nodeapi.CURRENT_RPC_USER import net.corda.nodeapi.PermissionException /** Helper method which checks that the current RPC user is entitled for the given permission. Throws a [PermissionException] otherwise. */ -fun requirePermission(permission: String) { +fun RpcContext.requirePermission(permission: String) { // TODO remove the NODE_USER condition once webserver doesn't need it - val currentUser = CURRENT_RPC_USER.get() val currentUserPermissions = currentUser.permissions if (currentUser.username != ArtemisMessagingComponent.NODE_USER && currentUserPermissions.intersect(listOf(permission, "ALL")).isEmpty()) { throw PermissionException("User not permissioned for $permission, permissions are $currentUserPermissions") diff --git a/node/src/main/kotlin/net/corda/node/shell/InteractiveShell.kt b/node/src/main/kotlin/net/corda/node/shell/InteractiveShell.kt index ade6394b49..603faa42ed 100644 --- a/node/src/main/kotlin/net/corda/node/shell/InteractiveShell.kt +++ b/node/src/main/kotlin/net/corda/node/shell/InteractiveShell.kt @@ -19,10 +19,11 @@ import net.corda.jackson.JacksonSupport import net.corda.jackson.StringToMethodCallParser import net.corda.node.internal.Node import net.corda.node.printBasicNodeInfo +import net.corda.node.services.messaging.CURRENT_RPC_CONTEXT +import net.corda.node.services.messaging.RpcContext import net.corda.node.services.statemachine.FlowStateMachineImpl import net.corda.node.utilities.ANSIProgressRenderer import net.corda.nodeapi.ArtemisMessagingComponent -import net.corda.nodeapi.CURRENT_RPC_USER import net.corda.nodeapi.User import org.crsh.command.InvocationContext import org.crsh.console.jline.JLineProcessor @@ -120,7 +121,7 @@ object InteractiveShell { InterruptHandler { jlineProcessor.interrupt() }.install() thread(name = "Command line shell processor", isDaemon = true) { // Give whoever has local shell access administrator access to the node. - CURRENT_RPC_USER.set(User(ArtemisMessagingComponent.NODE_USER, "", setOf())) + CURRENT_RPC_CONTEXT.set(RpcContext(User(ArtemisMessagingComponent.NODE_USER, "", setOf()))) Emoji.renderIfSupported { jlineProcessor.run() } diff --git a/node/src/test/kotlin/net/corda/node/CordaRPCOpsImplTest.kt b/node/src/test/kotlin/net/corda/node/CordaRPCOpsImplTest.kt index afa13c3d95..36ed83a2ab 100644 --- a/node/src/test/kotlin/net/corda/node/CordaRPCOpsImplTest.kt +++ b/node/src/test/kotlin/net/corda/node/CordaRPCOpsImplTest.kt @@ -15,11 +15,12 @@ import net.corda.core.transactions.SignedTransaction import net.corda.flows.CashIssueFlow import net.corda.flows.CashPaymentFlow import net.corda.node.internal.CordaRPCOpsImpl +import net.corda.node.services.messaging.CURRENT_RPC_CONTEXT +import net.corda.node.services.messaging.RpcContext import net.corda.node.services.network.NetworkMapService import net.corda.node.services.startFlowPermission import net.corda.node.services.transactions.SimpleNotaryService import net.corda.node.utilities.transaction -import net.corda.nodeapi.CURRENT_RPC_USER import net.corda.nodeapi.PermissionException import net.corda.nodeapi.User import net.corda.testing.expect @@ -57,10 +58,10 @@ class CordaRPCOpsImplTest { aliceNode = network.createNode(networkMapAddress = networkMap.info.address) notaryNode = network.createNode(advertisedServices = ServiceInfo(SimpleNotaryService.type), networkMapAddress = networkMap.info.address) rpc = CordaRPCOpsImpl(aliceNode.services, aliceNode.smm, aliceNode.database) - CURRENT_RPC_USER.set(User("user", "pwd", permissions = setOf( + CURRENT_RPC_CONTEXT.set(RpcContext(User("user", "pwd", permissions = setOf( startFlowPermission(), startFlowPermission() - ))) + )))) aliceNode.database.transaction { stateMachineUpdates = rpc.stateMachinesAndUpdates().second @@ -194,7 +195,7 @@ class CordaRPCOpsImplTest { @Test fun `cash command by user not permissioned for cash`() { - CURRENT_RPC_USER.set(User("user", "pwd", permissions = emptySet())) + CURRENT_RPC_CONTEXT.set(RpcContext(User("user", "pwd", permissions = emptySet()))) assertThatExceptionOfType(PermissionException::class.java).isThrownBy { rpc.startFlow(::CashIssueFlow, Amount(100, USD), diff --git a/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTests.kt b/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTests.kt index 58fd1e6601..41184f80f8 100644 --- a/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTests.kt @@ -214,7 +214,7 @@ class ArtemisMessagingTests { receivedMessages.add(message) } // Run after the handlers are added, otherwise (some of) the messages get delivered and discarded / dead-lettered. - thread { messagingClient.run() } + thread { messagingClient.run(messagingServer!!.serverControl) } return messagingClient } diff --git a/samples/attachment-demo/src/integration-test/kotlin/net/corda/attachmentdemo/AttachmentDemoTest.kt b/samples/attachment-demo/src/integration-test/kotlin/net/corda/attachmentdemo/AttachmentDemoTest.kt index 98719e679b..6bcbb5f65c 100644 --- a/samples/attachment-demo/src/integration-test/kotlin/net/corda/attachmentdemo/AttachmentDemoTest.kt +++ b/samples/attachment-demo/src/integration-test/kotlin/net/corda/attachmentdemo/AttachmentDemoTest.kt @@ -25,14 +25,14 @@ class AttachmentDemoTest { ).getOrThrow() val senderThread = CompletableFuture.supplyAsync { - nodeA.rpcClientToNode().use(demoUser[0].username, demoUser[0].password) { - sender(this, numOfExpectedBytes) + nodeA.rpcClientToNode().start(demoUser[0].username, demoUser[0].password).use { + sender(it.proxy, numOfExpectedBytes) } }.exceptionally { it.printStackTrace() } val recipientThread = CompletableFuture.supplyAsync { - nodeB.rpcClientToNode().use(demoUser[0].username, demoUser[0].password) { - recipient(this) + nodeB.rpcClientToNode().start(demoUser[0].username, demoUser[0].password).use { + recipient(it.proxy) } }.exceptionally { it.printStackTrace() } diff --git a/samples/attachment-demo/src/main/kotlin/net/corda/attachmentdemo/AttachmentDemo.kt b/samples/attachment-demo/src/main/kotlin/net/corda/attachmentdemo/AttachmentDemo.kt index 87935453a6..ef3941ff87 100644 --- a/samples/attachment-demo/src/main/kotlin/net/corda/attachmentdemo/AttachmentDemo.kt +++ b/samples/attachment-demo/src/main/kotlin/net/corda/attachmentdemo/AttachmentDemo.kt @@ -18,10 +18,12 @@ import net.corda.core.utilities.DUMMY_NOTARY import net.corda.core.utilities.DUMMY_NOTARY_KEY import net.corda.core.utilities.Emoji import net.corda.flows.FinalityFlow +import net.corda.node.driver.poll import java.io.InputStream import java.net.HttpURLConnection import java.net.URL import java.security.PublicKey +import java.util.concurrent.Executors import java.util.jar.JarInputStream import javax.servlet.http.HttpServletResponse.SC_OK import javax.ws.rs.core.HttpHeaders.CONTENT_DISPOSITION @@ -50,15 +52,15 @@ fun main(args: Array) { Role.SENDER -> { val host = HostAndPort.fromString("localhost:10006") println("Connecting to sender node ($host)") - CordaRPCClient(host).use("demo", "demo") { - sender(this) + CordaRPCClient(host).start("demo", "demo").use { + sender(it.proxy) } } Role.RECIPIENT -> { val host = HostAndPort.fromString("localhost:10009") println("Connecting to the recipient node ($host)") - CordaRPCClient(host).use("demo", "demo") { - recipient(this) + CordaRPCClient(host).start("demo", "demo").use { + recipient(it.proxy) } } } @@ -72,7 +74,8 @@ fun sender(rpc: CordaRPCOps, numOfClearBytes: Int = 1024) { // default size 1K. fun sender(rpc: CordaRPCOps, inputStream: InputStream, hash: SecureHash.SHA256) { // Get the identity key of the other side (the recipient). - val otherSide: Party = rpc.partyFromX500Name(DUMMY_BANK_B.name) ?: throw IllegalStateException("Could not find counterparty \"${DUMMY_BANK_B.name}\"") + val executor = Executors.newScheduledThreadPool(1) + val otherSide: Party = poll(executor, DUMMY_BANK_B.name.toString()) { rpc.partyFromX500Name(DUMMY_BANK_B.name) }.get() // Make sure we have the file in storage if (!rpc.attachmentExists(hash)) { @@ -97,6 +100,7 @@ fun sender(rpc: CordaRPCOps, inputStream: InputStream, hash: SecureHash.SHA256) val flowHandle = rpc.startTrackedFlow(::FinalityFlow, stx, setOf(otherSide)) flowHandle.progress.subscribe(::println) flowHandle.returnValue.getOrThrow() + println("Sent ${stx.id}") } fun recipient(rpc: CordaRPCOps) { diff --git a/samples/bank-of-corda-demo/src/integration-test/kotlin/net/corda/bank/BankOfCordaRPCClientTest.kt b/samples/bank-of-corda-demo/src/integration-test/kotlin/net/corda/bank/BankOfCordaRPCClientTest.kt index 6ebc80facb..6a41c6546e 100644 --- a/samples/bank-of-corda-demo/src/integration-test/kotlin/net/corda/bank/BankOfCordaRPCClientTest.kt +++ b/samples/bank-of-corda-demo/src/integration-test/kotlin/net/corda/bank/BankOfCordaRPCClientTest.kt @@ -26,13 +26,11 @@ class BankOfCordaRPCClientTest { // Bank of Corda RPC Client val bocClient = nodeBankOfCorda.rpcClientToNode() - bocClient.start("bocManager", "password1") - val bocProxy = bocClient.proxy() + val bocProxy = bocClient.start("bocManager", "password1").proxy // Big Corporation RPC Client val bigCorpClient = nodeBigCorporation.rpcClientToNode() - bigCorpClient.start("bigCorpCFO", "password2") - val bigCorpProxy = bigCorpClient.proxy() + val bigCorpProxy = bigCorpClient.start("bigCorpCFO", "password2").proxy // Register for Bank of Corda Vault updates val vaultUpdatesBoc = bocProxy.vaultAndUpdates().second diff --git a/samples/bank-of-corda-demo/src/main/kotlin/net/corda/bank/api/BankOfCordaClientApi.kt b/samples/bank-of-corda-demo/src/main/kotlin/net/corda/bank/api/BankOfCordaClientApi.kt index d0eb6b1bba..b4fef75d86 100644 --- a/samples/bank-of-corda-demo/src/main/kotlin/net/corda/bank/api/BankOfCordaClientApi.kt +++ b/samples/bank-of-corda-demo/src/main/kotlin/net/corda/bank/api/BankOfCordaClientApi.kt @@ -32,18 +32,19 @@ class BankOfCordaClientApi(val hostAndPort: HostAndPort) { fun requestRPCIssue(params: IssueRequestParams): SignedTransaction { val client = CordaRPCClient(hostAndPort) // TODO: privileged security controls required - client.start("bankUser", "test") - val proxy = client.proxy() + client.start("bankUser", "test").use { connection -> + val proxy = connection.proxy - // Resolve parties via RPC - val issueToParty = proxy.partyFromX500Name(params.issueToPartyName) - ?: throw Exception("Unable to locate ${params.issueToPartyName} in Network Map Service") - val issuerBankParty = proxy.partyFromX500Name(params.issuerBankName) - ?: throw Exception("Unable to locate ${params.issuerBankName} in Network Map Service") + // Resolve parties via RPC + val issueToParty = proxy.partyFromX500Name(params.issueToPartyName) + ?: throw Exception("Unable to locate ${params.issueToPartyName} in Network Map Service") + val issuerBankParty = proxy.partyFromX500Name(params.issuerBankName) + ?: throw Exception("Unable to locate ${params.issuerBankName} in Network Map Service") - val amount = Amount(params.amount, currency(params.currency)) - val issuerToPartyRef = OpaqueBytes.of(params.issueToPartyRefAsString.toByte()) + val amount = Amount(params.amount, currency(params.currency)) + val issuerToPartyRef = OpaqueBytes.of(params.issueToPartyRefAsString.toByte()) - return proxy.startFlow(::IssuanceRequester, amount, issueToParty, issuerToPartyRef, issuerBankParty).returnValue.getOrThrow() + return proxy.startFlow(::IssuanceRequester, amount, issueToParty, issuerToPartyRef, issuerBankParty).returnValue.getOrThrow() + } } } diff --git a/samples/irs-demo/src/integration-test/kotlin/net/corda/irs/IRSDemoTest.kt b/samples/irs-demo/src/integration-test/kotlin/net/corda/irs/IRSDemoTest.kt index 39eed20d4a..d0fabb3c87 100644 --- a/samples/irs-demo/src/integration-test/kotlin/net/corda/irs/IRSDemoTest.kt +++ b/samples/irs-demo/src/integration-test/kotlin/net/corda/irs/IRSDemoTest.kt @@ -66,8 +66,7 @@ class IRSDemoTest : IntegrationTestCategory { fun getFixingDateObservable(config: FullNodeConfiguration): BlockingObservable { val client = CordaRPCClient(config.rpcAddress!!) - client.start("user", "password") - val proxy = client.proxy() + val proxy = client.start("user", "password").proxy val vaultUpdates = proxy.vaultAndUpdates().second val fixingDates = vaultUpdates.map { update -> diff --git a/samples/raft-notary-demo/src/main/kotlin/net/corda/notarydemo/NotaryDemo.kt b/samples/raft-notary-demo/src/main/kotlin/net/corda/notarydemo/NotaryDemo.kt index 84527acb6f..dfb792b91a 100644 --- a/samples/raft-notary-demo/src/main/kotlin/net/corda/notarydemo/NotaryDemo.kt +++ b/samples/raft-notary-demo/src/main/kotlin/net/corda/notarydemo/NotaryDemo.kt @@ -22,8 +22,8 @@ import kotlin.system.exitProcess fun main(args: Array) { val host = HostAndPort.fromString("localhost:10003") println("Connecting to the recipient node ($host)") - CordaRPCClient(host).use("demo", "demo") { - val api = NotaryDemoClientApi(this) + CordaRPCClient(host).start("demo", "demo").use { + val api = NotaryDemoClientApi(it.proxy) api.startNotarisation() } } diff --git a/samples/trader-demo/src/integration-test/kotlin/net/corda/traderdemo/TraderDemoTest.kt b/samples/trader-demo/src/integration-test/kotlin/net/corda/traderdemo/TraderDemoTest.kt index 37c68fba2d..fb73735ec3 100644 --- a/samples/trader-demo/src/integration-test/kotlin/net/corda/traderdemo/TraderDemoTest.kt +++ b/samples/trader-demo/src/integration-test/kotlin/net/corda/traderdemo/TraderDemoTest.kt @@ -4,11 +4,13 @@ import com.google.common.util.concurrent.Futures import net.corda.client.rpc.CordaRPCClient import net.corda.core.contracts.DOLLARS import net.corda.core.getOrThrow +import net.corda.core.millis import net.corda.core.node.services.ServiceInfo import net.corda.core.utilities.DUMMY_BANK_A import net.corda.core.utilities.DUMMY_BANK_B import net.corda.core.utilities.DUMMY_NOTARY import net.corda.flows.IssuerFlow +import net.corda.node.driver.poll import net.corda.node.services.startFlowPermission import net.corda.node.services.transactions.SimpleNotaryService import net.corda.nodeapi.User @@ -17,6 +19,7 @@ import net.corda.testing.node.NodeBasedTest import net.corda.traderdemo.flow.SellerFlow import org.assertj.core.api.Assertions.assertThat import org.junit.Test +import java.util.concurrent.Executors class TraderDemoTest : NodeBasedTest() { @Test @@ -35,7 +38,7 @@ class TraderDemoTest : NodeBasedTest() { val (nodeARpc, nodeBRpc) = listOf(nodeA, nodeB).map { val client = CordaRPCClient(it.configuration.rpcAddress!!) - client.start(demoUser[0].username, demoUser[0].password).proxy() + client.start(demoUser[0].username, demoUser[0].password).proxy } val clientA = TraderDemoClientApi(nodeARpc) @@ -48,10 +51,19 @@ class TraderDemoTest : NodeBasedTest() { clientA.runBuyer(amount = 100.DOLLARS) clientB.runSeller(counterparty = nodeA.info.legalIdentity.name, amount = 5.DOLLARS) - val actualPaper = listOf(clientA.commercialPaperCount, clientB.commercialPaperCount) assertThat(clientA.cashCount).isGreaterThan(originalACash) assertThat(clientB.cashCount).isEqualTo(expectedBCash) - assertThat(actualPaper).isEqualTo(expectedPaper) + // Wait until A receives the commercial paper + val executor = Executors.newScheduledThreadPool(1) + poll(executor, "A to be notified of the commercial paper", pollInterval = 100.millis) { + val actualPaper = listOf(clientA.commercialPaperCount, clientB.commercialPaperCount) + if (actualPaper == expectedPaper) { + Unit + } else { + null + } + }.get() + executor.shutdown() assertThat(clientA.dollarCashBalance).isEqualTo(95.DOLLARS) assertThat(clientB.dollarCashBalance).isEqualTo(5.DOLLARS) } diff --git a/samples/trader-demo/src/main/kotlin/net/corda/traderdemo/TraderDemo.kt b/samples/trader-demo/src/main/kotlin/net/corda/traderdemo/TraderDemo.kt index 3c16063890..82a0ab081c 100644 --- a/samples/trader-demo/src/main/kotlin/net/corda/traderdemo/TraderDemo.kt +++ b/samples/trader-demo/src/main/kotlin/net/corda/traderdemo/TraderDemo.kt @@ -45,13 +45,13 @@ private class TraderDemo { val role = options.valueOf(roleArg)!! if (role == Role.BUYER) { val host = HostAndPort.fromString("localhost:10006") - CordaRPCClient(host).use("demo", "demo") { - TraderDemoClientApi(this).runBuyer() + CordaRPCClient(host).start("demo", "demo").use { + TraderDemoClientApi(it.proxy).runBuyer() } } else { val host = HostAndPort.fromString("localhost:10009") CordaRPCClient(host).use("demo", "demo") { - TraderDemoClientApi(this).runSeller(1000.DOLLARS, DUMMY_BANK_A.name) + TraderDemoClientApi(it.proxy).runSeller(1000.DOLLARS, DUMMY_BANK_A.name) } } } diff --git a/test-utils/build.gradle b/test-utils/build.gradle index b4c5c8bbb4..8209fea019 100644 --- a/test-utils/build.gradle +++ b/test-utils/build.gradle @@ -16,6 +16,7 @@ dependencies { compile project(':node') compile project(':webserver') compile project(':verifier') + compile project(':client:mock') compile "org.jetbrains.kotlin:kotlin-stdlib-jre8:$kotlin_version" compile "org.jetbrains.kotlin:kotlin-reflect:$kotlin_version" diff --git a/test-utils/src/main/kotlin/net/corda/testing/Measure.kt b/test-utils/src/main/kotlin/net/corda/testing/Measure.kt new file mode 100644 index 0000000000..d3a6f9f65d --- /dev/null +++ b/test-utils/src/main/kotlin/net/corda/testing/Measure.kt @@ -0,0 +1,53 @@ +package net.corda.testing + +import kotlin.reflect.KCallable +import kotlin.reflect.jvm.reflect + +/** + * These functions may be used to run measurements of a function where the parameters are chosen from corresponding + * [Iterable]s in a lexical manner. An example use case would be benchmarking the speed of a certain function call using + * different combinations of parameters. + */ + +@Suppress("UNCHECKED_CAST") +fun measure(a: Iterable, f: (A) -> R) = + measure(listOf(a), f.reflect()!!) { (f as ((Any?)->R))(it[0]) } +@Suppress("UNCHECKED_CAST") +fun measure(a: Iterable, b: Iterable, f: (A, B) -> R) = + measure(listOf(a, b), f.reflect()!!) { (f as ((Any?,Any?)->R))(it[0], it[1]) } +@Suppress("UNCHECKED_CAST") +fun measure(a: Iterable, b: Iterable, c: Iterable, f: (A, B, C) -> R) = + measure(listOf(a, b, c), f.reflect()!!) { (f as ((Any?,Any?,Any?)->R))(it[0], it[1], it[2]) } +@Suppress("UNCHECKED_CAST") +fun measure(a: Iterable, b: Iterable, c: Iterable, d: Iterable, f: (A, B, C, D) -> R) = + measure(listOf(a, b, c, d), f.reflect()!!) { (f as ((Any?,Any?,Any?,Any?)->R))(it[0], it[1], it[2], it[3]) } + +private fun measure(paramIterables: List>, kCallable: KCallable, call: (Array) -> R): Iterable> { + val kParameters = kCallable.parameters + return iterateLexical(paramIterables).map { params -> + MeasureResult( + parameters = params.mapIndexed { index, param -> Pair(kParameters[index].name!!, param) }, + result = call(params.toTypedArray()) + ) + } +} + +data class MeasureResult( + val parameters: List>, + val result: R +) + +fun iterateLexical(iterables: List>): Iterable> { + val result = ArrayList>() + fun iterateLexicalHelper(index: Int, list: List) { + if (index < iterables.size) { + iterables[index].forEach { + iterateLexicalHelper(index + 1, list + it) + } + } else { + result.add(list) + } + } + iterateLexicalHelper(0, emptyList()) + return result +} diff --git a/test-utils/src/main/kotlin/net/corda/testing/RPCDriver.kt b/test-utils/src/main/kotlin/net/corda/testing/RPCDriver.kt new file mode 100644 index 0000000000..8e3ae09eb4 --- /dev/null +++ b/test-utils/src/main/kotlin/net/corda/testing/RPCDriver.kt @@ -0,0 +1,469 @@ +package net.corda.testing + +import com.google.common.net.HostAndPort +import com.google.common.util.concurrent.ListenableFuture +import net.corda.client.mock.Generator +import net.corda.client.mock.generateOrFail +import net.corda.client.mock.int +import net.corda.client.mock.string +import net.corda.client.rpc.internal.RPCClient +import net.corda.client.rpc.internal.RPCClientConfiguration +import net.corda.core.div +import net.corda.core.messaging.RPCOps +import net.corda.core.utilities.ProcessUtilities +import net.corda.node.driver.* +import net.corda.node.services.RPCUserService +import net.corda.node.services.messaging.ArtemisMessagingServer +import net.corda.node.services.messaging.RPCServer +import net.corda.node.services.messaging.RPCServerConfiguration +import net.corda.nodeapi.ArtemisTcpTransport +import net.corda.nodeapi.ConnectionDirection +import net.corda.nodeapi.RPCApi +import net.corda.nodeapi.User +import org.apache.activemq.artemis.api.core.SimpleString +import org.apache.activemq.artemis.api.core.TransportConfiguration +import org.apache.activemq.artemis.api.core.client.ActiveMQClient +import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE +import org.apache.activemq.artemis.api.core.client.ClientSession +import org.apache.activemq.artemis.api.core.management.ActiveMQServerControl +import org.apache.activemq.artemis.core.config.Configuration +import org.apache.activemq.artemis.core.config.CoreQueueConfiguration +import org.apache.activemq.artemis.core.config.impl.ConfigurationImpl +import org.apache.activemq.artemis.core.remoting.impl.invm.InVMAcceptorFactory +import org.apache.activemq.artemis.core.remoting.impl.invm.InVMConnectorFactory +import org.apache.activemq.artemis.core.remoting.impl.netty.NettyAcceptorFactory +import org.apache.activemq.artemis.core.security.CheckType +import org.apache.activemq.artemis.core.security.Role +import org.apache.activemq.artemis.core.server.embedded.EmbeddedActiveMQ +import org.apache.activemq.artemis.core.server.impl.ActiveMQServerImpl +import org.apache.activemq.artemis.core.settings.impl.AddressFullMessagePolicy +import org.apache.activemq.artemis.core.settings.impl.AddressSettings +import org.apache.activemq.artemis.spi.core.protocol.RemotingConnection +import org.apache.activemq.artemis.spi.core.security.ActiveMQSecurityManager3 +import org.bouncycastle.asn1.x500.X500Name +import java.lang.reflect.Method +import java.nio.file.Path +import java.nio.file.Paths +import java.util.* +import javax.security.cert.X509Certificate + +interface RPCDriverExposedDSLInterface : DriverDSLExposedInterface { + /** + * Starts an In-VM RPC server. Note that only a single one may be started. + * + * @param rpcUser The single user who can access the server through RPC, and their permissions. + * @param nodeLegalName The legal name of the node to check against to authenticate a super user. + * @param configuration The RPC server configuration. + * @param ops The server-side implementation of the RPC interface. + */ + fun startInVmRpcServer( + rpcUser: User = rpcTestUser, + nodeLegalName: X500Name = fakeNodeLegalName, + maxFileSize: Int = ArtemisMessagingServer.MAX_FILE_SIZE, + maxBufferedBytesPerClient: Long = 10L * ArtemisMessagingServer.MAX_FILE_SIZE, + configuration: RPCServerConfiguration = RPCServerConfiguration.default, + ops : I + ): ListenableFuture + + /** + * Starts an In-VM RPC client. + * + * @param rpcOpsClass The [Class] of the RPC interface. + * @param username The username to authenticate with. + * @param password The password to authenticate with. + * @param configuration The RPC client configuration. + */ + fun startInVmRpcClient( + rpcOpsClass: Class, + username: String = rpcTestUser.username, + password: String = rpcTestUser.password, + configuration: RPCClientConfiguration = RPCClientConfiguration.default + ): ListenableFuture + + /** + * Starts an In-VM Artemis session connecting to the RPC server. + * + * @param username The username to authenticate with. + * @param password The password to authenticate with. + */ + fun startInVmArtemisSession( + username: String = rpcTestUser.username, + password: String = rpcTestUser.password + ): ClientSession + + /** + * Starts a Netty RPC server. + * + * @param serverName The name of the server, to be used for the folder created for Artemis files. + * @param rpcUser The single user who can access the server through RPC, and their permissions. + * @param nodeLegalName The legal name of the node to check against to authenticate a super user. + * @param configuration The RPC server configuration. + * @param ops The server-side implementation of the RPC interface. + */ + fun startRpcServer( + serverName: String = "driver-rpc-server", + rpcUser: User = rpcTestUser, + nodeLegalName: X500Name = fakeNodeLegalName, + maxFileSize: Int = ArtemisMessagingServer.MAX_FILE_SIZE, + maxBufferedBytesPerClient: Long = 10L * ArtemisMessagingServer.MAX_FILE_SIZE, + configuration: RPCServerConfiguration = RPCServerConfiguration.default, + ops : I + ) : ListenableFuture + + /** + * Starts a Netty RPC client. + * + * @param rpcOpsClass The [Class] of the RPC interface. + * @param rpcAddress The address of the RPC server to connect to. + * @param username The username to authenticate with. + * @param password The password to authenticate with. + * @param configuration The RPC client configuration. + */ + fun startRpcClient( + rpcOpsClass: Class, + rpcAddress: HostAndPort, + username: String = rpcTestUser.username, + password: String = rpcTestUser.password, + configuration: RPCClientConfiguration = RPCClientConfiguration.default + ): ListenableFuture + + /** + * Starts a Netty RPC client in a new JVM process that calls random RPCs with random arguments. + * + * @param rpcOpsClass The [Class] of the RPC interface. + * @param rpcAddress The address of the RPC server to connect to. + * @param username The username to authenticate with. + * @param password The password to authenticate with. + */ + fun startRandomRpcClient( + rpcOpsClass: Class, + rpcAddress: HostAndPort, + username: String = rpcTestUser.username, + password: String = rpcTestUser.password + ): ListenableFuture + + /** + * Starts a Netty Artemis session connecting to an RPC server. + * + * @param rpcAddress The address of the RPC server. + * @param username The username to authenticate with. + * @param password The password to authenticate with. + */ + fun startArtemisSession( + rpcAddress: HostAndPort, + username: String = rpcTestUser.username, + password: String = rpcTestUser.password + ): ClientSession +} +inline fun RPCDriverExposedDSLInterface.startInVmRpcClient( + username: String = rpcTestUser.username, + password: String = rpcTestUser.password, + configuration: RPCClientConfiguration = RPCClientConfiguration.default +) = startInVmRpcClient(I::class.java, username, password, configuration) +inline fun RPCDriverExposedDSLInterface.startRandomRpcClient( + hostAndPort: HostAndPort, + username: String = rpcTestUser.username, + password: String = rpcTestUser.password +) = startRandomRpcClient(I::class.java, hostAndPort, username, password) +inline fun RPCDriverExposedDSLInterface.startRpcClient( + rpcAddress: HostAndPort, + username: String = rpcTestUser.username, + password: String = rpcTestUser.password, + configuration: RPCClientConfiguration = RPCClientConfiguration.default +) = startRpcClient(I::class.java, rpcAddress, username, password, configuration) + +interface RPCDriverInternalDSLInterface : DriverDSLInternalInterface, RPCDriverExposedDSLInterface + +data class RpcServerHandle( + val hostAndPort: HostAndPort, + val serverControl: ActiveMQServerControl +) + +val rpcTestUser = User("user1", "test", permissions = emptySet()) +val fakeNodeLegalName = X500Name("not:a:valid:name") + +fun rpcDriver( + isDebug: Boolean = false, + driverDirectory: Path = Paths.get("build", getTimestampAsDirectoryName()), + portAllocation: PortAllocation = PortAllocation.Incremental(10000), + debugPortAllocation: PortAllocation = PortAllocation.Incremental(5005), + systemProperties: Map = emptyMap(), + useTestClock: Boolean = false, + automaticallyStartNetworkMap: Boolean = false, + dsl: RPCDriverExposedDSLInterface.() -> A +) = genericDriver( + driverDsl = RPCDriverDSL( + DriverDSL( + portAllocation = portAllocation, + debugPortAllocation = debugPortAllocation, + systemProperties = systemProperties, + driverDirectory = driverDirectory.toAbsolutePath(), + useTestClock = useTestClock, + automaticallyStartNetworkMap = automaticallyStartNetworkMap, + isDebug = isDebug + ) + ), + coerce = { it }, + dsl = dsl +) + +private class SingleUserSecurityManager(val rpcUser: User) : ActiveMQSecurityManager3 { + override fun validateUser(user: String?, password: String?) = isValid(user, password) + override fun validateUserAndRole(user: String?, password: String?, roles: MutableSet?, checkType: CheckType?) = isValid(user, password) + override fun validateUser(user: String?, password: String?, certificates: Array?): String? { + return validate(user, password) + } + override fun validateUserAndRole(user: String?, password: String?, roles: MutableSet?, checkType: CheckType?, address: String?, connection: RemotingConnection?): String? { + return validate(user, password) + } + + private fun isValid(user: String?, password: String?): Boolean { + return rpcUser.username == user && rpcUser.password == password + } + private fun validate(user: String?, password: String?): String? { + return if (isValid(user, password)) user else null + } +} + +data class RPCDriverDSL( + val driverDSL: DriverDSL +) : DriverDSLInternalInterface by driverDSL, RPCDriverInternalDSLInterface { + private companion object { + val notificationAddress = "notifications" + + private fun ConfigurationImpl.configureCommonSettings(maxFileSize: Int, maxBufferedBytesPerClient: Long) { + managementNotificationAddress = SimpleString(notificationAddress) + isPopulateValidatedUser = true + journalBufferSize_NIO = maxFileSize + journalBufferSize_AIO = maxFileSize + journalFileSize = maxFileSize + queueConfigurations = listOf( + CoreQueueConfiguration().apply { + name = RPCApi.RPC_SERVER_QUEUE_NAME + address = RPCApi.RPC_SERVER_QUEUE_NAME + isDurable = false + }, + CoreQueueConfiguration().apply { + name = RPCApi.RPC_CLIENT_BINDING_REMOVALS + address = notificationAddress + filterString = RPCApi.RPC_CLIENT_BINDING_REMOVAL_FILTER_EXPRESSION + isDurable = false + } + ) + addressesSettings = mapOf( + "${RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX}.#" to AddressSettings().apply { + maxSizeBytes = maxBufferedBytesPerClient + addressFullMessagePolicy = AddressFullMessagePolicy.FAIL + } + ) + } + fun createInVmRpcServerArtemisConfig(maxFileSize: Int, maxBufferedBytesPerClient: Long): Configuration { + return ConfigurationImpl().apply { + acceptorConfigurations = setOf(TransportConfiguration(InVMAcceptorFactory::class.java.name)) + isPersistenceEnabled = false + configureCommonSettings(maxFileSize, maxBufferedBytesPerClient) + } + } + fun createRpcServerArtemisConfig(maxFileSize: Int, maxBufferedBytesPerClient: Long, baseDirectory: Path, hostAndPort: HostAndPort): Configuration { + val connectionDirection = ConnectionDirection.Inbound(acceptorFactoryClassName = NettyAcceptorFactory::class.java.name) + return ConfigurationImpl().apply { + val artemisDir = "$baseDirectory/artemis" + bindingsDirectory = "$artemisDir/bindings" + journalDirectory = "$artemisDir/journal" + largeMessagesDirectory = "$artemisDir/large-messages" + acceptorConfigurations = setOf(ArtemisTcpTransport.tcpTransport(connectionDirection, hostAndPort, null)) + configureCommonSettings(maxFileSize, maxBufferedBytesPerClient) + } + } + val inVmClientTransportConfiguration = TransportConfiguration(InVMConnectorFactory::class.java.name) + fun createNettyClientTransportConfiguration(hostAndPort: HostAndPort): TransportConfiguration { + return ArtemisTcpTransport.tcpTransport(ConnectionDirection.Outbound(), hostAndPort, null) + } + } + + override fun startInVmRpcServer( + rpcUser: User, + nodeLegalName: X500Name, + maxFileSize: Int, + maxBufferedBytesPerClient: Long, + configuration: RPCServerConfiguration, + ops: I + ): ListenableFuture { + return driverDSL.executorService.submit { + val artemisConfig = createInVmRpcServerArtemisConfig(maxFileSize, maxBufferedBytesPerClient) + val server = EmbeddedActiveMQ() + server.setConfiguration(artemisConfig) + server.setSecurityManager(SingleUserSecurityManager(rpcUser)) + server.start() + driverDSL.shutdownManager.registerShutdown { + server.activeMQServer.stop() + server.stop() + } + startRpcServerWithBrokerRunning( + rpcUser, nodeLegalName, configuration, ops, inVmClientTransportConfiguration, + server.activeMQServer.activeMQServerControl + ) + } + } + + override fun startInVmRpcClient(rpcOpsClass: Class, username: String, password: String, configuration: RPCClientConfiguration): ListenableFuture { + return driverDSL.executorService.submit { + val client = RPCClient(inVmClientTransportConfiguration, configuration) + val connection = client.start(rpcOpsClass, username, password) + driverDSL.shutdownManager.registerShutdown { + connection.close() + } + connection.proxy + } + } + + override fun startInVmArtemisSession(username: String, password: String): ClientSession { + val locator = ActiveMQClient.createServerLocatorWithoutHA(inVmClientTransportConfiguration) + val sessionFactory = locator.createSessionFactory() + val session = sessionFactory.createSession(username, password, false, true, true, locator.isPreAcknowledge, DEFAULT_ACK_BATCH_SIZE) + driverDSL.shutdownManager.registerShutdown { + session.close() + sessionFactory.close() + locator.close() + } + return session + } + + override fun startRpcServer( + serverName: String, + rpcUser: User, + nodeLegalName: X500Name, + maxFileSize: Int, + maxBufferedBytesPerClient: Long, + configuration: RPCServerConfiguration, + ops: I + ): ListenableFuture { + val hostAndPort = driverDSL.portAllocation.nextHostAndPort() + return driverDSL.executorService.submit { + val artemisConfig = createRpcServerArtemisConfig(maxFileSize, maxBufferedBytesPerClient, driverDSL.driverDirectory / serverName, hostAndPort) + val server = ActiveMQServerImpl(artemisConfig, SingleUserSecurityManager(rpcUser)) + server.start() + driverDSL.shutdownManager.registerShutdown { + server.stop() + addressMustNotBeBound(driverDSL.executorService, hostAndPort).get() + } + val transportConfiguration = createNettyClientTransportConfiguration(hostAndPort) + startRpcServerWithBrokerRunning( + rpcUser, nodeLegalName, configuration, ops, transportConfiguration, + server.activeMQServerControl + ) + RpcServerHandle(hostAndPort, server.activeMQServerControl) + } + } + + override fun startRpcClient( + rpcOpsClass: Class, + rpcAddress: HostAndPort, + username: String, + password: String, + configuration: RPCClientConfiguration + ): ListenableFuture { + return driverDSL.executorService.submit { + val client = RPCClient(ArtemisTcpTransport.tcpTransport(ConnectionDirection.Outbound(), rpcAddress, null), configuration) + val connection = client.start(rpcOpsClass, username, password) + driverDSL.shutdownManager.registerShutdown { + connection.close() + } + connection.proxy + } + } + + override fun startRandomRpcClient(rpcOpsClass: Class, rpcAddress: HostAndPort, username: String, password: String): ListenableFuture { + val processFuture = driverDSL.executorService.submit { + ProcessUtilities.startJavaProcess(listOf(rpcOpsClass.name, rpcAddress.toString(), username, password)) + } + driverDSL.shutdownManager.registerProcessShutdown(processFuture) + return processFuture + } + + override fun startArtemisSession(rpcAddress: HostAndPort, username: String, password: String): ClientSession { + val locator = ActiveMQClient.createServerLocatorWithoutHA(createNettyClientTransportConfiguration(rpcAddress)) + val sessionFactory = locator.createSessionFactory() + val session = sessionFactory.createSession(username, password, false, true, true, false, DEFAULT_ACK_BATCH_SIZE) + driverDSL.shutdownManager.registerShutdown { + session.close() + sessionFactory.close() + locator.close() + } + + return session + } + + + private fun startRpcServerWithBrokerRunning( + rpcUser: User, + nodeLegalName: X500Name, + configuration: RPCServerConfiguration, + ops: I, + transportConfiguration: TransportConfiguration, + serverControl: ActiveMQServerControl + ) { + val locator = ActiveMQClient.createServerLocatorWithoutHA(transportConfiguration).apply { + minLargeMessageSize = ArtemisMessagingServer.MAX_FILE_SIZE + } + val userService = object : RPCUserService { + override fun getUser(username: String): User? = if (username == rpcUser.username) rpcUser else null + override val users: List get() = listOf(rpcUser) + } + val rpcServer = RPCServer( + ops, + rpcUser.username, + rpcUser.password, + locator, + userService, + nodeLegalName, + configuration + ) + driverDSL.shutdownManager.registerShutdown { + rpcServer.close() + locator.close() + } + rpcServer.start(serverControl) + } +} + +/** + * An out-of-process RPC user that connects to an RPC server and issues random RPCs with random arguments. + */ +class RandomRpcUser { + + companion object { + private inline fun HashMap, Generator<*>>.add(generator: Generator) = this.putIfAbsent(T::class.java, generator) + val generatorStore = HashMap, Generator<*>>().apply { + add(Generator.string()) + add(Generator.int()) + } + data class Call(val method: Method, val call: () -> Any?) + + @JvmStatic + fun main(args: Array) { + require(args.size == 4) + @Suppress("UNCHECKED_CAST") + val rpcClass = Class.forName(args[0]) as Class + val hostAndPort = HostAndPort.fromString(args[1]) + val username = args[2] + val password = args[3] + val handle = RPCClient(hostAndPort, null).start(rpcClass, username, password) + val callGenerators = rpcClass.declaredMethods.map { method -> + Generator.sequence(method.parameters.map { + generatorStore[it.type] ?: throw Exception("No generator for ${it.type}") + }).map { arguments -> + Call(method, { method.invoke(handle.proxy, *arguments.toTypedArray()) }) + } + } + val callGenerator = Generator.choice(callGenerators) + val random = SplittableRandom() + + while (true) { + val call = callGenerator.generateOrFail(random) + call.call() + Thread.sleep(100) + } + } + } +} diff --git a/test-utils/src/main/kotlin/net/corda/testing/node/NodeBasedTest.kt b/test-utils/src/main/kotlin/net/corda/testing/node/NodeBasedTest.kt index 721eef8e4e..cf93e05ebe 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/node/NodeBasedTest.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/node/NodeBasedTest.kt @@ -2,15 +2,13 @@ package net.corda.testing.node import com.google.common.util.concurrent.Futures import com.google.common.util.concurrent.ListenableFuture -import net.corda.core.createDirectories +import net.corda.core.* import net.corda.core.crypto.X509Utilities import net.corda.core.crypto.commonName -import net.corda.core.div -import net.corda.core.flatMap -import net.corda.core.map import net.corda.core.node.services.ServiceInfo import net.corda.core.node.services.ServiceType import net.corda.core.utilities.DUMMY_MAP +import net.corda.node.driver.addressMustNotBeBound import net.corda.node.internal.Node import net.corda.node.services.config.ConfigHelper import net.corda.node.services.config.FullNodeConfiguration @@ -26,6 +24,7 @@ import org.junit.After import org.junit.Rule import org.junit.rules.TemporaryFolder import java.util.* +import java.util.concurrent.Executors import kotlin.concurrent.thread /** @@ -53,9 +52,18 @@ abstract class NodeBasedTest { */ @After fun stopAllNodes() { + val shutdownExecutor = Executors.newScheduledThreadPool(1) nodes.forEach(Node::stop) + // Wait until ports are released + val portNotBoundChecks = nodes.flatMap { + listOf( + it.configuration.p2pAddress.let { addressMustNotBeBound(shutdownExecutor, it) }, + it.configuration.rpcAddress?.let { addressMustNotBeBound(shutdownExecutor, it) } + ) + }.filterNotNull() nodes.clear() _networkMapNode = null + Futures.allAsList(portNotBoundChecks).getOrThrow() } /** diff --git a/test-utils/src/main/kotlin/net/corda/testing/node/SimpleNode.kt b/test-utils/src/main/kotlin/net/corda/testing/node/SimpleNode.kt index 2631249c14..80b51f99a6 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/node/SimpleNode.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/node/SimpleNode.kt @@ -57,7 +57,7 @@ class SimpleNode(val config: NodeConfiguration, val address: HostAndPort = freeL }, userService) thread(name = config.myLegalName.commonName) { - net.run() + net.run(broker.serverControl) } } diff --git a/tools/demobench/src/main/kotlin/net/corda/demobench/rpc/NodeRPC.kt b/tools/demobench/src/main/kotlin/net/corda/demobench/rpc/NodeRPC.kt index 9a8db6f387..43ec90a4cd 100644 --- a/tools/demobench/src/main/kotlin/net/corda/demobench/rpc/NodeRPC.kt +++ b/tools/demobench/src/main/kotlin/net/corda/demobench/rpc/NodeRPC.kt @@ -2,6 +2,7 @@ package net.corda.demobench.rpc import com.google.common.net.HostAndPort import net.corda.client.rpc.CordaRPCClient +import net.corda.client.rpc.CordaRPCConnection import net.corda.core.messaging.CordaRPCOps import net.corda.core.utilities.loggerFor import net.corda.demobench.model.NodeConfig @@ -17,14 +18,16 @@ class NodeRPC(config: NodeConfig, start: (NodeConfig, CordaRPCOps) -> Unit, invo private val rpcClient = CordaRPCClient(HostAndPort.fromParts("localhost", config.rpcPort)) private val timer = Timer() + private val connections = Collections.synchronizedCollection(ArrayList()) init { val setupTask = object : TimerTask() { override fun run() { try { val user = config.users.elementAt(0) - rpcClient.start(user.username, user.password) - val ops = rpcClient.proxy() + val connection = rpcClient.start(user.username, user.password) + connections.add(connection) + val ops = connection.proxy // Cancel the "setup" task now that we've created the RPC client. this.cancel() @@ -50,7 +53,7 @@ class NodeRPC(config: NodeConfig, start: (NodeConfig, CordaRPCOps) -> Unit, invo override fun close() { timer.cancel() - rpcClient.close() + connections.forEach(CordaRPCConnection::close) } } diff --git a/tools/explorer/src/main/kotlin/net/corda/explorer/Main.kt b/tools/explorer/src/main/kotlin/net/corda/explorer/Main.kt index 47edb4c6bd..6d8df49bff 100644 --- a/tools/explorer/src/main/kotlin/net/corda/explorer/Main.kt +++ b/tools/explorer/src/main/kotlin/net/corda/explorer/Main.kt @@ -198,20 +198,20 @@ fun main(args: Array) { // Register with alice to use alice's RPC proxy to create random events. val aliceClient = aliceNode.rpcClientToNode() - aliceClient.start(user.username, user.password) - val aliceRPC = aliceClient.proxy() + val aliceConnection = aliceClient.start(user.username, user.password) + val aliceRPC = aliceConnection.proxy val bobClient = bobNode.rpcClientToNode() - bobClient.start(user.username, user.password) - val bobRPC = bobClient.proxy() + val bobConnection = bobClient.start(user.username, user.password) + val bobRPC = bobConnection.proxy val issuerClientGBP = issuerNodeGBP.rpcClientToNode() - issuerClientGBP.start(manager.username, manager.password) - val issuerRPCGBP = issuerClientGBP.proxy() + val issuerGBPConnection = issuerClientGBP.start(manager.username, manager.password) + val issuerRPCGBP = issuerGBPConnection.proxy val issuerClientUSD = issuerNodeUSD.rpcClientToNode() - issuerClientUSD.start(manager.username, manager.password) - val issuerRPCUSD = issuerClientUSD.proxy() + val issuerUSDConnection = issuerClientUSD.start(manager.username, manager.password) + val issuerRPCUSD = issuerUSDConnection.proxy val issuers = mapOf(USD to issuerRPCUSD, GBP to issuerRPCGBP) @@ -272,10 +272,10 @@ fun main(args: Array) { } println("Simulation completed") - aliceClient.close() - bobClient.close() - issuerClientGBP.close() - issuerClientUSD.close() + aliceConnection.close() + bobConnection.close() + issuerGBPConnection.close() + issuerUSDConnection.close() } waitForAllNodesToFinish() } diff --git a/tools/loadtest/src/main/kotlin/net/corda/loadtest/ConnectionManager.kt b/tools/loadtest/src/main/kotlin/net/corda/loadtest/ConnectionManager.kt index 8b9031fadd..bf46ebff5b 100644 --- a/tools/loadtest/src/main/kotlin/net/corda/loadtest/ConnectionManager.kt +++ b/tools/loadtest/src/main/kotlin/net/corda/loadtest/ConnectionManager.kt @@ -6,6 +6,7 @@ import com.jcraft.jsch.agentproxy.AgentProxy import com.jcraft.jsch.agentproxy.connector.SSHAgentConnector import com.jcraft.jsch.agentproxy.usocket.JNAUSocketFactory import net.corda.client.rpc.CordaRPCClient +import net.corda.client.rpc.CordaRPCConnection import net.corda.core.messaging.CordaRPCOps import net.corda.node.driver.PortAllocation import org.slf4j.LoggerFactory @@ -137,9 +138,9 @@ class NodeConnection( private val rpcUsername: String, private val rpcPassword: String ) : Closeable { - private var client: CordaRPCClient? = null - private var _proxy: CordaRPCOps? = null - val proxy: CordaRPCOps get() = _proxy ?: throw IllegalStateException("proxy requested, but the client is not running") + private val client = CordaRPCClient(localTunnelAddress) + private var connection: CordaRPCConnection? = null + val proxy: CordaRPCOps get() = connection?.proxy ?: throw IllegalStateException("proxy requested, but the client is not running") data class ShellCommandOutput( val originalShellCommand: String, @@ -162,32 +163,24 @@ class NodeConnection( } fun doWhileClientStopped(action: () -> A): A { - val client = client - val proxy = _proxy - require(client != null && proxy != null) { "doWhileClientStopped called with no running client" } + val connection = connection + require(connection != null) { "doWhileClientStopped called with no running client" } log.info("Stopping RPC proxy to $hostName, tunnel at $localTunnelAddress") - client!!.close() + connection!!.close() try { return action() } finally { log.info("Starting new RPC proxy to $hostName, tunnel at $localTunnelAddress") - val newClient = CordaRPCClient(localTunnelAddress) // TODO expose these somehow? - newClient.start(rpcUsername, rpcPassword) - val newProxy = newClient.proxy() - this.client = newClient - this._proxy = newProxy + val newConnection = client.start(rpcUsername, rpcPassword) + this.connection = newConnection } } fun startClient() { log.info("Creating RPC proxy to $hostName, tunnel at $localTunnelAddress") - val client = CordaRPCClient(localTunnelAddress) - client.start(rpcUsername, rpcPassword) - val proxy = client.proxy() + connection = client.start(rpcUsername, rpcPassword) log.info("Proxy created") - this.client = client - this._proxy = proxy } /** @@ -229,7 +222,7 @@ class NodeConnection( } override fun close() { - client?.close() + connection?.close() jSchSession.disconnect() } } diff --git a/webserver/src/main/kotlin/net/corda/webserver/internal/NodeWebServer.kt b/webserver/src/main/kotlin/net/corda/webserver/internal/NodeWebServer.kt index 085142d13f..a0ce5b7d6b 100644 --- a/webserver/src/main/kotlin/net/corda/webserver/internal/NodeWebServer.kt +++ b/webserver/src/main/kotlin/net/corda/webserver/internal/NodeWebServer.kt @@ -199,8 +199,8 @@ class NodeWebServer(val config: WebServerConfig) { private fun connectLocalRpcAsNodeUser(): CordaRPCOps { log.info("Connecting to node at ${config.p2pAddress} as node user") val client = CordaRPCClient(config.p2pAddress, config) - client.start(ArtemisMessagingComponent.NODE_USER, ArtemisMessagingComponent.NODE_USER) - return client.proxy() + val connection = client.start(ArtemisMessagingComponent.NODE_USER, ArtemisMessagingComponent.NODE_USER) + return connection.proxy } /** Fetch CordaPluginRegistry classes registered in META-INF/services/net.corda.core.node.CordaPluginRegistry files that exist in the classpath */ From 34517f653aee661cc4bcf554044f57d29c5eb0b9 Mon Sep 17 00:00:00 2001 From: Andras Slemmer Date: Thu, 4 May 2017 13:41:28 +0100 Subject: [PATCH 2/6] #592: Address more comments --- .../corda/client/rpc/internal/RPCClient.kt | 10 ++-- .../rpc/internal/RPCClientProxyHandler.kt | 4 +- .../corda/client/rpc/RPCConcurrencyTests.kt | 16 +----- .../corda/client/rpc/RPCPerformanceTests.kt | 50 +++++++++---------- .../net/corda/core/serialization/Kryo.kt | 12 +++-- .../net/corda/core/utilities/LazyPool.kt | 37 ++++++-------- .../corda/core/utilities/LazyStickyPool.kt | 14 +++++- .../kotlin/net/corda/core/utilities/Rate.kt | 29 +++++++++++ 8 files changed, 101 insertions(+), 71 deletions(-) create mode 100644 core/src/main/kotlin/net/corda/core/utilities/Rate.kt diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt index 60c77c676e..60d50928bd 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt @@ -32,11 +32,11 @@ data class RPCClientConfiguration( */ val trackRpcCallSites: Boolean, /** - * The interval of unused observable reaping in milliseconds. Leaked Observables (unused ones) are - * detected using weak references and are cleaned up in batches in this interval. If set too large it will waste - * server side resources for this duration. If set too low it wastes client side cycles. + * The interval of unused observable reaping. Leaked Observables (unused ones) are detected using weak references + * and are cleaned up in batches in this interval. If set too large it will waste server side resources for this + * duration. If set too low it wastes client side cycles. */ - val reapIntervalMs: Long, + val reapInterval: Duration, /** The number of threads to use for observations (for executing [Observable.onNext]) */ val observationExecutorPoolSize: Int, /** The maximum number of producers to create to handle outgoing messages */ @@ -61,7 +61,7 @@ data class RPCClientConfiguration( val default = RPCClientConfiguration( minimumServerProtocolVersion = 0, trackRpcCallSites = false, - reapIntervalMs = 1000, + reapInterval = 1.seconds, observationExecutorPoolSize = 4, producerPoolBound = 1, cacheConcurrencyLevel = 8, diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt index 6309af41f8..fb0e874b8a 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt @@ -165,8 +165,8 @@ class RPCClientProxyHandler( lifeCycle.transition(State.UNSTARTED, State.SERVER_VERSION_NOT_SET) reaperScheduledFuture = reaperExecutor.scheduleAtFixedRate( this::reapObservables, - rpcConfiguration.reapIntervalMs, - rpcConfiguration.reapIntervalMs, + rpcConfiguration.reapInterval.toMillis(), + rpcConfiguration.reapInterval.toMillis(), TimeUnit.MILLISECONDS ) sessionAndProducerPool.run { diff --git a/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCConcurrencyTests.kt b/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCConcurrencyTests.kt index 2e563bc40c..2ffe065832 100644 --- a/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCConcurrencyTests.kt +++ b/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCConcurrencyTests.kt @@ -1,34 +1,22 @@ package net.corda.client.rpc -import com.google.common.util.concurrent.Futures -import com.google.common.util.concurrent.ListenableFuture import net.corda.client.rpc.internal.RPCClientConfiguration import net.corda.core.future import net.corda.core.messaging.RPCOps +import net.corda.core.millis import net.corda.core.random63BitValue import net.corda.core.serialization.CordaSerializable -import net.corda.core.utilities.loggerFor -import net.corda.node.driver.poll import net.corda.node.services.messaging.RPCServerConfiguration -import net.corda.nodeapi.RPCApi import net.corda.testing.RPCDriverExposedDSLInterface import net.corda.testing.rpcDriver -import net.corda.testing.startRandomRpcClient -import net.corda.testing.startRpcClient -import org.apache.activemq.artemis.api.core.SimpleString -import org.junit.Ignore import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.Parameterized import rx.Observable -import rx.subjects.PublishSubject import rx.subjects.UnicastSubject import java.util.* import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.CountDownLatch -import java.util.concurrent.Executors -import java.util.concurrent.TimeUnit -import java.util.concurrent.atomic.AtomicInteger @RunWith(Parameterized::class) class RPCConcurrencyTests : AbstractRPCTest() { @@ -98,7 +86,7 @@ class RPCConcurrencyTests : AbstractRPCTest() { return testProxy( testOpsImpl, clientConfiguration = RPCClientConfiguration.default.copy( - reapIntervalMs = 100, + reapInterval = 100.millis, cacheConcurrencyLevel = 16 ), serverConfiguration = RPCServerConfiguration.default.copy( diff --git a/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPerformanceTests.kt b/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPerformanceTests.kt index fa804dd2c3..8d1fdcb65b 100644 --- a/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPerformanceTests.kt +++ b/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPerformanceTests.kt @@ -1,31 +1,25 @@ package net.corda.client.rpc +import com.codahale.metrics.ConsoleReporter import com.codahale.metrics.Gauge import com.codahale.metrics.JmxReporter import com.codahale.metrics.MetricRegistry -import com.esotericsoftware.kryo.Kryo -import com.esotericsoftware.kryo.Serializer -import com.esotericsoftware.kryo.io.Input -import com.esotericsoftware.kryo.io.Output -import com.esotericsoftware.kryo.pool.KryoPool import com.google.common.base.Stopwatch import net.corda.client.rpc.internal.RPCClientConfiguration import net.corda.core.messaging.RPCOps -import net.corda.core.millis -import net.corda.core.random63BitValue +import net.corda.core.minutes +import net.corda.core.seconds +import net.corda.core.utilities.Rate +import net.corda.core.utilities.div import net.corda.node.driver.ShutdownManager import net.corda.node.services.messaging.RPCServerConfiguration -import net.corda.nodeapi.RPCApi -import net.corda.nodeapi.RPCKryo import net.corda.testing.RPCDriverExposedDSLInterface import net.corda.testing.measure import net.corda.testing.rpcDriver -import org.apache.activemq.artemis.api.core.SimpleString import org.junit.Ignore import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.Parameterized -import rx.Observable import java.time.Duration import java.util.* import java.util.concurrent.* @@ -140,23 +134,24 @@ class RPCPerformanceTests : AbstractRPCTest() { @Test fun `consumption rate`() { rpcDriver { - val metricRegistry = startJmxReporter() + val metricRegistry = startReporter() val proxy = testProxy( RPCClientConfiguration.default.copy( - reapIntervalMs = 100, - cacheConcurrencyLevel = 16 + reapInterval = 1.seconds, + cacheConcurrencyLevel = 16, + producerPoolBound = 8 ), RPCServerConfiguration.default.copy( - rpcThreadPoolSize = 4, - consumerPoolSize = 4, - producerPoolBound = 4 + rpcThreadPoolSize = 8, + consumerPoolSize = 1, + producerPoolBound = 8 ) ) measurePerformancePublishMetrics( metricRegistry = metricRegistry, - parallelism = 4, - overallDurationSecond = 120.0, - injectionRatePerSecond = 20000.0, + parallelism = 8, + overallDuration = 5.minutes, + injectionRate = 20000L / TimeUnit.SECONDS, queueSizeMetricName = "$mode.QueueSize", workDurationMetricName = "$mode.WorkDuration", shutdownManager = this.shutdownManager, @@ -205,8 +200,8 @@ class RPCPerformanceTests : AbstractRPCTest() { fun measurePerformancePublishMetrics( metricRegistry: MetricRegistry, parallelism: Int, - overallDurationSecond: Double, - injectionRatePerSecond: Double, + overallDuration: Duration, + injectionRate: Rate, queueSizeMetricName: String, workDurationMetricName: String, shutdownManager: ShutdownManager, @@ -238,7 +233,7 @@ fun measurePerformancePublishMetrics( } val injector = executor.scheduleAtFixedRate( { - workSemaphore.release(injectionRatePerSecond.toInt()) + workSemaphore.release((injectionRate * TimeUnit.SECONDS).toInt()) }, 0, 1, @@ -251,7 +246,7 @@ fun measurePerformancePublishMetrics( workExecutor.awaitTermination(1, TimeUnit.SECONDS) executor.awaitTermination(1, TimeUnit.SECONDS) } - Thread.sleep((overallDurationSecond * 1000).toLong()) + Thread.sleep(overallDuration.toMillis()) } fun startInjectorWithBoundedQueue( @@ -289,7 +284,7 @@ fun startInjectorWithBoundedQueue( injector.join() } -fun RPCDriverExposedDSLInterface.startJmxReporter(): MetricRegistry { +fun RPCDriverExposedDSLInterface.startReporter(): MetricRegistry { val metricRegistry = MetricRegistry() val jmxReporter = thread { JmxReporter. @@ -307,9 +302,14 @@ fun RPCDriverExposedDSLInterface.startJmxReporter(): MetricRegistry { build(). start() } + val consoleReporter = thread { + ConsoleReporter.forRegistry(metricRegistry).build().start(1, TimeUnit.SECONDS) + } shutdownManager.registerShutdown { jmxReporter.interrupt() + consoleReporter.interrupt() jmxReporter.join() + consoleReporter.join() } return metricRegistry } diff --git a/core/src/main/kotlin/net/corda/core/serialization/Kryo.kt b/core/src/main/kotlin/net/corda/core/serialization/Kryo.kt index 95b82c118d..d744ea574d 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/Kryo.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/Kryo.kt @@ -148,13 +148,19 @@ fun T.serialize(kryo: KryoPool = p2PKryo(), internalOnly: Boolean = fa } -private val serializeBufferPool = LazyPool { ByteArray(64 * 1024) } -private val serializeOutputStreamPool = LazyPool(ByteArrayOutputStream::reset) { ByteArrayOutputStream(64 * 1024) } +private val serializeBufferPool = LazyPool( + newInstance = { ByteArray(64 * 1024) } +) +private val serializeOutputStreamPool = LazyPool( + clear = ByteArrayOutputStream::reset, + shouldReturnToPool = { it.size() < 256 * 1024 }, // Discard if it grew too large + newInstance = { ByteArrayOutputStream(64 * 1024) } +) fun T.serialize(kryo: Kryo, internalOnly: Boolean = false): SerializedBytes { return serializeOutputStreamPool.run { stream -> serializeBufferPool.run { buffer -> Output(buffer).use { - it.setOutputStream(stream) + it.outputStream = stream it.writeBytes(KryoHeaderV0_1.bytes) kryo.writeClassAndObject(it, this) } diff --git a/core/src/main/kotlin/net/corda/core/utilities/LazyPool.kt b/core/src/main/kotlin/net/corda/core/utilities/LazyPool.kt index a9df98637e..1a1abebdca 100644 --- a/core/src/main/kotlin/net/corda/core/utilities/LazyPool.kt +++ b/core/src/main/kotlin/net/corda/core/utilities/LazyPool.kt @@ -1,23 +1,28 @@ package net.corda.core.utilities +import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.LinkedBlockingQueue -import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.Semaphore /** * A lazy pool of resources [A]. * * @param clear If specified this function will be run on each borrowed instance before handing it over. + * @param shouldReturnToPool If specified this function will be run on each release to determine whether the instance + * should be returned to the pool for reuse. This may be useful for pooled resources that dynamically grow during + * usage, and we may not want to retain them forever. * @param bound If specified the pool will be bounded. Once all instances are borrowed subsequent borrows will block until an * instance is released. - * @param create The function to call to lazily create a pooled resource. + * @param newInstance The function to call to lazily newInstance a pooled resource. */ class LazyPool( private val clear: ((A) -> Unit)? = null, + private val shouldReturnToPool: ((A) -> Boolean)? = null, private val bound: Int? = null, - private val create: () -> A + private val newInstance: () -> A ) { - private val poolQueue = LinkedBlockingQueue() - private var poolSize = 0 + private val poolQueue = ConcurrentLinkedQueue() + private val poolSemaphore = Semaphore(bound ?: Int.MAX_VALUE) private enum class State { STARTED, @@ -32,23 +37,10 @@ class LazyPool( fun borrow(): A { lifeCycle.requireState(State.STARTED) + poolSemaphore.acquire() val pooled = poolQueue.poll() if (pooled == null) { - if (bound != null) { - val waitForRelease = synchronized(this) { - if (poolSize < bound) { - poolSize++ - false - } else { - true - } - } - if (waitForRelease) { - // Wait until one is released - return clearIfNeeded(poolQueue.take()) - } - } - return create() + return newInstance() } else { return clearIfNeeded(pooled) } @@ -56,7 +48,10 @@ class LazyPool( fun release(instance: A) { lifeCycle.requireState(State.STARTED) - poolQueue.add(instance) + if (shouldReturnToPool == null || shouldReturnToPool.invoke(instance)) { + poolQueue.add(instance) + } + poolSemaphore.release() } /** diff --git a/core/src/main/kotlin/net/corda/core/utilities/LazyStickyPool.kt b/core/src/main/kotlin/net/corda/core/utilities/LazyStickyPool.kt index 298279f09f..cec52e1842 100644 --- a/core/src/main/kotlin/net/corda/core/utilities/LazyStickyPool.kt +++ b/core/src/main/kotlin/net/corda/core/utilities/LazyStickyPool.kt @@ -6,10 +6,17 @@ import java.util.concurrent.LinkedBlockingQueue /** * A [LazyStickyPool] is a lazy pool of resources where a [borrow] may "stick" the borrowed instance to an object. * Any subsequent borrows using the same object will return the same pooled instance. + * + * @param size The size of the pool. + * @param shouldReturnToPool If specified this function will be run on each release to determine whether the instance + * should be returned to the pool for reuse. This may be useful for pooled resources that dynamically grow during + * usage, and we may not want to retain them forever. + * @param newInstance The function to call to create a pooled resource. */ // TODO This could be implemented more efficiently. Currently the "non-sticky" use case is not optimised, it just chooses a random instance to wait on. class LazyStickyPool( size: Int, + private val shouldReturnToPool: ((A) -> Boolean)? = null, private val newInstance: () -> A ) { private class InstanceBox { @@ -45,7 +52,12 @@ class LazyStickyPool( fun release(stickTo: Any, instance: A) { val box = boxes[toIndex(stickTo)] - box.instance!!.add(instance) + if (shouldReturnToPool == null || shouldReturnToPool.invoke(instance)) { + box.instance!!.add(instance) + } else { + // We need to create a new instance instead of setting the queue to null to unblock potentially waiting threads. + box.instance!!.add(newInstance()) + } } inline fun run(stickToOrNull: Any? = null, withInstance: (A) -> R): R { diff --git a/core/src/main/kotlin/net/corda/core/utilities/Rate.kt b/core/src/main/kotlin/net/corda/core/utilities/Rate.kt new file mode 100644 index 0000000000..1936a27fa3 --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/utilities/Rate.kt @@ -0,0 +1,29 @@ +package net.corda.core.utilities + +import java.time.Duration +import java.time.temporal.ChronoUnit +import java.util.concurrent.TimeUnit + +/** + * [Rate] holds a quantity denoting the frequency of some event e.g. 100 times per second or 2 times per day. + */ +data class Rate( + val numberOfEvents: Long, + val perTimeUnit: TimeUnit +) { + /** + * Returns the interval between two subsequent events. + */ + fun toInterval(): Duration { + return Duration.of(TimeUnit.NANOSECONDS.convert(1, perTimeUnit) / numberOfEvents, ChronoUnit.NANOS) + } + + /** + * Converts the number of events to the given unit. + */ + operator fun times(inUnit: TimeUnit): Long { + return inUnit.convert(numberOfEvents, perTimeUnit) + } +} + +operator fun Long.div(timeUnit: TimeUnit) = Rate(this, timeUnit) From 652cbb0d9f50b491804179dc5ef95e87ce89b60e Mon Sep 17 00:00:00 2001 From: Andras Slemmer Date: Thu, 4 May 2017 16:28:49 +0100 Subject: [PATCH 3/6] #592: RPCServer lifecycle --- .../net/corda/node/driver/DriverTests.kt | 2 ++ .../node/services/messaging/RPCServer.kt | 22 ++++++++++++++++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/node/src/integration-test/kotlin/net/corda/node/driver/DriverTests.kt b/node/src/integration-test/kotlin/net/corda/node/driver/DriverTests.kt index cd39098c83..54f77b30a9 100644 --- a/node/src/integration-test/kotlin/net/corda/node/driver/DriverTests.kt +++ b/node/src/integration-test/kotlin/net/corda/node/driver/DriverTests.kt @@ -75,8 +75,10 @@ class DriverTests { driver(isDebug = true, systemProperties = mapOf("log4j.configurationFile" to logConfigFile.toString())) { val baseDirectory = startNode(DUMMY_BANK_A.name).getOrThrow().configuration.baseDirectory val logFile = (baseDirectory / LOGS_DIRECTORY_NAME).list { it.sorted().findFirst().get() } + println("ASD $logFile") val debugLinesPresent = logFile.readLines { lines -> lines.anyMatch { line -> line.startsWith("[DEBUG]") } } assertThat(debugLinesPresent).isTrue() + println("hmm.") } } diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/RPCServer.kt b/node/src/main/kotlin/net/corda/node/services/messaging/RPCServer.kt index cb3f44ab28..9a0c0f70fe 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/RPCServer.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/RPCServer.kt @@ -18,6 +18,7 @@ import net.corda.core.messaging.RPCOps import net.corda.core.random63BitValue import net.corda.core.serialization.KryoPoolWithContext import net.corda.core.utilities.LazyStickyPool +import net.corda.core.utilities.LifeCycle import net.corda.core.utilities.debug import net.corda.core.utilities.loggerFor import net.corda.node.services.RPCUserService @@ -28,6 +29,7 @@ import org.apache.activemq.artemis.api.core.SimpleString import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE import org.apache.activemq.artemis.api.core.client.ClientConsumer import org.apache.activemq.artemis.api.core.client.ClientMessage +import org.apache.activemq.artemis.api.core.client.ClientSession import org.apache.activemq.artemis.api.core.client.ServerLocator import org.apache.activemq.artemis.api.core.management.ActiveMQServerControl import org.apache.activemq.artemis.api.core.management.CoreNotificationType @@ -84,6 +86,12 @@ class RPCServer( val log = loggerFor() val kryoPool = KryoPool.Builder { RPCKryo(RpcServerObservableSerializer) }.build() } + private enum class State { + UNSTARTED, + STARTED, + FINISHED + } + private val lifeCycle = LifeCycle(State.UNSTARTED) // The methodname->Method map to use for dispatching. private val methodTable = ops.javaClass.declaredMethods.groupBy { it.name }.mapValues { it.value.single() } // The observable subscription mapping. @@ -134,17 +142,24 @@ class RPCServer( rpcConfiguration.reapIntervalMs, TimeUnit.MILLISECONDS ) + val sessions = ArrayList() for (i in 1 .. rpcConfiguration.consumerPoolSize) { val sessionFactory = serverLocator.createSessionFactory() val session = sessionFactory.createSession(rpcServerUsername, rpcServerPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE) val consumer = session.createConsumer(RPCApi.RPC_SERVER_QUEUE_NAME) consumer.setMessageHandler(this@RPCServer::clientArtemisMessageHandler) - session.start() sessionAndConsumers.add(ArtemisConsumer(sessionFactory, session, consumer)) + sessions.add(session) } clientBindingRemovalConsumer = sessionAndConsumers[0].session.createConsumer(RPCApi.RPC_CLIENT_BINDING_REMOVALS) clientBindingRemovalConsumer.setMessageHandler(this::bindingRemovalArtemisMessageHandler) serverControl = activeMqServerControl + lifeCycle.transition(State.UNSTARTED, State.STARTED) + // We delay the consumer session start because Artemis starts delivering messages immediately, so we need to be + // fully initialised. + sessions.forEach { + it.start() + } } fun close() { @@ -165,9 +180,11 @@ class RPCServer( it.session.close() it.sessionFactory.close() } + lifeCycle.transition(State.STARTED, State.FINISHED) } private fun bindingRemovalArtemisMessageHandler(artemisMessage: ClientMessage) { + lifeCycle.requireState(State.STARTED) val notificationType = artemisMessage.getStringProperty(ManagementHelper.HDR_NOTIFICATION_TYPE) require(notificationType == CoreNotificationType.BINDING_REMOVED.name) val clientAddress = artemisMessage.getStringProperty(ManagementHelper.HDR_ROUTING_NAME) @@ -178,11 +195,13 @@ class RPCServer( // Note that this function operates on the *current* view of client observables. During invalidation further // Observables may be serialised and thus registered. private fun invalidateClient(clientAddress: SimpleString) { + lifeCycle.requireState(State.STARTED) val observableIds = clientAddressToObservables.removeAll(clientAddress) observableMap.invalidateAll(observableIds) } private fun clientArtemisMessageHandler(artemisMessage: ClientMessage) { + lifeCycle.requireState(State.STARTED) val clientToServer = RPCApi.ClientToServer.fromClientMessage(kryoPool, artemisMessage) log.debug { "-> RPC -> $clientToServer" } when (clientToServer) { @@ -234,6 +253,7 @@ class RPCServer( } private fun reapSubscriptions() { + lifeCycle.requireState(State.STARTED) observableMap.cleanUp() } From 3a2afcdbb2137cde899b69985693fb4152b39ead Mon Sep 17 00:00:00 2001 From: Andras Slemmer Date: Thu, 4 May 2017 17:38:59 +0100 Subject: [PATCH 4/6] #592: Address more comments --- .../net/corda/client/rpc/RPCStabilityTests.kt | 3 +-- .../rpc/internal/RPCClientProxyHandler.kt | 9 ++++----- .../net/corda/core/utilities/LazyStickyPool.kt | 11 +---------- .../src/main/kotlin/net/corda/nodeapi/RPCApi.kt | 2 +- .../corda/node/services/messaging/RPCServer.kt | 17 ++++++++++++----- 5 files changed, 19 insertions(+), 23 deletions(-) diff --git a/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/RPCStabilityTests.kt b/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/RPCStabilityTests.kt index 63a4dd6673..5287a0de4f 100644 --- a/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/RPCStabilityTests.kt +++ b/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/RPCStabilityTests.kt @@ -14,7 +14,6 @@ import net.corda.nodeapi.RPCApi import net.corda.nodeapi.RPCKryo import net.corda.testing.* import org.apache.activemq.artemis.api.core.SimpleString -import org.bouncycastle.crypto.tls.ConnectionEnd.server import org.junit.Test import rx.Observable import rx.subjects.PublishSubject @@ -80,7 +79,7 @@ class RPCStabilityTests { } val server = startRpcServer( configuration = RPCServerConfiguration.default.copy( - reapIntervalMs = 100 + reapInterval = 100.millis ), ops = trackSubscriberOpsImpl ).get() diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt index fb0e874b8a..da95f01b4d 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt @@ -162,7 +162,6 @@ class RPCClientProxyHandler( * Start the client. This creates the per-client queue, starts the consumer session and the reaper. */ fun start() { - lifeCycle.transition(State.UNSTARTED, State.SERVER_VERSION_NOT_SET) reaperScheduledFuture = reaperExecutor.scheduleAtFixedRate( this::reapObservables, rpcConfiguration.reapInterval.toMillis(), @@ -176,12 +175,12 @@ class RPCClientProxyHandler( val session = sessionFactory.createSession(rpcUsername, rpcPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE) val consumer = session.createConsumer(clientAddress) consumer.setMessageHandler(this@RPCClientProxyHandler::artemisMessageHandler) - session.start() sessionAndConsumer = ArtemisConsumer(sessionFactory, session, consumer) + lifeCycle.transition(State.UNSTARTED, State.SERVER_VERSION_NOT_SET) + session.start() } // This is the general function that transforms a client side RPC to internal Artemis messages. - @CallerSensitive override fun invoke(proxy: Any, method: Method, arguments: Array?): Any? { lifeCycle.requireState { it == State.STARTED || it == State.SERVER_VERSION_NOT_SET } checkProtocolVersion(method) @@ -269,7 +268,6 @@ class RPCClientProxyHandler( * Closes the RPC proxy. Reaps all observables, shuts down the reaper, closes all sessions and executors. */ fun close() { - lifeCycle.transition(State.STARTED, State.FINISHED) sessionAndConsumer.consumer.close() sessionAndConsumer.session.close() sessionAndConsumer.sessionFactory.close() @@ -287,6 +285,7 @@ class RPCClientProxyHandler( val observationExecutors = observationExecutorPool.close() observationExecutors.forEach { it.shutdownNow() } observationExecutors.forEach { it.awaitTermination(100, TimeUnit.MILLISECONDS) } + lifeCycle.transition(State.STARTED, State.FINISHED) } /** @@ -310,12 +309,12 @@ class RPCClientProxyHandler( * RPCs already may be called with it. */ internal fun setServerProtocolVersion(version: Int) { - lifeCycle.transition(State.SERVER_VERSION_NOT_SET, State.STARTED) if (serverProtocolVersion == null) { serverProtocolVersion = version } else { throw IllegalStateException("setServerProtocolVersion called, but the protocol version was already set!") } + lifeCycle.transition(State.SERVER_VERSION_NOT_SET, State.STARTED) } private fun reapObservables() { diff --git a/core/src/main/kotlin/net/corda/core/utilities/LazyStickyPool.kt b/core/src/main/kotlin/net/corda/core/utilities/LazyStickyPool.kt index cec52e1842..f44723b6b8 100644 --- a/core/src/main/kotlin/net/corda/core/utilities/LazyStickyPool.kt +++ b/core/src/main/kotlin/net/corda/core/utilities/LazyStickyPool.kt @@ -8,15 +8,11 @@ import java.util.concurrent.LinkedBlockingQueue * Any subsequent borrows using the same object will return the same pooled instance. * * @param size The size of the pool. - * @param shouldReturnToPool If specified this function will be run on each release to determine whether the instance - * should be returned to the pool for reuse. This may be useful for pooled resources that dynamically grow during - * usage, and we may not want to retain them forever. * @param newInstance The function to call to create a pooled resource. */ // TODO This could be implemented more efficiently. Currently the "non-sticky" use case is not optimised, it just chooses a random instance to wait on. class LazyStickyPool( size: Int, - private val shouldReturnToPool: ((A) -> Boolean)? = null, private val newInstance: () -> A ) { private class InstanceBox { @@ -52,12 +48,7 @@ class LazyStickyPool( fun release(stickTo: Any, instance: A) { val box = boxes[toIndex(stickTo)] - if (shouldReturnToPool == null || shouldReturnToPool.invoke(instance)) { - box.instance!!.add(instance) - } else { - // We need to create a new instance instead of setting the queue to null to unblock potentially waiting threads. - box.instance!!.add(newInstance()) - } + box.instance!!.add(instance) } inline fun run(stickToOrNull: Any? = null, withInstance: (A) -> R): R { diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/RPCApi.kt b/node-api/src/main/kotlin/net/corda/nodeapi/RPCApi.kt index 39d85d90d2..b3820c154b 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/RPCApi.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/RPCApi.kt @@ -54,7 +54,7 @@ import java.util.* * Note that multiple sessions like the above may interleave in an arbitrary fashion. * * Additionally the server may listen on client binding removals for cleanup using [RPC_CLIENT_BINDING_REMOVALS]. This - * requires the server to create a filter on the artemis notification address using + * requires the server to create a filter on the artemis notification address using [RPC_CLIENT_BINDING_REMOVAL_FILTER_EXPRESSION] */ object RPCApi { private val TAG_FIELD_NAME = "tag" diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/RPCServer.kt b/node/src/main/kotlin/net/corda/node/services/messaging/RPCServer.kt index 9a0c0f70fe..c7101e0d23 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/RPCServer.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/RPCServer.kt @@ -16,6 +16,7 @@ import net.corda.core.ErrorOr import net.corda.core.crypto.commonName import net.corda.core.messaging.RPCOps import net.corda.core.random63BitValue +import net.corda.core.seconds import net.corda.core.serialization.KryoPoolWithContext import net.corda.core.utilities.LazyStickyPool import net.corda.core.utilities.LifeCycle @@ -40,6 +41,7 @@ import rx.Observable import rx.Subscriber import rx.Subscription import java.lang.reflect.InvocationTargetException +import java.time.Duration import java.util.concurrent.ExecutorService import java.util.concurrent.Executors import java.util.concurrent.ScheduledFuture @@ -52,15 +54,15 @@ data class RPCServerConfiguration( val consumerPoolSize: Int, /** The maximum number of producers to create to handle outgoing messages */ val producerPoolBound: Int, - /** The interval of subscription reaping in milliseconds */ - val reapIntervalMs: Long + /** The interval of subscription reaping */ + val reapInterval: Duration ) { companion object { val default = RPCServerConfiguration( rpcThreadPoolSize = 4, consumerPoolSize = 2, producerPoolBound = 4, - reapIntervalMs = 1000 + reapInterval = 1.seconds ) } } @@ -138,8 +140,8 @@ class RPCServer( log.info("Starting RPC server with configuration $rpcConfiguration") reaperScheduledFuture = reaperExecutor.scheduleAtFixedRate( this::reapSubscriptions, - rpcConfiguration.reapIntervalMs, - rpcConfiguration.reapIntervalMs, + rpcConfiguration.reapInterval.toMillis(), + rpcConfiguration.reapInterval.toMillis(), TimeUnit.MILLISECONDS ) val sessions = ArrayList() @@ -274,6 +276,11 @@ class RPCServer( @JvmField internal val CURRENT_RPC_CONTEXT: ThreadLocal = ThreadLocal() +/** + * Returns a context specific to the current RPC call. Note that trying to call this function outside of an RPC will + * throw. If you'd like to use the context outside of the call (e.g. in another thread) then pass the returned reference + * around explicitly. + */ fun getRpcContext(): RpcContext = CURRENT_RPC_CONTEXT.get() /** From f744c4455e0f0cbd78ae45136e9d4264ab88b1e4 Mon Sep 17 00:00:00 2001 From: Andras Slemmer Date: Fri, 5 May 2017 15:10:56 +0100 Subject: [PATCH 5/6] #592: Fix test port allocation flakiness --- .../main/kotlin/net/corda/testing/CoreTestUtils.kt | 12 +++++------- .../src/main/kotlin/net/corda/testing/RPCDriver.kt | 11 ++++++++--- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/test-utils/src/main/kotlin/net/corda/testing/CoreTestUtils.kt b/test-utils/src/main/kotlin/net/corda/testing/CoreTestUtils.kt index 7d9a351605..aff00f26c8 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/CoreTestUtils.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/CoreTestUtils.kt @@ -35,6 +35,7 @@ import java.nio.file.Path import java.security.KeyPair import java.security.PublicKey import java.util.* +import java.util.concurrent.atomic.AtomicInteger import kotlin.reflect.KClass /** @@ -90,6 +91,7 @@ val MOCK_VERSION_INFO = VersionInfo(1, "Mock release", "Mock revision", "Mock Ve fun generateStateRef() = StateRef(SecureHash.randomSHA256(), 0) +private val freePortCounter = AtomicInteger(30000) /** * Returns a free port. * @@ -97,7 +99,7 @@ fun generateStateRef() = StateRef(SecureHash.randomSHA256(), 0) * Use [getFreeLocalPorts] for getting multiple ports. */ fun freeLocalHostAndPort(): HostAndPort { - val freePort = ServerSocket(0).use { it.localPort } + val freePort = freePortCounter.getAndAccumulate(0) { prev, _ -> 30000 + (prev - 30000 + 1) % 10000 } return HostAndPort.fromParts("localhost", freePort) } @@ -108,12 +110,8 @@ fun freeLocalHostAndPort(): HostAndPort { * to the Node, some other process else could allocate the returned ports. */ fun getFreeLocalPorts(hostName: String, numberToAlloc: Int): List { - // Create a bunch of sockets up front. - val sockets = Array(numberToAlloc) { ServerSocket(0) } - val result = sockets.map { HostAndPort.fromParts(hostName, it.localPort) } - // Close sockets only once we've grabbed all the ports we need. - sockets.forEach(ServerSocket::close) - return result + val freePort = freePortCounter.getAndAccumulate(0) { prev, _ -> 30000 + (prev - 30000 + numberToAlloc) % 10000 } + return (freePort .. freePort + numberToAlloc - 1).map { HostAndPort.fromParts(hostName, it) } } /** diff --git a/test-utils/src/main/kotlin/net/corda/testing/RPCDriver.kt b/test-utils/src/main/kotlin/net/corda/testing/RPCDriver.kt index 8e3ae09eb4..983cfd848e 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/RPCDriver.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/RPCDriver.kt @@ -10,6 +10,7 @@ import net.corda.client.rpc.internal.RPCClient import net.corda.client.rpc.internal.RPCClientConfiguration import net.corda.core.div import net.corda.core.messaging.RPCOps +import net.corda.core.random63BitValue import net.corda.core.utilities.ProcessUtilities import net.corda.node.driver.* import net.corda.node.services.RPCUserService @@ -101,7 +102,7 @@ interface RPCDriverExposedDSLInterface : DriverDSLExposedInterface { * @param ops The server-side implementation of the RPC interface. */ fun startRpcServer( - serverName: String = "driver-rpc-server", + serverName: String = "driver-rpc-server-${random63BitValue()}", rpcUser: User = rpcTestUser, nodeLegalName: X500Name = fakeNodeLegalName, maxFileSize: Int = ArtemisMessagingServer.MAX_FILE_SIZE, @@ -182,11 +183,14 @@ data class RpcServerHandle( val rpcTestUser = User("user1", "test", permissions = emptySet()) val fakeNodeLegalName = X500Name("not:a:valid:name") +// Use a global pool so that we can run RPC tests in parallel +private val globalPortAllocation = PortAllocation.Incremental(10000) +private val globalDebugPortAllocation = PortAllocation.Incremental(5005) fun rpcDriver( isDebug: Boolean = false, driverDirectory: Path = Paths.get("build", getTimestampAsDirectoryName()), - portAllocation: PortAllocation = PortAllocation.Incremental(10000), - debugPortAllocation: PortAllocation = PortAllocation.Incremental(5005), + portAllocation: PortAllocation = globalPortAllocation, + debugPortAllocation: PortAllocation = globalDebugPortAllocation, systemProperties: Map = emptyMap(), useTestClock: Boolean = false, automaticallyStartNetworkMap: Boolean = false, @@ -339,6 +343,7 @@ data class RPCDriverDSL( ops: I ): ListenableFuture { val hostAndPort = driverDSL.portAllocation.nextHostAndPort() + addressMustNotBeBound(driverDSL.executorService, hostAndPort) return driverDSL.executorService.submit { val artemisConfig = createRpcServerArtemisConfig(maxFileSize, maxBufferedBytesPerClient, driverDSL.driverDirectory / serverName, hostAndPort) val server = ActiveMQServerImpl(artemisConfig, SingleUserSecurityManager(rpcUser)) From cbe15e35c7b4070d3064269a9990150e71be51ed Mon Sep 17 00:00:00 2001 From: Andras Slemmer Date: Fri, 5 May 2017 17:31:24 +0100 Subject: [PATCH 6/6] Fix X500Name issue in RPCDriver --- test-utils/src/main/kotlin/net/corda/testing/RPCDriver.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test-utils/src/main/kotlin/net/corda/testing/RPCDriver.kt b/test-utils/src/main/kotlin/net/corda/testing/RPCDriver.kt index 983cfd848e..6822c1fa76 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/RPCDriver.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/RPCDriver.kt @@ -181,7 +181,7 @@ data class RpcServerHandle( ) val rpcTestUser = User("user1", "test", permissions = emptySet()) -val fakeNodeLegalName = X500Name("not:a:valid:name") +val fakeNodeLegalName = X500Name("CN=not:a:valid:name") // Use a global pool so that we can run RPC tests in parallel private val globalPortAllocation = PortAllocation.Incremental(10000)