RPC muxing, multithreading, RPC driver, performance tests

This commit is contained in:
Andras Slemmer
2017-03-29 17:28:02 +01:00
parent 25dbac0f07
commit de88ad4f40
63 changed files with 3223 additions and 1417 deletions

View File

@ -1,15 +1,10 @@
package net.corda.client.rpc
import net.corda.core.contracts.DOLLARS
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowInitiator
import net.corda.core.getOrThrow
import net.corda.core.messaging.FlowHandle
import net.corda.core.messaging.FlowProgressHandle
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.messaging.StateMachineUpdate
import net.corda.core.messaging.startFlow
import net.corda.core.messaging.startTrackedFlow
import net.corda.core.messaging.*
import net.corda.core.node.services.ServiceInfo
import net.corda.core.random63BitValue
import net.corda.core.serialization.OpaqueBytes
@ -27,7 +22,9 @@ import org.junit.After
import org.junit.Before
import org.junit.Test
import java.util.*
import kotlin.test.*
import kotlin.test.assertEquals
import kotlin.test.assertFalse
import kotlin.test.assertTrue
class CordaRPCClientTest : NodeBasedTest() {
private val rpcUser = User("user1", "test", permissions = setOf(
@ -36,6 +33,11 @@ class CordaRPCClientTest : NodeBasedTest() {
))
private lateinit var node: Node
private lateinit var client: CordaRPCClient
private var connection: CordaRPCConnection? = null
private fun login(username: String, password: String) {
connection = client.start(username, password)
}
@Before
fun setUp() {
@ -45,33 +47,35 @@ class CordaRPCClientTest : NodeBasedTest() {
@After
fun done() {
client.close()
connection?.close()
}
@Test
fun `log in with valid username and password`() {
client.start(rpcUser.username, rpcUser.password)
login(rpcUser.username, rpcUser.password)
}
@Test
fun `log in with unknown user`() {
assertThatExceptionOfType(ActiveMQSecurityException::class.java).isThrownBy {
client.start(random63BitValue().toString(), rpcUser.password)
login(random63BitValue().toString(), rpcUser.password)
}
}
@Test
fun `log in with incorrect password`() {
assertThatExceptionOfType(ActiveMQSecurityException::class.java).isThrownBy {
client.start(rpcUser.username, random63BitValue().toString())
login(rpcUser.username, random63BitValue().toString())
}
}
@Test
fun `close-send deadlock and premature shutdown on empty observable`() {
val proxy = createRpcProxy(rpcUser.username, rpcUser.password)
println("Starting client")
login(rpcUser.username, rpcUser.password)
println("Creating proxy")
println("Starting flow")
val flowHandle = proxy.startTrackedFlow(
val flowHandle = connection!!.proxy.startTrackedFlow(
::CashIssueFlow,
20.DOLLARS, OpaqueBytes.of(0), node.info.legalIdentity, node.info.legalIdentity)
println("Started flow, waiting on result")
@ -83,9 +87,8 @@ class CordaRPCClientTest : NodeBasedTest() {
@Test
fun `FlowException thrown by flow`() {
client.start(rpcUser.username, rpcUser.password)
val proxy = client.proxy()
val handle = proxy.startFlow(::CashPaymentFlow, 100.DOLLARS, node.info.legalIdentity)
login(rpcUser.username, rpcUser.password)
val handle = connection!!.proxy.startFlow(::CashPaymentFlow, 100.DOLLARS, node.info.legalIdentity)
// TODO Restrict this to CashException once RPC serialisation has been fixed
assertThatExceptionOfType(FlowException::class.java).isThrownBy {
handle.returnValue.getOrThrow()
@ -94,9 +97,8 @@ class CordaRPCClientTest : NodeBasedTest() {
@Test
fun `check basic flow has no progress`() {
client.start(rpcUser.username, rpcUser.password)
val proxy = client.proxy()
proxy.startFlow(::CashPaymentFlow, 100.DOLLARS, node.info.legalIdentity).use {
login(rpcUser.username, rpcUser.password)
connection!!.proxy.startFlow(::CashPaymentFlow, 100.DOLLARS, node.info.legalIdentity).use {
assertFalse(it is FlowProgressHandle<*>)
assertTrue(it is FlowHandle<*>)
}
@ -104,7 +106,8 @@ class CordaRPCClientTest : NodeBasedTest() {
@Test
fun `get cash balances`() {
val proxy = createRpcProxy(rpcUser.username, rpcUser.password)
login(rpcUser.username, rpcUser.password)
val proxy = connection!!.proxy
val startCash = proxy.getCashBalances()
assertTrue(startCash.isEmpty(), "Should not start with any cash")
@ -123,7 +126,8 @@ class CordaRPCClientTest : NodeBasedTest() {
@Test
fun `flow initiator via RPC`() {
val proxy = createRpcProxy(rpcUser.username, rpcUser.password)
login(rpcUser.username, rpcUser.password)
val proxy = connection!!.proxy
val smUpdates = proxy.stateMachinesAndUpdates()
var countRpcFlows = 0
var countShellFlows = 0
@ -148,11 +152,4 @@ class CordaRPCClientTest : NodeBasedTest() {
assertEquals(2, countRpcFlows)
assertEquals(1, countShellFlows)
}
private fun createRpcProxy(username: String, password: String): CordaRPCOps {
println("Starting client")
client.start(username, password)
println("Creating proxy")
return client.proxy()
}
}

View File

@ -0,0 +1,170 @@
package net.corda.client.rpc
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.pool.KryoPool
import com.google.common.util.concurrent.Futures
import net.corda.core.messaging.RPCOps
import net.corda.core.millis
import net.corda.core.random63BitValue
import net.corda.node.services.messaging.RPCServerConfiguration
import net.corda.nodeapi.RPCApi
import net.corda.nodeapi.RPCKryo
import net.corda.testing.*
import org.apache.activemq.artemis.api.core.SimpleString
import org.bouncycastle.crypto.tls.ConnectionEnd.server
import org.junit.Test
import rx.Observable
import rx.subjects.PublishSubject
import rx.subjects.UnicastSubject
import java.time.Duration
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger
class RPCStabilityTests {
interface LeakObservableOps: RPCOps {
fun leakObservable(): Observable<Nothing>
}
@Test
fun `client cleans up leaked observables`() {
rpcDriver {
val leakObservableOpsImpl = object : LeakObservableOps {
val leakedUnsubscribedCount = AtomicInteger(0)
override val protocolVersion = 0
override fun leakObservable(): Observable<Nothing> {
return PublishSubject.create<Nothing>().doOnUnsubscribe {
leakedUnsubscribedCount.incrementAndGet()
}
}
}
val server = startRpcServer<LeakObservableOps>(ops = leakObservableOpsImpl)
val proxy = startRpcClient<LeakObservableOps>(server.get().hostAndPort).get()
// Leak many observables
val N = 200
(1..N).toList().parallelStream().forEach {
proxy.leakObservable()
}
// In a loop force GC and check whether the server is notified
while (true) {
System.gc()
if (leakObservableOpsImpl.leakedUnsubscribedCount.get() == N) break
Thread.sleep(100)
}
}
}
interface TrackSubscriberOps : RPCOps {
fun subscribe(): Observable<Unit>
}
/**
* In this test we create a number of out of process RPC clients that call [TrackSubscriberOps.subscribe] in a loop.
*/
@Test
fun `server cleans up queues after disconnected clients`() {
rpcDriver {
val trackSubscriberOpsImpl = object : TrackSubscriberOps {
override val protocolVersion = 0
val subscriberCount = AtomicInteger(0)
val trackSubscriberCountObservable = UnicastSubject.create<Unit>().share().
doOnSubscribe { subscriberCount.incrementAndGet() }.
doOnUnsubscribe { subscriberCount.decrementAndGet() }
override fun subscribe(): Observable<Unit> {
return trackSubscriberCountObservable
}
}
val server = startRpcServer<TrackSubscriberOps>(
configuration = RPCServerConfiguration.default.copy(
reapIntervalMs = 100
),
ops = trackSubscriberOpsImpl
).get()
val numberOfClients = 4
val clients = Futures.allAsList((1 .. numberOfClients).map {
startRandomRpcClient<TrackSubscriberOps>(server.hostAndPort)
}).get()
// Poll until all clients connect
pollUntilClientNumber(server, numberOfClients)
pollUntilTrue("number of times subscribe() has been called") { trackSubscriberOpsImpl.subscriberCount.get() >= 100 }.get()
// Kill one client
clients[0].destroyForcibly()
pollUntilClientNumber(server, numberOfClients - 1)
// Kill the rest
(1 .. numberOfClients - 1).forEach {
clients[it].destroyForcibly()
}
pollUntilClientNumber(server, 0)
// Now poll until the server detects the disconnects and unsubscribes from all obserables.
pollUntilTrue("number of times subscribe() has been called") { trackSubscriberOpsImpl.subscriberCount.get() == 0 }.get()
}
}
interface SlowConsumerRPCOps : RPCOps {
fun streamAtInterval(interval: Duration, size: Int): Observable<ByteArray>
}
class SlowConsumerRPCOpsImpl : SlowConsumerRPCOps {
override val protocolVersion = 0
override fun streamAtInterval(interval: Duration, size: Int): Observable<ByteArray> {
val chunk = ByteArray(size)
return Observable.interval(interval.toMillis(), TimeUnit.MILLISECONDS).map { chunk }
}
}
val dummyObservableSerialiser = object : Serializer<Observable<Any>>() {
override fun write(kryo: Kryo?, output: Output?, `object`: Observable<Any>?) {
}
override fun read(kryo: Kryo?, input: Input?, type: Class<Observable<Any>>?): Observable<Any> {
return Observable.empty()
}
}
@Test
fun `slow consumers are kicked`() {
val kryoPool = KryoPool.Builder { RPCKryo(dummyObservableSerialiser) }.build()
rpcDriver {
val server = startRpcServer(maxBufferedBytesPerClient = 10 * 1024 * 1024, ops = SlowConsumerRPCOpsImpl()).get()
// Construct an RPC session manually so that we can hang in the message handler
val myQueue = "${RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX}.test.${random63BitValue()}"
val session = startArtemisSession(server.hostAndPort)
session.createTemporaryQueue(myQueue, myQueue)
val consumer = session.createConsumer(myQueue, null, -1, -1, false)
consumer.setMessageHandler {
Thread.sleep(50) // 5x slower than the server producer
it.acknowledge()
}
val producer = session.createProducer(RPCApi.RPC_SERVER_QUEUE_NAME)
session.start()
pollUntilClientNumber(server, 1)
val message = session.createMessage(false)
val request = RPCApi.ClientToServer.RpcRequest(
clientAddress = SimpleString(myQueue),
id = RPCApi.RpcRequestId(random63BitValue()),
methodName = SlowConsumerRPCOps::streamAtInterval.name,
arguments = listOf(10.millis, 123456)
)
request.writeToClientMessage(kryoPool, message)
producer.send(message)
session.commit()
// We are consuming slower than the server is producing, so we should be kicked after a while
pollUntilClientNumber(server, 0)
}
}
}
fun RPCDriverExposedDSLInterface.pollUntilClientNumber(server: RpcServerHandle, expected: Int) {
pollUntilTrue("number of RPC clients to become $expected") {
val clientAddresses = server.serverControl.addressNames.filter { it.startsWith(RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX) }
clientAddresses.size == expected
}.get()
}

View File

@ -1,168 +1,49 @@
package net.corda.client.rpc
import com.google.common.net.HostAndPort
import net.corda.core.ThreadBox
import net.corda.core.logElapsedTime
import net.corda.client.rpc.internal.RPCClient
import net.corda.client.rpc.internal.RPCClientConfiguration
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.minutes
import net.corda.core.seconds
import net.corda.core.utilities.loggerFor
import net.corda.nodeapi.ArtemisMessagingComponent
import net.corda.nodeapi.ArtemisTcpTransport.Companion.tcpTransport
import net.corda.nodeapi.ConnectionDirection
import net.corda.nodeapi.RPCException
import net.corda.nodeapi.config.SSLConfiguration
import net.corda.nodeapi.rpcLog
import org.apache.activemq.artemis.api.core.ActiveMQException
import org.apache.activemq.artemis.api.core.client.ActiveMQClient
import org.apache.activemq.artemis.api.core.client.ClientSession
import org.apache.activemq.artemis.api.core.client.ClientSessionFactory
import org.apache.activemq.artemis.api.core.client.ServerLocator
import rx.Observable
import java.io.Closeable
import java.time.Duration
import javax.annotation.concurrent.ThreadSafe
/**
* An RPC client connects to the specified server and allows you to make calls to the server that perform various
* useful tasks. See the documentation for [proxy] or review the docsite to learn more about how this API works.
*
* @param host The hostname and messaging port of the node.
* @param config If specified, the SSL configuration to use. If not specified, SSL will be disabled and the node will only be authenticated on non-SSL RPC port, the RPC traffic with not be encrypted when SSL is disabled.
*/
@ThreadSafe
class CordaRPCClient(val host: HostAndPort, override val config: SSLConfiguration? = null, val serviceConfigurationOverride: (ServerLocator.() -> Unit)? = null) : Closeable, ArtemisMessagingComponent() {
private companion object {
val log = loggerFor<CordaRPCClient>()
/** 10 MiB maximum allowed file size for attachments, including message headers. TODO: acquire this value from Network Map when supported. */
@JvmStatic val MAX_FILE_SIZE = 10485760
class CordaRPCConnection internal constructor(
connection: RPCClient.RPCConnection<CordaRPCOps>
) : RPCClient.RPCConnection<CordaRPCOps> by connection
data class CordaRPCClientConfiguration(
val connectionMaxRetryInterval: Duration
) {
internal fun toRpcClientConfiguration(): RPCClientConfiguration {
return RPCClientConfiguration.default.copy(
connectionMaxRetryInterval = connectionMaxRetryInterval
)
}
companion object {
@JvmStatic
val default = CordaRPCClientConfiguration(
connectionMaxRetryInterval = RPCClientConfiguration.default.connectionMaxRetryInterval
)
}
}
class CordaRPCClient(
hostAndPort: HostAndPort,
sslConfiguration: SSLConfiguration? = null,
configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.default
) {
private val rpcClient = RPCClient<CordaRPCOps>(
tcpTransport(ConnectionDirection.Outbound(), hostAndPort, sslConfiguration),
configuration.toRpcClientConfiguration()
)
fun start(username: String, password: String): CordaRPCConnection {
return CordaRPCConnection(rpcClient.start(CordaRPCOps::class.java, username, password))
}
// TODO: Certificate handling for clients needs more work.
private inner class State {
var running = false
lateinit var sessionFactory: ClientSessionFactory
lateinit var session: ClientSession
lateinit var clientImpl: CordaRPCClientImpl
}
private val state = ThreadBox(State())
/**
* Opens the connection to the server with the given username and password, then returns itself.
* Registers a JVM shutdown hook to cleanly disconnect.
*/
@Throws(ActiveMQException::class)
fun start(username: String, password: String): CordaRPCClient {
state.locked {
check(!running)
log.logElapsedTime("Startup") {
checkStorePasswords()
val serverLocator = ActiveMQClient.createServerLocatorWithoutHA(tcpTransport(ConnectionDirection.Outbound(), host, config, enableSSL = config != null)).apply {
// TODO: Put these in config file or make it user configurable?
threadPoolMaxSize = 1
confirmationWindowSize = 100000 // a guess
retryInterval = 5.seconds.toMillis()
retryIntervalMultiplier = 1.5 // Exponential backoff
maxRetryInterval = 3.minutes.toMillis()
minLargeMessageSize = MAX_FILE_SIZE
serviceConfigurationOverride?.invoke(this)
}
sessionFactory = serverLocator.createSessionFactory()
session = sessionFactory.createSession(username, password, false, true, true, serverLocator.isPreAcknowledge, serverLocator.ackBatchSize)
session.start()
clientImpl = CordaRPCClientImpl(session, state.lock, username)
running = true
}
}
Runtime.getRuntime().addShutdownHook(Thread {
close()
})
return this
}
/**
* A convenience function that opens a connection with the given credentials, executes the given code block with all
* available RPCs in scope and shuts down the RPC connection again. It's meant for quick prototyping and demos. For
* more control you probably want to control the lifecycle of the client and proxies independently, as well as
* configuring a timeout and other such features via the [proxy] method.
*
* After this method returns the client is closed and can't be restarted.
*/
@Throws(ActiveMQException::class)
fun <T> use(username: String, password: String, block: CordaRPCOps.() -> T): T {
require(!state.locked { running })
start(username, password)
(this as Closeable).use {
return proxy().block()
}
}
/** Shuts down the client and lets the server know it can free the used resources (in a nice way). */
override fun close() {
state.locked {
if (!running) return
session.close()
sessionFactory.close()
running = false
}
}
/**
* Returns a fresh proxy that lets you invoke RPCs on the server. Calls on it block, and if the server throws an
* exception then it will be rethrown on the client. Proxies are thread safe but only one RPC can be in flight at
* once. If you'd like to perform multiple RPCs in parallel, use this function multiple times to get multiple
* proxies.
*
* Creation of a proxy is a somewhat expensive operation that involves calls to the server, so if you want to do
* calls from many threads at once you should cache one proxy per thread and reuse them. This function itself is
* thread safe though so requires no extra synchronisation.
*
* RPC sends and receives are logged on the net.corda.rpc logger.
*
* By default there are no timeouts on calls. This is deliberate, RPCs without timeouts can survive restarts,
* maintenance downtime and moves of the server. RPCs can survive temporary losses or changes in client connectivity,
* like switching between wifi networks. You can specify a timeout on the level of a proxy. If a call times
* out it will throw [RPCException.Deadline].
*
* The [CordaRPCOps] defines what client RPCs are available. If an RPC returns an [Observable] anywhere in the
* object graph returned then the server-side observable is transparently linked to a messaging queue, and that
* queue linked to another observable on the client side here. *You are expected to use it*. The server will begin
* buffering messages immediately that it will expect you to drain by subscribing to the returned observer. You can
* opt-out of this by simply casting the [Observable] to [Closeable] or [AutoCloseable] and then calling the close
* method on it. You don't have to explicitly close the observable if you actually subscribe to it: it will close
* itself and free up the server-side resources either when the client or JVM itself is shutdown, or when there are
* no more subscribers to it. Once all the subscribers to a returned observable are unsubscribed, the observable is
* closed and you can't then re-subscribe again: you'll have to re-request a fresh observable with another RPC.
*
* The proxy and linked observables consume some small amount of resources on the server. It's OK to just exit your
* process and let the server clean up, but in a long running process where you only need something for a short
* amount of time it is polite to cast the objects to [Closeable] or [AutoCloseable] and close it when you are done.
* Finalizers are in place to warn you if you lose a reference to an unclosed proxy or observable.
*
* @throws RPCException if the server version is too low or if the server isn't reachable within the given time.
*/
@JvmOverloads
@Throws(RPCException::class)
fun proxy(timeout: Duration? = null, minVersion: Int = 0): CordaRPCOps {
return state.locked {
check(running) { "Client must have been started first" }
log.logElapsedTime("Proxy build") {
clientImpl.proxyFor(CordaRPCOps::class.java, timeout, minVersion)
}
}
}
@Suppress("UNUSED")
private fun finalize() {
state.locked {
if (running) {
rpcLog.warn("A CordaMQClient is being finalised whilst still running, did you forget to call close?")
close()
}
}
inline fun <A> use(username: String, password: String, block: (CordaRPCConnection) -> A): A {
return start(username, password).use(block)
}
}

View File

@ -1,418 +0,0 @@
package net.corda.client.rpc
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.KryoException
import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.pool.KryoPool
import com.google.common.cache.CacheBuilder
import net.corda.core.ErrorOr
import net.corda.core.bufferUntilSubscribed
import net.corda.core.messaging.RPCOps
import net.corda.core.messaging.RPCReturnsObservables
import net.corda.core.random63BitValue
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.core.utilities.debug
import net.corda.nodeapi.*
import org.apache.activemq.artemis.api.core.ActiveMQObjectClosedException
import org.apache.activemq.artemis.api.core.Message.HDR_DUPLICATE_DETECTION_ID
import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.client.ClientConsumer
import org.apache.activemq.artemis.api.core.client.ClientMessage
import org.apache.activemq.artemis.api.core.client.ClientProducer
import org.apache.activemq.artemis.api.core.client.ClientSession
import rx.Observable
import rx.subjects.PublishSubject
import java.io.Closeable
import java.lang.ref.WeakReference
import java.lang.reflect.InvocationHandler
import java.lang.reflect.Method
import java.lang.reflect.Proxy
import java.time.Duration
import java.util.*
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.locks.ReentrantLock
import javax.annotation.concurrent.GuardedBy
import javax.annotation.concurrent.ThreadSafe
import kotlin.concurrent.withLock
import kotlin.reflect.jvm.javaMethod
/**
* Core RPC engine implementation, to learn how to use RPC you should be looking at [CordaRPCClient].
*
* # Design notes
*
* The way RPCs are handled is fairly standard except for the handling of observables. When an RPC might return
* an [Observable] it is specially tagged. This causes the client to create a new transient queue for the
* receiving of observables and their observations with a random ID in the name. This ID is sent to the server in
* a message header. All observations are sent via this single queue.
*
* The reason for doing it this way and not the more obvious approach of one-queue-per-observable is that we want
* the queues to be *transient*, meaning their lifetime in the broker is tied to the session that created them.
* A server side observable and its associated queue is not a cost-free thing, let alone the memory and resources
* needed to actually generate the observations themselves, therefore we want to ensure these cannot leak. A
* transient queue will be deleted automatically if the client session terminates, which by default happens on
* disconnect but can also be configured to happen after a short delay (this allows clients to e.g. switch IP
* address). On the server the deletion of the observations queue triggers unsubscription from the associated
* observables, which in turn may then be garbage collected.
*
* Creating a transient queue requires a roundtrip to the broker and thus doing an RPC that could return
* observables takes two server roundtrips instead of one. That's why we require RPCs to be marked with
* [RPCReturnsObservables] as needing this special treatment instead of always doing it.
*
* If the Artemis/JMS APIs allowed us to create transient queues assigned to someone else then we could
* potentially use a different design in which the node creates new transient queues (one per observable) on the
* fly. The client would then have to watch out for this and start consuming those queues as they were created.
*
* We use one queue per RPC because we don't know ahead of time how many observables the server might return and
* often the server doesn't know either, which pushes towards a single queue design, but at the same time the
* processing of observations returned by an RPC might be striped across multiple threads and we'd like
* backpressure management to not be scoped per client process but with more granularity. So we end up with
* a compromise where the unit of backpressure management is the response to a single RPC.
*
* TODO: Backpressure isn't propagated all the way through the MQ broker at the moment.
*/
class CordaRPCClientImpl(private val session: ClientSession,
private val sessionLock: ReentrantLock,
private val username: String) {
companion object {
private val closeableCloseMethod = Closeable::close.javaMethod
private val autocloseableCloseMethod = AutoCloseable::close.javaMethod
}
/**
* Builds a proxy for the given type, which must descend from [RPCOps].
*
* @see CordaRPCClient.proxy for more information about how to use the proxies.
*/
fun <T : RPCOps> proxyFor(rpcInterface: Class<T>, timeout: Duration? = null, minVersion: Int = 0): T {
sessionLock.withLock {
if (producer == null)
producer = session.createProducer()
}
val proxyImpl = RPCProxyHandler(timeout)
@Suppress("UNCHECKED_CAST")
val proxy = Proxy.newProxyInstance(rpcInterface.classLoader, arrayOf(rpcInterface, Closeable::class.java), proxyImpl) as T
proxyImpl.serverProtocolVersion = proxy.protocolVersion
if (minVersion > proxyImpl.serverProtocolVersion)
throw RPCException("Requested minimum protocol version $minVersion is higher than the server's supported protocol version (${proxyImpl.serverProtocolVersion})")
return proxy
}
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////
//
//region RPC engine
//
// You can find docs on all this in the api doc for the proxyFor method, and in the docsite.
// Utility to quickly suck out the contents of an Artemis message. There's probably a more efficient way to
// do this.
private fun <T : Any> ClientMessage.deserialize(kryo: Kryo): T = ByteArray(bodySize).apply { bodyBuffer.readBytes(this) }.deserialize(kryo)
// We by default use a weak reference so GC can happen, otherwise they persist for the life of the client.
@GuardedBy("sessionLock")
private val addressToQueuedObservables = CacheBuilder.newBuilder().weakValues().build<String, QueuedObservable>()
// This is used to hold a reference counted hard reference when we know there are subscribers.
private val hardReferencesToQueuedObservables = Collections.synchronizedSet(mutableSetOf<QueuedObservable>())
private var producer: ClientProducer? = null
class ObservableDeserializer : Serializer<Observable<Any>>() {
override fun read(kryo: Kryo, input: Input, type: Class<Observable<Any>>): Observable<Any> {
val qName = kryo.context[RPCKryoQNameKey] as String
val rpcName = kryo.context[RPCKryoMethodNameKey] as String
val rpcLocation = kryo.context[RPCKryoLocationKey] as Throwable
val rpcClient = kryo.context[RPCKryoClientKey] as CordaRPCClientImpl
val handle = input.readInt(true)
val ob = rpcClient.sessionLock.withLock {
rpcClient.addressToQueuedObservables.getIfPresent(qName) ?: rpcClient.QueuedObservable(qName, rpcName, rpcLocation).apply {
rpcClient.addressToQueuedObservables.put(qName, this)
}
}
val result = ob.getForHandle(handle)
rpcLog.debug { "Deserializing and connecting a new observable for $rpcName on $qName: $result" }
return result
}
override fun write(kryo: Kryo, output: Output, `object`: Observable<Any>) {
throw UnsupportedOperationException("not implemented")
}
}
/**
* The proxy class returned to the client is auto-generated on the fly by the java.lang.reflect Proxy
* infrastructure. The JDK Proxy class writes bytecode into memory for a class that implements the requested
* interfaces and then routes all method calls to the invoke method below in a conveniently reified form.
* We can then easily take the data about the method call and turn it into an RPC. This avoids the need
* for the compile-time code generation which is so common in RPC systems.
*/
@ThreadSafe
private inner class RPCProxyHandler(private val timeout: Duration?) : InvocationHandler, Closeable {
private val proxyId = random63BitValue()
private val consumer: ClientConsumer
var serverProtocolVersion = 0
init {
val proxyAddress = constructAddress(proxyId)
consumer = sessionLock.withLock {
session.createTemporaryQueue(proxyAddress, proxyAddress)
session.createConsumer(proxyAddress)
}
}
private fun constructAddress(addressId: Long) = "${ArtemisMessagingComponent.CLIENTS_PREFIX}$username.rpc.$addressId"
@Synchronized
override fun invoke(proxy: Any, method: Method, args: Array<out Any>?): Any? {
if (isCloseInvocation(method)) {
close()
return null
}
if (method.name == "toString" && args == null)
return "Client RPC proxy"
if (consumer.isClosed)
throw RPCException("RPC Proxy is closed")
// All invoked methods on the proxy end up here.
val location = Throwable()
rpcLog.debug {
val argStr = args?.joinToString() ?: ""
"-> RPC -> ${method.name}($argStr): ${method.returnType}"
}
checkMethodVersion(method)
val msg: ClientMessage = createMessage(method)
// We could of course also check the return type of the method to see if it's Observable, but I'd
// rather haved the annotation be used consistently.
val returnsObservables = method.isAnnotationPresent(RPCReturnsObservables::class.java)
val kryo = if (returnsObservables) maybePrepareForObservables(location, method, msg) else createRPCKryoForDeserialization(this@CordaRPCClientImpl)
val next: ErrorOr<*> = try {
sendRequest(args, msg)
receiveResponse(kryo, method, timeout)
} finally {
releaseRPCKryoForDeserialization(kryo)
}
rpcLog.debug { "<- RPC <- ${method.name} = $next" }
return unwrapOrThrow(next)
}
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
private fun unwrapOrThrow(next: ErrorOr<*>): Any? {
val ex = next.error
if (ex != null) {
// Replace the stack trace because that's an implementation detail of the server that isn't so
// helpful to the user who wants to see where the error was on their side, and serialising stack
// frame objects is a bit annoying. We slice it here to avoid the invoke() machinery being exposed.
// The resulting exception looks like it was thrown from inside the called method.
(ex as java.lang.Throwable).stackTrace = java.lang.Throwable().stackTrace.let { it.sliceArray(1..it.size - 1) }
throw ex
} else {
return next.value
}
}
private fun receiveResponse(kryo: Kryo, method: Method, timeout: Duration?): ErrorOr<*> {
val artemisMessage: ClientMessage =
if (timeout == null)
consumer.receive() ?: throw ActiveMQObjectClosedException()
else
consumer.receive(timeout.toMillis()) ?: throw RPCException.DeadlineExceeded(method.name)
artemisMessage.acknowledge()
val next = artemisMessage.deserialize<ErrorOr<*>>(kryo)
return next
}
private fun sendRequest(args: Array<out Any>?, msg: ClientMessage) {
sessionLock.withLock {
val argsKryo = createRPCKryoForDeserialization(this@CordaRPCClientImpl)
val serializedArgs = try {
(args ?: emptyArray<Any?>()).serialize(argsKryo)
} catch (e: KryoException) {
throw RPCException("Could not serialize RPC arguments", e)
} finally {
releaseRPCKryoForDeserialization(argsKryo)
}
msg.writeBodyBufferBytes(serializedArgs.bytes)
producer!!.send(ArtemisMessagingComponent.RPC_REQUESTS_QUEUE, msg)
}
}
private fun maybePrepareForObservables(location: Throwable, method: Method, msg: ClientMessage): Kryo {
// Create a temporary queue just for the emissions on any observables that are returned.
val observationsId = random63BitValue()
val observationsQueueName = constructAddress(observationsId)
session.createTemporaryQueue(observationsQueueName, observationsQueueName)
msg.putLongProperty(ClientRPCRequestMessage.OBSERVATIONS_TO, observationsId)
// And make sure that we deserialise observable handles so that they're linked to the right
// queue. Also record a bit of metadata for debugging purposes.
return createRPCKryoForDeserialization(this@CordaRPCClientImpl, observationsQueueName, method.name, location)
}
private fun createMessage(method: Method): ClientMessage {
return session.createMessage(false).apply {
putStringProperty(ClientRPCRequestMessage.METHOD_NAME, method.name)
putLongProperty(ClientRPCRequestMessage.REPLY_TO, proxyId)
// Use the magic deduplication property built into Artemis as our message identity too
putStringProperty(HDR_DUPLICATE_DETECTION_ID, SimpleString(UUID.randomUUID().toString()))
}
}
private fun checkMethodVersion(method: Method) {
val methodVersion = method.getAnnotation(RPCSinceVersion::class.java)?.version ?: 0
if (methodVersion > serverProtocolVersion)
throw UnsupportedOperationException("Method ${method.name} was added in RPC protocol version $methodVersion but the server is running $serverProtocolVersion")
}
private fun isCloseInvocation(method: Method) = method == closeableCloseMethod || method == autocloseableCloseMethod
override fun close() {
consumer.close()
sessionLock.withLock { session.deleteQueue(constructAddress(proxyId)) }
}
override fun toString() = "Corda RPC Proxy listening on queue ${constructAddress(proxyId)}"
}
/**
* When subscribed to, starts consuming from the given queue name and demultiplexing the observables being
* sent to it. The server queue is moved into in-memory buffers (one per attached server-side observable)
* until drained through a subscription. When the subscriptions are all gone, the server-side queue is deleted.
*/
@ThreadSafe
private inner class QueuedObservable(private val qName: String,
private val rpcName: String,
private val rpcLocation: Throwable) {
private val root = PublishSubject.create<MarshalledObservation>()
private val rootShared = root.doOnUnsubscribe { close() }.share()
// This could be made more efficient by using a specialised IntMap
// When handling this map we don't synchronise on [this], otherwise there is a race condition between close() and deliver()
private val observables = Collections.synchronizedMap(HashMap<Int, Observable<Any>>())
@GuardedBy("sessionLock")
private var consumer: ClientConsumer? = null
private val referenceCount = AtomicInteger(0)
// We have to create a weak reference, otherwise we cannot be GC'd.
init {
val weakThis = WeakReference<QueuedObservable>(this)
consumer = sessionLock.withLock { session.createConsumer(qName) }.setMessageHandler { weakThis.get()?.deliver(it) }
}
/**
* We have to reference count subscriptions to the returned [Observable]s to prevent early GC because we are
* weak referenced.
*
* Derived [Observables] (e.g. filtered etc) hold a strong reference to the original, but for example, if
* the pattern as follows is used, the original passes out of scope and the direction of reference is from the
* original to the [Observer]. We use the reference counting to allow for this pattern.
*
* val observationsSubject = PublishSubject.create<Observation>()
* originalObservable.subscribe(observationsSubject)
* return observationsSubject
*/
private fun refCountUp() {
if (referenceCount.andIncrement == 0) {
hardReferencesToQueuedObservables.add(this)
}
}
private fun refCountDown() {
if (referenceCount.decrementAndGet() == 0) {
hardReferencesToQueuedObservables.remove(this)
}
}
fun getForHandle(handle: Int): Observable<Any> {
synchronized(observables) {
return observables.getOrPut(handle) {
/**
* Note that the order of bufferUntilSubscribed() -> dematerialize() is very important here.
*
* In particular doing it the other way around may result in the following edge case:
* The RPC returns two (or more) Observables. The first Observable unsubscribes *during serialisation*,
* before the second one is hit, causing the [rootShared] to unsubscribe and consequently closing
* the underlying artemis queue, even though the second Observable was not even registered.
*
* The buffer -> dematerialize order ensures that the Observable may not unsubscribe until the caller
* subscribes, which must be after full deserialisation and registering of all top level Observables.
*
* In addition, when subscribe and unsubscribe is called on the [Observable] returned here, we
* reference count a hard reference to this [QueuedObservable] to prevent premature GC.
*/
rootShared.filter { it.forHandle == handle }.map { it.what }.bufferUntilSubscribed().dematerialize<Any>().doOnSubscribe { refCountUp() }.doOnUnsubscribe { refCountDown() }.share()
}
}
}
private fun deliver(msg: ClientMessage) {
sessionLock.withLock { msg.acknowledge() }
val kryo = createRPCKryoForDeserialization(this@CordaRPCClientImpl, qName, rpcName, rpcLocation)
val received: MarshalledObservation = try {
msg.deserialize(kryo)
} finally {
releaseRPCKryoForDeserialization(kryo)
}
rpcLog.debug { "<- Observable [$rpcName] <- Received $received" }
synchronized(observables) {
// Force creation of the buffer if it doesn't already exist.
getForHandle(received.forHandle)
root.onNext(received)
}
}
fun close() {
sessionLock.withLock {
if (consumer != null) {
rpcLog.debug("Closing queue observable for call to $rpcName : $qName")
consumer?.close()
consumer = null
session.deleteQueue(qName)
}
}
}
@Suppress("UNUSED")
fun finalize() {
val closed = sessionLock.withLock {
if (consumer != null) {
consumer!!.close()
consumer = null
true
} else
false
}
if (closed) {
rpcLog.warn("""A hot observable returned from an RPC ($rpcName) was never subscribed to.
This wastes server-side resources because it was queueing observations for retrieval.
It is being closed now, but please adjust your code to call .notUsed() on the observable
to close it explicitly. (Java users: subscribe to it then unsubscribe). This warning
will appear less frequently in future versions of the platform and you can ignore it
if you want to.
""".trimIndent().replace('\n', ' '), rpcLocation)
}
}
}
//endregion
}
private val rpcDesKryoPool = KryoPool.Builder { RPCKryo(CordaRPCClientImpl.ObservableDeserializer()) }.build()
fun createRPCKryoForDeserialization(rpcClient: CordaRPCClientImpl, qName: String? = null, rpcName: String? = null, rpcLocation: Throwable? = null): Kryo {
val kryo = rpcDesKryoPool.borrow()
kryo.context.put(RPCKryoClientKey, rpcClient)
kryo.context.put(RPCKryoQNameKey, qName)
kryo.context.put(RPCKryoMethodNameKey, rpcName)
kryo.context.put(RPCKryoLocationKey, rpcLocation)
return kryo
}
fun releaseRPCKryoForDeserialization(kryo: Kryo) {
rpcDesKryoPool.release(kryo)
}

View File

@ -0,0 +1,169 @@
package net.corda.client.rpc.internal
import com.google.common.net.HostAndPort
import net.corda.core.logElapsedTime
import net.corda.core.messaging.RPCOps
import net.corda.core.minutes
import net.corda.core.random63BitValue
import net.corda.core.seconds
import net.corda.core.utilities.loggerFor
import net.corda.nodeapi.ArtemisTcpTransport.Companion.tcpTransport
import net.corda.nodeapi.ConnectionDirection
import net.corda.nodeapi.RPCApi
import net.corda.nodeapi.RPCException
import net.corda.nodeapi.config.SSLConfiguration
import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.TransportConfiguration
import org.apache.activemq.artemis.api.core.client.ActiveMQClient
import java.io.Closeable
import java.lang.reflect.Proxy
import java.time.Duration
/**
* This configuration may be used to tweak the internals of the RPC client.
*/
data class RPCClientConfiguration(
/** The minimum protocol version required from the server */
val minimumServerProtocolVersion: Int,
/**
* If set to true the client will track RPC call sites. If an error occurs subsequently during the RPC or in a
* returned Observable stream the stack trace of the originating RPC will be shown as well. Note that
* constructing call stacks is a moderately expensive operation.
*/
val trackRpcCallSites: Boolean,
/**
* The interval of unused observable reaping in milliseconds. Leaked Observables (unused ones) are
* detected using weak references and are cleaned up in batches in this interval. If set too large it will waste
* server side resources for this duration. If set too low it wastes client side cycles.
*/
val reapIntervalMs: Long,
/** The number of threads to use for observations (for executing [Observable.onNext]) */
val observationExecutorPoolSize: Int,
/** The maximum number of producers to create to handle outgoing messages */
val producerPoolBound: Int,
/**
* Determines the concurrency level of the Observable Cache. This is exposed because it implicitly determines
* the limit on the number of leaked observables reaped because of garbage collection per reaping.
* See the implementation of [com.google.common.cache.LocalCache] for details.
*/
val cacheConcurrencyLevel: Int,
/** The retry interval of artemis connections in milliseconds */
val connectionRetryInterval: Duration,
/** The retry interval multiplier for exponential backoff */
val connectionRetryIntervalMultiplier: Double,
/** Maximum retry interval */
val connectionMaxRetryInterval: Duration,
/** Maximum file size */
val maxFileSize: Int
) {
companion object {
@JvmStatic
val default = RPCClientConfiguration(
minimumServerProtocolVersion = 0,
trackRpcCallSites = false,
reapIntervalMs = 1000,
observationExecutorPoolSize = 4,
producerPoolBound = 1,
cacheConcurrencyLevel = 8,
connectionRetryInterval = 5.seconds,
connectionRetryIntervalMultiplier = 1.5,
connectionMaxRetryInterval = 3.minutes,
/** 10 MiB maximum allowed file size for attachments, including message headers. TODO: acquire this value from Network Map when supported. */
maxFileSize = 10485760
)
}
}
/**
* An RPC client that may be used to create connections to an RPC server.
*
* @param transport The Artemis transport to use to connect to the server.
* @param rpcConfiguration Configuration used to tweak client behaviour.
*/
class RPCClient<I : RPCOps>(
val transport: TransportConfiguration,
val rpcConfiguration: RPCClientConfiguration = RPCClientConfiguration.default
) {
constructor(
hostAndPort: HostAndPort,
sslConfiguration: SSLConfiguration? = null,
configuration: RPCClientConfiguration = RPCClientConfiguration.default
) : this(tcpTransport(ConnectionDirection.Outbound(), hostAndPort, sslConfiguration), configuration)
companion object {
private val log = loggerFor<RPCClient<*>>()
}
/**
* Holds a proxy object implementing [I] that forwards requests to the RPC server.
*
* [Closeable.close] may be used to shut down the connection and release associated resources.
*/
interface RPCConnection<out I : RPCOps> : Closeable {
val proxy: I
/** The RPC protocol version reported by the server */
val serverProtocolVersion: Int
}
/**
* Returns an [RPCConnection] containing a proxy that lets you invoke RPCs on the server. Calls on it block, and if
* the server throws an exception then it will be rethrown on the client. Proxies are thread safe and may be used to
* invoke multiple RPCs in parallel.
*
* RPC sends and receives are logged on the net.corda.rpc logger.
*
* The [RPCOps] defines what client RPCs are available. If an RPC returns an [Observable] anywhere in the object
* graph returned then the server-side observable is transparently forwarded to the client side here.
* *You are expected to use it*. The server will begin buffering messages immediately that it will expect you to
* drain by subscribing to the returned observer. You can opt-out of this by simply calling the
* [net.corda.client.rpc.notUsed] method on it. You don't have to explicitly close the observable if you actually
* subscribe to it: it will close itself and free up the server-side resources either when the client or JVM itself
* is shutdown, or when there are no more subscribers to it. Once all the subscribers to a returned observable are
* unsubscribed or the observable completes successfully or with an error, the observable is closed and you can't
* then re-subscribe again: you'll have to re-request a fresh observable with another RPC.
*
* @param rpcOpsClass The [Class] of the RPC interface.
* @param username The username to authenticate with.
* @param password The password to authenticate with.
* @throws RPCException if the server version is too low or if the server isn't reachable within the given time.
*/
fun start(
rpcOpsClass: Class<I>,
username: String,
password: String
): RPCConnection<I> {
return log.logElapsedTime("Startup") {
val clientAddress = SimpleString("${RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX}.$username.${random63BitValue()}")
val serverLocator = ActiveMQClient.createServerLocatorWithoutHA(transport).apply {
retryInterval = rpcConfiguration.connectionRetryInterval.toMillis()
retryIntervalMultiplier = rpcConfiguration.connectionRetryIntervalMultiplier
maxRetryInterval = rpcConfiguration.connectionMaxRetryInterval.toMillis()
minLargeMessageSize = rpcConfiguration.maxFileSize
}
val proxyHandler = RPCClientProxyHandler(rpcConfiguration, username, password, serverLocator, clientAddress, rpcOpsClass)
proxyHandler.start()
@Suppress("UNCHECKED_CAST")
val ops = Proxy.newProxyInstance(rpcOpsClass.classLoader, arrayOf(rpcOpsClass), proxyHandler) as I
val serverProtocolVersion = ops.protocolVersion
if (serverProtocolVersion < rpcConfiguration.minimumServerProtocolVersion) {
throw RPCException("Requested minimum protocol version (${rpcConfiguration.minimumServerProtocolVersion}) is higher" +
" than the server's supported protocol version ($serverProtocolVersion)")
}
proxyHandler.setServerProtocolVersion(serverProtocolVersion)
log.debug("RPC connected, returning proxy")
object : RPCConnection<I> {
override val proxy = ops
override val serverProtocolVersion = serverProtocolVersion
override fun close() {
proxyHandler.close()
serverLocator.close()
}
}
}
}
}

View File

@ -0,0 +1,423 @@
package net.corda.client.rpc.internal
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.pool.KryoPool
import com.google.common.cache.Cache
import com.google.common.cache.CacheBuilder
import com.google.common.cache.RemovalCause
import com.google.common.cache.RemovalListener
import com.google.common.util.concurrent.SettableFuture
import com.google.common.util.concurrent.ThreadFactoryBuilder
import net.corda.core.ThreadBox
import net.corda.core.getOrThrow
import net.corda.core.messaging.RPCOps
import net.corda.core.random63BitValue
import net.corda.core.serialization.KryoPoolWithContext
import net.corda.core.utilities.*
import net.corda.nodeapi.*
import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE
import org.apache.activemq.artemis.api.core.client.ClientMessage
import org.apache.activemq.artemis.api.core.client.ServerLocator
import rx.Notification
import rx.Observable
import rx.subjects.UnicastSubject
import sun.reflect.CallerSensitive
import java.lang.reflect.InvocationHandler
import java.lang.reflect.Method
import java.util.*
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.Executors
import java.util.concurrent.ScheduledFuture
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger
import kotlin.collections.ArrayList
import kotlin.reflect.jvm.javaMethod
/**
* This class provides a proxy implementation of an RPC interface for RPC clients. It translates API calls to lower-level
* RPC protocol messages. For this protocol see [RPCApi].
*
* When a method is called on the interface the arguments are serialised and the request is forwarded to the server. The
* server then executes the code that implements the RPC and sends a reply.
*
* An RPC reply may contain [Observable]s, which are serialised simply as unique IDs. On the client side we create a
* [UnicastSubject] for each such ID. Subsequently the server may send observations attached to this ID, which are
* forwarded to the [UnicastSubject]. Note that the observations themselves may contain further [Observable]s, which are
* handled in the same way.
*
* To do the above we take advantage of Kryo's datastructure traversal. When the client is deserialising a message from
* the server that may contain Observables it is supplied with an [ObservableContext] that exposes the map used to demux
* the observations. When an [Observable] is encountered during traversal a new [UnicastSubject] is added to the map and
* we carry on. Each observation later contains the corresponding Observable ID, and we just forward that to the
* associated [UnicastSubject].
*
* The client may signal that it no longer consumes a particular [Observable]. This may be done explicitly by
* unsubscribing from the [Observable], or if the [Observable] is garbage collected the client will eventually
* automatically signal the server. This is done using a cache that holds weak references to the [UnicastSubject]s.
* The cleanup happens in batches using a dedicated reaper, scheduled on [reaperExecutor].
*/
class RPCClientProxyHandler(
private val rpcConfiguration: RPCClientConfiguration,
private val rpcUsername: String,
private val rpcPassword: String,
private val serverLocator: ServerLocator,
private val clientAddress: SimpleString,
private val rpcOpsClass: Class<out RPCOps>
) : InvocationHandler {
private enum class State {
UNSTARTED,
SERVER_VERSION_NOT_SET,
STARTED,
FINISHED
}
private val lifeCycle = LifeCycle(State.UNSTARTED)
private companion object {
val log = loggerFor<RPCClientProxyHandler>()
// Note that this KryoPool is not yet capable of deserialising Observables, it requires Proxy-specific context
// to do that. However it may still be used for serialisation of RPC requests and related messages.
val kryoPool = KryoPool.Builder { RPCKryo(RpcClientObservableSerializer) }.build()
// To check whether toString() is being invoked
val toStringMethod: Method = Object::toString.javaMethod!!
}
// Used for reaping
private val reaperExecutor = Executors.newScheduledThreadPool(
1,
ThreadFactoryBuilder().setNameFormat("rpc-client-reaper-%d").build()
)
// A sticky pool for running Observable.onNext()s. We need the stickiness to preserve the observation ordering.
private val observationExecutorThreadFactory = ThreadFactoryBuilder().setNameFormat("rpc-client-observation-pool-%d").build()
private val observationExecutorPool = LazyStickyPool(rpcConfiguration.observationExecutorPoolSize) {
Executors.newFixedThreadPool(1, observationExecutorThreadFactory)
}
// Holds the RPC reply futures.
private val rpcReplyMap = RpcReplyMap()
// Optionally holds RPC call site stack traces to be shown on errors/warnings.
private val callSiteMap = if (rpcConfiguration.trackRpcCallSites) CallSiteMap() else null
// Holds the Observables and a reference store to keep Observables alive when subscribed to.
private val observableContext = ObservableContext(
callSiteMap = callSiteMap,
observableMap = createRpcObservableMap(),
hardReferenceStore = Collections.synchronizedSet(mutableSetOf<Observable<*>>())
)
// Holds a reference to the scheduled reaper.
private lateinit var reaperScheduledFuture: ScheduledFuture<*>
// The protocol version of the server, to be initialised to the value of [RPCOps.protocolVersion]
private var serverProtocolVersion: Int? = null
// Stores the Observable IDs that are already removed from the map but are not yet sent to the server.
private val observablesToReap = ThreadBox(object {
var observables = ArrayList<RPCApi.ObservableId>()
})
// A Kryo pool that automatically adds the observable context when an instance is requested.
private val kryoPoolWithObservableContext = RpcClientObservableSerializer.createPoolWithContext(kryoPool, observableContext)
private fun createRpcObservableMap(): RpcObservableMap {
val onObservableRemove = RemovalListener<RPCApi.ObservableId, UnicastSubject<Notification<Any>>> {
val rpcCallSite = callSiteMap?.remove(it.key.toLong)
if (it.cause == RemovalCause.COLLECTED) {
log.warn(listOf(
"A hot observable returned from an RPC was never subscribed to.",
"This wastes server-side resources because it was queueing observations for retrieval.",
"It is being closed now, but please adjust your code to call .notUsed() on the observable",
"to close it explicitly. (Java users: subscribe to it then unsubscribe). This warning",
"will appear less frequently in future versions of the platform and you can ignore it",
"if you want to.").joinToString(" "), rpcCallSite)
}
observablesToReap.locked { observables.add(it.key) }
}
return CacheBuilder.newBuilder().
weakValues().
removalListener(onObservableRemove).
concurrencyLevel(rpcConfiguration.cacheConcurrencyLevel).
build()
}
// We cannot pool consumers as we need to preserve the original muxed message order.
// TODO We may need to pool these somehow anyway, otherwise if the server sends many big messages in parallel a
// single consumer may be starved for flow control credits. Recheck this once Artemis's large message streaming is
// integrated properly.
private lateinit var sessionAndConsumer: ArtemisConsumer
// Pool producers to reduce contention on the client side.
private val sessionAndProducerPool = LazyPool(bound = rpcConfiguration.producerPoolBound) {
// Note how we create new sessions *and* session factories per producer.
// We cannot simply pool producers on one session because sessions are single threaded.
// We cannot simply pool sessions on one session factory because flow control credits are tied to factories, so
// sessions tend to starve each other when used concurrently.
val sessionFactory = serverLocator.createSessionFactory()
val session = sessionFactory.createSession(rpcUsername, rpcPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE)
session.start()
ArtemisProducer(sessionFactory, session, session.createProducer(RPCApi.RPC_SERVER_QUEUE_NAME))
}
/**
* Start the client. This creates the per-client queue, starts the consumer session and the reaper.
*/
fun start() {
lifeCycle.transition(State.UNSTARTED, State.SERVER_VERSION_NOT_SET)
reaperScheduledFuture = reaperExecutor.scheduleAtFixedRate(
this::reapObservables,
rpcConfiguration.reapIntervalMs,
rpcConfiguration.reapIntervalMs,
TimeUnit.MILLISECONDS
)
sessionAndProducerPool.run {
it.session.createTemporaryQueue(clientAddress, clientAddress)
}
val sessionFactory = serverLocator.createSessionFactory()
val session = sessionFactory.createSession(rpcUsername, rpcPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE)
val consumer = session.createConsumer(clientAddress)
consumer.setMessageHandler(this@RPCClientProxyHandler::artemisMessageHandler)
session.start()
sessionAndConsumer = ArtemisConsumer(sessionFactory, session, consumer)
}
// This is the general function that transforms a client side RPC to internal Artemis messages.
@CallerSensitive
override fun invoke(proxy: Any, method: Method, arguments: Array<out Any?>?): Any? {
lifeCycle.requireState { it == State.STARTED || it == State.SERVER_VERSION_NOT_SET }
checkProtocolVersion(method)
if (method == toStringMethod) {
return "Client RPC proxy for $rpcOpsClass"
}
if (sessionAndConsumer.session.isClosed) {
throw RPCException("RPC Proxy is closed")
}
val rpcId = RPCApi.RpcRequestId(random63BitValue())
callSiteMap?.set(rpcId.toLong, Throwable("<Call site of root RPC '${method.name}'>"))
try {
val request = RPCApi.ClientToServer.RpcRequest(clientAddress, rpcId, method.name, arguments?.toList() ?: emptyList())
val replyFuture = SettableFuture.create<Any>()
sessionAndProducerPool.run {
val message = it.session.createMessage(false)
request.writeToClientMessage(kryoPool, message)
log.debug {
val argumentsString = arguments?.joinToString() ?: ""
"-> RPC($rpcId) -> ${method.name}($argumentsString): ${method.returnType}"
}
require(rpcReplyMap.put(rpcId, replyFuture) == null) {
"Generated several RPC requests with same ID $rpcId"
}
it.producer.send(message)
it.session.commit()
}
return replyFuture.getOrThrow()
} finally {
callSiteMap?.remove(rpcId.toLong)
}
}
// The handler for Artemis messages.
private fun artemisMessageHandler(message: ClientMessage) {
val serverToClient = RPCApi.ServerToClient.fromClientMessage(kryoPoolWithObservableContext, message)
log.debug { "Got message from RPC server $serverToClient" }
when (serverToClient) {
is RPCApi.ServerToClient.RpcReply -> {
val replyFuture = rpcReplyMap.remove(serverToClient.id)
if (replyFuture == null) {
log.error("RPC reply arrived to unknown RPC ID ${serverToClient.id}, this indicates an internal RPC error.")
} else {
val rpcCallSite = callSiteMap?.get(serverToClient.id.toLong)
serverToClient.result.match(
onError = {
if (rpcCallSite != null) addRpcCallSiteToThrowable(it, rpcCallSite)
replyFuture.setException(it)
},
onValue = { replyFuture.set(it) }
)
}
}
is RPCApi.ServerToClient.Observation -> {
val observable = observableContext.observableMap.getIfPresent(serverToClient.id)
if (observable == null) {
log.debug("Observation ${serverToClient.content} arrived to unknown Observable with ID ${serverToClient.id}. " +
"This may be due to an observation arriving before the server was " +
"notified of observable shutdown")
} else {
// We schedule the onNext() on an executor sticky-pooled based on the Observable ID.
observationExecutorPool.run(serverToClient.id) { executor ->
executor.submit {
val content = serverToClient.content
if (content.isOnCompleted || content.isOnError) {
observableContext.observableMap.invalidate(serverToClient.id)
}
// Add call site information on error
if (content.isOnError) {
val rpcCallSite = callSiteMap?.get(serverToClient.id.toLong)
if (rpcCallSite != null) addRpcCallSiteToThrowable(content.throwable, rpcCallSite)
}
observable.onNext(content)
}
}
}
}
}
message.acknowledge()
}
/**
* Closes the RPC proxy. Reaps all observables, shuts down the reaper, closes all sessions and executors.
*/
fun close() {
lifeCycle.transition(State.STARTED, State.FINISHED)
sessionAndConsumer.consumer.close()
sessionAndConsumer.session.close()
sessionAndConsumer.sessionFactory.close()
reaperScheduledFuture.cancel(false)
observableContext.observableMap.invalidateAll()
reapObservables()
reaperExecutor.shutdownNow()
sessionAndProducerPool.close().forEach {
it.producer.close()
it.session.close()
it.sessionFactory.close()
}
// Note the ordering is important, we shut down the consumer *before* the observation executor, otherwise we may
// leak borrowed executors.
val observationExecutors = observationExecutorPool.close()
observationExecutors.forEach { it.shutdownNow() }
observationExecutors.forEach { it.awaitTermination(100, TimeUnit.MILLISECONDS) }
}
/**
* Check the [RPCSinceVersion] of the passed in [calledMethod] against the server's protocol version.
*/
private fun checkProtocolVersion(calledMethod: Method) {
val serverProtocolVersion = serverProtocolVersion
if (serverProtocolVersion == null) {
lifeCycle.requireState(State.SERVER_VERSION_NOT_SET)
} else {
lifeCycle.requireState(State.STARTED)
val sinceVersion = calledMethod.getAnnotation(RPCSinceVersion::class.java)?.version ?: 0
if (sinceVersion > serverProtocolVersion) {
throw UnsupportedOperationException("Method $calledMethod was added in RPC protocol version $sinceVersion but the server is running $serverProtocolVersion")
}
}
}
/**
* Set the server's protocol version. Note that before doing so the client is not considered fully started, although
* RPCs already may be called with it.
*/
internal fun setServerProtocolVersion(version: Int) {
lifeCycle.transition(State.SERVER_VERSION_NOT_SET, State.STARTED)
if (serverProtocolVersion == null) {
serverProtocolVersion = version
} else {
throw IllegalStateException("setServerProtocolVersion called, but the protocol version was already set!")
}
}
private fun reapObservables() {
observableContext.observableMap.cleanUp()
val observableIds = observablesToReap.locked {
if (observables.isNotEmpty()) {
val temporary = observables
observables = ArrayList()
temporary
} else {
null
}
}
if (observableIds != null) {
log.debug { "Reaping ${observableIds.size} observables" }
sessionAndProducerPool.run {
val message = it.session.createMessage(false)
RPCApi.ClientToServer.ObservablesClosed(observableIds).writeToClientMessage(message)
it.producer.send(message)
}
}
}
}
private typealias RpcObservableMap = Cache<RPCApi.ObservableId, UnicastSubject<Notification<Any>>>
private typealias RpcReplyMap = ConcurrentHashMap<RPCApi.RpcRequestId, SettableFuture<Any?>>
private typealias CallSiteMap = ConcurrentHashMap<Long, Throwable?>
/**
* Holds a context available during Kryo deserialisation of messages that are expected to contain Observables.
*
* @param observableMap holds the Observables that are ultimately exposed to the user.
* @param hardReferenceStore holds references to Observables we want to keep alive while they are subscribed to.
*/
private data class ObservableContext(
val callSiteMap: CallSiteMap?,
val observableMap: RpcObservableMap,
val hardReferenceStore: MutableSet<Observable<*>>
)
/**
* A [Serializer] to deserialise Observables once the corresponding Kryo instance has been provided with an [ObservableContext].
*/
private object RpcClientObservableSerializer : Serializer<Observable<Any>>() {
private object RpcObservableContextKey
fun createPoolWithContext(kryoPool: KryoPool, observableContext: ObservableContext): KryoPool {
return KryoPoolWithContext(kryoPool, RpcObservableContextKey, observableContext)
}
override fun read(kryo: Kryo, input: Input, type: Class<Observable<Any>>): Observable<Any> {
@Suppress("UNCHECKED_CAST")
val observableContext = kryo.context[RpcObservableContextKey] as ObservableContext
val observableId = RPCApi.ObservableId(input.readLong(true))
val observable = UnicastSubject.create<Notification<Any>>()
require(observableContext.observableMap.getIfPresent(observableId) == null) {
"Multiple Observables arrived with the same ID $observableId"
}
val rpcCallSite = getRpcCallSite(kryo, observableContext)
observableContext.observableMap.put(observableId, observable)
observableContext.callSiteMap?.put(observableId.toLong, rpcCallSite)
// We pin all Observables into a hard reference store (rooted in the RPC proxy) on subscription so that users
// don't need to store a reference to the Observables themselves.
return observable.pinInSubscriptions(observableContext.hardReferenceStore).doOnUnsubscribe {
// This causes Future completions to give warnings because the corresponding OnComplete sent from the server
// will arrive after the client unsubscribes from the observable and consequently invalidates the mapping.
// The unsubscribe is due to [ObservableToFuture]'s use of first().
observableContext.observableMap.invalidate(observableId)
}.dematerialize()
}
override fun write(kryo: Kryo, output: Output, observable: Observable<Any>) {
throw UnsupportedOperationException("Cannot serialise Observables on the client side")
}
private fun getRpcCallSite(kryo: Kryo, observableContext: ObservableContext): Throwable? {
val rpcRequestOrObservableId = kryo.context[RPCApi.RpcRequestOrObservableIdKey] as Long
return observableContext.callSiteMap?.get(rpcRequestOrObservableId)
}
}
private fun addRpcCallSiteToThrowable(throwable: Throwable, callSite: Throwable) {
var currentThrowable = throwable
while (true) {
val cause = currentThrowable.cause
if (cause == null) {
currentThrowable.initCause(callSite)
break
} else {
currentThrowable = cause
}
}
}
private fun <T> Observable<T>.pinInSubscriptions(hardReferenceStore: MutableSet<Observable<*>>): Observable<T> {
val refCount = AtomicInteger(0)
return this.doOnSubscribe {
if (refCount.getAndIncrement() == 0) {
require(hardReferenceStore.add(this)) { "Reference store already contained reference $this on add" }
}
}.doOnUnsubscribe {
if (refCount.decrementAndGet() == 0) {
require(hardReferenceStore.remove(this)) { "Reference store did not contain reference $this on remove" }
}
}
}

View File

@ -1,102 +0,0 @@
package net.corda.client.rpc
import net.corda.core.messaging.RPCOps
import net.corda.core.serialization.SerializedBytes
import net.corda.core.utilities.ALICE
import net.corda.core.utilities.LogHelper
import net.corda.node.services.RPCUserService
import net.corda.node.services.messaging.RPCDispatcher
import net.corda.node.utilities.AffinityExecutor
import net.corda.nodeapi.ArtemisMessagingComponent
import net.corda.nodeapi.User
import org.apache.activemq.artemis.api.core.Message
import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.TransportConfiguration
import org.apache.activemq.artemis.api.core.client.ActiveMQClient
import org.apache.activemq.artemis.api.core.client.ClientMessage
import org.apache.activemq.artemis.api.core.client.ClientProducer
import org.apache.activemq.artemis.api.core.client.ClientSession
import org.apache.activemq.artemis.core.config.impl.ConfigurationImpl
import org.apache.activemq.artemis.core.remoting.impl.invm.InVMAcceptorFactory
import org.apache.activemq.artemis.core.remoting.impl.invm.InVMConnectorFactory
import org.apache.activemq.artemis.core.server.embedded.EmbeddedActiveMQ
import org.junit.After
import org.junit.Before
import java.util.*
import java.util.concurrent.locks.ReentrantLock
abstract class AbstractClientRPCTest {
lateinit var artemis: EmbeddedActiveMQ
lateinit var serverSession: ClientSession
lateinit var clientSession: ClientSession
lateinit var producer: ClientProducer
lateinit var serverThread: AffinityExecutor.ServiceAffinityExecutor
@Before
fun rpcSetup() {
// Set up an in-memory Artemis with an RPC requests queue.
artemis = EmbeddedActiveMQ()
artemis.setConfiguration(ConfigurationImpl().apply {
acceptorConfigurations = setOf(TransportConfiguration(InVMAcceptorFactory::class.java.name))
isSecurityEnabled = false
isPersistenceEnabled = false
})
artemis.start()
val serverLocator = ActiveMQClient.createServerLocatorWithoutHA(TransportConfiguration(InVMConnectorFactory::class.java.name))
val sessionFactory = serverLocator.createSessionFactory()
serverSession = sessionFactory.createSession()
serverSession.start()
serverSession.createTemporaryQueue(ArtemisMessagingComponent.RPC_REQUESTS_QUEUE, ArtemisMessagingComponent.RPC_REQUESTS_QUEUE)
producer = serverSession.createProducer()
serverThread = AffinityExecutor.ServiceAffinityExecutor("unit-tests-rpc-dispatch-thread", 1)
serverSession.createTemporaryQueue("activemq.notifications", "rpc.qremovals", "_AMQ_NotifType = 'BINDING_REMOVED'")
clientSession = sessionFactory.createSession()
clientSession.start()
LogHelper.setLevel("+net.corda.rpc")
}
@After
fun rpcShutdown() {
safeClose(producer)
clientSession.stop()
serverSession.stop()
artemis.stop()
serverThread.shutdownNow()
}
fun <T : RPCOps> rpcProxyFor(rpcUser: User, rpcImpl: T, type: Class<T>): T {
val userService = object : RPCUserService {
override fun getUser(username: String): User? = if (username == rpcUser.username) rpcUser else null
override val users: List<User> get() = listOf(rpcUser)
}
val dispatcher = object : RPCDispatcher(rpcImpl, userService, ALICE.name) {
override fun send(data: SerializedBytes<*>, toAddress: String) {
val msg = serverSession.createMessage(false).apply {
writeBodyBufferBytes(data.bytes)
// Use the magic deduplication property built into Artemis as our message identity too
putStringProperty(Message.HDR_DUPLICATE_DETECTION_ID, SimpleString(UUID.randomUUID().toString()))
}
producer.send(toAddress, msg)
}
override fun getUser(message: ClientMessage): User = rpcUser
}
val serverNotifConsumer = serverSession.createConsumer("rpc.qremovals")
val serverConsumer = serverSession.createConsumer(ArtemisMessagingComponent.RPC_REQUESTS_QUEUE)
dispatcher.start(serverConsumer, serverNotifConsumer, serverThread)
return CordaRPCClientImpl(clientSession, ReentrantLock(), rpcUser.username).proxyFor(type)
}
fun safeClose(obj: Any) {
try {
(obj as AutoCloseable).close()
} catch (e: Exception) {
}
}
}

View File

@ -0,0 +1,56 @@
package net.corda.client.rpc
import net.corda.client.rpc.internal.RPCClientConfiguration
import net.corda.core.flatMap
import net.corda.core.map
import net.corda.core.messaging.RPCOps
import net.corda.node.services.messaging.RPCServerConfiguration
import net.corda.nodeapi.User
import net.corda.testing.RPCDriverExposedDSLInterface
import net.corda.testing.rpcTestUser
import net.corda.testing.startInVmRpcClient
import net.corda.testing.startRpcClient
import org.apache.activemq.artemis.api.core.client.ClientSession
import org.junit.runners.Parameterized
open class AbstractRPCTest {
enum class RPCTestMode {
InVm,
Netty
}
companion object {
@JvmStatic @Parameterized.Parameters(name = "Mode = {0}")
fun defaultModes() = modes(RPCTestMode.InVm, RPCTestMode.Netty)
fun modes(vararg modes: RPCTestMode) = listOf(*modes).map { arrayOf(it) }
}
@Parameterized.Parameter
lateinit var mode: RPCTestMode
data class TestProxy<out I : RPCOps>(
val ops: I,
val createSession: () -> ClientSession
)
inline fun <reified I : RPCOps> RPCDriverExposedDSLInterface.testProxy(
ops: I,
rpcUser: User = rpcTestUser,
clientConfiguration: RPCClientConfiguration = RPCClientConfiguration.default,
serverConfiguration: RPCServerConfiguration = RPCServerConfiguration.default
): TestProxy<I> {
return when (mode) {
RPCTestMode.InVm ->
startInVmRpcServer(ops = ops, rpcUser = rpcUser, configuration = serverConfiguration).flatMap {
startInVmRpcClient<I>(rpcUser.username, rpcUser.password, clientConfiguration).map {
TestProxy(it, { startInVmArtemisSession(rpcUser.username, rpcUser.password) })
}
}.get()
RPCTestMode.Netty ->
startRpcServer(ops = ops, rpcUser = rpcUser, configuration = serverConfiguration).flatMap { server ->
startRpcClient<I>(server.hostAndPort, rpcUser.username, rpcUser.password, clientConfiguration).map {
TestProxy(it, { startArtemisSession(server.hostAndPort, rpcUser.username, rpcUser.password) })
}
}.get()
}
}
}

View File

@ -5,16 +5,16 @@ import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.SettableFuture
import net.corda.core.getOrThrow
import net.corda.core.messaging.RPCOps
import net.corda.core.messaging.RPCReturnsObservables
import net.corda.core.success
import net.corda.nodeapi.CURRENT_RPC_USER
import net.corda.node.services.messaging.getRpcContext
import net.corda.nodeapi.RPCSinceVersion
import net.corda.nodeapi.User
import org.apache.activemq.artemis.api.core.SimpleString
import net.corda.testing.RPCDriverExposedDSLInterface
import net.corda.testing.rpcDriver
import net.corda.testing.rpcTestUser
import org.assertj.core.api.Assertions.assertThat
import org.junit.After
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import rx.Observable
import rx.subjects.PublishSubject
import java.util.concurrent.CountDownLatch
@ -23,22 +23,11 @@ import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertTrue
class ClientRPCInfrastructureTests : AbstractClientRPCTest() {
@RunWith(Parameterized::class)
class ClientRPCInfrastructureTests : AbstractRPCTest() {
// TODO: Test that timeouts work
lateinit var proxy: TestOps
private val authenticatedUser = User("test", "password", permissions = setOf())
@Before
fun setup() {
proxy = rpcProxyFor(authenticatedUser, TestOpsImpl(), TestOps::class.java)
}
@After
fun shutdown() {
safeClose(proxy)
}
private fun RPCDriverExposedDSLInterface.testProxy() = testProxy<TestOps>(TestOpsImpl()).ops
interface TestOps : RPCOps {
@Throws(IllegalArgumentException::class)
@ -48,16 +37,12 @@ class ClientRPCInfrastructureTests : AbstractClientRPCTest() {
fun someCalculation(str: String, num: Int): String
@RPCReturnsObservables
fun makeObservable(): Observable<Int>
@RPCReturnsObservables
fun makeComplicatedObservable(): Observable<Pair<String, Observable<String>>>
@RPCReturnsObservables
fun makeListenableFuture(): ListenableFuture<Int>
@RPCReturnsObservables
fun makeComplicatedListenableFuture(): ListenableFuture<Pair<String, ListenableFuture<String>>>
@RPCSinceVersion(2)
@ -78,117 +63,130 @@ class ClientRPCInfrastructureTests : AbstractClientRPCTest() {
override fun makeListenableFuture(): ListenableFuture<Int> = Futures.immediateFuture(1)
override fun makeComplicatedObservable() = complicatedObservable
override fun makeComplicatedListenableFuture(): ListenableFuture<Pair<String, ListenableFuture<String>>> = complicatedListenableFuturee
override fun addedLater(): Unit = throw UnsupportedOperationException("not implemented")
override fun captureUser(): String = CURRENT_RPC_USER.get().username
override fun addedLater(): Unit = throw IllegalStateException()
override fun captureUser(): String = getRpcContext().currentUser.username
}
@Test
fun `simple RPCs`() {
// Does nothing, doesn't throw.
proxy.void()
rpcDriver {
val proxy = testProxy()
// Does nothing, doesn't throw.
proxy.void()
assertEquals("Barf!", assertFailsWith<IllegalArgumentException> {
proxy.barf()
}.message)
assertEquals("Barf!", assertFailsWith<IllegalArgumentException> {
proxy.barf()
}.message)
assertEquals("hi 5", proxy.someCalculation("hi", 5))
assertEquals("hi 5", proxy.someCalculation("hi", 5))
}
}
@Test
fun `simple observable`() {
// This tests that the observations are transmitted correctly, also completion is transmitted.
val observations = proxy.makeObservable().toBlocking().toIterable().toList()
assertEquals(listOf(1, 2, 3, 4), observations)
rpcDriver {
val proxy = testProxy()
// This tests that the observations are transmitted correctly, also completion is transmitted.
val observations = proxy.makeObservable().toBlocking().toIterable().toList()
assertEquals(listOf(1, 2, 3, 4), observations)
}
}
@Test
fun `complex observables`() {
// This checks that we can return an object graph with complex usage of observables, like an observable
// that emits objects that contain more observables.
val serverQuotes = PublishSubject.create<Pair<String, Observable<String>>>()
val unsubscribeLatch = CountDownLatch(1)
complicatedObservable = serverQuotes.asObservable().doOnUnsubscribe { unsubscribeLatch.countDown() }
rpcDriver {
val proxy = testProxy()
// This checks that we can return an object graph with complex usage of observables, like an observable
// that emits objects that contain more observables.
val serverQuotes = PublishSubject.create<Pair<String, Observable<String>>>()
val unsubscribeLatch = CountDownLatch(1)
complicatedObservable = serverQuotes.asObservable().doOnUnsubscribe { unsubscribeLatch.countDown() }
val twainQuotes = "Mark Twain" to Observable.just(
"I have never let my schooling interfere with my education.",
"Clothes make the man. Naked people have little or no influence on society."
)
val wildeQuotes = "Oscar Wilde" to Observable.just(
"I can resist everything except temptation.",
"Always forgive your enemies - nothing annoys them so much."
)
val twainQuotes = "Mark Twain" to Observable.just(
"I have never let my schooling interfere with my education.",
"Clothes make the man. Naked people have little or no influence on society."
)
val wildeQuotes = "Oscar Wilde" to Observable.just(
"I can resist everything except temptation.",
"Always forgive your enemies - nothing annoys them so much."
)
val clientQuotes = LinkedBlockingQueue<String>()
val clientObs = proxy.makeComplicatedObservable()
val clientQuotes = LinkedBlockingQueue<String>()
val clientObs = proxy.makeComplicatedObservable()
val subscription = clientObs.subscribe {
val name = it.first
it.second.subscribe {
clientQuotes += "Quote by $name: $it"
val subscription = clientObs.subscribe {
val name = it.first
it.second.subscribe {
clientQuotes += "Quote by $name: $it"
}
}
assertThat(clientQuotes).isEmpty()
serverQuotes.onNext(twainQuotes)
assertEquals("Quote by Mark Twain: I have never let my schooling interfere with my education.", clientQuotes.take())
assertEquals("Quote by Mark Twain: Clothes make the man. Naked people have little or no influence on society.", clientQuotes.take())
serverQuotes.onNext(wildeQuotes)
assertEquals("Quote by Oscar Wilde: I can resist everything except temptation.", clientQuotes.take())
assertEquals("Quote by Oscar Wilde: Always forgive your enemies - nothing annoys them so much.", clientQuotes.take())
assertTrue(serverQuotes.hasObservers())
subscription.unsubscribe()
unsubscribeLatch.await()
}
val rpcQueuesQuery = SimpleString("clients.${authenticatedUser.username}.rpc.*")
assertEquals(2, clientSession.addressQuery(rpcQueuesQuery).queueNames.size)
assertThat(clientQuotes).isEmpty()
serverQuotes.onNext(twainQuotes)
assertEquals("Quote by Mark Twain: I have never let my schooling interfere with my education.", clientQuotes.take())
assertEquals("Quote by Mark Twain: Clothes make the man. Naked people have little or no influence on society.", clientQuotes.take())
serverQuotes.onNext(wildeQuotes)
assertEquals("Quote by Oscar Wilde: I can resist everything except temptation.", clientQuotes.take())
assertEquals("Quote by Oscar Wilde: Always forgive your enemies - nothing annoys them so much.", clientQuotes.take())
assertTrue(serverQuotes.hasObservers())
subscription.unsubscribe()
unsubscribeLatch.await()
assertEquals(1, clientSession.addressQuery(rpcQueuesQuery).queueNames.size)
}
@Test
fun `simple ListenableFuture`() {
val value = proxy.makeListenableFuture().getOrThrow()
assertThat(value).isEqualTo(1)
rpcDriver {
val proxy = testProxy()
val value = proxy.makeListenableFuture().getOrThrow()
assertThat(value).isEqualTo(1)
}
}
@Test
fun `complex ListenableFuture`() {
val serverQuote = SettableFuture.create<Pair<String, ListenableFuture<String>>>()
complicatedListenableFuturee = serverQuote
rpcDriver {
val proxy = testProxy()
val serverQuote = SettableFuture.create<Pair<String, ListenableFuture<String>>>()
complicatedListenableFuturee = serverQuote
val twainQuote = "Mark Twain" to Futures.immediateFuture("I have never let my schooling interfere with my education.")
val twainQuote = "Mark Twain" to Futures.immediateFuture("I have never let my schooling interfere with my education.")
val clientQuotes = LinkedBlockingQueue<String>()
val clientFuture = proxy.makeComplicatedListenableFuture()
val clientQuotes = LinkedBlockingQueue<String>()
val clientFuture = proxy.makeComplicatedListenableFuture()
clientFuture.success {
val name = it.first
it.second.success {
clientQuotes += "Quote by $name: $it"
clientFuture.success {
val name = it.first
it.second.success {
clientQuotes += "Quote by $name: $it"
}
}
assertThat(clientQuotes).isEmpty()
serverQuote.set(twainQuote)
assertThat(clientQuotes.take()).isEqualTo("Quote by Mark Twain: I have never let my schooling interfere with my education.")
// TODO This final assert sometimes fails because the relevant queue hasn't been removed yet
}
val rpcQueuesQuery = SimpleString("clients.${authenticatedUser.username}.rpc.*")
assertEquals(2, clientSession.addressQuery(rpcQueuesQuery).queueNames.size)
assertThat(clientQuotes).isEmpty()
serverQuote.set(twainQuote)
assertThat(clientQuotes.take()).isEqualTo("Quote by Mark Twain: I have never let my schooling interfere with my education.")
// TODO This final assert sometimes fails because the relevant queue hasn't been removed yet
// assertEquals(1, clientSession.addressQuery(rpcQueuesQuery).queueNames.size)
}
@Test
fun versioning() {
assertFailsWith<UnsupportedOperationException> { proxy.addedLater() }
rpcDriver {
val proxy = testProxy()
assertFailsWith<UnsupportedOperationException> { proxy.addedLater() }
}
}
@Test
fun `authenticated user is available to RPC`() {
assertThat(proxy.captureUser()).isEqualTo(authenticatedUser.username)
rpcDriver {
val proxy = testProxy()
assertThat(proxy.captureUser()).isEqualTo(rpcTestUser.username)
}
}
}

View File

@ -0,0 +1,194 @@
package net.corda.client.rpc
import com.google.common.util.concurrent.Futures
import com.google.common.util.concurrent.ListenableFuture
import net.corda.client.rpc.internal.RPCClientConfiguration
import net.corda.core.future
import net.corda.core.messaging.RPCOps
import net.corda.core.random63BitValue
import net.corda.core.serialization.CordaSerializable
import net.corda.core.utilities.loggerFor
import net.corda.node.driver.poll
import net.corda.node.services.messaging.RPCServerConfiguration
import net.corda.nodeapi.RPCApi
import net.corda.testing.RPCDriverExposedDSLInterface
import net.corda.testing.rpcDriver
import net.corda.testing.startRandomRpcClient
import net.corda.testing.startRpcClient
import org.apache.activemq.artemis.api.core.SimpleString
import org.junit.Ignore
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import rx.Observable
import rx.subjects.PublishSubject
import rx.subjects.UnicastSubject
import java.util.*
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.CountDownLatch
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger
@RunWith(Parameterized::class)
class RPCConcurrencyTests : AbstractRPCTest() {
/**
* Holds a "rose"-tree of [Observable]s which allows us to test arbitrary [Observable] nesting in RPC replies.
*/
@CordaSerializable
data class ObservableRose<out A>(val value: A, val branches: Observable<out ObservableRose<A>>)
private interface TestOps : RPCOps {
fun newLatch(numberOfDowns: Int): Long
fun waitLatch(id: Long)
fun downLatch(id: Long)
fun getImmediateObservableTree(depth: Int, branchingFactor: Int): ObservableRose<Int>
fun getParallelObservableTree(depth: Int, branchingFactor: Int): ObservableRose<Int>
}
class TestOpsImpl : TestOps {
private val latches = ConcurrentHashMap<Long, CountDownLatch>()
override val protocolVersion = 0
override fun newLatch(numberOfDowns: Int): Long {
val id = random63BitValue()
val latch = CountDownLatch(numberOfDowns)
latches.put(id, latch)
return id
}
override fun waitLatch(id: Long) {
latches[id]!!.await()
}
override fun downLatch(id: Long) {
latches[id]!!.countDown()
}
override fun getImmediateObservableTree(depth: Int, branchingFactor: Int): ObservableRose<Int> {
val branches = if (depth == 0) {
Observable.empty<ObservableRose<Int>>()
} else {
Observable.just(getImmediateObservableTree(depth - 1, branchingFactor)).repeat(branchingFactor.toLong())
}
return ObservableRose(depth, branches)
}
override fun getParallelObservableTree(depth: Int, branchingFactor: Int): ObservableRose<Int> {
val branches = if (depth == 0) {
Observable.empty<ObservableRose<Int>>()
} else {
val publish = UnicastSubject.create<ObservableRose<Int>>()
future {
(1..branchingFactor).toList().parallelStream().forEach {
publish.onNext(getParallelObservableTree(depth - 1, branchingFactor))
}
publish.onCompleted()
}
publish
}
return ObservableRose(depth, branches)
}
}
private lateinit var testOpsImpl: TestOpsImpl
private fun RPCDriverExposedDSLInterface.testProxy(): TestProxy<TestOps> {
testOpsImpl = TestOpsImpl()
return testProxy<TestOps>(
testOpsImpl,
clientConfiguration = RPCClientConfiguration.default.copy(
reapIntervalMs = 100,
cacheConcurrencyLevel = 16
),
serverConfiguration = RPCServerConfiguration.default.copy(
rpcThreadPoolSize = 4
)
)
}
@Test
fun `call multiple RPCs in parallel`() {
rpcDriver {
val proxy = testProxy()
val numberOfBlockedCalls = 2
val numberOfDownsRequired = 100
val id = proxy.ops.newLatch(numberOfDownsRequired)
val done = CountDownLatch(numberOfBlockedCalls)
// Start a couple of blocking RPC calls
(1..numberOfBlockedCalls).forEach {
future {
proxy.ops.waitLatch(id)
done.countDown()
}
}
// Down the latch that the others are waiting for concurrently
(1..numberOfDownsRequired).toList().parallelStream().forEach {
proxy.ops.downLatch(id)
}
done.await()
}
}
private fun intPower(base: Int, power: Int): Int {
return when (power) {
0 -> 1
1 -> base
else -> {
val a = intPower(base, power / 2)
if (power and 1 == 0) {
a * a
} else {
a * a * base
}
}
}
}
@Test
fun `nested immediate observables sequence correctly`() {
rpcDriver {
// We construct a rose tree of immediate Observables and check that parent observations arrive before children.
val proxy = testProxy()
val treeDepth = 6
val treeBranchingFactor = 3
val remainingLatch = CountDownLatch((intPower(treeBranchingFactor, treeDepth + 1) - 1) / (treeBranchingFactor - 1))
val depthsSeen = Collections.synchronizedSet(HashSet<Int>())
fun ObservableRose<Int>.subscribeToAll() {
remainingLatch.countDown()
this.branches.subscribe { tree ->
(tree.value + 1..treeDepth - 1).forEach {
require(it in depthsSeen) { "Got ${tree.value} before $it" }
}
depthsSeen.add(tree.value)
tree.subscribeToAll()
}
}
proxy.ops.getImmediateObservableTree(treeDepth, treeBranchingFactor).subscribeToAll()
remainingLatch.await()
}
}
@Test
fun `parallel nested observables`() {
rpcDriver {
val proxy = testProxy()
val treeDepth = 2
val treeBranchingFactor = 10
val remainingLatch = CountDownLatch((intPower(treeBranchingFactor, treeDepth + 1) - 1) / (treeBranchingFactor - 1))
val depthsSeen = Collections.synchronizedSet(HashSet<Int>())
fun ObservableRose<Int>.subscribeToAll() {
remainingLatch.countDown()
branches.subscribe { tree ->
(tree.value + 1..treeDepth - 1).forEach {
require(it in depthsSeen) { "Got ${tree.value} before $it" }
}
depthsSeen.add(tree.value)
tree.subscribeToAll()
}
}
proxy.ops.getParallelObservableTree(treeDepth, treeBranchingFactor).subscribeToAll()
remainingLatch.await()
}
}
}

View File

@ -0,0 +1,315 @@
package net.corda.client.rpc
import com.codahale.metrics.Gauge
import com.codahale.metrics.JmxReporter
import com.codahale.metrics.MetricRegistry
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.pool.KryoPool
import com.google.common.base.Stopwatch
import net.corda.client.rpc.internal.RPCClientConfiguration
import net.corda.core.messaging.RPCOps
import net.corda.core.millis
import net.corda.core.random63BitValue
import net.corda.node.driver.ShutdownManager
import net.corda.node.services.messaging.RPCServerConfiguration
import net.corda.nodeapi.RPCApi
import net.corda.nodeapi.RPCKryo
import net.corda.testing.RPCDriverExposedDSLInterface
import net.corda.testing.measure
import net.corda.testing.rpcDriver
import org.apache.activemq.artemis.api.core.SimpleString
import org.junit.Ignore
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import rx.Observable
import java.time.Duration
import java.util.*
import java.util.concurrent.*
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.locks.ReentrantLock
import javax.management.ObjectName
import kotlin.concurrent.thread
import kotlin.concurrent.withLock
@Ignore("Only use this locally for profiling")
@RunWith(Parameterized::class)
class RPCPerformanceTests : AbstractRPCTest() {
companion object {
@JvmStatic @Parameterized.Parameters(name = "Mode = {0}")
fun modes() = modes(RPCTestMode.Netty)
}
private interface TestOps : RPCOps {
fun simpleReply(input: ByteArray, sizeOfReply: Int): ByteArray
}
class TestOpsImpl : TestOps {
override val protocolVersion = 0
override fun simpleReply(input: ByteArray, sizeOfReply: Int): ByteArray {
return ByteArray(sizeOfReply)
}
}
private fun RPCDriverExposedDSLInterface.testProxy(
clientConfiguration: RPCClientConfiguration,
serverConfiguration: RPCServerConfiguration
): TestProxy<TestOps> {
return testProxy<TestOps>(
TestOpsImpl(),
clientConfiguration = clientConfiguration,
serverConfiguration = serverConfiguration
)
}
private fun warmup() {
rpcDriver {
val proxy = testProxy(
RPCClientConfiguration.default,
RPCServerConfiguration.default
)
val executor = Executors.newFixedThreadPool(4)
val N = 10000
val latch = CountDownLatch(N)
for (i in 1 .. N) {
executor.submit {
proxy.ops.simpleReply(ByteArray(1024), 1024)
latch.countDown()
}
}
latch.await()
}
}
data class SimpleRPCResult(
val requestPerSecond: Double,
val averageIndividualMs: Double,
val Mbps: Double
)
@Test
fun `measure Megabytes per second for simple RPCs`() {
warmup()
val inputOutputSizes = listOf(1024, 4096, 100 * 1024)
val overallTraffic = 512 * 1024 * 1024L
measure(inputOutputSizes, (1..5)) { inputOutputSize, N ->
rpcDriver {
val proxy = testProxy(
RPCClientConfiguration.default.copy(
cacheConcurrencyLevel = 16,
observationExecutorPoolSize = 2,
producerPoolBound = 2
),
RPCServerConfiguration.default.copy(
rpcThreadPoolSize = 8,
consumerPoolSize = 2,
producerPoolBound = 8
)
)
val numberOfRequests = overallTraffic / (2 * inputOutputSize)
val timings = Collections.synchronizedList(ArrayList<Long>())
val executor = Executors.newFixedThreadPool(8)
val totalElapsed = Stopwatch.createStarted().apply {
startInjectorWithBoundedQueue(
executor = executor,
numberOfInjections = numberOfRequests.toInt(),
queueBound = 100
) {
val elapsed = Stopwatch.createStarted().apply {
proxy.ops.simpleReply(ByteArray(inputOutputSize), inputOutputSize)
}.stop().elapsed(TimeUnit.MICROSECONDS)
timings.add(elapsed)
}
}.stop().elapsed(TimeUnit.MICROSECONDS)
executor.shutdownNow()
SimpleRPCResult(
requestPerSecond = 1000000.0 * numberOfRequests.toDouble() / totalElapsed.toDouble(),
averageIndividualMs = timings.average() / 1000.0,
Mbps = (overallTraffic.toDouble() / totalElapsed.toDouble()) * (1000000.0 / (1024.0 * 1024.0))
)
}
}.forEach(::println)
}
/**
* Runs 20k RPCs per second for two minutes and publishes relevant stats to JMX.
*/
@Test
fun `consumption rate`() {
rpcDriver {
val metricRegistry = startJmxReporter()
val proxy = testProxy(
RPCClientConfiguration.default.copy(
reapIntervalMs = 100,
cacheConcurrencyLevel = 16
),
RPCServerConfiguration.default.copy(
rpcThreadPoolSize = 4,
consumerPoolSize = 4,
producerPoolBound = 4
)
)
measurePerformancePublishMetrics(
metricRegistry = metricRegistry,
parallelism = 4,
overallDurationSecond = 120.0,
injectionRatePerSecond = 20000.0,
queueSizeMetricName = "$mode.QueueSize",
workDurationMetricName = "$mode.WorkDuration",
shutdownManager = this.shutdownManager,
work = {
proxy.ops.simpleReply(ByteArray(4096), 4096)
}
)
}
}
data class BigMessagesResult(
val Mbps: Double
)
@Test
fun `big messages`() {
warmup()
measure(listOf(1)) { clientParallelism -> // TODO this hangs with more parallelism
rpcDriver {
val proxy = testProxy(
RPCClientConfiguration.default,
RPCServerConfiguration.default.copy(
consumerPoolSize = 1
)
)
val executor = Executors.newFixedThreadPool(clientParallelism)
val numberOfMessages = 1000
val bigSize = 10_000_000
val elapsed = Stopwatch.createStarted().apply {
startInjectorWithBoundedQueue(
executor = executor,
numberOfInjections = numberOfMessages,
queueBound = 4
) {
proxy.ops.simpleReply(ByteArray(bigSize), 0)
}
}.stop().elapsed(TimeUnit.MICROSECONDS)
executor.shutdownNow()
BigMessagesResult(
Mbps = bigSize.toDouble() * numberOfMessages.toDouble() / elapsed * (1000000.0 / (1024.0 * 1024.0))
)
}
}.forEach(::println)
}
}
fun measurePerformancePublishMetrics(
metricRegistry: MetricRegistry,
parallelism: Int,
overallDurationSecond: Double,
injectionRatePerSecond: Double,
queueSizeMetricName: String,
workDurationMetricName: String,
shutdownManager: ShutdownManager,
work: () -> Unit
) {
val workSemaphore = Semaphore(0)
metricRegistry.register(queueSizeMetricName, Gauge { workSemaphore.availablePermits() })
val workDurationTimer = metricRegistry.timer(workDurationMetricName)
val executor = Executors.newSingleThreadScheduledExecutor()
val workExecutor = Executors.newFixedThreadPool(parallelism)
val timings = Collections.synchronizedList(ArrayList<Long>())
for (i in 1 .. parallelism) {
workExecutor.submit {
try {
while (true) {
workSemaphore.acquire()
workDurationTimer.time {
timings.add(
Stopwatch.createStarted().apply {
work()
}.stop().elapsed(TimeUnit.MICROSECONDS)
)
}
}
} catch (throwable: Throwable) {
throwable.printStackTrace()
}
}
}
val injector = executor.scheduleAtFixedRate(
{
workSemaphore.release(injectionRatePerSecond.toInt())
},
0,
1,
TimeUnit.SECONDS
)
shutdownManager.registerShutdown {
injector.cancel(true)
workExecutor.shutdownNow()
executor.shutdownNow()
workExecutor.awaitTermination(1, TimeUnit.SECONDS)
executor.awaitTermination(1, TimeUnit.SECONDS)
}
Thread.sleep((overallDurationSecond * 1000).toLong())
}
fun startInjectorWithBoundedQueue(
executor: ExecutorService,
numberOfInjections: Int,
queueBound: Int,
work: () -> Unit
) {
val remainingLatch = CountDownLatch(numberOfInjections)
val queuedCount = AtomicInteger(0)
val lock = ReentrantLock()
val canQueueAgain = lock.newCondition()
val injectorShutdown = AtomicBoolean(false)
val injector = thread(name = "injector") {
while (true) {
if (injectorShutdown.get()) break
executor.submit {
work()
if (queuedCount.decrementAndGet() < queueBound / 2) {
lock.withLock {
canQueueAgain.signal()
}
}
remainingLatch.countDown()
}
if (queuedCount.incrementAndGet() > queueBound) {
lock.withLock {
canQueueAgain.await()
}
}
}
}
remainingLatch.await()
injectorShutdown.set(true)
injector.join()
}
fun RPCDriverExposedDSLInterface.startJmxReporter(): MetricRegistry {
val metricRegistry = MetricRegistry()
val jmxReporter = thread {
JmxReporter.
forRegistry(metricRegistry).
inDomain("net.corda").
createsObjectNamesWith { _, domain, name ->
// Make the JMX hierarchy a bit better organised.
val category = name.substringBefore('.')
val subName = name.substringAfter('.', "")
if (subName == "")
ObjectName("$domain:name=$category")
else
ObjectName("$domain:type=$category,name=$subName")
}.
build().
start()
}
shutdownManager.registerShutdown {
jmxReporter.interrupt()
jmxReporter.join()
}
return metricRegistry
}

View File

@ -1,85 +0,0 @@
package net.corda.client.rpc
import net.corda.core.messaging.RPCOps
import net.corda.node.services.messaging.requirePermission
import net.corda.nodeapi.PermissionException
import net.corda.nodeapi.User
import org.junit.After
import org.junit.Test
import kotlin.test.assertFailsWith
class RPCPermissionsTest : AbstractClientRPCTest() {
companion object {
const val DUMMY_FLOW = "StartFlow.net.corda.flows.DummyFlow"
const val OTHER_FLOW = "StartFlow.net.corda.flows.OtherFlow"
const val ALL_ALLOWED = "ALL"
}
lateinit var proxy: TestOps
@After
fun shutdown() {
safeClose(proxy)
}
/*
* RPC operation.
*/
interface TestOps : RPCOps {
fun validatePermission(str: String)
}
class TestOpsImpl : TestOps {
override val protocolVersion = 1
override fun validatePermission(str: String) = requirePermission(str)
}
/**
* Create an RPC proxy for the given user.
*/
private fun proxyFor(rpcUser: User): TestOps = rpcProxyFor(rpcUser, TestOpsImpl(), TestOps::class.java)
private fun userOf(name: String, permissions: Set<String>) = User(name, "password", permissions)
@Test
fun `empty user cannot use any flows`() {
val emptyUser = userOf("empty", emptySet())
proxy = proxyFor(emptyUser)
assertFailsWith(PermissionException::class,
"User ${emptyUser.username} should not be allowed to use $DUMMY_FLOW.",
{ proxy.validatePermission(DUMMY_FLOW) })
}
@Test
fun `admin user can use any flow`() {
val adminUser = userOf("admin", setOf(ALL_ALLOWED))
proxy = proxyFor(adminUser)
proxy.validatePermission(DUMMY_FLOW)
}
@Test
fun `joe user is allowed to use DummyFlow`() {
val joeUser = userOf("joe", setOf(DUMMY_FLOW))
proxy = proxyFor(joeUser)
proxy.validatePermission(DUMMY_FLOW)
}
@Test
fun `joe user is not allowed to use OtherFlow`() {
val joeUser = userOf("joe", setOf(DUMMY_FLOW))
proxy = proxyFor(joeUser)
assertFailsWith(PermissionException::class,
"User ${joeUser.username} should not be allowed to use $OTHER_FLOW",
{ proxy.validatePermission(OTHER_FLOW) })
}
@Test
fun `check ALL is implemented the correct way round`() {
val joeUser = userOf("joe", setOf(DUMMY_FLOW))
proxy = proxyFor(joeUser)
assertFailsWith(PermissionException::class,
"Permission $ALL_ALLOWED should not do anything for User ${joeUser.username}",
{ proxy.validatePermission(ALL_ALLOWED) })
}
}

View File

@ -0,0 +1,93 @@
package net.corda.client.rpc
import net.corda.core.messaging.RPCOps
import net.corda.node.services.messaging.requirePermission
import net.corda.node.services.messaging.getRpcContext
import net.corda.nodeapi.PermissionException
import net.corda.nodeapi.User
import net.corda.testing.RPCDriverExposedDSLInterface
import net.corda.testing.rpcDriver
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import kotlin.test.assertFailsWith
@RunWith(Parameterized::class)
class RPCPermissionsTests : AbstractRPCTest() {
companion object {
const val DUMMY_FLOW = "StartFlow.net.corda.flows.DummyFlow"
const val OTHER_FLOW = "StartFlow.net.corda.flows.OtherFlow"
const val ALL_ALLOWED = "ALL"
}
/*
* RPC operation.
*/
interface TestOps : RPCOps {
fun validatePermission(str: String)
}
class TestOpsImpl : TestOps {
override val protocolVersion = 1
override fun validatePermission(str: String) = getRpcContext().requirePermission(str)
}
/**
* Create an RPC proxy for the given user.
*/
private fun RPCDriverExposedDSLInterface.testProxyFor(rpcUser: User) = testProxy<TestOps>(TestOpsImpl(), rpcUser).ops
private fun userOf(name: String, permissions: Set<String>) = User(name, "password", permissions)
@Test
fun `empty user cannot use any flows`() {
rpcDriver {
val emptyUser = userOf("empty", emptySet())
val proxy = testProxyFor(emptyUser)
assertFailsWith(PermissionException::class,
"User ${emptyUser.username} should not be allowed to use $DUMMY_FLOW.",
{ proxy.validatePermission(DUMMY_FLOW) })
}
}
@Test
fun `admin user can use any flow`() {
rpcDriver {
val adminUser = userOf("admin", setOf(ALL_ALLOWED))
val proxy = testProxyFor(adminUser)
proxy.validatePermission(DUMMY_FLOW)
}
}
@Test
fun `joe user is allowed to use DummyFlow`() {
rpcDriver {
val joeUser = userOf("joe", setOf(DUMMY_FLOW))
val proxy = testProxyFor(joeUser)
proxy.validatePermission(DUMMY_FLOW)
}
}
@Test
fun `joe user is not allowed to use OtherFlow`() {
rpcDriver {
val joeUser = userOf("joe", setOf(DUMMY_FLOW))
val proxy = testProxyFor(joeUser)
assertFailsWith(PermissionException::class,
"User ${joeUser.username} should not be allowed to use $OTHER_FLOW",
{ proxy.validatePermission(OTHER_FLOW) })
}
}
@Test
fun `check ALL is implemented the correct way round` () {
rpcDriver {
val joeUser = userOf("joe", setOf(DUMMY_FLOW))
val proxy = testProxyFor(joeUser)
assertFailsWith(PermissionException::class,
"Permission $ALL_ALLOWED should not do anything for User ${joeUser.username}",
{ proxy.validatePermission(ALL_ALLOWED) })
}
}
}

View File

@ -0,0 +1,25 @@
package net.corda.client.rpc
import java.io.InputStream
class RepeatingBytesInputStream(val bytesToRepeat: ByteArray, val numberOfBytes: Int) : InputStream() {
private var bytesLeft = numberOfBytes
override fun available() = bytesLeft
override fun read(): Int {
if (bytesLeft == 0) {
return -1
} else {
bytesLeft--
return bytesToRepeat[(numberOfBytes - bytesLeft) % bytesToRepeat.size].toInt()
}
}
override fun read(byteArray: ByteArray, offset: Int, length: Int): Int {
val until = Math.min(Math.min(offset + length, byteArray.size), offset + bytesLeft)
for (i in offset .. until - 1) {
byteArray[i] = bytesToRepeat[(numberOfBytes - bytesLeft + i - offset) % bytesToRepeat.size]
}
val bytesRead = until - offset
bytesLeft -= bytesRead
return if (bytesRead == 0 && bytesLeft == 0) -1 else bytesRead
}
}