mirror of
https://github.com/corda/corda.git
synced 2025-06-21 16:49:45 +00:00
RPC muxing, multithreading, RPC driver, performance tests
This commit is contained in:
@ -1,15 +1,10 @@
|
||||
package net.corda.client.rpc
|
||||
|
||||
import net.corda.core.contracts.DOLLARS
|
||||
import net.corda.core.flows.FlowInitiator
|
||||
import net.corda.core.flows.FlowException
|
||||
import net.corda.core.flows.FlowInitiator
|
||||
import net.corda.core.getOrThrow
|
||||
import net.corda.core.messaging.FlowHandle
|
||||
import net.corda.core.messaging.FlowProgressHandle
|
||||
import net.corda.core.messaging.CordaRPCOps
|
||||
import net.corda.core.messaging.StateMachineUpdate
|
||||
import net.corda.core.messaging.startFlow
|
||||
import net.corda.core.messaging.startTrackedFlow
|
||||
import net.corda.core.messaging.*
|
||||
import net.corda.core.node.services.ServiceInfo
|
||||
import net.corda.core.random63BitValue
|
||||
import net.corda.core.serialization.OpaqueBytes
|
||||
@ -27,7 +22,9 @@ import org.junit.After
|
||||
import org.junit.Before
|
||||
import org.junit.Test
|
||||
import java.util.*
|
||||
import kotlin.test.*
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertFalse
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
class CordaRPCClientTest : NodeBasedTest() {
|
||||
private val rpcUser = User("user1", "test", permissions = setOf(
|
||||
@ -36,6 +33,11 @@ class CordaRPCClientTest : NodeBasedTest() {
|
||||
))
|
||||
private lateinit var node: Node
|
||||
private lateinit var client: CordaRPCClient
|
||||
private var connection: CordaRPCConnection? = null
|
||||
|
||||
private fun login(username: String, password: String) {
|
||||
connection = client.start(username, password)
|
||||
}
|
||||
|
||||
@Before
|
||||
fun setUp() {
|
||||
@ -45,33 +47,35 @@ class CordaRPCClientTest : NodeBasedTest() {
|
||||
|
||||
@After
|
||||
fun done() {
|
||||
client.close()
|
||||
connection?.close()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `log in with valid username and password`() {
|
||||
client.start(rpcUser.username, rpcUser.password)
|
||||
login(rpcUser.username, rpcUser.password)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `log in with unknown user`() {
|
||||
assertThatExceptionOfType(ActiveMQSecurityException::class.java).isThrownBy {
|
||||
client.start(random63BitValue().toString(), rpcUser.password)
|
||||
login(random63BitValue().toString(), rpcUser.password)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `log in with incorrect password`() {
|
||||
assertThatExceptionOfType(ActiveMQSecurityException::class.java).isThrownBy {
|
||||
client.start(rpcUser.username, random63BitValue().toString())
|
||||
login(rpcUser.username, random63BitValue().toString())
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `close-send deadlock and premature shutdown on empty observable`() {
|
||||
val proxy = createRpcProxy(rpcUser.username, rpcUser.password)
|
||||
println("Starting client")
|
||||
login(rpcUser.username, rpcUser.password)
|
||||
println("Creating proxy")
|
||||
println("Starting flow")
|
||||
val flowHandle = proxy.startTrackedFlow(
|
||||
val flowHandle = connection!!.proxy.startTrackedFlow(
|
||||
::CashIssueFlow,
|
||||
20.DOLLARS, OpaqueBytes.of(0), node.info.legalIdentity, node.info.legalIdentity)
|
||||
println("Started flow, waiting on result")
|
||||
@ -83,9 +87,8 @@ class CordaRPCClientTest : NodeBasedTest() {
|
||||
|
||||
@Test
|
||||
fun `FlowException thrown by flow`() {
|
||||
client.start(rpcUser.username, rpcUser.password)
|
||||
val proxy = client.proxy()
|
||||
val handle = proxy.startFlow(::CashPaymentFlow, 100.DOLLARS, node.info.legalIdentity)
|
||||
login(rpcUser.username, rpcUser.password)
|
||||
val handle = connection!!.proxy.startFlow(::CashPaymentFlow, 100.DOLLARS, node.info.legalIdentity)
|
||||
// TODO Restrict this to CashException once RPC serialisation has been fixed
|
||||
assertThatExceptionOfType(FlowException::class.java).isThrownBy {
|
||||
handle.returnValue.getOrThrow()
|
||||
@ -94,9 +97,8 @@ class CordaRPCClientTest : NodeBasedTest() {
|
||||
|
||||
@Test
|
||||
fun `check basic flow has no progress`() {
|
||||
client.start(rpcUser.username, rpcUser.password)
|
||||
val proxy = client.proxy()
|
||||
proxy.startFlow(::CashPaymentFlow, 100.DOLLARS, node.info.legalIdentity).use {
|
||||
login(rpcUser.username, rpcUser.password)
|
||||
connection!!.proxy.startFlow(::CashPaymentFlow, 100.DOLLARS, node.info.legalIdentity).use {
|
||||
assertFalse(it is FlowProgressHandle<*>)
|
||||
assertTrue(it is FlowHandle<*>)
|
||||
}
|
||||
@ -104,7 +106,8 @@ class CordaRPCClientTest : NodeBasedTest() {
|
||||
|
||||
@Test
|
||||
fun `get cash balances`() {
|
||||
val proxy = createRpcProxy(rpcUser.username, rpcUser.password)
|
||||
login(rpcUser.username, rpcUser.password)
|
||||
val proxy = connection!!.proxy
|
||||
val startCash = proxy.getCashBalances()
|
||||
assertTrue(startCash.isEmpty(), "Should not start with any cash")
|
||||
|
||||
@ -123,7 +126,8 @@ class CordaRPCClientTest : NodeBasedTest() {
|
||||
|
||||
@Test
|
||||
fun `flow initiator via RPC`() {
|
||||
val proxy = createRpcProxy(rpcUser.username, rpcUser.password)
|
||||
login(rpcUser.username, rpcUser.password)
|
||||
val proxy = connection!!.proxy
|
||||
val smUpdates = proxy.stateMachinesAndUpdates()
|
||||
var countRpcFlows = 0
|
||||
var countShellFlows = 0
|
||||
@ -148,11 +152,4 @@ class CordaRPCClientTest : NodeBasedTest() {
|
||||
assertEquals(2, countRpcFlows)
|
||||
assertEquals(1, countShellFlows)
|
||||
}
|
||||
|
||||
private fun createRpcProxy(username: String, password: String): CordaRPCOps {
|
||||
println("Starting client")
|
||||
client.start(username, password)
|
||||
println("Creating proxy")
|
||||
return client.proxy()
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,170 @@
|
||||
package net.corda.client.rpc
|
||||
|
||||
import com.esotericsoftware.kryo.Kryo
|
||||
import com.esotericsoftware.kryo.Serializer
|
||||
import com.esotericsoftware.kryo.io.Input
|
||||
import com.esotericsoftware.kryo.io.Output
|
||||
import com.esotericsoftware.kryo.pool.KryoPool
|
||||
import com.google.common.util.concurrent.Futures
|
||||
import net.corda.core.messaging.RPCOps
|
||||
import net.corda.core.millis
|
||||
import net.corda.core.random63BitValue
|
||||
import net.corda.node.services.messaging.RPCServerConfiguration
|
||||
import net.corda.nodeapi.RPCApi
|
||||
import net.corda.nodeapi.RPCKryo
|
||||
import net.corda.testing.*
|
||||
import org.apache.activemq.artemis.api.core.SimpleString
|
||||
import org.bouncycastle.crypto.tls.ConnectionEnd.server
|
||||
import org.junit.Test
|
||||
import rx.Observable
|
||||
import rx.subjects.PublishSubject
|
||||
import rx.subjects.UnicastSubject
|
||||
import java.time.Duration
|
||||
import java.util.concurrent.TimeUnit
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
|
||||
|
||||
class RPCStabilityTests {
|
||||
|
||||
interface LeakObservableOps: RPCOps {
|
||||
fun leakObservable(): Observable<Nothing>
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `client cleans up leaked observables`() {
|
||||
rpcDriver {
|
||||
val leakObservableOpsImpl = object : LeakObservableOps {
|
||||
val leakedUnsubscribedCount = AtomicInteger(0)
|
||||
override val protocolVersion = 0
|
||||
override fun leakObservable(): Observable<Nothing> {
|
||||
return PublishSubject.create<Nothing>().doOnUnsubscribe {
|
||||
leakedUnsubscribedCount.incrementAndGet()
|
||||
}
|
||||
}
|
||||
}
|
||||
val server = startRpcServer<LeakObservableOps>(ops = leakObservableOpsImpl)
|
||||
val proxy = startRpcClient<LeakObservableOps>(server.get().hostAndPort).get()
|
||||
// Leak many observables
|
||||
val N = 200
|
||||
(1..N).toList().parallelStream().forEach {
|
||||
proxy.leakObservable()
|
||||
}
|
||||
// In a loop force GC and check whether the server is notified
|
||||
while (true) {
|
||||
System.gc()
|
||||
if (leakObservableOpsImpl.leakedUnsubscribedCount.get() == N) break
|
||||
Thread.sleep(100)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
interface TrackSubscriberOps : RPCOps {
|
||||
fun subscribe(): Observable<Unit>
|
||||
}
|
||||
|
||||
/**
|
||||
* In this test we create a number of out of process RPC clients that call [TrackSubscriberOps.subscribe] in a loop.
|
||||
*/
|
||||
@Test
|
||||
fun `server cleans up queues after disconnected clients`() {
|
||||
rpcDriver {
|
||||
val trackSubscriberOpsImpl = object : TrackSubscriberOps {
|
||||
override val protocolVersion = 0
|
||||
val subscriberCount = AtomicInteger(0)
|
||||
val trackSubscriberCountObservable = UnicastSubject.create<Unit>().share().
|
||||
doOnSubscribe { subscriberCount.incrementAndGet() }.
|
||||
doOnUnsubscribe { subscriberCount.decrementAndGet() }
|
||||
override fun subscribe(): Observable<Unit> {
|
||||
return trackSubscriberCountObservable
|
||||
}
|
||||
}
|
||||
val server = startRpcServer<TrackSubscriberOps>(
|
||||
configuration = RPCServerConfiguration.default.copy(
|
||||
reapIntervalMs = 100
|
||||
),
|
||||
ops = trackSubscriberOpsImpl
|
||||
).get()
|
||||
|
||||
val numberOfClients = 4
|
||||
val clients = Futures.allAsList((1 .. numberOfClients).map {
|
||||
startRandomRpcClient<TrackSubscriberOps>(server.hostAndPort)
|
||||
}).get()
|
||||
|
||||
// Poll until all clients connect
|
||||
pollUntilClientNumber(server, numberOfClients)
|
||||
pollUntilTrue("number of times subscribe() has been called") { trackSubscriberOpsImpl.subscriberCount.get() >= 100 }.get()
|
||||
// Kill one client
|
||||
clients[0].destroyForcibly()
|
||||
pollUntilClientNumber(server, numberOfClients - 1)
|
||||
// Kill the rest
|
||||
(1 .. numberOfClients - 1).forEach {
|
||||
clients[it].destroyForcibly()
|
||||
}
|
||||
pollUntilClientNumber(server, 0)
|
||||
// Now poll until the server detects the disconnects and unsubscribes from all obserables.
|
||||
pollUntilTrue("number of times subscribe() has been called") { trackSubscriberOpsImpl.subscriberCount.get() == 0 }.get()
|
||||
}
|
||||
}
|
||||
|
||||
interface SlowConsumerRPCOps : RPCOps {
|
||||
fun streamAtInterval(interval: Duration, size: Int): Observable<ByteArray>
|
||||
}
|
||||
class SlowConsumerRPCOpsImpl : SlowConsumerRPCOps {
|
||||
override val protocolVersion = 0
|
||||
|
||||
override fun streamAtInterval(interval: Duration, size: Int): Observable<ByteArray> {
|
||||
val chunk = ByteArray(size)
|
||||
return Observable.interval(interval.toMillis(), TimeUnit.MILLISECONDS).map { chunk }
|
||||
}
|
||||
}
|
||||
val dummyObservableSerialiser = object : Serializer<Observable<Any>>() {
|
||||
override fun write(kryo: Kryo?, output: Output?, `object`: Observable<Any>?) {
|
||||
}
|
||||
override fun read(kryo: Kryo?, input: Input?, type: Class<Observable<Any>>?): Observable<Any> {
|
||||
return Observable.empty()
|
||||
}
|
||||
}
|
||||
@Test
|
||||
fun `slow consumers are kicked`() {
|
||||
val kryoPool = KryoPool.Builder { RPCKryo(dummyObservableSerialiser) }.build()
|
||||
rpcDriver {
|
||||
val server = startRpcServer(maxBufferedBytesPerClient = 10 * 1024 * 1024, ops = SlowConsumerRPCOpsImpl()).get()
|
||||
|
||||
// Construct an RPC session manually so that we can hang in the message handler
|
||||
val myQueue = "${RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX}.test.${random63BitValue()}"
|
||||
val session = startArtemisSession(server.hostAndPort)
|
||||
session.createTemporaryQueue(myQueue, myQueue)
|
||||
val consumer = session.createConsumer(myQueue, null, -1, -1, false)
|
||||
consumer.setMessageHandler {
|
||||
Thread.sleep(50) // 5x slower than the server producer
|
||||
it.acknowledge()
|
||||
}
|
||||
val producer = session.createProducer(RPCApi.RPC_SERVER_QUEUE_NAME)
|
||||
session.start()
|
||||
|
||||
pollUntilClientNumber(server, 1)
|
||||
|
||||
val message = session.createMessage(false)
|
||||
val request = RPCApi.ClientToServer.RpcRequest(
|
||||
clientAddress = SimpleString(myQueue),
|
||||
id = RPCApi.RpcRequestId(random63BitValue()),
|
||||
methodName = SlowConsumerRPCOps::streamAtInterval.name,
|
||||
arguments = listOf(10.millis, 123456)
|
||||
)
|
||||
request.writeToClientMessage(kryoPool, message)
|
||||
producer.send(message)
|
||||
session.commit()
|
||||
|
||||
// We are consuming slower than the server is producing, so we should be kicked after a while
|
||||
pollUntilClientNumber(server, 0)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
fun RPCDriverExposedDSLInterface.pollUntilClientNumber(server: RpcServerHandle, expected: Int) {
|
||||
pollUntilTrue("number of RPC clients to become $expected") {
|
||||
val clientAddresses = server.serverControl.addressNames.filter { it.startsWith(RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX) }
|
||||
clientAddresses.size == expected
|
||||
}.get()
|
||||
}
|
@ -1,168 +1,49 @@
|
||||
package net.corda.client.rpc
|
||||
|
||||
import com.google.common.net.HostAndPort
|
||||
import net.corda.core.ThreadBox
|
||||
import net.corda.core.logElapsedTime
|
||||
import net.corda.client.rpc.internal.RPCClient
|
||||
import net.corda.client.rpc.internal.RPCClientConfiguration
|
||||
import net.corda.core.messaging.CordaRPCOps
|
||||
import net.corda.core.minutes
|
||||
import net.corda.core.seconds
|
||||
import net.corda.core.utilities.loggerFor
|
||||
import net.corda.nodeapi.ArtemisMessagingComponent
|
||||
import net.corda.nodeapi.ArtemisTcpTransport.Companion.tcpTransport
|
||||
import net.corda.nodeapi.ConnectionDirection
|
||||
import net.corda.nodeapi.RPCException
|
||||
import net.corda.nodeapi.config.SSLConfiguration
|
||||
import net.corda.nodeapi.rpcLog
|
||||
import org.apache.activemq.artemis.api.core.ActiveMQException
|
||||
import org.apache.activemq.artemis.api.core.client.ActiveMQClient
|
||||
import org.apache.activemq.artemis.api.core.client.ClientSession
|
||||
import org.apache.activemq.artemis.api.core.client.ClientSessionFactory
|
||||
import org.apache.activemq.artemis.api.core.client.ServerLocator
|
||||
import rx.Observable
|
||||
import java.io.Closeable
|
||||
import java.time.Duration
|
||||
import javax.annotation.concurrent.ThreadSafe
|
||||
|
||||
/**
|
||||
* An RPC client connects to the specified server and allows you to make calls to the server that perform various
|
||||
* useful tasks. See the documentation for [proxy] or review the docsite to learn more about how this API works.
|
||||
*
|
||||
* @param host The hostname and messaging port of the node.
|
||||
* @param config If specified, the SSL configuration to use. If not specified, SSL will be disabled and the node will only be authenticated on non-SSL RPC port, the RPC traffic with not be encrypted when SSL is disabled.
|
||||
*/
|
||||
@ThreadSafe
|
||||
class CordaRPCClient(val host: HostAndPort, override val config: SSLConfiguration? = null, val serviceConfigurationOverride: (ServerLocator.() -> Unit)? = null) : Closeable, ArtemisMessagingComponent() {
|
||||
private companion object {
|
||||
val log = loggerFor<CordaRPCClient>()
|
||||
/** 10 MiB maximum allowed file size for attachments, including message headers. TODO: acquire this value from Network Map when supported. */
|
||||
@JvmStatic val MAX_FILE_SIZE = 10485760
|
||||
class CordaRPCConnection internal constructor(
|
||||
connection: RPCClient.RPCConnection<CordaRPCOps>
|
||||
) : RPCClient.RPCConnection<CordaRPCOps> by connection
|
||||
|
||||
data class CordaRPCClientConfiguration(
|
||||
val connectionMaxRetryInterval: Duration
|
||||
) {
|
||||
internal fun toRpcClientConfiguration(): RPCClientConfiguration {
|
||||
return RPCClientConfiguration.default.copy(
|
||||
connectionMaxRetryInterval = connectionMaxRetryInterval
|
||||
)
|
||||
}
|
||||
companion object {
|
||||
@JvmStatic
|
||||
val default = CordaRPCClientConfiguration(
|
||||
connectionMaxRetryInterval = RPCClientConfiguration.default.connectionMaxRetryInterval
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
class CordaRPCClient(
|
||||
hostAndPort: HostAndPort,
|
||||
sslConfiguration: SSLConfiguration? = null,
|
||||
configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.default
|
||||
) {
|
||||
private val rpcClient = RPCClient<CordaRPCOps>(
|
||||
tcpTransport(ConnectionDirection.Outbound(), hostAndPort, sslConfiguration),
|
||||
configuration.toRpcClientConfiguration()
|
||||
)
|
||||
|
||||
fun start(username: String, password: String): CordaRPCConnection {
|
||||
return CordaRPCConnection(rpcClient.start(CordaRPCOps::class.java, username, password))
|
||||
}
|
||||
|
||||
// TODO: Certificate handling for clients needs more work.
|
||||
private inner class State {
|
||||
var running = false
|
||||
lateinit var sessionFactory: ClientSessionFactory
|
||||
lateinit var session: ClientSession
|
||||
lateinit var clientImpl: CordaRPCClientImpl
|
||||
}
|
||||
|
||||
private val state = ThreadBox(State())
|
||||
|
||||
/**
|
||||
* Opens the connection to the server with the given username and password, then returns itself.
|
||||
* Registers a JVM shutdown hook to cleanly disconnect.
|
||||
*/
|
||||
@Throws(ActiveMQException::class)
|
||||
fun start(username: String, password: String): CordaRPCClient {
|
||||
state.locked {
|
||||
check(!running)
|
||||
log.logElapsedTime("Startup") {
|
||||
checkStorePasswords()
|
||||
val serverLocator = ActiveMQClient.createServerLocatorWithoutHA(tcpTransport(ConnectionDirection.Outbound(), host, config, enableSSL = config != null)).apply {
|
||||
// TODO: Put these in config file or make it user configurable?
|
||||
threadPoolMaxSize = 1
|
||||
confirmationWindowSize = 100000 // a guess
|
||||
retryInterval = 5.seconds.toMillis()
|
||||
retryIntervalMultiplier = 1.5 // Exponential backoff
|
||||
maxRetryInterval = 3.minutes.toMillis()
|
||||
minLargeMessageSize = MAX_FILE_SIZE
|
||||
serviceConfigurationOverride?.invoke(this)
|
||||
}
|
||||
sessionFactory = serverLocator.createSessionFactory()
|
||||
session = sessionFactory.createSession(username, password, false, true, true, serverLocator.isPreAcknowledge, serverLocator.ackBatchSize)
|
||||
session.start()
|
||||
clientImpl = CordaRPCClientImpl(session, state.lock, username)
|
||||
running = true
|
||||
}
|
||||
}
|
||||
|
||||
Runtime.getRuntime().addShutdownHook(Thread {
|
||||
close()
|
||||
})
|
||||
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* A convenience function that opens a connection with the given credentials, executes the given code block with all
|
||||
* available RPCs in scope and shuts down the RPC connection again. It's meant for quick prototyping and demos. For
|
||||
* more control you probably want to control the lifecycle of the client and proxies independently, as well as
|
||||
* configuring a timeout and other such features via the [proxy] method.
|
||||
*
|
||||
* After this method returns the client is closed and can't be restarted.
|
||||
*/
|
||||
@Throws(ActiveMQException::class)
|
||||
fun <T> use(username: String, password: String, block: CordaRPCOps.() -> T): T {
|
||||
require(!state.locked { running })
|
||||
start(username, password)
|
||||
(this as Closeable).use {
|
||||
return proxy().block()
|
||||
}
|
||||
}
|
||||
|
||||
/** Shuts down the client and lets the server know it can free the used resources (in a nice way). */
|
||||
override fun close() {
|
||||
state.locked {
|
||||
if (!running) return
|
||||
session.close()
|
||||
sessionFactory.close()
|
||||
running = false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a fresh proxy that lets you invoke RPCs on the server. Calls on it block, and if the server throws an
|
||||
* exception then it will be rethrown on the client. Proxies are thread safe but only one RPC can be in flight at
|
||||
* once. If you'd like to perform multiple RPCs in parallel, use this function multiple times to get multiple
|
||||
* proxies.
|
||||
*
|
||||
* Creation of a proxy is a somewhat expensive operation that involves calls to the server, so if you want to do
|
||||
* calls from many threads at once you should cache one proxy per thread and reuse them. This function itself is
|
||||
* thread safe though so requires no extra synchronisation.
|
||||
*
|
||||
* RPC sends and receives are logged on the net.corda.rpc logger.
|
||||
*
|
||||
* By default there are no timeouts on calls. This is deliberate, RPCs without timeouts can survive restarts,
|
||||
* maintenance downtime and moves of the server. RPCs can survive temporary losses or changes in client connectivity,
|
||||
* like switching between wifi networks. You can specify a timeout on the level of a proxy. If a call times
|
||||
* out it will throw [RPCException.Deadline].
|
||||
*
|
||||
* The [CordaRPCOps] defines what client RPCs are available. If an RPC returns an [Observable] anywhere in the
|
||||
* object graph returned then the server-side observable is transparently linked to a messaging queue, and that
|
||||
* queue linked to another observable on the client side here. *You are expected to use it*. The server will begin
|
||||
* buffering messages immediately that it will expect you to drain by subscribing to the returned observer. You can
|
||||
* opt-out of this by simply casting the [Observable] to [Closeable] or [AutoCloseable] and then calling the close
|
||||
* method on it. You don't have to explicitly close the observable if you actually subscribe to it: it will close
|
||||
* itself and free up the server-side resources either when the client or JVM itself is shutdown, or when there are
|
||||
* no more subscribers to it. Once all the subscribers to a returned observable are unsubscribed, the observable is
|
||||
* closed and you can't then re-subscribe again: you'll have to re-request a fresh observable with another RPC.
|
||||
*
|
||||
* The proxy and linked observables consume some small amount of resources on the server. It's OK to just exit your
|
||||
* process and let the server clean up, but in a long running process where you only need something for a short
|
||||
* amount of time it is polite to cast the objects to [Closeable] or [AutoCloseable] and close it when you are done.
|
||||
* Finalizers are in place to warn you if you lose a reference to an unclosed proxy or observable.
|
||||
*
|
||||
* @throws RPCException if the server version is too low or if the server isn't reachable within the given time.
|
||||
*/
|
||||
@JvmOverloads
|
||||
@Throws(RPCException::class)
|
||||
fun proxy(timeout: Duration? = null, minVersion: Int = 0): CordaRPCOps {
|
||||
return state.locked {
|
||||
check(running) { "Client must have been started first" }
|
||||
log.logElapsedTime("Proxy build") {
|
||||
clientImpl.proxyFor(CordaRPCOps::class.java, timeout, minVersion)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Suppress("UNUSED")
|
||||
private fun finalize() {
|
||||
state.locked {
|
||||
if (running) {
|
||||
rpcLog.warn("A CordaMQClient is being finalised whilst still running, did you forget to call close?")
|
||||
close()
|
||||
}
|
||||
}
|
||||
inline fun <A> use(username: String, password: String, block: (CordaRPCConnection) -> A): A {
|
||||
return start(username, password).use(block)
|
||||
}
|
||||
}
|
@ -1,418 +0,0 @@
|
||||
package net.corda.client.rpc
|
||||
|
||||
import com.esotericsoftware.kryo.Kryo
|
||||
import com.esotericsoftware.kryo.KryoException
|
||||
import com.esotericsoftware.kryo.Serializer
|
||||
import com.esotericsoftware.kryo.io.Input
|
||||
import com.esotericsoftware.kryo.io.Output
|
||||
import com.esotericsoftware.kryo.pool.KryoPool
|
||||
import com.google.common.cache.CacheBuilder
|
||||
import net.corda.core.ErrorOr
|
||||
import net.corda.core.bufferUntilSubscribed
|
||||
import net.corda.core.messaging.RPCOps
|
||||
import net.corda.core.messaging.RPCReturnsObservables
|
||||
import net.corda.core.random63BitValue
|
||||
import net.corda.core.serialization.deserialize
|
||||
import net.corda.core.serialization.serialize
|
||||
import net.corda.core.utilities.debug
|
||||
import net.corda.nodeapi.*
|
||||
import org.apache.activemq.artemis.api.core.ActiveMQObjectClosedException
|
||||
import org.apache.activemq.artemis.api.core.Message.HDR_DUPLICATE_DETECTION_ID
|
||||
import org.apache.activemq.artemis.api.core.SimpleString
|
||||
import org.apache.activemq.artemis.api.core.client.ClientConsumer
|
||||
import org.apache.activemq.artemis.api.core.client.ClientMessage
|
||||
import org.apache.activemq.artemis.api.core.client.ClientProducer
|
||||
import org.apache.activemq.artemis.api.core.client.ClientSession
|
||||
import rx.Observable
|
||||
import rx.subjects.PublishSubject
|
||||
import java.io.Closeable
|
||||
import java.lang.ref.WeakReference
|
||||
import java.lang.reflect.InvocationHandler
|
||||
import java.lang.reflect.Method
|
||||
import java.lang.reflect.Proxy
|
||||
import java.time.Duration
|
||||
import java.util.*
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
import java.util.concurrent.locks.ReentrantLock
|
||||
import javax.annotation.concurrent.GuardedBy
|
||||
import javax.annotation.concurrent.ThreadSafe
|
||||
import kotlin.concurrent.withLock
|
||||
import kotlin.reflect.jvm.javaMethod
|
||||
|
||||
/**
|
||||
* Core RPC engine implementation, to learn how to use RPC you should be looking at [CordaRPCClient].
|
||||
*
|
||||
* # Design notes
|
||||
*
|
||||
* The way RPCs are handled is fairly standard except for the handling of observables. When an RPC might return
|
||||
* an [Observable] it is specially tagged. This causes the client to create a new transient queue for the
|
||||
* receiving of observables and their observations with a random ID in the name. This ID is sent to the server in
|
||||
* a message header. All observations are sent via this single queue.
|
||||
*
|
||||
* The reason for doing it this way and not the more obvious approach of one-queue-per-observable is that we want
|
||||
* the queues to be *transient*, meaning their lifetime in the broker is tied to the session that created them.
|
||||
* A server side observable and its associated queue is not a cost-free thing, let alone the memory and resources
|
||||
* needed to actually generate the observations themselves, therefore we want to ensure these cannot leak. A
|
||||
* transient queue will be deleted automatically if the client session terminates, which by default happens on
|
||||
* disconnect but can also be configured to happen after a short delay (this allows clients to e.g. switch IP
|
||||
* address). On the server the deletion of the observations queue triggers unsubscription from the associated
|
||||
* observables, which in turn may then be garbage collected.
|
||||
*
|
||||
* Creating a transient queue requires a roundtrip to the broker and thus doing an RPC that could return
|
||||
* observables takes two server roundtrips instead of one. That's why we require RPCs to be marked with
|
||||
* [RPCReturnsObservables] as needing this special treatment instead of always doing it.
|
||||
*
|
||||
* If the Artemis/JMS APIs allowed us to create transient queues assigned to someone else then we could
|
||||
* potentially use a different design in which the node creates new transient queues (one per observable) on the
|
||||
* fly. The client would then have to watch out for this and start consuming those queues as they were created.
|
||||
*
|
||||
* We use one queue per RPC because we don't know ahead of time how many observables the server might return and
|
||||
* often the server doesn't know either, which pushes towards a single queue design, but at the same time the
|
||||
* processing of observations returned by an RPC might be striped across multiple threads and we'd like
|
||||
* backpressure management to not be scoped per client process but with more granularity. So we end up with
|
||||
* a compromise where the unit of backpressure management is the response to a single RPC.
|
||||
*
|
||||
* TODO: Backpressure isn't propagated all the way through the MQ broker at the moment.
|
||||
*/
|
||||
class CordaRPCClientImpl(private val session: ClientSession,
|
||||
private val sessionLock: ReentrantLock,
|
||||
private val username: String) {
|
||||
companion object {
|
||||
private val closeableCloseMethod = Closeable::close.javaMethod
|
||||
private val autocloseableCloseMethod = AutoCloseable::close.javaMethod
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds a proxy for the given type, which must descend from [RPCOps].
|
||||
*
|
||||
* @see CordaRPCClient.proxy for more information about how to use the proxies.
|
||||
*/
|
||||
fun <T : RPCOps> proxyFor(rpcInterface: Class<T>, timeout: Duration? = null, minVersion: Int = 0): T {
|
||||
sessionLock.withLock {
|
||||
if (producer == null)
|
||||
producer = session.createProducer()
|
||||
}
|
||||
val proxyImpl = RPCProxyHandler(timeout)
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
val proxy = Proxy.newProxyInstance(rpcInterface.classLoader, arrayOf(rpcInterface, Closeable::class.java), proxyImpl) as T
|
||||
proxyImpl.serverProtocolVersion = proxy.protocolVersion
|
||||
if (minVersion > proxyImpl.serverProtocolVersion)
|
||||
throw RPCException("Requested minimum protocol version $minVersion is higher than the server's supported protocol version (${proxyImpl.serverProtocolVersion})")
|
||||
return proxy
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
//region RPC engine
|
||||
//
|
||||
// You can find docs on all this in the api doc for the proxyFor method, and in the docsite.
|
||||
|
||||
// Utility to quickly suck out the contents of an Artemis message. There's probably a more efficient way to
|
||||
// do this.
|
||||
private fun <T : Any> ClientMessage.deserialize(kryo: Kryo): T = ByteArray(bodySize).apply { bodyBuffer.readBytes(this) }.deserialize(kryo)
|
||||
|
||||
// We by default use a weak reference so GC can happen, otherwise they persist for the life of the client.
|
||||
@GuardedBy("sessionLock")
|
||||
private val addressToQueuedObservables = CacheBuilder.newBuilder().weakValues().build<String, QueuedObservable>()
|
||||
// This is used to hold a reference counted hard reference when we know there are subscribers.
|
||||
private val hardReferencesToQueuedObservables = Collections.synchronizedSet(mutableSetOf<QueuedObservable>())
|
||||
|
||||
private var producer: ClientProducer? = null
|
||||
|
||||
class ObservableDeserializer : Serializer<Observable<Any>>() {
|
||||
override fun read(kryo: Kryo, input: Input, type: Class<Observable<Any>>): Observable<Any> {
|
||||
val qName = kryo.context[RPCKryoQNameKey] as String
|
||||
val rpcName = kryo.context[RPCKryoMethodNameKey] as String
|
||||
val rpcLocation = kryo.context[RPCKryoLocationKey] as Throwable
|
||||
val rpcClient = kryo.context[RPCKryoClientKey] as CordaRPCClientImpl
|
||||
val handle = input.readInt(true)
|
||||
val ob = rpcClient.sessionLock.withLock {
|
||||
rpcClient.addressToQueuedObservables.getIfPresent(qName) ?: rpcClient.QueuedObservable(qName, rpcName, rpcLocation).apply {
|
||||
rpcClient.addressToQueuedObservables.put(qName, this)
|
||||
}
|
||||
}
|
||||
val result = ob.getForHandle(handle)
|
||||
rpcLog.debug { "Deserializing and connecting a new observable for $rpcName on $qName: $result" }
|
||||
return result
|
||||
}
|
||||
|
||||
override fun write(kryo: Kryo, output: Output, `object`: Observable<Any>) {
|
||||
throw UnsupportedOperationException("not implemented")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The proxy class returned to the client is auto-generated on the fly by the java.lang.reflect Proxy
|
||||
* infrastructure. The JDK Proxy class writes bytecode into memory for a class that implements the requested
|
||||
* interfaces and then routes all method calls to the invoke method below in a conveniently reified form.
|
||||
* We can then easily take the data about the method call and turn it into an RPC. This avoids the need
|
||||
* for the compile-time code generation which is so common in RPC systems.
|
||||
*/
|
||||
@ThreadSafe
|
||||
private inner class RPCProxyHandler(private val timeout: Duration?) : InvocationHandler, Closeable {
|
||||
private val proxyId = random63BitValue()
|
||||
private val consumer: ClientConsumer
|
||||
|
||||
var serverProtocolVersion = 0
|
||||
|
||||
init {
|
||||
val proxyAddress = constructAddress(proxyId)
|
||||
consumer = sessionLock.withLock {
|
||||
session.createTemporaryQueue(proxyAddress, proxyAddress)
|
||||
session.createConsumer(proxyAddress)
|
||||
}
|
||||
}
|
||||
|
||||
private fun constructAddress(addressId: Long) = "${ArtemisMessagingComponent.CLIENTS_PREFIX}$username.rpc.$addressId"
|
||||
|
||||
@Synchronized
|
||||
override fun invoke(proxy: Any, method: Method, args: Array<out Any>?): Any? {
|
||||
if (isCloseInvocation(method)) {
|
||||
close()
|
||||
return null
|
||||
}
|
||||
if (method.name == "toString" && args == null)
|
||||
return "Client RPC proxy"
|
||||
|
||||
if (consumer.isClosed)
|
||||
throw RPCException("RPC Proxy is closed")
|
||||
|
||||
// All invoked methods on the proxy end up here.
|
||||
val location = Throwable()
|
||||
rpcLog.debug {
|
||||
val argStr = args?.joinToString() ?: ""
|
||||
"-> RPC -> ${method.name}($argStr): ${method.returnType}"
|
||||
}
|
||||
|
||||
checkMethodVersion(method)
|
||||
|
||||
val msg: ClientMessage = createMessage(method)
|
||||
// We could of course also check the return type of the method to see if it's Observable, but I'd
|
||||
// rather haved the annotation be used consistently.
|
||||
val returnsObservables = method.isAnnotationPresent(RPCReturnsObservables::class.java)
|
||||
val kryo = if (returnsObservables) maybePrepareForObservables(location, method, msg) else createRPCKryoForDeserialization(this@CordaRPCClientImpl)
|
||||
val next: ErrorOr<*> = try {
|
||||
sendRequest(args, msg)
|
||||
receiveResponse(kryo, method, timeout)
|
||||
} finally {
|
||||
releaseRPCKryoForDeserialization(kryo)
|
||||
}
|
||||
rpcLog.debug { "<- RPC <- ${method.name} = $next" }
|
||||
return unwrapOrThrow(next)
|
||||
}
|
||||
|
||||
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
|
||||
private fun unwrapOrThrow(next: ErrorOr<*>): Any? {
|
||||
val ex = next.error
|
||||
if (ex != null) {
|
||||
// Replace the stack trace because that's an implementation detail of the server that isn't so
|
||||
// helpful to the user who wants to see where the error was on their side, and serialising stack
|
||||
// frame objects is a bit annoying. We slice it here to avoid the invoke() machinery being exposed.
|
||||
// The resulting exception looks like it was thrown from inside the called method.
|
||||
(ex as java.lang.Throwable).stackTrace = java.lang.Throwable().stackTrace.let { it.sliceArray(1..it.size - 1) }
|
||||
throw ex
|
||||
} else {
|
||||
return next.value
|
||||
}
|
||||
}
|
||||
|
||||
private fun receiveResponse(kryo: Kryo, method: Method, timeout: Duration?): ErrorOr<*> {
|
||||
val artemisMessage: ClientMessage =
|
||||
if (timeout == null)
|
||||
consumer.receive() ?: throw ActiveMQObjectClosedException()
|
||||
else
|
||||
consumer.receive(timeout.toMillis()) ?: throw RPCException.DeadlineExceeded(method.name)
|
||||
artemisMessage.acknowledge()
|
||||
val next = artemisMessage.deserialize<ErrorOr<*>>(kryo)
|
||||
return next
|
||||
}
|
||||
|
||||
private fun sendRequest(args: Array<out Any>?, msg: ClientMessage) {
|
||||
sessionLock.withLock {
|
||||
val argsKryo = createRPCKryoForDeserialization(this@CordaRPCClientImpl)
|
||||
val serializedArgs = try {
|
||||
(args ?: emptyArray<Any?>()).serialize(argsKryo)
|
||||
} catch (e: KryoException) {
|
||||
throw RPCException("Could not serialize RPC arguments", e)
|
||||
} finally {
|
||||
releaseRPCKryoForDeserialization(argsKryo)
|
||||
}
|
||||
msg.writeBodyBufferBytes(serializedArgs.bytes)
|
||||
producer!!.send(ArtemisMessagingComponent.RPC_REQUESTS_QUEUE, msg)
|
||||
}
|
||||
}
|
||||
|
||||
private fun maybePrepareForObservables(location: Throwable, method: Method, msg: ClientMessage): Kryo {
|
||||
// Create a temporary queue just for the emissions on any observables that are returned.
|
||||
val observationsId = random63BitValue()
|
||||
val observationsQueueName = constructAddress(observationsId)
|
||||
session.createTemporaryQueue(observationsQueueName, observationsQueueName)
|
||||
msg.putLongProperty(ClientRPCRequestMessage.OBSERVATIONS_TO, observationsId)
|
||||
// And make sure that we deserialise observable handles so that they're linked to the right
|
||||
// queue. Also record a bit of metadata for debugging purposes.
|
||||
return createRPCKryoForDeserialization(this@CordaRPCClientImpl, observationsQueueName, method.name, location)
|
||||
}
|
||||
|
||||
private fun createMessage(method: Method): ClientMessage {
|
||||
return session.createMessage(false).apply {
|
||||
putStringProperty(ClientRPCRequestMessage.METHOD_NAME, method.name)
|
||||
putLongProperty(ClientRPCRequestMessage.REPLY_TO, proxyId)
|
||||
// Use the magic deduplication property built into Artemis as our message identity too
|
||||
putStringProperty(HDR_DUPLICATE_DETECTION_ID, SimpleString(UUID.randomUUID().toString()))
|
||||
}
|
||||
}
|
||||
|
||||
private fun checkMethodVersion(method: Method) {
|
||||
val methodVersion = method.getAnnotation(RPCSinceVersion::class.java)?.version ?: 0
|
||||
if (methodVersion > serverProtocolVersion)
|
||||
throw UnsupportedOperationException("Method ${method.name} was added in RPC protocol version $methodVersion but the server is running $serverProtocolVersion")
|
||||
}
|
||||
|
||||
private fun isCloseInvocation(method: Method) = method == closeableCloseMethod || method == autocloseableCloseMethod
|
||||
|
||||
override fun close() {
|
||||
consumer.close()
|
||||
sessionLock.withLock { session.deleteQueue(constructAddress(proxyId)) }
|
||||
}
|
||||
|
||||
override fun toString() = "Corda RPC Proxy listening on queue ${constructAddress(proxyId)}"
|
||||
}
|
||||
|
||||
/**
|
||||
* When subscribed to, starts consuming from the given queue name and demultiplexing the observables being
|
||||
* sent to it. The server queue is moved into in-memory buffers (one per attached server-side observable)
|
||||
* until drained through a subscription. When the subscriptions are all gone, the server-side queue is deleted.
|
||||
*/
|
||||
@ThreadSafe
|
||||
private inner class QueuedObservable(private val qName: String,
|
||||
private val rpcName: String,
|
||||
private val rpcLocation: Throwable) {
|
||||
private val root = PublishSubject.create<MarshalledObservation>()
|
||||
private val rootShared = root.doOnUnsubscribe { close() }.share()
|
||||
|
||||
// This could be made more efficient by using a specialised IntMap
|
||||
// When handling this map we don't synchronise on [this], otherwise there is a race condition between close() and deliver()
|
||||
private val observables = Collections.synchronizedMap(HashMap<Int, Observable<Any>>())
|
||||
|
||||
@GuardedBy("sessionLock")
|
||||
private var consumer: ClientConsumer? = null
|
||||
|
||||
private val referenceCount = AtomicInteger(0)
|
||||
|
||||
// We have to create a weak reference, otherwise we cannot be GC'd.
|
||||
init {
|
||||
val weakThis = WeakReference<QueuedObservable>(this)
|
||||
consumer = sessionLock.withLock { session.createConsumer(qName) }.setMessageHandler { weakThis.get()?.deliver(it) }
|
||||
}
|
||||
|
||||
/**
|
||||
* We have to reference count subscriptions to the returned [Observable]s to prevent early GC because we are
|
||||
* weak referenced.
|
||||
*
|
||||
* Derived [Observables] (e.g. filtered etc) hold a strong reference to the original, but for example, if
|
||||
* the pattern as follows is used, the original passes out of scope and the direction of reference is from the
|
||||
* original to the [Observer]. We use the reference counting to allow for this pattern.
|
||||
*
|
||||
* val observationsSubject = PublishSubject.create<Observation>()
|
||||
* originalObservable.subscribe(observationsSubject)
|
||||
* return observationsSubject
|
||||
*/
|
||||
private fun refCountUp() {
|
||||
if (referenceCount.andIncrement == 0) {
|
||||
hardReferencesToQueuedObservables.add(this)
|
||||
}
|
||||
}
|
||||
|
||||
private fun refCountDown() {
|
||||
if (referenceCount.decrementAndGet() == 0) {
|
||||
hardReferencesToQueuedObservables.remove(this)
|
||||
}
|
||||
}
|
||||
|
||||
fun getForHandle(handle: Int): Observable<Any> {
|
||||
synchronized(observables) {
|
||||
return observables.getOrPut(handle) {
|
||||
/**
|
||||
* Note that the order of bufferUntilSubscribed() -> dematerialize() is very important here.
|
||||
*
|
||||
* In particular doing it the other way around may result in the following edge case:
|
||||
* The RPC returns two (or more) Observables. The first Observable unsubscribes *during serialisation*,
|
||||
* before the second one is hit, causing the [rootShared] to unsubscribe and consequently closing
|
||||
* the underlying artemis queue, even though the second Observable was not even registered.
|
||||
*
|
||||
* The buffer -> dematerialize order ensures that the Observable may not unsubscribe until the caller
|
||||
* subscribes, which must be after full deserialisation and registering of all top level Observables.
|
||||
*
|
||||
* In addition, when subscribe and unsubscribe is called on the [Observable] returned here, we
|
||||
* reference count a hard reference to this [QueuedObservable] to prevent premature GC.
|
||||
*/
|
||||
rootShared.filter { it.forHandle == handle }.map { it.what }.bufferUntilSubscribed().dematerialize<Any>().doOnSubscribe { refCountUp() }.doOnUnsubscribe { refCountDown() }.share()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun deliver(msg: ClientMessage) {
|
||||
sessionLock.withLock { msg.acknowledge() }
|
||||
val kryo = createRPCKryoForDeserialization(this@CordaRPCClientImpl, qName, rpcName, rpcLocation)
|
||||
val received: MarshalledObservation = try {
|
||||
msg.deserialize(kryo)
|
||||
} finally {
|
||||
releaseRPCKryoForDeserialization(kryo)
|
||||
}
|
||||
rpcLog.debug { "<- Observable [$rpcName] <- Received $received" }
|
||||
synchronized(observables) {
|
||||
// Force creation of the buffer if it doesn't already exist.
|
||||
getForHandle(received.forHandle)
|
||||
root.onNext(received)
|
||||
}
|
||||
}
|
||||
|
||||
fun close() {
|
||||
sessionLock.withLock {
|
||||
if (consumer != null) {
|
||||
rpcLog.debug("Closing queue observable for call to $rpcName : $qName")
|
||||
consumer?.close()
|
||||
consumer = null
|
||||
session.deleteQueue(qName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Suppress("UNUSED")
|
||||
fun finalize() {
|
||||
val closed = sessionLock.withLock {
|
||||
if (consumer != null) {
|
||||
consumer!!.close()
|
||||
consumer = null
|
||||
true
|
||||
} else
|
||||
false
|
||||
}
|
||||
if (closed) {
|
||||
rpcLog.warn("""A hot observable returned from an RPC ($rpcName) was never subscribed to.
|
||||
This wastes server-side resources because it was queueing observations for retrieval.
|
||||
It is being closed now, but please adjust your code to call .notUsed() on the observable
|
||||
to close it explicitly. (Java users: subscribe to it then unsubscribe). This warning
|
||||
will appear less frequently in future versions of the platform and you can ignore it
|
||||
if you want to.
|
||||
""".trimIndent().replace('\n', ' '), rpcLocation)
|
||||
}
|
||||
}
|
||||
}
|
||||
//endregion
|
||||
}
|
||||
|
||||
private val rpcDesKryoPool = KryoPool.Builder { RPCKryo(CordaRPCClientImpl.ObservableDeserializer()) }.build()
|
||||
|
||||
fun createRPCKryoForDeserialization(rpcClient: CordaRPCClientImpl, qName: String? = null, rpcName: String? = null, rpcLocation: Throwable? = null): Kryo {
|
||||
val kryo = rpcDesKryoPool.borrow()
|
||||
kryo.context.put(RPCKryoClientKey, rpcClient)
|
||||
kryo.context.put(RPCKryoQNameKey, qName)
|
||||
kryo.context.put(RPCKryoMethodNameKey, rpcName)
|
||||
kryo.context.put(RPCKryoLocationKey, rpcLocation)
|
||||
return kryo
|
||||
}
|
||||
|
||||
fun releaseRPCKryoForDeserialization(kryo: Kryo) {
|
||||
rpcDesKryoPool.release(kryo)
|
||||
}
|
@ -0,0 +1,169 @@
|
||||
package net.corda.client.rpc.internal
|
||||
|
||||
import com.google.common.net.HostAndPort
|
||||
import net.corda.core.logElapsedTime
|
||||
import net.corda.core.messaging.RPCOps
|
||||
import net.corda.core.minutes
|
||||
import net.corda.core.random63BitValue
|
||||
import net.corda.core.seconds
|
||||
import net.corda.core.utilities.loggerFor
|
||||
import net.corda.nodeapi.ArtemisTcpTransport.Companion.tcpTransport
|
||||
import net.corda.nodeapi.ConnectionDirection
|
||||
import net.corda.nodeapi.RPCApi
|
||||
import net.corda.nodeapi.RPCException
|
||||
import net.corda.nodeapi.config.SSLConfiguration
|
||||
import org.apache.activemq.artemis.api.core.SimpleString
|
||||
import org.apache.activemq.artemis.api.core.TransportConfiguration
|
||||
import org.apache.activemq.artemis.api.core.client.ActiveMQClient
|
||||
import java.io.Closeable
|
||||
import java.lang.reflect.Proxy
|
||||
import java.time.Duration
|
||||
|
||||
/**
|
||||
* This configuration may be used to tweak the internals of the RPC client.
|
||||
*/
|
||||
data class RPCClientConfiguration(
|
||||
/** The minimum protocol version required from the server */
|
||||
val minimumServerProtocolVersion: Int,
|
||||
/**
|
||||
* If set to true the client will track RPC call sites. If an error occurs subsequently during the RPC or in a
|
||||
* returned Observable stream the stack trace of the originating RPC will be shown as well. Note that
|
||||
* constructing call stacks is a moderately expensive operation.
|
||||
*/
|
||||
val trackRpcCallSites: Boolean,
|
||||
/**
|
||||
* The interval of unused observable reaping in milliseconds. Leaked Observables (unused ones) are
|
||||
* detected using weak references and are cleaned up in batches in this interval. If set too large it will waste
|
||||
* server side resources for this duration. If set too low it wastes client side cycles.
|
||||
*/
|
||||
val reapIntervalMs: Long,
|
||||
/** The number of threads to use for observations (for executing [Observable.onNext]) */
|
||||
val observationExecutorPoolSize: Int,
|
||||
/** The maximum number of producers to create to handle outgoing messages */
|
||||
val producerPoolBound: Int,
|
||||
/**
|
||||
* Determines the concurrency level of the Observable Cache. This is exposed because it implicitly determines
|
||||
* the limit on the number of leaked observables reaped because of garbage collection per reaping.
|
||||
* See the implementation of [com.google.common.cache.LocalCache] for details.
|
||||
*/
|
||||
val cacheConcurrencyLevel: Int,
|
||||
/** The retry interval of artemis connections in milliseconds */
|
||||
val connectionRetryInterval: Duration,
|
||||
/** The retry interval multiplier for exponential backoff */
|
||||
val connectionRetryIntervalMultiplier: Double,
|
||||
/** Maximum retry interval */
|
||||
val connectionMaxRetryInterval: Duration,
|
||||
/** Maximum file size */
|
||||
val maxFileSize: Int
|
||||
) {
|
||||
companion object {
|
||||
@JvmStatic
|
||||
val default = RPCClientConfiguration(
|
||||
minimumServerProtocolVersion = 0,
|
||||
trackRpcCallSites = false,
|
||||
reapIntervalMs = 1000,
|
||||
observationExecutorPoolSize = 4,
|
||||
producerPoolBound = 1,
|
||||
cacheConcurrencyLevel = 8,
|
||||
connectionRetryInterval = 5.seconds,
|
||||
connectionRetryIntervalMultiplier = 1.5,
|
||||
connectionMaxRetryInterval = 3.minutes,
|
||||
/** 10 MiB maximum allowed file size for attachments, including message headers. TODO: acquire this value from Network Map when supported. */
|
||||
maxFileSize = 10485760
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* An RPC client that may be used to create connections to an RPC server.
|
||||
*
|
||||
* @param transport The Artemis transport to use to connect to the server.
|
||||
* @param rpcConfiguration Configuration used to tweak client behaviour.
|
||||
*/
|
||||
class RPCClient<I : RPCOps>(
|
||||
val transport: TransportConfiguration,
|
||||
val rpcConfiguration: RPCClientConfiguration = RPCClientConfiguration.default
|
||||
) {
|
||||
constructor(
|
||||
hostAndPort: HostAndPort,
|
||||
sslConfiguration: SSLConfiguration? = null,
|
||||
configuration: RPCClientConfiguration = RPCClientConfiguration.default
|
||||
) : this(tcpTransport(ConnectionDirection.Outbound(), hostAndPort, sslConfiguration), configuration)
|
||||
|
||||
companion object {
|
||||
private val log = loggerFor<RPCClient<*>>()
|
||||
}
|
||||
|
||||
/**
|
||||
* Holds a proxy object implementing [I] that forwards requests to the RPC server.
|
||||
*
|
||||
* [Closeable.close] may be used to shut down the connection and release associated resources.
|
||||
*/
|
||||
interface RPCConnection<out I : RPCOps> : Closeable {
|
||||
val proxy: I
|
||||
/** The RPC protocol version reported by the server */
|
||||
val serverProtocolVersion: Int
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an [RPCConnection] containing a proxy that lets you invoke RPCs on the server. Calls on it block, and if
|
||||
* the server throws an exception then it will be rethrown on the client. Proxies are thread safe and may be used to
|
||||
* invoke multiple RPCs in parallel.
|
||||
*
|
||||
* RPC sends and receives are logged on the net.corda.rpc logger.
|
||||
*
|
||||
* The [RPCOps] defines what client RPCs are available. If an RPC returns an [Observable] anywhere in the object
|
||||
* graph returned then the server-side observable is transparently forwarded to the client side here.
|
||||
* *You are expected to use it*. The server will begin buffering messages immediately that it will expect you to
|
||||
* drain by subscribing to the returned observer. You can opt-out of this by simply calling the
|
||||
* [net.corda.client.rpc.notUsed] method on it. You don't have to explicitly close the observable if you actually
|
||||
* subscribe to it: it will close itself and free up the server-side resources either when the client or JVM itself
|
||||
* is shutdown, or when there are no more subscribers to it. Once all the subscribers to a returned observable are
|
||||
* unsubscribed or the observable completes successfully or with an error, the observable is closed and you can't
|
||||
* then re-subscribe again: you'll have to re-request a fresh observable with another RPC.
|
||||
*
|
||||
* @param rpcOpsClass The [Class] of the RPC interface.
|
||||
* @param username The username to authenticate with.
|
||||
* @param password The password to authenticate with.
|
||||
* @throws RPCException if the server version is too low or if the server isn't reachable within the given time.
|
||||
*/
|
||||
fun start(
|
||||
rpcOpsClass: Class<I>,
|
||||
username: String,
|
||||
password: String
|
||||
): RPCConnection<I> {
|
||||
return log.logElapsedTime("Startup") {
|
||||
val clientAddress = SimpleString("${RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX}.$username.${random63BitValue()}")
|
||||
|
||||
val serverLocator = ActiveMQClient.createServerLocatorWithoutHA(transport).apply {
|
||||
retryInterval = rpcConfiguration.connectionRetryInterval.toMillis()
|
||||
retryIntervalMultiplier = rpcConfiguration.connectionRetryIntervalMultiplier
|
||||
maxRetryInterval = rpcConfiguration.connectionMaxRetryInterval.toMillis()
|
||||
minLargeMessageSize = rpcConfiguration.maxFileSize
|
||||
}
|
||||
|
||||
val proxyHandler = RPCClientProxyHandler(rpcConfiguration, username, password, serverLocator, clientAddress, rpcOpsClass)
|
||||
proxyHandler.start()
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
val ops = Proxy.newProxyInstance(rpcOpsClass.classLoader, arrayOf(rpcOpsClass), proxyHandler) as I
|
||||
|
||||
val serverProtocolVersion = ops.protocolVersion
|
||||
if (serverProtocolVersion < rpcConfiguration.minimumServerProtocolVersion) {
|
||||
throw RPCException("Requested minimum protocol version (${rpcConfiguration.minimumServerProtocolVersion}) is higher" +
|
||||
" than the server's supported protocol version ($serverProtocolVersion)")
|
||||
}
|
||||
proxyHandler.setServerProtocolVersion(serverProtocolVersion)
|
||||
|
||||
log.debug("RPC connected, returning proxy")
|
||||
object : RPCConnection<I> {
|
||||
override val proxy = ops
|
||||
override val serverProtocolVersion = serverProtocolVersion
|
||||
override fun close() {
|
||||
proxyHandler.close()
|
||||
serverLocator.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,423 @@
|
||||
package net.corda.client.rpc.internal
|
||||
|
||||
import com.esotericsoftware.kryo.Kryo
|
||||
import com.esotericsoftware.kryo.Serializer
|
||||
import com.esotericsoftware.kryo.io.Input
|
||||
import com.esotericsoftware.kryo.io.Output
|
||||
import com.esotericsoftware.kryo.pool.KryoPool
|
||||
import com.google.common.cache.Cache
|
||||
import com.google.common.cache.CacheBuilder
|
||||
import com.google.common.cache.RemovalCause
|
||||
import com.google.common.cache.RemovalListener
|
||||
import com.google.common.util.concurrent.SettableFuture
|
||||
import com.google.common.util.concurrent.ThreadFactoryBuilder
|
||||
import net.corda.core.ThreadBox
|
||||
import net.corda.core.getOrThrow
|
||||
import net.corda.core.messaging.RPCOps
|
||||
import net.corda.core.random63BitValue
|
||||
import net.corda.core.serialization.KryoPoolWithContext
|
||||
import net.corda.core.utilities.*
|
||||
import net.corda.nodeapi.*
|
||||
import org.apache.activemq.artemis.api.core.SimpleString
|
||||
import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE
|
||||
import org.apache.activemq.artemis.api.core.client.ClientMessage
|
||||
import org.apache.activemq.artemis.api.core.client.ServerLocator
|
||||
import rx.Notification
|
||||
import rx.Observable
|
||||
import rx.subjects.UnicastSubject
|
||||
import sun.reflect.CallerSensitive
|
||||
import java.lang.reflect.InvocationHandler
|
||||
import java.lang.reflect.Method
|
||||
import java.util.*
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import java.util.concurrent.Executors
|
||||
import java.util.concurrent.ScheduledFuture
|
||||
import java.util.concurrent.TimeUnit
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
import kotlin.collections.ArrayList
|
||||
import kotlin.reflect.jvm.javaMethod
|
||||
|
||||
/**
|
||||
* This class provides a proxy implementation of an RPC interface for RPC clients. It translates API calls to lower-level
|
||||
* RPC protocol messages. For this protocol see [RPCApi].
|
||||
*
|
||||
* When a method is called on the interface the arguments are serialised and the request is forwarded to the server. The
|
||||
* server then executes the code that implements the RPC and sends a reply.
|
||||
*
|
||||
* An RPC reply may contain [Observable]s, which are serialised simply as unique IDs. On the client side we create a
|
||||
* [UnicastSubject] for each such ID. Subsequently the server may send observations attached to this ID, which are
|
||||
* forwarded to the [UnicastSubject]. Note that the observations themselves may contain further [Observable]s, which are
|
||||
* handled in the same way.
|
||||
*
|
||||
* To do the above we take advantage of Kryo's datastructure traversal. When the client is deserialising a message from
|
||||
* the server that may contain Observables it is supplied with an [ObservableContext] that exposes the map used to demux
|
||||
* the observations. When an [Observable] is encountered during traversal a new [UnicastSubject] is added to the map and
|
||||
* we carry on. Each observation later contains the corresponding Observable ID, and we just forward that to the
|
||||
* associated [UnicastSubject].
|
||||
*
|
||||
* The client may signal that it no longer consumes a particular [Observable]. This may be done explicitly by
|
||||
* unsubscribing from the [Observable], or if the [Observable] is garbage collected the client will eventually
|
||||
* automatically signal the server. This is done using a cache that holds weak references to the [UnicastSubject]s.
|
||||
* The cleanup happens in batches using a dedicated reaper, scheduled on [reaperExecutor].
|
||||
*/
|
||||
class RPCClientProxyHandler(
|
||||
private val rpcConfiguration: RPCClientConfiguration,
|
||||
private val rpcUsername: String,
|
||||
private val rpcPassword: String,
|
||||
private val serverLocator: ServerLocator,
|
||||
private val clientAddress: SimpleString,
|
||||
private val rpcOpsClass: Class<out RPCOps>
|
||||
) : InvocationHandler {
|
||||
|
||||
private enum class State {
|
||||
UNSTARTED,
|
||||
SERVER_VERSION_NOT_SET,
|
||||
STARTED,
|
||||
FINISHED
|
||||
}
|
||||
private val lifeCycle = LifeCycle(State.UNSTARTED)
|
||||
|
||||
private companion object {
|
||||
val log = loggerFor<RPCClientProxyHandler>()
|
||||
// Note that this KryoPool is not yet capable of deserialising Observables, it requires Proxy-specific context
|
||||
// to do that. However it may still be used for serialisation of RPC requests and related messages.
|
||||
val kryoPool = KryoPool.Builder { RPCKryo(RpcClientObservableSerializer) }.build()
|
||||
// To check whether toString() is being invoked
|
||||
val toStringMethod: Method = Object::toString.javaMethod!!
|
||||
}
|
||||
|
||||
// Used for reaping
|
||||
private val reaperExecutor = Executors.newScheduledThreadPool(
|
||||
1,
|
||||
ThreadFactoryBuilder().setNameFormat("rpc-client-reaper-%d").build()
|
||||
)
|
||||
|
||||
// A sticky pool for running Observable.onNext()s. We need the stickiness to preserve the observation ordering.
|
||||
private val observationExecutorThreadFactory = ThreadFactoryBuilder().setNameFormat("rpc-client-observation-pool-%d").build()
|
||||
private val observationExecutorPool = LazyStickyPool(rpcConfiguration.observationExecutorPoolSize) {
|
||||
Executors.newFixedThreadPool(1, observationExecutorThreadFactory)
|
||||
}
|
||||
|
||||
// Holds the RPC reply futures.
|
||||
private val rpcReplyMap = RpcReplyMap()
|
||||
// Optionally holds RPC call site stack traces to be shown on errors/warnings.
|
||||
private val callSiteMap = if (rpcConfiguration.trackRpcCallSites) CallSiteMap() else null
|
||||
// Holds the Observables and a reference store to keep Observables alive when subscribed to.
|
||||
private val observableContext = ObservableContext(
|
||||
callSiteMap = callSiteMap,
|
||||
observableMap = createRpcObservableMap(),
|
||||
hardReferenceStore = Collections.synchronizedSet(mutableSetOf<Observable<*>>())
|
||||
)
|
||||
// Holds a reference to the scheduled reaper.
|
||||
private lateinit var reaperScheduledFuture: ScheduledFuture<*>
|
||||
// The protocol version of the server, to be initialised to the value of [RPCOps.protocolVersion]
|
||||
private var serverProtocolVersion: Int? = null
|
||||
|
||||
// Stores the Observable IDs that are already removed from the map but are not yet sent to the server.
|
||||
private val observablesToReap = ThreadBox(object {
|
||||
var observables = ArrayList<RPCApi.ObservableId>()
|
||||
})
|
||||
// A Kryo pool that automatically adds the observable context when an instance is requested.
|
||||
private val kryoPoolWithObservableContext = RpcClientObservableSerializer.createPoolWithContext(kryoPool, observableContext)
|
||||
|
||||
private fun createRpcObservableMap(): RpcObservableMap {
|
||||
val onObservableRemove = RemovalListener<RPCApi.ObservableId, UnicastSubject<Notification<Any>>> {
|
||||
val rpcCallSite = callSiteMap?.remove(it.key.toLong)
|
||||
if (it.cause == RemovalCause.COLLECTED) {
|
||||
log.warn(listOf(
|
||||
"A hot observable returned from an RPC was never subscribed to.",
|
||||
"This wastes server-side resources because it was queueing observations for retrieval.",
|
||||
"It is being closed now, but please adjust your code to call .notUsed() on the observable",
|
||||
"to close it explicitly. (Java users: subscribe to it then unsubscribe). This warning",
|
||||
"will appear less frequently in future versions of the platform and you can ignore it",
|
||||
"if you want to.").joinToString(" "), rpcCallSite)
|
||||
}
|
||||
observablesToReap.locked { observables.add(it.key) }
|
||||
}
|
||||
return CacheBuilder.newBuilder().
|
||||
weakValues().
|
||||
removalListener(onObservableRemove).
|
||||
concurrencyLevel(rpcConfiguration.cacheConcurrencyLevel).
|
||||
build()
|
||||
}
|
||||
|
||||
// We cannot pool consumers as we need to preserve the original muxed message order.
|
||||
// TODO We may need to pool these somehow anyway, otherwise if the server sends many big messages in parallel a
|
||||
// single consumer may be starved for flow control credits. Recheck this once Artemis's large message streaming is
|
||||
// integrated properly.
|
||||
private lateinit var sessionAndConsumer: ArtemisConsumer
|
||||
// Pool producers to reduce contention on the client side.
|
||||
private val sessionAndProducerPool = LazyPool(bound = rpcConfiguration.producerPoolBound) {
|
||||
// Note how we create new sessions *and* session factories per producer.
|
||||
// We cannot simply pool producers on one session because sessions are single threaded.
|
||||
// We cannot simply pool sessions on one session factory because flow control credits are tied to factories, so
|
||||
// sessions tend to starve each other when used concurrently.
|
||||
val sessionFactory = serverLocator.createSessionFactory()
|
||||
val session = sessionFactory.createSession(rpcUsername, rpcPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE)
|
||||
session.start()
|
||||
ArtemisProducer(sessionFactory, session, session.createProducer(RPCApi.RPC_SERVER_QUEUE_NAME))
|
||||
}
|
||||
|
||||
/**
|
||||
* Start the client. This creates the per-client queue, starts the consumer session and the reaper.
|
||||
*/
|
||||
fun start() {
|
||||
lifeCycle.transition(State.UNSTARTED, State.SERVER_VERSION_NOT_SET)
|
||||
reaperScheduledFuture = reaperExecutor.scheduleAtFixedRate(
|
||||
this::reapObservables,
|
||||
rpcConfiguration.reapIntervalMs,
|
||||
rpcConfiguration.reapIntervalMs,
|
||||
TimeUnit.MILLISECONDS
|
||||
)
|
||||
sessionAndProducerPool.run {
|
||||
it.session.createTemporaryQueue(clientAddress, clientAddress)
|
||||
}
|
||||
val sessionFactory = serverLocator.createSessionFactory()
|
||||
val session = sessionFactory.createSession(rpcUsername, rpcPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE)
|
||||
val consumer = session.createConsumer(clientAddress)
|
||||
consumer.setMessageHandler(this@RPCClientProxyHandler::artemisMessageHandler)
|
||||
session.start()
|
||||
sessionAndConsumer = ArtemisConsumer(sessionFactory, session, consumer)
|
||||
}
|
||||
|
||||
// This is the general function that transforms a client side RPC to internal Artemis messages.
|
||||
@CallerSensitive
|
||||
override fun invoke(proxy: Any, method: Method, arguments: Array<out Any?>?): Any? {
|
||||
lifeCycle.requireState { it == State.STARTED || it == State.SERVER_VERSION_NOT_SET }
|
||||
checkProtocolVersion(method)
|
||||
if (method == toStringMethod) {
|
||||
return "Client RPC proxy for $rpcOpsClass"
|
||||
}
|
||||
if (sessionAndConsumer.session.isClosed) {
|
||||
throw RPCException("RPC Proxy is closed")
|
||||
}
|
||||
val rpcId = RPCApi.RpcRequestId(random63BitValue())
|
||||
callSiteMap?.set(rpcId.toLong, Throwable("<Call site of root RPC '${method.name}'>"))
|
||||
try {
|
||||
val request = RPCApi.ClientToServer.RpcRequest(clientAddress, rpcId, method.name, arguments?.toList() ?: emptyList())
|
||||
val replyFuture = SettableFuture.create<Any>()
|
||||
sessionAndProducerPool.run {
|
||||
val message = it.session.createMessage(false)
|
||||
request.writeToClientMessage(kryoPool, message)
|
||||
|
||||
log.debug {
|
||||
val argumentsString = arguments?.joinToString() ?: ""
|
||||
"-> RPC($rpcId) -> ${method.name}($argumentsString): ${method.returnType}"
|
||||
}
|
||||
|
||||
require(rpcReplyMap.put(rpcId, replyFuture) == null) {
|
||||
"Generated several RPC requests with same ID $rpcId"
|
||||
}
|
||||
it.producer.send(message)
|
||||
it.session.commit()
|
||||
}
|
||||
return replyFuture.getOrThrow()
|
||||
} finally {
|
||||
callSiteMap?.remove(rpcId.toLong)
|
||||
}
|
||||
}
|
||||
|
||||
// The handler for Artemis messages.
|
||||
private fun artemisMessageHandler(message: ClientMessage) {
|
||||
val serverToClient = RPCApi.ServerToClient.fromClientMessage(kryoPoolWithObservableContext, message)
|
||||
log.debug { "Got message from RPC server $serverToClient" }
|
||||
when (serverToClient) {
|
||||
is RPCApi.ServerToClient.RpcReply -> {
|
||||
val replyFuture = rpcReplyMap.remove(serverToClient.id)
|
||||
if (replyFuture == null) {
|
||||
log.error("RPC reply arrived to unknown RPC ID ${serverToClient.id}, this indicates an internal RPC error.")
|
||||
} else {
|
||||
val rpcCallSite = callSiteMap?.get(serverToClient.id.toLong)
|
||||
serverToClient.result.match(
|
||||
onError = {
|
||||
if (rpcCallSite != null) addRpcCallSiteToThrowable(it, rpcCallSite)
|
||||
replyFuture.setException(it)
|
||||
},
|
||||
onValue = { replyFuture.set(it) }
|
||||
)
|
||||
}
|
||||
}
|
||||
is RPCApi.ServerToClient.Observation -> {
|
||||
val observable = observableContext.observableMap.getIfPresent(serverToClient.id)
|
||||
if (observable == null) {
|
||||
log.debug("Observation ${serverToClient.content} arrived to unknown Observable with ID ${serverToClient.id}. " +
|
||||
"This may be due to an observation arriving before the server was " +
|
||||
"notified of observable shutdown")
|
||||
} else {
|
||||
// We schedule the onNext() on an executor sticky-pooled based on the Observable ID.
|
||||
observationExecutorPool.run(serverToClient.id) { executor ->
|
||||
executor.submit {
|
||||
val content = serverToClient.content
|
||||
if (content.isOnCompleted || content.isOnError) {
|
||||
observableContext.observableMap.invalidate(serverToClient.id)
|
||||
}
|
||||
// Add call site information on error
|
||||
if (content.isOnError) {
|
||||
val rpcCallSite = callSiteMap?.get(serverToClient.id.toLong)
|
||||
if (rpcCallSite != null) addRpcCallSiteToThrowable(content.throwable, rpcCallSite)
|
||||
}
|
||||
observable.onNext(content)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
message.acknowledge()
|
||||
}
|
||||
|
||||
/**
|
||||
* Closes the RPC proxy. Reaps all observables, shuts down the reaper, closes all sessions and executors.
|
||||
*/
|
||||
fun close() {
|
||||
lifeCycle.transition(State.STARTED, State.FINISHED)
|
||||
sessionAndConsumer.consumer.close()
|
||||
sessionAndConsumer.session.close()
|
||||
sessionAndConsumer.sessionFactory.close()
|
||||
reaperScheduledFuture.cancel(false)
|
||||
observableContext.observableMap.invalidateAll()
|
||||
reapObservables()
|
||||
reaperExecutor.shutdownNow()
|
||||
sessionAndProducerPool.close().forEach {
|
||||
it.producer.close()
|
||||
it.session.close()
|
||||
it.sessionFactory.close()
|
||||
}
|
||||
// Note the ordering is important, we shut down the consumer *before* the observation executor, otherwise we may
|
||||
// leak borrowed executors.
|
||||
val observationExecutors = observationExecutorPool.close()
|
||||
observationExecutors.forEach { it.shutdownNow() }
|
||||
observationExecutors.forEach { it.awaitTermination(100, TimeUnit.MILLISECONDS) }
|
||||
}
|
||||
|
||||
/**
|
||||
* Check the [RPCSinceVersion] of the passed in [calledMethod] against the server's protocol version.
|
||||
*/
|
||||
private fun checkProtocolVersion(calledMethod: Method) {
|
||||
val serverProtocolVersion = serverProtocolVersion
|
||||
if (serverProtocolVersion == null) {
|
||||
lifeCycle.requireState(State.SERVER_VERSION_NOT_SET)
|
||||
} else {
|
||||
lifeCycle.requireState(State.STARTED)
|
||||
val sinceVersion = calledMethod.getAnnotation(RPCSinceVersion::class.java)?.version ?: 0
|
||||
if (sinceVersion > serverProtocolVersion) {
|
||||
throw UnsupportedOperationException("Method $calledMethod was added in RPC protocol version $sinceVersion but the server is running $serverProtocolVersion")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the server's protocol version. Note that before doing so the client is not considered fully started, although
|
||||
* RPCs already may be called with it.
|
||||
*/
|
||||
internal fun setServerProtocolVersion(version: Int) {
|
||||
lifeCycle.transition(State.SERVER_VERSION_NOT_SET, State.STARTED)
|
||||
if (serverProtocolVersion == null) {
|
||||
serverProtocolVersion = version
|
||||
} else {
|
||||
throw IllegalStateException("setServerProtocolVersion called, but the protocol version was already set!")
|
||||
}
|
||||
}
|
||||
|
||||
private fun reapObservables() {
|
||||
observableContext.observableMap.cleanUp()
|
||||
val observableIds = observablesToReap.locked {
|
||||
if (observables.isNotEmpty()) {
|
||||
val temporary = observables
|
||||
observables = ArrayList()
|
||||
temporary
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}
|
||||
if (observableIds != null) {
|
||||
log.debug { "Reaping ${observableIds.size} observables" }
|
||||
sessionAndProducerPool.run {
|
||||
val message = it.session.createMessage(false)
|
||||
RPCApi.ClientToServer.ObservablesClosed(observableIds).writeToClientMessage(message)
|
||||
it.producer.send(message)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private typealias RpcObservableMap = Cache<RPCApi.ObservableId, UnicastSubject<Notification<Any>>>
|
||||
private typealias RpcReplyMap = ConcurrentHashMap<RPCApi.RpcRequestId, SettableFuture<Any?>>
|
||||
private typealias CallSiteMap = ConcurrentHashMap<Long, Throwable?>
|
||||
|
||||
/**
|
||||
* Holds a context available during Kryo deserialisation of messages that are expected to contain Observables.
|
||||
*
|
||||
* @param observableMap holds the Observables that are ultimately exposed to the user.
|
||||
* @param hardReferenceStore holds references to Observables we want to keep alive while they are subscribed to.
|
||||
*/
|
||||
private data class ObservableContext(
|
||||
val callSiteMap: CallSiteMap?,
|
||||
val observableMap: RpcObservableMap,
|
||||
val hardReferenceStore: MutableSet<Observable<*>>
|
||||
)
|
||||
|
||||
/**
|
||||
* A [Serializer] to deserialise Observables once the corresponding Kryo instance has been provided with an [ObservableContext].
|
||||
*/
|
||||
private object RpcClientObservableSerializer : Serializer<Observable<Any>>() {
|
||||
private object RpcObservableContextKey
|
||||
fun createPoolWithContext(kryoPool: KryoPool, observableContext: ObservableContext): KryoPool {
|
||||
return KryoPoolWithContext(kryoPool, RpcObservableContextKey, observableContext)
|
||||
}
|
||||
|
||||
override fun read(kryo: Kryo, input: Input, type: Class<Observable<Any>>): Observable<Any> {
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
val observableContext = kryo.context[RpcObservableContextKey] as ObservableContext
|
||||
val observableId = RPCApi.ObservableId(input.readLong(true))
|
||||
val observable = UnicastSubject.create<Notification<Any>>()
|
||||
require(observableContext.observableMap.getIfPresent(observableId) == null) {
|
||||
"Multiple Observables arrived with the same ID $observableId"
|
||||
}
|
||||
val rpcCallSite = getRpcCallSite(kryo, observableContext)
|
||||
observableContext.observableMap.put(observableId, observable)
|
||||
observableContext.callSiteMap?.put(observableId.toLong, rpcCallSite)
|
||||
// We pin all Observables into a hard reference store (rooted in the RPC proxy) on subscription so that users
|
||||
// don't need to store a reference to the Observables themselves.
|
||||
return observable.pinInSubscriptions(observableContext.hardReferenceStore).doOnUnsubscribe {
|
||||
// This causes Future completions to give warnings because the corresponding OnComplete sent from the server
|
||||
// will arrive after the client unsubscribes from the observable and consequently invalidates the mapping.
|
||||
// The unsubscribe is due to [ObservableToFuture]'s use of first().
|
||||
observableContext.observableMap.invalidate(observableId)
|
||||
}.dematerialize()
|
||||
}
|
||||
|
||||
override fun write(kryo: Kryo, output: Output, observable: Observable<Any>) {
|
||||
throw UnsupportedOperationException("Cannot serialise Observables on the client side")
|
||||
}
|
||||
|
||||
private fun getRpcCallSite(kryo: Kryo, observableContext: ObservableContext): Throwable? {
|
||||
val rpcRequestOrObservableId = kryo.context[RPCApi.RpcRequestOrObservableIdKey] as Long
|
||||
return observableContext.callSiteMap?.get(rpcRequestOrObservableId)
|
||||
}
|
||||
}
|
||||
|
||||
private fun addRpcCallSiteToThrowable(throwable: Throwable, callSite: Throwable) {
|
||||
var currentThrowable = throwable
|
||||
while (true) {
|
||||
val cause = currentThrowable.cause
|
||||
if (cause == null) {
|
||||
currentThrowable.initCause(callSite)
|
||||
break
|
||||
} else {
|
||||
currentThrowable = cause
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun <T> Observable<T>.pinInSubscriptions(hardReferenceStore: MutableSet<Observable<*>>): Observable<T> {
|
||||
val refCount = AtomicInteger(0)
|
||||
return this.doOnSubscribe {
|
||||
if (refCount.getAndIncrement() == 0) {
|
||||
require(hardReferenceStore.add(this)) { "Reference store already contained reference $this on add" }
|
||||
}
|
||||
}.doOnUnsubscribe {
|
||||
if (refCount.decrementAndGet() == 0) {
|
||||
require(hardReferenceStore.remove(this)) { "Reference store did not contain reference $this on remove" }
|
||||
}
|
||||
}
|
||||
}
|
@ -1,102 +0,0 @@
|
||||
package net.corda.client.rpc
|
||||
|
||||
import net.corda.core.messaging.RPCOps
|
||||
import net.corda.core.serialization.SerializedBytes
|
||||
import net.corda.core.utilities.ALICE
|
||||
import net.corda.core.utilities.LogHelper
|
||||
import net.corda.node.services.RPCUserService
|
||||
import net.corda.node.services.messaging.RPCDispatcher
|
||||
import net.corda.node.utilities.AffinityExecutor
|
||||
import net.corda.nodeapi.ArtemisMessagingComponent
|
||||
import net.corda.nodeapi.User
|
||||
import org.apache.activemq.artemis.api.core.Message
|
||||
import org.apache.activemq.artemis.api.core.SimpleString
|
||||
import org.apache.activemq.artemis.api.core.TransportConfiguration
|
||||
import org.apache.activemq.artemis.api.core.client.ActiveMQClient
|
||||
import org.apache.activemq.artemis.api.core.client.ClientMessage
|
||||
import org.apache.activemq.artemis.api.core.client.ClientProducer
|
||||
import org.apache.activemq.artemis.api.core.client.ClientSession
|
||||
import org.apache.activemq.artemis.core.config.impl.ConfigurationImpl
|
||||
import org.apache.activemq.artemis.core.remoting.impl.invm.InVMAcceptorFactory
|
||||
import org.apache.activemq.artemis.core.remoting.impl.invm.InVMConnectorFactory
|
||||
import org.apache.activemq.artemis.core.server.embedded.EmbeddedActiveMQ
|
||||
import org.junit.After
|
||||
import org.junit.Before
|
||||
import java.util.*
|
||||
import java.util.concurrent.locks.ReentrantLock
|
||||
|
||||
abstract class AbstractClientRPCTest {
|
||||
lateinit var artemis: EmbeddedActiveMQ
|
||||
lateinit var serverSession: ClientSession
|
||||
lateinit var clientSession: ClientSession
|
||||
lateinit var producer: ClientProducer
|
||||
lateinit var serverThread: AffinityExecutor.ServiceAffinityExecutor
|
||||
|
||||
@Before
|
||||
fun rpcSetup() {
|
||||
// Set up an in-memory Artemis with an RPC requests queue.
|
||||
artemis = EmbeddedActiveMQ()
|
||||
artemis.setConfiguration(ConfigurationImpl().apply {
|
||||
acceptorConfigurations = setOf(TransportConfiguration(InVMAcceptorFactory::class.java.name))
|
||||
isSecurityEnabled = false
|
||||
isPersistenceEnabled = false
|
||||
})
|
||||
artemis.start()
|
||||
|
||||
val serverLocator = ActiveMQClient.createServerLocatorWithoutHA(TransportConfiguration(InVMConnectorFactory::class.java.name))
|
||||
val sessionFactory = serverLocator.createSessionFactory()
|
||||
serverSession = sessionFactory.createSession()
|
||||
serverSession.start()
|
||||
|
||||
serverSession.createTemporaryQueue(ArtemisMessagingComponent.RPC_REQUESTS_QUEUE, ArtemisMessagingComponent.RPC_REQUESTS_QUEUE)
|
||||
producer = serverSession.createProducer()
|
||||
serverThread = AffinityExecutor.ServiceAffinityExecutor("unit-tests-rpc-dispatch-thread", 1)
|
||||
serverSession.createTemporaryQueue("activemq.notifications", "rpc.qremovals", "_AMQ_NotifType = 'BINDING_REMOVED'")
|
||||
|
||||
clientSession = sessionFactory.createSession()
|
||||
clientSession.start()
|
||||
|
||||
LogHelper.setLevel("+net.corda.rpc")
|
||||
}
|
||||
|
||||
@After
|
||||
fun rpcShutdown() {
|
||||
safeClose(producer)
|
||||
clientSession.stop()
|
||||
serverSession.stop()
|
||||
artemis.stop()
|
||||
serverThread.shutdownNow()
|
||||
}
|
||||
|
||||
fun <T : RPCOps> rpcProxyFor(rpcUser: User, rpcImpl: T, type: Class<T>): T {
|
||||
val userService = object : RPCUserService {
|
||||
override fun getUser(username: String): User? = if (username == rpcUser.username) rpcUser else null
|
||||
override val users: List<User> get() = listOf(rpcUser)
|
||||
}
|
||||
|
||||
val dispatcher = object : RPCDispatcher(rpcImpl, userService, ALICE.name) {
|
||||
override fun send(data: SerializedBytes<*>, toAddress: String) {
|
||||
val msg = serverSession.createMessage(false).apply {
|
||||
writeBodyBufferBytes(data.bytes)
|
||||
// Use the magic deduplication property built into Artemis as our message identity too
|
||||
putStringProperty(Message.HDR_DUPLICATE_DETECTION_ID, SimpleString(UUID.randomUUID().toString()))
|
||||
}
|
||||
producer.send(toAddress, msg)
|
||||
}
|
||||
|
||||
override fun getUser(message: ClientMessage): User = rpcUser
|
||||
}
|
||||
|
||||
val serverNotifConsumer = serverSession.createConsumer("rpc.qremovals")
|
||||
val serverConsumer = serverSession.createConsumer(ArtemisMessagingComponent.RPC_REQUESTS_QUEUE)
|
||||
dispatcher.start(serverConsumer, serverNotifConsumer, serverThread)
|
||||
return CordaRPCClientImpl(clientSession, ReentrantLock(), rpcUser.username).proxyFor(type)
|
||||
}
|
||||
|
||||
fun safeClose(obj: Any) {
|
||||
try {
|
||||
(obj as AutoCloseable).close()
|
||||
} catch (e: Exception) {
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,56 @@
|
||||
package net.corda.client.rpc
|
||||
|
||||
import net.corda.client.rpc.internal.RPCClientConfiguration
|
||||
import net.corda.core.flatMap
|
||||
import net.corda.core.map
|
||||
import net.corda.core.messaging.RPCOps
|
||||
import net.corda.node.services.messaging.RPCServerConfiguration
|
||||
import net.corda.nodeapi.User
|
||||
import net.corda.testing.RPCDriverExposedDSLInterface
|
||||
import net.corda.testing.rpcTestUser
|
||||
import net.corda.testing.startInVmRpcClient
|
||||
import net.corda.testing.startRpcClient
|
||||
import org.apache.activemq.artemis.api.core.client.ClientSession
|
||||
import org.junit.runners.Parameterized
|
||||
|
||||
open class AbstractRPCTest {
|
||||
enum class RPCTestMode {
|
||||
InVm,
|
||||
Netty
|
||||
}
|
||||
|
||||
companion object {
|
||||
@JvmStatic @Parameterized.Parameters(name = "Mode = {0}")
|
||||
fun defaultModes() = modes(RPCTestMode.InVm, RPCTestMode.Netty)
|
||||
fun modes(vararg modes: RPCTestMode) = listOf(*modes).map { arrayOf(it) }
|
||||
}
|
||||
@Parameterized.Parameter
|
||||
lateinit var mode: RPCTestMode
|
||||
|
||||
data class TestProxy<out I : RPCOps>(
|
||||
val ops: I,
|
||||
val createSession: () -> ClientSession
|
||||
)
|
||||
|
||||
inline fun <reified I : RPCOps> RPCDriverExposedDSLInterface.testProxy(
|
||||
ops: I,
|
||||
rpcUser: User = rpcTestUser,
|
||||
clientConfiguration: RPCClientConfiguration = RPCClientConfiguration.default,
|
||||
serverConfiguration: RPCServerConfiguration = RPCServerConfiguration.default
|
||||
): TestProxy<I> {
|
||||
return when (mode) {
|
||||
RPCTestMode.InVm ->
|
||||
startInVmRpcServer(ops = ops, rpcUser = rpcUser, configuration = serverConfiguration).flatMap {
|
||||
startInVmRpcClient<I>(rpcUser.username, rpcUser.password, clientConfiguration).map {
|
||||
TestProxy(it, { startInVmArtemisSession(rpcUser.username, rpcUser.password) })
|
||||
}
|
||||
}.get()
|
||||
RPCTestMode.Netty ->
|
||||
startRpcServer(ops = ops, rpcUser = rpcUser, configuration = serverConfiguration).flatMap { server ->
|
||||
startRpcClient<I>(server.hostAndPort, rpcUser.username, rpcUser.password, clientConfiguration).map {
|
||||
TestProxy(it, { startArtemisSession(server.hostAndPort, rpcUser.username, rpcUser.password) })
|
||||
}
|
||||
}.get()
|
||||
}
|
||||
}
|
||||
}
|
@ -5,16 +5,16 @@ import com.google.common.util.concurrent.ListenableFuture
|
||||
import com.google.common.util.concurrent.SettableFuture
|
||||
import net.corda.core.getOrThrow
|
||||
import net.corda.core.messaging.RPCOps
|
||||
import net.corda.core.messaging.RPCReturnsObservables
|
||||
import net.corda.core.success
|
||||
import net.corda.nodeapi.CURRENT_RPC_USER
|
||||
import net.corda.node.services.messaging.getRpcContext
|
||||
import net.corda.nodeapi.RPCSinceVersion
|
||||
import net.corda.nodeapi.User
|
||||
import org.apache.activemq.artemis.api.core.SimpleString
|
||||
import net.corda.testing.RPCDriverExposedDSLInterface
|
||||
import net.corda.testing.rpcDriver
|
||||
import net.corda.testing.rpcTestUser
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.junit.After
|
||||
import org.junit.Before
|
||||
import org.junit.Test
|
||||
import org.junit.runner.RunWith
|
||||
import org.junit.runners.Parameterized
|
||||
import rx.Observable
|
||||
import rx.subjects.PublishSubject
|
||||
import java.util.concurrent.CountDownLatch
|
||||
@ -23,22 +23,11 @@ import kotlin.test.assertEquals
|
||||
import kotlin.test.assertFailsWith
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
class ClientRPCInfrastructureTests : AbstractClientRPCTest() {
|
||||
@RunWith(Parameterized::class)
|
||||
class ClientRPCInfrastructureTests : AbstractRPCTest() {
|
||||
// TODO: Test that timeouts work
|
||||
|
||||
lateinit var proxy: TestOps
|
||||
|
||||
private val authenticatedUser = User("test", "password", permissions = setOf())
|
||||
|
||||
@Before
|
||||
fun setup() {
|
||||
proxy = rpcProxyFor(authenticatedUser, TestOpsImpl(), TestOps::class.java)
|
||||
}
|
||||
|
||||
@After
|
||||
fun shutdown() {
|
||||
safeClose(proxy)
|
||||
}
|
||||
private fun RPCDriverExposedDSLInterface.testProxy() = testProxy<TestOps>(TestOpsImpl()).ops
|
||||
|
||||
interface TestOps : RPCOps {
|
||||
@Throws(IllegalArgumentException::class)
|
||||
@ -48,16 +37,12 @@ class ClientRPCInfrastructureTests : AbstractClientRPCTest() {
|
||||
|
||||
fun someCalculation(str: String, num: Int): String
|
||||
|
||||
@RPCReturnsObservables
|
||||
fun makeObservable(): Observable<Int>
|
||||
|
||||
@RPCReturnsObservables
|
||||
fun makeComplicatedObservable(): Observable<Pair<String, Observable<String>>>
|
||||
|
||||
@RPCReturnsObservables
|
||||
fun makeListenableFuture(): ListenableFuture<Int>
|
||||
|
||||
@RPCReturnsObservables
|
||||
fun makeComplicatedListenableFuture(): ListenableFuture<Pair<String, ListenableFuture<String>>>
|
||||
|
||||
@RPCSinceVersion(2)
|
||||
@ -78,117 +63,130 @@ class ClientRPCInfrastructureTests : AbstractClientRPCTest() {
|
||||
override fun makeListenableFuture(): ListenableFuture<Int> = Futures.immediateFuture(1)
|
||||
override fun makeComplicatedObservable() = complicatedObservable
|
||||
override fun makeComplicatedListenableFuture(): ListenableFuture<Pair<String, ListenableFuture<String>>> = complicatedListenableFuturee
|
||||
override fun addedLater(): Unit = throw UnsupportedOperationException("not implemented")
|
||||
override fun captureUser(): String = CURRENT_RPC_USER.get().username
|
||||
override fun addedLater(): Unit = throw IllegalStateException()
|
||||
override fun captureUser(): String = getRpcContext().currentUser.username
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `simple RPCs`() {
|
||||
// Does nothing, doesn't throw.
|
||||
proxy.void()
|
||||
rpcDriver {
|
||||
val proxy = testProxy()
|
||||
// Does nothing, doesn't throw.
|
||||
proxy.void()
|
||||
|
||||
assertEquals("Barf!", assertFailsWith<IllegalArgumentException> {
|
||||
proxy.barf()
|
||||
}.message)
|
||||
assertEquals("Barf!", assertFailsWith<IllegalArgumentException> {
|
||||
proxy.barf()
|
||||
}.message)
|
||||
|
||||
assertEquals("hi 5", proxy.someCalculation("hi", 5))
|
||||
assertEquals("hi 5", proxy.someCalculation("hi", 5))
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `simple observable`() {
|
||||
// This tests that the observations are transmitted correctly, also completion is transmitted.
|
||||
val observations = proxy.makeObservable().toBlocking().toIterable().toList()
|
||||
assertEquals(listOf(1, 2, 3, 4), observations)
|
||||
rpcDriver {
|
||||
val proxy = testProxy()
|
||||
// This tests that the observations are transmitted correctly, also completion is transmitted.
|
||||
val observations = proxy.makeObservable().toBlocking().toIterable().toList()
|
||||
assertEquals(listOf(1, 2, 3, 4), observations)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `complex observables`() {
|
||||
// This checks that we can return an object graph with complex usage of observables, like an observable
|
||||
// that emits objects that contain more observables.
|
||||
val serverQuotes = PublishSubject.create<Pair<String, Observable<String>>>()
|
||||
val unsubscribeLatch = CountDownLatch(1)
|
||||
complicatedObservable = serverQuotes.asObservable().doOnUnsubscribe { unsubscribeLatch.countDown() }
|
||||
rpcDriver {
|
||||
val proxy = testProxy()
|
||||
// This checks that we can return an object graph with complex usage of observables, like an observable
|
||||
// that emits objects that contain more observables.
|
||||
val serverQuotes = PublishSubject.create<Pair<String, Observable<String>>>()
|
||||
val unsubscribeLatch = CountDownLatch(1)
|
||||
complicatedObservable = serverQuotes.asObservable().doOnUnsubscribe { unsubscribeLatch.countDown() }
|
||||
|
||||
val twainQuotes = "Mark Twain" to Observable.just(
|
||||
"I have never let my schooling interfere with my education.",
|
||||
"Clothes make the man. Naked people have little or no influence on society."
|
||||
)
|
||||
val wildeQuotes = "Oscar Wilde" to Observable.just(
|
||||
"I can resist everything except temptation.",
|
||||
"Always forgive your enemies - nothing annoys them so much."
|
||||
)
|
||||
val twainQuotes = "Mark Twain" to Observable.just(
|
||||
"I have never let my schooling interfere with my education.",
|
||||
"Clothes make the man. Naked people have little or no influence on society."
|
||||
)
|
||||
val wildeQuotes = "Oscar Wilde" to Observable.just(
|
||||
"I can resist everything except temptation.",
|
||||
"Always forgive your enemies - nothing annoys them so much."
|
||||
)
|
||||
|
||||
val clientQuotes = LinkedBlockingQueue<String>()
|
||||
val clientObs = proxy.makeComplicatedObservable()
|
||||
val clientQuotes = LinkedBlockingQueue<String>()
|
||||
val clientObs = proxy.makeComplicatedObservable()
|
||||
|
||||
val subscription = clientObs.subscribe {
|
||||
val name = it.first
|
||||
it.second.subscribe {
|
||||
clientQuotes += "Quote by $name: $it"
|
||||
val subscription = clientObs.subscribe {
|
||||
val name = it.first
|
||||
it.second.subscribe {
|
||||
clientQuotes += "Quote by $name: $it"
|
||||
}
|
||||
}
|
||||
|
||||
assertThat(clientQuotes).isEmpty()
|
||||
|
||||
serverQuotes.onNext(twainQuotes)
|
||||
assertEquals("Quote by Mark Twain: I have never let my schooling interfere with my education.", clientQuotes.take())
|
||||
assertEquals("Quote by Mark Twain: Clothes make the man. Naked people have little or no influence on society.", clientQuotes.take())
|
||||
|
||||
serverQuotes.onNext(wildeQuotes)
|
||||
assertEquals("Quote by Oscar Wilde: I can resist everything except temptation.", clientQuotes.take())
|
||||
assertEquals("Quote by Oscar Wilde: Always forgive your enemies - nothing annoys them so much.", clientQuotes.take())
|
||||
|
||||
assertTrue(serverQuotes.hasObservers())
|
||||
subscription.unsubscribe()
|
||||
unsubscribeLatch.await()
|
||||
}
|
||||
|
||||
val rpcQueuesQuery = SimpleString("clients.${authenticatedUser.username}.rpc.*")
|
||||
assertEquals(2, clientSession.addressQuery(rpcQueuesQuery).queueNames.size)
|
||||
|
||||
assertThat(clientQuotes).isEmpty()
|
||||
|
||||
serverQuotes.onNext(twainQuotes)
|
||||
assertEquals("Quote by Mark Twain: I have never let my schooling interfere with my education.", clientQuotes.take())
|
||||
assertEquals("Quote by Mark Twain: Clothes make the man. Naked people have little or no influence on society.", clientQuotes.take())
|
||||
|
||||
serverQuotes.onNext(wildeQuotes)
|
||||
assertEquals("Quote by Oscar Wilde: I can resist everything except temptation.", clientQuotes.take())
|
||||
assertEquals("Quote by Oscar Wilde: Always forgive your enemies - nothing annoys them so much.", clientQuotes.take())
|
||||
|
||||
assertTrue(serverQuotes.hasObservers())
|
||||
subscription.unsubscribe()
|
||||
unsubscribeLatch.await()
|
||||
assertEquals(1, clientSession.addressQuery(rpcQueuesQuery).queueNames.size)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `simple ListenableFuture`() {
|
||||
val value = proxy.makeListenableFuture().getOrThrow()
|
||||
assertThat(value).isEqualTo(1)
|
||||
rpcDriver {
|
||||
val proxy = testProxy()
|
||||
val value = proxy.makeListenableFuture().getOrThrow()
|
||||
assertThat(value).isEqualTo(1)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `complex ListenableFuture`() {
|
||||
val serverQuote = SettableFuture.create<Pair<String, ListenableFuture<String>>>()
|
||||
complicatedListenableFuturee = serverQuote
|
||||
rpcDriver {
|
||||
val proxy = testProxy()
|
||||
val serverQuote = SettableFuture.create<Pair<String, ListenableFuture<String>>>()
|
||||
complicatedListenableFuturee = serverQuote
|
||||
|
||||
val twainQuote = "Mark Twain" to Futures.immediateFuture("I have never let my schooling interfere with my education.")
|
||||
val twainQuote = "Mark Twain" to Futures.immediateFuture("I have never let my schooling interfere with my education.")
|
||||
|
||||
val clientQuotes = LinkedBlockingQueue<String>()
|
||||
val clientFuture = proxy.makeComplicatedListenableFuture()
|
||||
val clientQuotes = LinkedBlockingQueue<String>()
|
||||
val clientFuture = proxy.makeComplicatedListenableFuture()
|
||||
|
||||
clientFuture.success {
|
||||
val name = it.first
|
||||
it.second.success {
|
||||
clientQuotes += "Quote by $name: $it"
|
||||
clientFuture.success {
|
||||
val name = it.first
|
||||
it.second.success {
|
||||
clientQuotes += "Quote by $name: $it"
|
||||
}
|
||||
}
|
||||
|
||||
assertThat(clientQuotes).isEmpty()
|
||||
|
||||
serverQuote.set(twainQuote)
|
||||
assertThat(clientQuotes.take()).isEqualTo("Quote by Mark Twain: I have never let my schooling interfere with my education.")
|
||||
|
||||
// TODO This final assert sometimes fails because the relevant queue hasn't been removed yet
|
||||
}
|
||||
|
||||
val rpcQueuesQuery = SimpleString("clients.${authenticatedUser.username}.rpc.*")
|
||||
assertEquals(2, clientSession.addressQuery(rpcQueuesQuery).queueNames.size)
|
||||
|
||||
assertThat(clientQuotes).isEmpty()
|
||||
|
||||
serverQuote.set(twainQuote)
|
||||
assertThat(clientQuotes.take()).isEqualTo("Quote by Mark Twain: I have never let my schooling interfere with my education.")
|
||||
|
||||
// TODO This final assert sometimes fails because the relevant queue hasn't been removed yet
|
||||
// assertEquals(1, clientSession.addressQuery(rpcQueuesQuery).queueNames.size)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun versioning() {
|
||||
assertFailsWith<UnsupportedOperationException> { proxy.addedLater() }
|
||||
rpcDriver {
|
||||
val proxy = testProxy()
|
||||
assertFailsWith<UnsupportedOperationException> { proxy.addedLater() }
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `authenticated user is available to RPC`() {
|
||||
assertThat(proxy.captureUser()).isEqualTo(authenticatedUser.username)
|
||||
rpcDriver {
|
||||
val proxy = testProxy()
|
||||
assertThat(proxy.captureUser()).isEqualTo(rpcTestUser.username)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,194 @@
|
||||
package net.corda.client.rpc
|
||||
|
||||
import com.google.common.util.concurrent.Futures
|
||||
import com.google.common.util.concurrent.ListenableFuture
|
||||
import net.corda.client.rpc.internal.RPCClientConfiguration
|
||||
import net.corda.core.future
|
||||
import net.corda.core.messaging.RPCOps
|
||||
import net.corda.core.random63BitValue
|
||||
import net.corda.core.serialization.CordaSerializable
|
||||
import net.corda.core.utilities.loggerFor
|
||||
import net.corda.node.driver.poll
|
||||
import net.corda.node.services.messaging.RPCServerConfiguration
|
||||
import net.corda.nodeapi.RPCApi
|
||||
import net.corda.testing.RPCDriverExposedDSLInterface
|
||||
import net.corda.testing.rpcDriver
|
||||
import net.corda.testing.startRandomRpcClient
|
||||
import net.corda.testing.startRpcClient
|
||||
import org.apache.activemq.artemis.api.core.SimpleString
|
||||
import org.junit.Ignore
|
||||
import org.junit.Test
|
||||
import org.junit.runner.RunWith
|
||||
import org.junit.runners.Parameterized
|
||||
import rx.Observable
|
||||
import rx.subjects.PublishSubject
|
||||
import rx.subjects.UnicastSubject
|
||||
import java.util.*
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import java.util.concurrent.CountDownLatch
|
||||
import java.util.concurrent.Executors
|
||||
import java.util.concurrent.TimeUnit
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
|
||||
@RunWith(Parameterized::class)
|
||||
class RPCConcurrencyTests : AbstractRPCTest() {
|
||||
|
||||
/**
|
||||
* Holds a "rose"-tree of [Observable]s which allows us to test arbitrary [Observable] nesting in RPC replies.
|
||||
*/
|
||||
@CordaSerializable
|
||||
data class ObservableRose<out A>(val value: A, val branches: Observable<out ObservableRose<A>>)
|
||||
|
||||
private interface TestOps : RPCOps {
|
||||
fun newLatch(numberOfDowns: Int): Long
|
||||
fun waitLatch(id: Long)
|
||||
fun downLatch(id: Long)
|
||||
fun getImmediateObservableTree(depth: Int, branchingFactor: Int): ObservableRose<Int>
|
||||
fun getParallelObservableTree(depth: Int, branchingFactor: Int): ObservableRose<Int>
|
||||
}
|
||||
|
||||
class TestOpsImpl : TestOps {
|
||||
private val latches = ConcurrentHashMap<Long, CountDownLatch>()
|
||||
override val protocolVersion = 0
|
||||
|
||||
override fun newLatch(numberOfDowns: Int): Long {
|
||||
val id = random63BitValue()
|
||||
val latch = CountDownLatch(numberOfDowns)
|
||||
latches.put(id, latch)
|
||||
return id
|
||||
}
|
||||
|
||||
override fun waitLatch(id: Long) {
|
||||
latches[id]!!.await()
|
||||
}
|
||||
|
||||
override fun downLatch(id: Long) {
|
||||
latches[id]!!.countDown()
|
||||
}
|
||||
|
||||
override fun getImmediateObservableTree(depth: Int, branchingFactor: Int): ObservableRose<Int> {
|
||||
val branches = if (depth == 0) {
|
||||
Observable.empty<ObservableRose<Int>>()
|
||||
} else {
|
||||
Observable.just(getImmediateObservableTree(depth - 1, branchingFactor)).repeat(branchingFactor.toLong())
|
||||
}
|
||||
return ObservableRose(depth, branches)
|
||||
}
|
||||
|
||||
override fun getParallelObservableTree(depth: Int, branchingFactor: Int): ObservableRose<Int> {
|
||||
val branches = if (depth == 0) {
|
||||
Observable.empty<ObservableRose<Int>>()
|
||||
} else {
|
||||
val publish = UnicastSubject.create<ObservableRose<Int>>()
|
||||
future {
|
||||
(1..branchingFactor).toList().parallelStream().forEach {
|
||||
publish.onNext(getParallelObservableTree(depth - 1, branchingFactor))
|
||||
}
|
||||
publish.onCompleted()
|
||||
}
|
||||
publish
|
||||
}
|
||||
return ObservableRose(depth, branches)
|
||||
}
|
||||
}
|
||||
|
||||
private lateinit var testOpsImpl: TestOpsImpl
|
||||
private fun RPCDriverExposedDSLInterface.testProxy(): TestProxy<TestOps> {
|
||||
testOpsImpl = TestOpsImpl()
|
||||
return testProxy<TestOps>(
|
||||
testOpsImpl,
|
||||
clientConfiguration = RPCClientConfiguration.default.copy(
|
||||
reapIntervalMs = 100,
|
||||
cacheConcurrencyLevel = 16
|
||||
),
|
||||
serverConfiguration = RPCServerConfiguration.default.copy(
|
||||
rpcThreadPoolSize = 4
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `call multiple RPCs in parallel`() {
|
||||
rpcDriver {
|
||||
val proxy = testProxy()
|
||||
val numberOfBlockedCalls = 2
|
||||
val numberOfDownsRequired = 100
|
||||
val id = proxy.ops.newLatch(numberOfDownsRequired)
|
||||
val done = CountDownLatch(numberOfBlockedCalls)
|
||||
// Start a couple of blocking RPC calls
|
||||
(1..numberOfBlockedCalls).forEach {
|
||||
future {
|
||||
proxy.ops.waitLatch(id)
|
||||
done.countDown()
|
||||
}
|
||||
}
|
||||
// Down the latch that the others are waiting for concurrently
|
||||
(1..numberOfDownsRequired).toList().parallelStream().forEach {
|
||||
proxy.ops.downLatch(id)
|
||||
}
|
||||
done.await()
|
||||
}
|
||||
}
|
||||
|
||||
private fun intPower(base: Int, power: Int): Int {
|
||||
return when (power) {
|
||||
0 -> 1
|
||||
1 -> base
|
||||
else -> {
|
||||
val a = intPower(base, power / 2)
|
||||
if (power and 1 == 0) {
|
||||
a * a
|
||||
} else {
|
||||
a * a * base
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `nested immediate observables sequence correctly`() {
|
||||
rpcDriver {
|
||||
// We construct a rose tree of immediate Observables and check that parent observations arrive before children.
|
||||
val proxy = testProxy()
|
||||
val treeDepth = 6
|
||||
val treeBranchingFactor = 3
|
||||
val remainingLatch = CountDownLatch((intPower(treeBranchingFactor, treeDepth + 1) - 1) / (treeBranchingFactor - 1))
|
||||
val depthsSeen = Collections.synchronizedSet(HashSet<Int>())
|
||||
fun ObservableRose<Int>.subscribeToAll() {
|
||||
remainingLatch.countDown()
|
||||
this.branches.subscribe { tree ->
|
||||
(tree.value + 1..treeDepth - 1).forEach {
|
||||
require(it in depthsSeen) { "Got ${tree.value} before $it" }
|
||||
}
|
||||
depthsSeen.add(tree.value)
|
||||
tree.subscribeToAll()
|
||||
}
|
||||
}
|
||||
proxy.ops.getImmediateObservableTree(treeDepth, treeBranchingFactor).subscribeToAll()
|
||||
remainingLatch.await()
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `parallel nested observables`() {
|
||||
rpcDriver {
|
||||
val proxy = testProxy()
|
||||
val treeDepth = 2
|
||||
val treeBranchingFactor = 10
|
||||
val remainingLatch = CountDownLatch((intPower(treeBranchingFactor, treeDepth + 1) - 1) / (treeBranchingFactor - 1))
|
||||
val depthsSeen = Collections.synchronizedSet(HashSet<Int>())
|
||||
fun ObservableRose<Int>.subscribeToAll() {
|
||||
remainingLatch.countDown()
|
||||
branches.subscribe { tree ->
|
||||
(tree.value + 1..treeDepth - 1).forEach {
|
||||
require(it in depthsSeen) { "Got ${tree.value} before $it" }
|
||||
}
|
||||
depthsSeen.add(tree.value)
|
||||
tree.subscribeToAll()
|
||||
}
|
||||
}
|
||||
proxy.ops.getParallelObservableTree(treeDepth, treeBranchingFactor).subscribeToAll()
|
||||
remainingLatch.await()
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,315 @@
|
||||
package net.corda.client.rpc
|
||||
|
||||
import com.codahale.metrics.Gauge
|
||||
import com.codahale.metrics.JmxReporter
|
||||
import com.codahale.metrics.MetricRegistry
|
||||
import com.esotericsoftware.kryo.Kryo
|
||||
import com.esotericsoftware.kryo.Serializer
|
||||
import com.esotericsoftware.kryo.io.Input
|
||||
import com.esotericsoftware.kryo.io.Output
|
||||
import com.esotericsoftware.kryo.pool.KryoPool
|
||||
import com.google.common.base.Stopwatch
|
||||
import net.corda.client.rpc.internal.RPCClientConfiguration
|
||||
import net.corda.core.messaging.RPCOps
|
||||
import net.corda.core.millis
|
||||
import net.corda.core.random63BitValue
|
||||
import net.corda.node.driver.ShutdownManager
|
||||
import net.corda.node.services.messaging.RPCServerConfiguration
|
||||
import net.corda.nodeapi.RPCApi
|
||||
import net.corda.nodeapi.RPCKryo
|
||||
import net.corda.testing.RPCDriverExposedDSLInterface
|
||||
import net.corda.testing.measure
|
||||
import net.corda.testing.rpcDriver
|
||||
import org.apache.activemq.artemis.api.core.SimpleString
|
||||
import org.junit.Ignore
|
||||
import org.junit.Test
|
||||
import org.junit.runner.RunWith
|
||||
import org.junit.runners.Parameterized
|
||||
import rx.Observable
|
||||
import java.time.Duration
|
||||
import java.util.*
|
||||
import java.util.concurrent.*
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
import java.util.concurrent.locks.ReentrantLock
|
||||
import javax.management.ObjectName
|
||||
import kotlin.concurrent.thread
|
||||
import kotlin.concurrent.withLock
|
||||
|
||||
@Ignore("Only use this locally for profiling")
|
||||
@RunWith(Parameterized::class)
|
||||
class RPCPerformanceTests : AbstractRPCTest() {
|
||||
companion object {
|
||||
@JvmStatic @Parameterized.Parameters(name = "Mode = {0}")
|
||||
fun modes() = modes(RPCTestMode.Netty)
|
||||
}
|
||||
private interface TestOps : RPCOps {
|
||||
fun simpleReply(input: ByteArray, sizeOfReply: Int): ByteArray
|
||||
}
|
||||
|
||||
class TestOpsImpl : TestOps {
|
||||
override val protocolVersion = 0
|
||||
override fun simpleReply(input: ByteArray, sizeOfReply: Int): ByteArray {
|
||||
return ByteArray(sizeOfReply)
|
||||
}
|
||||
}
|
||||
|
||||
private fun RPCDriverExposedDSLInterface.testProxy(
|
||||
clientConfiguration: RPCClientConfiguration,
|
||||
serverConfiguration: RPCServerConfiguration
|
||||
): TestProxy<TestOps> {
|
||||
return testProxy<TestOps>(
|
||||
TestOpsImpl(),
|
||||
clientConfiguration = clientConfiguration,
|
||||
serverConfiguration = serverConfiguration
|
||||
)
|
||||
}
|
||||
|
||||
private fun warmup() {
|
||||
rpcDriver {
|
||||
val proxy = testProxy(
|
||||
RPCClientConfiguration.default,
|
||||
RPCServerConfiguration.default
|
||||
)
|
||||
val executor = Executors.newFixedThreadPool(4)
|
||||
val N = 10000
|
||||
val latch = CountDownLatch(N)
|
||||
for (i in 1 .. N) {
|
||||
executor.submit {
|
||||
proxy.ops.simpleReply(ByteArray(1024), 1024)
|
||||
latch.countDown()
|
||||
}
|
||||
}
|
||||
latch.await()
|
||||
}
|
||||
}
|
||||
|
||||
data class SimpleRPCResult(
|
||||
val requestPerSecond: Double,
|
||||
val averageIndividualMs: Double,
|
||||
val Mbps: Double
|
||||
)
|
||||
@Test
|
||||
fun `measure Megabytes per second for simple RPCs`() {
|
||||
warmup()
|
||||
val inputOutputSizes = listOf(1024, 4096, 100 * 1024)
|
||||
val overallTraffic = 512 * 1024 * 1024L
|
||||
measure(inputOutputSizes, (1..5)) { inputOutputSize, N ->
|
||||
rpcDriver {
|
||||
val proxy = testProxy(
|
||||
RPCClientConfiguration.default.copy(
|
||||
cacheConcurrencyLevel = 16,
|
||||
observationExecutorPoolSize = 2,
|
||||
producerPoolBound = 2
|
||||
),
|
||||
RPCServerConfiguration.default.copy(
|
||||
rpcThreadPoolSize = 8,
|
||||
consumerPoolSize = 2,
|
||||
producerPoolBound = 8
|
||||
)
|
||||
)
|
||||
|
||||
val numberOfRequests = overallTraffic / (2 * inputOutputSize)
|
||||
val timings = Collections.synchronizedList(ArrayList<Long>())
|
||||
val executor = Executors.newFixedThreadPool(8)
|
||||
val totalElapsed = Stopwatch.createStarted().apply {
|
||||
startInjectorWithBoundedQueue(
|
||||
executor = executor,
|
||||
numberOfInjections = numberOfRequests.toInt(),
|
||||
queueBound = 100
|
||||
) {
|
||||
val elapsed = Stopwatch.createStarted().apply {
|
||||
proxy.ops.simpleReply(ByteArray(inputOutputSize), inputOutputSize)
|
||||
}.stop().elapsed(TimeUnit.MICROSECONDS)
|
||||
timings.add(elapsed)
|
||||
}
|
||||
}.stop().elapsed(TimeUnit.MICROSECONDS)
|
||||
executor.shutdownNow()
|
||||
SimpleRPCResult(
|
||||
requestPerSecond = 1000000.0 * numberOfRequests.toDouble() / totalElapsed.toDouble(),
|
||||
averageIndividualMs = timings.average() / 1000.0,
|
||||
Mbps = (overallTraffic.toDouble() / totalElapsed.toDouble()) * (1000000.0 / (1024.0 * 1024.0))
|
||||
)
|
||||
}
|
||||
}.forEach(::println)
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs 20k RPCs per second for two minutes and publishes relevant stats to JMX.
|
||||
*/
|
||||
@Test
|
||||
fun `consumption rate`() {
|
||||
rpcDriver {
|
||||
val metricRegistry = startJmxReporter()
|
||||
val proxy = testProxy(
|
||||
RPCClientConfiguration.default.copy(
|
||||
reapIntervalMs = 100,
|
||||
cacheConcurrencyLevel = 16
|
||||
),
|
||||
RPCServerConfiguration.default.copy(
|
||||
rpcThreadPoolSize = 4,
|
||||
consumerPoolSize = 4,
|
||||
producerPoolBound = 4
|
||||
)
|
||||
)
|
||||
measurePerformancePublishMetrics(
|
||||
metricRegistry = metricRegistry,
|
||||
parallelism = 4,
|
||||
overallDurationSecond = 120.0,
|
||||
injectionRatePerSecond = 20000.0,
|
||||
queueSizeMetricName = "$mode.QueueSize",
|
||||
workDurationMetricName = "$mode.WorkDuration",
|
||||
shutdownManager = this.shutdownManager,
|
||||
work = {
|
||||
proxy.ops.simpleReply(ByteArray(4096), 4096)
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
data class BigMessagesResult(
|
||||
val Mbps: Double
|
||||
)
|
||||
@Test
|
||||
fun `big messages`() {
|
||||
warmup()
|
||||
measure(listOf(1)) { clientParallelism -> // TODO this hangs with more parallelism
|
||||
rpcDriver {
|
||||
val proxy = testProxy(
|
||||
RPCClientConfiguration.default,
|
||||
RPCServerConfiguration.default.copy(
|
||||
consumerPoolSize = 1
|
||||
)
|
||||
)
|
||||
val executor = Executors.newFixedThreadPool(clientParallelism)
|
||||
val numberOfMessages = 1000
|
||||
val bigSize = 10_000_000
|
||||
val elapsed = Stopwatch.createStarted().apply {
|
||||
startInjectorWithBoundedQueue(
|
||||
executor = executor,
|
||||
numberOfInjections = numberOfMessages,
|
||||
queueBound = 4
|
||||
) {
|
||||
proxy.ops.simpleReply(ByteArray(bigSize), 0)
|
||||
}
|
||||
}.stop().elapsed(TimeUnit.MICROSECONDS)
|
||||
executor.shutdownNow()
|
||||
BigMessagesResult(
|
||||
Mbps = bigSize.toDouble() * numberOfMessages.toDouble() / elapsed * (1000000.0 / (1024.0 * 1024.0))
|
||||
)
|
||||
}
|
||||
}.forEach(::println)
|
||||
}
|
||||
}
|
||||
|
||||
fun measurePerformancePublishMetrics(
|
||||
metricRegistry: MetricRegistry,
|
||||
parallelism: Int,
|
||||
overallDurationSecond: Double,
|
||||
injectionRatePerSecond: Double,
|
||||
queueSizeMetricName: String,
|
||||
workDurationMetricName: String,
|
||||
shutdownManager: ShutdownManager,
|
||||
work: () -> Unit
|
||||
) {
|
||||
val workSemaphore = Semaphore(0)
|
||||
metricRegistry.register(queueSizeMetricName, Gauge { workSemaphore.availablePermits() })
|
||||
val workDurationTimer = metricRegistry.timer(workDurationMetricName)
|
||||
val executor = Executors.newSingleThreadScheduledExecutor()
|
||||
val workExecutor = Executors.newFixedThreadPool(parallelism)
|
||||
val timings = Collections.synchronizedList(ArrayList<Long>())
|
||||
for (i in 1 .. parallelism) {
|
||||
workExecutor.submit {
|
||||
try {
|
||||
while (true) {
|
||||
workSemaphore.acquire()
|
||||
workDurationTimer.time {
|
||||
timings.add(
|
||||
Stopwatch.createStarted().apply {
|
||||
work()
|
||||
}.stop().elapsed(TimeUnit.MICROSECONDS)
|
||||
)
|
||||
}
|
||||
}
|
||||
} catch (throwable: Throwable) {
|
||||
throwable.printStackTrace()
|
||||
}
|
||||
}
|
||||
}
|
||||
val injector = executor.scheduleAtFixedRate(
|
||||
{
|
||||
workSemaphore.release(injectionRatePerSecond.toInt())
|
||||
},
|
||||
0,
|
||||
1,
|
||||
TimeUnit.SECONDS
|
||||
)
|
||||
shutdownManager.registerShutdown {
|
||||
injector.cancel(true)
|
||||
workExecutor.shutdownNow()
|
||||
executor.shutdownNow()
|
||||
workExecutor.awaitTermination(1, TimeUnit.SECONDS)
|
||||
executor.awaitTermination(1, TimeUnit.SECONDS)
|
||||
}
|
||||
Thread.sleep((overallDurationSecond * 1000).toLong())
|
||||
}
|
||||
|
||||
fun startInjectorWithBoundedQueue(
|
||||
executor: ExecutorService,
|
||||
numberOfInjections: Int,
|
||||
queueBound: Int,
|
||||
work: () -> Unit
|
||||
) {
|
||||
val remainingLatch = CountDownLatch(numberOfInjections)
|
||||
val queuedCount = AtomicInteger(0)
|
||||
val lock = ReentrantLock()
|
||||
val canQueueAgain = lock.newCondition()
|
||||
val injectorShutdown = AtomicBoolean(false)
|
||||
val injector = thread(name = "injector") {
|
||||
while (true) {
|
||||
if (injectorShutdown.get()) break
|
||||
executor.submit {
|
||||
work()
|
||||
if (queuedCount.decrementAndGet() < queueBound / 2) {
|
||||
lock.withLock {
|
||||
canQueueAgain.signal()
|
||||
}
|
||||
}
|
||||
remainingLatch.countDown()
|
||||
}
|
||||
if (queuedCount.incrementAndGet() > queueBound) {
|
||||
lock.withLock {
|
||||
canQueueAgain.await()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
remainingLatch.await()
|
||||
injectorShutdown.set(true)
|
||||
injector.join()
|
||||
}
|
||||
|
||||
fun RPCDriverExposedDSLInterface.startJmxReporter(): MetricRegistry {
|
||||
val metricRegistry = MetricRegistry()
|
||||
val jmxReporter = thread {
|
||||
JmxReporter.
|
||||
forRegistry(metricRegistry).
|
||||
inDomain("net.corda").
|
||||
createsObjectNamesWith { _, domain, name ->
|
||||
// Make the JMX hierarchy a bit better organised.
|
||||
val category = name.substringBefore('.')
|
||||
val subName = name.substringAfter('.', "")
|
||||
if (subName == "")
|
||||
ObjectName("$domain:name=$category")
|
||||
else
|
||||
ObjectName("$domain:type=$category,name=$subName")
|
||||
}.
|
||||
build().
|
||||
start()
|
||||
}
|
||||
shutdownManager.registerShutdown {
|
||||
jmxReporter.interrupt()
|
||||
jmxReporter.join()
|
||||
}
|
||||
return metricRegistry
|
||||
}
|
@ -1,85 +0,0 @@
|
||||
package net.corda.client.rpc
|
||||
|
||||
import net.corda.core.messaging.RPCOps
|
||||
import net.corda.node.services.messaging.requirePermission
|
||||
import net.corda.nodeapi.PermissionException
|
||||
import net.corda.nodeapi.User
|
||||
import org.junit.After
|
||||
import org.junit.Test
|
||||
import kotlin.test.assertFailsWith
|
||||
|
||||
class RPCPermissionsTest : AbstractClientRPCTest() {
|
||||
companion object {
|
||||
const val DUMMY_FLOW = "StartFlow.net.corda.flows.DummyFlow"
|
||||
const val OTHER_FLOW = "StartFlow.net.corda.flows.OtherFlow"
|
||||
const val ALL_ALLOWED = "ALL"
|
||||
}
|
||||
|
||||
lateinit var proxy: TestOps
|
||||
|
||||
@After
|
||||
fun shutdown() {
|
||||
safeClose(proxy)
|
||||
}
|
||||
|
||||
/*
|
||||
* RPC operation.
|
||||
*/
|
||||
interface TestOps : RPCOps {
|
||||
fun validatePermission(str: String)
|
||||
}
|
||||
|
||||
class TestOpsImpl : TestOps {
|
||||
override val protocolVersion = 1
|
||||
override fun validatePermission(str: String) = requirePermission(str)
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an RPC proxy for the given user.
|
||||
*/
|
||||
private fun proxyFor(rpcUser: User): TestOps = rpcProxyFor(rpcUser, TestOpsImpl(), TestOps::class.java)
|
||||
|
||||
private fun userOf(name: String, permissions: Set<String>) = User(name, "password", permissions)
|
||||
|
||||
@Test
|
||||
fun `empty user cannot use any flows`() {
|
||||
val emptyUser = userOf("empty", emptySet())
|
||||
proxy = proxyFor(emptyUser)
|
||||
assertFailsWith(PermissionException::class,
|
||||
"User ${emptyUser.username} should not be allowed to use $DUMMY_FLOW.",
|
||||
{ proxy.validatePermission(DUMMY_FLOW) })
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `admin user can use any flow`() {
|
||||
val adminUser = userOf("admin", setOf(ALL_ALLOWED))
|
||||
proxy = proxyFor(adminUser)
|
||||
proxy.validatePermission(DUMMY_FLOW)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `joe user is allowed to use DummyFlow`() {
|
||||
val joeUser = userOf("joe", setOf(DUMMY_FLOW))
|
||||
proxy = proxyFor(joeUser)
|
||||
proxy.validatePermission(DUMMY_FLOW)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `joe user is not allowed to use OtherFlow`() {
|
||||
val joeUser = userOf("joe", setOf(DUMMY_FLOW))
|
||||
proxy = proxyFor(joeUser)
|
||||
assertFailsWith(PermissionException::class,
|
||||
"User ${joeUser.username} should not be allowed to use $OTHER_FLOW",
|
||||
{ proxy.validatePermission(OTHER_FLOW) })
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `check ALL is implemented the correct way round`() {
|
||||
val joeUser = userOf("joe", setOf(DUMMY_FLOW))
|
||||
proxy = proxyFor(joeUser)
|
||||
assertFailsWith(PermissionException::class,
|
||||
"Permission $ALL_ALLOWED should not do anything for User ${joeUser.username}",
|
||||
{ proxy.validatePermission(ALL_ALLOWED) })
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,93 @@
|
||||
package net.corda.client.rpc
|
||||
|
||||
import net.corda.core.messaging.RPCOps
|
||||
import net.corda.node.services.messaging.requirePermission
|
||||
import net.corda.node.services.messaging.getRpcContext
|
||||
import net.corda.nodeapi.PermissionException
|
||||
import net.corda.nodeapi.User
|
||||
import net.corda.testing.RPCDriverExposedDSLInterface
|
||||
import net.corda.testing.rpcDriver
|
||||
import org.junit.Test
|
||||
import org.junit.runner.RunWith
|
||||
import org.junit.runners.Parameterized
|
||||
import kotlin.test.assertFailsWith
|
||||
|
||||
@RunWith(Parameterized::class)
|
||||
class RPCPermissionsTests : AbstractRPCTest() {
|
||||
companion object {
|
||||
const val DUMMY_FLOW = "StartFlow.net.corda.flows.DummyFlow"
|
||||
const val OTHER_FLOW = "StartFlow.net.corda.flows.OtherFlow"
|
||||
const val ALL_ALLOWED = "ALL"
|
||||
}
|
||||
|
||||
/*
|
||||
* RPC operation.
|
||||
*/
|
||||
interface TestOps : RPCOps {
|
||||
fun validatePermission(str: String)
|
||||
}
|
||||
|
||||
class TestOpsImpl : TestOps {
|
||||
override val protocolVersion = 1
|
||||
override fun validatePermission(str: String) = getRpcContext().requirePermission(str)
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an RPC proxy for the given user.
|
||||
*/
|
||||
private fun RPCDriverExposedDSLInterface.testProxyFor(rpcUser: User) = testProxy<TestOps>(TestOpsImpl(), rpcUser).ops
|
||||
|
||||
private fun userOf(name: String, permissions: Set<String>) = User(name, "password", permissions)
|
||||
|
||||
@Test
|
||||
fun `empty user cannot use any flows`() {
|
||||
rpcDriver {
|
||||
val emptyUser = userOf("empty", emptySet())
|
||||
val proxy = testProxyFor(emptyUser)
|
||||
assertFailsWith(PermissionException::class,
|
||||
"User ${emptyUser.username} should not be allowed to use $DUMMY_FLOW.",
|
||||
{ proxy.validatePermission(DUMMY_FLOW) })
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `admin user can use any flow`() {
|
||||
rpcDriver {
|
||||
val adminUser = userOf("admin", setOf(ALL_ALLOWED))
|
||||
val proxy = testProxyFor(adminUser)
|
||||
proxy.validatePermission(DUMMY_FLOW)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `joe user is allowed to use DummyFlow`() {
|
||||
rpcDriver {
|
||||
val joeUser = userOf("joe", setOf(DUMMY_FLOW))
|
||||
val proxy = testProxyFor(joeUser)
|
||||
proxy.validatePermission(DUMMY_FLOW)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `joe user is not allowed to use OtherFlow`() {
|
||||
rpcDriver {
|
||||
val joeUser = userOf("joe", setOf(DUMMY_FLOW))
|
||||
val proxy = testProxyFor(joeUser)
|
||||
assertFailsWith(PermissionException::class,
|
||||
"User ${joeUser.username} should not be allowed to use $OTHER_FLOW",
|
||||
{ proxy.validatePermission(OTHER_FLOW) })
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `check ALL is implemented the correct way round` () {
|
||||
rpcDriver {
|
||||
val joeUser = userOf("joe", setOf(DUMMY_FLOW))
|
||||
val proxy = testProxyFor(joeUser)
|
||||
assertFailsWith(PermissionException::class,
|
||||
"Permission $ALL_ALLOWED should not do anything for User ${joeUser.username}",
|
||||
{ proxy.validatePermission(ALL_ALLOWED) })
|
||||
}
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,25 @@
|
||||
package net.corda.client.rpc
|
||||
|
||||
import java.io.InputStream
|
||||
|
||||
class RepeatingBytesInputStream(val bytesToRepeat: ByteArray, val numberOfBytes: Int) : InputStream() {
|
||||
private var bytesLeft = numberOfBytes
|
||||
override fun available() = bytesLeft
|
||||
override fun read(): Int {
|
||||
if (bytesLeft == 0) {
|
||||
return -1
|
||||
} else {
|
||||
bytesLeft--
|
||||
return bytesToRepeat[(numberOfBytes - bytesLeft) % bytesToRepeat.size].toInt()
|
||||
}
|
||||
}
|
||||
override fun read(byteArray: ByteArray, offset: Int, length: Int): Int {
|
||||
val until = Math.min(Math.min(offset + length, byteArray.size), offset + bytesLeft)
|
||||
for (i in offset .. until - 1) {
|
||||
byteArray[i] = bytesToRepeat[(numberOfBytes - bytesLeft + i - offset) % bytesToRepeat.size]
|
||||
}
|
||||
val bytesRead = until - offset
|
||||
bytesLeft -= bytesRead
|
||||
return if (bytesRead == 0 && bytesLeft == 0) -1 else bytesRead
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user